robinzixuan commited on
Commit
4bd71d4
·
verified ·
1 Parent(s): b0eaf6f

Update modeling_opt.py

Browse files
Files changed (1) hide show
  1. modeling_opt.py +2 -174
modeling_opt.py CHANGED
@@ -3,6 +3,7 @@
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
6
  # You may obtain a copy of the License at
7
  #
8
  # http://www.apache.org/licenses/LICENSE-2.0
@@ -315,183 +316,10 @@ class OPTAttention(nn.Module):
315
  return attn_output, attn_weights_reshaped, past_key_value
316
 
317
 
318
- class OPTOutEffHop(OPTAttention):
319
- """Multi-headed attention from 'Attention Is All You Need' paper"""
320
-
321
- def __init__(
322
- self,
323
- config: OPTConfig,
324
- is_decoder: bool = False,
325
- **kwargs,
326
- ):
327
- super().__init__()
328
- self.config = config
329
- self.embed_dim = config.hidden_size
330
- self.num_heads = config.num_attention_heads
331
- self.dropout = config.attention_dropout
332
- self.enable_bias = config.enable_bias
333
- self.attention= softmax_1
334
- self.head_dim = self.embed_dim // self.num_heads
335
- self.is_causal = True
336
-
337
- if (self.head_dim * self.num_heads) != self.embed_dim:
338
- raise ValueError(
339
- f'''embed_dim must be divisible by num_heads (got `embed_dim`: {
340
- self.embed_dim}'''
341
- f" and `num_heads`: {self.num_heads})."
342
- )
343
- self.scaling = self.head_dim**-0.5
344
- self.is_decoder = is_decoder
345
-
346
- self.k_proj = nn.Linear(
347
- self.embed_dim, self.embed_dim, bias=self.enable_bias)
348
- self.v_proj = nn.Linear(
349
- self.embed_dim, self.embed_dim, bias=self.enable_bias)
350
- self.q_proj = nn.Linear(
351
- self.embed_dim, self.embed_dim, bias=self.enable_bias)
352
- self.out_proj = nn.Linear(
353
- self.embed_dim, self.embed_dim, bias=self.enable_bias)
354
-
355
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
356
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
357
-
358
- def forward(
359
- self,
360
- hidden_states: torch.Tensor,
361
- key_value_states: Optional[torch.Tensor] = None,
362
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
363
- attention_mask: Optional[torch.Tensor] = None,
364
- layer_head_mask: Optional[torch.Tensor] = None,
365
- output_attentions: bool = False,
366
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
367
- """Input shape: Batch x Time x Channel"""
368
-
369
- # if key_value_states are provided this layer is used as a cross-attention layer
370
- # for the decoder
371
- is_cross_attention = key_value_states is not None
372
-
373
- bsz, tgt_len, _ = hidden_states.size()
374
-
375
- # get query proj
376
- query_states = self.q_proj(hidden_states) * self.scaling
377
- # get key, value proj
378
- if is_cross_attention and past_key_value is not None:
379
- # reuse k,v, cross_attentions
380
- key_states = past_key_value[0]
381
- value_states = past_key_value[1]
382
- elif is_cross_attention:
383
- # cross_attentions
384
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
385
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
386
- elif past_key_value is not None:
387
- # reuse k, v, self_attention
388
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
389
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
390
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
391
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
392
- else:
393
- # self_attention
394
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
395
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
396
-
397
- if self.is_decoder:
398
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
399
- # Further calls to cross_attention layer can then reuse all cross-attention
400
- # key/value_states (first "if" case)
401
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
402
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
403
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
404
- # if encoder bi-directional self-attention `past_key_value` is always `None`
405
- past_key_value = (key_states, value_states)
406
-
407
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
408
- query_states = self._shape(
409
- query_states, tgt_len, bsz).view(*proj_shape)
410
- key_states = key_states.view(*proj_shape)
411
- value_states = value_states.view(*proj_shape)
412
-
413
- src_len = key_states.size(1)
414
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
415
-
416
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
417
- raise ValueError(
418
- f'''Attention weights should be of size {
419
- (bsz * self.num_heads, tgt_len, src_len)}, but is"
420
- f" {attn_weights.size()}'''
421
- )
422
 
423
- if attention_mask is not None:
424
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
425
- raise ValueError(
426
- f'''Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {
427
- attention_mask.size()}'''
428
- )
429
- attn_weights = attn_weights.view(
430
- bsz, self.num_heads, tgt_len, src_len) + attention_mask
431
- attn_weights = torch.max(
432
- attn_weights, torch.tensor(torch.finfo(
433
- attn_weights.dtype).min, device=attn_weights.device)
434
- )
435
- attn_weights = attn_weights.view(
436
- bsz * self.num_heads, tgt_len, src_len)
437
-
438
- # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
439
- if attn_weights.dtype == torch.float16:
440
- attn_weights = softmax_1(
441
- attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
442
- else:
443
- attn_weights = softmax_1(attn_weights, dim=-1)
444
-
445
- if layer_head_mask is not None:
446
- if layer_head_mask.size() != (self.num_heads,):
447
- raise ValueError(
448
- f'''Head mask for a single layer should be of size {
449
- (self.num_heads,)}, but is'''
450
- f" {layer_head_mask.size()}"
451
- )
452
- attn_weights = layer_head_mask.view(
453
- 1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
454
- attn_weights = attn_weights.view(
455
- bsz * self.num_heads, tgt_len, src_len)
456
-
457
- if output_attentions:
458
- # this operation is a bit awkward, but it's required to
459
- # make sure that attn_weights keeps its gradient.
460
- # In order to do so, attn_weights have to be reshaped
461
- # twice and have to be reused in the following
462
- attn_weights_reshaped = attn_weights.view(
463
- bsz, self.num_heads, tgt_len, src_len)
464
- attn_weights = attn_weights_reshaped.view(
465
- bsz * self.num_heads, tgt_len, src_len)
466
- else:
467
- attn_weights_reshaped = None
468
-
469
- attn_probs = nn.functional.dropout(
470
- attn_weights, p=self.dropout, training=self.training)
471
-
472
- attn_output = torch.bmm(attn_probs, value_states)
473
-
474
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
475
- raise ValueError(
476
- f'''`attn_output` should be of size {
477
- (bsz, self.num_heads, tgt_len, self.head_dim)}, but is'''
478
- f" {attn_output.size()}"
479
- )
480
-
481
- attn_output = attn_output.view(
482
- bsz, self.num_heads, tgt_len, self.head_dim)
483
- attn_output = attn_output.transpose(1, 2)
484
-
485
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
486
- # partitioned aross GPUs when using tensor-parallelism.
487
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
488
-
489
- attn_output = self.out_proj(attn_output)
490
-
491
- return attn_output, attn_weights_reshaped, past_key_value
492
 
493
 
494
- class OptFlashAttention2(OPTOutEffHop):
495
  """
496
  OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
497
  The only required change would be on the forward pass where it needs to correctly call the public API of flash
 
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
6
+
7
  # You may obtain a copy of the License at
8
  #
9
  # http://www.apache.org/licenses/LICENSE-2.0
 
316
  return attn_output, attn_weights_reshaped, past_key_value
317
 
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
 
322
+ class OptFlashAttention2(OPTAttention):
323
  """
324
  OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
325
  The only required change would be on the forward pass where it needs to correctly call the public API of flash