ivelin commited on
Commit
2399c69
·
1 Parent(s): f85a58b

fix: load model revision based on input

Browse files
Files changed (1) hide show
  1. app.py +29 -24
app.py CHANGED
@@ -6,29 +6,32 @@ import torch
6
  import html
7
  from transformers import DonutProcessor, VisionEncoderDecoderModel
8
 
9
- pretrained_repo_name = 'ivelin/donut-refexp-click'
10
- pretrained_revision = 'main'
11
- # revision can be git commit hash, branch or tag
12
- # use 'main' for latest revision
13
- print(f"Loading model checkpoint: {pretrained_repo_name}")
14
-
15
- processor = DonutProcessor.from_pretrained(
16
- pretrained_repo_name, revision=pretrained_revision, use_auth_token="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")
17
- processor.image_processor.do_align_long_axis = False
18
- # do not manipulate image size and position
19
- processor.image_processor.do_resize = False
20
- processor.image_processor.do_thumbnail = False
21
- processor.image_processor.do_pad = False
22
- # processor.image_processor.do_rescale = False
23
- processor.image_processor.do_normalize = True
24
- print(f'processor image size: {processor.image_processor.size}')
25
-
26
- model = VisionEncoderDecoderModel.from_pretrained(
27
- pretrained_repo_name, use_auth_token="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", revision=pretrained_revision)
28
-
29
- device = "cuda" if torch.cuda.is_available() else "cpu"
30
- model.to(device)
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def prepare_image_for_encoder(image=None, output_image_size=None):
34
  """
@@ -86,10 +89,12 @@ def translate_point_coords_from_out_to_in(point=None, input_image_size=None, out
86
  f"translated point={point}, resized_image_size: {resized_width, resized_height}")
87
 
88
 
89
- def process_refexp(image: Image, prompt: str, revision: str = 'main'):
90
 
91
  print(f"(image, prompt): {image}, {prompt}")
92
- print(f"model checkpoint revision: {revision}")
 
 
93
 
94
  # trim prompt to 80 characters and normalize to lowercase
95
  prompt = prompt[:80].lower()
 
6
  import html
7
  from transformers import DonutProcessor, VisionEncoderDecoderModel
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ global model, processor
11
+
12
+
13
+ def load_model(pretrained_revision: str = 'main'):
14
+ global model, processor
15
+ pretrained_repo_name = 'ivelin/donut-refexp-click'
16
+ # revision can be git commit hash, branch or tag
17
+ # use 'main' for latest revision
18
+ print(f"Loading model checkpoint from repo: {pretrained_repo_name}, revision: {pretrained_revision}")
19
+ processor = DonutProcessor.from_pretrained(
20
+ pretrained_repo_name, revision=pretrained_revision, use_auth_token="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")
21
+ processor.image_processor.do_align_long_axis = False
22
+ # do not manipulate image size and position
23
+ processor.image_processor.do_resize = False
24
+ processor.image_processor.do_thumbnail = False
25
+ processor.image_processor.do_pad = False
26
+ # processor.image_processor.do_rescale = False
27
+ processor.image_processor.do_normalize = True
28
+ print(f'processor image size: {processor.image_processor.size}')
29
+ model = VisionEncoderDecoderModel.from_pretrained(
30
+ pretrained_repo_name, use_auth_token="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", revision=pretrained_revision)
31
+
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ model.to(device)
34
+
35
 
36
  def prepare_image_for_encoder(image=None, output_image_size=None):
37
  """
 
89
  f"translated point={point}, resized_image_size: {resized_width, resized_height}")
90
 
91
 
92
+ def process_refexp(image: Image, prompt: str, model_revision: str = 'main'):
93
 
94
  print(f"(image, prompt): {image}, {prompt}")
95
+ print(f"model checkpoint revision: {model_revision}")
96
+
97
+ load_model(model_revision)
98
 
99
  # trim prompt to 80 characters and normalize to lowercase
100
  prompt = prompt[:80].lower()