Spaces:
Running
Running
"""from fastapi import FastAPI, UploadFile, Form | |
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse | |
import os | |
import shutil | |
from PIL import Image | |
from transformers import ViltProcessor, ViltForQuestionAnswering | |
from gtts import gTTS | |
import torch | |
import tempfile | |
import gradio as gr | |
app = FastAPI() | |
# Load VQA Model | |
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
def answer_question_from_image(image, question): | |
if image is None or not question.strip(): | |
return "Please upload an image and ask a question.", None | |
# Process with model | |
inputs = vqa_processor(image, question, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = vqa_model(**inputs) | |
predicted_id = outputs.logits.argmax(-1).item() | |
answer = vqa_model.config.id2label[predicted_id] | |
# Generate TTS audio | |
try: | |
tts = gTTS(text=answer) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: | |
tts.save(tmp.name) | |
audio_path = tmp.name | |
except Exception as e: | |
return f"Answer: {answer}\n\nโ ๏ธ Audio generation error: {e}", None | |
return answer, audio_path | |
def process_image_question(image: Image.Image, question: str): | |
answer, audio_path = answer_question_from_image(image, question) | |
return answer, audio_path | |
gui = gr.Interface( | |
fn=process_image_question, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Textbox(lines=2, placeholder="Ask a question about the image...", label="Question") | |
], | |
outputs=[ | |
gr.Textbox(label="Answer", lines=5), | |
gr.Audio(label="Answer (Audio)", type="filepath") | |
], | |
title="๐ง Image QA with Voice", | |
description="Upload an image and ask a question. You'll get a text + spoken answer." | |
) | |
app = gr.mount_gradio_app(app, gui, path="/") | |
@app.get("/") | |
def home(): | |
return RedirectResponse(url="/") """ | |
from fastapi import FastAPI | |
from fastapi.responses import RedirectResponse | |
import tempfile | |
import torch | |
from PIL import Image | |
from gtts import gTTS | |
import gradio as gr | |
from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM | |
app = FastAPI() | |
# Load Models | |
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M") | |
gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M") | |
# Rewrite answer to human-like sentence | |
def rewrite_answer(question: str, short_answer: str) -> str: | |
prompt = f"Question: {question}\nAnswer: {short_answer}\nRewrite the answer into a complete sentence:" | |
inputs = gpt_tokenizer(prompt, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = gpt_model.generate( | |
**inputs, | |
max_new_tokens=40, | |
do_sample=False, | |
pad_token_id=gpt_tokenizer.eos_token_id, | |
temperature=0.7, | |
) | |
result = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the sentence after the "Rewrite..." line | |
return result.split("Rewrite the answer into a complete sentence:")[-1].strip() | |
def answer_question_from_image(image: Image.Image, question: str): | |
if image is None or not question.strip(): | |
return "โ Please upload an image and type a question.", None | |
inputs = vqa_processor(image, question, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = vqa_model(**inputs) | |
predicted_id = outputs.logits.argmax(-1).item() | |
short_answer = vqa_model.config.id2label[predicted_id] | |
full_sentence = rewrite_answer(question, short_answer) | |
try: | |
tts = gTTS(text=full_sentence) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: | |
tts.save(tmp.name) | |
audio_path = tmp.name | |
except Exception as e: | |
return f"{full_sentence}\n\nโ ๏ธ Audio generation error: {e}", None | |
return full_sentence, audio_path | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=answer_question_from_image, | |
inputs=[ | |
gr.Image(type="pil", label="๐ผ๏ธ Upload Image"), | |
gr.Textbox(lines=2, placeholder="Ask a question about the image", label="โ Question") | |
], | |
outputs=[ | |
gr.Textbox(label="๐ฌ Answer"), | |
gr.Audio(label="๐ Voice Output", type="filepath") | |
], | |
title="๐ง Image QA with Voice (VQA + GPT-Neo)", | |
description="Ask a question about an image and get a full sentence answer, including audio!" | |
) | |
app = gr.mount_gradio_app(app, interface, path="/") | |
def home(): | |
return RedirectResponse(url="/") | |