File size: 13,296 Bytes
12bfd03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Arithmetic coder."""
import io
import math
import random
import typing as tp

import torch

from academicodec.binary import BitPacker
from academicodec.binary import BitUnpacker


def build_stable_quantized_cdf(pdf: torch.Tensor,
                               total_range_bits: int,
                               roundoff: float=1e-8,
                               min_range: int=2,
                               check: bool=True) -> torch.Tensor:
    """Turn the given PDF into a quantized CDF that splits
    [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
    to the PDF.

    Args:
        pdf (torch.Tensor): probability distribution, shape should be `[N]`.
        total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
            during the coding process is `[0, 2 ** total_range_bits - 1]`.
        roundoff (float): will round the pdf up to that level to remove difference coming
        from e.g. evaluating the Language Model on different architectures.
        min_range (int): minimum range width. Should always be at least 2 for numerical
            stability. Use this to avoid pathological behavior is a value
            that is expected to be rare actually happens in real life.
        check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
    """
    pdf = pdf.detach()
    if roundoff:
        pdf = (pdf / roundoff).floor() * roundoff
    # interpolate with uniform distribution to achieve desired minimum probability.
    total_range = 2**total_range_bits
    cardinality = len(pdf)
    alpha = min_range * cardinality / total_range
    assert alpha <= 1, "you must reduce min_range"
    ranges = (((1 - alpha) * total_range) * pdf).floor().long()
    ranges += min_range
    quantized_cdf = torch.cumsum(ranges, dim=-1)
    if min_range < 2:
        raise ValueError("min_range must be at least 2.")
    if check:
        assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
        if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
            ).any() or quantized_cdf[0] < min_range:
            raise ValueError("You must increase your total_range_bits.")
    return quantized_cdf


class ArithmeticCoder:
    """ArithmeticCoder,
    Let us take a distribution `p` over `N` symbols, and assume we have a stream
    of random variables `s_t` sampled from `p`. Let us assume that we have a budget
    of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
    corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
    sequence `(s_t)` by doing the following:

    1) Initialize the current range to` [0 ** 2 B - 1]`.
    2) For each time step t, split the current range into contiguous chunks,
        one for each possible outcome, with size roughly proportional to `p`.
        For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
        would be `{[0, 2], [3, 3]}`.
    3) Select the chunk corresponding to `s_t`, and replace the current range with this.
    4) When done encoding all the values, just select any value remaining in the range.

    You will notice that this procedure can fail: for instance if at any point in time
    the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
    possible outcome. Intuitively, the more likely a value is, the less the range width
    will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
    coding scheme, likely outcomes would take less bits, and more of them can be coded
    with a fixed budget.

    In practice, we do not know `B` ahead of time, but we have a way to inject new bits
    when the current range decreases below a given limit (given by `total_range_bits`), without
    having to redo all the computations. If we encode mostly likely values, we will seldom
    need to inject new bits, but a single rare value can deplete our stock of entropy!

    In this explanation, we assumed that the distribution `p` was constant. In fact, the present
    code works for any sequence `(p_t)` possibly different for each timestep.
    We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
    the KL between the true distribution and `p_t`, the most efficient the coding will be.

    Args:
        fo (IO[bytes]): file-like object to which the bytes will be written to.
        total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
            Any time the current range width fall under this limit, new bits will
            be injected to rescale the initial range.
    """

    def __init__(self, fo: tp.IO[bytes], total_range_bits: int=24):
        assert total_range_bits <= 30
        self.total_range_bits = total_range_bits
        self.packer = BitPacker(bits=1, fo=fo)  # we push single bits at a time.
        self.low: int = 0
        self.high: int = 0
        self.max_bit: int = -1
        self._dbg: tp.List[tp.Any] = []
        self._dbg2: tp.List[tp.Any] = []

    @property
    def delta(self) -> int:
        """Return the current range width."""
        return self.high - self.low + 1

    def _flush_common_prefix(self):
        # If self.low and self.high start with the sames bits,
        # those won't change anymore as we always just increase the range
        # by powers of 2, and we can flush them out to the bit stream.
        assert self.high >= self.low, (self.low, self.high)
        assert self.high < 2**(self.max_bit + 1)
        while self.max_bit >= 0:
            b1 = self.low >> self.max_bit
            b2 = self.high >> self.max_bit
            if b1 == b2:
                self.low -= (b1 << self.max_bit)
                self.high -= (b1 << self.max_bit)
                assert self.high >= self.low, (self.high, self.low,
                                               self.max_bit)
                assert self.low >= 0
                self.max_bit -= 1
                self.packer.push(b1)
            else:
                break

    def push(self, symbol: int, quantized_cdf: torch.Tensor):
        """Push the given symbol on the stream, flushing out bits
        if possible.

        Args:
            symbol (int): symbol to encode with the AC.
            quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
                to build this from your pdf estimate.
        """
        while self.delta < 2**self.total_range_bits:
            self.low *= 2
            self.high = self.high * 2 + 1
            self.max_bit += 1

        range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
        range_high = quantized_cdf[symbol].item() - 1
        effective_low = int(
            math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
        effective_high = int(
            math.floor(range_high * (self.delta / (2**self.total_range_bits))))
        assert self.low <= self.high
        self.high = self.low + effective_high
        self.low = self.low + effective_low
        assert self.low <= self.high, (effective_low, effective_high, range_low,
                                       range_high)
        self._dbg.append((self.low, self.high))
        self._dbg2.append((self.low, self.high))
        outs = self._flush_common_prefix()
        assert self.low <= self.high
        assert self.max_bit >= -1
        assert self.max_bit <= 61, self.max_bit
        return outs

    def flush(self):
        """Flush the remaining information to the stream.
        """
        while self.max_bit >= 0:
            b1 = (self.low >> self.max_bit) & 1
            self.packer.push(b1)
            self.max_bit -= 1
        self.packer.flush()


class ArithmeticDecoder:
    """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.

    Note that this must be called with **exactly** the same parameters and sequence
    of quantized cdf as the arithmetic encoder or the wrong values will be decoded.

    If the AC encoder current range is [L, H], with `L` and `H` having the some common
    prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
    For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
    `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
    for a specific sequence of symbols and a binary-search allows us to decode those symbols.
    At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
    and we will need to read new bits from the stream and repeat the process.

    """

    def __init__(self, fo: tp.IO[bytes], total_range_bits: int=24):
        self.total_range_bits = total_range_bits
        self.low: int = 0
        self.high: int = 0
        self.current: int = 0
        self.max_bit: int = -1
        self.unpacker = BitUnpacker(
            bits=1, fo=fo)  # we pull single bits at a time.
        # Following is for debugging
        self._dbg: tp.List[tp.Any] = []
        self._dbg2: tp.List[tp.Any] = []
        self._last: tp.Any = None

    @property
    def delta(self) -> int:
        return self.high - self.low + 1

    def _flush_common_prefix(self):
        # Given the current range [L, H], if both have a common prefix,
        # we know we can remove it from our representation to avoid handling large numbers.
        while self.max_bit >= 0:
            b1 = self.low >> self.max_bit
            b2 = self.high >> self.max_bit
            if b1 == b2:
                self.low -= (b1 << self.max_bit)
                self.high -= (b1 << self.max_bit)
                self.current -= (b1 << self.max_bit)
                assert self.high >= self.low
                assert self.low >= 0
                self.max_bit -= 1
            else:
                break

    def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
        """Pull a symbol, reading as many bits from the stream as required.
        This returns `None` when the stream has been exhausted.

        Args:
            quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
                to build this from your pdf estimate. This must be **exatly**
                the same cdf as the one used at encoding time.
        """
        while self.delta < 2**self.total_range_bits:
            bit = self.unpacker.pull()
            if bit is None:
                return None
            self.low *= 2
            self.high = self.high * 2 + 1
            self.current = self.current * 2 + bit
            self.max_bit += 1

        def bin_search(low_idx: int, high_idx: int):
            # Binary search is not just for coding interviews :)
            if high_idx < low_idx:
                raise RuntimeError("Binary search failed")
            mid = (low_idx + high_idx) // 2
            range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
            range_high = quantized_cdf[mid].item() - 1
            effective_low = int(
                math.ceil(range_low * (self.delta / (2**self.total_range_bits)
                                       )))
            effective_high = int(
                math.floor(range_high * (self.delta / (2**self.total_range_bits)
                                         )))
            low = effective_low + self.low
            high = effective_high + self.low
            if self.current >= low:
                if self.current <= high:
                    return (mid, low, high, self.current)
                else:
                    return bin_search(mid + 1, high_idx)
            else:
                return bin_search(low_idx, mid - 1)

        self._last = (self.low, self.high, self.current, self.max_bit)
        sym, self.low, self.high, self.current = bin_search(
            0, len(quantized_cdf) - 1)
        self._dbg.append((self.low, self.high, self.current))
        self._flush_common_prefix()
        self._dbg2.append((self.low, self.high, self.current))

        return sym


def test():
    torch.manual_seed(1234)
    random.seed(1234)
    for _ in range(4):
        pdfs = []
        cardinality = random.randrange(4000)
        steps = random.randrange(100, 500)
        fo = io.BytesIO()
        encoder = ArithmeticCoder(fo)
        symbols = []
        for step in range(steps):
            pdf = torch.softmax(torch.randn(cardinality), dim=0)
            pdfs.append(pdf)
            q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
            symbol = torch.multinomial(pdf, 1).item()
            symbols.append(symbol)
            encoder.push(symbol, q_cdf)
        encoder.flush()

        fo.seek(0)
        decoder = ArithmeticDecoder(fo)
        for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
            q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
            decoded_symbol = decoder.pull(q_cdf)
            assert decoded_symbol == symbol, idx
        assert decoder.pull(torch.zeros(1)) is None


if __name__ == "__main__":
    test()