SaiShailesh's picture
Upload 2 files
4c1e086 verified
raw
history blame
5.18 kB
import streamlit as st
import torch
from torch import nn
from diffusers import DDPMScheduler, UNet2DModel
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
# Reuse your existing model code
class ClassConditionedUnet(nn.Module):
def __init__(self, num_classes=3, class_emb_size=12):
super().__init__()
self.class_emb = nn.Embedding(num_classes, class_emb_size)
self.model = UNet2DModel(
sample_size=64,
in_channels=3 + class_emb_size,
out_channels=3,
layers_per_block=2,
block_out_channels=(64, 128, 256, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
def forward(self, x, t, class_labels):
bs, ch, w, h = x.shape
class_cond = self.class_emb(class_labels)
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
net_input = torch.cat((x, class_cond), 1)
return self.model(net_input, t).sample
@st.cache_resource
def load_model(model_path):
"""Load the model with caching to avoid reloading"""
device = 'cpu' # For deployment, we'll use CPU
net = ClassConditionedUnet().to(device)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
checkpoint = torch.load(model_path, map_location='cpu')
net.load_state_dict(checkpoint['model_state_dict'])
return net, noise_scheduler
def generate_mixed_faces(net, noise_scheduler, mix_weights, num_images=1):
"""Generate faces with mixed ethnic features"""
device = next(net.parameters()).device
net.eval()
with torch.no_grad():
x = torch.randn(num_images, 3, 64, 64).to(device)
# Get embeddings for all classes
emb_asian = net.class_emb(torch.zeros(num_images).long().to(device))
emb_indian = net.class_emb(torch.ones(num_images).long().to(device))
emb_european = net.class_emb(torch.full((num_images,), 2).to(device))
progress_bar = st.progress(0)
for idx, t in enumerate(noise_scheduler.timesteps):
# Update progress bar
progress_bar.progress(idx / len(noise_scheduler.timesteps))
# Mix embeddings according to weights
mixed_emb = (
mix_weights[0] * emb_asian +
mix_weights[1] * emb_indian +
mix_weights[2] * emb_european
)
# Override embedding layer temporarily
original_forward = net.class_emb.forward
net.class_emb.forward = lambda _: mixed_emb
residual = net(x, t, torch.zeros(num_images).long().to(device))
x = noise_scheduler.step(residual, t, x).prev_sample
# Restore original embedding layer
net.class_emb.forward = original_forward
progress_bar.progress(1.0)
x = (x.clamp(-1, 1) + 1) / 2
return x
def main():
st.title("AI Face Generator with Ethnic Features Mixing")
# Load model
try:
net, noise_scheduler = load_model('final_model/final_diffusion_model.pt')
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return
# Create sliders for ethnicity percentages
st.subheader("Adjust Ethnicity Mix")
col1, col2, col3 = st.columns(3)
with col1:
asian_pct = st.slider("Asian Features %", 0, 100, 33, 1)
with col2:
indian_pct = st.slider("Indian Features %", 0, 100, 33, 1)
with col3:
european_pct = st.slider("European Features %", 0, 100, 34, 1)
# Calculate total and normalize if needed
total = asian_pct + indian_pct + european_pct
if total == 0:
st.warning("Total percentage cannot be 0%. Please adjust the sliders.")
return
# Normalize weights to sum to 1
weights = [asian_pct/total, indian_pct/total, european_pct/total]
# Display current mix
st.write("Current mix (normalized):")
st.write(f"Asian: {weights[0]:.2%}, Indian: {weights[1]:.2%}, European: {weights[2]:.2%}")
# Generate button
if st.button("Generate Face"):
try:
with st.spinner("Generating face..."):
# Generate the image
generated_images = generate_mixed_faces(net, noise_scheduler, weights)
# Convert to numpy and display
img = generated_images[0].permute(1, 2, 0).cpu().numpy()
st.image(img, caption="Generated Face", use_column_width=True)
except Exception as e:
st.error(f"Error generating image: {str(e)}")
if __name__ == "__main__":
main()