ikraamkb commited on
Commit
47942ca
·
verified ·
1 Parent(s): ffd57c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -34
app.py CHANGED
@@ -65,43 +65,45 @@ app = gr.mount_gradio_app(app, gui, path="/")
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
- from fastapi import FastAPI
69
- from fastapi.responses import RedirectResponse
70
- import tempfile
71
- import torch
72
  from PIL import Image
 
73
  from gtts import gTTS
 
 
74
  import gradio as gr
75
- from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
76
 
77
  app = FastAPI()
78
 
79
- # Load Models
80
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
81
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
82
 
83
- gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
84
- gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
 
85
 
86
- # Rewrite answer to human-like sentence
87
- def rewrite_answer(question: str, short_answer: str) -> str:
88
- prompt = f"Question: {question}\nAnswer: {short_answer}\nRewrite the answer into a complete sentence:"
89
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
90
  with torch.no_grad():
91
  outputs = gpt_model.generate(
92
  **inputs,
93
- max_new_tokens=40,
94
- do_sample=False,
95
- pad_token_id=gpt_tokenizer.eos_token_id,
96
  temperature=0.7,
 
97
  )
98
- result = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
99
- # Extract only the sentence after the "Rewrite..." line
100
- return result.split("Rewrite the answer into a complete sentence:")[-1].strip()
101
 
102
- def answer_question_from_image(image: Image.Image, question: str):
103
  if image is None or not question.strip():
104
- return "Please upload an image and type a question.", None
105
 
106
  inputs = vqa_processor(image, question, return_tensors="pt")
107
  with torch.no_grad():
@@ -109,35 +111,39 @@ def answer_question_from_image(image: Image.Image, question: str):
109
  predicted_id = outputs.logits.argmax(-1).item()
110
  short_answer = vqa_model.config.id2label[predicted_id]
111
 
112
- full_sentence = rewrite_answer(question, short_answer)
 
113
 
114
  try:
115
- tts = gTTS(text=full_sentence)
116
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
117
  tts.save(tmp.name)
118
  audio_path = tmp.name
119
  except Exception as e:
120
- return f"{full_sentence}\n\n⚠️ Audio generation error: {e}", None
121
 
122
- return full_sentence, audio_path
123
 
124
- # Gradio Interface
125
- interface = gr.Interface(
126
- fn=answer_question_from_image,
 
 
 
127
  inputs=[
128
- gr.Image(type="pil", label="🖼️ Upload Image"),
129
- gr.Textbox(lines=2, placeholder="Ask a question about the image", label="Question")
130
  ],
131
  outputs=[
132
- gr.Textbox(label="💬 Answer"),
133
- gr.Audio(label="🔊 Voice Output", type="filepath")
134
  ],
135
- title="🧠 Image QA with Voice (VQA + GPT-Neo)",
136
- description="Ask a question about an image and get a full sentence answer, including audio!"
137
  )
138
 
139
- app = gr.mount_gradio_app(app, interface, path="/")
140
 
141
  @app.get("/")
142
  def home():
143
- return RedirectResponse(url="/")
 
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
+ from fastapi import FastAPI, UploadFile, Form
69
+ from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
70
+ import os
71
+ import shutil
72
  from PIL import Image
73
+ from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
74
  from gtts import gTTS
75
+ import torch
76
+ import tempfile
77
  import gradio as gr
 
78
 
79
  app = FastAPI()
80
 
81
+ # Load VQA Model
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
+ # Load GPT model to rewrite answers (Phi-1.5)
86
+ gpt_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
87
+ gpt_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5")
88
 
89
+ def rewrite_answer(question, short_answer):
90
+ prompt = f"Question: {question}\nShort Answer: {short_answer}\nFull Sentence:"
 
91
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
92
  with torch.no_grad():
93
  outputs = gpt_model.generate(
94
  **inputs,
95
+ max_new_tokens=50,
96
+ do_sample=True,
97
+ top_p=0.9,
98
  temperature=0.7,
99
+ pad_token_id=gpt_tokenizer.eos_token_id
100
  )
101
+ generated = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ return generated.split("Full Sentence:")[-1].strip()
 
103
 
104
+ def answer_question_from_image(image, question):
105
  if image is None or not question.strip():
106
+ return "Please upload an image and ask a question.", None
107
 
108
  inputs = vqa_processor(image, question, return_tensors="pt")
109
  with torch.no_grad():
 
111
  predicted_id = outputs.logits.argmax(-1).item()
112
  short_answer = vqa_model.config.id2label[predicted_id]
113
 
114
+ # Rewrite short answer to full sentence
115
+ full_answer = rewrite_answer(question, short_answer)
116
 
117
  try:
118
+ tts = gTTS(text=full_answer)
119
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
120
  tts.save(tmp.name)
121
  audio_path = tmp.name
122
  except Exception as e:
123
+ return f"Answer: {full_answer}\n\n⚠️ Audio generation error: {e}", None
124
 
125
+ return full_answer, audio_path
126
 
127
+ def process_image_question(image: Image.Image, question: str):
128
+ answer, audio_path = answer_question_from_image(image, question)
129
+ return answer, audio_path
130
+
131
+ gui = gr.Interface(
132
+ fn=process_image_question,
133
  inputs=[
134
+ gr.Image(type="pil", label="Upload Image"),
135
+ gr.Textbox(lines=2, placeholder="Ask a question about the image...", label="Question")
136
  ],
137
  outputs=[
138
+ gr.Textbox(label="Answer", lines=5),
139
+ gr.Audio(label="Answer (Audio)", type="filepath")
140
  ],
141
+ title="🧠 Image QA with Voice",
142
+ description="Upload an image and ask a question. You'll get a full-sentence spoken answer."
143
  )
144
 
145
+ app = gr.mount_gradio_app(app, gui, path="/")
146
 
147
  @app.get("/")
148
  def home():
149
+ return RedirectResponse(url="/")