Time-TravelRephotography / convert_weight.py
feng2022's picture
time
f9827f9
raw
history blame
8.6 kB
import argparse
import os
import sys
import pickle
import math
import torch
import numpy as np
from torchvision import utils
from model import Generator, Discriminator
def convert_modconv(vars, source_name, target_name, flip=False):
weight = vars[source_name + "/weight"].value().eval()
mod_weight = vars[source_name + "/mod_weight"].value().eval()
mod_bias = vars[source_name + "/mod_bias"].value().eval()
noise = vars[source_name + "/noise_strength"].value().eval()
bias = vars[source_name + "/bias"].value().eval()
dic = {
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
"conv.modulation.weight": mod_weight.transpose((1, 0)),
"conv.modulation.bias": mod_bias + 1,
"noise.weight": np.array([noise]),
"activate.bias": bias,
}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
if flip:
dic_torch[target_name + ".conv.weight"] = torch.flip(
dic_torch[target_name + ".conv.weight"], [3, 4]
)
return dic_torch
def convert_conv(vars, source_name, target_name, bias=True, start=0):
weight = vars[source_name + "/weight"].value().eval()
dic = {"weight": weight.transpose((3, 2, 0, 1))}
if bias:
dic["bias"] = vars[source_name + "/bias"].value().eval()
dic_torch = {}
dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
if bias:
dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
return dic_torch
def convert_torgb(vars, source_name, target_name):
weight = vars[source_name + "/weight"].value().eval()
mod_weight = vars[source_name + "/mod_weight"].value().eval()
mod_bias = vars[source_name + "/mod_bias"].value().eval()
bias = vars[source_name + "/bias"].value().eval()
dic = {
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
"conv.modulation.weight": mod_weight.transpose((1, 0)),
"conv.modulation.bias": mod_bias + 1,
"bias": bias.reshape((1, 3, 1, 1)),
}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
return dic_torch
def convert_dense(vars, source_name, target_name):
weight = vars[source_name + "/weight"].value().eval()
bias = vars[source_name + "/bias"].value().eval()
dic = {"weight": weight.transpose((1, 0)), "bias": bias}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
return dic_torch
def update(state_dict, new):
for k, v in new.items():
if k not in state_dict:
raise KeyError(k + " is not found")
if v.shape != state_dict[k].shape:
raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
state_dict[k] = v
def discriminator_fill_statedict(statedict, vars, size):
log_size = int(math.log(size, 2))
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
conv_i = 1
for i in range(log_size - 2, 0, -1):
reso = 4 * 2 ** i
update(
statedict,
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
),
)
conv_i += 1
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
return statedict
def fill_statedict(state_dict, vars, size, n_mlp):
log_size = int(math.log(size, 2))
for i in range(n_mlp):
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}"))
update(
state_dict,
{
"input.input": torch.from_numpy(
vars["G_synthesis/4x4/Const/const"].value().eval()
)
},
)
update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1"))
for i in range(log_size - 2):
reso = 4 * 2 ** (i + 1)
update(
state_dict,
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"),
)
update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1"))
conv_i = 0
for i in range(log_size - 2):
reso = 4 * 2 ** (i + 1)
update(
state_dict,
convert_modconv(
vars,
f"G_synthesis/{reso}x{reso}/Conv0_up",
f"convs.{conv_i}",
flip=True,
),
)
update(
state_dict,
convert_modconv(
vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}"
),
)
conv_i += 2
for i in range(0, (log_size - 2) * 2 + 1):
update(
state_dict,
{
f"noises.noise_{i}": torch.from_numpy(
vars[f"G_synthesis/noise{i}"].value().eval()
)
},
)
return state_dict
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(
description="Tensorflow to pytorch model checkpoint converter"
)
parser.add_argument(
"--repo",
type=str,
required=True,
help="path to the offical StyleGAN2 repository with dnnlib/ folder",
)
parser.add_argument(
"--gen", action="store_true", help="convert the generator weights"
)
parser.add_argument(
"--disc", action="store_true", help="convert the discriminator weights"
)
parser.add_argument(
"--channel_multiplier",
type=int,
default=2,
help="channel multiplier factor. config-f = 2, else = 1",
)
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
args = parser.parse_args()
sys.path.append(args.repo)
import dnnlib
from dnnlib import tflib
tflib.init_tf()
with open(args.path, "rb") as f:
generator, discriminator, g_ema = pickle.load(f)
size = g_ema.output_shape[2]
n_mlp = 0
mapping_layers_names = g_ema.__getstate__()['components']['mapping'].list_layers()
for layer in mapping_layers_names:
if layer[0].startswith('Dense'):
n_mlp += 1
g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
state_dict = g.state_dict()
state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)
g.load_state_dict(state_dict)
latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval())
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
if args.gen:
g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
g_train_state = g_train.state_dict()
g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp)
ckpt["g"] = g_train_state
if args.disc:
disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
d_state = disc.state_dict()
d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
ckpt["d"] = d_state
name = os.path.splitext(os.path.basename(args.path))[0]
torch.save(ckpt, name + ".pt")
batch_size = {256: 16, 512: 9, 1024: 4}
n_sample = batch_size.get(size, 25)
g = g.to(device)
z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")
with torch.no_grad():
img_pt, _ = g(
[torch.from_numpy(z).to(device)],
truncation=0.5,
truncation_latent=latent_avg.to(device),
randomize_noise=False,
)
Gs_kwargs = dnnlib.EasyDict()
Gs_kwargs.randomize_noise = False
img_tf = g_ema.run(z, None, **Gs_kwargs)
img_tf = torch.from_numpy(img_tf).to(device)
img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(
0.0, 1.0
)
img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)
print(img_diff.abs().max())
utils.save_image(
img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
)