ikraamkb commited on
Commit
ffd57c2
·
verified ·
1 Parent(s): c8db168

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -39
app.py CHANGED
@@ -65,46 +65,43 @@ app = gr.mount_gradio_app(app, gui, path="/")
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
- from fastapi import FastAPI, UploadFile, Form
69
- from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
70
- import os
71
- import shutil
72
  from PIL import Image
73
- from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
74
  from gtts import gTTS
75
- import torch
76
- import tempfile
77
  import gradio as gr
 
78
 
79
  app = FastAPI()
80
 
81
- # Load VQA Model
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
- # Load GPT model to rewrite answers
86
- # Replacing Falcon-7B-Instruct with GPT-Neo-125M
87
  gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
88
  gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
89
 
90
- def rewrite_answer(question, short_answer):
91
- prompt = f"Question: {question}\nShort Answer: {short_answer}\nFull Sentence:" # Simpler prompt for GPT-Neo
 
92
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
93
  with torch.no_grad():
94
  outputs = gpt_model.generate(
95
  **inputs,
96
- max_new_tokens=50,
97
- do_sample=True,
98
- top_p=0.9,
99
- temperature=0.8,
100
- pad_token_id=gpt_tokenizer.eos_token_id
101
  )
102
- full = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
103
- return full.split("Full Sentence:")[-1].strip()
 
104
 
105
- def answer_question_from_image(image, question):
106
  if image is None or not question.strip():
107
- return "Please upload an image and ask a question.", None
108
 
109
  inputs = vqa_processor(image, question, return_tensors="pt")
110
  with torch.no_grad():
@@ -112,38 +109,34 @@ def answer_question_from_image(image, question):
112
  predicted_id = outputs.logits.argmax(-1).item()
113
  short_answer = vqa_model.config.id2label[predicted_id]
114
 
115
- # Rewrite short answer to full sentence
116
- full_answer = rewrite_answer(question, short_answer)
117
 
118
  try:
119
- tts = gTTS(text=full_answer)
120
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
121
  tts.save(tmp.name)
122
  audio_path = tmp.name
123
  except Exception as e:
124
- return f"Answer: {full_answer}\n\n⚠️ Audio generation error: {e}", None
125
 
126
- return full_answer, audio_path
127
 
128
- def process_image_question(image: Image.Image, question: str):
129
- answer, audio_path = answer_question_from_image(image, question)
130
- return answer, audio_path
131
-
132
- gui = gr.Interface(
133
- fn=process_image_question,
134
  inputs=[
135
- gr.Image(type="pil", label="Upload Image"),
136
- gr.Textbox(lines=2, placeholder="Ask a question about the image...", label="Question")
137
  ],
138
  outputs=[
139
- gr.Textbox(label="Answer", lines=5),
140
- gr.Audio(label="Answer (Audio)", type="filepath")
141
  ],
142
- title="🧠 Image QA with Voice",
143
- description="Upload an image and ask a question. You'll get a full-sentence spoken answer."
144
  )
145
 
146
- app = gr.mount_gradio_app(app, gui, path="/")
147
 
148
  @app.get("/")
149
  def home():
 
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
+ from fastapi import FastAPI
69
+ from fastapi.responses import RedirectResponse
70
+ import tempfile
71
+ import torch
72
  from PIL import Image
 
73
  from gtts import gTTS
 
 
74
  import gradio as gr
75
+ from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
76
 
77
  app = FastAPI()
78
 
79
+ # Load Models
80
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
81
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
82
 
 
 
83
  gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
84
  gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
85
 
86
+ # Rewrite answer to human-like sentence
87
+ def rewrite_answer(question: str, short_answer: str) -> str:
88
+ prompt = f"Question: {question}\nAnswer: {short_answer}\nRewrite the answer into a complete sentence:"
89
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
90
  with torch.no_grad():
91
  outputs = gpt_model.generate(
92
  **inputs,
93
+ max_new_tokens=40,
94
+ do_sample=False,
95
+ pad_token_id=gpt_tokenizer.eos_token_id,
96
+ temperature=0.7,
 
97
  )
98
+ result = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
99
+ # Extract only the sentence after the "Rewrite..." line
100
+ return result.split("Rewrite the answer into a complete sentence:")[-1].strip()
101
 
102
+ def answer_question_from_image(image: Image.Image, question: str):
103
  if image is None or not question.strip():
104
+ return "Please upload an image and type a question.", None
105
 
106
  inputs = vqa_processor(image, question, return_tensors="pt")
107
  with torch.no_grad():
 
109
  predicted_id = outputs.logits.argmax(-1).item()
110
  short_answer = vqa_model.config.id2label[predicted_id]
111
 
112
+ full_sentence = rewrite_answer(question, short_answer)
 
113
 
114
  try:
115
+ tts = gTTS(text=full_sentence)
116
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
117
  tts.save(tmp.name)
118
  audio_path = tmp.name
119
  except Exception as e:
120
+ return f"{full_sentence}\n\n⚠️ Audio generation error: {e}", None
121
 
122
+ return full_sentence, audio_path
123
 
124
+ # Gradio Interface
125
+ interface = gr.Interface(
126
+ fn=answer_question_from_image,
 
 
 
127
  inputs=[
128
+ gr.Image(type="pil", label="🖼️ Upload Image"),
129
+ gr.Textbox(lines=2, placeholder="Ask a question about the image", label="Question")
130
  ],
131
  outputs=[
132
+ gr.Textbox(label="💬 Answer"),
133
+ gr.Audio(label="🔊 Voice Output", type="filepath")
134
  ],
135
+ title="🧠 Image QA with Voice (VQA + GPT-Neo)",
136
+ description="Ask a question about an image and get a full sentence answer, including audio!"
137
  )
138
 
139
+ app = gr.mount_gradio_app(app, interface, path="/")
140
 
141
  @app.get("/")
142
  def home():