re / app.py
SivaResearch's picture
Update app.py
0d58b13 verified
raw
history blame
3.66 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("ai4bharat/Airavata")
model = AutoModelForCausalLM.from_pretrained("ai4bharat/Airavata")
def generate_response(prompt):
input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=50)
output_ids = model.generate(input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return response
iface = gr.Interface(
fn=generate_response,
inputs="text",
outputs="text",
live=True,
title="Airavata LLMs Chatbot",
description="Ask me anything, and I'll generate a response!",
theme="light",
)
iface.launch()
# import gradio as gr
# import torch
# from transformers import AutoTokenizer, AutoModelForCausalLM
# device = "cuda" if torch.cuda.is_available() else "cpu"
# def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
# formatted_text = ""
# for message in messages:
# if message["role"] == "system":
# formatted_text += "\n" + message["content"] + "\n"
# elif message["role"] == "user":
# formatted_text += "\n" + message["content"] + "\n"
# elif message["role"] == "assistant":
# formatted_text += "\n" + message["content"].strip() + eos + "\n"
# else:
# raise ValueError(
# "Tulu chat template only supports 'system', 'user', and 'assistant' roles. Invalid role: {}.".format(
# message["role"]
# )
# )
# formatted_text += "\n"
# formatted_text = bos + formatted_text if add_bos else formatted_text
# return formatted_text
# def inference(input_prompts, model, tokenizer):
# input_prompts = [
# create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
# for input_prompt in input_prompts
# ]
# encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
# encodings = encodings.to(device)
# with torch.no_grad():
# outputs = model.generate(encodings.input_ids, do_sample=False, max_length=250)
# output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
# input_prompts = [
# tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
# ]
# output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
# return output_texts
# model_name = "ai4bharat/Airavata"
# tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
# tokenizer.pad_token = tokenizer.eos_token
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
# examples = [
# ["मुझे अपने करियर के बारे में सुझाव दो", "मैं कैसे अध्ययन कर सकता हूँ?"],
# ["कृपया मुझे एक कहानी सुनाएं", "ताजमहल के बारे में कुछ बताएं"],
# ["मेरा नाम क्या है?", "आपका पसंदीदा फिल्म कौन सी है?"],
# ]
# iface = gr.Chat(
# model_fn=lambda input_prompts: inference(input_prompts, model, tokenizer),
# inputs=["text"],
# outputs="text",
# examples=examples,
# title="Airavata Chatbot",
# theme="light", # Optional: Set a light theme
# )
# iface.launch()