Spaces:
Running
Running
BatchEncoding wrapper for custom tokenizer output
Browse files- harim_plus.py +14 -4
harim_plus.py
CHANGED
@@ -7,9 +7,12 @@ import torch.nn.functional as F
|
|
7 |
from transformers import (AutoModelForSeq2SeqLM,
|
8 |
AutoTokenizer,
|
9 |
PreTrainedTokenizer,
|
10 |
-
PreTrainedTokenizerFast
|
|
|
|
|
|
|
11 |
import pandas as pd
|
12 |
-
from tqdm import
|
13 |
|
14 |
from typing import List, Dict, Union
|
15 |
from collections import defaultdict
|
@@ -201,8 +204,15 @@ class Harimplus_Scorer:
|
|
201 |
emp_in = self._prep_input( mini_e_, src_or_tgt='src' )
|
202 |
|
203 |
|
204 |
-
tgt_mask = tgt_in.attention_mask
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
src_in = src_in.to(self._device)
|
207 |
emp_in = emp_in.to(self._device)
|
208 |
tgt_in = tgt_in.to(self._device)
|
|
|
7 |
from transformers import (AutoModelForSeq2SeqLM,
|
8 |
AutoTokenizer,
|
9 |
PreTrainedTokenizer,
|
10 |
+
PreTrainedTokenizerFast,
|
11 |
+
)
|
12 |
+
from transformers.tokenization_utils_base import BatchEncoding # for custom tokenizer other than huggingface
|
13 |
+
|
14 |
import pandas as pd
|
15 |
+
from tqdm import tqdme
|
16 |
|
17 |
from typing import List, Dict, Union
|
18 |
from collections import defaultdict
|
|
|
204 |
emp_in = self._prep_input( mini_e_, src_or_tgt='src' )
|
205 |
|
206 |
|
207 |
+
tgt_mask = tgt_in.attention_mask # torch.Tensor
|
208 |
+
# if not tokenizer loaded from huggingface, this might cause some problem (.to(device))
|
209 |
+
if not isinstance(src_in, BatchEncoding):
|
210 |
+
src_in = BatchEncoding(src_in)
|
211 |
+
if not isinstance(emp_in, BatchEncoding):
|
212 |
+
emp_in = BatchEncoding(emp_in)
|
213 |
+
if not isinstance(tgt_in, BatchEncoding):
|
214 |
+
tgt_in = BatchEncoding(tgt_in)
|
215 |
+
|
216 |
src_in = src_in.to(self._device)
|
217 |
emp_in = emp_in.to(self._device)
|
218 |
tgt_in = tgt_in.to(self._device)
|