Diffusers
Safetensors
shunk031 commited on
Commit
ceb9782
·
verified ·
1 Parent(s): 75182cb

Upload unet/unet_2d_ncsn.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. unet/unet_2d_ncsn.py +90 -0
unet/unet_2d_ncsn.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ from diffusers import UNet2DModel
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+
9
+
10
+ class UNet2DModelForNCSN(UNet2DModel, ModelMixin, ConfigMixin): # type: ignore[misc]
11
+ @register_to_config
12
+ def __init__(
13
+ self,
14
+ sigma_min: float,
15
+ sigma_max: float,
16
+ num_train_timesteps: int,
17
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
18
+ in_channels: int = 3,
19
+ out_channels: int = 3,
20
+ center_input_sample: bool = False,
21
+ time_embedding_type: str = "positional",
22
+ time_embedding_dim: Optional[int] = None,
23
+ freq_shift: int = 0,
24
+ flip_sin_to_cos: bool = True,
25
+ down_block_types: Tuple[str, ...] = (
26
+ "DownBlock2D",
27
+ "AttnDownBlock2D",
28
+ "AttnDownBlock2D",
29
+ "AttnDownBlock2D",
30
+ ),
31
+ up_block_types: Tuple[str, ...] = (
32
+ "AttnUpBlock2D",
33
+ "AttnUpBlock2D",
34
+ "AttnUpBlock2D",
35
+ "UpBlock2D",
36
+ ),
37
+ block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
38
+ layers_per_block: int = 2,
39
+ mid_block_scale_factor: float = 1,
40
+ downsample_padding: int = 1,
41
+ downsample_type: str = "conv",
42
+ upsample_type: str = "conv",
43
+ dropout: float = 0.0,
44
+ act_fn: str = "silu",
45
+ attention_head_dim: Optional[int] = 8,
46
+ norm_num_groups: int = 32,
47
+ attn_norm_num_groups: Optional[int] = None,
48
+ norm_eps: float = 1e-5,
49
+ resnet_time_scale_shift: str = "default",
50
+ add_attention: bool = True,
51
+ class_embed_type: Optional[str] = None,
52
+ num_class_embeds: Optional[int] = None,
53
+ ) -> None:
54
+ super().__init__(
55
+ sample_size,
56
+ in_channels,
57
+ out_channels,
58
+ center_input_sample,
59
+ time_embedding_type,
60
+ time_embedding_dim,
61
+ freq_shift,
62
+ flip_sin_to_cos,
63
+ down_block_types,
64
+ up_block_types,
65
+ block_out_channels,
66
+ layers_per_block,
67
+ mid_block_scale_factor,
68
+ downsample_padding,
69
+ downsample_type,
70
+ upsample_type,
71
+ dropout,
72
+ act_fn,
73
+ attention_head_dim,
74
+ norm_num_groups,
75
+ attn_norm_num_groups,
76
+ norm_eps,
77
+ resnet_time_scale_shift,
78
+ add_attention,
79
+ class_embed_type,
80
+ num_class_embeds,
81
+ num_train_timesteps,
82
+ )
83
+ sigmas = torch.exp(
84
+ torch.linspace(
85
+ start=math.log(sigma_max),
86
+ end=math.log(sigma_min),
87
+ steps=num_train_timesteps,
88
+ )
89
+ )
90
+ self.register_buffer("sigmas", sigmas) # type: ignore