keishihara commited on
Commit
3d729d6
·
verified ·
1 Parent(s): 05a7b1d

Upload folder using huggingface_hub

Browse files
configuration_act_estimator.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class ActEstimatorConfig(PretrainedConfig):
5
+ model_type = "ACT-Estimator"
6
+
7
+ def __init__(
8
+ self,
9
+ input_shape=(3, 44, 224, 224),
10
+ num_classes=9,
11
+ max_seq_len=44,
12
+ timestamp_dim=1,
13
+ d_model=512,
14
+ num_heads=8,
15
+ dropout=0.1,
16
+ feature_map_size=4,
17
+ **kwargs
18
+ ):
19
+ self.input_shape = input_shape
20
+ self.num_classes = num_classes
21
+ self.max_seq_len = max_seq_len
22
+ self.timestamp_dim = timestamp_dim
23
+ self.d_model = d_model
24
+ self.num_heads = num_heads
25
+ self.dropout = dropout
26
+ self.feature_map_size = feature_map_size
27
+ super().__init__(**kwargs)
28
+
29
+
model.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Sequence
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor, nn
7
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
8
+
9
+ from transformers import AutoModel, PreTrainedModel
10
+
11
+
12
+ class MaxPool3dSamePadding(nn.MaxPool3d):
13
+ def compute_pad(self, dim, s):
14
+ if s % self.stride[dim] == 0:
15
+ return max(self.kernel_size[dim] - self.stride[dim], 0)
16
+ else:
17
+ return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
18
+
19
+ def forward(self, x):
20
+ (batch, channel, t, h, w) = x.size()
21
+ pad_t = self.compute_pad(0, t)
22
+ pad_h = self.compute_pad(1, h)
23
+ pad_w = self.compute_pad(2, w)
24
+
25
+ pad_t_f = pad_t // 2
26
+ pad_t_b = pad_t - pad_t_f
27
+ pad_h_f = pad_h // 2
28
+ pad_h_b = pad_h - pad_h_f
29
+ pad_w_f = pad_w // 2
30
+ pad_w_b = pad_w - pad_w_f
31
+
32
+ pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
33
+ x = F.pad(x, pad)
34
+ return super().forward(x)
35
+
36
+
37
+ class Unit3D(nn.Module):
38
+ def __init__(
39
+ self,
40
+ in_channels,
41
+ output_channels,
42
+ kernel_shape=(1, 1, 1),
43
+ stride=(1, 1, 1),
44
+ padding=0,
45
+ activation_fn=F.relu,
46
+ use_batch_norm=True,
47
+ use_bias=False,
48
+ name="unit_3d",
49
+ ):
50
+ """Initializes Unit3D module."""
51
+ super().__init__()
52
+
53
+ self._output_channels = output_channels
54
+ self._kernel_shape = kernel_shape
55
+ self._stride = stride
56
+ self._use_batch_norm = use_batch_norm
57
+ self._activation_fn = activation_fn
58
+ self._use_bias = use_bias
59
+ self.name = name
60
+ self.padding = padding
61
+
62
+ self.conv3d = nn.Conv3d(
63
+ in_channels=in_channels,
64
+ out_channels=self._output_channels,
65
+ kernel_size=self._kernel_shape,
66
+ stride=self._stride,
67
+ padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function
68
+ bias=self._use_bias,
69
+ )
70
+
71
+ if self._use_batch_norm:
72
+ self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01)
73
+
74
+ def compute_pad(self, dim, s):
75
+ if s % self._stride[dim] == 0:
76
+ return max(self._kernel_shape[dim] - self._stride[dim], 0)
77
+ else:
78
+ return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
79
+
80
+ def forward(self, x):
81
+ (batch, channel, t, h, w) = x.size()
82
+ pad_t = self.compute_pad(0, t)
83
+ pad_h = self.compute_pad(1, h)
84
+ pad_w = self.compute_pad(2, w)
85
+
86
+ pad_t_f = pad_t // 2
87
+ pad_t_b = pad_t - pad_t_f
88
+ pad_h_f = pad_h // 2
89
+ pad_h_b = pad_h - pad_h_f
90
+ pad_w_f = pad_w // 2
91
+ pad_w_b = pad_w - pad_w_f
92
+
93
+ pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
94
+ x = F.pad(x, pad)
95
+
96
+ x = self.conv3d(x)
97
+ if self._use_batch_norm:
98
+ x = self.bn(x)
99
+ if self._activation_fn is not None:
100
+ x = self._activation_fn(x)
101
+ return x
102
+
103
+
104
+ class InceptionModule(nn.Module):
105
+ def __init__(self, in_channels, out_channels, name):
106
+ super().__init__()
107
+
108
+ self.b0 = Unit3D(
109
+ in_channels=in_channels,
110
+ output_channels=out_channels[0],
111
+ kernel_shape=[1, 1, 1],
112
+ padding=0,
113
+ name=name + "/Branch_0/Conv3d_0a_1x1",
114
+ )
115
+ self.b1a = Unit3D(
116
+ in_channels=in_channels,
117
+ output_channels=out_channels[1],
118
+ kernel_shape=[1, 1, 1],
119
+ padding=0,
120
+ name=name + "/Branch_1/Conv3d_0a_1x1",
121
+ )
122
+ self.b1b = Unit3D(
123
+ in_channels=out_channels[1],
124
+ output_channels=out_channels[2],
125
+ kernel_shape=[3, 3, 3],
126
+ name=name + "/Branch_1/Conv3d_0b_3x3",
127
+ )
128
+ self.b2a = Unit3D(
129
+ in_channels=in_channels,
130
+ output_channels=out_channels[3],
131
+ kernel_shape=[1, 1, 1],
132
+ padding=0,
133
+ name=name + "/Branch_2/Conv3d_0a_1x1",
134
+ )
135
+ self.b2b = Unit3D(
136
+ in_channels=out_channels[3],
137
+ output_channels=out_channels[4],
138
+ kernel_shape=[3, 3, 3],
139
+ name=name + "/Branch_2/Conv3d_0b_3x3",
140
+ )
141
+ self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(1, 1, 1), padding=0)
142
+ self.b3b = Unit3D(
143
+ in_channels=in_channels,
144
+ output_channels=out_channels[5],
145
+ kernel_shape=[1, 1, 1],
146
+ padding=0,
147
+ name=name + "/Branch_3/Conv3d_0b_1x1",
148
+ )
149
+ self.name = name
150
+
151
+ def forward(self, x):
152
+ b0 = self.b0(x)
153
+ b1 = self.b1b(self.b1a(x))
154
+ b2 = self.b2b(self.b2a(x))
155
+ b3 = self.b3b(self.b3a(x))
156
+ return torch.cat([b0, b1, b2, b3], dim=1)
157
+
158
+
159
+ class InceptionI3d(nn.Module):
160
+ """Inception-v1 I3D architecture.
161
+ The model is introduced in:
162
+ Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
163
+ Joao Carreira, Andrew Zisserman
164
+ https://arxiv.org/pdf/1705.07750v1.pdf.
165
+ See also the Inception architecture, introduced in:
166
+ Going deeper with convolutions
167
+ Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
168
+ Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
169
+ http://arxiv.org/pdf/1409.4842v1.pdf.
170
+ """
171
+
172
+ # Endpoints of the model in order. During construction, all the endpoints up
173
+ # to a designated `final_endpoint` are returned in a dictionary as the
174
+ # second return value.
175
+ VALID_ENDPOINTS = (
176
+ "Conv3d_1a_7x7",
177
+ "MaxPool3d_2a_3x3",
178
+ "Conv3d_2b_1x1",
179
+ "Conv3d_2c_3x3",
180
+ "MaxPool3d_3a_3x3",
181
+ "Mixed_3b",
182
+ "Mixed_3c",
183
+ "MaxPool3d_4a_3x3",
184
+ "Mixed_4b",
185
+ "Mixed_4c",
186
+ "Mixed_4d",
187
+ "Mixed_4e",
188
+ "Mixed_4f",
189
+ "MaxPool3d_5a_2x2",
190
+ "Mixed_5b",
191
+ "Mixed_5c",
192
+ "Logits",
193
+ "Predictions",
194
+ )
195
+
196
+ def __init__(
197
+ self,
198
+ time_spatial_squeeze=True,
199
+ final_endpoint="Logits",
200
+ name="inception_i3d",
201
+ in_channels=3,
202
+ ):
203
+ """Initializes I3D model instance.
204
+ Args:
205
+ num_classes: The number of outputs in the logit layer (default 400, which
206
+ matches the Kinetics dataset).
207
+ spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
208
+ before returning (default True).
209
+ final_endpoint: The model contains many possible endpoints.
210
+ `final_endpoint` specifies the last endpoint for the model to be built
211
+ up to. In addition to the output at `final_endpoint`, all the outputs
212
+ at endpoints up to `final_endpoint` will also be returned, in a
213
+ dictionary. `final_endpoint` must be one of
214
+ InceptionI3d.VALID_ENDPOINTS (default 'Logits').
215
+ name: A string (optional). The name of this module.
216
+ Raises:
217
+ ValueError: if `final_endpoint` is not recognized.
218
+ """
219
+
220
+ if final_endpoint not in self.VALID_ENDPOINTS:
221
+ raise ValueError(f"Unknown final endpoint {final_endpoint}")
222
+
223
+ super().__init__()
224
+ self._time_spatial_squeeze = time_spatial_squeeze
225
+ self._final_endpoint = final_endpoint
226
+ self.logits = None
227
+
228
+ if self._final_endpoint not in self.VALID_ENDPOINTS:
229
+ raise ValueError(f"Unknown final endpoint {self._final_endpoint}")
230
+
231
+ self.end_points = {}
232
+ end_point = "Conv3d_1a_7x7"
233
+ self.end_points[end_point] = Unit3D(
234
+ in_channels=in_channels,
235
+ output_channels=64,
236
+ kernel_shape=[7, 7, 7],
237
+ stride=(2, 2, 2),
238
+ padding=(3, 3, 3),
239
+ name=name + end_point,
240
+ )
241
+ if self._final_endpoint == end_point:
242
+ return
243
+
244
+ end_point = "MaxPool3d_2a_3x3"
245
+ self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
246
+ if self._final_endpoint == end_point:
247
+ return
248
+
249
+ end_point = "Conv3d_2b_1x1"
250
+ self.end_points[end_point] = Unit3D(
251
+ in_channels=64,
252
+ output_channels=64,
253
+ kernel_shape=[1, 1, 1],
254
+ padding=0,
255
+ name=name + end_point,
256
+ )
257
+ if self._final_endpoint == end_point:
258
+ return
259
+
260
+ end_point = "Conv3d_2c_3x3"
261
+ self.end_points[end_point] = Unit3D(
262
+ in_channels=64,
263
+ output_channels=192,
264
+ kernel_shape=[3, 3, 3],
265
+ padding=1,
266
+ name=name + end_point,
267
+ )
268
+ if self._final_endpoint == end_point:
269
+ return
270
+
271
+ end_point = "MaxPool3d_3a_3x3"
272
+ self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
273
+ if self._final_endpoint == end_point:
274
+ return
275
+
276
+ end_point = "Mixed_3b"
277
+ self.end_points[end_point] = InceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point)
278
+ if self._final_endpoint == end_point:
279
+ return
280
+
281
+ end_point = "Mixed_3c"
282
+ self.end_points[end_point] = InceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point)
283
+ if self._final_endpoint == end_point:
284
+ return
285
+
286
+ end_point = "MaxPool3d_4a_3x3"
287
+ self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0)
288
+ if self._final_endpoint == end_point:
289
+ return
290
+
291
+ end_point = "Mixed_4b"
292
+ self.end_points[end_point] = InceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
293
+ if self._final_endpoint == end_point:
294
+ return
295
+
296
+ end_point = "Mixed_4c"
297
+ self.end_points[end_point] = InceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
298
+ if self._final_endpoint == end_point:
299
+ return
300
+
301
+ end_point = "Mixed_4d"
302
+ self.end_points[end_point] = InceptionModule(160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)
303
+ if self._final_endpoint == end_point:
304
+ return
305
+
306
+ end_point = "Mixed_4e"
307
+ self.end_points[end_point] = InceptionModule(128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)
308
+ if self._final_endpoint == end_point:
309
+ return
310
+
311
+ end_point = "Mixed_4f"
312
+ self.end_points[end_point] = InceptionModule(
313
+ 112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128], name + end_point
314
+ )
315
+ if self._final_endpoint == end_point:
316
+ return
317
+
318
+ end_point = "MaxPool3d_5a_2x2"
319
+ self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 2, 2], stride=(1, 2, 2), padding=0)
320
+ if self._final_endpoint == end_point:
321
+ return
322
+
323
+ end_point = "Mixed_5b"
324
+ self.end_points[end_point] = InceptionModule(
325
+ 256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128], name + end_point
326
+ )
327
+ if self._final_endpoint == end_point:
328
+ return
329
+
330
+ end_point = "Mixed_5c"
331
+ self.end_points[end_point] = InceptionModule(
332
+ 256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128], name + end_point
333
+ )
334
+
335
+ if self._final_endpoint == end_point:
336
+ return
337
+
338
+ self.build()
339
+
340
+ def build(self):
341
+ for k in self.end_points.keys():
342
+ self.add_module(k, self.end_points[k])
343
+
344
+ def get_out_size(self, shape: Sequence[int], dim=None) -> int:
345
+ device = next(self.parameters()).device
346
+ out = self(torch.zeros((1, *shape), device=device))
347
+ return out.size(dim)
348
+
349
+ def forward(self, x):
350
+ for end_point in self.VALID_ENDPOINTS:
351
+ if end_point in self.end_points:
352
+ x = self._modules[end_point](x) # use _modules to work with dataparallel
353
+ return x
354
+
355
+
356
+ class PositionalEncoding(nn.Module):
357
+ def __init__(self, d_model: int, max_len: int = 5000) -> None:
358
+ super().__init__()
359
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
360
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
361
+ pe = torch.zeros(max_len, d_model)
362
+ pe[:, 0::2] = torch.sin(position * div_term)
363
+ pe[:, 1::2] = torch.cos(position * div_term)
364
+ pe = pe.unsqueeze(0)
365
+ self.register_buffer("pe", pe)
366
+
367
+ def forward(self, x: Tensor) -> Tensor:
368
+ """
369
+ Args:
370
+ x (Tensor): shape [batch_size, seq_len, embedding_dim]
371
+ """
372
+ x = x + self.pe[:, : x.size(1), :]
373
+ return x
374
+
375
+
376
+ class CrossAttention(nn.Module):
377
+ def __init__(self, dim_q, dim_k, dim_v, dim_out, num_heads):
378
+ super().__init__()
379
+ self.num_heads = num_heads
380
+ self.head_dim = dim_out // num_heads
381
+ assert dim_out % num_heads == 0, "dim_out must be divisible by num_heads"
382
+ self.scale = self.head_dim**-0.5
383
+
384
+ self.query_proj = nn.Linear(dim_q, dim_out)
385
+ self.key_proj = nn.Linear(dim_k, dim_out)
386
+ self.value_proj = nn.Linear(dim_v, dim_out)
387
+
388
+ self.out_proj = nn.Linear(dim_out, dim_out)
389
+
390
+ def forward(self, query, key, value):
391
+ # Linear transformation of query, key, and value
392
+ q = self.query_proj(query) # shape: (batch_size, query_len, dim_out)
393
+ k = self.key_proj(key) # shape: (batch_size, key_len, dim_out)
394
+ v = self.value_proj(value) # shape: (batch_size, value_len, dim_out)
395
+
396
+ # Split dimensions for multi-head attention, and compute per head
397
+ # print("q:", q.size(), "k:", k.size(), "v:", v.size())
398
+ q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim).transpose(1, 2)
399
+ k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim).transpose(1, 2)
400
+ v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim).transpose(1, 2)
401
+
402
+ # Scaled dot-product attention
403
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
404
+ attn_weights = attn_weights.softmax(dim=-1)
405
+
406
+ # Multiply attention weights with values
407
+ attn_output = torch.matmul(attn_weights, v)
408
+
409
+ # Concatenate results and return to original dimensions
410
+ attn_output = attn_output.transpose(1, 2).reshape(v.size(0), -1, self.num_heads * self.head_dim)
411
+ output = self.out_proj(attn_output)
412
+
413
+ return output, attn_weights
414
+
415
+
416
+ class FeedForward(nn.Module):
417
+ def __init__(self, d_model, hidden, drop_prob=0.1):
418
+ super().__init__()
419
+ self.linear1 = nn.Linear(d_model, hidden)
420
+ self.linear2 = nn.Linear(hidden, d_model)
421
+ self.gelu = nn.GELU()
422
+ self.dropout = nn.Dropout(p=drop_prob)
423
+
424
+ def forward(self, x):
425
+ x = self.linear1(x)
426
+ x = self.gelu(x)
427
+ x = self.dropout(x)
428
+ x = self.linear2(x)
429
+ return x
430
+
431
+
432
+ class PreGRULayer(nn.Module):
433
+ def __init__(
434
+ self,
435
+ d_model,
436
+ num_heads,
437
+ ffn_hidden,
438
+ dropout: float = 0.1,
439
+ ) -> None:
440
+ super().__init__()
441
+
442
+ self.pre_norm0 = nn.LayerNorm(d_model)
443
+ self.self_attention = nn.MultiheadAttention(
444
+ embed_dim=d_model,
445
+ num_heads=num_heads,
446
+ dropout=dropout,
447
+ batch_first=True,
448
+ )
449
+ self.dropout0 = nn.Dropout(dropout)
450
+
451
+ self.pre_norm1 = nn.LayerNorm(d_model)
452
+ self.cross_attention = CrossAttention(
453
+ dim_q=d_model,
454
+ dim_k=d_model,
455
+ dim_v=d_model,
456
+ dim_out=d_model,
457
+ num_heads=num_heads,
458
+ )
459
+ self.dropout1 = nn.Dropout(dropout)
460
+
461
+ self.pre_norm2 = nn.LayerNorm(d_model)
462
+ self.ffn = FeedForward(d_model, ffn_hidden)
463
+ self.dropout2 = nn.Dropout(dropout)
464
+
465
+ def forward(self, q, x) -> torch.Tensor:
466
+ """
467
+ Expected shapes:
468
+ - q: (b, 1, dim_q)
469
+ - x: (b, seq, dim_kv)
470
+ Output shape:
471
+ (b, seq, d_model)
472
+ """
473
+
474
+ # cross attention
475
+ _x = x
476
+ x = self.pre_norm1(x)
477
+ x, _ = self.cross_attention(query=q, key=x, value=x)
478
+ x = self.dropout1(x)
479
+ x = x + _x
480
+
481
+ # self attention
482
+ _x = x
483
+ x = self.pre_norm0(x)
484
+ x, _ = self.self_attention(query=x, key=x, value=x)
485
+ x = self.dropout0(x)
486
+ x = x + _x
487
+
488
+ # pairwise feed foward
489
+ _x = x
490
+ x = self.pre_norm2(x)
491
+ x = self.ffn(x)
492
+ x = self.dropout2(x)
493
+ x = x + _x
494
+
495
+ return x
496
+
497
+
498
+ class VariableLengthWaypointPredictor(nn.Module):
499
+ """Variable-length GRU-based waypoint predictor with optional timestamp inputs."""
500
+
501
+ def __init__(
502
+ self,
503
+ d_model,
504
+ memory_seq_len,
505
+ timestamp_dim=0,
506
+ waypoint_dim=2,
507
+ num_heads=4,
508
+ start_from_origin=True,
509
+ dropout: float = 0.1,
510
+ ):
511
+ super().__init__()
512
+ self.waypoint_dim = waypoint_dim
513
+ self.start_from_origin = start_from_origin
514
+
515
+ self.hidden_state = nn.Parameter(torch.randn(1, d_model))
516
+ self.pos_embedding = nn.Parameter(torch.randn(1, memory_seq_len, d_model))
517
+
518
+ self.pre_gru_layer = PreGRULayer(
519
+ d_model=d_model,
520
+ num_heads=num_heads,
521
+ ffn_hidden=d_model // 2,
522
+ )
523
+ self.gru = nn.GRUCell(
524
+ input_size=waypoint_dim + d_model + timestamp_dim,
525
+ hidden_size=d_model,
526
+ )
527
+ self.head = nn.Sequential(
528
+ nn.Linear(d_model, d_model // 2),
529
+ nn.Dropout(p=dropout),
530
+ nn.ReLU(),
531
+ nn.Linear(d_model // 2, waypoint_dim), # wp_dim
532
+ )
533
+
534
+ def forward(
535
+ self,
536
+ memory: Tensor, # (b, t, c)
537
+ num_waypoints: int,
538
+ timestamps: Tensor = None,
539
+ ) -> dict[str, Tensor]:
540
+ batch_size = memory.shape[0]
541
+ dtype = memory.dtype
542
+
543
+ wp = memory.new_zeros((batch_size, self.waypoint_dim))
544
+ h = self.hidden_state.repeat(batch_size, 1).to(dtype)
545
+ pos_embedding = self.pos_embedding.repeat(batch_size, 1, 1).to(dtype)
546
+ memory = memory + pos_embedding
547
+
548
+ waypoints = []
549
+ if self.start_from_origin:
550
+ # add first waypoint as zero origin
551
+ waypoints.append(memory.new_zeros((batch_size, self.waypoint_dim)))
552
+ num_waypoints = num_waypoints - 1
553
+
554
+ for t in range(num_waypoints):
555
+ inputs = self.pre_gru_layer(q=h.unsqueeze(1), x=memory) # (b, t, c)
556
+ inputs = inputs.mean(1) # (b, c)
557
+ inputs = torch.cat([wp, inputs], dim=1)
558
+
559
+ if timestamps is not None:
560
+ inputs = torch.cat([inputs, timestamps[:, t].reshape(batch_size, -1)], dim=1)
561
+
562
+ h = self.gru(inputs, h)
563
+ dx = self.head(h)
564
+ wp = wp + dx
565
+ waypoints.append(wp)
566
+
567
+ waypoints = torch.stack(waypoints, dim=1) # (b, n_wps, wp_dim)
568
+
569
+ return waypoints
570
+
571
+
572
+ class VideoActionEstimator(nn.Module):
573
+ def __init__(
574
+ self,
575
+ input_shape,
576
+ num_classes,
577
+ max_seq_len=44,
578
+ timestamp_dim=0,
579
+ d_model=512,
580
+ num_heads=8,
581
+ dropout=0.1,
582
+ feature_map_size=4,
583
+ **kwargs,
584
+ ):
585
+ super().__init__()
586
+ self.max_seq_len = max_seq_len
587
+ self.timestamp_dim = timestamp_dim
588
+ assert input_shape[1] == max_seq_len
589
+
590
+ self.backbone = InceptionI3d()
591
+ feature_dim, seq_len = self.backbone.get_out_size(input_shape)[1:3]
592
+
593
+ self.avg_pool = nn.AdaptiveAvgPool3d((None, feature_map_size, feature_map_size))
594
+ memory_seq_len = seq_len * feature_map_size**2
595
+
596
+ self.squeeze_linear = nn.Linear(feature_dim, d_model)
597
+ self.positional_encoding = PositionalEncoding(d_model=d_model, max_len=memory_seq_len)
598
+ encoder_layer = TransformerEncoderLayer(
599
+ d_model=d_model,
600
+ nhead=num_heads,
601
+ dim_feedforward=512,
602
+ batch_first=True,
603
+ activation=F.gelu,
604
+ )
605
+ self.self_attn = TransformerEncoder(
606
+ encoder_layer,
607
+ num_layers=2,
608
+ )
609
+
610
+ self.classifier = nn.Sequential(
611
+ nn.Linear(d_model, d_model),
612
+ nn.Dropout(p=dropout),
613
+ nn.GELU(),
614
+ nn.Linear(d_model, num_classes),
615
+ )
616
+ self.visual_odmetry = VariableLengthWaypointPredictor(
617
+ d_model=d_model,
618
+ memory_seq_len=memory_seq_len,
619
+ waypoint_dim=2, # x, y axes
620
+ timestamp_dim=timestamp_dim,
621
+ num_heads=num_heads,
622
+ )
623
+
624
+ def forward(self, frames: Tensor, timestamps: Tensor = None) -> dict[str, Tensor]:
625
+ x = frames
626
+ num_frames = x.size(2) # seq which must be consistent in a batch
627
+ assert (
628
+ num_frames <= self.max_seq_len
629
+ ), f"Input tensor has exceeded sequence length(={num_frames}) than max_seq_len(={self.max_seq_len})"
630
+
631
+ x = self.backbone(x) # (b, 1024, 11, 7, 7)
632
+ x = self.avg_pool(x) # (b, 1024, 11, 4, 4)
633
+
634
+ b, c, t, h, w = x.size()
635
+ x = x.view(b, t * h * w, c) # (b, 176, 1024)
636
+ x = self.squeeze_linear(x) # (b, 176, 512)
637
+ x = self.positional_encoding(x)
638
+
639
+ x = self.self_attn(x) # (b, 176, 512)
640
+ latent_tensor = x.mean(1) # (b, 512)
641
+ logits = self.classifier(latent_tensor)
642
+ waypoints = self.visual_odmetry(x, num_frames, timestamps=timestamps)
643
+
644
+ return {
645
+ "command": logits,
646
+ "waypoints": waypoints,
647
+ }
648
+
modeling_act_estimator.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from transformers import PreTrainedModel
4
+
5
+ from model import VideoActionEstimator
6
+ from configuration_act_estimator import ActEstimatorConfig
7
+
8
+
9
+ class ActEstimator(PreTrainedModel):
10
+ config_class = ActEstimatorConfig
11
+
12
+ def __init__(self, config: ActEstimatorConfig):
13
+ super().__init__(config)
14
+ self.model = VideoActionEstimator(**config.to_dict())
15
+
16
+ def forward(self, frames: Tensor, timestamps: Tensor = None) -> dict[str, Tensor]:
17
+ return self.model(frames, timestamps)
18
+
19
+
20
+
21
+ # actestimator_config = ActEstimatorConfig.from_pretrained(".")
22
+ # print(actestimator_config.to_dict())
23
+
24
+ # actestimator_model = ActEstimator(actestimator_config)
25
+
26
+ # state_dict = torch.load("ckpt.pth", weights_only=True)
27
+ # actestimator_model.model.load_state_dict(state_dict)
28
+
29
+ # print(actestimator_model)
30
+ # actestimator_model.save_pretrained(".")
31
+
32
+
33
+
34
+ # model = ActEstimator.from_pretrained(".")
35
+ # print(model(torch.randn(1, 3, 44, 224, 224), torch.randn(1, 44)))
36
+