File size: 835 Bytes
acc22af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import PretrainedConfig

class ProbUNetConfig(PretrainedConfig):
    model_type = "ProbUNet"
    def __init__(
            self,
            dim=2,
            in_channels=1,
            out_channels=1,
            num_feature_maps=24,
            latent_size=3,
            depth=5,
            latent_distribution="normal",
            no_outact_op=False,
            prob_injection_at="end",
            **kwargs):
        self.dim = dim
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_feature_maps = num_feature_maps
        self.latent_size = latent_size
        self.depth = depth
        self.latent_distribution = latent_distribution
        self.no_outact_op = no_outact_op
        self.prob_injection_at = prob_injection_at
        super().__init__(**kwargs)