Spaces:
Sleeping
Sleeping
Commit
·
17dff85
1
Parent(s):
61fb4bf
Upload 4 files
Browse files- UTILS/__init__.py +1 -0
- UTILS/more_dim.py +93 -0
- app.py +181 -0
- models/model_v48.pth +3 -0
UTILS/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .more_dim import get_more_dim
|
UTILS/more_dim.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import cv2 as cv
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
|
7 |
+
|
8 |
+
def get_binary_img_(img):
|
9 |
+
gray_img = img
|
10 |
+
if len(img.shape) > 2:
|
11 |
+
gray_img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
|
12 |
+
binary_img = cv.Canny(gray_img, 80, 150)
|
13 |
+
return binary_img
|
14 |
+
|
15 |
+
|
16 |
+
def get_morp_dilate_(binary_img):
|
17 |
+
kernel = cv.getStructuringElement(cv.MORPH_RECT, (3, 3))
|
18 |
+
# morp_dilate = cv.morphologyEx(binaryImg, cv.MORPH_DILATE, kernel=(1, 3), iterations=3)
|
19 |
+
# morp_dilate = cv.morphologyEx(morp_dilate, cv.MORPH_DILATE, kernel=(3, 1), iterations=3)
|
20 |
+
# morp_dilate = cv.morphologyEx(binaryImg, cv.MORPH_DILATE, kernel=(11, 11), iterations=3)
|
21 |
+
morp_dilate = cv.morphologyEx(binary_img, cv.MORPH_DILATE, kernel=kernel, iterations=3)
|
22 |
+
return morp_dilate
|
23 |
+
|
24 |
+
|
25 |
+
def get_water_img_(img, morp_dilate):
|
26 |
+
# 寻找图像轮廓 返回修改后的 图像的轮廓 以及它们的层次
|
27 |
+
# contours, hierarchy = cv.findContours(gray_img, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
|
28 |
+
# contours, hierarchy = cv.findContours(thresh, cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
|
29 |
+
# contours, hierarchy = cv.findContours(thresh, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
|
30 |
+
contours, hierarchy = cv.findContours(morp_dilate, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
|
31 |
+
# 32位有符号整数类型,
|
32 |
+
marks = np.zeros(morp_dilate.shape[:2], np.int32)
|
33 |
+
# 绘制每一个轮廓
|
34 |
+
for index in range(len(contours)):
|
35 |
+
# 对marks进行标记,对不同区域的轮廓使用不同的亮度绘制,相当于设置注水点,有多少个轮廓,就有多少个轮廓
|
36 |
+
# 图像上不同线条的灰度值是不同的,底部略暗,越往上灰度越高
|
37 |
+
marks = cv.drawContours(marks, contours, index, (index, index, index), 1, 8, hierarchy)
|
38 |
+
|
39 |
+
# 使用分水岭算法
|
40 |
+
# 经过watershed函数的处理,不同区域间的值被置为-1(边界)没有标记清楚的区域被置为0,其他每个区域的值保持不变:1,2,...,contours.size()
|
41 |
+
marks_water = cv.watershed(img, marks)
|
42 |
+
return marks_water
|
43 |
+
|
44 |
+
|
45 |
+
def get_mask_img_(morp_dilate, file_dir):
|
46 |
+
contours, hierarchy = cv.findContours(morp_dilate, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
|
47 |
+
# 32位有符号整数类型,
|
48 |
+
marks = np.zeros(morp_dilate.shape[:2], np.int32)
|
49 |
+
for index in range(len(contours)):
|
50 |
+
dist = cv.pointPolygonTest(contours[index], (marks.shape[0] // 2, marks.shape[1] // 2), True)
|
51 |
+
if dist >= 0:
|
52 |
+
marks = cv.drawContours(marks, contours, contourIdx=index, color=1, thickness=1, lineType=8,
|
53 |
+
hierarchy=hierarchy)
|
54 |
+
|
55 |
+
edges = np.zeros((marks.shape[0] + 2, marks.shape[1] + 2), np.uint8) # 掩码,长短需要加2个像素
|
56 |
+
try:
|
57 |
+
cv.floodFill(marks, edges, (marks.shape[0] // 2, marks.shape[1] // 2), 1, cv.FLOODFILL_MASK_ONLY) # 漫水填充
|
58 |
+
except Exception as e:
|
59 |
+
if file_dir:
|
60 |
+
print(file_dir)
|
61 |
+
print(e)
|
62 |
+
print("=================")
|
63 |
+
print(traceback.format_exc())
|
64 |
+
# raise e
|
65 |
+
marks = np.ones(morp_dilate.shape[:2], np.int32)
|
66 |
+
return marks
|
67 |
+
|
68 |
+
|
69 |
+
def get_binary_img(binary_img, mask):
|
70 |
+
masked_binary_img = cv.bitwise_and(binary_img, binary_img, mask=mask.astype('uint8'))
|
71 |
+
return masked_binary_img
|
72 |
+
|
73 |
+
|
74 |
+
def get_water_img(img, morp_dilate, mask):
|
75 |
+
water_img = get_water_img_(img, morp_dilate)
|
76 |
+
masked_water = cv.bitwise_and(water_img, water_img, mask=mask.astype('uint8'))
|
77 |
+
return masked_water
|
78 |
+
|
79 |
+
|
80 |
+
def get_more_dim(img, file_dir, source_img=None):
|
81 |
+
if source_img is None:
|
82 |
+
source_img = img
|
83 |
+
# img: ndarray: 852, 847, 3
|
84 |
+
binary_img = get_binary_img_(img)
|
85 |
+
morp_dilate = get_morp_dilate_(binary_img)
|
86 |
+
mask = get_mask_img_(morp_dilate, file_dir)
|
87 |
+
|
88 |
+
masked_binary_img = get_binary_img(binary_img, mask)
|
89 |
+
masked_water = get_water_img(source_img, morp_dilate, mask)
|
90 |
+
# print(f"masked_binary_img shape:{masked_binary_img.shape} masked_water shape:{masked_water.shape}")
|
91 |
+
# print(f"type(masked_binary_img):{type(masked_binary_img)} type(masked_water):{type(masked_water)}")
|
92 |
+
# return np.stack((masked_binary_img, mask), axis=0)
|
93 |
+
return masked_binary_img, masked_water
|
app.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import traceback
|
3 |
+
from io import BytesIO
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
from matplotlib.colors import ListedColormap
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
from UTILS import get_more_dim
|
14 |
+
|
15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
|
17 |
+
model_path = 'models/model_v48.pth'
|
18 |
+
model_pic_size = 512
|
19 |
+
model_class_num = 14
|
20 |
+
model = torch.load(model_path, map_location=torch.device('cpu'))
|
21 |
+
model = model.to(device)
|
22 |
+
|
23 |
+
colors = ['Black', 'Silver', 'White', 'Brown', 'LightCoral', 'Tomato', 'LightSalmon', 'Chocolate', 'Tan',
|
24 |
+
'PapayaWhip', 'Gold', 'Ivory', 'GreenYellow', 'Green', 'DarkSeaGreen', 'DarkTurquoise', 'LightBLue',
|
25 |
+
'SteelBlue']
|
26 |
+
mode = 'predict'
|
27 |
+
|
28 |
+
|
29 |
+
def get_predict(origin_img, need_subplot=False):
|
30 |
+
features, pad_width, pad_height = get_features(origin_img, pic_size=model_pic_size)
|
31 |
+
predict_npy, subplot_img = save_predict(model, features, device=device, class_num=model_class_num,
|
32 |
+
need_subplot=need_subplot)
|
33 |
+
return predict_npy, subplot_img, pad_width, pad_height
|
34 |
+
|
35 |
+
|
36 |
+
def save_predict(model, features, device, class_num=14, need_subplot=False):
|
37 |
+
cmap = ListedColormap(colors[:class_num])
|
38 |
+
|
39 |
+
model.eval()
|
40 |
+
with torch.no_grad():
|
41 |
+
features = features.to(device)
|
42 |
+
|
43 |
+
predictions = model(features)
|
44 |
+
|
45 |
+
features = torch.squeeze(features)
|
46 |
+
features = features.detach().cpu()
|
47 |
+
predictions = torch.squeeze(predictions)
|
48 |
+
predictions = predictions.detach().cpu()
|
49 |
+
|
50 |
+
features_len = features.shape[0]
|
51 |
+
|
52 |
+
origin_img = transforms.ToPILImage()(features[:3])
|
53 |
+
binary_img = features[3]
|
54 |
+
water_img = features[4]
|
55 |
+
predict_img = label_to_img(predictions)
|
56 |
+
predict_npy = predict_img.numpy().astype('uint8')
|
57 |
+
|
58 |
+
subplot = None
|
59 |
+
if need_subplot:
|
60 |
+
subplot = save_subplot(features_len, origin_img, predict_img, binary_img, water_img, vmax=class_num,
|
61 |
+
cmap=cmap)
|
62 |
+
|
63 |
+
return predict_npy, subplot
|
64 |
+
|
65 |
+
|
66 |
+
def label_to_img(label):
|
67 |
+
max_label_values, max_label_indices = torch.max(label, dim=0)
|
68 |
+
return max_label_indices
|
69 |
+
|
70 |
+
|
71 |
+
def save_subplot(features_len, origin_img, predict_img, feature_1=None, feature_2=None, vmax=14,
|
72 |
+
cmap=None):
|
73 |
+
plt.clf()
|
74 |
+
plt.close()
|
75 |
+
|
76 |
+
# colorbar 左 下 宽 高 ;设置colorbar位置;
|
77 |
+
rect = [0.92, 0.36, 0.015, 0.99 - 0.37 * 2]
|
78 |
+
|
79 |
+
fig = plt.figure()
|
80 |
+
subplot_num = features_len - 2 + 1
|
81 |
+
subplot_count = 0
|
82 |
+
|
83 |
+
subplot_count += 1
|
84 |
+
plt.subplot(1, subplot_num, subplot_count)
|
85 |
+
plt.imshow(origin_img)
|
86 |
+
if features_len > 3:
|
87 |
+
subplot_count += 1
|
88 |
+
plt.subplot(1, subplot_num, subplot_count)
|
89 |
+
plt.imshow(feature_1)
|
90 |
+
if features_len > 4:
|
91 |
+
subplot_count += 1
|
92 |
+
plt.subplot(1, subplot_num, subplot_count)
|
93 |
+
plt.imshow(feature_2)
|
94 |
+
|
95 |
+
subplot_count += 1
|
96 |
+
plt.subplot(1, subplot_num, subplot_count)
|
97 |
+
im = plt.imshow(predict_img, vmin=-1, vmax=vmax, cmap=cmap)
|
98 |
+
# 前面三个子图的总宽度 为 全部宽度的 0.9;剩下的0.1用来放置colorbar
|
99 |
+
fig.subplots_adjust(right=0.9)
|
100 |
+
cbar_ax = fig.add_axes(rect)
|
101 |
+
plt.colorbar(im, cax=cbar_ax)
|
102 |
+
|
103 |
+
with BytesIO() as out:
|
104 |
+
plt.savefig(out, dpi=300)
|
105 |
+
subplot_bytes = out.getvalue()
|
106 |
+
return subplot_bytes
|
107 |
+
|
108 |
+
|
109 |
+
def get_features(origin_img, pic_size):
|
110 |
+
img = origin_img.convert('RGB')
|
111 |
+
img_np = np.array(img)
|
112 |
+
try:
|
113 |
+
masked_binary_img, masked_water = get_more_dim(img_np, file_dir=None)
|
114 |
+
except Exception as e:
|
115 |
+
logging.error(e)
|
116 |
+
logging.error("=================")
|
117 |
+
logging.error(traceback.format_exc())
|
118 |
+
masked_binary_img = np.zeros(img_np.shape[:2], np.int32)
|
119 |
+
masked_water = np.zeros(img_np.shape[:2], np.int32)
|
120 |
+
img, pad_width, pad_height = transform_pic_shape(img, pic_size)
|
121 |
+
masked_binary_img, _, _ = transform_pic_shape(torch.tensor(masked_binary_img), pic_size)
|
122 |
+
masked_water, _, _ = transform_pic_shape(torch.tensor(masked_water), pic_size)
|
123 |
+
data_mode_dim = torch.stack((masked_binary_img, masked_water), axis=0)
|
124 |
+
img = transforms.ToTensor()(img)
|
125 |
+
featurs = torch.cat((img, data_mode_dim), dim=0)
|
126 |
+
featurs = torch.unsqueeze(featurs, dim=0)
|
127 |
+
return featurs, pad_width, pad_height
|
128 |
+
|
129 |
+
|
130 |
+
def transform_pic_shape(img, pic_size):
|
131 |
+
# 对于RGB图
|
132 |
+
# Image.size为(宽,高)
|
133 |
+
# array.shape为(高,宽,通道数)
|
134 |
+
# array.size为 高x宽x通道数 的总个数
|
135 |
+
height, width = get_image_shape(img)
|
136 |
+
if height > pic_size - 1 or width > pic_size - 1:
|
137 |
+
is_unsqueeze = False
|
138 |
+
if type(img) == torch.Tensor and len(img.shape) == 2:
|
139 |
+
img = torch.unsqueeze(img, dim=0)
|
140 |
+
is_unsqueeze = True
|
141 |
+
img = transforms.Resize(size=pic_size - 1, max_size=pic_size,
|
142 |
+
interpolation=transforms.InterpolationMode.NEAREST)(img)
|
143 |
+
if is_unsqueeze:
|
144 |
+
img = torch.squeeze(img)
|
145 |
+
height, width = get_image_shape(img)
|
146 |
+
|
147 |
+
pad_width = 0
|
148 |
+
pad_height = 0
|
149 |
+
if height < pic_size or width < pic_size:
|
150 |
+
# 当为 a 时,上下左右均填充 a 个像素
|
151 |
+
# 当为 (a, b) 时,左右填充 a 个像素,上下填充 b 个像素
|
152 |
+
# 当为 (a, b, c, d) 时,左上右下分别填充 a,b,c,d
|
153 |
+
# padding_mode: 填充模式,有 4 种模式,constant、edge、reflect、symmetric
|
154 |
+
pad_width = (pic_size - width) // 2
|
155 |
+
pad_height = (pic_size - height) // 2
|
156 |
+
img = transforms.Pad(
|
157 |
+
padding=[pad_width, pad_height, pic_size - pad_width - width, pic_size - pad_height - height],
|
158 |
+
fill=0)(img)
|
159 |
+
return img, pad_width, pad_height
|
160 |
+
|
161 |
+
|
162 |
+
def get_image_shape(img):
|
163 |
+
if type(img) == Image.Image:
|
164 |
+
width, height = img.size
|
165 |
+
else:
|
166 |
+
if len(img.shape) == 3:
|
167 |
+
channel_num, height, width = img.shape
|
168 |
+
else:
|
169 |
+
height, width = img.shape
|
170 |
+
return height, width
|
171 |
+
|
172 |
+
|
173 |
+
def greet(img):
|
174 |
+
predict_npy, subplot_img, pad_width, pad_height = get_predict(img, need_subplot=False)
|
175 |
+
predict_npy = predict_npy / model_class_num * 255
|
176 |
+
predict_img = Image.fromarray(predict_npy).convert(mode='L')
|
177 |
+
return predict_img
|
178 |
+
|
179 |
+
|
180 |
+
iface = gr.Interface(fn=greet, inputs=gr.Image(type="pil"), outputs="image")
|
181 |
+
iface.launch(server_name="0.0.0.0", share=True)
|
models/model_v48.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1cd93eeccf176fcee883213eef18edfc2098d40733279617b43c070ae73227c9
|
3 |
+
size 183422311
|