ZhengPeng7 commited on
Commit
681c14f
·
1 Parent(s): b0bc43c

Upgrade the weights loading method to avoid duplicated loading.

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -49,23 +49,23 @@ weights_path = 'General'
49
  birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file[weights_path])), trust_remote_code=True)
50
  birefnet.to(device)
51
  birefnet.eval()
52
- birefnet.weights_path = weights_path
53
 
54
 
55
  @spaces.GPU
56
  def predict(image, resolution, weights_file):
57
- global birefnet
58
- if birefnet.weights_path != weights_file:
59
  print('*' * 10)
60
- print('\t1: ', weights_file, birefnet.weights_path)
61
  # Load BiRefNet with chosen weights
62
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else 'BiRefNet'))
63
  print('Change weights to:', _weights_file)
64
  birefnet = birefnet.from_pretrained(_weights_file)
65
  birefnet.to(device)
66
  birefnet.eval()
67
- birefnet.weights_path = weights_file
68
- print('\t2: ', weights_file, birefnet.weights_path)
69
 
70
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
71
  # Image is a RGB numpy array.
 
49
  birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file[weights_path])), trust_remote_code=True)
50
  birefnet.to(device)
51
  birefnet.eval()
52
+ weights_path = weights_path
53
 
54
 
55
  @spaces.GPU
56
  def predict(image, resolution, weights_file):
57
+ global weights_path
58
+ if weights_path != weights_file:
59
  print('*' * 10)
60
+ print('\t1: ', weights_file, weights_path)
61
  # Load BiRefNet with chosen weights
62
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else 'BiRefNet'))
63
  print('Change weights to:', _weights_file)
64
  birefnet = birefnet.from_pretrained(_weights_file)
65
  birefnet.to(device)
66
  birefnet.eval()
67
+ weights_path = weights_file
68
+ print('\t2: ', weights_file, weights_path)
69
 
70
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
71
  # Image is a RGB numpy array.