Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
1be3d4d
1
Parent(s):
8af980d
Fix the transformers cache bug in versions >= v4.22.0.
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ from typing import Tuple
|
|
10 |
|
11 |
from PIL import Image
|
12 |
from gradio_imageslider import ImageSlider
|
13 |
-
|
14 |
from torchvision import transforms
|
15 |
|
16 |
import requests
|
@@ -18,6 +18,7 @@ from io import BytesIO
|
|
18 |
import zipfile
|
19 |
|
20 |
|
|
|
21 |
torch.set_float32_matmul_precision('high')
|
22 |
torch.jit.script = lambda f: f
|
23 |
|
@@ -89,7 +90,7 @@ usage_to_weights_file = {
|
|
89 |
'General-dynamic': 'BiRefNet_dynamic',
|
90 |
}
|
91 |
|
92 |
-
birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
|
93 |
birefnet.to(device)
|
94 |
birefnet.eval(); birefnet.half()
|
95 |
|
@@ -102,7 +103,7 @@ def predict(images, resolution, weights_file):
|
|
102 |
# Load BiRefNet with chosen weights
|
103 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
104 |
print('Using weights: {}.'.format(_weights_file))
|
105 |
-
birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
|
106 |
birefnet.to(device)
|
107 |
birefnet.eval(); birefnet.half()
|
108 |
|
|
|
10 |
|
11 |
from PIL import Image
|
12 |
from gradio_imageslider import ImageSlider
|
13 |
+
import transformers
|
14 |
from torchvision import transforms
|
15 |
|
16 |
import requests
|
|
|
18 |
import zipfile
|
19 |
|
20 |
|
21 |
+
transformers.utils.move_cache()
|
22 |
torch.set_float32_matmul_precision('high')
|
23 |
torch.jit.script = lambda f: f
|
24 |
|
|
|
90 |
'General-dynamic': 'BiRefNet_dynamic',
|
91 |
}
|
92 |
|
93 |
+
birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
|
94 |
birefnet.to(device)
|
95 |
birefnet.eval(); birefnet.half()
|
96 |
|
|
|
103 |
# Load BiRefNet with chosen weights
|
104 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
105 |
print('Using weights: {}.'.format(_weights_file))
|
106 |
+
birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
|
107 |
birefnet.to(device)
|
108 |
birefnet.eval(); birefnet.half()
|
109 |
|