File size: 1,237 Bytes
810fa8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from einops import rearrange
from torch import nn
from diffusers.models import AutoencoderKL


class HFVAEWrapper(nn.Module):
    def __init__(self, hfvae='mse'):
        super(HFVAEWrapper, self).__init__()
        self.vae = AutoencoderKL.from_pretrained(hfvae, cache_dir='cache_dir')
    def encode(self, x):  # b c h w
        t = 0
        if x.ndim == 5:
            b, c, t, h, w = x.shape
            x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous()
        x = self.vae.encode(x).latent_dist.sample().mul_(0.18215)
        if t != 0:
            x = rearrange(x, '(b t) c h w -> b c t h w', t=t).contiguous()
        return x
    def decode(self, x):
        t = 0
        if x.ndim == 5:
            b, c, t, h, w = x.shape
            x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous()
        x = self.vae.decode(x / 0.18215).sample
        if t != 0:
            x = rearrange(x, '(b t) c h w -> b t c h w', t=t).contiguous()
        return x

class SDVAEWrapper(nn.Module):
    def __init__(self):
        super(SDVAEWrapper, self).__init__()
        raise NotImplementedError

    def encode(self, x):  # b c h w
        raise NotImplementedError

    def decode(self, x):
        raise NotImplementedError