ikraamkb commited on
Commit
c8db168
·
verified ·
1 Parent(s): 62d4126

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -19
app.py CHANGED
@@ -65,13 +65,12 @@ app = gr.mount_gradio_app(app, gui, path="/")
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
- from fastapi import FastAPI
69
- from fastapi.responses import RedirectResponse
 
 
70
  from PIL import Image
71
- from transformers import (
72
- ViltProcessor, ViltForQuestionAnswering,
73
- T5Tokenizer, T5ForConditionalGeneration
74
- )
75
  from gtts import gTTS
76
  import torch
77
  import tempfile
@@ -79,20 +78,29 @@ import gradio as gr
79
 
80
  app = FastAPI()
81
 
82
- # VQA Model
83
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
85
 
86
- # Text Rewriter (FLAN-T5-base)
87
- rewrite_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
88
- rewrite_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
 
89
 
90
  def rewrite_answer(question, short_answer):
91
- prompt = f"Answer the question '{question}' with a complete sentence using this answer: '{short_answer}'"
92
- inputs = rewrite_tokenizer(prompt, return_tensors="pt")
93
  with torch.no_grad():
94
- outputs = rewrite_model.generate(**inputs, max_new_tokens=50)
95
- return rewrite_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
96
 
97
  def answer_question_from_image(image, question):
98
  if image is None or not question.strip():
@@ -104,21 +112,23 @@ def answer_question_from_image(image, question):
104
  predicted_id = outputs.logits.argmax(-1).item()
105
  short_answer = vqa_model.config.id2label[predicted_id]
106
 
107
- # Rewrite to full sentence
108
  full_answer = rewrite_answer(question, short_answer)
109
 
110
  try:
111
  tts = gTTS(text=full_answer)
112
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
113
  tts.save(tmp.name)
114
- return full_answer, tmp.name
115
  except Exception as e:
116
- return f"{full_answer}\n\n⚠️ Audio generation error: {e}", None
 
 
117
 
118
  def process_image_question(image: Image.Image, question: str):
119
- return answer_question_from_image(image, question)
 
120
 
121
- # Gradio UI
122
  gui = gr.Interface(
123
  fn=process_image_question,
124
  inputs=[
 
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
+ from fastapi import FastAPI, UploadFile, Form
69
+ from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
70
+ import os
71
+ import shutil
72
  from PIL import Image
73
+ from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
 
 
 
74
  from gtts import gTTS
75
  import torch
76
  import tempfile
 
78
 
79
  app = FastAPI()
80
 
81
+ # Load VQA Model
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
+ # Load GPT model to rewrite answers
86
+ # Replacing Falcon-7B-Instruct with GPT-Neo-125M
87
+ gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
88
+ gpt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
89
 
90
  def rewrite_answer(question, short_answer):
91
+ prompt = f"Question: {question}\nShort Answer: {short_answer}\nFull Sentence:" # Simpler prompt for GPT-Neo
92
+ inputs = gpt_tokenizer(prompt, return_tensors="pt")
93
  with torch.no_grad():
94
+ outputs = gpt_model.generate(
95
+ **inputs,
96
+ max_new_tokens=50,
97
+ do_sample=True,
98
+ top_p=0.9,
99
+ temperature=0.8,
100
+ pad_token_id=gpt_tokenizer.eos_token_id
101
+ )
102
+ full = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
103
+ return full.split("Full Sentence:")[-1].strip()
104
 
105
  def answer_question_from_image(image, question):
106
  if image is None or not question.strip():
 
112
  predicted_id = outputs.logits.argmax(-1).item()
113
  short_answer = vqa_model.config.id2label[predicted_id]
114
 
115
+ # Rewrite short answer to full sentence
116
  full_answer = rewrite_answer(question, short_answer)
117
 
118
  try:
119
  tts = gTTS(text=full_answer)
120
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
121
  tts.save(tmp.name)
122
+ audio_path = tmp.name
123
  except Exception as e:
124
+ return f"Answer: {full_answer}\n\n⚠️ Audio generation error: {e}", None
125
+
126
+ return full_answer, audio_path
127
 
128
  def process_image_question(image: Image.Image, question: str):
129
+ answer, audio_path = answer_question_from_image(image, question)
130
+ return answer, audio_path
131
 
 
132
  gui = gr.Interface(
133
  fn=process_image_question,
134
  inputs=[