ikraamkb commited on
Commit
e59323e
·
verified ·
1 Parent(s): 14e7320

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -22
app.py CHANGED
@@ -70,7 +70,7 @@ from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
70
  import os
71
  import shutil
72
  from PIL import Image
73
- from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
74
  from gtts import gTTS
75
  import torch
76
  import tempfile
@@ -82,39 +82,26 @@ app = FastAPI()
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
- # Load GPT model to rewrite answers (Phi-1.5)
86
- gpt_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
87
- gpt_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5")
88
 
89
  def rewrite_answer(question, short_answer):
90
- prompt = f"Write a full sentence to answer this:\nQ: {question}\nA: {short_answer}\nFull sentence:"
91
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
92
  with torch.no_grad():
93
  outputs = gpt_model.generate(
94
  **inputs,
95
  max_new_tokens=50,
96
  do_sample=True,
97
- top_k=40,
98
  top_p=0.9,
99
- temperature=0.6,
100
- pad_token_id=gpt_tokenizer.eos_token_id
101
  )
102
 
103
- generated = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
104
-
105
- # Try to isolate the answer portion
106
- if "Full sentence:" in generated:
107
- rewritten = generated.split("Full sentence:")[-1].strip()
108
- else:
109
- rewritten = generated.strip()
110
-
111
- # Fallback to basic templating if model fails
112
- if not rewritten or len(rewritten.split()) < 3:
113
- rewritten = f"The answer to the question '{question}' is: {short_answer}."
114
-
115
  return rewritten
116
 
117
-
118
  def answer_question_from_image(image, question):
119
  if image is None or not question.strip():
120
  return "Please upload an image and ask a question.", None
@@ -125,7 +112,7 @@ def answer_question_from_image(image, question):
125
  predicted_id = outputs.logits.argmax(-1).item()
126
  short_answer = vqa_model.config.id2label[predicted_id]
127
 
128
- # Rewrite short answer to full sentence with Phi-1.5
129
  full_answer = rewrite_answer(question, short_answer)
130
 
131
  try:
 
70
  import os
71
  import shutil
72
  from PIL import Image
73
+ from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForSeq2SeqLM
74
  from gtts import gTTS
75
  import torch
76
  import tempfile
 
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
+ # Load FLAN-T5 model to rewrite answers (better for CPU)
86
+ gpt_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
87
+ gpt_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
88
 
89
  def rewrite_answer(question, short_answer):
90
+ prompt = f"Write a full sentence that answers the question '{question}' using this answer: {short_answer}."
91
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
92
  with torch.no_grad():
93
  outputs = gpt_model.generate(
94
  **inputs,
95
  max_new_tokens=50,
96
  do_sample=True,
 
97
  top_p=0.9,
98
+ temperature=0.7,
99
+ pad_token_id=gpt_tokenizer.pad_token_id
100
  )
101
 
102
+ rewritten = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
103
  return rewritten
104
 
 
105
  def answer_question_from_image(image, question):
106
  if image is None or not question.strip():
107
  return "Please upload an image and ask a question.", None
 
112
  predicted_id = outputs.logits.argmax(-1).item()
113
  short_answer = vqa_model.config.id2label[predicted_id]
114
 
115
+ # Rewrite short answer to full sentence
116
  full_answer = rewrite_answer(question, short_answer)
117
 
118
  try: