Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -84,7 +84,7 @@ def format_output(prediction):
|
|
| 84 |
return prediction.replace('<0x0A>', '\n')
|
| 85 |
|
| 86 |
# First model prediction ko-deplot
|
| 87 |
-
@spaces.GPU(enable_queue=True,duration=
|
| 88 |
def predict_model1(image):
|
| 89 |
images = [image]
|
| 90 |
inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
|
|
@@ -116,25 +116,8 @@ def replace_unk(text):
|
|
| 116 |
text = text.replace('<unk>', '')
|
| 117 |
return text
|
| 118 |
|
| 119 |
-
# Second model prediction aihub_deplot
|
| 120 |
-
@spaces.GPU(enable_queue=True,duration=120)
|
| 121 |
-
def predict_model2(image):
|
| 122 |
-
image = image.convert("RGB")
|
| 123 |
-
inputs = processor2(images=image, return_tensors="pt", max_patches=MAX_PATCHES).to(device)
|
| 124 |
|
| 125 |
-
|
| 126 |
-
attention_mask = inputs.attention_mask.to(device)
|
| 127 |
-
|
| 128 |
-
model2.eval()
|
| 129 |
-
|
| 130 |
-
with torch.no_grad():
|
| 131 |
-
deplot_generated_ids = model2.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=1000)
|
| 132 |
-
generated_datatable = processor2.batch_decode(deplot_generated_ids, skip_special_tokens=False)[0]
|
| 133 |
-
generated_datatable = generated_datatable.replace("<pad>", "<unk>").replace("</s>", "<unk>")
|
| 134 |
-
refined_table = replace_unk(generated_datatable)
|
| 135 |
-
return refined_table
|
| 136 |
-
|
| 137 |
-
@spaces.GPU(enable_queue=True,duration=120)
|
| 138 |
def predict_model3(image):
|
| 139 |
image=image.convert("RGB")
|
| 140 |
input_prompt = "<extract_data_table> <s_answer>"
|
|
|
|
| 84 |
return prediction.replace('<0x0A>', '\n')
|
| 85 |
|
| 86 |
# First model prediction ko-deplot
|
| 87 |
+
@spaces.GPU(enable_queue=True,duration=100)
|
| 88 |
def predict_model1(image):
|
| 89 |
images = [image]
|
| 90 |
inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
|
|
|
|
| 116 |
text = text.replace('<unk>', '')
|
| 117 |
return text
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
@spaces.GPU(enable_queue=True,duration=100)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
def predict_model3(image):
|
| 122 |
image=image.convert("RGB")
|
| 123 |
input_prompt = "<extract_data_table> <s_answer>"
|