ikraamkb commited on
Commit
974f8bb
·
verified ·
1 Parent(s): bdeddcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -5
app.py CHANGED
@@ -70,7 +70,7 @@ from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
70
  import os
71
  import shutil
72
  from PIL import Image
73
- from transformers import ViltProcessor, ViltForQuestionAnswering
74
  from gtts import gTTS
75
  import torch
76
  import tempfile
@@ -82,25 +82,42 @@ app = FastAPI()
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def answer_question_from_image(image, question):
86
  if image is None or not question.strip():
87
  return "Please upload an image and ask a question.", None
88
 
 
89
  inputs = vqa_processor(image, question, return_tensors="pt")
90
  with torch.no_grad():
91
  outputs = vqa_model(**inputs)
92
  predicted_id = outputs.logits.argmax(-1).item()
93
  short_answer = vqa_model.config.id2label[predicted_id]
94
 
 
 
 
 
95
  try:
96
- tts = gTTS(text=short_answer)
97
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
98
  tts.save(tmp.name)
99
  audio_path = tmp.name
100
  except Exception as e:
101
- return f"Answer: {short_answer}\n\n⚠️ Audio generation error: {e}", None
102
 
103
- return short_answer, audio_path
104
 
105
  def process_image_question(image: Image.Image, question: str):
106
  answer, audio_path = answer_question_from_image(image, question)
@@ -117,7 +134,7 @@ gui = gr.Interface(
117
  gr.Audio(label="Answer (Audio)", type="filepath")
118
  ],
119
  title="🧠 Image QA with Voice",
120
- description="Upload an image and ask a question. You'll get an answer spoken out loud."
121
  )
122
 
123
  app = gr.mount_gradio_app(app, gui, path="/")
 
70
  import os
71
  import shutil
72
  from PIL import Image
73
+ from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline
74
  from gtts import gTTS
75
  import torch
76
  import tempfile
 
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
+ # Load GPT model for rewriting short answers
86
+ gpt_rewriter = pipeline("text-generation", model="EleutherAI/gpt-neo-1.3B")
87
+
88
+ def rewrite_answer(question: str, short_answer: str):
89
+ prompt = f"Q: {question}\nA: {short_answer}\n\nRespond with a full sentence:"
90
+ try:
91
+ result = gpt_rewriter(prompt, max_length=50, do_sample=False)
92
+ full_sentence = result[0]['generated_text'].split("Respond with a full sentence:")[-1].strip()
93
+ return full_sentence
94
+ except Exception as e:
95
+ return short_answer # fallback
96
+
97
  def answer_question_from_image(image, question):
98
  if image is None or not question.strip():
99
  return "Please upload an image and ask a question.", None
100
 
101
+ # Process with model
102
  inputs = vqa_processor(image, question, return_tensors="pt")
103
  with torch.no_grad():
104
  outputs = vqa_model(**inputs)
105
  predicted_id = outputs.logits.argmax(-1).item()
106
  short_answer = vqa_model.config.id2label[predicted_id]
107
 
108
+ # Rewrite short answer using GPT
109
+ full_answer = rewrite_answer(question, short_answer)
110
+
111
+ # Generate TTS audio
112
  try:
113
+ tts = gTTS(text=full_answer)
114
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
115
  tts.save(tmp.name)
116
  audio_path = tmp.name
117
  except Exception as e:
118
+ return f"Answer: {full_answer}\n\n⚠️ Audio generation error: {e}", None
119
 
120
+ return full_answer, audio_path
121
 
122
  def process_image_question(image: Image.Image, question: str):
123
  answer, audio_path = answer_question_from_image(image, question)
 
134
  gr.Audio(label="Answer (Audio)", type="filepath")
135
  ],
136
  title="🧠 Image QA with Voice",
137
+ description="Upload an image and ask a question. You'll get a detailed text + spoken answer."
138
  )
139
 
140
  app = gr.mount_gradio_app(app, gui, path="/")