File size: 4,279 Bytes
2637637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st
import torch
from diffusers import (
    StableDiffusionXLControlNetPipeline,
    ControlNetModel
)
from PIL import Image

# For demonstration, we’ll assume there’s a published SDXL lineart controlnet on HF.
# Replace with a valid repo if the name below doesn’t exist or adjust to your needs.
LINEART_CONTROLNET_REPO = "lllyasviel/sdxl-controlnet-lineart"  # Example placeholder
SDXL_MODEL_REPO = "RunDiffusion/Juggernaut-XL-v9"               # Or "stabilityai/stable-diffusion-xl-base-1.0"

@st.cache_resource
def load_pipeline():
    """
    Loads the ControlNet model (line-art) and the main Stable Diffusion XL model (Juggernaut XL).
    Returns a pipeline ready for inference.
    """
    controlnet = ControlNetModel.from_pretrained(
        LINEART_CONTROLNET_REPO,
        torch_dtype=torch.float16
    )

    pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
        SDXL_MODEL_REPO,
        controlnet=controlnet,
        torch_dtype=torch.float16
    )
    
    # Move to GPU if available
    if torch.cuda.is_available():
        pipe.to("cuda")

    return pipe

def combine_lineart_and_colormask(lineart: Image.Image, color_mask: Image.Image) -> Image.Image:
    """
    Naive example of combining lineart and color mask into a single control image. 
    Here we just alpha-blend them for demonstration. 
    In practice, you might want more sophisticated merges, 
    or you could use multi-ControlNet if the pipelines/models are available.
    """
    # Resize color mask to match lineart size
    color_mask = color_mask.resize(lineart.size)
    
    # Convert to RGBA
    lineart_rgba = lineart.convert("RGBA")
    color_mask_rgba = color_mask.convert("RGBA")
    
    # Simple alpha blend for demonstration
    blended = Image.blend(lineart_rgba, color_mask_rgba, alpha=0.5)
    return blended.convert("RGB")

def main():
    st.title("Line-Art + Color Mask with SDXL ControlNet")
    st.markdown(
        "Upload a **line-art sketch** and a **color mask**, then let "
        "Stable Diffusion XL (Juggernaut XL) + ControlNet (Lineart) do the rest!"
    )

    # Sidebar inputs for text prompt, etc.
    prompt = st.sidebar.text_input(
        "Prompt",
        value="A cute cartoon-style character with vibrant colors"
    )
    negative_prompt = st.sidebar.text_input(
        "Negative Prompt",
        value="ugly, deformed"
    )
    guidance_scale = st.sidebar.slider(
        "Guidance Scale (classifier-free)",
        min_value=1.0,
        max_value=20.0,
        value=9.0
    )
    num_inference_steps = st.sidebar.slider(
        "Number of Inference Steps",
        min_value=10,
        max_value=100,
        value=30
    )

    # Main area for uploading images
    lineart_file = st.file_uploader("Upload Line-Art Sketch (png/jpg)", type=["png", "jpg", "jpeg"])
    color_file = st.file_uploader("Upload Color Mask (png/jpg)", type=["png", "jpg", "jpeg"])

    if lineart_file and color_file:
        lineart_image = Image.open(lineart_file)
        color_mask = Image.open(color_file)

        st.image(lineart_image, caption="Line-Art Preview", width=300)
        st.image(color_mask, caption="Color Mask Preview", width=300)

        # Combine images into a single control image
        combined_control_image = combine_lineart_and_colormask(lineart_image, color_mask)
        st.image(combined_control_image, caption="Combined Control Image", width=300)

        # Button to run inference
        if st.button("Generate"):
            pipe = load_pipeline()

            with st.spinner("Generating image..."):
                result = pipe(
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    control_image=combined_control_image,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    # For SDXL pipelines, also pass an additional prompt for the refiner if needed
                    # refiner_prompt=prompt,  # if your pipeline supports it
                ).images[0]

            st.image(result, caption="Generated Image", width=512)
    else:
        st.warning("Please upload both a line-art sketch and a color mask.")

if __name__ == "__main__":
    main()