fix tensor shape error when torch version less than 2 (#4)
Browse files- Update modeling_chatglm.py (5853f34804e757ba1bb3496331d108e2674088cc)
Co-authored-by: Yaowei Zheng <[email protected]>
- modeling_chatglm.py +4 -7
modeling_chatglm.py
CHANGED
@@ -247,15 +247,12 @@ class CoreAttention(torch.nn.Module):
|
|
247 |
# This is actually dropping out entire tokens to attend to, which might
|
248 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
249 |
attention_probs = self.attention_dropout(attention_probs)
|
250 |
-
# =========================
|
251 |
-
# Context layer. [sq, b, hp]
|
252 |
-
# =========================
|
253 |
-
|
254 |
-
# value_layer -> context layer.
|
255 |
-
# [sk, b, np, hn] --> [b, np, sq, hn]
|
256 |
|
|
|
|
|
|
|
257 |
# context layer shape: [b, np, sq, hn]
|
258 |
-
output_size = (value_layer.size(
|
259 |
# change view [b * np, sk, hn]
|
260 |
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
261 |
# change view [b * np, sq, sk]
|
|
|
247 |
# This is actually dropping out entire tokens to attend to, which might
|
248 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
249 |
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
+
# query layer shape: [b * np, sq, hn]
|
252 |
+
# value layer shape: [b, np, sk, hn]
|
253 |
+
# attention shape: [b, np, sq, sk]
|
254 |
# context layer shape: [b, np, sq, hn]
|
255 |
+
output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
|
256 |
# change view [b * np, sk, hn]
|
257 |
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
258 |
# change view [b * np, sq, sk]
|