import re import unicodedata from pathlib import Path import gradio as gr import torch import unidic_lite from bm25s.hf import BM25HF, TokenizerHF from fugashi import GenericTagger from transformers import AutoModelForPreTraining, AutoTokenizer ALIAS_SEP = "|" repo_id = "studio-ousia/luxe" revision = "ja-v0.3.1" nayose_repo_id = "studio-ousia/luxe-nayose-bm25" ignore_category_patterns = [ r"\d+年", r"楽曲 [ぁ-ん]", r"漫画作品 [ぁ-ん]", r"アニメ作品 [ぁ-ん]", r"アニメ作品 [ぁ-ん]", r"の一覧", r"各国の", r"各年の", ] model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True) num_normal_entities = len(tokenizer.entity_vocab) - model.config.num_category_entities num_category_entities = model.config.num_category_entities id2normal_entity = { entity_id: entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id < num_normal_entities } id2category_entity = { entity_id - num_normal_entities: entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id >= num_normal_entities } ignore_category_entity_ids = [ entity_id - num_normal_entities for entity, entity_id in tokenizer.entity_vocab.items() if entity_id >= num_normal_entities and any(re.search(pattern, entity) for pattern in ignore_category_patterns) ] entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight normal_entity_embeddings = entity_embeddings[:num_normal_entities] category_entity_embeddings = entity_embeddings[num_normal_entities:] class MecabTokenizer: def __init__(self): unidic_dir = unidic_lite.DICDIR mecabrc_file = Path(unidic_dir, "mecabrc") mecab_option = f"-d {unidic_dir} -r {mecabrc_file}" self.tagger = GenericTagger(mecab_option) def __call__(self, text: str) -> list[tuple[str, str, tuple[int, int]]]: outputs = [] end = 0 for node in self.tagger(text): word = node.surface.strip() pos = node.feature[0] start = text.index(word, end) end = start + len(word) outputs.append((word, pos, (start, end))) return outputs mecab_tokenizer = MecabTokenizer() def normalize_text(text: str) -> str: return unicodedata.normalize("NFKC", text) bm25_tokenizer = TokenizerHF(lower=True, splitter=tokenizer.tokenize, stopwords=None, stemmer=None) bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25") bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25") def get_texts_from_file(file_path): texts = [] with open(file_path) as f: for line in f: line = line.strip() if line: texts.append(normalize_text(line)) return texts def get_noun_spans_from_text(text: str) -> list[tuple[int, int]]: last_pos = None noun_spans = [] for word, pos, (start, end) in mecab_tokenizer(text): if pos == "名詞": if len(noun_spans) > 0 and last_pos == "名詞": noun_spans[-1] = (noun_spans[-1][0], end) else: noun_spans.append((start, end)) last_pos = pos return noun_spans def get_token_spans(text: str) -> list[tuple[int, int]]: token_spans = [] end = 0 for token in tokenizer.tokenize(text): token = token.removeprefix("##") start = text.index(token, end) end = start + len(token) token_spans.append((start, end)) return [(0, 0)] + token_spans + [(end, end)] # count for "[CLS]" and "[SEP]" def get_predicted_entity_spans( ner_logits: torch.Tensor, token_spans: list[tuple[int, int]], entity_span_sensitivity: float = 1.0 ) -> list[tuple[int, int]]: length = ner_logits.size(-1) assert ner_logits.size() == (length, length) # not batched ner_probs = torch.sigmoid(ner_logits).triu() probs_sorted, sort_idxs = ner_probs.flatten().sort(descending=True) predicted_entity_spans = [] for p, i in zip(probs_sorted, sort_idxs.tolist()): if p < 10.0 ** (-1.0 * entity_span_sensitivity): break start_idx = i // length end_idx = i % length start = token_spans[start_idx][0] end = token_spans[end_idx][1] for ex_start, ex_end in predicted_entity_spans: if not (start < end <= ex_start or ex_end <= start < end): break else: predicted_entity_spans.append((start, end)) return sorted(predicted_entity_spans) def get_topk_entities_from_texts( texts: list[str], k: int = 5, entity_span_sensitivity: float = 1.0, nayose_coef: float = 1.0 ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]: batch_entity_spans: list[list[tuple[int, int]]] = [] topk_normal_entities: list[list[str]] = [] topk_category_entities: list[list[str]] = [] topk_span_entities: list[list[list[str]]] = [] for text in texts: tokenized_examples = tokenizer(text, return_tensors="pt") model_outputs = model(**tokenized_examples) token_spans = get_token_spans(text) entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity) batch_entity_spans.append(entity_spans) tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, return_tensors="pt") model_outputs = model(**tokenized_examples) model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf") _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k) topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()]) _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k) topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()]) if model_outputs.entity_logits is not None: span_entity_logits = model_outputs.entity_logits[0, :, :500000] if nayose_coef > 0.0: nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans] nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries) nayose_scores = torch.vstack( [torch.from_numpy(bm25_retriever.get_scores(tokens)) for tokens in nayose_query_tokens] ) span_entity_logits += nayose_coef * nayose_scores _, topk_span_entity_ids = span_entity_logits.topk(k) topk_span_entities.append( [[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()] ) else: topk_span_entities.append([]) return batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities def get_selected_entity(evt: gr.SelectData): return evt.value[0] def get_similar_entities(query_entity: str, k: int = 10) -> list[str]: query_entity_id = tokenizer.entity_vocab[query_entity] if query_entity_id < num_normal_entities: topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:] topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()] else: query_entity_id -= num_normal_entities topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T topk_entity_scores[ignore_category_entity_ids] = float("-inf") topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:] topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()] return topk_entities with gr.Blocks() as demo: gr.Markdown("# 📝 LUXE Demo") gr.Markdown("## 入力テキスト") texts = gr.State([]) topk = gr.State(5) entity_span_sensitivity = gr.State(1.0) nayose_coef = gr.State(1.0) batch_entity_spans = gr.State([]) topk_normal_entities = gr.State([]) topk_category_entities = gr.State([]) topk_span_entities = gr.State([]) selected_entity = gr.State() similar_entities = gr.State([]) with gr.Tab(label="直接入力"): text_input = gr.Textbox(label="入力テキスト") with gr.Tab(label="ファイルアップロード"): texts_file = gr.File(label="入力テキストファイル") with gr.Accordion(label="ハイパーパラメータ", open=False): topk_input = gr.Number(5, label="エンティティ件数", interactive=True) entity_span_sensitivity_input = gr.Slider( minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True ) nayose_coef_input = gr.Slider( minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True ) text_input.change(fn=lambda text: [normalize_text(text)], inputs=text_input, outputs=texts) texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts) topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk) entity_span_sensitivity_input.change( fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity ) nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef) texts.change( fn=get_topk_entities_from_texts, inputs=[texts, topk, entity_span_sensitivity, nayose_coef], outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities], ) topk.change( fn=get_topk_entities_from_texts, inputs=[texts, topk, entity_span_sensitivity, nayose_coef], outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities], ) entity_span_sensitivity.change( fn=get_topk_entities_from_texts, inputs=[texts, topk, entity_span_sensitivity, nayose_coef], outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities], ) nayose_coef.change( fn=get_topk_entities_from_texts, inputs=[texts, topk, entity_span_sensitivity, nayose_coef], outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities], ) topk_input.change(inputs=topk_input, outputs=topk) gr.Markdown("---") gr.Markdown("## 出力エンティティ") @gr.render(inputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities]) def render_topk_entities( texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities ): for text, entity_spans, normal_entities, category_entities, span_entities in zip( texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities ): highlighted_text_value = [] cur = 0 for start, end in entity_spans: if cur < start: highlighted_text_value.append((text[cur:start], None)) highlighted_text_value.append((text[start:end], "Entity")) cur = end if cur < len(text): highlighted_text_value.append((text[cur:], None)) gr.HighlightedText( value=highlighted_text_value, color_map={"Entity": "green"}, combine_adjacent=False, label="Text" ) # gr.Textbox(text, label="Text") gr.Dataset( label="Topic Entities", components=["text"], samples=[[entity] for entity in normal_entities] ).select(fn=get_selected_entity, outputs=selected_entity) gr.Dataset( label="Topic Categories", components=["text"], samples=[[entity] for entity in category_entities] ).select(fn=get_selected_entity, outputs=selected_entity) span_texts = [text[start:end] for start, end in entity_spans] for span_text, entities in zip(span_texts, span_entities): gr.Dataset( label=f"Span Entities for {span_text}", components=["text"], samples=[[entity] for entity in entities], ).select(fn=get_selected_entity, outputs=selected_entity) # gr.Markdown("---") # gr.Markdown("## 選択されたエンティティの類似エンティティ") # selected_entity.change(fn=get_similar_entities, inputs=selected_entity, outputs=similar_entities) # @gr.render(inputs=[selected_entity, similar_entities]) # def render_similar_entities(selected_entity, similar_entities): # gr.Textbox(selected_entity, label="Selected Entity") # gr.Dataset(label="Similar Entities", components=["text"], samples=[[entity] for entity in similar_entities]) demo.launch()