File size: 4,790 Bytes
58627fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
    Divide a document collection into N-word/token passage spans (with wrap-around for last passage).
"""

import os
import math
import ujson
import random

from multiprocessing import Pool
from argparse import ArgumentParser
from colbert.utils.utils import print_message

Format1 = 'docid,text'  # MS MARCO Passages
Format2 = 'docid,text,title'   # DPR Wikipedia
Format3 = 'docid,url,title,text'  # MS MARCO Documents


def process_page(inp):
    """
        Wraps around if we split: make sure last passage isn't too short.
        This is meant to be similar to the DPR preprocessing.
    """

    (nwords, overlap, tokenizer), (title_idx, docid, title, url, content) = inp

    if tokenizer is None:
        words = content.split()
    else:
        words = tokenizer.tokenize(content)

    words_ = (words + words) if len(words) > nwords else words
    passages = [words_[offset:offset + nwords] for offset in range(0, len(words) - overlap, nwords - overlap)]

    assert all(len(psg) in [len(words), nwords] for psg in passages), (list(map(len, passages)), len(words))

    if tokenizer is None:
        passages = [' '.join(psg) for psg in passages]
    else:
        passages = [' '.join(psg).replace(' ##', '') for psg in passages]

    if title_idx % 100000 == 0:
        print("#> ", title_idx, '\t\t\t', title)

        for p in passages:
            print("$$$ ", '\t\t', p)
            print()

        print()
        print()
        print()

    return (docid, title, url, passages)


def main(args):
    random.seed(12345)
    print_message("#> Starting...")

    letter = 'w' if not args.use_wordpiece else 't'
    output_path = f'{args.input}.{letter}{args.nwords}_{args.overlap}'
    assert not os.path.exists(output_path)

    RawCollection = []
    Collection = []

    NumIllFormattedLines = 0

    with open(args.input) as f:
        for line_idx, line in enumerate(f):
            if line_idx % (100*1000) == 0:
                print(line_idx, end=' ')

            title, url = None, None

            try:
                line = line.strip().split('\t')

                if args.format == Format1:
                    docid, doc = line
                elif args.format == Format2:
                    docid, doc, title = line
                elif args.format == Format3:
                    docid, url, title, doc = line

                RawCollection.append((line_idx, docid, title, url, doc))
            except:
                NumIllFormattedLines += 1

                if NumIllFormattedLines % 1000 == 0:
                    print(f'\n[{line_idx}] NumIllFormattedLines = {NumIllFormattedLines}\n')

    print()
    print_message("# of documents is", len(RawCollection), '\n')

    p = Pool(args.nthreads)

    print_message("#> Starting parallel processing...")

    tokenizer = None
    if args.use_wordpiece:
        from transformers import BertTokenizerFast
        tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

    process_page_params = [(args.nwords, args.overlap, tokenizer)] * len(RawCollection)
    Collection = p.map(process_page, zip(process_page_params, RawCollection))

    print_message(f"#> Writing to {output_path} ...")
    with open(output_path, 'w') as f:
        line_idx = 1

        if args.format == Format1:
            f.write('\t'.join(['id', 'text']) + '\n')
        elif args.format == Format2:
            f.write('\t'.join(['id', 'text', 'title']) + '\n')
        elif args.format == Format3:
            f.write('\t'.join(['id', 'text', 'title', 'docid']) + '\n')

        for docid, title, url, passages in Collection:
            for passage in passages:
                if args.format == Format1:
                    f.write('\t'.join([str(line_idx), passage]) + '\n')
                elif args.format == Format2:
                    f.write('\t'.join([str(line_idx), passage, title]) + '\n')
                elif args.format == Format3:
                    f.write('\t'.join([str(line_idx), passage, title, docid]) + '\n')

                line_idx += 1


if __name__ == "__main__":
    parser = ArgumentParser(description="docs2passages.")

    # Input Arguments.
    parser.add_argument('--input', dest='input', required=True)
    parser.add_argument('--format', dest='format', required=True, choices=[Format1, Format2, Format3])

    # Output Arguments.
    parser.add_argument('--use-wordpiece', dest='use_wordpiece', default=False, action='store_true')
    parser.add_argument('--nwords', dest='nwords', default=100, type=int)
    parser.add_argument('--overlap', dest='overlap', default=0, type=int)

    # Other Arguments.
    parser.add_argument('--nthreads', dest='nthreads', default=28, type=int)

    args = parser.parse_args()
    assert args.nwords in range(50, 500)

    main(args)