File size: 1,542 Bytes
eba1c6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from models.modelNetA import Generator as GA
from models.modelNetB import Generator as GB
from models.modelNetC import Generator as GC



DEVICE='cpu'
model_type = 'model_b'

modeltype2path = {
    'model_a': 'DTM_exp_train10%_model_a/g-best.pth',
    'model_b': 'DTM_exp_train10%_model_b/g-best.pth',
    'model_c': 'DTM_exp_train10%_model_c/g-best.pth',
}

if model_type == 'model_a':
    generator = GA()
if model_type == 'model_b':
    generator = GB()
if model_type == 'model_c':
    generator = GC()

generator = torch.nn.DataParallel(generator)
state_dict_Gen = torch.load(modeltype2path[model_type], map_location=torch.device('cpu'))
generator.load_state_dict(state_dict_Gen)
generator = generator.module.to(DEVICE)
# generator.to(DEVICE)
generator.eval()

preprocess = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])
input_img = Image.open('demo_imgs/fake.jpg')
torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0).to(DEVICE)
with torch.no_grad():
    output = generator(torch_img)
sr, sr_dem_selected = output[0], output[1]
sr = sr.squeeze(0).cpu()

print(sr.shape)
torchvision.utils.save_image(sr, 'sr.png')

sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
print(sr_dem_selected.shape)
plt.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected))
plt.colorbar()
plt.savefig('test.png')