File size: 3,138 Bytes
47942ca
 
 
 
c330600
974f8bb
c330600
bca0a86
47942ca
 
c330600
c55ca48
c330600
 
47942ca
c330600
 
 
4a81c80
 
974f8bb
bca0a86
f5f107d
bca0a86
4a81c80
 
 
 
 
 
 
 
974f8bb
47942ca
c330600
47942ca
c330600
4a81c80
 
 
 
bca0a86
 
4a81c80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
974f8bb
c330600
4a81c80
c330600
 
c8db168
c330600
4a81c80
c8db168
4a81c80
c330600
47942ca
 
 
 
 
 
c330600
47942ca
 
c330600
 
47942ca
 
c330600
47942ca
4a81c80
c330600
 
47942ca
c330600
5b4fc38
fac31c8
bca0a86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from fastapi import FastAPI, UploadFile, Form
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
import os
import shutil
from PIL import Image
from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline
from gtts import gTTS
import easyocr
import torch
import tempfile
import gradio as gr

app = FastAPI()

# Load VQA Model
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

# Load image captioning model
captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")

# Load EasyOCR reader
reader = easyocr.Reader(['en', 'fr', 'ar'])

def classify_question(question: str):
    question_lower = question.lower()
    if any(word in question_lower for word in ["text", "say", "written", "read"]):
        return "ocr"
    elif any(word in question_lower for word in ["caption", "describe", "what is in the image"]):
        return "caption"
    else:
        return "vqa"

def answer_question_from_image(image, question):
    if image is None or not question.strip():
        return "Please upload an image and ask a question.", None

    mode = classify_question(question)

    if mode == "ocr":
        try:
            result = reader.readtext(image)
            text = " ".join([entry[1] for entry in result])
            answer = text.strip() or "No readable text found."
        except Exception as e:
            answer = f"OCR Error: {e}"

    elif mode == "caption":
        try:
            answer = captioner(image)[0]['generated_text']
        except Exception as e:
            answer = f"Captioning error: {e}"

    else:
        try:
            inputs = vqa_processor(image, question, return_tensors="pt")
            with torch.no_grad():
                outputs = vqa_model(**inputs)
            predicted_id = outputs.logits.argmax(-1).item()
            answer = vqa_model.config.id2label[predicted_id]
        except Exception as e:
            answer = f"VQA error: {e}"

    try:
        tts = gTTS(text=answer)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
            tts.save(tmp.name)
            audio_path = tmp.name
    except Exception as e:
        return f"Answer: {answer}\n\n⚠️ Audio generation error: {e}", None

    return answer, audio_path

def process_image_question(image: Image.Image, question: str):
    answer, audio_path = answer_question_from_image(image, question)
    return answer, audio_path

gui = gr.Interface(
    fn=process_image_question,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(lines=2, placeholder="Ask a question about the image...", label="Question")
    ],
    outputs=[
        gr.Textbox(label="Answer", lines=5),
        gr.Audio(label="Answer (Audio)", type="filepath")
    ],
    title="🧠 Image QA with Voice",
    description="Upload an image and ask a question. Works for OCR, captioning, and VQA."
)

app = gr.mount_gradio_app(app, gui, path="/")

@app.get("/")
def home():
    return RedirectResponse(url="/")