ikraamkb commited on
Commit
62d4126
·
verified ·
1 Parent(s): 9173622

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -29
app.py CHANGED
@@ -65,12 +65,13 @@ app = gr.mount_gradio_app(app, gui, path="/")
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
- from fastapi import FastAPI, UploadFile, Form
69
- from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
70
- import os
71
- import shutil
72
  from PIL import Image
73
- from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
 
 
 
74
  from gtts import gTTS
75
  import torch
76
  import tempfile
@@ -78,28 +79,20 @@ import gradio as gr
78
 
79
  app = FastAPI()
80
 
81
- # Load VQA Model
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
- # Load Falcon-7B-Instruct model to rewrite answers
86
- gpt_tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct")
87
- gpt_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-7b-instruct")
88
 
89
  def rewrite_answer(question, short_answer):
90
- prompt = f"Rewrite the short answer into a natural sentence.\nQuestion: {question}\nAnswer: {short_answer}\nFull Sentence:"
91
- inputs = gpt_tokenizer(prompt, return_tensors="pt")
92
  with torch.no_grad():
93
- outputs = gpt_model.generate(
94
- **inputs,
95
- max_new_tokens=50,
96
- do_sample=True,
97
- top_p=0.9,
98
- temperature=0.8,
99
- pad_token_id=gpt_tokenizer.eos_token_id
100
- )
101
- rewritten = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
102
- return rewritten.split("Full Sentence:")[-1].strip()
103
 
104
  def answer_question_from_image(image, question):
105
  if image is None or not question.strip():
@@ -111,23 +104,21 @@ def answer_question_from_image(image, question):
111
  predicted_id = outputs.logits.argmax(-1).item()
112
  short_answer = vqa_model.config.id2label[predicted_id]
113
 
114
- # Rewrite short answer to full sentence
115
  full_answer = rewrite_answer(question, short_answer)
116
 
117
  try:
118
  tts = gTTS(text=full_answer)
119
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
120
  tts.save(tmp.name)
121
- audio_path = tmp.name
122
  except Exception as e:
123
- return f"Answer: {full_answer}\n\n⚠️ Audio generation error: {e}", None
124
-
125
- return full_answer, audio_path
126
 
127
  def process_image_question(image: Image.Image, question: str):
128
- answer, audio_path = answer_question_from_image(image, question)
129
- return answer, audio_path
130
 
 
131
  gui = gr.Interface(
132
  fn=process_image_question,
133
  inputs=[
@@ -146,4 +137,4 @@ app = gr.mount_gradio_app(app, gui, path="/")
146
 
147
  @app.get("/")
148
  def home():
149
- return RedirectResponse(url="/")
 
65
  @app.get("/")
66
  def home():
67
  return RedirectResponse(url="/") """
68
+ from fastapi import FastAPI
69
+ from fastapi.responses import RedirectResponse
 
 
70
  from PIL import Image
71
+ from transformers import (
72
+ ViltProcessor, ViltForQuestionAnswering,
73
+ T5Tokenizer, T5ForConditionalGeneration
74
+ )
75
  from gtts import gTTS
76
  import torch
77
  import tempfile
 
79
 
80
  app = FastAPI()
81
 
82
+ # VQA Model
83
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
85
 
86
+ # Text Rewriter (FLAN-T5-base)
87
+ rewrite_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
88
+ rewrite_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
89
 
90
  def rewrite_answer(question, short_answer):
91
+ prompt = f"Answer the question '{question}' with a complete sentence using this answer: '{short_answer}'"
92
+ inputs = rewrite_tokenizer(prompt, return_tensors="pt")
93
  with torch.no_grad():
94
+ outputs = rewrite_model.generate(**inputs, max_new_tokens=50)
95
+ return rewrite_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
96
 
97
  def answer_question_from_image(image, question):
98
  if image is None or not question.strip():
 
104
  predicted_id = outputs.logits.argmax(-1).item()
105
  short_answer = vqa_model.config.id2label[predicted_id]
106
 
107
+ # Rewrite to full sentence
108
  full_answer = rewrite_answer(question, short_answer)
109
 
110
  try:
111
  tts = gTTS(text=full_answer)
112
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
113
  tts.save(tmp.name)
114
+ return full_answer, tmp.name
115
  except Exception as e:
116
+ return f"{full_answer}\n\n⚠️ Audio generation error: {e}", None
 
 
117
 
118
  def process_image_question(image: Image.Image, question: str):
119
+ return answer_question_from_image(image, question)
 
120
 
121
+ # Gradio UI
122
  gui = gr.Interface(
123
  fn=process_image_question,
124
  inputs=[
 
137
 
138
  @app.get("/")
139
  def home():
140
+ return RedirectResponse(url="/")