imvladikon
commited on
Commit
•
f6a94e8
1
Parent(s):
f1983d5
Update modeling_enc_t5.py
Browse files- modeling_enc_t5.py +2 -112
modeling_enc_t5.py
CHANGED
@@ -1,122 +1,13 @@
|
|
1 |
import copy
|
2 |
-
from typing import Any, Dict, List, Optional
|
3 |
|
4 |
import torch
|
5 |
from torch import nn
|
6 |
-
from torch.nn import
|
7 |
-
from transformers import
|
8 |
-
from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
|
9 |
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
|
10 |
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
11 |
|
12 |
|
13 |
-
class EncT5Tokenizer(T5TokenizerFast):
|
14 |
-
def __init__(
|
15 |
-
self,
|
16 |
-
vocab_file,
|
17 |
-
bos_token="<s>",
|
18 |
-
eos_token="</s>",
|
19 |
-
unk_token="<unk>",
|
20 |
-
pad_token="<pad>",
|
21 |
-
extra_ids=100,
|
22 |
-
additional_special_tokens=None,
|
23 |
-
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
24 |
-
**kwargs,
|
25 |
-
) -> None:
|
26 |
-
sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
27 |
-
|
28 |
-
super().__init__(
|
29 |
-
vocab_file=vocab_file,
|
30 |
-
bos_token=bos_token,
|
31 |
-
eos_token=eos_token,
|
32 |
-
unk_token=unk_token,
|
33 |
-
pad_token=pad_token,
|
34 |
-
extra_ids=extra_ids,
|
35 |
-
additional_special_tokens=additional_special_tokens,
|
36 |
-
sp_model_kwargs=sp_model_kwargs,
|
37 |
-
**kwargs,
|
38 |
-
)
|
39 |
-
|
40 |
-
def get_special_tokens_mask(
|
41 |
-
self,
|
42 |
-
token_ids_0: List[int],
|
43 |
-
token_ids_1: Optional[List[int]] = None,
|
44 |
-
already_has_special_tokens: bool = False,
|
45 |
-
) -> List[int]:
|
46 |
-
"""
|
47 |
-
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
48 |
-
special tokens using the tokenizer `prepare_for_model` method.
|
49 |
-
Args:
|
50 |
-
token_ids_0 (`List[int]`):
|
51 |
-
List of IDs.
|
52 |
-
token_ids_1 (`List[int]`, *optional*):
|
53 |
-
Optional second list of IDs for sequence pairs.
|
54 |
-
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
55 |
-
Whether or not the token list is already formatted with special tokens for the model.
|
56 |
-
Returns:
|
57 |
-
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
58 |
-
"""
|
59 |
-
if already_has_special_tokens:
|
60 |
-
return super().get_special_tokens_mask(
|
61 |
-
token_ids_0=token_ids_0,
|
62 |
-
token_ids_1=token_ids_1,
|
63 |
-
already_has_special_tokens=True,
|
64 |
-
)
|
65 |
-
|
66 |
-
# normal case: some special tokens
|
67 |
-
if token_ids_1 is None:
|
68 |
-
return [1] + ([0] * len(token_ids_0)) + [1]
|
69 |
-
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
70 |
-
|
71 |
-
def create_token_type_ids_from_sequences(
|
72 |
-
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
73 |
-
) -> List[int]:
|
74 |
-
"""
|
75 |
-
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
76 |
-
use of token type ids, therefore a list of zeros is returned.
|
77 |
-
Args:
|
78 |
-
token_ids_0 (`List[int]`):
|
79 |
-
List of IDs.
|
80 |
-
token_ids_1 (`List[int]`, *optional*):
|
81 |
-
Optional second list of IDs for sequence pairs.
|
82 |
-
Returns:
|
83 |
-
`List[int]`: List of zeros.
|
84 |
-
"""
|
85 |
-
bos = [self.bos_token_id]
|
86 |
-
eos = [self.eos_token_id]
|
87 |
-
|
88 |
-
if token_ids_1 is None:
|
89 |
-
return len(bos + token_ids_0 + eos) * [0]
|
90 |
-
return len(bos + token_ids_0 + eos + token_ids_1 + eos) * [0]
|
91 |
-
|
92 |
-
def build_inputs_with_special_tokens(
|
93 |
-
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
94 |
-
) -> List[int]:
|
95 |
-
"""
|
96 |
-
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
97 |
-
adding special tokens. A sequence has the following format:
|
98 |
-
- single sequence: `<s> X </s>`
|
99 |
-
- pair of sequences: `<s> A </s> B </s>`
|
100 |
-
Args:
|
101 |
-
token_ids_0 (`List[int]`):
|
102 |
-
List of IDs to which the special tokens will be added.
|
103 |
-
token_ids_1 (`List[int]`, *optional*):
|
104 |
-
Optional second list of IDs for sequence pairs.
|
105 |
-
Returns:
|
106 |
-
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
107 |
-
"""
|
108 |
-
if token_ids_1 is None:
|
109 |
-
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
110 |
-
else:
|
111 |
-
return (
|
112 |
-
[self.bos_token_id]
|
113 |
-
+ token_ids_0
|
114 |
-
+ [self.eos_token_id]
|
115 |
-
+ token_ids_1
|
116 |
-
+ [self.eos_token_id]
|
117 |
-
)
|
118 |
-
|
119 |
-
|
120 |
class EncT5ForTokenClassification(T5PreTrainedModel):
|
121 |
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
122 |
|
@@ -222,4 +113,3 @@ class EncT5ForTokenClassification(T5PreTrainedModel):
|
|
222 |
hidden_states=outputs.hidden_states,
|
223 |
attentions=outputs.attentions,
|
224 |
)
|
225 |
-
|
|
|
1 |
import copy
|
|
|
2 |
|
3 |
import torch
|
4 |
from torch import nn
|
5 |
+
from torch.nn import CrossEntropyLoss
|
6 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
|
|
7 |
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
|
8 |
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
class EncT5ForTokenClassification(T5PreTrainedModel):
|
12 |
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
13 |
|
|
|
113 |
hidden_states=outputs.hidden_states,
|
114 |
attentions=outputs.attentions,
|
115 |
)
|
|