Diffusers
Safetensors
File size: 2,916 Bytes
ceb9782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import math
from typing import Optional, Tuple, Union

import torch
from diffusers import UNet2DModel
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin


class UNet2DModelForNCSN(UNet2DModel, ModelMixin, ConfigMixin):  # type: ignore[misc]
    @register_to_config
    def __init__(
        self,
        sigma_min: float,
        sigma_max: float,
        num_train_timesteps: int,
        sample_size: Optional[Union[int, Tuple[int, int]]] = None,
        in_channels: int = 3,
        out_channels: int = 3,
        center_input_sample: bool = False,
        time_embedding_type: str = "positional",
        time_embedding_dim: Optional[int] = None,
        freq_shift: int = 0,
        flip_sin_to_cos: bool = True,
        down_block_types: Tuple[str, ...] = (
            "DownBlock2D",
            "AttnDownBlock2D",
            "AttnDownBlock2D",
            "AttnDownBlock2D",
        ),
        up_block_types: Tuple[str, ...] = (
            "AttnUpBlock2D",
            "AttnUpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
        ),
        block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
        layers_per_block: int = 2,
        mid_block_scale_factor: float = 1,
        downsample_padding: int = 1,
        downsample_type: str = "conv",
        upsample_type: str = "conv",
        dropout: float = 0.0,
        act_fn: str = "silu",
        attention_head_dim: Optional[int] = 8,
        norm_num_groups: int = 32,
        attn_norm_num_groups: Optional[int] = None,
        norm_eps: float = 1e-5,
        resnet_time_scale_shift: str = "default",
        add_attention: bool = True,
        class_embed_type: Optional[str] = None,
        num_class_embeds: Optional[int] = None,
    ) -> None:
        super().__init__(
            sample_size,
            in_channels,
            out_channels,
            center_input_sample,
            time_embedding_type,
            time_embedding_dim,
            freq_shift,
            flip_sin_to_cos,
            down_block_types,
            up_block_types,
            block_out_channels,
            layers_per_block,
            mid_block_scale_factor,
            downsample_padding,
            downsample_type,
            upsample_type,
            dropout,
            act_fn,
            attention_head_dim,
            norm_num_groups,
            attn_norm_num_groups,
            norm_eps,
            resnet_time_scale_shift,
            add_attention,
            class_embed_type,
            num_class_embeds,
            num_train_timesteps,
        )
        sigmas = torch.exp(
            torch.linspace(
                start=math.log(sigma_max),
                end=math.log(sigma_min),
                steps=num_train_timesteps,
            )
        )
        self.register_buffer("sigmas", sigmas)  # type: ignore