Spaces:
Runtime error
Runtime error
Update app.py
#1
by
ybelkada
- opened
app.py
CHANGED
|
@@ -148,12 +148,12 @@ def evaluate(
|
|
| 148 |
|
| 149 |
|
| 150 |
## deplot models
|
| 151 |
-
model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
|
| 152 |
processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
|
| 153 |
|
| 154 |
def process_document(image, question):
|
| 155 |
# image = Image.open(image)
|
| 156 |
-
inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt")
|
| 157 |
predictions = model_deplot.generate(**inputs, max_new_tokens=512)
|
| 158 |
table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
|
| 159 |
|
|
|
|
| 148 |
|
| 149 |
|
| 150 |
## deplot models
|
| 151 |
+
model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
|
| 152 |
processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
|
| 153 |
|
| 154 |
def process_document(image, question):
|
| 155 |
# image = Image.open(image)
|
| 156 |
+
inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(torch.bfloat16, 0)
|
| 157 |
predictions = model_deplot.generate(**inputs, max_new_tokens=512)
|
| 158 |
table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
|
| 159 |
|