Spaces:
Build error
Build error
import streamlit as st | |
import torch | |
from torch import nn | |
from diffusers import DDPMScheduler, UNet2DModel | |
import matplotlib.pyplot as plt | |
from tqdm.auto import tqdm | |
# Reuse your existing model code | |
class ClassConditionedUnet(nn.Module): | |
def __init__(self, num_classes=3, class_emb_size=12): | |
super().__init__() | |
self.class_emb = nn.Embedding(num_classes, class_emb_size) | |
self.model = UNet2DModel( | |
sample_size=64, | |
in_channels=3 + class_emb_size, | |
out_channels=3, | |
layers_per_block=2, | |
block_out_channels=(64, 128, 256, 512), | |
down_block_types=( | |
"DownBlock2D", | |
"DownBlock2D", | |
"AttnDownBlock2D", | |
"AttnDownBlock2D", | |
), | |
up_block_types=( | |
"AttnUpBlock2D", | |
"AttnUpBlock2D", | |
"UpBlock2D", | |
"UpBlock2D", | |
), | |
) | |
def forward(self, x, t, class_labels): | |
bs, ch, w, h = x.shape | |
class_cond = self.class_emb(class_labels) | |
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h) | |
net_input = torch.cat((x, class_cond), 1) | |
return self.model(net_input, t).sample | |
def load_model(model_path): | |
"""Load the model with caching to avoid reloading""" | |
device = 'cpu' # For deployment, we'll use CPU | |
net = ClassConditionedUnet().to(device) | |
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2') | |
checkpoint = torch.load(model_path, map_location='cpu') | |
net.load_state_dict(checkpoint['model_state_dict']) | |
return net, noise_scheduler | |
def generate_mixed_faces(net, noise_scheduler, mix_weights, num_images=1): | |
"""Generate faces with mixed ethnic features""" | |
device = next(net.parameters()).device | |
net.eval() | |
with torch.no_grad(): | |
x = torch.randn(num_images, 3, 64, 64).to(device) | |
# Get embeddings for all classes | |
emb_asian = net.class_emb(torch.zeros(num_images).long().to(device)) | |
emb_indian = net.class_emb(torch.ones(num_images).long().to(device)) | |
emb_european = net.class_emb(torch.full((num_images,), 2).to(device)) | |
progress_bar = st.progress(0) | |
for idx, t in enumerate(noise_scheduler.timesteps): | |
# Update progress bar | |
progress_bar.progress(idx / len(noise_scheduler.timesteps)) | |
# Mix embeddings according to weights | |
mixed_emb = ( | |
mix_weights[0] * emb_asian + | |
mix_weights[1] * emb_indian + | |
mix_weights[2] * emb_european | |
) | |
# Override embedding layer temporarily | |
original_forward = net.class_emb.forward | |
net.class_emb.forward = lambda _: mixed_emb | |
residual = net(x, t, torch.zeros(num_images).long().to(device)) | |
x = noise_scheduler.step(residual, t, x).prev_sample | |
# Restore original embedding layer | |
net.class_emb.forward = original_forward | |
progress_bar.progress(1.0) | |
x = (x.clamp(-1, 1) + 1) / 2 | |
return x | |
def main(): | |
st.title("AI Face Generator with Ethnic Features Mixing") | |
# Load model | |
try: | |
net, noise_scheduler = load_model('final_model/final_diffusion_model.pt') | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return | |
# Create sliders for ethnicity percentages | |
st.subheader("Adjust Ethnicity Mix") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
asian_pct = st.slider("Asian Features %", 0, 100, 33, 1) | |
with col2: | |
indian_pct = st.slider("Indian Features %", 0, 100, 33, 1) | |
with col3: | |
european_pct = st.slider("European Features %", 0, 100, 34, 1) | |
# Calculate total and normalize if needed | |
total = asian_pct + indian_pct + european_pct | |
if total == 0: | |
st.warning("Total percentage cannot be 0%. Please adjust the sliders.") | |
return | |
# Normalize weights to sum to 1 | |
weights = [asian_pct/total, indian_pct/total, european_pct/total] | |
# Display current mix | |
st.write("Current mix (normalized):") | |
st.write(f"Asian: {weights[0]:.2%}, Indian: {weights[1]:.2%}, European: {weights[2]:.2%}") | |
# Generate button | |
if st.button("Generate Face"): | |
try: | |
with st.spinner("Generating face..."): | |
# Generate the image | |
generated_images = generate_mixed_faces(net, noise_scheduler, weights) | |
# Convert to numpy and display | |
img = generated_images[0].permute(1, 2, 0).cpu().numpy() | |
st.image(img, caption="Generated Face", use_column_width=True) | |
except Exception as e: | |
st.error(f"Error generating image: {str(e)}") | |
if __name__ == "__main__": | |
main() |