File size: 4,845 Bytes
c330600
f94fa3b
 
 
fac31c8
f94fa3b
fac31c8
f94fa3b
fac31c8
f94fa3b
607327a
680d4a4
 
1e83db4
a74f8b0
f94fa3b
 
 
cf9a79a
 
f94fa3b
 
 
cf9a79a
f94fa3b
 
 
 
 
 
6dfac5c
f94fa3b
 
fac31c8
 
 
 
f94fa3b
 
12d05c0
f94fa3b
 
 
 
 
 
12d05c0
fac31c8
 
 
 
 
 
 
 
f94fa3b
fac31c8
 
 
f94fa3b
5b4fc38
 
fac31c8
3e87c53
c330600
 
 
ffd57c2
 
 
 
c330600
 
 
ffd57c2
c55ca48
c330600
 
ffd57c2
c330600
 
 
c8db168
 
c4c89fb
ffd57c2
 
 
c8db168
b0cc6e9
c8db168
 
ffd57c2
 
 
 
c8db168
ffd57c2
 
 
c330600
ffd57c2
c330600
ffd57c2
c330600
 
 
 
 
 
 
ffd57c2
c330600
 
ffd57c2
c330600
 
c8db168
c330600
ffd57c2
c8db168
ffd57c2
c330600
ffd57c2
 
 
c330600
ffd57c2
 
c330600
 
ffd57c2
 
c330600
ffd57c2
 
c330600
 
ffd57c2
c330600
5b4fc38
fac31c8
62d4126
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""from fastapi import FastAPI, UploadFile, Form
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
import os
import shutil
from PIL import Image
from transformers import ViltProcessor, ViltForQuestionAnswering
from gtts import gTTS
import torch
import tempfile
import gradio as gr



app = FastAPI()

# Load VQA Model
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")


def answer_question_from_image(image, question):
    if image is None or not question.strip():
        return "Please upload an image and ask a question.", None

    # Process with model
    inputs = vqa_processor(image, question, return_tensors="pt")
    with torch.no_grad():
        outputs = vqa_model(**inputs)
    predicted_id = outputs.logits.argmax(-1).item()
    answer = vqa_model.config.id2label[predicted_id]

    # Generate TTS audio
    try:
        tts = gTTS(text=answer)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
            tts.save(tmp.name)
            audio_path = tmp.name
    except Exception as e:
        return f"Answer: {answer}\n\nโš ๏ธ Audio generation error: {e}", None

    return answer, audio_path


def process_image_question(image: Image.Image, question: str):
    answer, audio_path = answer_question_from_image(image, question)
    return answer, audio_path


gui = gr.Interface(
    fn=process_image_question,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(lines=2, placeholder="Ask a question about the image...", label="Question")
    ],
    outputs=[
        gr.Textbox(label="Answer", lines=5),
        gr.Audio(label="Answer (Audio)", type="filepath")
    ],
    title="๐Ÿง  Image QA with Voice",
    description="Upload an image and ask a question. You'll get a text + spoken answer."
)

app = gr.mount_gradio_app(app, gui, path="/")

@app.get("/")
def home():
    return RedirectResponse(url="/") """
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
import tempfile
import torch
from PIL import Image
from gtts import gTTS
import gradio as gr
from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM

app = FastAPI()

# Load Models
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")

# Rewrite answer to human-like sentence
def rewrite_answer(question: str, short_answer: str) -> str:
    prompt = f"Question: {question}\nAnswer: {short_answer}\nRewrite the answer into a complete sentence:"
    inputs = gpt_tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = gpt_model.generate(
            **inputs,
            max_new_tokens=40,
            do_sample=False,
            pad_token_id=gpt_tokenizer.eos_token_id,
            temperature=0.7,
        )
    result = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract only the sentence after the "Rewrite..." line
    return result.split("Rewrite the answer into a complete sentence:")[-1].strip()

def answer_question_from_image(image: Image.Image, question: str):
    if image is None or not question.strip():
        return "โ— Please upload an image and type a question.", None

    inputs = vqa_processor(image, question, return_tensors="pt")
    with torch.no_grad():
        outputs = vqa_model(**inputs)
    predicted_id = outputs.logits.argmax(-1).item()
    short_answer = vqa_model.config.id2label[predicted_id]

    full_sentence = rewrite_answer(question, short_answer)

    try:
        tts = gTTS(text=full_sentence)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
            tts.save(tmp.name)
            audio_path = tmp.name
    except Exception as e:
        return f"{full_sentence}\n\nโš ๏ธ Audio generation error: {e}", None

    return full_sentence, audio_path

# Gradio Interface
interface = gr.Interface(
    fn=answer_question_from_image,
    inputs=[
        gr.Image(type="pil", label="๐Ÿ–ผ๏ธ Upload Image"),
        gr.Textbox(lines=2, placeholder="Ask a question about the image", label="โ“ Question")
    ],
    outputs=[
        gr.Textbox(label="๐Ÿ’ฌ Answer"),
        gr.Audio(label="๐Ÿ”Š Voice Output", type="filepath")
    ],
    title="๐Ÿง  Image QA with Voice (VQA + GPT-Neo)",
    description="Ask a question about an image and get a full sentence answer, including audio!"
)

app = gr.mount_gradio_app(app, interface, path="/")

@app.get("/")
def home():
    return RedirectResponse(url="/")