File size: 3,597 Bytes
be48827
b2fd97c
e2cf2b0
b2fd97c
06fe617
e2cf2b0
 
 
 
 
 
c1a6745
e2cf2b0
ff2c2a9
 
 
e9ceefd
e2cf2b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1a6745
e2cf2b0
9a66c24
e2cf2b0
 
8721178
 
 
 
e2cf2b0
8721178
e2cf2b0
 
 
 
 
 
 
 
 
 
 
 
8721178
e2cf2b0
 
 
 
8721178
e2cf2b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

import os
import pickle
import sys
import subprocess
import imageio
import numpy as np
import scipy.interpolate
import torch
from tqdm import tqdm
import gradio as gr 


os.system("git clone https://github.com/NVlabs/stylegan3")

sys.path.append("stylegan3")

def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
    batch_size, channels, img_h, img_w = img.shape
    if grid_w is None:
        grid_w = batch_size // grid_h
    assert batch_size == grid_w * grid_h
    if float_to_uint8:
        img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
    img = img.permute(2, 0, 3, 1, 4)
    img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
    if chw_to_hwc:
        img = img.permute(1, 2, 0)
    if to_numpy:
        img = img.cpu().numpy()
    return img




network_pkl='braingan-400.pkl'
with open(network_pkl, 'rb') as f:
    G = pickle.load(f)['G_ema'] 
    
device = torch.device('cuda')
G.eval()
G.to(device)
def predict(Seed,choices):

  shuffle_seed=None
  w_frames=60*4
  kind='cubic' 
  num_keyframes=None
  wraps=2
  psi=1 

  
  if choices=='4x2':
    grid_w = 4
    grid_h = 2
    s1=Seed
    seeds=(np.arange(s1-8,s1)).tolist()
  if choices=='2x1':
    grid_w = 2
    grid_h = 1
    s1=Seed
    seeds=(np.arange(s1-2,s1)).tolist()


  mp4='ex.mp4'
  truncation_psi=1
  num_keyframes=None


  if num_keyframes is None:
      if len(seeds) % (grid_w*grid_h) != 0:
          raise ValueError('Number of input seeds must be divisible by grid W*H')
      num_keyframes = len(seeds) // (grid_w*grid_h)

  all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
  for idx in range(num_keyframes*grid_h*grid_w):
      all_seeds[idx] = seeds[idx % len(seeds)]

  if shuffle_seed is not None:
      rng = np.random.RandomState(seed=shuffle_seed)
      rng.shuffle(all_seeds)

  zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device)
  ws = G.mapping(z=zs, c=None, truncation_psi=psi)
  _ = G.synthesis(ws[:1]) # warm up
  ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])

  # Interpolation.
  grid = []
  for yi in range(grid_h):
      row = []
      for xi in range(grid_w):
          x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
          y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
          interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
          row.append(interp)
      grid.append(row)

  # Render video.
  video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264')
  for frame_idx in tqdm(range(num_keyframes * w_frames)):
      imgs = []
      for yi in range(grid_h):
          for xi in range(grid_w):
              interp = grid[yi][xi]
              w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
              img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
              imgs.append(img)
      video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
  video_out.close()
  return 'ex.mp4'



choices=['4x2','2x1']
interface=gr.Interface(fn=predict, title="Brain MR Image Generation with StyleGAN-2",
                       description = "",
                       article = "Author: S.Serdar Helli",
                       inputs=[gr.inputs.Slider( minimum=16, maximum=2**10,label='Seed'),gr.inputs.Radio( choices=choices,  default='4x2',label='Image Grid')],
                       outputs=gr.outputs.Video(label='Video'))


interface.launch(debug=True)