Spaces:
Paused
Paused
File size: 3,188 Bytes
c627930 9e51870 397093f 9e51870 a4eced8 a15b17b 9e51870 c627930 9e51870 fe967b9 9e51870 fe967b9 9e51870 fe967b9 9e51870 fe967b9 9e51870 c627930 5b6b5e8 c627930 12e3bb4 9e51870 afcd5ad 08931c2 afcd5ad d2f2a9c afcd5ad a4eced8 afcd5ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import whisper
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# from gradio import inputs # Import the 'inputs' module from 'gradio'
Asr_model = whisper.load_model("base")
Asr_model.device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "ai4bharat/Airavata"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
formatted_text = ""
for message in messages:
if message["role"] == "system":
formatted_text += "<|system|>\n" + message["content"] + "\n"
elif message["role"] == "user":
formatted_text += "<|user|>\n" + message["content"] + "\n"
elif message["role"] == "assistant":
formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
else:
raise ValueError(
"Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
message["role"]
)
)
formatted_text += "<|assistant|>\n"
formatted_text = bos + formatted_text if add_bos else formatted_text
return formatted_text
def inference(input_prompt, model, tokenizer):
input_prompt = create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
encodings = tokenizer(input_prompt, padding=True, return_tensors="pt")
encodings = encodings.to(device)
with torch.inference_mode(): # Add missing import statement for torch.inference_mode()
outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
input_prompt = tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True)
output_text = output_text[len(input_prompt):]
return output_text
def transcribe(audio):
#time.sleep(3)
# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio(audio)
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device)
# detect the spoken language
_, probs = model.detect_language(mel)
print(f"Detected language: {max(probs, key=probs.get)}")
# decode the audio
options = whisper.DecodingOptions()
result = whisper.decode(model, mel, options)
return result.text
def chat_interface(audio_tuple ):
audio_path = audio_tuple[0] if isinstance(audio_tuple, tuple) else audio_tuple
message = transcribe(audio)
outputs = inference(message, model, tokenizer)
return outputs
gr.Interface(
title="CAMAI - Centralized Actionable Multimodal Agri Assistant on Edge Intelligence for Farmers",
fn=chat_interface,
inputs=[
gr.Audio(sources=["microphone"])
],
outputs=[
"textbox"
],
theme="darkly"
).launch()
|