Text Generation
ELM
English
dev-slx commited on
Commit
618a701
1 Parent(s): 2666a5d

Upload 9 files

Browse files
Files changed (9) hide show
  1. LICENSE +202 -0
  2. README.md +56 -3
  3. elm/infer_elm.py +132 -0
  4. elm/model.py +418 -0
  5. elm/positional_embeddings.py +86 -0
  6. elm/utils.py +25 -0
  7. models/.gitattributes +2 -0
  8. requirements.txt +2 -0
  9. run.py +24 -0
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,56 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SliceX AI™ ELM (Efficient Language Models)
2
+ This repository contains code to run our ELM models.
3
+
4
+ Models are located in the "models" folder. ELM models in this repository comes in three sizes (elm-1.0, elm-0.75 and elm-0.25) and supports the following use-cases.
5
+ - news_classification
6
+ - toxicity_detection
7
+ - news_content_generation
8
+
9
+ ## Download ELM repo
10
+ ```bash
11
+ git clone [email protected]:slicexai/elm-0.25-v0.1
12
+ sudo apt-get intall git-lfs
13
+ git lfs install
14
+ ```
15
+ (Optional) Installing git-lfs without sudo,
16
+ ```bash
17
+ wget https://github.com/git-lfs/git-lfs/releases/download/v3.2.0/git-lfs-linux-amd64-v3.2.0.tar.gz
18
+ tar -xzf git-lfs-linux-amd64-v3.2.0.tar.gz
19
+ PATH=$PATH:/<absolute-path>/git-lfs-3.2.0/
20
+ git lfs install
21
+ ```
22
+
23
+ ## Download ELM task-specific model checkpoints
24
+ ```bash
25
+ cd elm-0.25-v0.1
26
+ git lfs pull -I models/elm-0.25_news_classification/ckpt.pt
27
+ git lfs pull -I models/elm-0.25_toxicity_detection/ckpt.pt
28
+ git lfs pull -I models/elm-0.25_news_content_generation/ckpt.pt
29
+ ```
30
+
31
+ ## Installation
32
+ ```bash
33
+ pip install -r requirements.txt
34
+ ```
35
+
36
+ ## How to use - Run ELM on a sample task (e.g., news classification)
37
+ ```bash
38
+ python run.py <elm-model-directory>
39
+ E.g. python run.py models/elm-0.25_news_classification
40
+ ```
41
+ Prompts for the specific tasks can be found in the corresponding checkpoint directory. See an example below in the form of `models/elm-0.25_news_classification/example_prompts.json`.
42
+ ```json
43
+ {
44
+ "inputs": ["GM May Close Plant in Europe DETROIT (Reuters) - General Motors Corp. &lt;A HREF=\"http://www.investor.reuters.com/FullQuote.aspx?ticker=GM.N target=/stocks/quickinfo/fullquote\"&gt;GM.N&lt;/A&gt; will likely cut some jobs in Europe and may close a plant there as part of a restructuring plan under development to try to return the region to profitability, the U.S. automaker said on Wednesday."],
45
+ "template": "[INST]Below is a news article. Please classify it under one of the following classes (World, Business, Sports, Sci/Tech). Please format your response as a JSON payload.\n\n### Article: {input}\n\n### JSON Response:[/INST]"
46
+ }
47
+ ```
48
+
49
+ Running the above command returns the following response
50
+
51
+ ```json
52
+ {
53
+ "prompt": "[INST]Below is a news article. Please classify it under one of the following classes (World, Business, Sports, Sci/Tech). Please format your response as a JSON payload.\n\n### Article: GM May Close Plant in Europe DETROIT (Reuters) - General Motors Corp. &lt;A HREF=\"http://www.investor.reuters.com/FullQuote.aspx?ticker=GM.N target=/stocks/quickinfo/fullquote\"&gt;GM.N&lt;/A&gt; will likely cut some jobs in Europe and may close a plant there as part of a restructuring plan under development to try to return the region to profitability, the U.S. automaker said on Wednesday.\n\n### JSON Response:[/INST]",
54
+ "response": "{'text_label': 'Business'}"
55
+ }
56
+ ```
elm/infer_elm.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
2
+
3
+ from elm.model import *
4
+ from elm.utils import batchify
5
+ from transformers import AutoTokenizer
6
+ import json
7
+
8
+
9
+ def load_elm_model_and_tokenizer(local_path,
10
+ model_config_dict,
11
+ device="cuda",
12
+ load_partial=True,
13
+ get_num_layers_from_ckpt=True):
14
+ """Load ELM model and tokenizer from local checkpoint."""
15
+ model_args = ModelArgs(**model_config_dict)
16
+ model = load_elm_model_from_ckpt(local_path, device=device, model_args=model_args, load_partial=load_partial, get_num_layers_from_ckpt=get_num_layers_from_ckpt)
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(local_path)
19
+ tokenizer.padding_side = "left"
20
+ tokenizer.truncation_side = "left"
21
+ return model, tokenizer
22
+
23
+
24
+ def generate_elm_response_given_model(prompts, model, tokenizer,
25
+ device="cuda",
26
+ max_ctx_word_len=1024,
27
+ max_ctx_token_len=0,
28
+ max_new_tokens=500,
29
+ temperature=0.8, # set to 0 for greedy decoding
30
+ top_k=200,
31
+ return_tok_cnt=False,
32
+ return_gen_only=False,
33
+ early_stop_on_eos=False):
34
+ """Generate responses from ELM model given an input list of prompts ([str])."""
35
+ if max_ctx_token_len > 0:
36
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_ctx_token_len).to(device)
37
+ else:
38
+ prompts = [" ".join(p.split(" ")[-max_ctx_word_len:]) for p in prompts]
39
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
40
+
41
+ results = []
42
+
43
+ input_tok_cnt = torch.numel(inputs.input_ids)
44
+
45
+ model.eval()
46
+
47
+ out_tok_cnt = 0
48
+ with torch.no_grad():
49
+ temperature = temperature
50
+ top_k = top_k
51
+
52
+ outputs = model.generate(inputs.input_ids, max_new_tokens, temperature=temperature, top_k=top_k,
53
+ return_gen_only=return_gen_only)
54
+
55
+ if return_tok_cnt:
56
+ out_tok_cnt += torch.numel(outputs)
57
+
58
+ if early_stop_on_eos:
59
+ mod_outputs = []
60
+ for i in range(len(outputs)):
61
+ curr_out = outputs[i]
62
+
63
+ eos_loc_id = -1
64
+ for j in range(len(outputs[i])):
65
+ tok_id = outputs[i][j]
66
+ if tok_id == tokenizer.eos_token_id:
67
+ eos_loc_id = j
68
+ break
69
+ if eos_loc_id >= 0:
70
+ curr_out = outputs[i][:eos_loc_id]
71
+ mod_outputs.append(curr_out)
72
+ outputs = mod_outputs
73
+ detokenized_output = tokenizer.batch_decode(outputs, skip_special_tokens=False)
74
+
75
+ results = detokenized_output
76
+
77
+ if return_tok_cnt:
78
+ return results, (input_tok_cnt, out_tok_cnt)
79
+
80
+ return results
81
+
82
+ def generate_elm_responses(elm_model_path,
83
+ prompts,
84
+ device=None,
85
+ elm_model_config={},
86
+ eval_batch_size=1,
87
+ verbose=True):
88
+
89
+
90
+ if not device:
91
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
+ print(f"Setting device to {device}")
93
+
94
+ model_config_dict = {
95
+ "hidden_size": elm_model_config.get("hidden_size", 2048),
96
+ "max_inp_len": elm_model_config.get("max_inp_len", 2048),
97
+ "num_attention_heads": elm_model_config.get("num_attention_heads", 32),
98
+ "num_layers": elm_model_config.get("num_layers", 48),
99
+ "bits": elm_model_config.get("bits", 256),
100
+ "vocab_size": elm_model_config.get("vocab_size", 50304),
101
+ "dropout": elm_model_config.get("dropout", 0.1),
102
+ "use_rotary_embeddings": elm_model_config.get("use_rotary_embeddings", True)
103
+ }
104
+
105
+ model, tokenizer = load_elm_model_and_tokenizer(local_path=elm_model_path, model_config_dict=model_config_dict, device=device, load_partial=True)
106
+
107
+ #prompts = [prompt if "[INST]" in prompt else f"[INST]{prompt}[/INST]" for prompt in prompts]
108
+ max_new_tokens = 128
109
+ if "classification" in elm_model_path or "detection" in elm_model_path:
110
+ max_new_tokens = 12
111
+ result = []
112
+ for prompt_batch in batchify(prompts, eval_batch_size):
113
+ responses, _ = generate_elm_response_given_model(prompt_batch,
114
+ model,
115
+ tokenizer,
116
+ device=device,
117
+ max_ctx_word_len=1024,
118
+ max_ctx_token_len=512,
119
+ max_new_tokens=max_new_tokens,
120
+ return_tok_cnt=True,
121
+ return_gen_only=False,
122
+ temperature=0.0,
123
+ early_stop_on_eos=True)
124
+
125
+ for prompt, response in zip(prompt_batch, responses):
126
+ response = response.split("[/INST]")[-1].strip()
127
+ result.append(response)
128
+ if verbose:
129
+ print(json.dumps({"prompt": prompt, "response": response}, indent=4))
130
+ print("\n***\n")
131
+ return result
132
+
elm/model.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
2
+
3
+ import copy
4
+ import inspect
5
+ import math
6
+ import numpy as np
7
+ import os
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import List, Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from elm.utils import *
17
+ from elm.positional_embeddings import *
18
+
19
+
20
+ def get_elm_model_map(model_name):
21
+ """Map the model type to corresponding class."""
22
+ elm_model_map = {
23
+ "rambutan": RambutanSlice,
24
+ }
25
+
26
+ return elm_model_map.get(model_name, RambutanSlice)
27
+
28
+
29
+ @dataclass
30
+ class ModelArgs:
31
+ """ELM Model Args"""
32
+ model_name_or_path: str = "ELM"
33
+ compile_model: bool = False
34
+ elm_model_class: Optional[str] = "rambutan"
35
+ hidden_size: Optional[int] = 2048
36
+ max_inp_len: Optional[int] = 2048
37
+ attn_window_size: Optional[int] = max_inp_len
38
+ num_attention_heads: Optional[int] = 32
39
+ layernorm_eps: float = 1e-5
40
+ attention_dropout: float = 0.1
41
+ hidden_dropout: float = 0.1
42
+ num_layers: Optional[int] = 16
43
+ bits: Optional[int] = 256
44
+ vocab_size: Optional[int] = 50304
45
+ dropout: Optional[int] = 0.1
46
+ use_rotary_embeddings: Optional[bool] = True
47
+ tokenizer: Optional[str] = None
48
+
49
+
50
+ class ELM(torch.nn.Module):
51
+ """ELM (SliceX GPT) model."""
52
+ def __init__(self,
53
+ model_args: ModelArgs):
54
+ """Initialize an ELM model instance."""
55
+ super().__init__()
56
+
57
+ self.model_args = model_args
58
+
59
+ elm_model_class = model_args.elm_model_class
60
+ hidden_size = model_args.hidden_size
61
+ max_inp_len = model_args.max_inp_len
62
+ num_attention_heads = model_args.num_attention_heads
63
+ layernorm_eps = model_args.layernorm_eps
64
+ attention_dropout = model_args.attention_dropout
65
+ hidden_dropout = model_args.hidden_dropout
66
+ num_layers = model_args.num_layers
67
+ bits = model_args.bits
68
+ vocab_size = model_args.vocab_size
69
+ use_rotary_embeddings = model_args.use_rotary_embeddings
70
+
71
+ layer_class = get_elm_model_map(elm_model_class)
72
+
73
+ self.slice_transformer = torch.nn.ModuleDict(dict(
74
+ temb = torch.nn.Embedding(vocab_size, hidden_size),
75
+ pemb = torch.nn.Embedding(max_inp_len, hidden_size) if not use_rotary_embeddings else None,
76
+ drop = torch.nn.Dropout(hidden_dropout),
77
+ h = torch.nn.ModuleList([ layer_class(model_args=model_args) for _ in range(num_layers) ]),
78
+ ln_f = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps),
79
+ ))
80
+
81
+ self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)
82
+
83
+ print("Number of model parameters: %.2fM" % (self.get_num_params(False)/1e6,))
84
+
85
+
86
+ def forward(self,
87
+ x: torch.Tensor,
88
+ attention_mask: Optional[torch.Tensor] = None,
89
+ targets: Optional[torch.Tensor] = None):
90
+ device = x.device
91
+ batch, seqlen = x.size()
92
+
93
+
94
+ tok_emb = self.slice_transformer.temb(x)
95
+
96
+ if not self.model_args.use_rotary_embeddings:
97
+ pos = torch.arange(0, seqlen, dtype=torch.long, device=device)
98
+ pos_emb = self.slice_transformer.pemb(pos)
99
+ x = self.slice_transformer.drop(tok_emb + pos_emb)
100
+ else:
101
+ x = self.slice_transformer.drop(tok_emb)
102
+
103
+ tlayer_id = 0
104
+ ignore_index_id = -100
105
+ loss = torch.zeros(1).to(device)
106
+ loss_denom = 0
107
+
108
+ for tlayer in self.slice_transformer.h:
109
+ x = tlayer(x, attention_mask=attention_mask)
110
+
111
+ tlayer_id += 1
112
+
113
+ x = self.slice_transformer.ln_f(x)
114
+
115
+ if targets is not None:
116
+ logits = self.lm_head(x)
117
+
118
+ shift_logits = logits[..., :-1, :].contiguous()
119
+ shift_targets = targets[..., 1:].contiguous()
120
+ curr_loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)),
121
+ shift_targets.view(-1),
122
+ ignore_index=ignore_index_id)
123
+ loss += curr_loss.float()
124
+ loss_denom += 1
125
+ else:
126
+ logits = self.lm_head(x[:, [-1], :])
127
+
128
+ loss = loss / loss_denom
129
+
130
+ return logits, loss
131
+
132
+
133
+ def get_num_params(self, non_embedding=True):
134
+ """
135
+ Return the number of parameters in the model.
136
+ For non-embedding count (default), the position embeddings get subtracted.
137
+ This assumes parameter tying between input and final layer embeddings. Oherwise
138
+ If there is no parameter sharing , set the flag to False to include parameters for both layers.
139
+ """
140
+ n_params = sum(p.numel() for p in self.parameters())
141
+ if non_embedding and not self.model_args.use_rotary_embeddings:
142
+ n_params -= self.slice_transformer.pemb.weight.numel()
143
+ return n_params
144
+
145
+
146
+ @torch.no_grad()
147
+ def generate(self, x, max_new_tokens, temperature=0.8, top_k=200, top_p=0.9,
148
+ return_gen_only=False):
149
+ max_inp_len = self.model_args.max_inp_len
150
+
151
+ for _ in range(max_new_tokens):
152
+ x_ctxt = x if x.size(1) <= max_inp_len else x[:, -max_inp_len:]
153
+
154
+ logits, _ = self(x_ctxt)
155
+
156
+ next_id = None
157
+
158
+ if temperature <= 0:
159
+ next_id = torch.argmax(logits, dim=-1)
160
+ else:
161
+ logits = logits[:, -1, :] / temperature
162
+
163
+ if top_k is not None:
164
+ v, k = torch.topk(logits, min(top_k, logits.size(-1)))
165
+ logits[logits < v[:, [-1]]] = -float('Inf')
166
+
167
+ probs = F.softmax(logits, dim=-1)
168
+
169
+ if top_p is None:
170
+ next_id = torch.multinomial(probs, num_samples=1)
171
+ else:
172
+ next_id = sample_top_p(probs, top_p)
173
+ x = torch.cat((x, next_id), dim=1)
174
+
175
+ if return_gen_only:
176
+ return x[:,-max_new_tokens:]
177
+
178
+ return x
179
+
180
+
181
+ class RambutanMLP(torch.nn.Module):
182
+ """RambutanMLP version of MLP module used in the ELM (SliceX GPT) Transformer block."""
183
+ def __init__(self, dim=768, bits=32, dropout = 0.0):
184
+ super(RambutanMLP, self).__init__()
185
+ self.dim = dim
186
+ self.bits = bits
187
+
188
+ self.dropout = torch.nn.Dropout(dropout)
189
+
190
+ self.A1_c_w = torch.nn.Linear(self.dim, self.bits, bias=True)
191
+
192
+ self.Hexperts = 4
193
+ self.Hexpertemb = torch.nn.Embedding(self.bits, self.dim)
194
+
195
+ self.expert_aggr = torch.nn.Linear(self.Hexperts, 1)
196
+
197
+
198
+ def forward(self, x):
199
+ h_c = torch.nn.functional.softmax(self.A1_c_w(x), dim=-1)
200
+
201
+ v, i = torch.topk(h_c, self.Hexperts)
202
+
203
+ if len(x.size()) < 3:
204
+ p = v.unsqueeze(-1).expand(-1,-1,self.dim)
205
+ else:
206
+ p = v.unsqueeze(-1).expand(-1,-1,-1,self.dim)
207
+
208
+ h_emb = p * self.Hexpertemb(i)
209
+
210
+ if len(x.size()) < 3:
211
+ out = self.expert_aggr(h_emb.transpose(1,2)).reshape(h_emb.size(0), -1)
212
+ else:
213
+ out = self.expert_aggr(h_emb.transpose(-2,-1)).reshape(x.size())
214
+
215
+ out = x * out
216
+ out = self.dropout(out)
217
+
218
+ return out
219
+
220
+
221
+ class RambutanSlice(torch.nn.Module):
222
+ """Rambutan version of ELM (SliceX GPT) Transformer block."""
223
+ def __init__(self,
224
+ model_args: ModelArgs):
225
+ super().__init__()
226
+
227
+ self.model_args = model_args
228
+
229
+ self.num_attention_heads = model_args.num_attention_heads
230
+ self.kv_channels = model_args.hidden_size // model_args.num_attention_heads
231
+ self.ln1 = torch.nn.LayerNorm(model_args.hidden_size, eps=model_args.layernorm_eps)
232
+ self.ln2 = torch.nn.LayerNorm(model_args.hidden_size, eps=model_args.layernorm_eps)
233
+ self.mlp = RambutanMLP(dim=model_args.hidden_size, bits=model_args.bits)
234
+ self.cattn = RambutanCausalSelfAttention(model_args=model_args)
235
+
236
+
237
+ def forward(self,
238
+ x: torch.Tensor,
239
+ attention_mask: torch.Tensor = None):
240
+ res = x
241
+
242
+ x = self.ln1(x)
243
+ x = self.cattn(x, attention_mask=attention_mask)
244
+
245
+ x = res + x
246
+ res = x
247
+ x = self.ln2(x)
248
+ x = self.mlp(x)
249
+
250
+ return x + res
251
+
252
+
253
+ class RambutanCausalSelfAttention(torch.nn.Module):
254
+ """Rambutan version of self-attention module used in the ELM (SliceX GPT) transformer block."""
255
+
256
+ def __init__(self,
257
+ model_args: ModelArgs):
258
+ super().__init__()
259
+
260
+ self.model_args = model_args
261
+
262
+ n_embd = model_args.hidden_size
263
+ n_head = model_args.num_attention_heads
264
+ bias = False
265
+ dropout = model_args.attention_dropout
266
+
267
+ assert n_embd % n_head == 0
268
+
269
+ self.c_attn = torch.nn.Linear(n_embd, 3 * n_embd, bias=bias)
270
+
271
+ self.c_proj = torch.nn.Linear(n_embd, n_embd, bias=bias)
272
+
273
+ self.attn_dropout = torch.nn.Dropout(dropout)
274
+ self.resid_dropout = torch.nn.Dropout(dropout)
275
+ self.n_head = n_head
276
+ self.n_embd = n_embd
277
+ self.dropout = dropout
278
+
279
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
280
+
281
+ if not self.flash:
282
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
283
+ self.rotary_embeddings = (
284
+ RotaryEmbedding(n_embd // n_head) if model_args.use_rotary_embeddings else None
285
+ )
286
+
287
+
288
+ def forward(self, x, attention_mask: torch.Tensor = None):
289
+ B, T, C = x.size()
290
+ device = x.device
291
+
292
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
293
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
294
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
295
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
296
+
297
+ if self.rotary_embeddings:
298
+ q, k = self.rotary_embeddings(q=q, k=k)
299
+
300
+ is_causal = True
301
+ attn_mask = None
302
+
303
+ if attention_mask is not None:
304
+ att_mask_input = attention_mask
305
+ att_mask_input = att_mask_input.unsqueeze(-1).expand(B, T, T)
306
+
307
+ if is_causal:
308
+ att_mask_causal = torch.tril(torch.ones(T, T)).view(1,T,T).expand(B,T,T).to(device)
309
+ attn_mask = (att_mask_causal * att_mask_input)
310
+ else:
311
+ attn_mask = att_mask_input
312
+
313
+ attn_mask = attn_mask.unsqueeze(1).expand(B, self.n_head, T, T)
314
+ attn_mask.float().to(device)
315
+
316
+
317
+ if self.flash:
318
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0, is_causal=True)
319
+ else:
320
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
321
+
322
+ if is_causal and attn_mask is None:
323
+ attn_mask = torch.tril(torch.ones(T, T)).view(1,T,T).expand(B,T,T).to(device)
324
+ attn_mask = attn_mask.unsqueeze(1).expand(B, self.n_head, T, T)
325
+
326
+ if attn_mask is not None:
327
+ att = att.masked_fill(attn_mask == 0, torch.finfo(att.dtype).min)
328
+
329
+ att = F.softmax(att, dim=-1)
330
+ att = self.attn_dropout(att)
331
+ y = att @ v
332
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
333
+
334
+ y = self.resid_dropout(self.c_proj(y))
335
+
336
+ return y
337
+
338
+
339
+ def init_elm_model(model_args=ModelArgs(), device="cuda", model_config_dict=None):
340
+ """Initialize ELM model using default or model_config parameters."""
341
+ if model_config_dict:
342
+ model_args = ModelArgs(**model_config_dict)
343
+
344
+ dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
345
+
346
+ model = ELM(model_args=model_args).to(dtype=dtype)
347
+
348
+ return model
349
+
350
+ def get_h_layers_in_ckpt(ckpt_state_dict,
351
+ layer_name_template = 'slice_transformer.h.{layer_num}.'):
352
+ num_layers_in_ckpt = 0
353
+ from collections import defaultdict
354
+ layer_wise_dict = defaultdict(lambda: defaultdict(list))
355
+
356
+ layer_num_found = True
357
+ while layer_num_found:
358
+ layer_num_found = False
359
+ for layer_name in ckpt_state_dict.keys():
360
+ if layer_name_template.format(layer_num=num_layers_in_ckpt) in layer_name:
361
+ layer_wise_dict[num_layers_in_ckpt][layer_name] = ckpt_state_dict[layer_name]
362
+ layer_num_found = True
363
+ num_layers_in_ckpt += 1
364
+ return layer_wise_dict
365
+
366
+ def load_elm_model_from_ckpt(ckpt_dir, device='cuda', load_partial=False, model_args=ModelArgs(), get_num_layers_from_ckpt=True):
367
+ """Load ELM model from local checkpoint."""
368
+ print(f"Loading ELM checkpoint from {ckpt_dir}")
369
+ ckpt_path = os.path.join(ckpt_dir, 'ckpt.pt')
370
+ checkpoint = torch.load(ckpt_path, map_location=device)
371
+
372
+ if get_num_layers_from_ckpt:
373
+ layer_name_template = 'slice_transformer.h.{layer_num}.'
374
+ ckpt_layer_wise_dict = get_h_layers_in_ckpt(checkpoint['model'],
375
+ layer_name_template = layer_name_template)
376
+ model_args.num_layers = len(ckpt_layer_wise_dict)
377
+ model = init_elm_model(model_args=model_args, device=device)
378
+ ckpt_state_dict = checkpoint['model']
379
+
380
+ unwanted_prefix = '_orig_mod.'
381
+ for k,v in list(ckpt_state_dict.items()):
382
+ if k.startswith(unwanted_prefix):
383
+ ckpt_state_dict[k[len(unwanted_prefix):]] = ckpt_state_dict.pop(k)
384
+
385
+ if load_partial:
386
+ mod_state_dict = model.state_dict()
387
+ for k,v in list(ckpt_state_dict.items()):
388
+ if k in mod_state_dict:
389
+ v_size = v.size()
390
+ mod_size = mod_state_dict[k].size()
391
+
392
+ if v_size == mod_size:
393
+ mod_state_dict[k] = v
394
+ else:
395
+ if len(v_size) == 1:
396
+ mod_state_dict[k][:v_size[-1]] = v
397
+ elif len(v_size) == 2:
398
+ mod_state_dict[k][:v_size[-2], :v_size[-1]] = v
399
+
400
+ ckpt_state_dict = mod_state_dict
401
+ load_status = model.load_state_dict(ckpt_state_dict)
402
+ print(load_status)
403
+ model.to(device)
404
+
405
+ return model
406
+
407
+
408
+ def sample_top_p(probs, threshold):
409
+ """Perform top-p sampling on probability distribution using a threshold."""
410
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
411
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
412
+ mask = probs_sum - probs_sort > threshold
413
+ probs_sort[mask] = 0.0
414
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
415
+ next_token = torch.multinomial(probs_sort, num_samples=1)
416
+ next_token = torch.gather(probs_idx, -1, next_token)
417
+
418
+ return next_token
elm/positional_embeddings.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Tuple
3
+
4
+
5
+ def rotate_half(x):
6
+ x1, x2 = x.chunk(2, dim=-1)
7
+ return torch.cat((-x2, x1), dim=-1)
8
+
9
+
10
+ @torch.jit.script
11
+ def apply_rotary_pos_emb(x, cos, sin):
12
+ # NOTE: This could probably be moved to Triton
13
+
14
+ # Handle a possible sequence length mismatch in between q and k
15
+ cos = cos[:, :, : x.shape[-2], :]
16
+ sin = sin[:, :, : x.shape[-2], :]
17
+
18
+ return (x * cos) + (rotate_half(x) * sin)
19
+
20
+
21
+ class RotaryEmbedding(torch.nn.Module):
22
+ """
23
+ Rotary position embeddings from RoFormer (Su et. al, 2021).
24
+ """
25
+
26
+ def __init__(self, dim_model: int, *_, **__):
27
+ super().__init__()
28
+ # Generate and save the inverse frequency buffer (non trainable)
29
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
30
+ self.register_buffer("inv_freq", inv_freq)
31
+
32
+ self._seq_len_cached = None
33
+ self._cos_cached = None
34
+ self._sin_cached = None
35
+
36
+ def update_cos_sin_tables(self, x, seq_dimension=1):
37
+ seq_len = x.shape[seq_dimension]
38
+
39
+ # Reset the tables if the sequence length has changed,
40
+ # or if we're on a new device (possibly due to tracing for instance)
41
+ if (
42
+ seq_len != self._seq_len_cached
43
+ or self._cos_cached.device != x.device
44
+ or self._cos_cached.dtype != x.dtype
45
+ ):
46
+ self._seq_len_cached = seq_len
47
+ t = torch.arange(
48
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
49
+ )
50
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
51
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
52
+
53
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
54
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
55
+
56
+ return self._cos_cached, self._sin_cached
57
+
58
+ def forward(
59
+ self, q: torch.Tensor, k: torch.Tensor
60
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ self._cos_cached, self._sin_cached = self.update_cos_sin_tables(
62
+ k, seq_dimension=-2
63
+ )
64
+
65
+ return (
66
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
67
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
68
+ )
69
+
70
+
71
+ def __test_rope__():
72
+ dtype=torch.float16
73
+ batch=4
74
+ seqlen=2048
75
+ dim=4096
76
+ num_heads=32
77
+ dim_key_head=dim // num_heads
78
+
79
+ x=torch.randn(batch,seqlen,num_heads,dim_key_head).to(dtype=dtype).to('cuda')
80
+
81
+ rpe=RotaryEmbedding(dim_key_head).to(dtype=dtype).to('cuda')
82
+ q,k=rpe(q=x,k=x)
83
+
84
+
85
+ #__test_rope__()
86
+
elm/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
2
+
3
+ from prettytable import PrettyTable
4
+
5
+ def count_parameters(model):
6
+ """Count the number of parameters in the model."""
7
+ table = PrettyTable(["Modules", "Parameters"])
8
+ total_params = 0
9
+
10
+ for name, parameter in model.named_parameters():
11
+ if not parameter.requires_grad: continue
12
+ params = parameter.numel()
13
+ table.add_row([name, params])
14
+ total_params+=params
15
+
16
+ print(table)
17
+ print(f"Total Trainable Params: {total_params}")
18
+
19
+ return total_params
20
+
21
+
22
+ def batchify(lst, n):
23
+ """Divide a list into chunks of size n."""
24
+ return [lst[i:i + n] for i in range(0, len(lst), n)]
25
+
models/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ “*.pt” filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ transformers
run.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ from elm.infer_elm import generate_elm_responses
5
+
6
+ parser = argparse.ArgumentParser(description='run prompts with elm model.')
7
+ parser.add_argument('elm_model_path', help='Path to the elm_model_path')
8
+
9
+
10
+ def get_prompt_config_file(elm_model_path):
11
+ return os.path.join(elm_model_path, "example_prompts.json")
12
+
13
+ def run(elm_model_path: str):
14
+ prompt_config_file = get_prompt_config_file(elm_model_path)
15
+
16
+ with open(prompt_config_file, "r") as f:
17
+ prompt_info = json.load(f)
18
+ prompts = [prompt_info["template"].format(input=input) for input in prompt_info["inputs"]]
19
+ print(f"Loaded prompts from: {prompt_config_file}")
20
+ generate_elm_responses(elm_model_path, prompts, verbose=True)
21
+
22
+ if __name__ == "__main__":
23
+ args = parser.parse_args()
24
+ run(args.elm_model_path)