imvladikon commited on
Commit
f6a94e8
1 Parent(s): f1983d5

Update modeling_enc_t5.py

Browse files
Files changed (1) hide show
  1. modeling_enc_t5.py +2 -112
modeling_enc_t5.py CHANGED
@@ -1,122 +1,13 @@
1
  import copy
2
- from typing import Any, Dict, List, Optional
3
 
4
  import torch
5
  from torch import nn
6
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
- from transformers import T5TokenizerFast, T5Config
8
- from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
9
  from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
10
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
11
 
12
 
13
- class EncT5Tokenizer(T5TokenizerFast):
14
- def __init__(
15
- self,
16
- vocab_file,
17
- bos_token="<s>",
18
- eos_token="</s>",
19
- unk_token="<unk>",
20
- pad_token="<pad>",
21
- extra_ids=100,
22
- additional_special_tokens=None,
23
- sp_model_kwargs: Optional[Dict[str, Any]] = None,
24
- **kwargs,
25
- ) -> None:
26
- sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
27
-
28
- super().__init__(
29
- vocab_file=vocab_file,
30
- bos_token=bos_token,
31
- eos_token=eos_token,
32
- unk_token=unk_token,
33
- pad_token=pad_token,
34
- extra_ids=extra_ids,
35
- additional_special_tokens=additional_special_tokens,
36
- sp_model_kwargs=sp_model_kwargs,
37
- **kwargs,
38
- )
39
-
40
- def get_special_tokens_mask(
41
- self,
42
- token_ids_0: List[int],
43
- token_ids_1: Optional[List[int]] = None,
44
- already_has_special_tokens: bool = False,
45
- ) -> List[int]:
46
- """
47
- Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
48
- special tokens using the tokenizer `prepare_for_model` method.
49
- Args:
50
- token_ids_0 (`List[int]`):
51
- List of IDs.
52
- token_ids_1 (`List[int]`, *optional*):
53
- Optional second list of IDs for sequence pairs.
54
- already_has_special_tokens (`bool`, *optional*, defaults to `False`):
55
- Whether or not the token list is already formatted with special tokens for the model.
56
- Returns:
57
- `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
58
- """
59
- if already_has_special_tokens:
60
- return super().get_special_tokens_mask(
61
- token_ids_0=token_ids_0,
62
- token_ids_1=token_ids_1,
63
- already_has_special_tokens=True,
64
- )
65
-
66
- # normal case: some special tokens
67
- if token_ids_1 is None:
68
- return [1] + ([0] * len(token_ids_0)) + [1]
69
- return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
70
-
71
- def create_token_type_ids_from_sequences(
72
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
73
- ) -> List[int]:
74
- """
75
- Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
76
- use of token type ids, therefore a list of zeros is returned.
77
- Args:
78
- token_ids_0 (`List[int]`):
79
- List of IDs.
80
- token_ids_1 (`List[int]`, *optional*):
81
- Optional second list of IDs for sequence pairs.
82
- Returns:
83
- `List[int]`: List of zeros.
84
- """
85
- bos = [self.bos_token_id]
86
- eos = [self.eos_token_id]
87
-
88
- if token_ids_1 is None:
89
- return len(bos + token_ids_0 + eos) * [0]
90
- return len(bos + token_ids_0 + eos + token_ids_1 + eos) * [0]
91
-
92
- def build_inputs_with_special_tokens(
93
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
94
- ) -> List[int]:
95
- """
96
- Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
97
- adding special tokens. A sequence has the following format:
98
- - single sequence: `<s> X </s>`
99
- - pair of sequences: `<s> A </s> B </s>`
100
- Args:
101
- token_ids_0 (`List[int]`):
102
- List of IDs to which the special tokens will be added.
103
- token_ids_1 (`List[int]`, *optional*):
104
- Optional second list of IDs for sequence pairs.
105
- Returns:
106
- `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
107
- """
108
- if token_ids_1 is None:
109
- return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
110
- else:
111
- return (
112
- [self.bos_token_id]
113
- + token_ids_0
114
- + [self.eos_token_id]
115
- + token_ids_1
116
- + [self.eos_token_id]
117
- )
118
-
119
-
120
  class EncT5ForTokenClassification(T5PreTrainedModel):
121
  _keys_to_ignore_on_load_unexpected = [r"pooler"]
122
 
@@ -222,4 +113,3 @@ class EncT5ForTokenClassification(T5PreTrainedModel):
222
  hidden_states=outputs.hidden_states,
223
  attentions=outputs.attentions,
224
  )
225
-
 
1
  import copy
 
2
 
3
  import torch
4
  from torch import nn
5
+ from torch.nn import CrossEntropyLoss
6
+ from transformers.modeling_outputs import TokenClassifierOutput
 
7
  from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
8
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class EncT5ForTokenClassification(T5PreTrainedModel):
12
  _keys_to_ignore_on_load_unexpected = [r"pooler"]
13
 
 
113
  hidden_states=outputs.hidden_states,
114
  attentions=outputs.attentions,
115
  )