shuvom commited on
Commit
7ec0f0f
Β·
1 Parent(s): 704f750

add app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, TextStreamer
3
+ from threading import Thread
4
+ import gradio as gr
5
+ from peft import PeftModel
6
+
7
+ model_name_or_path = "sarvamai/OpenHathi-7B-Hi-v0.1-Base"
8
+ peft_model_id = "shuvom/OpenHathi-7B-FT-v0.1_SI"
9
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, device_map="auto")
10
+
11
+ # tokenizer.chat_template = chat_template
12
+ tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
13
+ # make embedding resizing configurable?
14
+ model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
15
+
16
+ model = PeftModel.from_pretrained(model, peft_model_id)
17
+
18
+ class ChatCompletion:
19
+ def __init__(self, model, tokenizer, system_prompt=None):
20
+ self.model = model
21
+ self.tokenizer = tokenizer
22
+ self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
23
+ self.print_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
24
+ # set the model in inference mode
25
+ self.model.eval()
26
+ self.system_prompt = system_prompt
27
+
28
+ def get_completion(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
29
+ if temperature < 1e-2:
30
+ temperature = 1e-2
31
+ messages = []
32
+ if message_history is not None:
33
+ messages.extend(message_history)
34
+ elif system_prompt or self.system_prompt:
35
+ system_prompt = system_prompt or self.system_prompt
36
+ messages.append({"role": "system", "content":system_prompt})
37
+ messages.append({"role": "user", "content": prompt})
38
+ chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
39
+
40
+ inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
41
+ # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
42
+ generation_kwargs = dict(max_new_tokens=max_new_tokens,
43
+ temperature=temperature,
44
+ top_p=0.95,
45
+ do_sample=True,
46
+ eos_token_id=tokenizer.eos_token_id,
47
+ repetition_penalty=1.2
48
+ )
49
+ generated_text = self.model.generate(**inputs, streamer=self.print_streamer, **generation_kwargs)
50
+ return generated_text
51
+
52
+ def get_chat_completion(self, message, history):
53
+ messages = []
54
+ if self.system_prompt:
55
+ messages.append({"role": "system", "content":self.system_prompt})
56
+ for user_message, assistant_message in history:
57
+ messages.append({"role": "user", "content": user_message})
58
+ messages.append({"role": "system", "content": assistant_message})
59
+ messages.append({"role": "user", "content": message})
60
+ chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
+
62
+ inputs = self.tokenizer(chat_prompt, return_tensors="pt")
63
+ # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
64
+ generation_kwargs = dict(inputs,
65
+ streamer=self.streamer,
66
+ max_new_tokens=2048,
67
+ temperature=0.2,
68
+ top_p=0.95,
69
+ eos_token_id=tokenizer.eos_token_id,
70
+ do_sample=True,
71
+ repetition_penalty=1.2,
72
+ )
73
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
74
+ thread.start()
75
+ generated_text = ""
76
+ for new_text in self.streamer:
77
+ generated_text += new_text.replace(self.tokenizer.eos_token, "")
78
+ yield generated_text
79
+ thread.join()
80
+ return generated_text
81
+
82
+ def get_completion_without_streaming(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
83
+ if temperature < 1e-2:
84
+ temperature = 1e-2
85
+ messages = []
86
+ if message_history is not None:
87
+ messages.extend(message_history)
88
+ elif system_prompt or self.system_prompt:
89
+ system_prompt = system_prompt or self.system_prompt
90
+ messages.append({"role": "system", "content":system_prompt})
91
+ messages.append({"role": "user", "content": prompt})
92
+ chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
93
+
94
+ inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
95
+ # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
96
+ generation_kwargs = dict(max_new_tokens=max_new_tokens,
97
+ temperature=temperature,
98
+ top_p=0.95,
99
+ do_sample=True,
100
+ repetition_penalty=1.1)
101
+ outputs = self.model.generate(**inputs, **generation_kwargs)
102
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
103
+ return generated_text
104
+
105
+ text_generator = ChatCompletion(model, tokenizer, system_prompt="You are a native Hindi speaker who can converse at expert level in both Hindi and colloquial Hinglish.")
106
+
107
+ gr.ChatInterface(text_generator.get_chat_completion).queue().launch(debug=True)