Tonylin52 commited on
Commit
9214dde
·
1 Parent(s): 91ce4f3

Delete modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +0 -1261
modeling_chatglm.py DELETED
@@ -1,1261 +0,0 @@
1
- """ PyTorch ChatGLM model. """
2
-
3
- import math
4
- import copy
5
- import os
6
- import warnings
7
- import re
8
-
9
- import torch
10
- import torch.utils.checkpoint
11
- import torch.nn.functional as F
12
- from torch import nn
13
- from torch.nn import CrossEntropyLoss, LayerNorm
14
- from torch.nn.utils import skip_init
15
- from typing import Optional, Tuple, Union, List, Callable
16
-
17
- from transformers.utils import (
18
- add_code_sample_docstrings,
19
- add_start_docstrings,
20
- add_start_docstrings_to_model_forward,
21
- )
22
- from transformers.modeling_outputs import (
23
- BaseModelOutputWithPast,
24
- CausalLMOutputWithPast,
25
- BaseModelOutputWithPastAndCrossAttentions,
26
- )
27
- from transformers.modeling_utils import PreTrainedModel
28
- from transformers.utils import logging
29
- from transformers.generation.logits_process import LogitsProcessor
30
- from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
31
-
32
- from .configuration_chatglm import ChatGLMConfig
33
-
34
- # flags required to enable jit fusion kernels
35
- torch._C._jit_set_profiling_mode(False)
36
- torch._C._jit_set_profiling_executor(False)
37
- torch._C._jit_override_can_fuse_on_cpu(True)
38
- torch._C._jit_override_can_fuse_on_gpu(True)
39
-
40
- logger = logging.get_logger(__name__)
41
-
42
- _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B"
43
- _CONFIG_FOR_DOC = "ChatGLM6BConfig"
44
-
45
- CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
46
- "THUDM/chatglm-6b",
47
- # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
48
- ]
49
-
50
-
51
- class InvalidScoreLogitsProcessor(LogitsProcessor):
52
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
53
- if torch.isnan(scores).any() or torch.isinf(scores).any():
54
- scores.zero_()
55
- scores[..., 20005] = 5e4
56
- return scores
57
-
58
-
59
- def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
60
- """Load tf checkpoints in a pytorch model."""
61
- try:
62
- import re
63
-
64
- import numpy as np
65
- import tensorflow as tf
66
- except ImportError:
67
- logger.error(
68
- "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
69
- "https://www.tensorflow.org/install/ for installation instructions."
70
- )
71
- raise
72
- tf_path = os.path.abspath(tf_checkpoint_path)
73
- logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
74
- # Load weights from TF model
75
- init_vars = tf.train.list_variables(tf_path)
76
- names = []
77
- arrays = []
78
- for name, shape in init_vars:
79
- logger.info(f"Loading TF weight {name} with shape {shape}")
80
- array = tf.train.load_variable(tf_path, name)
81
- names.append(name)
82
- arrays.append(array)
83
-
84
- for name, array in zip(names, arrays):
85
- name = name.split("/")
86
- # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
87
- # which are not required for using pretrained model
88
- if any(
89
- n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
90
- for n in name
91
- ):
92
- logger.info(f"Skipping {'/'.join(name)}")
93
- continue
94
- pointer = model
95
- for m_name in name:
96
- if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
97
- scope_names = re.split(r"_(\d+)", m_name)
98
- else:
99
- scope_names = [m_name]
100
- if scope_names[0] == "kernel" or scope_names[0] == "gamma":
101
- pointer = getattr(pointer, "weight")
102
- elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
103
- pointer = getattr(pointer, "bias")
104
- elif scope_names[0] == "output_weights":
105
- pointer = getattr(pointer, "weight")
106
- elif scope_names[0] == "squad":
107
- pointer = getattr(pointer, "classifier")
108
- else:
109
- try:
110
- pointer = getattr(pointer, scope_names[0])
111
- except AttributeError:
112
- logger.info(f"Skipping {'/'.join(name)}")
113
- continue
114
- if len(scope_names) >= 2:
115
- num = int(scope_names[1])
116
- pointer = pointer[num]
117
- if m_name[-11:] == "_embeddings":
118
- pointer = getattr(pointer, "weight")
119
- elif m_name == "kernel":
120
- array = np.transpose(array)
121
- try:
122
- assert (
123
- pointer.shape == array.shape
124
- ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
125
- except AssertionError as e:
126
- e.args += (pointer.shape, array.shape)
127
- raise
128
- logger.info(f"Initialize PyTorch weight {name}")
129
- pointer.data = torch.from_numpy(array)
130
- return model
131
-
132
-
133
- @torch.jit.script
134
- def gelu_impl(x):
135
- """OpenAI's gelu implementation."""
136
- return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
137
- (1.0 + 0.044715 * x * x)))
138
-
139
-
140
- def gelu(x):
141
- return gelu_impl(x)
142
-
143
-
144
- class RotaryEmbedding(torch.nn.Module):
145
- def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
146
- super().__init__()
147
- inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
148
- inv_freq = inv_freq.half()
149
- self.learnable = learnable
150
- if learnable:
151
- self.inv_freq = torch.nn.Parameter(inv_freq)
152
- self.max_seq_len_cached = None
153
- else:
154
- self.register_buffer('inv_freq', inv_freq)
155
- self.max_seq_len_cached = None
156
- self.cos_cached = None
157
- self.sin_cached = None
158
- self.precision = precision
159
-
160
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
161
- error_msgs):
162
- pass
163
-
164
- def forward(self, x, seq_dim=1, seq_len=None):
165
- if seq_len is None:
166
- seq_len = x.shape[seq_dim]
167
- if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
168
- self.max_seq_len_cached = None if self.learnable else seq_len
169
- t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
170
- freqs = torch.einsum('i,j->ij', t, self.inv_freq)
171
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
172
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
173
- if self.precision == torch.bfloat16:
174
- emb = emb.float()
175
-
176
- # [sx, 1 (b * np), hn]
177
- cos_cached = emb.cos()[:, None, :]
178
- sin_cached = emb.sin()[:, None, :]
179
- if self.precision == torch.bfloat16:
180
- cos_cached = cos_cached.bfloat16()
181
- sin_cached = sin_cached.bfloat16()
182
- if self.learnable:
183
- return cos_cached, sin_cached
184
- self.cos_cached, self.sin_cached = cos_cached, sin_cached
185
- return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
186
-
187
-
188
- def rotate_half(x):
189
- x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
190
- return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
191
-
192
-
193
- @torch.jit.script
194
- def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
195
- # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
196
- cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
197
- F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
198
- q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
199
- return q, k
200
-
201
-
202
- def attention_fn(
203
- self,
204
- query_layer,
205
- key_layer,
206
- value_layer,
207
- attention_mask,
208
- hidden_size_per_partition,
209
- layer_id,
210
- layer_past=None,
211
- scaling_attention_score=True,
212
- use_cache=False,
213
- ):
214
- if layer_past is not None:
215
- past_key, past_value = layer_past
216
- key_layer = torch.cat((past_key, key_layer), dim=0)
217
- value_layer = torch.cat((past_value, value_layer), dim=0)
218
-
219
- # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
220
- seq_len, b, nh, hidden_size = key_layer.shape
221
-
222
- if use_cache:
223
- present = (key_layer, value_layer)
224
- else:
225
- present = None
226
-
227
- query_key_layer_scaling_coeff = float(layer_id + 1)
228
- if scaling_attention_score:
229
- query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
230
-
231
- # ===================================
232
- # Raw attention scores. [b, np, s, s]
233
- # ===================================
234
-
235
- # [b, np, sq, sk]
236
- output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
237
-
238
- # [sq, b, np, hn] -> [sq, b * np, hn]
239
- query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
240
- # [sk, b, np, hn] -> [sk, b * np, hn]
241
- key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
242
-
243
- matmul_result = torch.empty(
244
- output_size[0] * output_size[1],
245
- output_size[2],
246
- output_size[3],
247
- dtype=query_layer.dtype,
248
- device=query_layer.device,
249
- )
250
-
251
- matmul_result = torch.baddbmm(
252
- matmul_result,
253
- query_layer.transpose(0, 1), # [b * np, sq, hn]
254
- key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
255
- beta=0.0,
256
- alpha=1.0,
257
- )
258
-
259
- # change view to [b, np, sq, sk]
260
- attention_scores = matmul_result.view(*output_size)
261
-
262
- if self.scale_mask_softmax:
263
- self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
264
- attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())
265
- else:
266
- if not (attention_mask == 0).all():
267
- # if auto-regressive, skip
268
- attention_scores.masked_fill_(attention_mask, -10000.0)
269
- dtype = attention_scores.type()
270
- attention_scores = attention_scores.float()
271
- attention_scores = attention_scores * query_key_layer_scaling_coeff
272
-
273
- attention_probs = F.softmax(attention_scores, dim=-1)
274
-
275
- attention_probs = attention_probs.type(dtype)
276
-
277
- # =========================
278
- # Context layer. [sq, b, hp]
279
- # =========================
280
-
281
- # value_layer -> context layer.
282
- # [sk, b, np, hn] --> [b, np, sq, hn]
283
-
284
- # context layer shape: [b, np, sq, hn]
285
- output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
286
-
287
- # change view [sk, b * np, hn]
288
- value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
289
-
290
- # change view [b * np, sq, sk]
291
- attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
292
-
293
- # matmul: [b * np, sq, hn]
294
- context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
295
-
296
- # change view [b, np, sq, hn]
297
- context_layer = context_layer.view(*output_size)
298
-
299
- # [b, np, sq, hn] --> [sq, b, np, hn]
300
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
301
-
302
- # [sq, b, np, hn] --> [sq, b, hp]
303
- new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
304
- context_layer = context_layer.view(*new_context_layer_shape)
305
-
306
- outputs = (context_layer, present, attention_probs)
307
-
308
- return outputs
309
-
310
-
311
- class SelfAttention(torch.nn.Module):
312
- def __init__(self, hidden_size, num_attention_heads,
313
- layer_id, hidden_size_per_attention_head=None, bias=True,
314
- params_dtype=torch.float, position_encoding_2d=True):
315
- super(SelfAttention, self).__init__()
316
-
317
- self.layer_id = layer_id
318
- self.hidden_size = hidden_size
319
- self.hidden_size_per_partition = hidden_size
320
- self.num_attention_heads = num_attention_heads
321
- self.num_attention_heads_per_partition = num_attention_heads
322
- self.position_encoding_2d = position_encoding_2d
323
- self.rotary_emb = RotaryEmbedding(
324
- self.hidden_size // (self.num_attention_heads * 2)
325
- if position_encoding_2d
326
- else self.hidden_size // self.num_attention_heads,
327
- base=10000,
328
- precision=torch.half,
329
- learnable=False,
330
- )
331
-
332
- self.scale_mask_softmax = None
333
-
334
- if hidden_size_per_attention_head is None:
335
- self.hidden_size_per_attention_head = hidden_size // num_attention_heads
336
- else:
337
- self.hidden_size_per_attention_head = hidden_size_per_attention_head
338
-
339
- self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
340
-
341
- # Strided linear layer.
342
- self.query_key_value = skip_init(
343
- torch.nn.Linear,
344
- hidden_size,
345
- 3 * self.inner_hidden_size,
346
- bias=bias,
347
- dtype=params_dtype,
348
- )
349
-
350
- self.dense = skip_init(
351
- torch.nn.Linear,
352
- self.inner_hidden_size,
353
- hidden_size,
354
- bias=bias,
355
- dtype=params_dtype,
356
- )
357
-
358
- @staticmethod
359
- def attention_mask_func(attention_scores, attention_mask):
360
- attention_scores.masked_fill_(attention_mask, -10000.0)
361
- return attention_scores
362
-
363
- def split_tensor_along_last_dim(self, tensor, num_partitions,
364
- contiguous_split_chunks=False):
365
- """Split a tensor along its last dimension.
366
- Arguments:
367
- tensor: input tensor.
368
- num_partitions: number of partitions to split the tensor
369
- contiguous_split_chunks: If True, make each chunk contiguous
370
- in memory.
371
- """
372
- # Get the size and dimension.
373
- last_dim = tensor.dim() - 1
374
- last_dim_size = tensor.size()[last_dim] // num_partitions
375
- # Split.
376
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
377
- # Note: torch.split does not create contiguous tensors by default.
378
- if contiguous_split_chunks:
379
- return tuple(chunk.contiguous() for chunk in tensor_list)
380
-
381
- return tensor_list
382
-
383
- def forward(
384
- self,
385
- hidden_states: torch.Tensor,
386
- position_ids,
387
- attention_mask: torch.Tensor,
388
- layer_id,
389
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
390
- use_cache: bool = False,
391
- output_attentions: bool = False,
392
- ):
393
- """
394
- hidden_states: [seq_len, batch, hidden_size]
395
- attention_mask: [(1, 1), seq_len, seq_len]
396
- """
397
-
398
- # [seq_len, batch, 3 * hidden_size]
399
- mixed_raw_layer = self.query_key_value(hidden_states)
400
-
401
- # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]
402
- new_tensor_shape = mixed_raw_layer.size()[:-1] + (
403
- self.num_attention_heads_per_partition,
404
- 3 * self.hidden_size_per_attention_head,
405
- )
406
- mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)
407
-
408
- # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
409
- (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3)
410
-
411
- if self.position_encoding_2d:
412
- q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
413
- k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
414
- cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
415
- position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
416
- position_ids[:, 1, :].transpose(0, 1).contiguous()
417
- q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
418
- q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
419
- query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
420
- key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
421
- else:
422
- position_ids = position_ids.transpose(0, 1)
423
- cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
424
- # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
425
- query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids)
426
-
427
- # [seq_len, batch, hidden_size]
428
- context_layer, present, attention_probs = attention_fn(
429
- self=self,
430
- query_layer=query_layer,
431
- key_layer=key_layer,
432
- value_layer=value_layer,
433
- attention_mask=attention_mask,
434
- hidden_size_per_partition=self.hidden_size_per_partition,
435
- layer_id=layer_id,
436
- layer_past=layer_past,
437
- use_cache=use_cache
438
- )
439
-
440
- output = self.dense(context_layer)
441
-
442
- outputs = (output, present)
443
-
444
- if output_attentions:
445
- outputs += (attention_probs,)
446
-
447
- return outputs # output, present, attention_probs
448
-
449
-
450
- class GEGLU(torch.nn.Module):
451
- def __init__(self):
452
- super().__init__()
453
- self.activation_fn = F.gelu
454
-
455
- def forward(self, x):
456
- # dim=-1 breaks in jit for pt<1.10
457
- x1, x2 = x.chunk(2, dim=(x.ndim - 1))
458
- return x1 * self.activation_fn(x2)
459
-
460
-
461
- class GLU(torch.nn.Module):
462
- def __init__(self, hidden_size, inner_hidden_size=None,
463
- layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float):
464
- super(GLU, self).__init__()
465
- self.layer_id = layer_id
466
- self.activation_func = activation_func
467
-
468
- # Project to 4h.
469
- self.hidden_size = hidden_size
470
- if inner_hidden_size is None:
471
- inner_hidden_size = 4 * hidden_size
472
- self.inner_hidden_size = inner_hidden_size
473
- self.dense_h_to_4h = skip_init(
474
- torch.nn.Linear,
475
- self.hidden_size,
476
- self.inner_hidden_size,
477
- bias=bias,
478
- dtype=params_dtype,
479
- )
480
- # Project back to h.
481
- self.dense_4h_to_h = skip_init(
482
- torch.nn.Linear,
483
- self.inner_hidden_size,
484
- self.hidden_size,
485
- bias=bias,
486
- dtype=params_dtype,
487
- )
488
-
489
- def forward(self, hidden_states):
490
- """
491
- hidden_states: [seq_len, batch, hidden_size]
492
- """
493
-
494
- # [seq_len, batch, inner_hidden_size]
495
- intermediate_parallel = self.dense_h_to_4h(hidden_states)
496
-
497
- intermediate_parallel = self.activation_func(intermediate_parallel)
498
-
499
- output = self.dense_4h_to_h(intermediate_parallel)
500
-
501
- return output
502
-
503
-
504
- class GLMBlock(torch.nn.Module):
505
- def __init__(
506
- self,
507
- hidden_size,
508
- num_attention_heads,
509
- layernorm_epsilon,
510
- layer_id,
511
- inner_hidden_size=None,
512
- hidden_size_per_attention_head=None,
513
- layernorm=LayerNorm,
514
- use_bias=True,
515
- params_dtype=torch.float,
516
- num_layers=28,
517
- position_encoding_2d=True
518
- ):
519
- super(GLMBlock, self).__init__()
520
- # Set output layer initialization if not provided.
521
-
522
- self.layer_id = layer_id
523
-
524
- # Layernorm on the input data.
525
- self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
526
-
527
- self.position_encoding_2d = position_encoding_2d
528
-
529
- # Self attention.
530
- self.attention = SelfAttention(
531
- hidden_size,
532
- num_attention_heads,
533
- layer_id,
534
- hidden_size_per_attention_head=hidden_size_per_attention_head,
535
- bias=use_bias,
536
- params_dtype=params_dtype,
537
- position_encoding_2d=self.position_encoding_2d
538
- )
539
-
540
- # Layernorm on the input data.
541
- self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
542
-
543
- self.num_layers = num_layers
544
-
545
- # GLU
546
- self.mlp = GLU(
547
- hidden_size,
548
- inner_hidden_size=inner_hidden_size,
549
- bias=use_bias,
550
- layer_id=layer_id,
551
- params_dtype=params_dtype,
552
- )
553
-
554
- def forward(
555
- self,
556
- hidden_states: torch.Tensor,
557
- position_ids,
558
- attention_mask: torch.Tensor,
559
- layer_id,
560
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
561
- use_cache: bool = False,
562
- output_attentions: bool = False,
563
- ):
564
- """
565
- hidden_states: [seq_len, batch, hidden_size]
566
- attention_mask: [(1, 1), seq_len, seq_len]
567
- """
568
-
569
- # Layer norm at the begining of the transformer layer.
570
- # [seq_len, batch, hidden_size]
571
- attention_input = self.input_layernorm(hidden_states)
572
-
573
- # Self attention.
574
- attention_outputs = self.attention(
575
- attention_input,
576
- position_ids,
577
- attention_mask=attention_mask,
578
- layer_id=layer_id,
579
- layer_past=layer_past,
580
- use_cache=use_cache,
581
- output_attentions=output_attentions
582
- )
583
-
584
- attention_output = attention_outputs[0]
585
-
586
- outputs = attention_outputs[1:]
587
-
588
- # Residual connection.
589
- alpha = (2 * self.num_layers) ** 0.5
590
- hidden_states = attention_input * alpha + attention_output
591
-
592
- mlp_input = self.post_attention_layernorm(hidden_states)
593
-
594
- # MLP.
595
- mlp_output = self.mlp(mlp_input)
596
-
597
- # Second residual connection.
598
- output = mlp_input * alpha + mlp_output
599
-
600
- if use_cache:
601
- outputs = (output,) + outputs
602
- else:
603
- outputs = (output,) + outputs[1:]
604
-
605
- return outputs # hidden_states, present, attentions
606
-
607
-
608
- class ChatGLMPreTrainedModel(PreTrainedModel):
609
- """
610
- An abstract class to handle weights initialization and
611
- a simple interface for downloading and loading pretrained models.
612
- """
613
-
614
- is_parallelizable = False
615
- supports_gradient_checkpointing = False
616
- config_class = ChatGLMConfig
617
- base_model_prefix = "transformer"
618
- _no_split_modules = ["GLM6BBlock"]
619
-
620
- def __init__(self, *inputs, **kwargs):
621
- super().__init__(*inputs, **kwargs)
622
-
623
- def _init_weights(self, module: nn.Module):
624
- """Initialize the weights."""
625
- return
626
-
627
-
628
- CHATGLM_6B_START_DOCSTRING = r"""
629
- This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
630
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
631
- usage and behavior.
632
-
633
- Parameters:
634
- config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
635
- Initializing with a config file does not load the weights associated with the model, only the configuration.
636
- Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
637
- """
638
-
639
- CHATGLM_6B_INPUTS_DOCSTRING = r"""
640
- Args:
641
- input_ids (`torch.LongTensor` of shape `({0})`):
642
- Indices of input sequence tokens in the vocabulary.
643
-
644
- Indices can be obtained using [`ChatGLM6BTokenizer`].
645
- See [`PreTrainedTokenizer.encode`] and
646
- [`PreTrainedTokenizer.__call__`] for details.
647
-
648
- [What are input IDs?](../glossary#input-ids)
649
- attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
650
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
651
-
652
- - 1 for tokens that are **not masked**,
653
- - 0 for tokens that are **masked**.
654
-
655
- [What are attention masks?](../glossary#attention-mask)
656
- token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
657
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
658
-
659
- - 0 corresponds to a *sentence A* token,
660
- - 1 corresponds to a *sentence B* token.
661
-
662
- [What are token type IDs?](../glossary#token-type-ids)
663
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
664
- Indices of positions of each input sequence tokens in the position embeddings.
665
- Selected in the range `[0, config.max_position_embeddings - 1]`.
666
-
667
- [What are position IDs?](../glossary#position-ids)
668
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
669
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
670
-
671
- - 1 indicates the head is **not masked**,
672
- - 0 indicates the head is **masked**.
673
-
674
- inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
675
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
676
- This is useful if you want more control over how to convert *input_ids* indices into associated vectors
677
- than the model's internal embedding lookup matrix.
678
- output_attentions (`bool`, *optional*):
679
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
680
- tensors for more detail.
681
- output_hidden_states (`bool`, *optional*):
682
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
683
- more detail.
684
- return_dict (`bool`, *optional*):
685
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
686
- """
687
-
688
-
689
- @add_start_docstrings(
690
- "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.",
691
- CHATGLM_6B_START_DOCSTRING,
692
- )
693
- class ChatGLMModel(ChatGLMPreTrainedModel):
694
- """
695
-
696
- The model can behave as an encoder (with only self-attention) as well
697
- as a decoder, in which case a layer of cross-attention is added between
698
- the self-attention layers, following the architecture described in [Attention is
699
- all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
700
- Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
701
-
702
- To behave as an decoder the model needs to be initialized with the
703
- `is_decoder` argument of the configuration set to `True`.
704
- To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
705
- argument and `add_cross_attention` set to `True`; an
706
- `encoder_hidden_states` is then expected as an input to the forward pass.
707
- """
708
-
709
- def __init__(self, config: ChatGLMConfig):
710
- super().__init__(config)
711
-
712
- # recording parameters
713
- self.max_sequence_length = config.max_sequence_length
714
- self.hidden_size = config.hidden_size
715
- self.params_dtype = torch.half
716
- self.num_attention_heads = config.num_attention_heads
717
- self.vocab_size = config.vocab_size
718
- self.num_layers = config.num_layers
719
- self.layernorm_epsilon = config.layernorm_epsilon
720
- self.inner_hidden_size = config.inner_hidden_size
721
- self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
722
- self.position_encoding_2d = config.position_encoding_2d
723
-
724
- self.word_embeddings = skip_init(
725
- torch.nn.Embedding,
726
- num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
727
- dtype=self.params_dtype
728
- )
729
-
730
- def get_layer(layer_id):
731
- return GLMBlock(
732
- self.hidden_size,
733
- self.num_attention_heads,
734
- self.layernorm_epsilon,
735
- layer_id,
736
- inner_hidden_size=self.inner_hidden_size,
737
- hidden_size_per_attention_head=self.hidden_size_per_attention_head,
738
- layernorm=LayerNorm,
739
- use_bias=True,
740
- params_dtype=self.params_dtype,
741
- position_encoding_2d=self.position_encoding_2d,
742
- )
743
-
744
- self.layers = torch.nn.ModuleList(
745
- [get_layer(layer_id) for layer_id in range(self.num_layers)]
746
- )
747
-
748
- # Final layer norm before output.
749
- self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
750
-
751
- def get_input_embeddings(self):
752
- return self.word_embeddings
753
-
754
- def set_input_embeddings(self, new_embeddings: torch.Tensor):
755
- self.word_embeddings = new_embeddings
756
-
757
- def get_masks(self, seq, device):
758
- context_length = seq.index(self.config.bos_token_id) + 1
759
-
760
- attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
761
- attention_mask.tril_()
762
- attention_mask[..., :context_length - 1] = 1
763
- attention_mask.unsqueeze_(1)
764
- attention_mask = (attention_mask < 0.5).bool()
765
-
766
- return attention_mask
767
-
768
- def get_position_ids(self, seq, mask_position, device, gmask=False):
769
- context_length = seq.index(self.config.bos_token_id) + 1
770
- if self.position_encoding_2d:
771
- seq_length = seq.index(self.config.bos_token_id)
772
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
773
- if not gmask:
774
- position_ids[seq_length:] = mask_position
775
- block_position_ids = torch.cat((
776
- torch.zeros(seq_length, dtype=torch.long, device=device),
777
- torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
778
- ))
779
- position_ids = torch.stack((position_ids, block_position_ids), dim=0)
780
- else:
781
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
782
- if not gmask:
783
- position_ids[context_length - 1:] = mask_position
784
-
785
- position_ids = position_ids.unsqueeze(0)
786
-
787
- return position_ids
788
-
789
- @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
790
- @add_code_sample_docstrings(
791
- checkpoint=_CHECKPOINT_FOR_DOC,
792
- output_type=BaseModelOutputWithPastAndCrossAttentions,
793
- config_class=_CONFIG_FOR_DOC,
794
- )
795
- def forward(
796
- self,
797
- input_ids: Optional[torch.LongTensor] = None,
798
- position_ids: Optional[torch.LongTensor] = None,
799
- attention_mask: Optional[torch.Tensor] = None,
800
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
801
- inputs_embeds: Optional[torch.LongTensor] = None,
802
- use_cache: Optional[bool] = None,
803
- output_attentions: Optional[bool] = None,
804
- output_hidden_states: Optional[bool] = None,
805
- return_dict: Optional[bool] = None,
806
- ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
807
-
808
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
809
- output_hidden_states = (
810
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
811
- )
812
- use_cache = use_cache if use_cache is not None else self.config.use_cache
813
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
814
-
815
- if input_ids is not None and inputs_embeds is not None:
816
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
817
- elif input_ids is not None:
818
- batch_size, seq_length = input_ids.shape[:2]
819
- elif inputs_embeds is not None:
820
- batch_size, seq_length, _ = inputs_embeds.shape[:2]
821
- else:
822
- raise ValueError("You have to specify either input_ids or inputs_embeds")
823
-
824
- if past_key_values is None:
825
- past_key_values = tuple([None] * len(self.layers))
826
- seq = input_ids[0].tolist()
827
-
828
- if attention_mask is None:
829
- attention_mask = self.get_masks(
830
- seq=seq,
831
- device=input_ids.device
832
- )
833
-
834
- if position_ids is None:
835
- MASK, gMASK = 150000, 150001
836
- mask_token = MASK if MASK in input_ids else gMASK
837
- use_gmask = False if MASK in input_ids else gMASK
838
-
839
- mask_position = seq.index(mask_token)
840
- position_ids = self.get_position_ids(
841
- seq=seq,
842
- mask_position=mask_position,
843
- device=input_ids.device,
844
- gmask=use_gmask
845
- )
846
-
847
- if inputs_embeds is None:
848
- inputs_embeds = self.word_embeddings(input_ids)
849
-
850
- # [seq_len, batch, hidden_size]
851
- hidden_states = inputs_embeds.transpose(0, 1)
852
-
853
- presents = () if use_cache else None
854
- all_self_attentions = () if output_attentions else None
855
- all_hidden_states = () if output_hidden_states else None
856
-
857
- seq_length_with_past = seq_length
858
- past_key_values_length = 0
859
- if past_key_values[0] is not None:
860
- past_key_values_length = past_key_values[0][0].shape[0]
861
- seq_length_with_past = seq_length_with_past + past_key_values_length
862
- if attention_mask is None:
863
- attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
864
-
865
- else:
866
- attention_mask = attention_mask.to(input_ids.device)
867
-
868
- for i, layer in enumerate(self.layers):
869
-
870
- if output_hidden_states:
871
- all_hidden_states = all_hidden_states + (hidden_states,)
872
-
873
- layer_ret = layer(
874
- hidden_states,
875
- position_ids=position_ids,
876
- attention_mask=attention_mask,
877
- layer_id=torch.tensor(i),
878
- layer_past=past_key_values[i],
879
- use_cache=use_cache,
880
- output_attentions=output_attentions
881
- )
882
-
883
- hidden_states = layer_ret[0]
884
-
885
- if use_cache:
886
- presents = presents + (layer_ret[1],)
887
-
888
- if output_attentions:
889
- all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
890
-
891
- # Final layer norm.
892
- hidden_states = self.final_layernorm(hidden_states)
893
-
894
- if output_hidden_states:
895
- all_hidden_states = all_hidden_states + (hidden_states,)
896
-
897
- if not return_dict:
898
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
899
-
900
- return BaseModelOutputWithPast(
901
- last_hidden_state=hidden_states,
902
- past_key_values=presents,
903
- hidden_states=all_hidden_states,
904
- attentions=all_self_attentions,
905
- )
906
-
907
-
908
- class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
909
- def __init__(self, config):
910
- super().__init__(config)
911
-
912
- # self.hidden_size = config.hidden_size
913
- # self.params_dtype = torch.half
914
- # self.vocab_size = config.vocab_size
915
- self.max_sequence_length = config.max_sequence_length
916
-
917
- self.position_encoding_2d = config.position_encoding_2d
918
-
919
- self.transformer = ChatGLMModel(config)
920
-
921
- self.lm_head = skip_init(
922
- nn.Linear,
923
- config.hidden_size,
924
- config.vocab_size,
925
- bias=False,
926
- dtype=torch.half
927
- )
928
-
929
- def get_output_embeddings(self):
930
- return self.lm_head
931
-
932
- def set_output_embeddings(self, new_embeddings):
933
- self.lm_head = new_embeddings
934
-
935
- def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
936
- attention_mask = torch.ones((1, context_length, context_length), device=device)
937
- attention_mask.tril_()
938
- attention_mask[..., :context_length - 1] = 1
939
- attention_mask.unsqueeze_(1)
940
- attention_mask = (attention_mask < 0.5).bool()
941
-
942
- if self.position_encoding_2d:
943
- seq_length = seq.index(self.config.bos_token_id)
944
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
945
- if not gmask:
946
- position_ids[seq_length:] = mask_position
947
- block_position_ids = torch.cat((
948
- torch.zeros(seq_length, dtype=torch.long, device=device),
949
- torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
950
- ))
951
- position_ids = torch.stack((position_ids, block_position_ids), dim=0)
952
- else:
953
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
954
- if not gmask:
955
- position_ids[context_length - 1:] = mask_position
956
-
957
- position_ids = position_ids.unsqueeze(0)
958
-
959
- return attention_mask, position_ids
960
-
961
- def prepare_inputs_for_generation(
962
- self,
963
- input_ids: torch.LongTensor,
964
- past: Optional[torch.Tensor] = None,
965
- past_key_values: Optional[torch.Tensor] = None,
966
- attention_mask: Optional[torch.Tensor] = None,
967
- **kwargs
968
- ) -> dict:
969
-
970
- MASK, gMASK = 150000, 150001
971
- mask_token = MASK if MASK in input_ids else gMASK
972
- use_gmask = False if MASK in input_ids else gMASK
973
- seq = input_ids[0].tolist()
974
- mask_position = seq.index(mask_token)
975
-
976
- if mask_token not in seq:
977
- raise ValueError("You have to add either [MASK] or [gMASK] in your input")
978
-
979
- # only last token for input_ids if past is not None
980
- if past is not None or past_key_values is not None:
981
- context_length = seq.index(self.config.bos_token_id)
982
- last_token = input_ids[:, -1].unsqueeze(-1)
983
- if self.position_encoding_2d:
984
- position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
985
- device=input_ids.device)
986
- else:
987
- position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
988
-
989
- if past is None:
990
- past = past_key_values
991
- return {
992
- "input_ids": last_token,
993
- "past_key_values": past,
994
- "position_ids": position_ids,
995
- }
996
- else:
997
- attention_mask, position_ids = self.get_masks_and_position_ids(
998
- seq=seq,
999
- mask_position=mask_position,
1000
- context_length=len(seq),
1001
- device=input_ids.device,
1002
- gmask=use_gmask
1003
- )
1004
-
1005
- return {
1006
- "input_ids": input_ids,
1007
- "past_key_values": past,
1008
- "position_ids": position_ids,
1009
- "attention_mask": attention_mask
1010
- }
1011
-
1012
- def forward(
1013
- self,
1014
- input_ids: Optional[torch.Tensor] = None,
1015
- position_ids: Optional[torch.Tensor] = None,
1016
- attention_mask: Optional[torch.Tensor] = None,
1017
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1018
- inputs_embeds: Optional[torch.Tensor] = None,
1019
- labels: Optional[torch.Tensor] = None,
1020
- use_cache: Optional[bool] = None,
1021
- output_attentions: Optional[bool] = None,
1022
- output_hidden_states: Optional[bool] = None,
1023
- return_dict: Optional[bool] = None,
1024
- ):
1025
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1026
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1027
-
1028
- transformer_outputs = self.transformer(
1029
- input_ids=input_ids,
1030
- position_ids=position_ids,
1031
- attention_mask=attention_mask,
1032
- past_key_values=past_key_values,
1033
- inputs_embeds=inputs_embeds,
1034
- use_cache=use_cache,
1035
- output_attentions=output_attentions,
1036
- output_hidden_states=output_hidden_states,
1037
- return_dict=return_dict,
1038
- )
1039
-
1040
- hidden_states = transformer_outputs[0]
1041
-
1042
- lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous()
1043
-
1044
- loss = None
1045
- if labels is not None:
1046
- lm_logits = lm_logits.to(torch.float32)
1047
-
1048
- # Shift so that tokens < n predict n
1049
- shift_logits = lm_logits[..., :-1, :].contiguous()
1050
- shift_labels = labels[..., 1:].contiguous()
1051
- # Flatten the tokens
1052
- loss_fct = CrossEntropyLoss()
1053
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1054
-
1055
- lm_logits = lm_logits.to(hidden_states.dtype)
1056
- loss = loss.to(hidden_states.dtype)
1057
-
1058
- if not return_dict:
1059
- output = (lm_logits,) + transformer_outputs[1:]
1060
- return ((loss,) + output) if loss is not None else output
1061
-
1062
- return CausalLMOutputWithPast(
1063
- loss=loss,
1064
- logits=lm_logits,
1065
- past_key_values=transformer_outputs.past_key_values,
1066
- hidden_states=transformer_outputs.hidden_states,
1067
- attentions=transformer_outputs.attentions,
1068
- )
1069
-
1070
- @staticmethod
1071
- def _reorder_cache(
1072
- past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1073
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1074
- """
1075
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1076
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1077
- beam_idx at every generation step.
1078
-
1079
- Output shares the same memory storage as `past`.
1080
- """
1081
- return tuple(
1082
- (
1083
- layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
1084
- layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1085
- )
1086
- for layer_past in past
1087
- )
1088
-
1089
- def process_response(self, response):
1090
- response = response.strip()
1091
- response = response.replace("[[训练时间]]", "2023年")
1092
- punkts = [
1093
- [",", ","],
1094
- ["!", "!"],
1095
- [":", ":"],
1096
- [";", ";"],
1097
- ["\?", "?"],
1098
- ]
1099
- for item in punkts:
1100
- response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
1101
- response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
1102
- return response
1103
-
1104
- @torch.no_grad()
1105
- def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1106
- do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
1107
- if history is None:
1108
- history = []
1109
- if logits_processor is None:
1110
- logits_processor = LogitsProcessorList()
1111
- logits_processor.append(InvalidScoreLogitsProcessor())
1112
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1113
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1114
- if not history:
1115
- prompt = query
1116
- else:
1117
- prompt = ""
1118
- for i, (old_query, response) in enumerate(history):
1119
- prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1120
- prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1121
- input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1122
- input_ids = input_ids.to(self.device)
1123
- outputs = self.generate(**input_ids, **gen_kwargs)
1124
- outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1125
- response = tokenizer.decode(outputs)
1126
- response = self.process_response(response)
1127
- history = history + [(query, response)]
1128
- return response, history
1129
-
1130
- @torch.no_grad()
1131
- def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
1132
- do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
1133
- if history is None:
1134
- history = []
1135
- if logits_processor is None:
1136
- logits_processor = LogitsProcessorList()
1137
- logits_processor.append(InvalidScoreLogitsProcessor())
1138
- gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1139
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1140
- if not history:
1141
- prompt = query
1142
- else:
1143
- prompt = ""
1144
- for i, (old_query, response) in enumerate(history):
1145
- prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1146
- prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1147
- input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1148
- input_ids = input_ids.to(self.device)
1149
- for outputs in self.stream_generate(**input_ids, **gen_kwargs):
1150
- outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1151
- response = tokenizer.decode(outputs)
1152
- response = self.process_response(response)
1153
- new_history = history + [(query, response)]
1154
- yield response, new_history
1155
-
1156
- @torch.no_grad()
1157
- def stream_generate(
1158
- self,
1159
- input_ids,
1160
- generation_config: Optional[GenerationConfig] = None,
1161
- logits_processor: Optional[LogitsProcessorList] = None,
1162
- stopping_criteria: Optional[StoppingCriteriaList] = None,
1163
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1164
- **kwargs,
1165
- ):
1166
- batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1167
-
1168
- if generation_config is None:
1169
- generation_config = self.generation_config
1170
- generation_config = copy.deepcopy(generation_config)
1171
- model_kwargs = generation_config.update(**kwargs)
1172
- bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1173
-
1174
- if isinstance(eos_token_id, int):
1175
- eos_token_id = [eos_token_id]
1176
-
1177
- has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1178
- if has_default_max_length and generation_config.max_new_tokens is None:
1179
- warnings.warn(
1180
- f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1181
- "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1182
- " recommend using `max_new_tokens` to control the maximum length of the generation.",
1183
- UserWarning,
1184
- )
1185
- elif generation_config.max_new_tokens is not None:
1186
- generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1187
- if not has_default_max_length:
1188
- logger.warn(
1189
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1190
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1191
- "Please refer to the documentation for more information. "
1192
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1193
- UserWarning,
1194
- )
1195
-
1196
- if input_ids_seq_length >= generation_config.max_length:
1197
- input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1198
- logger.warning(
1199
- f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1200
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1201
- " increasing `max_new_tokens`."
1202
- )
1203
-
1204
- # 2. Set generation parameters if not already defined
1205
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1206
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1207
-
1208
- logits_processor = self._get_logits_processor(
1209
- generation_config=generation_config,
1210
- input_ids_seq_length=input_ids_seq_length,
1211
- encoder_input_ids=input_ids,
1212
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1213
- logits_processor=logits_processor,
1214
- )
1215
-
1216
- stopping_criteria = self._get_stopping_criteria(
1217
- generation_config=generation_config, stopping_criteria=stopping_criteria
1218
- )
1219
- logits_warper = self._get_logits_warper(generation_config)
1220
-
1221
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1222
- scores = None
1223
- while True:
1224
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1225
- # forward pass to get next token
1226
- outputs = self(
1227
- **model_inputs,
1228
- return_dict=True,
1229
- output_attentions=False,
1230
- output_hidden_states=False,
1231
- )
1232
-
1233
- next_token_logits = outputs.logits[:, -1, :]
1234
-
1235
- # pre-process distribution
1236
- next_token_scores = logits_processor(input_ids, next_token_logits)
1237
- next_token_scores = logits_warper(input_ids, next_token_scores)
1238
-
1239
- # sample
1240
- probs = nn.functional.softmax(next_token_scores, dim=-1)
1241
- if generation_config.do_sample:
1242
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1243
- else:
1244
- next_tokens = torch.argmax(probs, dim=-1)
1245
-
1246
- # update generated ids, model inputs, and length for next step
1247
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1248
- model_kwargs = self._update_model_kwargs_for_generation(
1249
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1250
- )
1251
- unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
1252
-
1253
- # stop when each sentence is finished, or if we exceed the maximum length
1254
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1255
- break
1256
- yield input_ids
1257
-
1258
- def quantize(self, bits: int):
1259
- from .quantization import quantize
1260
- self.transformer = quantize(self.transformer, bits)
1261
- return self