File size: 4,749 Bytes
0ef6f8e
 
 
 
 
 
 
e384b95
 
 
0ef6f8e
 
 
 
 
 
 
 
 
 
e384b95
0ef6f8e
 
 
e384b95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ef6f8e
 
e384b95
 
 
 
 
 
 
 
 
dcef4a2
0ef6f8e
 
dcef4a2
 
 
e384b95
 
0ef6f8e
 
 
e384b95
0ef6f8e
 
e384b95
0ef6f8e
 
 
 
 
 
 
 
 
 
 
 
 
 
e384b95
0ef6f8e
 
e384b95
0ef6f8e
 
 
 
 
 
 
 
 
 
e384b95
 
 
0ef6f8e
e384b95
0ef6f8e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy
from transformers import TokenClassificationPipeline

class UniversalDependenciesPipeline(TokenClassificationPipeline):
  def __init__(self,**kwargs):
    super().__init__(**kwargs)
    x=self.model.config.label2id
    self.root=numpy.full((len(x)),-numpy.inf)
    self.left_arc=numpy.full((len(x)),-numpy.inf)
    self.right_arc=numpy.full((len(x)),-numpy.inf)
    for k,v in x.items():
      if k.endswith("|root"):
        self.root[v]=0
      elif k.find("|l-")>0:
        self.left_arc[v]=0
      elif k.find("|r-")>0:
        self.right_arc[v]=0
  def check_model_type(self,supported_models):
    pass
  def postprocess(self,model_outputs,**kwargs):
    import torch
    if "logits" not in model_outputs:
      return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
    m=model_outputs["logits"][0].cpu().numpy()
    k=numpy.argmax(m,axis=1).tolist()
    x=[self.model.config.id2label[i].split("|")[1]=="o" for i in k[1:-1]]
    v=model_outputs["input_ids"][0].tolist()
    off=model_outputs["offset_mapping"][0].tolist()
    for i,(s,e) in reversed(list(enumerate(off))):
      if s<e:
        d=model_outputs["sentence"][s:e]
        j=len(d)-len(d.lstrip())
        if j>0:
          d=d.lstrip()
          off[i][0]+=j
        j=len(d)-len(d.rstrip())
        if j>0:
          d=d.rstrip()
          off[i][1]-=j
        if d.strip()=="":
          off.pop(i)
          v.pop(i)
          x.pop(i-1)
    if len(x)<127:
      x=[True]*len(x)
    else:
      w=sum([len(x)-i+1 if b else 0 for i,b in enumerate(x)])+1
      for i in numpy.argsort(numpy.max(m,axis=1)[1:-1]):
        if x[i]==False and w+len(x)-i<8192:
          x[i]=True
          w+=len(x)-i+1
    w=[self.tokenizer.cls_token_id]
    for i,j in enumerate(x):
      if j:
        w+=v[i+1:]
    with torch.no_grad():
      e=self.model(input_ids=torch.tensor([w]).to(self.device))
    m=e.logits[0].cpu().numpy()
    w=len(v)-2
    e=numpy.full((w,w,m.shape[-1]),m.min())
    k=1
    for i in range(w):
      if x[i]:
        e[i,i]=m[k]+self.root
        k+=1
        for j in range(1,w-i):
          e[i+j,i]=m[k]+self.left_arc
          e[i,i+j]=m[k]+self.right_arc
          k+=1
        k+=1
    g=self.model.config.label2id["X|x|r-goeswith"]
    m,r=numpy.max(e,axis=2),numpy.tri(e.shape[0])
    for i in range(e.shape[0]):
      for j in range(i+2,e.shape[1]):
        r[i,j]=1
        if numpy.argmax(e[i,j-1])==g and numpy.argmax(m[:,j-1])==i:
          r[i,j]=r[i,j-1]
    e[:,:,g]+=numpy.where(r==0,0,-numpy.inf)
    m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
    h=self.chu_liu_edmonds(m)
    z=[i for i,j in enumerate(h) if i==j]
    if len(z)>1:
      k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
      m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
      h=self.chu_liu_edmonds(m)
    v=[(s,e) for s,e in off if s<e]
    q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
    if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
      for i,j in reversed(list(enumerate(q[1:],1))):
        if j[-1]=="r-goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"r-goeswith"}:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          v[i-1]=(v[i-1][0],v.pop(i)[1])
          q.pop(i)
        elif v[i-1][1]>v[i][0]:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          v[i-1]=(v[i-1][0],v.pop(i)[1])
          q.pop(i)
    t=model_outputs["sentence"].replace("\n"," ")
    u="# text = "+t+"\n"
    for i,(s,e) in enumerate(v):
      u+="\t".join([str(i+1),t[s:e],"_",q[i][0],"_","_" if len(q[i])<4 else "|".join(q[i][2:-1]),str(0 if h[i]==i else h[i]+1),"root" if q[i][-1]=="root" else q[i][-1][2:],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
    return u+"\n"
  def chu_liu_edmonds(self,matrix):
    h=numpy.argmax(matrix,axis=0)
    x=[-1 if i==j else j for i,j in enumerate(h)]
    for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
      y=[]
      while x!=y:
        y=list(x)
        for i,j in enumerate(x):
          x[i]=b(x,i,j)
      if max(x)<0:
        return h
    y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
    z=matrix-numpy.max(matrix,axis=0)
    m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
    k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
    h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
    i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
    h[i]=x[k[-1]] if k[-1]<len(x) else i
    return h