phenotyping_pipeline / bgremover.py
anfruizhu's picture
Update bgremover.py
d0e4e39 verified
raw
history blame contribute delete
26.3 kB
import cv2 as cv
import numpy as np
from PIL import Image
import glob
import pathlib
import sys
import u2net_utils
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim
from u2net_utils.data_loader import RescaleT
from u2net_utils.data_loader import ToTensor
from u2net_utils.data_loader import ToTensorLab
from u2net_utils.data_loader import SalObjDataset
from u2net_utils.model import U2NET # full size version 173.6 MB
from u2net_utils.model import U2NETP # small version u2net 4.7 MB
from torchvision import models
import onnxruntime as ort
import cv2 as cv
import numpy as np
from torchvision.transforms import v2 as transforms
# MODEL_PATH = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_gpu\models\u2net.pth"
# MODEL_PATH = r"D:\CIAT\catalogue\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models"
# MODEL_PATH = r"D:\local_mydata\models\spidermites\best_models"
MODEL_PATH = "./models"
#************************
# from loguru import logger
# from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
# import subprocess
# # Grounding DINO
# import GroundingDINO.groundingdino.datasets.transforms as T
# from GroundingDINO.groundingdino.models import build_model
# from GroundingDINO.groundingdino.util import box_ops
# from GroundingDINO.groundingdino.util.slconfig import SLConfig
# from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# from huggingface_hub import hf_hub_download
import gc
def clear():
gc.collect()
torch.cuda.empty_cache()
# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
class BackgroundRemover():
def __init__(self):
#Load model
#model_dir = "/workspace/u2net.pth"
#model_dir = "D:/local_mydata/models/u2net.pth"
model_dir = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_gpu\models\u2net.pth"
model_dir = os.path.join(MODEL_PATH, "u2net.pth")
## Load model
net = U2NET(3,1)
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_dir))
net.cuda()
else:
net.load_state_dict(torch.load(model_dir, map_location='cpu'))
net.eval()
self.net = net
def remove_background(self, filepath_image):
img_name_list = [filepath_image]
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
lbl_name_list = [],
transform=transforms.Compose([RescaleT(320),
ToTensorLab(flag=0)])
)
test_salobj_dataloader = DataLoader(test_salobj_dataset,
batch_size=1,
shuffle=False,
num_workers=1)
net = self.net
for i_test, data_test in enumerate(test_salobj_dataloader):
print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
# normalization
pred = d1[:,0,:,:]
pred = normPRED(pred)
# save results to test_results folder
#if not os.path.exists(prediction_dir):
# os.makedirs(prediction_dir, exist_ok=True)
#save_output(img_name_list[i_test],pred,prediction_dir)
predict = pred
predict = predict.squeeze()
#mask_torch.permute(1, 2, 0).detach().cpu().numpy()
predict_np = predict.cpu().data.numpy()
img = cv.imread(filepath_image)
w = img.shape[1]
h = img.shape[0]
#im = Image.fromarray(predict_np*255).convert('RGB')
#image = io.imread(filepath_image)
#imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
imo = cv.resize(predict_np, (w,h), cv.INTER_LINEAR )
#del d1,d2,d3,d4,d5,d6,d7
return imo
def remove_background_save(self, path_in, path_out, path_out_mask = None):
print("remove_background_save")
mask_torch = self.remove_background(path_in)
mask = mask_torch*255
mask = mask.astype(np.uint8)
img = cv.imread(path_in)
mask0 = mask#cv.UMat(cv.imread(mask,0))
#127
#200
ret,binary_mask = cv.threshold(mask0,80,255,cv.THRESH_BINARY)
binary_mask = np.uint8(binary_mask)
res = cv.bitwise_and(img,img, mask = binary_mask)
cv.imwrite(path_out, res)
if not (path_out_mask == None):
cv.imwrite(path_out_mask, mask)
def remove_background_dir(self, path_in, path_out):
img_name_list = glob.glob(os.path.join(path_in, "*.jpg"))
for img_name in img_name_list:
img_name_output = img_name.replace(path_in, path_out)
if not os.path.exists(img_name_output):
self.remove_background_save(img_name, img_name_output)
print(img_name.replace(path_in, path_out))
def remove_background_gradio(self, np_image):
w = np_image.shape[1]
h = np_image.shape[0]
#image = torch.tensor(np_image)
#image = image.permute(2,0,1)
image = np_image#Image.fromarray(np_image)
imidx = np.array([0])
#label = "test"
#***
label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2])
if(3==len(label_3.shape)):
label = label_3[:,:,0]
elif(2==len(label_3.shape)):
label = label_3
if(3==len(image.shape) and 2==len(label.shape)):
label = label[:,:,np.newaxis]
elif(2==len(image.shape) and 2==len(label.shape)):
image = image[:,:,np.newaxis]
label = label[:,:,np.newaxis]
#***
sample = {'imidx':imidx, 'image':image, 'label':label}
print(image.shape)
print(label.shape)
eval_transform = transforms.Compose([RescaleT(320),ToTensorLab(flag=0)])
#eval_transform = transforms.Compose([RescaleT(320)])
#eval_transform = transforms.Compose([RescaleT(320)])
#eval_transform = transforms.Compose([ToTensorLab(flag=0)])
#eval_transform = transforms.Compose([transforms.Resize(320)
# , transforms.ToTensor()])
#eval_transform = transforms.Compose([transforms.Resize(320)])
test_salobj_dataloader = DataLoader(sample,
batch_size=1,
shuffle=False,
num_workers=1)
sample = eval_transform(sample)
net = self.net
#for i_test, data_test in enumerate(test_salobj_dataloader):
#device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#x = eval_transform(sample)
#x = x[:3, ...].to(device)
inputs_test = sample['image']
inputs_test = inputs_test.type(torch.FloatTensor)
inputs_test = inputs_test.unsqueeze(0)
print(inputs_test.shape)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
# normalization
pred = d1[:,0,:,:]
pred = normPRED(pred)
predict = pred
predict = predict.squeeze()
#mask_torch.permute(1, 2, 0).detach().cpu().numpy()
predict_np = predict.cpu().data.numpy()
imo = cv.resize(predict_np, (w,h), cv.INTER_LINEAR )
mask = imo*255
mask = mask.astype(np.uint8)
mask0 = mask#cv.UMat(cv.imread(mask,0))
#127
#200
ret,binary_mask = cv.threshold(mask0,80,255,cv.THRESH_BINARY)
#ret,binary_mask = cv.threshold(mask0,233,255,cv.THRESH_BINARY)
binary_mask = np.uint8(binary_mask)
res = cv.bitwise_and(np_image,np_image, mask = binary_mask)
return mask, res
def apply_mask(self, input, mask, threshold):
mask = cv.cvtColor(mask, cv.COLOR_BGR2GRAY)
ret,binary_mask = cv.threshold(mask,threshold,255,cv.THRESH_BINARY)
#binary_mask = np.uint8(binary_mask)
#binary_mask = mask
print("apply mask")
print(input.shape)
print(input.dtype)
print(binary_mask.shape)
print(binary_mask.dtype)
res = cv.bitwise_and(input,input, mask = binary_mask)
# foreground_alpha = mask.astype(np.float32) / 255.0
# # Create a new image to store the result with same size and type as foreground
# blended_image = np.zeros_like(input)
# # Loop through each pixel and apply alpha based on mask value
# for channel in range(3): # Loop through BGR channels
# blended_image[:, :, channel] = input[:, :, channel] * foreground_alpha
return res, binary_mask
def get_transform(train = True):
transforms_list = []
#if train:
# transforms.append(T.RandomHorizontalFlip(0.5))
transforms_list.append(transforms.Resize(256))
transforms_list.append(transforms.CenterCrop(256))
#transforms_list.append(transforms.ToDtype(torch.float, scale=True))
transforms_list.append(transforms.ToTensor())
#transforms_list.append(transforms.ToDtype(torch.float32, scale=True))
transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
return transforms.Compose(transforms_list)
class DamageClassifier():
def __init__(self):
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model_name =""
def initialize(self, model_name):
#Load model
model = None
if model_name == "Resnet18":
model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\resnet18_SpidermitesModel.pth"
model_filepath = os.path.join(MODEL_PATH, "resnet18_SpidermitesModel.pth")
model = models.resnet18(weights='IMAGENET1K_V1')
if model_name == "Resnet152":
model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\short_resnet152_SpidermitesModel_44_44.pth"
model_filepath = os.path.join(MODEL_PATH, "short_resnet152_SpidermitesModel_44_44.pth")
model = models.resnet152(weights='IMAGENET1K_V1')
if model_name == "Googlenet":
model_filepath = r"\\catalogue.cgiarad.org\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\regnet_x_32gf_SpidermitesModel.pth"
model_filepath = model_filepath = os.path.join(MODEL_PATH, "regnet_x_32gf_SpidermitesModel.pth")
model = models.regnet_x_32gf(weights='IMAGENET1K_V1')
if model_name == "Regnet32":
model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\short_resnet18_SpidermitesModel.pth"
model_filepath = model_filepath = os.path.join(MODEL_PATH, "short_resnet18_SpidermitesModel.pth")
model = models.resnet18(weights='IMAGENET1K_V1')
if not model is None:
#Add fully connected layer at the end with num_classes as output
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 4)
if torch.cuda.is_available():
model.load_state_dict(torch.load(model_filepath))
model.cuda()
else:
model.load_state_dict(torch.load(model_filepath, map_location='cpu'))
model.eval()
self.model = model
self.model_name = model_name
return
def inference(self, np_image, model_name):
if model_name == "Regnet":
model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\regnet_x_32gf_SpidermitesModel.onnx"
model_filepath = model_filepath = os.path.join(MODEL_PATH, "regnet_x_32gf_SpidermitesModel.onnx")
ort_sess = ort.InferenceSession(model_filepath
,providers=ort.get_available_providers()
)
transforms_list = []
transforms_list.append(transforms.ToTensor())
transforms_list.append(transforms.Resize(512))
transforms_list.append(transforms.CenterCrop(512))
transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
apply_t = transforms.Compose(transforms_list)
img = apply_t(np_image)
imgs = np.array([img.numpy()])
outputs = ort_sess.run(None, {'input': [img.numpy()]})
np_res = outputs[0][0]
final_res = {'0-(No damage)': np_res[0]
,'1-3-(Moderately damaged)': np_res[1]
,'4-7-(Damaged)': np_res[2]
,'8-10-(Severely damaged)': np_res[3]}
return final_res
else:
if self.model_name != model_name:
self.initialize(model_name)
with torch.no_grad():
print("inference")
print(np_image.shape)
pil_image = Image.fromarray(np_image.astype('uint8'))
data_transforms = get_transform(train = False)
img = data_transforms(pil_image)
inputs = img.to(self.device)
outputs = self.model(inputs.unsqueeze(0))
#_, preds = torch.max(outputs, 1)
print(outputs)
_, preds = torch.max(outputs, 1)
print(preds)
m = nn.Softmax(dim=1)
res = m(outputs)
print(res)
np_res = res[0].cpu().numpy()
print(np_res)
final_res = {'0-(No damage)': np_res[0]
,'1-3-(Moderately damaged)': np_res[1]
,'4-7-(Damaged)': np_res[2]
,'8-10-(Severely damaged)': np_res[3]}
return final_res
class ColorCheckerDetector():
def __init__(self):
return
def process(self, np_image_mask, np_image):
ret,binary_mask = cv.threshold(np_image_mask,80,255,cv.THRESH_BINARY)
binary_mask_C = cv.cvtColor(binary_mask, cv.COLOR_BGR2GRAY) #change to single channel
(contours, hierarchy) = cv.findContours(binary_mask_C, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
main_contour = contours[0]
# compute the center of the contour
moments = cv.moments(main_contour)
cx = int(moments["m10"] / moments["m00"])
cy = int(moments["m01"] / moments["m00"])
# Bounding rect
bb_x,bb_y,bb_w,bb_h = cv.boundingRect(binary_mask_C)
# Min Bounding rect
rect = cv.minAreaRect(main_contour)
box = cv.boxPoints(rect)
box = np.int64(box)
# Fitting line
rows,cols = binary_mask_C.shape[:2]
#[vx,vy,x,y] = cv.fitLine(main_contour, cv.DIST_L2,0,0.01,0.01)
[vx,vy,x,y] = cv.fitLine(box, cv.DIST_L2,0,0.01,0.01)
lefty = int((-x*vy/vx) + y)
righty = int(((cols-x)*vy/vx)+y)
point1 = (cols-1,righty)
point2 = (0,lefty)
angle = np.arctan2(np.abs(righty-lefty),cols)
# rotation matrix
M_rot = cv.getRotationMatrix2D((cx, cy), -angle*180.0/np.pi, 1.0)
rotated = cv.warpAffine(np_image, M_rot, (binary_mask.shape[1], binary_mask.shape[0]))
#perspective transform
input_pts = box.astype(np.float32)
maxHeight = 200
maxWidth = 290
output_pts = np.float32([[0, 0],
[maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1] ,
[0, maxHeight - 1]]
)
M_per = cv.getPerspectiveTransform(input_pts,output_pts)
corrected = cv.warpPerspective(np_image,M_per,(maxWidth, maxHeight),flags=cv.INTER_LINEAR)
res = cv.drawContours(np_image, main_contour, -1, (255,255,0), 5)
res = cv.rectangle(res,(bb_x,bb_y),(bb_x+bb_w,bb_y+bb_h),(0,255,0),5)
res = cv.drawContours(res,[box],0,(0,0,255),5)
res = cv.line(res,(cols-1,righty),(0,lefty),(0,0,255),5)
return [res, rotated, corrected]
class BatchProcessor():
def __init__(self):
return
def batch_process(self, input_dir, output_dir, output_suffixes = ["output"], format="jpg", pattern='**/*.tiff', processing_fc=None, output_format = None):
if processing_fc == None:
print("Processing function is None")
return
else:
if output_format == None:
output_format = format
# Get list of files in folder and subfolders
pattern = '**/*.' + format
files = glob.glob(pattern, root_dir=input_dir, recursive=True)
for file in files:
filepath = os.path.join(input_dir, file)
basename = os.path.basename(filepath)
parent_dir = os.path.dirname(filepath)
extra_path = file.replace(basename,"")
output_dir = os.path.join(output_dir, extra_path)
# Create output filepath list
output_filepaths = []
for suffix in output_suffixes:
output_filepaths.append(os.path.join(output_dir, basename.replace("." + format, "_" + suffix + "." + output_format)))
if not os.path.exists(output_filepaths[0]):# Process only if first output file does not exist
if not os.path.exists(output_dir): # Create subfolders if necessary
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
processing_fc(filepath, output_filepaths) # Process and save file
print(file)
print(output_filepaths[0])
print("****")
class Segmentor():
def __init__(self):
self.sam_predictor = None
self.groundingdino_model = None
#self.sam_checkpoint = './sam_vit_h_4b8939.pth'
#self.sam_checkpoint = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\sam_vit_h_4b8939.pth"
self.sam_checkpoint = r"D:\local_mydev\Grounded-Segment-Anything\sam_vit_h_4b8939.pth"
# self.config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
# self.ckpt_repo_id = "ShilongLiu/GroundingDINO"
# self.ckpt_filename = "groundingdino_swint_ogc.pth"
self.config_file = r"D:\local_mydev\gsam\GroundingDINO\groundingdino\config\GroundingDINO_SwinT_OGC.py"
self.ckpt_repo_id = "ShilongLiu/GroundingDINO"
self.ckpt_filename = "groundingdino_swint_ogc.pth"
self.device ='cpu'
self.load_sam_model(self.device)
self.load_groundingdino_model(self.device)
return
def get_sam_vit_h_4b8939(self):
return
# if not os.path.exists('./sam_vit_h_4b8939.pth'):
# logger.info(f"get sam_vit_h_4b8939.pth...")
# result = subprocess.run(['wget', '-nv', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
# print(f'wget sam_vit_h_4b8939.pth result = {result}')
def load_sam_model(self, device):
sam_checkpoint = self.sam_checkpoint
# initialize SAM
self.get_sam_vit_h_4b8939()
logger.info(f"initialize SAM model...")
sam_device = device
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
self.sam_predictor = SamPredictor(sam_model)
self.sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
def get_grounding_output(self, model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
# get phrase
tokenlizer = model.tokenizer
tokenized = tokenlizer(caption)
# build pred
pred_phrases = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
return boxes_filt, pred_phrases
def load_model_hf(self, model_config_path, repo_id, filename, device='cpu'):
args = SLConfig.fromfile(model_config_path)
model = build_model(args)
args.device = device
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
checkpoint = torch.load(cache_file, map_location=device)
print(checkpoint['model'])
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
print("Model loaded from {} \n => {}".format(cache_file, log))
_ = model.eval()
return model
def load_groundingdino_model(self, device):
config_file = self.config_file
ckpt_repo_id = self.ckpt_repo_id
ckpt_filename = self.ckpt_filename
# initialize groundingdino model
logger.info(f"initialize groundingdino model...")
self.groundingdino_model = self.load_model_hf(config_file, ckpt_repo_id, ckpt_filename, device=device) #'cpu')
logger.info(f"initialize groundingdino model...{type(self.groundingdino_model)}")
def show_mask(self, mask, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
color = np.array([1.0, 0, 0, 1.0])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
return mask_image
def process(self, np_image, text_prompt):
results = []
results.append(np_image)
#results.append(np_image)
sam_predictor = self.sam_predictor
groundingdino_model = self.groundingdino_model
image = np_image
#text_prompt = text_prompt.strip()
box_threshold = 0.3
text_threshold = 0.25
size = image.shape
H, W = size[1], size[0]
# RUN grounding dino model
groundingdino_device = 'cpu'
#image_dino = torch.from_numpy(image)
image_dino = Image.fromarray(image)
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
print(image.shape)
image_dino, _ = transform(image_dino, None) # 3, h, w
boxes_filt, pred_phrases =self.get_grounding_output(
groundingdino_model, image_dino, text_prompt, box_threshold, text_threshold, device=groundingdino_device
)
if sam_predictor:
sam_predictor.set_image(image)
if sam_predictor:
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
masks, _, _, _ = sam_predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes,
multimask_output = False,
)
print("RESULTS*************")
print(len(masks))
# results = []
for mask in masks:
print(type(mask))
print(mask.shape)
#mask_img = mask.cpu().data.numpy()
mask_img =self.show_mask(mask.cpu().numpy())
print(type(mask_img))
print(mask_img.shape)
results.append(mask_img)
# results.append(mask.cpu().numpy())
return results
#assert sam_checkpoint, 'sam_checkpoint is not found!'
return None