File size: 3,892 Bytes
6dd488f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import json
import random
import glob
import torch
import einops
import torchvision

import safetensors.torch as sf


def write_to_json(data, file_path):
    temp_file_path = file_path + ".tmp"
    with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
        json.dump(data, temp_file, indent=4)
    os.replace(temp_file_path, file_path)
    return


def read_from_json(file_path):
    with open(file_path, 'rt', encoding='utf-8') as file:
        data = json.load(file)
    return data


def get_active_parameters(m):
    return {k:v for k, v in m.named_parameters() if v.requires_grad}


def cast_training_params(m, dtype=torch.float32):
    for param in m.parameters():
        if param.requires_grad:
            param.data = param.to(dtype)
    return


def set_attr_recursive(obj, attr, value):
    attrs = attr.split(".")
    for name in attrs[:-1]:
        obj = getattr(obj, name)
    setattr(obj, attrs[-1], value)
    return


@torch.no_grad()
def batch_mixture(a, b, probability_a=0.5, mask_a=None):
    assert a.shape == b.shape, "Tensors must have the same shape"
    batch_size = a.size(0)

    if mask_a is None:
        mask_a = torch.rand(batch_size) < probability_a

    mask_a = mask_a.to(a.device)
    mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
    result = torch.where(mask_a, a, b)
    return result


@torch.no_grad()
def zero_module(module):
    for p in module.parameters():
        p.detach().zero_()
    return module


def load_last_state(model, folder='accelerator_output'):
    file_pattern = os.path.join(folder, '**', 'model.safetensors')
    files = glob.glob(file_pattern, recursive=True)

    if not files:
        print("No model.safetensors files found in the specified folder.")
        return

    newest_file = max(files, key=os.path.getmtime)
    state_dict = sf.load_file(newest_file)
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

    if missing_keys:
        print("Missing keys:", missing_keys)
    if unexpected_keys:
        print("Unexpected keys:", unexpected_keys)

    print("Loaded model state from:", newest_file)
    return


def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
    tags = tags_str.split(', ')
    tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
    prompt = ', '.join(tags)
    return prompt


def save_bcthw_as_mp4(x, output_filename, fps=10):
    b, c, t, h, w = x.shape

    per_row = b
    for p in [6, 5, 4, 3, 2]:
        if b % p == 0:
            per_row = p
            break

    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
    torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'})
    return x


def save_bcthw_as_png(x, output_filename):
    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
    torchvision.io.write_png(x, output_filename)
    return output_filename


def add_tensors_with_padding(tensor1, tensor2):
    if tensor1.shape == tensor2.shape:
        return tensor1 + tensor2

    shape1 = tensor1.shape
    shape2 = tensor2.shape

    new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))

    padded_tensor1 = torch.zeros(new_shape)
    padded_tensor2 = torch.zeros(new_shape)

    padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
    padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2

    result = padded_tensor1 + padded_tensor2
    return result