Unet / src /run /unet /inference.py
dawn17's picture
Upload 35 files
bcc0f94
raw
history blame
3.54 kB
import os
import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import yaml
from albumentations.pytorch import ToTensorV2
from easydict import EasyDict
from PIL import Image
from src.models.unet.resunet import UNet as Model
class ResUnetInfer:
def __init__(self, model_path, config_path):
use_cuda = torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.config = self.load_config(config_path=config_path)
self.model = self.load_model(model_path=model_path)
self.transform = A.Compose(
[
A.Resize(self.config.input_size[0], self.config.input_size[1]),
A.Normalize(
mean=self.config.mean,
std=self.config.std,
max_pixel_value=255,
),
ToTensorV2(),
]
)
def load_model(self, model_path):
model = Model(
decoder_config=self.config.decoder_config, nclasses=self.config.nclasses
).to(self.device)
if os.path.isfile(model_path):
checkpoint = torch.load(model_path, map_location=self.device)
model.decoder.load_state_dict(
checkpoint["decoder_state_dict"], strict=False
)
model.output.load_state_dict(checkpoint["output_state_dict"], strict=False)
return model
def load_config(self, config_path):
with open(config_path, "r") as file:
yaml_data = yaml.safe_load(file)
return EasyDict(yaml_data)
def infer(self, image, image_weight=0.01):
self.model.eval()
input_tensor = self.transform(image=image)["image"].unsqueeze(0)
# get mask
with torch.no_grad():
"""
output_tensor = [batch, 1, 224, 224]
batch = 1
"""
output_tensor = self.model(input_tensor.to(self.device))
mask = torch.sigmoid(output_tensor)
mask = nn.UpsamplingBilinear2d(size=(image.shape[0], image.shape[1]))(mask)
mask = mask.squeeze(0)
# add zeros for green and blue channels
# our mask will be red in colour
zero_channels = torch.zeros((2, image.shape[0], image.shape[1]), device=self.device)
mask = torch.cat([mask, zero_channels], dim=0)
mask = mask.permute(1,2,0).cpu().numpy()
mask = np.uint8(255 * mask)
# overlap image and mask
mask = (1 - image_weight) * mask + image_weight * image
mask = mask / np.max(mask)
return np.uint8(255 * mask)
@staticmethod
def load_image_as_array(image_path):
# Load a PIL image
pil_image = Image.open(image_path)
# Convert PIL image to NumPy array
return np.array(pil_image.convert("RGB"))
@staticmethod
def plot_array(array: np.array, figsize=(10, 10)):
plt.figure(figsize=figsize)
plt.imshow(array)
plt.show()
@staticmethod
def save_numpy_as_image(numpy_array, image_path):
"""
Saves a NumPy array as an image.
Args:
numpy_array (numpy.ndarray): The NumPy array to be saved as an image.
image_path (str): The path where the image will be saved.
"""
# Convert the NumPy array to a PIL image
image = Image.fromarray(numpy_array)
# Save the PIL image to the specified path
image.save(image_path)