SivaResearch commited on
Commit
c4807ae
·
verified ·
1 Parent(s): b5f8d32

updated with examples

Browse files
Files changed (1) hide show
  1. app.py +40 -24
app.py CHANGED
@@ -4,44 +4,60 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
6
 
7
- # Load model and tokenizer
8
- model_name = "ai4bharat/Airavata"
9
- tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
10
- tokenizer.pad_token = tokenizer.eos_token
11
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
12
-
13
- # Function for generating responses
14
- def inference(message):
15
- prompt = create_prompt_with_chat_format([{"role": "user", "content": message}], add_bos=False)
16
- encoding = tokenizer(prompt, return_tensors="pt").to(device)
17
- with torch.inference_mode():
18
- output = model.generate(encoding.input_ids, do_sample=False, max_new_tokens=250)
19
- return tokenizer.decode(output[0], skip_special_tokens=True)[len(message) :]
20
-
21
  def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
22
  formatted_text = ""
23
  for message in messages:
24
  if message["role"] == "system":
25
- formatted_text += "<|system|>\n" + message["content"] + "\n"
26
  elif message["role"] == "user":
27
- formatted_text += "<|user|>\n" + message["content"] + "\n"
28
  elif message["role"] == "assistant":
29
- formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
30
  else:
31
  raise ValueError(
32
- "Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
33
  message["role"]
34
  )
35
  )
36
- formatted_text += "<|assistant|>\n"
37
  formatted_text = bos + formatted_text if add_bos else formatted_text
38
  return formatted_text
39
 
40
- # Create Gradio chat interface
41
- iface = gr.ChatInterface(
42
- fn=inference,
43
- inputs=[gr.Textbox(lines=3, label="Ask me anything")],
44
- outputs=gr.Textbox(label="Response", live=True),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  title="Airavata Chatbot",
46
  theme="light", # Optional: Set a light theme
47
  )
 
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
8
  formatted_text = ""
9
  for message in messages:
10
  if message["role"] == "system":
11
+ formatted_text += "\n" + message["content"] + "\n"
12
  elif message["role"] == "user":
13
+ formatted_text += "\n" + message["content"] + "\n"
14
  elif message["role"] == "assistant":
15
+ formatted_text += "\n" + message["content"].strip() + eos + "\n"
16
  else:
17
  raise ValueError(
18
+ "Tulu chat template only supports 'system', 'user', and 'assistant' roles. Invalid role: {}.".format(
19
  message["role"]
20
  )
21
  )
22
+ formatted_text += "\n"
23
  formatted_text = bos + formatted_text if add_bos else formatted_text
24
  return formatted_text
25
 
26
+ def inference(input_prompts, model, tokenizer):
27
+ input_prompts = [
28
+ create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
29
+ for input_prompt in input_prompts
30
+ ]
31
+
32
+ encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
33
+ encodings = encodings.to(device)
34
+
35
+ with torch.no_grad():
36
+ outputs = model.generate(encodings.input_ids, do_sample=False, max_length=250)
37
+
38
+ output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
39
+
40
+ input_prompts = [
41
+ tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
42
+ ]
43
+ output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
44
+ return output_texts
45
+
46
+ model_name = "ai4bharat/Airavata"
47
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
50
+ examples = [
51
+ ["मुझे अपने करियर के बारे में सुझाव दो", "मैं कैसे अध्ययन कर सकता हूँ?"],
52
+ ["कृपया मुझे एक कहानी सुनाएं", "ताजमहल के बारे में कुछ बताएं"],
53
+ ["मेरा नाम क्या है?", "आपका पसंदीदा फिल्म कौन सी है?"],
54
+ ]
55
+
56
+ iface = gr.Chat(
57
+ model_fn=lambda input_prompts: inference(input_prompts, model, tokenizer),
58
+ inputs=["text"],
59
+ outputs="text",
60
+ examples=examples,
61
  title="Airavata Chatbot",
62
  theme="light", # Optional: Set a light theme
63
  )