Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,9 +5,6 @@ from PIL import Image
|
|
5 |
from diffusers import StableDiffusionPipeline
|
6 |
import streamlit as st
|
7 |
from transformers import CLIPTokenizer
|
8 |
-
import os
|
9 |
-
from io import BytesIO
|
10 |
-
from huggingface_hub import login
|
11 |
|
12 |
# Define the device
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -30,7 +27,7 @@ class CustomImageDataset(Dataset):
|
|
30 |
return image, prompt
|
31 |
|
32 |
# Function to fine-tune the model
|
33 |
-
def fine_tune_model(images, prompts, num_epochs=3
|
34 |
transform = transforms.Compose([
|
35 |
transforms.Resize((512, 512)),
|
36 |
transforms.ToTensor(),
|
@@ -46,14 +43,18 @@ def fine_tune_model(images, prompts, num_epochs=3, model_name="fine_tuned_model"
|
|
46 |
vae = pipeline.vae.to(device)
|
47 |
unet = pipeline.unet.to(device)
|
48 |
text_encoder = pipeline.text_encoder.to(device)
|
49 |
-
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
50 |
-
optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6)
|
51 |
|
|
|
52 |
timesteps = torch.linspace(0, 1, steps=5).to(device)
|
53 |
|
|
|
54 |
for epoch in range(num_epochs):
|
55 |
for i, (images, prompts) in enumerate(dataloader):
|
56 |
-
images = images.to(device)
|
|
|
|
|
57 |
inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device)
|
58 |
|
59 |
latents = vae.encode(images).latent_dist.sample() * 0.18215
|
@@ -62,6 +63,7 @@ def fine_tune_model(images, prompts, num_epochs=3, model_name="fine_tuned_model"
|
|
62 |
noise = torch.randn_like(latents).to(device)
|
63 |
noisy_latents = latents + noise
|
64 |
|
|
|
65 |
timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float()
|
66 |
pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample
|
67 |
|
@@ -73,74 +75,27 @@ def fine_tune_model(images, prompts, num_epochs=3, model_name="fine_tuned_model"
|
|
73 |
if i % 10 == 0:
|
74 |
st.write(f"Epoch {epoch+1}/{num_epochs}, Step {i+1}/{len(dataloader)}, Loss: {loss.item()}")
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
else:
|
84 |
-
# Save the fine-tuned model components locally
|
85 |
-
model_path = os.path.join(save_dir, model_name)
|
86 |
-
os.makedirs(model_path, exist_ok=True)
|
87 |
-
vae.save_pretrained(os.path.join(model_path, "vae"))
|
88 |
-
unet.save_pretrained(os.path.join(model_path, "unet"))
|
89 |
-
text_encoder.save_pretrained(os.path.join(model_path, "text_encoder"))
|
90 |
-
tokenizer.save_pretrained(os.path.join(model_path, "tokenizer"))
|
91 |
-
st.success(f"Fine-tuning completed and model saved as {model_name} locally!")
|
92 |
-
|
93 |
-
# Function to load fine-tuned model
|
94 |
-
def load_fine_tuned_model(model_name, from_hub=False, save_dir="fine_tuned_models"):
|
95 |
-
if from_hub:
|
96 |
-
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
|
97 |
-
pipeline.vae = pipeline.vae.from_pretrained(model_name + "-vae").to(device)
|
98 |
-
pipeline.unet = pipeline.unet.from_pretrained(model_name + "-unet").to(device)
|
99 |
-
pipeline.text_encoder = pipeline.text_encoder.from_pretrained(model_name + "-text-encoder").to(device)
|
100 |
-
tokenizer = CLIPTokenizer.from_pretrained(model_name + "-tokenizer")
|
101 |
-
else:
|
102 |
-
model_path = os.path.join(save_dir, model_name)
|
103 |
-
if not os.path.exists(model_path):
|
104 |
-
raise OSError(f"Model directory {model_path} does not exist.")
|
105 |
-
|
106 |
-
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
|
107 |
-
pipeline.vae = pipeline.vae.from_pretrained(os.path.join(model_path, "vae")).to(device)
|
108 |
-
pipeline.unet = pipeline.unet.from_pretrained(os.path.join(model_path, "unet")).to(device)
|
109 |
-
pipeline.text_encoder = pipeline.text_encoder.from_pretrained(os.path.join(model_path, "text_encoder")).to(device)
|
110 |
-
tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
|
111 |
-
return pipeline, tokenizer
|
112 |
-
|
113 |
-
# Function to list available fine-tuned models
|
114 |
-
def list_available_models(save_dir="fine_tuned_models"):
|
115 |
-
models = []
|
116 |
-
if os.path.exists(save_dir):
|
117 |
-
models = [name for name in os.listdir(save_dir) if os.path.isdir(os.path.join(save_dir, name))]
|
118 |
-
return models
|
119 |
|
120 |
# Function to generate images
|
121 |
def generate_images(pipeline, prompt):
|
122 |
with torch.no_grad():
|
123 |
# Generate image from the prompt
|
124 |
output = pipeline(prompt)
|
|
|
|
|
125 |
image = output.images[0] # Get the first generated image
|
126 |
return image
|
127 |
|
128 |
# Streamlit app layout
|
129 |
st.title("Fine-Tune Stable Diffusion with Your Images")
|
130 |
|
131 |
-
# Hugging Face login
|
132 |
-
hf_token = st.text_input("Enter your Hugging Face token", type="password")
|
133 |
-
if hf_token:
|
134 |
-
login(token=hf_token)
|
135 |
-
st.success("Logged in to Hugging Face!")
|
136 |
-
|
137 |
-
# List available fine-tuned models
|
138 |
-
available_models = list_available_models()
|
139 |
-
model_choice = st.selectbox(
|
140 |
-
"Select a model to use",
|
141 |
-
options=["Pre-trained Stable Diffusion"] + available_models
|
142 |
-
)
|
143 |
-
|
144 |
# Upload images
|
145 |
uploaded_files = st.file_uploader("Upload your images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
|
146 |
|
@@ -154,20 +109,9 @@ if uploaded_files:
|
|
154 |
prompt = st.text_input(f"Enter a prompt for {file.name}")
|
155 |
prompts.append(prompt)
|
156 |
|
157 |
-
#
|
158 |
if st.button("Start Fine-Tuning") and uploaded_files and prompts:
|
159 |
-
|
160 |
-
model_name = st.text_input("Enter a name for the fine-tuned model")
|
161 |
-
push_to_hub = st.checkbox("Push to Hugging Face Hub", value=True)
|
162 |
-
if model_name:
|
163 |
-
st.write("Fine-tuning pre-trained model...")
|
164 |
-
fine_tune_model(images, prompts, model_name=model_name, push_to_hub=push_to_hub)
|
165 |
-
else:
|
166 |
-
st.error("Please enter a name for the fine-tuned model.")
|
167 |
-
else:
|
168 |
-
st.write(f"Loading fine-tuned model: {model_choice}")
|
169 |
-
pipeline, tokenizer = load_fine_tuned_model(model_choice)
|
170 |
-
st.write("Model loaded. You can now generate images using this fine-tuned model.")
|
171 |
|
172 |
# Generate new images
|
173 |
st.subheader("Generate New Images")
|
@@ -175,14 +119,12 @@ new_prompt = st.text_input("Enter a prompt to generate a new image")
|
|
175 |
if st.button("Generate Image"):
|
176 |
if new_prompt:
|
177 |
with st.spinner("Generating image..."):
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
image_io.seek(0)
|
188 |
-
st.download_button(label="Download Image", data=image_io, file_name="generated_image.png", mime="image/png")
|
|
|
5 |
from diffusers import StableDiffusionPipeline
|
6 |
import streamlit as st
|
7 |
from transformers import CLIPTokenizer
|
|
|
|
|
|
|
8 |
|
9 |
# Define the device
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
27 |
return image, prompt
|
28 |
|
29 |
# Function to fine-tune the model
|
30 |
+
def fine_tune_model(images, prompts, num_epochs=3):
|
31 |
transform = transforms.Compose([
|
32 |
transforms.Resize((512, 512)),
|
33 |
transforms.ToTensor(),
|
|
|
43 |
vae = pipeline.vae.to(device)
|
44 |
unet = pipeline.unet.to(device)
|
45 |
text_encoder = pipeline.text_encoder.to(device)
|
46 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") # Ensure correct tokenizer is used
|
47 |
+
optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6) # Define the optimizer
|
48 |
|
49 |
+
# Define timestep range for training
|
50 |
timesteps = torch.linspace(0, 1, steps=5).to(device)
|
51 |
|
52 |
+
# Fine-tuning loop
|
53 |
for epoch in range(num_epochs):
|
54 |
for i, (images, prompts) in enumerate(dataloader):
|
55 |
+
images = images.to(device) # Move images to GPU if available
|
56 |
+
|
57 |
+
# Tokenize the prompts
|
58 |
inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device)
|
59 |
|
60 |
latents = vae.encode(images).latent_dist.sample() * 0.18215
|
|
|
63 |
noise = torch.randn_like(latents).to(device)
|
64 |
noisy_latents = latents + noise
|
65 |
|
66 |
+
# Pass text embeddings and timestep to UNet
|
67 |
timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float()
|
68 |
pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample
|
69 |
|
|
|
75 |
if i % 10 == 0:
|
76 |
st.write(f"Epoch {epoch+1}/{num_epochs}, Step {i+1}/{len(dataloader)}, Loss: {loss.item()}")
|
77 |
|
78 |
+
st.success("Fine-tuning completed!")
|
79 |
+
|
80 |
+
# Function to convert tensor to PIL Image
|
81 |
+
def tensor_to_pil(tensor):
|
82 |
+
tensor = tensor.squeeze().cpu().clamp(0, 1) # Remove batch dimension if necessary
|
83 |
+
tensor = transforms.ToPILImage()(tensor)
|
84 |
+
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# Function to generate images
|
87 |
def generate_images(pipeline, prompt):
|
88 |
with torch.no_grad():
|
89 |
# Generate image from the prompt
|
90 |
output = pipeline(prompt)
|
91 |
+
|
92 |
+
# Convert the output to PIL Image
|
93 |
image = output.images[0] # Get the first generated image
|
94 |
return image
|
95 |
|
96 |
# Streamlit app layout
|
97 |
st.title("Fine-Tune Stable Diffusion with Your Images")
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
# Upload images
|
100 |
uploaded_files = st.file_uploader("Upload your images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
|
101 |
|
|
|
109 |
prompt = st.text_input(f"Enter a prompt for {file.name}")
|
110 |
prompts.append(prompt)
|
111 |
|
112 |
+
# Start fine-tuning
|
113 |
if st.button("Start Fine-Tuning") and uploaded_files and prompts:
|
114 |
+
fine_tune_model(images, prompts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
# Generate new images
|
117 |
st.subheader("Generate New Images")
|
|
|
119 |
if st.button("Generate Image"):
|
120 |
if new_prompt:
|
121 |
with st.spinner("Generating image..."):
|
122 |
+
# Use the fine-tuned pipeline for generation
|
123 |
+
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device) # Load the fine-tuned model
|
124 |
+
image = generate_images(pipeline, new_prompt)
|
125 |
+
st.image(image, caption="Generated Image") # Display the generated image
|
126 |
+
|
127 |
+
# Save the generated image for download
|
128 |
+
image.save("generated_image.png")
|
129 |
+
st.download_button(label="Download Image", data=open("generated_image.png", "rb"), file_name="generated_image.png")
|
130 |
+
|
|
|
|