Updated modeling_gpt2 that uses accelerate API to parallelize model
Browse files- modeling_gpt2.py +45 -206
modeling_gpt2.py
CHANGED
@@ -27,17 +27,17 @@ from torch import nn
|
|
27 |
from torch.cuda.amp import autocast
|
28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
29 |
|
30 |
-
from
|
31 |
-
from
|
32 |
BaseModelOutputWithPastAndCrossAttentions,
|
33 |
CausalLMOutputWithCrossAttentions,
|
34 |
QuestionAnsweringModelOutput,
|
35 |
SequenceClassifierOutputWithPast,
|
36 |
TokenClassifierOutput,
|
37 |
)
|
38 |
-
from
|
39 |
-
from
|
40 |
-
from
|
41 |
ModelOutput,
|
42 |
add_code_sample_docstrings,
|
43 |
add_start_docstrings,
|
@@ -45,7 +45,7 @@ from transformers.utils import (
|
|
45 |
logging,
|
46 |
replace_return_docstrings,
|
47 |
)
|
48 |
-
from
|
49 |
from .configuration_gpt2 import GPT2Config
|
50 |
|
51 |
|
@@ -194,7 +194,7 @@ class GPT2Attention(nn.Module):
|
|
194 |
if not self.is_cross_attention:
|
195 |
# if only "normal" attention layer implements causal mask
|
196 |
query_length, key_length = query.size(-2), key.size(-2)
|
197 |
-
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
198 |
mask_value = torch.finfo(attn_weights.dtype).min
|
199 |
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
200 |
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
@@ -606,56 +606,6 @@ GPT2_INPUTS_DOCSTRING = r"""
|
|
606 |
return_dict (`bool`, *optional*):
|
607 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
608 |
"""
|
609 |
-
PARALLELIZE_DOCSTRING = r"""
|
610 |
-
This is an experimental feature and is a subject to change at a moment's notice.
|
611 |
-
|
612 |
-
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
|
613 |
-
it will evenly distribute blocks across all devices.
|
614 |
-
|
615 |
-
Args:
|
616 |
-
device_map (`Dict[int, list]`, optional, defaults to None):
|
617 |
-
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
618 |
-
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
619 |
-
have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
|
620 |
-
following number of attention modules:
|
621 |
-
|
622 |
-
- gpt2: 12
|
623 |
-
- gpt2-medium: 24
|
624 |
-
- gpt2-large: 36
|
625 |
-
- gpt2-xl: 48
|
626 |
-
|
627 |
-
Example:
|
628 |
-
|
629 |
-
```python
|
630 |
-
# Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
|
631 |
-
model = GPT2LMHeadModel.from_pretrained("gpt2-xl")
|
632 |
-
device_map = {
|
633 |
-
0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
634 |
-
1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
|
635 |
-
2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
|
636 |
-
3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
|
637 |
-
}
|
638 |
-
model.parallelize(device_map)
|
639 |
-
```
|
640 |
-
"""
|
641 |
-
DEPARALLELIZE_DOCSTRING = r"""
|
642 |
-
Moves the model to cpu from a model parallel state.
|
643 |
-
|
644 |
-
Example:
|
645 |
-
|
646 |
-
```python
|
647 |
-
# On a 4 GPU machine with gpt2-large:
|
648 |
-
model = GPT2LMHeadModel.from_pretrained("gpt2-large")
|
649 |
-
device_map = {
|
650 |
-
0: [0, 1, 2, 3, 4, 5, 6, 7],
|
651 |
-
1: [8, 9, 10, 11, 12, 13, 14, 15],
|
652 |
-
2: [16, 17, 18, 19, 20, 21, 22, 23],
|
653 |
-
3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
|
654 |
-
}
|
655 |
-
model.parallelize(device_map) # Splits the model across several devices
|
656 |
-
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
657 |
-
```
|
658 |
-
"""
|
659 |
|
660 |
|
661 |
@add_start_docstrings(
|
@@ -676,57 +626,13 @@ class GPT2Model(GPT2PreTrainedModel):
|
|
676 |
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
677 |
|
678 |
# Model parallel
|
679 |
-
self.model_parallel = False
|
680 |
-
self.device_map = None
|
681 |
self.gradient_checkpointing = False
|
682 |
|
683 |
# Initialize weights and apply final processing
|
684 |
self.post_init()
|
685 |
|
686 |
-
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
687 |
-
def parallelize(self, device_map=None):
|
688 |
-
# Check validity of device_map
|
689 |
-
warnings.warn(
|
690 |
-
"`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
|
691 |
-
" model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
|
692 |
-
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
|
693 |
-
" ...}",
|
694 |
-
FutureWarning,
|
695 |
-
)
|
696 |
-
self.device_map = (
|
697 |
-
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
|
698 |
-
)
|
699 |
-
assert_device_map(self.device_map, len(self.h))
|
700 |
-
self.model_parallel = True
|
701 |
-
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
702 |
-
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
703 |
-
self.wte = self.wte.to(self.first_device)
|
704 |
-
self.wpe = self.wpe.to(self.first_device)
|
705 |
-
# Load onto devices
|
706 |
-
for k, v in self.device_map.items():
|
707 |
-
for block in v:
|
708 |
-
cuda_device = "cuda:" + str(k)
|
709 |
-
self.h[block] = self.h[block].to(cuda_device)
|
710 |
-
# ln_f to last
|
711 |
-
self.ln_f = self.ln_f.to(self.last_device)
|
712 |
-
|
713 |
-
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
714 |
-
def deparallelize(self):
|
715 |
-
warnings.warn(
|
716 |
-
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
717 |
-
FutureWarning,
|
718 |
-
)
|
719 |
-
self.model_parallel = False
|
720 |
-
self.device_map = None
|
721 |
-
self.first_device = "cpu"
|
722 |
-
self.last_device = "cpu"
|
723 |
-
self.wte = self.wte.to("cpu")
|
724 |
-
self.wpe = self.wpe.to("cpu")
|
725 |
-
for index in range(len(self.h)):
|
726 |
-
self.h[index] = self.h[index].to("cpu")
|
727 |
-
self.ln_f = self.ln_f.to("cpu")
|
728 |
-
torch.cuda.empty_cache()
|
729 |
-
|
730 |
def get_input_embeddings(self):
|
731 |
return self.wte
|
732 |
|
@@ -813,7 +719,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|
813 |
# positions we want to attend and the dtype's smallest value for masked positions.
|
814 |
# Since we are adding it to the raw scores before the softmax, this is
|
815 |
# effectively the same as removing these entirely.
|
816 |
-
attention_mask = attention_mask.to(dtype=self.dtype)
|
817 |
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
818 |
|
819 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
@@ -836,7 +742,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|
836 |
if inputs_embeds is None:
|
837 |
inputs_embeds = self.wte(input_ids)
|
838 |
position_embeds = self.wpe(position_ids)
|
839 |
-
hidden_states = inputs_embeds + position_embeds
|
840 |
|
841 |
if token_type_ids is not None:
|
842 |
token_type_embeds = self.wte(token_type_ids)
|
@@ -858,17 +764,14 @@ class GPT2Model(GPT2PreTrainedModel):
|
|
858 |
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
859 |
all_hidden_states = () if output_hidden_states else None
|
860 |
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
861 |
-
#
|
862 |
-
if
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
attention_mask = attention_mask.to(hidden_states.device)
|
870 |
-
if isinstance(head_mask, torch.Tensor):
|
871 |
-
head_mask = head_mask.to(hidden_states.device)
|
872 |
if output_hidden_states:
|
873 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
874 |
|
@@ -905,11 +808,11 @@ class GPT2Model(GPT2PreTrainedModel):
|
|
905 |
if self.config.add_cross_attention:
|
906 |
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
907 |
|
908 |
-
# Model Parallel: If it's the last layer for that device, put things on the next device
|
909 |
-
if self.model_parallel:
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
|
914 |
hidden_states = self.ln_f(hidden_states)
|
915 |
|
@@ -950,43 +853,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|
950 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
951 |
|
952 |
# Model parallel
|
953 |
-
self.model_parallel = False
|
954 |
-
self.device_map = None
|
955 |
|
956 |
# Initialize weights and apply final processing
|
957 |
self.post_init()
|
958 |
|
959 |
-
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
960 |
-
def parallelize(self, device_map=None):
|
961 |
-
warnings.warn(
|
962 |
-
"`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
|
963 |
-
" your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
|
964 |
-
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
|
965 |
-
" 0, 'transformer.h.1': 1, ...}",
|
966 |
-
FutureWarning,
|
967 |
-
)
|
968 |
-
self.device_map = (
|
969 |
-
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
970 |
-
if device_map is None
|
971 |
-
else device_map
|
972 |
-
)
|
973 |
-
assert_device_map(self.device_map, len(self.transformer.h))
|
974 |
-
self.transformer.parallelize(self.device_map)
|
975 |
-
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
976 |
-
self.model_parallel = True
|
977 |
-
|
978 |
-
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
979 |
-
def deparallelize(self):
|
980 |
-
warnings.warn(
|
981 |
-
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
982 |
-
FutureWarning,
|
983 |
-
)
|
984 |
-
self.transformer.deparallelize()
|
985 |
-
self.transformer = self.transformer.to("cpu")
|
986 |
-
self.lm_head = self.lm_head.to("cpu")
|
987 |
-
self.model_parallel = False
|
988 |
-
torch.cuda.empty_cache()
|
989 |
-
|
990 |
def get_output_embeddings(self):
|
991 |
return self.lm_head
|
992 |
|
@@ -1089,9 +961,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|
1089 |
hidden_states = transformer_outputs[0]
|
1090 |
|
1091 |
# Set device for model parallelism
|
1092 |
-
if self.model_parallel:
|
1093 |
-
|
1094 |
-
|
1095 |
|
1096 |
lm_logits = self.lm_head(hidden_states)
|
1097 |
|
@@ -1153,46 +1025,13 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|
1153 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
1154 |
self.multiple_choice_head = SequenceSummary(config)
|
1155 |
|
1156 |
-
# Model parallel
|
1157 |
-
self.model_parallel = False
|
1158 |
-
self.device_map = None
|
1159 |
|
1160 |
# Initialize weights and apply final processing
|
1161 |
self.post_init()
|
1162 |
|
1163 |
-
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
1164 |
-
def parallelize(self, device_map=None):
|
1165 |
-
warnings.warn(
|
1166 |
-
"`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
|
1167 |
-
" load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
|
1168 |
-
" own `device_map` but it needs to be a dictionary module_name to device, so for instance"
|
1169 |
-
" {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
|
1170 |
-
FutureWarning,
|
1171 |
-
)
|
1172 |
-
self.device_map = (
|
1173 |
-
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
1174 |
-
if device_map is None
|
1175 |
-
else device_map
|
1176 |
-
)
|
1177 |
-
assert_device_map(self.device_map, len(self.transformer.h))
|
1178 |
-
self.transformer.parallelize(self.device_map)
|
1179 |
-
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
1180 |
-
self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
|
1181 |
-
self.model_parallel = True
|
1182 |
-
|
1183 |
-
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
1184 |
-
def deparallelize(self):
|
1185 |
-
warnings.warn(
|
1186 |
-
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
1187 |
-
FutureWarning,
|
1188 |
-
)
|
1189 |
-
self.transformer.deparallelize()
|
1190 |
-
self.transformer = self.transformer.to("cpu")
|
1191 |
-
self.lm_head = self.lm_head.to("cpu")
|
1192 |
-
self.multiple_choice_head = self.multiple_choice_head.to("cpu")
|
1193 |
-
self.model_parallel = False
|
1194 |
-
torch.cuda.empty_cache()
|
1195 |
-
|
1196 |
def get_output_embeddings(self):
|
1197 |
return self.lm_head
|
1198 |
|
@@ -1314,10 +1153,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|
1314 |
|
1315 |
hidden_states = transformer_outputs[0]
|
1316 |
|
1317 |
-
# Set device for model parallelism
|
1318 |
-
if self.model_parallel:
|
1319 |
-
|
1320 |
-
|
1321 |
|
1322 |
lm_logits = self.lm_head(hidden_states)
|
1323 |
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
@@ -1387,9 +1226,9 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
|
1387 |
self.transformer = GPT2Model(config)
|
1388 |
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
1389 |
|
1390 |
-
# Model parallel
|
1391 |
-
self.model_parallel = False
|
1392 |
-
self.device_map = None
|
1393 |
|
1394 |
# Initialize weights and apply final processing
|
1395 |
self.post_init()
|
@@ -1521,9 +1360,9 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
|
|
1521 |
self.dropout = nn.Dropout(classifier_dropout)
|
1522 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1523 |
|
1524 |
-
# Model parallel
|
1525 |
-
self.model_parallel = False
|
1526 |
-
self.device_map = None
|
1527 |
|
1528 |
# Initialize weights and apply final processing
|
1529 |
self.post_init()
|
@@ -1624,9 +1463,9 @@ class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
|
|
1624 |
self.transformer = GPT2Model(config)
|
1625 |
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
1626 |
|
1627 |
-
# Model parallel
|
1628 |
-
self.model_parallel = False
|
1629 |
-
self.device_map = None
|
1630 |
|
1631 |
# Initialize weights and apply final processing
|
1632 |
self.post_init()
|
|
|
27 |
from torch.cuda.amp import autocast
|
28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
29 |
|
30 |
+
from ...activations import ACT2FN
|
31 |
+
from ...modeling_outputs import (
|
32 |
BaseModelOutputWithPastAndCrossAttentions,
|
33 |
CausalLMOutputWithCrossAttentions,
|
34 |
QuestionAnsweringModelOutput,
|
35 |
SequenceClassifierOutputWithPast,
|
36 |
TokenClassifierOutput,
|
37 |
)
|
38 |
+
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
39 |
+
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
40 |
+
from ...utils import (
|
41 |
ModelOutput,
|
42 |
add_code_sample_docstrings,
|
43 |
add_start_docstrings,
|
|
|
45 |
logging,
|
46 |
replace_return_docstrings,
|
47 |
)
|
48 |
+
from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
49 |
from .configuration_gpt2 import GPT2Config
|
50 |
|
51 |
|
|
|
194 |
if not self.is_cross_attention:
|
195 |
# if only "normal" attention layer implements causal mask
|
196 |
query_length, key_length = query.size(-2), key.size(-2)
|
197 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(attn_weights.device)
|
198 |
mask_value = torch.finfo(attn_weights.dtype).min
|
199 |
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
200 |
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
|
|
606 |
return_dict (`bool`, *optional*):
|
607 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
608 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
|
610 |
|
611 |
@add_start_docstrings(
|
|
|
626 |
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
627 |
|
628 |
# Model parallel
|
629 |
+
# self.model_parallel = False
|
630 |
+
# self.device_map = None
|
631 |
self.gradient_checkpointing = False
|
632 |
|
633 |
# Initialize weights and apply final processing
|
634 |
self.post_init()
|
635 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
636 |
def get_input_embeddings(self):
|
637 |
return self.wte
|
638 |
|
|
|
719 |
# positions we want to attend and the dtype's smallest value for masked positions.
|
720 |
# Since we are adding it to the raw scores before the softmax, this is
|
721 |
# effectively the same as removing these entirely.
|
722 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
723 |
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
724 |
|
725 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
|
742 |
if inputs_embeds is None:
|
743 |
inputs_embeds = self.wte(input_ids)
|
744 |
position_embeds = self.wpe(position_ids)
|
745 |
+
hidden_states = inputs_embeds.to(position_embeds.device) + position_embeds
|
746 |
|
747 |
if token_type_ids is not None:
|
748 |
token_type_embeds = self.wte(token_type_ids)
|
|
|
764 |
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
765 |
all_hidden_states = () if output_hidden_states else None
|
766 |
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
767 |
+
# # Ensure layer_past is on same device as hidden_states (might not be correct)
|
768 |
+
if layer_past is not None:
|
769 |
+
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
770 |
+
# # Ensure that attention_mask is always on the same device as hidden_states
|
771 |
+
if attention_mask is not None:
|
772 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
773 |
+
if isinstance(head_mask, torch.Tensor):
|
774 |
+
head_mask = head_mask.to(hidden_states.device)
|
|
|
|
|
|
|
775 |
if output_hidden_states:
|
776 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
777 |
|
|
|
808 |
if self.config.add_cross_attention:
|
809 |
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
810 |
|
811 |
+
# # Model Parallel: If it's the last layer for that device, put things on the next device
|
812 |
+
# if self.model_parallel:
|
813 |
+
# for k, v in self.device_map.items():
|
814 |
+
# if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
815 |
+
# hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
816 |
|
817 |
hidden_states = self.ln_f(hidden_states)
|
818 |
|
|
|
853 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
854 |
|
855 |
# Model parallel
|
856 |
+
# self.model_parallel = False
|
857 |
+
# self.device_map = None
|
858 |
|
859 |
# Initialize weights and apply final processing
|
860 |
self.post_init()
|
861 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
862 |
def get_output_embeddings(self):
|
863 |
return self.lm_head
|
864 |
|
|
|
961 |
hidden_states = transformer_outputs[0]
|
962 |
|
963 |
# Set device for model parallelism
|
964 |
+
# if self.model_parallel:
|
965 |
+
# torch.cuda.set_device(self.transformer.first_device)
|
966 |
+
# hidden_states = hidden_states.to(self.lm_head.weight.device)
|
967 |
|
968 |
lm_logits = self.lm_head(hidden_states)
|
969 |
|
|
|
1025 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
1026 |
self.multiple_choice_head = SequenceSummary(config)
|
1027 |
|
1028 |
+
# # Model parallel
|
1029 |
+
# self.model_parallel = False
|
1030 |
+
# self.device_map = None
|
1031 |
|
1032 |
# Initialize weights and apply final processing
|
1033 |
self.post_init()
|
1034 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1035 |
def get_output_embeddings(self):
|
1036 |
return self.lm_head
|
1037 |
|
|
|
1153 |
|
1154 |
hidden_states = transformer_outputs[0]
|
1155 |
|
1156 |
+
# # Set device for model parallelism
|
1157 |
+
# if self.model_parallel:
|
1158 |
+
# torch.cuda.set_device(self.transformer.first_device)
|
1159 |
+
# hidden_states = hidden_states.to(self.lm_head.weight.device)
|
1160 |
|
1161 |
lm_logits = self.lm_head(hidden_states)
|
1162 |
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
|
|
1226 |
self.transformer = GPT2Model(config)
|
1227 |
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
1228 |
|
1229 |
+
# # Model parallel
|
1230 |
+
# self.model_parallel = False
|
1231 |
+
# self.device_map = None
|
1232 |
|
1233 |
# Initialize weights and apply final processing
|
1234 |
self.post_init()
|
|
|
1360 |
self.dropout = nn.Dropout(classifier_dropout)
|
1361 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1362 |
|
1363 |
+
# # Model parallel
|
1364 |
+
# self.model_parallel = False
|
1365 |
+
# self.device_map = None
|
1366 |
|
1367 |
# Initialize weights and apply final processing
|
1368 |
self.post_init()
|
|
|
1463 |
self.transformer = GPT2Model(config)
|
1464 |
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
1465 |
|
1466 |
+
# # Model parallel
|
1467 |
+
# self.model_parallel = False
|
1468 |
+
# self.device_map = None
|
1469 |
|
1470 |
# Initialize weights and apply final processing
|
1471 |
self.post_init()
|