File size: 4,238 Bytes
be48827
b2fd97c
e2cf2b0
b2fd97c
06fe617
e2cf2b0
 
 
 
 
 
c1a6745
e2cf2b0
ff2c2a9
 
 
e9ceefd
e2cf2b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1a6745
e2cf2b0
9a66c24
e2cf2b0
 
8721178
fafc0c2
8721178
 
dad6ea4
8721178
e2cf2b0
 
 
 
 
dad6ea4
e2cf2b0
 
 
 
 
dad6ea4
 
 
 
 
 
 
e2cf2b0
 
 
dad6ea4
 
 
 
 
 
 
e2cf2b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcc2206
e2cf2b0
 
 
 
 
 
 
 
 
 
dcc2206
e2cf2b0
 
 
 
dad6ea4
 
e2cf2b0
 
 
dad6ea4
 
 
061565a
dcc2206
061565a
e2cf2b0
 
728f8d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

import os
import pickle
import sys
import subprocess
import imageio
import numpy as np
import scipy.interpolate
import torch
from tqdm import tqdm
import gradio as gr 


os.system("git clone https://github.com/NVlabs/stylegan3")

sys.path.append("stylegan3")

def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
    batch_size, channels, img_h, img_w = img.shape
    if grid_w is None:
        grid_w = batch_size // grid_h
    assert batch_size == grid_w * grid_h
    if float_to_uint8:
        img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
    img = img.permute(2, 0, 3, 1, 4)
    img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
    if chw_to_hwc:
        img = img.permute(1, 2, 0)
    if to_numpy:
        img = img.cpu().numpy()
    return img




network_pkl='braingan-400.pkl'
with open(network_pkl, 'rb') as f:
    G = pickle.load(f)['G_ema'] 
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G.eval()
G.to(device)
def predict(Seed,choices,choices2):

  shuffle_seed=None
  kind='cubic' 
  num_keyframes=None
  wraps=2
  psi=1 
  s1=Seed

  
  if choices=='4x2':
    grid_w = 4
    grid_h = 2
    if choices2=="Large Video":
        seeds=(np.arange(s1-16,s1)).tolist()  
        w_frames=60*4
    if choices2=="Small Video":
        seeds=(np.arange(s1-8,s1)).tolist()  
        w_frames=30*4
        
  if choices=='2x1':
    grid_w = 2
    grid_h = 1
    if choices2=="Large Video":
        seeds=(np.arange(s1-4,s1)).tolist()  
        w_frames=60*4
    if choices2=="Small Video":
        seeds=(np.arange(s1-2,s1)).tolist()  
        w_frames=30*4



  mp4='ex.mp4'
  truncation_psi=1
  num_keyframes=None


  if num_keyframes is None:
      if len(seeds) % (grid_w*grid_h) != 0:
          raise ValueError('Number of input seeds must be divisible by grid W*H')
      num_keyframes = len(seeds) // (grid_w*grid_h)

  all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
  for idx in range(num_keyframes*grid_h*grid_w):
      all_seeds[idx] = seeds[idx % len(seeds)]

  if shuffle_seed is not None:
      rng = np.random.RandomState(seed=shuffle_seed)
      rng.shuffle(all_seeds)

  zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device)
  ws = G.mapping(z=zs, c=None, truncation_psi=psi)
  _ = G.synthesis(ws[:1]) # warm up
  ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])

  # Interpolation.
  grid = []
  for yi in range(grid_h):
      row = []
      for xi in range(grid_w):
          x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
          y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
          interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
          row.append(interp)
      grid.append(row)

  # Render video.
  video_out = imageio.get_writer(mp4, mode='I', fps=30, codec='libx264')
  for frame_idx in tqdm(range(num_keyframes * w_frames)):
      imgs = []
      for yi in range(grid_h):
          for xi in range(grid_w):
              interp = grid[yi][xi]
              w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
              img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
              imgs.append(img)
      video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
  video_out.close()
  return mp4



choices=['4x2','2x1']
choices2=["Large Video","Small Video"]

interface=gr.Interface(fn=predict, title="Brain MR Image Generation with StyleGAN-2",
                       description = "",
                       article = "Author: S.Serdar Helli",
                       inputs=[gr.inputs.Slider( minimum=16, maximum=2**10,label='Seed'),gr.inputs.Radio( choices=choices,  default='4x2',label='Image Grid'),
                       ,gr.inputs.Radio( choices=choices2,  default="Small Video",label='Video Size - It depends on usage of cuda')]],
                       
                       outputs=gr.outputs.Video(label='Video'),
                       live=False)
                                              


interface.launch(debug=True,show_error=True)