|
import os |
|
import json |
|
import random |
|
import glob |
|
import torch |
|
import einops |
|
import torchvision |
|
|
|
import safetensors.torch as sf |
|
|
|
|
|
def write_to_json(data, file_path): |
|
temp_file_path = file_path + ".tmp" |
|
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file: |
|
json.dump(data, temp_file, indent=4) |
|
os.replace(temp_file_path, file_path) |
|
return |
|
|
|
|
|
def read_from_json(file_path): |
|
with open(file_path, 'rt', encoding='utf-8') as file: |
|
data = json.load(file) |
|
return data |
|
|
|
|
|
def get_active_parameters(m): |
|
return {k:v for k, v in m.named_parameters() if v.requires_grad} |
|
|
|
|
|
def cast_training_params(m, dtype=torch.float32): |
|
for param in m.parameters(): |
|
if param.requires_grad: |
|
param.data = param.to(dtype) |
|
return |
|
|
|
|
|
def set_attr_recursive(obj, attr, value): |
|
attrs = attr.split(".") |
|
for name in attrs[:-1]: |
|
obj = getattr(obj, name) |
|
setattr(obj, attrs[-1], value) |
|
return |
|
|
|
|
|
@torch.no_grad() |
|
def batch_mixture(a, b, probability_a=0.5, mask_a=None): |
|
assert a.shape == b.shape, "Tensors must have the same shape" |
|
batch_size = a.size(0) |
|
|
|
if mask_a is None: |
|
mask_a = torch.rand(batch_size) < probability_a |
|
|
|
mask_a = mask_a.to(a.device) |
|
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1)) |
|
result = torch.where(mask_a, a, b) |
|
return result |
|
|
|
|
|
@torch.no_grad() |
|
def zero_module(module): |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
|
|
def load_last_state(model, folder='accelerator_output'): |
|
file_pattern = os.path.join(folder, '**', 'model.safetensors') |
|
files = glob.glob(file_pattern, recursive=True) |
|
|
|
if not files: |
|
print("No model.safetensors files found in the specified folder.") |
|
return |
|
|
|
newest_file = max(files, key=os.path.getmtime) |
|
state_dict = sf.load_file(newest_file) |
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
|
|
if missing_keys: |
|
print("Missing keys:", missing_keys) |
|
if unexpected_keys: |
|
print("Unexpected keys:", unexpected_keys) |
|
|
|
print("Loaded model state from:", newest_file) |
|
return |
|
|
|
|
|
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32): |
|
tags = tags_str.split(', ') |
|
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags))) |
|
prompt = ', '.join(tags) |
|
return prompt |
|
|
|
|
|
def save_bcthw_as_mp4(x, output_filename, fps=10): |
|
b, c, t, h, w = x.shape |
|
|
|
per_row = b |
|
for p in [6, 5, 4, 3, 2]: |
|
if b % p == 0: |
|
per_row = p |
|
break |
|
|
|
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) |
|
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 |
|
x = x.detach().cpu().to(torch.uint8) |
|
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) |
|
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'}) |
|
return x |
|
|
|
|
|
def save_bcthw_as_png(x, output_filename): |
|
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) |
|
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 |
|
x = x.detach().cpu().to(torch.uint8) |
|
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)') |
|
torchvision.io.write_png(x, output_filename) |
|
return output_filename |
|
|
|
|
|
def add_tensors_with_padding(tensor1, tensor2): |
|
if tensor1.shape == tensor2.shape: |
|
return tensor1 + tensor2 |
|
|
|
shape1 = tensor1.shape |
|
shape2 = tensor2.shape |
|
|
|
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) |
|
|
|
padded_tensor1 = torch.zeros(new_shape) |
|
padded_tensor2 = torch.zeros(new_shape) |
|
|
|
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1 |
|
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2 |
|
|
|
result = padded_tensor1 + padded_tensor2 |
|
return result |
|
|