Update app.py
Browse files
app.py
CHANGED
@@ -130,13 +130,15 @@ class TransferVAE_Video(nn.Module):
|
|
130 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
131 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
132 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
|
|
133 |
f_mean = self.f_mean(lstm_out_f)
|
134 |
-
|
135 |
-
|
136 |
features, _ = self.z_rnn(lstm_out)
|
|
|
137 |
z_mean = self.z_mean(features)
|
138 |
-
|
139 |
-
|
140 |
return f_post, z_post
|
141 |
|
142 |
|
@@ -150,16 +152,6 @@ class TransferVAE_Video(nn.Module):
|
|
150 |
x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1])
|
151 |
x_embed = self.encoder(x)[0]
|
152 |
return x_embed.view(x_shape[0], x_shape[1], -1)
|
153 |
-
|
154 |
-
|
155 |
-
def reparameterize(self, mean, logvar, random_sampling=True):
|
156 |
-
if random_sampling is True:
|
157 |
-
eps = torch.randn_like(logvar)
|
158 |
-
std = torch.exp(0.5 * logvar)
|
159 |
-
z = mean + eps * std
|
160 |
-
return z
|
161 |
-
else:
|
162 |
-
return mean
|
163 |
|
164 |
|
165 |
def forward(self, x, beta):
|
@@ -171,9 +163,7 @@ class TransferVAE_Video(nn.Module):
|
|
171 |
zf = torch.cat((z_post, f_expand), dim=2)
|
172 |
recon_x = self.decoder_frame(zf)
|
173 |
return f_post, z_post, recon_x
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
|
178 |
def name2seq(file_name):
|
179 |
images = []
|
@@ -235,7 +225,7 @@ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, s
|
|
235 |
axs[1, 3].imshow(src_Zf_tar_Zt)
|
236 |
axs[1, 3].axis('off')
|
237 |
|
238 |
-
plt.subplots_adjust(hspace=0.
|
239 |
|
240 |
save_name = 'MyPlot_{}.png'.format(frame_id)
|
241 |
|
|
|
130 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
131 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
132 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
133 |
+
|
134 |
f_mean = self.f_mean(lstm_out_f)
|
135 |
+
f_post = f_mean
|
136 |
+
|
137 |
features, _ = self.z_rnn(lstm_out)
|
138 |
+
|
139 |
z_mean = self.z_mean(features)
|
140 |
+
z_post = z_mean
|
141 |
+
|
142 |
return f_post, z_post
|
143 |
|
144 |
|
|
|
152 |
x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1])
|
153 |
x_embed = self.encoder(x)[0]
|
154 |
return x_embed.view(x_shape[0], x_shape[1], -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
|
157 |
def forward(self, x, beta):
|
|
|
163 |
zf = torch.cat((z_post, f_expand), dim=2)
|
164 |
recon_x = self.decoder_frame(zf)
|
165 |
return f_post, z_post, recon_x
|
166 |
+
|
|
|
|
|
167 |
|
168 |
def name2seq(file_name):
|
169 |
images = []
|
|
|
225 |
axs[1, 3].imshow(src_Zf_tar_Zt)
|
226 |
axs[1, 3].axis('off')
|
227 |
|
228 |
+
plt.subplots_adjust(hspace=0.0125, wspace=0.0)
|
229 |
|
230 |
save_name = 'MyPlot_{}.png'.format(frame_id)
|
231 |
|