File size: 4,393 Bytes
72fcc23 27dfa14 2063543 27dfa14 72fcc23 27dfa14 72fcc23 27dfa14 72fcc23 27dfa14 72fcc23 27dfa14 92e4087 27dfa14 e7fbbfd 27dfa14 92e4087 27dfa14 1bce38c 27dfa14 1bce38c ffcf3f0 1bce38c d80634a 1bce38c 84b67ad cbdb566 84b67ad e7fbbfd 1bce38c 27dfa14 72fcc23 27dfa14 72fcc23 27dfa14 72fcc23 27dfa14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import numpy
from transformers import TokenClassificationPipeline
class UniversalDependenciesPipeline(TokenClassificationPipeline):
def _forward(self,model_inputs):
import torch
v=model_inputs["input_ids"][0].tolist()
with torch.no_grad():
e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
return {"logits":e.logits[:,1:-2,:],**model_inputs}
def postprocess(self,model_outputs,**kwargs):
if "logits" not in model_outputs:
return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
e=model_outputs["logits"].numpy()
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,-numpy.inf)
g=self.model.config.label2id["X|_|goeswith"]
m,r=numpy.max(e,axis=2),numpy.tri(e.shape[0])
for i in range(e.shape[0]):
for j in range(i+2,e.shape[1]):
r[i,j]=1
if numpy.argmax(e[i,j-1])==g and numpy.argmax(m[:,j-1])==i:
r[i,j]=r[i,j-1]
e[:,:,g]+=numpy.where(r==0,0,-numpy.inf)
m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
h=self.chu_liu_edmonds(m)
z=[i for i,j in enumerate(h) if i==j]
if len(z)>1:
k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
h=self.chu_liu_edmonds(m)
t=model_outputs["sentence"].replace("\n"," ")
v=[(s,e,c if c!=self.tokenizer.unk_token else t[s:e]) for (s,e),c in zip(model_outputs["offset_mapping"][0].tolist(),self.tokenizer.convert_ids_to_tokens(model_outputs["input_ids"][0].tolist())) if s<e]
q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
g="aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none"
if g:
for i,j in reversed(list(enumerate(q[1:],1))):
if j[-1]=="goeswith" and set([k[-1] for k in q[h[i]+1:i+1]])=={"goeswith"}:
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
s,e,c=v.pop(i)
v[i-1]=(v[i-1][0],e,v[i-1][2]+c)
q.pop(i)
u="\n"
z={"a":"ア","i":"イ","u":"ウ","e":"エ","o":"オ","k":"ㇰ","s":"ㇱ","t":"ㇳ","n":"ㇴ","h":"ㇷ","m":"ㇺ","r":"ㇽ","p":"ㇷ゚"}
f=-1
for i,(s,e,c) in reversed(list(enumerate(v))):
if t[s]=="\u309a":
s-=1
w,x=[j for j in t[s:e]],""
if i>0 and s<v[i-1][1]:
w[0]=z[c[0]] if c[0] in z else "ッ"
f=max(f,i)
elif f>0:
x="{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t{}\n".format(i+1,f+1,t[s:v[f][1]],"_" if f+1<len(v) and v[f][1]<v[f+1][0] else "SpaceAfter=No")
f=-1
if i+1<len(v) and e>v[i+1][0]:
w[-1]=z[c[-1]] if c[-1] in z else "ッ"
if g:
l="".join(w).replace(" ","") if max(w)<"z" else c
l=l.replace("sh","s").replace("ch","c").replace("au","aw").replace("iu","iw").replace("eu","ew").replace("uu","uw").replace("ou","ow").replace("ai","ay").replace("ui","uy").replace("ei","ey").replace("oi","oy")
if q[i][1]=="人称接辞":
if l.find("=")<0:
l="="+l if i>h[i] else l+"="
else:
l="_"
u=x+"\t".join([str(i+1),"".join(w),l,q[i][0],"|".join(q[i][1:-1]),"_",str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"+u
return "# text = "+t+"\n"+u
def chu_liu_edmonds(self,matrix):
h=numpy.argmax(matrix,axis=0)
x=[-1 if i==j else j for i,j in enumerate(h)]
for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
y=[]
while x!=y:
y=list(x)
for i,j in enumerate(x):
x[i]=b(x,i,j)
if max(x)<0:
return h
y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
z=matrix-numpy.max(matrix,axis=0)
m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
h[i]=x[k[-1]] if k[-1]<len(x) else i
return h
|