jamino30 commited on
Commit
b0b9200
·
verified ·
1 Parent(s): 873d9c6

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +67 -0
  2. requirements.txt +2 -0
  3. vae_supertux.pth +3 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import gradio as gr
6
+
7
+ # Define the VAE model
8
+ class ConvVAE(nn.Module):
9
+ def __init__(self, input_channels=3, latent_dim=16):
10
+ super(ConvVAE, self).__init__()
11
+ self.latent_dim = latent_dim
12
+ self.enc_conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1)
13
+ self.enc_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
14
+ self.enc_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
15
+ self.fc_mu = nn.Linear(5120, latent_dim)
16
+ self.fc_logvar = nn.Linear(5120, latent_dim)
17
+ self.fc_decode = nn.Linear(latent_dim, 5120)
18
+ self.dec_conv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1)
19
+ self.dec_conv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
20
+ self.dec_conv3 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=(0,1))
21
+
22
+ def reparameterize(self, mu, logvar):
23
+ std = torch.exp(0.5 * logvar)
24
+ eps = torch.randn_like(std)
25
+ return mu + eps * std
26
+
27
+ def forward(self, x):
28
+ x = F.relu(self.enc_conv1(x))
29
+ x = F.relu(self.enc_conv2(x))
30
+ x = F.relu(self.enc_conv3(x))
31
+ x = x.view(x.size(0), -1)
32
+ mu = self.fc_mu(x)
33
+ logvar = self.fc_logvar(x)
34
+ z = self.reparameterize(mu, logvar)
35
+ return self.decode(z)
36
+
37
+ def decode(self, z):
38
+ x = F.relu(self.fc_decode(z))
39
+ x = x.view(x.size(0), 128, 4, 10)
40
+ x = F.relu(self.dec_conv1(x))
41
+ x = F.relu(self.dec_conv2(x))
42
+ x = self.dec_conv3(x)
43
+ return F.softmax(x, dim=1)
44
+
45
+ # Load model
46
+ model = ConvVAE()
47
+ model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
48
+ model.eval()
49
+
50
+ def generate_map(seed: int = None):
51
+ if seed:
52
+ torch.manual_seed(seed)
53
+ z = torch.randn(1, model.latent_dim)
54
+ with torch.no_grad():
55
+ output = model.decode(z) # Shape: (1, 3, 15, 40)
56
+ output = output.squeeze(0).argmax(dim=0)
57
+ grid = output.cpu().numpy()
58
+ padded_grid = np.vstack([np.zeros((5, grid.shape[1]), dtype=int), grid]) # Append 5 rows of zeros
59
+ return ["".join(map(str, row)) for row in padded_grid] # Convert each row to a string
60
+
61
+ gr.Interface(
62
+ fn=generate_map,
63
+ inputs=gr.Number(label="Seed"),
64
+ outputs=gr.JSON(label="Generated Map Grid"),
65
+ title="VAE Level Generator",
66
+ description="Returns a 20x40 grid as a list of strings where 0=air, 1=ground, 2=lava"
67
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ gradio
vae_supertux.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e887c368f7ef92170aed65ef3f9eddf719f6f972178dbcdd66b01ad097171f4
3
+ size 1755214