chinnadhurai sankar commited on
Commit
f396208
1 Parent(s): d7fe0a9

initial commit

Browse files
elm/infer_elm.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
2
 
3
  from elm.model import *
4
  from elm.utils import batchify
@@ -129,4 +129,4 @@ def generate_elm_responses(elm_model_path,
129
  print(json.dumps({"prompt": prompt, "response": response}, indent=4))
130
  print("\n***\n")
131
  return result
132
-
 
1
+ # Copyright (c) 2024, SliceX AI, Inc.
2
 
3
  from elm.model import *
4
  from elm.utils import batchify
 
129
  print(json.dumps({"prompt": prompt, "response": response}, indent=4))
130
  print("\n***\n")
131
  return result
132
+
elm/infer_elm_for_demo_app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, SliceX AI, Inc.
2
+
3
+ from elm.model import *
4
+ from elm.utils import batchify
5
+ from transformers import AutoTokenizer
6
+ import json
7
+
8
+
9
+ def load_elm_model_and_tokenizer(local_path,
10
+ model_config_dict,
11
+ device="cuda",
12
+ load_partial=True,
13
+ get_num_layers_from_ckpt=True):
14
+ """Load ELM model and tokenizer from local checkpoint."""
15
+ model_args = ModelArgs(**model_config_dict)
16
+ model = load_elm_model_from_ckpt(local_path, device=device, model_args=model_args, load_partial=load_partial, get_num_layers_from_ckpt=get_num_layers_from_ckpt)
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(local_path)
19
+ tokenizer.padding_side = "left"
20
+ tokenizer.truncation_side = "left"
21
+ return model, tokenizer
22
+
23
+
24
+ def generate_elm_response_given_model(prompts, model, tokenizer,
25
+ device="cuda",
26
+ max_ctx_word_len=1024,
27
+ max_ctx_token_len=0,
28
+ max_new_tokens=500,
29
+ temperature=0.8, # set to 0 for greedy decoding
30
+ top_k=200,
31
+ return_tok_cnt=False,
32
+ return_gen_only=False,
33
+ early_stop_on_eos=False):
34
+ """Generate responses from ELM model given an input list of prompts ([str])."""
35
+ if max_ctx_token_len > 0:
36
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_ctx_token_len).to(device)
37
+ else:
38
+ prompts = [" ".join(p.split(" ")[-max_ctx_word_len:]) for p in prompts]
39
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
40
+
41
+ results = []
42
+
43
+ input_tok_cnt = torch.numel(inputs.input_ids)
44
+
45
+ model.eval()
46
+
47
+ out_tok_cnt = 0
48
+ with torch.no_grad():
49
+ temperature = temperature
50
+ top_k = top_k
51
+
52
+ outputs = model.generate(inputs.input_ids, max_new_tokens, temperature=temperature, top_k=top_k,
53
+ return_gen_only=return_gen_only)
54
+
55
+ if return_tok_cnt:
56
+ out_tok_cnt += torch.numel(outputs)
57
+
58
+ if early_stop_on_eos:
59
+ mod_outputs = []
60
+ for i in range(len(outputs)):
61
+ curr_out = outputs[i]
62
+
63
+ eos_loc_id = -1
64
+ for j in range(len(outputs[i])):
65
+ tok_id = outputs[i][j]
66
+ if tok_id == tokenizer.eos_token_id:
67
+ eos_loc_id = j
68
+ break
69
+ if eos_loc_id >= 0:
70
+ curr_out = outputs[i][:eos_loc_id]
71
+ mod_outputs.append(curr_out)
72
+ outputs = mod_outputs
73
+ detokenized_output = tokenizer.batch_decode(outputs, skip_special_tokens=False)
74
+
75
+ results = detokenized_output
76
+
77
+ if return_tok_cnt:
78
+ return results, (input_tok_cnt, out_tok_cnt)
79
+
80
+ return results
81
+
82
+ def load_elm_model_given_path(elm_model_path, elm_model_config={}, device=None):
83
+ if not device:
84
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
85
+ print(f"Setting device to {device}")
86
+ model_config_dict = {
87
+ "hidden_size": elm_model_config.get("hidden_size", 2048),
88
+ "max_inp_len": elm_model_config.get("max_inp_len", 2048),
89
+ "num_attention_heads": elm_model_config.get("num_attention_heads", 32),
90
+ "num_layers": elm_model_config.get("num_layers", 48),
91
+ "bits": elm_model_config.get("bits", 256),
92
+ "vocab_size": elm_model_config.get("vocab_size", 50304),
93
+ "dropout": elm_model_config.get("dropout", 0.1),
94
+ "use_rotary_embeddings": elm_model_config.get("use_rotary_embeddings", True)
95
+ }
96
+
97
+ model, tokenizer = load_elm_model_and_tokenizer(local_path=elm_model_path, model_config_dict=model_config_dict, device=device, load_partial=True)
98
+ return {"model": model, "tokenizer": tokenizer}
99
+
100
+ def generate_elm_responses(elm_model_path,
101
+ prompts,
102
+ device=None,
103
+ elm_model_config={},
104
+ eval_batch_size=1,
105
+ verbose=True,
106
+ model_info=None):
107
+
108
+
109
+ if not device:
110
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
111
+ print(f"Setting device to {device}")
112
+
113
+ if not model_info:
114
+ model_info = load_elm_model_given_path(elm_model_path, elm_model_config=elm_model_config, device=device)
115
+
116
+ model, tokenizer = model_info["model"], model_info["tokenizer"]
117
+
118
+ #prompts = [prompt if "[INST]" in prompt else f"[INST]{prompt}[/INST]" for prompt in prompts]
119
+ max_new_tokens = 128
120
+ if "classification" in elm_model_path or "detection" in elm_model_path:
121
+ max_new_tokens = 12
122
+ result = []
123
+ for prompt_batch in batchify(prompts, eval_batch_size):
124
+ responses, _ = generate_elm_response_given_model(prompt_batch,
125
+ model,
126
+ tokenizer,
127
+ device=device,
128
+ max_ctx_word_len=1024,
129
+ max_ctx_token_len=512,
130
+ max_new_tokens=max_new_tokens,
131
+ return_tok_cnt=True,
132
+ return_gen_only=False,
133
+ temperature=0.0,
134
+ early_stop_on_eos=True)
135
+
136
+ for prompt, response in zip(prompt_batch, responses):
137
+ response = response.split("[/INST]")[-1].strip()
138
+ result.append(response)
139
+ if verbose:
140
+ print(json.dumps({"prompt": prompt, "response": response}, indent=4))
141
+ print("\n***\n")
142
+ return result
143
+
elm/model.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
2
 
3
  import copy
4
  import inspect
@@ -100,15 +100,12 @@ class ELM(torch.nn.Module):
100
  else:
101
  x = self.slice_transformer.drop(tok_emb)
102
 
103
- tlayer_id = 0
104
  ignore_index_id = -100
105
  loss = torch.zeros(1).to(device)
106
  loss_denom = 0
107
 
108
  for tlayer in self.slice_transformer.h:
109
  x = tlayer(x, attention_mask=attention_mask)
110
-
111
- tlayer_id += 1
112
 
113
  x = self.slice_transformer.ln_f(x)
114
 
@@ -133,9 +130,8 @@ class ELM(torch.nn.Module):
133
  def get_num_params(self, non_embedding=True):
134
  """
135
  Return the number of parameters in the model.
136
- For non-embedding count (default), the position embeddings get subtracted.
137
- This assumes parameter tying between input and final layer embeddings. Oherwise
138
- If there is no parameter sharing , set the flag to False to include parameters for both layers.
139
  """
140
  n_params = sum(p.numel() for p in self.parameters())
141
  if non_embedding and not self.model_args.use_rotary_embeddings:
@@ -342,6 +338,8 @@ def init_elm_model(model_args=ModelArgs(), device="cuda", model_config_dict=None
342
  model_args = ModelArgs(**model_config_dict)
343
 
344
  dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
 
 
345
 
346
  model = ELM(model_args=model_args).to(dtype=dtype)
347
 
@@ -415,4 +413,4 @@ def sample_top_p(probs, threshold):
415
  next_token = torch.multinomial(probs_sort, num_samples=1)
416
  next_token = torch.gather(probs_idx, -1, next_token)
417
 
418
- return next_token
 
1
+ # Copyright (c) 2024, SliceX AI, Inc.
2
 
3
  import copy
4
  import inspect
 
100
  else:
101
  x = self.slice_transformer.drop(tok_emb)
102
 
 
103
  ignore_index_id = -100
104
  loss = torch.zeros(1).to(device)
105
  loss_denom = 0
106
 
107
  for tlayer in self.slice_transformer.h:
108
  x = tlayer(x, attention_mask=attention_mask)
 
 
109
 
110
  x = self.slice_transformer.ln_f(x)
111
 
 
130
  def get_num_params(self, non_embedding=True):
131
  """
132
  Return the number of parameters in the model.
133
+ For non-embedding count (default), subtract position embeddings if parameter tying applies.
134
+ If there is no parameter sharing, set the flag to False to include parameters for both input/output layers.
 
135
  """
136
  n_params = sum(p.numel() for p in self.parameters())
137
  if non_embedding and not self.model_args.use_rotary_embeddings:
 
338
  model_args = ModelArgs(**model_config_dict)
339
 
340
  dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
341
+ if not torch.cuda.is_available():
342
+ dtype = torch.bfloat16
343
 
344
  model = ELM(model_args=model_args).to(dtype=dtype)
345
 
 
413
  next_token = torch.multinomial(probs_sort, num_samples=1)
414
  next_token = torch.gather(probs_idx, -1, next_token)
415
 
416
+ return next_token
elm/positional_embeddings.py CHANGED
@@ -9,8 +9,6 @@ def rotate_half(x):
9
 
10
  @torch.jit.script
11
  def apply_rotary_pos_emb(x, cos, sin):
12
- # NOTE: This could probably be moved to Triton
13
-
14
  # Handle a possible sequence length mismatch in between q and k
15
  cos = cos[:, :, : x.shape[-2], :]
16
  sin = sin[:, :, : x.shape[-2], :]
 
9
 
10
  @torch.jit.script
11
  def apply_rotary_pos_emb(x, cos, sin):
 
 
12
  # Handle a possible sequence length mismatch in between q and k
13
  cos = cos[:, :, : x.shape[-2], :]
14
  sin = sin[:, :, : x.shape[-2], :]
elm/utils.py CHANGED
@@ -1,21 +1,16 @@
1
- # Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
2
 
3
- from prettytable import PrettyTable
4
 
5
  def count_parameters(model):
6
  """Count the number of parameters in the model."""
7
- table = PrettyTable(["Modules", "Parameters"])
8
  total_params = 0
9
 
10
  for name, parameter in model.named_parameters():
11
  if not parameter.requires_grad: continue
12
  params = parameter.numel()
13
- table.add_row([name, params])
14
  total_params+=params
15
 
16
- print(table)
17
  print(f"Total Trainable Params: {total_params}")
18
-
19
  return total_params
20
 
21
 
 
1
+ # Copyright (c) 2024, SliceX AI, Inc.
2
 
 
3
 
4
  def count_parameters(model):
5
  """Count the number of parameters in the model."""
 
6
  total_params = 0
7
 
8
  for name, parameter in model.named_parameters():
9
  if not parameter.requires_grad: continue
10
  params = parameter.numel()
 
11
  total_params+=params
12
 
 
13
  print(f"Total Trainable Params: {total_params}")
 
14
  return total_params
15
 
16