ikraamkb commited on
Commit
cb83f1d
·
verified ·
1 Parent(s): 680d4a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -17
app.py CHANGED
@@ -65,15 +65,13 @@ app = gr.mount_gradio_app(app, gui, path="/")
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
- from fastapi import FastAPI, UploadFile, Form
69
- from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
70
- import os
71
- import shutil
72
  from PIL import Image
 
73
  from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
74
  from gtts import gTTS
75
- import torch
76
- import tempfile
77
  import gradio as gr
78
 
79
  app = FastAPI()
@@ -86,19 +84,30 @@ vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetune
86
  gpt_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
87
  gpt_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
88
 
89
-
90
- def rewrite_answer(question):
91
- prompt = f"{question}\nAnswer with a full sentence:"
 
 
 
92
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
93
  with torch.no_grad():
94
  outputs = gpt_model.generate(
95
  **inputs,
96
- max_new_tokens=40,
97
- do_sample=False,
 
 
98
  pad_token_id=gpt_tokenizer.eos_token_id
99
  )
 
100
  generated = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
101
- rewritten = generated.split(":")[-1].strip()
 
 
 
 
 
102
  return rewritten
103
 
104
  def answer_question_from_image(image, question):
@@ -111,16 +120,17 @@ def answer_question_from_image(image, question):
111
  predicted_id = outputs.logits.argmax(-1).item()
112
  short_answer = vqa_model.config.id2label[predicted_id]
113
 
114
- # Rewrite short answer to full sentence with GPT-Neo
115
- full_answer = rewrite_answer(f"Question: {question}\nAnswer: {short_answer}")
116
 
 
117
  try:
118
  tts = gTTS(text=full_answer)
119
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
120
  tts.save(tmp.name)
121
  audio_path = tmp.name
122
  except Exception as e:
123
- return f"Answer: {full_answer}\n\n⚠️ Audio generation error: {e}", None
124
 
125
  return full_answer, audio_path
126
 
@@ -128,6 +138,7 @@ def process_image_question(image: Image.Image, question: str):
128
  answer, audio_path = answer_question_from_image(image, question)
129
  return answer, audio_path
130
 
 
131
  gui = gr.Interface(
132
  fn=process_image_question,
133
  inputs=[
@@ -139,11 +150,12 @@ gui = gr.Interface(
139
  gr.Audio(label="Answer (Audio)", type="filepath")
140
  ],
141
  title="🧠 Image QA with Voice",
142
- description="Upload an image and ask a question. You'll get a full-sentence spoken answer."
143
  )
144
 
 
145
  app = gr.mount_gradio_app(app, gui, path="/")
146
 
147
  @app.get("/")
148
  def home():
149
- return RedirectResponse(url="/")
 
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
+ from fastapi import FastAPI
69
+ from fastapi.responses import RedirectResponse
70
+ import tempfile
 
71
  from PIL import Image
72
+ import torch
73
  from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
74
  from gtts import gTTS
 
 
75
  import gradio as gr
76
 
77
  app = FastAPI()
 
84
  gpt_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
85
  gpt_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
86
 
87
+ def rewrite_answer(question, short_answer):
88
+ prompt = (
89
+ f"Question: {question}\n"
90
+ f"Short Answer: {short_answer}\n"
91
+ f"Now write a full sentence answering the question:"
92
+ )
93
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
94
  with torch.no_grad():
95
  outputs = gpt_model.generate(
96
  **inputs,
97
+ max_new_tokens=50,
98
+ do_sample=True,
99
+ top_p=0.9,
100
+ temperature=0.7,
101
  pad_token_id=gpt_tokenizer.eos_token_id
102
  )
103
+
104
  generated = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
105
+
106
+ if "Now write a full sentence answering the question:" in generated:
107
+ rewritten = generated.split("Now write a full sentence answering the question:")[-1].strip()
108
+ else:
109
+ rewritten = generated.strip()
110
+
111
  return rewritten
112
 
113
  def answer_question_from_image(image, question):
 
120
  predicted_id = outputs.logits.argmax(-1).item()
121
  short_answer = vqa_model.config.id2label[predicted_id]
122
 
123
+ # Rewrite to human-like sentence
124
+ full_answer = rewrite_answer(question, short_answer)
125
 
126
+ # Convert to speech
127
  try:
128
  tts = gTTS(text=full_answer)
129
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
130
  tts.save(tmp.name)
131
  audio_path = tmp.name
132
  except Exception as e:
133
+ return f"{full_answer}\n\n⚠️ Audio generation error: {e}", None
134
 
135
  return full_answer, audio_path
136
 
 
138
  answer, audio_path = answer_question_from_image(image, question)
139
  return answer, audio_path
140
 
141
+ # Gradio UI
142
  gui = gr.Interface(
143
  fn=process_image_question,
144
  inputs=[
 
150
  gr.Audio(label="Answer (Audio)", type="filepath")
151
  ],
152
  title="🧠 Image QA with Voice",
153
+ description="Upload an image and ask a question. You'll get a human-like spoken answer."
154
  )
155
 
156
+ # Mount on FastAPI
157
  app = gr.mount_gradio_app(app, gui, path="/")
158
 
159
  @app.get("/")
160
  def home():
161
+ return RedirectResponse(url="/")