oiTrans / scripts /concat_joint_data.py
harveen
Harveen | Adding code
74fc30d
raw
history blame
3.64 kB
import os
from tqdm import tqdm
import sys
LANGS = [
"as",
"bn",
"gu",
"hi",
"kn",
"ml",
"mr",
"or",
"pa",
"ta",
"te",
#"ur"
]
def add_token(sent, tag_infos):
""" add special tokens specified by tag_infos to each element in list
tag_infos: list of tuples (tag_type,tag)
each tag_info results in a token of the form: __{tag_type}__{tag}__
"""
tokens = []
for tag_type, tag in tag_infos:
token = '__' + tag_type + '__' + tag + '__'
tokens.append(token)
return ' '.join(tokens) + ' ' + sent
def concat_data(data_dir, outdir, lang_pair_list,
out_src_lang='SRC', out_trg_lang='TGT', split='train'):
"""
data_dir: input dir, contains directories for language pairs named l1-l2
"""
os.makedirs(outdir, exist_ok=True)
out_src_fname = '{}/{}.{}'.format(outdir, split, out_src_lang)
out_trg_fname = '{}/{}.{}'.format(outdir, split, out_trg_lang)
# out_meta_fname='{}/metadata.txt'.format(outdir)
print()
print(out_src_fname)
print(out_trg_fname)
# print(out_meta_fname)
# concatenate train data
if os.path.isfile(out_src_fname):
os.unlink(out_src_fname)
if os.path.isfile(out_trg_fname):
os.unlink(out_trg_fname)
# if os.path.isfile(out_meta_fname):
# os.unlink(out_meta_fname)
for src_lang, trg_lang in tqdm(lang_pair_list):
print('src: {}, tgt:{}'.format(src_lang, trg_lang))
in_src_fname = '{}/{}-{}/{}.{}'.format(
data_dir, src_lang, trg_lang, split, src_lang)
in_trg_fname = '{}/{}-{}/{}.{}'.format(
data_dir, src_lang, trg_lang, split, trg_lang)
if not os.path.exists(in_src_fname):
continue
if not os.path.exists(in_trg_fname):
continue
print(in_src_fname)
os.system('cat {} >> {}'.format(in_src_fname, out_src_fname))
print(in_trg_fname)
os.system('cat {} >> {}'.format(in_trg_fname, out_trg_fname))
# with open('{}/lang_pairs.txt'.format(outdir),'w',encoding='utf-8') as lpfile:
# lpfile.write('\n'.join( [ '-'.join(x) for x in lang_pair_list ] ))
corpus_stats(data_dir, outdir, lang_pair_list, split)
def corpus_stats(data_dir, outdir, lang_pair_list, split):
"""
data_dir: input dir, contains directories for language pairs named l1-l2
"""
with open('{}/{}_lang_pairs.txt'.format(outdir, split), 'w', encoding='utf-8') as lpfile:
for src_lang, trg_lang in tqdm(lang_pair_list):
print('src: {}, tgt:{}'.format(src_lang, trg_lang))
in_src_fname = '{}/{}-{}/{}.{}'.format(
data_dir, src_lang, trg_lang, split, src_lang)
# in_trg_fname='{}/{}-{}/train.{}'.format(data_dir,src_lang,trg_lang,trg_lang)
if not os.path.exists(in_src_fname):
continue
print(in_src_fname)
corpus_size = 0
with open(in_src_fname, 'r', encoding='utf-8') as infile:
corpus_size = sum(map(lambda x: 1, infile))
lpfile.write('{}\t{}\t{}\n'.format(
src_lang, trg_lang, corpus_size))
if __name__ == '__main__':
in_dir = sys.argv[1]
out_dir = sys.argv[2]
src_lang = sys.argv[3]
tgt_lang = sys.argv[4]
split = sys.argv[5]
lang_pair_list = []
if src_lang == 'en':
for lang in LANGS:
lang_pair_list.append(['en', lang])
else:
for lang in LANGS:
lang_pair_list.append([lang, 'en'])
concat_data(in_dir, out_dir, lang_pair_list, split=split)