ikraamkb commited on
Commit
b0cc6e9
·
verified ·
1 Parent(s): c330600

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -16
app.py CHANGED
@@ -68,7 +68,7 @@ from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
68
  import os
69
  import shutil
70
  from PIL import Image
71
- from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline
72
  from gtts import gTTS
73
  import torch
74
  import tempfile
@@ -80,33 +80,37 @@ app = FastAPI()
80
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
81
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
82
 
83
- # Load GPT model for rewriting short answers
84
- gpt_rewriter = pipeline("text-generation", model="EleutherAI/gpt-neo-1.3B")
 
85
 
86
- def rewrite_answer(question: str, short_answer: str):
87
- prompt = f"Q: {question}\nA: {short_answer}\n\nRespond with a full sentence:"
88
- try:
89
- result = gpt_rewriter(prompt, max_length=50, do_sample=False)
90
- full_sentence = result[0]['generated_text'].split("Respond with a full sentence:")[-1].strip()
91
- return full_sentence
92
- except Exception as e:
93
- return short_answer # fallback
 
 
 
 
 
94
 
95
  def answer_question_from_image(image, question):
96
  if image is None or not question.strip():
97
  return "Please upload an image and ask a question.", None
98
 
99
- # Process with model
100
  inputs = vqa_processor(image, question, return_tensors="pt")
101
  with torch.no_grad():
102
  outputs = vqa_model(**inputs)
103
  predicted_id = outputs.logits.argmax(-1).item()
104
  short_answer = vqa_model.config.id2label[predicted_id]
105
 
106
- # Rewrite short answer using GPT
107
- full_answer = rewrite_answer(question, short_answer)
108
 
109
- # Generate TTS audio
110
  try:
111
  tts = gTTS(text=full_answer)
112
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
@@ -132,7 +136,7 @@ gui = gr.Interface(
132
  gr.Audio(label="Answer (Audio)", type="filepath")
133
  ],
134
  title="🧠 Image QA with Voice",
135
- description="Upload an image and ask a question. You'll get a detailed text + spoken answer."
136
  )
137
 
138
  app = gr.mount_gradio_app(app, gui, path="/")
 
68
  import os
69
  import shutil
70
  from PIL import Image
71
+ from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
72
  from gtts import gTTS
73
  import torch
74
  import tempfile
 
80
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
81
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
82
 
83
+ # Load GPT model to rewrite answers
84
+ gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
85
+ gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
86
 
87
+ def rewrite_answer(question):
88
+ prompt = f"{question}\nAnswer with a full sentence:"
89
+ inputs = gpt_tokenizer(prompt, return_tensors="pt")
90
+ with torch.no_grad():
91
+ outputs = gpt_model.generate(
92
+ **inputs,
93
+ max_new_tokens=40,
94
+ do_sample=False,
95
+ pad_token_id=gpt_tokenizer.eos_token_id
96
+ )
97
+ generated = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
98
+ rewritten = generated.split(":")[-1].strip()
99
+ return rewritten
100
 
101
  def answer_question_from_image(image, question):
102
  if image is None or not question.strip():
103
  return "Please upload an image and ask a question.", None
104
 
 
105
  inputs = vqa_processor(image, question, return_tensors="pt")
106
  with torch.no_grad():
107
  outputs = vqa_model(**inputs)
108
  predicted_id = outputs.logits.argmax(-1).item()
109
  short_answer = vqa_model.config.id2label[predicted_id]
110
 
111
+ # Rewrite short answer to full sentence with GPT-Neo
112
+ full_answer = rewrite_answer(f"Question: {question}\nAnswer: {short_answer}")
113
 
 
114
  try:
115
  tts = gTTS(text=full_answer)
116
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
 
136
  gr.Audio(label="Answer (Audio)", type="filepath")
137
  ],
138
  title="🧠 Image QA with Voice",
139
+ description="Upload an image and ask a question. You'll get a full-sentence spoken answer."
140
  )
141
 
142
  app = gr.mount_gradio_app(app, gui, path="/")