ikraamkb commited on
Commit
c55ca48
·
verified ·
1 Parent(s): b7dc4fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -17
app.py CHANGED
@@ -65,26 +65,26 @@ app = gr.mount_gradio_app(app, gui, path="/")
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
- from fastapi import FastAPI
69
- from fastapi.responses import RedirectResponse
70
- import tempfile
 
71
  from PIL import Image
72
- import torch
73
  from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
74
  from gtts import gTTS
 
 
75
  import gradio as gr
76
- from transformers import AutoModelForSeq2SeqLM
77
  app = FastAPI()
78
 
79
  # Load VQA Model
80
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
81
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
82
 
83
- # Load GPT model to rewrite answers
84
-
85
- gpt_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
86
- gpt_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
87
-
88
 
89
  def rewrite_answer(question, short_answer):
90
  prompt = (
@@ -98,7 +98,7 @@ def rewrite_answer(question, short_answer):
98
  **inputs,
99
  max_new_tokens=50,
100
  do_sample=True,
101
- top_p=0.9,
102
  temperature=0.7,
103
  pad_token_id=gpt_tokenizer.eos_token_id
104
  )
@@ -122,17 +122,16 @@ def answer_question_from_image(image, question):
122
  predicted_id = outputs.logits.argmax(-1).item()
123
  short_answer = vqa_model.config.id2label[predicted_id]
124
 
125
- # Rewrite to human-like sentence
126
  full_answer = rewrite_answer(question, short_answer)
127
 
128
- # Convert to speech
129
  try:
130
  tts = gTTS(text=full_answer)
131
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
132
  tts.save(tmp.name)
133
  audio_path = tmp.name
134
  except Exception as e:
135
- return f"{full_answer}\n\n⚠️ Audio generation error: {e}", None
136
 
137
  return full_answer, audio_path
138
 
@@ -140,7 +139,6 @@ def process_image_question(image: Image.Image, question: str):
140
  answer, audio_path = answer_question_from_image(image, question)
141
  return answer, audio_path
142
 
143
- # Gradio UI
144
  gui = gr.Interface(
145
  fn=process_image_question,
146
  inputs=[
@@ -152,10 +150,9 @@ gui = gr.Interface(
152
  gr.Audio(label="Answer (Audio)", type="filepath")
153
  ],
154
  title="🧠 Image QA with Voice",
155
- description="Upload an image and ask a question. You'll get a human-like spoken answer."
156
  )
157
 
158
- # Mount on FastAPI
159
  app = gr.mount_gradio_app(app, gui, path="/")
160
 
161
  @app.get("/")
 
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
+ from fastapi import FastAPI, UploadFile, Form
69
+ from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
70
+ import os
71
+ import shutil
72
  from PIL import Image
 
73
  from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
74
  from gtts import gTTS
75
+ import torch
76
+ import tempfile
77
  import gradio as gr
78
+
79
  app = FastAPI()
80
 
81
  # Load VQA Model
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
+ # Load GPT model to rewrite answers (Phi-1.5)
86
+ gpt_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
87
+ gpt_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5")
 
 
88
 
89
  def rewrite_answer(question, short_answer):
90
  prompt = (
 
98
  **inputs,
99
  max_new_tokens=50,
100
  do_sample=True,
101
+ top_p=0.95,
102
  temperature=0.7,
103
  pad_token_id=gpt_tokenizer.eos_token_id
104
  )
 
122
  predicted_id = outputs.logits.argmax(-1).item()
123
  short_answer = vqa_model.config.id2label[predicted_id]
124
 
125
+ # Rewrite short answer to full sentence with Phi-1.5
126
  full_answer = rewrite_answer(question, short_answer)
127
 
 
128
  try:
129
  tts = gTTS(text=full_answer)
130
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
131
  tts.save(tmp.name)
132
  audio_path = tmp.name
133
  except Exception as e:
134
+ return f"Answer: {full_answer}\n\n⚠️ Audio generation error: {e}", None
135
 
136
  return full_answer, audio_path
137
 
 
139
  answer, audio_path = answer_question_from_image(image, question)
140
  return answer, audio_path
141
 
 
142
  gui = gr.Interface(
143
  fn=process_image_question,
144
  inputs=[
 
150
  gr.Audio(label="Answer (Audio)", type="filepath")
151
  ],
152
  title="🧠 Image QA with Voice",
153
+ description="Upload an image and ask a question. You'll get a full-sentence spoken answer."
154
  )
155
 
 
156
  app = gr.mount_gradio_app(app, gui, path="/")
157
 
158
  @app.get("/")