Spaces:
Running
Running
Update visualization.py
Browse files- visualization.py +4 -4
visualization.py
CHANGED
@@ -131,9 +131,9 @@ def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=12
|
|
131 |
device = torch.device("cpu")
|
132 |
if folder is None:
|
133 |
folder = Path(f"tmp/{arch}/{dataset}/{seed}/")
|
134 |
-
model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth"))
|
135 |
-
state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth")
|
136 |
-
selection= torch.load(folder / f"SlDD_Selection_50.pt")
|
137 |
state_dict['linear.selection']=selection
|
138 |
|
139 |
feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
|
@@ -370,7 +370,7 @@ def direct_inference(input):
|
|
370 |
|
371 |
#original
|
372 |
|
373 |
-
data_dir=Path
|
374 |
classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
|
375 |
output_name=classlist.loc[classlist['cl_id']==output,'class_name'].values[0]
|
376 |
if concatenated_image[0][1]!=[]:
|
|
|
131 |
device = torch.device("cpu")
|
132 |
if folder is None:
|
133 |
folder = Path(f"tmp/{arch}/{dataset}/{seed}/")
|
134 |
+
model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth", map_location=torch.device('cpu')))
|
135 |
+
state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth", map_location=torch.device('cpu'))
|
136 |
+
selection= torch.load(folder / f"SlDD_Selection_50.pt", map_location=torch.device('cpu'))
|
137 |
state_dict['linear.selection']=selection
|
138 |
|
139 |
feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
|
|
|
370 |
|
371 |
#original
|
372 |
|
373 |
+
data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/")
|
374 |
classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
|
375 |
output_name=classlist.loc[classlist['cl_id']==output,'class_name'].values[0]
|
376 |
if concatenated_image[0][1]!=[]:
|