Upload folder using huggingface_hub
Browse files- configuration_act_estimator.py +29 -0
- model.py +648 -0
- modeling_act_estimator.py +36 -0
configuration_act_estimator.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ActEstimatorConfig(PretrainedConfig):
|
5 |
+
model_type = "ACT-Estimator"
|
6 |
+
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
input_shape=(3, 44, 224, 224),
|
10 |
+
num_classes=9,
|
11 |
+
max_seq_len=44,
|
12 |
+
timestamp_dim=1,
|
13 |
+
d_model=512,
|
14 |
+
num_heads=8,
|
15 |
+
dropout=0.1,
|
16 |
+
feature_map_size=4,
|
17 |
+
**kwargs
|
18 |
+
):
|
19 |
+
self.input_shape = input_shape
|
20 |
+
self.num_classes = num_classes
|
21 |
+
self.max_seq_len = max_seq_len
|
22 |
+
self.timestamp_dim = timestamp_dim
|
23 |
+
self.d_model = d_model
|
24 |
+
self.num_heads = num_heads
|
25 |
+
self.dropout = dropout
|
26 |
+
self.feature_map_size = feature_map_size
|
27 |
+
super().__init__(**kwargs)
|
28 |
+
|
29 |
+
|
model.py
ADDED
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections.abc import Sequence
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor, nn
|
7 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
8 |
+
|
9 |
+
from transformers import AutoModel, PreTrainedModel
|
10 |
+
|
11 |
+
|
12 |
+
class MaxPool3dSamePadding(nn.MaxPool3d):
|
13 |
+
def compute_pad(self, dim, s):
|
14 |
+
if s % self.stride[dim] == 0:
|
15 |
+
return max(self.kernel_size[dim] - self.stride[dim], 0)
|
16 |
+
else:
|
17 |
+
return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
(batch, channel, t, h, w) = x.size()
|
21 |
+
pad_t = self.compute_pad(0, t)
|
22 |
+
pad_h = self.compute_pad(1, h)
|
23 |
+
pad_w = self.compute_pad(2, w)
|
24 |
+
|
25 |
+
pad_t_f = pad_t // 2
|
26 |
+
pad_t_b = pad_t - pad_t_f
|
27 |
+
pad_h_f = pad_h // 2
|
28 |
+
pad_h_b = pad_h - pad_h_f
|
29 |
+
pad_w_f = pad_w // 2
|
30 |
+
pad_w_b = pad_w - pad_w_f
|
31 |
+
|
32 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
33 |
+
x = F.pad(x, pad)
|
34 |
+
return super().forward(x)
|
35 |
+
|
36 |
+
|
37 |
+
class Unit3D(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
in_channels,
|
41 |
+
output_channels,
|
42 |
+
kernel_shape=(1, 1, 1),
|
43 |
+
stride=(1, 1, 1),
|
44 |
+
padding=0,
|
45 |
+
activation_fn=F.relu,
|
46 |
+
use_batch_norm=True,
|
47 |
+
use_bias=False,
|
48 |
+
name="unit_3d",
|
49 |
+
):
|
50 |
+
"""Initializes Unit3D module."""
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
self._output_channels = output_channels
|
54 |
+
self._kernel_shape = kernel_shape
|
55 |
+
self._stride = stride
|
56 |
+
self._use_batch_norm = use_batch_norm
|
57 |
+
self._activation_fn = activation_fn
|
58 |
+
self._use_bias = use_bias
|
59 |
+
self.name = name
|
60 |
+
self.padding = padding
|
61 |
+
|
62 |
+
self.conv3d = nn.Conv3d(
|
63 |
+
in_channels=in_channels,
|
64 |
+
out_channels=self._output_channels,
|
65 |
+
kernel_size=self._kernel_shape,
|
66 |
+
stride=self._stride,
|
67 |
+
padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function
|
68 |
+
bias=self._use_bias,
|
69 |
+
)
|
70 |
+
|
71 |
+
if self._use_batch_norm:
|
72 |
+
self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01)
|
73 |
+
|
74 |
+
def compute_pad(self, dim, s):
|
75 |
+
if s % self._stride[dim] == 0:
|
76 |
+
return max(self._kernel_shape[dim] - self._stride[dim], 0)
|
77 |
+
else:
|
78 |
+
return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
(batch, channel, t, h, w) = x.size()
|
82 |
+
pad_t = self.compute_pad(0, t)
|
83 |
+
pad_h = self.compute_pad(1, h)
|
84 |
+
pad_w = self.compute_pad(2, w)
|
85 |
+
|
86 |
+
pad_t_f = pad_t // 2
|
87 |
+
pad_t_b = pad_t - pad_t_f
|
88 |
+
pad_h_f = pad_h // 2
|
89 |
+
pad_h_b = pad_h - pad_h_f
|
90 |
+
pad_w_f = pad_w // 2
|
91 |
+
pad_w_b = pad_w - pad_w_f
|
92 |
+
|
93 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
94 |
+
x = F.pad(x, pad)
|
95 |
+
|
96 |
+
x = self.conv3d(x)
|
97 |
+
if self._use_batch_norm:
|
98 |
+
x = self.bn(x)
|
99 |
+
if self._activation_fn is not None:
|
100 |
+
x = self._activation_fn(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class InceptionModule(nn.Module):
|
105 |
+
def __init__(self, in_channels, out_channels, name):
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
self.b0 = Unit3D(
|
109 |
+
in_channels=in_channels,
|
110 |
+
output_channels=out_channels[0],
|
111 |
+
kernel_shape=[1, 1, 1],
|
112 |
+
padding=0,
|
113 |
+
name=name + "/Branch_0/Conv3d_0a_1x1",
|
114 |
+
)
|
115 |
+
self.b1a = Unit3D(
|
116 |
+
in_channels=in_channels,
|
117 |
+
output_channels=out_channels[1],
|
118 |
+
kernel_shape=[1, 1, 1],
|
119 |
+
padding=0,
|
120 |
+
name=name + "/Branch_1/Conv3d_0a_1x1",
|
121 |
+
)
|
122 |
+
self.b1b = Unit3D(
|
123 |
+
in_channels=out_channels[1],
|
124 |
+
output_channels=out_channels[2],
|
125 |
+
kernel_shape=[3, 3, 3],
|
126 |
+
name=name + "/Branch_1/Conv3d_0b_3x3",
|
127 |
+
)
|
128 |
+
self.b2a = Unit3D(
|
129 |
+
in_channels=in_channels,
|
130 |
+
output_channels=out_channels[3],
|
131 |
+
kernel_shape=[1, 1, 1],
|
132 |
+
padding=0,
|
133 |
+
name=name + "/Branch_2/Conv3d_0a_1x1",
|
134 |
+
)
|
135 |
+
self.b2b = Unit3D(
|
136 |
+
in_channels=out_channels[3],
|
137 |
+
output_channels=out_channels[4],
|
138 |
+
kernel_shape=[3, 3, 3],
|
139 |
+
name=name + "/Branch_2/Conv3d_0b_3x3",
|
140 |
+
)
|
141 |
+
self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(1, 1, 1), padding=0)
|
142 |
+
self.b3b = Unit3D(
|
143 |
+
in_channels=in_channels,
|
144 |
+
output_channels=out_channels[5],
|
145 |
+
kernel_shape=[1, 1, 1],
|
146 |
+
padding=0,
|
147 |
+
name=name + "/Branch_3/Conv3d_0b_1x1",
|
148 |
+
)
|
149 |
+
self.name = name
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
b0 = self.b0(x)
|
153 |
+
b1 = self.b1b(self.b1a(x))
|
154 |
+
b2 = self.b2b(self.b2a(x))
|
155 |
+
b3 = self.b3b(self.b3a(x))
|
156 |
+
return torch.cat([b0, b1, b2, b3], dim=1)
|
157 |
+
|
158 |
+
|
159 |
+
class InceptionI3d(nn.Module):
|
160 |
+
"""Inception-v1 I3D architecture.
|
161 |
+
The model is introduced in:
|
162 |
+
Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
|
163 |
+
Joao Carreira, Andrew Zisserman
|
164 |
+
https://arxiv.org/pdf/1705.07750v1.pdf.
|
165 |
+
See also the Inception architecture, introduced in:
|
166 |
+
Going deeper with convolutions
|
167 |
+
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
|
168 |
+
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
|
169 |
+
http://arxiv.org/pdf/1409.4842v1.pdf.
|
170 |
+
"""
|
171 |
+
|
172 |
+
# Endpoints of the model in order. During construction, all the endpoints up
|
173 |
+
# to a designated `final_endpoint` are returned in a dictionary as the
|
174 |
+
# second return value.
|
175 |
+
VALID_ENDPOINTS = (
|
176 |
+
"Conv3d_1a_7x7",
|
177 |
+
"MaxPool3d_2a_3x3",
|
178 |
+
"Conv3d_2b_1x1",
|
179 |
+
"Conv3d_2c_3x3",
|
180 |
+
"MaxPool3d_3a_3x3",
|
181 |
+
"Mixed_3b",
|
182 |
+
"Mixed_3c",
|
183 |
+
"MaxPool3d_4a_3x3",
|
184 |
+
"Mixed_4b",
|
185 |
+
"Mixed_4c",
|
186 |
+
"Mixed_4d",
|
187 |
+
"Mixed_4e",
|
188 |
+
"Mixed_4f",
|
189 |
+
"MaxPool3d_5a_2x2",
|
190 |
+
"Mixed_5b",
|
191 |
+
"Mixed_5c",
|
192 |
+
"Logits",
|
193 |
+
"Predictions",
|
194 |
+
)
|
195 |
+
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
time_spatial_squeeze=True,
|
199 |
+
final_endpoint="Logits",
|
200 |
+
name="inception_i3d",
|
201 |
+
in_channels=3,
|
202 |
+
):
|
203 |
+
"""Initializes I3D model instance.
|
204 |
+
Args:
|
205 |
+
num_classes: The number of outputs in the logit layer (default 400, which
|
206 |
+
matches the Kinetics dataset).
|
207 |
+
spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
|
208 |
+
before returning (default True).
|
209 |
+
final_endpoint: The model contains many possible endpoints.
|
210 |
+
`final_endpoint` specifies the last endpoint for the model to be built
|
211 |
+
up to. In addition to the output at `final_endpoint`, all the outputs
|
212 |
+
at endpoints up to `final_endpoint` will also be returned, in a
|
213 |
+
dictionary. `final_endpoint` must be one of
|
214 |
+
InceptionI3d.VALID_ENDPOINTS (default 'Logits').
|
215 |
+
name: A string (optional). The name of this module.
|
216 |
+
Raises:
|
217 |
+
ValueError: if `final_endpoint` is not recognized.
|
218 |
+
"""
|
219 |
+
|
220 |
+
if final_endpoint not in self.VALID_ENDPOINTS:
|
221 |
+
raise ValueError(f"Unknown final endpoint {final_endpoint}")
|
222 |
+
|
223 |
+
super().__init__()
|
224 |
+
self._time_spatial_squeeze = time_spatial_squeeze
|
225 |
+
self._final_endpoint = final_endpoint
|
226 |
+
self.logits = None
|
227 |
+
|
228 |
+
if self._final_endpoint not in self.VALID_ENDPOINTS:
|
229 |
+
raise ValueError(f"Unknown final endpoint {self._final_endpoint}")
|
230 |
+
|
231 |
+
self.end_points = {}
|
232 |
+
end_point = "Conv3d_1a_7x7"
|
233 |
+
self.end_points[end_point] = Unit3D(
|
234 |
+
in_channels=in_channels,
|
235 |
+
output_channels=64,
|
236 |
+
kernel_shape=[7, 7, 7],
|
237 |
+
stride=(2, 2, 2),
|
238 |
+
padding=(3, 3, 3),
|
239 |
+
name=name + end_point,
|
240 |
+
)
|
241 |
+
if self._final_endpoint == end_point:
|
242 |
+
return
|
243 |
+
|
244 |
+
end_point = "MaxPool3d_2a_3x3"
|
245 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
|
246 |
+
if self._final_endpoint == end_point:
|
247 |
+
return
|
248 |
+
|
249 |
+
end_point = "Conv3d_2b_1x1"
|
250 |
+
self.end_points[end_point] = Unit3D(
|
251 |
+
in_channels=64,
|
252 |
+
output_channels=64,
|
253 |
+
kernel_shape=[1, 1, 1],
|
254 |
+
padding=0,
|
255 |
+
name=name + end_point,
|
256 |
+
)
|
257 |
+
if self._final_endpoint == end_point:
|
258 |
+
return
|
259 |
+
|
260 |
+
end_point = "Conv3d_2c_3x3"
|
261 |
+
self.end_points[end_point] = Unit3D(
|
262 |
+
in_channels=64,
|
263 |
+
output_channels=192,
|
264 |
+
kernel_shape=[3, 3, 3],
|
265 |
+
padding=1,
|
266 |
+
name=name + end_point,
|
267 |
+
)
|
268 |
+
if self._final_endpoint == end_point:
|
269 |
+
return
|
270 |
+
|
271 |
+
end_point = "MaxPool3d_3a_3x3"
|
272 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
|
273 |
+
if self._final_endpoint == end_point:
|
274 |
+
return
|
275 |
+
|
276 |
+
end_point = "Mixed_3b"
|
277 |
+
self.end_points[end_point] = InceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point)
|
278 |
+
if self._final_endpoint == end_point:
|
279 |
+
return
|
280 |
+
|
281 |
+
end_point = "Mixed_3c"
|
282 |
+
self.end_points[end_point] = InceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point)
|
283 |
+
if self._final_endpoint == end_point:
|
284 |
+
return
|
285 |
+
|
286 |
+
end_point = "MaxPool3d_4a_3x3"
|
287 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0)
|
288 |
+
if self._final_endpoint == end_point:
|
289 |
+
return
|
290 |
+
|
291 |
+
end_point = "Mixed_4b"
|
292 |
+
self.end_points[end_point] = InceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
|
293 |
+
if self._final_endpoint == end_point:
|
294 |
+
return
|
295 |
+
|
296 |
+
end_point = "Mixed_4c"
|
297 |
+
self.end_points[end_point] = InceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
|
298 |
+
if self._final_endpoint == end_point:
|
299 |
+
return
|
300 |
+
|
301 |
+
end_point = "Mixed_4d"
|
302 |
+
self.end_points[end_point] = InceptionModule(160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)
|
303 |
+
if self._final_endpoint == end_point:
|
304 |
+
return
|
305 |
+
|
306 |
+
end_point = "Mixed_4e"
|
307 |
+
self.end_points[end_point] = InceptionModule(128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)
|
308 |
+
if self._final_endpoint == end_point:
|
309 |
+
return
|
310 |
+
|
311 |
+
end_point = "Mixed_4f"
|
312 |
+
self.end_points[end_point] = InceptionModule(
|
313 |
+
112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128], name + end_point
|
314 |
+
)
|
315 |
+
if self._final_endpoint == end_point:
|
316 |
+
return
|
317 |
+
|
318 |
+
end_point = "MaxPool3d_5a_2x2"
|
319 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 2, 2], stride=(1, 2, 2), padding=0)
|
320 |
+
if self._final_endpoint == end_point:
|
321 |
+
return
|
322 |
+
|
323 |
+
end_point = "Mixed_5b"
|
324 |
+
self.end_points[end_point] = InceptionModule(
|
325 |
+
256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128], name + end_point
|
326 |
+
)
|
327 |
+
if self._final_endpoint == end_point:
|
328 |
+
return
|
329 |
+
|
330 |
+
end_point = "Mixed_5c"
|
331 |
+
self.end_points[end_point] = InceptionModule(
|
332 |
+
256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128], name + end_point
|
333 |
+
)
|
334 |
+
|
335 |
+
if self._final_endpoint == end_point:
|
336 |
+
return
|
337 |
+
|
338 |
+
self.build()
|
339 |
+
|
340 |
+
def build(self):
|
341 |
+
for k in self.end_points.keys():
|
342 |
+
self.add_module(k, self.end_points[k])
|
343 |
+
|
344 |
+
def get_out_size(self, shape: Sequence[int], dim=None) -> int:
|
345 |
+
device = next(self.parameters()).device
|
346 |
+
out = self(torch.zeros((1, *shape), device=device))
|
347 |
+
return out.size(dim)
|
348 |
+
|
349 |
+
def forward(self, x):
|
350 |
+
for end_point in self.VALID_ENDPOINTS:
|
351 |
+
if end_point in self.end_points:
|
352 |
+
x = self._modules[end_point](x) # use _modules to work with dataparallel
|
353 |
+
return x
|
354 |
+
|
355 |
+
|
356 |
+
class PositionalEncoding(nn.Module):
|
357 |
+
def __init__(self, d_model: int, max_len: int = 5000) -> None:
|
358 |
+
super().__init__()
|
359 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
360 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
361 |
+
pe = torch.zeros(max_len, d_model)
|
362 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
363 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
364 |
+
pe = pe.unsqueeze(0)
|
365 |
+
self.register_buffer("pe", pe)
|
366 |
+
|
367 |
+
def forward(self, x: Tensor) -> Tensor:
|
368 |
+
"""
|
369 |
+
Args:
|
370 |
+
x (Tensor): shape [batch_size, seq_len, embedding_dim]
|
371 |
+
"""
|
372 |
+
x = x + self.pe[:, : x.size(1), :]
|
373 |
+
return x
|
374 |
+
|
375 |
+
|
376 |
+
class CrossAttention(nn.Module):
|
377 |
+
def __init__(self, dim_q, dim_k, dim_v, dim_out, num_heads):
|
378 |
+
super().__init__()
|
379 |
+
self.num_heads = num_heads
|
380 |
+
self.head_dim = dim_out // num_heads
|
381 |
+
assert dim_out % num_heads == 0, "dim_out must be divisible by num_heads"
|
382 |
+
self.scale = self.head_dim**-0.5
|
383 |
+
|
384 |
+
self.query_proj = nn.Linear(dim_q, dim_out)
|
385 |
+
self.key_proj = nn.Linear(dim_k, dim_out)
|
386 |
+
self.value_proj = nn.Linear(dim_v, dim_out)
|
387 |
+
|
388 |
+
self.out_proj = nn.Linear(dim_out, dim_out)
|
389 |
+
|
390 |
+
def forward(self, query, key, value):
|
391 |
+
# Linear transformation of query, key, and value
|
392 |
+
q = self.query_proj(query) # shape: (batch_size, query_len, dim_out)
|
393 |
+
k = self.key_proj(key) # shape: (batch_size, key_len, dim_out)
|
394 |
+
v = self.value_proj(value) # shape: (batch_size, value_len, dim_out)
|
395 |
+
|
396 |
+
# Split dimensions for multi-head attention, and compute per head
|
397 |
+
# print("q:", q.size(), "k:", k.size(), "v:", v.size())
|
398 |
+
q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim).transpose(1, 2)
|
399 |
+
k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim).transpose(1, 2)
|
400 |
+
v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim).transpose(1, 2)
|
401 |
+
|
402 |
+
# Scaled dot-product attention
|
403 |
+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
404 |
+
attn_weights = attn_weights.softmax(dim=-1)
|
405 |
+
|
406 |
+
# Multiply attention weights with values
|
407 |
+
attn_output = torch.matmul(attn_weights, v)
|
408 |
+
|
409 |
+
# Concatenate results and return to original dimensions
|
410 |
+
attn_output = attn_output.transpose(1, 2).reshape(v.size(0), -1, self.num_heads * self.head_dim)
|
411 |
+
output = self.out_proj(attn_output)
|
412 |
+
|
413 |
+
return output, attn_weights
|
414 |
+
|
415 |
+
|
416 |
+
class FeedForward(nn.Module):
|
417 |
+
def __init__(self, d_model, hidden, drop_prob=0.1):
|
418 |
+
super().__init__()
|
419 |
+
self.linear1 = nn.Linear(d_model, hidden)
|
420 |
+
self.linear2 = nn.Linear(hidden, d_model)
|
421 |
+
self.gelu = nn.GELU()
|
422 |
+
self.dropout = nn.Dropout(p=drop_prob)
|
423 |
+
|
424 |
+
def forward(self, x):
|
425 |
+
x = self.linear1(x)
|
426 |
+
x = self.gelu(x)
|
427 |
+
x = self.dropout(x)
|
428 |
+
x = self.linear2(x)
|
429 |
+
return x
|
430 |
+
|
431 |
+
|
432 |
+
class PreGRULayer(nn.Module):
|
433 |
+
def __init__(
|
434 |
+
self,
|
435 |
+
d_model,
|
436 |
+
num_heads,
|
437 |
+
ffn_hidden,
|
438 |
+
dropout: float = 0.1,
|
439 |
+
) -> None:
|
440 |
+
super().__init__()
|
441 |
+
|
442 |
+
self.pre_norm0 = nn.LayerNorm(d_model)
|
443 |
+
self.self_attention = nn.MultiheadAttention(
|
444 |
+
embed_dim=d_model,
|
445 |
+
num_heads=num_heads,
|
446 |
+
dropout=dropout,
|
447 |
+
batch_first=True,
|
448 |
+
)
|
449 |
+
self.dropout0 = nn.Dropout(dropout)
|
450 |
+
|
451 |
+
self.pre_norm1 = nn.LayerNorm(d_model)
|
452 |
+
self.cross_attention = CrossAttention(
|
453 |
+
dim_q=d_model,
|
454 |
+
dim_k=d_model,
|
455 |
+
dim_v=d_model,
|
456 |
+
dim_out=d_model,
|
457 |
+
num_heads=num_heads,
|
458 |
+
)
|
459 |
+
self.dropout1 = nn.Dropout(dropout)
|
460 |
+
|
461 |
+
self.pre_norm2 = nn.LayerNorm(d_model)
|
462 |
+
self.ffn = FeedForward(d_model, ffn_hidden)
|
463 |
+
self.dropout2 = nn.Dropout(dropout)
|
464 |
+
|
465 |
+
def forward(self, q, x) -> torch.Tensor:
|
466 |
+
"""
|
467 |
+
Expected shapes:
|
468 |
+
- q: (b, 1, dim_q)
|
469 |
+
- x: (b, seq, dim_kv)
|
470 |
+
Output shape:
|
471 |
+
(b, seq, d_model)
|
472 |
+
"""
|
473 |
+
|
474 |
+
# cross attention
|
475 |
+
_x = x
|
476 |
+
x = self.pre_norm1(x)
|
477 |
+
x, _ = self.cross_attention(query=q, key=x, value=x)
|
478 |
+
x = self.dropout1(x)
|
479 |
+
x = x + _x
|
480 |
+
|
481 |
+
# self attention
|
482 |
+
_x = x
|
483 |
+
x = self.pre_norm0(x)
|
484 |
+
x, _ = self.self_attention(query=x, key=x, value=x)
|
485 |
+
x = self.dropout0(x)
|
486 |
+
x = x + _x
|
487 |
+
|
488 |
+
# pairwise feed foward
|
489 |
+
_x = x
|
490 |
+
x = self.pre_norm2(x)
|
491 |
+
x = self.ffn(x)
|
492 |
+
x = self.dropout2(x)
|
493 |
+
x = x + _x
|
494 |
+
|
495 |
+
return x
|
496 |
+
|
497 |
+
|
498 |
+
class VariableLengthWaypointPredictor(nn.Module):
|
499 |
+
"""Variable-length GRU-based waypoint predictor with optional timestamp inputs."""
|
500 |
+
|
501 |
+
def __init__(
|
502 |
+
self,
|
503 |
+
d_model,
|
504 |
+
memory_seq_len,
|
505 |
+
timestamp_dim=0,
|
506 |
+
waypoint_dim=2,
|
507 |
+
num_heads=4,
|
508 |
+
start_from_origin=True,
|
509 |
+
dropout: float = 0.1,
|
510 |
+
):
|
511 |
+
super().__init__()
|
512 |
+
self.waypoint_dim = waypoint_dim
|
513 |
+
self.start_from_origin = start_from_origin
|
514 |
+
|
515 |
+
self.hidden_state = nn.Parameter(torch.randn(1, d_model))
|
516 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, memory_seq_len, d_model))
|
517 |
+
|
518 |
+
self.pre_gru_layer = PreGRULayer(
|
519 |
+
d_model=d_model,
|
520 |
+
num_heads=num_heads,
|
521 |
+
ffn_hidden=d_model // 2,
|
522 |
+
)
|
523 |
+
self.gru = nn.GRUCell(
|
524 |
+
input_size=waypoint_dim + d_model + timestamp_dim,
|
525 |
+
hidden_size=d_model,
|
526 |
+
)
|
527 |
+
self.head = nn.Sequential(
|
528 |
+
nn.Linear(d_model, d_model // 2),
|
529 |
+
nn.Dropout(p=dropout),
|
530 |
+
nn.ReLU(),
|
531 |
+
nn.Linear(d_model // 2, waypoint_dim), # wp_dim
|
532 |
+
)
|
533 |
+
|
534 |
+
def forward(
|
535 |
+
self,
|
536 |
+
memory: Tensor, # (b, t, c)
|
537 |
+
num_waypoints: int,
|
538 |
+
timestamps: Tensor = None,
|
539 |
+
) -> dict[str, Tensor]:
|
540 |
+
batch_size = memory.shape[0]
|
541 |
+
dtype = memory.dtype
|
542 |
+
|
543 |
+
wp = memory.new_zeros((batch_size, self.waypoint_dim))
|
544 |
+
h = self.hidden_state.repeat(batch_size, 1).to(dtype)
|
545 |
+
pos_embedding = self.pos_embedding.repeat(batch_size, 1, 1).to(dtype)
|
546 |
+
memory = memory + pos_embedding
|
547 |
+
|
548 |
+
waypoints = []
|
549 |
+
if self.start_from_origin:
|
550 |
+
# add first waypoint as zero origin
|
551 |
+
waypoints.append(memory.new_zeros((batch_size, self.waypoint_dim)))
|
552 |
+
num_waypoints = num_waypoints - 1
|
553 |
+
|
554 |
+
for t in range(num_waypoints):
|
555 |
+
inputs = self.pre_gru_layer(q=h.unsqueeze(1), x=memory) # (b, t, c)
|
556 |
+
inputs = inputs.mean(1) # (b, c)
|
557 |
+
inputs = torch.cat([wp, inputs], dim=1)
|
558 |
+
|
559 |
+
if timestamps is not None:
|
560 |
+
inputs = torch.cat([inputs, timestamps[:, t].reshape(batch_size, -1)], dim=1)
|
561 |
+
|
562 |
+
h = self.gru(inputs, h)
|
563 |
+
dx = self.head(h)
|
564 |
+
wp = wp + dx
|
565 |
+
waypoints.append(wp)
|
566 |
+
|
567 |
+
waypoints = torch.stack(waypoints, dim=1) # (b, n_wps, wp_dim)
|
568 |
+
|
569 |
+
return waypoints
|
570 |
+
|
571 |
+
|
572 |
+
class VideoActionEstimator(nn.Module):
|
573 |
+
def __init__(
|
574 |
+
self,
|
575 |
+
input_shape,
|
576 |
+
num_classes,
|
577 |
+
max_seq_len=44,
|
578 |
+
timestamp_dim=0,
|
579 |
+
d_model=512,
|
580 |
+
num_heads=8,
|
581 |
+
dropout=0.1,
|
582 |
+
feature_map_size=4,
|
583 |
+
**kwargs,
|
584 |
+
):
|
585 |
+
super().__init__()
|
586 |
+
self.max_seq_len = max_seq_len
|
587 |
+
self.timestamp_dim = timestamp_dim
|
588 |
+
assert input_shape[1] == max_seq_len
|
589 |
+
|
590 |
+
self.backbone = InceptionI3d()
|
591 |
+
feature_dim, seq_len = self.backbone.get_out_size(input_shape)[1:3]
|
592 |
+
|
593 |
+
self.avg_pool = nn.AdaptiveAvgPool3d((None, feature_map_size, feature_map_size))
|
594 |
+
memory_seq_len = seq_len * feature_map_size**2
|
595 |
+
|
596 |
+
self.squeeze_linear = nn.Linear(feature_dim, d_model)
|
597 |
+
self.positional_encoding = PositionalEncoding(d_model=d_model, max_len=memory_seq_len)
|
598 |
+
encoder_layer = TransformerEncoderLayer(
|
599 |
+
d_model=d_model,
|
600 |
+
nhead=num_heads,
|
601 |
+
dim_feedforward=512,
|
602 |
+
batch_first=True,
|
603 |
+
activation=F.gelu,
|
604 |
+
)
|
605 |
+
self.self_attn = TransformerEncoder(
|
606 |
+
encoder_layer,
|
607 |
+
num_layers=2,
|
608 |
+
)
|
609 |
+
|
610 |
+
self.classifier = nn.Sequential(
|
611 |
+
nn.Linear(d_model, d_model),
|
612 |
+
nn.Dropout(p=dropout),
|
613 |
+
nn.GELU(),
|
614 |
+
nn.Linear(d_model, num_classes),
|
615 |
+
)
|
616 |
+
self.visual_odmetry = VariableLengthWaypointPredictor(
|
617 |
+
d_model=d_model,
|
618 |
+
memory_seq_len=memory_seq_len,
|
619 |
+
waypoint_dim=2, # x, y axes
|
620 |
+
timestamp_dim=timestamp_dim,
|
621 |
+
num_heads=num_heads,
|
622 |
+
)
|
623 |
+
|
624 |
+
def forward(self, frames: Tensor, timestamps: Tensor = None) -> dict[str, Tensor]:
|
625 |
+
x = frames
|
626 |
+
num_frames = x.size(2) # seq which must be consistent in a batch
|
627 |
+
assert (
|
628 |
+
num_frames <= self.max_seq_len
|
629 |
+
), f"Input tensor has exceeded sequence length(={num_frames}) than max_seq_len(={self.max_seq_len})"
|
630 |
+
|
631 |
+
x = self.backbone(x) # (b, 1024, 11, 7, 7)
|
632 |
+
x = self.avg_pool(x) # (b, 1024, 11, 4, 4)
|
633 |
+
|
634 |
+
b, c, t, h, w = x.size()
|
635 |
+
x = x.view(b, t * h * w, c) # (b, 176, 1024)
|
636 |
+
x = self.squeeze_linear(x) # (b, 176, 512)
|
637 |
+
x = self.positional_encoding(x)
|
638 |
+
|
639 |
+
x = self.self_attn(x) # (b, 176, 512)
|
640 |
+
latent_tensor = x.mean(1) # (b, 512)
|
641 |
+
logits = self.classifier(latent_tensor)
|
642 |
+
waypoints = self.visual_odmetry(x, num_frames, timestamps=timestamps)
|
643 |
+
|
644 |
+
return {
|
645 |
+
"command": logits,
|
646 |
+
"waypoints": waypoints,
|
647 |
+
}
|
648 |
+
|
modeling_act_estimator.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
from transformers import PreTrainedModel
|
4 |
+
|
5 |
+
from model import VideoActionEstimator
|
6 |
+
from configuration_act_estimator import ActEstimatorConfig
|
7 |
+
|
8 |
+
|
9 |
+
class ActEstimator(PreTrainedModel):
|
10 |
+
config_class = ActEstimatorConfig
|
11 |
+
|
12 |
+
def __init__(self, config: ActEstimatorConfig):
|
13 |
+
super().__init__(config)
|
14 |
+
self.model = VideoActionEstimator(**config.to_dict())
|
15 |
+
|
16 |
+
def forward(self, frames: Tensor, timestamps: Tensor = None) -> dict[str, Tensor]:
|
17 |
+
return self.model(frames, timestamps)
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
# actestimator_config = ActEstimatorConfig.from_pretrained(".")
|
22 |
+
# print(actestimator_config.to_dict())
|
23 |
+
|
24 |
+
# actestimator_model = ActEstimator(actestimator_config)
|
25 |
+
|
26 |
+
# state_dict = torch.load("ckpt.pth", weights_only=True)
|
27 |
+
# actestimator_model.model.load_state_dict(state_dict)
|
28 |
+
|
29 |
+
# print(actestimator_model)
|
30 |
+
# actestimator_model.save_pretrained(".")
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
# model = ActEstimator.from_pretrained(".")
|
35 |
+
# print(model(torch.randn(1, 3, 44, 224, 224), torch.randn(1, 44)))
|
36 |
+
|