Add transformers library and text-generation pipeline to model card

#1
by nielsr HF staff - opened
Files changed (1) hide show
  1. README.md +78 -75
README.md CHANGED
@@ -1,75 +1,78 @@
1
- ---
2
- language:
3
- - en
4
- tags:
5
- - pytorch
6
- - causal-lm
7
- - muddformer
8
- license: mit
9
- ---
10
- MUDDFormer-2.8B is a pretrained language model on the Pile with 300B tokens, which uses a simple yet effective method to address the limitations of residual connections and enhance cross-layer information flow in Transformers. Please see downstrem evaluations and more details in the paper[(MUDDFormer: Breaking Residual Bottlenecks in Transformers via Multiway Dynamic Dense Connections)](https://arxiv.org/abs/2502.12170). In addition, we open-source Jax training code on [(Github)](https://github.com/Caiyun-AI/MUDDFormer/).
11
-
12
- We recommend <strong>compiled version</strong> of MUDDFormer with *torch.compile* for inference acceleration. Please refer to Generation section for compile implementation.
13
-
14
- # Usage
15
-
16
- ## Env
17
-
18
- ```
19
- pip install transformers==4.40.2 torch==2.5.1 einops==0.8.0
20
- ```
21
-
22
- ## Generation
23
-
24
- ```
25
- import time
26
- from transformers import AutoTokenizer, AutoModelForCausalLM
27
- import torch
28
-
29
- import os
30
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
31
-
32
- device = torch.device('cuda:0')
33
- dtype = torch.bfloat16
34
- MAX_BATCH_SIZE = 1
35
- MAX_SEQ_LENGTH = 2048
36
- NUM_TOKENS_TO_GENERATE = 10
37
- COMPILE = True
38
- OPTIMIZED_COMPILE = False
39
-
40
- if OPTIMIZED_COMPILE:
41
- import torch._dynamo.config
42
- import torch._inductor.config
43
- torch._dynamo.config.cache_size_limit = 64
44
- torch._inductor.config.coordinate_descent_tuning = True
45
- torch._inductor.config.triton.unique_kernel_names = True
46
- torch._inductor.config.fx_graph_cache = True
47
-
48
- tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/MUDDFormer-2.8B")
49
- model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/MUDDFormer-2.8B", trust_remote_code=True)
50
-
51
- _ = model.to(device=device,dtype=dtype)
52
- with torch.device(device):
53
- model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH, dtype=dtype)
54
-
55
- def decode_one_token(model, cur_token, input_pos):
56
- logits = model(cur_token, input_pos=input_pos, return_tensor=True)
57
- new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
58
- return new_token
59
-
60
- prompt = "Beijing is the capital of China. London is the capital of"
61
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
62
-
63
- compiled_decode_one_token = torch.compile(decode_one_token,mode="reduce-overhead", fullgraph=True) if COMPILE else None
64
-
65
- print('Start generating tokens, but it will take a few minutes to compile at the first time.')
66
- for i in range(10):
67
- t0 = time.time()
68
- with torch.no_grad():
69
- generated_ids = model.generate(input_ids.to(device),num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)
70
- text = tokenizer.decode(generated_ids[0])
71
- if i ==0:
72
- print(f'Generated text: {text}')
73
- t1 = time.time()
74
- print(f'Time consumed at iteration {i}: {t1-t0}s')
75
- ```
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ tags:
8
+ - pytorch
9
+ - causal-lm
10
+ - muddformer
11
+ ---
12
+
13
+ MUDDFormer-2.8B is a pretrained language model on the Pile with 300B tokens, which uses a simple yet effective method to address the limitations of residual connections and enhance cross-layer information flow in Transformers. Please see downstrem evaluations and more details in the paper[(MUDDFormer: Breaking Residual Bottlenecks in Transformers via Multiway Dynamic Dense Connections)](https://arxiv.org/abs/2502.12170). In addition, we open-source Jax training code on [(Github)](https://github.com/Caiyun-AI/MUDDFormer/).
14
+
15
+ We recommend <strong>compiled version</strong> of MUDDFormer with *torch.compile* for inference acceleration. Please refer to Generation section for compile implementation.
16
+
17
+ # Usage
18
+
19
+ ## Env
20
+
21
+ ```
22
+ pip install transformers==4.40.2 torch==2.5.1 einops==0.8.0
23
+ ```
24
+
25
+ ## Generation
26
+
27
+ ```python
28
+ import time
29
+ from transformers import AutoTokenizer, AutoModelForCausalLM
30
+ import torch
31
+
32
+ import os
33
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
34
+
35
+ device = torch.device('cuda:0')
36
+ dtype = torch.bfloat16
37
+ MAX_BATCH_SIZE = 1
38
+ MAX_SEQ_LENGTH = 2048
39
+ NUM_TOKENS_TO_GENERATE = 10
40
+ COMPILE = True
41
+ OPTIMIZED_COMPILE = False
42
+
43
+ if OPTIMIZED_COMPILE:
44
+ import torch._dynamo.config
45
+ import torch._inductor.config
46
+ torch._dynamo.config.cache_size_limit = 64
47
+ torch._inductor.config.coordinate_descent_tuning = True
48
+ torch._inductor.config.triton.unique_kernel_names = True
49
+ torch._inductor.config.fx_graph_cache = True
50
+
51
+ tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/MUDDFormer-2.8B")
52
+ model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/MUDDFormer-2.8B", trust_remote_code=True)
53
+
54
+ _ = model.to(device=device,dtype=dtype)
55
+ with torch.device(device):
56
+ model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH, dtype=dtype)
57
+
58
+ def decode_one_token(model, cur_token, input_pos):
59
+ logits = model(cur_token, input_pos=input_pos, return_tensor=True)
60
+ new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
61
+ return new_token
62
+
63
+ prompt = "Beijing is the capital of China. London is the capital of"
64
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
65
+
66
+ compiled_decode_one_token = torch.compile(decode_one_token,mode="reduce-overhead", fullgraph=True) if COMPILE else None
67
+
68
+ print('Start generating tokens, but it will take a few minutes to compile at the first time.')
69
+ for i in range(10):
70
+ t0 = time.time()
71
+ with torch.no_grad():
72
+ generated_ids = model.generate(input_ids.to(device),num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)
73
+ text = tokenizer.decode(generated_ids[0])
74
+ if i ==0:
75
+ print(f'Generated text: {text}')
76
+ t1 = time.time()
77
+ print(f'Time consumed at iteration {i}: {t1-t0}s')
78
+ ```