StarCycle commited on
Commit
d2d310a
·
1 Parent(s): fde5383
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. epoch1.5_ckpt/config.json +47 -0
  2. epoch1.5_ckpt/configuration_internlm2.py +151 -0
  3. epoch1.5_ckpt/generation_config.json +7 -0
  4. epoch1.5_ckpt/model.safetensors +3 -0
  5. epoch1.5_ckpt/modeling_internlm2.py +1391 -0
  6. epoch1.5_ckpt/projector/config.json +17 -0
  7. epoch1.5_ckpt/projector/configuration_projector.py +23 -0
  8. epoch1.5_ckpt/projector/model.safetensors +3 -0
  9. epoch1.5_ckpt/projector/modeling_projector.py +51 -0
  10. epoch1.5_ckpt/special_tokens_map.json +6 -0
  11. epoch1.5_ckpt/tokenization_internlm2.py +236 -0
  12. epoch1.5_ckpt/tokenization_internlm2_fast.py +214 -0
  13. epoch1.5_ckpt/tokenizer.json +0 -0
  14. epoch1.5_ckpt/tokenizer.model +3 -0
  15. epoch1.5_ckpt/tokenizer_config.json +46 -0
  16. epoch1.5_ckpt/xtuner_config.py +207 -0
  17. epoch1_ckpt/config.json +47 -0
  18. epoch1_ckpt/configuration_internlm2.py +151 -0
  19. epoch1_ckpt/generation_config.json +7 -0
  20. epoch1_ckpt/model.safetensors +3 -0
  21. epoch1_ckpt/modeling_internlm2.py +1391 -0
  22. epoch1_ckpt/projector/config.json +17 -0
  23. epoch1_ckpt/projector/configuration_projector.py +23 -0
  24. epoch1_ckpt/projector/model.safetensors +3 -0
  25. epoch1_ckpt/projector/modeling_projector.py +51 -0
  26. epoch1_ckpt/special_tokens_map.json +6 -0
  27. epoch1_ckpt/tokenization_internlm2.py +236 -0
  28. epoch1_ckpt/tokenization_internlm2_fast.py +214 -0
  29. epoch1_ckpt/tokenizer.json +0 -0
  30. epoch1_ckpt/tokenizer.model +3 -0
  31. epoch1_ckpt/tokenizer_config.json +46 -0
  32. epoch1_ckpt/xtuner_config.py +207 -0
  33. epoch2_ckpt/config.json +47 -0
  34. epoch2_ckpt/configuration_internlm2.py +151 -0
  35. epoch2_ckpt/generation_config.json +7 -0
  36. epoch2_ckpt/model.safetensors +3 -0
  37. epoch2_ckpt/modeling_internlm2.py +1391 -0
  38. epoch2_ckpt/projector/config.json +17 -0
  39. epoch2_ckpt/projector/configuration_projector.py +23 -0
  40. epoch2_ckpt/projector/model.safetensors +3 -0
  41. epoch2_ckpt/projector/modeling_projector.py +51 -0
  42. epoch2_ckpt/special_tokens_map.json +6 -0
  43. epoch2_ckpt/tokenization_internlm2.py +236 -0
  44. epoch2_ckpt/tokenization_internlm2_fast.py +214 -0
  45. epoch2_ckpt/tokenizer.json +0 -0
  46. epoch2_ckpt/tokenizer.model +3 -0
  47. epoch2_ckpt/tokenizer_config.json +46 -0
  48. epoch2_ckpt/xtuner_config.py +221 -0
  49. modified_transformers/src/transformers/models/siglip/modeling_siglip.py +1299 -0
  50. modified_xtuner/xtuner/dataset/huggingface.py +242 -0
epoch1.5_ckpt/config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "internlm/internlm2-1_8b",
3
+ "architectures": [
4
+ "InternLM2ForCausalLM"
5
+ ],
6
+ "attn_implementation": "eager",
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_internlm2.InternLM2Config",
9
+ "AutoModel": "internlm/internlm2-1_8b--modeling_internlm2.InternLM2ForCausalLM",
10
+ "AutoModelForCausalLM": "modeling_internlm2.InternLM2ForCausalLM"
11
+ },
12
+ "bias": false,
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2048,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 8192,
19
+ "max_position_embeddings": 32768,
20
+ "model_type": "internlm2",
21
+ "num_attention_heads": 16,
22
+ "num_hidden_layers": 24,
23
+ "num_key_value_heads": 8,
24
+ "pad_token_id": 2,
25
+ "quantization_config": {
26
+ "_load_in_4bit": true,
27
+ "_load_in_8bit": false,
28
+ "bnb_4bit_compute_dtype": "float16",
29
+ "bnb_4bit_quant_type": "nf4",
30
+ "bnb_4bit_use_double_quant": true,
31
+ "llm_int8_enable_fp32_cpu_offload": false,
32
+ "llm_int8_has_fp16_weight": false,
33
+ "llm_int8_skip_modules": null,
34
+ "llm_int8_threshold": 6.0,
35
+ "load_in_4bit": true,
36
+ "load_in_8bit": false,
37
+ "quant_method": "bitsandbytes"
38
+ },
39
+ "rms_norm_eps": 1e-05,
40
+ "rope_scaling": null,
41
+ "rope_theta": 1000000,
42
+ "tie_word_embeddings": false,
43
+ "torch_dtype": "float16",
44
+ "transformers_version": "4.39.0.dev0",
45
+ "use_cache": false,
46
+ "vocab_size": 92544
47
+ }
epoch1.5_ckpt/configuration_internlm2.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ InternLM2 model configuration"""
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
+
26
+
27
+ # Modified from transformers.model.llama.configuration_llama.LlamaConfig
28
+ class InternLM2Config(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
31
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
32
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32000):
40
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`InternLM2Model`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 11008):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
61
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
62
+ just in case (e.g., 512 or 1024 or 2048).
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
66
+ The epsilon used by the rms normalization layers.
67
+ use_cache (`bool`, *optional*, defaults to `True`):
68
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
69
+ relevant if `config.is_decoder=True`.
70
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
71
+ Whether to tie weight embeddings
72
+ Example:
73
+
74
+ """
75
+ model_type = "internlm2"
76
+ _auto_class = "AutoConfig"
77
+
78
+ def __init__( # pylint: disable=W0102
79
+ self,
80
+ vocab_size=103168,
81
+ hidden_size=4096,
82
+ intermediate_size=11008,
83
+ num_hidden_layers=32,
84
+ num_attention_heads=32,
85
+ num_key_value_heads=None,
86
+ hidden_act="silu",
87
+ max_position_embeddings=2048,
88
+ initializer_range=0.02,
89
+ rms_norm_eps=1e-6,
90
+ use_cache=True,
91
+ pad_token_id=0,
92
+ bos_token_id=1,
93
+ eos_token_id=2,
94
+ tie_word_embeddings=False,
95
+ bias=True,
96
+ rope_theta=10000,
97
+ rope_scaling=None,
98
+ attn_implementation="eager",
99
+ **kwargs,
100
+ ):
101
+ self.vocab_size = vocab_size
102
+ self.max_position_embeddings = max_position_embeddings
103
+ self.hidden_size = hidden_size
104
+ self.intermediate_size = intermediate_size
105
+ self.num_hidden_layers = num_hidden_layers
106
+ self.num_attention_heads = num_attention_heads
107
+ self.bias = bias
108
+
109
+ if num_key_value_heads is None:
110
+ num_key_value_heads = num_attention_heads
111
+ self.num_key_value_heads = num_key_value_heads
112
+
113
+ self.hidden_act = hidden_act
114
+ self.initializer_range = initializer_range
115
+ self.rms_norm_eps = rms_norm_eps
116
+ self.use_cache = use_cache
117
+ self.rope_theta = rope_theta
118
+ self.rope_scaling = rope_scaling
119
+ self._rope_scaling_validation()
120
+
121
+ self.attn_implementation = attn_implementation
122
+ if self.attn_implementation is None:
123
+ self.attn_implementation = "eager"
124
+ super().__init__(
125
+ pad_token_id=pad_token_id,
126
+ bos_token_id=bos_token_id,
127
+ eos_token_id=eos_token_id,
128
+ tie_word_embeddings=tie_word_embeddings,
129
+ **kwargs,
130
+ )
131
+
132
+ def _rope_scaling_validation(self):
133
+ """
134
+ Validate the `rope_scaling` configuration.
135
+ """
136
+ if self.rope_scaling is None:
137
+ return
138
+
139
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
140
+ raise ValueError(
141
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
142
+ f"got {self.rope_scaling}"
143
+ )
144
+ rope_scaling_type = self.rope_scaling.get("type", None)
145
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
146
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
147
+ raise ValueError(
148
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
149
+ )
150
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
151
+ raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
epoch1.5_ckpt/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.39.0.dev0"
7
+ }
epoch1.5_ckpt/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eec6399ec7679fa5bef3e198007653bd38da3c9d9f893ba36eaaef5e6740498
3
+ size 1537498688
epoch1.5_ckpt/modeling_internlm2.py ADDED
@@ -0,0 +1,1391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ PyTorch InternLM2 model."""
17
+ import math
18
+ import queue
19
+ import threading
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from einops import rearrange
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ SequenceClassifierOutputWithPast,
34
+ )
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import (
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ replace_return_docstrings,
41
+ )
42
+
43
+ try:
44
+ from transformers.generation.streamers import BaseStreamer
45
+ except: # noqa # pylint: disable=bare-except
46
+ BaseStreamer = None
47
+
48
+ from .configuration_internlm2 import InternLM2Config
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CONFIG_FOR_DOC = "InternLM2Config"
53
+
54
+ flash_attn_func, flash_attn_varlen_func = None, None
55
+ pad_input, index_first_axis, unpad_input = None, None, None
56
+ def _import_flash_attn():
57
+ global flash_attn_func, flash_attn_varlen_func
58
+ global pad_input, index_first_axis, unpad_input
59
+ try:
60
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
61
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
62
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
63
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
64
+ except ImportError:
65
+ raise ImportError("flash_attn is not installed.")
66
+
67
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
+ def _get_unpad_data(attention_mask):
69
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
70
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
72
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
73
+ return (
74
+ indices,
75
+ cu_seqlens,
76
+ max_seqlen_in_batch,
77
+ )
78
+
79
+
80
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
81
+ def _make_causal_mask(
82
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
83
+ ):
84
+ """
85
+ Make causal mask used for bi-directional self-attention.
86
+ """
87
+ bsz, tgt_len = input_ids_shape
88
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
89
+ mask_cond = torch.arange(mask.size(-1), device=device)
90
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
91
+ mask = mask.to(dtype)
92
+
93
+ if past_key_values_length > 0:
94
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
95
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
96
+
97
+
98
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
99
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
100
+ """
101
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
102
+ """
103
+ bsz, src_len = mask.size()
104
+ tgt_len = tgt_len if tgt_len is not None else src_len
105
+
106
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
107
+
108
+ inverted_mask = 1.0 - expanded_mask
109
+
110
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
111
+
112
+
113
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
114
+ class InternLM2RMSNorm(nn.Module):
115
+ def __init__(self, hidden_size, eps=1e-6):
116
+ """
117
+ InternLM2RMSNorm is equivalent to T5LayerNorm
118
+ """
119
+ super().__init__()
120
+ self.weight = nn.Parameter(torch.ones(hidden_size))
121
+ self.variance_epsilon = eps
122
+
123
+ def forward(self, hidden_states):
124
+ input_dtype = hidden_states.dtype
125
+ hidden_states = hidden_states.to(torch.float32)
126
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
127
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
128
+ return self.weight * hidden_states.to(input_dtype)
129
+
130
+
131
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
132
+ class InternLM2RotaryEmbedding(nn.Module):
133
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
134
+ super().__init__()
135
+
136
+ self.dim = dim
137
+ self.max_position_embeddings = max_position_embeddings
138
+ self.base = base
139
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
140
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
141
+
142
+ # Build here to make `torch.jit.trace` work.
143
+ self._set_cos_sin_cache(
144
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
145
+ )
146
+
147
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
148
+ self.max_seq_len_cached = seq_len
149
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
150
+
151
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
152
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
153
+ emb = torch.cat((freqs, freqs), dim=-1)
154
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
155
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
156
+
157
+ def forward(self, x, seq_len=None):
158
+ # x: [bs, num_attention_heads, seq_len, head_size]
159
+ if seq_len > self.max_seq_len_cached:
160
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
161
+
162
+ return (
163
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
164
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
165
+ )
166
+
167
+
168
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
169
+ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
170
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
171
+
172
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
173
+ self.scaling_factor = scaling_factor
174
+ super().__init__(dim, max_position_embeddings, base, device)
175
+
176
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
177
+ self.max_seq_len_cached = seq_len
178
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
179
+ t = t / self.scaling_factor
180
+
181
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
182
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
183
+ emb = torch.cat((freqs, freqs), dim=-1)
184
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
185
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
186
+
187
+
188
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
189
+ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
190
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
191
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
192
+ """
193
+
194
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
195
+ self.scaling_factor = scaling_factor
196
+ super().__init__(dim, max_position_embeddings, base, device)
197
+
198
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
199
+ self.max_seq_len_cached = seq_len
200
+
201
+ if seq_len > self.max_position_embeddings:
202
+ base = self.base * (
203
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
204
+ ) ** (self.dim / (self.dim - 2))
205
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
206
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
207
+
208
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
209
+
210
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
211
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
212
+ emb = torch.cat((freqs, freqs), dim=-1)
213
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
214
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
215
+
216
+
217
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
218
+ def rotate_half(x):
219
+ """Rotates half the hidden dims of the input."""
220
+ x1 = x[..., : x.shape[-1] // 2]
221
+ x2 = x[..., x.shape[-1] // 2 :]
222
+ return torch.cat((-x2, x1), dim=-1)
223
+
224
+
225
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
226
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
227
+ """Applies Rotary Position Embedding to the query and key tensors."""
228
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
229
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
230
+ q_embed = (q * cos) + (rotate_half(q) * sin)
231
+ k_embed = (k * cos) + (rotate_half(k) * sin)
232
+ return q_embed, k_embed
233
+
234
+
235
+ class InternLM2MLP(nn.Module):
236
+ def __init__(self, config):
237
+ super().__init__()
238
+ self.config = config
239
+ self.hidden_size = config.hidden_size
240
+ self.intermediate_size = config.intermediate_size
241
+ self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
242
+ self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
243
+ self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
244
+ self.act_fn = ACT2FN[config.hidden_act]
245
+
246
+ def forward(self, x):
247
+ down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
248
+
249
+ return down_proj
250
+
251
+
252
+ # Copied from transformers.model.llama.modeling_llama.repeat_kv
253
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
254
+ """
255
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
256
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
257
+ """
258
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
259
+ if n_rep == 1:
260
+ return hidden_states
261
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
262
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
263
+
264
+
265
+ # Modified from transformers.model.llama.modeling_llama.LlamaAttention
266
+ class InternLM2Attention(nn.Module):
267
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
268
+
269
+ def __init__(self, config: InternLM2Config):
270
+ super().__init__()
271
+ self.config = config
272
+ self.hidden_size = config.hidden_size
273
+ self.num_heads = config.num_attention_heads
274
+ self.head_dim = self.hidden_size // self.num_heads
275
+ self.num_key_value_heads = config.num_key_value_heads
276
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
277
+ self.max_position_embeddings = config.max_position_embeddings
278
+ self.is_causal = True
279
+
280
+ if (self.head_dim * self.num_heads) != self.hidden_size:
281
+ raise ValueError(
282
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
283
+ f" and `num_heads`: {self.num_heads})."
284
+ )
285
+
286
+ self.wqkv = nn.Linear(
287
+ self.hidden_size,
288
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
289
+ bias=config.bias,
290
+ )
291
+
292
+ self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
293
+ self._init_rope()
294
+
295
+ def _init_rope(self):
296
+ if self.config.rope_scaling is None:
297
+ self.rotary_emb = InternLM2RotaryEmbedding(
298
+ self.head_dim,
299
+ max_position_embeddings=self.max_position_embeddings,
300
+ base=self.config.rope_theta,
301
+ )
302
+ else:
303
+ scaling_type = self.config.rope_scaling["type"]
304
+ scaling_factor = self.config.rope_scaling["factor"]
305
+ if scaling_type == "dynamic":
306
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
307
+ self.head_dim,
308
+ max_position_embeddings=self.max_position_embeddings,
309
+ base=self.config.rope_theta,
310
+ scaling_factor=scaling_factor,
311
+ )
312
+ elif scaling_type == "linear":
313
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
314
+ self.head_dim,
315
+ max_position_embeddings=self.max_position_embeddings,
316
+ base=self.config.rope_theta,
317
+ scaling_factor=scaling_factor,
318
+ )
319
+ else:
320
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
321
+ return self.rotary_emb
322
+
323
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
324
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
325
+
326
+ def forward(
327
+ self,
328
+ hidden_states: torch.Tensor,
329
+ attention_mask: Optional[torch.Tensor] = None,
330
+ position_ids: Optional[torch.LongTensor] = None,
331
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
332
+ output_attentions: bool = False,
333
+ use_cache: bool = False,
334
+ **kwargs,
335
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
336
+ if "padding_mask" in kwargs:
337
+ warnings.warn(
338
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
339
+ "Please make sure use `attention_mask` instead.`"
340
+ )
341
+
342
+ bsz, q_len, _ = hidden_states.size()
343
+
344
+ qkv_states = self.wqkv(hidden_states)
345
+
346
+ qkv_states = rearrange(
347
+ qkv_states,
348
+ "b q (h gs d) -> b q h gs d",
349
+ gs=2 + self.num_key_value_groups,
350
+ d=self.head_dim,
351
+ )
352
+
353
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
354
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
355
+ key_states = qkv_states[..., -2, :]
356
+ value_states = qkv_states[..., -1, :]
357
+
358
+ query_states = query_states.transpose(1, 2)
359
+ key_states = key_states.transpose(1, 2)
360
+ value_states = value_states.transpose(1, 2)
361
+
362
+ kv_seq_len = key_states.shape[-2]
363
+ if past_key_value is not None:
364
+ kv_seq_len += past_key_value[0].shape[-2]
365
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
366
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
367
+
368
+ if past_key_value is not None:
369
+ # reuse k, v, self_attention
370
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
371
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
372
+
373
+ past_key_value = (key_states, value_states) if use_cache else None
374
+
375
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
376
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
377
+
378
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
379
+
380
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
381
+ raise ValueError(
382
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
383
+ f" {attn_weights.size()}"
384
+ )
385
+
386
+ if attention_mask is not None:
387
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
388
+ raise ValueError(
389
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
390
+ )
391
+ attn_weights = attn_weights + attention_mask
392
+
393
+ # upcast attention to fp32
394
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
395
+ attn_output = torch.matmul(attn_weights, value_states)
396
+
397
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
398
+ raise ValueError(
399
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
400
+ f" {attn_output.size()}"
401
+ )
402
+
403
+ attn_output = attn_output.transpose(1, 2).contiguous()
404
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
405
+
406
+ attn_output = self.wo(attn_output)
407
+
408
+ if not output_attentions:
409
+ attn_weights = None
410
+
411
+ return attn_output, attn_weights, past_key_value
412
+
413
+
414
+ # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
415
+ class InternLM2FlashAttention2(InternLM2Attention):
416
+ """
417
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
418
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
419
+ flash attention and deal with padding tokens in case the input contains any of them.
420
+ """
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: Optional[torch.LongTensor] = None,
426
+ position_ids: Optional[torch.LongTensor] = None,
427
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
428
+ output_attentions: bool = False,
429
+ use_cache: bool = False,
430
+ **kwargs,
431
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
432
+ # InternLM2FlashAttention2 attention does not support output_attentions
433
+ if "padding_mask" in kwargs:
434
+ warnings.warn(
435
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
436
+ "Please make sure use `attention_mask` instead.`"
437
+ )
438
+
439
+ # overwrite attention_mask with padding_mask
440
+ attention_mask = kwargs.pop("padding_mask")
441
+
442
+ output_attentions = False
443
+
444
+ bsz, q_len, _ = hidden_states.size()
445
+
446
+ qkv_states = self.wqkv(hidden_states)
447
+
448
+ qkv_states = rearrange(
449
+ qkv_states,
450
+ "b q (h gs d) -> b q h gs d",
451
+ gs=2 + self.num_key_value_groups,
452
+ d=self.head_dim,
453
+ )
454
+
455
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
456
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
457
+ key_states = qkv_states[..., -2, :]
458
+ value_states = qkv_states[..., -1, :]
459
+
460
+ query_states = query_states.transpose(1, 2)
461
+ key_states = key_states.transpose(1, 2)
462
+ value_states = value_states.transpose(1, 2)
463
+
464
+ kv_seq_len = key_states.shape[-2]
465
+ if past_key_value is not None:
466
+ kv_seq_len += past_key_value[0].shape[-2]
467
+
468
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
469
+
470
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
471
+
472
+ if past_key_value is not None:
473
+ # reuse k, v, self_attention
474
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
475
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
476
+
477
+ past_key_value = (key_states, value_states) if use_cache else None
478
+
479
+ query_states = query_states.transpose(1, 2)
480
+ key_states = key_states.transpose(1, 2)
481
+ value_states = value_states.transpose(1, 2)
482
+
483
+ attn_output = self._flash_attention_forward(
484
+ query_states, key_states, value_states, attention_mask, q_len
485
+ )
486
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
487
+ attn_output = self.wo(attn_output)
488
+
489
+ if not output_attentions:
490
+ attn_weights = None
491
+
492
+ return attn_output, attn_weights, past_key_value
493
+
494
+ def _flash_attention_forward(
495
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
496
+ ):
497
+ """
498
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
499
+ first unpad the input, then computes the attention scores and pad the final attention scores.
500
+
501
+ Args:
502
+ query_states (`torch.Tensor`):
503
+ Input query states to be passed to Flash Attention API
504
+ key_states (`torch.Tensor`):
505
+ Input key states to be passed to Flash Attention API
506
+ value_states (`torch.Tensor`):
507
+ Input value states to be passed to Flash Attention API
508
+ attention_mask (`torch.Tensor`):
509
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
510
+ position of padding tokens and 1 for the position of non-padding tokens.
511
+ dropout (`int`, *optional*):
512
+ Attention dropout
513
+ softmax_scale (`float`, *optional*):
514
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
515
+ """
516
+ # Contains at least one padding token in the sequence
517
+ causal = self.is_causal and query_length != 1
518
+ if attention_mask is not None:
519
+ batch_size = query_states.shape[0]
520
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
521
+ query_states, key_states, value_states, attention_mask, query_length
522
+ )
523
+
524
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
525
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
526
+
527
+ attn_output_unpad = flash_attn_varlen_func(
528
+ query_states,
529
+ key_states,
530
+ value_states,
531
+ cu_seqlens_q=cu_seqlens_q,
532
+ cu_seqlens_k=cu_seqlens_k,
533
+ max_seqlen_q=max_seqlen_in_batch_q,
534
+ max_seqlen_k=max_seqlen_in_batch_k,
535
+ dropout_p=dropout,
536
+ softmax_scale=softmax_scale,
537
+ causal=causal,
538
+ )
539
+
540
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
541
+ else:
542
+ attn_output = flash_attn_func(
543
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
544
+ )
545
+
546
+ return attn_output
547
+
548
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
549
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
550
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
551
+
552
+ key_layer = index_first_axis(
553
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
554
+ )
555
+ value_layer = index_first_axis(
556
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
557
+ )
558
+
559
+ if query_length == kv_seq_len:
560
+ query_layer = index_first_axis(
561
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
562
+ )
563
+ cu_seqlens_q = cu_seqlens_k
564
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
565
+ indices_q = indices_k
566
+ elif query_length == 1:
567
+ max_seqlen_in_batch_q = 1
568
+ cu_seqlens_q = torch.arange(
569
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
570
+ ) # There is a memcpy here, that is very bad.
571
+ indices_q = cu_seqlens_q[:-1]
572
+ query_layer = query_layer.squeeze(1)
573
+ else:
574
+ # The -q_len: slice assumes left padding.
575
+ attention_mask = attention_mask[:, -query_length:]
576
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
577
+
578
+ return (
579
+ query_layer,
580
+ key_layer,
581
+ value_layer,
582
+ indices_q.to(torch.int64),
583
+ (cu_seqlens_q, cu_seqlens_k),
584
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
585
+ )
586
+
587
+ INTERNLM2_ATTENTION_CLASSES = {
588
+ "eager": InternLM2Attention,
589
+ "flash_attention_2": InternLM2FlashAttention2,
590
+ }
591
+
592
+ # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
593
+ class InternLM2DecoderLayer(nn.Module):
594
+ def __init__(self, config: InternLM2Config):
595
+ super().__init__()
596
+ self.hidden_size = config.hidden_size
597
+
598
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
599
+
600
+ self.feed_forward = InternLM2MLP(config)
601
+ self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
602
+ self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
603
+
604
+ def forward(
605
+ self,
606
+ hidden_states: torch.Tensor,
607
+ attention_mask: Optional[torch.Tensor] = None,
608
+ position_ids: Optional[torch.LongTensor] = None,
609
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
610
+ output_attentions: Optional[bool] = False,
611
+ use_cache: Optional[bool] = False,
612
+ **kwargs,
613
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
614
+ """
615
+ Args:
616
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
617
+ attention_mask (`torch.FloatTensor`, *optional*):
618
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
619
+ query_sequence_length, key_sequence_length)` if default attention is used.
620
+ output_attentions (`bool`, *optional*):
621
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
622
+ returned tensors for more detail.
623
+ use_cache (`bool`, *optional*):
624
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
625
+ (see `past_key_values`).
626
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
627
+ """
628
+ if "padding_mask" in kwargs:
629
+ warnings.warn(
630
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
631
+ "Please make sure use `attention_mask` instead.`"
632
+ )
633
+
634
+ residual = hidden_states
635
+
636
+ hidden_states = self.attention_norm(hidden_states)
637
+
638
+ # Self Attention
639
+ hidden_states, self_attn_weights, present_key_value = self.attention(
640
+ hidden_states=hidden_states,
641
+ attention_mask=attention_mask,
642
+ position_ids=position_ids,
643
+ past_key_value=past_key_value,
644
+ output_attentions=output_attentions,
645
+ use_cache=use_cache,
646
+ **kwargs,
647
+ )
648
+ hidden_states = residual + hidden_states
649
+
650
+ # Fully Connected
651
+ residual = hidden_states
652
+ hidden_states = self.ffn_norm(hidden_states)
653
+ hidden_states = self.feed_forward(hidden_states)
654
+ hidden_states = residual + hidden_states
655
+
656
+ outputs = (hidden_states,)
657
+
658
+ if output_attentions:
659
+ outputs += (self_attn_weights,)
660
+
661
+ if use_cache:
662
+ outputs += (present_key_value,)
663
+
664
+ return outputs
665
+
666
+
667
+ InternLM2_START_DOCSTRING = r"""
668
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
669
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
670
+ etc.)
671
+
672
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
673
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
674
+ and behavior.
675
+
676
+ Parameters:
677
+ config ([`InternLM2Config`]):
678
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
679
+ load the weights associated with the model, only the configuration. Check out the
680
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
681
+ """
682
+
683
+
684
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
685
+ @add_start_docstrings(
686
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
687
+ InternLM2_START_DOCSTRING,
688
+ )
689
+ class InternLM2PreTrainedModel(PreTrainedModel):
690
+ config_class = InternLM2Config
691
+ base_model_prefix = "model"
692
+ supports_gradient_checkpointing = True
693
+ _no_split_modules = ["InternLM2DecoderLayer"]
694
+ _skip_keys_device_placement = "past_key_values"
695
+
696
+ def _init_weights(self, module):
697
+ std = self.config.initializer_range
698
+ if isinstance(module, nn.Linear):
699
+ module.weight.data.normal_(mean=0.0, std=std)
700
+ if module.bias is not None:
701
+ module.bias.data.zero_()
702
+ elif isinstance(module, nn.Embedding):
703
+ module.weight.data.normal_(mean=0.0, std=std)
704
+ if module.padding_idx is not None:
705
+ module.weight.data[module.padding_idx].zero_()
706
+
707
+
708
+ InternLM2_INPUTS_DOCSTRING = r"""
709
+ Args:
710
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
711
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
712
+ it.
713
+
714
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
715
+ [`PreTrainedTokenizer.__call__`] for details.
716
+
717
+ [What are input IDs?](../glossary#input-ids)
718
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
719
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
720
+
721
+ - 1 for tokens that are **not masked**,
722
+ - 0 for tokens that are **masked**.
723
+
724
+ [What are attention masks?](../glossary#attention-mask)
725
+
726
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
727
+ [`PreTrainedTokenizer.__call__`] for details.
728
+
729
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
730
+ `past_key_values`).
731
+
732
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
733
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
734
+ information on the default strategy.
735
+
736
+ - 1 indicates the head is **not masked**,
737
+ - 0 indicates the head is **masked**.
738
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
739
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
740
+ config.n_positions - 1]`.
741
+
742
+ [What are position IDs?](../glossary#position-ids)
743
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
744
+ when `config.use_cache=True`):
745
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
746
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
747
+ `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
748
+
749
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
750
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
751
+
752
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
753
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
754
+ of shape `(batch_size, sequence_length)`.
755
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
756
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
757
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
758
+ model's internal embedding lookup matrix.
759
+ use_cache (`bool`, *optional*):
760
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
761
+ `past_key_values`).
762
+ output_attentions (`bool`, *optional*):
763
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
764
+ tensors for more detail.
765
+ output_hidden_states (`bool`, *optional*):
766
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
767
+ more detail.
768
+ return_dict (`bool`, *optional*):
769
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
770
+ """
771
+
772
+
773
+ # Modified from transformers.model.llama.modeling_llama.LlamaModel
774
+ @add_start_docstrings(
775
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
776
+ InternLM2_START_DOCSTRING,
777
+ )
778
+ class InternLM2Model(InternLM2PreTrainedModel):
779
+ """
780
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
781
+
782
+ Args:
783
+ config: InternLM2Config
784
+ """
785
+
786
+ _auto_class = "AutoModel"
787
+
788
+ def __init__(self, config: InternLM2Config):
789
+ super().__init__(config)
790
+ self.padding_idx = config.pad_token_id
791
+ self.vocab_size = config.vocab_size
792
+ self.config = config
793
+
794
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
795
+
796
+ self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
797
+ self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
798
+
799
+ self.gradient_checkpointing = False
800
+ # Initialize weights and apply final processing
801
+ self.post_init()
802
+
803
+ def get_input_embeddings(self):
804
+ return self.tok_embeddings
805
+
806
+ def set_input_embeddings(self, value):
807
+ self.tok_embeddings = value
808
+
809
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
810
+ # create causal mask
811
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
812
+ combined_attention_mask = None
813
+ if input_shape[-1] > 1:
814
+ combined_attention_mask = _make_causal_mask(
815
+ input_shape,
816
+ inputs_embeds.dtype,
817
+ device=inputs_embeds.device,
818
+ past_key_values_length=past_key_values_length,
819
+ )
820
+
821
+ if attention_mask is not None:
822
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
823
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
824
+ inputs_embeds.device
825
+ )
826
+ combined_attention_mask = (
827
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
828
+ )
829
+
830
+ return combined_attention_mask
831
+
832
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
833
+ def forward(
834
+ self,
835
+ input_ids: torch.LongTensor = None,
836
+ attention_mask: Optional[torch.Tensor] = None,
837
+ position_ids: Optional[torch.LongTensor] = None,
838
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
839
+ inputs_embeds: Optional[torch.FloatTensor] = None,
840
+ use_cache: Optional[bool] = None,
841
+ output_attentions: Optional[bool] = None,
842
+ output_hidden_states: Optional[bool] = None,
843
+ return_dict: Optional[bool] = None,
844
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
845
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
846
+ output_hidden_states = (
847
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
848
+ )
849
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
850
+
851
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
852
+
853
+ if self.config.attn_implementation == "flash_attention_2":
854
+ _import_flash_attn()
855
+
856
+ # retrieve input_ids and inputs_embeds
857
+ if input_ids is not None and inputs_embeds is not None:
858
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
859
+ elif input_ids is not None:
860
+ batch_size, seq_length = input_ids.shape[:2]
861
+ elif inputs_embeds is not None:
862
+ batch_size, seq_length = inputs_embeds.shape[:2]
863
+ else:
864
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
865
+
866
+ seq_length_with_past = seq_length
867
+ past_key_values_length = 0
868
+ if past_key_values is not None:
869
+ past_key_values_length = past_key_values[0][0].shape[2]
870
+ seq_length_with_past = seq_length_with_past + past_key_values_length
871
+
872
+ if position_ids is None:
873
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
874
+ position_ids = torch.arange(
875
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
876
+ )
877
+ position_ids = position_ids.unsqueeze(0)
878
+
879
+ if inputs_embeds is None:
880
+ inputs_embeds = self.tok_embeddings(input_ids)
881
+
882
+ if self.config.attn_implementation == "flash_attention_2":
883
+ # 2d mask is passed through the layers
884
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
885
+ else:
886
+ if attention_mask is None:
887
+ attention_mask = torch.ones(
888
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
889
+ )
890
+ attention_mask = self._prepare_decoder_attention_mask(
891
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
892
+ )
893
+
894
+ # embed positions
895
+ hidden_states = inputs_embeds
896
+
897
+ if self.gradient_checkpointing and self.training:
898
+ if use_cache:
899
+ logger.warning_once(
900
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
901
+ )
902
+ use_cache = False
903
+
904
+ # decoder layers
905
+ all_hidden_states = () if output_hidden_states else None
906
+ all_self_attns = () if output_attentions else None
907
+ next_decoder_cache = () if use_cache else None
908
+
909
+ for idx, decoder_layer in enumerate(self.layers):
910
+ if output_hidden_states:
911
+ all_hidden_states += (hidden_states,)
912
+
913
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
914
+
915
+ if self.gradient_checkpointing and self.training:
916
+
917
+ def create_custom_forward(module):
918
+ def custom_forward(*inputs):
919
+ # None for past_key_value
920
+ return module(*inputs, output_attentions, None)
921
+
922
+ return custom_forward
923
+
924
+ layer_outputs = torch.utils.checkpoint.checkpoint(
925
+ create_custom_forward(decoder_layer),
926
+ hidden_states,
927
+ attention_mask,
928
+ position_ids,
929
+ None,
930
+ )
931
+ else:
932
+ layer_outputs = decoder_layer(
933
+ hidden_states,
934
+ attention_mask=attention_mask,
935
+ position_ids=position_ids,
936
+ past_key_value=past_key_value,
937
+ output_attentions=output_attentions,
938
+ use_cache=use_cache,
939
+ )
940
+
941
+ hidden_states = layer_outputs[0]
942
+
943
+ if use_cache:
944
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
945
+
946
+ if output_attentions:
947
+ all_self_attns += (layer_outputs[1],)
948
+
949
+ hidden_states = self.norm(hidden_states)
950
+
951
+ # add hidden states from the last decoder layer
952
+ if output_hidden_states:
953
+ all_hidden_states += (hidden_states,)
954
+
955
+ next_cache = next_decoder_cache if use_cache else None
956
+ if not return_dict:
957
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
958
+ return BaseModelOutputWithPast(
959
+ last_hidden_state=hidden_states,
960
+ past_key_values=next_cache,
961
+ hidden_states=all_hidden_states,
962
+ attentions=all_self_attns,
963
+ )
964
+
965
+
966
+ # Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
967
+ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
968
+ _auto_class = "AutoModelForCausalLM"
969
+
970
+ _tied_weights_keys = ["output.weight"]
971
+
972
+ def __init__(self, config):
973
+ super().__init__(config)
974
+ self.model = InternLM2Model(config)
975
+ self.vocab_size = config.vocab_size
976
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
977
+
978
+ # Initialize weights and apply final processing
979
+ self.post_init()
980
+
981
+ def get_input_embeddings(self):
982
+ return self.model.tok_embeddings
983
+
984
+ def set_input_embeddings(self, value):
985
+ self.model.tok_embeddings = value
986
+
987
+ def get_output_embeddings(self):
988
+ return self.output
989
+
990
+ def set_output_embeddings(self, new_embeddings):
991
+ self.output = new_embeddings
992
+
993
+ def set_decoder(self, decoder):
994
+ self.model = decoder
995
+
996
+ def get_decoder(self):
997
+ return self.model
998
+
999
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1000
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1001
+ def forward(
1002
+ self,
1003
+ input_ids: torch.LongTensor = None,
1004
+ attention_mask: Optional[torch.Tensor] = None,
1005
+ position_ids: Optional[torch.LongTensor] = None,
1006
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1007
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1008
+ labels: Optional[torch.LongTensor] = None,
1009
+ use_cache: Optional[bool] = None,
1010
+ output_attentions: Optional[bool] = None,
1011
+ output_hidden_states: Optional[bool] = None,
1012
+ return_dict: Optional[bool] = None,
1013
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1014
+ r"""
1015
+ Args:
1016
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1017
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1018
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1019
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1020
+
1021
+ Returns:
1022
+
1023
+ Example:
1024
+
1025
+ ```python
1026
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
1027
+
1028
+ >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1029
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1030
+
1031
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1032
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1033
+
1034
+ >>> # Generate
1035
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1036
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1037
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1038
+ ```"""
1039
+
1040
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1041
+ output_hidden_states = (
1042
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1043
+ )
1044
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1045
+
1046
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1047
+ outputs = self.model(
1048
+ input_ids=input_ids,
1049
+ attention_mask=attention_mask,
1050
+ position_ids=position_ids,
1051
+ past_key_values=past_key_values,
1052
+ inputs_embeds=inputs_embeds,
1053
+ use_cache=use_cache,
1054
+ output_attentions=output_attentions,
1055
+ output_hidden_states=output_hidden_states,
1056
+ return_dict=return_dict,
1057
+ )
1058
+
1059
+ hidden_states = outputs[0]
1060
+ logits = self.output(hidden_states)
1061
+ logits = logits.float()
1062
+
1063
+ loss = None
1064
+ if labels is not None:
1065
+ # Shift so that tokens < n predict n
1066
+ shift_logits = logits[..., :-1, :].contiguous()
1067
+ shift_labels = labels[..., 1:].contiguous()
1068
+ # Flatten the tokens
1069
+ loss_fct = CrossEntropyLoss()
1070
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1071
+ shift_labels = shift_labels.view(-1)
1072
+ # Enable model parallelism
1073
+ shift_labels = shift_labels.to(shift_logits.device)
1074
+ loss = loss_fct(shift_logits, shift_labels)
1075
+
1076
+ if not return_dict:
1077
+ output = (logits,) + outputs[1:]
1078
+ return (loss,) + output if loss is not None else output
1079
+
1080
+ return CausalLMOutputWithPast(
1081
+ loss=loss,
1082
+ logits=logits,
1083
+ past_key_values=outputs.past_key_values,
1084
+ hidden_states=outputs.hidden_states,
1085
+ attentions=outputs.attentions,
1086
+ )
1087
+
1088
+ def prepare_inputs_for_generation(
1089
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1090
+ ):
1091
+ if past_key_values is not None:
1092
+ past_length = past_key_values[0][0].shape[2]
1093
+
1094
+ # Some generation methods already pass only the last input ID
1095
+ if input_ids.shape[1] > past_length:
1096
+ remove_prefix_length = past_length
1097
+ else:
1098
+ # Default to old behavior: keep only final ID
1099
+ remove_prefix_length = input_ids.shape[1] - 1
1100
+
1101
+ input_ids = input_ids[:, remove_prefix_length:]
1102
+
1103
+ position_ids = kwargs.get("position_ids", None)
1104
+ if attention_mask is not None and position_ids is None:
1105
+ # create position_ids on the fly for batch generation
1106
+ position_ids = attention_mask.long().cumsum(-1) - 1
1107
+ position_ids.masked_fill_(attention_mask == 0, 1)
1108
+ if past_key_values:
1109
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1110
+
1111
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1112
+ if inputs_embeds is not None and past_key_values is None:
1113
+ model_inputs = {"inputs_embeds": inputs_embeds}
1114
+ else:
1115
+ model_inputs = {"input_ids": input_ids}
1116
+
1117
+ model_inputs.update(
1118
+ {
1119
+ "position_ids": position_ids,
1120
+ "past_key_values": past_key_values,
1121
+ "use_cache": kwargs.get("use_cache"),
1122
+ "attention_mask": attention_mask,
1123
+ }
1124
+ )
1125
+ return model_inputs
1126
+
1127
+ @staticmethod
1128
+ def _reorder_cache(past_key_values, beam_idx):
1129
+ reordered_past = ()
1130
+ for layer_past in past_key_values:
1131
+ reordered_past += (
1132
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1133
+ )
1134
+ return reordered_past
1135
+
1136
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
1137
+ if tokenizer.add_bos_token:
1138
+ prompt = ""
1139
+ else:
1140
+ prompt = tokenizer.bos_token
1141
+ if meta_instruction:
1142
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1143
+ for record in history:
1144
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1145
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
1146
+ return tokenizer([prompt], return_tensors="pt")
1147
+
1148
+ @torch.no_grad()
1149
+ def chat(
1150
+ self,
1151
+ tokenizer,
1152
+ query: str,
1153
+ history: List[Tuple[str, str]] = [],
1154
+ streamer: Optional[BaseStreamer] = None,
1155
+ max_new_tokens: int = 1024,
1156
+ do_sample: bool = True,
1157
+ temperature: float = 0.8,
1158
+ top_p: float = 0.8,
1159
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1160
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1161
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
1162
+ **kwargs,
1163
+ ):
1164
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1165
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1166
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1167
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]]
1168
+ outputs = self.generate(
1169
+ **inputs,
1170
+ streamer=streamer,
1171
+ max_new_tokens=max_new_tokens,
1172
+ do_sample=do_sample,
1173
+ temperature=temperature,
1174
+ top_p=top_p,
1175
+ eos_token_id=eos_token_id,
1176
+ **kwargs,
1177
+ )
1178
+ outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1179
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
1180
+ response = response.split("<|im_end|>")[0]
1181
+ history = history + [(query, response)]
1182
+ return response, history
1183
+
1184
+ @torch.no_grad()
1185
+ def stream_chat(
1186
+ self,
1187
+ tokenizer,
1188
+ query: str,
1189
+ history: List[Tuple[str, str]] = [],
1190
+ max_new_tokens: int = 1024,
1191
+ do_sample: bool = True,
1192
+ temperature: float = 0.8,
1193
+ top_p: float = 0.8,
1194
+ **kwargs,
1195
+ ):
1196
+ """
1197
+ Return a generator in format: (response, history)
1198
+ Eg.
1199
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
1200
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
1201
+ """
1202
+ if BaseStreamer is None:
1203
+ raise ModuleNotFoundError(
1204
+ "The version of `transformers` is too low. Please make sure "
1205
+ "that you have installed `transformers>=4.28.0`."
1206
+ )
1207
+
1208
+ response_queue = queue.Queue(maxsize=20)
1209
+
1210
+ class ChatStreamer(BaseStreamer):
1211
+ def __init__(self, tokenizer) -> None:
1212
+ super().__init__()
1213
+ self.tokenizer = tokenizer
1214
+ self.queue = response_queue
1215
+ self.query = query
1216
+ self.history = history
1217
+ self.response = ""
1218
+ self.cache = []
1219
+ self.received_inputs = False
1220
+ self.queue.put((self.response, history + [(self.query, self.response)]))
1221
+
1222
+ def put(self, value):
1223
+ if len(value.shape) > 1 and value.shape[0] > 1:
1224
+ raise ValueError("ChatStreamer only supports batch size 1")
1225
+ elif len(value.shape) > 1:
1226
+ value = value[0]
1227
+
1228
+ if not self.received_inputs:
1229
+ # The first received value is input_ids, ignore here
1230
+ self.received_inputs = True
1231
+ return
1232
+
1233
+ self.cache.extend(value.tolist())
1234
+ token = self.tokenizer.decode(self.cache, skip_special_tokens=True)
1235
+ if token.strip() != "<|im_end|>":
1236
+ self.response = self.response + token
1237
+ history = self.history + [(self.query, self.response)]
1238
+ self.queue.put((self.response, history))
1239
+ self.cache = []
1240
+ else:
1241
+ self.end()
1242
+
1243
+ def end(self):
1244
+ self.queue.put(None)
1245
+
1246
+ def stream_producer():
1247
+ return self.chat(
1248
+ tokenizer=tokenizer,
1249
+ query=query,
1250
+ streamer=ChatStreamer(tokenizer=tokenizer),
1251
+ history=history,
1252
+ max_new_tokens=max_new_tokens,
1253
+ do_sample=do_sample,
1254
+ temperature=temperature,
1255
+ top_p=top_p,
1256
+ **kwargs,
1257
+ )
1258
+
1259
+ def consumer():
1260
+ producer = threading.Thread(target=stream_producer)
1261
+ producer.start()
1262
+ while True:
1263
+ res = response_queue.get()
1264
+ if res is None:
1265
+ return
1266
+ yield res
1267
+
1268
+ return consumer()
1269
+
1270
+
1271
+ # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1272
+ @add_start_docstrings(
1273
+ """
1274
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
1275
+
1276
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
1277
+ as other causal models (e.g. GPT-2) do.
1278
+
1279
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1280
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1281
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1282
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1283
+ each row of the batch).
1284
+ """,
1285
+ InternLM2_START_DOCSTRING,
1286
+ )
1287
+ class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
1288
+ def __init__(self, config):
1289
+ super().__init__(config)
1290
+ self.num_labels = config.num_labels
1291
+ self.model = InternLM2Model(config)
1292
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1293
+
1294
+ # Initialize weights and apply final processing
1295
+ self.post_init()
1296
+
1297
+ def get_input_embeddings(self):
1298
+ return self.model.tok_embeddings
1299
+
1300
+ def set_input_embeddings(self, value):
1301
+ self.model.tok_embeddings = value
1302
+
1303
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1304
+ def forward(
1305
+ self,
1306
+ input_ids: torch.LongTensor = None,
1307
+ attention_mask: Optional[torch.Tensor] = None,
1308
+ position_ids: Optional[torch.LongTensor] = None,
1309
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1310
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1311
+ labels: Optional[torch.LongTensor] = None,
1312
+ use_cache: Optional[bool] = None,
1313
+ output_attentions: Optional[bool] = None,
1314
+ output_hidden_states: Optional[bool] = None,
1315
+ return_dict: Optional[bool] = None,
1316
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1317
+ r"""
1318
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1319
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1320
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1321
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1322
+ """
1323
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1324
+
1325
+ transformer_outputs = self.model(
1326
+ input_ids,
1327
+ attention_mask=attention_mask,
1328
+ position_ids=position_ids,
1329
+ past_key_values=past_key_values,
1330
+ inputs_embeds=inputs_embeds,
1331
+ use_cache=use_cache,
1332
+ output_attentions=output_attentions,
1333
+ output_hidden_states=output_hidden_states,
1334
+ return_dict=return_dict,
1335
+ )
1336
+ hidden_states = transformer_outputs[0]
1337
+ logits = self.score(hidden_states)
1338
+
1339
+ if input_ids is not None:
1340
+ batch_size = input_ids.shape[0]
1341
+ else:
1342
+ batch_size = inputs_embeds.shape[0]
1343
+
1344
+ if self.config.pad_token_id is None and batch_size != 1:
1345
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1346
+ if self.config.pad_token_id is None:
1347
+ sequence_lengths = -1
1348
+ else:
1349
+ if input_ids is not None:
1350
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1351
+ logits.device
1352
+ )
1353
+ else:
1354
+ sequence_lengths = -1
1355
+
1356
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1357
+
1358
+ loss = None
1359
+ if labels is not None:
1360
+ labels = labels.to(logits.device)
1361
+ if self.config.problem_type is None:
1362
+ if self.num_labels == 1:
1363
+ self.config.problem_type = "regression"
1364
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1365
+ self.config.problem_type = "single_label_classification"
1366
+ else:
1367
+ self.config.problem_type = "multi_label_classification"
1368
+
1369
+ if self.config.problem_type == "regression":
1370
+ loss_fct = MSELoss()
1371
+ if self.num_labels == 1:
1372
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1373
+ else:
1374
+ loss = loss_fct(pooled_logits, labels)
1375
+ elif self.config.problem_type == "single_label_classification":
1376
+ loss_fct = CrossEntropyLoss()
1377
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1378
+ elif self.config.problem_type == "multi_label_classification":
1379
+ loss_fct = BCEWithLogitsLoss()
1380
+ loss = loss_fct(pooled_logits, labels)
1381
+ if not return_dict:
1382
+ output = (pooled_logits,) + transformer_outputs[1:]
1383
+ return ((loss,) + output) if loss is not None else output
1384
+
1385
+ return SequenceClassifierOutputWithPast(
1386
+ loss=loss,
1387
+ logits=pooled_logits,
1388
+ past_key_values=transformer_outputs.past_key_values,
1389
+ hidden_states=transformer_outputs.hidden_states,
1390
+ attentions=transformer_outputs.attentions,
1391
+ )
epoch1.5_ckpt/projector/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ProjectorModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_projector.ProjectorConfig",
7
+ "AutoModel": "modeling_projector.ProjectorModel"
8
+ },
9
+ "bias": true,
10
+ "depth": 2,
11
+ "hidden_act": "gelu",
12
+ "llm_hidden_size": 2048,
13
+ "model_type": "projector",
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.39.0.dev0",
16
+ "visual_hidden_size": 2176
17
+ }
epoch1.5_ckpt/projector/configuration_projector.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class ProjectorConfig(PretrainedConfig):
6
+ model_type = 'projector'
7
+ _auto_class = 'AutoConfig'
8
+
9
+ def __init__(
10
+ self,
11
+ visual_hidden_size=4096,
12
+ llm_hidden_size=4096,
13
+ depth=2,
14
+ hidden_act='gelu',
15
+ bias=True,
16
+ **kwargs,
17
+ ):
18
+ self.visual_hidden_size = visual_hidden_size
19
+ self.llm_hidden_size = llm_hidden_size
20
+ self.depth = depth
21
+ self.hidden_act = hidden_act
22
+ self.bias = bias
23
+ super().__init__(**kwargs)
epoch1.5_ckpt/projector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4397cd33de0a81740f5d1cbdc553b9ea0e63129de2c7ee79d545174a751f930f
3
+ size 34619760
epoch1.5_ckpt/projector/modeling_projector.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel
5
+ from transformers.activations import ACT2FN
6
+
7
+ from .configuration_projector import ProjectorConfig
8
+
9
+
10
+ class ProjectorModel(PreTrainedModel):
11
+ _auto_class = 'AutoModel'
12
+ config_class = ProjectorConfig
13
+ base_model_prefix = 'model'
14
+ supports_gradient_checkpointing = True
15
+
16
+ def __init__(self, config: ProjectorConfig) -> None:
17
+ super().__init__(config)
18
+ self.gradient_checkpointing = False
19
+
20
+ modules = [
21
+ nn.Linear(
22
+ config.visual_hidden_size,
23
+ config.llm_hidden_size,
24
+ bias=config.bias)
25
+ ]
26
+ for _ in range(1, config.depth):
27
+ modules.append(ACT2FN[config.hidden_act])
28
+ modules.append(
29
+ nn.Linear(
30
+ config.llm_hidden_size,
31
+ config.llm_hidden_size,
32
+ bias=config.bias))
33
+ self.model = nn.Sequential(*modules)
34
+
35
+ def enable_input_require_grads(self):
36
+
37
+ def make_inputs_require_grad(module, input, output):
38
+ output.requires_grad_(True)
39
+
40
+ self.model.register_forward_hook(make_inputs_require_grad)
41
+
42
+ def _set_gradient_checkpointing(self, module, value=False):
43
+ if isinstance(module, ProjectorModel):
44
+ module.gradient_checkpointing = value
45
+
46
+ def forward(self, x):
47
+ if self.gradient_checkpointing and self.training:
48
+ layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
49
+ else:
50
+ layer_outputs = self.model(x)
51
+ return layer_outputs
epoch1.5_ckpt/special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "</s>",
5
+ "unk_token": "<unk>"
6
+ }
epoch1.5_ckpt/tokenization_internlm2.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization classes for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, List, Optional, Tuple
22
+
23
+ import sentencepiece as spm
24
+ from transformers.tokenization_utils import PreTrainedTokenizer
25
+ from transformers.utils import logging
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {}
32
+
33
+
34
+ # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
35
+ class InternLM2Tokenizer(PreTrainedTokenizer):
36
+ """
37
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
38
+
39
+ Args:
40
+ vocab_file (`str`):
41
+ Path to the vocabulary file.
42
+ """
43
+
44
+ vocab_files_names = VOCAB_FILES_NAMES
45
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
46
+ model_input_names = ["input_ids", "attention_mask"]
47
+ _auto_class = "AutoTokenizer"
48
+
49
+ def __init__(
50
+ self,
51
+ vocab_file,
52
+ unk_token="<unk>",
53
+ bos_token="<s>",
54
+ eos_token="</s>",
55
+ pad_token="</s>",
56
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
57
+ add_bos_token=True,
58
+ add_eos_token=False,
59
+ decode_with_prefix_space=False,
60
+ clean_up_tokenization_spaces=False,
61
+ **kwargs,
62
+ ):
63
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
64
+ self.vocab_file = vocab_file
65
+ self.add_bos_token = add_bos_token
66
+ self.add_eos_token = add_eos_token
67
+ self.decode_with_prefix_space = decode_with_prefix_space
68
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
69
+ self.sp_model.Load(vocab_file)
70
+ self._no_prefix_space_tokens = None
71
+ super().__init__(
72
+ bos_token=bos_token,
73
+ eos_token=eos_token,
74
+ unk_token=unk_token,
75
+ pad_token=pad_token,
76
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
77
+ **kwargs,
78
+ )
79
+
80
+ @property
81
+ def no_prefix_space_tokens(self):
82
+ if self._no_prefix_space_tokens is None:
83
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
84
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
85
+ return self._no_prefix_space_tokens
86
+
87
+ @property
88
+ def vocab_size(self):
89
+ """Returns vocab size"""
90
+ return self.sp_model.get_piece_size()
91
+
92
+ @property
93
+ def bos_token_id(self) -> Optional[int]:
94
+ return self.sp_model.bos_id()
95
+
96
+ @property
97
+ def eos_token_id(self) -> Optional[int]:
98
+ return self.sp_model.eos_id()
99
+
100
+ def get_vocab(self):
101
+ """Returns vocab as a dict"""
102
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
103
+ vocab.update(self.added_tokens_encoder)
104
+ return vocab
105
+
106
+ def _tokenize(self, text):
107
+ """Returns a tokenized string."""
108
+ return self.sp_model.encode(text, out_type=str)
109
+
110
+ def _convert_token_to_id(self, token):
111
+ """Converts a token (str) in an id using the vocab."""
112
+ return self.sp_model.piece_to_id(token)
113
+
114
+ def _convert_id_to_token(self, index):
115
+ """Converts an index (integer) in a token (str) using the vocab."""
116
+ token = self.sp_model.IdToPiece(index)
117
+ return token
118
+
119
+ def _maybe_add_prefix_space(self, tokens, decoded):
120
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
121
+ return " " + decoded
122
+ else:
123
+ return decoded
124
+
125
+ def convert_tokens_to_string(self, tokens):
126
+ """Converts a sequence of tokens (string) in a single string."""
127
+ current_sub_tokens = []
128
+ out_string = ""
129
+ prev_is_special = False
130
+ for token in tokens:
131
+ # make sure that special tokens are not decoded using sentencepiece model
132
+ if token in self.all_special_tokens:
133
+ if not prev_is_special:
134
+ out_string += " "
135
+ out_string += self.sp_model.decode(current_sub_tokens) + token
136
+ prev_is_special = True
137
+ current_sub_tokens = []
138
+ else:
139
+ current_sub_tokens.append(token)
140
+ prev_is_special = False
141
+ out_string += self.sp_model.decode(current_sub_tokens)
142
+ out_string = self.clean_up_tokenization(out_string)
143
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
144
+ return out_string[1:]
145
+
146
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
147
+ """
148
+ Save the vocabulary and special tokens file to a directory.
149
+
150
+ Args:
151
+ save_directory (`str`):
152
+ The directory in which to save the vocabulary.
153
+
154
+ Returns:
155
+ `Tuple(str)`: Paths to the files saved.
156
+ """
157
+ if not os.path.isdir(save_directory):
158
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
159
+ return
160
+ out_vocab_file = os.path.join(
161
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
162
+ )
163
+
164
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
165
+ copyfile(self.vocab_file, out_vocab_file)
166
+ elif not os.path.isfile(self.vocab_file):
167
+ with open(out_vocab_file, "wb") as fi:
168
+ content_spiece_model = self.sp_model.serialized_model_proto()
169
+ fi.write(content_spiece_model)
170
+
171
+ return (out_vocab_file,)
172
+
173
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
174
+ if self.add_bos_token:
175
+ bos_token_ids = [self.bos_token_id]
176
+ else:
177
+ bos_token_ids = []
178
+
179
+ output = bos_token_ids + token_ids_0
180
+
181
+ if token_ids_1 is not None:
182
+ output = output + token_ids_1
183
+
184
+ if self.add_eos_token:
185
+ output = output + [self.eos_token_id]
186
+
187
+ return output
188
+
189
+ def get_special_tokens_mask(
190
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
191
+ ) -> List[int]:
192
+ """
193
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
194
+ special tokens using the tokenizer `prepare_for_model` method.
195
+
196
+ Args:
197
+ token_ids_0 (`List[int]`):
198
+ List of IDs.
199
+ token_ids_1 (`List[int]`, *optional*):
200
+ Optional second list of IDs for sequence pairs.
201
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
202
+ Whether or not the token list is already formatted with special tokens for the model.
203
+
204
+ Returns:
205
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
206
+ """
207
+ if already_has_special_tokens:
208
+ return super().get_special_tokens_mask(
209
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
210
+ )
211
+
212
+ if token_ids_1 is None:
213
+ return [1] + ([0] * len(token_ids_0)) + [1]
214
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
215
+
216
+ def create_token_type_ids_from_sequences(
217
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
218
+ ) -> List[int]:
219
+ """
220
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
221
+ use of token type ids, therefore a list of zeros is returned.
222
+
223
+ Args:
224
+ token_ids_0 (`List[int]`):
225
+ List of IDs.
226
+ token_ids_1 (`List[int]`, *optional*):
227
+ Optional second list of IDs for sequence pairs.
228
+
229
+ Returns:
230
+ `List[int]`: List of zeros.
231
+ """
232
+ eos = [self.eos_token_id]
233
+
234
+ if token_ids_1 is None:
235
+ return len(token_ids_0 + eos) * [0]
236
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
epoch1.5_ckpt/tokenization_internlm2_fast.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization Fast class for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, Optional, Tuple
22
+
23
+ from tokenizers import processors, decoders, Tokenizer, normalizers
24
+ from tokenizers.models import BPE
25
+
26
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
27
+ from transformers.utils import logging
28
+
29
+ from transformers.convert_slow_tokenizer import (
30
+ SLOW_TO_FAST_CONVERTERS,
31
+ SpmConverter,
32
+ SentencePieceExtractor,
33
+ )
34
+
35
+ from .tokenization_internlm2 import InternLM2Tokenizer
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
40
+
41
+ # Modified from transformers.convert_slow_tokenizer.LlamaConverter
42
+ class InternLM2Converter(SpmConverter):
43
+ handle_byte_fallback = True
44
+
45
+ def vocab(self, proto):
46
+ vocab = [
47
+ ("<unk>", 0.0),
48
+ ("<s>", 0.0),
49
+ ("</s>", 0.0),
50
+ ]
51
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
52
+ return vocab
53
+
54
+ def unk_id(self, proto):
55
+ unk_id = 0
56
+ return unk_id
57
+
58
+ def decoder(self, replacement, add_prefix_space):
59
+ decoders_sequence = [
60
+ decoders.Replace("▁", " "),
61
+ decoders.ByteFallback(),
62
+ decoders.Fuse(),
63
+ ]
64
+ if self.proto.normalizer_spec.add_dummy_prefix:
65
+ decoders_sequence.append(decoders.Strip(content=" ", left=1))
66
+ return decoders.Sequence(decoders_sequence)
67
+
68
+ def tokenizer(self, proto):
69
+ model_type = proto.trainer_spec.model_type
70
+ vocab_scores = self.vocab(proto)
71
+ # special tokens
72
+ added_tokens = self.original_tokenizer.added_tokens_decoder
73
+ for i in range(len(vocab_scores)):
74
+ piece, score = vocab_scores[i]
75
+ if i in added_tokens:
76
+ vocab_scores[i] = (added_tokens[i].content, score)
77
+ if model_type == 1:
78
+ raise RuntimeError("InternLM2 is supposed to be a BPE model!")
79
+
80
+ elif model_type == 2:
81
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
82
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
83
+ tokenizer = Tokenizer(
84
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
85
+ )
86
+ tokenizer.add_special_tokens(
87
+ [ added_token for index, added_token in added_tokens.items()]
88
+ )
89
+ else:
90
+ raise Exception(
91
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
92
+ )
93
+
94
+ return tokenizer
95
+
96
+ def normalizer(self, proto):
97
+ normalizers_list = []
98
+ if proto.normalizer_spec.add_dummy_prefix:
99
+ normalizers_list.append(normalizers.Prepend(prepend="▁"))
100
+ normalizers_list.append(normalizers.Replace(pattern=" ", content="▁"))
101
+ return normalizers.Sequence(normalizers_list)
102
+
103
+ def pre_tokenizer(self, replacement, add_prefix_space):
104
+ return None
105
+
106
+ SLOW_TO_FAST_CONVERTERS["InternLM2Tokenizer"] = InternLM2Converter
107
+
108
+
109
+ # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
110
+ class InternLM2TokenizerFast(PreTrainedTokenizerFast):
111
+ vocab_files_names = VOCAB_FILES_NAMES
112
+ slow_tokenizer_class = InternLM2Tokenizer
113
+ padding_side = "left"
114
+ model_input_names = ["input_ids", "attention_mask"]
115
+ _auto_class = "AutoTokenizer"
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_file,
120
+ unk_token="<unk>",
121
+ bos_token="<s>",
122
+ eos_token="</s>",
123
+ pad_token="</s>",
124
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
125
+ add_bos_token=True,
126
+ add_eos_token=False,
127
+ decode_with_prefix_space=False,
128
+ clean_up_tokenization_spaces=False,
129
+ **kwargs,
130
+ ):
131
+ super().__init__(
132
+ vocab_file=vocab_file,
133
+ unk_token=unk_token,
134
+ bos_token=bos_token,
135
+ eos_token=eos_token,
136
+ pad_token=pad_token,
137
+ sp_model_kwargs=sp_model_kwargs,
138
+ add_bos_token=add_bos_token,
139
+ add_eos_token=add_eos_token,
140
+ decode_with_prefix_space=decode_with_prefix_space,
141
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
142
+ **kwargs,
143
+ )
144
+ self._add_bos_token = add_bos_token
145
+ self._add_eos_token = add_eos_token
146
+ self.update_post_processor()
147
+ self.vocab_file = vocab_file
148
+
149
+ @property
150
+ def can_save_slow_tokenizer(self) -> bool:
151
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
152
+
153
+ def update_post_processor(self):
154
+ """
155
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
156
+ """
157
+ bos = self.bos_token
158
+ bos_token_id = self.bos_token_id
159
+ if bos is None and self.add_bos_token:
160
+ raise ValueError("add_bos_token = True but bos_token = None")
161
+
162
+ eos = self.eos_token
163
+ eos_token_id = self.eos_token_id
164
+ if eos is None and self.add_eos_token:
165
+ raise ValueError("add_eos_token = True but eos_token = None")
166
+
167
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
168
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
169
+
170
+ special_tokens = []
171
+ if self.add_bos_token:
172
+ special_tokens.append((bos, bos_token_id))
173
+ if self.add_eos_token:
174
+ special_tokens.append((eos, eos_token_id))
175
+ self._tokenizer.post_processor = processors.TemplateProcessing(
176
+ single=single, pair=pair, special_tokens=special_tokens
177
+ )
178
+
179
+ @property
180
+ def add_eos_token(self):
181
+ return self._add_eos_token
182
+
183
+ @property
184
+ def add_bos_token(self):
185
+ return self._add_bos_token
186
+
187
+ @add_eos_token.setter
188
+ def add_eos_token(self, value):
189
+ self._add_eos_token = value
190
+ self.update_post_processor()
191
+
192
+ @add_bos_token.setter
193
+ def add_bos_token(self, value):
194
+ self._add_bos_token = value
195
+ self.update_post_processor()
196
+
197
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
198
+ if not self.can_save_slow_tokenizer:
199
+ raise ValueError(
200
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
201
+ "tokenizer."
202
+ )
203
+
204
+ if not os.path.isdir(save_directory):
205
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
206
+ return
207
+ out_vocab_file = os.path.join(
208
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
209
+ )
210
+
211
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
212
+ copyfile(self.vocab_file, out_vocab_file)
213
+
214
+ return (out_vocab_file,)
epoch1.5_ckpt/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
epoch1.5_ckpt/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f868398fc4e05ee1e8aeba95ddf18ddcc45b8bce55d5093bead5bbf80429b48b
3
+ size 1477754
epoch1.5_ckpt/tokenizer_config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "auto_map": {
31
+ "AutoTokenizer": [
32
+ "tokenization_internlm2.InternLM2Tokenizer",
33
+ "tokenization_internlm2_fast.InternLM2TokenizerFast"
34
+ ]
35
+ },
36
+ "bos_token": "<s>",
37
+ "clean_up_tokenization_spaces": false,
38
+ "decode_with_prefix_space": false,
39
+ "eos_token": "</s>",
40
+ "model_max_length": 1000000000000000019884624838656,
41
+ "pad_token": "</s>",
42
+ "padding_side": "right",
43
+ "sp_model_kwargs": null,
44
+ "tokenizer_class": "InternLM2Tokenizer",
45
+ "unk_token": "<unk>"
46
+ }
epoch1.5_ckpt/xtuner_config.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SYSTEM = ''
2
+ accumulative_counts = 2
3
+ batch_size = 8
4
+ betas = (
5
+ 0.9,
6
+ 0.999,
7
+ )
8
+ custom_hooks = [
9
+ dict(
10
+ tokenizer=dict(
11
+ padding_side='right',
12
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
13
+ trust_remote_code=True,
14
+ type='transformers.AutoTokenizer.from_pretrained'),
15
+ type='xtuner.engine.hooks.DatasetInfoHook'),
16
+ dict(
17
+ evaluation_images='https://llava-vl.github.io/static/images/view.jpg',
18
+ evaluation_inputs=[
19
+ '请描述一下这张照片',
20
+ 'Please describe this picture',
21
+ ],
22
+ every_n_iters=500,
23
+ image_processor=dict(
24
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
25
+ trust_remote_code=True,
26
+ type='transformers.SiglipImageProcessor.from_pretrained'),
27
+ prompt_template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
28
+ system='',
29
+ tokenizer=dict(
30
+ padding_side='right',
31
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
32
+ trust_remote_code=True,
33
+ type='transformers.AutoTokenizer.from_pretrained'),
34
+ type='xtuner.engine.hooks.EvaluateChatHook'),
35
+ ]
36
+ data_path = './llava_data/llava_v1_5_lrv_mix1008k.json'
37
+ data_root = './llava_data/'
38
+ dataloader_num_workers = 4
39
+ default_hooks = dict(
40
+ checkpoint=dict(
41
+ by_epoch=False,
42
+ interval=500,
43
+ max_keep_ckpts=2,
44
+ type='mmengine.hooks.CheckpointHook'),
45
+ logger=dict(
46
+ interval=10,
47
+ log_metric_by_epoch=False,
48
+ type='mmengine.hooks.LoggerHook'),
49
+ param_scheduler=dict(type='mmengine.hooks.ParamSchedulerHook'),
50
+ sampler_seed=dict(type='mmengine.hooks.DistSamplerSeedHook'),
51
+ timer=dict(type='mmengine.hooks.IterTimerHook'))
52
+ dino_path = 'facebook/dinov2-large'
53
+ env_cfg = dict(
54
+ cudnn_benchmark=False,
55
+ dist_cfg=dict(backend='nccl'),
56
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
57
+ evaluation_freq = 500
58
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
59
+ evaluation_inputs = [
60
+ '请描述一下这张照片',
61
+ 'Please describe this picture',
62
+ ]
63
+ image_folder = './llava_data/llava_images'
64
+ image_processor = dict(
65
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
66
+ trust_remote_code=True,
67
+ type='transformers.SiglipImageProcessor.from_pretrained')
68
+ image_processor_path = 'google/siglip-so400m-patch14-384'
69
+ launcher = 'pytorch'
70
+ llava_dataset = dict(
71
+ data_path='./llava_data/llava_v1_5_lrv_mix1008k.json',
72
+ dataset_map_fn='xtuner.dataset.map_fns.llava_map_fn',
73
+ image_folder='./llava_data/llava_images',
74
+ image_processor=dict(
75
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
76
+ trust_remote_code=True,
77
+ type='transformers.SiglipImageProcessor.from_pretrained'),
78
+ max_length=1472,
79
+ pad_image_to_square=False,
80
+ template_map_fn=dict(
81
+ template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
82
+ type='xtuner.dataset.map_fns.template_map_fn_factory'),
83
+ tokenizer=dict(
84
+ padding_side='right',
85
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
86
+ trust_remote_code=True,
87
+ type='transformers.AutoTokenizer.from_pretrained'),
88
+ type='xtuner.dataset.LLaVADataset')
89
+ llm_name_or_path = 'internlm/internlm2-1_8b'
90
+ load_from = None
91
+ log_level = 'INFO'
92
+ log_processor = dict(by_epoch=False)
93
+ lr = 2e-05
94
+ max_epochs = 2
95
+ max_length = 1472
96
+ max_norm = 1
97
+ model = dict(
98
+ dino=dict(
99
+ pretrained_model_name_or_path='facebook/dinov2-large',
100
+ type='transformers.Dinov2Model.from_pretrained'),
101
+ freeze_llm=False,
102
+ freeze_visual_encoder=True,
103
+ llm=dict(
104
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
105
+ quantization_config=dict(
106
+ bnb_4bit_compute_dtype='torch.float16',
107
+ bnb_4bit_quant_type='nf4',
108
+ bnb_4bit_use_double_quant=True,
109
+ llm_int8_has_fp16_weight=False,
110
+ llm_int8_threshold=6.0,
111
+ load_in_4bit=True,
112
+ load_in_8bit=False,
113
+ type='transformers.BitsAndBytesConfig'),
114
+ torch_dtype='torch.float16',
115
+ trust_remote_code=True,
116
+ type='transformers.AutoModelForCausalLM.from_pretrained'),
117
+ siglip=dict(
118
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
119
+ type='transformers.SiglipVisionModel.from_pretrained'),
120
+ type='xtuner.model.LLaVAModel')
121
+ optim_type = 'torch.optim.AdamW'
122
+ optim_wrapper = dict(
123
+ optimizer=dict(
124
+ betas=(
125
+ 0.9,
126
+ 0.999,
127
+ ),
128
+ lr=2e-05,
129
+ type='torch.optim.AdamW',
130
+ weight_decay=0.1),
131
+ type='DeepSpeedOptimWrapper')
132
+ param_scheduler = [
133
+ dict(
134
+ begin=0,
135
+ by_epoch=True,
136
+ convert_to_iter_based=True,
137
+ end=0.06,
138
+ start_factor=1e-05,
139
+ type='mmengine.optim.LinearLR'),
140
+ dict(
141
+ begin=0.06,
142
+ by_epoch=True,
143
+ convert_to_iter_based=True,
144
+ end=2,
145
+ eta_min=0.0,
146
+ type='mmengine.optim.CosineAnnealingLR'),
147
+ ]
148
+ prompt_template = 'xtuner.utils.PROMPT_TEMPLATE.internlm2_chat'
149
+ randomness = dict(deterministic=False, seed=None)
150
+ resume = False
151
+ runner_type = 'FlexibleRunner'
152
+ save_steps = 500
153
+ save_total_limit = 2
154
+ siglip_path = 'google/siglip-so400m-patch14-384'
155
+ strategy = dict(
156
+ config=dict(
157
+ bf16=dict(enabled=True),
158
+ fp16=dict(enabled=False, initial_scale_power=16),
159
+ gradient_accumulation_steps='auto',
160
+ gradient_clipping='auto',
161
+ train_micro_batch_size_per_gpu='auto',
162
+ zero_allow_untested_optimizer=True,
163
+ zero_force_ds_cpu_optimizer=False,
164
+ zero_optimization=dict(overlap_comm=True, stage=2)),
165
+ exclude_frozen_parameters=True,
166
+ gradient_accumulation_steps=2,
167
+ gradient_clipping=1,
168
+ train_micro_batch_size_per_gpu=8,
169
+ type='xtuner.engine.DeepSpeedStrategy')
170
+ tokenizer = dict(
171
+ padding_side='right',
172
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
173
+ trust_remote_code=True,
174
+ type='transformers.AutoTokenizer.from_pretrained')
175
+ train_cfg = dict(max_epochs=2, type='xtuner.engine.runner.TrainLoop')
176
+ train_dataloader = dict(
177
+ batch_size=8,
178
+ collate_fn=dict(type='xtuner.dataset.collate_fns.default_collate_fn'),
179
+ dataset=dict(
180
+ data_path='./llava_data/llava_v1_5_lrv_mix1008k.json',
181
+ dataset_map_fn='xtuner.dataset.map_fns.llava_map_fn',
182
+ image_folder='./llava_data/llava_images',
183
+ image_processor=dict(
184
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
185
+ trust_remote_code=True,
186
+ type='transformers.SiglipImageProcessor.from_pretrained'),
187
+ max_length=1472,
188
+ pad_image_to_square=False,
189
+ template_map_fn=dict(
190
+ template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
191
+ type='xtuner.dataset.map_fns.template_map_fn_factory'),
192
+ tokenizer=dict(
193
+ padding_side='right',
194
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
195
+ trust_remote_code=True,
196
+ type='transformers.AutoTokenizer.from_pretrained'),
197
+ type='xtuner.dataset.LLaVADataset'),
198
+ num_workers=4,
199
+ sampler=dict(shuffle=True, type='mmengine.dataset.DefaultSampler'))
200
+ visualizer = dict(
201
+ type='mmengine.visualization.Visualizer',
202
+ vis_backends=[
203
+ dict(type='mmengine.visualization.TensorboardVisBackend'),
204
+ ])
205
+ warmup_ratio = 0.03
206
+ weight_decay = 0.1
207
+ work_dir = './work_dirs/train_config'
epoch1_ckpt/config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "internlm/internlm2-1_8b",
3
+ "architectures": [
4
+ "InternLM2ForCausalLM"
5
+ ],
6
+ "attn_implementation": "eager",
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_internlm2.InternLM2Config",
9
+ "AutoModel": "internlm/internlm2-1_8b--modeling_internlm2.InternLM2ForCausalLM",
10
+ "AutoModelForCausalLM": "modeling_internlm2.InternLM2ForCausalLM"
11
+ },
12
+ "bias": false,
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2048,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 8192,
19
+ "max_position_embeddings": 32768,
20
+ "model_type": "internlm2",
21
+ "num_attention_heads": 16,
22
+ "num_hidden_layers": 24,
23
+ "num_key_value_heads": 8,
24
+ "pad_token_id": 2,
25
+ "quantization_config": {
26
+ "_load_in_4bit": true,
27
+ "_load_in_8bit": false,
28
+ "bnb_4bit_compute_dtype": "float16",
29
+ "bnb_4bit_quant_type": "nf4",
30
+ "bnb_4bit_use_double_quant": true,
31
+ "llm_int8_enable_fp32_cpu_offload": false,
32
+ "llm_int8_has_fp16_weight": false,
33
+ "llm_int8_skip_modules": null,
34
+ "llm_int8_threshold": 6.0,
35
+ "load_in_4bit": true,
36
+ "load_in_8bit": false,
37
+ "quant_method": "bitsandbytes"
38
+ },
39
+ "rms_norm_eps": 1e-05,
40
+ "rope_scaling": null,
41
+ "rope_theta": 1000000,
42
+ "tie_word_embeddings": false,
43
+ "torch_dtype": "float16",
44
+ "transformers_version": "4.39.0.dev0",
45
+ "use_cache": false,
46
+ "vocab_size": 92544
47
+ }
epoch1_ckpt/configuration_internlm2.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ InternLM2 model configuration"""
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
+
26
+
27
+ # Modified from transformers.model.llama.configuration_llama.LlamaConfig
28
+ class InternLM2Config(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
31
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
32
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32000):
40
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`InternLM2Model`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 11008):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
61
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
62
+ just in case (e.g., 512 or 1024 or 2048).
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
66
+ The epsilon used by the rms normalization layers.
67
+ use_cache (`bool`, *optional*, defaults to `True`):
68
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
69
+ relevant if `config.is_decoder=True`.
70
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
71
+ Whether to tie weight embeddings
72
+ Example:
73
+
74
+ """
75
+ model_type = "internlm2"
76
+ _auto_class = "AutoConfig"
77
+
78
+ def __init__( # pylint: disable=W0102
79
+ self,
80
+ vocab_size=103168,
81
+ hidden_size=4096,
82
+ intermediate_size=11008,
83
+ num_hidden_layers=32,
84
+ num_attention_heads=32,
85
+ num_key_value_heads=None,
86
+ hidden_act="silu",
87
+ max_position_embeddings=2048,
88
+ initializer_range=0.02,
89
+ rms_norm_eps=1e-6,
90
+ use_cache=True,
91
+ pad_token_id=0,
92
+ bos_token_id=1,
93
+ eos_token_id=2,
94
+ tie_word_embeddings=False,
95
+ bias=True,
96
+ rope_theta=10000,
97
+ rope_scaling=None,
98
+ attn_implementation="eager",
99
+ **kwargs,
100
+ ):
101
+ self.vocab_size = vocab_size
102
+ self.max_position_embeddings = max_position_embeddings
103
+ self.hidden_size = hidden_size
104
+ self.intermediate_size = intermediate_size
105
+ self.num_hidden_layers = num_hidden_layers
106
+ self.num_attention_heads = num_attention_heads
107
+ self.bias = bias
108
+
109
+ if num_key_value_heads is None:
110
+ num_key_value_heads = num_attention_heads
111
+ self.num_key_value_heads = num_key_value_heads
112
+
113
+ self.hidden_act = hidden_act
114
+ self.initializer_range = initializer_range
115
+ self.rms_norm_eps = rms_norm_eps
116
+ self.use_cache = use_cache
117
+ self.rope_theta = rope_theta
118
+ self.rope_scaling = rope_scaling
119
+ self._rope_scaling_validation()
120
+
121
+ self.attn_implementation = attn_implementation
122
+ if self.attn_implementation is None:
123
+ self.attn_implementation = "eager"
124
+ super().__init__(
125
+ pad_token_id=pad_token_id,
126
+ bos_token_id=bos_token_id,
127
+ eos_token_id=eos_token_id,
128
+ tie_word_embeddings=tie_word_embeddings,
129
+ **kwargs,
130
+ )
131
+
132
+ def _rope_scaling_validation(self):
133
+ """
134
+ Validate the `rope_scaling` configuration.
135
+ """
136
+ if self.rope_scaling is None:
137
+ return
138
+
139
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
140
+ raise ValueError(
141
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
142
+ f"got {self.rope_scaling}"
143
+ )
144
+ rope_scaling_type = self.rope_scaling.get("type", None)
145
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
146
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
147
+ raise ValueError(
148
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
149
+ )
150
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
151
+ raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
epoch1_ckpt/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.39.0.dev0"
7
+ }
epoch1_ckpt/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72384fdb9bca2fbc5f93739199d30573be01d8d2362c873f360f450d8425af8f
3
+ size 1537498688
epoch1_ckpt/modeling_internlm2.py ADDED
@@ -0,0 +1,1391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ PyTorch InternLM2 model."""
17
+ import math
18
+ import queue
19
+ import threading
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from einops import rearrange
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ SequenceClassifierOutputWithPast,
34
+ )
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import (
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ replace_return_docstrings,
41
+ )
42
+
43
+ try:
44
+ from transformers.generation.streamers import BaseStreamer
45
+ except: # noqa # pylint: disable=bare-except
46
+ BaseStreamer = None
47
+
48
+ from .configuration_internlm2 import InternLM2Config
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CONFIG_FOR_DOC = "InternLM2Config"
53
+
54
+ flash_attn_func, flash_attn_varlen_func = None, None
55
+ pad_input, index_first_axis, unpad_input = None, None, None
56
+ def _import_flash_attn():
57
+ global flash_attn_func, flash_attn_varlen_func
58
+ global pad_input, index_first_axis, unpad_input
59
+ try:
60
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
61
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
62
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
63
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
64
+ except ImportError:
65
+ raise ImportError("flash_attn is not installed.")
66
+
67
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
+ def _get_unpad_data(attention_mask):
69
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
70
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
72
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
73
+ return (
74
+ indices,
75
+ cu_seqlens,
76
+ max_seqlen_in_batch,
77
+ )
78
+
79
+
80
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
81
+ def _make_causal_mask(
82
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
83
+ ):
84
+ """
85
+ Make causal mask used for bi-directional self-attention.
86
+ """
87
+ bsz, tgt_len = input_ids_shape
88
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
89
+ mask_cond = torch.arange(mask.size(-1), device=device)
90
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
91
+ mask = mask.to(dtype)
92
+
93
+ if past_key_values_length > 0:
94
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
95
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
96
+
97
+
98
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
99
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
100
+ """
101
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
102
+ """
103
+ bsz, src_len = mask.size()
104
+ tgt_len = tgt_len if tgt_len is not None else src_len
105
+
106
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
107
+
108
+ inverted_mask = 1.0 - expanded_mask
109
+
110
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
111
+
112
+
113
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
114
+ class InternLM2RMSNorm(nn.Module):
115
+ def __init__(self, hidden_size, eps=1e-6):
116
+ """
117
+ InternLM2RMSNorm is equivalent to T5LayerNorm
118
+ """
119
+ super().__init__()
120
+ self.weight = nn.Parameter(torch.ones(hidden_size))
121
+ self.variance_epsilon = eps
122
+
123
+ def forward(self, hidden_states):
124
+ input_dtype = hidden_states.dtype
125
+ hidden_states = hidden_states.to(torch.float32)
126
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
127
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
128
+ return self.weight * hidden_states.to(input_dtype)
129
+
130
+
131
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
132
+ class InternLM2RotaryEmbedding(nn.Module):
133
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
134
+ super().__init__()
135
+
136
+ self.dim = dim
137
+ self.max_position_embeddings = max_position_embeddings
138
+ self.base = base
139
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
140
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
141
+
142
+ # Build here to make `torch.jit.trace` work.
143
+ self._set_cos_sin_cache(
144
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
145
+ )
146
+
147
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
148
+ self.max_seq_len_cached = seq_len
149
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
150
+
151
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
152
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
153
+ emb = torch.cat((freqs, freqs), dim=-1)
154
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
155
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
156
+
157
+ def forward(self, x, seq_len=None):
158
+ # x: [bs, num_attention_heads, seq_len, head_size]
159
+ if seq_len > self.max_seq_len_cached:
160
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
161
+
162
+ return (
163
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
164
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
165
+ )
166
+
167
+
168
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
169
+ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
170
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
171
+
172
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
173
+ self.scaling_factor = scaling_factor
174
+ super().__init__(dim, max_position_embeddings, base, device)
175
+
176
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
177
+ self.max_seq_len_cached = seq_len
178
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
179
+ t = t / self.scaling_factor
180
+
181
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
182
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
183
+ emb = torch.cat((freqs, freqs), dim=-1)
184
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
185
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
186
+
187
+
188
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
189
+ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
190
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
191
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
192
+ """
193
+
194
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
195
+ self.scaling_factor = scaling_factor
196
+ super().__init__(dim, max_position_embeddings, base, device)
197
+
198
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
199
+ self.max_seq_len_cached = seq_len
200
+
201
+ if seq_len > self.max_position_embeddings:
202
+ base = self.base * (
203
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
204
+ ) ** (self.dim / (self.dim - 2))
205
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
206
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
207
+
208
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
209
+
210
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
211
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
212
+ emb = torch.cat((freqs, freqs), dim=-1)
213
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
214
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
215
+
216
+
217
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
218
+ def rotate_half(x):
219
+ """Rotates half the hidden dims of the input."""
220
+ x1 = x[..., : x.shape[-1] // 2]
221
+ x2 = x[..., x.shape[-1] // 2 :]
222
+ return torch.cat((-x2, x1), dim=-1)
223
+
224
+
225
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
226
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
227
+ """Applies Rotary Position Embedding to the query and key tensors."""
228
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
229
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
230
+ q_embed = (q * cos) + (rotate_half(q) * sin)
231
+ k_embed = (k * cos) + (rotate_half(k) * sin)
232
+ return q_embed, k_embed
233
+
234
+
235
+ class InternLM2MLP(nn.Module):
236
+ def __init__(self, config):
237
+ super().__init__()
238
+ self.config = config
239
+ self.hidden_size = config.hidden_size
240
+ self.intermediate_size = config.intermediate_size
241
+ self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
242
+ self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
243
+ self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
244
+ self.act_fn = ACT2FN[config.hidden_act]
245
+
246
+ def forward(self, x):
247
+ down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
248
+
249
+ return down_proj
250
+
251
+
252
+ # Copied from transformers.model.llama.modeling_llama.repeat_kv
253
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
254
+ """
255
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
256
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
257
+ """
258
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
259
+ if n_rep == 1:
260
+ return hidden_states
261
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
262
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
263
+
264
+
265
+ # Modified from transformers.model.llama.modeling_llama.LlamaAttention
266
+ class InternLM2Attention(nn.Module):
267
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
268
+
269
+ def __init__(self, config: InternLM2Config):
270
+ super().__init__()
271
+ self.config = config
272
+ self.hidden_size = config.hidden_size
273
+ self.num_heads = config.num_attention_heads
274
+ self.head_dim = self.hidden_size // self.num_heads
275
+ self.num_key_value_heads = config.num_key_value_heads
276
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
277
+ self.max_position_embeddings = config.max_position_embeddings
278
+ self.is_causal = True
279
+
280
+ if (self.head_dim * self.num_heads) != self.hidden_size:
281
+ raise ValueError(
282
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
283
+ f" and `num_heads`: {self.num_heads})."
284
+ )
285
+
286
+ self.wqkv = nn.Linear(
287
+ self.hidden_size,
288
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
289
+ bias=config.bias,
290
+ )
291
+
292
+ self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
293
+ self._init_rope()
294
+
295
+ def _init_rope(self):
296
+ if self.config.rope_scaling is None:
297
+ self.rotary_emb = InternLM2RotaryEmbedding(
298
+ self.head_dim,
299
+ max_position_embeddings=self.max_position_embeddings,
300
+ base=self.config.rope_theta,
301
+ )
302
+ else:
303
+ scaling_type = self.config.rope_scaling["type"]
304
+ scaling_factor = self.config.rope_scaling["factor"]
305
+ if scaling_type == "dynamic":
306
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
307
+ self.head_dim,
308
+ max_position_embeddings=self.max_position_embeddings,
309
+ base=self.config.rope_theta,
310
+ scaling_factor=scaling_factor,
311
+ )
312
+ elif scaling_type == "linear":
313
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
314
+ self.head_dim,
315
+ max_position_embeddings=self.max_position_embeddings,
316
+ base=self.config.rope_theta,
317
+ scaling_factor=scaling_factor,
318
+ )
319
+ else:
320
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
321
+ return self.rotary_emb
322
+
323
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
324
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
325
+
326
+ def forward(
327
+ self,
328
+ hidden_states: torch.Tensor,
329
+ attention_mask: Optional[torch.Tensor] = None,
330
+ position_ids: Optional[torch.LongTensor] = None,
331
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
332
+ output_attentions: bool = False,
333
+ use_cache: bool = False,
334
+ **kwargs,
335
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
336
+ if "padding_mask" in kwargs:
337
+ warnings.warn(
338
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
339
+ "Please make sure use `attention_mask` instead.`"
340
+ )
341
+
342
+ bsz, q_len, _ = hidden_states.size()
343
+
344
+ qkv_states = self.wqkv(hidden_states)
345
+
346
+ qkv_states = rearrange(
347
+ qkv_states,
348
+ "b q (h gs d) -> b q h gs d",
349
+ gs=2 + self.num_key_value_groups,
350
+ d=self.head_dim,
351
+ )
352
+
353
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
354
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
355
+ key_states = qkv_states[..., -2, :]
356
+ value_states = qkv_states[..., -1, :]
357
+
358
+ query_states = query_states.transpose(1, 2)
359
+ key_states = key_states.transpose(1, 2)
360
+ value_states = value_states.transpose(1, 2)
361
+
362
+ kv_seq_len = key_states.shape[-2]
363
+ if past_key_value is not None:
364
+ kv_seq_len += past_key_value[0].shape[-2]
365
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
366
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
367
+
368
+ if past_key_value is not None:
369
+ # reuse k, v, self_attention
370
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
371
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
372
+
373
+ past_key_value = (key_states, value_states) if use_cache else None
374
+
375
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
376
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
377
+
378
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
379
+
380
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
381
+ raise ValueError(
382
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
383
+ f" {attn_weights.size()}"
384
+ )
385
+
386
+ if attention_mask is not None:
387
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
388
+ raise ValueError(
389
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
390
+ )
391
+ attn_weights = attn_weights + attention_mask
392
+
393
+ # upcast attention to fp32
394
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
395
+ attn_output = torch.matmul(attn_weights, value_states)
396
+
397
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
398
+ raise ValueError(
399
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
400
+ f" {attn_output.size()}"
401
+ )
402
+
403
+ attn_output = attn_output.transpose(1, 2).contiguous()
404
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
405
+
406
+ attn_output = self.wo(attn_output)
407
+
408
+ if not output_attentions:
409
+ attn_weights = None
410
+
411
+ return attn_output, attn_weights, past_key_value
412
+
413
+
414
+ # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
415
+ class InternLM2FlashAttention2(InternLM2Attention):
416
+ """
417
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
418
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
419
+ flash attention and deal with padding tokens in case the input contains any of them.
420
+ """
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: Optional[torch.LongTensor] = None,
426
+ position_ids: Optional[torch.LongTensor] = None,
427
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
428
+ output_attentions: bool = False,
429
+ use_cache: bool = False,
430
+ **kwargs,
431
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
432
+ # InternLM2FlashAttention2 attention does not support output_attentions
433
+ if "padding_mask" in kwargs:
434
+ warnings.warn(
435
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
436
+ "Please make sure use `attention_mask` instead.`"
437
+ )
438
+
439
+ # overwrite attention_mask with padding_mask
440
+ attention_mask = kwargs.pop("padding_mask")
441
+
442
+ output_attentions = False
443
+
444
+ bsz, q_len, _ = hidden_states.size()
445
+
446
+ qkv_states = self.wqkv(hidden_states)
447
+
448
+ qkv_states = rearrange(
449
+ qkv_states,
450
+ "b q (h gs d) -> b q h gs d",
451
+ gs=2 + self.num_key_value_groups,
452
+ d=self.head_dim,
453
+ )
454
+
455
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
456
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
457
+ key_states = qkv_states[..., -2, :]
458
+ value_states = qkv_states[..., -1, :]
459
+
460
+ query_states = query_states.transpose(1, 2)
461
+ key_states = key_states.transpose(1, 2)
462
+ value_states = value_states.transpose(1, 2)
463
+
464
+ kv_seq_len = key_states.shape[-2]
465
+ if past_key_value is not None:
466
+ kv_seq_len += past_key_value[0].shape[-2]
467
+
468
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
469
+
470
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
471
+
472
+ if past_key_value is not None:
473
+ # reuse k, v, self_attention
474
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
475
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
476
+
477
+ past_key_value = (key_states, value_states) if use_cache else None
478
+
479
+ query_states = query_states.transpose(1, 2)
480
+ key_states = key_states.transpose(1, 2)
481
+ value_states = value_states.transpose(1, 2)
482
+
483
+ attn_output = self._flash_attention_forward(
484
+ query_states, key_states, value_states, attention_mask, q_len
485
+ )
486
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
487
+ attn_output = self.wo(attn_output)
488
+
489
+ if not output_attentions:
490
+ attn_weights = None
491
+
492
+ return attn_output, attn_weights, past_key_value
493
+
494
+ def _flash_attention_forward(
495
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
496
+ ):
497
+ """
498
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
499
+ first unpad the input, then computes the attention scores and pad the final attention scores.
500
+
501
+ Args:
502
+ query_states (`torch.Tensor`):
503
+ Input query states to be passed to Flash Attention API
504
+ key_states (`torch.Tensor`):
505
+ Input key states to be passed to Flash Attention API
506
+ value_states (`torch.Tensor`):
507
+ Input value states to be passed to Flash Attention API
508
+ attention_mask (`torch.Tensor`):
509
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
510
+ position of padding tokens and 1 for the position of non-padding tokens.
511
+ dropout (`int`, *optional*):
512
+ Attention dropout
513
+ softmax_scale (`float`, *optional*):
514
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
515
+ """
516
+ # Contains at least one padding token in the sequence
517
+ causal = self.is_causal and query_length != 1
518
+ if attention_mask is not None:
519
+ batch_size = query_states.shape[0]
520
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
521
+ query_states, key_states, value_states, attention_mask, query_length
522
+ )
523
+
524
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
525
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
526
+
527
+ attn_output_unpad = flash_attn_varlen_func(
528
+ query_states,
529
+ key_states,
530
+ value_states,
531
+ cu_seqlens_q=cu_seqlens_q,
532
+ cu_seqlens_k=cu_seqlens_k,
533
+ max_seqlen_q=max_seqlen_in_batch_q,
534
+ max_seqlen_k=max_seqlen_in_batch_k,
535
+ dropout_p=dropout,
536
+ softmax_scale=softmax_scale,
537
+ causal=causal,
538
+ )
539
+
540
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
541
+ else:
542
+ attn_output = flash_attn_func(
543
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
544
+ )
545
+
546
+ return attn_output
547
+
548
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
549
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
550
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
551
+
552
+ key_layer = index_first_axis(
553
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
554
+ )
555
+ value_layer = index_first_axis(
556
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
557
+ )
558
+
559
+ if query_length == kv_seq_len:
560
+ query_layer = index_first_axis(
561
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
562
+ )
563
+ cu_seqlens_q = cu_seqlens_k
564
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
565
+ indices_q = indices_k
566
+ elif query_length == 1:
567
+ max_seqlen_in_batch_q = 1
568
+ cu_seqlens_q = torch.arange(
569
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
570
+ ) # There is a memcpy here, that is very bad.
571
+ indices_q = cu_seqlens_q[:-1]
572
+ query_layer = query_layer.squeeze(1)
573
+ else:
574
+ # The -q_len: slice assumes left padding.
575
+ attention_mask = attention_mask[:, -query_length:]
576
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
577
+
578
+ return (
579
+ query_layer,
580
+ key_layer,
581
+ value_layer,
582
+ indices_q.to(torch.int64),
583
+ (cu_seqlens_q, cu_seqlens_k),
584
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
585
+ )
586
+
587
+ INTERNLM2_ATTENTION_CLASSES = {
588
+ "eager": InternLM2Attention,
589
+ "flash_attention_2": InternLM2FlashAttention2,
590
+ }
591
+
592
+ # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
593
+ class InternLM2DecoderLayer(nn.Module):
594
+ def __init__(self, config: InternLM2Config):
595
+ super().__init__()
596
+ self.hidden_size = config.hidden_size
597
+
598
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
599
+
600
+ self.feed_forward = InternLM2MLP(config)
601
+ self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
602
+ self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
603
+
604
+ def forward(
605
+ self,
606
+ hidden_states: torch.Tensor,
607
+ attention_mask: Optional[torch.Tensor] = None,
608
+ position_ids: Optional[torch.LongTensor] = None,
609
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
610
+ output_attentions: Optional[bool] = False,
611
+ use_cache: Optional[bool] = False,
612
+ **kwargs,
613
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
614
+ """
615
+ Args:
616
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
617
+ attention_mask (`torch.FloatTensor`, *optional*):
618
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
619
+ query_sequence_length, key_sequence_length)` if default attention is used.
620
+ output_attentions (`bool`, *optional*):
621
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
622
+ returned tensors for more detail.
623
+ use_cache (`bool`, *optional*):
624
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
625
+ (see `past_key_values`).
626
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
627
+ """
628
+ if "padding_mask" in kwargs:
629
+ warnings.warn(
630
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
631
+ "Please make sure use `attention_mask` instead.`"
632
+ )
633
+
634
+ residual = hidden_states
635
+
636
+ hidden_states = self.attention_norm(hidden_states)
637
+
638
+ # Self Attention
639
+ hidden_states, self_attn_weights, present_key_value = self.attention(
640
+ hidden_states=hidden_states,
641
+ attention_mask=attention_mask,
642
+ position_ids=position_ids,
643
+ past_key_value=past_key_value,
644
+ output_attentions=output_attentions,
645
+ use_cache=use_cache,
646
+ **kwargs,
647
+ )
648
+ hidden_states = residual + hidden_states
649
+
650
+ # Fully Connected
651
+ residual = hidden_states
652
+ hidden_states = self.ffn_norm(hidden_states)
653
+ hidden_states = self.feed_forward(hidden_states)
654
+ hidden_states = residual + hidden_states
655
+
656
+ outputs = (hidden_states,)
657
+
658
+ if output_attentions:
659
+ outputs += (self_attn_weights,)
660
+
661
+ if use_cache:
662
+ outputs += (present_key_value,)
663
+
664
+ return outputs
665
+
666
+
667
+ InternLM2_START_DOCSTRING = r"""
668
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
669
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
670
+ etc.)
671
+
672
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
673
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
674
+ and behavior.
675
+
676
+ Parameters:
677
+ config ([`InternLM2Config`]):
678
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
679
+ load the weights associated with the model, only the configuration. Check out the
680
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
681
+ """
682
+
683
+
684
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
685
+ @add_start_docstrings(
686
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
687
+ InternLM2_START_DOCSTRING,
688
+ )
689
+ class InternLM2PreTrainedModel(PreTrainedModel):
690
+ config_class = InternLM2Config
691
+ base_model_prefix = "model"
692
+ supports_gradient_checkpointing = True
693
+ _no_split_modules = ["InternLM2DecoderLayer"]
694
+ _skip_keys_device_placement = "past_key_values"
695
+
696
+ def _init_weights(self, module):
697
+ std = self.config.initializer_range
698
+ if isinstance(module, nn.Linear):
699
+ module.weight.data.normal_(mean=0.0, std=std)
700
+ if module.bias is not None:
701
+ module.bias.data.zero_()
702
+ elif isinstance(module, nn.Embedding):
703
+ module.weight.data.normal_(mean=0.0, std=std)
704
+ if module.padding_idx is not None:
705
+ module.weight.data[module.padding_idx].zero_()
706
+
707
+
708
+ InternLM2_INPUTS_DOCSTRING = r"""
709
+ Args:
710
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
711
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
712
+ it.
713
+
714
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
715
+ [`PreTrainedTokenizer.__call__`] for details.
716
+
717
+ [What are input IDs?](../glossary#input-ids)
718
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
719
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
720
+
721
+ - 1 for tokens that are **not masked**,
722
+ - 0 for tokens that are **masked**.
723
+
724
+ [What are attention masks?](../glossary#attention-mask)
725
+
726
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
727
+ [`PreTrainedTokenizer.__call__`] for details.
728
+
729
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
730
+ `past_key_values`).
731
+
732
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
733
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
734
+ information on the default strategy.
735
+
736
+ - 1 indicates the head is **not masked**,
737
+ - 0 indicates the head is **masked**.
738
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
739
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
740
+ config.n_positions - 1]`.
741
+
742
+ [What are position IDs?](../glossary#position-ids)
743
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
744
+ when `config.use_cache=True`):
745
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
746
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
747
+ `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
748
+
749
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
750
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
751
+
752
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
753
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
754
+ of shape `(batch_size, sequence_length)`.
755
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
756
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
757
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
758
+ model's internal embedding lookup matrix.
759
+ use_cache (`bool`, *optional*):
760
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
761
+ `past_key_values`).
762
+ output_attentions (`bool`, *optional*):
763
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
764
+ tensors for more detail.
765
+ output_hidden_states (`bool`, *optional*):
766
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
767
+ more detail.
768
+ return_dict (`bool`, *optional*):
769
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
770
+ """
771
+
772
+
773
+ # Modified from transformers.model.llama.modeling_llama.LlamaModel
774
+ @add_start_docstrings(
775
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
776
+ InternLM2_START_DOCSTRING,
777
+ )
778
+ class InternLM2Model(InternLM2PreTrainedModel):
779
+ """
780
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
781
+
782
+ Args:
783
+ config: InternLM2Config
784
+ """
785
+
786
+ _auto_class = "AutoModel"
787
+
788
+ def __init__(self, config: InternLM2Config):
789
+ super().__init__(config)
790
+ self.padding_idx = config.pad_token_id
791
+ self.vocab_size = config.vocab_size
792
+ self.config = config
793
+
794
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
795
+
796
+ self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
797
+ self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
798
+
799
+ self.gradient_checkpointing = False
800
+ # Initialize weights and apply final processing
801
+ self.post_init()
802
+
803
+ def get_input_embeddings(self):
804
+ return self.tok_embeddings
805
+
806
+ def set_input_embeddings(self, value):
807
+ self.tok_embeddings = value
808
+
809
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
810
+ # create causal mask
811
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
812
+ combined_attention_mask = None
813
+ if input_shape[-1] > 1:
814
+ combined_attention_mask = _make_causal_mask(
815
+ input_shape,
816
+ inputs_embeds.dtype,
817
+ device=inputs_embeds.device,
818
+ past_key_values_length=past_key_values_length,
819
+ )
820
+
821
+ if attention_mask is not None:
822
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
823
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
824
+ inputs_embeds.device
825
+ )
826
+ combined_attention_mask = (
827
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
828
+ )
829
+
830
+ return combined_attention_mask
831
+
832
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
833
+ def forward(
834
+ self,
835
+ input_ids: torch.LongTensor = None,
836
+ attention_mask: Optional[torch.Tensor] = None,
837
+ position_ids: Optional[torch.LongTensor] = None,
838
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
839
+ inputs_embeds: Optional[torch.FloatTensor] = None,
840
+ use_cache: Optional[bool] = None,
841
+ output_attentions: Optional[bool] = None,
842
+ output_hidden_states: Optional[bool] = None,
843
+ return_dict: Optional[bool] = None,
844
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
845
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
846
+ output_hidden_states = (
847
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
848
+ )
849
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
850
+
851
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
852
+
853
+ if self.config.attn_implementation == "flash_attention_2":
854
+ _import_flash_attn()
855
+
856
+ # retrieve input_ids and inputs_embeds
857
+ if input_ids is not None and inputs_embeds is not None:
858
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
859
+ elif input_ids is not None:
860
+ batch_size, seq_length = input_ids.shape[:2]
861
+ elif inputs_embeds is not None:
862
+ batch_size, seq_length = inputs_embeds.shape[:2]
863
+ else:
864
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
865
+
866
+ seq_length_with_past = seq_length
867
+ past_key_values_length = 0
868
+ if past_key_values is not None:
869
+ past_key_values_length = past_key_values[0][0].shape[2]
870
+ seq_length_with_past = seq_length_with_past + past_key_values_length
871
+
872
+ if position_ids is None:
873
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
874
+ position_ids = torch.arange(
875
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
876
+ )
877
+ position_ids = position_ids.unsqueeze(0)
878
+
879
+ if inputs_embeds is None:
880
+ inputs_embeds = self.tok_embeddings(input_ids)
881
+
882
+ if self.config.attn_implementation == "flash_attention_2":
883
+ # 2d mask is passed through the layers
884
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
885
+ else:
886
+ if attention_mask is None:
887
+ attention_mask = torch.ones(
888
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
889
+ )
890
+ attention_mask = self._prepare_decoder_attention_mask(
891
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
892
+ )
893
+
894
+ # embed positions
895
+ hidden_states = inputs_embeds
896
+
897
+ if self.gradient_checkpointing and self.training:
898
+ if use_cache:
899
+ logger.warning_once(
900
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
901
+ )
902
+ use_cache = False
903
+
904
+ # decoder layers
905
+ all_hidden_states = () if output_hidden_states else None
906
+ all_self_attns = () if output_attentions else None
907
+ next_decoder_cache = () if use_cache else None
908
+
909
+ for idx, decoder_layer in enumerate(self.layers):
910
+ if output_hidden_states:
911
+ all_hidden_states += (hidden_states,)
912
+
913
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
914
+
915
+ if self.gradient_checkpointing and self.training:
916
+
917
+ def create_custom_forward(module):
918
+ def custom_forward(*inputs):
919
+ # None for past_key_value
920
+ return module(*inputs, output_attentions, None)
921
+
922
+ return custom_forward
923
+
924
+ layer_outputs = torch.utils.checkpoint.checkpoint(
925
+ create_custom_forward(decoder_layer),
926
+ hidden_states,
927
+ attention_mask,
928
+ position_ids,
929
+ None,
930
+ )
931
+ else:
932
+ layer_outputs = decoder_layer(
933
+ hidden_states,
934
+ attention_mask=attention_mask,
935
+ position_ids=position_ids,
936
+ past_key_value=past_key_value,
937
+ output_attentions=output_attentions,
938
+ use_cache=use_cache,
939
+ )
940
+
941
+ hidden_states = layer_outputs[0]
942
+
943
+ if use_cache:
944
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
945
+
946
+ if output_attentions:
947
+ all_self_attns += (layer_outputs[1],)
948
+
949
+ hidden_states = self.norm(hidden_states)
950
+
951
+ # add hidden states from the last decoder layer
952
+ if output_hidden_states:
953
+ all_hidden_states += (hidden_states,)
954
+
955
+ next_cache = next_decoder_cache if use_cache else None
956
+ if not return_dict:
957
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
958
+ return BaseModelOutputWithPast(
959
+ last_hidden_state=hidden_states,
960
+ past_key_values=next_cache,
961
+ hidden_states=all_hidden_states,
962
+ attentions=all_self_attns,
963
+ )
964
+
965
+
966
+ # Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
967
+ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
968
+ _auto_class = "AutoModelForCausalLM"
969
+
970
+ _tied_weights_keys = ["output.weight"]
971
+
972
+ def __init__(self, config):
973
+ super().__init__(config)
974
+ self.model = InternLM2Model(config)
975
+ self.vocab_size = config.vocab_size
976
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
977
+
978
+ # Initialize weights and apply final processing
979
+ self.post_init()
980
+
981
+ def get_input_embeddings(self):
982
+ return self.model.tok_embeddings
983
+
984
+ def set_input_embeddings(self, value):
985
+ self.model.tok_embeddings = value
986
+
987
+ def get_output_embeddings(self):
988
+ return self.output
989
+
990
+ def set_output_embeddings(self, new_embeddings):
991
+ self.output = new_embeddings
992
+
993
+ def set_decoder(self, decoder):
994
+ self.model = decoder
995
+
996
+ def get_decoder(self):
997
+ return self.model
998
+
999
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1000
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1001
+ def forward(
1002
+ self,
1003
+ input_ids: torch.LongTensor = None,
1004
+ attention_mask: Optional[torch.Tensor] = None,
1005
+ position_ids: Optional[torch.LongTensor] = None,
1006
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1007
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1008
+ labels: Optional[torch.LongTensor] = None,
1009
+ use_cache: Optional[bool] = None,
1010
+ output_attentions: Optional[bool] = None,
1011
+ output_hidden_states: Optional[bool] = None,
1012
+ return_dict: Optional[bool] = None,
1013
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1014
+ r"""
1015
+ Args:
1016
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1017
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1018
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1019
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1020
+
1021
+ Returns:
1022
+
1023
+ Example:
1024
+
1025
+ ```python
1026
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
1027
+
1028
+ >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1029
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1030
+
1031
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1032
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1033
+
1034
+ >>> # Generate
1035
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1036
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1037
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1038
+ ```"""
1039
+
1040
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1041
+ output_hidden_states = (
1042
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1043
+ )
1044
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1045
+
1046
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1047
+ outputs = self.model(
1048
+ input_ids=input_ids,
1049
+ attention_mask=attention_mask,
1050
+ position_ids=position_ids,
1051
+ past_key_values=past_key_values,
1052
+ inputs_embeds=inputs_embeds,
1053
+ use_cache=use_cache,
1054
+ output_attentions=output_attentions,
1055
+ output_hidden_states=output_hidden_states,
1056
+ return_dict=return_dict,
1057
+ )
1058
+
1059
+ hidden_states = outputs[0]
1060
+ logits = self.output(hidden_states)
1061
+ logits = logits.float()
1062
+
1063
+ loss = None
1064
+ if labels is not None:
1065
+ # Shift so that tokens < n predict n
1066
+ shift_logits = logits[..., :-1, :].contiguous()
1067
+ shift_labels = labels[..., 1:].contiguous()
1068
+ # Flatten the tokens
1069
+ loss_fct = CrossEntropyLoss()
1070
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1071
+ shift_labels = shift_labels.view(-1)
1072
+ # Enable model parallelism
1073
+ shift_labels = shift_labels.to(shift_logits.device)
1074
+ loss = loss_fct(shift_logits, shift_labels)
1075
+
1076
+ if not return_dict:
1077
+ output = (logits,) + outputs[1:]
1078
+ return (loss,) + output if loss is not None else output
1079
+
1080
+ return CausalLMOutputWithPast(
1081
+ loss=loss,
1082
+ logits=logits,
1083
+ past_key_values=outputs.past_key_values,
1084
+ hidden_states=outputs.hidden_states,
1085
+ attentions=outputs.attentions,
1086
+ )
1087
+
1088
+ def prepare_inputs_for_generation(
1089
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1090
+ ):
1091
+ if past_key_values is not None:
1092
+ past_length = past_key_values[0][0].shape[2]
1093
+
1094
+ # Some generation methods already pass only the last input ID
1095
+ if input_ids.shape[1] > past_length:
1096
+ remove_prefix_length = past_length
1097
+ else:
1098
+ # Default to old behavior: keep only final ID
1099
+ remove_prefix_length = input_ids.shape[1] - 1
1100
+
1101
+ input_ids = input_ids[:, remove_prefix_length:]
1102
+
1103
+ position_ids = kwargs.get("position_ids", None)
1104
+ if attention_mask is not None and position_ids is None:
1105
+ # create position_ids on the fly for batch generation
1106
+ position_ids = attention_mask.long().cumsum(-1) - 1
1107
+ position_ids.masked_fill_(attention_mask == 0, 1)
1108
+ if past_key_values:
1109
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1110
+
1111
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1112
+ if inputs_embeds is not None and past_key_values is None:
1113
+ model_inputs = {"inputs_embeds": inputs_embeds}
1114
+ else:
1115
+ model_inputs = {"input_ids": input_ids}
1116
+
1117
+ model_inputs.update(
1118
+ {
1119
+ "position_ids": position_ids,
1120
+ "past_key_values": past_key_values,
1121
+ "use_cache": kwargs.get("use_cache"),
1122
+ "attention_mask": attention_mask,
1123
+ }
1124
+ )
1125
+ return model_inputs
1126
+
1127
+ @staticmethod
1128
+ def _reorder_cache(past_key_values, beam_idx):
1129
+ reordered_past = ()
1130
+ for layer_past in past_key_values:
1131
+ reordered_past += (
1132
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1133
+ )
1134
+ return reordered_past
1135
+
1136
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
1137
+ if tokenizer.add_bos_token:
1138
+ prompt = ""
1139
+ else:
1140
+ prompt = tokenizer.bos_token
1141
+ if meta_instruction:
1142
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1143
+ for record in history:
1144
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1145
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
1146
+ return tokenizer([prompt], return_tensors="pt")
1147
+
1148
+ @torch.no_grad()
1149
+ def chat(
1150
+ self,
1151
+ tokenizer,
1152
+ query: str,
1153
+ history: List[Tuple[str, str]] = [],
1154
+ streamer: Optional[BaseStreamer] = None,
1155
+ max_new_tokens: int = 1024,
1156
+ do_sample: bool = True,
1157
+ temperature: float = 0.8,
1158
+ top_p: float = 0.8,
1159
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1160
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1161
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
1162
+ **kwargs,
1163
+ ):
1164
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1165
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1166
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1167
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]]
1168
+ outputs = self.generate(
1169
+ **inputs,
1170
+ streamer=streamer,
1171
+ max_new_tokens=max_new_tokens,
1172
+ do_sample=do_sample,
1173
+ temperature=temperature,
1174
+ top_p=top_p,
1175
+ eos_token_id=eos_token_id,
1176
+ **kwargs,
1177
+ )
1178
+ outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1179
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
1180
+ response = response.split("<|im_end|>")[0]
1181
+ history = history + [(query, response)]
1182
+ return response, history
1183
+
1184
+ @torch.no_grad()
1185
+ def stream_chat(
1186
+ self,
1187
+ tokenizer,
1188
+ query: str,
1189
+ history: List[Tuple[str, str]] = [],
1190
+ max_new_tokens: int = 1024,
1191
+ do_sample: bool = True,
1192
+ temperature: float = 0.8,
1193
+ top_p: float = 0.8,
1194
+ **kwargs,
1195
+ ):
1196
+ """
1197
+ Return a generator in format: (response, history)
1198
+ Eg.
1199
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
1200
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
1201
+ """
1202
+ if BaseStreamer is None:
1203
+ raise ModuleNotFoundError(
1204
+ "The version of `transformers` is too low. Please make sure "
1205
+ "that you have installed `transformers>=4.28.0`."
1206
+ )
1207
+
1208
+ response_queue = queue.Queue(maxsize=20)
1209
+
1210
+ class ChatStreamer(BaseStreamer):
1211
+ def __init__(self, tokenizer) -> None:
1212
+ super().__init__()
1213
+ self.tokenizer = tokenizer
1214
+ self.queue = response_queue
1215
+ self.query = query
1216
+ self.history = history
1217
+ self.response = ""
1218
+ self.cache = []
1219
+ self.received_inputs = False
1220
+ self.queue.put((self.response, history + [(self.query, self.response)]))
1221
+
1222
+ def put(self, value):
1223
+ if len(value.shape) > 1 and value.shape[0] > 1:
1224
+ raise ValueError("ChatStreamer only supports batch size 1")
1225
+ elif len(value.shape) > 1:
1226
+ value = value[0]
1227
+
1228
+ if not self.received_inputs:
1229
+ # The first received value is input_ids, ignore here
1230
+ self.received_inputs = True
1231
+ return
1232
+
1233
+ self.cache.extend(value.tolist())
1234
+ token = self.tokenizer.decode(self.cache, skip_special_tokens=True)
1235
+ if token.strip() != "<|im_end|>":
1236
+ self.response = self.response + token
1237
+ history = self.history + [(self.query, self.response)]
1238
+ self.queue.put((self.response, history))
1239
+ self.cache = []
1240
+ else:
1241
+ self.end()
1242
+
1243
+ def end(self):
1244
+ self.queue.put(None)
1245
+
1246
+ def stream_producer():
1247
+ return self.chat(
1248
+ tokenizer=tokenizer,
1249
+ query=query,
1250
+ streamer=ChatStreamer(tokenizer=tokenizer),
1251
+ history=history,
1252
+ max_new_tokens=max_new_tokens,
1253
+ do_sample=do_sample,
1254
+ temperature=temperature,
1255
+ top_p=top_p,
1256
+ **kwargs,
1257
+ )
1258
+
1259
+ def consumer():
1260
+ producer = threading.Thread(target=stream_producer)
1261
+ producer.start()
1262
+ while True:
1263
+ res = response_queue.get()
1264
+ if res is None:
1265
+ return
1266
+ yield res
1267
+
1268
+ return consumer()
1269
+
1270
+
1271
+ # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1272
+ @add_start_docstrings(
1273
+ """
1274
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
1275
+
1276
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
1277
+ as other causal models (e.g. GPT-2) do.
1278
+
1279
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1280
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1281
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1282
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1283
+ each row of the batch).
1284
+ """,
1285
+ InternLM2_START_DOCSTRING,
1286
+ )
1287
+ class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
1288
+ def __init__(self, config):
1289
+ super().__init__(config)
1290
+ self.num_labels = config.num_labels
1291
+ self.model = InternLM2Model(config)
1292
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1293
+
1294
+ # Initialize weights and apply final processing
1295
+ self.post_init()
1296
+
1297
+ def get_input_embeddings(self):
1298
+ return self.model.tok_embeddings
1299
+
1300
+ def set_input_embeddings(self, value):
1301
+ self.model.tok_embeddings = value
1302
+
1303
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1304
+ def forward(
1305
+ self,
1306
+ input_ids: torch.LongTensor = None,
1307
+ attention_mask: Optional[torch.Tensor] = None,
1308
+ position_ids: Optional[torch.LongTensor] = None,
1309
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1310
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1311
+ labels: Optional[torch.LongTensor] = None,
1312
+ use_cache: Optional[bool] = None,
1313
+ output_attentions: Optional[bool] = None,
1314
+ output_hidden_states: Optional[bool] = None,
1315
+ return_dict: Optional[bool] = None,
1316
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1317
+ r"""
1318
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1319
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1320
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1321
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1322
+ """
1323
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1324
+
1325
+ transformer_outputs = self.model(
1326
+ input_ids,
1327
+ attention_mask=attention_mask,
1328
+ position_ids=position_ids,
1329
+ past_key_values=past_key_values,
1330
+ inputs_embeds=inputs_embeds,
1331
+ use_cache=use_cache,
1332
+ output_attentions=output_attentions,
1333
+ output_hidden_states=output_hidden_states,
1334
+ return_dict=return_dict,
1335
+ )
1336
+ hidden_states = transformer_outputs[0]
1337
+ logits = self.score(hidden_states)
1338
+
1339
+ if input_ids is not None:
1340
+ batch_size = input_ids.shape[0]
1341
+ else:
1342
+ batch_size = inputs_embeds.shape[0]
1343
+
1344
+ if self.config.pad_token_id is None and batch_size != 1:
1345
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1346
+ if self.config.pad_token_id is None:
1347
+ sequence_lengths = -1
1348
+ else:
1349
+ if input_ids is not None:
1350
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1351
+ logits.device
1352
+ )
1353
+ else:
1354
+ sequence_lengths = -1
1355
+
1356
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1357
+
1358
+ loss = None
1359
+ if labels is not None:
1360
+ labels = labels.to(logits.device)
1361
+ if self.config.problem_type is None:
1362
+ if self.num_labels == 1:
1363
+ self.config.problem_type = "regression"
1364
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1365
+ self.config.problem_type = "single_label_classification"
1366
+ else:
1367
+ self.config.problem_type = "multi_label_classification"
1368
+
1369
+ if self.config.problem_type == "regression":
1370
+ loss_fct = MSELoss()
1371
+ if self.num_labels == 1:
1372
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1373
+ else:
1374
+ loss = loss_fct(pooled_logits, labels)
1375
+ elif self.config.problem_type == "single_label_classification":
1376
+ loss_fct = CrossEntropyLoss()
1377
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1378
+ elif self.config.problem_type == "multi_label_classification":
1379
+ loss_fct = BCEWithLogitsLoss()
1380
+ loss = loss_fct(pooled_logits, labels)
1381
+ if not return_dict:
1382
+ output = (pooled_logits,) + transformer_outputs[1:]
1383
+ return ((loss,) + output) if loss is not None else output
1384
+
1385
+ return SequenceClassifierOutputWithPast(
1386
+ loss=loss,
1387
+ logits=pooled_logits,
1388
+ past_key_values=transformer_outputs.past_key_values,
1389
+ hidden_states=transformer_outputs.hidden_states,
1390
+ attentions=transformer_outputs.attentions,
1391
+ )
epoch1_ckpt/projector/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ProjectorModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_projector.ProjectorConfig",
7
+ "AutoModel": "modeling_projector.ProjectorModel"
8
+ },
9
+ "bias": true,
10
+ "depth": 2,
11
+ "hidden_act": "gelu",
12
+ "llm_hidden_size": 2048,
13
+ "model_type": "projector",
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.39.0.dev0",
16
+ "visual_hidden_size": 2176
17
+ }
epoch1_ckpt/projector/configuration_projector.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class ProjectorConfig(PretrainedConfig):
6
+ model_type = 'projector'
7
+ _auto_class = 'AutoConfig'
8
+
9
+ def __init__(
10
+ self,
11
+ visual_hidden_size=4096,
12
+ llm_hidden_size=4096,
13
+ depth=2,
14
+ hidden_act='gelu',
15
+ bias=True,
16
+ **kwargs,
17
+ ):
18
+ self.visual_hidden_size = visual_hidden_size
19
+ self.llm_hidden_size = llm_hidden_size
20
+ self.depth = depth
21
+ self.hidden_act = hidden_act
22
+ self.bias = bias
23
+ super().__init__(**kwargs)
epoch1_ckpt/projector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5efd8615f6485bc1ac4274eddcb038f7e99d43dcbe2dff49685da212eba051c1
3
+ size 34619760
epoch1_ckpt/projector/modeling_projector.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel
5
+ from transformers.activations import ACT2FN
6
+
7
+ from .configuration_projector import ProjectorConfig
8
+
9
+
10
+ class ProjectorModel(PreTrainedModel):
11
+ _auto_class = 'AutoModel'
12
+ config_class = ProjectorConfig
13
+ base_model_prefix = 'model'
14
+ supports_gradient_checkpointing = True
15
+
16
+ def __init__(self, config: ProjectorConfig) -> None:
17
+ super().__init__(config)
18
+ self.gradient_checkpointing = False
19
+
20
+ modules = [
21
+ nn.Linear(
22
+ config.visual_hidden_size,
23
+ config.llm_hidden_size,
24
+ bias=config.bias)
25
+ ]
26
+ for _ in range(1, config.depth):
27
+ modules.append(ACT2FN[config.hidden_act])
28
+ modules.append(
29
+ nn.Linear(
30
+ config.llm_hidden_size,
31
+ config.llm_hidden_size,
32
+ bias=config.bias))
33
+ self.model = nn.Sequential(*modules)
34
+
35
+ def enable_input_require_grads(self):
36
+
37
+ def make_inputs_require_grad(module, input, output):
38
+ output.requires_grad_(True)
39
+
40
+ self.model.register_forward_hook(make_inputs_require_grad)
41
+
42
+ def _set_gradient_checkpointing(self, module, value=False):
43
+ if isinstance(module, ProjectorModel):
44
+ module.gradient_checkpointing = value
45
+
46
+ def forward(self, x):
47
+ if self.gradient_checkpointing and self.training:
48
+ layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
49
+ else:
50
+ layer_outputs = self.model(x)
51
+ return layer_outputs
epoch1_ckpt/special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "</s>",
5
+ "unk_token": "<unk>"
6
+ }
epoch1_ckpt/tokenization_internlm2.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization classes for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, List, Optional, Tuple
22
+
23
+ import sentencepiece as spm
24
+ from transformers.tokenization_utils import PreTrainedTokenizer
25
+ from transformers.utils import logging
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {}
32
+
33
+
34
+ # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
35
+ class InternLM2Tokenizer(PreTrainedTokenizer):
36
+ """
37
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
38
+
39
+ Args:
40
+ vocab_file (`str`):
41
+ Path to the vocabulary file.
42
+ """
43
+
44
+ vocab_files_names = VOCAB_FILES_NAMES
45
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
46
+ model_input_names = ["input_ids", "attention_mask"]
47
+ _auto_class = "AutoTokenizer"
48
+
49
+ def __init__(
50
+ self,
51
+ vocab_file,
52
+ unk_token="<unk>",
53
+ bos_token="<s>",
54
+ eos_token="</s>",
55
+ pad_token="</s>",
56
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
57
+ add_bos_token=True,
58
+ add_eos_token=False,
59
+ decode_with_prefix_space=False,
60
+ clean_up_tokenization_spaces=False,
61
+ **kwargs,
62
+ ):
63
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
64
+ self.vocab_file = vocab_file
65
+ self.add_bos_token = add_bos_token
66
+ self.add_eos_token = add_eos_token
67
+ self.decode_with_prefix_space = decode_with_prefix_space
68
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
69
+ self.sp_model.Load(vocab_file)
70
+ self._no_prefix_space_tokens = None
71
+ super().__init__(
72
+ bos_token=bos_token,
73
+ eos_token=eos_token,
74
+ unk_token=unk_token,
75
+ pad_token=pad_token,
76
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
77
+ **kwargs,
78
+ )
79
+
80
+ @property
81
+ def no_prefix_space_tokens(self):
82
+ if self._no_prefix_space_tokens is None:
83
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
84
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
85
+ return self._no_prefix_space_tokens
86
+
87
+ @property
88
+ def vocab_size(self):
89
+ """Returns vocab size"""
90
+ return self.sp_model.get_piece_size()
91
+
92
+ @property
93
+ def bos_token_id(self) -> Optional[int]:
94
+ return self.sp_model.bos_id()
95
+
96
+ @property
97
+ def eos_token_id(self) -> Optional[int]:
98
+ return self.sp_model.eos_id()
99
+
100
+ def get_vocab(self):
101
+ """Returns vocab as a dict"""
102
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
103
+ vocab.update(self.added_tokens_encoder)
104
+ return vocab
105
+
106
+ def _tokenize(self, text):
107
+ """Returns a tokenized string."""
108
+ return self.sp_model.encode(text, out_type=str)
109
+
110
+ def _convert_token_to_id(self, token):
111
+ """Converts a token (str) in an id using the vocab."""
112
+ return self.sp_model.piece_to_id(token)
113
+
114
+ def _convert_id_to_token(self, index):
115
+ """Converts an index (integer) in a token (str) using the vocab."""
116
+ token = self.sp_model.IdToPiece(index)
117
+ return token
118
+
119
+ def _maybe_add_prefix_space(self, tokens, decoded):
120
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
121
+ return " " + decoded
122
+ else:
123
+ return decoded
124
+
125
+ def convert_tokens_to_string(self, tokens):
126
+ """Converts a sequence of tokens (string) in a single string."""
127
+ current_sub_tokens = []
128
+ out_string = ""
129
+ prev_is_special = False
130
+ for token in tokens:
131
+ # make sure that special tokens are not decoded using sentencepiece model
132
+ if token in self.all_special_tokens:
133
+ if not prev_is_special:
134
+ out_string += " "
135
+ out_string += self.sp_model.decode(current_sub_tokens) + token
136
+ prev_is_special = True
137
+ current_sub_tokens = []
138
+ else:
139
+ current_sub_tokens.append(token)
140
+ prev_is_special = False
141
+ out_string += self.sp_model.decode(current_sub_tokens)
142
+ out_string = self.clean_up_tokenization(out_string)
143
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
144
+ return out_string[1:]
145
+
146
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
147
+ """
148
+ Save the vocabulary and special tokens file to a directory.
149
+
150
+ Args:
151
+ save_directory (`str`):
152
+ The directory in which to save the vocabulary.
153
+
154
+ Returns:
155
+ `Tuple(str)`: Paths to the files saved.
156
+ """
157
+ if not os.path.isdir(save_directory):
158
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
159
+ return
160
+ out_vocab_file = os.path.join(
161
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
162
+ )
163
+
164
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
165
+ copyfile(self.vocab_file, out_vocab_file)
166
+ elif not os.path.isfile(self.vocab_file):
167
+ with open(out_vocab_file, "wb") as fi:
168
+ content_spiece_model = self.sp_model.serialized_model_proto()
169
+ fi.write(content_spiece_model)
170
+
171
+ return (out_vocab_file,)
172
+
173
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
174
+ if self.add_bos_token:
175
+ bos_token_ids = [self.bos_token_id]
176
+ else:
177
+ bos_token_ids = []
178
+
179
+ output = bos_token_ids + token_ids_0
180
+
181
+ if token_ids_1 is not None:
182
+ output = output + token_ids_1
183
+
184
+ if self.add_eos_token:
185
+ output = output + [self.eos_token_id]
186
+
187
+ return output
188
+
189
+ def get_special_tokens_mask(
190
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
191
+ ) -> List[int]:
192
+ """
193
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
194
+ special tokens using the tokenizer `prepare_for_model` method.
195
+
196
+ Args:
197
+ token_ids_0 (`List[int]`):
198
+ List of IDs.
199
+ token_ids_1 (`List[int]`, *optional*):
200
+ Optional second list of IDs for sequence pairs.
201
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
202
+ Whether or not the token list is already formatted with special tokens for the model.
203
+
204
+ Returns:
205
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
206
+ """
207
+ if already_has_special_tokens:
208
+ return super().get_special_tokens_mask(
209
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
210
+ )
211
+
212
+ if token_ids_1 is None:
213
+ return [1] + ([0] * len(token_ids_0)) + [1]
214
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
215
+
216
+ def create_token_type_ids_from_sequences(
217
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
218
+ ) -> List[int]:
219
+ """
220
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
221
+ use of token type ids, therefore a list of zeros is returned.
222
+
223
+ Args:
224
+ token_ids_0 (`List[int]`):
225
+ List of IDs.
226
+ token_ids_1 (`List[int]`, *optional*):
227
+ Optional second list of IDs for sequence pairs.
228
+
229
+ Returns:
230
+ `List[int]`: List of zeros.
231
+ """
232
+ eos = [self.eos_token_id]
233
+
234
+ if token_ids_1 is None:
235
+ return len(token_ids_0 + eos) * [0]
236
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
epoch1_ckpt/tokenization_internlm2_fast.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization Fast class for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, Optional, Tuple
22
+
23
+ from tokenizers import processors, decoders, Tokenizer, normalizers
24
+ from tokenizers.models import BPE
25
+
26
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
27
+ from transformers.utils import logging
28
+
29
+ from transformers.convert_slow_tokenizer import (
30
+ SLOW_TO_FAST_CONVERTERS,
31
+ SpmConverter,
32
+ SentencePieceExtractor,
33
+ )
34
+
35
+ from .tokenization_internlm2 import InternLM2Tokenizer
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
40
+
41
+ # Modified from transformers.convert_slow_tokenizer.LlamaConverter
42
+ class InternLM2Converter(SpmConverter):
43
+ handle_byte_fallback = True
44
+
45
+ def vocab(self, proto):
46
+ vocab = [
47
+ ("<unk>", 0.0),
48
+ ("<s>", 0.0),
49
+ ("</s>", 0.0),
50
+ ]
51
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
52
+ return vocab
53
+
54
+ def unk_id(self, proto):
55
+ unk_id = 0
56
+ return unk_id
57
+
58
+ def decoder(self, replacement, add_prefix_space):
59
+ decoders_sequence = [
60
+ decoders.Replace("▁", " "),
61
+ decoders.ByteFallback(),
62
+ decoders.Fuse(),
63
+ ]
64
+ if self.proto.normalizer_spec.add_dummy_prefix:
65
+ decoders_sequence.append(decoders.Strip(content=" ", left=1))
66
+ return decoders.Sequence(decoders_sequence)
67
+
68
+ def tokenizer(self, proto):
69
+ model_type = proto.trainer_spec.model_type
70
+ vocab_scores = self.vocab(proto)
71
+ # special tokens
72
+ added_tokens = self.original_tokenizer.added_tokens_decoder
73
+ for i in range(len(vocab_scores)):
74
+ piece, score = vocab_scores[i]
75
+ if i in added_tokens:
76
+ vocab_scores[i] = (added_tokens[i].content, score)
77
+ if model_type == 1:
78
+ raise RuntimeError("InternLM2 is supposed to be a BPE model!")
79
+
80
+ elif model_type == 2:
81
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
82
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
83
+ tokenizer = Tokenizer(
84
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
85
+ )
86
+ tokenizer.add_special_tokens(
87
+ [ added_token for index, added_token in added_tokens.items()]
88
+ )
89
+ else:
90
+ raise Exception(
91
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
92
+ )
93
+
94
+ return tokenizer
95
+
96
+ def normalizer(self, proto):
97
+ normalizers_list = []
98
+ if proto.normalizer_spec.add_dummy_prefix:
99
+ normalizers_list.append(normalizers.Prepend(prepend="▁"))
100
+ normalizers_list.append(normalizers.Replace(pattern=" ", content="▁"))
101
+ return normalizers.Sequence(normalizers_list)
102
+
103
+ def pre_tokenizer(self, replacement, add_prefix_space):
104
+ return None
105
+
106
+ SLOW_TO_FAST_CONVERTERS["InternLM2Tokenizer"] = InternLM2Converter
107
+
108
+
109
+ # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
110
+ class InternLM2TokenizerFast(PreTrainedTokenizerFast):
111
+ vocab_files_names = VOCAB_FILES_NAMES
112
+ slow_tokenizer_class = InternLM2Tokenizer
113
+ padding_side = "left"
114
+ model_input_names = ["input_ids", "attention_mask"]
115
+ _auto_class = "AutoTokenizer"
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_file,
120
+ unk_token="<unk>",
121
+ bos_token="<s>",
122
+ eos_token="</s>",
123
+ pad_token="</s>",
124
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
125
+ add_bos_token=True,
126
+ add_eos_token=False,
127
+ decode_with_prefix_space=False,
128
+ clean_up_tokenization_spaces=False,
129
+ **kwargs,
130
+ ):
131
+ super().__init__(
132
+ vocab_file=vocab_file,
133
+ unk_token=unk_token,
134
+ bos_token=bos_token,
135
+ eos_token=eos_token,
136
+ pad_token=pad_token,
137
+ sp_model_kwargs=sp_model_kwargs,
138
+ add_bos_token=add_bos_token,
139
+ add_eos_token=add_eos_token,
140
+ decode_with_prefix_space=decode_with_prefix_space,
141
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
142
+ **kwargs,
143
+ )
144
+ self._add_bos_token = add_bos_token
145
+ self._add_eos_token = add_eos_token
146
+ self.update_post_processor()
147
+ self.vocab_file = vocab_file
148
+
149
+ @property
150
+ def can_save_slow_tokenizer(self) -> bool:
151
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
152
+
153
+ def update_post_processor(self):
154
+ """
155
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
156
+ """
157
+ bos = self.bos_token
158
+ bos_token_id = self.bos_token_id
159
+ if bos is None and self.add_bos_token:
160
+ raise ValueError("add_bos_token = True but bos_token = None")
161
+
162
+ eos = self.eos_token
163
+ eos_token_id = self.eos_token_id
164
+ if eos is None and self.add_eos_token:
165
+ raise ValueError("add_eos_token = True but eos_token = None")
166
+
167
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
168
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
169
+
170
+ special_tokens = []
171
+ if self.add_bos_token:
172
+ special_tokens.append((bos, bos_token_id))
173
+ if self.add_eos_token:
174
+ special_tokens.append((eos, eos_token_id))
175
+ self._tokenizer.post_processor = processors.TemplateProcessing(
176
+ single=single, pair=pair, special_tokens=special_tokens
177
+ )
178
+
179
+ @property
180
+ def add_eos_token(self):
181
+ return self._add_eos_token
182
+
183
+ @property
184
+ def add_bos_token(self):
185
+ return self._add_bos_token
186
+
187
+ @add_eos_token.setter
188
+ def add_eos_token(self, value):
189
+ self._add_eos_token = value
190
+ self.update_post_processor()
191
+
192
+ @add_bos_token.setter
193
+ def add_bos_token(self, value):
194
+ self._add_bos_token = value
195
+ self.update_post_processor()
196
+
197
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
198
+ if not self.can_save_slow_tokenizer:
199
+ raise ValueError(
200
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
201
+ "tokenizer."
202
+ )
203
+
204
+ if not os.path.isdir(save_directory):
205
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
206
+ return
207
+ out_vocab_file = os.path.join(
208
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
209
+ )
210
+
211
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
212
+ copyfile(self.vocab_file, out_vocab_file)
213
+
214
+ return (out_vocab_file,)
epoch1_ckpt/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
epoch1_ckpt/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f868398fc4e05ee1e8aeba95ddf18ddcc45b8bce55d5093bead5bbf80429b48b
3
+ size 1477754
epoch1_ckpt/tokenizer_config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "auto_map": {
31
+ "AutoTokenizer": [
32
+ "tokenization_internlm2.InternLM2Tokenizer",
33
+ "tokenization_internlm2_fast.InternLM2TokenizerFast"
34
+ ]
35
+ },
36
+ "bos_token": "<s>",
37
+ "clean_up_tokenization_spaces": false,
38
+ "decode_with_prefix_space": false,
39
+ "eos_token": "</s>",
40
+ "model_max_length": 1000000000000000019884624838656,
41
+ "pad_token": "</s>",
42
+ "padding_side": "right",
43
+ "sp_model_kwargs": null,
44
+ "tokenizer_class": "InternLM2Tokenizer",
45
+ "unk_token": "<unk>"
46
+ }
epoch1_ckpt/xtuner_config.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SYSTEM = ''
2
+ accumulative_counts = 2
3
+ batch_size = 8
4
+ betas = (
5
+ 0.9,
6
+ 0.999,
7
+ )
8
+ custom_hooks = [
9
+ dict(
10
+ tokenizer=dict(
11
+ padding_side='right',
12
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
13
+ trust_remote_code=True,
14
+ type='transformers.AutoTokenizer.from_pretrained'),
15
+ type='xtuner.engine.hooks.DatasetInfoHook'),
16
+ dict(
17
+ evaluation_images='https://llava-vl.github.io/static/images/view.jpg',
18
+ evaluation_inputs=[
19
+ '请描述一下这张照片',
20
+ 'Please describe this picture',
21
+ ],
22
+ every_n_iters=500,
23
+ image_processor=dict(
24
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
25
+ trust_remote_code=True,
26
+ type='transformers.SiglipImageProcessor.from_pretrained'),
27
+ prompt_template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
28
+ system='',
29
+ tokenizer=dict(
30
+ padding_side='right',
31
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
32
+ trust_remote_code=True,
33
+ type='transformers.AutoTokenizer.from_pretrained'),
34
+ type='xtuner.engine.hooks.EvaluateChatHook'),
35
+ ]
36
+ data_path = './llava_data/llava_v1_5_lrv_mix1008k.json'
37
+ data_root = './llava_data/'
38
+ dataloader_num_workers = 4
39
+ default_hooks = dict(
40
+ checkpoint=dict(
41
+ by_epoch=False,
42
+ interval=500,
43
+ max_keep_ckpts=2,
44
+ type='mmengine.hooks.CheckpointHook'),
45
+ logger=dict(
46
+ interval=10,
47
+ log_metric_by_epoch=False,
48
+ type='mmengine.hooks.LoggerHook'),
49
+ param_scheduler=dict(type='mmengine.hooks.ParamSchedulerHook'),
50
+ sampler_seed=dict(type='mmengine.hooks.DistSamplerSeedHook'),
51
+ timer=dict(type='mmengine.hooks.IterTimerHook'))
52
+ dino_path = 'facebook/dinov2-large'
53
+ env_cfg = dict(
54
+ cudnn_benchmark=False,
55
+ dist_cfg=dict(backend='nccl'),
56
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
57
+ evaluation_freq = 500
58
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
59
+ evaluation_inputs = [
60
+ '请描述一下这张照片',
61
+ 'Please describe this picture',
62
+ ]
63
+ image_folder = './llava_data/llava_images'
64
+ image_processor = dict(
65
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
66
+ trust_remote_code=True,
67
+ type='transformers.SiglipImageProcessor.from_pretrained')
68
+ image_processor_path = 'google/siglip-so400m-patch14-384'
69
+ launcher = 'pytorch'
70
+ llava_dataset = dict(
71
+ data_path='./llava_data/llava_v1_5_lrv_mix1008k.json',
72
+ dataset_map_fn='xtuner.dataset.map_fns.llava_map_fn',
73
+ image_folder='./llava_data/llava_images',
74
+ image_processor=dict(
75
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
76
+ trust_remote_code=True,
77
+ type='transformers.SiglipImageProcessor.from_pretrained'),
78
+ max_length=1472,
79
+ pad_image_to_square=False,
80
+ template_map_fn=dict(
81
+ template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
82
+ type='xtuner.dataset.map_fns.template_map_fn_factory'),
83
+ tokenizer=dict(
84
+ padding_side='right',
85
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
86
+ trust_remote_code=True,
87
+ type='transformers.AutoTokenizer.from_pretrained'),
88
+ type='xtuner.dataset.LLaVADataset')
89
+ llm_name_or_path = 'internlm/internlm2-1_8b'
90
+ load_from = None
91
+ log_level = 'INFO'
92
+ log_processor = dict(by_epoch=False)
93
+ lr = 2e-05
94
+ max_epochs = 2
95
+ max_length = 1472
96
+ max_norm = 1
97
+ model = dict(
98
+ dino=dict(
99
+ pretrained_model_name_or_path='facebook/dinov2-large',
100
+ type='transformers.Dinov2Model.from_pretrained'),
101
+ freeze_llm=False,
102
+ freeze_visual_encoder=True,
103
+ llm=dict(
104
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
105
+ quantization_config=dict(
106
+ bnb_4bit_compute_dtype='torch.float16',
107
+ bnb_4bit_quant_type='nf4',
108
+ bnb_4bit_use_double_quant=True,
109
+ llm_int8_has_fp16_weight=False,
110
+ llm_int8_threshold=6.0,
111
+ load_in_4bit=True,
112
+ load_in_8bit=False,
113
+ type='transformers.BitsAndBytesConfig'),
114
+ torch_dtype='torch.float16',
115
+ trust_remote_code=True,
116
+ type='transformers.AutoModelForCausalLM.from_pretrained'),
117
+ siglip=dict(
118
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
119
+ type='transformers.SiglipVisionModel.from_pretrained'),
120
+ type='xtuner.model.LLaVAModel')
121
+ optim_type = 'torch.optim.AdamW'
122
+ optim_wrapper = dict(
123
+ optimizer=dict(
124
+ betas=(
125
+ 0.9,
126
+ 0.999,
127
+ ),
128
+ lr=2e-05,
129
+ type='torch.optim.AdamW',
130
+ weight_decay=0.1),
131
+ type='DeepSpeedOptimWrapper')
132
+ param_scheduler = [
133
+ dict(
134
+ begin=0,
135
+ by_epoch=True,
136
+ convert_to_iter_based=True,
137
+ end=0.06,
138
+ start_factor=1e-05,
139
+ type='mmengine.optim.LinearLR'),
140
+ dict(
141
+ begin=0.06,
142
+ by_epoch=True,
143
+ convert_to_iter_based=True,
144
+ end=2,
145
+ eta_min=0.0,
146
+ type='mmengine.optim.CosineAnnealingLR'),
147
+ ]
148
+ prompt_template = 'xtuner.utils.PROMPT_TEMPLATE.internlm2_chat'
149
+ randomness = dict(deterministic=False, seed=None)
150
+ resume = False
151
+ runner_type = 'FlexibleRunner'
152
+ save_steps = 500
153
+ save_total_limit = 2
154
+ siglip_path = 'google/siglip-so400m-patch14-384'
155
+ strategy = dict(
156
+ config=dict(
157
+ bf16=dict(enabled=True),
158
+ fp16=dict(enabled=False, initial_scale_power=16),
159
+ gradient_accumulation_steps='auto',
160
+ gradient_clipping='auto',
161
+ train_micro_batch_size_per_gpu='auto',
162
+ zero_allow_untested_optimizer=True,
163
+ zero_force_ds_cpu_optimizer=False,
164
+ zero_optimization=dict(overlap_comm=True, stage=2)),
165
+ exclude_frozen_parameters=True,
166
+ gradient_accumulation_steps=2,
167
+ gradient_clipping=1,
168
+ train_micro_batch_size_per_gpu=8,
169
+ type='xtuner.engine.DeepSpeedStrategy')
170
+ tokenizer = dict(
171
+ padding_side='right',
172
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
173
+ trust_remote_code=True,
174
+ type='transformers.AutoTokenizer.from_pretrained')
175
+ train_cfg = dict(max_epochs=2, type='xtuner.engine.runner.TrainLoop')
176
+ train_dataloader = dict(
177
+ batch_size=8,
178
+ collate_fn=dict(type='xtuner.dataset.collate_fns.default_collate_fn'),
179
+ dataset=dict(
180
+ data_path='./llava_data/llava_v1_5_lrv_mix1008k.json',
181
+ dataset_map_fn='xtuner.dataset.map_fns.llava_map_fn',
182
+ image_folder='./llava_data/llava_images',
183
+ image_processor=dict(
184
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
185
+ trust_remote_code=True,
186
+ type='transformers.SiglipImageProcessor.from_pretrained'),
187
+ max_length=1472,
188
+ pad_image_to_square=False,
189
+ template_map_fn=dict(
190
+ template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
191
+ type='xtuner.dataset.map_fns.template_map_fn_factory'),
192
+ tokenizer=dict(
193
+ padding_side='right',
194
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
195
+ trust_remote_code=True,
196
+ type='transformers.AutoTokenizer.from_pretrained'),
197
+ type='xtuner.dataset.LLaVADataset'),
198
+ num_workers=4,
199
+ sampler=dict(shuffle=True, type='mmengine.dataset.DefaultSampler'))
200
+ visualizer = dict(
201
+ type='mmengine.visualization.Visualizer',
202
+ vis_backends=[
203
+ dict(type='mmengine.visualization.TensorboardVisBackend'),
204
+ ])
205
+ warmup_ratio = 0.03
206
+ weight_decay = 0.1
207
+ work_dir = './work_dirs/train_config'
epoch2_ckpt/config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "internlm/internlm2-1_8b",
3
+ "architectures": [
4
+ "InternLM2ForCausalLM"
5
+ ],
6
+ "attn_implementation": "eager",
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_internlm2.InternLM2Config",
9
+ "AutoModel": "internlm/internlm2-1_8b--modeling_internlm2.InternLM2ForCausalLM",
10
+ "AutoModelForCausalLM": "modeling_internlm2.InternLM2ForCausalLM"
11
+ },
12
+ "bias": false,
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2048,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 8192,
19
+ "max_position_embeddings": 32768,
20
+ "model_type": "internlm2",
21
+ "num_attention_heads": 16,
22
+ "num_hidden_layers": 24,
23
+ "num_key_value_heads": 8,
24
+ "pad_token_id": 2,
25
+ "quantization_config": {
26
+ "_load_in_4bit": true,
27
+ "_load_in_8bit": false,
28
+ "bnb_4bit_compute_dtype": "float16",
29
+ "bnb_4bit_quant_type": "nf4",
30
+ "bnb_4bit_use_double_quant": true,
31
+ "llm_int8_enable_fp32_cpu_offload": false,
32
+ "llm_int8_has_fp16_weight": false,
33
+ "llm_int8_skip_modules": null,
34
+ "llm_int8_threshold": 6.0,
35
+ "load_in_4bit": true,
36
+ "load_in_8bit": false,
37
+ "quant_method": "bitsandbytes"
38
+ },
39
+ "rms_norm_eps": 1e-05,
40
+ "rope_scaling": null,
41
+ "rope_theta": 1000000,
42
+ "tie_word_embeddings": false,
43
+ "torch_dtype": "float16",
44
+ "transformers_version": "4.39.0.dev0",
45
+ "use_cache": false,
46
+ "vocab_size": 92544
47
+ }
epoch2_ckpt/configuration_internlm2.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ InternLM2 model configuration"""
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
+
26
+
27
+ # Modified from transformers.model.llama.configuration_llama.LlamaConfig
28
+ class InternLM2Config(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
31
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
32
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32000):
40
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`InternLM2Model`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 11008):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
61
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
62
+ just in case (e.g., 512 or 1024 or 2048).
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
66
+ The epsilon used by the rms normalization layers.
67
+ use_cache (`bool`, *optional*, defaults to `True`):
68
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
69
+ relevant if `config.is_decoder=True`.
70
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
71
+ Whether to tie weight embeddings
72
+ Example:
73
+
74
+ """
75
+ model_type = "internlm2"
76
+ _auto_class = "AutoConfig"
77
+
78
+ def __init__( # pylint: disable=W0102
79
+ self,
80
+ vocab_size=103168,
81
+ hidden_size=4096,
82
+ intermediate_size=11008,
83
+ num_hidden_layers=32,
84
+ num_attention_heads=32,
85
+ num_key_value_heads=None,
86
+ hidden_act="silu",
87
+ max_position_embeddings=2048,
88
+ initializer_range=0.02,
89
+ rms_norm_eps=1e-6,
90
+ use_cache=True,
91
+ pad_token_id=0,
92
+ bos_token_id=1,
93
+ eos_token_id=2,
94
+ tie_word_embeddings=False,
95
+ bias=True,
96
+ rope_theta=10000,
97
+ rope_scaling=None,
98
+ attn_implementation="eager",
99
+ **kwargs,
100
+ ):
101
+ self.vocab_size = vocab_size
102
+ self.max_position_embeddings = max_position_embeddings
103
+ self.hidden_size = hidden_size
104
+ self.intermediate_size = intermediate_size
105
+ self.num_hidden_layers = num_hidden_layers
106
+ self.num_attention_heads = num_attention_heads
107
+ self.bias = bias
108
+
109
+ if num_key_value_heads is None:
110
+ num_key_value_heads = num_attention_heads
111
+ self.num_key_value_heads = num_key_value_heads
112
+
113
+ self.hidden_act = hidden_act
114
+ self.initializer_range = initializer_range
115
+ self.rms_norm_eps = rms_norm_eps
116
+ self.use_cache = use_cache
117
+ self.rope_theta = rope_theta
118
+ self.rope_scaling = rope_scaling
119
+ self._rope_scaling_validation()
120
+
121
+ self.attn_implementation = attn_implementation
122
+ if self.attn_implementation is None:
123
+ self.attn_implementation = "eager"
124
+ super().__init__(
125
+ pad_token_id=pad_token_id,
126
+ bos_token_id=bos_token_id,
127
+ eos_token_id=eos_token_id,
128
+ tie_word_embeddings=tie_word_embeddings,
129
+ **kwargs,
130
+ )
131
+
132
+ def _rope_scaling_validation(self):
133
+ """
134
+ Validate the `rope_scaling` configuration.
135
+ """
136
+ if self.rope_scaling is None:
137
+ return
138
+
139
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
140
+ raise ValueError(
141
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
142
+ f"got {self.rope_scaling}"
143
+ )
144
+ rope_scaling_type = self.rope_scaling.get("type", None)
145
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
146
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
147
+ raise ValueError(
148
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
149
+ )
150
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
151
+ raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
epoch2_ckpt/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.39.0.dev0"
7
+ }
epoch2_ckpt/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ea0b161faf8abb64f439392028a33acc6615dba6785536f694a12b6fa498425
3
+ size 1537498688
epoch2_ckpt/modeling_internlm2.py ADDED
@@ -0,0 +1,1391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ PyTorch InternLM2 model."""
17
+ import math
18
+ import queue
19
+ import threading
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from einops import rearrange
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ SequenceClassifierOutputWithPast,
34
+ )
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import (
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ replace_return_docstrings,
41
+ )
42
+
43
+ try:
44
+ from transformers.generation.streamers import BaseStreamer
45
+ except: # noqa # pylint: disable=bare-except
46
+ BaseStreamer = None
47
+
48
+ from .configuration_internlm2 import InternLM2Config
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CONFIG_FOR_DOC = "InternLM2Config"
53
+
54
+ flash_attn_func, flash_attn_varlen_func = None, None
55
+ pad_input, index_first_axis, unpad_input = None, None, None
56
+ def _import_flash_attn():
57
+ global flash_attn_func, flash_attn_varlen_func
58
+ global pad_input, index_first_axis, unpad_input
59
+ try:
60
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
61
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
62
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
63
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
64
+ except ImportError:
65
+ raise ImportError("flash_attn is not installed.")
66
+
67
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
+ def _get_unpad_data(attention_mask):
69
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
70
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
72
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
73
+ return (
74
+ indices,
75
+ cu_seqlens,
76
+ max_seqlen_in_batch,
77
+ )
78
+
79
+
80
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
81
+ def _make_causal_mask(
82
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
83
+ ):
84
+ """
85
+ Make causal mask used for bi-directional self-attention.
86
+ """
87
+ bsz, tgt_len = input_ids_shape
88
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
89
+ mask_cond = torch.arange(mask.size(-1), device=device)
90
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
91
+ mask = mask.to(dtype)
92
+
93
+ if past_key_values_length > 0:
94
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
95
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
96
+
97
+
98
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
99
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
100
+ """
101
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
102
+ """
103
+ bsz, src_len = mask.size()
104
+ tgt_len = tgt_len if tgt_len is not None else src_len
105
+
106
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
107
+
108
+ inverted_mask = 1.0 - expanded_mask
109
+
110
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
111
+
112
+
113
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
114
+ class InternLM2RMSNorm(nn.Module):
115
+ def __init__(self, hidden_size, eps=1e-6):
116
+ """
117
+ InternLM2RMSNorm is equivalent to T5LayerNorm
118
+ """
119
+ super().__init__()
120
+ self.weight = nn.Parameter(torch.ones(hidden_size))
121
+ self.variance_epsilon = eps
122
+
123
+ def forward(self, hidden_states):
124
+ input_dtype = hidden_states.dtype
125
+ hidden_states = hidden_states.to(torch.float32)
126
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
127
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
128
+ return self.weight * hidden_states.to(input_dtype)
129
+
130
+
131
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
132
+ class InternLM2RotaryEmbedding(nn.Module):
133
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
134
+ super().__init__()
135
+
136
+ self.dim = dim
137
+ self.max_position_embeddings = max_position_embeddings
138
+ self.base = base
139
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
140
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
141
+
142
+ # Build here to make `torch.jit.trace` work.
143
+ self._set_cos_sin_cache(
144
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
145
+ )
146
+
147
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
148
+ self.max_seq_len_cached = seq_len
149
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
150
+
151
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
152
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
153
+ emb = torch.cat((freqs, freqs), dim=-1)
154
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
155
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
156
+
157
+ def forward(self, x, seq_len=None):
158
+ # x: [bs, num_attention_heads, seq_len, head_size]
159
+ if seq_len > self.max_seq_len_cached:
160
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
161
+
162
+ return (
163
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
164
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
165
+ )
166
+
167
+
168
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
169
+ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
170
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
171
+
172
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
173
+ self.scaling_factor = scaling_factor
174
+ super().__init__(dim, max_position_embeddings, base, device)
175
+
176
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
177
+ self.max_seq_len_cached = seq_len
178
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
179
+ t = t / self.scaling_factor
180
+
181
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
182
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
183
+ emb = torch.cat((freqs, freqs), dim=-1)
184
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
185
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
186
+
187
+
188
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
189
+ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
190
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
191
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
192
+ """
193
+
194
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
195
+ self.scaling_factor = scaling_factor
196
+ super().__init__(dim, max_position_embeddings, base, device)
197
+
198
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
199
+ self.max_seq_len_cached = seq_len
200
+
201
+ if seq_len > self.max_position_embeddings:
202
+ base = self.base * (
203
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
204
+ ) ** (self.dim / (self.dim - 2))
205
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
206
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
207
+
208
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
209
+
210
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
211
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
212
+ emb = torch.cat((freqs, freqs), dim=-1)
213
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
214
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
215
+
216
+
217
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
218
+ def rotate_half(x):
219
+ """Rotates half the hidden dims of the input."""
220
+ x1 = x[..., : x.shape[-1] // 2]
221
+ x2 = x[..., x.shape[-1] // 2 :]
222
+ return torch.cat((-x2, x1), dim=-1)
223
+
224
+
225
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
226
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
227
+ """Applies Rotary Position Embedding to the query and key tensors."""
228
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
229
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
230
+ q_embed = (q * cos) + (rotate_half(q) * sin)
231
+ k_embed = (k * cos) + (rotate_half(k) * sin)
232
+ return q_embed, k_embed
233
+
234
+
235
+ class InternLM2MLP(nn.Module):
236
+ def __init__(self, config):
237
+ super().__init__()
238
+ self.config = config
239
+ self.hidden_size = config.hidden_size
240
+ self.intermediate_size = config.intermediate_size
241
+ self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
242
+ self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
243
+ self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
244
+ self.act_fn = ACT2FN[config.hidden_act]
245
+
246
+ def forward(self, x):
247
+ down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
248
+
249
+ return down_proj
250
+
251
+
252
+ # Copied from transformers.model.llama.modeling_llama.repeat_kv
253
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
254
+ """
255
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
256
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
257
+ """
258
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
259
+ if n_rep == 1:
260
+ return hidden_states
261
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
262
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
263
+
264
+
265
+ # Modified from transformers.model.llama.modeling_llama.LlamaAttention
266
+ class InternLM2Attention(nn.Module):
267
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
268
+
269
+ def __init__(self, config: InternLM2Config):
270
+ super().__init__()
271
+ self.config = config
272
+ self.hidden_size = config.hidden_size
273
+ self.num_heads = config.num_attention_heads
274
+ self.head_dim = self.hidden_size // self.num_heads
275
+ self.num_key_value_heads = config.num_key_value_heads
276
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
277
+ self.max_position_embeddings = config.max_position_embeddings
278
+ self.is_causal = True
279
+
280
+ if (self.head_dim * self.num_heads) != self.hidden_size:
281
+ raise ValueError(
282
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
283
+ f" and `num_heads`: {self.num_heads})."
284
+ )
285
+
286
+ self.wqkv = nn.Linear(
287
+ self.hidden_size,
288
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
289
+ bias=config.bias,
290
+ )
291
+
292
+ self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
293
+ self._init_rope()
294
+
295
+ def _init_rope(self):
296
+ if self.config.rope_scaling is None:
297
+ self.rotary_emb = InternLM2RotaryEmbedding(
298
+ self.head_dim,
299
+ max_position_embeddings=self.max_position_embeddings,
300
+ base=self.config.rope_theta,
301
+ )
302
+ else:
303
+ scaling_type = self.config.rope_scaling["type"]
304
+ scaling_factor = self.config.rope_scaling["factor"]
305
+ if scaling_type == "dynamic":
306
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
307
+ self.head_dim,
308
+ max_position_embeddings=self.max_position_embeddings,
309
+ base=self.config.rope_theta,
310
+ scaling_factor=scaling_factor,
311
+ )
312
+ elif scaling_type == "linear":
313
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
314
+ self.head_dim,
315
+ max_position_embeddings=self.max_position_embeddings,
316
+ base=self.config.rope_theta,
317
+ scaling_factor=scaling_factor,
318
+ )
319
+ else:
320
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
321
+ return self.rotary_emb
322
+
323
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
324
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
325
+
326
+ def forward(
327
+ self,
328
+ hidden_states: torch.Tensor,
329
+ attention_mask: Optional[torch.Tensor] = None,
330
+ position_ids: Optional[torch.LongTensor] = None,
331
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
332
+ output_attentions: bool = False,
333
+ use_cache: bool = False,
334
+ **kwargs,
335
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
336
+ if "padding_mask" in kwargs:
337
+ warnings.warn(
338
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
339
+ "Please make sure use `attention_mask` instead.`"
340
+ )
341
+
342
+ bsz, q_len, _ = hidden_states.size()
343
+
344
+ qkv_states = self.wqkv(hidden_states)
345
+
346
+ qkv_states = rearrange(
347
+ qkv_states,
348
+ "b q (h gs d) -> b q h gs d",
349
+ gs=2 + self.num_key_value_groups,
350
+ d=self.head_dim,
351
+ )
352
+
353
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
354
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
355
+ key_states = qkv_states[..., -2, :]
356
+ value_states = qkv_states[..., -1, :]
357
+
358
+ query_states = query_states.transpose(1, 2)
359
+ key_states = key_states.transpose(1, 2)
360
+ value_states = value_states.transpose(1, 2)
361
+
362
+ kv_seq_len = key_states.shape[-2]
363
+ if past_key_value is not None:
364
+ kv_seq_len += past_key_value[0].shape[-2]
365
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
366
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
367
+
368
+ if past_key_value is not None:
369
+ # reuse k, v, self_attention
370
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
371
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
372
+
373
+ past_key_value = (key_states, value_states) if use_cache else None
374
+
375
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
376
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
377
+
378
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
379
+
380
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
381
+ raise ValueError(
382
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
383
+ f" {attn_weights.size()}"
384
+ )
385
+
386
+ if attention_mask is not None:
387
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
388
+ raise ValueError(
389
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
390
+ )
391
+ attn_weights = attn_weights + attention_mask
392
+
393
+ # upcast attention to fp32
394
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
395
+ attn_output = torch.matmul(attn_weights, value_states)
396
+
397
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
398
+ raise ValueError(
399
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
400
+ f" {attn_output.size()}"
401
+ )
402
+
403
+ attn_output = attn_output.transpose(1, 2).contiguous()
404
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
405
+
406
+ attn_output = self.wo(attn_output)
407
+
408
+ if not output_attentions:
409
+ attn_weights = None
410
+
411
+ return attn_output, attn_weights, past_key_value
412
+
413
+
414
+ # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
415
+ class InternLM2FlashAttention2(InternLM2Attention):
416
+ """
417
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
418
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
419
+ flash attention and deal with padding tokens in case the input contains any of them.
420
+ """
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: Optional[torch.LongTensor] = None,
426
+ position_ids: Optional[torch.LongTensor] = None,
427
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
428
+ output_attentions: bool = False,
429
+ use_cache: bool = False,
430
+ **kwargs,
431
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
432
+ # InternLM2FlashAttention2 attention does not support output_attentions
433
+ if "padding_mask" in kwargs:
434
+ warnings.warn(
435
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
436
+ "Please make sure use `attention_mask` instead.`"
437
+ )
438
+
439
+ # overwrite attention_mask with padding_mask
440
+ attention_mask = kwargs.pop("padding_mask")
441
+
442
+ output_attentions = False
443
+
444
+ bsz, q_len, _ = hidden_states.size()
445
+
446
+ qkv_states = self.wqkv(hidden_states)
447
+
448
+ qkv_states = rearrange(
449
+ qkv_states,
450
+ "b q (h gs d) -> b q h gs d",
451
+ gs=2 + self.num_key_value_groups,
452
+ d=self.head_dim,
453
+ )
454
+
455
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
456
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
457
+ key_states = qkv_states[..., -2, :]
458
+ value_states = qkv_states[..., -1, :]
459
+
460
+ query_states = query_states.transpose(1, 2)
461
+ key_states = key_states.transpose(1, 2)
462
+ value_states = value_states.transpose(1, 2)
463
+
464
+ kv_seq_len = key_states.shape[-2]
465
+ if past_key_value is not None:
466
+ kv_seq_len += past_key_value[0].shape[-2]
467
+
468
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
469
+
470
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
471
+
472
+ if past_key_value is not None:
473
+ # reuse k, v, self_attention
474
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
475
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
476
+
477
+ past_key_value = (key_states, value_states) if use_cache else None
478
+
479
+ query_states = query_states.transpose(1, 2)
480
+ key_states = key_states.transpose(1, 2)
481
+ value_states = value_states.transpose(1, 2)
482
+
483
+ attn_output = self._flash_attention_forward(
484
+ query_states, key_states, value_states, attention_mask, q_len
485
+ )
486
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
487
+ attn_output = self.wo(attn_output)
488
+
489
+ if not output_attentions:
490
+ attn_weights = None
491
+
492
+ return attn_output, attn_weights, past_key_value
493
+
494
+ def _flash_attention_forward(
495
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
496
+ ):
497
+ """
498
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
499
+ first unpad the input, then computes the attention scores and pad the final attention scores.
500
+
501
+ Args:
502
+ query_states (`torch.Tensor`):
503
+ Input query states to be passed to Flash Attention API
504
+ key_states (`torch.Tensor`):
505
+ Input key states to be passed to Flash Attention API
506
+ value_states (`torch.Tensor`):
507
+ Input value states to be passed to Flash Attention API
508
+ attention_mask (`torch.Tensor`):
509
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
510
+ position of padding tokens and 1 for the position of non-padding tokens.
511
+ dropout (`int`, *optional*):
512
+ Attention dropout
513
+ softmax_scale (`float`, *optional*):
514
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
515
+ """
516
+ # Contains at least one padding token in the sequence
517
+ causal = self.is_causal and query_length != 1
518
+ if attention_mask is not None:
519
+ batch_size = query_states.shape[0]
520
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
521
+ query_states, key_states, value_states, attention_mask, query_length
522
+ )
523
+
524
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
525
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
526
+
527
+ attn_output_unpad = flash_attn_varlen_func(
528
+ query_states,
529
+ key_states,
530
+ value_states,
531
+ cu_seqlens_q=cu_seqlens_q,
532
+ cu_seqlens_k=cu_seqlens_k,
533
+ max_seqlen_q=max_seqlen_in_batch_q,
534
+ max_seqlen_k=max_seqlen_in_batch_k,
535
+ dropout_p=dropout,
536
+ softmax_scale=softmax_scale,
537
+ causal=causal,
538
+ )
539
+
540
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
541
+ else:
542
+ attn_output = flash_attn_func(
543
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
544
+ )
545
+
546
+ return attn_output
547
+
548
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
549
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
550
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
551
+
552
+ key_layer = index_first_axis(
553
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
554
+ )
555
+ value_layer = index_first_axis(
556
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
557
+ )
558
+
559
+ if query_length == kv_seq_len:
560
+ query_layer = index_first_axis(
561
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
562
+ )
563
+ cu_seqlens_q = cu_seqlens_k
564
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
565
+ indices_q = indices_k
566
+ elif query_length == 1:
567
+ max_seqlen_in_batch_q = 1
568
+ cu_seqlens_q = torch.arange(
569
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
570
+ ) # There is a memcpy here, that is very bad.
571
+ indices_q = cu_seqlens_q[:-1]
572
+ query_layer = query_layer.squeeze(1)
573
+ else:
574
+ # The -q_len: slice assumes left padding.
575
+ attention_mask = attention_mask[:, -query_length:]
576
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
577
+
578
+ return (
579
+ query_layer,
580
+ key_layer,
581
+ value_layer,
582
+ indices_q.to(torch.int64),
583
+ (cu_seqlens_q, cu_seqlens_k),
584
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
585
+ )
586
+
587
+ INTERNLM2_ATTENTION_CLASSES = {
588
+ "eager": InternLM2Attention,
589
+ "flash_attention_2": InternLM2FlashAttention2,
590
+ }
591
+
592
+ # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
593
+ class InternLM2DecoderLayer(nn.Module):
594
+ def __init__(self, config: InternLM2Config):
595
+ super().__init__()
596
+ self.hidden_size = config.hidden_size
597
+
598
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
599
+
600
+ self.feed_forward = InternLM2MLP(config)
601
+ self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
602
+ self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
603
+
604
+ def forward(
605
+ self,
606
+ hidden_states: torch.Tensor,
607
+ attention_mask: Optional[torch.Tensor] = None,
608
+ position_ids: Optional[torch.LongTensor] = None,
609
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
610
+ output_attentions: Optional[bool] = False,
611
+ use_cache: Optional[bool] = False,
612
+ **kwargs,
613
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
614
+ """
615
+ Args:
616
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
617
+ attention_mask (`torch.FloatTensor`, *optional*):
618
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
619
+ query_sequence_length, key_sequence_length)` if default attention is used.
620
+ output_attentions (`bool`, *optional*):
621
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
622
+ returned tensors for more detail.
623
+ use_cache (`bool`, *optional*):
624
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
625
+ (see `past_key_values`).
626
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
627
+ """
628
+ if "padding_mask" in kwargs:
629
+ warnings.warn(
630
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
631
+ "Please make sure use `attention_mask` instead.`"
632
+ )
633
+
634
+ residual = hidden_states
635
+
636
+ hidden_states = self.attention_norm(hidden_states)
637
+
638
+ # Self Attention
639
+ hidden_states, self_attn_weights, present_key_value = self.attention(
640
+ hidden_states=hidden_states,
641
+ attention_mask=attention_mask,
642
+ position_ids=position_ids,
643
+ past_key_value=past_key_value,
644
+ output_attentions=output_attentions,
645
+ use_cache=use_cache,
646
+ **kwargs,
647
+ )
648
+ hidden_states = residual + hidden_states
649
+
650
+ # Fully Connected
651
+ residual = hidden_states
652
+ hidden_states = self.ffn_norm(hidden_states)
653
+ hidden_states = self.feed_forward(hidden_states)
654
+ hidden_states = residual + hidden_states
655
+
656
+ outputs = (hidden_states,)
657
+
658
+ if output_attentions:
659
+ outputs += (self_attn_weights,)
660
+
661
+ if use_cache:
662
+ outputs += (present_key_value,)
663
+
664
+ return outputs
665
+
666
+
667
+ InternLM2_START_DOCSTRING = r"""
668
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
669
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
670
+ etc.)
671
+
672
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
673
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
674
+ and behavior.
675
+
676
+ Parameters:
677
+ config ([`InternLM2Config`]):
678
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
679
+ load the weights associated with the model, only the configuration. Check out the
680
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
681
+ """
682
+
683
+
684
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
685
+ @add_start_docstrings(
686
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
687
+ InternLM2_START_DOCSTRING,
688
+ )
689
+ class InternLM2PreTrainedModel(PreTrainedModel):
690
+ config_class = InternLM2Config
691
+ base_model_prefix = "model"
692
+ supports_gradient_checkpointing = True
693
+ _no_split_modules = ["InternLM2DecoderLayer"]
694
+ _skip_keys_device_placement = "past_key_values"
695
+
696
+ def _init_weights(self, module):
697
+ std = self.config.initializer_range
698
+ if isinstance(module, nn.Linear):
699
+ module.weight.data.normal_(mean=0.0, std=std)
700
+ if module.bias is not None:
701
+ module.bias.data.zero_()
702
+ elif isinstance(module, nn.Embedding):
703
+ module.weight.data.normal_(mean=0.0, std=std)
704
+ if module.padding_idx is not None:
705
+ module.weight.data[module.padding_idx].zero_()
706
+
707
+
708
+ InternLM2_INPUTS_DOCSTRING = r"""
709
+ Args:
710
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
711
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
712
+ it.
713
+
714
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
715
+ [`PreTrainedTokenizer.__call__`] for details.
716
+
717
+ [What are input IDs?](../glossary#input-ids)
718
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
719
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
720
+
721
+ - 1 for tokens that are **not masked**,
722
+ - 0 for tokens that are **masked**.
723
+
724
+ [What are attention masks?](../glossary#attention-mask)
725
+
726
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
727
+ [`PreTrainedTokenizer.__call__`] for details.
728
+
729
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
730
+ `past_key_values`).
731
+
732
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
733
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
734
+ information on the default strategy.
735
+
736
+ - 1 indicates the head is **not masked**,
737
+ - 0 indicates the head is **masked**.
738
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
739
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
740
+ config.n_positions - 1]`.
741
+
742
+ [What are position IDs?](../glossary#position-ids)
743
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
744
+ when `config.use_cache=True`):
745
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
746
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
747
+ `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
748
+
749
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
750
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
751
+
752
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
753
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
754
+ of shape `(batch_size, sequence_length)`.
755
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
756
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
757
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
758
+ model's internal embedding lookup matrix.
759
+ use_cache (`bool`, *optional*):
760
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
761
+ `past_key_values`).
762
+ output_attentions (`bool`, *optional*):
763
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
764
+ tensors for more detail.
765
+ output_hidden_states (`bool`, *optional*):
766
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
767
+ more detail.
768
+ return_dict (`bool`, *optional*):
769
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
770
+ """
771
+
772
+
773
+ # Modified from transformers.model.llama.modeling_llama.LlamaModel
774
+ @add_start_docstrings(
775
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
776
+ InternLM2_START_DOCSTRING,
777
+ )
778
+ class InternLM2Model(InternLM2PreTrainedModel):
779
+ """
780
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
781
+
782
+ Args:
783
+ config: InternLM2Config
784
+ """
785
+
786
+ _auto_class = "AutoModel"
787
+
788
+ def __init__(self, config: InternLM2Config):
789
+ super().__init__(config)
790
+ self.padding_idx = config.pad_token_id
791
+ self.vocab_size = config.vocab_size
792
+ self.config = config
793
+
794
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
795
+
796
+ self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
797
+ self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
798
+
799
+ self.gradient_checkpointing = False
800
+ # Initialize weights and apply final processing
801
+ self.post_init()
802
+
803
+ def get_input_embeddings(self):
804
+ return self.tok_embeddings
805
+
806
+ def set_input_embeddings(self, value):
807
+ self.tok_embeddings = value
808
+
809
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
810
+ # create causal mask
811
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
812
+ combined_attention_mask = None
813
+ if input_shape[-1] > 1:
814
+ combined_attention_mask = _make_causal_mask(
815
+ input_shape,
816
+ inputs_embeds.dtype,
817
+ device=inputs_embeds.device,
818
+ past_key_values_length=past_key_values_length,
819
+ )
820
+
821
+ if attention_mask is not None:
822
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
823
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
824
+ inputs_embeds.device
825
+ )
826
+ combined_attention_mask = (
827
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
828
+ )
829
+
830
+ return combined_attention_mask
831
+
832
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
833
+ def forward(
834
+ self,
835
+ input_ids: torch.LongTensor = None,
836
+ attention_mask: Optional[torch.Tensor] = None,
837
+ position_ids: Optional[torch.LongTensor] = None,
838
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
839
+ inputs_embeds: Optional[torch.FloatTensor] = None,
840
+ use_cache: Optional[bool] = None,
841
+ output_attentions: Optional[bool] = None,
842
+ output_hidden_states: Optional[bool] = None,
843
+ return_dict: Optional[bool] = None,
844
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
845
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
846
+ output_hidden_states = (
847
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
848
+ )
849
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
850
+
851
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
852
+
853
+ if self.config.attn_implementation == "flash_attention_2":
854
+ _import_flash_attn()
855
+
856
+ # retrieve input_ids and inputs_embeds
857
+ if input_ids is not None and inputs_embeds is not None:
858
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
859
+ elif input_ids is not None:
860
+ batch_size, seq_length = input_ids.shape[:2]
861
+ elif inputs_embeds is not None:
862
+ batch_size, seq_length = inputs_embeds.shape[:2]
863
+ else:
864
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
865
+
866
+ seq_length_with_past = seq_length
867
+ past_key_values_length = 0
868
+ if past_key_values is not None:
869
+ past_key_values_length = past_key_values[0][0].shape[2]
870
+ seq_length_with_past = seq_length_with_past + past_key_values_length
871
+
872
+ if position_ids is None:
873
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
874
+ position_ids = torch.arange(
875
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
876
+ )
877
+ position_ids = position_ids.unsqueeze(0)
878
+
879
+ if inputs_embeds is None:
880
+ inputs_embeds = self.tok_embeddings(input_ids)
881
+
882
+ if self.config.attn_implementation == "flash_attention_2":
883
+ # 2d mask is passed through the layers
884
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
885
+ else:
886
+ if attention_mask is None:
887
+ attention_mask = torch.ones(
888
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
889
+ )
890
+ attention_mask = self._prepare_decoder_attention_mask(
891
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
892
+ )
893
+
894
+ # embed positions
895
+ hidden_states = inputs_embeds
896
+
897
+ if self.gradient_checkpointing and self.training:
898
+ if use_cache:
899
+ logger.warning_once(
900
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
901
+ )
902
+ use_cache = False
903
+
904
+ # decoder layers
905
+ all_hidden_states = () if output_hidden_states else None
906
+ all_self_attns = () if output_attentions else None
907
+ next_decoder_cache = () if use_cache else None
908
+
909
+ for idx, decoder_layer in enumerate(self.layers):
910
+ if output_hidden_states:
911
+ all_hidden_states += (hidden_states,)
912
+
913
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
914
+
915
+ if self.gradient_checkpointing and self.training:
916
+
917
+ def create_custom_forward(module):
918
+ def custom_forward(*inputs):
919
+ # None for past_key_value
920
+ return module(*inputs, output_attentions, None)
921
+
922
+ return custom_forward
923
+
924
+ layer_outputs = torch.utils.checkpoint.checkpoint(
925
+ create_custom_forward(decoder_layer),
926
+ hidden_states,
927
+ attention_mask,
928
+ position_ids,
929
+ None,
930
+ )
931
+ else:
932
+ layer_outputs = decoder_layer(
933
+ hidden_states,
934
+ attention_mask=attention_mask,
935
+ position_ids=position_ids,
936
+ past_key_value=past_key_value,
937
+ output_attentions=output_attentions,
938
+ use_cache=use_cache,
939
+ )
940
+
941
+ hidden_states = layer_outputs[0]
942
+
943
+ if use_cache:
944
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
945
+
946
+ if output_attentions:
947
+ all_self_attns += (layer_outputs[1],)
948
+
949
+ hidden_states = self.norm(hidden_states)
950
+
951
+ # add hidden states from the last decoder layer
952
+ if output_hidden_states:
953
+ all_hidden_states += (hidden_states,)
954
+
955
+ next_cache = next_decoder_cache if use_cache else None
956
+ if not return_dict:
957
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
958
+ return BaseModelOutputWithPast(
959
+ last_hidden_state=hidden_states,
960
+ past_key_values=next_cache,
961
+ hidden_states=all_hidden_states,
962
+ attentions=all_self_attns,
963
+ )
964
+
965
+
966
+ # Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
967
+ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
968
+ _auto_class = "AutoModelForCausalLM"
969
+
970
+ _tied_weights_keys = ["output.weight"]
971
+
972
+ def __init__(self, config):
973
+ super().__init__(config)
974
+ self.model = InternLM2Model(config)
975
+ self.vocab_size = config.vocab_size
976
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
977
+
978
+ # Initialize weights and apply final processing
979
+ self.post_init()
980
+
981
+ def get_input_embeddings(self):
982
+ return self.model.tok_embeddings
983
+
984
+ def set_input_embeddings(self, value):
985
+ self.model.tok_embeddings = value
986
+
987
+ def get_output_embeddings(self):
988
+ return self.output
989
+
990
+ def set_output_embeddings(self, new_embeddings):
991
+ self.output = new_embeddings
992
+
993
+ def set_decoder(self, decoder):
994
+ self.model = decoder
995
+
996
+ def get_decoder(self):
997
+ return self.model
998
+
999
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1000
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1001
+ def forward(
1002
+ self,
1003
+ input_ids: torch.LongTensor = None,
1004
+ attention_mask: Optional[torch.Tensor] = None,
1005
+ position_ids: Optional[torch.LongTensor] = None,
1006
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1007
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1008
+ labels: Optional[torch.LongTensor] = None,
1009
+ use_cache: Optional[bool] = None,
1010
+ output_attentions: Optional[bool] = None,
1011
+ output_hidden_states: Optional[bool] = None,
1012
+ return_dict: Optional[bool] = None,
1013
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1014
+ r"""
1015
+ Args:
1016
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1017
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1018
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1019
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1020
+
1021
+ Returns:
1022
+
1023
+ Example:
1024
+
1025
+ ```python
1026
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
1027
+
1028
+ >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1029
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1030
+
1031
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1032
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1033
+
1034
+ >>> # Generate
1035
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1036
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1037
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1038
+ ```"""
1039
+
1040
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1041
+ output_hidden_states = (
1042
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1043
+ )
1044
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1045
+
1046
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1047
+ outputs = self.model(
1048
+ input_ids=input_ids,
1049
+ attention_mask=attention_mask,
1050
+ position_ids=position_ids,
1051
+ past_key_values=past_key_values,
1052
+ inputs_embeds=inputs_embeds,
1053
+ use_cache=use_cache,
1054
+ output_attentions=output_attentions,
1055
+ output_hidden_states=output_hidden_states,
1056
+ return_dict=return_dict,
1057
+ )
1058
+
1059
+ hidden_states = outputs[0]
1060
+ logits = self.output(hidden_states)
1061
+ logits = logits.float()
1062
+
1063
+ loss = None
1064
+ if labels is not None:
1065
+ # Shift so that tokens < n predict n
1066
+ shift_logits = logits[..., :-1, :].contiguous()
1067
+ shift_labels = labels[..., 1:].contiguous()
1068
+ # Flatten the tokens
1069
+ loss_fct = CrossEntropyLoss()
1070
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1071
+ shift_labels = shift_labels.view(-1)
1072
+ # Enable model parallelism
1073
+ shift_labels = shift_labels.to(shift_logits.device)
1074
+ loss = loss_fct(shift_logits, shift_labels)
1075
+
1076
+ if not return_dict:
1077
+ output = (logits,) + outputs[1:]
1078
+ return (loss,) + output if loss is not None else output
1079
+
1080
+ return CausalLMOutputWithPast(
1081
+ loss=loss,
1082
+ logits=logits,
1083
+ past_key_values=outputs.past_key_values,
1084
+ hidden_states=outputs.hidden_states,
1085
+ attentions=outputs.attentions,
1086
+ )
1087
+
1088
+ def prepare_inputs_for_generation(
1089
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1090
+ ):
1091
+ if past_key_values is not None:
1092
+ past_length = past_key_values[0][0].shape[2]
1093
+
1094
+ # Some generation methods already pass only the last input ID
1095
+ if input_ids.shape[1] > past_length:
1096
+ remove_prefix_length = past_length
1097
+ else:
1098
+ # Default to old behavior: keep only final ID
1099
+ remove_prefix_length = input_ids.shape[1] - 1
1100
+
1101
+ input_ids = input_ids[:, remove_prefix_length:]
1102
+
1103
+ position_ids = kwargs.get("position_ids", None)
1104
+ if attention_mask is not None and position_ids is None:
1105
+ # create position_ids on the fly for batch generation
1106
+ position_ids = attention_mask.long().cumsum(-1) - 1
1107
+ position_ids.masked_fill_(attention_mask == 0, 1)
1108
+ if past_key_values:
1109
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1110
+
1111
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1112
+ if inputs_embeds is not None and past_key_values is None:
1113
+ model_inputs = {"inputs_embeds": inputs_embeds}
1114
+ else:
1115
+ model_inputs = {"input_ids": input_ids}
1116
+
1117
+ model_inputs.update(
1118
+ {
1119
+ "position_ids": position_ids,
1120
+ "past_key_values": past_key_values,
1121
+ "use_cache": kwargs.get("use_cache"),
1122
+ "attention_mask": attention_mask,
1123
+ }
1124
+ )
1125
+ return model_inputs
1126
+
1127
+ @staticmethod
1128
+ def _reorder_cache(past_key_values, beam_idx):
1129
+ reordered_past = ()
1130
+ for layer_past in past_key_values:
1131
+ reordered_past += (
1132
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1133
+ )
1134
+ return reordered_past
1135
+
1136
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
1137
+ if tokenizer.add_bos_token:
1138
+ prompt = ""
1139
+ else:
1140
+ prompt = tokenizer.bos_token
1141
+ if meta_instruction:
1142
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1143
+ for record in history:
1144
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1145
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
1146
+ return tokenizer([prompt], return_tensors="pt")
1147
+
1148
+ @torch.no_grad()
1149
+ def chat(
1150
+ self,
1151
+ tokenizer,
1152
+ query: str,
1153
+ history: List[Tuple[str, str]] = [],
1154
+ streamer: Optional[BaseStreamer] = None,
1155
+ max_new_tokens: int = 1024,
1156
+ do_sample: bool = True,
1157
+ temperature: float = 0.8,
1158
+ top_p: float = 0.8,
1159
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1160
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1161
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
1162
+ **kwargs,
1163
+ ):
1164
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1165
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1166
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1167
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]]
1168
+ outputs = self.generate(
1169
+ **inputs,
1170
+ streamer=streamer,
1171
+ max_new_tokens=max_new_tokens,
1172
+ do_sample=do_sample,
1173
+ temperature=temperature,
1174
+ top_p=top_p,
1175
+ eos_token_id=eos_token_id,
1176
+ **kwargs,
1177
+ )
1178
+ outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1179
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
1180
+ response = response.split("<|im_end|>")[0]
1181
+ history = history + [(query, response)]
1182
+ return response, history
1183
+
1184
+ @torch.no_grad()
1185
+ def stream_chat(
1186
+ self,
1187
+ tokenizer,
1188
+ query: str,
1189
+ history: List[Tuple[str, str]] = [],
1190
+ max_new_tokens: int = 1024,
1191
+ do_sample: bool = True,
1192
+ temperature: float = 0.8,
1193
+ top_p: float = 0.8,
1194
+ **kwargs,
1195
+ ):
1196
+ """
1197
+ Return a generator in format: (response, history)
1198
+ Eg.
1199
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
1200
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
1201
+ """
1202
+ if BaseStreamer is None:
1203
+ raise ModuleNotFoundError(
1204
+ "The version of `transformers` is too low. Please make sure "
1205
+ "that you have installed `transformers>=4.28.0`."
1206
+ )
1207
+
1208
+ response_queue = queue.Queue(maxsize=20)
1209
+
1210
+ class ChatStreamer(BaseStreamer):
1211
+ def __init__(self, tokenizer) -> None:
1212
+ super().__init__()
1213
+ self.tokenizer = tokenizer
1214
+ self.queue = response_queue
1215
+ self.query = query
1216
+ self.history = history
1217
+ self.response = ""
1218
+ self.cache = []
1219
+ self.received_inputs = False
1220
+ self.queue.put((self.response, history + [(self.query, self.response)]))
1221
+
1222
+ def put(self, value):
1223
+ if len(value.shape) > 1 and value.shape[0] > 1:
1224
+ raise ValueError("ChatStreamer only supports batch size 1")
1225
+ elif len(value.shape) > 1:
1226
+ value = value[0]
1227
+
1228
+ if not self.received_inputs:
1229
+ # The first received value is input_ids, ignore here
1230
+ self.received_inputs = True
1231
+ return
1232
+
1233
+ self.cache.extend(value.tolist())
1234
+ token = self.tokenizer.decode(self.cache, skip_special_tokens=True)
1235
+ if token.strip() != "<|im_end|>":
1236
+ self.response = self.response + token
1237
+ history = self.history + [(self.query, self.response)]
1238
+ self.queue.put((self.response, history))
1239
+ self.cache = []
1240
+ else:
1241
+ self.end()
1242
+
1243
+ def end(self):
1244
+ self.queue.put(None)
1245
+
1246
+ def stream_producer():
1247
+ return self.chat(
1248
+ tokenizer=tokenizer,
1249
+ query=query,
1250
+ streamer=ChatStreamer(tokenizer=tokenizer),
1251
+ history=history,
1252
+ max_new_tokens=max_new_tokens,
1253
+ do_sample=do_sample,
1254
+ temperature=temperature,
1255
+ top_p=top_p,
1256
+ **kwargs,
1257
+ )
1258
+
1259
+ def consumer():
1260
+ producer = threading.Thread(target=stream_producer)
1261
+ producer.start()
1262
+ while True:
1263
+ res = response_queue.get()
1264
+ if res is None:
1265
+ return
1266
+ yield res
1267
+
1268
+ return consumer()
1269
+
1270
+
1271
+ # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1272
+ @add_start_docstrings(
1273
+ """
1274
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
1275
+
1276
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
1277
+ as other causal models (e.g. GPT-2) do.
1278
+
1279
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1280
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1281
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1282
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1283
+ each row of the batch).
1284
+ """,
1285
+ InternLM2_START_DOCSTRING,
1286
+ )
1287
+ class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
1288
+ def __init__(self, config):
1289
+ super().__init__(config)
1290
+ self.num_labels = config.num_labels
1291
+ self.model = InternLM2Model(config)
1292
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1293
+
1294
+ # Initialize weights and apply final processing
1295
+ self.post_init()
1296
+
1297
+ def get_input_embeddings(self):
1298
+ return self.model.tok_embeddings
1299
+
1300
+ def set_input_embeddings(self, value):
1301
+ self.model.tok_embeddings = value
1302
+
1303
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1304
+ def forward(
1305
+ self,
1306
+ input_ids: torch.LongTensor = None,
1307
+ attention_mask: Optional[torch.Tensor] = None,
1308
+ position_ids: Optional[torch.LongTensor] = None,
1309
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1310
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1311
+ labels: Optional[torch.LongTensor] = None,
1312
+ use_cache: Optional[bool] = None,
1313
+ output_attentions: Optional[bool] = None,
1314
+ output_hidden_states: Optional[bool] = None,
1315
+ return_dict: Optional[bool] = None,
1316
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1317
+ r"""
1318
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1319
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1320
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1321
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1322
+ """
1323
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1324
+
1325
+ transformer_outputs = self.model(
1326
+ input_ids,
1327
+ attention_mask=attention_mask,
1328
+ position_ids=position_ids,
1329
+ past_key_values=past_key_values,
1330
+ inputs_embeds=inputs_embeds,
1331
+ use_cache=use_cache,
1332
+ output_attentions=output_attentions,
1333
+ output_hidden_states=output_hidden_states,
1334
+ return_dict=return_dict,
1335
+ )
1336
+ hidden_states = transformer_outputs[0]
1337
+ logits = self.score(hidden_states)
1338
+
1339
+ if input_ids is not None:
1340
+ batch_size = input_ids.shape[0]
1341
+ else:
1342
+ batch_size = inputs_embeds.shape[0]
1343
+
1344
+ if self.config.pad_token_id is None and batch_size != 1:
1345
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1346
+ if self.config.pad_token_id is None:
1347
+ sequence_lengths = -1
1348
+ else:
1349
+ if input_ids is not None:
1350
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1351
+ logits.device
1352
+ )
1353
+ else:
1354
+ sequence_lengths = -1
1355
+
1356
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1357
+
1358
+ loss = None
1359
+ if labels is not None:
1360
+ labels = labels.to(logits.device)
1361
+ if self.config.problem_type is None:
1362
+ if self.num_labels == 1:
1363
+ self.config.problem_type = "regression"
1364
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1365
+ self.config.problem_type = "single_label_classification"
1366
+ else:
1367
+ self.config.problem_type = "multi_label_classification"
1368
+
1369
+ if self.config.problem_type == "regression":
1370
+ loss_fct = MSELoss()
1371
+ if self.num_labels == 1:
1372
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1373
+ else:
1374
+ loss = loss_fct(pooled_logits, labels)
1375
+ elif self.config.problem_type == "single_label_classification":
1376
+ loss_fct = CrossEntropyLoss()
1377
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1378
+ elif self.config.problem_type == "multi_label_classification":
1379
+ loss_fct = BCEWithLogitsLoss()
1380
+ loss = loss_fct(pooled_logits, labels)
1381
+ if not return_dict:
1382
+ output = (pooled_logits,) + transformer_outputs[1:]
1383
+ return ((loss,) + output) if loss is not None else output
1384
+
1385
+ return SequenceClassifierOutputWithPast(
1386
+ loss=loss,
1387
+ logits=pooled_logits,
1388
+ past_key_values=transformer_outputs.past_key_values,
1389
+ hidden_states=transformer_outputs.hidden_states,
1390
+ attentions=transformer_outputs.attentions,
1391
+ )
epoch2_ckpt/projector/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ProjectorModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_projector.ProjectorConfig",
7
+ "AutoModel": "modeling_projector.ProjectorModel"
8
+ },
9
+ "bias": true,
10
+ "depth": 2,
11
+ "hidden_act": "gelu",
12
+ "llm_hidden_size": 2048,
13
+ "model_type": "projector",
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.39.0.dev0",
16
+ "visual_hidden_size": 2176
17
+ }
epoch2_ckpt/projector/configuration_projector.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class ProjectorConfig(PretrainedConfig):
6
+ model_type = 'projector'
7
+ _auto_class = 'AutoConfig'
8
+
9
+ def __init__(
10
+ self,
11
+ visual_hidden_size=4096,
12
+ llm_hidden_size=4096,
13
+ depth=2,
14
+ hidden_act='gelu',
15
+ bias=True,
16
+ **kwargs,
17
+ ):
18
+ self.visual_hidden_size = visual_hidden_size
19
+ self.llm_hidden_size = llm_hidden_size
20
+ self.depth = depth
21
+ self.hidden_act = hidden_act
22
+ self.bias = bias
23
+ super().__init__(**kwargs)
epoch2_ckpt/projector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ef282e8d1dbdcdecb9a30a7f69fea39f347d19381cc4c1af673c42a505071b5
3
+ size 34619760
epoch2_ckpt/projector/modeling_projector.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel
5
+ from transformers.activations import ACT2FN
6
+
7
+ from .configuration_projector import ProjectorConfig
8
+
9
+
10
+ class ProjectorModel(PreTrainedModel):
11
+ _auto_class = 'AutoModel'
12
+ config_class = ProjectorConfig
13
+ base_model_prefix = 'model'
14
+ supports_gradient_checkpointing = True
15
+
16
+ def __init__(self, config: ProjectorConfig) -> None:
17
+ super().__init__(config)
18
+ self.gradient_checkpointing = False
19
+
20
+ modules = [
21
+ nn.Linear(
22
+ config.visual_hidden_size,
23
+ config.llm_hidden_size,
24
+ bias=config.bias)
25
+ ]
26
+ for _ in range(1, config.depth):
27
+ modules.append(ACT2FN[config.hidden_act])
28
+ modules.append(
29
+ nn.Linear(
30
+ config.llm_hidden_size,
31
+ config.llm_hidden_size,
32
+ bias=config.bias))
33
+ self.model = nn.Sequential(*modules)
34
+
35
+ def enable_input_require_grads(self):
36
+
37
+ def make_inputs_require_grad(module, input, output):
38
+ output.requires_grad_(True)
39
+
40
+ self.model.register_forward_hook(make_inputs_require_grad)
41
+
42
+ def _set_gradient_checkpointing(self, module, value=False):
43
+ if isinstance(module, ProjectorModel):
44
+ module.gradient_checkpointing = value
45
+
46
+ def forward(self, x):
47
+ if self.gradient_checkpointing and self.training:
48
+ layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
49
+ else:
50
+ layer_outputs = self.model(x)
51
+ return layer_outputs
epoch2_ckpt/special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "</s>",
5
+ "unk_token": "<unk>"
6
+ }
epoch2_ckpt/tokenization_internlm2.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization classes for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, List, Optional, Tuple
22
+
23
+ import sentencepiece as spm
24
+ from transformers.tokenization_utils import PreTrainedTokenizer
25
+ from transformers.utils import logging
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {}
32
+
33
+
34
+ # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
35
+ class InternLM2Tokenizer(PreTrainedTokenizer):
36
+ """
37
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
38
+
39
+ Args:
40
+ vocab_file (`str`):
41
+ Path to the vocabulary file.
42
+ """
43
+
44
+ vocab_files_names = VOCAB_FILES_NAMES
45
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
46
+ model_input_names = ["input_ids", "attention_mask"]
47
+ _auto_class = "AutoTokenizer"
48
+
49
+ def __init__(
50
+ self,
51
+ vocab_file,
52
+ unk_token="<unk>",
53
+ bos_token="<s>",
54
+ eos_token="</s>",
55
+ pad_token="</s>",
56
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
57
+ add_bos_token=True,
58
+ add_eos_token=False,
59
+ decode_with_prefix_space=False,
60
+ clean_up_tokenization_spaces=False,
61
+ **kwargs,
62
+ ):
63
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
64
+ self.vocab_file = vocab_file
65
+ self.add_bos_token = add_bos_token
66
+ self.add_eos_token = add_eos_token
67
+ self.decode_with_prefix_space = decode_with_prefix_space
68
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
69
+ self.sp_model.Load(vocab_file)
70
+ self._no_prefix_space_tokens = None
71
+ super().__init__(
72
+ bos_token=bos_token,
73
+ eos_token=eos_token,
74
+ unk_token=unk_token,
75
+ pad_token=pad_token,
76
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
77
+ **kwargs,
78
+ )
79
+
80
+ @property
81
+ def no_prefix_space_tokens(self):
82
+ if self._no_prefix_space_tokens is None:
83
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
84
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
85
+ return self._no_prefix_space_tokens
86
+
87
+ @property
88
+ def vocab_size(self):
89
+ """Returns vocab size"""
90
+ return self.sp_model.get_piece_size()
91
+
92
+ @property
93
+ def bos_token_id(self) -> Optional[int]:
94
+ return self.sp_model.bos_id()
95
+
96
+ @property
97
+ def eos_token_id(self) -> Optional[int]:
98
+ return self.sp_model.eos_id()
99
+
100
+ def get_vocab(self):
101
+ """Returns vocab as a dict"""
102
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
103
+ vocab.update(self.added_tokens_encoder)
104
+ return vocab
105
+
106
+ def _tokenize(self, text):
107
+ """Returns a tokenized string."""
108
+ return self.sp_model.encode(text, out_type=str)
109
+
110
+ def _convert_token_to_id(self, token):
111
+ """Converts a token (str) in an id using the vocab."""
112
+ return self.sp_model.piece_to_id(token)
113
+
114
+ def _convert_id_to_token(self, index):
115
+ """Converts an index (integer) in a token (str) using the vocab."""
116
+ token = self.sp_model.IdToPiece(index)
117
+ return token
118
+
119
+ def _maybe_add_prefix_space(self, tokens, decoded):
120
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
121
+ return " " + decoded
122
+ else:
123
+ return decoded
124
+
125
+ def convert_tokens_to_string(self, tokens):
126
+ """Converts a sequence of tokens (string) in a single string."""
127
+ current_sub_tokens = []
128
+ out_string = ""
129
+ prev_is_special = False
130
+ for token in tokens:
131
+ # make sure that special tokens are not decoded using sentencepiece model
132
+ if token in self.all_special_tokens:
133
+ if not prev_is_special:
134
+ out_string += " "
135
+ out_string += self.sp_model.decode(current_sub_tokens) + token
136
+ prev_is_special = True
137
+ current_sub_tokens = []
138
+ else:
139
+ current_sub_tokens.append(token)
140
+ prev_is_special = False
141
+ out_string += self.sp_model.decode(current_sub_tokens)
142
+ out_string = self.clean_up_tokenization(out_string)
143
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
144
+ return out_string[1:]
145
+
146
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
147
+ """
148
+ Save the vocabulary and special tokens file to a directory.
149
+
150
+ Args:
151
+ save_directory (`str`):
152
+ The directory in which to save the vocabulary.
153
+
154
+ Returns:
155
+ `Tuple(str)`: Paths to the files saved.
156
+ """
157
+ if not os.path.isdir(save_directory):
158
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
159
+ return
160
+ out_vocab_file = os.path.join(
161
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
162
+ )
163
+
164
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
165
+ copyfile(self.vocab_file, out_vocab_file)
166
+ elif not os.path.isfile(self.vocab_file):
167
+ with open(out_vocab_file, "wb") as fi:
168
+ content_spiece_model = self.sp_model.serialized_model_proto()
169
+ fi.write(content_spiece_model)
170
+
171
+ return (out_vocab_file,)
172
+
173
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
174
+ if self.add_bos_token:
175
+ bos_token_ids = [self.bos_token_id]
176
+ else:
177
+ bos_token_ids = []
178
+
179
+ output = bos_token_ids + token_ids_0
180
+
181
+ if token_ids_1 is not None:
182
+ output = output + token_ids_1
183
+
184
+ if self.add_eos_token:
185
+ output = output + [self.eos_token_id]
186
+
187
+ return output
188
+
189
+ def get_special_tokens_mask(
190
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
191
+ ) -> List[int]:
192
+ """
193
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
194
+ special tokens using the tokenizer `prepare_for_model` method.
195
+
196
+ Args:
197
+ token_ids_0 (`List[int]`):
198
+ List of IDs.
199
+ token_ids_1 (`List[int]`, *optional*):
200
+ Optional second list of IDs for sequence pairs.
201
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
202
+ Whether or not the token list is already formatted with special tokens for the model.
203
+
204
+ Returns:
205
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
206
+ """
207
+ if already_has_special_tokens:
208
+ return super().get_special_tokens_mask(
209
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
210
+ )
211
+
212
+ if token_ids_1 is None:
213
+ return [1] + ([0] * len(token_ids_0)) + [1]
214
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
215
+
216
+ def create_token_type_ids_from_sequences(
217
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
218
+ ) -> List[int]:
219
+ """
220
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
221
+ use of token type ids, therefore a list of zeros is returned.
222
+
223
+ Args:
224
+ token_ids_0 (`List[int]`):
225
+ List of IDs.
226
+ token_ids_1 (`List[int]`, *optional*):
227
+ Optional second list of IDs for sequence pairs.
228
+
229
+ Returns:
230
+ `List[int]`: List of zeros.
231
+ """
232
+ eos = [self.eos_token_id]
233
+
234
+ if token_ids_1 is None:
235
+ return len(token_ids_0 + eos) * [0]
236
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
epoch2_ckpt/tokenization_internlm2_fast.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization Fast class for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, Optional, Tuple
22
+
23
+ from tokenizers import processors, decoders, Tokenizer, normalizers
24
+ from tokenizers.models import BPE
25
+
26
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
27
+ from transformers.utils import logging
28
+
29
+ from transformers.convert_slow_tokenizer import (
30
+ SLOW_TO_FAST_CONVERTERS,
31
+ SpmConverter,
32
+ SentencePieceExtractor,
33
+ )
34
+
35
+ from .tokenization_internlm2 import InternLM2Tokenizer
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
40
+
41
+ # Modified from transformers.convert_slow_tokenizer.LlamaConverter
42
+ class InternLM2Converter(SpmConverter):
43
+ handle_byte_fallback = True
44
+
45
+ def vocab(self, proto):
46
+ vocab = [
47
+ ("<unk>", 0.0),
48
+ ("<s>", 0.0),
49
+ ("</s>", 0.0),
50
+ ]
51
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
52
+ return vocab
53
+
54
+ def unk_id(self, proto):
55
+ unk_id = 0
56
+ return unk_id
57
+
58
+ def decoder(self, replacement, add_prefix_space):
59
+ decoders_sequence = [
60
+ decoders.Replace("▁", " "),
61
+ decoders.ByteFallback(),
62
+ decoders.Fuse(),
63
+ ]
64
+ if self.proto.normalizer_spec.add_dummy_prefix:
65
+ decoders_sequence.append(decoders.Strip(content=" ", left=1))
66
+ return decoders.Sequence(decoders_sequence)
67
+
68
+ def tokenizer(self, proto):
69
+ model_type = proto.trainer_spec.model_type
70
+ vocab_scores = self.vocab(proto)
71
+ # special tokens
72
+ added_tokens = self.original_tokenizer.added_tokens_decoder
73
+ for i in range(len(vocab_scores)):
74
+ piece, score = vocab_scores[i]
75
+ if i in added_tokens:
76
+ vocab_scores[i] = (added_tokens[i].content, score)
77
+ if model_type == 1:
78
+ raise RuntimeError("InternLM2 is supposed to be a BPE model!")
79
+
80
+ elif model_type == 2:
81
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
82
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
83
+ tokenizer = Tokenizer(
84
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
85
+ )
86
+ tokenizer.add_special_tokens(
87
+ [ added_token for index, added_token in added_tokens.items()]
88
+ )
89
+ else:
90
+ raise Exception(
91
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
92
+ )
93
+
94
+ return tokenizer
95
+
96
+ def normalizer(self, proto):
97
+ normalizers_list = []
98
+ if proto.normalizer_spec.add_dummy_prefix:
99
+ normalizers_list.append(normalizers.Prepend(prepend="▁"))
100
+ normalizers_list.append(normalizers.Replace(pattern=" ", content="▁"))
101
+ return normalizers.Sequence(normalizers_list)
102
+
103
+ def pre_tokenizer(self, replacement, add_prefix_space):
104
+ return None
105
+
106
+ SLOW_TO_FAST_CONVERTERS["InternLM2Tokenizer"] = InternLM2Converter
107
+
108
+
109
+ # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
110
+ class InternLM2TokenizerFast(PreTrainedTokenizerFast):
111
+ vocab_files_names = VOCAB_FILES_NAMES
112
+ slow_tokenizer_class = InternLM2Tokenizer
113
+ padding_side = "left"
114
+ model_input_names = ["input_ids", "attention_mask"]
115
+ _auto_class = "AutoTokenizer"
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_file,
120
+ unk_token="<unk>",
121
+ bos_token="<s>",
122
+ eos_token="</s>",
123
+ pad_token="</s>",
124
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
125
+ add_bos_token=True,
126
+ add_eos_token=False,
127
+ decode_with_prefix_space=False,
128
+ clean_up_tokenization_spaces=False,
129
+ **kwargs,
130
+ ):
131
+ super().__init__(
132
+ vocab_file=vocab_file,
133
+ unk_token=unk_token,
134
+ bos_token=bos_token,
135
+ eos_token=eos_token,
136
+ pad_token=pad_token,
137
+ sp_model_kwargs=sp_model_kwargs,
138
+ add_bos_token=add_bos_token,
139
+ add_eos_token=add_eos_token,
140
+ decode_with_prefix_space=decode_with_prefix_space,
141
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
142
+ **kwargs,
143
+ )
144
+ self._add_bos_token = add_bos_token
145
+ self._add_eos_token = add_eos_token
146
+ self.update_post_processor()
147
+ self.vocab_file = vocab_file
148
+
149
+ @property
150
+ def can_save_slow_tokenizer(self) -> bool:
151
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
152
+
153
+ def update_post_processor(self):
154
+ """
155
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
156
+ """
157
+ bos = self.bos_token
158
+ bos_token_id = self.bos_token_id
159
+ if bos is None and self.add_bos_token:
160
+ raise ValueError("add_bos_token = True but bos_token = None")
161
+
162
+ eos = self.eos_token
163
+ eos_token_id = self.eos_token_id
164
+ if eos is None and self.add_eos_token:
165
+ raise ValueError("add_eos_token = True but eos_token = None")
166
+
167
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
168
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
169
+
170
+ special_tokens = []
171
+ if self.add_bos_token:
172
+ special_tokens.append((bos, bos_token_id))
173
+ if self.add_eos_token:
174
+ special_tokens.append((eos, eos_token_id))
175
+ self._tokenizer.post_processor = processors.TemplateProcessing(
176
+ single=single, pair=pair, special_tokens=special_tokens
177
+ )
178
+
179
+ @property
180
+ def add_eos_token(self):
181
+ return self._add_eos_token
182
+
183
+ @property
184
+ def add_bos_token(self):
185
+ return self._add_bos_token
186
+
187
+ @add_eos_token.setter
188
+ def add_eos_token(self, value):
189
+ self._add_eos_token = value
190
+ self.update_post_processor()
191
+
192
+ @add_bos_token.setter
193
+ def add_bos_token(self, value):
194
+ self._add_bos_token = value
195
+ self.update_post_processor()
196
+
197
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
198
+ if not self.can_save_slow_tokenizer:
199
+ raise ValueError(
200
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
201
+ "tokenizer."
202
+ )
203
+
204
+ if not os.path.isdir(save_directory):
205
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
206
+ return
207
+ out_vocab_file = os.path.join(
208
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
209
+ )
210
+
211
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
212
+ copyfile(self.vocab_file, out_vocab_file)
213
+
214
+ return (out_vocab_file,)
epoch2_ckpt/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
epoch2_ckpt/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f868398fc4e05ee1e8aeba95ddf18ddcc45b8bce55d5093bead5bbf80429b48b
3
+ size 1477754
epoch2_ckpt/tokenizer_config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "auto_map": {
31
+ "AutoTokenizer": [
32
+ "tokenization_internlm2.InternLM2Tokenizer",
33
+ "tokenization_internlm2_fast.InternLM2TokenizerFast"
34
+ ]
35
+ },
36
+ "bos_token": "<s>",
37
+ "clean_up_tokenization_spaces": false,
38
+ "decode_with_prefix_space": false,
39
+ "eos_token": "</s>",
40
+ "model_max_length": 1000000000000000019884624838656,
41
+ "pad_token": "</s>",
42
+ "padding_side": "right",
43
+ "sp_model_kwargs": null,
44
+ "tokenizer_class": "InternLM2Tokenizer",
45
+ "unk_token": "<unk>"
46
+ }
epoch2_ckpt/xtuner_config.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, SiglipImageProcessor,
10
+ SiglipVisionModel, Dinov2Model)
11
+ from peft import LoraConfig
12
+
13
+ from xtuner.dataset import LLaVADataset
14
+ from xtuner.dataset.collate_fns import default_collate_fn
15
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
16
+ from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook
17
+ from xtuner.engine.runner import TrainLoop
18
+ from xtuner.model import LLaVAModel
19
+ from xtuner.utils import PROMPT_TEMPLATE
20
+
21
+ #######################################################################
22
+ # PART 1 Settings #
23
+ #######################################################################
24
+ # Model
25
+ llm_name_or_path = 'internlm/internlm2-1_8b'
26
+ image_processor_path = 'google/siglip-so400m-patch14-384'
27
+ siglip_path = 'google/siglip-so400m-patch14-384'
28
+ dino_path = 'facebook/dinov2-large'
29
+
30
+ # Data
31
+ data_root = './llava_data/'
32
+ data_path = data_root + 'llava_v1_5_lrv_mix1008k.json'
33
+ image_folder = data_root + 'llava_images'
34
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
35
+ max_length = int(2048 - (336 / 14)**2)
36
+
37
+ # Scheduler & Optimizer
38
+ batch_size = 8 # per_device
39
+ accumulative_counts = 2
40
+ dataloader_num_workers = 4
41
+ max_epochs = 2
42
+ optim_type = AdamW
43
+ lr = 2e-5
44
+ betas = (0.9, 0.999)
45
+ weight_decay = 0.1
46
+ max_norm = 1 # grad clip
47
+ warmup_ratio = 0.03
48
+
49
+ # Save
50
+ save_steps = 500
51
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
52
+
53
+ # Evaluate the generation performance during the training
54
+ evaluation_freq = 500
55
+ SYSTEM = ''
56
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
57
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
58
+
59
+ #######################################################################
60
+ # PART 2 Model & Tokenizer & Image Processor #
61
+ #######################################################################
62
+ tokenizer = dict(
63
+ type=AutoTokenizer.from_pretrained,
64
+ pretrained_model_name_or_path=llm_name_or_path,
65
+ trust_remote_code=True,
66
+ padding_side='right')
67
+
68
+ image_processor = dict(
69
+ type=SiglipImageProcessor.from_pretrained,
70
+ pretrained_model_name_or_path=image_processor_path,
71
+ trust_remote_code=True)
72
+
73
+ model = dict(
74
+ type=LLaVAModel,
75
+ freeze_llm=False,
76
+ freeze_visual_encoder=True,
77
+ llm=dict(
78
+ type=AutoModelForCausalLM.from_pretrained,
79
+ pretrained_model_name_or_path=llm_name_or_path,
80
+ trust_remote_code=True,
81
+ torch_dtype=torch.float16,
82
+ quantization_config=dict(
83
+ type=BitsAndBytesConfig,
84
+ load_in_4bit=True,
85
+ load_in_8bit=False,
86
+ llm_int8_threshold=6.0,
87
+ llm_int8_has_fp16_weight=False,
88
+ bnb_4bit_compute_dtype=torch.float16,
89
+ bnb_4bit_use_double_quant=True,
90
+ bnb_4bit_quant_type='nf4')),
91
+ siglip=dict(
92
+ type=SiglipVisionModel.from_pretrained,
93
+ pretrained_model_name_or_path=siglip_path),
94
+ dino=dict(
95
+ type=Dinov2Model.from_pretrained,
96
+ pretrained_model_name_or_path=dino_path),
97
+ )
98
+
99
+ #######################################################################
100
+ # PART 3 Dataset & Dataloader #
101
+ #######################################################################
102
+ llava_dataset = dict(
103
+ type=LLaVADataset,
104
+ data_path=data_path,
105
+ image_folder=image_folder,
106
+ tokenizer=tokenizer,
107
+ image_processor=image_processor,
108
+ dataset_map_fn=llava_map_fn,
109
+ template_map_fn=dict(
110
+ type=template_map_fn_factory, template=prompt_template),
111
+ max_length=max_length,
112
+ pad_image_to_square=False)
113
+
114
+ train_dataloader = dict(
115
+ batch_size=batch_size,
116
+ num_workers=dataloader_num_workers,
117
+ dataset=llava_dataset,
118
+ sampler=dict(type=DefaultSampler, shuffle=True),
119
+ collate_fn=dict(type=default_collate_fn))
120
+
121
+ #######################################################################
122
+ # PART 4 Scheduler & Optimizer #
123
+ #######################################################################
124
+ # optimizer
125
+ optim_wrapper = dict(
126
+ type=AmpOptimWrapper,
127
+ optimizer=dict(
128
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
129
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
130
+ accumulative_counts=accumulative_counts,
131
+ loss_scale='dynamic',
132
+ dtype='float16')
133
+
134
+ # learning policy
135
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
136
+ param_scheduler = [
137
+ dict(
138
+ type=LinearLR,
139
+ start_factor=1e-5,
140
+ by_epoch=True,
141
+ begin=0,
142
+ end=warmup_ratio * max_epochs,
143
+ convert_to_iter_based=True),
144
+ dict(
145
+ type=CosineAnnealingLR,
146
+ eta_min=0.0,
147
+ by_epoch=True,
148
+ begin=warmup_ratio * max_epochs,
149
+ end=max_epochs,
150
+ convert_to_iter_based=True)
151
+ ]
152
+
153
+ # train, val, test setting
154
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
155
+
156
+ #######################################################################
157
+ # PART 5 Runtime #
158
+ #######################################################################
159
+ # Log the dialogue periodically during the training process, optional
160
+ custom_hooks = [
161
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
162
+ dict(
163
+ type=EvaluateChatHook,
164
+ tokenizer=tokenizer,
165
+ image_processor=image_processor,
166
+ every_n_iters=evaluation_freq,
167
+ evaluation_inputs=evaluation_inputs,
168
+ evaluation_images=evaluation_images,
169
+ system=SYSTEM,
170
+ prompt_template=prompt_template)
171
+ ]
172
+
173
+ # configure default hooks
174
+ default_hooks = dict(
175
+ # record the time of every iteration.
176
+ timer=dict(type=IterTimerHook),
177
+ # print log every 10 iterations.
178
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
179
+ # enable the parameter scheduler.
180
+ param_scheduler=dict(type=ParamSchedulerHook),
181
+ # save checkpoint per `save_steps`.
182
+ checkpoint=dict(
183
+ type=CheckpointHook,
184
+ by_epoch=False,
185
+ interval=save_steps,
186
+ max_keep_ckpts=save_total_limit),
187
+ # set sampler seed in distributed evrionment.
188
+ sampler_seed=dict(type=DistSamplerSeedHook),
189
+ )
190
+
191
+ # configure environment
192
+ env_cfg = dict(
193
+ # whether to enable cudnn benchmark
194
+ cudnn_benchmark=False,
195
+ # set multi process parameters
196
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
197
+ # set distributed parameters
198
+ dist_cfg=dict(backend='nccl'),
199
+ )
200
+
201
+ # set visualizer
202
+ from mmengine.visualization import Visualizer, TensorboardVisBackend
203
+ visualizer = dict(
204
+ type=Visualizer,
205
+ vis_backends=[dict(type=TensorboardVisBackend)]
206
+ )
207
+
208
+ # set log level
209
+ log_level = 'INFO'
210
+
211
+ # load from which checkpoint
212
+ load_from = None
213
+
214
+ # whether to resume training from the loaded checkpoint
215
+ resume = False
216
+
217
+ # Defaults to use random seed and disable `deterministic`
218
+ randomness = dict(seed=None, deterministic=False)
219
+
220
+ # set log processor
221
+ log_processor = dict(by_epoch=False)
modified_transformers/src/transformers/models/siglip/modeling_siglip.py ADDED
@@ -0,0 +1,1299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Siglip model."""
16
+
17
+
18
+ import math
19
+ import warnings
20
+ from dataclasses import dataclass
21
+ from typing import Any, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+ from torch.nn.init import _calculate_fan_in_and_fan_out
29
+
30
+ from ...activations import ACT2FN
31
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
32
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
33
+ from ...modeling_utils import PreTrainedModel
34
+ from ...utils import (
35
+ ModelOutput,
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ replace_return_docstrings,
41
+ )
42
+ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ # General docstring
48
+ _CONFIG_FOR_DOC = "SiglipConfig"
49
+ _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
50
+
51
+ # Image classification docstring
52
+ _IMAGE_CLASS_CHECKPOINT = "google/siglip-base-patch16-224"
53
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_1"
54
+
55
+
56
+ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
57
+ "google/siglip-base-patch16-224",
58
+ # See all SigLIP models at https://huggingface.co/models?filter=siglip
59
+ ]
60
+
61
+
62
+ def _trunc_normal_(tensor, mean, std, a, b):
63
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
64
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
65
+ def norm_cdf(x):
66
+ # Computes standard normal cumulative distribution function
67
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
68
+
69
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
70
+ warnings.warn(
71
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
72
+ "The distribution of values may be incorrect.",
73
+ stacklevel=2,
74
+ )
75
+
76
+ # Values are generated by using a truncated uniform distribution and
77
+ # then using the inverse CDF for the normal distribution.
78
+ # Get upper and lower cdf values
79
+ l = norm_cdf((a - mean) / std)
80
+ u = norm_cdf((b - mean) / std)
81
+
82
+ # Uniformly fill tensor with values from [l, u], then translate to
83
+ # [2l-1, 2u-1].
84
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
85
+
86
+ # Use inverse cdf transform for normal distribution to get truncated
87
+ # standard normal
88
+ tensor.erfinv_()
89
+
90
+ # Transform to proper mean, std
91
+ tensor.mul_(std * math.sqrt(2.0))
92
+ tensor.add_(mean)
93
+
94
+ # Clamp to ensure it's in the proper range
95
+ tensor.clamp_(min=a, max=b)
96
+
97
+
98
+ def trunc_normal_tf_(
99
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
100
+ ) -> torch.Tensor:
101
+ """Fills the input Tensor with values drawn from a truncated
102
+ normal distribution. The values are effectively drawn from the
103
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
104
+ with values outside :math:`[a, b]` redrawn until they are within
105
+ the bounds. The method used for generating the random values works
106
+ best when :math:`a \\leq \text{mean} \\leq b`.
107
+
108
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
109
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
110
+ and the result is subsquently scaled and shifted by the mean and std args.
111
+
112
+ Args:
113
+ tensor: an n-dimensional `torch.Tensor`
114
+ mean: the mean of the normal distribution
115
+ std: the standard deviation of the normal distribution
116
+ a: the minimum cutoff value
117
+ b: the maximum cutoff value
118
+ """
119
+ with torch.no_grad():
120
+ _trunc_normal_(tensor, 0, 1.0, a, b)
121
+ tensor.mul_(std).add_(mean)
122
+
123
+
124
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
125
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
126
+ if mode == "fan_in":
127
+ denom = fan_in
128
+ elif mode == "fan_out":
129
+ denom = fan_out
130
+ elif mode == "fan_avg":
131
+ denom = (fan_in + fan_out) / 2
132
+
133
+ variance = scale / denom
134
+
135
+ if distribution == "truncated_normal":
136
+ # constant is stddev of standard normal truncated to (-2, 2)
137
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
138
+ elif distribution == "normal":
139
+ with torch.no_grad():
140
+ tensor.normal_(std=math.sqrt(variance))
141
+ elif distribution == "uniform":
142
+ bound = math.sqrt(3 * variance)
143
+ with torch.no_grad():
144
+ tensor.uniform_(-bound, bound)
145
+ else:
146
+ raise ValueError(f"invalid distribution {distribution}")
147
+
148
+
149
+ def lecun_normal_(tensor):
150
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
151
+
152
+
153
+ def default_flax_embed_init(tensor):
154
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
155
+
156
+
157
+ @dataclass
158
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
159
+ class SiglipVisionModelOutput(ModelOutput):
160
+ """
161
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
162
+
163
+ Args:
164
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
165
+ The image embeddings obtained by applying the projection layer to the pooler_output.
166
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
167
+ Sequence of hidden-states at the output of the last layer of the model.
168
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
169
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
170
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
171
+
172
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
173
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
174
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
175
+ sequence_length)`.
176
+
177
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
178
+ heads.
179
+ """
180
+
181
+ image_embeds: Optional[torch.FloatTensor] = None
182
+ last_hidden_state: torch.FloatTensor = None
183
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
184
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
185
+
186
+
187
+ @dataclass
188
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
189
+ class SiglipTextModelOutput(ModelOutput):
190
+ """
191
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
192
+
193
+ Args:
194
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
195
+ The text embeddings obtained by applying the projection layer to the pooler_output.
196
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
197
+ Sequence of hidden-states at the output of the last layer of the model.
198
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
199
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
200
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
201
+
202
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
203
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
204
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
205
+ sequence_length)`.
206
+
207
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
208
+ heads.
209
+ """
210
+
211
+ text_embeds: Optional[torch.FloatTensor] = None
212
+ last_hidden_state: torch.FloatTensor = None
213
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
214
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
215
+
216
+
217
+ @dataclass
218
+ # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
219
+ class SiglipOutput(ModelOutput):
220
+ """
221
+ Args:
222
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
223
+ Contrastive loss for image-text similarity.
224
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
225
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
226
+ similarity scores.
227
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
228
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
229
+ similarity scores.
230
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
231
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
232
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
233
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
234
+ text_model_output(`BaseModelOutputWithPooling`):
235
+ The output of the [`SiglipTextModel`].
236
+ vision_model_output(`BaseModelOutputWithPooling`):
237
+ The output of the [`SiglipVisionModel`].
238
+ """
239
+
240
+ loss: Optional[torch.FloatTensor] = None
241
+ logits_per_image: torch.FloatTensor = None
242
+ logits_per_text: torch.FloatTensor = None
243
+ text_embeds: torch.FloatTensor = None
244
+ image_embeds: torch.FloatTensor = None
245
+ text_model_output: BaseModelOutputWithPooling = None
246
+ vision_model_output: BaseModelOutputWithPooling = None
247
+
248
+ def to_tuple(self) -> Tuple[Any]:
249
+ return tuple(
250
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
251
+ for k in self.keys()
252
+ )
253
+
254
+
255
+ class SiglipVisionEmbeddings(nn.Module):
256
+ def __init__(self, config: SiglipVisionConfig):
257
+ super().__init__()
258
+ self.config = config
259
+ self.embed_dim = config.hidden_size
260
+ self.image_size = config.image_size
261
+ self.patch_size = config.patch_size
262
+
263
+ self.patch_embedding = nn.Conv2d(
264
+ in_channels=config.num_channels,
265
+ out_channels=self.embed_dim,
266
+ kernel_size=self.patch_size,
267
+ stride=self.patch_size,
268
+ padding="valid",
269
+ )
270
+
271
+ self.num_patches = (self.image_size // self.patch_size) ** 2
272
+ self.num_positions = self.num_patches
273
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
274
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
275
+
276
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
277
+ target_dtype = self.patch_embedding.weight.dtype
278
+ patch_embeds = self.patch_embedding(pixel_values.to(target_dtype)) # shape = [*, width, grid, grid]
279
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
280
+
281
+ embeddings = embeddings + self.position_embedding(self.position_ids)
282
+ return embeddings
283
+
284
+
285
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
286
+ class SiglipTextEmbeddings(nn.Module):
287
+ def __init__(self, config: SiglipTextConfig):
288
+ super().__init__()
289
+ embed_dim = config.hidden_size
290
+
291
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
292
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
293
+
294
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
295
+ self.register_buffer(
296
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
297
+ )
298
+
299
+ def forward(
300
+ self,
301
+ input_ids: Optional[torch.LongTensor] = None,
302
+ position_ids: Optional[torch.LongTensor] = None,
303
+ inputs_embeds: Optional[torch.FloatTensor] = None,
304
+ ) -> torch.Tensor:
305
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
306
+
307
+ if position_ids is None:
308
+ position_ids = self.position_ids[:, :seq_length]
309
+
310
+ if inputs_embeds is None:
311
+ inputs_embeds = self.token_embedding(input_ids)
312
+
313
+ position_embeddings = self.position_embedding(position_ids)
314
+ embeddings = inputs_embeds + position_embeddings
315
+
316
+ return embeddings
317
+
318
+
319
+ class SiglipAttention(nn.Module):
320
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
321
+
322
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
323
+ def __init__(self, config):
324
+ super().__init__()
325
+ self.config = config
326
+ self.embed_dim = config.hidden_size
327
+ self.num_heads = config.num_attention_heads
328
+ self.head_dim = self.embed_dim // self.num_heads
329
+ if self.head_dim * self.num_heads != self.embed_dim:
330
+ raise ValueError(
331
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
332
+ f" {self.num_heads})."
333
+ )
334
+ self.scale = self.head_dim**-0.5
335
+ self.dropout = config.attention_dropout
336
+
337
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
338
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
339
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
340
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
341
+
342
+ def forward(
343
+ self,
344
+ hidden_states: torch.Tensor,
345
+ attention_mask: Optional[torch.Tensor] = None,
346
+ output_attentions: Optional[bool] = False,
347
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
348
+ """Input shape: Batch x Time x Channel"""
349
+
350
+ batch_size, q_len, _ = hidden_states.size()
351
+
352
+ query_states = self.q_proj(hidden_states)
353
+ key_states = self.k_proj(hidden_states)
354
+ value_states = self.v_proj(hidden_states)
355
+
356
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
357
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
358
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
359
+
360
+ k_v_seq_len = key_states.shape[-2]
361
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
362
+
363
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
364
+ raise ValueError(
365
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
366
+ f" {attn_weights.size()}"
367
+ )
368
+
369
+ if attention_mask is not None:
370
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
371
+ raise ValueError(
372
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
373
+ )
374
+ attn_weights = attn_weights + attention_mask
375
+
376
+ # upcast attention to fp32
377
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
378
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
379
+ attn_output = torch.matmul(attn_weights, value_states)
380
+
381
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
382
+ raise ValueError(
383
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
384
+ f" {attn_output.size()}"
385
+ )
386
+
387
+ attn_output = attn_output.transpose(1, 2).contiguous()
388
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
389
+
390
+ attn_output = self.out_proj(attn_output)
391
+
392
+ return attn_output, attn_weights
393
+
394
+
395
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
396
+ class SiglipMLP(nn.Module):
397
+ def __init__(self, config):
398
+ super().__init__()
399
+ self.config = config
400
+ self.activation_fn = ACT2FN[config.hidden_act]
401
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
402
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
403
+
404
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
405
+ hidden_states = self.fc1(hidden_states)
406
+ hidden_states = self.activation_fn(hidden_states)
407
+ hidden_states = self.fc2(hidden_states)
408
+ return hidden_states
409
+
410
+
411
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
412
+ class SiglipEncoderLayer(nn.Module):
413
+ def __init__(self, config: SiglipConfig):
414
+ super().__init__()
415
+ self.embed_dim = config.hidden_size
416
+ self.self_attn = SiglipAttention(config)
417
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
418
+ self.mlp = SiglipMLP(config)
419
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
420
+
421
+ # Ignore copy
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: torch.Tensor,
426
+ output_attentions: Optional[bool] = False,
427
+ ) -> Tuple[torch.FloatTensor]:
428
+ """
429
+ Args:
430
+ hidden_states (`torch.FloatTensor`):
431
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
432
+ attention_mask (`torch.FloatTensor`):
433
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
434
+ output_attentions (`bool`, *optional*, defaults to `False`):
435
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
436
+ returned tensors for more detail.
437
+ """
438
+ residual = hidden_states
439
+
440
+ hidden_states = self.layer_norm1(hidden_states)
441
+ hidden_states, attn_weights = self.self_attn(
442
+ hidden_states=hidden_states,
443
+ attention_mask=attention_mask,
444
+ output_attentions=output_attentions,
445
+ )
446
+ hidden_states = residual + hidden_states
447
+
448
+ residual = hidden_states
449
+ hidden_states = self.layer_norm2(hidden_states)
450
+ hidden_states = self.mlp(hidden_states)
451
+ hidden_states = residual + hidden_states
452
+
453
+ outputs = (hidden_states,)
454
+
455
+ if output_attentions:
456
+ outputs += (attn_weights,)
457
+
458
+ return outputs
459
+
460
+
461
+ class SiglipPreTrainedModel(PreTrainedModel):
462
+ """
463
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
464
+ models.
465
+ """
466
+
467
+ config_class = SiglipConfig
468
+ base_model_prefix = "siglip"
469
+ supports_gradient_checkpointing = True
470
+
471
+ def _init_weights(self, module):
472
+ """Initialize the weights"""
473
+ if isinstance(module, SiglipVisionEmbeddings):
474
+ width = (
475
+ self.config.vision_config.hidden_size
476
+ if isinstance(self.config, SiglipConfig)
477
+ else self.config.hidden_size
478
+ )
479
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
480
+ elif isinstance(module, nn.Embedding):
481
+ default_flax_embed_init(module.weight)
482
+ elif isinstance(module, SiglipAttention):
483
+ nn.init.xavier_uniform_(module.q_proj.weight)
484
+ nn.init.xavier_uniform_(module.k_proj.weight)
485
+ nn.init.xavier_uniform_(module.v_proj.weight)
486
+ nn.init.xavier_uniform_(module.out_proj.weight)
487
+ nn.init.zeros_(module.q_proj.bias)
488
+ nn.init.zeros_(module.k_proj.bias)
489
+ nn.init.zeros_(module.v_proj.bias)
490
+ nn.init.zeros_(module.out_proj.bias)
491
+ elif isinstance(module, SiglipMLP):
492
+ nn.init.xavier_uniform_(module.fc1.weight)
493
+ nn.init.xavier_uniform_(module.fc2.weight)
494
+ nn.init.normal_(module.fc1.bias, std=1e-6)
495
+ nn.init.normal_(module.fc2.bias, std=1e-6)
496
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
497
+ nn.init.xavier_uniform_(module.probe.data)
498
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
499
+ nn.init.zeros_(module.attention.in_proj_bias.data)
500
+ elif isinstance(module, SiglipModel):
501
+ logit_scale_init = torch.log(torch.tensor(1.0))
502
+ module.logit_scale.data.fill_(logit_scale_init)
503
+ module.logit_bias.data.zero_()
504
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
505
+ lecun_normal_(module.weight)
506
+ if module.bias is not None:
507
+ nn.init.zeros_(module.bias)
508
+ elif isinstance(module, nn.LayerNorm):
509
+ module.bias.data.zero_()
510
+ module.weight.data.fill_(1.0)
511
+
512
+
513
+ SIGLIP_START_DOCSTRING = r"""
514
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
515
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
516
+ etc.)
517
+
518
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
519
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
520
+ and behavior.
521
+
522
+ Parameters:
523
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
524
+ Initializing with a config file does not load the weights associated with the model, only the
525
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
526
+ """
527
+
528
+ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
529
+ Args:
530
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
531
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
532
+ it.
533
+
534
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
535
+ [`PreTrainedTokenizer.__call__`] for details.
536
+
537
+ [What are input IDs?](../glossary#input-ids)
538
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
539
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
540
+
541
+ - 1 for tokens that are **not masked**,
542
+ - 0 for tokens that are **masked**.
543
+
544
+ [What are attention masks?](../glossary#attention-mask)
545
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
546
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
547
+ config.max_position_embeddings - 1]`.
548
+
549
+ [What are position IDs?](../glossary#position-ids)
550
+ output_attentions (`bool`, *optional*):
551
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
552
+ tensors for more detail.
553
+ output_hidden_states (`bool`, *optional*):
554
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
555
+ more detail.
556
+ return_dict (`bool`, *optional*):
557
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
558
+ """
559
+
560
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
561
+ Args:
562
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
563
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
564
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
565
+ output_attentions (`bool`, *optional*):
566
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
567
+ tensors for more detail.
568
+ output_hidden_states (`bool`, *optional*):
569
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
570
+ more detail.
571
+ return_dict (`bool`, *optional*):
572
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
573
+ """
574
+
575
+ SIGLIP_INPUTS_DOCSTRING = r"""
576
+ Args:
577
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
578
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
579
+ it.
580
+
581
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
582
+ [`PreTrainedTokenizer.__call__`] for details.
583
+
584
+ [What are input IDs?](../glossary#input-ids)
585
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
586
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
587
+
588
+ - 1 for tokens that are **not masked**,
589
+ - 0 for tokens that are **masked**.
590
+
591
+ [What are attention masks?](../glossary#attention-mask)
592
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
593
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
594
+ config.max_position_embeddings - 1]`.
595
+
596
+ [What are position IDs?](../glossary#position-ids)
597
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
598
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
599
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
600
+ return_loss (`bool`, *optional*):
601
+ Whether or not to return the contrastive loss.
602
+ output_attentions (`bool`, *optional*):
603
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
604
+ tensors for more detail.
605
+ output_hidden_states (`bool`, *optional*):
606
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
607
+ more detail.
608
+ return_dict (`bool`, *optional*):
609
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
610
+ """
611
+
612
+
613
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
614
+ class SiglipEncoder(nn.Module):
615
+ """
616
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
617
+ [`SiglipEncoderLayer`].
618
+
619
+ Args:
620
+ config: SiglipConfig
621
+ """
622
+
623
+ def __init__(self, config: SiglipConfig):
624
+ super().__init__()
625
+ self.config = config
626
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
627
+ self.gradient_checkpointing = False
628
+
629
+ # Ignore copy
630
+ def forward(
631
+ self,
632
+ inputs_embeds,
633
+ attention_mask: Optional[torch.Tensor] = None,
634
+ output_attentions: Optional[bool] = None,
635
+ output_hidden_states: Optional[bool] = None,
636
+ return_dict: Optional[bool] = None,
637
+ ) -> Union[Tuple, BaseModelOutput]:
638
+ r"""
639
+ Args:
640
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
641
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
642
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
643
+ than the model's internal embedding lookup matrix.
644
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
645
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
646
+
647
+ - 1 for tokens that are **not masked**,
648
+ - 0 for tokens that are **masked**.
649
+
650
+ [What are attention masks?](../glossary#attention-mask)
651
+ output_attentions (`bool`, *optional*):
652
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
653
+ returned tensors for more detail.
654
+ output_hidden_states (`bool`, *optional*):
655
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
656
+ for more detail.
657
+ return_dict (`bool`, *optional*):
658
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
659
+ """
660
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
661
+ output_hidden_states = (
662
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
663
+ )
664
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
665
+
666
+ encoder_states = () if output_hidden_states else None
667
+ all_attentions = () if output_attentions else None
668
+
669
+ hidden_states = inputs_embeds
670
+ for encoder_layer in self.layers:
671
+ if output_hidden_states:
672
+ encoder_states = encoder_states + (hidden_states,)
673
+ if self.gradient_checkpointing and self.training:
674
+ layer_outputs = self._gradient_checkpointing_func(
675
+ encoder_layer.__call__,
676
+ hidden_states,
677
+ attention_mask,
678
+ output_attentions,
679
+ )
680
+ else:
681
+ layer_outputs = encoder_layer(
682
+ hidden_states,
683
+ attention_mask,
684
+ output_attentions=output_attentions,
685
+ )
686
+
687
+ hidden_states = layer_outputs[0]
688
+
689
+ if output_attentions:
690
+ all_attentions = all_attentions + (layer_outputs[1],)
691
+
692
+ if output_hidden_states:
693
+ encoder_states = encoder_states + (hidden_states,)
694
+
695
+ if not return_dict:
696
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
697
+ return BaseModelOutput(
698
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
699
+ )
700
+
701
+
702
+ class SiglipTextTransformer(nn.Module):
703
+ def __init__(self, config: SiglipTextConfig):
704
+ super().__init__()
705
+ self.config = config
706
+ embed_dim = config.hidden_size
707
+ self.embeddings = SiglipTextEmbeddings(config)
708
+ self.encoder = SiglipEncoder(config)
709
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
710
+
711
+ self.head = nn.Linear(embed_dim, embed_dim)
712
+
713
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
714
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
715
+ def forward(
716
+ self,
717
+ input_ids: Optional[torch.Tensor] = None,
718
+ attention_mask: Optional[torch.Tensor] = None,
719
+ position_ids: Optional[torch.Tensor] = None,
720
+ output_attentions: Optional[bool] = None,
721
+ output_hidden_states: Optional[bool] = None,
722
+ return_dict: Optional[bool] = None,
723
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
724
+ r"""
725
+ Returns:
726
+
727
+ """
728
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
729
+ output_hidden_states = (
730
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
731
+ )
732
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
733
+
734
+ if input_ids is None:
735
+ raise ValueError("You have to specify input_ids")
736
+
737
+ input_shape = input_ids.size()
738
+ input_ids = input_ids.view(-1, input_shape[-1])
739
+
740
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
741
+
742
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
743
+ # expand attention_mask
744
+ if attention_mask is not None:
745
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
746
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
747
+
748
+ encoder_outputs = self.encoder(
749
+ inputs_embeds=hidden_states,
750
+ attention_mask=attention_mask,
751
+ output_attentions=output_attentions,
752
+ output_hidden_states=output_hidden_states,
753
+ return_dict=return_dict,
754
+ )
755
+
756
+ last_hidden_state = encoder_outputs[0]
757
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
758
+
759
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
760
+ pooled_output = last_hidden_state[:, -1, :]
761
+ pooled_output = self.head(pooled_output)
762
+
763
+ if not return_dict:
764
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
765
+
766
+ return BaseModelOutputWithPooling(
767
+ last_hidden_state=last_hidden_state,
768
+ pooler_output=pooled_output,
769
+ hidden_states=encoder_outputs.hidden_states,
770
+ attentions=encoder_outputs.attentions,
771
+ )
772
+
773
+
774
+ @add_start_docstrings(
775
+ """The text model from SigLIP without any head or projection on top.""",
776
+ SIGLIP_START_DOCSTRING,
777
+ )
778
+ class SiglipTextModel(SiglipPreTrainedModel):
779
+ config_class = SiglipTextConfig
780
+
781
+ _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
782
+
783
+ def __init__(self, config: SiglipTextConfig):
784
+ super().__init__(config)
785
+ self.text_model = SiglipTextTransformer(config)
786
+ # Initialize weights and apply final processing
787
+ self.post_init()
788
+
789
+ def get_input_embeddings(self) -> nn.Module:
790
+ return self.text_model.embeddings.token_embedding
791
+
792
+ def set_input_embeddings(self, value):
793
+ self.text_model.embeddings.token_embedding = value
794
+
795
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
796
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
797
+ def forward(
798
+ self,
799
+ input_ids: Optional[torch.Tensor] = None,
800
+ attention_mask: Optional[torch.Tensor] = None,
801
+ position_ids: Optional[torch.Tensor] = None,
802
+ output_attentions: Optional[bool] = None,
803
+ output_hidden_states: Optional[bool] = None,
804
+ return_dict: Optional[bool] = None,
805
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
806
+ r"""
807
+ Returns:
808
+
809
+ Examples:
810
+
811
+ ```python
812
+ >>> from transformers import AutoTokenizer, SiglipTextModel
813
+
814
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
815
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
816
+
817
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
818
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
819
+
820
+ >>> outputs = model(**inputs)
821
+ >>> last_hidden_state = outputs.last_hidden_state
822
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
823
+ ```"""
824
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
825
+
826
+ return self.text_model(
827
+ input_ids=input_ids,
828
+ attention_mask=attention_mask,
829
+ position_ids=position_ids,
830
+ output_attentions=output_attentions,
831
+ output_hidden_states=output_hidden_states,
832
+ return_dict=return_dict,
833
+ )
834
+
835
+
836
+ class SiglipVisionTransformer(nn.Module):
837
+ def __init__(self, config: SiglipVisionConfig):
838
+ super().__init__()
839
+ self.config = config
840
+ embed_dim = config.hidden_size
841
+
842
+ self.embeddings = SiglipVisionEmbeddings(config)
843
+ self.encoder = SiglipEncoder(config)
844
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
845
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
846
+
847
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
848
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
849
+ def forward(
850
+ self,
851
+ pixel_values,
852
+ output_attentions: Optional[bool] = None,
853
+ output_hidden_states: Optional[bool] = None,
854
+ return_dict: Optional[bool] = None,
855
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
856
+ r"""
857
+ Returns:
858
+
859
+ """
860
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
861
+ output_hidden_states = (
862
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
863
+ )
864
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
865
+
866
+ hidden_states = self.embeddings(pixel_values)
867
+
868
+ encoder_outputs = self.encoder(
869
+ inputs_embeds=hidden_states,
870
+ output_attentions=output_attentions,
871
+ output_hidden_states=output_hidden_states,
872
+ return_dict=return_dict,
873
+ )
874
+
875
+ last_hidden_state = encoder_outputs[0]
876
+ last_hidden_state = self.post_layernorm(last_hidden_state)
877
+
878
+ pooled_output = self.head(last_hidden_state)
879
+
880
+ if not return_dict:
881
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
882
+
883
+ return BaseModelOutputWithPooling(
884
+ last_hidden_state=last_hidden_state,
885
+ pooler_output=pooled_output,
886
+ hidden_states=encoder_outputs.hidden_states,
887
+ attentions=encoder_outputs.attentions,
888
+ )
889
+
890
+
891
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
892
+ """Multihead Attention Pooling."""
893
+
894
+ def __init__(self, config: SiglipVisionConfig):
895
+ super().__init__()
896
+
897
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
898
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
899
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
900
+ self.mlp = SiglipMLP(config)
901
+
902
+ def forward(self, hidden_state):
903
+ batch_size = hidden_state.shape[0]
904
+ probe = self.probe.repeat(batch_size, 1, 1)
905
+
906
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
907
+
908
+ residual = hidden_state
909
+ hidden_state = self.layernorm(hidden_state)
910
+ hidden_state = residual + self.mlp(hidden_state)
911
+
912
+ return hidden_state[:, 0]
913
+
914
+
915
+ @add_start_docstrings(
916
+ """The vision model from SigLIP without any head or projection on top.""",
917
+ SIGLIP_START_DOCSTRING,
918
+ )
919
+ class SiglipVisionModel(SiglipPreTrainedModel):
920
+ config_class = SiglipVisionConfig
921
+ main_input_name = "pixel_values"
922
+
923
+ def __init__(self, config: SiglipVisionConfig):
924
+ super().__init__(config)
925
+
926
+ self.vision_model = SiglipVisionTransformer(config)
927
+
928
+ # Initialize weights and apply final processing
929
+ self.post_init()
930
+
931
+ def get_input_embeddings(self) -> nn.Module:
932
+ return self.vision_model.embeddings.patch_embedding
933
+
934
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
935
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
936
+ def forward(
937
+ self,
938
+ pixel_values,
939
+ output_attentions: Optional[bool] = None,
940
+ output_hidden_states: Optional[bool] = None,
941
+ return_dict: Optional[bool] = None,
942
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
943
+ r"""
944
+ Returns:
945
+
946
+ Examples:
947
+
948
+ ```python
949
+ >>> from PIL import Image
950
+ >>> import requests
951
+ >>> from transformers import AutoProcessor, SiglipVisionModel
952
+
953
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
954
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
955
+
956
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
957
+ >>> image = Image.open(requests.get(url, stream=True).raw)
958
+
959
+ >>> inputs = processor(images=image, return_tensors="pt")
960
+
961
+ >>> outputs = model(**inputs)
962
+ >>> last_hidden_state = outputs.last_hidden_state
963
+ >>> pooled_output = outputs.pooler_output # pooled features
964
+ ```"""
965
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
966
+
967
+ return self.vision_model(
968
+ pixel_values=pixel_values,
969
+ output_attentions=output_attentions,
970
+ output_hidden_states=output_hidden_states,
971
+ return_dict=return_dict,
972
+ )
973
+
974
+
975
+ @add_start_docstrings(SIGLIP_START_DOCSTRING)
976
+ class SiglipModel(SiglipPreTrainedModel):
977
+ config_class = SiglipConfig
978
+
979
+ def __init__(self, config: SiglipConfig):
980
+ super().__init__(config)
981
+
982
+ if not isinstance(config.text_config, SiglipTextConfig):
983
+ raise ValueError(
984
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
985
+ f" {type(config.text_config)}."
986
+ )
987
+
988
+ if not isinstance(config.vision_config, SiglipVisionConfig):
989
+ raise ValueError(
990
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
991
+ f" {type(config.vision_config)}."
992
+ )
993
+
994
+ text_config = config.text_config
995
+ vision_config = config.vision_config
996
+
997
+ self.text_model = SiglipTextTransformer(text_config)
998
+ self.vision_model = SiglipVisionTransformer(vision_config)
999
+
1000
+ self.logit_scale = nn.Parameter(torch.randn(1))
1001
+ self.logit_bias = nn.Parameter(torch.randn(1))
1002
+
1003
+ # Initialize weights and apply final processing
1004
+ self.post_init()
1005
+
1006
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1007
+ def get_text_features(
1008
+ self,
1009
+ input_ids: Optional[torch.Tensor] = None,
1010
+ attention_mask: Optional[torch.Tensor] = None,
1011
+ position_ids: Optional[torch.Tensor] = None,
1012
+ output_attentions: Optional[bool] = None,
1013
+ output_hidden_states: Optional[bool] = None,
1014
+ return_dict: Optional[bool] = None,
1015
+ ) -> torch.FloatTensor:
1016
+ r"""
1017
+ Returns:
1018
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1019
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
1020
+
1021
+ Examples:
1022
+
1023
+ ```python
1024
+ >>> from transformers import AutoTokenizer, AutoModel
1025
+ >>> import torch
1026
+
1027
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1028
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1029
+
1030
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1031
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1032
+ >>> with torch.no_grad():
1033
+ ... text_features = model.get_text_features(**inputs)
1034
+ ```"""
1035
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1036
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1037
+ output_hidden_states = (
1038
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1039
+ )
1040
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1041
+
1042
+ text_outputs = self.text_model(
1043
+ input_ids=input_ids,
1044
+ attention_mask=attention_mask,
1045
+ position_ids=position_ids,
1046
+ output_attentions=output_attentions,
1047
+ output_hidden_states=output_hidden_states,
1048
+ return_dict=return_dict,
1049
+ )
1050
+
1051
+ pooled_output = text_outputs[1]
1052
+
1053
+ return pooled_output
1054
+
1055
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1056
+ def get_image_features(
1057
+ self,
1058
+ pixel_values: Optional[torch.FloatTensor] = None,
1059
+ output_attentions: Optional[bool] = None,
1060
+ output_hidden_states: Optional[bool] = None,
1061
+ return_dict: Optional[bool] = None,
1062
+ ) -> torch.FloatTensor:
1063
+ r"""
1064
+ Returns:
1065
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1066
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
1067
+
1068
+ Examples:
1069
+
1070
+ ```python
1071
+ >>> from PIL import Image
1072
+ >>> import requests
1073
+ >>> from transformers import AutoProcessor, AutoModel
1074
+ >>> import torch
1075
+
1076
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1077
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1078
+
1079
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1080
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1081
+
1082
+ >>> inputs = processor(images=image, return_tensors="pt")
1083
+
1084
+ >>> with torch.no_grad():
1085
+ ... image_features = model.get_image_features(**inputs)
1086
+ ```"""
1087
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1088
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1089
+ output_hidden_states = (
1090
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1091
+ )
1092
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1093
+
1094
+ vision_outputs = self.vision_model(
1095
+ pixel_values=pixel_values,
1096
+ output_attentions=output_attentions,
1097
+ output_hidden_states=output_hidden_states,
1098
+ return_dict=return_dict,
1099
+ )
1100
+
1101
+ pooled_output = vision_outputs[1]
1102
+
1103
+ return pooled_output
1104
+
1105
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1106
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1107
+ def forward(
1108
+ self,
1109
+ input_ids: Optional[torch.LongTensor] = None,
1110
+ pixel_values: Optional[torch.FloatTensor] = None,
1111
+ attention_mask: Optional[torch.Tensor] = None,
1112
+ position_ids: Optional[torch.LongTensor] = None,
1113
+ return_loss: Optional[bool] = None,
1114
+ output_attentions: Optional[bool] = None,
1115
+ output_hidden_states: Optional[bool] = None,
1116
+ return_dict: Optional[bool] = None,
1117
+ ) -> Union[Tuple, SiglipOutput]:
1118
+ r"""
1119
+ Returns:
1120
+
1121
+ Examples:
1122
+
1123
+ ```python
1124
+ >>> from PIL import Image
1125
+ >>> import requests
1126
+ >>> from transformers import AutoProcessor, AutoModel
1127
+ >>> import torch
1128
+
1129
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1130
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1131
+
1132
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1133
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1134
+
1135
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1136
+ >>> # important: we pass `padding=max_length` since the model was trained with this
1137
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1138
+
1139
+ >>> with torch.no_grad():
1140
+ ... outputs = model(**inputs)
1141
+
1142
+ >>> logits_per_image = outputs.logits_per_image
1143
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1144
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1145
+ 31.9% that image 0 is 'a photo of 2 cats'
1146
+ ```"""
1147
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1148
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1149
+ output_hidden_states = (
1150
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1151
+ )
1152
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1153
+
1154
+ vision_outputs = self.vision_model(
1155
+ pixel_values=pixel_values,
1156
+ output_attentions=output_attentions,
1157
+ output_hidden_states=output_hidden_states,
1158
+ return_dict=return_dict,
1159
+ )
1160
+
1161
+ text_outputs = self.text_model(
1162
+ input_ids=input_ids,
1163
+ attention_mask=attention_mask,
1164
+ position_ids=position_ids,
1165
+ output_attentions=output_attentions,
1166
+ output_hidden_states=output_hidden_states,
1167
+ return_dict=return_dict,
1168
+ )
1169
+
1170
+ image_embeds = vision_outputs[1]
1171
+ text_embeds = text_outputs[1]
1172
+
1173
+ # normalized features
1174
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1175
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1176
+
1177
+ # cosine similarity as logits
1178
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
1179
+ logits_per_image = logits_per_text.t()
1180
+
1181
+ loss = None
1182
+ if return_loss:
1183
+ raise NotImplementedError("SigLIP loss to be implemented")
1184
+
1185
+ if not return_dict:
1186
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1187
+ return ((loss,) + output) if loss is not None else output
1188
+
1189
+ return SiglipOutput(
1190
+ loss=loss,
1191
+ logits_per_image=logits_per_image,
1192
+ logits_per_text=logits_per_text,
1193
+ text_embeds=text_embeds,
1194
+ image_embeds=image_embeds,
1195
+ text_model_output=text_outputs,
1196
+ vision_model_output=vision_outputs,
1197
+ )
1198
+
1199
+
1200
+ @add_start_docstrings(
1201
+ """
1202
+ SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
1203
+ the patch tokens) e.g. for ImageNet.
1204
+ """,
1205
+ SIGLIP_START_DOCSTRING,
1206
+ )
1207
+ class SiglipForImageClassification(SiglipPreTrainedModel):
1208
+ main_input_name = "pixel_values"
1209
+
1210
+ def __init__(self, config: SiglipConfig) -> None:
1211
+ super().__init__(config)
1212
+
1213
+ self.num_labels = config.num_labels
1214
+ self.vision_model = SiglipVisionTransformer(config.vision_config)
1215
+
1216
+ # Classifier head
1217
+ self.classifier = (
1218
+ nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
1219
+ )
1220
+
1221
+ # Initialize weights and apply final processing
1222
+ self.post_init()
1223
+
1224
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1225
+ @add_code_sample_docstrings(
1226
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1227
+ output_type=ImageClassifierOutput,
1228
+ config_class=_CONFIG_FOR_DOC,
1229
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1230
+ )
1231
+ def forward(
1232
+ self,
1233
+ pixel_values: Optional[torch.Tensor] = None,
1234
+ labels: Optional[torch.Tensor] = None,
1235
+ output_attentions: Optional[bool] = None,
1236
+ output_hidden_states: Optional[bool] = None,
1237
+ return_dict: Optional[bool] = None,
1238
+ ) -> Union[tuple, ImageClassifierOutput]:
1239
+ r"""
1240
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1241
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1242
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1243
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1244
+ """
1245
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1246
+ output_hidden_states = (
1247
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1248
+ )
1249
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1250
+
1251
+ outputs = self.vision_model(
1252
+ pixel_values,
1253
+ output_attentions=output_attentions,
1254
+ output_hidden_states=output_hidden_states,
1255
+ return_dict=return_dict,
1256
+ )
1257
+
1258
+ sequence_output = outputs[0]
1259
+
1260
+ # average pool the patch tokens
1261
+ sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
1262
+ # apply classifier
1263
+ logits = self.classifier(sequence_output)
1264
+
1265
+ loss = None
1266
+ if labels is not None:
1267
+ # move labels to correct device to enable model parallelism
1268
+ labels = labels.to(logits.device)
1269
+ if self.config.problem_type is None:
1270
+ if self.num_labels == 1:
1271
+ self.config.problem_type = "regression"
1272
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1273
+ self.config.problem_type = "single_label_classification"
1274
+ else:
1275
+ self.config.problem_type = "multi_label_classification"
1276
+
1277
+ if self.config.problem_type == "regression":
1278
+ loss_fct = MSELoss()
1279
+ if self.num_labels == 1:
1280
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1281
+ else:
1282
+ loss = loss_fct(logits, labels)
1283
+ elif self.config.problem_type == "single_label_classification":
1284
+ loss_fct = CrossEntropyLoss()
1285
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1286
+ elif self.config.problem_type == "multi_label_classification":
1287
+ loss_fct = BCEWithLogitsLoss()
1288
+ loss = loss_fct(logits, labels)
1289
+
1290
+ if not return_dict:
1291
+ output = (logits,) + outputs[2:]
1292
+ return ((loss,) + output) if loss is not None else output
1293
+
1294
+ return ImageClassifierOutput(
1295
+ loss=loss,
1296
+ logits=logits,
1297
+ hidden_states=outputs.hidden_states,
1298
+ attentions=outputs.attentions,
1299
+ )
modified_xtuner/xtuner/dataset/huggingface.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+ import os
4
+ from datetime import timedelta
5
+ from functools import partial
6
+
7
+ import numpy as np
8
+ from datasets import DatasetDict, concatenate_datasets
9
+ from mmengine import print_log
10
+ from mmengine.config import Config, ConfigDict
11
+ from mmengine.utils.misc import get_object_from_string
12
+ from torch import distributed as dist
13
+
14
+ from xtuner.registry import BUILDER, MAP_FUNC
15
+ from .utils import Packer, encode_fn
16
+
17
+
18
+ def get_lengths(example):
19
+ return {'length': len(example['input_ids'])}
20
+
21
+
22
+ def build_origin_dataset(dataset, split):
23
+ if isinstance(dataset, DatasetDict):
24
+ if split is None:
25
+ dataset = concatenate_datasets(dataset.values())
26
+ else:
27
+ dataset = dataset[split]
28
+ elif isinstance(dataset, dict) or isinstance(
29
+ dataset, Config) or isinstance(dataset, ConfigDict):
30
+ dataset = BUILDER.build(dataset)
31
+ if isinstance(dataset, DatasetDict):
32
+ if split is None:
33
+ dataset = concatenate_datasets(dataset.values())
34
+ else:
35
+ dataset = dataset[split]
36
+ return dataset
37
+
38
+
39
+ def map_dataset(dataset, dataset_map_fn, map_num_proc):
40
+ if isinstance(dataset_map_fn, str):
41
+ map_fn_obj = MAP_FUNC.get(dataset_map_fn) or get_object_from_string(
42
+ dataset_map_fn)
43
+ if map_fn_obj is not None:
44
+ dataset_map_fn = map_fn_obj
45
+ else:
46
+ raise TypeError('dataset_map_fn must be a function or a '
47
+ "registered function's string in MAP_FUNC, "
48
+ f"but got a string of '{dataset_map_fn}'")
49
+
50
+ dataset = dataset.map(dataset_map_fn, num_proc=map_num_proc)
51
+ return dataset
52
+
53
+
54
+ def add_template_to_dataset(dataset, template_map_fn, map_num_proc):
55
+ if isinstance(template_map_fn,
56
+ dict) or isinstance(template_map_fn, Config) or isinstance(
57
+ template_map_fn, ConfigDict):
58
+ template_map_fn = BUILDER.build(template_map_fn)
59
+ dataset = dataset.map(template_map_fn, num_proc=map_num_proc)
60
+ # remove invalid data
61
+ dataset = dataset.filter(
62
+ lambda example: len(example['conversation']) > 0,
63
+ num_proc=map_num_proc)
64
+ return dataset
65
+
66
+
67
+ def tokenize_dataset(dataset, tokenizer, max_length, with_image_token,
68
+ input_ids_with_output, remove_unused_columns,
69
+ map_num_proc):
70
+ assert (tokenizer is not None) and (max_length is not None), \
71
+ f'({tokenizer}, {max_length})'
72
+ if isinstance(tokenizer, dict) or isinstance(
73
+ tokenizer, Config) or isinstance(tokenizer, ConfigDict):
74
+ tokenizer = BUILDER.build(tokenizer)
75
+ dataset = dataset.map(
76
+ partial(
77
+ encode_fn,
78
+ tokenizer=tokenizer,
79
+ max_length=max_length,
80
+ with_image_token=with_image_token,
81
+ input_ids_with_output=input_ids_with_output),
82
+ remove_columns=list(dataset.column_names)
83
+ if remove_unused_columns else None,
84
+ num_proc=map_num_proc)
85
+ return dataset
86
+
87
+
88
+ def pack_dataset(dataset, max_length, use_varlen_attn, shuffle_before_pack,
89
+ map_num_proc):
90
+ if shuffle_before_pack:
91
+ dataset = dataset.shuffle()
92
+ dataset = dataset.flatten_indices(num_proc=map_num_proc)
93
+ dataset = dataset.map(
94
+ Packer(max_length, use_varlen_attn=use_varlen_attn),
95
+ batched=True,
96
+ num_proc=map_num_proc)
97
+ return dataset
98
+
99
+
100
+ def process(dataset,
101
+ do_dataset_tokenization=True,
102
+ tokenizer=None,
103
+ max_length=None,
104
+ dataset_map_fn=None,
105
+ template_map_fn=None,
106
+ max_dataset_length=None,
107
+ split='train',
108
+ remove_unused_columns=False,
109
+ rename_maps=[],
110
+ shuffle_before_pack=True,
111
+ pack_to_max_length=True,
112
+ use_varlen_attn=False,
113
+ input_ids_with_output=True,
114
+ with_image_token=False,
115
+ map_num_proc=4):
116
+ """Post-process the dataset loaded from the Hugging Face Hub, or a local
117
+ dataset.
118
+
119
+ Args:
120
+ dataset: The dataset to be post-processed.
121
+ do_dataset_tokenization: Whether the dataset need to be tokenized
122
+ in this function. Default to True.
123
+ tokenizer: The tokenizer processes some raw text as input and outputs
124
+ an Encoding. If `do_dataset_tokenization` is True, this argument
125
+ should not be None. Default to None.
126
+ max_length: Max length of the sequence. If `do_dataset_tokenization`
127
+ or `pack_to_max_length` is True, this argument should not be None.
128
+ Default to None.
129
+ dataset_map_fn: Map the original dataset format to the one defined
130
+ by xTuner.
131
+ template_map_fn: Add the prompt template to the dataset
132
+ max_dataset_length: If the length of the dataset is too long, we can
133
+ randomly extract `max_dataset_length` from it.
134
+ split: Which split of the data to load.
135
+ If `None`, will return a single concatenated dataset with all
136
+ splits (typically `datasets.Split.TRAIN` and
137
+ `datasets.Split.TEST`).
138
+ If given, will return a single Dataset.
139
+ remove_unused_columns: Whether to remove columns from the dataset
140
+ that are not used during training.
141
+ rename_maps: Rename the column name of the dataset.
142
+ shuffle_before_pack: Whether to shuffle the dataset before
143
+ packing them.
144
+ pack_to_max_length: Whether to pack the dataset to the `max_length `.
145
+ This usually improves gpu utilization and therefore reduces
146
+ training time.
147
+ use_varlen_attn: If use_varlen_attn is True, we calculate attention
148
+ the actual length of the sequence rather than the actual length
149
+ of the sequence
150
+ input_ids_with_output: Whether to put the groundtruth output
151
+ corresponding to the question into the dataset. Typically set
152
+ it to True during training and False during testing.
153
+ with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to
154
+ IMAGE_TOKEN_INDEX. Typically set it to True during the training
155
+ of VLM.
156
+ map_num_proc: Max number of processes when mapping the dataset.
157
+ """
158
+ if use_varlen_attn:
159
+ assert pack_to_max_length, \
160
+ '`pack_to_max_length` in `process_hf_dataset` should be set to ' \
161
+ 'True if `use_varlen_attn` is True.'
162
+ if pack_to_max_length:
163
+ assert split == 'train' or split is None, \
164
+ ('`split` should be `train` or `None` if `pack_to_max_length` is '
165
+ f'True, but got {split}.')
166
+
167
+ dataset = build_origin_dataset(dataset, split)
168
+
169
+ # sample `max_dataset_length` items from the original dataset to
170
+ # save time consumed by map function
171
+ if max_dataset_length is not None:
172
+ max_dataset_length = min(max_dataset_length, len(dataset))
173
+ indices = np.random.choice(
174
+ len(dataset), max_dataset_length, replace=False)
175
+ dataset = dataset.select(indices)
176
+
177
+ # Extract the useful data for training from the original dataset.
178
+ if dataset_map_fn is not None:
179
+ dataset = map_dataset(dataset, dataset_map_fn, map_num_proc)
180
+
181
+ # Add prompt template, such as <|System|>: xxx <|User|>: xxx <|Bot|>: xxx
182
+ if template_map_fn is not None:
183
+ dataset = add_template_to_dataset(dataset, template_map_fn,
184
+ map_num_proc)
185
+
186
+ for old, new in rename_maps:
187
+ dataset = dataset.rename_column(old, new)
188
+
189
+ # remove unused columns
190
+ if pack_to_max_length and (not remove_unused_columns):
191
+ print_log(
192
+ 'We have to remove unused columns if '
193
+ '`pack_to_max_length` is set to True.',
194
+ logger='current',
195
+ level=logging.WARNING)
196
+ remove_unused_columns = True
197
+
198
+ if do_dataset_tokenization:
199
+ dataset = tokenize_dataset(dataset, tokenizer, max_length,
200
+ with_image_token, input_ids_with_output,
201
+ remove_unused_columns, map_num_proc)
202
+ else:
203
+ assert {'input_ids', 'labels'}.issubset(dataset.column_names)
204
+
205
+ if input_ids_with_output:
206
+ # remove data that does not have the valid labels.
207
+ dataset = dataset.filter(
208
+ lambda example: any(label >= 0 for label in example['labels']),
209
+ num_proc=map_num_proc)
210
+
211
+ # pack to max length
212
+ if pack_to_max_length:
213
+ dataset = pack_dataset(dataset, max_length, use_varlen_attn,
214
+ shuffle_before_pack, map_num_proc)
215
+
216
+ # add 'length'
217
+ dataset = dataset.map(get_lengths, num_proc=map_num_proc)
218
+ setattr(dataset, 'length', dataset['length'])
219
+
220
+ return dataset
221
+
222
+
223
+ def process_hf_dataset(*args, **kwargs):
224
+ if not (dist.is_available() and dist.is_initialized()):
225
+ return process(*args, **kwargs)
226
+
227
+ xtuner_dataset_timeout = timedelta(
228
+ minutes=int(os.getenv('XTUNER_DATASET_TIMEOUT', default=30)))
229
+ print_log(
230
+ f'xtuner_dataset_timeout = {xtuner_dataset_timeout}', logger='current')
231
+ # monitored barrier requires gloo process group to perform host-side sync.
232
+ group_gloo = dist.new_group(backend='gloo', timeout=xtuner_dataset_timeout)
233
+
234
+ if dist.get_rank() == 0:
235
+ dataset = process(*args, **kwargs)
236
+ objects = [dataset]
237
+ else:
238
+ objects = [None]
239
+
240
+ dist.monitored_barrier(group=group_gloo, timeout=xtuner_dataset_timeout)
241
+ dist.broadcast_object_list(objects, src=0)
242
+ return objects[0]