Haaribo commited on
Commit
0df2aff
·
verified ·
1 Parent(s): 8e3d21b

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +4 -4
visualization.py CHANGED
@@ -131,9 +131,9 @@ def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=12
131
  device = torch.device("cpu")
132
  if folder is None:
133
  folder = Path(f"tmp/{arch}/{dataset}/{seed}/")
134
- model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth"))
135
- state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth")
136
- selection= torch.load(folder / f"SlDD_Selection_50.pt")
137
  state_dict['linear.selection']=selection
138
 
139
  feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
@@ -370,7 +370,7 @@ def direct_inference(input):
370
 
371
  #original
372
 
373
- data_dir=Path.home()/"tmp/Datasets/CUB200/CUB_200_2011/"
374
  classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
375
  output_name=classlist.loc[classlist['cl_id']==output,'class_name'].values[0]
376
  if concatenated_image[0][1]!=[]:
 
131
  device = torch.device("cpu")
132
  if folder is None:
133
  folder = Path(f"tmp/{arch}/{dataset}/{seed}/")
134
+ model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth", map_location=torch.device('cpu')))
135
+ state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth", map_location=torch.device('cpu'))
136
+ selection= torch.load(folder / f"SlDD_Selection_50.pt", map_location=torch.device('cpu'))
137
  state_dict['linear.selection']=selection
138
 
139
  feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
 
370
 
371
  #original
372
 
373
+ data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/")
374
  classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
375
  output_name=classlist.loc[classlist['cl_id']==output,'class_name'].values[0]
376
  if concatenated_image[0][1]!=[]: