Spaces:
Sleeping
Sleeping
fix: load model revision based on input
Browse files
app.py
CHANGED
|
@@ -6,29 +6,32 @@ import torch
|
|
| 6 |
import html
|
| 7 |
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
| 8 |
|
| 9 |
-
pretrained_repo_name = 'ivelin/donut-refexp-click'
|
| 10 |
-
pretrained_revision = 'main'
|
| 11 |
-
# revision can be git commit hash, branch or tag
|
| 12 |
-
# use 'main' for latest revision
|
| 13 |
-
print(f"Loading model checkpoint: {pretrained_repo_name}")
|
| 14 |
-
|
| 15 |
-
processor = DonutProcessor.from_pretrained(
|
| 16 |
-
pretrained_repo_name, revision=pretrained_revision, use_auth_token="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")
|
| 17 |
-
processor.image_processor.do_align_long_axis = False
|
| 18 |
-
# do not manipulate image size and position
|
| 19 |
-
processor.image_processor.do_resize = False
|
| 20 |
-
processor.image_processor.do_thumbnail = False
|
| 21 |
-
processor.image_processor.do_pad = False
|
| 22 |
-
# processor.image_processor.do_rescale = False
|
| 23 |
-
processor.image_processor.do_normalize = True
|
| 24 |
-
print(f'processor image size: {processor.image_processor.size}')
|
| 25 |
-
|
| 26 |
-
model = VisionEncoderDecoderModel.from_pretrained(
|
| 27 |
-
pretrained_repo_name, use_auth_token="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", revision=pretrained_revision)
|
| 28 |
-
|
| 29 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 30 |
-
model.to(device)
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def prepare_image_for_encoder(image=None, output_image_size=None):
|
| 34 |
"""
|
|
@@ -86,10 +89,12 @@ def translate_point_coords_from_out_to_in(point=None, input_image_size=None, out
|
|
| 86 |
f"translated point={point}, resized_image_size: {resized_width, resized_height}")
|
| 87 |
|
| 88 |
|
| 89 |
-
def process_refexp(image: Image, prompt: str,
|
| 90 |
|
| 91 |
print(f"(image, prompt): {image}, {prompt}")
|
| 92 |
-
print(f"model checkpoint revision: {
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# trim prompt to 80 characters and normalize to lowercase
|
| 95 |
prompt = prompt[:80].lower()
|
|
|
|
| 6 |
import html
|
| 7 |
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
global model, processor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_model(pretrained_revision: str = 'main'):
|
| 14 |
+
global model, processor
|
| 15 |
+
pretrained_repo_name = 'ivelin/donut-refexp-click'
|
| 16 |
+
# revision can be git commit hash, branch or tag
|
| 17 |
+
# use 'main' for latest revision
|
| 18 |
+
print(f"Loading model checkpoint from repo: {pretrained_repo_name}, revision: {pretrained_revision}")
|
| 19 |
+
processor = DonutProcessor.from_pretrained(
|
| 20 |
+
pretrained_repo_name, revision=pretrained_revision, use_auth_token="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")
|
| 21 |
+
processor.image_processor.do_align_long_axis = False
|
| 22 |
+
# do not manipulate image size and position
|
| 23 |
+
processor.image_processor.do_resize = False
|
| 24 |
+
processor.image_processor.do_thumbnail = False
|
| 25 |
+
processor.image_processor.do_pad = False
|
| 26 |
+
# processor.image_processor.do_rescale = False
|
| 27 |
+
processor.image_processor.do_normalize = True
|
| 28 |
+
print(f'processor image size: {processor.image_processor.size}')
|
| 29 |
+
model = VisionEncoderDecoderModel.from_pretrained(
|
| 30 |
+
pretrained_repo_name, use_auth_token="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", revision=pretrained_revision)
|
| 31 |
+
|
| 32 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
+
model.to(device)
|
| 34 |
+
|
| 35 |
|
| 36 |
def prepare_image_for_encoder(image=None, output_image_size=None):
|
| 37 |
"""
|
|
|
|
| 89 |
f"translated point={point}, resized_image_size: {resized_width, resized_height}")
|
| 90 |
|
| 91 |
|
| 92 |
+
def process_refexp(image: Image, prompt: str, model_revision: str = 'main'):
|
| 93 |
|
| 94 |
print(f"(image, prompt): {image}, {prompt}")
|
| 95 |
+
print(f"model checkpoint revision: {model_revision}")
|
| 96 |
+
|
| 97 |
+
load_model(model_revision)
|
| 98 |
|
| 99 |
# trim prompt to 80 characters and normalize to lowercase
|
| 100 |
prompt = prompt[:80].lower()
|