Spaces:
Runtime error
Runtime error
import torch | |
import torchvision | |
from torchvision import transforms | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from models.modelNetA import Generator as GA | |
from models.modelNetB import Generator as GB | |
from models.modelNetC import Generator as GC | |
DEVICE='cpu' | |
model_type = 'model_b' | |
modeltype2path = { | |
'model_a': 'DTM_exp_train10%_model_a/g-best.pth', | |
'model_b': 'DTM_exp_train10%_model_b/g-best.pth', | |
'model_c': 'DTM_exp_train10%_model_c/g-best.pth', | |
} | |
if model_type == 'model_a': | |
generator = GA() | |
if model_type == 'model_b': | |
generator = GB() | |
if model_type == 'model_c': | |
generator = GC() | |
generator = torch.nn.DataParallel(generator) | |
state_dict_Gen = torch.load(modeltype2path[model_type], map_location=torch.device('cpu')) | |
generator.load_state_dict(state_dict_Gen) | |
generator = generator.module.to(DEVICE) | |
# generator.to(DEVICE) | |
generator.eval() | |
preprocess = transforms.Compose([ | |
transforms.Grayscale(), | |
transforms.Resize((512, 512)), | |
transforms.ToTensor() | |
]) | |
input_img = Image.open('demo_imgs/fake.jpg') | |
torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0).to(DEVICE) | |
with torch.no_grad(): | |
output = generator(torch_img) | |
sr, sr_dem_selected = output[0], output[1] | |
sr = sr.squeeze(0).cpu() | |
print(sr.shape) | |
torchvision.utils.save_image(sr, 'sr.png') | |
sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy() | |
print(sr_dem_selected.shape) | |
plt.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected)) | |
plt.colorbar() | |
plt.savefig('test.png') |