File size: 2,607 Bytes
e149b0f
 
 
 
 
 
 
 
 
 
 
 
4f8366b
 
 
 
 
 
 
 
 
 
 
 
 
 
e149b0f
 
 
 
 
4f8366b
 
 
 
 
 
 
 
e149b0f
 
 
 
4f8366b
 
 
 
 
 
 
 
 
 
 
 
 
 
e149b0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 11 09:46:51 2023

@author: peter
"""

from allennlp.predictors.predictor import Predictor
import pandas

def clean(sentence):
    """
    Ensure sentence ends with full stop

    Parameters
    ----------
    sentence : str
        Sentence to be cleaned

    Returns
    -------
    str
        Sentence with full stop at the end.

    """
    return sentence if sentence.strip().endswith('.') else sentence+'.'

class CoreferenceResolver(object):
    
    def __init__(self):
        """
        Creates the Coreference resolver

        Returns
        -------
        None.

        """
        model_url = "https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz"
        self.predictor = Predictor.from_path(model_url)
        
    def __call__(self,group):
        """
        

        Parameters
        ----------
        group : pandas.Series
            Sentences on which to perform coreference resolution

        Returns
        -------
        pandas.Series
            Sentences with coreferences resolved

        """
        tokenized = group.apply(clean).str.split()
        line_breaks = tokenized.apply(len).cumsum()
        doc = []
        for line in tokenized:
            doc.extend(line)
        clusters = self.predictor.predict_tokenized(doc)
        resolutions = {}
        for cluster in clusters['clusters']:
            starts = []
            longest = -1
            canonical = None
            for [start_pos,end_pos] in cluster:
                resolutions[start_pos]={'end':end_pos+1}
                starts.append(start_pos)
                length = end_pos - start_pos
                if length > longest:
                    longest = length
                    canonical = doc[start_pos:end_pos+1]
            for start in starts:
                resolutions[start]['canonical']=canonical
        doc_pos = 0
        line = 0
        results = []
        current = []
        while doc_pos < len(doc):
            if doc_pos in resolutions:
                current.extend(resolutions[doc_pos]['canonical'])
                doc_pos=resolutions[doc_pos]['end']
            else:
                current.append(doc[doc_pos])
                doc_pos+=1
            if doc_pos>=line_breaks.iloc[line]:
                results.append(' '.join(current))
                line+=1
                current = []
        return pandas.Series(results,
                             index=group.index)