File size: 2,024 Bytes
f94fa3b
 
 
 
fac31c8
f94fa3b
fac31c8
f94fa3b
fac31c8
f94fa3b
607327a
1e83db4
a74f8b0
f94fa3b
 
 
cf9a79a
 
f94fa3b
 
 
cf9a79a
f94fa3b
 
 
 
 
 
6dfac5c
f94fa3b
 
fac31c8
 
 
 
f94fa3b
 
12d05c0
f94fa3b
 
 
 
 
 
12d05c0
fac31c8
 
 
 
 
 
 
 
f94fa3b
fac31c8
 
 
f94fa3b
5b4fc38
 
fac31c8
3e87c53
5b4fc38
fac31c8
f94fa3b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from fastapi import FastAPI, UploadFile, Form
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
import os
import shutil
from PIL import Image
from transformers import ViltProcessor, ViltForQuestionAnswering
from gtts import gTTS
import torch
import tempfile
import gradio as gr

app = FastAPI()

# Load VQA Model
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")


def answer_question_from_image(image, question):
    if image is None or not question.strip():
        return "Please upload an image and ask a question.", None

    # Process with model
    inputs = vqa_processor(image, question, return_tensors="pt")
    with torch.no_grad():
        outputs = vqa_model(**inputs)
    predicted_id = outputs.logits.argmax(-1).item()
    answer = vqa_model.config.id2label[predicted_id]

    # Generate TTS audio
    try:
        tts = gTTS(text=answer)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
            tts.save(tmp.name)
            audio_path = tmp.name
    except Exception as e:
        return f"Answer: {answer}\n\n⚠️ Audio generation error: {e}", None

    return answer, audio_path


def process_image_question(image: Image.Image, question: str):
    answer, audio_path = answer_question_from_image(image, question)
    return answer, audio_path


gui = gr.Interface(
    fn=process_image_question,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(lines=2, placeholder="Ask a question about the image...", label="Question")
    ],
    outputs=[
        gr.Textbox(label="Answer", lines=5),
        gr.Audio(label="Answer (Audio)", type="filepath")
    ],
    title="🧠 Image QA with Voice",
    description="Upload an image and ask a question. You'll get a text + spoken answer."
)

app = gr.mount_gradio_app(app, gui, path="/")

@app.get("/")
def home():
    return RedirectResponse(url="/")