Sangjun2 commited on
Commit
e67d45e
·
verified ·
1 Parent(s): 4859729

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -19
app.py CHANGED
@@ -84,7 +84,7 @@ def format_output(prediction):
84
  return prediction.replace('<0x0A>', '\n')
85
 
86
  # First model prediction ko-deplot
87
- @spaces.GPU(enable_queue=True,duration=120)
88
  def predict_model1(image):
89
  images = [image]
90
  inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
@@ -116,25 +116,8 @@ def replace_unk(text):
116
  text = text.replace('<unk>', '')
117
  return text
118
 
119
- # Second model prediction aihub_deplot
120
- @spaces.GPU(enable_queue=True,duration=120)
121
- def predict_model2(image):
122
- image = image.convert("RGB")
123
- inputs = processor2(images=image, return_tensors="pt", max_patches=MAX_PATCHES).to(device)
124
 
125
- flattened_patches = inputs.flattened_patches.to(device)
126
- attention_mask = inputs.attention_mask.to(device)
127
-
128
- model2.eval()
129
-
130
- with torch.no_grad():
131
- deplot_generated_ids = model2.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=1000)
132
- generated_datatable = processor2.batch_decode(deplot_generated_ids, skip_special_tokens=False)[0]
133
- generated_datatable = generated_datatable.replace("<pad>", "<unk>").replace("</s>", "<unk>")
134
- refined_table = replace_unk(generated_datatable)
135
- return refined_table
136
-
137
- @spaces.GPU(enable_queue=True,duration=120)
138
  def predict_model3(image):
139
  image=image.convert("RGB")
140
  input_prompt = "<extract_data_table> <s_answer>"
 
84
  return prediction.replace('<0x0A>', '\n')
85
 
86
  # First model prediction ko-deplot
87
+ @spaces.GPU(enable_queue=True,duration=100)
88
  def predict_model1(image):
89
  images = [image]
90
  inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
 
116
  text = text.replace('<unk>', '')
117
  return text
118
 
 
 
 
 
 
119
 
120
+ @spaces.GPU(enable_queue=True,duration=100)
 
 
 
 
 
 
 
 
 
 
 
 
121
  def predict_model3(image):
122
  image=image.convert("RGB")
123
  input_prompt = "<extract_data_table> <s_answer>"