osanseviero commited on
Commit
bee801c
·
1 Parent(s): dde7894

Add model and demo

Browse files
Files changed (7) hide show
  1. Procfile +1 -0
  2. app.py +91 -0
  3. autoencoder_model.png +0 -0
  4. model-final.pth +3 -0
  5. predict.py +79 -0
  6. prediction.ipynb +0 -0
  7. requirements.txt +8 -0
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: sh setup.sh && streamlit run app.py
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import torch
3
+ import torch.nn as nn
4
+ import cv2
5
+ from skimage.color import lab2rgb, rgb2lab, rgb2gray
6
+ from skimage import io
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+
10
+ class ColorizationNet(nn.Module):
11
+ def __init__(self, input_size=128):
12
+ super(ColorizationNet, self).__init__()
13
+
14
+ MIDLEVEL_FEATURE_SIZE = 128
15
+ resnet=models.resnet18(pretrained=True)
16
+ resnet.conv1.weight=nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
17
+
18
+ self.midlevel_resnet =nn.Sequential(*list(resnet.children())[0:6])
19
+
20
+ self.upsample = nn.Sequential(
21
+ nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1),
22
+ nn.BatchNorm2d(128),
23
+ nn.ReLU(),
24
+ nn.Upsample(scale_factor=2),
25
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
26
+ nn.BatchNorm2d(64),
27
+ nn.ReLU(),
28
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
29
+ nn.BatchNorm2d(64),
30
+ nn.ReLU(),
31
+ nn.Upsample(scale_factor=2),
32
+ nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
33
+ nn.BatchNorm2d(32),
34
+ nn.ReLU(),
35
+ nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),
36
+ nn.Upsample(scale_factor=2)
37
+ )
38
+
39
+ def forward(self, input):
40
+
41
+ # Pass input through ResNet-gray to extract features
42
+ midlevel_features = self.midlevel_resnet(input)
43
+
44
+ # Upsample to get colors
45
+ output = self.upsample(midlevel_features)
46
+ return output
47
+
48
+
49
+
50
+ def show_output(grayscale_input, ab_input):
51
+ '''Show/save rgb image from grayscale and ab channels
52
+ Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
53
+ color_image = torch.cat((grayscale_input, ab_input), 0).detach().numpy() # combine channels
54
+ color_image = color_image.transpose((1, 2, 0)) # rescale for matplotlib
55
+ color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
56
+ color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
57
+ color_image = lab2rgb(color_image.astype(np.float64))
58
+ grayscale_input = grayscale_input.squeeze().numpy()
59
+ # plt.imshow(grayscale_input)
60
+ # plt.imshow(color_image)
61
+ return color_image
62
+
63
+ def colorize(img,print_img=True):
64
+ # img=cv2.imread(img)
65
+ img=cv2.resize(img,(224,224))
66
+ grayscale_input= torch.Tensor(rgb2gray(img))
67
+ ab_input=model(grayscale_input.unsqueeze(0).unsqueeze(0)).squeeze(0)
68
+ predicted=show_output(grayscale_input.unsqueeze(0), ab_input)
69
+ if print_img:
70
+ plt.imshow(predicted)
71
+ return predicted
72
+
73
+ # device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ # torch.load with map_location=torch.device('cpu')
75
+ model=torch.load("model-final.pth",map_location ='cpu')
76
+
77
+
78
+ import streamlit as st
79
+ st.title("Image Colorizer")
80
+
81
+ file=st.file_uploader("Please upload the B/W image",type=["jpg","jpeg","png"])
82
+ print(file)
83
+ if file is None:
84
+ st.text("Please Upload an image")
85
+ else:
86
+ file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
87
+ opencv_image = cv2.imdecode(file_bytes, 1)
88
+ im=colorize(opencv_image)
89
+ st.image(im)
90
+ st.text("Colorized!!")
91
+ # st.image(file)
autoencoder_model.png ADDED
model-final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6268c0b73c7bc3fefd3918d113fb74976f9780f4737bf6e4c088811a1a6872ec
3
+ size 3867929
predict.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, './WordLM')
3
+
4
+ import PIL
5
+ import torch
6
+ import torch.nn as nn
7
+ import cv2
8
+ from skimage.color import lab2rgb, rgb2lab, rgb2gray
9
+ from skimage import io
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+
13
+ class ColorizationNet(nn.Module):
14
+ def __init__(self, input_size=128):
15
+ super(ColorizationNet, self).__init__()
16
+
17
+ MIDLEVEL_FEATURE_SIZE = 128
18
+ resnet=models.resnet18(pretrained=True)
19
+ resnet.conv1.weight=nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
20
+
21
+ self.midlevel_resnet =nn.Sequential(*list(resnet.children())[0:6])
22
+
23
+ self.upsample = nn.Sequential(
24
+ nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1),
25
+ nn.BatchNorm2d(128),
26
+ nn.ReLU(),
27
+ nn.Upsample(scale_factor=2),
28
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
29
+ nn.BatchNorm2d(64),
30
+ nn.ReLU(),
31
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
32
+ nn.BatchNorm2d(64),
33
+ nn.ReLU(),
34
+ nn.Upsample(scale_factor=2),
35
+ nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
36
+ nn.BatchNorm2d(32),
37
+ nn.ReLU(),
38
+ nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),
39
+ nn.Upsample(scale_factor=2)
40
+ )
41
+
42
+ def forward(self, input):
43
+
44
+ # Pass input through ResNet-gray to extract features
45
+ midlevel_features = self.midlevel_resnet(input)
46
+
47
+ # Upsample to get colors
48
+ output = self.upsample(midlevel_features)
49
+ return output
50
+
51
+
52
+
53
+ def show_output(grayscale_input, ab_input):
54
+ '''Show/save rgb image from grayscale and ab channels
55
+ Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
56
+ color_image = torch.cat((grayscale_input, ab_input), 0).detach().numpy() # combine channels
57
+ color_image = color_image.transpose((1, 2, 0)) # rescale for matplotlib
58
+ color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
59
+ color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
60
+ color_image = lab2rgb(color_image.astype(np.float64))
61
+ grayscale_input = grayscale_input.squeeze().numpy()
62
+ # plt.imshow(grayscale_input)
63
+ # plt.imshow(color_image)
64
+ return color_image
65
+
66
+ model=torch.load("model-final.pth")
67
+
68
+ def colorize(img_path,print_img=True):
69
+ img=cv2.imread(img_path)
70
+ img=cv2.resize(img,(224,224))
71
+ grayscale_input= torch.Tensor(rgb2gray(img))
72
+ ab_input=model(grayscale_input.unsqueeze(0).unsqueeze(0)).squeeze(0)
73
+ predicted=show_output(grayscale_input.unsqueeze(0), ab_input)
74
+ if print_img:
75
+ plt.imshow(predicted)
76
+ return predicted
77
+
78
+ # out=colorize("download.png")
79
+ # print(out)
prediction.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/torch_stable.html
2
+ torch==1.7.1+cpu
3
+ torchvision==0.9.1+cpu
4
+ numpy==1.18.5
5
+ opencv-python-headless==4.4.0.46
6
+ matplotlib==3.4.2
7
+ scikit-image==0.18.1
8
+ streamlit==0.81.1