File size: 5,003 Bytes
d554a77
 
 
 
 
 
 
bcdd2d8
 
d554a77
bcdd2d8
d554a77
bcdd2d8
 
d554a77
 
 
bcdd2d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d554a77
bcdd2d8
 
d554a77
 
 
 
 
 
bcdd2d8
 
 
 
 
 
 
 
 
 
 
 
 
 
d554a77
 
bcdd2d8
 
 
 
 
 
d554a77
 
 
bcdd2d8
d554a77
 
 
 
 
 
bcdd2d8
d554a77
 
 
 
 
 
 
 
3eb6342
 
 
 
 
 
 
 
 
 
 
 
 
 
d554a77
 
bcdd2d8
d554a77
 
bcdd2d8
 
d554a77
 
 
 
 
 
 
 
 
 
bcdd2d8
 
 
d554a77
bcdd2d8
d554a77
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import numpy
from transformers import TokenClassificationPipeline

class UniversalDependenciesPipeline(TokenClassificationPipeline):
  def __init__(self,**kwargs):
    super().__init__(**kwargs)
    x=self.model.config.label2id
    self.root=numpy.full((len(x)),-numpy.inf)
    self.arc=numpy.full((len(x)),-numpy.inf)
    for k,v in x.items():
      if k.endswith("|[root]"):
        self.root[v]=0
      elif k.endswith("]"):
        self.arc[v]=0
  def _forward(self,model_inputs):
    import torch
    v=model_inputs["input_ids"][0].tolist()
    if len(v)<91:
      x=[True]*(len(v)-2)
    else:
      with torch.no_grad():
        e=self.model(input_ids=torch.tensor([v]).to(self.device))
      m=e.logits[0].cpu().numpy()
      e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
      z=e/e.sum(axis=-1,keepdims=True)
      k=numpy.argmax(m,axis=1).tolist()
      x=[not self.model.config.id2label[p].split("|")[0].endswith(".") for p in k[1:-1]]
      w=(sum([1 for b in x if b])+1)*(len(x)+1)+1
      for i in numpy.argsort([z[i+1,k[i+1]] for i in range(len(x))]):
        if w+len(x)>8191:
          break
        if not x[i]:
          x[i]=True
          w+=len(x)+1
    ids=list(v)
    for i in range(len(x)):
      if x[i]:
        ids+=v[1:i+1]+[self.tokenizer.mask_token_id]+v[i+2:]
    with torch.no_grad():
      e=self.model(input_ids=torch.tensor([ids]).to(self.device))
    return {"logits":e.logits,"thin_out":x,**model_inputs}
  def check_model_type(self,supported_models):
    pass
  def postprocess(self,model_outputs,**kwargs):
    if "logits" not in model_outputs:
      return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
    m=model_outputs["logits"][0].cpu().numpy()
    x=model_outputs["thin_out"]
    e=numpy.full((len(x),len(x),m.shape[-1]),m.min())
    k=len(x)+2
    for i in range(len(x)):
      if x[i]:
        for j in range(len(x)):
          if i==j:
            e[i,i]=m[k]+self.root
          else:
            e[i,j]=m[k]+self.arc
          k+=1
        k+=1
    g=self.model.config.label2id["X.|[goeswith]"]
    m,r=numpy.max(e,axis=2),numpy.tri(e.shape[0])
    for i in range(e.shape[0]):
      for j in range(i+2,e.shape[1]):
        r[i,j]=1
        if numpy.argmax(e[i,j-1])==g:
          if numpy.argmax(m[:,j-1])==i:
            r[i,j]=r[i,j-1]
    e[:,:,g]+=numpy.where(r==0,0,-numpy.inf)
    m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
    h=self.chu_liu_edmonds(m)
    z=[i for i,j in enumerate(h) if i==j]
    if len(z)>1:
      k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
      m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
      h=self.chu_liu_edmonds(m)
    v=[(s,e) for s,e in model_outputs["offset_mapping"][0].tolist() if s<e]
    q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
    if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
      for i,j in reversed(list(enumerate(q[1:],1))):
        if j[-1]=="[goeswith]" and set([t[-1] for t in q[h[i]+1:i+1]])=={"[goeswith]"}:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          v[i-1]=(v[i-1][0],v.pop(i)[1])
          q.pop(i)
        elif v[i-1][1]>v[i][0]:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          v[i-1]=(v[i-1][0],v.pop(i)[1])
          q.pop(i)
    t=model_outputs["sentence"].replace("\n"," ")
    for i,(s,e) in reversed(list(enumerate(v))):
      d=t[s:e]
      j=len(d)-len(d.lstrip())
      if j>0:
        d=d.lstrip()
        v[i]=(v[i][0]+j,v[i][1])
      j=len(d)-len(d.rstrip())
      if j>0:
        d=d.rstrip()
        v[i]=(v[i][0],v[i][1]-j)
      if d.strip()=="":
        h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
        v.pop(i)
        q.pop(i)
    u="# text = "+t+"\n"
    for i,(s,e) in enumerate(v):
      u+="\t".join([str(i+1),t[s:e],"_",q[i][0].replace(".",""),"_","_" if len(q[i])<3 else "|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),q[i][-1][1:-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
    return u+"\n"
  def chu_liu_edmonds(self,matrix):
    import numpy
    h=numpy.argmax(matrix,axis=0)
    x=[-1 if i==j else j for i,j in enumerate(h)]
    for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
      y=[]
      while x!=y:
        y=list(x)
        for i,j in enumerate(x):
          x[i]=b(x,i,j)
      if max(x)<0:
        return h
    y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
    z=matrix-numpy.max(matrix,axis=0)
    m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
    k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
    h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
    i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
    h[i]=x[k[-1]] if k[-1]<len(x) else i
    return h