#### py39_cp_cp from zh_mt5_model import * from en_t2t_model import * import os os.environ['KMP_DUPLICATE_LIB_OK']='True' import spacy import pandas as pd import numpy as np import re from tqdm import tqdm from copy import deepcopy import pathlib import json import pickle as pkl from tqdm import tqdm from easynmt import EasyNMT ### https://huggingface.co/svjack/squad_gen_qst_zh_v0 path = "svjack/squad_gen_qst_zh_v0" asker_zh = T5_B(path, device = "cpu") zh_nlp = spacy.load("zh_core_web_sm") en_nlp = spacy.load("en_core_web_sm") trans_model = EasyNMT('opus-mt') def detect_language(text): assert type(text) == type("") # detect_list.append(trans_model.language_detection_fasttext(prompt)) lang = trans_model.language_detection_fasttext(text) lang = lang.lower().strip() if "zh" not in lang and "en" not in lang: lang = "others" if "zh" in lang: lang = "zh" if "en" in lang: lang = "en" assert lang in ["en", "zh", "others"] return lang def drop_duplicates_by_col(df, on_col = "aug_sparql_query"): assert hasattr(df, "size") assert on_col in df.columns.tolist() req = [] set_ = set([]) for i, r in df.iterrows(): if r[on_col] not in set_: set_.add(r[on_col]) req.append(r) return pd.DataFrame(req) def sent_with_ents(sent, en_nlp): assert type(sent) == type("") doc = en_nlp(sent) return (sent, pd.Series(doc.ents).map( lambda span: (span.text, span.label_) ).values.tolist()) def gen_ask_by_span_zh(asker ,sent, span): if type(span) == type(""): span = [span] if not span: return [] sent = sent.replace("|", "") span = list(map(lambda x: x.replace("|", ""), span)) x = list(map(lambda x: "{}|{}".format(sent, x), span)) return list(map( lambda y: asker.predict(y) , x)) #### list return def gen_ask_by_span(asker, sent, span, lang): assert lang in ["en", "zh"] if lang == "zh": return gen_ask_by_span_zh(asker ,sent, span) else: return gen_ask_by_span_en(t2t, sent, span) def filter_ent_cate(ent_list, maintain_cate_list = [ "DATE", "FAC", "GPE", "LOC", "PERSON" ]): if not ent_list: return [] return list(filter(lambda t2: t2[1] in maintain_cate_list, ent_list)) def batch_as_list(a, batch_size = int(100000)): req = [] for ele in a: if not req: req.append([]) if len(req[-1]) < batch_size: req[-1].append(ele) else: req.append([]) req[-1].append(ele) return req def gen_qst_to_df(paragraph, nlp = zh_nlp, asker = asker_zh, nlp_input = None, maintain_cate_list = [ "DATE", "FAC", "GPE", "LOC", "PERSON" ], limit_ents_size = 10, batch_size = 4 ): if limit_ents_size is None: limit_ents_size = 10000 assert type(paragraph) == type("") lang = detect_language(paragraph) if lang != "zh": lang = "en" nlp = en_nlp if lang == "en" else zh_nlp if nlp_input is None: _, entity_list = sent_with_ents(paragraph, nlp) else: _, entity_list = deepcopy(nlp_input) if maintain_cate_list: entity_list = filter_ent_cate(entity_list, maintain_cate_list = maintain_cate_list) entity_list = entity_list[:limit_ents_size] if not entity_list: return None l = batch_as_list(entity_list, batch_size) for ele in tqdm(l): ents = list(map(lambda x: x[0], ele)) ent_cates = list(map(lambda x: x[1], ele)) #questions = gen_ask_by_span_zh(asker, paragraph, ents) questions = gen_ask_by_span(asker, paragraph, ents, lang) assert len(ele) == len(ent_cates) == len(questions) #return [ele, ent_cates, questions, ans] batch_l = list(map(pd.Series, [ents, ent_cates, questions])) batch_df = pd.concat(batch_l, axis = 1) batch_df.columns = ["entity", "entity_cate", "question",] yield batch_df