File size: 3,742 Bytes
4f8366b
 
 
 
 
 
 
 
9ca9d81
4b9d371
4f8366b
 
 
 
 
5b00eb7
4f8366b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ca9d81
 
4f8366b
 
 
9ca9d81
 
acea8cd
39b4175
 
9ca9d81
 
839e93f
9ca9d81
 
 
 
 
4f8366b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ca9d81
 
4f8366b
 
 
 
85922e5
4f8366b
 
 
 
 
fcfc2b3
 
be7beac
fcfc2b3
 
 
be7beac
 
fcfc2b3
 
 
4f8366b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 20 07:48:54 2023

@author: peter
"""

import numpy
import pandas
import tokenizers

class CorpusLoader(object):
    
    def __init__(self,path,
                 tokenizer,
                 text_inputs,
                 text_outputs,
                 label=None):
        """
        Creates the Corpus Loader

        Parameters
        ----------
        path : str
            Path to load dataset from
        start_doc : tokenizers.Encoding
            Token id for document start character
        end_doc : tokenizers.Encoding
            Token id for the document end character
        text_inputs : list[str]
            Columns of the dataset to add to the inputs
        text_outputs : dict[str,tuple[str]]
            The columns of the dataset to add to the outputs. The key is the name
            of the column in the original dataset, the first element of the tuple
            is the name that the column prefixed with '<s>' will have in the 
            inputs, and the second element of the tuple is the name that the column
            suffixed with '</s>' will have in the outputs
        label : str, optional
            A column of numerical labels to add to the outputs. The default is None.

        Returns
        -------
        None.

        """
        data = pandas.read_csv(path)
        self.n_rows = data.shape[0]
        self.text_inputs = text_inputs
        self.text_outputs = text_outputs
        self.label = label
        self.rng = numpy.random.default_rng()
        columns = list(set(self.text_inputs)|set(self.text_outputs.keys()))
        tokenized = {column:tokenizer.encode_batch(data[column].apply(lambda x:tokenizers.TextInputSequence(x)),
                                                   add_special_tokens=False)
                     for column in columns}
        if self.label is not None:
            tokenized[self.label] = data[self.label]
            columns.append(self.label)
        self.dataset = [{column:tokenized[column][i]
                         for column in columns}
                        for i in range(self.n_rows)]
        self.start_doc = tokenizer.encode('<s>')
        self.end_doc = tokenizer.encode('</s>')
        
    def __len__(self):
        """
        The length of the corpus

        Returns
        -------
        int
            The number of samples

        """
        return self.n_rows
    
    def __iter__(self):
        """
        Generates samples in a random order

        Yields
        ------
        X : dict
            Inputs for model
        Y : dict
            outputs for model

        """
        self.rng.shuffle(self.dataset)
        for row in self.dataset:
            X={}
            Y={}
            for column in self.text_inputs:
                X[column] = row[column]
            for (column,(x_name,y_name)) in self.text_outputs.items():
                X[x_name] = tokenizers.Encoding.merge([self.start_doc,row[column]])
                Y[y_name] = tokenizers.Encoding.merge([row[column],self.end_doc])
            if self.label is not None:
                Y[self.label]=row[self.label]
            yield (X,Y)
            
    def max_lengths(self):
        result = {column:max((len(row[column])
                              for row in self.dataset))
                  for column in self.text_inputs}
        for (column,(inside,outside)) in self.text_outputs.items():
            n = result[column] if column in result else max((len(row[column])
                                                                 for row in self.dataset))
            result[inside] = n+1
            result[outside] = n+1
        return result