import subprocess from pathlib import Path import einops import gradio as gr import numpy as np import torch from huggingface_hub import hf_hub_download from PIL import Image from torch import nn from torchvision.utils import save_image class Generator(nn.Module): def __init__(self, nc=4, nz=100, ngf=64): super(Generator, self).__init__() self.network = nn.Sequential( nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False), nn.BatchNorm2d(ngf), nn.ReLU(True), nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), nn.Tanh(), ) def forward(self, input): output = self.network(input) return output model = Generator() weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth') model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) @torch.no_grad() def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8): save_dir = Path(save_dir) save_dir.mkdir(exist_ok=True, parents=True) z1 = torch.randn(rows * cols, 100, 1, 1) z2 = torch.randn(rows * cols, 100, 1, 1) zs = [] for i in range(frames): alpha = i / frames z = (1 - alpha) * z1 + alpha * z2 zs.append(z) zs += zs[::-1] # also go in reverse order to complete loop for i, z in enumerate(zs): imgs = model(z) # normalize imgs = (imgs + 1) / 2 imgs = (imgs.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8) # create grid imgs = einops.rearrange(imgs, "(b1 b2) h w c -> (b1 h) (b2 w) c", b1=rows, b2=cols) Image.fromarray(imgs).save(save_dir / f"{i:03}.png") subprocess.call(f"convert -dispose previous -delay 10 -loop 0 {save_dir}/*.png out.gif".split()) def predict(choice, seed): torch.manual_seed(seed) if choice == 'interpolation': interpolate() return 'out.gif' else: z = torch.randn(64, 100, 1, 1) punks = model(z) save_image(punks, "punks.png", normalize=True) return 'punks.png' gr.Interface( predict, inputs=[ gr.inputs.Dropdown(['image', 'interpolation'], label='Output Type'), gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42), ], outputs="image", title="Cryptopunks GAN", description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.", article="
Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks | Github Repo
", examples=[["interpolation", 123], ["interpolation", 42], ["image", 456], ["image", 42]], ).launch(cache_examples=True)