Ahsen Khaliq commited on
Commit
46c8e4c
1 Parent(s): 4d6f95f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -56
app.py CHANGED
@@ -9,25 +9,20 @@ import math
9
  import gradio as gr
10
  from torchvision import transforms
11
  import torchtext
12
-
 
 
13
  torch.hub.download_url_to_file('https://i.imgur.com/WEHmKef.jpg', 'gpu.jpg')
14
-
15
  # Images
16
  torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2021/08/04/14/16/tower-6521842_1280.jpg', 'tower.jpg')
17
  torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2017/08/31/05/36/buildings-2699520_1280.jpg', 'city.jpg')
18
-
19
  idx = 0
20
-
21
  torchtext.utils.download_from_url("https://drive.google.com/uc?id=1NDD54BLligyr8tzo8QGI5eihZisXK1nq", root=".")
22
-
23
-
24
  def to_PIL_img(img):
25
  result = Image.fromarray((img.data.cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8))
26
  return result
27
  def save_img(img, output_path):
28
  to_PIL_img(img).save(output_path)
29
-
30
-
31
  def param2stroke(param, H, W, meta_brushes):
32
  """
33
  Input a set of stroke parameters and output its corresponding foregrounds and alpha maps.
@@ -38,7 +33,6 @@ def param2stroke(param, H, W, meta_brushes):
38
  W: output width.
39
  meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width.
40
  The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
41
-
42
  Returns:
43
  foregrounds: a tensor with shape n_strokes x 3 x H x W, containing color information.
44
  alphas: a tensor with shape n_strokes x 3 x H x W,
@@ -61,7 +55,6 @@ def param2stroke(param, H, W, meta_brushes):
61
  index[h > w] = 0
62
  index[h <= w] = 1
63
  brush = meta_brushes_resize[index.long()]
64
-
65
  # Calculate warp matrix according to the rules defined by pytorch, in order for warping.
66
  warp_00 = cos_theta / w
67
  warp_01 = sin_theta * H / (W * w)
@@ -87,8 +80,6 @@ def param2stroke(param, H, W, meta_brushes):
87
  foreground = morphology.dilation(foreground)
88
  alphas = morphology.erosion(alphas)
89
  return foreground, alphas
90
-
91
-
92
  def param2img_serial(
93
  param, decision, meta_brushes, cur_canvas, frame_dir, has_border=False, original_h=None, original_w=None, *, all_frames):
94
  """
@@ -111,7 +102,6 @@ def param2img_serial(
111
  on the border before saving, or there would be a black border.
112
  original_h: to indicate the original height for cropping when saving intermediate results.
113
  original_w: to indicate the original width for cropping when saving intermediate results.
114
-
115
  Returns:
116
  cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results.
117
  """
@@ -133,7 +123,6 @@ def param2img_serial(
133
  odd_y_even_x_coord_y, odd_y_even_x_coord_x = torch.meshgrid([odd_idx_y, even_idx_x])
134
  cur_canvas = F.pad(cur_canvas, [patch_size_x // 4, patch_size_x // 4,
135
  patch_size_y // 4, patch_size_y // 4, 0, 0, 0, 0])
136
-
137
  def partial_render(this_canvas, patch_coord_y, patch_coord_x, stroke_id):
138
  canvas_patch = F.unfold(this_canvas, (patch_size_y, patch_size_x),
139
  stride=(patch_size_y // 2, patch_size_x // 2))
@@ -161,17 +150,14 @@ def param2img_serial(
161
  this_canvas = this_canvas.view(b, 3, selected_h * patch_size_y, selected_w * patch_size_x).contiguous()
162
  # this_canvas: b, 3, selected_h * py, selected_w * px
163
  return this_canvas
164
-
165
  global idx
166
  if has_border:
167
  factor = 2
168
  else:
169
  factor = 4
170
-
171
  def store_frame(img):
172
  all_frames.append(to_PIL_img(img))
173
 
174
-
175
  if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
176
  for i in range(s):
177
  canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x, i)
@@ -186,7 +172,6 @@ def param2img_serial(
186
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
187
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
188
  store_frame(frame[0])
189
-
190
  if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
191
  for i in range(s):
192
  canvas = partial_render(cur_canvas, odd_y_odd_x_coord_y, odd_y_odd_x_coord_x, i)
@@ -203,7 +188,6 @@ def param2img_serial(
203
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
204
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
205
  store_frame(frame[0])
206
-
207
  if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
208
  for i in range(s):
209
  canvas = partial_render(cur_canvas, odd_y_even_x_coord_y, odd_y_even_x_coord_x, i)
@@ -219,7 +203,6 @@ def param2img_serial(
219
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
220
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
221
  store_frame(frame[0])
222
-
223
  if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
224
  for i in range(s):
225
  canvas = partial_render(cur_canvas, even_y_odd_x_coord_y, even_y_odd_x_coord_x, i)
@@ -235,12 +218,8 @@ def param2img_serial(
235
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
236
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
237
  store_frame(frame[0])
238
-
239
  cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
240
-
241
  return cur_canvas
242
-
243
-
244
  def param2img_parallel(param, decision, meta_brushes, cur_canvas):
245
  """
246
  Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
@@ -255,7 +234,6 @@ def param2img_parallel(param, decision, meta_brushes, cur_canvas):
255
  The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
256
  cur_canvas: a tensor with shape batch size x 3 x H x W,
257
  where H and W denote height and width of padded results of original images.
258
-
259
  Returns:
260
  cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results.
261
  """
@@ -289,11 +267,8 @@ def param2img_parallel(param, decision, meta_brushes, cur_canvas):
289
  alphas = alphas.view(-1, h, w, s, 3, patch_size_y, patch_size_x).contiguous()
290
  # foreground, alpha: b, h, w, stroke_per_patch, 3, render_size_y, render_size_x
291
  decision = decision.view(-1, h, w, s, 1, 1, 1).contiguous()
292
-
293
  # decision: b, h, w, stroke_per_patch, 1, 1, 1
294
-
295
  def partial_render(this_canvas, patch_coord_y, patch_coord_x):
296
-
297
  canvas_patch = F.unfold(this_canvas, (patch_size_y, patch_size_x),
298
  stride=(patch_size_y // 2, patch_size_x // 2))
299
  # canvas_patch: b, 3 * py * px, h * w
@@ -317,7 +292,6 @@ def param2img_parallel(param, decision, meta_brushes, cur_canvas):
317
  this_canvas = this_canvas.view(b, 3, h_half * patch_size_y, w_half * patch_size_x).contiguous()
318
  # this_canvas: b, 3, h_half * py, w_half * px
319
  return this_canvas
320
-
321
  if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
322
  canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x)
323
  if not is_odd_y:
@@ -325,7 +299,6 @@ def param2img_parallel(param, decision, meta_brushes, cur_canvas):
325
  if not is_odd_x:
326
  canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
327
  cur_canvas = canvas
328
-
329
  if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
330
  canvas = partial_render(cur_canvas, odd_y_odd_x_coord_y, odd_y_odd_x_coord_x)
331
  canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, -canvas.shape[3]:], canvas], dim=2)
@@ -335,7 +308,6 @@ def param2img_parallel(param, decision, meta_brushes, cur_canvas):
335
  if is_odd_x:
336
  canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
337
  cur_canvas = canvas
338
-
339
  if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
340
  canvas = partial_render(cur_canvas, odd_y_even_x_coord_y, odd_y_even_x_coord_x)
341
  canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, :canvas.shape[3]], canvas], dim=2)
@@ -344,7 +316,6 @@ def param2img_parallel(param, decision, meta_brushes, cur_canvas):
344
  if not is_odd_x:
345
  canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
346
  cur_canvas = canvas
347
-
348
  if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
349
  canvas = partial_render(cur_canvas, even_y_odd_x_coord_y, even_y_odd_x_coord_x)
350
  canvas = torch.cat([cur_canvas[:, :, :canvas.shape[2], :patch_size_x // 2], canvas], dim=3)
@@ -353,12 +324,8 @@ def param2img_parallel(param, decision, meta_brushes, cur_canvas):
353
  if is_odd_x:
354
  canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
355
  cur_canvas = canvas
356
-
357
  cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
358
-
359
  return cur_canvas
360
-
361
-
362
  def read_img(img_path, img_type='RGB', h=None, w=None):
363
  img = Image.open(img_path).convert(img_type)
364
  if h is not None and w is not None:
@@ -369,8 +336,6 @@ def read_img(img_path, img_type='RGB', h=None, w=None):
369
  img = img.transpose((2, 0, 1))
370
  img = torch.from_numpy(img).unsqueeze(0).float() / 255.
371
  return img
372
-
373
-
374
  def pad(img, H, W):
375
  b, c, h, w = img.shape
376
  pad_h = (H - h) // 2
@@ -382,8 +347,6 @@ def pad(img, H, W):
382
  img = torch.cat([torch.zeros((b, c, H, pad_w), device=img.device), img,
383
  torch.zeros((b, c, H, pad_w + remainder_w), device=img.device)], dim=-1)
384
  return img
385
-
386
-
387
  def crop(img, h, w):
388
  H, W = img.shape[-2:]
389
  pad_h = (H - h) // 2
@@ -392,11 +355,21 @@ def crop(img, h, w):
392
  remainder_w = (W - w) % 2
393
  img = img[:, :, pad_h:H - pad_h - remainder_h, pad_w:W - pad_w - remainder_w]
394
  return img
395
-
396
-
397
  def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
398
  if not os.path.exists(output_dir):
399
  os.mkdir(output_dir)
 
 
 
 
 
 
 
 
 
 
 
 
400
  input_name = os.path.basename(input_path)
401
  output_path = os.path.join(output_dir, input_name)
402
  frame_dir = None
@@ -415,12 +388,10 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
415
  net_g.eval()
416
  for param in net_g.parameters():
417
  param.requires_grad = False
418
-
419
  brush_large_vertical = read_img('brush/brush_large_vertical.png', 'L').to(device)
420
  brush_large_horizontal = read_img('brush/brush_large_horizontal.png', 'L').to(device)
421
  meta_brushes = torch.cat(
422
  [brush_large_vertical, brush_large_horizontal], dim=0)
423
-
424
  with torch.no_grad():
425
  original_img = read_img(input_path, 'RGB', resize_h, resize_w).to(device)
426
  original_h, original_w = original_img.shape[-2:]
@@ -438,14 +409,12 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
438
  stride=(patch_size, patch_size))
439
  # There are patch_num * patch_num patches in total
440
  patch_num = (layer_size - patch_size) // patch_size + 1
441
-
442
  # img_patch, result_patch: b, 3 * output_size * output_size, h * w
443
  img_patch = img_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
444
  result_patch = result_patch.permute(0, 2, 1).contiguous().view(
445
  -1, 3, patch_size, patch_size).contiguous()
446
  shape_param, stroke_decision = net_g(img_patch, result_patch)
447
  stroke_decision = network.SignWithSigmoidGrad.apply(stroke_decision)
448
-
449
  grid = shape_param[:, :, :2].view(img_patch.shape[0] * stroke_num, 1, 1, 2).contiguous()
450
  img_temp = img_patch.unsqueeze(1).contiguous().repeat(1, stroke_num, 1, 1, 1).view(
451
  img_patch.shape[0] * stroke_num, 3, patch_size, patch_size).contiguous()
@@ -465,7 +434,6 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
465
  frame_dir, False, original_h, original_w, all_frames = all_frames)
466
  else:
467
  final_result = param2img_parallel(param, decision, meta_brushes, final_result)
468
-
469
  border_size = original_img_pad_size // (2 * patch_num)
470
  img = F.interpolate(original_img_pad, (patch_size * (2 ** layer), patch_size * (2 ** layer)))
471
  result = F.interpolate(final_result, (patch_size * (2 ** layer), patch_size * (2 ** layer)))
@@ -482,7 +450,6 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
482
  img_patch = img_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
483
  result_patch = result_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
484
  shape_param, stroke_decision = net_g(img_patch, result_patch)
485
-
486
  grid = shape_param[:, :, :2].view(img_patch.shape[0] * stroke_num, 1, 1, 2).contiguous()
487
  img_temp = img_patch.unsqueeze(1).contiguous().repeat(1, stroke_num, 1, 1, 1).view(
488
  img_patch.shape[0] * stroke_num, 3, patch_size, patch_size).contiguous()
@@ -503,17 +470,13 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
503
  else:
504
  final_result = param2img_parallel(param, decision, meta_brushes, final_result)
505
  final_result = final_result[:, :, border_size:-border_size, border_size:-border_size]
506
-
507
  final_result = crop(final_result, original_h, original_w)
508
  save_img(final_result[0], output_path)
509
  tensor_to_pil = transforms.ToPILImage()(final_result[0].squeeze_(0))
510
  #return tensor_to_pil
511
-
512
  all_frames[0].save(os.path.join(frame_dir, 'animation.gif'),
513
  save_all=True, append_images=all_frames[1:], optimize=False, duration=40, loop=0)
514
  return os.path.join(frame_dir, "animation.gif"), tensor_to_pil
515
-
516
-
517
 
518
  def gradio_inference(image):
519
  return main(input_path=image.name,
@@ -523,7 +486,6 @@ def gradio_inference(image):
523
  resize_h=400, # resize original input to this size. None means do not resize.
524
  resize_w=400, # resize original input to this size. None means do not resize.
525
  serial=True) # if need animation, serial must be True.
526
-
527
  inferences_running = 0
528
  def throttled_inference(image):
529
  global inferences_running
@@ -538,11 +500,9 @@ def throttled_inference(image):
538
  finally:
539
  print("Inference finished")
540
  inferences_running -= 1
541
-
542
  title = "Paint Transformer"
543
  description = "Gradio demo for Paint Transformer: Feed Forward Neural Painting with Stroke Prediction. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
544
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.03798'>Paint Transformer: Feed Forward Neural Painting with Stroke Prediction</a> | <a href='https://github.com/Huage001/PaintTransformer'>Github Repo</a></p>"
545
-
546
  gr.Interface(
547
  throttled_inference,
548
  gr.inputs.Image(type="file", label="Input"),
@@ -555,4 +515,4 @@ gr.Interface(
555
  ['city.jpg'],
556
  ['tower.jpg']
557
  ]
558
- ).launch(debug=True)
 
9
  import gradio as gr
10
  from torchvision import transforms
11
  import torchtext
12
+ from stat import ST_CTIME
13
+ from datetime import datetime, timedelta
14
+ import shutil
15
  torch.hub.download_url_to_file('https://i.imgur.com/WEHmKef.jpg', 'gpu.jpg')
 
16
  # Images
17
  torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2021/08/04/14/16/tower-6521842_1280.jpg', 'tower.jpg')
18
  torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2017/08/31/05/36/buildings-2699520_1280.jpg', 'city.jpg')
 
19
  idx = 0
 
20
  torchtext.utils.download_from_url("https://drive.google.com/uc?id=1NDD54BLligyr8tzo8QGI5eihZisXK1nq", root=".")
 
 
21
  def to_PIL_img(img):
22
  result = Image.fromarray((img.data.cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8))
23
  return result
24
  def save_img(img, output_path):
25
  to_PIL_img(img).save(output_path)
 
 
26
  def param2stroke(param, H, W, meta_brushes):
27
  """
28
  Input a set of stroke parameters and output its corresponding foregrounds and alpha maps.
 
33
  W: output width.
34
  meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width.
35
  The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
 
36
  Returns:
37
  foregrounds: a tensor with shape n_strokes x 3 x H x W, containing color information.
38
  alphas: a tensor with shape n_strokes x 3 x H x W,
 
55
  index[h > w] = 0
56
  index[h <= w] = 1
57
  brush = meta_brushes_resize[index.long()]
 
58
  # Calculate warp matrix according to the rules defined by pytorch, in order for warping.
59
  warp_00 = cos_theta / w
60
  warp_01 = sin_theta * H / (W * w)
 
80
  foreground = morphology.dilation(foreground)
81
  alphas = morphology.erosion(alphas)
82
  return foreground, alphas
 
 
83
  def param2img_serial(
84
  param, decision, meta_brushes, cur_canvas, frame_dir, has_border=False, original_h=None, original_w=None, *, all_frames):
85
  """
 
102
  on the border before saving, or there would be a black border.
103
  original_h: to indicate the original height for cropping when saving intermediate results.
104
  original_w: to indicate the original width for cropping when saving intermediate results.
 
105
  Returns:
106
  cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results.
107
  """
 
123
  odd_y_even_x_coord_y, odd_y_even_x_coord_x = torch.meshgrid([odd_idx_y, even_idx_x])
124
  cur_canvas = F.pad(cur_canvas, [patch_size_x // 4, patch_size_x // 4,
125
  patch_size_y // 4, patch_size_y // 4, 0, 0, 0, 0])
 
126
  def partial_render(this_canvas, patch_coord_y, patch_coord_x, stroke_id):
127
  canvas_patch = F.unfold(this_canvas, (patch_size_y, patch_size_x),
128
  stride=(patch_size_y // 2, patch_size_x // 2))
 
150
  this_canvas = this_canvas.view(b, 3, selected_h * patch_size_y, selected_w * patch_size_x).contiguous()
151
  # this_canvas: b, 3, selected_h * py, selected_w * px
152
  return this_canvas
 
153
  global idx
154
  if has_border:
155
  factor = 2
156
  else:
157
  factor = 4
 
158
  def store_frame(img):
159
  all_frames.append(to_PIL_img(img))
160
 
 
161
  if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
162
  for i in range(s):
163
  canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x, i)
 
172
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
173
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
174
  store_frame(frame[0])
 
175
  if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
176
  for i in range(s):
177
  canvas = partial_render(cur_canvas, odd_y_odd_x_coord_y, odd_y_odd_x_coord_x, i)
 
188
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
189
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
190
  store_frame(frame[0])
 
191
  if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
192
  for i in range(s):
193
  canvas = partial_render(cur_canvas, odd_y_even_x_coord_y, odd_y_even_x_coord_x, i)
 
203
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
204
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
205
  store_frame(frame[0])
 
206
  if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
207
  for i in range(s):
208
  canvas = partial_render(cur_canvas, even_y_odd_x_coord_y, even_y_odd_x_coord_x, i)
 
218
  patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
219
  save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
220
  store_frame(frame[0])
 
221
  cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
 
222
  return cur_canvas
 
 
223
  def param2img_parallel(param, decision, meta_brushes, cur_canvas):
224
  """
225
  Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
 
234
  The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
235
  cur_canvas: a tensor with shape batch size x 3 x H x W,
236
  where H and W denote height and width of padded results of original images.
 
237
  Returns:
238
  cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results.
239
  """
 
267
  alphas = alphas.view(-1, h, w, s, 3, patch_size_y, patch_size_x).contiguous()
268
  # foreground, alpha: b, h, w, stroke_per_patch, 3, render_size_y, render_size_x
269
  decision = decision.view(-1, h, w, s, 1, 1, 1).contiguous()
 
270
  # decision: b, h, w, stroke_per_patch, 1, 1, 1
 
271
  def partial_render(this_canvas, patch_coord_y, patch_coord_x):
 
272
  canvas_patch = F.unfold(this_canvas, (patch_size_y, patch_size_x),
273
  stride=(patch_size_y // 2, patch_size_x // 2))
274
  # canvas_patch: b, 3 * py * px, h * w
 
292
  this_canvas = this_canvas.view(b, 3, h_half * patch_size_y, w_half * patch_size_x).contiguous()
293
  # this_canvas: b, 3, h_half * py, w_half * px
294
  return this_canvas
 
295
  if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
296
  canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x)
297
  if not is_odd_y:
 
299
  if not is_odd_x:
300
  canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
301
  cur_canvas = canvas
 
302
  if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
303
  canvas = partial_render(cur_canvas, odd_y_odd_x_coord_y, odd_y_odd_x_coord_x)
304
  canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, -canvas.shape[3]:], canvas], dim=2)
 
308
  if is_odd_x:
309
  canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
310
  cur_canvas = canvas
 
311
  if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
312
  canvas = partial_render(cur_canvas, odd_y_even_x_coord_y, odd_y_even_x_coord_x)
313
  canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, :canvas.shape[3]], canvas], dim=2)
 
316
  if not is_odd_x:
317
  canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
318
  cur_canvas = canvas
 
319
  if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
320
  canvas = partial_render(cur_canvas, even_y_odd_x_coord_y, even_y_odd_x_coord_x)
321
  canvas = torch.cat([cur_canvas[:, :, :canvas.shape[2], :patch_size_x // 2], canvas], dim=3)
 
324
  if is_odd_x:
325
  canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
326
  cur_canvas = canvas
 
327
  cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
 
328
  return cur_canvas
 
 
329
  def read_img(img_path, img_type='RGB', h=None, w=None):
330
  img = Image.open(img_path).convert(img_type)
331
  if h is not None and w is not None:
 
336
  img = img.transpose((2, 0, 1))
337
  img = torch.from_numpy(img).unsqueeze(0).float() / 255.
338
  return img
 
 
339
  def pad(img, H, W):
340
  b, c, h, w = img.shape
341
  pad_h = (H - h) // 2
 
347
  img = torch.cat([torch.zeros((b, c, H, pad_w), device=img.device), img,
348
  torch.zeros((b, c, H, pad_w + remainder_w), device=img.device)], dim=-1)
349
  return img
 
 
350
  def crop(img, h, w):
351
  H, W = img.shape[-2:]
352
  pad_h = (H - h) // 2
 
355
  remainder_w = (W - w) % 2
356
  img = img[:, :, pad_h:H - pad_h - remainder_h, pad_w:W - pad_w - remainder_w]
357
  return img
 
 
358
  def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
359
  if not os.path.exists(output_dir):
360
  os.mkdir(output_dir)
361
+
362
+ for entry in os.listdir(output_dir):
363
+ path = os.path.join(output_dir, entry)
364
+ stats = os.stat(path)
365
+ created_time = datetime.fromtimestamp(stats[ST_CTIME])
366
+ if created_time < datetime.now() - timedelta(minutes = 10):
367
+ if os.path.isdir(path):
368
+ shutil.rmtree(path)
369
+ else:
370
+ os.remove(path)
371
+
372
+
373
  input_name = os.path.basename(input_path)
374
  output_path = os.path.join(output_dir, input_name)
375
  frame_dir = None
 
388
  net_g.eval()
389
  for param in net_g.parameters():
390
  param.requires_grad = False
 
391
  brush_large_vertical = read_img('brush/brush_large_vertical.png', 'L').to(device)
392
  brush_large_horizontal = read_img('brush/brush_large_horizontal.png', 'L').to(device)
393
  meta_brushes = torch.cat(
394
  [brush_large_vertical, brush_large_horizontal], dim=0)
 
395
  with torch.no_grad():
396
  original_img = read_img(input_path, 'RGB', resize_h, resize_w).to(device)
397
  original_h, original_w = original_img.shape[-2:]
 
409
  stride=(patch_size, patch_size))
410
  # There are patch_num * patch_num patches in total
411
  patch_num = (layer_size - patch_size) // patch_size + 1
 
412
  # img_patch, result_patch: b, 3 * output_size * output_size, h * w
413
  img_patch = img_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
414
  result_patch = result_patch.permute(0, 2, 1).contiguous().view(
415
  -1, 3, patch_size, patch_size).contiguous()
416
  shape_param, stroke_decision = net_g(img_patch, result_patch)
417
  stroke_decision = network.SignWithSigmoidGrad.apply(stroke_decision)
 
418
  grid = shape_param[:, :, :2].view(img_patch.shape[0] * stroke_num, 1, 1, 2).contiguous()
419
  img_temp = img_patch.unsqueeze(1).contiguous().repeat(1, stroke_num, 1, 1, 1).view(
420
  img_patch.shape[0] * stroke_num, 3, patch_size, patch_size).contiguous()
 
434
  frame_dir, False, original_h, original_w, all_frames = all_frames)
435
  else:
436
  final_result = param2img_parallel(param, decision, meta_brushes, final_result)
 
437
  border_size = original_img_pad_size // (2 * patch_num)
438
  img = F.interpolate(original_img_pad, (patch_size * (2 ** layer), patch_size * (2 ** layer)))
439
  result = F.interpolate(final_result, (patch_size * (2 ** layer), patch_size * (2 ** layer)))
 
450
  img_patch = img_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
451
  result_patch = result_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
452
  shape_param, stroke_decision = net_g(img_patch, result_patch)
 
453
  grid = shape_param[:, :, :2].view(img_patch.shape[0] * stroke_num, 1, 1, 2).contiguous()
454
  img_temp = img_patch.unsqueeze(1).contiguous().repeat(1, stroke_num, 1, 1, 1).view(
455
  img_patch.shape[0] * stroke_num, 3, patch_size, patch_size).contiguous()
 
470
  else:
471
  final_result = param2img_parallel(param, decision, meta_brushes, final_result)
472
  final_result = final_result[:, :, border_size:-border_size, border_size:-border_size]
 
473
  final_result = crop(final_result, original_h, original_w)
474
  save_img(final_result[0], output_path)
475
  tensor_to_pil = transforms.ToPILImage()(final_result[0].squeeze_(0))
476
  #return tensor_to_pil
 
477
  all_frames[0].save(os.path.join(frame_dir, 'animation.gif'),
478
  save_all=True, append_images=all_frames[1:], optimize=False, duration=40, loop=0)
479
  return os.path.join(frame_dir, "animation.gif"), tensor_to_pil
 
 
480
 
481
  def gradio_inference(image):
482
  return main(input_path=image.name,
 
486
  resize_h=400, # resize original input to this size. None means do not resize.
487
  resize_w=400, # resize original input to this size. None means do not resize.
488
  serial=True) # if need animation, serial must be True.
 
489
  inferences_running = 0
490
  def throttled_inference(image):
491
  global inferences_running
 
500
  finally:
501
  print("Inference finished")
502
  inferences_running -= 1
 
503
  title = "Paint Transformer"
504
  description = "Gradio demo for Paint Transformer: Feed Forward Neural Painting with Stroke Prediction. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
505
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.03798'>Paint Transformer: Feed Forward Neural Painting with Stroke Prediction</a> | <a href='https://github.com/Huage001/PaintTransformer'>Github Repo</a></p>"
 
506
  gr.Interface(
507
  throttled_inference,
508
  gr.inputs.Image(type="file", label="Input"),
 
515
  ['city.jpg'],
516
  ['tower.jpg']
517
  ]
518
+ ).launch(debug=True)