File size: 4,749 Bytes
86b8c35
 
 
 
 
 
 
1a91052
 
 
86b8c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f52a6db
86b8c35
 
 
 
 
 
 
 
 
 
 
 
 
 
f52a6db
 
 
 
 
 
 
 
 
86b8c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adbaeb5
86b8c35
 
adbaeb5
 
 
1a91052
 
86b8c35
 
 
1a91052
86b8c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a91052
86b8c35
 
 
 
 
 
 
 
 
 
1a91052
 
 
86b8c35
1a91052
86b8c35
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy
from transformers import TokenClassificationPipeline

class UniversalDependenciesPipeline(TokenClassificationPipeline):
  def __init__(self,**kwargs):
    super().__init__(**kwargs)
    x=self.model.config.label2id
    self.root=numpy.full((len(x)),-numpy.inf)
    self.left_arc=numpy.full((len(x)),-numpy.inf)
    self.right_arc=numpy.full((len(x)),-numpy.inf)
    for k,v in x.items():
      if k.endswith("|root"):
        self.root[v]=0
      elif k.find("|l-")>0:
        self.left_arc[v]=0
      elif k.find("|r-")>0:
        self.right_arc[v]=0
  def check_model_type(self,supported_models):
    pass
  def postprocess(self,model_outputs,**kwargs):
    import torch
    if "logits" not in model_outputs:
      return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
    m=model_outputs["logits"][0].cpu().numpy()
    k=numpy.argmax(m,axis=1).tolist()
    x=[self.model.config.id2label[i].split("|")[1]=="o" for i in k[1:-1]]
    v=model_outputs["input_ids"][0].tolist()
    off=model_outputs["offset_mapping"][0].tolist()
    for i,(s,e) in reversed(list(enumerate(off))):
      if s<e:
        d=model_outputs["sentence"][s:e]
        j=len(d)-len(d.lstrip())
        if j>0:
          d=d.lstrip()
          off[i][0]+=j
        j=len(d)-len(d.rstrip())
        if j>0:
          d=d.rstrip()
          off[i][1]-=j
        if d.strip()=="":
          off.pop(i)
          v.pop(i)
          x.pop(i-1)
    if len(x)<127:
      x=[True]*len(x)
    else:
      w=sum([len(x)-i+1 if b else 0 for i,b in enumerate(x)])+1
      for i in numpy.argsort(numpy.max(m,axis=1)[1:-1]):
        if x[i]==False and w+len(x)-i<8192:
          x[i]=True
          w+=len(x)-i+1
    w=[self.tokenizer.cls_token_id]
    for i,j in enumerate(x):
      if j:
        w+=v[i+1:]
    with torch.no_grad():
      e=self.model(input_ids=torch.tensor([w]).to(self.device))
    m=e.logits[0].cpu().numpy()
    w=len(v)-2
    e=numpy.full((w,w,m.shape[-1]),m.min())
    k=1
    for i in range(w):
      if x[i]:
        e[i,i]=m[k]+self.root
        k+=1
        for j in range(1,w-i):
          e[i+j,i]=m[k]+self.left_arc
          e[i,i+j]=m[k]+self.right_arc
          k+=1
        k+=1
    g=self.model.config.label2id["X|x|r-goeswith"]
    m,r=numpy.max(e,axis=2),numpy.tri(e.shape[0])
    for i in range(e.shape[0]):
      for j in range(i+2,e.shape[1]):
        r[i,j]=1
        if numpy.argmax(e[i,j-1])==g and numpy.argmax(m[:,j-1])==i:
          r[i,j]=r[i,j-1]
    e[:,:,g]+=numpy.where(r==0,0,-numpy.inf)
    m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
    h=self.chu_liu_edmonds(m)
    z=[i for i,j in enumerate(h) if i==j]
    if len(z)>1:
      k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
      m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
      h=self.chu_liu_edmonds(m)
    v=[(s,e) for s,e in off if s<e]
    q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
    if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
      for i,j in reversed(list(enumerate(q[1:],1))):
        if j[-1]=="r-goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"r-goeswith"}:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          v[i-1]=(v[i-1][0],v.pop(i)[1])
          q.pop(i)
        elif v[i-1][1]>v[i][0]:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          v[i-1]=(v[i-1][0],v.pop(i)[1])
          q.pop(i)
    t=model_outputs["sentence"].replace("\n"," ")
    u="# text = "+t+"\n"
    for i,(s,e) in enumerate(v):
      u+="\t".join([str(i+1),t[s:e],"_",q[i][0],"_","_" if len(q[i])<4 else "|".join(q[i][2:-1]),str(0 if h[i]==i else h[i]+1),"root" if q[i][-1]=="root" else q[i][-1][2:],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
    return u+"\n"
  def chu_liu_edmonds(self,matrix):
    h=numpy.argmax(matrix,axis=0)
    x=[-1 if i==j else j for i,j in enumerate(h)]
    for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
      y=[]
      while x!=y:
        y=list(x)
        for i,j in enumerate(x):
          x[i]=b(x,i,j)
      if max(x)<0:
        return h
    y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
    z=matrix-numpy.max(matrix,axis=0)
    m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
    k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
    h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
    i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
    h[i]=x[k[-1]] if k[-1]<len(x) else i
    return h