File size: 4,896 Bytes
464e64a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import streamlit as st
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET

# ---------------------- MODEL LOAD ---------------------- #
@st.cache_resource
def load_model():
    model_path = "cloth_segmentation/networks/u2net.pth"
    model = U2NET(3, 1)
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    model.eval()
    return model

model = load_model()

# ---------------------- UTILITY FUNCTIONS ---------------------- #
def refine_mask(mask):
    close_kernel = np.ones((5, 5), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
    erode_kernel = np.ones((3, 3), np.uint8)
    mask = cv2.erode(mask, erode_kernel, iterations=1)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
    mask = cv2.GaussianBlur(mask, (5, 5), 1.5)
    return mask

def segment_dress(image_np):
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])
    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)
    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()
    output = (output - output.min()) / (output.max() - output.min() + 1e-8)
    adaptive_thresh = np.mean(output) + 0.2
    dress_mask = (output > adaptive_thresh).astype(np.uint8) * 255
    dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
    return refine_mask(dress_mask)

def apply_grabcut(image_np, dress_mask):
    bgd_model = np.zeros((1, 65), np.float64)
    fgd_model = np.zeros((1, 65), np.float64)
    mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')
    coords = cv2.findNonZero(dress_mask)
    if coords is not None:
        x, y, w, h = cv2.boundingRect(coords)
        rect = (x, y, w, h)
        cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 3, cv2.GC_INIT_WITH_MASK)
    refined_mask = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
    return refine_mask(refined_mask)

def recolor_dress(image_np, dress_mask, target_color):
    target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
    img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
    dress_pixels = img_lab[dress_mask > 0]
    if len(dress_pixels) == 0:
        return image_np
    mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)
    a_shift = target_color_lab[1] - mean_A
    b_shift = target_color_lab[2] - mean_B
    img_lab[..., 1] = np.clip(img_lab[..., 1] + (dress_mask / 255.0) * a_shift, 0, 255)
    img_lab[..., 2] = np.clip(img_lab[..., 2] + (dress_mask / 255.0) * b_shift, 0, 255)
    img_recolored = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
    feathered_mask = cv2.GaussianBlur(dress_mask, (21, 21), 7)
    lightness_mask = (img_lab[..., 0] / 255.0) ** 0.7
    adaptive_feather = (feathered_mask * lightness_mask).astype(np.uint8)
    return (image_np * (1 - adaptive_feather[..., None] / 255) + img_recolored * (adaptive_feather[..., None] / 255)).astype(np.uint8)

def change_dress_color(img, color):
    color_map = {
        "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0),
        "Yellow": (0, 255, 255), "Purple": (128, 0, 128), "Orange": (0, 165, 255),
        "Cyan": (255, 255, 0), "Magenta": (255, 0, 255), "White": (255, 255, 255),
        "Black": (0, 0, 0)
    }
    new_color_bgr = color_map.get(color, (0, 0, 255))
    img_np = np.array(img)
    try:
        dress_mask = segment_dress(img_np)
        if np.sum(dress_mask) < 1000:
            return img
        dress_mask = apply_grabcut(img_np, dress_mask)
        img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)
        return Image.fromarray(img_recolored)
    except Exception as e:
        st.error(f"Error processing image: {str(e)}")
        return img

# ---------------------- STREAMLIT UI ---------------------- #
st.title("πŸ‘— AI Dress Color Changer")
st.markdown("Upload a dress image and select a new color for realistic recoloring")

uploaded_img = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
color_option = st.selectbox("Choose a Color", [
    "Red", "Blue", "Green", "Yellow", "Purple", 
    "Orange", "Cyan", "Magenta", "White", "Black"
])

if uploaded_img:
    image = Image.open(uploaded_img).convert("RGB")
    st.image(image, caption="Original Image", use_column_width=True)

    if st.button("Recolor Dress"):
        with st.spinner("Processing..."):
            result = change_dress_color(image, color_option)
            st.image(result, caption="Recolored Dress", use_column_width=True)