taka-yamakoshi
commited on
Commit
·
b218eb4
1
Parent(s):
e87e116
fix
Browse files
skeleton_modeling_albert.py
CHANGED
@@ -10,11 +10,16 @@ def SkeletonAlbertLayer(layer_id,layer,hidden,interventions):
|
|
10 |
attention_layer = layer.attention
|
11 |
num_heads = attention_layer.num_attention_heads
|
12 |
head_dim = attention_layer.attention_head_size
|
|
|
13 |
|
14 |
qry = attention_layer.query(hidden)
|
15 |
key = attention_layer.key(hidden)
|
16 |
val = attention_layer.value(hidden)
|
17 |
|
|
|
|
|
|
|
|
|
18 |
# swap representations
|
19 |
interv_layer = interventions.pop(layer_id,None)
|
20 |
if interv_layer is not None:
|
@@ -29,8 +34,8 @@ def SkeletonAlbertLayer(layer_id,layer,hidden,interventions):
|
|
29 |
if interv_rep is not None:
|
30 |
new_state = reps[rep_type].clone()
|
31 |
for head_id, pos, swap_ids in interv_rep:
|
32 |
-
new_state[swap_ids[0],pos,head_id] = reps[
|
33 |
-
new_state[swap_ids[1],pos,head_id] = reps[
|
34 |
reps[rep_type] = new_state.clone()
|
35 |
|
36 |
hidden = reps['lay'].clone()
|
|
|
10 |
attention_layer = layer.attention
|
11 |
num_heads = attention_layer.num_attention_heads
|
12 |
head_dim = attention_layer.attention_head_size
|
13 |
+
assert num_heads*head_dim == hidden.shape[2]
|
14 |
|
15 |
qry = attention_layer.query(hidden)
|
16 |
key = attention_layer.key(hidden)
|
17 |
val = attention_layer.value(hidden)
|
18 |
|
19 |
+
assert qry.shape == hidden.shape
|
20 |
+
assert key.shape == hidden.shape
|
21 |
+
assert val.shape == hidden.shape
|
22 |
+
|
23 |
# swap representations
|
24 |
interv_layer = interventions.pop(layer_id,None)
|
25 |
if interv_layer is not None:
|
|
|
34 |
if interv_rep is not None:
|
35 |
new_state = reps[rep_type].clone()
|
36 |
for head_id, pos, swap_ids in interv_rep:
|
37 |
+
new_state[swap_ids[0],pos,head_dim*head_id:head_dim*(head_id+1)] = reps[rep_type][swap_ids[1],pos,head_dim*head_id:head_dim*(head_id+1)]
|
38 |
+
new_state[swap_ids[1],pos,head_dim*head_id:head_dim*(head_id+1)] = reps[rep_type][swap_ids[0],pos,head_dim*head_id:head_dim*(head_id+1)]
|
39 |
reps[rep_type] = new_state.clone()
|
40 |
|
41 |
hidden = reps['lay'].clone()
|