QARAC / qarac /utils /CoreferenceResolver.py
PeteBleackley
Components for managing training corpora
4f8366b
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 11 09:46:51 2023
@author: peter
"""
from allennlp.predictors.predictor import Predictor
import pandas
def clean(sentence):
"""
Ensure sentence ends with full stop
Parameters
----------
sentence : str
Sentence to be cleaned
Returns
-------
str
Sentence with full stop at the end.
"""
return sentence if sentence.strip().endswith('.') else sentence+'.'
class CoreferenceResolver(object):
def __init__(self):
"""
Creates the Coreference resolver
Returns
-------
None.
"""
model_url = "https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz"
self.predictor = Predictor.from_path(model_url)
def __call__(self,group):
"""
Parameters
----------
group : pandas.Series
Sentences on which to perform coreference resolution
Returns
-------
pandas.Series
Sentences with coreferences resolved
"""
tokenized = group.apply(clean).str.split()
line_breaks = tokenized.apply(len).cumsum()
doc = []
for line in tokenized:
doc.extend(line)
clusters = self.predictor.predict_tokenized(doc)
resolutions = {}
for cluster in clusters['clusters']:
starts = []
longest = -1
canonical = None
for [start_pos,end_pos] in cluster:
resolutions[start_pos]={'end':end_pos+1}
starts.append(start_pos)
length = end_pos - start_pos
if length > longest:
longest = length
canonical = doc[start_pos:end_pos+1]
for start in starts:
resolutions[start]['canonical']=canonical
doc_pos = 0
line = 0
results = []
current = []
while doc_pos < len(doc):
if doc_pos in resolutions:
current.extend(resolutions[doc_pos]['canonical'])
doc_pos=resolutions[doc_pos]['end']
else:
current.append(doc[doc_pos])
doc_pos+=1
if doc_pos>=line_breaks.iloc[line]:
results.append(' '.join(current))
line+=1
current = []
return pandas.Series(results,
index=group.index)