Linoy Tsaban commited on
Commit
d19d91b
·
1 Parent(s): 5d7ba0f

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +15 -15
utils.py CHANGED
@@ -3,7 +3,7 @@ from PIL import Image, ImageDraw ,ImageFont
3
  from matplotlib import pyplot as plt
4
  import torchvision.transforms as T
5
  import os
6
- import torch
7
  import yaml
8
 
9
  def show_torch_img(img):
@@ -20,14 +20,14 @@ def tensor_to_pil(tensor_imgs):
20
  tensor_imgs = torch.cat(tensor_imgs)
21
  tensor_imgs = (tensor_imgs / 2 + 0.5).clamp(0, 1)
22
  to_pil = T.ToPILImage()
23
- pil_imgs = [to_pil(img) for img in tensor_imgs]
24
  return pil_imgs
25
 
26
  def pil_to_tensor(pil_imgs):
27
  to_torch = T.ToTensor()
28
  if type(pil_imgs) == PIL.Image.Image:
29
  tensor_imgs = to_torch(pil_imgs).unsqueeze(0)*2-1
30
- elif type(pil_imgs) == list:
31
  tensor_imgs = torch.cat([to_torch(pil_imgs).unsqueeze(0)*2-1 for img in pil_imgs]).to(device)
32
  else:
33
  raise Exception("Input need to be PIL.Image or list of PIL.Image")
@@ -40,30 +40,30 @@ def pil_to_tensor(pil_imgs):
40
  # num_col = n // num_rows
41
  # num_col = num_col + 1 if n % num_rows else num_col
42
  # num_col
43
- def add_margin(pil_img, top = 0, right = 0, bottom = 0,
44
  left = 0, color = (255,255,255)):
45
  width, height = pil_img.size
46
  new_width = width + right + left
47
  new_height = height + top + bottom
48
  result = Image.new(pil_img.mode, (new_width, new_height), color)
49
-
50
  result.paste(pil_img, (left, top))
51
  return result
52
 
53
- def image_grid(imgs, rows = 1, cols = None,
54
  size = None,
55
  titles = None, text_pos = (0, 0)):
56
  if type(imgs) == list and type(imgs[0]) == torch.Tensor:
57
  imgs = torch.cat(imgs)
58
  if type(imgs) == torch.Tensor:
59
  imgs = tensor_to_pil(imgs)
60
-
61
  if not size is None:
62
  imgs = [img.resize((size,size)) for img in imgs]
63
  if cols is None:
64
  cols = len(imgs)
65
  assert len(imgs) >= rows*cols
66
-
67
  top=20
68
  w, h = imgs[0].size
69
  delta = 0
@@ -71,23 +71,23 @@ def image_grid(imgs, rows = 1, cols = None,
71
  delta = top
72
  h = imgs[1].size[1]
73
  if not titles is None:
74
- font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf",
75
  size = 20, encoding="unic")
76
- h = top + h
77
- grid = Image.new('RGB', size=(cols*w, rows*h+delta))
78
  for i, img in enumerate(imgs):
79
-
80
  if not titles is None:
81
  img = add_margin(img, top = top, bottom = 0,left=0)
82
  draw = ImageDraw.Draw(img)
83
- draw.text(text_pos, titles[i],(0,0,0),
84
  font = font)
85
  if not delta == 0 and i > 0:
86
  grid.paste(img, box=(i%cols*w, i//cols*h+delta))
87
  else:
88
  grid.paste(img, box=(i%cols*w, i//cols*h))
89
-
90
- return grid
91
 
92
 
93
  """
 
3
  from matplotlib import pyplot as plt
4
  import torchvision.transforms as T
5
  import os
6
+ import torch
7
  import yaml
8
 
9
  def show_torch_img(img):
 
20
  tensor_imgs = torch.cat(tensor_imgs)
21
  tensor_imgs = (tensor_imgs / 2 + 0.5).clamp(0, 1)
22
  to_pil = T.ToPILImage()
23
+ pil_imgs = [to_pil(img) for img in tensor_imgs]
24
  return pil_imgs
25
 
26
  def pil_to_tensor(pil_imgs):
27
  to_torch = T.ToTensor()
28
  if type(pil_imgs) == PIL.Image.Image:
29
  tensor_imgs = to_torch(pil_imgs).unsqueeze(0)*2-1
30
+ elif type(pil_imgs) == list:
31
  tensor_imgs = torch.cat([to_torch(pil_imgs).unsqueeze(0)*2-1 for img in pil_imgs]).to(device)
32
  else:
33
  raise Exception("Input need to be PIL.Image or list of PIL.Image")
 
40
  # num_col = n // num_rows
41
  # num_col = num_col + 1 if n % num_rows else num_col
42
  # num_col
43
+ def add_margin(pil_img, top = 0, right = 0, bottom = 0,
44
  left = 0, color = (255,255,255)):
45
  width, height = pil_img.size
46
  new_width = width + right + left
47
  new_height = height + top + bottom
48
  result = Image.new(pil_img.mode, (new_width, new_height), color)
49
+
50
  result.paste(pil_img, (left, top))
51
  return result
52
 
53
+ def image_grid(imgs, rows = 1, cols = None,
54
  size = None,
55
  titles = None, text_pos = (0, 0)):
56
  if type(imgs) == list and type(imgs[0]) == torch.Tensor:
57
  imgs = torch.cat(imgs)
58
  if type(imgs) == torch.Tensor:
59
  imgs = tensor_to_pil(imgs)
60
+
61
  if not size is None:
62
  imgs = [img.resize((size,size)) for img in imgs]
63
  if cols is None:
64
  cols = len(imgs)
65
  assert len(imgs) >= rows*cols
66
+
67
  top=20
68
  w, h = imgs[0].size
69
  delta = 0
 
71
  delta = top
72
  h = imgs[1].size[1]
73
  if not titles is None:
74
+ font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf",
75
  size = 20, encoding="unic")
76
+ h = top + h
77
+ grid = Image.new('RGB', size=(cols*w, rows*h+delta))
78
  for i, img in enumerate(imgs):
79
+
80
  if not titles is None:
81
  img = add_margin(img, top = top, bottom = 0,left=0)
82
  draw = ImageDraw.Draw(img)
83
+ draw.text(text_pos, titles[i],(0,0,0),
84
  font = font)
85
  if not delta == 0 and i > 0:
86
  grid.paste(img, box=(i%cols*w, i//cols*h+delta))
87
  else:
88
  grid.paste(img, box=(i%cols*w, i//cols*h))
89
+
90
+ return grid
91
 
92
 
93
  """