fl399 commited on
Commit
d43e0b2
·
1 Parent(s): 3808a7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -126,7 +126,7 @@ def evaluate(
126
  table,
127
  question,
128
  llm="alpaca-lora",
129
- shot="1-shot",
130
  input=None,
131
  temperature=0.1,
132
  top_p=0.75,
@@ -138,7 +138,7 @@ def evaluate(
138
  prompt_0shot = _INSTRUCTION + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
139
  prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
140
  if llm == "alpaca-lora":
141
- if shot == "1-shot":
142
  inputs = tokenizer(prompt, return_tensors="pt")
143
  else:
144
  inputs = tokenizer(prompt_0shot, return_tensors="pt")
@@ -161,7 +161,7 @@ def evaluate(
161
  s = generation_output.sequences[0]
162
  output = tokenizer.decode(s)
163
  elif llm == "flan-ul2":
164
- if shot == "1-shot":
165
  output = query({
166
  "inputs": prompt
167
  })[0]["generated_text"]
@@ -179,18 +179,16 @@ def evaluate(
179
  model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
180
  processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
181
 
182
- def process_document(llm, image, question):
183
  # image = Image.open(image)
184
  inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(0, torch.bfloat16)
185
  predictions = model_deplot.generate(**inputs, max_new_tokens=512)
186
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
187
 
188
  # send prompt+table to LLM
189
- res = evaluate(table, question, llm=llm)
190
- #return res + "\n\n" + res.split("A:")[-1]
191
  if llm == "alpaca-lora":
192
  return [table, res.split("A:")[-1]]
193
- # return [table, res]
194
  else:
195
  return [table, res]
196
 
 
126
  table,
127
  question,
128
  llm="alpaca-lora",
129
+ num_shot="1-shot",
130
  input=None,
131
  temperature=0.1,
132
  top_p=0.75,
 
138
  prompt_0shot = _INSTRUCTION + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
139
  prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
140
  if llm == "alpaca-lora":
141
+ if num_shot == "1-shot":
142
  inputs = tokenizer(prompt, return_tensors="pt")
143
  else:
144
  inputs = tokenizer(prompt_0shot, return_tensors="pt")
 
161
  s = generation_output.sequences[0]
162
  output = tokenizer.decode(s)
163
  elif llm == "flan-ul2":
164
+ if num_shot == "1-shot":
165
  output = query({
166
  "inputs": prompt
167
  })[0]["generated_text"]
 
179
  model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
180
  processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
181
 
182
+ def process_document(llm, num_shot, image, question):
183
  # image = Image.open(image)
184
  inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(0, torch.bfloat16)
185
  predictions = model_deplot.generate(**inputs, max_new_tokens=512)
186
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
187
 
188
  # send prompt+table to LLM
189
+ res = evaluate(table, num_shot, question, llm=llm)
 
190
  if llm == "alpaca-lora":
191
  return [table, res.split("A:")[-1]]
 
192
  else:
193
  return [table, res]
194