yyk19 commited on
Commit
ca83fd7
·
1 Parent(s): 066645f

update the checkpoints fine-tuned on TextCaps 5K.

Browse files
app.py CHANGED
@@ -82,12 +82,18 @@ def load_ckpt(model_ckpt = "LAION-Glyph-10M-Epoch-5"):
82
  time.sleep(2)
83
  print("empty the cuda cache")
84
 
85
- if model_ckpt == "LAION-Glyph-1M":
86
- model = load_model_ckpt(model, "laion1M_model_wo_ema.ckpt")
87
- elif model_ckpt == "LAION-Glyph-10M-Epoch-5":
88
  model = load_model_ckpt(model, "laion10M_epoch_5_model_wo_ema.ckpt")
89
  elif model_ckpt == "LAION-Glyph-10M-Epoch-6":
90
  model = load_model_ckpt(model, "laion10M_epoch_6_model_wo_ema.ckpt")
 
 
 
 
 
 
91
 
92
  render_tool = Render_Text(model)
93
  output_str = f"already change the model checkpoint to {model_ckpt}"
@@ -126,7 +132,7 @@ with block:
126
  only_show_rendered_image = gr.Number(value=1, visible=False)
127
  default_width = [0.3, 0.3, 0.3, 0.3]
128
  default_top_left_x = [0.35, 0.15, 0.15, 0.5]
129
- default_top_left_y = [0.5, 0.25, 0.75, 0.75]
130
  with gr.Column():
131
 
132
  with gr.Row():
@@ -154,7 +160,8 @@ with block:
154
  with gr.Accordion("Model Options", open=False):
155
  with gr.Row():
156
  # model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M", "Textcaps5K-10"], label="Checkpoint", default = "LAION-Glyph-10M")
157
- model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M-Epoch-6", "LAION-Glyph-10M-Epoch-5", "LAION-Glyph-1M"], label="Checkpoint", default = "LAION-Glyph-10M-Epoch-6")
 
158
  # load_button = gr.Button(value = "Load Checkpoint")
159
 
160
  with gr.Accordion("Shared Advanced Options", open=False):
 
82
  time.sleep(2)
83
  print("empty the cuda cache")
84
 
85
+ # if model_ckpt == "LAION-Glyph-1M":
86
+ # model = load_model_ckpt(model, "laion1M_model_wo_ema.ckpt")
87
+ if model_ckpt == "LAION-Glyph-10M-Epoch-5":
88
  model = load_model_ckpt(model, "laion10M_epoch_5_model_wo_ema.ckpt")
89
  elif model_ckpt == "LAION-Glyph-10M-Epoch-6":
90
  model = load_model_ckpt(model, "laion10M_epoch_6_model_wo_ema.ckpt")
91
+ elif model_ckpt == "TextCaps-5K-Epoch-10":
92
+ model = load_model_ckpt(model, "textcaps5K_epoch_10_model_wo_ema.ckpt")
93
+ elif model_ckpt == "TextCaps-5K-Epoch-20":
94
+ model = load_model_ckpt(model, "textcaps5K_epoch_20_model_wo_ema.ckpt")
95
+ elif model_ckpt == "TextCaps-5K-Epoch-40":
96
+ model = load_model_ckpt(model, "textcaps5K_epoch_40_model_wo_ema.ckpt")
97
 
98
  render_tool = Render_Text(model)
99
  output_str = f"already change the model checkpoint to {model_ckpt}"
 
132
  only_show_rendered_image = gr.Number(value=1, visible=False)
133
  default_width = [0.3, 0.3, 0.3, 0.3]
134
  default_top_left_x = [0.35, 0.15, 0.15, 0.5]
135
+ default_top_left_y = [0.4, 0.15, 0.65, 0.65]
136
  with gr.Column():
137
 
138
  with gr.Row():
 
160
  with gr.Accordion("Model Options", open=False):
161
  with gr.Row():
162
  # model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M", "Textcaps5K-10"], label="Checkpoint", default = "LAION-Glyph-10M")
163
+ # model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M-Epoch-6", "LAION-Glyph-10M-Epoch-5", "LAION-Glyph-1M"], label="Checkpoint", default = "LAION-Glyph-10M-Epoch-6")
164
+ model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M-Epoch-6", "LAION-Glyph-10M-Epoch-5", "TextCaps-5K-Epoch-10", "TextCaps-5K-Epoch-20", "TextCaps-5K-Epoch-40"], label="Checkpoint", default = "LAION-Glyph-10M-Epoch-6")
165
  # load_button = gr.Button(value = "Load Checkpoint")
166
 
167
  with gr.Accordion("Shared Advanced Options", open=False):
cldm/cldm.py CHANGED
@@ -532,13 +532,6 @@ class ControlLDM(LatentDiffusion):
532
  self.freeze_glyph_image_encoder = model.freeze_image_encoder #image_encoder.freeze_model
533
  self.glyph_control_model = model
534
  self.glyph_image_encoder_type = model.image_encoder_type
535
- # self.glyph_control_optim = torch.optim.AdamW([
536
- # {"params": gain_or_bias_params, "weight_decay": 0.}, # "lr": self.glycon_lr},
537
- # {"params": rest_params, "weight_decay": self.glycon_wd} #, "lr": self.glycon_lr},
538
- # ],
539
- # lr = self.glycon_lr
540
- # )
541
- # params += list(model.image_encoder.parameters())
542
 
543
 
544
 
@@ -738,16 +731,6 @@ class ControlLDM(LatentDiffusion):
738
  if decoder_params is not None:
739
  params_wlr.append({"params": decoder_params, "lr": self.decoder_lr})
740
 
741
-
742
- # if not self.sep_lr:
743
- # opt = torch.optim.AdamW(params, lr=lr)
744
- # else:
745
- # opt = torch.optim.AdamW(
746
- # [
747
- # {"params": params},
748
- # {"params": decoder_params, "lr": self.decoder_lr}
749
- # ], lr=lr
750
- # )
751
  if not self.freeze_glyph_image_encoder:
752
  if self.glyph_image_encoder_type == "CLIP":
753
  # assert self.sep_lr
@@ -866,20 +849,6 @@ class ControlLDM(LatentDiffusion):
866
  if p.requires_grad and p.grad is not None:
867
  grad_norm_v = p.grad.cpu().detach().norm().item()
868
  gradnorm_list.append(grad_norm_v)
869
- # for name, p in self.named_parameters():
870
- # if p.requires_grad and p.grad is not None:
871
- # grad_norm_v = p.grad.detach().norm().item()
872
- # gradnorm_list.append(grad_norm_v)
873
- # if "textemb_merge_model" in name:
874
- # self.log("all_gradients/{}_norm".format(name),
875
- # gradnorm_list[-1],
876
- # prog_bar=False, logger=True, on_step=True, on_epoch=False
877
- # )
878
- # # if grad_norm_v > 0.1:
879
- # # print("the norm of gradient w.r.t {} > 0.1: {:.2f}".format
880
- # # (
881
- # # name, grad_norm_v
882
- # # ))
883
  if len(gradnorm_list):
884
  self.log("all_gradients/grad_norm_mean",
885
  np.mean(gradnorm_list),
@@ -943,19 +912,4 @@ class ControlLDM(LatentDiffusion):
943
  prog_bar=False, logger=True, on_step=True, on_epoch=False
944
  )
945
  del gradnorm_list
946
- del zeroconvs
947
-
948
- # def freeze_unet(self):
949
- # # Have some bugs
950
- # self.model.eval()
951
- # # self.model.train = disabled_train
952
- # for param in self.model.parameters():
953
- # param.requires_grad = False
954
-
955
- # if not self.sd_locked:
956
- # self.model.diffusion_model.output_blocks.train()
957
- # self.model.diffusion_model.out.train()
958
- # for param in self.model.diffusion_model.out.parameters():
959
- # param.requires_grad = True
960
- # for param in self.model.diffusion_model.output_blocks.parameters():
961
- # param.requires_grad = True
 
532
  self.freeze_glyph_image_encoder = model.freeze_image_encoder #image_encoder.freeze_model
533
  self.glyph_control_model = model
534
  self.glyph_image_encoder_type = model.image_encoder_type
 
 
 
 
 
 
 
535
 
536
 
537
 
 
731
  if decoder_params is not None:
732
  params_wlr.append({"params": decoder_params, "lr": self.decoder_lr})
733
 
 
 
 
 
 
 
 
 
 
 
734
  if not self.freeze_glyph_image_encoder:
735
  if self.glyph_image_encoder_type == "CLIP":
736
  # assert self.sep_lr
 
849
  if p.requires_grad and p.grad is not None:
850
  grad_norm_v = p.grad.cpu().detach().norm().item()
851
  gradnorm_list.append(grad_norm_v)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
  if len(gradnorm_list):
853
  self.log("all_gradients/grad_norm_mean",
854
  np.mean(gradnorm_list),
 
912
  prog_bar=False, logger=True, on_step=True, on_epoch=False
913
  )
914
  del gradnorm_list
915
+ del zeroconvs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config_ema_unlock.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-6 #1.0e-5 #1.0e-4
3
+ target: cldm.cldm.ControlLDM
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ control_key: "hint"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false
16
+ conditioning_key: crossattn
17
+ monitor: #val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ only_mid_control: False
20
+ sd_locked: False #True
21
+ use_ema: True #TODO: specify
22
+
23
+ control_stage_config:
24
+ target: cldm.cldm.ControlNet
25
+ params:
26
+ use_checkpoint: True
27
+ image_size: 32 # unused
28
+ in_channels: 4
29
+ hint_channels: 3
30
+ model_channels: 320
31
+ attention_resolutions: [ 4, 2, 1 ]
32
+ num_res_blocks: 2
33
+ channel_mult: [ 1, 2, 4, 4 ]
34
+ num_head_channels: 64 # need to fix for flash-attn
35
+ use_spatial_transformer: True
36
+ use_linear_in_transformer: True
37
+ transformer_depth: 1
38
+ context_dim: 1024
39
+ legacy: False
40
+
41
+ unet_config:
42
+ target: cldm.cldm.ControlledUnetModel
43
+ params:
44
+ use_checkpoint: True
45
+ image_size: 32 # unused
46
+ in_channels: 4
47
+ out_channels: 4
48
+ model_channels: 320
49
+ attention_resolutions: [ 4, 2, 1 ]
50
+ num_res_blocks: 2
51
+ channel_mult: [ 1, 2, 4, 4 ]
52
+ num_head_channels: 64 # need to fix for flash-attn
53
+ use_spatial_transformer: True
54
+ use_linear_in_transformer: True
55
+ transformer_depth: 1
56
+ context_dim: 1024
57
+ legacy: False
58
+
59
+ first_stage_config:
60
+ target: ldm.models.autoencoder.AutoencoderKL
61
+ params:
62
+ embed_dim: 4
63
+ monitor: val/rec_loss
64
+ ddconfig:
65
+ #attn_type: "vanilla-xformers"
66
+ double_z: true
67
+ z_channels: 4
68
+ resolution: 256
69
+ in_channels: 3
70
+ out_ch: 3
71
+ ch: 128
72
+ ch_mult:
73
+ - 1
74
+ - 2
75
+ - 4
76
+ - 4
77
+ num_res_blocks: 2
78
+ attn_resolutions: []
79
+ dropout: 0.0
80
+ lossconfig:
81
+ target: torch.nn.Identity
82
+
83
+ cond_stage_config:
84
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
85
+ params:
86
+ freeze: True
87
+ layer: "penultimate"
88
+ # device: "cpu" #TODO: specify
laion1M_model_wo_ema.ckpt → textcaps5K_epoch_10_model_wo_ema.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0b86b22188bf580e80773a5ae101bf9787eb258349f3f1acf0ae50fd10cb3fec
3
- size 6671922039
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c26cd80dcdd8b5563a68f397f291d0d2d7a4bef7a8c2435fd97a36be32ef61be
3
+ size 6671914001
textcaps5K_epoch_20_model_wo_ema.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85c887aa42db7afbed071629bcf5a07cfccdcdc80216475d8a2536fed75cc600
3
+ size 6671914001
textcaps5K_epoch_40_model_wo_ema.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:511be806f6e44f9c33af75df181adacfb3a0bb71aac8df8b303fff36e8e97dae
3
+ size 6671914001
transfer.py CHANGED
@@ -2,10 +2,13 @@ from omegaconf import OmegaConf
2
  from scripts.rendertext_tool import Render_Text, load_model_from_config
3
  import torch
4
 
5
- cfg = OmegaConf.load("config_ema.yaml")
6
- # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
7
- model = load_model_from_config(cfg, "mp_rank_00_model_states.pt", verbose=True)
8
 
 
 
 
9
 
10
  from pytorch_lightning.callbacks import ModelCheckpoint
11
  with model.ema_scope("store ema weights"):
@@ -18,6 +21,6 @@ with model.ema_scope("store ema weights"):
18
  file_content = {
19
  'state_dict': store_sd
20
  }
21
- torch.save(file_content, "model_wo_ema.ckpt")
22
  print("has stored the transfered ckpt.")
23
  print("trial ends!")
 
2
  from scripts.rendertext_tool import Render_Text, load_model_from_config
3
  import torch
4
 
5
+ # cfg = OmegaConf.load("config_ema.yaml")
6
+ # # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
7
+ # model = load_model_from_config(cfg, "mp_rank_00_model_states.pt", verbose=True)
8
 
9
+ cfg = OmegaConf.load("config_ema_unlock.yaml")
10
+ epoch_idx = 39
11
+ model = load_model_from_config(cfg, "epoch={:0>6d}.ckpt".format(epoch_idx), verbose=True)
12
 
13
  from pytorch_lightning.callbacks import ModelCheckpoint
14
  with model.ema_scope("store ema weights"):
 
21
  file_content = {
22
  'state_dict': store_sd
23
  }
24
+ torch.save(file_content, f"textcaps5K_epoch_{epoch_idx+1}_model_wo_ema.ckpt")
25
  print("has stored the transfered ckpt.")
26
  print("trial ends!")