SaiShailesh's picture
Upload 2 files
4c1e086 verified
import streamlit as st
import torch
from torch import nn
from diffusers import DDPMScheduler, UNet2DModel
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
# Reuse your existing model code
class ClassConditionedUnet(nn.Module):
def __init__(self, num_classes=3, class_emb_size=12):
super().__init__()
self.class_emb = nn.Embedding(num_classes, class_emb_size)
self.model = UNet2DModel(
sample_size=64,
in_channels=3 + class_emb_size,
out_channels=3,
layers_per_block=2,
block_out_channels=(64, 128, 256, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
def forward(self, x, t, class_labels):
bs, ch, w, h = x.shape
class_cond = self.class_emb(class_labels)
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
net_input = torch.cat((x, class_cond), 1)
return self.model(net_input, t).sample
@st.cache_resource
def load_model(model_path):
"""Load the model with caching to avoid reloading"""
device = 'cpu' # For deployment, we'll use CPU
net = ClassConditionedUnet().to(device)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
checkpoint = torch.load(model_path, map_location='cpu')
net.load_state_dict(checkpoint['model_state_dict'])
return net, noise_scheduler
def generate_mixed_faces(net, noise_scheduler, mix_weights, num_images=1):
"""Generate faces with mixed ethnic features"""
device = next(net.parameters()).device
net.eval()
with torch.no_grad():
x = torch.randn(num_images, 3, 64, 64).to(device)
# Get embeddings for all classes
emb_asian = net.class_emb(torch.zeros(num_images).long().to(device))
emb_indian = net.class_emb(torch.ones(num_images).long().to(device))
emb_european = net.class_emb(torch.full((num_images,), 2).to(device))
progress_bar = st.progress(0)
for idx, t in enumerate(noise_scheduler.timesteps):
# Update progress bar
progress_bar.progress(idx / len(noise_scheduler.timesteps))
# Mix embeddings according to weights
mixed_emb = (
mix_weights[0] * emb_asian +
mix_weights[1] * emb_indian +
mix_weights[2] * emb_european
)
# Override embedding layer temporarily
original_forward = net.class_emb.forward
net.class_emb.forward = lambda _: mixed_emb
residual = net(x, t, torch.zeros(num_images).long().to(device))
x = noise_scheduler.step(residual, t, x).prev_sample
# Restore original embedding layer
net.class_emb.forward = original_forward
progress_bar.progress(1.0)
x = (x.clamp(-1, 1) + 1) / 2
return x
def main():
st.title("AI Face Generator with Ethnic Features Mixing")
# Load model
try:
net, noise_scheduler = load_model('final_model/final_diffusion_model.pt')
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return
# Create sliders for ethnicity percentages
st.subheader("Adjust Ethnicity Mix")
col1, col2, col3 = st.columns(3)
with col1:
asian_pct = st.slider("Asian Features %", 0, 100, 33, 1)
with col2:
indian_pct = st.slider("Indian Features %", 0, 100, 33, 1)
with col3:
european_pct = st.slider("European Features %", 0, 100, 34, 1)
# Calculate total and normalize if needed
total = asian_pct + indian_pct + european_pct
if total == 0:
st.warning("Total percentage cannot be 0%. Please adjust the sliders.")
return
# Normalize weights to sum to 1
weights = [asian_pct/total, indian_pct/total, european_pct/total]
# Display current mix
st.write("Current mix (normalized):")
st.write(f"Asian: {weights[0]:.2%}, Indian: {weights[1]:.2%}, European: {weights[2]:.2%}")
# Generate button
if st.button("Generate Face"):
try:
with st.spinner("Generating face..."):
# Generate the image
generated_images = generate_mixed_faces(net, noise_scheduler, weights)
# Convert to numpy and display
img = generated_images[0].permute(1, 2, 0).cpu().numpy()
st.image(img, caption="Generated Face", use_column_width=True)
except Exception as e:
st.error(f"Error generating image: {str(e)}")
if __name__ == "__main__":
main()