|
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation |
|
from PIL import Image |
|
import torch.nn as nn |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import os |
|
from skimage.morphology import dilation, square |
|
import cv2 |
|
import torch |
|
|
|
from .file_util import node_path,checkpoints_path |
|
from .utils import * |
|
|
|
|
|
model_folder_path = checkpoints_path("ComfyUI_Seg_VITON","segformer_b2_clothes") |
|
|
|
processor = SegformerImageProcessor.from_pretrained(model_folder_path) |
|
model = AutoModelForSemanticSegmentation.from_pretrained(model_folder_path) |
|
|
|
|
|
|
|
def get_segmentation(tensor_image): |
|
cloth = tensor2pil(tensor_image) |
|
|
|
inputs = processor(images=cloth, return_tensors="pt") |
|
outputs = model(**inputs) |
|
logits = outputs.logits.cpu() |
|
upsampled_logits = nn.functional.interpolate(logits, size=cloth.size[::-1], mode="bilinear", align_corners=False) |
|
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy() |
|
return pred_seg,cloth |
|
|
|
|
|
|
|
def seg_mask(tensor_image): |
|
pred_seg,cloth = get_segmentation(tensor_image) |
|
|
|
mask = (pred_seg != 0).astype(np.uint8) |
|
mask = np.where(mask == 1, 0, 255) |
|
|
|
|
|
cloth_mask = Image.fromarray(np.uint8(mask)) |
|
cloth_mask = cloth_mask.convert("RGB") |
|
return cloth_mask,cloth |
|
|
|
|
|
def seg_reverse_mask(file): |
|
return seg_mask_by_label(file,[0]) |
|
|
|
|
|
def seg_mask_by_label(file,array_label): |
|
pred_seg,cloth = get_segmentation(file) |
|
|
|
labels_to_keep = array_label |
|
mask = np.isin(pred_seg, labels_to_keep).astype(np.uint8) |
|
mask = np.where(mask == 1, 0, 255) |
|
cloth_mask = Image.fromarray(np.uint8(mask)) |
|
cloth_mask = cloth_mask.convert("RGB") |
|
return cloth_mask,cloth |
|
|
|
|
|
|
|
def seg_show(pred_seg): |
|
|
|
plt.imshow(pred_seg, cmap="viridis") |
|
plt.axis('off') |
|
plt.show() |
|
|
|
|