我问问
#1
by
DORA1222
- opened
- README.md +15 -16
- modeling_chatglm.py +8 -21
- tokenization_chatglm.py +1 -1
README.md
CHANGED
@@ -16,13 +16,7 @@ tags:
|
|
16 |
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
17 |
</p>
|
18 |
|
19 |
-
## 更新/Update
|
20 |
-
|
21 |
-
- 我们优化了KV Cache的存储方式,减少了显存碎片的产生。基于优化后的代码,模型可以在约**20G显存**的情况下处理32K长度的上下文(FP/BF16格式)。
|
22 |
-
- We have optimized the storage method of the KV Cache, reducing the generation of memory fragmentation. Based on the optimized code, the model can process a context length of 32K under approximately **20G** of memory (FP/BF16 format).
|
23 |
-
|
24 |
## 介绍
|
25 |
-
|
26 |
ChatGLM**2**-6B-32K在[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)的基础上进一步强化了对于长文本的理解能力,能够更好的处理最多32K长度的上下文。具体地,我们基于[位置插值](https://arxiv.org/abs/2306.15595)(Positional Interpolation)的方法对位置编码进行了更新,并在对话阶段使用 32K 的上下文长度训练。在实际的使用中,如果您面临的上下文长度基本在 **8K 以内**,我们推荐使用[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b);如果您需要处理**超过 8K** 的上下文长度,我们推荐使用ChatGLM2-6B-32K。
|
27 |
|
28 |
ChatGLM**2**-6B-32K是开源中英双语对话模型 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 的加长版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B-32k 引入了如下新特性:
|
@@ -86,17 +80,22 @@ For more instructions, including how to run CLI and web demos, and model quantiz
|
|
86 |
|
87 |
## 引用
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
If you find our work helpful, please consider citing the following paper.
|
92 |
|
93 |
```
|
94 |
-
@
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
}
|
102 |
```
|
|
|
16 |
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
17 |
</p>
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
## 介绍
|
|
|
20 |
ChatGLM**2**-6B-32K在[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)的基础上进一步强化了对于长文本的理解能力,能够更好的处理最多32K长度的上下文。具体地,我们基于[位置插值](https://arxiv.org/abs/2306.15595)(Positional Interpolation)的方法对位置编码进行了更新,并在对话阶段使用 32K 的上下文长度训练。在实际的使用中,如果您面临的上下文长度基本在 **8K 以内**,我们推荐使用[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b);如果您需要处理**超过 8K** 的上下文长度,我们推荐使用ChatGLM2-6B-32K。
|
21 |
|
22 |
ChatGLM**2**-6B-32K是开源中英双语对话模型 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 的加长版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B-32k 引入了如下新特性:
|
|
|
80 |
|
81 |
## 引用
|
82 |
|
83 |
+
如果你觉得我们的工作有帮助的话,请考虑引用下列论文,ChatGLM2-6B 的论文会在近期公布,敬请期待~
|
|
|
|
|
84 |
|
85 |
```
|
86 |
+
@article{zeng2022glm,
|
87 |
+
title={Glm-130b: An open bilingual pre-trained model},
|
88 |
+
author={Zeng, Aohan and Liu, Xiao and Du, Zhengxiao and Wang, Zihan and Lai, Hanyu and Ding, Ming and Yang, Zhuoyi and Xu, Yifan and Zheng, Wendi and Xia, Xiao and others},
|
89 |
+
journal={arXiv preprint arXiv:2210.02414},
|
90 |
+
year={2022}
|
91 |
+
}
|
92 |
+
```
|
93 |
+
```
|
94 |
+
@inproceedings{du2022glm,
|
95 |
+
title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
|
96 |
+
author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
|
97 |
+
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
|
98 |
+
pages={320--335},
|
99 |
+
year={2022}
|
100 |
}
|
101 |
```
|
modeling_chatglm.py
CHANGED
@@ -413,10 +413,7 @@ class SelfAttention(torch.nn.Module):
|
|
413 |
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
414 |
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
415 |
if use_cache:
|
416 |
-
|
417 |
-
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
|
418 |
-
else:
|
419 |
-
kv_cache = (key_layer, value_layer)
|
420 |
else:
|
421 |
kv_cache = None
|
422 |
|
@@ -615,8 +612,12 @@ class GLMTransformer(torch.nn.Module):
|
|
615 |
if not kv_caches:
|
616 |
kv_caches = [None for _ in range(self.num_layers)]
|
617 |
presents = () if use_cache else None
|
618 |
-
if self.training:
|
619 |
-
use_cache
|
|
|
|
|
|
|
|
|
620 |
|
621 |
all_self_attentions = None
|
622 |
all_hidden_states = () if output_hidden_states else None
|
@@ -644,15 +645,7 @@ class GLMTransformer(torch.nn.Module):
|
|
644 |
)
|
645 |
hidden_states, kv_cache = layer_ret
|
646 |
if use_cache:
|
647 |
-
|
648 |
-
if kv_caches[0] is not None:
|
649 |
-
presents = presents + (kv_cache,)
|
650 |
-
# prefilling in decoding, use tensor format to save cuda memory
|
651 |
-
else:
|
652 |
-
if len(presents) == 0:
|
653 |
-
presents = kv_cache
|
654 |
-
else:
|
655 |
-
presents = torch.cat((presents, kv_cache), dim=0)
|
656 |
|
657 |
if output_hidden_states:
|
658 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
@@ -837,12 +830,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
837 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
838 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
839 |
)
|
840 |
-
if presents is not None and type(presents) is torch.Tensor:
|
841 |
-
presents = presents.split(1, dim=0)
|
842 |
-
presents = list(presents)
|
843 |
-
presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
|
844 |
-
presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
|
845 |
-
presents = tuple(presents)
|
846 |
|
847 |
if not return_dict:
|
848 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
413 |
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
414 |
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
415 |
if use_cache:
|
416 |
+
kv_cache = (key_layer, value_layer)
|
|
|
|
|
|
|
417 |
else:
|
418 |
kv_cache = None
|
419 |
|
|
|
612 |
if not kv_caches:
|
613 |
kv_caches = [None for _ in range(self.num_layers)]
|
614 |
presents = () if use_cache else None
|
615 |
+
if self.gradient_checkpointing and self.training:
|
616 |
+
if use_cache:
|
617 |
+
logger.warning_once(
|
618 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
619 |
+
)
|
620 |
+
use_cache = False
|
621 |
|
622 |
all_self_attentions = None
|
623 |
all_hidden_states = () if output_hidden_states else None
|
|
|
645 |
)
|
646 |
hidden_states, kv_cache = layer_ret
|
647 |
if use_cache:
|
648 |
+
presents = presents + (kv_cache,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
|
650 |
if output_hidden_states:
|
651 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
830 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
831 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
832 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
833 |
|
834 |
if not return_dict:
|
835 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
tokenization_chatglm.py
CHANGED
@@ -66,6 +66,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
66 |
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
67 |
|
68 |
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
|
|
|
69 |
self.name = "GLMTokenizer"
|
70 |
|
71 |
self.vocab_file = vocab_file
|
@@ -75,7 +76,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
75 |
"<eos>": self.tokenizer.eos_id,
|
76 |
"<pad>": self.tokenizer.pad_id
|
77 |
}
|
78 |
-
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
|
79 |
|
80 |
def get_command(self, token):
|
81 |
if token in self.special_tokens:
|
|
|
66 |
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
67 |
|
68 |
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
|
69 |
+
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
|
70 |
self.name = "GLMTokenizer"
|
71 |
|
72 |
self.vocab_file = vocab_file
|
|
|
76 |
"<eos>": self.tokenizer.eos_id,
|
77 |
"<pad>": self.tokenizer.pad_id
|
78 |
}
|
|
|
79 |
|
80 |
def get_command(self, token):
|
81 |
if token in self.special_tokens:
|