bc180203823 commited on
Commit
fe407fb
·
verified ·
1 Parent(s): 0a6985b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """DDColor_colab.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/github/camenduru/DDColor-colab/blob/main/DDColor_colab.ipynb
8
+ """
9
+
10
+ # Commented out IPython magic to ensure Python compatibility.
11
+ # %cd /content
12
+ !git clone -b dev https://github.com/camenduru/DDColor
13
+
14
+ !apt -y install -qq aria2
15
+ !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/cv_ddcolor_image-colorization/resolve/main/pytorch_model.pt -d /content/DDColor/models -o pytorch_model.pt
16
+
17
+ !wget https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/audrey_hepburn.jpg -O /content/DDColor/in.jpg
18
+ !pip install -q timm
19
+
20
+ # %cd /content/DDColor
21
+
22
+ !sed -i 's/from \.version import __gitsha__, __version__/# from \.version import __gitsha__, __version__/' /content/DDColor/basicsr/__init__.py
23
+
24
+ import argparse
25
+ import cv2
26
+ import numpy as np
27
+ import os
28
+ from tqdm import tqdm
29
+ import torch
30
+ from basicsr.archs.ddcolor_arch import DDColor
31
+ import torch.nn.functional as F
32
+
33
+ class ImageColorizationPipeline(object):
34
+
35
+ def __init__(self, model_path, input_size=256, model_size='large'):
36
+
37
+ self.input_size = input_size
38
+ if torch.cuda.is_available():
39
+ self.device = torch.device('cuda')
40
+ else:
41
+ self.device = torch.device('cpu')
42
+
43
+ if model_size == 'tiny':
44
+ self.encoder_name = 'convnext-t'
45
+ else:
46
+ self.encoder_name = 'convnext-l'
47
+
48
+ self.decoder_type = "MultiScaleColorDecoder"
49
+
50
+ if self.decoder_type == 'MultiScaleColorDecoder':
51
+ self.model = DDColor(
52
+ encoder_name=self.encoder_name,
53
+ decoder_name='MultiScaleColorDecoder',
54
+ input_size=[self.input_size, self.input_size],
55
+ num_output_channels=2,
56
+ last_norm='Spectral',
57
+ do_normalize=False,
58
+ num_queries=100,
59
+ num_scales=3,
60
+ dec_layers=9,
61
+ ).to(self.device)
62
+ else:
63
+ self.model = DDColor(
64
+ encoder_name=self.encoder_name,
65
+ decoder_name='SingleColorDecoder',
66
+ input_size=[self.input_size, self.input_size],
67
+ num_output_channels=2,
68
+ last_norm='Spectral',
69
+ do_normalize=False,
70
+ num_queries=256,
71
+ ).to(self.device)
72
+
73
+ self.model.load_state_dict(
74
+ torch.load(model_path, map_location=torch.device('cpu'))['params'],
75
+ strict=False)
76
+ self.model.eval()
77
+
78
+ @torch.no_grad()
79
+ def process(self, img):
80
+ self.height, self.width = img.shape[:2]
81
+ # print(self.width, self.height)
82
+ # if self.width * self.height < 100000:
83
+ # self.input_size = 256
84
+
85
+ img = (img / 255.0).astype(np.float32)
86
+ orig_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1] # (h, w, 1)
87
+
88
+ # resize rgb image -> lab -> get grey -> rgb
89
+ img = cv2.resize(img, (self.input_size, self.input_size))
90
+ img_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]
91
+ img_gray_lab = np.concatenate((img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1)
92
+ img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB)
93
+
94
+ tensor_gray_rgb = torch.from_numpy(img_gray_rgb.transpose((2, 0, 1))).float().unsqueeze(0).to(self.device)
95
+ output_ab = self.model(tensor_gray_rgb).cpu() # (1, 2, self.height, self.width)
96
+
97
+ # resize ab -> concat original l -> rgb
98
+ output_ab_resize = F.interpolate(output_ab, size=(self.height, self.width))[0].float().numpy().transpose(1, 2, 0)
99
+ output_lab = np.concatenate((orig_l, output_ab_resize), axis=-1)
100
+ output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR)
101
+
102
+ output_img = (output_bgr * 255.0).round().astype(np.uint8)
103
+
104
+ return output_img
105
+
106
+ colorizer = ImageColorizationPipeline(model_path='/content/DDColor/models/pytorch_model.pt', input_size=512)
107
+
108
+ # helper function taken from: https://huggingface.co/blog/stable_diffusion
109
+ from PIL import Image
110
+ def image_grid(imgs, rows, cols):
111
+ assert len(imgs) == rows*cols
112
+
113
+ w, h = imgs[0].size
114
+ grid = Image.new('RGB', size=(cols*w, rows*h))
115
+ grid_w, grid_h = grid.size
116
+
117
+ for i, img in enumerate(imgs):
118
+ grid.paste(img, box=(i%cols*w, i//cols*h))
119
+ return grid
120
+
121
+ image_in = cv2.imread('/content/DDColor/in.jpg')
122
+ image_out = colorizer.process(image_in)
123
+ cv2.imwrite('/content/DDColor/out.jpg', image_out)
124
+ image_in_pil = Image.fromarray(cv2.cvtColor(image_in, cv2.COLOR_BGR2RGB))
125
+ image_out_pil = Image.fromarray(cv2.cvtColor(image_out, cv2.COLOR_BGR2RGB))
126
+ images = [image_in_pil, image_out_pil]
127
+ grid = image_grid(images, rows=1, cols=2)
128
+ grid