Spaces:
Runtime error
Runtime error
#### py39_cp_cp | |
from zh_mt5_model import * | |
from en_t2t_model import * | |
import os | |
os.environ['KMP_DUPLICATE_LIB_OK']='True' | |
import spacy | |
import pandas as pd | |
import numpy as np | |
import re | |
from tqdm import tqdm | |
from copy import deepcopy | |
import pathlib | |
import json | |
import pickle as pkl | |
from tqdm import tqdm | |
from easynmt import EasyNMT | |
### https://huggingface.co/svjack/squad_gen_qst_zh_v0 | |
path = "svjack/squad_gen_qst_zh_v0" | |
asker_zh = T5_B(path, | |
device = "cpu") | |
zh_nlp = spacy.load("zh_core_web_sm") | |
en_nlp = spacy.load("en_core_web_sm") | |
trans_model = EasyNMT('opus-mt') | |
def detect_language(text): | |
assert type(text) == type("") | |
# detect_list.append(trans_model.language_detection_fasttext(prompt)) | |
lang = trans_model.language_detection_fasttext(text) | |
lang = lang.lower().strip() | |
if "zh" not in lang and "en" not in lang: | |
lang = "others" | |
if "zh" in lang: | |
lang = "zh" | |
if "en" in lang: | |
lang = "en" | |
assert lang in ["en", "zh", "others"] | |
return lang | |
def drop_duplicates_by_col(df, on_col = "aug_sparql_query"): | |
assert hasattr(df, "size") | |
assert on_col in df.columns.tolist() | |
req = [] | |
set_ = set([]) | |
for i, r in df.iterrows(): | |
if r[on_col] not in set_: | |
set_.add(r[on_col]) | |
req.append(r) | |
return pd.DataFrame(req) | |
def sent_with_ents(sent, en_nlp): | |
assert type(sent) == type("") | |
doc = en_nlp(sent) | |
return (sent, pd.Series(doc.ents).map( | |
lambda span: (span.text, span.label_) | |
).values.tolist()) | |
def gen_ask_by_span_zh(asker ,sent, span): | |
if type(span) == type(""): | |
span = [span] | |
if not span: | |
return [] | |
sent = sent.replace("|", "") | |
span = list(map(lambda x: x.replace("|", ""), span)) | |
x = list(map(lambda x: "{}|{}".format(sent, x), span)) | |
return list(map( | |
lambda y: asker.predict(y) | |
, x)) | |
#### list return | |
def gen_ask_by_span(asker, sent, span, lang): | |
assert lang in ["en", "zh"] | |
if lang == "zh": | |
return gen_ask_by_span_zh(asker ,sent, span) | |
else: | |
return gen_ask_by_span_en(t2t, sent, span) | |
def filter_ent_cate(ent_list, maintain_cate_list = [ | |
"DATE", "FAC", "GPE", "LOC", "PERSON" | |
]): | |
if not ent_list: | |
return [] | |
return list(filter(lambda t2: t2[1] in maintain_cate_list, ent_list)) | |
def batch_as_list(a, batch_size = int(100000)): | |
req = [] | |
for ele in a: | |
if not req: | |
req.append([]) | |
if len(req[-1]) < batch_size: | |
req[-1].append(ele) | |
else: | |
req.append([]) | |
req[-1].append(ele) | |
return req | |
def gen_qst_to_df(paragraph, | |
nlp = zh_nlp, | |
asker = asker_zh, | |
nlp_input = None, | |
maintain_cate_list = [ | |
"DATE", "FAC", "GPE", "LOC", "PERSON" | |
], limit_ents_size = 10, batch_size = 4 | |
): | |
if limit_ents_size is None: | |
limit_ents_size = 10000 | |
assert type(paragraph) == type("") | |
lang = detect_language(paragraph) | |
if lang != "zh": | |
lang = "en" | |
nlp = en_nlp if lang == "en" else zh_nlp | |
if nlp_input is None: | |
_, entity_list = sent_with_ents(paragraph, nlp) | |
else: | |
_, entity_list = deepcopy(nlp_input) | |
if maintain_cate_list: | |
entity_list = filter_ent_cate(entity_list, maintain_cate_list = maintain_cate_list) | |
entity_list = entity_list[:limit_ents_size] | |
if not entity_list: | |
return None | |
l = batch_as_list(entity_list, batch_size) | |
for ele in tqdm(l): | |
ents = list(map(lambda x: x[0], ele)) | |
ent_cates = list(map(lambda x: x[1], ele)) | |
#questions = gen_ask_by_span_zh(asker, paragraph, ents) | |
questions = gen_ask_by_span(asker, paragraph, ents, lang) | |
assert len(ele) == len(ent_cates) == len(questions) | |
#return [ele, ent_cates, questions, ans] | |
batch_l = list(map(pd.Series, [ents, ent_cates, questions])) | |
batch_df = pd.concat(batch_l, axis = 1) | |
batch_df.columns = ["entity", "entity_cate", "question",] | |
yield batch_df | |