Spaces:
Building
Building
Update app.py
Browse files
app.py
CHANGED
@@ -68,7 +68,7 @@ from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
|
|
68 |
import os
|
69 |
import shutil
|
70 |
from PIL import Image
|
71 |
-
from transformers import ViltProcessor, ViltForQuestionAnswering,
|
72 |
from gtts import gTTS
|
73 |
import torch
|
74 |
import tempfile
|
@@ -80,33 +80,37 @@ app = FastAPI()
|
|
80 |
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
81 |
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
82 |
|
83 |
-
# Load GPT model
|
84 |
-
|
|
|
85 |
|
86 |
-
def rewrite_answer(question
|
87 |
-
prompt = f"
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
def answer_question_from_image(image, question):
|
96 |
if image is None or not question.strip():
|
97 |
return "Please upload an image and ask a question.", None
|
98 |
|
99 |
-
# Process with model
|
100 |
inputs = vqa_processor(image, question, return_tensors="pt")
|
101 |
with torch.no_grad():
|
102 |
outputs = vqa_model(**inputs)
|
103 |
predicted_id = outputs.logits.argmax(-1).item()
|
104 |
short_answer = vqa_model.config.id2label[predicted_id]
|
105 |
|
106 |
-
# Rewrite short answer
|
107 |
-
full_answer = rewrite_answer(question
|
108 |
|
109 |
-
# Generate TTS audio
|
110 |
try:
|
111 |
tts = gTTS(text=full_answer)
|
112 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
|
@@ -132,7 +136,7 @@ gui = gr.Interface(
|
|
132 |
gr.Audio(label="Answer (Audio)", type="filepath")
|
133 |
],
|
134 |
title="🧠 Image QA with Voice",
|
135 |
-
description="Upload an image and ask a question. You'll get a
|
136 |
)
|
137 |
|
138 |
app = gr.mount_gradio_app(app, gui, path="/")
|
|
|
68 |
import os
|
69 |
import shutil
|
70 |
from PIL import Image
|
71 |
+
from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
|
72 |
from gtts import gTTS
|
73 |
import torch
|
74 |
import tempfile
|
|
|
80 |
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
81 |
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
82 |
|
83 |
+
# Load GPT model to rewrite answers
|
84 |
+
gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
85 |
+
gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
86 |
|
87 |
+
def rewrite_answer(question):
|
88 |
+
prompt = f"{question}\nAnswer with a full sentence:"
|
89 |
+
inputs = gpt_tokenizer(prompt, return_tensors="pt")
|
90 |
+
with torch.no_grad():
|
91 |
+
outputs = gpt_model.generate(
|
92 |
+
**inputs,
|
93 |
+
max_new_tokens=40,
|
94 |
+
do_sample=False,
|
95 |
+
pad_token_id=gpt_tokenizer.eos_token_id
|
96 |
+
)
|
97 |
+
generated = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
98 |
+
rewritten = generated.split(":")[-1].strip()
|
99 |
+
return rewritten
|
100 |
|
101 |
def answer_question_from_image(image, question):
|
102 |
if image is None or not question.strip():
|
103 |
return "Please upload an image and ask a question.", None
|
104 |
|
|
|
105 |
inputs = vqa_processor(image, question, return_tensors="pt")
|
106 |
with torch.no_grad():
|
107 |
outputs = vqa_model(**inputs)
|
108 |
predicted_id = outputs.logits.argmax(-1).item()
|
109 |
short_answer = vqa_model.config.id2label[predicted_id]
|
110 |
|
111 |
+
# Rewrite short answer to full sentence with GPT-Neo
|
112 |
+
full_answer = rewrite_answer(f"Question: {question}\nAnswer: {short_answer}")
|
113 |
|
|
|
114 |
try:
|
115 |
tts = gTTS(text=full_answer)
|
116 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
|
|
|
136 |
gr.Audio(label="Answer (Audio)", type="filepath")
|
137 |
],
|
138 |
title="🧠 Image QA with Voice",
|
139 |
+
description="Upload an image and ask a question. You'll get a full-sentence spoken answer."
|
140 |
)
|
141 |
|
142 |
app = gr.mount_gradio_app(app, gui, path="/")
|