File size: 3,584 Bytes
3f96a16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
"""GPT Blocks used for the GPT Model."""
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from .attention import ATTN_CLASS_REGISTRY
from .ffn import FFN_CLASS_REGISTRY, build_ffn
from .norm import NORM_CLASS_REGISTRY
class MPTBlock(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
expansion_ratio: int,
attn_config: Optional[Dict] = None,
ffn_config: Optional[Dict] = None,
resid_pdrop: float = 0.0,
norm_type: str = "low_precision_layernorm",
fc_type: str = "torch",
device: Optional[str] = None,
no_bias: bool = False,
**kwargs: Any
):
if attn_config is None:
attn_config = {
"attn_type": "multihead_attention",
"attn_pdrop": 0.0,
"attn_impl": "triton",
"qk_ln": False,
"clip_qkv": None,
"softmax_scale": None,
"prefix_lm": False,
"attn_uses_sequence_id": False,
"alibi": False,
"alibi_bias_max": 8,
}
if ffn_config is None:
ffn_config = {"ffn_type": "mptmlp"}
del kwargs
super().__init__()
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
assert isinstance(attn_config["attn_type"], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
args_to_exclude_in_attn_class = {
"attn_type",
"prefix_lm",
"alibi",
"attn_uses_sequence_id",
"alibi_bias_max",
}
attn_config_subset_for_attn_class = {
k: v
for (k, v) in attn_config.items()
if k not in args_to_exclude_in_attn_class
}
self.norm_1 = norm_class(d_model, device=device)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config["ffn_type"]], "_has_norm", False):
self.norm_2 = norm_class(d_model, device=device)
self.ffn = build_ffn(
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
bias=not no_bias,
**ffn_config
)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
def forward(
self,
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[Tuple[torch.Tensor, torch.Tensor]],
]:
a = self.norm_1(x)
(b, attn_weights, past_key_value) = self.attn(
a,
past_key_value=past_key_value,
attn_bias=attn_bias,
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
)
x = x + self.resid_attn_dropout(b)
m = x
if self.norm_2 is not None:
m = self.norm_2(x)
n = self.ffn(m)
x = x + self.resid_ffn_dropout(n)
return (x, attn_weights, past_key_value)
|