wondervictor commited on
Commit
0370e48
·
verified ·
1 Parent(s): b82061c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -91,16 +91,14 @@ def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
91
  # CLIP模型初始化
92
  if clip_model is None:
93
  clip_model, _, _ = open_clip.create_model_and_transforms("convnext_large_d_320", pretrained="laion2b_s29b_b131k_ft_soup")
 
94
  print("CLIP model initialized.")
95
 
96
  # Mask Adapter模型初始化
97
  if mask_adapter is None:
98
- mask_adapter = build_mask_adapter(cfg, "MASKAdapterHead").cpu()
99
  # 加载Adapter状态字典
100
- adapter_state_dict = torch.load(adapter_pth)
101
- adapter_state_dict = {k.replace('mask_adapter.', '').replace('adapter.', ''): v
102
- for k, v in adapter_state_dict["model"].items()
103
- if k.startswith('adapter') or k.startswith('mask_adapter')}
104
  mask_adapter.load_state_dict(adapter_state_dict)
105
  print("Mask Adapter model initialized.")
106
 
 
91
  # CLIP模型初始化
92
  if clip_model is None:
93
  clip_model, _, _ = open_clip.create_model_and_transforms("convnext_large_d_320", pretrained="laion2b_s29b_b131k_ft_soup")
94
+ clip_model = clip_model.to("cpu")
95
  print("CLIP model initialized.")
96
 
97
  # Mask Adapter模型初始化
98
  if mask_adapter is None:
99
+ mask_adapter = build_mask_adapter(cfg, "MASKAdapterHead").cuda()
100
  # 加载Adapter状态字典
101
+ adapter_state_dict = torch.load(adapter_pth, map_location=torch.device('cpu'))
 
 
 
102
  mask_adapter.load_state_dict(adapter_state_dict)
103
  print("Mask Adapter model initialized.")
104