re / app.py
SivaResearch's picture
Update app.py
a4eced8 verified
raw
history blame
3.15 kB
import whisper
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# from gradio import inputs # Import the 'inputs' module from 'gradio'
Asr_model = whisper.load_model("base")
Asr_model.device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "ai4bharat/Airavata"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
formatted_text = ""
for message in messages:
if message["role"] == "system":
formatted_text += "<|system|>\n" + message["content"] + "\n"
elif message["role"] == "user":
formatted_text += "<|user|>\n" + message["content"] + "\n"
elif message["role"] == "assistant":
formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
else:
raise ValueError(
"Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
message["role"]
)
)
formatted_text += "<|assistant|>\n"
formatted_text = bos + formatted_text if add_bos else formatted_text
return formatted_text
def inference(input_prompt, model, tokenizer):
input_prompt = create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
encodings = tokenizer(input_prompt, padding=True, return_tensors="pt")
encodings = encodings.to(device)
with torch.inference_mode(): # Add missing import statement for torch.inference_mode()
outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
input_prompt = tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True)
output_text = output_text[len(input_prompt):]
return output_text
def transcribe(audio):
#time.sleep(3)
# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio(audio)
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device)
# detect the spoken language
_, probs = model.detect_language(mel)
print(f"Detected language: {max(probs, key=probs.get)}")
# decode the audio
options = whisper.DecodingOptions()
result = whisper.decode(model, mel, options)
return result.text
def chat_interface(audio):
message = transcribe(audio)
outputs = inference(message, model, tokenizer)
return outputs
gr.Interface(
title="CAMAI - Centralized Actionable Multimodal Agri Assistant on Edge Intelligence for Farmers",
fn=chat_interface,
inputs=[
gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Audio file")
],
outputs=[
"textbox"
],
theme="darkly"
).launch()