Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -84,7 +84,7 @@ def format_output(prediction):
|
|
84 |
return prediction.replace('<0x0A>', '\n')
|
85 |
|
86 |
# First model prediction ko-deplot
|
87 |
-
@spaces.GPU(enable_queue=True,duration=
|
88 |
def predict_model1(image):
|
89 |
images = [image]
|
90 |
inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
|
@@ -116,25 +116,8 @@ def replace_unk(text):
|
|
116 |
text = text.replace('<unk>', '')
|
117 |
return text
|
118 |
|
119 |
-
# Second model prediction aihub_deplot
|
120 |
-
@spaces.GPU(enable_queue=True,duration=120)
|
121 |
-
def predict_model2(image):
|
122 |
-
image = image.convert("RGB")
|
123 |
-
inputs = processor2(images=image, return_tensors="pt", max_patches=MAX_PATCHES).to(device)
|
124 |
|
125 |
-
|
126 |
-
attention_mask = inputs.attention_mask.to(device)
|
127 |
-
|
128 |
-
model2.eval()
|
129 |
-
|
130 |
-
with torch.no_grad():
|
131 |
-
deplot_generated_ids = model2.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=1000)
|
132 |
-
generated_datatable = processor2.batch_decode(deplot_generated_ids, skip_special_tokens=False)[0]
|
133 |
-
generated_datatable = generated_datatable.replace("<pad>", "<unk>").replace("</s>", "<unk>")
|
134 |
-
refined_table = replace_unk(generated_datatable)
|
135 |
-
return refined_table
|
136 |
-
|
137 |
-
@spaces.GPU(enable_queue=True,duration=120)
|
138 |
def predict_model3(image):
|
139 |
image=image.convert("RGB")
|
140 |
input_prompt = "<extract_data_table> <s_answer>"
|
|
|
84 |
return prediction.replace('<0x0A>', '\n')
|
85 |
|
86 |
# First model prediction ko-deplot
|
87 |
+
@spaces.GPU(enable_queue=True,duration=100)
|
88 |
def predict_model1(image):
|
89 |
images = [image]
|
90 |
inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
|
|
|
116 |
text = text.replace('<unk>', '')
|
117 |
return text
|
118 |
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
+
@spaces.GPU(enable_queue=True,duration=100)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
def predict_model3(image):
|
122 |
image=image.convert("RGB")
|
123 |
input_prompt = "<extract_data_table> <s_answer>"
|