colorsteven / app.py
doublelotus's picture
Create app.py
2637637 verified
import streamlit as st
import torch
from diffusers import (
StableDiffusionXLControlNetPipeline,
ControlNetModel
)
from PIL import Image
# For demonstration, we’ll assume there’s a published SDXL lineart controlnet on HF.
# Replace with a valid repo if the name below doesn’t exist or adjust to your needs.
LINEART_CONTROLNET_REPO = "lllyasviel/sdxl-controlnet-lineart" # Example placeholder
SDXL_MODEL_REPO = "RunDiffusion/Juggernaut-XL-v9" # Or "stabilityai/stable-diffusion-xl-base-1.0"
@st.cache_resource
def load_pipeline():
"""
Loads the ControlNet model (line-art) and the main Stable Diffusion XL model (Juggernaut XL).
Returns a pipeline ready for inference.
"""
controlnet = ControlNetModel.from_pretrained(
LINEART_CONTROLNET_REPO,
torch_dtype=torch.float16
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
SDXL_MODEL_REPO,
controlnet=controlnet,
torch_dtype=torch.float16
)
# Move to GPU if available
if torch.cuda.is_available():
pipe.to("cuda")
return pipe
def combine_lineart_and_colormask(lineart: Image.Image, color_mask: Image.Image) -> Image.Image:
"""
Naive example of combining lineart and color mask into a single control image.
Here we just alpha-blend them for demonstration.
In practice, you might want more sophisticated merges,
or you could use multi-ControlNet if the pipelines/models are available.
"""
# Resize color mask to match lineart size
color_mask = color_mask.resize(lineart.size)
# Convert to RGBA
lineart_rgba = lineart.convert("RGBA")
color_mask_rgba = color_mask.convert("RGBA")
# Simple alpha blend for demonstration
blended = Image.blend(lineart_rgba, color_mask_rgba, alpha=0.5)
return blended.convert("RGB")
def main():
st.title("Line-Art + Color Mask with SDXL ControlNet")
st.markdown(
"Upload a **line-art sketch** and a **color mask**, then let "
"Stable Diffusion XL (Juggernaut XL) + ControlNet (Lineart) do the rest!"
)
# Sidebar inputs for text prompt, etc.
prompt = st.sidebar.text_input(
"Prompt",
value="A cute cartoon-style character with vibrant colors"
)
negative_prompt = st.sidebar.text_input(
"Negative Prompt",
value="ugly, deformed"
)
guidance_scale = st.sidebar.slider(
"Guidance Scale (classifier-free)",
min_value=1.0,
max_value=20.0,
value=9.0
)
num_inference_steps = st.sidebar.slider(
"Number of Inference Steps",
min_value=10,
max_value=100,
value=30
)
# Main area for uploading images
lineart_file = st.file_uploader("Upload Line-Art Sketch (png/jpg)", type=["png", "jpg", "jpeg"])
color_file = st.file_uploader("Upload Color Mask (png/jpg)", type=["png", "jpg", "jpeg"])
if lineart_file and color_file:
lineart_image = Image.open(lineart_file)
color_mask = Image.open(color_file)
st.image(lineart_image, caption="Line-Art Preview", width=300)
st.image(color_mask, caption="Color Mask Preview", width=300)
# Combine images into a single control image
combined_control_image = combine_lineart_and_colormask(lineart_image, color_mask)
st.image(combined_control_image, caption="Combined Control Image", width=300)
# Button to run inference
if st.button("Generate"):
pipe = load_pipeline()
with st.spinner("Generating image..."):
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
control_image=combined_control_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
# For SDXL pipelines, also pass an additional prompt for the refiner if needed
# refiner_prompt=prompt, # if your pipeline supports it
).images[0]
st.image(result, caption="Generated Image", width=512)
else:
st.warning("Please upload both a line-art sketch and a color mask.")
if __name__ == "__main__":
main()