|
import logging |
|
import torch.utils.data |
|
from fairseq.data import FairseqDataset |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class OFADataset(FairseqDataset): |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def encode_text(self, text, length=None, append_bos=False, append_eos=False): |
|
s = self.tgt_dict.encode_line( |
|
line=self.bpe.encode(text), |
|
add_if_not_exist=False, |
|
append_eos=False |
|
).long() |
|
if length is not None: |
|
s = s[:length] |
|
if append_bos: |
|
s = torch.cat([self.bos_item, s]) |
|
if append_eos: |
|
s = torch.cat([s, self.eos_item]) |
|
return s |
|
|