ldkong commited on
Commit
9432cb2
·
1 Parent(s): 1b79989

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -18
app.py CHANGED
@@ -130,13 +130,15 @@ class TransferVAE_Video(nn.Module):
130
  backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
131
  frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
132
  lstm_out_f = torch.cat((frontal, backward), dim=1)
 
133
  f_mean = self.f_mean(lstm_out_f)
134
- f_logvar = self.f_logvar(lstm_out_f)
135
- f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
136
  features, _ = self.z_rnn(lstm_out)
 
137
  z_mean = self.z_mean(features)
138
- z_logvar = self.z_logvar(features)
139
- z_post = self.reparameterize(z_mean, z_logvar, random_sampling=False)
140
  return f_post, z_post
141
 
142
 
@@ -150,16 +152,6 @@ class TransferVAE_Video(nn.Module):
150
  x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1])
151
  x_embed = self.encoder(x)[0]
152
  return x_embed.view(x_shape[0], x_shape[1], -1)
153
-
154
-
155
- def reparameterize(self, mean, logvar, random_sampling=True):
156
- if random_sampling is True:
157
- eps = torch.randn_like(logvar)
158
- std = torch.exp(0.5 * logvar)
159
- z = mean + eps * std
160
- return z
161
- else:
162
- return mean
163
 
164
 
165
  def forward(self, x, beta):
@@ -171,9 +163,7 @@ class TransferVAE_Video(nn.Module):
171
  zf = torch.cat((z_post, f_expand), dim=2)
172
  recon_x = self.decoder_frame(zf)
173
  return f_post, z_post, recon_x
174
-
175
-
176
-
177
 
178
  def name2seq(file_name):
179
  images = []
@@ -235,7 +225,7 @@ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, s
235
  axs[1, 3].imshow(src_Zf_tar_Zt)
236
  axs[1, 3].axis('off')
237
 
238
- plt.subplots_adjust(hspace=0.01, wspace=0.0)
239
 
240
  save_name = 'MyPlot_{}.png'.format(frame_id)
241
 
 
130
  backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
131
  frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
132
  lstm_out_f = torch.cat((frontal, backward), dim=1)
133
+
134
  f_mean = self.f_mean(lstm_out_f)
135
+ f_post = f_mean
136
+
137
  features, _ = self.z_rnn(lstm_out)
138
+
139
  z_mean = self.z_mean(features)
140
+ z_post = z_mean
141
+
142
  return f_post, z_post
143
 
144
 
 
152
  x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1])
153
  x_embed = self.encoder(x)[0]
154
  return x_embed.view(x_shape[0], x_shape[1], -1)
 
 
 
 
 
 
 
 
 
 
155
 
156
 
157
  def forward(self, x, beta):
 
163
  zf = torch.cat((z_post, f_expand), dim=2)
164
  recon_x = self.decoder_frame(zf)
165
  return f_post, z_post, recon_x
166
+
 
 
167
 
168
  def name2seq(file_name):
169
  images = []
 
225
  axs[1, 3].imshow(src_Zf_tar_Zt)
226
  axs[1, 3].axis('off')
227
 
228
+ plt.subplots_adjust(hspace=0.0125, wspace=0.0)
229
 
230
  save_name = 'MyPlot_{}.png'.format(frame_id)
231