Summarization / app.py
ikraamkb's picture
Update app.py
14e7320 verified
raw
history blame
5.33 kB
"""from fastapi import FastAPI, UploadFile, Form
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
import os
import shutil
from PIL import Image
from transformers import ViltProcessor, ViltForQuestionAnswering
from gtts import gTTS
import torch
import tempfile
import gradio as gr
app = FastAPI()
# Load VQA Model
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
def answer_question_from_image(image, question):
if image is None or not question.strip():
return "Please upload an image and ask a question.", None
# Process with model
inputs = vqa_processor(image, question, return_tensors="pt")
with torch.no_grad():
outputs = vqa_model(**inputs)
predicted_id = outputs.logits.argmax(-1).item()
answer = vqa_model.config.id2label[predicted_id]
# Generate TTS audio
try:
tts = gTTS(text=answer)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
tts.save(tmp.name)
audio_path = tmp.name
except Exception as e:
return f"Answer: {answer}\n\n⚠️ Audio generation error: {e}", None
return answer, audio_path
def process_image_question(image: Image.Image, question: str):
answer, audio_path = answer_question_from_image(image, question)
return answer, audio_path
gui = gr.Interface(
fn=process_image_question,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(lines=2, placeholder="Ask a question about the image...", label="Question")
],
outputs=[
gr.Textbox(label="Answer", lines=5),
gr.Audio(label="Answer (Audio)", type="filepath")
],
title="🧠 Image QA with Voice",
description="Upload an image and ask a question. You'll get a text + spoken answer."
)
app = gr.mount_gradio_app(app, gui, path="/")
@app.get("/")
def home():
return RedirectResponse(url="/") """
from fastapi import FastAPI, UploadFile, Form
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
import os
import shutil
from PIL import Image
from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
from gtts import gTTS
import torch
import tempfile
import gradio as gr
app = FastAPI()
# Load VQA Model
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
# Load GPT model to rewrite answers (Phi-1.5)
gpt_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
gpt_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5")
def rewrite_answer(question, short_answer):
prompt = f"Write a full sentence to answer this:\nQ: {question}\nA: {short_answer}\nFull sentence:"
inputs = gpt_tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = gpt_model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
top_k=40,
top_p=0.9,
temperature=0.6,
pad_token_id=gpt_tokenizer.eos_token_id
)
generated = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Try to isolate the answer portion
if "Full sentence:" in generated:
rewritten = generated.split("Full sentence:")[-1].strip()
else:
rewritten = generated.strip()
# Fallback to basic templating if model fails
if not rewritten or len(rewritten.split()) < 3:
rewritten = f"The answer to the question '{question}' is: {short_answer}."
return rewritten
def answer_question_from_image(image, question):
if image is None or not question.strip():
return "Please upload an image and ask a question.", None
inputs = vqa_processor(image, question, return_tensors="pt")
with torch.no_grad():
outputs = vqa_model(**inputs)
predicted_id = outputs.logits.argmax(-1).item()
short_answer = vqa_model.config.id2label[predicted_id]
# Rewrite short answer to full sentence with Phi-1.5
full_answer = rewrite_answer(question, short_answer)
try:
tts = gTTS(text=full_answer)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
tts.save(tmp.name)
audio_path = tmp.name
except Exception as e:
return f"Answer: {full_answer}\n\n⚠️ Audio generation error: {e}", None
return full_answer, audio_path
def process_image_question(image: Image.Image, question: str):
answer, audio_path = answer_question_from_image(image, question)
return answer, audio_path
gui = gr.Interface(
fn=process_image_question,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(lines=2, placeholder="Ask a question about the image...", label="Question")
],
outputs=[
gr.Textbox(label="Answer", lines=5),
gr.Audio(label="Answer (Audio)", type="filepath")
],
title="🧠 Image QA with Voice",
description="Upload an image and ask a question. You'll get a full-sentence spoken answer."
)
app = gr.mount_gradio_app(app, gui, path="/")
@app.get("/")
def home():
return RedirectResponse(url="/")