LanguageBind commited on
Commit
810fa8c
1 Parent(s): 05dac61

Create vae/vae.py

Browse files
opensora/models/ae/imagebase/vae/vae.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ from torch import nn
3
+ from diffusers.models import AutoencoderKL
4
+
5
+
6
+ class HFVAEWrapper(nn.Module):
7
+ def __init__(self, hfvae='mse'):
8
+ super(HFVAEWrapper, self).__init__()
9
+ self.vae = AutoencoderKL.from_pretrained(hfvae, cache_dir='cache_dir')
10
+ def encode(self, x): # b c h w
11
+ t = 0
12
+ if x.ndim == 5:
13
+ b, c, t, h, w = x.shape
14
+ x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous()
15
+ x = self.vae.encode(x).latent_dist.sample().mul_(0.18215)
16
+ if t != 0:
17
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t).contiguous()
18
+ return x
19
+ def decode(self, x):
20
+ t = 0
21
+ if x.ndim == 5:
22
+ b, c, t, h, w = x.shape
23
+ x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous()
24
+ x = self.vae.decode(x / 0.18215).sample
25
+ if t != 0:
26
+ x = rearrange(x, '(b t) c h w -> b t c h w', t=t).contiguous()
27
+ return x
28
+
29
+ class SDVAEWrapper(nn.Module):
30
+ def __init__(self):
31
+ super(SDVAEWrapper, self).__init__()
32
+ raise NotImplementedError
33
+
34
+ def encode(self, x): # b c h w
35
+ raise NotImplementedError
36
+
37
+ def decode(self, x):
38
+ raise NotImplementedError