ZhengPeng7 commited on
Commit
1be3d4d
·
1 Parent(s): 8af980d

Fix the transformers cache bug in versions >= v4.22.0.

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -10,7 +10,7 @@ from typing import Tuple
10
 
11
  from PIL import Image
12
  from gradio_imageslider import ImageSlider
13
- from transformers import AutoModelForImageSegmentation
14
  from torchvision import transforms
15
 
16
  import requests
@@ -18,6 +18,7 @@ from io import BytesIO
18
  import zipfile
19
 
20
 
 
21
  torch.set_float32_matmul_precision('high')
22
  torch.jit.script = lambda f: f
23
 
@@ -89,7 +90,7 @@ usage_to_weights_file = {
89
  'General-dynamic': 'BiRefNet_dynamic',
90
  }
91
 
92
- birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
93
  birefnet.to(device)
94
  birefnet.eval(); birefnet.half()
95
 
@@ -102,7 +103,7 @@ def predict(images, resolution, weights_file):
102
  # Load BiRefNet with chosen weights
103
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
104
  print('Using weights: {}.'.format(_weights_file))
105
- birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
106
  birefnet.to(device)
107
  birefnet.eval(); birefnet.half()
108
 
 
10
 
11
  from PIL import Image
12
  from gradio_imageslider import ImageSlider
13
+ import transformers
14
  from torchvision import transforms
15
 
16
  import requests
 
18
  import zipfile
19
 
20
 
21
+ transformers.utils.move_cache()
22
  torch.set_float32_matmul_precision('high')
23
  torch.jit.script = lambda f: f
24
 
 
90
  'General-dynamic': 'BiRefNet_dynamic',
91
  }
92
 
93
+ birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
94
  birefnet.to(device)
95
  birefnet.eval(); birefnet.half()
96
 
 
103
  # Load BiRefNet with chosen weights
104
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
105
  print('Using weights: {}.'.format(_weights_file))
106
+ birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
107
  birefnet.to(device)
108
  birefnet.eval(); birefnet.half()
109