File size: 1,540 Bytes
e077396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from dataclasses import dataclass
from typing import List, Tuple, Union, Optional

import torch
from diffusers import DDPMScheduler
from diffusers.utils import BaseOutput


@dataclass
class DMDSchedulerOutput(BaseOutput):
    pred_original_sample: Optional[torch.FloatTensor] = None


class DMDScheduler(DDPMScheduler):
    def set_timesteps(
        self,
        num_inference_steps: Optional[int] = None,
        device: Union[str, torch.device] = None,
        timesteps: Optional[List[int]] = None,
    ):
        self.timesteps = torch.tensor([self.config.num_train_timesteps-1]).long().to(device)

    def step(
        self,
        model_output: torch.FloatTensor,
        timestep: int,
        sample: torch.FloatTensor,
        generator=None,
        return_dict: bool = True,
    ) -> Union[DMDSchedulerOutput, Tuple]:
        t = self.config.num_train_timesteps - 1

        # 1. compute alphas, betas
        alpha_prod_t = self.alphas_cumprod[t]
        beta_prod_t = 1 - alpha_prod_t

        if self.config.prediction_type == "epsilon":
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
                " `v_prediction`  for the DDPMScheduler."
            )

        if not return_dict:
            return (pred_original_sample,)

        return DMDSchedulerOutput(pred_original_sample=pred_original_sample)