File size: 4,544 Bytes
9e7a39a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
model:
  target: sgm.models.diffusion.DiffusionEngine
  params:
    input_key: image
    scale_factor: 0.18215
    disable_first_stage_autocast: True

    denoiser_config:
      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
      params:
        num_idx: 1000

        weighting_config:
          target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
        scaling_config:
          target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

    network_config:
      target: sgm.modules.diffusionmodules.openaimodel.UNetAddModel
      params:
        use_checkpoint: False
        in_channels: 9
        out_channels: 4
        ctrl_channels: 0
        model_channels: 320
        attention_resolutions: [4, 2, 1]
        attn_type: add_attn
        attn_layers: 
          - output_blocks.6.1
        num_res_blocks: 2
        channel_mult: [1, 2, 4, 4]
        num_head_channels: 64
        use_spatial_transformer: True
        use_linear_in_transformer: True
        transformer_depth: 1
        context_dim: 0
        add_context_dim: 2048
        legacy: False

    conditioner_config:
      target: sgm.modules.GeneralConditioner
      params:
        emb_models:
          # crossattn cond
          # - is_trainable: False
          #   input_key: txt
          #   target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
          #   params:
          #     arch: ViT-H-14
          #     version: ./checkpoints/encoders/OpenCLIP/ViT-H-14/open_clip_pytorch_model.bin
          #     layer: penultimate
          # add crossattn cond
          - is_trainable: False
            input_key: label
            target: sgm.modules.encoders.modules.LabelEncoder
            params:
              is_add_embedder: True
              max_len: 12
              emb_dim: 2048
              n_heads: 8
              n_trans_layers: 12
              ckpt_path: ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt # ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt
          # concat cond
          - is_trainable: False
            input_key: mask
            target: sgm.modules.encoders.modules.IdentityEncoder
          - is_trainable: False
            input_key: masked
            target: sgm.modules.encoders.modules.LatentEncoder
            params:
              scale_factor: 0.18215
              config:
                target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
                params:
                  ckpt_path: ./checkpoints/AEs/AE_inpainting_2.safetensors
                  embed_dim: 4
                  monitor: val/rec_loss
                  ddconfig:
                    attn_type: vanilla-xformers
                    double_z: true
                    z_channels: 4
                    resolution: 256
                    in_channels: 3
                    out_ch: 3
                    ch: 128
                    ch_mult: [1, 2, 4, 4]
                    num_res_blocks: 2
                    attn_resolutions: []
                    dropout: 0.0
                  lossconfig:
                    target: torch.nn.Identity

    first_stage_config:
      target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          attn_type: vanilla-xformers
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult: [1, 2, 4, 4]
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    loss_fn_config:
      target: sgm.modules.diffusionmodules.loss.FullLoss # StandardDiffusionLoss
      params:
        seq_len: 12
        kernel_size: 3
        gaussian_sigma: 0.5
        min_attn_size: 16
        lambda_local_loss: 0.02
        lambda_ocr_loss: 0.001
        ocr_enabled: False

        predictor_config:
          target: sgm.modules.predictors.model.ParseqPredictor
          params:
            ckpt_path: "./checkpoints/predictors/parseq-bb5792a6.pt"
        
        sigma_sampler_config:
          target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
          params:
            num_idx: 1000
            
            discretization_config:
              target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization