Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from types import SimpleNamespace | |
import torch | |
import torch.nn.functional as F | |
import wandb | |
import streamlit as st | |
from utilities import ContextUnet, setup_ddpm | |
from constants import WANDB_API_KEY | |
def load_model(): | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
#login to wandb | |
#wandb.login(key=WANDB_API_KEY) | |
"Load the model from wandb artifacts" | |
api = wandb.Api(api_key=WANDB_API_KEY) | |
artifact = api.artifact("teamaditya/model-registry/Feature2Sprite:v0", type="model") | |
model_path = Path(artifact.download()) | |
# recover model info from the registry | |
producer_run = artifact.logged_by() | |
# load the weights dictionary | |
model_weights = torch.load(model_path/"context_model.pth", | |
map_location="cpu") | |
# create the model | |
model = ContextUnet(in_channels=3, | |
n_feat=producer_run.config["n_feat"], | |
n_cfeat=producer_run.config["n_cfeat"], | |
height=producer_run.config["height"]) | |
# load the weights into the model | |
model.load_state_dict(model_weights) | |
# set the model to eval mode | |
model.eval() | |
return model.to(DEVICE) | |
def show_image(img): | |
img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2 | |
st.image(img, clamp=True) | |
def generate_sprites(feature_vector): | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
config = SimpleNamespace( | |
# hyperparameters | |
num_samples = 30, | |
# ddpm sampler hyperparameters | |
timesteps = 500, | |
beta1 = 1e-4, | |
beta2 = 0.02, | |
# network hyperparameters | |
height = 16, | |
) | |
nn_model = load_model() | |
_, sample_ddpm_context = setup_ddpm(config.beta1, | |
config.beta2, | |
config.timesteps, | |
DEVICE) | |
noises = torch.randn(config.num_samples, 3, | |
config.height, config.height).to(DEVICE) | |
feature_vector = torch.tensor([feature_vector]).to(DEVICE).float() | |
ddpm_samples, _ = sample_ddpm_context(nn_model, noises, feature_vector) | |
#upscale the 16*16 images to 256*256 | |
ddpm_samples = F.interpolate(ddpm_samples, size=(256, 256), mode="bilinear") | |
# show the images | |
show_image(ddpm_samples[0]) | |