Spaces:
Sleeping
Sleeping
import logging | |
import traceback | |
from io import BytesIO | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
from matplotlib import pyplot as plt | |
from matplotlib.colors import ListedColormap | |
from torchvision import transforms | |
from UTILS import get_more_dim | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_path = 'models/model_v48.pth' | |
model_pic_size = 512 | |
model_class_num = 14 | |
model = torch.load(model_path, map_location=torch.device('cpu')) | |
model = model.to(device) | |
colors = ['Black', 'Silver', 'White', 'Brown', 'LightCoral', 'Tomato', 'LightSalmon', 'Chocolate', 'Tan', | |
'PapayaWhip', 'Gold', 'Ivory', 'GreenYellow', 'Green', 'DarkSeaGreen', 'DarkTurquoise', 'LightBLue', | |
'SteelBlue'] | |
mode = 'predict' | |
def get_predict(origin_img, need_subplot=False): | |
features, pad_width, pad_height = get_features(origin_img, pic_size=model_pic_size) | |
predict_npy, subplot_img = save_predict(model, features, device=device, class_num=model_class_num, | |
need_subplot=need_subplot) | |
return predict_npy, subplot_img, pad_width, pad_height | |
def save_predict(model, features, device, class_num=14, need_subplot=False): | |
cmap = ListedColormap(colors[:class_num]) | |
model.eval() | |
with torch.no_grad(): | |
features = features.to(device) | |
predictions = model(features) | |
features = torch.squeeze(features) | |
features = features.detach().cpu() | |
predictions = torch.squeeze(predictions) | |
predictions = predictions.detach().cpu() | |
features_len = features.shape[0] | |
origin_img = transforms.ToPILImage()(features[:3]) | |
binary_img = features[3] | |
water_img = features[4] | |
predict_img = label_to_img(predictions) | |
predict_npy = predict_img.numpy().astype('uint8') | |
subplot = None | |
if need_subplot: | |
subplot = save_subplot(features_len, origin_img, predict_img, binary_img, water_img, vmax=class_num, | |
cmap=cmap) | |
return predict_npy, subplot | |
def label_to_img(label): | |
max_label_values, max_label_indices = torch.max(label, dim=0) | |
return max_label_indices | |
def save_subplot(features_len, origin_img, predict_img, feature_1=None, feature_2=None, vmax=14, | |
cmap=None): | |
plt.clf() | |
plt.close() | |
# colorbar 左 下 宽 高 ;设置colorbar位置; | |
rect = [0.92, 0.36, 0.015, 0.99 - 0.37 * 2] | |
fig = plt.figure() | |
subplot_num = features_len - 2 + 1 | |
subplot_count = 0 | |
subplot_count += 1 | |
plt.subplot(1, subplot_num, subplot_count) | |
plt.imshow(origin_img) | |
if features_len > 3: | |
subplot_count += 1 | |
plt.subplot(1, subplot_num, subplot_count) | |
plt.imshow(feature_1) | |
if features_len > 4: | |
subplot_count += 1 | |
plt.subplot(1, subplot_num, subplot_count) | |
plt.imshow(feature_2) | |
subplot_count += 1 | |
plt.subplot(1, subplot_num, subplot_count) | |
im = plt.imshow(predict_img, vmin=-1, vmax=vmax, cmap=cmap) | |
# 前面三个子图的总宽度 为 全部宽度的 0.9;剩下的0.1用来放置colorbar | |
fig.subplots_adjust(right=0.9) | |
cbar_ax = fig.add_axes(rect) | |
plt.colorbar(im, cax=cbar_ax) | |
with BytesIO() as out: | |
plt.savefig(out, dpi=300) | |
subplot_bytes = out.getvalue() | |
return subplot_bytes | |
def get_features(origin_img, pic_size): | |
img = origin_img.convert('RGB') | |
img_np = np.array(img) | |
try: | |
masked_binary_img, masked_water = get_more_dim(img_np, file_dir=None) | |
except Exception as e: | |
logging.error(e) | |
logging.error("=================") | |
logging.error(traceback.format_exc()) | |
masked_binary_img = np.zeros(img_np.shape[:2], np.int32) | |
masked_water = np.zeros(img_np.shape[:2], np.int32) | |
img, pad_width, pad_height = transform_pic_shape(img, pic_size) | |
masked_binary_img, _, _ = transform_pic_shape(torch.tensor(masked_binary_img), pic_size) | |
masked_water, _, _ = transform_pic_shape(torch.tensor(masked_water), pic_size) | |
data_mode_dim = torch.stack((masked_binary_img, masked_water), axis=0) | |
img = transforms.ToTensor()(img) | |
featurs = torch.cat((img, data_mode_dim), dim=0) | |
featurs = torch.unsqueeze(featurs, dim=0) | |
return featurs, pad_width, pad_height | |
def transform_pic_shape(img, pic_size): | |
# 对于RGB图 | |
# Image.size为(宽,高) | |
# array.shape为(高,宽,通道数) | |
# array.size为 高x宽x通道数 的总个数 | |
height, width = get_image_shape(img) | |
if height > pic_size - 1 or width > pic_size - 1: | |
is_unsqueeze = False | |
if type(img) == torch.Tensor and len(img.shape) == 2: | |
img = torch.unsqueeze(img, dim=0) | |
is_unsqueeze = True | |
img = transforms.Resize(size=pic_size - 1, max_size=pic_size, | |
interpolation=transforms.InterpolationMode.NEAREST)(img) | |
if is_unsqueeze: | |
img = torch.squeeze(img) | |
height, width = get_image_shape(img) | |
pad_width = 0 | |
pad_height = 0 | |
if height < pic_size or width < pic_size: | |
# 当为 a 时,上下左右均填充 a 个像素 | |
# 当为 (a, b) 时,左右填充 a 个像素,上下填充 b 个像素 | |
# 当为 (a, b, c, d) 时,左上右下分别填充 a,b,c,d | |
# padding_mode: 填充模式,有 4 种模式,constant、edge、reflect、symmetric | |
pad_width = (pic_size - width) // 2 | |
pad_height = (pic_size - height) // 2 | |
img = transforms.Pad( | |
padding=[pad_width, pad_height, pic_size - pad_width - width, pic_size - pad_height - height], | |
fill=0)(img) | |
return img, pad_width, pad_height | |
def get_image_shape(img): | |
if type(img) == Image.Image: | |
width, height = img.size | |
else: | |
if len(img.shape) == 3: | |
channel_num, height, width = img.shape | |
else: | |
height, width = img.shape | |
return height, width | |
def greet(img): | |
predict_npy, subplot_img, pad_width, pad_height = get_predict(img, need_subplot=False) | |
predict_npy = predict_npy / model_class_num * 255 | |
predict_img = Image.fromarray(predict_npy).convert(mode='L') | |
return predict_img | |
iface = gr.Interface(fn=greet, inputs=gr.Image(type="pil"), outputs="image") | |
# iface.launch(server_name="0.0.0.0", share=True) | |
iface.launch(server_name="0.0.0.0") | |