Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- app.py +67 -0
- requirements.txt +2 -0
- vae_supertux.pth +3 -0
app.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
# Define the VAE model
|
8 |
+
class ConvVAE(nn.Module):
|
9 |
+
def __init__(self, input_channels=3, latent_dim=16):
|
10 |
+
super(ConvVAE, self).__init__()
|
11 |
+
self.latent_dim = latent_dim
|
12 |
+
self.enc_conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1)
|
13 |
+
self.enc_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
|
14 |
+
self.enc_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
15 |
+
self.fc_mu = nn.Linear(5120, latent_dim)
|
16 |
+
self.fc_logvar = nn.Linear(5120, latent_dim)
|
17 |
+
self.fc_decode = nn.Linear(latent_dim, 5120)
|
18 |
+
self.dec_conv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1)
|
19 |
+
self.dec_conv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
|
20 |
+
self.dec_conv3 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=(0,1))
|
21 |
+
|
22 |
+
def reparameterize(self, mu, logvar):
|
23 |
+
std = torch.exp(0.5 * logvar)
|
24 |
+
eps = torch.randn_like(std)
|
25 |
+
return mu + eps * std
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
x = F.relu(self.enc_conv1(x))
|
29 |
+
x = F.relu(self.enc_conv2(x))
|
30 |
+
x = F.relu(self.enc_conv3(x))
|
31 |
+
x = x.view(x.size(0), -1)
|
32 |
+
mu = self.fc_mu(x)
|
33 |
+
logvar = self.fc_logvar(x)
|
34 |
+
z = self.reparameterize(mu, logvar)
|
35 |
+
return self.decode(z)
|
36 |
+
|
37 |
+
def decode(self, z):
|
38 |
+
x = F.relu(self.fc_decode(z))
|
39 |
+
x = x.view(x.size(0), 128, 4, 10)
|
40 |
+
x = F.relu(self.dec_conv1(x))
|
41 |
+
x = F.relu(self.dec_conv2(x))
|
42 |
+
x = self.dec_conv3(x)
|
43 |
+
return F.softmax(x, dim=1)
|
44 |
+
|
45 |
+
# Load model
|
46 |
+
model = ConvVAE()
|
47 |
+
model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
|
48 |
+
model.eval()
|
49 |
+
|
50 |
+
def generate_map(seed: int = None):
|
51 |
+
if seed:
|
52 |
+
torch.manual_seed(seed)
|
53 |
+
z = torch.randn(1, model.latent_dim)
|
54 |
+
with torch.no_grad():
|
55 |
+
output = model.decode(z) # Shape: (1, 3, 15, 40)
|
56 |
+
output = output.squeeze(0).argmax(dim=0)
|
57 |
+
grid = output.cpu().numpy()
|
58 |
+
padded_grid = np.vstack([np.zeros((5, grid.shape[1]), dtype=int), grid]) # Append 5 rows of zeros
|
59 |
+
return ["".join(map(str, row)) for row in padded_grid] # Convert each row to a string
|
60 |
+
|
61 |
+
gr.Interface(
|
62 |
+
fn=generate_map,
|
63 |
+
inputs=gr.Number(label="Seed"),
|
64 |
+
outputs=gr.JSON(label="Generated Map Grid"),
|
65 |
+
title="VAE Level Generator",
|
66 |
+
description="Returns a 20x40 grid as a list of strings where 0=air, 1=ground, 2=lava"
|
67 |
+
).launch()
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
gradio
|
vae_supertux.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e887c368f7ef92170aed65ef3f9eddf719f6f972178dbcdd66b01ad097171f4
|
3 |
+
size 1755214
|