Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,201 Bytes
d6bc023 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
import math
from typing import Callable, Tuple
import torch
def do_nothing(x, mode=None):
return x
def bipartite_soft_matching(
metric: torch.Tensor,
r: int,
class_token: bool = False,
distill_token: bool = False,
) -> Tuple[Callable, Callable]:
"""
Applies ToMe with a balanced matching set (50%, 50%).
Input size is [batch, tokens, channels].
r indicates the number of tokens to remove (max 50% of tokens).
Extra args:
- class_token: Whether or not there's a class token.
- distill_token: Whether or not there's also a distillation token.
When enabled, the class token and distillation tokens won't get merged.
"""
protected = 0
if class_token:
protected += 1
if distill_token:
protected += 1
# We can only reduce by a maximum of 50% tokens
t = metric.shape[1]
r = min(r, (t - protected) // 2)
if r <= 0:
return do_nothing, do_nothing
with torch.no_grad():
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = metric[..., ::2, :], metric[..., 1::2, :]
scores = a @ b.transpose(-1, -2)
if class_token:
scores[..., 0, :] = -math.inf
if distill_token:
scores[..., :, 0] = -math.inf
node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
if class_token:
# Sort to ensure the class token is at the start
unm_idx = unm_idx.sort(dim=1)[0]
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = x[..., ::2, :], x[..., 1::2, :]
n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
if distill_token:
return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
else:
return torch.cat([unm, dst], dim=1)
def unmerge(x: torch.Tensor) -> torch.Tensor:
unm_len = unm_idx.shape[1]
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
n, _, c = unm.shape
src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))
out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)
out[..., 1::2, :] = dst
out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)
return out
return merge, unmerge
def kth_bipartite_soft_matching(
metric: torch.Tensor, k: int
) -> Tuple[Callable, Callable]:
"""
Applies ToMe with the two sets as (every kth element, the rest).
If n is the number of tokens, resulting number of tokens will be n // z.
Input size is [batch, tokens, channels].
z indicates the stride for the first set.
z = 2 is equivalent to regular bipartite_soft_matching with r = 0.5 * N
"""
if k <= 1:
return do_nothing, do_nothing
def split(x):
t_rnd = (x.shape[1] // k) * k
x = x[:, :t_rnd, :].view(x.shape[0], -1, k, x.shape[2])
a, b = (
x[:, :, : (k - 1), :].contiguous().view(x.shape[0], -1, x.shape[-1]),
x[:, :, (k - 1), :],
)
return a, b
with torch.no_grad():
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
r = a.shape[1]
scores = a @ b.transpose(-1, -2)
_, dst_idx = scores.max(dim=-1)
dst_idx = dst_idx[..., None]
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, _, c = src.shape
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return dst
def unmerge(x: torch.Tensor) -> torch.Tensor:
n, _, c = x.shape
dst = x
src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)).to(x.dtype)
src = src.view(n, -1, (k - 1), c)
dst = dst.view(n, -1, 1, c)
out = torch.cat([src, dst], dim=-2)
out = out.contiguous().view(n, -1, c)
return out
return merge, unmerge
def random_bipartite_soft_matching(
metric: torch.Tensor, r: int
) -> Tuple[Callable, Callable]:
"""
Applies ToMe with the two sets as (r chosen randomly, the rest).
Input size is [batch, tokens, channels].
This will reduce the number of tokens by r.
"""
if r <= 0:
return do_nothing, do_nothing
with torch.no_grad():
B, N, _ = metric.shape
rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1)
a_idx = rand_idx[:, :r, :]
b_idx = rand_idx[:, r:, :]
def split(x):
C = x.shape[-1]
a = x.gather(dim=1, index=a_idx.expand(B, r, C))
b = x.gather(dim=1, index=b_idx.expand(B, N - r, C))
return a, b
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
scores = a @ b.transpose(-1, -2)
_, dst_idx = scores.max(dim=-1)
dst_idx = dst_idx[..., None]
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
C = src.shape[-1]
dst = dst.scatter_reduce(-2, dst_idx.expand(B, r, C), src, reduce=mode)
return dst
def unmerge(x: torch.Tensor) -> torch.Tensor:
C = x.shape[-1]
dst = x
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, C))
out = torch.zeros(B, N, C, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=a_idx.expand(B, r, C), src=src)
out.scatter_(dim=-2, index=b_idx.expand(B, N - r, C), src=dst)
return out
return merge, unmerge
def merge_wavg(
merge: Callable, x: torch.Tensor, size: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Applies the merge function by taking a weighted average based on token size.
Returns the merged tensor and the new token sizes.
"""
if size is None:
size = torch.ones_like(x[..., 0, None])
x = merge(x * size, mode="sum")
size = merge(size, mode="sum")
x = x / size
return x, size
def merge_source(
merge: Callable, x: torch.Tensor, source: torch.Tensor = None
) -> torch.Tensor:
"""
For source tracking. Source is an adjacency matrix between the initial tokens and final merged groups.
x is used to find out how many tokens there are in case the source is None.
"""
if source is None:
n, t, _ = x.shape
source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t)
source = merge(source, mode="amax")
return source |