Spaces:
Sleeping
Sleeping
fix: load model revision based on input
Browse files
app.py
CHANGED
@@ -6,29 +6,32 @@ import torch
|
|
6 |
import html
|
7 |
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
8 |
|
9 |
-
pretrained_repo_name = 'ivelin/donut-refexp-click'
|
10 |
-
pretrained_revision = 'main'
|
11 |
-
# revision can be git commit hash, branch or tag
|
12 |
-
# use 'main' for latest revision
|
13 |
-
print(f"Loading model checkpoint: {pretrained_repo_name}")
|
14 |
-
|
15 |
-
processor = DonutProcessor.from_pretrained(
|
16 |
-
pretrained_repo_name, revision=pretrained_revision, use_auth_token="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")
|
17 |
-
processor.image_processor.do_align_long_axis = False
|
18 |
-
# do not manipulate image size and position
|
19 |
-
processor.image_processor.do_resize = False
|
20 |
-
processor.image_processor.do_thumbnail = False
|
21 |
-
processor.image_processor.do_pad = False
|
22 |
-
# processor.image_processor.do_rescale = False
|
23 |
-
processor.image_processor.do_normalize = True
|
24 |
-
print(f'processor image size: {processor.image_processor.size}')
|
25 |
-
|
26 |
-
model = VisionEncoderDecoderModel.from_pretrained(
|
27 |
-
pretrained_repo_name, use_auth_token="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", revision=pretrained_revision)
|
28 |
-
|
29 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
-
model.to(device)
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def prepare_image_for_encoder(image=None, output_image_size=None):
|
34 |
"""
|
@@ -86,10 +89,12 @@ def translate_point_coords_from_out_to_in(point=None, input_image_size=None, out
|
|
86 |
f"translated point={point}, resized_image_size: {resized_width, resized_height}")
|
87 |
|
88 |
|
89 |
-
def process_refexp(image: Image, prompt: str,
|
90 |
|
91 |
print(f"(image, prompt): {image}, {prompt}")
|
92 |
-
print(f"model checkpoint revision: {
|
|
|
|
|
93 |
|
94 |
# trim prompt to 80 characters and normalize to lowercase
|
95 |
prompt = prompt[:80].lower()
|
|
|
6 |
import html
|
7 |
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
global model, processor
|
11 |
+
|
12 |
+
|
13 |
+
def load_model(pretrained_revision: str = 'main'):
|
14 |
+
global model, processor
|
15 |
+
pretrained_repo_name = 'ivelin/donut-refexp-click'
|
16 |
+
# revision can be git commit hash, branch or tag
|
17 |
+
# use 'main' for latest revision
|
18 |
+
print(f"Loading model checkpoint from repo: {pretrained_repo_name}, revision: {pretrained_revision}")
|
19 |
+
processor = DonutProcessor.from_pretrained(
|
20 |
+
pretrained_repo_name, revision=pretrained_revision, use_auth_token="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")
|
21 |
+
processor.image_processor.do_align_long_axis = False
|
22 |
+
# do not manipulate image size and position
|
23 |
+
processor.image_processor.do_resize = False
|
24 |
+
processor.image_processor.do_thumbnail = False
|
25 |
+
processor.image_processor.do_pad = False
|
26 |
+
# processor.image_processor.do_rescale = False
|
27 |
+
processor.image_processor.do_normalize = True
|
28 |
+
print(f'processor image size: {processor.image_processor.size}')
|
29 |
+
model = VisionEncoderDecoderModel.from_pretrained(
|
30 |
+
pretrained_repo_name, use_auth_token="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", revision=pretrained_revision)
|
31 |
+
|
32 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
+
model.to(device)
|
34 |
+
|
35 |
|
36 |
def prepare_image_for_encoder(image=None, output_image_size=None):
|
37 |
"""
|
|
|
89 |
f"translated point={point}, resized_image_size: {resized_width, resized_height}")
|
90 |
|
91 |
|
92 |
+
def process_refexp(image: Image, prompt: str, model_revision: str = 'main'):
|
93 |
|
94 |
print(f"(image, prompt): {image}, {prompt}")
|
95 |
+
print(f"model checkpoint revision: {model_revision}")
|
96 |
+
|
97 |
+
load_model(model_revision)
|
98 |
|
99 |
# trim prompt to 80 characters and normalize to lowercase
|
100 |
prompt = prompt[:80].lower()
|