Spaces:
Paused
Paused
import whisper | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
# from gradio import inputs # Import the 'inputs' module from 'gradio' | |
Asr_model = whisper.load_model("base") | |
Asr_model.device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_name = "ai4bharat/Airavata" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device) | |
def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True): | |
formatted_text = "" | |
for message in messages: | |
if message["role"] == "system": | |
formatted_text += "<|system|>\n" + message["content"] + "\n" | |
elif message["role"] == "user": | |
formatted_text += "<|user|>\n" + message["content"] + "\n" | |
elif message["role"] == "assistant": | |
formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n" | |
else: | |
raise ValueError( | |
"Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format( | |
message["role"] | |
) | |
) | |
formatted_text += "<|assistant|>\n" | |
formatted_text = bos + formatted_text if add_bos else formatted_text | |
return formatted_text | |
def inference(input_prompt, model, tokenizer): | |
input_prompt = create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False) | |
encodings = tokenizer(input_prompt, padding=True, return_tensors="pt") | |
encodings = encodings.to(device) | |
with torch.inference_mode(): # Add missing import statement for torch.inference_mode() | |
outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250) | |
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
input_prompt = tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) | |
output_text = output_text[len(input_prompt):] | |
return output_text | |
def transcribe(audio): | |
#time.sleep(3) | |
# load audio and pad/trim it to fit 30 seconds | |
audio = whisper.load_audio(audio) | |
audio = whisper.pad_or_trim(audio) | |
# make log-Mel spectrogram and move to the same device as the model | |
mel = whisper.log_mel_spectrogram(audio).to(model.device) | |
# detect the spoken language | |
_, probs = model.detect_language(mel) | |
print(f"Detected language: {max(probs, key=probs.get)}") | |
# decode the audio | |
options = whisper.DecodingOptions() | |
result = whisper.decode(model, mel, options) | |
return result.text | |
def chat_interface(audio): | |
message = transcribe(audio) | |
outputs = inference(message, model, tokenizer) | |
return outputs | |
gr.Interface( | |
title="CAMAI - Centralized Actionable Multimodal Agri Assistant on Edge Intelligence for Farmers", | |
fn=chat_interface, | |
inputs=[ | |
gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Audio file") | |
], | |
outputs=[ | |
"textbox" | |
], | |
theme="darkly" | |
).launch() | |