davda54 commited on
Commit
beb8435
·
verified ·
1 Parent(s): 5cc6955

Delete convert_to_safetensors.py

Browse files
Files changed (1) hide show
  1. convert_to_safetensors.py +0 -68
convert_to_safetensors.py DELETED
@@ -1,68 +0,0 @@
1
- import argparse
2
- import random
3
- from statistics import mean, stdev
4
- from typing import List
5
- import torch
6
- import torchmetrics
7
- from datasets import load_dataset
8
- from tqdm import tqdm
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
-
11
-
12
- def parse_args():
13
- parser = argparse.ArgumentParser()
14
- parser.add_argument(
15
- "--model_name_or_path",
16
- type=str,
17
- default="/scratch/project_465000144/dasamuel/normistral/normistral-11b-masked-post-hf-60000",
18
- help="Path to the pre-trained model",
19
- )
20
- args = parser.parse_args()
21
-
22
- return args
23
-
24
-
25
- def load_model(model_path: str):
26
- # Load the pre-trained model and tokenizer
27
- tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".", token="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", torch_dtype=torch.bfloat16)
28
- model = AutoModelForCausalLM.from_pretrained(model_path, cache_dir=".", token="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", torch_dtype=torch.bfloat16).cuda().eval()
29
-
30
- eos_token_ids = [
31
- token_id
32
- for token_id in range(tokenizer.vocab_size)
33
- if "\n" in tokenizer.decode([token_id])
34
- ]
35
-
36
- if hasattr(model.config, "n_positions"):
37
- max_length = model.config.n_positions
38
- elif hasattr(model.config, "max_position_embeddings"):
39
- max_length = model.config.max_position_embeddings
40
- elif hasattr(model.config, "max_length"):
41
- max_length = model.config.max_length
42
- elif hasattr(model.config, "n_ctx"):
43
- max_length = model.config.n_ctx
44
- else:
45
- max_length = 4096 # Default value
46
-
47
- return {
48
- "name": model_path.split("/")[-1],
49
- "tokenizer": tokenizer,
50
- "model": model,
51
- "eos_token_ids": eos_token_ids,
52
- "max_length": max_length,
53
- }
54
-
55
-
56
- def main():
57
- args = parse_args()
58
-
59
- model = load_model(args.model_name_or_path)
60
-
61
- model["model"].save_pretrained(
62
- args.model_name_or_path,
63
- max_shard_size="4.7GB"
64
- )
65
-
66
-
67
- if __name__ == "__main__":
68
- main()