|
import re |
|
import torch |
|
|
|
|
|
src = torch.load("./models/xtts/model.pth", map_location="cpu")['model'] |
|
|
|
dst = { |
|
"ar": "./models/tortoise/autoregressive.pth", |
|
"df": "./models/tortoise/diffusion_decoder.pth", |
|
} |
|
|
|
for model, path in dst.items(): |
|
dst[model] = torch.load(path, map_location="cpu") |
|
torch.save( dst[model], f'{path}.bkp' ) |
|
|
|
|
|
regexes = { |
|
"ar": r'^gpt\.', |
|
"df": r'^diffusion_decoder\.', |
|
} |
|
for k, v in src.items(): |
|
for model, regex in regexes.items(): |
|
if re.match(regex, k): |
|
key = re.sub(regex, "", k) |
|
if key not in dst[model]: |
|
continue |
|
print(f"Writing {k} into {key}") |
|
dst[model][key] = v |
|
break |
|
|
|
|
|
torch.save(dst['ar'], "./models/tortoise/autoregressive.xtts.pth") |
|
torch.save(dst['df'], "./models/tortoise/diffusion_decoder.xtts.pth") |