Summarization / app.py
ikraamkb's picture
Update app.py
ffd57c2 verified
raw
history blame
4.85 kB
"""from fastapi import FastAPI, UploadFile, Form
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
import os
import shutil
from PIL import Image
from transformers import ViltProcessor, ViltForQuestionAnswering
from gtts import gTTS
import torch
import tempfile
import gradio as gr
app = FastAPI()
# Load VQA Model
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
def answer_question_from_image(image, question):
if image is None or not question.strip():
return "Please upload an image and ask a question.", None
# Process with model
inputs = vqa_processor(image, question, return_tensors="pt")
with torch.no_grad():
outputs = vqa_model(**inputs)
predicted_id = outputs.logits.argmax(-1).item()
answer = vqa_model.config.id2label[predicted_id]
# Generate TTS audio
try:
tts = gTTS(text=answer)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
tts.save(tmp.name)
audio_path = tmp.name
except Exception as e:
return f"Answer: {answer}\n\nโš ๏ธ Audio generation error: {e}", None
return answer, audio_path
def process_image_question(image: Image.Image, question: str):
answer, audio_path = answer_question_from_image(image, question)
return answer, audio_path
gui = gr.Interface(
fn=process_image_question,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(lines=2, placeholder="Ask a question about the image...", label="Question")
],
outputs=[
gr.Textbox(label="Answer", lines=5),
gr.Audio(label="Answer (Audio)", type="filepath")
],
title="๐Ÿง  Image QA with Voice",
description="Upload an image and ask a question. You'll get a text + spoken answer."
)
app = gr.mount_gradio_app(app, gui, path="/")
@app.get("/")
def home():
return RedirectResponse(url="/") """
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
import tempfile
import torch
from PIL import Image
from gtts import gTTS
import gradio as gr
from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
app = FastAPI()
# Load Models
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
# Rewrite answer to human-like sentence
def rewrite_answer(question: str, short_answer: str) -> str:
prompt = f"Question: {question}\nAnswer: {short_answer}\nRewrite the answer into a complete sentence:"
inputs = gpt_tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = gpt_model.generate(
**inputs,
max_new_tokens=40,
do_sample=False,
pad_token_id=gpt_tokenizer.eos_token_id,
temperature=0.7,
)
result = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the sentence after the "Rewrite..." line
return result.split("Rewrite the answer into a complete sentence:")[-1].strip()
def answer_question_from_image(image: Image.Image, question: str):
if image is None or not question.strip():
return "โ— Please upload an image and type a question.", None
inputs = vqa_processor(image, question, return_tensors="pt")
with torch.no_grad():
outputs = vqa_model(**inputs)
predicted_id = outputs.logits.argmax(-1).item()
short_answer = vqa_model.config.id2label[predicted_id]
full_sentence = rewrite_answer(question, short_answer)
try:
tts = gTTS(text=full_sentence)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
tts.save(tmp.name)
audio_path = tmp.name
except Exception as e:
return f"{full_sentence}\n\nโš ๏ธ Audio generation error: {e}", None
return full_sentence, audio_path
# Gradio Interface
interface = gr.Interface(
fn=answer_question_from_image,
inputs=[
gr.Image(type="pil", label="๐Ÿ–ผ๏ธ Upload Image"),
gr.Textbox(lines=2, placeholder="Ask a question about the image", label="โ“ Question")
],
outputs=[
gr.Textbox(label="๐Ÿ’ฌ Answer"),
gr.Audio(label="๐Ÿ”Š Voice Output", type="filepath")
],
title="๐Ÿง  Image QA with Voice (VQA + GPT-Neo)",
description="Ask a question about an image and get a full sentence answer, including audio!"
)
app = gr.mount_gradio_app(app, interface, path="/")
@app.get("/")
def home():
return RedirectResponse(url="/")