zhangshengdong commited on
Commit
17dff85
·
1 Parent(s): 61fb4bf

Upload 4 files

Browse files
Files changed (4) hide show
  1. UTILS/__init__.py +1 -0
  2. UTILS/more_dim.py +93 -0
  3. app.py +181 -0
  4. models/model_v48.pth +3 -0
UTILS/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .more_dim import get_more_dim
UTILS/more_dim.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+
3
+ import numpy as np
4
+ import cv2 as cv
5
+ from matplotlib import pyplot as plt
6
+
7
+
8
+ def get_binary_img_(img):
9
+ gray_img = img
10
+ if len(img.shape) > 2:
11
+ gray_img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
12
+ binary_img = cv.Canny(gray_img, 80, 150)
13
+ return binary_img
14
+
15
+
16
+ def get_morp_dilate_(binary_img):
17
+ kernel = cv.getStructuringElement(cv.MORPH_RECT, (3, 3))
18
+ # morp_dilate = cv.morphologyEx(binaryImg, cv.MORPH_DILATE, kernel=(1, 3), iterations=3)
19
+ # morp_dilate = cv.morphologyEx(morp_dilate, cv.MORPH_DILATE, kernel=(3, 1), iterations=3)
20
+ # morp_dilate = cv.morphologyEx(binaryImg, cv.MORPH_DILATE, kernel=(11, 11), iterations=3)
21
+ morp_dilate = cv.morphologyEx(binary_img, cv.MORPH_DILATE, kernel=kernel, iterations=3)
22
+ return morp_dilate
23
+
24
+
25
+ def get_water_img_(img, morp_dilate):
26
+ # 寻找图像轮廓 返回修改后的 图像的轮廓 以及它们的层次
27
+ # contours, hierarchy = cv.findContours(gray_img, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
28
+ # contours, hierarchy = cv.findContours(thresh, cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
29
+ # contours, hierarchy = cv.findContours(thresh, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
30
+ contours, hierarchy = cv.findContours(morp_dilate, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
31
+ # 32位有符号整数类型,
32
+ marks = np.zeros(morp_dilate.shape[:2], np.int32)
33
+ # 绘制每一个轮廓
34
+ for index in range(len(contours)):
35
+ # 对marks进行标记,对不同区域的轮廓使用不同的亮度绘制,相当于设置注水点,有多少个轮廓,就有多少个轮廓
36
+ # 图像上不同线条的灰度值是不同的,底部略暗,越往上灰度越高
37
+ marks = cv.drawContours(marks, contours, index, (index, index, index), 1, 8, hierarchy)
38
+
39
+ # 使用分水岭算法
40
+ # 经过watershed函数的处理,不同区域间的值被置为-1(边界)没有标记清楚的区域被置为0,其他每个区域的值保持不变:1,2,...,contours.size()
41
+ marks_water = cv.watershed(img, marks)
42
+ return marks_water
43
+
44
+
45
+ def get_mask_img_(morp_dilate, file_dir):
46
+ contours, hierarchy = cv.findContours(morp_dilate, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
47
+ # 32位有符号整数类型,
48
+ marks = np.zeros(morp_dilate.shape[:2], np.int32)
49
+ for index in range(len(contours)):
50
+ dist = cv.pointPolygonTest(contours[index], (marks.shape[0] // 2, marks.shape[1] // 2), True)
51
+ if dist >= 0:
52
+ marks = cv.drawContours(marks, contours, contourIdx=index, color=1, thickness=1, lineType=8,
53
+ hierarchy=hierarchy)
54
+
55
+ edges = np.zeros((marks.shape[0] + 2, marks.shape[1] + 2), np.uint8) # 掩码,长短需要加2个像素
56
+ try:
57
+ cv.floodFill(marks, edges, (marks.shape[0] // 2, marks.shape[1] // 2), 1, cv.FLOODFILL_MASK_ONLY) # 漫水填充
58
+ except Exception as e:
59
+ if file_dir:
60
+ print(file_dir)
61
+ print(e)
62
+ print("=================")
63
+ print(traceback.format_exc())
64
+ # raise e
65
+ marks = np.ones(morp_dilate.shape[:2], np.int32)
66
+ return marks
67
+
68
+
69
+ def get_binary_img(binary_img, mask):
70
+ masked_binary_img = cv.bitwise_and(binary_img, binary_img, mask=mask.astype('uint8'))
71
+ return masked_binary_img
72
+
73
+
74
+ def get_water_img(img, morp_dilate, mask):
75
+ water_img = get_water_img_(img, morp_dilate)
76
+ masked_water = cv.bitwise_and(water_img, water_img, mask=mask.astype('uint8'))
77
+ return masked_water
78
+
79
+
80
+ def get_more_dim(img, file_dir, source_img=None):
81
+ if source_img is None:
82
+ source_img = img
83
+ # img: ndarray: 852, 847, 3
84
+ binary_img = get_binary_img_(img)
85
+ morp_dilate = get_morp_dilate_(binary_img)
86
+ mask = get_mask_img_(morp_dilate, file_dir)
87
+
88
+ masked_binary_img = get_binary_img(binary_img, mask)
89
+ masked_water = get_water_img(source_img, morp_dilate, mask)
90
+ # print(f"masked_binary_img shape:{masked_binary_img.shape} masked_water shape:{masked_water.shape}")
91
+ # print(f"type(masked_binary_img):{type(masked_binary_img)} type(masked_water):{type(masked_water)}")
92
+ # return np.stack((masked_binary_img, mask), axis=0)
93
+ return masked_binary_img, masked_water
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import traceback
3
+ from io import BytesIO
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ from matplotlib import pyplot as plt
10
+ from matplotlib.colors import ListedColormap
11
+ from torchvision import transforms
12
+
13
+ from UTILS import get_more_dim
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ model_path = 'models/model_v48.pth'
18
+ model_pic_size = 512
19
+ model_class_num = 14
20
+ model = torch.load(model_path, map_location=torch.device('cpu'))
21
+ model = model.to(device)
22
+
23
+ colors = ['Black', 'Silver', 'White', 'Brown', 'LightCoral', 'Tomato', 'LightSalmon', 'Chocolate', 'Tan',
24
+ 'PapayaWhip', 'Gold', 'Ivory', 'GreenYellow', 'Green', 'DarkSeaGreen', 'DarkTurquoise', 'LightBLue',
25
+ 'SteelBlue']
26
+ mode = 'predict'
27
+
28
+
29
+ def get_predict(origin_img, need_subplot=False):
30
+ features, pad_width, pad_height = get_features(origin_img, pic_size=model_pic_size)
31
+ predict_npy, subplot_img = save_predict(model, features, device=device, class_num=model_class_num,
32
+ need_subplot=need_subplot)
33
+ return predict_npy, subplot_img, pad_width, pad_height
34
+
35
+
36
+ def save_predict(model, features, device, class_num=14, need_subplot=False):
37
+ cmap = ListedColormap(colors[:class_num])
38
+
39
+ model.eval()
40
+ with torch.no_grad():
41
+ features = features.to(device)
42
+
43
+ predictions = model(features)
44
+
45
+ features = torch.squeeze(features)
46
+ features = features.detach().cpu()
47
+ predictions = torch.squeeze(predictions)
48
+ predictions = predictions.detach().cpu()
49
+
50
+ features_len = features.shape[0]
51
+
52
+ origin_img = transforms.ToPILImage()(features[:3])
53
+ binary_img = features[3]
54
+ water_img = features[4]
55
+ predict_img = label_to_img(predictions)
56
+ predict_npy = predict_img.numpy().astype('uint8')
57
+
58
+ subplot = None
59
+ if need_subplot:
60
+ subplot = save_subplot(features_len, origin_img, predict_img, binary_img, water_img, vmax=class_num,
61
+ cmap=cmap)
62
+
63
+ return predict_npy, subplot
64
+
65
+
66
+ def label_to_img(label):
67
+ max_label_values, max_label_indices = torch.max(label, dim=0)
68
+ return max_label_indices
69
+
70
+
71
+ def save_subplot(features_len, origin_img, predict_img, feature_1=None, feature_2=None, vmax=14,
72
+ cmap=None):
73
+ plt.clf()
74
+ plt.close()
75
+
76
+ # colorbar 左 下 宽 高 ;设置colorbar位置;
77
+ rect = [0.92, 0.36, 0.015, 0.99 - 0.37 * 2]
78
+
79
+ fig = plt.figure()
80
+ subplot_num = features_len - 2 + 1
81
+ subplot_count = 0
82
+
83
+ subplot_count += 1
84
+ plt.subplot(1, subplot_num, subplot_count)
85
+ plt.imshow(origin_img)
86
+ if features_len > 3:
87
+ subplot_count += 1
88
+ plt.subplot(1, subplot_num, subplot_count)
89
+ plt.imshow(feature_1)
90
+ if features_len > 4:
91
+ subplot_count += 1
92
+ plt.subplot(1, subplot_num, subplot_count)
93
+ plt.imshow(feature_2)
94
+
95
+ subplot_count += 1
96
+ plt.subplot(1, subplot_num, subplot_count)
97
+ im = plt.imshow(predict_img, vmin=-1, vmax=vmax, cmap=cmap)
98
+ # 前面三个子图的总宽度 为 全部宽度的 0.9;剩下的0.1用来放置colorbar
99
+ fig.subplots_adjust(right=0.9)
100
+ cbar_ax = fig.add_axes(rect)
101
+ plt.colorbar(im, cax=cbar_ax)
102
+
103
+ with BytesIO() as out:
104
+ plt.savefig(out, dpi=300)
105
+ subplot_bytes = out.getvalue()
106
+ return subplot_bytes
107
+
108
+
109
+ def get_features(origin_img, pic_size):
110
+ img = origin_img.convert('RGB')
111
+ img_np = np.array(img)
112
+ try:
113
+ masked_binary_img, masked_water = get_more_dim(img_np, file_dir=None)
114
+ except Exception as e:
115
+ logging.error(e)
116
+ logging.error("=================")
117
+ logging.error(traceback.format_exc())
118
+ masked_binary_img = np.zeros(img_np.shape[:2], np.int32)
119
+ masked_water = np.zeros(img_np.shape[:2], np.int32)
120
+ img, pad_width, pad_height = transform_pic_shape(img, pic_size)
121
+ masked_binary_img, _, _ = transform_pic_shape(torch.tensor(masked_binary_img), pic_size)
122
+ masked_water, _, _ = transform_pic_shape(torch.tensor(masked_water), pic_size)
123
+ data_mode_dim = torch.stack((masked_binary_img, masked_water), axis=0)
124
+ img = transforms.ToTensor()(img)
125
+ featurs = torch.cat((img, data_mode_dim), dim=0)
126
+ featurs = torch.unsqueeze(featurs, dim=0)
127
+ return featurs, pad_width, pad_height
128
+
129
+
130
+ def transform_pic_shape(img, pic_size):
131
+ # 对于RGB图
132
+ # Image.size为(宽,高)
133
+ # array.shape为(高,宽,通道数)
134
+ # array.size为 高x宽x通道数 的总个数
135
+ height, width = get_image_shape(img)
136
+ if height > pic_size - 1 or width > pic_size - 1:
137
+ is_unsqueeze = False
138
+ if type(img) == torch.Tensor and len(img.shape) == 2:
139
+ img = torch.unsqueeze(img, dim=0)
140
+ is_unsqueeze = True
141
+ img = transforms.Resize(size=pic_size - 1, max_size=pic_size,
142
+ interpolation=transforms.InterpolationMode.NEAREST)(img)
143
+ if is_unsqueeze:
144
+ img = torch.squeeze(img)
145
+ height, width = get_image_shape(img)
146
+
147
+ pad_width = 0
148
+ pad_height = 0
149
+ if height < pic_size or width < pic_size:
150
+ # 当为 a 时,上下左右均填充 a 个像素
151
+ # 当为 (a, b) 时,左右填充 a 个像素,上下填充 b 个像素
152
+ # 当为 (a, b, c, d) 时,左上右下分别填充 a,b,c,d
153
+ # padding_mode: 填充模式,有 4 种模式,constant、edge、reflect、symmetric
154
+ pad_width = (pic_size - width) // 2
155
+ pad_height = (pic_size - height) // 2
156
+ img = transforms.Pad(
157
+ padding=[pad_width, pad_height, pic_size - pad_width - width, pic_size - pad_height - height],
158
+ fill=0)(img)
159
+ return img, pad_width, pad_height
160
+
161
+
162
+ def get_image_shape(img):
163
+ if type(img) == Image.Image:
164
+ width, height = img.size
165
+ else:
166
+ if len(img.shape) == 3:
167
+ channel_num, height, width = img.shape
168
+ else:
169
+ height, width = img.shape
170
+ return height, width
171
+
172
+
173
+ def greet(img):
174
+ predict_npy, subplot_img, pad_width, pad_height = get_predict(img, need_subplot=False)
175
+ predict_npy = predict_npy / model_class_num * 255
176
+ predict_img = Image.fromarray(predict_npy).convert(mode='L')
177
+ return predict_img
178
+
179
+
180
+ iface = gr.Interface(fn=greet, inputs=gr.Image(type="pil"), outputs="image")
181
+ iface.launch(server_name="0.0.0.0", share=True)
models/model_v48.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cd93eeccf176fcee883213eef18edfc2098d40733279617b43c070ae73227c9
3
+ size 183422311