File size: 5,275 Bytes
f59d332
 
 
 
 
 
 
 
ffd4066
f59d332
 
 
 
 
 
 
 
ffd4066
 
f59d332
 
 
 
ffd4066
 
f59d332
ffd4066
f59d332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffd4066
 
 
f59d332
 
 
 
 
 
 
 
 
ffd4066
f59d332
 
ffd4066
 
f59d332
 
 
 
 
 
 
 
 
f3f71bf
f59d332
f3f71bf
f59d332
f3f71bf
 
 
f59d332
 
 
 
ffd4066
f59d332
 
 
ffd4066
f59d332
 
 
 
 
 
fcb09d2
f59d332
 
ffd4066
f59d332
 
 
 
 
 
 
 
 
 
ffd4066
 
 
1b47855
ffd4066
1b47855
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy
from transformers import TokenClassificationPipeline

class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
  def __init__(self,**kwargs):
    super().__init__(**kwargs)
    x=self.model.config.label2id
    y=[k for k in x if k.startswith("B-") or not (k.startswith("I-") or k.endswith("|root") or k.find("|l-")>0 or k.find("|r-")>0)]
    self.transition=numpy.full((len(x),len(x)),-numpy.inf)
    for k,v in x.items():
      for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
        self.transition[v,x[j]]=0
  def check_model_type(self,supported_models):
    pass
  def postprocess(self,model_outputs,**kwargs):
    if "logits" not in model_outputs:
      return self.postprocess(model_outputs[0],**kwargs)
    return self.bellman_ford_token_classification(model_outputs,**kwargs)
  def bellman_ford_token_classification(self,model_outputs,**kwargs):
    m=model_outputs["logits"][0].numpy()
    e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
    z=e/e.sum(axis=-1,keepdims=True)
    for i in range(m.shape[0]-1,0,-1):
      m[i-1]+=numpy.max(m[i]+self.transition,axis=1)
    k=[numpy.argmax(m[0]+self.transition[0])]
    for i in range(1,m.shape[0]):
      k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
    w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
    if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
      for i,t in reversed(list(enumerate(w))):
        p=t.pop("entity")
        if p.startswith("I-"):
          w[i-1]["score"]=min(w[i-1]["score"],t["score"])
          w[i-1]["end"]=w.pop(i)["end"]
        elif p.startswith("B-"):
          t["entity_group"]=p[2:]
        else:
          t["entity_group"]=p
    for t in w:
      t["text"]=model_outputs["sentence"][t["start"]:t["end"]]
    return w

class UniversalDependenciesCausalPipeline(BellmanFordTokenClassificationPipeline):
  def __init__(self,**kwargs):
    kwargs["aggregation_strategy"]="simple"
    super().__init__(**kwargs)
    x=self.model.config.label2id
    self.root=numpy.full((len(x)),-numpy.inf)
    self.left_arc=numpy.full((len(x)),-numpy.inf)
    self.right_arc=numpy.full((len(x)),-numpy.inf)
    for k,v in x.items():
      if k.endswith("|root"):
        self.root[v]=0
      elif k.find("|l-")>0:
        self.left_arc[v]=0
      elif k.find("|r-")>0:
        self.right_arc[v]=0
  def postprocess(self,model_outputs,**kwargs):
    import torch
    kwargs["aggregation_strategy"]="simple"
    if "logits" not in model_outputs:
      return self.postprocess(model_outputs[0],**kwargs)
    w=self.bellman_ford_token_classification(model_outputs,**kwargs)
    d=[t["text"] for t in w]
    v=self.tokenizer(d,add_special_tokens=False)
    e=self.model.get_input_embeddings().weight
    m=[]
    for x in v["input_ids"]:
      if x==[]:
        x=[self.tokenizer.unk_token_id]
      m.append(e[x,:].sum(axis=0))
    m.append(e[self.tokenizer.sep_token_id,:])
    m.append(e[self.tokenizer.pad_token_id,:])
    m=torch.stack(m).to(self.device)
    k=list(range(len(d)+1))
    e=[]
    with torch.no_grad():
      for i in range(len(d)):
        e.append(self.model(inputs_embeds=torch.unsqueeze(m[k+list(range(i,len(d)))+[-1]*i,:],0)).logits[0,-len(d):,:])
    e=torch.stack(e).cpu().numpy()
    for i in range(len(d)):
      for j in range(i):
        e[-j-1,-i-1],e[-i-1,-j-1]=e[-i-1,i-j]+self.left_arc,e[-i-1,i-j]+self.right_arc
      e[-i-1,-i-1]=e[-i-1,0]+self.root
    m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
    h=self.chu_liu_edmonds(m)
    z=[i for i,j in enumerate(h) if i==j]
    if len(z)>1:
      k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
      m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
      h=self.chu_liu_edmonds(m)
    q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
    t=model_outputs["sentence"].replace("\n"," ")
    u="# text = "+t+"\n"
    for i,j in enumerate(d):
      u+="\t".join([str(i+1),j,j,q[i][0],"_","_" if len(q[i])<3 else "|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),"root" if q[i][-1]=="root" else q[i][-1][2:],"_","_" if i+1<len(d) and w[i]["end"]<w[i+1]["start"] else "SpaceAfter=No"])+"\n"
    return u+"\n"
  def chu_liu_edmonds(self,matrix):
    h=numpy.argmax(matrix,axis=0)
    x=[-1 if i==j else j for i,j in enumerate(h)]
    for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
      y=[]
      while x!=y:
        y=list(x)
        for i,j in enumerate(x):
          x[i]=b(x,i,j)
      if max(x)<0:
        return h
    y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
    z=matrix-numpy.max(matrix,axis=0)
    m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
    k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
    h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
    i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
    h[i]=x[k[-1]] if k[-1]<len(x) else i
    return h