File size: 4,749 Bytes
b4f668d
 
 
 
 
 
 
d01eb45
 
 
b4f668d
 
 
 
 
 
 
 
 
 
d01eb45
b4f668d
 
 
d01eb45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f668d
 
d01eb45
 
 
 
 
 
 
 
 
5221ad8
b4f668d
 
5221ad8
 
 
d01eb45
 
b4f668d
 
 
d01eb45
b4f668d
 
d01eb45
b4f668d
 
 
 
 
 
 
 
 
 
 
 
 
 
d01eb45
b4f668d
 
d01eb45
b4f668d
 
 
 
 
 
 
 
 
 
d01eb45
 
 
b4f668d
d01eb45
b4f668d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy
from transformers import TokenClassificationPipeline

class UniversalDependenciesPipeline(TokenClassificationPipeline):
  def __init__(self,**kwargs):
    super().__init__(**kwargs)
    x=self.model.config.label2id
    self.root=numpy.full((len(x)),-numpy.inf)
    self.left_arc=numpy.full((len(x)),-numpy.inf)
    self.right_arc=numpy.full((len(x)),-numpy.inf)
    for k,v in x.items():
      if k.endswith("|root"):
        self.root[v]=0
      elif k.find("|l-")>0:
        self.left_arc[v]=0
      elif k.find("|r-")>0:
        self.right_arc[v]=0
  def check_model_type(self,supported_models):
    pass
  def postprocess(self,model_outputs,**kwargs):
    import torch
    if "logits" not in model_outputs:
      return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
    m=model_outputs["logits"][0].cpu().numpy()
    k=numpy.argmax(m,axis=1).tolist()
    x=[self.model.config.id2label[i].split("|")[1]=="o" for i in k[1:-1]]
    v=model_outputs["input_ids"][0].tolist()
    off=model_outputs["offset_mapping"][0].tolist()
    for i,(s,e) in reversed(list(enumerate(off))):
      if s<e:
        d=model_outputs["sentence"][s:e]
        j=len(d)-len(d.lstrip())
        if j>0:
          d=d.lstrip()
          off[i][0]+=j
        j=len(d)-len(d.rstrip())
        if j>0:
          d=d.rstrip()
          off[i][1]-=j
        if d.strip()=="":
          off.pop(i)
          v.pop(i)
          x.pop(i-1)
    if len(x)<127:
      x=[True]*len(x)
    else:
      w=sum([len(x)-i+1 if b else 0 for i,b in enumerate(x)])+1
      for i in numpy.argsort(numpy.max(m,axis=1)[1:-1]):
        if x[i]==False and w+len(x)-i<8192:
          x[i]=True
          w+=len(x)-i+1
    w=[self.tokenizer.cls_token_id]
    for i,j in enumerate(x):
      if j:
        w+=v[i+1:]
    with torch.no_grad():
      e=self.model(input_ids=torch.tensor([w]).to(self.device))
    m=e.logits[0].cpu().numpy()
    w=len(v)-2
    e=numpy.full((w,w,m.shape[-1]),m.min())
    k=1
    for i in range(w):
      if x[i]:
        e[i,i]=m[k]+self.root
        k+=1
        for j in range(1,w-i):
          e[i+j,i]=m[k]+self.left_arc
          e[i,i+j]=m[k]+self.right_arc
          k+=1
        k+=1
    g=self.model.config.label2id["X|x|r-goeswith"]
    m,r=numpy.max(e,axis=2),numpy.tri(e.shape[0])
    for i in range(e.shape[0]):
      for j in range(i+2,e.shape[1]):
        r[i,j]=1
        if numpy.argmax(e[i,j-1])==g and numpy.argmax(m[:,j-1])==i:
          r[i,j]=r[i,j-1]
    e[:,:,g]+=numpy.where(r==0,0,-numpy.inf)
    m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
    h=self.chu_liu_edmonds(m)
    z=[i for i,j in enumerate(h) if i==j]
    if len(z)>1:
      k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
      m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
      h=self.chu_liu_edmonds(m)
    v=[(s,e) for s,e in off if s<e]
    q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
    if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
      for i,j in reversed(list(enumerate(q[1:],1))):
        if j[-1]=="r-goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"r-goeswith"}:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          v[i-1]=(v[i-1][0],v.pop(i)[1])
          q.pop(i)
        elif v[i-1][1]>v[i][0]:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          v[i-1]=(v[i-1][0],v.pop(i)[1])
          q.pop(i)
    t=model_outputs["sentence"].replace("\n"," ")
    u="# text = "+t+"\n"
    for i,(s,e) in enumerate(v):
      u+="\t".join([str(i+1),t[s:e],"_",q[i][0],"_","_" if len(q[i])<4 else "|".join(q[i][2:-1]),str(0 if h[i]==i else h[i]+1),"root" if q[i][-1]=="root" else q[i][-1][2:],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
    return u+"\n"
  def chu_liu_edmonds(self,matrix):
    h=numpy.argmax(matrix,axis=0)
    x=[-1 if i==j else j for i,j in enumerate(h)]
    for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
      y=[]
      while x!=y:
        y=list(x)
        for i,j in enumerate(x):
          x[i]=b(x,i,j)
      if max(x)<0:
        return h
    y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
    z=matrix-numpy.max(matrix,axis=0)
    m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
    k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
    h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
    i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
    h[i]=x[k[-1]] if k[-1]<len(x) else i
    return h