Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load model and tokenizer | |
model_name = "ai4bharat/Airavata" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device) | |
# Function for generating responses | |
def inference(message): | |
prompt = create_prompt_with_chat_format([{"role": "user", "content": message}], add_bos=False) | |
encoding = tokenizer(prompt, return_tensors="pt").to(device) | |
with torch.inference_mode(): | |
output = model.generate(encoding.input_ids, do_sample=False, max_new_tokens=250) | |
return tokenizer.decode(output[0], skip_special_tokens=True)[len(message) :] | |
def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True): | |
formatted_text = "" | |
for message in messages: | |
if message["role"] == "system": | |
formatted_text += "<|system|>\n" + message["content"] + "\n" | |
elif message["role"] == "user": | |
formatted_text += "<|user|>\n" + message["content"] + "\n" | |
elif message["role"] == "assistant": | |
formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n" | |
else: | |
raise ValueError( | |
"Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format( | |
message["role"] | |
) | |
) | |
formatted_text += "<|assistant|>\n" | |
formatted_text = bos + formatted_text if add_bos else formatted_text | |
return formatted_text | |
# Create Gradio chat interface | |
iface = gr.ChatInterface( | |
fn=inference, | |
inputs=[gr.Textbox(lines=3, label="Ask me anything")], | |
outputs=gr.Textbox(label="Response", live=True), | |
title="Airavata Chatbot", | |
theme="light", # Optional: Set a light theme | |
) | |
iface.launch() | |