Spaces:
Runtime error
Runtime error
File size: 4,790 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""
Divide a document collection into N-word/token passage spans (with wrap-around for last passage).
"""
import os
import math
import ujson
import random
from multiprocessing import Pool
from argparse import ArgumentParser
from colbert.utils.utils import print_message
Format1 = 'docid,text' # MS MARCO Passages
Format2 = 'docid,text,title' # DPR Wikipedia
Format3 = 'docid,url,title,text' # MS MARCO Documents
def process_page(inp):
"""
Wraps around if we split: make sure last passage isn't too short.
This is meant to be similar to the DPR preprocessing.
"""
(nwords, overlap, tokenizer), (title_idx, docid, title, url, content) = inp
if tokenizer is None:
words = content.split()
else:
words = tokenizer.tokenize(content)
words_ = (words + words) if len(words) > nwords else words
passages = [words_[offset:offset + nwords] for offset in range(0, len(words) - overlap, nwords - overlap)]
assert all(len(psg) in [len(words), nwords] for psg in passages), (list(map(len, passages)), len(words))
if tokenizer is None:
passages = [' '.join(psg) for psg in passages]
else:
passages = [' '.join(psg).replace(' ##', '') for psg in passages]
if title_idx % 100000 == 0:
print("#> ", title_idx, '\t\t\t', title)
for p in passages:
print("$$$ ", '\t\t', p)
print()
print()
print()
print()
return (docid, title, url, passages)
def main(args):
random.seed(12345)
print_message("#> Starting...")
letter = 'w' if not args.use_wordpiece else 't'
output_path = f'{args.input}.{letter}{args.nwords}_{args.overlap}'
assert not os.path.exists(output_path)
RawCollection = []
Collection = []
NumIllFormattedLines = 0
with open(args.input) as f:
for line_idx, line in enumerate(f):
if line_idx % (100*1000) == 0:
print(line_idx, end=' ')
title, url = None, None
try:
line = line.strip().split('\t')
if args.format == Format1:
docid, doc = line
elif args.format == Format2:
docid, doc, title = line
elif args.format == Format3:
docid, url, title, doc = line
RawCollection.append((line_idx, docid, title, url, doc))
except:
NumIllFormattedLines += 1
if NumIllFormattedLines % 1000 == 0:
print(f'\n[{line_idx}] NumIllFormattedLines = {NumIllFormattedLines}\n')
print()
print_message("# of documents is", len(RawCollection), '\n')
p = Pool(args.nthreads)
print_message("#> Starting parallel processing...")
tokenizer = None
if args.use_wordpiece:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
process_page_params = [(args.nwords, args.overlap, tokenizer)] * len(RawCollection)
Collection = p.map(process_page, zip(process_page_params, RawCollection))
print_message(f"#> Writing to {output_path} ...")
with open(output_path, 'w') as f:
line_idx = 1
if args.format == Format1:
f.write('\t'.join(['id', 'text']) + '\n')
elif args.format == Format2:
f.write('\t'.join(['id', 'text', 'title']) + '\n')
elif args.format == Format3:
f.write('\t'.join(['id', 'text', 'title', 'docid']) + '\n')
for docid, title, url, passages in Collection:
for passage in passages:
if args.format == Format1:
f.write('\t'.join([str(line_idx), passage]) + '\n')
elif args.format == Format2:
f.write('\t'.join([str(line_idx), passage, title]) + '\n')
elif args.format == Format3:
f.write('\t'.join([str(line_idx), passage, title, docid]) + '\n')
line_idx += 1
if __name__ == "__main__":
parser = ArgumentParser(description="docs2passages.")
# Input Arguments.
parser.add_argument('--input', dest='input', required=True)
parser.add_argument('--format', dest='format', required=True, choices=[Format1, Format2, Format3])
# Output Arguments.
parser.add_argument('--use-wordpiece', dest='use_wordpiece', default=False, action='store_true')
parser.add_argument('--nwords', dest='nwords', default=100, type=int)
parser.add_argument('--overlap', dest='overlap', default=0, type=int)
# Other Arguments.
parser.add_argument('--nthreads', dest='nthreads', default=28, type=int)
args = parser.parse_args()
assert args.nwords in range(50, 500)
main(args)
|