|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import contextlib |
|
import sys |
|
from collections import Counter |
|
from multiprocessing import Pool |
|
|
|
from fairseq.data.encoders.gpt2_bpe import get_encoder |
|
|
|
|
|
def main(): |
|
""" |
|
Helper script to encode raw text with the GPT-2 BPE using multiple processes. |
|
|
|
The encoder.json and vocab.bpe files can be obtained here: |
|
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json |
|
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe |
|
""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--encoder-json", |
|
help="path to encoder.json", |
|
) |
|
parser.add_argument( |
|
"--vocab-bpe", |
|
type=str, |
|
help="path to vocab.bpe", |
|
) |
|
parser.add_argument( |
|
"--inputs", |
|
nargs="+", |
|
default=["-"], |
|
help="input files to filter/encode", |
|
) |
|
parser.add_argument( |
|
"--outputs", |
|
nargs="+", |
|
default=["-"], |
|
help="path to save encoded outputs", |
|
) |
|
parser.add_argument( |
|
"--keep-empty", |
|
action="store_true", |
|
help="keep empty lines", |
|
) |
|
parser.add_argument("--workers", type=int, default=20) |
|
args = parser.parse_args() |
|
|
|
assert len(args.inputs) == len( |
|
args.outputs |
|
), "number of input and output paths should match" |
|
|
|
with contextlib.ExitStack() as stack: |
|
inputs = [ |
|
stack.enter_context(open(input, "r", encoding="utf-8")) |
|
if input != "-" |
|
else sys.stdin |
|
for input in args.inputs |
|
] |
|
outputs = [ |
|
stack.enter_context(open(output, "w", encoding="utf-8")) |
|
if output != "-" |
|
else sys.stdout |
|
for output in args.outputs |
|
] |
|
|
|
encoder = MultiprocessingEncoder(args) |
|
pool = Pool(args.workers, initializer=encoder.initializer) |
|
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100) |
|
|
|
stats = Counter() |
|
for i, (filt, enc_lines) in enumerate(encoded_lines, start=1): |
|
if filt == "PASS": |
|
for enc_line, output_h in zip(enc_lines, outputs): |
|
print(enc_line, file=output_h) |
|
else: |
|
stats["num_filtered_" + filt] += 1 |
|
if i % 10000 == 0: |
|
print("processed {} lines".format(i), file=sys.stderr) |
|
|
|
for k, v in stats.most_common(): |
|
print("[{}] filtered {} lines".format(k, v), file=sys.stderr) |
|
|
|
|
|
class MultiprocessingEncoder(object): |
|
def __init__(self, args): |
|
self.args = args |
|
|
|
def initializer(self): |
|
global bpe |
|
bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe) |
|
|
|
def encode(self, line): |
|
global bpe |
|
ids = bpe.encode(line) |
|
return list(map(str, ids)) |
|
|
|
def decode(self, tokens): |
|
global bpe |
|
return bpe.decode(tokens) |
|
|
|
def encode_lines(self, lines): |
|
""" |
|
Encode a set of lines. All lines will be encoded together. |
|
""" |
|
enc_lines = [] |
|
for line in lines: |
|
line = line.strip() |
|
if len(line) == 0 and not self.args.keep_empty: |
|
return ["EMPTY", None] |
|
tokens = self.encode(line) |
|
enc_lines.append(" ".join(tokens)) |
|
return ["PASS", enc_lines] |
|
|
|
def decode_lines(self, lines): |
|
dec_lines = [] |
|
for line in lines: |
|
tokens = map(int, line.strip().split()) |
|
dec_lines.append(self.decode(tokens)) |
|
return ["PASS", dec_lines] |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|