ikraamkb commited on
Commit
9173622
·
verified ·
1 Parent(s): 79cce77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -82,24 +82,24 @@ app = FastAPI()
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
- # Load Falcon-RW-1B model to rewrite answers
86
- gpt_tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b")
87
- gpt_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-rw-1b")
88
 
89
  def rewrite_answer(question, short_answer):
90
- prompt = f"Question: {question}\nShort Answer: {short_answer}\nFull sentence:" # Few-shot style prompt
91
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
92
  with torch.no_grad():
93
  outputs = gpt_model.generate(
94
  **inputs,
95
  max_new_tokens=50,
96
  do_sample=True,
97
- top_p=0.95,
98
  temperature=0.8,
99
  pad_token_id=gpt_tokenizer.eos_token_id
100
  )
101
  rewritten = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
102
- return rewritten.split("Full sentence:")[-1].strip()
103
 
104
  def answer_question_from_image(image, question):
105
  if image is None or not question.strip():
 
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
+ # Load Falcon-7B-Instruct model to rewrite answers
86
+ gpt_tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct")
87
+ gpt_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-7b-instruct")
88
 
89
  def rewrite_answer(question, short_answer):
90
+ prompt = f"Rewrite the short answer into a natural sentence.\nQuestion: {question}\nAnswer: {short_answer}\nFull Sentence:"
91
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
92
  with torch.no_grad():
93
  outputs = gpt_model.generate(
94
  **inputs,
95
  max_new_tokens=50,
96
  do_sample=True,
97
+ top_p=0.9,
98
  temperature=0.8,
99
  pad_token_id=gpt_tokenizer.eos_token_id
100
  )
101
  rewritten = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ return rewritten.split("Full Sentence:")[-1].strip()
103
 
104
  def answer_question_from_image(image, question):
105
  if image is None or not question.strip():