SivaResearch commited on
Commit
d9aa2f6
·
verified ·
1 Parent(s): 52c953c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ device = "cuda" if torch.cuda.is_available() else "cpu"
6
+
7
+ # Load model and tokenizer
8
+ model_name = "ai4bharat/Airavata"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
10
+ tokenizer.pad_token = tokenizer.eos_token
11
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
12
+
13
+ # Function for generating responses
14
+ def inference(message):
15
+ prompt = create_prompt_with_chat_format([{"role": "user", "content": message}], add_bos=False)
16
+ encoding = tokenizer(prompt, return_tensors="pt").to(device)
17
+ with torch.inference_mode():
18
+ output = model.generate(encoding.input_ids, do_sample=False, max_new_tokens=250)
19
+ return tokenizer.decode(output[0], skip_special_tokens=True)[len(message) :]
20
+
21
+ def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
22
+ formatted_text = ""
23
+ for message in messages:
24
+ if message["role"] == "system":
25
+ formatted_text += "<|system|>\n" + message["content"] + "\n"
26
+ elif message["role"] == "user":
27
+ formatted_text += "<|user|>\n" + message["content"] + "\n"
28
+ elif message["role"] == "assistant":
29
+ formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
30
+ else:
31
+ raise ValueError(
32
+ "Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
33
+ message["role"]
34
+ )
35
+ )
36
+ formatted_text += "<|assistant|>\n"
37
+ formatted_text = bos + formatted_text if add_bos else formatted_text
38
+ return formatted_text
39
+
40
+ # Create Gradio chat interface
41
+ iface = gr.ChatInterface(
42
+ fn=inference,
43
+ inputs=[gr.Textbox(lines=3, label="Ask me anything")],
44
+ outputs=gr.Textbox(label="Response", live=True),
45
+ title="Airavata Chatbot",
46
+ theme="light", # Optional: Set a light theme
47
+ )
48
+
49
+ iface.launch()