Text Generation
English
Eval Results
zifei9 commited on
Commit
da7b5ba
1 Parent(s): d9cc98d

Updated modeling_gpt2 that uses accelerate API to parallelize model

Browse files
Files changed (1) hide show
  1. modeling_gpt2.py +45 -206
modeling_gpt2.py CHANGED
@@ -27,17 +27,17 @@ from torch import nn
27
  from torch.cuda.amp import autocast
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
 
30
- from transformers.activations import ACT2FN
31
- from transformers.modeling_outputs import (
32
  BaseModelOutputWithPastAndCrossAttentions,
33
  CausalLMOutputWithCrossAttentions,
34
  QuestionAnsweringModelOutput,
35
  SequenceClassifierOutputWithPast,
36
  TokenClassifierOutput,
37
  )
38
- from transformers.modeling_utils import PreTrainedModel, SequenceSummary
39
- from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
40
- from transformers.utils import (
41
  ModelOutput,
42
  add_code_sample_docstrings,
43
  add_start_docstrings,
@@ -45,7 +45,7 @@ from transformers.utils import (
45
  logging,
46
  replace_return_docstrings,
47
  )
48
- from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
49
  from .configuration_gpt2 import GPT2Config
50
 
51
 
@@ -194,7 +194,7 @@ class GPT2Attention(nn.Module):
194
  if not self.is_cross_attention:
195
  # if only "normal" attention layer implements causal mask
196
  query_length, key_length = query.size(-2), key.size(-2)
197
- causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
198
  mask_value = torch.finfo(attn_weights.dtype).min
199
  # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
200
  # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
@@ -606,56 +606,6 @@ GPT2_INPUTS_DOCSTRING = r"""
606
  return_dict (`bool`, *optional*):
607
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
608
  """
609
- PARALLELIZE_DOCSTRING = r"""
610
- This is an experimental feature and is a subject to change at a moment's notice.
611
-
612
- Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
613
- it will evenly distribute blocks across all devices.
614
-
615
- Args:
616
- device_map (`Dict[int, list]`, optional, defaults to None):
617
- A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
618
- automatically mapped to the first device (for esoteric reasons). That means that the first device should
619
- have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
620
- following number of attention modules:
621
-
622
- - gpt2: 12
623
- - gpt2-medium: 24
624
- - gpt2-large: 36
625
- - gpt2-xl: 48
626
-
627
- Example:
628
-
629
- ```python
630
- # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
631
- model = GPT2LMHeadModel.from_pretrained("gpt2-xl")
632
- device_map = {
633
- 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
634
- 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
635
- 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
636
- 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
637
- }
638
- model.parallelize(device_map)
639
- ```
640
- """
641
- DEPARALLELIZE_DOCSTRING = r"""
642
- Moves the model to cpu from a model parallel state.
643
-
644
- Example:
645
-
646
- ```python
647
- # On a 4 GPU machine with gpt2-large:
648
- model = GPT2LMHeadModel.from_pretrained("gpt2-large")
649
- device_map = {
650
- 0: [0, 1, 2, 3, 4, 5, 6, 7],
651
- 1: [8, 9, 10, 11, 12, 13, 14, 15],
652
- 2: [16, 17, 18, 19, 20, 21, 22, 23],
653
- 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
654
- }
655
- model.parallelize(device_map) # Splits the model across several devices
656
- model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
657
- ```
658
- """
659
 
660
 
661
  @add_start_docstrings(
@@ -676,57 +626,13 @@ class GPT2Model(GPT2PreTrainedModel):
676
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
677
 
678
  # Model parallel
679
- self.model_parallel = False
680
- self.device_map = None
681
  self.gradient_checkpointing = False
682
 
683
  # Initialize weights and apply final processing
684
  self.post_init()
685
 
686
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
687
- def parallelize(self, device_map=None):
688
- # Check validity of device_map
689
- warnings.warn(
690
- "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
691
- " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
692
- " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
693
- " ...}",
694
- FutureWarning,
695
- )
696
- self.device_map = (
697
- get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
698
- )
699
- assert_device_map(self.device_map, len(self.h))
700
- self.model_parallel = True
701
- self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
702
- self.last_device = "cuda:" + str(max(self.device_map.keys()))
703
- self.wte = self.wte.to(self.first_device)
704
- self.wpe = self.wpe.to(self.first_device)
705
- # Load onto devices
706
- for k, v in self.device_map.items():
707
- for block in v:
708
- cuda_device = "cuda:" + str(k)
709
- self.h[block] = self.h[block].to(cuda_device)
710
- # ln_f to last
711
- self.ln_f = self.ln_f.to(self.last_device)
712
-
713
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
714
- def deparallelize(self):
715
- warnings.warn(
716
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
717
- FutureWarning,
718
- )
719
- self.model_parallel = False
720
- self.device_map = None
721
- self.first_device = "cpu"
722
- self.last_device = "cpu"
723
- self.wte = self.wte.to("cpu")
724
- self.wpe = self.wpe.to("cpu")
725
- for index in range(len(self.h)):
726
- self.h[index] = self.h[index].to("cpu")
727
- self.ln_f = self.ln_f.to("cpu")
728
- torch.cuda.empty_cache()
729
-
730
  def get_input_embeddings(self):
731
  return self.wte
732
 
@@ -813,7 +719,7 @@ class GPT2Model(GPT2PreTrainedModel):
813
  # positions we want to attend and the dtype's smallest value for masked positions.
814
  # Since we are adding it to the raw scores before the softmax, this is
815
  # effectively the same as removing these entirely.
816
- attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
817
  attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
818
 
819
  # If a 2D or 3D attention mask is provided for the cross-attention
@@ -836,7 +742,7 @@ class GPT2Model(GPT2PreTrainedModel):
836
  if inputs_embeds is None:
837
  inputs_embeds = self.wte(input_ids)
838
  position_embeds = self.wpe(position_ids)
839
- hidden_states = inputs_embeds + position_embeds
840
 
841
  if token_type_ids is not None:
842
  token_type_embeds = self.wte(token_type_ids)
@@ -858,17 +764,14 @@ class GPT2Model(GPT2PreTrainedModel):
858
  all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
859
  all_hidden_states = () if output_hidden_states else None
860
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
861
- # Model parallel
862
- if self.model_parallel:
863
- torch.cuda.set_device(hidden_states.device)
864
- # Ensure layer_past is on same device as hidden_states (might not be correct)
865
- if layer_past is not None:
866
- layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
867
- # Ensure that attention_mask is always on the same device as hidden_states
868
- if attention_mask is not None:
869
- attention_mask = attention_mask.to(hidden_states.device)
870
- if isinstance(head_mask, torch.Tensor):
871
- head_mask = head_mask.to(hidden_states.device)
872
  if output_hidden_states:
873
  all_hidden_states = all_hidden_states + (hidden_states,)
874
 
@@ -905,11 +808,11 @@ class GPT2Model(GPT2PreTrainedModel):
905
  if self.config.add_cross_attention:
906
  all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
907
 
908
- # Model Parallel: If it's the last layer for that device, put things on the next device
909
- if self.model_parallel:
910
- for k, v in self.device_map.items():
911
- if i == v[-1] and "cuda:" + str(k) != self.last_device:
912
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
913
 
914
  hidden_states = self.ln_f(hidden_states)
915
 
@@ -950,43 +853,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
950
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
951
 
952
  # Model parallel
953
- self.model_parallel = False
954
- self.device_map = None
955
 
956
  # Initialize weights and apply final processing
957
  self.post_init()
958
 
959
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
960
- def parallelize(self, device_map=None):
961
- warnings.warn(
962
- "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
963
- " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
964
- " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
965
- " 0, 'transformer.h.1': 1, ...}",
966
- FutureWarning,
967
- )
968
- self.device_map = (
969
- get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
970
- if device_map is None
971
- else device_map
972
- )
973
- assert_device_map(self.device_map, len(self.transformer.h))
974
- self.transformer.parallelize(self.device_map)
975
- self.lm_head = self.lm_head.to(self.transformer.first_device)
976
- self.model_parallel = True
977
-
978
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
979
- def deparallelize(self):
980
- warnings.warn(
981
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
982
- FutureWarning,
983
- )
984
- self.transformer.deparallelize()
985
- self.transformer = self.transformer.to("cpu")
986
- self.lm_head = self.lm_head.to("cpu")
987
- self.model_parallel = False
988
- torch.cuda.empty_cache()
989
-
990
  def get_output_embeddings(self):
991
  return self.lm_head
992
 
@@ -1089,9 +961,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
1089
  hidden_states = transformer_outputs[0]
1090
 
1091
  # Set device for model parallelism
1092
- if self.model_parallel:
1093
- torch.cuda.set_device(self.transformer.first_device)
1094
- hidden_states = hidden_states.to(self.lm_head.weight.device)
1095
 
1096
  lm_logits = self.lm_head(hidden_states)
1097
 
@@ -1153,46 +1025,13 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1153
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1154
  self.multiple_choice_head = SequenceSummary(config)
1155
 
1156
- # Model parallel
1157
- self.model_parallel = False
1158
- self.device_map = None
1159
 
1160
  # Initialize weights and apply final processing
1161
  self.post_init()
1162
 
1163
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
1164
- def parallelize(self, device_map=None):
1165
- warnings.warn(
1166
- "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
1167
- " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
1168
- " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1169
- " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
1170
- FutureWarning,
1171
- )
1172
- self.device_map = (
1173
- get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1174
- if device_map is None
1175
- else device_map
1176
- )
1177
- assert_device_map(self.device_map, len(self.transformer.h))
1178
- self.transformer.parallelize(self.device_map)
1179
- self.lm_head = self.lm_head.to(self.transformer.first_device)
1180
- self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
1181
- self.model_parallel = True
1182
-
1183
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1184
- def deparallelize(self):
1185
- warnings.warn(
1186
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1187
- FutureWarning,
1188
- )
1189
- self.transformer.deparallelize()
1190
- self.transformer = self.transformer.to("cpu")
1191
- self.lm_head = self.lm_head.to("cpu")
1192
- self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1193
- self.model_parallel = False
1194
- torch.cuda.empty_cache()
1195
-
1196
  def get_output_embeddings(self):
1197
  return self.lm_head
1198
 
@@ -1314,10 +1153,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1314
 
1315
  hidden_states = transformer_outputs[0]
1316
 
1317
- # Set device for model parallelism
1318
- if self.model_parallel:
1319
- torch.cuda.set_device(self.transformer.first_device)
1320
- hidden_states = hidden_states.to(self.lm_head.weight.device)
1321
 
1322
  lm_logits = self.lm_head(hidden_states)
1323
  mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
@@ -1387,9 +1226,9 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1387
  self.transformer = GPT2Model(config)
1388
  self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1389
 
1390
- # Model parallel
1391
- self.model_parallel = False
1392
- self.device_map = None
1393
 
1394
  # Initialize weights and apply final processing
1395
  self.post_init()
@@ -1521,9 +1360,9 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1521
  self.dropout = nn.Dropout(classifier_dropout)
1522
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1523
 
1524
- # Model parallel
1525
- self.model_parallel = False
1526
- self.device_map = None
1527
 
1528
  # Initialize weights and apply final processing
1529
  self.post_init()
@@ -1624,9 +1463,9 @@ class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
1624
  self.transformer = GPT2Model(config)
1625
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1626
 
1627
- # Model parallel
1628
- self.model_parallel = False
1629
- self.device_map = None
1630
 
1631
  # Initialize weights and apply final processing
1632
  self.post_init()
 
27
  from torch.cuda.amp import autocast
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
 
30
+ from ...activations import ACT2FN
31
+ from ...modeling_outputs import (
32
  BaseModelOutputWithPastAndCrossAttentions,
33
  CausalLMOutputWithCrossAttentions,
34
  QuestionAnsweringModelOutput,
35
  SequenceClassifierOutputWithPast,
36
  TokenClassifierOutput,
37
  )
38
+ from ...modeling_utils import PreTrainedModel, SequenceSummary
39
+ from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
40
+ from ...utils import (
41
  ModelOutput,
42
  add_code_sample_docstrings,
43
  add_start_docstrings,
 
45
  logging,
46
  replace_return_docstrings,
47
  )
48
+ from ...utils.model_parallel_utils import assert_device_map, get_device_map
49
  from .configuration_gpt2 import GPT2Config
50
 
51
 
 
194
  if not self.is_cross_attention:
195
  # if only "normal" attention layer implements causal mask
196
  query_length, key_length = query.size(-2), key.size(-2)
197
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(attn_weights.device)
198
  mask_value = torch.finfo(attn_weights.dtype).min
199
  # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
200
  # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
 
606
  return_dict (`bool`, *optional*):
607
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
608
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
 
610
 
611
  @add_start_docstrings(
 
626
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
627
 
628
  # Model parallel
629
+ # self.model_parallel = False
630
+ # self.device_map = None
631
  self.gradient_checkpointing = False
632
 
633
  # Initialize weights and apply final processing
634
  self.post_init()
635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
  def get_input_embeddings(self):
637
  return self.wte
638
 
 
719
  # positions we want to attend and the dtype's smallest value for masked positions.
720
  # Since we are adding it to the raw scores before the softmax, this is
721
  # effectively the same as removing these entirely.
722
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
723
  attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
724
 
725
  # If a 2D or 3D attention mask is provided for the cross-attention
 
742
  if inputs_embeds is None:
743
  inputs_embeds = self.wte(input_ids)
744
  position_embeds = self.wpe(position_ids)
745
+ hidden_states = inputs_embeds.to(position_embeds.device) + position_embeds
746
 
747
  if token_type_ids is not None:
748
  token_type_embeds = self.wte(token_type_ids)
 
764
  all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
765
  all_hidden_states = () if output_hidden_states else None
766
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
767
+ # # Ensure layer_past is on same device as hidden_states (might not be correct)
768
+ if layer_past is not None:
769
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
770
+ # # Ensure that attention_mask is always on the same device as hidden_states
771
+ if attention_mask is not None:
772
+ attention_mask = attention_mask.to(hidden_states.device)
773
+ if isinstance(head_mask, torch.Tensor):
774
+ head_mask = head_mask.to(hidden_states.device)
 
 
 
775
  if output_hidden_states:
776
  all_hidden_states = all_hidden_states + (hidden_states,)
777
 
 
808
  if self.config.add_cross_attention:
809
  all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
810
 
811
+ # # Model Parallel: If it's the last layer for that device, put things on the next device
812
+ # if self.model_parallel:
813
+ # for k, v in self.device_map.items():
814
+ # if i == v[-1] and "cuda:" + str(k) != self.last_device:
815
+ # hidden_states = hidden_states.to("cuda:" + str(k + 1))
816
 
817
  hidden_states = self.ln_f(hidden_states)
818
 
 
853
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
854
 
855
  # Model parallel
856
+ # self.model_parallel = False
857
+ # self.device_map = None
858
 
859
  # Initialize weights and apply final processing
860
  self.post_init()
861
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
  def get_output_embeddings(self):
863
  return self.lm_head
864
 
 
961
  hidden_states = transformer_outputs[0]
962
 
963
  # Set device for model parallelism
964
+ # if self.model_parallel:
965
+ # torch.cuda.set_device(self.transformer.first_device)
966
+ # hidden_states = hidden_states.to(self.lm_head.weight.device)
967
 
968
  lm_logits = self.lm_head(hidden_states)
969
 
 
1025
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1026
  self.multiple_choice_head = SequenceSummary(config)
1027
 
1028
+ # # Model parallel
1029
+ # self.model_parallel = False
1030
+ # self.device_map = None
1031
 
1032
  # Initialize weights and apply final processing
1033
  self.post_init()
1034
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1035
  def get_output_embeddings(self):
1036
  return self.lm_head
1037
 
 
1153
 
1154
  hidden_states = transformer_outputs[0]
1155
 
1156
+ # # Set device for model parallelism
1157
+ # if self.model_parallel:
1158
+ # torch.cuda.set_device(self.transformer.first_device)
1159
+ # hidden_states = hidden_states.to(self.lm_head.weight.device)
1160
 
1161
  lm_logits = self.lm_head(hidden_states)
1162
  mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
 
1226
  self.transformer = GPT2Model(config)
1227
  self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1228
 
1229
+ # # Model parallel
1230
+ # self.model_parallel = False
1231
+ # self.device_map = None
1232
 
1233
  # Initialize weights and apply final processing
1234
  self.post_init()
 
1360
  self.dropout = nn.Dropout(classifier_dropout)
1361
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1362
 
1363
+ # # Model parallel
1364
+ # self.model_parallel = False
1365
+ # self.device_map = None
1366
 
1367
  # Initialize weights and apply final processing
1368
  self.post_init()
 
1463
  self.transformer = GPT2Model(config)
1464
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1465
 
1466
+ # # Model parallel
1467
+ # self.model_parallel = False
1468
+ # self.device_map = None
1469
 
1470
  # Initialize weights and apply final processing
1471
  self.post_init()