ikraamkb commited on
Commit
79cce77
·
verified ·
1 Parent(s): e59323e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -70,7 +70,7 @@ from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
70
  import os
71
  import shutil
72
  from PIL import Image
73
- from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForSeq2SeqLM
74
  from gtts import gTTS
75
  import torch
76
  import tempfile
@@ -82,25 +82,24 @@ app = FastAPI()
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
- # Load FLAN-T5 model to rewrite answers (better for CPU)
86
- gpt_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
87
- gpt_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
88
 
89
  def rewrite_answer(question, short_answer):
90
- prompt = f"Write a full sentence that answers the question '{question}' using this answer: {short_answer}."
91
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
92
  with torch.no_grad():
93
  outputs = gpt_model.generate(
94
  **inputs,
95
  max_new_tokens=50,
96
  do_sample=True,
97
- top_p=0.9,
98
- temperature=0.7,
99
- pad_token_id=gpt_tokenizer.pad_token_id
100
  )
101
-
102
  rewritten = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
103
- return rewritten
104
 
105
  def answer_question_from_image(image, question):
106
  if image is None or not question.strip():
@@ -147,4 +146,4 @@ app = gr.mount_gradio_app(app, gui, path="/")
147
 
148
  @app.get("/")
149
  def home():
150
- return RedirectResponse(url="/")
 
70
  import os
71
  import shutil
72
  from PIL import Image
73
+ from transformers import ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
74
  from gtts import gTTS
75
  import torch
76
  import tempfile
 
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
+ # Load Falcon-RW-1B model to rewrite answers
86
+ gpt_tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b")
87
+ gpt_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-rw-1b")
88
 
89
  def rewrite_answer(question, short_answer):
90
+ prompt = f"Question: {question}\nShort Answer: {short_answer}\nFull sentence:" # Few-shot style prompt
91
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
92
  with torch.no_grad():
93
  outputs = gpt_model.generate(
94
  **inputs,
95
  max_new_tokens=50,
96
  do_sample=True,
97
+ top_p=0.95,
98
+ temperature=0.8,
99
+ pad_token_id=gpt_tokenizer.eos_token_id
100
  )
 
101
  rewritten = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ return rewritten.split("Full sentence:")[-1].strip()
103
 
104
  def answer_question_from_image(image, question):
105
  if image is None or not question.strip():
 
146
 
147
  @app.get("/")
148
  def home():
149
+ return RedirectResponse(url="/")