andrewrreed HF staff Zymrael commited on
Commit
92d683c
·
verified ·
0 Parent(s):

Duplicate from togethercomputer/evo-1-131k-base

Browse files

Co-authored-by: Michael Poli <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - stripedhyena
5
+ - long context
6
+ - deep signal processing
7
+ - hybrid
8
+ - biology
9
+ - genomics
10
+ ---
11
+
12
+
13
+ ## Evo-1 (Phase 2)
14
+
15
+ <p align="center">
16
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/62a1306bbe7fa896d2c8de44/JoEHcvLTUlHoMcgh3mmAz.png" width="70%" />
17
+ </p>
18
+
19
+
20
+ ### About
21
+
22
+ Evo is a biological foundation model capable of long-context modeling and design.
23
+
24
+ Evo uses the [StripedHyena architecture](https://github.com/togethercomputer/stripedhyena) to enable modeling of sequences at a single-nucleotide, byte-level resolution with near-linear scaling of compute and memory relative to context length.
25
+ Evo has 7 billion parameters and is trained on OpenGenome, a prokaryotic whole-genome dataset containing ~300 billion tokens.
26
+
27
+ Technical details about Evo can be found in our preprint and our accompanying blog posts. Evo was collaboratively developed by the [Arc Institute](https://arcinstitute.org/) and TogetherAI.
28
+
29
+ As part of our commitment to open science, we release **weights of 15 intermediate pretraining checkpoints** for phase 1 and phase 2 of pretraining. The checkpoints are available as branches of the corresponding HuggingFace repository.
30
+
31
+ **Evo-1 (Phase 2)** is our **longer context model** in the Evo family, trained at a context length of 131k and tested on generation of sequences of length >650k
32
+
33
+ | Checkpoint Name | Description |
34
+ |----------------------------------------|-------------|
35
+ | `evo-1-8k-base` | A model pretrained with 8,192 context. We use this model as the base model for molecular-scale finetuning tasks. |
36
+ | `evo-1-131k-base` | A model pretrained with 131,072 context using `evo-1-8k-base` as the initialization. We use this model to reason about and generate sequences at the genome scale. |
37
+
38
+ ### Model Architecture
39
+
40
+ StripedHyena is a deep signal processing, hybrid architecture composed of multi-head attention and gated convolutions arranged in [Hyena](https://arxiv.org/abs/2302.10866) blocks, improving over decoder-only Transformers.
41
+
42
+ StripedHyena is designed to leverage the specialization of each of its layer classes, with Hyena layers implementing the bulk of the computation required for sequence processing and attention layers supplementing the ability to perform targeted pattern recall.
43
+
44
+
45
+ Some highlights of the architecture:
46
+ - **Efficient autoregressive generation** via a recurrent mode (>500k generation with a single 80GB GPU)
47
+ - **Significantly faster training and finetuning** at long context (>3x at 131k)
48
+ - **Improved scaling laws over state-of-the-art architectures** (e.g., Transformer++) on both natural language and biological sequences.
49
+ - **Robust to training beyond the compute-optimal frontier** e.g., training way beyond Chinchilla-optimal token amounts (see preprint for details -- more details to come)
50
+
51
+
52
+ ### How to use Evo
53
+
54
+ Example usage is provided in the [standalone repo](https://github.com/evo-design/evo).
55
+
56
+
57
+ #### Parametrization for Inference and Finetuning
58
+
59
+ One of the advantages of deep signal processing models is their flexibility. Different parametrizations of convolutions can be used depending on the memory, expressivity and causality requirements of pretraining, finetuning or inference workloads.
60
+
61
+ The main classes are:
62
+ - Modal canonical: unconstrained poles ([reference](https://arxiv.org/pdf/2203.14343.pdf), [reference](https://arxiv.org/abs/2310.18780)), or constrained poles ([reference](https://arxiv.org/abs/2206.11893), [reference](https://arxiv.org/pdf/2303.06349.pdf)).
63
+ - Companion canonical / rational: TBA.
64
+ - Hypernetworks: hypernetwork ([reference](https://arxiv.org/abs/2102.02611)), modulated hypernetwork ([reference](https://arxiv.org/abs/2302.10866)).
65
+ - Explicit: modulated explicit ([reference](https://arxiv.org/pdf/2210.09298.pdf)).
66
+
67
+ StripedHyena is a mixed precision model. Make sure to keep your `poles` and `residues` in `float32` precision, especially for longer prompts or training.
68
+
69
+
70
+
71
+ ### Disclaimer
72
+
73
+ To use StripedHyena outside of the playground, you will need to install custom kernels. Please follow the instructions from the [standalone repository](https://github.com/togethercomputer/stripedhyena).
74
+
75
+ ## Cite
76
+
77
+ ```
78
+ @article{nguyen2024sequence,
79
+ author = {Eric Nguyen and Michael Poli and Matthew G. Durrant and Armin W. Thomas and Brian Kang and Jeremy Sullivan and Madelena Y. Ng and Ashley Lewis and Aman Patel and Aaron Lou and Stefano Ermon and Stephen A. Baccus and Tina Hernandez-Boussard and Christopher Ré and Patrick D. Hsu and Brian L. Hie},
80
+ journal = {Arc Institute manuscripts},
81
+ title = {Sequence modeling and design from molecular to genome scale with Evo},
82
+ url = {https://arcinstitute.org/manuscripts/Evo},
83
+ year = {2024},
84
+ }
85
+ ```
cache.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ from torch import Tensor
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional
8
+
9
+
10
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
11
+ @dataclass
12
+ class InferenceParams:
13
+ """Inference parameters that are passed to the main model in order
14
+ to efficienly calculate and store the context during inference."""
15
+
16
+ max_seqlen: int
17
+ max_batch_size: int
18
+ seqlen_offset: int = 0
19
+ batch_size_offset: int = 0
20
+ key_value_memory_dict: dict = field(default_factory=dict)
21
+ lengths_per_sample: Optional[Tensor] = None
22
+
23
+ def reset(self, max_seqlen, max_batch_size):
24
+ self.max_seqlen = max_seqlen
25
+ self.max_batch_size = max_batch_size
26
+ self.seqlen_offset = 0
27
+ if self.lengths_per_sample is not None:
28
+ self.lengths_per_sample.zero_()
29
+
30
+
31
+ @dataclass
32
+ class RecurrentInferenceParams:
33
+ """Inference parameters passed to blocks with recurrent mode."""
34
+
35
+ fir_filter_length: int = 3
36
+ state_dim: int = 16
37
+ seqlen_offset: int = 0
38
+ fir_state_dict: dict = field(default_factory=dict)
39
+ state_dict: dict = field(default_factory=dict)
40
+
41
+ def reset(self):
42
+ self.fir_filter_length = 3
43
+ self.state_dim = 16
44
+ self.seqlen_offset = 0
config.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "_name_or_path": "togethercomputer/evo-1-131k-base",
4
+ "architectures": [
5
+ "StripedHyenaModelForCausalLM"
6
+ ],
7
+ "attn_layer_idxs": [
8
+ 8,
9
+ 16,
10
+ 24
11
+ ],
12
+ "auto_map": {
13
+ "AutoConfig": "configuration_hyena.StripedHyenaConfig",
14
+ "AutoModelForCausalLM": "modeling_hyena.StripedHyenaModelForCausalLM",
15
+ "AutoTokenizer": [
16
+ "tokenizer.ByteTokenizer",
17
+ null
18
+ ]
19
+ },
20
+ "column_split": false,
21
+ "column_split_hyena": true,
22
+ "eps": 1e-06,
23
+ "final_norm": true,
24
+ "hidden_size": 4096,
25
+ "hyena_filter_groups": 1,
26
+ "hyena_layer_idxs": [
27
+ 0,
28
+ 1,
29
+ 2,
30
+ 3,
31
+ 4,
32
+ 5,
33
+ 6,
34
+ 7,
35
+ 9,
36
+ 10,
37
+ 11,
38
+ 12,
39
+ 13,
40
+ 14,
41
+ 15,
42
+ 17,
43
+ 18,
44
+ 19,
45
+ 20,
46
+ 21,
47
+ 22,
48
+ 23,
49
+ 25,
50
+ 26,
51
+ 27,
52
+ 28,
53
+ 29,
54
+ 30,
55
+ 31
56
+ ],
57
+ "inference_mode": false,
58
+ "inner_mlp_size": 10928,
59
+ "log_intermediate_values": false,
60
+ "make_vocab_size_divisible_by": 8,
61
+ "max_seqlen": 131072,
62
+ "mha_out_proj_bias": true,
63
+ "mlp_activation": "gelu",
64
+ "model_parallel_size": 1,
65
+ "model_type": "stripedhyena",
66
+ "num_attention_heads": 32,
67
+ "num_filters": 4096,
68
+ "num_layers": 32,
69
+ "pipe_parallel_size": 1,
70
+ "prefill_style": "fft",
71
+ "proj_groups": 1,
72
+ "qkv_proj_bias": true,
73
+ "rotary_emb_base": 10000,
74
+ "rotary_emb_scaling_factor": 16,
75
+ "short_filter_bias": true,
76
+ "short_filter_length": 3,
77
+ "smeared_gqa": false,
78
+ "split_k0": true,
79
+ "state_size": 8,
80
+ "tie_embeddings": true,
81
+ "torch_dtype": "bfloat16",
82
+ "transformers_version": null,
83
+ "use_cache": true,
84
+ "use_flash_attention_2": true,
85
+ "use_flash_depthwise": false,
86
+ "use_flash_rmsnorm": false,
87
+ "use_flashfft": false,
88
+ "use_interpolated_rotary_pos_emb": true,
89
+ "vocab_size": 512
90
+ }
configuration_hyena.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import json
3
+
4
+
5
+ class StripedHyenaConfig(PretrainedConfig):
6
+ model_type = "stripedhyena"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ num_filters=4096,
13
+ inner_mlp_size=14336,
14
+ attn_layer_idxs=[],
15
+ hyena_layer_idxs=[],
16
+ num_layers=32,
17
+ tie_embeddings=False,
18
+ short_filter_length=3,
19
+ num_attention_heads=32,
20
+ proj_groups=4,
21
+ hyena_filter_groups=1,
22
+ split_k0=True,
23
+ column_split_hyena=True,
24
+ column_split=False,
25
+ model_parallel_size=1,
26
+ pipe_parallel_size=1,
27
+ short_filter_bias=True,
28
+ mha_out_proj_bias=False,
29
+ qkv_proj_bias=False,
30
+ final_norm=True,
31
+ use_cache=True,
32
+ use_flash_attention_2=True,
33
+ use_flash_rmsnorm=True,
34
+ use_flash_depthwise=False,
35
+ use_flashfft=False,
36
+ inference_mode=False,
37
+ prefill_style="fft",
38
+ max_seqlen=32768,
39
+ eps=1e-5,
40
+ state_size=2,
41
+ rotary_emb_base=500000,
42
+ smeared_gqa=False,
43
+ make_vocab_size_divisible_by=8,
44
+ log_intermediate_values=False,
45
+ **kwargs,
46
+ ):
47
+ self.vocab_size = vocab_size
48
+ self.hidden_size = hidden_size
49
+ self.num_filters = num_filters
50
+ self.inner_mlp_size = inner_mlp_size
51
+ self.attn_layer_idxs = attn_layer_idxs
52
+ self.hyena_layer_idxs = hyena_layer_idxs
53
+ self.num_layers = num_layers
54
+ self.tie_embeddings = tie_embeddings
55
+ self.short_filter_length = short_filter_length
56
+ self.num_attention_heads = num_attention_heads
57
+ self.proj_groups = proj_groups
58
+ self.hyena_filter_groups = hyena_filter_groups
59
+ self.split_k0 = split_k0
60
+ self.column_split_hyena = column_split_hyena
61
+ self.column_split = column_split
62
+ self.model_parallel_size = model_parallel_size
63
+ self.pipe_parallel_size = pipe_parallel_size
64
+ self.short_filter_bias = short_filter_bias
65
+ self.mha_out_proj_bias = mha_out_proj_bias
66
+ self.qkv_proj_bias = qkv_proj_bias
67
+ self.final_norm = final_norm
68
+ self.use_cache = use_cache
69
+ self.use_flash_attention_2 = use_flash_attention_2
70
+ self.use_flash_rmsnorm = use_flash_rmsnorm
71
+ self.use_flash_depthwise = use_flash_depthwise
72
+ self.use_flashfft = use_flashfft
73
+ self.inference_mode = inference_mode
74
+ self.prefill_style = prefill_style
75
+ self.max_seqlen = max_seqlen
76
+ self.eps = eps
77
+ self.state_size = state_size
78
+ self.rotary_emb_base = rotary_emb_base
79
+ self.smeared_gqa = smeared_gqa
80
+ self.make_vocab_size_divisible_by = make_vocab_size_divisible_by
81
+ self.log_intermediate_values = log_intermediate_values
82
+ super().__init__(**kwargs)
83
+
84
+ def to_dict(self):
85
+ return {attr: getattr(self, attr) for attr in self.__dict__}
86
+
87
+ @classmethod
88
+ def from_original_config(cls, config_path, **kwargs):
89
+ with open(config_path, "r") as f:
90
+ config = json.load(f)
91
+
92
+ return cls(**config, **kwargs)
engine.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ import gc
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ try:
12
+ import conv1d_cpp
13
+ except:
14
+ pass
15
+ from .utils import column_split
16
+
17
+ IIR_PREFILL_MODES = [
18
+ "recurrence",
19
+ "modal-fft",
20
+ "hybrid-modal-recurrence",
21
+ "modal-scan",
22
+ "canonical-fft",
23
+ "iir-fir-caching",
24
+ ]
25
+
26
+
27
+ def canonicalize_modal_system(poles, residues):
28
+ """Canonicalize a modal system.
29
+
30
+ Args:
31
+ poles (Tensor): The poles of the system.
32
+ residues (Tensor): The residues of the system.
33
+
34
+ Returns:
35
+ Tuple[Tensor, Tensor]: The canonicalized poles and residues.
36
+ """
37
+ raise NotImplementedError
38
+
39
+
40
+ def list_tensors(idx):
41
+ for obj in gc.get_objects():
42
+ try:
43
+ if torch.is_tensor(obj) and isinstance(obj, torch.Tensor):
44
+ # dump to log
45
+ print(type(obj), obj.size())
46
+ el = obj[0]
47
+ with open(f"tensors_{idx}.txt", "a") as f:
48
+ f.write(f"{type(obj)} {obj.size()} {el}\n")
49
+ except Exception as e:
50
+ pass
51
+
52
+
53
+ class HyenaInferenceEngine:
54
+ def __init__(
55
+ self,
56
+ fir_fn=None,
57
+ iir_prefill_style="modal-fft",
58
+ layer_idx=None,
59
+ ) -> None:
60
+ self.fir_fn = fir_fn
61
+ assert iir_prefill_style in IIR_PREFILL_MODES, f"iir_prefill_style must be one of {IIR_PREFILL_MODES}"
62
+ self.iir_prefill_style = iir_prefill_style
63
+ self.layer_idx = layer_idx
64
+ self.low_mem_mode = False
65
+
66
+ def parallel_fir(
67
+ self,
68
+ fir_fn,
69
+ u,
70
+ weight,
71
+ bias,
72
+ L,
73
+ fir_length=3,
74
+ inference_params=None,
75
+ prefill_mode=None,
76
+ padding_mask=None,
77
+ ):
78
+ """Compute the output state of the long convolutional filter."""
79
+ # prepare input layout, dimensions and dispatch to fir kernel
80
+ if fir_fn != torch.nn.functional.conv1d:
81
+ z_pre = fir_fn(u)[:, :L] # B, L, D
82
+ z_pre = z_pre.permute(0, 2, 1)
83
+ else:
84
+ u = u.permute(0, 2, 1) # B, D, L
85
+ z_pre = fir_fn(
86
+ u,
87
+ weight,
88
+ bias=None, # don't pass it here, add manually instead! source of small error
89
+ stride=1,
90
+ padding=fir_length - 1,
91
+ groups=u.shape[1],
92
+ )[..., :L]
93
+
94
+ # add manually instead! source of small error
95
+ z_pre = z_pre + bias[None, :, None]
96
+
97
+ # handle padding post fir, the only place with biases
98
+ if type(padding_mask) == torch.Tensor:
99
+ z_pre = z_pre * padding_mask[:, None]
100
+
101
+ if inference_params is not None:
102
+ # handle seqlen last and dim last cases for `u`
103
+ if fir_fn != torch.nn.functional.conv1d:
104
+ fir_state = u[:, -fir_length + 1 :].permute(0, 2, 1)
105
+ else:
106
+ fir_state = u[..., -fir_length + 1 :]
107
+ else:
108
+ fir_state = None
109
+
110
+ return z_pre, fir_state
111
+
112
+ def parallel_iir(
113
+ self,
114
+ z_pre,
115
+ h,
116
+ D,
117
+ L,
118
+ poles,
119
+ residues,
120
+ t,
121
+ dims,
122
+ layer_idx,
123
+ inference_params=None,
124
+ prefill_style="fft",
125
+ fftconv_fn=None,
126
+ padding_mask=None,
127
+ use_flashfft=False,
128
+ column_split_hyena=False,
129
+ long_fir_threshold=None,
130
+ ):
131
+ """Compute the output state of the short convolutional filter."""
132
+ fft_size = 2 * L
133
+ hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims
134
+ # Compatibility with training infra that column splits the projections
135
+ if column_split_hyena:
136
+ z = z_pre.reshape(
137
+ z_pre.shape[0],
138
+ num_attention_heads,
139
+ 3 * hidden_size_per_attention_head,
140
+ z_pre.shape[2],
141
+ )
142
+ x2, x1, v = (
143
+ z[:, :, :hidden_size_per_attention_head],
144
+ z[
145
+ :,
146
+ :,
147
+ hidden_size_per_attention_head : 2 * hidden_size_per_attention_head,
148
+ ],
149
+ z[:, :, 2 * hidden_size_per_attention_head :],
150
+ )
151
+ x2, x1, v = (
152
+ x2.reshape(x2.shape[0], -1, x2.shape[-1]),
153
+ x1.reshape(x1.shape[0], -1, x1.shape[-1]),
154
+ v.reshape(v.shape[0], -1, v.shape[-1]),
155
+ )
156
+ else:
157
+ x2, x1, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)
158
+
159
+ x1v = x1 * v
160
+
161
+ if inference_params is not None and prefill_style == "recurrence":
162
+ y = self.prefill_via_direct_recurrence(
163
+ inference_params=inference_params,
164
+ x1v=x1v,
165
+ L=L,
166
+ poles=poles,
167
+ residues=residues,
168
+ )
169
+
170
+ else:
171
+ if use_flashfft and (L % 2) == 0: # only works with even L
172
+ y = fftconv_fn(
173
+ x1v.to(dtype=torch.bfloat16).contiguous(),
174
+ h.to(dtype=torch.float32),
175
+ )
176
+ X_s = None
177
+
178
+ elif long_fir_threshold is None:
179
+ H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size
180
+ X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size)
181
+ X = X_s[..., : H.shape[-1]]
182
+ if len(z_pre.shape) > 3:
183
+ H = H.unsqueeze(1)
184
+ y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
185
+
186
+ else:
187
+ assert h.shape[0] == 1, "batch size must be 1 for long_fir_threshold"
188
+ h = h[0][:, None] # rearrange to d, 1, l for depthwise conv1d
189
+ h = h[..., :long_fir_threshold]
190
+ y = F.conv1d(
191
+ x1v,
192
+ h.to(dtype=x1v.dtype),
193
+ stride=1,
194
+ groups=x1v.shape[1],
195
+ padding=h.shape[-1] - 1,
196
+ )[..., :L]
197
+
198
+ y = y.to(dtype=x1v.dtype)
199
+ y = (y + x1v * D.unsqueeze(-1)) * x2
200
+
201
+ if inference_params is not None:
202
+ if prefill_style == "fft":
203
+ self.prefill_via_modal_fft(
204
+ inference_params=inference_params,
205
+ x1v=x1v,
206
+ X_s=X_s,
207
+ L=L,
208
+ t=t,
209
+ poles=poles,
210
+ dims=dims,
211
+ layer_idx=layer_idx,
212
+ use_flashfft=use_flashfft,
213
+ fftconv_fn=fftconv_fn,
214
+ )
215
+
216
+ elif prefill_style == "recurrence":
217
+ # recurrent prefill is done before
218
+ pass
219
+ else:
220
+ raise NotImplementedError
221
+ if self.low_mem_mode:
222
+ # TODO: smarter gc
223
+ del z_pre, x2, x1, v, x1v, h, poles, residues
224
+ torch.cuda.empty_cache()
225
+
226
+ return y.permute(0, 2, 1)
227
+
228
+ def step_fir(self, u, fir_state, weight, bias=None):
229
+ """Step the FIR filter.
230
+
231
+ Note:
232
+ `fir_state` contains the last `short_filter_length - 1` elements of `u`: `u_(L-2), u_{L-1), ...`
233
+ We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]` (SISO / multi SISO layout).
234
+ """
235
+ h0, h = weight[..., 0, -1], weight[..., 0, :-1]
236
+ h0, h = h0[None], h[None]
237
+ y = h0 * u + torch.sum(fir_state * h, dim=-1) + bias
238
+
239
+ # update
240
+ fir_state = torch.roll(fir_state, -1, dims=2)
241
+ fir_state[..., -1] = u
242
+ return y, fir_state
243
+
244
+ def step_iir(self, x2, x1, v, D, residues, poles, iir_state, iir_groups=1):
245
+ x1v = x1 * v
246
+
247
+ residues, poles = (
248
+ torch.view_as_complex(residues.to(torch.float32)),
249
+ torch.view_as_complex(poles.to(torch.float32)),
250
+ )
251
+ # squeeze the dummy seqlen dimension
252
+ # D, state_dim, 1 -> 1, D, state_dim
253
+ residues, poles = residues[..., 0][None], poles[..., 0][None]
254
+ iir_state = poles * iir_state + x1v[..., None]
255
+
256
+ res_state = torch.sum(residues * iir_state, dim=-1).real
257
+
258
+ if iir_groups > 1:
259
+ raise NotImplementedError
260
+ y = x2 * (res_state + D * x1v)
261
+
262
+ return y, iir_state
263
+
264
+ def prefill_via_fir_caching(self, u, inference_params, L, *args, **kwargs):
265
+ """Turns the IIR filter into a FIR and uses a cache for decoding."""
266
+ raise NotImplementedError(":)")
267
+
268
+ def prefill_via_direct_recurrence(
269
+ self, inference_params, x1v, L, residues, poles, *args, **kwargs
270
+ ) -> torch.Tensor:
271
+ """
272
+ Compute the IIR state via explicit SSM recurrence (modal form)
273
+
274
+ This is the most memory efficient prefilling method for Hyena filters.
275
+
276
+ Note:
277
+ dtypes: [state: float32, poles: float32, x1v: bfloat16, output: bfloat16]
278
+ """
279
+ state_dim = poles.shape[1]
280
+ x1v_ = x1v[..., None, None] # b, d, l, sdim, reim
281
+ x1v_ = x1v_.repeat(1, 1, 1, state_dim, 2) # b, d, l, sdim, reim
282
+ x1v_[..., 1] = 0
283
+
284
+ state = 0 * x1v_[:, :, 0]
285
+ output = 0 * x1v_[:, :, :, 0, 0] # b, d, l
286
+
287
+ # suppress dummy seqlen dimension
288
+ poles = poles[:, :, 0][None]
289
+ residues = residues[:, :, 0][None].repeat(x1v_.shape[0], 1, 1, 1) # b, d, sdim, reim
290
+
291
+ # state: b, d, sdim, reim
292
+ # poles: 1, d, sdim, reim
293
+ # x1v_: b, d, l, sdim, reim
294
+ for i in range(L):
295
+ state[..., 0] = poles[..., 0] * state[..., 0] - poles[..., 1] * state[..., 1] + x1v_[:, :, i, :, 0]
296
+ state[..., 1] = poles[..., 0] * state[..., 1] + poles[..., 1] * state[..., 0] + x1v_[:, :, i, :, 1]
297
+ output[:, :, i] = torch.sum(residues * state, dim=-2)[..., 0] # .real
298
+
299
+ inference_params.state_dict[self.layer_idx] = torch.view_as_complex(state.to(dtype=torch.float32))
300
+
301
+ return output
302
+
303
+ def prefill_via_hybrid_recurrence(self, inference_params, u, log_poles, x1v_f_a, L, *args, **kwargs):
304
+ """
305
+ Compute the IIR state via hybrid recurrence-convolution over blocks
306
+ """
307
+ raise NotImplementedError(":)")
308
+
309
+ def prefill_via_scan(self, u, inference_params=None, *args, **kwargs):
310
+ raise NotImplementedError
311
+
312
+ def prefill_via_canonical_fft(self, u, inference_params=None, *args, **kwargs):
313
+ """
314
+ Compute the IIR state via a single FFT with the denominator of the SSM in companion form.
315
+
316
+ This is the most memory efficient "parallelized" prefilling method for Hyena.
317
+
318
+ From: https://arxiv.org/abs/2310.18780
319
+ """
320
+ raise NotImplementedError(":)")
321
+
322
+ def prefill_via_modal_fft(
323
+ self,
324
+ inference_params,
325
+ x1v,
326
+ L,
327
+ poles,
328
+ t,
329
+ dims,
330
+ layer_idx,
331
+ X_s=None,
332
+ use_flashfft=False,
333
+ fftconv_fn=None,
334
+ state_dtype=torch.complex64,
335
+ *args,
336
+ **kwargs,
337
+ ):
338
+ """
339
+ Compute the IIR state via a single FFT, using the poles of the SSM in modal form.
340
+ """
341
+ # When the model has a long convolution derived from a SSM in modal form and prefill_style is "fft",
342
+ # we split the filter into poles and residues and reuse FFT computation on the input.
343
+ # This optimization is currently not supported when using flashfftconv.
344
+ hidden_size, _, _, state_size, hyena_filter_groups = dims
345
+
346
+ if use_flashfft:
347
+ # using real states
348
+ poles = poles.squeeze().reshape(poles.shape[0], -1)[..., None]
349
+
350
+ state_s = poles**t
351
+ if hyena_filter_groups > 1:
352
+ raise NotImplementedError
353
+
354
+ x1v = x1v[:, :, None].repeat(1, 1, 2 * state_size, 1)
355
+ x1v = x1v.reshape(x1v.shape[0], -1, x1v.shape[-1])
356
+ state_s = state_s[None]
357
+
358
+ state = fftconv_fn(
359
+ x1v.contiguous(),
360
+ state_s.to(dtype=torch.float32),
361
+ )
362
+ state = state[..., L - 1].reshape(x1v.shape[0], hidden_size, state_size, 2)
363
+ state = torch.view_as_complex(state.contiguous().to(dtype=torch.float32))
364
+ inference_params.state_dict[self.layer_idx] = state
365
+ else:
366
+ assert X_s is not None
367
+ bs = x1v.shape[0]
368
+ fft_size = 2 * L
369
+ poles = torch.view_as_complex(poles.to(torch.float32))
370
+ state_s = poles**t
371
+ state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # B, D, state_dim, 2 * L
372
+ if hyena_filter_groups > 1:
373
+ state_S = state_S.repeat_interleave(hidden_size // hyena_filter_groups, 1)
374
+ state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
375
+ inference_params.state_dict[layer_idx] = state[..., L - 1].to(dtype=state_dtype)
376
+
377
+ def _compute_state(self, log_poles, u, t, L, *args, **kwargs):
378
+ """
379
+ Compute the IIR state given an input `u` and log_poles of the modal system.
380
+ """
381
+ bs = u.shape[0]
382
+ fft_size = 2 * L
383
+ U = torch.fft.rfft(u.to(torch.float32), n=fft_size)
384
+ fft_size = 2 * L
385
+ x = (log_poles * t).exp()
386
+ # [batch, hidden_size, state_dim, 2 * seqlen]
387
+ X = torch.fft.fft(x, n=fft_size).repeat(bs, 1, 1, 1)
388
+ state = torch.fft.ifft(U[..., None, :] * X, n=fft_size)[..., :L]
389
+ return state
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.36.2"
4
+ }
layers.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.nn.functional as F
8
+ import torch.nn as nn
9
+ from .utils import grab_first_if_tuple
10
+
11
+ def grab_first_if_tuple(x):
12
+ if x.__class__.__name__ == "tuple":
13
+ return x[0]
14
+ else:
15
+ return x
16
+
17
+ class RMSNorm(torch.nn.Module):
18
+ def __init__(self, config):
19
+ super(RMSNorm, self).__init__()
20
+ self.eps, self.hidden_size = config.eps, config.hidden_size
21
+ self.scale = torch.nn.Parameter(torch.ones(self.hidden_size))
22
+ self.register_parameter("scale", self.scale)
23
+ self.use_flash_rmsnorm = config.get("use_flash_rmsnorm", False)
24
+
25
+ if self.use_flash_rmsnorm:
26
+ try:
27
+ from flash_attn.ops.rms_norm import rms_norm as rmsnorm_func
28
+
29
+ self.rmsnorm_func = rmsnorm_func
30
+ except:
31
+ raise ImportError(
32
+ "For `use_flash_rmsnorm`: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/layer_norm`"
33
+ )
34
+
35
+ def forward(self, x):
36
+ if self.use_flash_rmsnorm:
37
+ return self.rmsnorm_func(x, self.scale, self.eps)
38
+ else:
39
+ y = x / (x.norm(2, dim=-1, keepdim=True) * self.hidden_size ** (-1.0 / 2) + self.eps)
40
+ return self.scale * y
41
+
42
+
43
+ class ParallelGatedMLP(nn.Module):
44
+ def __init__(
45
+ self,
46
+ config,
47
+ ):
48
+ super().__init__()
49
+
50
+ multiple_of = config.get("inner_size_multiple_of", 64)
51
+ self.act_type = config.get("mlp_activation", "silu")
52
+ if self.act_type == "gelu":
53
+ self.act = F.gelu
54
+ elif self.act_type == "silu":
55
+ self.act = F.silu
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ self.multiple_of = multiple_of * config.model_parallel_size
60
+
61
+ inner_size = int(2 * config.hidden_size * 4 / 3)
62
+ inner_size = self.multiple_of * ((inner_size + self.multiple_of - 1) // self.multiple_of)
63
+ if config.get("inner_mlp_size", None) is not None:
64
+ inner_size = config.inner_mlp_size
65
+
66
+ self.l1 = nn.Linear(
67
+ in_features=config.hidden_size,
68
+ out_features=inner_size,
69
+ bias=False,
70
+ )
71
+ self.l2 = nn.Linear(
72
+ in_features=config.hidden_size,
73
+ out_features=inner_size,
74
+ bias=False,
75
+ )
76
+ self.l3 = nn.Linear(
77
+ in_features=inner_size,
78
+ out_features=config.hidden_size,
79
+ bias=False,
80
+ )
81
+
82
+ def forward(self, z):
83
+ z1, z2 = self.l1(z), self.l2(z)
84
+ z1, z2 = grab_first_if_tuple(z1), grab_first_if_tuple(z2)
85
+ y = self.l3(self.act(z1) * z2)
86
+ return grab_first_if_tuple(y)
87
+
88
+
89
+ class Embedding(nn.Module):
90
+ _train_dtype = "bf16"
91
+
92
+ def __init__(self, config):
93
+ super().__init__()
94
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
95
+
96
+ def embed(self, input_ids, position_ids=None, tokentype_ids=None):
97
+ embeddings = self.word_embeddings(input_ids)
98
+ return embeddings
99
+
100
+ def unembed(self, u):
101
+ weight = self.word_embeddings.weight
102
+ return torch.matmul(u, weight)
103
+
104
+
105
+ class VocabParallelEmbedding(nn.Embedding):
106
+ "Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py"
107
+
108
+ def __init__(self, config):
109
+ vocab_size, process_group, padding_idx = (
110
+ config.vocab_size,
111
+ config.get("process_group", None),
112
+ config.get("padding_idx", None),
113
+ )
114
+ self.process_group = process_group
115
+ if process_group is not None:
116
+ world_size = torch.distributed.get_world_size(process_group)
117
+ if vocab_size % world_size != 0:
118
+ raise ValueError(
119
+ f"vocab_size ({vocab_size}) must be divisible by " f"world_size ({world_size})"
120
+ )
121
+ if world_size > 1 and padding_idx is not None:
122
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
123
+ else:
124
+ world_size = 1
125
+ super().__init__(
126
+ vocab_size // world_size,
127
+ embedding_dim=config.hidden_size,
128
+ padding_idx=padding_idx,
129
+ )
130
+
131
+ def embed(self, x: Tensor) -> Tensor:
132
+ if self.process_group is None:
133
+ return self.forward(x)
134
+ else:
135
+ rank = torch.distributed.get_rank(self.process_group)
136
+ vocab_size = self.num_embeddings
137
+ vocab_start_index, vocab_end_index = (
138
+ rank * vocab_size,
139
+ (rank + 1) * vocab_size,
140
+ )
141
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
142
+ input_ids_mask = (x < vocab_start_index) | (x >= vocab_end_index)
143
+ x = x - vocab_start_index
144
+ x[input_ids_mask] = 0
145
+ embeddings = self.forward(x)
146
+ embeddings[input_ids_mask] = 0.0
147
+ # Reduce to the global process group
148
+ torch.distributed.all_reduce(embeddings, group=self.process_group)
149
+ return embeddings
150
+
151
+ def unembed(self, u: Tensor) -> Tensor:
152
+ if self.process_group is None:
153
+ return u @ self.weight.T
154
+ else:
155
+ raise NotImplementedError
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc5f3b6258c1a7e513cc9e41a326d8d5e0f32d112408273be54ba69b522b50de
3
+ size 4980059464
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6c6760a34950595555656d00f19cb5b1620e5a47cc3a1a0c56a3a1f057ebfa1
3
+ size 4929849248
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdeef9c8c68ed48bc97e6ddab502fa95d8327dfe917a33c4db079e0fc29a7267
3
+ size 3003304856
model.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+ # Note: MP and PP utilities are removed for ease of use and editing.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .cache import InferenceParams, RecurrentInferenceParams
11
+ from .engine import HyenaInferenceEngine
12
+ from .layers import ParallelGatedMLP, RMSNorm, VocabParallelEmbedding
13
+ from .utils import column_split, print_rank_0
14
+
15
+ try:
16
+ from flash_attn.modules.mha import MHA
17
+ except ImportError:
18
+ "flash_attn not installed"
19
+
20
+ try:
21
+ from .positional_embeddings import swap_mha_rope
22
+ except ImportError:
23
+ "could not import swap_mha_rope from positional_embeddings.py"
24
+
25
+ # dummy import to force huggingface to bundle the tokenizer
26
+ from .tokenizer import ByteTokenizer
27
+
28
+
29
+ class AttentionBlock(nn.Module):
30
+ def __init__(self, config, layer_idx) -> None:
31
+ super().__init__()
32
+ self.config = config
33
+ self.pre_norm, self.post_norm = RMSNorm(config), RMSNorm(config)
34
+ self.layer_idx = layer_idx
35
+ self.proj_groups = config.get("proj_groups", 1)
36
+ dtype = config.get("attn_block_dtype", torch.bfloat16)
37
+ mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
38
+ self.num_attention_heads = config.num_attention_heads
39
+ self.hidden_size_per_attention_head = config.hidden_size // config.num_attention_heads
40
+
41
+ self.counter = 0
42
+ self.inner_mha_cls = MHA(
43
+ embed_dim=config.hidden_size,
44
+ num_heads=config.num_attention_heads,
45
+ num_heads_kv=config.num_attention_heads // self.proj_groups,
46
+ rotary_emb_dim=config.hidden_size // config.num_attention_heads,
47
+ qkv_proj_bias=config.get("qkv_proj_bias", True),
48
+ rotary_emb_base=config.get("rotary_emb_base", 10000),
49
+ causal=True,
50
+ layer_idx=layer_idx,
51
+ out_proj_bias=config.get("mha_out_proj_bias", True),
52
+ use_flash_attn=self.config.use_flash_attn,
53
+ ).to(dtype=dtype)
54
+
55
+ # check if using interpolated rotary pos emb from config, and swap the rope emb
56
+ if config.get("use_interpolated_rotary_pos_emb", False):
57
+ swap_mha_rope(
58
+ mha=self.inner_mha_cls,
59
+ kwargs_new_rope={'scaling_factor': config.get("rotary_emb_scaling_factor", 1.)},
60
+ )
61
+
62
+ if self.config.get("smeared_gqa", False):
63
+ self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads
64
+ self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
65
+
66
+ self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
67
+
68
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
69
+ if (
70
+ type(padding_mask) == torch.Tensor
71
+ ): # workaround for masking bug in FA. This works because Wqkv does not have bias
72
+ # and attention scores will be also automatically zeroed.
73
+ u = u * padding_mask[..., None]
74
+ u = (
75
+ self.inner_mha_cls(
76
+ self.pre_norm(u),
77
+ inference_params=inference_params,
78
+ )
79
+ + u
80
+ )
81
+ if type(padding_mask) == torch.Tensor: # guard against bias
82
+ u = u * padding_mask[..., None]
83
+ u = self.mlp(self.post_norm(u)) + u
84
+ return u, None
85
+
86
+
87
+ class ParallelHyenaFilter(nn.Module):
88
+ def __init__(self, config, layer_idx) -> None:
89
+ super().__init__()
90
+ self.config = config
91
+ self.layer_idx = layer_idx
92
+ self.hyena_filter_groups = config.get("hyena_filter_groups", self.config.hidden_size)
93
+
94
+ self.use_flashfft = config.get("use_flashfft", False)
95
+ self.state_size = config.state_size
96
+ self.hidden_size = config.hidden_size
97
+ self.num_filters = config.num_filters
98
+ self.inference_mode = config.get("inference_mode", True)
99
+ self.counter = 0
100
+ self.column_split_hyena = config.get("column_split_hyena", True)
101
+
102
+ assert self.hidden_size % self.num_filters == 0 and self.num_filters <= self.hidden_size
103
+
104
+ self.D = nn.Parameter(torch.zeros(self.hidden_size))
105
+
106
+ # attention heads are not used except to split post short_filter
107
+ # projections in the same way as the checkpoint
108
+ self.num_attention_heads = config.num_attention_heads
109
+ self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
110
+
111
+ # after preprocessing here we can save the new checkpoint
112
+ self.short_filter_length = config.short_filter_length
113
+ self.short_filter_weight = nn.Parameter(torch.randn(3 * config.hidden_size, 1, config.short_filter_length))
114
+ self.short_filter_bias = (
115
+ nn.Parameter(torch.randn(3 * config.hidden_size)) if config.short_filter_bias else None
116
+ )
117
+
118
+ self.engine = HyenaInferenceEngine(layer_idx=layer_idx)
119
+ self.use_flash_depthwise = config.get("use_flash_depthwise", False)
120
+ self.data_dtype = None
121
+
122
+ if self.use_flash_depthwise:
123
+ self.fir_fn = FlashDepthwiseConv1d(
124
+ channels=3 * self.hidden_size,
125
+ kernel_size=self.short_filter_length,
126
+ padding=self.short_filter_length - 1,
127
+ weights=self.short_filter_weight,
128
+ bias=self.short_filter_bias,
129
+ device=None,
130
+ dtype=self.config.get("depthwise_dtype", torch.bfloat16),
131
+ )
132
+ else:
133
+ self.fir_fn = F.conv1d
134
+
135
+ self.fftconv_fn = None
136
+ self.long_fir_threshold = config.get("long_fir_threshold", None)
137
+ if self.long_fir_threshold is not None:
138
+ assert self.use_flashfft is False, "long_fir_threshold not compatible with fused flashfft"
139
+
140
+ self.num_systems = self.hidden_size // self.hyena_filter_groups
141
+
142
+ poles = torch.randn(self.num_systems, self.state_size, 1, 2)
143
+
144
+ # TODO: bring over init from internals
145
+ poles[..., 0] = 1e-2 * torch.randn(self.num_systems, self.state_size, 1)
146
+ poles[..., 1] = 1e-3 * torch.randn(self.num_systems, self.state_size, 1)
147
+
148
+ self.poles = nn.Parameter(poles)
149
+
150
+ self.residues = nn.Parameter(torch.randn(self.num_systems, self.state_size, 1, 2))
151
+ self.h = None
152
+
153
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
154
+ if inference_params is not None and self.layer_idx in inference_params.fir_state_dict.keys():
155
+ return self.sequential_forward(u, inference_params)
156
+
157
+ else:
158
+ return self.parallel_forward(u, inference_params, padding_mask)
159
+
160
+ def parallel_forward(self, u, inference_params=None, padding_mask=None):
161
+ L = u.shape[1]
162
+ z_pre, fir_state = self.engine.parallel_fir(
163
+ self.fir_fn,
164
+ u,
165
+ self.short_filter_weight,
166
+ self.short_filter_bias,
167
+ L,
168
+ fir_length=self.short_filter_length,
169
+ inference_params=inference_params,
170
+ padding_mask=padding_mask,
171
+ )
172
+ if inference_params:
173
+ inference_params.fir_state_dict[self.layer_idx] = fir_state
174
+
175
+ if self.h is None:
176
+ h, filter_dtype, poles, residues = self.compute_filter(L, u.device)
177
+ else:
178
+ h = self.h
179
+ filter_dtype = self.h.dtype
180
+
181
+ if self.hyena_filter_groups > 1:
182
+ h = h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 1)
183
+
184
+ # if inference_params is not None, we plan to perform generation:
185
+ # prefilling is handled by the engine.
186
+ dims = (
187
+ self.hidden_size,
188
+ self.num_attention_heads,
189
+ self.hidden_size_per_attention_head,
190
+ self.state_size,
191
+ self.hyena_filter_groups,
192
+ )
193
+ y = self.engine.parallel_iir(
194
+ z_pre,
195
+ h,
196
+ self.D,
197
+ L,
198
+ t=self.t,
199
+ poles=self.poles,
200
+ residues=self.residues,
201
+ dims=dims,
202
+ inference_params=inference_params,
203
+ layer_idx=self.layer_idx,
204
+ prefill_style=self.config.get("prefill_style", "fft"),
205
+ use_flashfft=self.use_flashfft,
206
+ fftconv_fn=self.fftconv_fn,
207
+ column_split_hyena=self.column_split_hyena,
208
+ long_fir_threshold=self.long_fir_threshold,
209
+ padding_mask=padding_mask,
210
+ )
211
+
212
+ return y, inference_params
213
+
214
+ def sequential_forward(self, u, inference_params):
215
+ if self.data_dtype is None:
216
+ self.data_dtype = u.dtype
217
+ if len(u.shape) > 2:
218
+ u = u[:, -1]
219
+
220
+ fir_state, iir_state = (
221
+ inference_params.fir_state_dict[self.layer_idx],
222
+ inference_params.state_dict[self.layer_idx],
223
+ )
224
+
225
+ z_pre, fir_state = self.engine.step_fir(
226
+ u, fir_state, weight=self.short_filter_weight, bias=self.short_filter_bias
227
+ )
228
+ x2, x1, v = (
229
+ column_split(z_pre, self.num_attention_heads, self.hidden_size_per_attention_head)
230
+ if self.column_split_hyena
231
+ else z_pre.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1)
232
+ )
233
+
234
+ y, iir_state = self.engine.step_iir(
235
+ x2,
236
+ x1,
237
+ v,
238
+ self.D,
239
+ self.residues,
240
+ self.poles,
241
+ iir_state,
242
+ iir_groups=self.hyena_filter_groups,
243
+ )
244
+
245
+ inference_params.fir_state_dict[self.layer_idx] = fir_state
246
+ inference_params.state_dict[self.layer_idx] = iir_state
247
+ y = y.to(dtype=self.data_dtype)
248
+ return y[:, None], inference_params
249
+
250
+ def update_time(self, L, device):
251
+ """
252
+ Set [0, 1, ..., L-1] where L is the length of the current batch of inputs.
253
+ If L is greater than the length of the previous batch, then the time vector is
254
+ reinitialized. Otherwise, the time vector is truncated from cache.
255
+ """
256
+ if not hasattr(self, "t"):
257
+ self.t = torch.arange(L, device=device)[None, None]
258
+ elif self.t.shape[-1] < L:
259
+ self.t = torch.arange(L, device=device)[None, None]
260
+ else:
261
+ self.t = self.t[..., :L]
262
+
263
+ def compute_filter(self, L, device):
264
+ self.update_time(L, device)
265
+ filter_dtype = torch.float32
266
+ residues, log_poles = (
267
+ torch.view_as_complex(self.residues.to(filter_dtype)),
268
+ torch.view_as_complex(self.poles.to(filter_dtype)).log(),
269
+ )
270
+ h = (residues * (log_poles * self.t).exp()).real.sum(1)[None]
271
+ return h, filter_dtype, log_poles, residues
272
+
273
+
274
+ class ParallelGatedConvBlock(nn.Module):
275
+ def __init__(self, config, layer_idx) -> None:
276
+ super().__init__()
277
+ self.config = config
278
+ self.layer_idx = layer_idx
279
+ self.low_mem_mode = config.get("low_mem_mode", False)
280
+ dtype = config.get("hyena_block_dtype", torch.float32)
281
+ mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
282
+ self.pre_norm, self.post_norm = RMSNorm(config).to(dtype=dtype), RMSNorm(config).to(dtype=dtype)
283
+ self.filter = ParallelHyenaFilter(config, layer_idx).to(dtype=dtype)
284
+ self.projections = nn.Linear(config.hidden_size, 3 * config.hidden_size)
285
+ self.out_filter_dense = nn.Linear(config.hidden_size, config.hidden_size).to(dtype)
286
+ self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
287
+
288
+ self.proj_norm_fn = self.proj_norm
289
+ self.res_mlp_norm_fn = self.res_mlp_norm
290
+
291
+ if self.config.get("compile", False):
292
+ self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
293
+ self.res_mlp_norm_fn = torch.compile(
294
+ self.res_mlp_norm, fullgraph=True, dynamic=False, mode="reduce-overhead"
295
+ )
296
+
297
+ def proj_norm(self, x):
298
+ return self.projections(self.pre_norm(x))
299
+
300
+ def res_mlp_norm(self, x):
301
+ return self.mlp(self.post_norm(x)) + x
302
+
303
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
304
+ z = self.proj_norm_fn(u)
305
+
306
+ if type(padding_mask) == torch.Tensor: # guard against bias
307
+ z = z * padding_mask[..., None]
308
+
309
+ z, inference_params = self.filter(z, inference_params=inference_params, padding_mask=padding_mask)
310
+
311
+ z_in = self.out_filter_dense(z) + u
312
+
313
+ if type(padding_mask) == torch.Tensor: # guard against bias
314
+ z_in = z_in * padding_mask[..., None]
315
+
316
+ y = self.res_mlp_norm_fn(z_in)
317
+
318
+ return y, inference_params
319
+
320
+
321
+ def get_block(config, layer_idx, flash_fft=None):
322
+ if layer_idx in config.attn_layer_idxs:
323
+ return AttentionBlock(config, layer_idx)
324
+ elif layer_idx in config.hyena_layer_idxs:
325
+ block = ParallelGatedConvBlock(config, layer_idx)
326
+ if config.get("use_flashfft", "False"):
327
+ block.filter.fftconv_fn = flash_fft
328
+ return block
329
+ else:
330
+ raise NotImplementedError
331
+
332
+
333
+ class StripedHyena(nn.Module):
334
+ def __init__(self, config):
335
+ super().__init__()
336
+ self.config = config
337
+ self.embedding_layer = VocabParallelEmbedding(config)
338
+ self.norm = RMSNorm(config) if config.get("final_norm", True) else None
339
+ self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)
340
+
341
+ if config.get("use_flashfft", "False"):
342
+ try:
343
+ from flashfftconv import FlashFFTConv
344
+ except:
345
+ raise ImportError
346
+ self.flash_fft = FlashFFTConv(2 * config.seqlen, dtype=torch.bfloat16)
347
+ else:
348
+ self.flash_fft = None
349
+
350
+ self.blocks = nn.ModuleList(
351
+ get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
352
+ )
353
+
354
+ def forward(self, x, inference_params_dict=None, padding_mask=None):
355
+ L = x.shape[1]
356
+ x = self.embedding_layer.embed(x)
357
+ if inference_params_dict is not None:
358
+ x, inference_params_dict_out = self.stateful_forward(
359
+ x,
360
+ inference_params_dict=inference_params_dict,
361
+ )
362
+ else:
363
+ x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask)
364
+
365
+ x = self.norm(x)
366
+ x = self.unembed.unembed(x)
367
+ return x, inference_params_dict_out
368
+
369
+ def stateful_forward(self, x, inference_params_dict=None):
370
+ for block_idx, block in enumerate(self.blocks):
371
+ block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
372
+ inference_params = inference_params_dict[block_name]
373
+ x, _ = block(x, inference_params=inference_params)
374
+
375
+ return x, inference_params_dict
376
+
377
+ def stateless_forward(self, x, padding_mask=None):
378
+ if type(padding_mask) == torch.Tensor:
379
+ x = x * padding_mask[..., None]
380
+
381
+ for _, block in enumerate(self.blocks):
382
+ x, _ = block(x, inference_params=None, padding_mask=padding_mask)
383
+ return x, None
384
+
385
+ def initialize_inference_params(self):
386
+ print_rank_0("Initializing inference params...")
387
+ inference_params_dict = {
388
+ "mha": InferenceParams(
389
+ max_seqlen=self.config.get("max_seqlen", 8192),
390
+ max_batch_size=self.config.get("max_batch_size", 1),
391
+ seqlen_offset=0,
392
+ ),
393
+ "hyena": RecurrentInferenceParams(
394
+ fir_filter_length=self.config.short_filter_length,
395
+ state_dim=self.config.state_size,
396
+ seqlen_offset=0,
397
+ ),
398
+ }
399
+ return inference_params_dict
400
+
401
+ def precompute_filters(self, L, device):
402
+ for block_idx, block in enumerate(self.blocks):
403
+ if type(block) == ParallelGatedConvBlock:
404
+ if type(block.filter) == ParallelHyenaFilter:
405
+ L = block.filter.long_fir_threshold or L
406
+ print_rank_0(f"Precomputing filters, L={L}...")
407
+
408
+ filter_dtype = torch.float16 if L >= 2048 else torch.float32
409
+
410
+ block.filter._set_time(L, device)
411
+ residues, poles = (
412
+ torch.view_as_complex(block.filter.residues.to(torch.float16)),
413
+ torch.view_as_complex(block.filter.poles.to(torch.float16)),
414
+ )
415
+
416
+ block.filter.h = (residues * poles**block.filter.t).real.sum(1)[None]
417
+ block.filter.h = block.filter.h.to(dtype=filter_dtype)
418
+
419
+ def load_poles_residues(self, path):
420
+ "Load different poles and residues for each layer."
421
+ for block_idx, block in enumerate(self.blocks):
422
+ if type(block) == ParallelGatedConvBlock:
423
+ if type(block.filter) == ParallelHyenaFilter:
424
+ print(f"Loading poles and residues for block {block_idx}")
425
+ poles = torch.load(path + f"/approx_poles_{block_idx+1}.pt", map_location="cpu")
426
+ poles = torch.view_as_real(poles)
427
+ residues = torch.load(path + f"/approx_residues_{block_idx+1}.pt", map_location="cpu")
428
+ residues = torch.view_as_real(residues)
429
+ poles = poles.permute(1, 0, 2).unsqueeze(-2)
430
+ residues = residues.permute(1, 0, 2).unsqueeze(-2)
431
+
432
+ block.filter.poles = nn.Parameter(poles)
433
+ block.filter.residues = nn.Parameter(residues)
434
+
435
+ def to_bfloat16_except_poles_residues(self):
436
+ """Convert all parameters to bfloat16 except for the poles and residues.
437
+
438
+ Particularly important for longer prompts.
439
+ """
440
+ for k, p in self.named_parameters():
441
+ if "poles" not in k and "residues" not in k:
442
+ p.data = p.data.to(torch.bfloat16)
443
+
444
+ def load_from_split_converted_state_dict(self, path):
445
+
446
+ print("Loading from split converted state dict")
447
+
448
+ embedding_weight = torch.load(path + "/layer_00.pt")["word_embeddings.weight"]
449
+ self.embedding_layer.weight = nn.Parameter(embedding_weight.to(self.embedding_layer.weight.dtype))
450
+
451
+ print("Loading embedding weight ok")
452
+
453
+ if self.config.get("final_norm", False) is not None:
454
+ idx = len(self.blocks) + 1
455
+ final_norm_scale = torch.load(path + f"/layer_{idx:02d}.pt")["norm.scale"]
456
+ self.norm.scale = nn.Parameter(final_norm_scale.to(self.norm.scale.dtype))
457
+
458
+ print("loading final norm ok")
459
+
460
+ if not self.config.get("tie_embeddings", True):
461
+ idx = len(self.blocks) + 2
462
+ embedding_weight = torch.load(path + f"/layer_{idx:02d}.pt")["word_embeddings.weight"]
463
+ self.unembed.weight = nn.Parameter(embedding_weight.to(self.unembed.weight.dtype))
464
+
465
+ print("loading unembed weight ok")
466
+
467
+ for block_idx, block in enumerate(self.blocks):
468
+ print("loading block {}...".format(block_idx))
469
+ # strict = False if type(block) == ParallelGatedConvBlock else True
470
+ # some blocks (optionally) go through a round of conv distillation on some parameters
471
+ strict = True # safer to be strict and account for every layer
472
+
473
+ loaded_dict = torch.load(path + f"/layer_{block_idx + 1:02d}.pt")
474
+ block.load_state_dict(loaded_dict, strict=strict)
model.safetensors.index.json ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 12913164672
4
+ },
5
+ "weight_map": {
6
+ "backbone.blocks.0.filter.D": "model-00001-of-00003.safetensors",
7
+ "backbone.blocks.0.filter.poles": "model-00001-of-00003.safetensors",
8
+ "backbone.blocks.0.filter.residues": "model-00001-of-00003.safetensors",
9
+ "backbone.blocks.0.filter.short_filter_bias": "model-00001-of-00003.safetensors",
10
+ "backbone.blocks.0.filter.short_filter_weight": "model-00001-of-00003.safetensors",
11
+ "backbone.blocks.0.mlp.l1.weight": "model-00001-of-00003.safetensors",
12
+ "backbone.blocks.0.mlp.l2.weight": "model-00001-of-00003.safetensors",
13
+ "backbone.blocks.0.mlp.l3.weight": "model-00001-of-00003.safetensors",
14
+ "backbone.blocks.0.out_filter_dense.bias": "model-00001-of-00003.safetensors",
15
+ "backbone.blocks.0.out_filter_dense.weight": "model-00001-of-00003.safetensors",
16
+ "backbone.blocks.0.post_norm.scale": "model-00001-of-00003.safetensors",
17
+ "backbone.blocks.0.pre_norm.scale": "model-00001-of-00003.safetensors",
18
+ "backbone.blocks.0.projections.bias": "model-00001-of-00003.safetensors",
19
+ "backbone.blocks.0.projections.weight": "model-00001-of-00003.safetensors",
20
+ "backbone.blocks.1.filter.D": "model-00001-of-00003.safetensors",
21
+ "backbone.blocks.1.filter.poles": "model-00001-of-00003.safetensors",
22
+ "backbone.blocks.1.filter.residues": "model-00001-of-00003.safetensors",
23
+ "backbone.blocks.1.filter.short_filter_bias": "model-00001-of-00003.safetensors",
24
+ "backbone.blocks.1.filter.short_filter_weight": "model-00001-of-00003.safetensors",
25
+ "backbone.blocks.1.mlp.l1.weight": "model-00001-of-00003.safetensors",
26
+ "backbone.blocks.1.mlp.l2.weight": "model-00001-of-00003.safetensors",
27
+ "backbone.blocks.1.mlp.l3.weight": "model-00001-of-00003.safetensors",
28
+ "backbone.blocks.1.out_filter_dense.bias": "model-00001-of-00003.safetensors",
29
+ "backbone.blocks.1.out_filter_dense.weight": "model-00001-of-00003.safetensors",
30
+ "backbone.blocks.1.post_norm.scale": "model-00001-of-00003.safetensors",
31
+ "backbone.blocks.1.pre_norm.scale": "model-00001-of-00003.safetensors",
32
+ "backbone.blocks.1.projections.bias": "model-00001-of-00003.safetensors",
33
+ "backbone.blocks.1.projections.weight": "model-00001-of-00003.safetensors",
34
+ "backbone.blocks.10.filter.D": "model-00001-of-00003.safetensors",
35
+ "backbone.blocks.10.filter.poles": "model-00001-of-00003.safetensors",
36
+ "backbone.blocks.10.filter.residues": "model-00001-of-00003.safetensors",
37
+ "backbone.blocks.10.filter.short_filter_bias": "model-00001-of-00003.safetensors",
38
+ "backbone.blocks.10.filter.short_filter_weight": "model-00001-of-00003.safetensors",
39
+ "backbone.blocks.10.mlp.l1.weight": "model-00001-of-00003.safetensors",
40
+ "backbone.blocks.10.mlp.l2.weight": "model-00001-of-00003.safetensors",
41
+ "backbone.blocks.10.mlp.l3.weight": "model-00001-of-00003.safetensors",
42
+ "backbone.blocks.10.out_filter_dense.bias": "model-00001-of-00003.safetensors",
43
+ "backbone.blocks.10.out_filter_dense.weight": "model-00001-of-00003.safetensors",
44
+ "backbone.blocks.10.post_norm.scale": "model-00001-of-00003.safetensors",
45
+ "backbone.blocks.10.pre_norm.scale": "model-00001-of-00003.safetensors",
46
+ "backbone.blocks.10.projections.bias": "model-00001-of-00003.safetensors",
47
+ "backbone.blocks.10.projections.weight": "model-00001-of-00003.safetensors",
48
+ "backbone.blocks.11.filter.D": "model-00001-of-00003.safetensors",
49
+ "backbone.blocks.11.filter.poles": "model-00001-of-00003.safetensors",
50
+ "backbone.blocks.11.filter.residues": "model-00001-of-00003.safetensors",
51
+ "backbone.blocks.11.filter.short_filter_bias": "model-00001-of-00003.safetensors",
52
+ "backbone.blocks.11.filter.short_filter_weight": "model-00001-of-00003.safetensors",
53
+ "backbone.blocks.11.mlp.l1.weight": "model-00001-of-00003.safetensors",
54
+ "backbone.blocks.11.mlp.l2.weight": "model-00001-of-00003.safetensors",
55
+ "backbone.blocks.11.mlp.l3.weight": "model-00001-of-00003.safetensors",
56
+ "backbone.blocks.11.out_filter_dense.bias": "model-00001-of-00003.safetensors",
57
+ "backbone.blocks.11.out_filter_dense.weight": "model-00001-of-00003.safetensors",
58
+ "backbone.blocks.11.post_norm.scale": "model-00001-of-00003.safetensors",
59
+ "backbone.blocks.11.pre_norm.scale": "model-00001-of-00003.safetensors",
60
+ "backbone.blocks.11.projections.bias": "model-00001-of-00003.safetensors",
61
+ "backbone.blocks.11.projections.weight": "model-00001-of-00003.safetensors",
62
+ "backbone.blocks.12.filter.D": "model-00001-of-00003.safetensors",
63
+ "backbone.blocks.12.filter.poles": "model-00001-of-00003.safetensors",
64
+ "backbone.blocks.12.filter.residues": "model-00001-of-00003.safetensors",
65
+ "backbone.blocks.12.filter.short_filter_bias": "model-00001-of-00003.safetensors",
66
+ "backbone.blocks.12.filter.short_filter_weight": "model-00001-of-00003.safetensors",
67
+ "backbone.blocks.12.mlp.l1.weight": "model-00002-of-00003.safetensors",
68
+ "backbone.blocks.12.mlp.l2.weight": "model-00002-of-00003.safetensors",
69
+ "backbone.blocks.12.mlp.l3.weight": "model-00002-of-00003.safetensors",
70
+ "backbone.blocks.12.out_filter_dense.bias": "model-00001-of-00003.safetensors",
71
+ "backbone.blocks.12.out_filter_dense.weight": "model-00001-of-00003.safetensors",
72
+ "backbone.blocks.12.post_norm.scale": "model-00001-of-00003.safetensors",
73
+ "backbone.blocks.12.pre_norm.scale": "model-00001-of-00003.safetensors",
74
+ "backbone.blocks.12.projections.bias": "model-00001-of-00003.safetensors",
75
+ "backbone.blocks.12.projections.weight": "model-00001-of-00003.safetensors",
76
+ "backbone.blocks.13.filter.D": "model-00002-of-00003.safetensors",
77
+ "backbone.blocks.13.filter.poles": "model-00002-of-00003.safetensors",
78
+ "backbone.blocks.13.filter.residues": "model-00002-of-00003.safetensors",
79
+ "backbone.blocks.13.filter.short_filter_bias": "model-00002-of-00003.safetensors",
80
+ "backbone.blocks.13.filter.short_filter_weight": "model-00002-of-00003.safetensors",
81
+ "backbone.blocks.13.mlp.l1.weight": "model-00002-of-00003.safetensors",
82
+ "backbone.blocks.13.mlp.l2.weight": "model-00002-of-00003.safetensors",
83
+ "backbone.blocks.13.mlp.l3.weight": "model-00002-of-00003.safetensors",
84
+ "backbone.blocks.13.out_filter_dense.bias": "model-00002-of-00003.safetensors",
85
+ "backbone.blocks.13.out_filter_dense.weight": "model-00002-of-00003.safetensors",
86
+ "backbone.blocks.13.post_norm.scale": "model-00002-of-00003.safetensors",
87
+ "backbone.blocks.13.pre_norm.scale": "model-00002-of-00003.safetensors",
88
+ "backbone.blocks.13.projections.bias": "model-00002-of-00003.safetensors",
89
+ "backbone.blocks.13.projections.weight": "model-00002-of-00003.safetensors",
90
+ "backbone.blocks.14.filter.D": "model-00002-of-00003.safetensors",
91
+ "backbone.blocks.14.filter.poles": "model-00002-of-00003.safetensors",
92
+ "backbone.blocks.14.filter.residues": "model-00002-of-00003.safetensors",
93
+ "backbone.blocks.14.filter.short_filter_bias": "model-00002-of-00003.safetensors",
94
+ "backbone.blocks.14.filter.short_filter_weight": "model-00002-of-00003.safetensors",
95
+ "backbone.blocks.14.mlp.l1.weight": "model-00002-of-00003.safetensors",
96
+ "backbone.blocks.14.mlp.l2.weight": "model-00002-of-00003.safetensors",
97
+ "backbone.blocks.14.mlp.l3.weight": "model-00002-of-00003.safetensors",
98
+ "backbone.blocks.14.out_filter_dense.bias": "model-00002-of-00003.safetensors",
99
+ "backbone.blocks.14.out_filter_dense.weight": "model-00002-of-00003.safetensors",
100
+ "backbone.blocks.14.post_norm.scale": "model-00002-of-00003.safetensors",
101
+ "backbone.blocks.14.pre_norm.scale": "model-00002-of-00003.safetensors",
102
+ "backbone.blocks.14.projections.bias": "model-00002-of-00003.safetensors",
103
+ "backbone.blocks.14.projections.weight": "model-00002-of-00003.safetensors",
104
+ "backbone.blocks.15.filter.D": "model-00002-of-00003.safetensors",
105
+ "backbone.blocks.15.filter.poles": "model-00002-of-00003.safetensors",
106
+ "backbone.blocks.15.filter.residues": "model-00002-of-00003.safetensors",
107
+ "backbone.blocks.15.filter.short_filter_bias": "model-00002-of-00003.safetensors",
108
+ "backbone.blocks.15.filter.short_filter_weight": "model-00002-of-00003.safetensors",
109
+ "backbone.blocks.15.mlp.l1.weight": "model-00002-of-00003.safetensors",
110
+ "backbone.blocks.15.mlp.l2.weight": "model-00002-of-00003.safetensors",
111
+ "backbone.blocks.15.mlp.l3.weight": "model-00002-of-00003.safetensors",
112
+ "backbone.blocks.15.out_filter_dense.bias": "model-00002-of-00003.safetensors",
113
+ "backbone.blocks.15.out_filter_dense.weight": "model-00002-of-00003.safetensors",
114
+ "backbone.blocks.15.post_norm.scale": "model-00002-of-00003.safetensors",
115
+ "backbone.blocks.15.pre_norm.scale": "model-00002-of-00003.safetensors",
116
+ "backbone.blocks.15.projections.bias": "model-00002-of-00003.safetensors",
117
+ "backbone.blocks.15.projections.weight": "model-00002-of-00003.safetensors",
118
+ "backbone.blocks.16.inner_mha_cls.Wqkv.bias": "model-00002-of-00003.safetensors",
119
+ "backbone.blocks.16.inner_mha_cls.Wqkv.weight": "model-00002-of-00003.safetensors",
120
+ "backbone.blocks.16.inner_mha_cls.out_proj.bias": "model-00002-of-00003.safetensors",
121
+ "backbone.blocks.16.inner_mha_cls.out_proj.weight": "model-00002-of-00003.safetensors",
122
+ "backbone.blocks.16.inner_mha_cls.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
123
+ "backbone.blocks.16.mlp.l1.weight": "model-00002-of-00003.safetensors",
124
+ "backbone.blocks.16.mlp.l2.weight": "model-00002-of-00003.safetensors",
125
+ "backbone.blocks.16.mlp.l3.weight": "model-00002-of-00003.safetensors",
126
+ "backbone.blocks.16.post_norm.scale": "model-00002-of-00003.safetensors",
127
+ "backbone.blocks.16.pre_norm.scale": "model-00002-of-00003.safetensors",
128
+ "backbone.blocks.17.filter.D": "model-00002-of-00003.safetensors",
129
+ "backbone.blocks.17.filter.poles": "model-00002-of-00003.safetensors",
130
+ "backbone.blocks.17.filter.residues": "model-00002-of-00003.safetensors",
131
+ "backbone.blocks.17.filter.short_filter_bias": "model-00002-of-00003.safetensors",
132
+ "backbone.blocks.17.filter.short_filter_weight": "model-00002-of-00003.safetensors",
133
+ "backbone.blocks.17.mlp.l1.weight": "model-00002-of-00003.safetensors",
134
+ "backbone.blocks.17.mlp.l2.weight": "model-00002-of-00003.safetensors",
135
+ "backbone.blocks.17.mlp.l3.weight": "model-00002-of-00003.safetensors",
136
+ "backbone.blocks.17.out_filter_dense.bias": "model-00002-of-00003.safetensors",
137
+ "backbone.blocks.17.out_filter_dense.weight": "model-00002-of-00003.safetensors",
138
+ "backbone.blocks.17.post_norm.scale": "model-00002-of-00003.safetensors",
139
+ "backbone.blocks.17.pre_norm.scale": "model-00002-of-00003.safetensors",
140
+ "backbone.blocks.17.projections.bias": "model-00002-of-00003.safetensors",
141
+ "backbone.blocks.17.projections.weight": "model-00002-of-00003.safetensors",
142
+ "backbone.blocks.18.filter.D": "model-00002-of-00003.safetensors",
143
+ "backbone.blocks.18.filter.poles": "model-00002-of-00003.safetensors",
144
+ "backbone.blocks.18.filter.residues": "model-00002-of-00003.safetensors",
145
+ "backbone.blocks.18.filter.short_filter_bias": "model-00002-of-00003.safetensors",
146
+ "backbone.blocks.18.filter.short_filter_weight": "model-00002-of-00003.safetensors",
147
+ "backbone.blocks.18.mlp.l1.weight": "model-00002-of-00003.safetensors",
148
+ "backbone.blocks.18.mlp.l2.weight": "model-00002-of-00003.safetensors",
149
+ "backbone.blocks.18.mlp.l3.weight": "model-00002-of-00003.safetensors",
150
+ "backbone.blocks.18.out_filter_dense.bias": "model-00002-of-00003.safetensors",
151
+ "backbone.blocks.18.out_filter_dense.weight": "model-00002-of-00003.safetensors",
152
+ "backbone.blocks.18.post_norm.scale": "model-00002-of-00003.safetensors",
153
+ "backbone.blocks.18.pre_norm.scale": "model-00002-of-00003.safetensors",
154
+ "backbone.blocks.18.projections.bias": "model-00002-of-00003.safetensors",
155
+ "backbone.blocks.18.projections.weight": "model-00002-of-00003.safetensors",
156
+ "backbone.blocks.19.filter.D": "model-00002-of-00003.safetensors",
157
+ "backbone.blocks.19.filter.poles": "model-00002-of-00003.safetensors",
158
+ "backbone.blocks.19.filter.residues": "model-00002-of-00003.safetensors",
159
+ "backbone.blocks.19.filter.short_filter_bias": "model-00002-of-00003.safetensors",
160
+ "backbone.blocks.19.filter.short_filter_weight": "model-00002-of-00003.safetensors",
161
+ "backbone.blocks.19.mlp.l1.weight": "model-00002-of-00003.safetensors",
162
+ "backbone.blocks.19.mlp.l2.weight": "model-00002-of-00003.safetensors",
163
+ "backbone.blocks.19.mlp.l3.weight": "model-00002-of-00003.safetensors",
164
+ "backbone.blocks.19.out_filter_dense.bias": "model-00002-of-00003.safetensors",
165
+ "backbone.blocks.19.out_filter_dense.weight": "model-00002-of-00003.safetensors",
166
+ "backbone.blocks.19.post_norm.scale": "model-00002-of-00003.safetensors",
167
+ "backbone.blocks.19.pre_norm.scale": "model-00002-of-00003.safetensors",
168
+ "backbone.blocks.19.projections.bias": "model-00002-of-00003.safetensors",
169
+ "backbone.blocks.19.projections.weight": "model-00002-of-00003.safetensors",
170
+ "backbone.blocks.2.filter.D": "model-00001-of-00003.safetensors",
171
+ "backbone.blocks.2.filter.poles": "model-00001-of-00003.safetensors",
172
+ "backbone.blocks.2.filter.residues": "model-00001-of-00003.safetensors",
173
+ "backbone.blocks.2.filter.short_filter_bias": "model-00001-of-00003.safetensors",
174
+ "backbone.blocks.2.filter.short_filter_weight": "model-00001-of-00003.safetensors",
175
+ "backbone.blocks.2.mlp.l1.weight": "model-00001-of-00003.safetensors",
176
+ "backbone.blocks.2.mlp.l2.weight": "model-00001-of-00003.safetensors",
177
+ "backbone.blocks.2.mlp.l3.weight": "model-00001-of-00003.safetensors",
178
+ "backbone.blocks.2.out_filter_dense.bias": "model-00001-of-00003.safetensors",
179
+ "backbone.blocks.2.out_filter_dense.weight": "model-00001-of-00003.safetensors",
180
+ "backbone.blocks.2.post_norm.scale": "model-00001-of-00003.safetensors",
181
+ "backbone.blocks.2.pre_norm.scale": "model-00001-of-00003.safetensors",
182
+ "backbone.blocks.2.projections.bias": "model-00001-of-00003.safetensors",
183
+ "backbone.blocks.2.projections.weight": "model-00001-of-00003.safetensors",
184
+ "backbone.blocks.20.filter.D": "model-00002-of-00003.safetensors",
185
+ "backbone.blocks.20.filter.poles": "model-00002-of-00003.safetensors",
186
+ "backbone.blocks.20.filter.residues": "model-00002-of-00003.safetensors",
187
+ "backbone.blocks.20.filter.short_filter_bias": "model-00002-of-00003.safetensors",
188
+ "backbone.blocks.20.filter.short_filter_weight": "model-00002-of-00003.safetensors",
189
+ "backbone.blocks.20.mlp.l1.weight": "model-00002-of-00003.safetensors",
190
+ "backbone.blocks.20.mlp.l2.weight": "model-00002-of-00003.safetensors",
191
+ "backbone.blocks.20.mlp.l3.weight": "model-00002-of-00003.safetensors",
192
+ "backbone.blocks.20.out_filter_dense.bias": "model-00002-of-00003.safetensors",
193
+ "backbone.blocks.20.out_filter_dense.weight": "model-00002-of-00003.safetensors",
194
+ "backbone.blocks.20.post_norm.scale": "model-00002-of-00003.safetensors",
195
+ "backbone.blocks.20.pre_norm.scale": "model-00002-of-00003.safetensors",
196
+ "backbone.blocks.20.projections.bias": "model-00002-of-00003.safetensors",
197
+ "backbone.blocks.20.projections.weight": "model-00002-of-00003.safetensors",
198
+ "backbone.blocks.21.filter.D": "model-00002-of-00003.safetensors",
199
+ "backbone.blocks.21.filter.poles": "model-00002-of-00003.safetensors",
200
+ "backbone.blocks.21.filter.residues": "model-00002-of-00003.safetensors",
201
+ "backbone.blocks.21.filter.short_filter_bias": "model-00002-of-00003.safetensors",
202
+ "backbone.blocks.21.filter.short_filter_weight": "model-00002-of-00003.safetensors",
203
+ "backbone.blocks.21.mlp.l1.weight": "model-00002-of-00003.safetensors",
204
+ "backbone.blocks.21.mlp.l2.weight": "model-00002-of-00003.safetensors",
205
+ "backbone.blocks.21.mlp.l3.weight": "model-00002-of-00003.safetensors",
206
+ "backbone.blocks.21.out_filter_dense.bias": "model-00002-of-00003.safetensors",
207
+ "backbone.blocks.21.out_filter_dense.weight": "model-00002-of-00003.safetensors",
208
+ "backbone.blocks.21.post_norm.scale": "model-00002-of-00003.safetensors",
209
+ "backbone.blocks.21.pre_norm.scale": "model-00002-of-00003.safetensors",
210
+ "backbone.blocks.21.projections.bias": "model-00002-of-00003.safetensors",
211
+ "backbone.blocks.21.projections.weight": "model-00002-of-00003.safetensors",
212
+ "backbone.blocks.22.filter.D": "model-00002-of-00003.safetensors",
213
+ "backbone.blocks.22.filter.poles": "model-00002-of-00003.safetensors",
214
+ "backbone.blocks.22.filter.residues": "model-00002-of-00003.safetensors",
215
+ "backbone.blocks.22.filter.short_filter_bias": "model-00002-of-00003.safetensors",
216
+ "backbone.blocks.22.filter.short_filter_weight": "model-00002-of-00003.safetensors",
217
+ "backbone.blocks.22.mlp.l1.weight": "model-00002-of-00003.safetensors",
218
+ "backbone.blocks.22.mlp.l2.weight": "model-00002-of-00003.safetensors",
219
+ "backbone.blocks.22.mlp.l3.weight": "model-00002-of-00003.safetensors",
220
+ "backbone.blocks.22.out_filter_dense.bias": "model-00002-of-00003.safetensors",
221
+ "backbone.blocks.22.out_filter_dense.weight": "model-00002-of-00003.safetensors",
222
+ "backbone.blocks.22.post_norm.scale": "model-00002-of-00003.safetensors",
223
+ "backbone.blocks.22.pre_norm.scale": "model-00002-of-00003.safetensors",
224
+ "backbone.blocks.22.projections.bias": "model-00002-of-00003.safetensors",
225
+ "backbone.blocks.22.projections.weight": "model-00002-of-00003.safetensors",
226
+ "backbone.blocks.23.filter.D": "model-00002-of-00003.safetensors",
227
+ "backbone.blocks.23.filter.poles": "model-00002-of-00003.safetensors",
228
+ "backbone.blocks.23.filter.residues": "model-00002-of-00003.safetensors",
229
+ "backbone.blocks.23.filter.short_filter_bias": "model-00002-of-00003.safetensors",
230
+ "backbone.blocks.23.filter.short_filter_weight": "model-00002-of-00003.safetensors",
231
+ "backbone.blocks.23.mlp.l1.weight": "model-00002-of-00003.safetensors",
232
+ "backbone.blocks.23.mlp.l2.weight": "model-00002-of-00003.safetensors",
233
+ "backbone.blocks.23.mlp.l3.weight": "model-00002-of-00003.safetensors",
234
+ "backbone.blocks.23.out_filter_dense.bias": "model-00002-of-00003.safetensors",
235
+ "backbone.blocks.23.out_filter_dense.weight": "model-00002-of-00003.safetensors",
236
+ "backbone.blocks.23.post_norm.scale": "model-00002-of-00003.safetensors",
237
+ "backbone.blocks.23.pre_norm.scale": "model-00002-of-00003.safetensors",
238
+ "backbone.blocks.23.projections.bias": "model-00002-of-00003.safetensors",
239
+ "backbone.blocks.23.projections.weight": "model-00002-of-00003.safetensors",
240
+ "backbone.blocks.24.inner_mha_cls.Wqkv.bias": "model-00002-of-00003.safetensors",
241
+ "backbone.blocks.24.inner_mha_cls.Wqkv.weight": "model-00002-of-00003.safetensors",
242
+ "backbone.blocks.24.inner_mha_cls.out_proj.bias": "model-00002-of-00003.safetensors",
243
+ "backbone.blocks.24.inner_mha_cls.out_proj.weight": "model-00002-of-00003.safetensors",
244
+ "backbone.blocks.24.inner_mha_cls.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
245
+ "backbone.blocks.24.mlp.l1.weight": "model-00002-of-00003.safetensors",
246
+ "backbone.blocks.24.mlp.l2.weight": "model-00003-of-00003.safetensors",
247
+ "backbone.blocks.24.mlp.l3.weight": "model-00003-of-00003.safetensors",
248
+ "backbone.blocks.24.post_norm.scale": "model-00002-of-00003.safetensors",
249
+ "backbone.blocks.24.pre_norm.scale": "model-00002-of-00003.safetensors",
250
+ "backbone.blocks.25.filter.D": "model-00003-of-00003.safetensors",
251
+ "backbone.blocks.25.filter.poles": "model-00003-of-00003.safetensors",
252
+ "backbone.blocks.25.filter.residues": "model-00003-of-00003.safetensors",
253
+ "backbone.blocks.25.filter.short_filter_bias": "model-00003-of-00003.safetensors",
254
+ "backbone.blocks.25.filter.short_filter_weight": "model-00003-of-00003.safetensors",
255
+ "backbone.blocks.25.mlp.l1.weight": "model-00003-of-00003.safetensors",
256
+ "backbone.blocks.25.mlp.l2.weight": "model-00003-of-00003.safetensors",
257
+ "backbone.blocks.25.mlp.l3.weight": "model-00003-of-00003.safetensors",
258
+ "backbone.blocks.25.out_filter_dense.bias": "model-00003-of-00003.safetensors",
259
+ "backbone.blocks.25.out_filter_dense.weight": "model-00003-of-00003.safetensors",
260
+ "backbone.blocks.25.post_norm.scale": "model-00003-of-00003.safetensors",
261
+ "backbone.blocks.25.pre_norm.scale": "model-00003-of-00003.safetensors",
262
+ "backbone.blocks.25.projections.bias": "model-00003-of-00003.safetensors",
263
+ "backbone.blocks.25.projections.weight": "model-00003-of-00003.safetensors",
264
+ "backbone.blocks.26.filter.D": "model-00003-of-00003.safetensors",
265
+ "backbone.blocks.26.filter.poles": "model-00003-of-00003.safetensors",
266
+ "backbone.blocks.26.filter.residues": "model-00003-of-00003.safetensors",
267
+ "backbone.blocks.26.filter.short_filter_bias": "model-00003-of-00003.safetensors",
268
+ "backbone.blocks.26.filter.short_filter_weight": "model-00003-of-00003.safetensors",
269
+ "backbone.blocks.26.mlp.l1.weight": "model-00003-of-00003.safetensors",
270
+ "backbone.blocks.26.mlp.l2.weight": "model-00003-of-00003.safetensors",
271
+ "backbone.blocks.26.mlp.l3.weight": "model-00003-of-00003.safetensors",
272
+ "backbone.blocks.26.out_filter_dense.bias": "model-00003-of-00003.safetensors",
273
+ "backbone.blocks.26.out_filter_dense.weight": "model-00003-of-00003.safetensors",
274
+ "backbone.blocks.26.post_norm.scale": "model-00003-of-00003.safetensors",
275
+ "backbone.blocks.26.pre_norm.scale": "model-00003-of-00003.safetensors",
276
+ "backbone.blocks.26.projections.bias": "model-00003-of-00003.safetensors",
277
+ "backbone.blocks.26.projections.weight": "model-00003-of-00003.safetensors",
278
+ "backbone.blocks.27.filter.D": "model-00003-of-00003.safetensors",
279
+ "backbone.blocks.27.filter.poles": "model-00003-of-00003.safetensors",
280
+ "backbone.blocks.27.filter.residues": "model-00003-of-00003.safetensors",
281
+ "backbone.blocks.27.filter.short_filter_bias": "model-00003-of-00003.safetensors",
282
+ "backbone.blocks.27.filter.short_filter_weight": "model-00003-of-00003.safetensors",
283
+ "backbone.blocks.27.mlp.l1.weight": "model-00003-of-00003.safetensors",
284
+ "backbone.blocks.27.mlp.l2.weight": "model-00003-of-00003.safetensors",
285
+ "backbone.blocks.27.mlp.l3.weight": "model-00003-of-00003.safetensors",
286
+ "backbone.blocks.27.out_filter_dense.bias": "model-00003-of-00003.safetensors",
287
+ "backbone.blocks.27.out_filter_dense.weight": "model-00003-of-00003.safetensors",
288
+ "backbone.blocks.27.post_norm.scale": "model-00003-of-00003.safetensors",
289
+ "backbone.blocks.27.pre_norm.scale": "model-00003-of-00003.safetensors",
290
+ "backbone.blocks.27.projections.bias": "model-00003-of-00003.safetensors",
291
+ "backbone.blocks.27.projections.weight": "model-00003-of-00003.safetensors",
292
+ "backbone.blocks.28.filter.D": "model-00003-of-00003.safetensors",
293
+ "backbone.blocks.28.filter.poles": "model-00003-of-00003.safetensors",
294
+ "backbone.blocks.28.filter.residues": "model-00003-of-00003.safetensors",
295
+ "backbone.blocks.28.filter.short_filter_bias": "model-00003-of-00003.safetensors",
296
+ "backbone.blocks.28.filter.short_filter_weight": "model-00003-of-00003.safetensors",
297
+ "backbone.blocks.28.mlp.l1.weight": "model-00003-of-00003.safetensors",
298
+ "backbone.blocks.28.mlp.l2.weight": "model-00003-of-00003.safetensors",
299
+ "backbone.blocks.28.mlp.l3.weight": "model-00003-of-00003.safetensors",
300
+ "backbone.blocks.28.out_filter_dense.bias": "model-00003-of-00003.safetensors",
301
+ "backbone.blocks.28.out_filter_dense.weight": "model-00003-of-00003.safetensors",
302
+ "backbone.blocks.28.post_norm.scale": "model-00003-of-00003.safetensors",
303
+ "backbone.blocks.28.pre_norm.scale": "model-00003-of-00003.safetensors",
304
+ "backbone.blocks.28.projections.bias": "model-00003-of-00003.safetensors",
305
+ "backbone.blocks.28.projections.weight": "model-00003-of-00003.safetensors",
306
+ "backbone.blocks.29.filter.D": "model-00003-of-00003.safetensors",
307
+ "backbone.blocks.29.filter.poles": "model-00003-of-00003.safetensors",
308
+ "backbone.blocks.29.filter.residues": "model-00003-of-00003.safetensors",
309
+ "backbone.blocks.29.filter.short_filter_bias": "model-00003-of-00003.safetensors",
310
+ "backbone.blocks.29.filter.short_filter_weight": "model-00003-of-00003.safetensors",
311
+ "backbone.blocks.29.mlp.l1.weight": "model-00003-of-00003.safetensors",
312
+ "backbone.blocks.29.mlp.l2.weight": "model-00003-of-00003.safetensors",
313
+ "backbone.blocks.29.mlp.l3.weight": "model-00003-of-00003.safetensors",
314
+ "backbone.blocks.29.out_filter_dense.bias": "model-00003-of-00003.safetensors",
315
+ "backbone.blocks.29.out_filter_dense.weight": "model-00003-of-00003.safetensors",
316
+ "backbone.blocks.29.post_norm.scale": "model-00003-of-00003.safetensors",
317
+ "backbone.blocks.29.pre_norm.scale": "model-00003-of-00003.safetensors",
318
+ "backbone.blocks.29.projections.bias": "model-00003-of-00003.safetensors",
319
+ "backbone.blocks.29.projections.weight": "model-00003-of-00003.safetensors",
320
+ "backbone.blocks.3.filter.D": "model-00001-of-00003.safetensors",
321
+ "backbone.blocks.3.filter.poles": "model-00001-of-00003.safetensors",
322
+ "backbone.blocks.3.filter.residues": "model-00001-of-00003.safetensors",
323
+ "backbone.blocks.3.filter.short_filter_bias": "model-00001-of-00003.safetensors",
324
+ "backbone.blocks.3.filter.short_filter_weight": "model-00001-of-00003.safetensors",
325
+ "backbone.blocks.3.mlp.l1.weight": "model-00001-of-00003.safetensors",
326
+ "backbone.blocks.3.mlp.l2.weight": "model-00001-of-00003.safetensors",
327
+ "backbone.blocks.3.mlp.l3.weight": "model-00001-of-00003.safetensors",
328
+ "backbone.blocks.3.out_filter_dense.bias": "model-00001-of-00003.safetensors",
329
+ "backbone.blocks.3.out_filter_dense.weight": "model-00001-of-00003.safetensors",
330
+ "backbone.blocks.3.post_norm.scale": "model-00001-of-00003.safetensors",
331
+ "backbone.blocks.3.pre_norm.scale": "model-00001-of-00003.safetensors",
332
+ "backbone.blocks.3.projections.bias": "model-00001-of-00003.safetensors",
333
+ "backbone.blocks.3.projections.weight": "model-00001-of-00003.safetensors",
334
+ "backbone.blocks.30.filter.D": "model-00003-of-00003.safetensors",
335
+ "backbone.blocks.30.filter.poles": "model-00003-of-00003.safetensors",
336
+ "backbone.blocks.30.filter.residues": "model-00003-of-00003.safetensors",
337
+ "backbone.blocks.30.filter.short_filter_bias": "model-00003-of-00003.safetensors",
338
+ "backbone.blocks.30.filter.short_filter_weight": "model-00003-of-00003.safetensors",
339
+ "backbone.blocks.30.mlp.l1.weight": "model-00003-of-00003.safetensors",
340
+ "backbone.blocks.30.mlp.l2.weight": "model-00003-of-00003.safetensors",
341
+ "backbone.blocks.30.mlp.l3.weight": "model-00003-of-00003.safetensors",
342
+ "backbone.blocks.30.out_filter_dense.bias": "model-00003-of-00003.safetensors",
343
+ "backbone.blocks.30.out_filter_dense.weight": "model-00003-of-00003.safetensors",
344
+ "backbone.blocks.30.post_norm.scale": "model-00003-of-00003.safetensors",
345
+ "backbone.blocks.30.pre_norm.scale": "model-00003-of-00003.safetensors",
346
+ "backbone.blocks.30.projections.bias": "model-00003-of-00003.safetensors",
347
+ "backbone.blocks.30.projections.weight": "model-00003-of-00003.safetensors",
348
+ "backbone.blocks.31.filter.D": "model-00003-of-00003.safetensors",
349
+ "backbone.blocks.31.filter.poles": "model-00003-of-00003.safetensors",
350
+ "backbone.blocks.31.filter.residues": "model-00003-of-00003.safetensors",
351
+ "backbone.blocks.31.filter.short_filter_bias": "model-00003-of-00003.safetensors",
352
+ "backbone.blocks.31.filter.short_filter_weight": "model-00003-of-00003.safetensors",
353
+ "backbone.blocks.31.mlp.l1.weight": "model-00003-of-00003.safetensors",
354
+ "backbone.blocks.31.mlp.l2.weight": "model-00003-of-00003.safetensors",
355
+ "backbone.blocks.31.mlp.l3.weight": "model-00003-of-00003.safetensors",
356
+ "backbone.blocks.31.out_filter_dense.bias": "model-00003-of-00003.safetensors",
357
+ "backbone.blocks.31.out_filter_dense.weight": "model-00003-of-00003.safetensors",
358
+ "backbone.blocks.31.post_norm.scale": "model-00003-of-00003.safetensors",
359
+ "backbone.blocks.31.pre_norm.scale": "model-00003-of-00003.safetensors",
360
+ "backbone.blocks.31.projections.bias": "model-00003-of-00003.safetensors",
361
+ "backbone.blocks.31.projections.weight": "model-00003-of-00003.safetensors",
362
+ "backbone.blocks.4.filter.D": "model-00001-of-00003.safetensors",
363
+ "backbone.blocks.4.filter.poles": "model-00001-of-00003.safetensors",
364
+ "backbone.blocks.4.filter.residues": "model-00001-of-00003.safetensors",
365
+ "backbone.blocks.4.filter.short_filter_bias": "model-00001-of-00003.safetensors",
366
+ "backbone.blocks.4.filter.short_filter_weight": "model-00001-of-00003.safetensors",
367
+ "backbone.blocks.4.mlp.l1.weight": "model-00001-of-00003.safetensors",
368
+ "backbone.blocks.4.mlp.l2.weight": "model-00001-of-00003.safetensors",
369
+ "backbone.blocks.4.mlp.l3.weight": "model-00001-of-00003.safetensors",
370
+ "backbone.blocks.4.out_filter_dense.bias": "model-00001-of-00003.safetensors",
371
+ "backbone.blocks.4.out_filter_dense.weight": "model-00001-of-00003.safetensors",
372
+ "backbone.blocks.4.post_norm.scale": "model-00001-of-00003.safetensors",
373
+ "backbone.blocks.4.pre_norm.scale": "model-00001-of-00003.safetensors",
374
+ "backbone.blocks.4.projections.bias": "model-00001-of-00003.safetensors",
375
+ "backbone.blocks.4.projections.weight": "model-00001-of-00003.safetensors",
376
+ "backbone.blocks.5.filter.D": "model-00001-of-00003.safetensors",
377
+ "backbone.blocks.5.filter.poles": "model-00001-of-00003.safetensors",
378
+ "backbone.blocks.5.filter.residues": "model-00001-of-00003.safetensors",
379
+ "backbone.blocks.5.filter.short_filter_bias": "model-00001-of-00003.safetensors",
380
+ "backbone.blocks.5.filter.short_filter_weight": "model-00001-of-00003.safetensors",
381
+ "backbone.blocks.5.mlp.l1.weight": "model-00001-of-00003.safetensors",
382
+ "backbone.blocks.5.mlp.l2.weight": "model-00001-of-00003.safetensors",
383
+ "backbone.blocks.5.mlp.l3.weight": "model-00001-of-00003.safetensors",
384
+ "backbone.blocks.5.out_filter_dense.bias": "model-00001-of-00003.safetensors",
385
+ "backbone.blocks.5.out_filter_dense.weight": "model-00001-of-00003.safetensors",
386
+ "backbone.blocks.5.post_norm.scale": "model-00001-of-00003.safetensors",
387
+ "backbone.blocks.5.pre_norm.scale": "model-00001-of-00003.safetensors",
388
+ "backbone.blocks.5.projections.bias": "model-00001-of-00003.safetensors",
389
+ "backbone.blocks.5.projections.weight": "model-00001-of-00003.safetensors",
390
+ "backbone.blocks.6.filter.D": "model-00001-of-00003.safetensors",
391
+ "backbone.blocks.6.filter.poles": "model-00001-of-00003.safetensors",
392
+ "backbone.blocks.6.filter.residues": "model-00001-of-00003.safetensors",
393
+ "backbone.blocks.6.filter.short_filter_bias": "model-00001-of-00003.safetensors",
394
+ "backbone.blocks.6.filter.short_filter_weight": "model-00001-of-00003.safetensors",
395
+ "backbone.blocks.6.mlp.l1.weight": "model-00001-of-00003.safetensors",
396
+ "backbone.blocks.6.mlp.l2.weight": "model-00001-of-00003.safetensors",
397
+ "backbone.blocks.6.mlp.l3.weight": "model-00001-of-00003.safetensors",
398
+ "backbone.blocks.6.out_filter_dense.bias": "model-00001-of-00003.safetensors",
399
+ "backbone.blocks.6.out_filter_dense.weight": "model-00001-of-00003.safetensors",
400
+ "backbone.blocks.6.post_norm.scale": "model-00001-of-00003.safetensors",
401
+ "backbone.blocks.6.pre_norm.scale": "model-00001-of-00003.safetensors",
402
+ "backbone.blocks.6.projections.bias": "model-00001-of-00003.safetensors",
403
+ "backbone.blocks.6.projections.weight": "model-00001-of-00003.safetensors",
404
+ "backbone.blocks.7.filter.D": "model-00001-of-00003.safetensors",
405
+ "backbone.blocks.7.filter.poles": "model-00001-of-00003.safetensors",
406
+ "backbone.blocks.7.filter.residues": "model-00001-of-00003.safetensors",
407
+ "backbone.blocks.7.filter.short_filter_bias": "model-00001-of-00003.safetensors",
408
+ "backbone.blocks.7.filter.short_filter_weight": "model-00001-of-00003.safetensors",
409
+ "backbone.blocks.7.mlp.l1.weight": "model-00001-of-00003.safetensors",
410
+ "backbone.blocks.7.mlp.l2.weight": "model-00001-of-00003.safetensors",
411
+ "backbone.blocks.7.mlp.l3.weight": "model-00001-of-00003.safetensors",
412
+ "backbone.blocks.7.out_filter_dense.bias": "model-00001-of-00003.safetensors",
413
+ "backbone.blocks.7.out_filter_dense.weight": "model-00001-of-00003.safetensors",
414
+ "backbone.blocks.7.post_norm.scale": "model-00001-of-00003.safetensors",
415
+ "backbone.blocks.7.pre_norm.scale": "model-00001-of-00003.safetensors",
416
+ "backbone.blocks.7.projections.bias": "model-00001-of-00003.safetensors",
417
+ "backbone.blocks.7.projections.weight": "model-00001-of-00003.safetensors",
418
+ "backbone.blocks.8.inner_mha_cls.Wqkv.bias": "model-00001-of-00003.safetensors",
419
+ "backbone.blocks.8.inner_mha_cls.Wqkv.weight": "model-00001-of-00003.safetensors",
420
+ "backbone.blocks.8.inner_mha_cls.out_proj.bias": "model-00001-of-00003.safetensors",
421
+ "backbone.blocks.8.inner_mha_cls.out_proj.weight": "model-00001-of-00003.safetensors",
422
+ "backbone.blocks.8.inner_mha_cls.rotary_emb.inv_freq": "model-00001-of-00003.safetensors",
423
+ "backbone.blocks.8.mlp.l1.weight": "model-00001-of-00003.safetensors",
424
+ "backbone.blocks.8.mlp.l2.weight": "model-00001-of-00003.safetensors",
425
+ "backbone.blocks.8.mlp.l3.weight": "model-00001-of-00003.safetensors",
426
+ "backbone.blocks.8.post_norm.scale": "model-00001-of-00003.safetensors",
427
+ "backbone.blocks.8.pre_norm.scale": "model-00001-of-00003.safetensors",
428
+ "backbone.blocks.9.filter.D": "model-00001-of-00003.safetensors",
429
+ "backbone.blocks.9.filter.poles": "model-00001-of-00003.safetensors",
430
+ "backbone.blocks.9.filter.residues": "model-00001-of-00003.safetensors",
431
+ "backbone.blocks.9.filter.short_filter_bias": "model-00001-of-00003.safetensors",
432
+ "backbone.blocks.9.filter.short_filter_weight": "model-00001-of-00003.safetensors",
433
+ "backbone.blocks.9.mlp.l1.weight": "model-00001-of-00003.safetensors",
434
+ "backbone.blocks.9.mlp.l2.weight": "model-00001-of-00003.safetensors",
435
+ "backbone.blocks.9.mlp.l3.weight": "model-00001-of-00003.safetensors",
436
+ "backbone.blocks.9.out_filter_dense.bias": "model-00001-of-00003.safetensors",
437
+ "backbone.blocks.9.out_filter_dense.weight": "model-00001-of-00003.safetensors",
438
+ "backbone.blocks.9.post_norm.scale": "model-00001-of-00003.safetensors",
439
+ "backbone.blocks.9.pre_norm.scale": "model-00001-of-00003.safetensors",
440
+ "backbone.blocks.9.projections.bias": "model-00001-of-00003.safetensors",
441
+ "backbone.blocks.9.projections.weight": "model-00001-of-00003.safetensors",
442
+ "backbone.embedding_layer.weight": "model-00001-of-00003.safetensors",
443
+ "backbone.norm.scale": "model-00001-of-00003.safetensors"
444
+ }
445
+ }
modeling_hyena.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """StripedHyena custom code port for the Hugging Face Hub"""
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from .configuration_hyena import StripedHyenaConfig
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
9
+ from transformers.utils import logging
10
+ from typing import Optional, Tuple, Union
11
+ from .model import StripedHyena
12
+ from .utils import dotdict
13
+ from .cache import InferenceParams
14
+ from .engine import HyenaInferenceEngine
15
+ from .layers import RMSNorm
16
+ from .utils import dotdict, column_split
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class StripedHyenaPreTrainedModel(PreTrainedModel):
22
+ config_class = StripedHyenaConfig
23
+ base_model_prefix = "sh"
24
+ supports_gradient_checkpointing = False
25
+ _no_split_modules = ["AttentionBlock", "ParallelGatedConvBlock"]
26
+ _skip_keys_device_placement = "past_key_values"
27
+ _keys_to_ignore_on_load_missing = [r"freq"]
28
+ _keys_to_ignore_on_load_unexpected = [r"fftconv", r"twiddle_factors"]
29
+ _supports_flash_attn_2 = True
30
+
31
+
32
+ class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
33
+ supports_gradient_checkpointing = True
34
+
35
+ def __init__(self, config, **kwargs):
36
+ super().__init__(config, **kwargs)
37
+ model_config = dotdict(config.to_dict())
38
+ self.backbone = StripedHyena(model_config)
39
+ self.backbone.gradient_checkpointing = False
40
+ self.config = config
41
+ vocab_size = config.vocab_size
42
+ if vocab_size % config.make_vocab_size_divisible_by != 0:
43
+ vocab_size += config.make_vocab_size_divisible_by - (
44
+ vocab_size % config.make_vocab_size_divisible_by
45
+ )
46
+ self.vocab_size = vocab_size
47
+ self.post_init()
48
+ self.force_dtype()
49
+
50
+ def force_dtype(self):
51
+ self.backbone.to_bfloat16_except_poles_residues()
52
+
53
+ def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
54
+ self.backbone.gradient_checkpointing = enable
55
+
56
+ def get_input_embeddings(self):
57
+ return self.backbone.embedding_layer
58
+
59
+ def forward(
60
+ self,
61
+ input_ids: torch.LongTensor = None,
62
+ attention_mask: Optional[torch.LongTensor] = None,
63
+ labels: Optional[torch.LongTensor] = None,
64
+ use_cache: Optional[bool] = None,
65
+ output_attentions: Optional[bool] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
+ past_key_values=None,
68
+ return_dict: Optional[bool] = None,
69
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
70
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
71
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
72
+
73
+ if use_cache:
74
+ if self.backbone.gradient_checkpointing and self.backbone.training:
75
+ logger.warning_once(
76
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
77
+ )
78
+ use_cache = False
79
+ elif labels is not None:
80
+ logger.warning_once(
81
+ "`use_cache=True` is incompatible with loss calculation. Setting `use_cache=False`..."
82
+ )
83
+ use_cache = False
84
+
85
+ inputs = input_ids
86
+ if use_cache:
87
+ if past_key_values is None:
88
+ past_key_values = self.backbone.initialize_inference_params()
89
+
90
+ batch_size = input_ids.shape[0]
91
+ past_key_values["mha"].max_batch_size = batch_size
92
+ past_key_values["hyena"].max_batch_size = batch_size
93
+ else:
94
+ seqlen_offset = past_key_values["mha"].seqlen_offset
95
+ if seqlen_offset == 0:
96
+ # second loop through generate will have prompt_len + 1 as seqlen
97
+ seqlen_offset = input_ids.shape[-1] - 1
98
+ past_key_values["hyena"].seqlen_offset = seqlen_offset
99
+ past_key_values["mha"].seqlen_offset = seqlen_offset
100
+ else:
101
+ past_key_values["mha"].seqlen_offset += 1
102
+ past_key_values["hyena"].seqlen_offset += 1
103
+
104
+ inputs = input_ids[
105
+ :,
106
+ -1:,
107
+ ]
108
+
109
+ logits, past_key_values = self.backbone(
110
+ inputs,
111
+ padding_mask=attention_mask,
112
+ inference_params_dict=past_key_values if use_cache else None,
113
+ )
114
+
115
+ loss = None
116
+ if labels is not None:
117
+ shift_logits = logits[..., :-1, :].contiguous()
118
+ shift_labels = labels[..., 1:].contiguous()
119
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
120
+ shift_labels = shift_labels.view(-1)
121
+ shift_labels = shift_labels.to(shift_logits.device)
122
+ loss = F.cross_entropy(shift_logits, shift_labels)
123
+
124
+ if return_dict:
125
+ return CausalLMOutputWithPast(
126
+ logits=logits,
127
+ hidden_states=None,
128
+ past_key_values=past_key_values if use_cache else None,
129
+ loss=loss,
130
+ )
131
+ else:
132
+ return logits
133
+
134
+ @classmethod
135
+ def can_generate(cls) -> bool:
136
+ return True
137
+
138
+ def prepare_inputs_for_generation(
139
+ self, input_ids, attention_mask=None, past_key_values=None, **kwargs
140
+ ):
141
+ return {
142
+ "input_ids": input_ids,
143
+ "attention_mask": attention_mask,
144
+ "past_key_values": past_key_values,
145
+ }
positional_embeddings.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This software is distributed under the terms of the Apache License, Version 2.0
2
+ # Author: Armin Thomas, Eric Nguyen
3
+
4
+ import torch
5
+ import copy
6
+ from einops import rearrange
7
+ from flash_attn.layers.rotary import RotaryEmbedding
8
+ from flash_attn.modules.mha import MHA
9
+
10
+
11
+ # simple wrapper for flash-attn RoPE with linear scaling:
12
+ class LinearlyScaledRotaryEmbedding(RotaryEmbedding):
13
+ def __init__(
14
+ self,
15
+ dim: int,
16
+ scaling_factor: float=1.,
17
+ base=10000.0,
18
+ interleaved=False,
19
+ scale_base=None,
20
+ pos_idx_in_fp32=True,
21
+ device=None,
22
+ ):
23
+ super().__init__(
24
+ dim=dim,
25
+ base=base,
26
+ interleaved=interleaved,
27
+ scale_base=scale_base,
28
+ pos_idx_in_fp32=pos_idx_in_fp32,
29
+ device=device
30
+ )
31
+ self._linear_scaling_factor = scaling_factor
32
+ # adpated from: https://github.com/Dao-AILab/flash-attention/blob/43ceab630bc6c27712428da5a33fc9cb5c369d91/flash_attn/layers/rotary.py#L368
33
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
34
+ # Reset the tables if the sequence length has changed,
35
+ # if we're on a new device (possibly due to tracing for instance),
36
+ # or if we're switching from inference mode to training
37
+ if (
38
+ seqlen > self._seq_len_cached
39
+ or self._cos_cached is None
40
+ or self._cos_cached.device != device
41
+ or self._cos_cached.dtype != dtype
42
+ or (self.training and self._cos_cached.is_inference())
43
+ ):
44
+ self._seq_len_cached = seqlen
45
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
46
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
47
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
48
+ if self.pos_idx_in_fp32:
49
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
50
+ # linear scaling:
51
+ t = t / self._linear_scaling_factor
52
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
53
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
54
+ # cos & sin output to change significantly.
55
+ # We want to recompute self.inv_freq if it was not loaded in fp32
56
+ if self.inv_freq.dtype != torch.float32:
57
+ inv_freq = self._compute_inv_freq(device=device)
58
+ else:
59
+ inv_freq = self.inv_freq
60
+ else:
61
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
62
+ # linear scaling:
63
+ t = t / self._linear_scaling_factor
64
+ inv_freq = self.inv_freq
65
+ # Don't do einsum, it converts fp32 to fp16 under AMP
66
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
67
+ freqs = torch.outer(t, inv_freq)
68
+ if self.scale is None:
69
+ self._cos_cached = torch.cos(freqs).to(dtype)
70
+ self._sin_cached = torch.sin(freqs).to(dtype)
71
+ else:
72
+ power = (
73
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
74
+ - seqlen // 2
75
+ ) / self.scale_base
76
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
77
+ # We want the multiplication by scale to happen in fp32
78
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
79
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
80
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
81
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
82
+
83
+ # swap out RoPE of existing mha:
84
+ def swap_mha_rope(
85
+ mha,
86
+ new_rope: torch.nn.Module=LinearlyScaledRotaryEmbedding,
87
+ kwargs_new_rope: dict=None
88
+ ):
89
+ # determine mha dtype and device:
90
+ dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype
91
+ device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device
92
+ # determine RoPE settings:
93
+ kwargs_old_rope = dict(
94
+ dim = mha.rotary_emb.dim,
95
+ base = mha.rotary_emb.base,
96
+ interleaved = mha.rotary_emb.interleaved,
97
+ scale_base = mha.rotary_emb.scale_base,
98
+ pos_idx_in_fp32 = mha.rotary_emb.pos_idx_in_fp32,
99
+ device = mha.rotary_emb.inv_freq.device
100
+ )
101
+ # delete old RoPE:
102
+ del mha.rotary_emb
103
+ # create new RoPE:
104
+ kwargs_new_rope = kwargs_new_rope or {'scaling_factor': 1.0}
105
+ scaled_rope = new_rope(
106
+ **kwargs_new_rope,
107
+ **kwargs_old_rope
108
+ ).to(dtype)
109
+ # attach new RoPE to mha:
110
+ mha.rotary_emb = scaled_rope
111
+ # make new sure RoPE is correctly registered:
112
+ assert isinstance(mha.rotary_emb, new_rope)
113
+ return mha
pytorch_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7f1a446b4063869fa7a5c7b0e94cd2f234c44f21c6168b0dd8747f0bf33ab46
3
+ size 16814399082
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
streamer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+
3
+
4
+ class BaseStreamer:
5
+ """
6
+ Base class from which `.generate()` streamers should inherit.
7
+ """
8
+
9
+ def put(self, value):
10
+ """Function that is called by `.generate()` to push new tokens"""
11
+ raise NotImplementedError()
12
+
13
+ def end(self):
14
+ """Function that is called by `.generate()` to signal the end of generation"""
15
+ raise NotImplementedError()
16
+
17
+
18
+ class ByteStreamer(BaseStreamer):
19
+ """
20
+ Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
21
+
22
+ <Tip warning={true}>
23
+
24
+ The API for the streamer classes is still under development and may change in the future.
25
+
26
+ </Tip>
27
+
28
+ Parameters:
29
+ tokenizer (`AutoTokenizer`):
30
+ The tokenized used to decode the tokens.
31
+ skip_prompt (`bool`, *optional*, defaults to `False`):
32
+ Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
33
+ decode_kwargs (`dict`, *optional*):
34
+ Additional keyword arguments to pass to the tokenizer's `decode` method.
35
+
36
+ Examples:
37
+
38
+ ```python
39
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
40
+
41
+ >>> tok = AutoTokenizer.from_pretrained("gpt2")
42
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
43
+ >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
44
+ >>> streamer = TextStreamer(tok)
45
+
46
+ >>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
47
+ >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
48
+ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
49
+ ```
50
+ """
51
+
52
+ def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
53
+ self.tokenizer = tokenizer
54
+ self.skip_prompt = skip_prompt
55
+ self.decode_kwargs = decode_kwargs
56
+
57
+ # variables used in the streaming process
58
+ self.token_cache = []
59
+ self.print_len = 0
60
+ self.next_tokens_are_prompt = True
61
+
62
+ def put(self, value):
63
+ """
64
+ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
65
+ """
66
+ if len(value.shape) > 1 and value.shape[0] > 1:
67
+ raise ValueError("TextStreamer only supports batch size 1")
68
+ elif len(value.shape) > 1:
69
+ value = value[0]
70
+
71
+ if self.skip_prompt and self.next_tokens_are_prompt:
72
+ self.next_tokens_are_prompt = False
73
+ return
74
+
75
+ # Add the new token to the cache and decodes the entire thing.
76
+ self.token_cache.extend(value.tolist())
77
+ text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
78
+
79
+ # After the symbol for a new line, we flush the cache.
80
+ if text.endswith("\n"):
81
+ printable_text = text[self.print_len :]
82
+ self.token_cache = []
83
+ self.print_len = 0
84
+ else:
85
+ printable_text = text[self.print_len : self.print_len + 1]
86
+ self.print_len += len(printable_text)
87
+
88
+ self.on_finalized_text(printable_text)
89
+
90
+ def end(self):
91
+ """Flushes any remaining cache and prints a newline to stdout."""
92
+ # Flush the cache, if it exists
93
+ if len(self.token_cache) > 0:
94
+ text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
95
+ printable_text = text[self.print_len :]
96
+ self.token_cache = []
97
+ self.print_len = 0
98
+ else:
99
+ printable_text = ""
100
+
101
+ self.next_tokens_are_prompt = True
102
+ self.on_finalized_text(printable_text, stream_end=True)
103
+
104
+ def on_finalized_text(self, text: str, stream_end: bool = False):
105
+ """Prints the new text to stdout. If the stream is ending, also prints a newline."""
106
+ print(text, flush=True, end="" if not stream_end else None)
tokenizer.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
2
+ from __future__ import annotations
3
+
4
+ import torch
5
+
6
+ import numpy as np
7
+
8
+ from os import PathLike
9
+ from typing import List, Tuple
10
+
11
+ from tokenizers import Tokenizer
12
+ from transformers.tokenization_utils import PreTrainedTokenizer
13
+ from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
14
+ from transformers.utils.generic import TensorType, PaddingStrategy
15
+
16
+
17
+ EMPTY: str = ""
18
+
19
+
20
+ class ByteTokenizer(PreTrainedTokenizer):
21
+
22
+ """UTF-8 Encoder."""
23
+
24
+ @classmethod
25
+ def from_pretrained(cls, model_id: str | PathLike, **kwargs) -> ByteTokenizer:
26
+
27
+ return cls(**kwargs, byte_level=True)
28
+
29
+ @property
30
+ def vocab_size(self) -> int:
31
+
32
+ return 512
33
+
34
+ @property
35
+ def byte_level(self) -> bool:
36
+
37
+ return self.init_kwargs.get('byte_level', True)
38
+
39
+ def get_vocab(self) -> Dict[str, int]:
40
+
41
+ return {chr(i): i for i in range(self.vocab_size)}
42
+
43
+ def __len__(self) -> int:
44
+
45
+ return self.vocab_size
46
+
47
+ def clamp(self, n: int) -> int:
48
+
49
+ return max(32, min(n, self.vocab_size))
50
+
51
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
52
+
53
+ return list(text)
54
+
55
+ def byte_tokenize(self, text: str) -> np.ndarray:
56
+
57
+ return np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
58
+
59
+ def _convert_token_to_id(self, token: str) -> int:
60
+
61
+ return self.clamp(ord(token))
62
+
63
+ def _convert_id_to_token(self, index: int) -> str:
64
+
65
+ return chr(self.clamp(index))
66
+
67
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
68
+
69
+ return EMPTY.join(tokens)
70
+
71
+ def _decode(self, token_ids: List[int], **kwargs) -> str:
72
+
73
+ indices = np.asarray(token_ids, dtype=np.uint8)
74
+
75
+ return (
76
+ indices.clip(min=32, max=self.vocab_size, out=indices)
77
+ .tobytes()
78
+ .decode('utf-8')
79
+ )
80
+
81
+ def _encode_plus(self, text: str, **kwargs) -> BatchEncoding:
82
+
83
+ first_ids = self.byte_tokenize(text).tolist()
84
+
85
+ return self.prepare_for_model(
86
+ first_ids,
87
+ pair_ids=None,
88
+ add_special_tokens=kwargs.get('add_special_tokens', False),
89
+ padding=kwargs.get('padding_strategy', PaddingStrategy.DO_NOT_PAD).value,
90
+ truncation=kwargs.get('truncation_strategy', TruncationStrategy.DO_NOT_TRUNCATE).value,
91
+ max_length=kwargs.get('max_length'),
92
+ stride=kwargs.get('stride', 0),
93
+ pad_to_multiple_of=kwargs.get('pad_to_multiple_of'),
94
+ return_tensors=kwargs.get('return_tensors'),
95
+ prepend_batch_axis=True,
96
+ return_attention_mask=kwargs.get('return_attention_mask'),
97
+ return_token_type_ids=kwargs.get('return_token_type_ids'),
98
+ return_overflowing_tokens=kwargs.get('return_overflowing_tokens', False),
99
+ return_special_tokens_mask=kwargs.get('return_special_tokens_mask', False),
100
+ return_length=kwargs.get('return_length', False),
101
+ verbose=kwargs.get('verbose', True),
102
+ )
103
+
104
+ def _batch_encode_plus(self, batch_text_or_text_pairs: List[str], **kwargs) -> BatchEncoding:
105
+
106
+ input_ids = [(self.byte_tokenize(text).tolist(), None) for text in batch_text_or_text_pairs]
107
+
108
+ return self._batch_prepare_for_model(
109
+ input_ids,
110
+ add_special_tokens=kwargs.get('add_special_tokens', False),
111
+ padding_strategy=kwargs.get('padding_strategy', PaddingStrategy.DO_NOT_PAD),
112
+ truncation_strategy=kwargs.get('truncation_strategy', TruncationStrategy.DO_NOT_TRUNCATE),
113
+ max_length=kwargs.get('max_length'),
114
+ stride=kwargs.get('stride', 0),
115
+ pad_to_multiple_of=kwargs.get('pad_to_multiple_of'),
116
+ return_attention_mask=kwargs.get('return_attention_mask'),
117
+ return_token_type_ids=kwargs.get('return_token_type_ids'),
118
+ return_overflowing_tokens=kwargs.get('return_overflowing_tokens', False),
119
+ return_special_tokens_mask=kwargs.get('return_special_tokens_mask', False),
120
+ return_length=kwargs.get('return_length', False),
121
+ return_tensors=kwargs.get('return_tensors'),
122
+ verbose=kwargs.get('verbose', True),
123
+ )
124
+
125
+ def _save_pretrained(
126
+ self, save_directory: str | PathLike, file_names: Tuple[str], **kwargs
127
+ ) -> Tuple[str]:
128
+
129
+ return file_names
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenizer.ByteTokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "byte_level": true,
10
+ "clean_up_tokenization_spaces": true,
11
+ "model_max_length": 1000000000000000019884624838656,
12
+ "padding_side": "left",
13
+ "truncation_side": "left"
14
+ }
utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def grab_first_if_tuple(x):
5
+ if x.__class__.__name__ == "tuple":
6
+ return x[0]
7
+ else:
8
+ return x
9
+
10
+
11
+ def column_split(x, num_heads, head_size):
12
+ """Split a tensor with `num_heads` alongside the head dimension, instead of
13
+ across heads. Fixed to three projections
14
+ """
15
+
16
+ x_reshaped = x.reshape(
17
+ x.shape[0],
18
+ num_heads,
19
+ 3 * head_size,
20
+ )
21
+
22
+ x2, x1, v = (
23
+ x_reshaped[:, :, :head_size],
24
+ x_reshaped[
25
+ :,
26
+ :,
27
+ head_size : 2 * head_size,
28
+ ],
29
+ x_reshaped[:, :, 2 * head_size :],
30
+ )
31
+ x2, x1, v = (
32
+ x2.reshape(x2.shape[0], -1),
33
+ x1.reshape(x1.shape[0], -1),
34
+ v.reshape(v.shape[0], -1),
35
+ )
36
+ return x2, x1, v
37
+
38
+
39
+ def get_init_from_string(init_str):
40
+ if type(init_str) == str:
41
+ if init_str == "torch.nn.init.zeros_":
42
+ return torch.nn.init.zeros_
43
+ elif init_str == "torch.nn.init.xavier_uniform_":
44
+ return torch.nn.init.xavier_uniform_
45
+ elif init_str == "torch.nn.init.xavier_normal_":
46
+ return torch.nn.init.xavier_normal_
47
+ else:
48
+ raise ValueError(f"Unrecognized init {init_str}")
49
+
50
+
51
+ def print_rank_0(message, debug=False, end="\n"):
52
+ """Print from rank 0 only."""
53
+ if torch.distributed.is_initialized():
54
+ if torch.distributed.get_rank() == 0:
55
+ print(message, flush=True, end=end)
56
+ else:
57
+ print(message, flush=True, end=end)
58
+
59
+
60
+ class dotdict(dict):
61
+ """dot.notation access to dictionary attributes"""
62
+
63
+ __getattr__ = dict.get
64
+ __setattr__ = dict.__setitem__
65
+ __delattr__ = dict.__delitem__
66
+
67
+
68
+ def ensure_divisibility(numerator, denominator):
69
+ """Ensure that numerator is divisible by the denominator."""
70
+ assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
71
+
72
+
73
+ def divide(numerator, denominator):
74
+ """Ensure that numerator is divisible by the denominator and return
75
+ the division value."""
76
+ ensure_divisibility(numerator, denominator)
77
+ return numerator // denominator
78
+
79
+
80
+ class VocabUtility:
81
+ """Split the vocabulary into `world_size` chunks amd return the
82
+ first and last index of the vocabulary belonging to the `rank`
83
+ partition: Note that indices in [first, last]"""
84
+
85
+ @staticmethod
86
+ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
87
+ index_f = rank * per_partition_vocab_size
88
+ index_l = index_f + per_partition_vocab_size
89
+ return index_f, index_l
90
+
91
+ @staticmethod
92
+ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
93
+ per_partition_vocab_size = divide(global_vocab_size, world_size)
94
+ return VocabUtility.vocab_range_from_per_partition_vocab_size(
95
+ per_partition_vocab_size, rank, world_size
96
+ )