Diffusers
Safetensors
shunk031 commited on
Commit
fb4d378
·
verified ·
1 Parent(s): ca77068

Upload scheduler/scheduling_ncsn.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scheduler/scheduling_ncsn.py +129 -0
scheduler/scheduling_ncsn.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ from diffusers.configuration_utils import (
7
+ ConfigMixin,
8
+ register_to_config,
9
+ )
10
+ from diffusers.schedulers.scheduling_utils import (
11
+ SchedulerMixin,
12
+ SchedulerOutput,
13
+ )
14
+ from einops import rearrange
15
+
16
+
17
+ @dataclass
18
+ class AnnealedLangevinDynamicOutput(SchedulerOutput):
19
+ """Annealed Langevin Dynamic output class."""
20
+
21
+
22
+ class AnnealedLangevinDynamicScheduler(SchedulerMixin, ConfigMixin): # type: ignore
23
+ order = 1
24
+
25
+ @register_to_config
26
+ def __init__(
27
+ self,
28
+ num_train_timesteps: int,
29
+ num_annealed_steps: int,
30
+ sigma_min: float,
31
+ sigma_max: float,
32
+ sampling_eps: float,
33
+ ) -> None:
34
+ self.num_train_timesteps = num_train_timesteps
35
+ self.num_annealed_steps = num_annealed_steps
36
+
37
+ self._sigma_min = sigma_min
38
+ self._sigma_max = sigma_max
39
+ self._sampling_eps = sampling_eps
40
+
41
+ self._sigmas: Optional[torch.Tensor] = None
42
+ self._step_size: Optional[torch.Tensor] = None
43
+ self._timesteps: Optional[torch.Tensor] = None
44
+
45
+ self.set_sigmas(num_inference_steps=num_train_timesteps)
46
+
47
+ @property
48
+ def sigmas(self) -> torch.Tensor:
49
+ assert self._sigmas is not None
50
+ return self._sigmas
51
+
52
+ @property
53
+ def step_size(self) -> torch.Tensor:
54
+ assert self._step_size is not None
55
+ return self._step_size
56
+
57
+ @property
58
+ def timesteps(self) -> torch.Tensor:
59
+ assert self._timesteps is not None
60
+ return self._timesteps
61
+
62
+ def scale_model_input(
63
+ self, sample: torch.Tensor, timestep: Optional[int] = None
64
+ ) -> torch.Tensor:
65
+ return sample
66
+
67
+ def set_timesteps(
68
+ self,
69
+ num_inference_steps: int,
70
+ sampling_eps: Optional[float] = None,
71
+ device: Optional[Union[str, torch.device]] = None,
72
+ ) -> None:
73
+ sampling_eps = sampling_eps or self._sampling_eps
74
+ self._timesteps = torch.arange(start=0, end=num_inference_steps)
75
+
76
+ def set_sigmas(
77
+ self,
78
+ num_inference_steps: int,
79
+ sigma_min: Optional[float] = None,
80
+ sigma_max: Optional[float] = None,
81
+ sampling_eps: Optional[float] = None,
82
+ ) -> None:
83
+ if self._timesteps is None:
84
+ self.set_timesteps(
85
+ num_inference_steps=num_inference_steps,
86
+ sampling_eps=sampling_eps,
87
+ )
88
+
89
+ sigma_min = sigma_min or self._sigma_min
90
+ sigma_max = sigma_max or self._sigma_max
91
+ self._sigmas = torch.exp(
92
+ torch.linspace(
93
+ start=math.log(sigma_max),
94
+ end=math.log(sigma_min),
95
+ steps=num_inference_steps,
96
+ )
97
+ )
98
+
99
+ sampling_eps = sampling_eps or self._sampling_eps
100
+ self._step_size = sampling_eps * (self.sigmas / self.sigmas[-1]) ** 2
101
+
102
+ def step(
103
+ self,
104
+ model_output: torch.Tensor,
105
+ timestep: int,
106
+ samples: torch.Tensor,
107
+ return_dict: bool = True,
108
+ **kwargs,
109
+ ) -> Union[AnnealedLangevinDynamicOutput, Tuple]:
110
+ z = torch.randn_like(samples)
111
+ step_size = self.step_size[timestep]
112
+ samples = samples + 0.5 * step_size * model_output + torch.sqrt(step_size) * z
113
+
114
+ if return_dict:
115
+ return AnnealedLangevinDynamicOutput(prev_sample=samples)
116
+ else:
117
+ return (samples,)
118
+
119
+ def add_noise(
120
+ self,
121
+ original_samples: torch.Tensor,
122
+ noise: torch.Tensor,
123
+ timesteps: torch.Tensor,
124
+ ) -> torch.Tensor:
125
+ timesteps = timesteps.to(original_samples.device)
126
+ sigmas = self.sigmas.to(original_samples.device)[timesteps]
127
+ sigmas = rearrange(sigmas, "b -> b 1 1 1")
128
+ noisy_samples = original_samples + noise * sigmas
129
+ return noisy_samples