Spaces:
Sleeping
Sleeping
update the checkpoints fine-tuned on TextCaps 5K.
Browse files- app.py +12 -5
- cldm/cldm.py +1 -47
- config_ema_unlock.yaml +88 -0
- laion1M_model_wo_ema.ckpt → textcaps5K_epoch_10_model_wo_ema.ckpt +2 -2
- textcaps5K_epoch_20_model_wo_ema.ckpt +3 -0
- textcaps5K_epoch_40_model_wo_ema.ckpt +3 -0
- transfer.py +7 -4
app.py
CHANGED
@@ -82,12 +82,18 @@ def load_ckpt(model_ckpt = "LAION-Glyph-10M-Epoch-5"):
|
|
82 |
time.sleep(2)
|
83 |
print("empty the cuda cache")
|
84 |
|
85 |
-
if model_ckpt == "LAION-Glyph-1M":
|
86 |
-
|
87 |
-
|
88 |
model = load_model_ckpt(model, "laion10M_epoch_5_model_wo_ema.ckpt")
|
89 |
elif model_ckpt == "LAION-Glyph-10M-Epoch-6":
|
90 |
model = load_model_ckpt(model, "laion10M_epoch_6_model_wo_ema.ckpt")
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
render_tool = Render_Text(model)
|
93 |
output_str = f"already change the model checkpoint to {model_ckpt}"
|
@@ -126,7 +132,7 @@ with block:
|
|
126 |
only_show_rendered_image = gr.Number(value=1, visible=False)
|
127 |
default_width = [0.3, 0.3, 0.3, 0.3]
|
128 |
default_top_left_x = [0.35, 0.15, 0.15, 0.5]
|
129 |
-
default_top_left_y = [0.
|
130 |
with gr.Column():
|
131 |
|
132 |
with gr.Row():
|
@@ -154,7 +160,8 @@ with block:
|
|
154 |
with gr.Accordion("Model Options", open=False):
|
155 |
with gr.Row():
|
156 |
# model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M", "Textcaps5K-10"], label="Checkpoint", default = "LAION-Glyph-10M")
|
157 |
-
model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M-Epoch-6", "LAION-Glyph-10M-Epoch-5", "LAION-Glyph-1M"], label="Checkpoint", default = "LAION-Glyph-10M-Epoch-6")
|
|
|
158 |
# load_button = gr.Button(value = "Load Checkpoint")
|
159 |
|
160 |
with gr.Accordion("Shared Advanced Options", open=False):
|
|
|
82 |
time.sleep(2)
|
83 |
print("empty the cuda cache")
|
84 |
|
85 |
+
# if model_ckpt == "LAION-Glyph-1M":
|
86 |
+
# model = load_model_ckpt(model, "laion1M_model_wo_ema.ckpt")
|
87 |
+
if model_ckpt == "LAION-Glyph-10M-Epoch-5":
|
88 |
model = load_model_ckpt(model, "laion10M_epoch_5_model_wo_ema.ckpt")
|
89 |
elif model_ckpt == "LAION-Glyph-10M-Epoch-6":
|
90 |
model = load_model_ckpt(model, "laion10M_epoch_6_model_wo_ema.ckpt")
|
91 |
+
elif model_ckpt == "TextCaps-5K-Epoch-10":
|
92 |
+
model = load_model_ckpt(model, "textcaps5K_epoch_10_model_wo_ema.ckpt")
|
93 |
+
elif model_ckpt == "TextCaps-5K-Epoch-20":
|
94 |
+
model = load_model_ckpt(model, "textcaps5K_epoch_20_model_wo_ema.ckpt")
|
95 |
+
elif model_ckpt == "TextCaps-5K-Epoch-40":
|
96 |
+
model = load_model_ckpt(model, "textcaps5K_epoch_40_model_wo_ema.ckpt")
|
97 |
|
98 |
render_tool = Render_Text(model)
|
99 |
output_str = f"already change the model checkpoint to {model_ckpt}"
|
|
|
132 |
only_show_rendered_image = gr.Number(value=1, visible=False)
|
133 |
default_width = [0.3, 0.3, 0.3, 0.3]
|
134 |
default_top_left_x = [0.35, 0.15, 0.15, 0.5]
|
135 |
+
default_top_left_y = [0.4, 0.15, 0.65, 0.65]
|
136 |
with gr.Column():
|
137 |
|
138 |
with gr.Row():
|
|
|
160 |
with gr.Accordion("Model Options", open=False):
|
161 |
with gr.Row():
|
162 |
# model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M", "Textcaps5K-10"], label="Checkpoint", default = "LAION-Glyph-10M")
|
163 |
+
# model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M-Epoch-6", "LAION-Glyph-10M-Epoch-5", "LAION-Glyph-1M"], label="Checkpoint", default = "LAION-Glyph-10M-Epoch-6")
|
164 |
+
model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M-Epoch-6", "LAION-Glyph-10M-Epoch-5", "TextCaps-5K-Epoch-10", "TextCaps-5K-Epoch-20", "TextCaps-5K-Epoch-40"], label="Checkpoint", default = "LAION-Glyph-10M-Epoch-6")
|
165 |
# load_button = gr.Button(value = "Load Checkpoint")
|
166 |
|
167 |
with gr.Accordion("Shared Advanced Options", open=False):
|
cldm/cldm.py
CHANGED
@@ -532,13 +532,6 @@ class ControlLDM(LatentDiffusion):
|
|
532 |
self.freeze_glyph_image_encoder = model.freeze_image_encoder #image_encoder.freeze_model
|
533 |
self.glyph_control_model = model
|
534 |
self.glyph_image_encoder_type = model.image_encoder_type
|
535 |
-
# self.glyph_control_optim = torch.optim.AdamW([
|
536 |
-
# {"params": gain_or_bias_params, "weight_decay": 0.}, # "lr": self.glycon_lr},
|
537 |
-
# {"params": rest_params, "weight_decay": self.glycon_wd} #, "lr": self.glycon_lr},
|
538 |
-
# ],
|
539 |
-
# lr = self.glycon_lr
|
540 |
-
# )
|
541 |
-
# params += list(model.image_encoder.parameters())
|
542 |
|
543 |
|
544 |
|
@@ -738,16 +731,6 @@ class ControlLDM(LatentDiffusion):
|
|
738 |
if decoder_params is not None:
|
739 |
params_wlr.append({"params": decoder_params, "lr": self.decoder_lr})
|
740 |
|
741 |
-
|
742 |
-
# if not self.sep_lr:
|
743 |
-
# opt = torch.optim.AdamW(params, lr=lr)
|
744 |
-
# else:
|
745 |
-
# opt = torch.optim.AdamW(
|
746 |
-
# [
|
747 |
-
# {"params": params},
|
748 |
-
# {"params": decoder_params, "lr": self.decoder_lr}
|
749 |
-
# ], lr=lr
|
750 |
-
# )
|
751 |
if not self.freeze_glyph_image_encoder:
|
752 |
if self.glyph_image_encoder_type == "CLIP":
|
753 |
# assert self.sep_lr
|
@@ -866,20 +849,6 @@ class ControlLDM(LatentDiffusion):
|
|
866 |
if p.requires_grad and p.grad is not None:
|
867 |
grad_norm_v = p.grad.cpu().detach().norm().item()
|
868 |
gradnorm_list.append(grad_norm_v)
|
869 |
-
# for name, p in self.named_parameters():
|
870 |
-
# if p.requires_grad and p.grad is not None:
|
871 |
-
# grad_norm_v = p.grad.detach().norm().item()
|
872 |
-
# gradnorm_list.append(grad_norm_v)
|
873 |
-
# if "textemb_merge_model" in name:
|
874 |
-
# self.log("all_gradients/{}_norm".format(name),
|
875 |
-
# gradnorm_list[-1],
|
876 |
-
# prog_bar=False, logger=True, on_step=True, on_epoch=False
|
877 |
-
# )
|
878 |
-
# # if grad_norm_v > 0.1:
|
879 |
-
# # print("the norm of gradient w.r.t {} > 0.1: {:.2f}".format
|
880 |
-
# # (
|
881 |
-
# # name, grad_norm_v
|
882 |
-
# # ))
|
883 |
if len(gradnorm_list):
|
884 |
self.log("all_gradients/grad_norm_mean",
|
885 |
np.mean(gradnorm_list),
|
@@ -943,19 +912,4 @@ class ControlLDM(LatentDiffusion):
|
|
943 |
prog_bar=False, logger=True, on_step=True, on_epoch=False
|
944 |
)
|
945 |
del gradnorm_list
|
946 |
-
del zeroconvs
|
947 |
-
|
948 |
-
# def freeze_unet(self):
|
949 |
-
# # Have some bugs
|
950 |
-
# self.model.eval()
|
951 |
-
# # self.model.train = disabled_train
|
952 |
-
# for param in self.model.parameters():
|
953 |
-
# param.requires_grad = False
|
954 |
-
|
955 |
-
# if not self.sd_locked:
|
956 |
-
# self.model.diffusion_model.output_blocks.train()
|
957 |
-
# self.model.diffusion_model.out.train()
|
958 |
-
# for param in self.model.diffusion_model.out.parameters():
|
959 |
-
# param.requires_grad = True
|
960 |
-
# for param in self.model.diffusion_model.output_blocks.parameters():
|
961 |
-
# param.requires_grad = True
|
|
|
532 |
self.freeze_glyph_image_encoder = model.freeze_image_encoder #image_encoder.freeze_model
|
533 |
self.glyph_control_model = model
|
534 |
self.glyph_image_encoder_type = model.image_encoder_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
535 |
|
536 |
|
537 |
|
|
|
731 |
if decoder_params is not None:
|
732 |
params_wlr.append({"params": decoder_params, "lr": self.decoder_lr})
|
733 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
if not self.freeze_glyph_image_encoder:
|
735 |
if self.glyph_image_encoder_type == "CLIP":
|
736 |
# assert self.sep_lr
|
|
|
849 |
if p.requires_grad and p.grad is not None:
|
850 |
grad_norm_v = p.grad.cpu().detach().norm().item()
|
851 |
gradnorm_list.append(grad_norm_v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
852 |
if len(gradnorm_list):
|
853 |
self.log("all_gradients/grad_norm_mean",
|
854 |
np.mean(gradnorm_list),
|
|
|
912 |
prog_bar=False, logger=True, on_step=True, on_epoch=False
|
913 |
)
|
914 |
del gradnorm_list
|
915 |
+
del zeroconvs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config_ema_unlock.yaml
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-6 #1.0e-5 #1.0e-4
|
3 |
+
target: cldm.cldm.ControlLDM
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
control_key: "hint"
|
13 |
+
image_size: 64
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: false
|
16 |
+
conditioning_key: crossattn
|
17 |
+
monitor: #val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
only_mid_control: False
|
20 |
+
sd_locked: False #True
|
21 |
+
use_ema: True #TODO: specify
|
22 |
+
|
23 |
+
control_stage_config:
|
24 |
+
target: cldm.cldm.ControlNet
|
25 |
+
params:
|
26 |
+
use_checkpoint: True
|
27 |
+
image_size: 32 # unused
|
28 |
+
in_channels: 4
|
29 |
+
hint_channels: 3
|
30 |
+
model_channels: 320
|
31 |
+
attention_resolutions: [ 4, 2, 1 ]
|
32 |
+
num_res_blocks: 2
|
33 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
34 |
+
num_head_channels: 64 # need to fix for flash-attn
|
35 |
+
use_spatial_transformer: True
|
36 |
+
use_linear_in_transformer: True
|
37 |
+
transformer_depth: 1
|
38 |
+
context_dim: 1024
|
39 |
+
legacy: False
|
40 |
+
|
41 |
+
unet_config:
|
42 |
+
target: cldm.cldm.ControlledUnetModel
|
43 |
+
params:
|
44 |
+
use_checkpoint: True
|
45 |
+
image_size: 32 # unused
|
46 |
+
in_channels: 4
|
47 |
+
out_channels: 4
|
48 |
+
model_channels: 320
|
49 |
+
attention_resolutions: [ 4, 2, 1 ]
|
50 |
+
num_res_blocks: 2
|
51 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
52 |
+
num_head_channels: 64 # need to fix for flash-attn
|
53 |
+
use_spatial_transformer: True
|
54 |
+
use_linear_in_transformer: True
|
55 |
+
transformer_depth: 1
|
56 |
+
context_dim: 1024
|
57 |
+
legacy: False
|
58 |
+
|
59 |
+
first_stage_config:
|
60 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
61 |
+
params:
|
62 |
+
embed_dim: 4
|
63 |
+
monitor: val/rec_loss
|
64 |
+
ddconfig:
|
65 |
+
#attn_type: "vanilla-xformers"
|
66 |
+
double_z: true
|
67 |
+
z_channels: 4
|
68 |
+
resolution: 256
|
69 |
+
in_channels: 3
|
70 |
+
out_ch: 3
|
71 |
+
ch: 128
|
72 |
+
ch_mult:
|
73 |
+
- 1
|
74 |
+
- 2
|
75 |
+
- 4
|
76 |
+
- 4
|
77 |
+
num_res_blocks: 2
|
78 |
+
attn_resolutions: []
|
79 |
+
dropout: 0.0
|
80 |
+
lossconfig:
|
81 |
+
target: torch.nn.Identity
|
82 |
+
|
83 |
+
cond_stage_config:
|
84 |
+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
85 |
+
params:
|
86 |
+
freeze: True
|
87 |
+
layer: "penultimate"
|
88 |
+
# device: "cpu" #TODO: specify
|
laion1M_model_wo_ema.ckpt → textcaps5K_epoch_10_model_wo_ema.ckpt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c26cd80dcdd8b5563a68f397f291d0d2d7a4bef7a8c2435fd97a36be32ef61be
|
3 |
+
size 6671914001
|
textcaps5K_epoch_20_model_wo_ema.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85c887aa42db7afbed071629bcf5a07cfccdcdc80216475d8a2536fed75cc600
|
3 |
+
size 6671914001
|
textcaps5K_epoch_40_model_wo_ema.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:511be806f6e44f9c33af75df181adacfb3a0bb71aac8df8b303fff36e8e97dae
|
3 |
+
size 6671914001
|
transfer.py
CHANGED
@@ -2,10 +2,13 @@ from omegaconf import OmegaConf
|
|
2 |
from scripts.rendertext_tool import Render_Text, load_model_from_config
|
3 |
import torch
|
4 |
|
5 |
-
cfg = OmegaConf.load("config_ema.yaml")
|
6 |
-
# model = load_model_from_config(cfg, "model_states.pt", verbose=True)
|
7 |
-
model = load_model_from_config(cfg, "mp_rank_00_model_states.pt", verbose=True)
|
8 |
|
|
|
|
|
|
|
9 |
|
10 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
11 |
with model.ema_scope("store ema weights"):
|
@@ -18,6 +21,6 @@ with model.ema_scope("store ema weights"):
|
|
18 |
file_content = {
|
19 |
'state_dict': store_sd
|
20 |
}
|
21 |
-
torch.save(file_content, "
|
22 |
print("has stored the transfered ckpt.")
|
23 |
print("trial ends!")
|
|
|
2 |
from scripts.rendertext_tool import Render_Text, load_model_from_config
|
3 |
import torch
|
4 |
|
5 |
+
# cfg = OmegaConf.load("config_ema.yaml")
|
6 |
+
# # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
|
7 |
+
# model = load_model_from_config(cfg, "mp_rank_00_model_states.pt", verbose=True)
|
8 |
|
9 |
+
cfg = OmegaConf.load("config_ema_unlock.yaml")
|
10 |
+
epoch_idx = 39
|
11 |
+
model = load_model_from_config(cfg, "epoch={:0>6d}.ckpt".format(epoch_idx), verbose=True)
|
12 |
|
13 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
14 |
with model.ema_scope("store ema weights"):
|
|
|
21 |
file_content = {
|
22 |
'state_dict': store_sd
|
23 |
}
|
24 |
+
torch.save(file_content, f"textcaps5K_epoch_{epoch_idx+1}_model_wo_ema.ckpt")
|
25 |
print("has stored the transfered ckpt.")
|
26 |
print("trial ends!")
|