File size: 6,668 Bytes
7ef93e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# this code is adapted from the script contributed by anon from /h/
# modified, from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/safe.py

import io
import pickle
import collections
import sys
import traceback

import torch
import numpy
import _codecs
import zipfile
import re


# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage


def encode(*args):
    out = _codecs.encode(*args)
    return out


class RestrictedUnpickler(pickle.Unpickler):
    extra_handler = None

    def persistent_load(self, saved_id):
        assert saved_id[0] == 'storage'
        return TypedStorage()

    def find_class(self, module, name):
        if self.extra_handler is not None:
            res = self.extra_handler(module, name)
            if res is not None:
                return res

        if module == 'collections' and name == 'OrderedDict':
            return getattr(collections, name)
        if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
            return getattr(torch._utils, name)
        if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
            return getattr(torch, name)
        if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
            return getattr(torch.nn.modules.container, name)
        if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
            return getattr(numpy.core.multiarray, name)
        if module == 'numpy' and name in ['dtype', 'ndarray']:
            return getattr(numpy, name)
        if module == '_codecs' and name == 'encode':
            return encode
        if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
            import pytorch_lightning.callbacks
            return pytorch_lightning.callbacks.model_checkpoint
        if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
            import pytorch_lightning.callbacks.model_checkpoint
            return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
        if module == "__builtin__" and name == 'set':
            return set

        # Forbid everything else.
        raise Exception(f"global '{module}/{name}' is forbidden")


# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")

def check_zip_filenames(filename, names):
    for name in names:
        if allowed_zip_names_re.match(name):
            continue

        raise Exception(f"bad file inside {filename}: {name}")


def check_pt(filename, extra_handler):
    try:

        # new pytorch format is a zip file
        with zipfile.ZipFile(filename) as z:
            check_zip_filenames(filename, z.namelist())

            # find filename of data.pkl in zip file: '<directory name>/data.pkl'
            data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
            if len(data_pkl_filenames) == 0:
                raise Exception(f"data.pkl not found in {filename}")
            if len(data_pkl_filenames) > 1:
                raise Exception(f"Multiple data.pkl found in {filename}")
            with z.open(data_pkl_filenames[0]) as file:
                unpickler = RestrictedUnpickler(file)
                unpickler.extra_handler = extra_handler
                unpickler.load()

    except zipfile.BadZipfile:

        # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
        with open(filename, "rb") as file:
            unpickler = RestrictedUnpickler(file)
            unpickler.extra_handler = extra_handler
            for i in range(5):
                unpickler.load()


def load(filename, *args, **kwargs):
    return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)


def load_with_extra(filename, extra_handler=None, *args, **kwargs):
    """
    this function is intended to be used by extensions that want to load models with
    some extra classes in them that the usual unpickler would find suspicious.

    Use the extra_handler argument to specify a function that takes module and field name as text,
    and returns that field's value:

    ```python
    def extra(module, name):
        if module == 'collections' and name == 'OrderedDict':
            return collections.OrderedDict

        return None

    safe.load_with_extra('model.pt', extra_handler=extra)
    ```

    The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
    definitely unsafe.
    """

    try:
        check_pt(filename, extra_handler)

    except pickle.UnpicklingError:
        print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)
        print("The file is most likely corrupted.", file=sys.stderr)
        return None

    except Exception:
        print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)
        print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
        print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
        return None

    return unsafe_torch_load(filename, *args, **kwargs)


class Extra:
    """
    A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
    (because it's not your code making the torch.load call). The intended use is like this:

```
import torch
from modules import safe

def handler(module, name):
    if module == 'torch' and name in ['float64', 'float16']:
        return getattr(torch, name)

    return None

with safe.Extra(handler):
    x = torch.load('model.pt')
```
    """

    def __init__(self, handler):
        self.handler = handler

    def __enter__(self):
        global global_extra_handler

        assert global_extra_handler is None, 'already inside an Extra() block'
        global_extra_handler = self.handler

    def __exit__(self, exc_type, exc_val, exc_tb):
        global global_extra_handler

        global_extra_handler = None


unsafe_torch_load = torch.load
torch.load = load
global_extra_handler = None