prithivMLmods's picture
Update app.py
a592e13 verified
raw
history blame
6.14 kB
import os
import time
from threading import Thread
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers.image_utils import load_image
import edge_tts
import asyncio
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
# Load models
MODEL_ID = "prithivMLmods/FastThink-0.5B-Tiny"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16).eval()
# For multimodal OCR processing
OCR_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
ocr_processor = AutoProcessor.from_pretrained(OCR_MODEL_ID, trust_remote_code=True)
ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(OCR_MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16).to("cuda").eval()
TTS_VOICES = [
"en-US-JennyNeural", # @tts1
"en-US-GuyNeural", # @tts2
"en-US-AriaNeural", # @tts3
"en-US-DavisNeural", # @tts4
"en-US-JaneNeural", # @tts5
"en-US-JasonNeural", # @tts6
"en-US-NancyNeural", # @tts7
"en-US-TonyNeural", # @tts8
]
# Handle text-to-speech conversion
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
"""Convert text to speech using Edge TTS and save as MP3"""
communicate = edge_tts.Communicate(text, voice)
await communicate.save(output_file)
return output_file
@spaces.GPU
def generate(
input_dict,
history,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2
):
"""Generates chatbot response and handles TTS requests with multimodal support"""
text = input_dict.get("text", "")
files = input_dict.get("files", [])
# Handle multimodal OCR processing
if files:
images = [load_image(image) for image in files]
else:
images = []
# Check if the message is TTS request
tts_prefix = "@tts"
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 9))
voice_index = next((i for i in range(1, 9) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
if is_tts and voice_index:
voice = TTS_VOICES[voice_index - 1]
text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
else:
voice = None
text = text.replace(tts_prefix, "").strip()
# If images are provided, combine image and text for the prompt
if images:
# Prepare images as part of the conversation
messages = [
{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in images],
{"type": "text", "text": text},
],
}
]
prompt = ocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = ocr_processor(
text=[prompt],
images=images,
return_tensors="pt",
padding=True,
).to("cuda")
else:
# Normal text-only input
conversation = [*history, {"role": "user", "content": text}]
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
# Start generation in a separate thread
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Collect generated text
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
final_response = "".join(outputs)
# Handle text-to-speech
if is_tts and voice:
output_file = asyncio.run(text_to_speech(final_response, voice))
yield gr.Audio(output_file, autoplay=True) # Return playable audio
else:
yield final_response # Return text response
# Gradio Interface
demo = gr.Interface(
fn=generate,
inputs=[
gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), # Multimodal input
gr.Textbox(label="Chat History", value="", placeholder="Previous conversation history"),
gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
],
outputs=["text", "audio"],
examples=[
["@tts1 Who is Nikola Tesla, and why did he die?"],
["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
["Write a Python function to check if a number is prime."],
["@tts2 What causes rainbows to form?"],
["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
["@tts5 What is the capital of France?"],
],
stop_btn="Stop Generation",
description="QwQ Edge: A Chatbot with Text-to-Speech and Multimodal Support",
css=css,
fill_height=True,
)
if __name__ == "__main__":
demo.launch()