Spaces:
Running
Running
| import re | |
| from pathlib import Path | |
| import gradio as gr | |
| import unidic_lite | |
| from fugashi import GenericTagger | |
| from transformers import AutoModelForPreTraining, AutoTokenizer | |
| repo_id = "studio-ousia/luxe" | |
| revision = "ja-v0.2" | |
| ignore_category_patterns = [ | |
| r"\d+年", | |
| r"楽曲 [ぁ-ん]", | |
| r"漫画作品 [ぁ-ん]", | |
| r"アニメ作品 [ぁ-ん]", | |
| r"アニメ作品 [ぁ-ん]", | |
| r"の一覧", | |
| r"各国の", | |
| r"各年の", | |
| ] | |
| model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True) | |
| tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True) | |
| num_normal_entities = len(tokenizer.entity_vocab) - model.config.num_category_entities | |
| num_category_entities = model.config.num_category_entities | |
| id2normal_entity = { | |
| entity_id: entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id < num_normal_entities | |
| } | |
| id2category_entity = { | |
| entity_id - num_normal_entities: entity | |
| for entity, entity_id in tokenizer.entity_vocab.items() | |
| if entity_id >= num_normal_entities | |
| } | |
| ignore_category_entity_ids = [ | |
| entity_id - num_normal_entities | |
| for entity, entity_id in tokenizer.entity_vocab.items() | |
| if entity_id >= num_normal_entities and any(re.search(pattern, entity) for pattern in ignore_category_patterns) | |
| ] | |
| entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight | |
| normal_entity_embeddings = entity_embeddings[:num_normal_entities] | |
| category_entity_embeddings = entity_embeddings[num_normal_entities:] | |
| class MecabTokenizer: | |
| def __init__(self): | |
| unidic_dir = unidic_lite.DICDIR | |
| mecabrc_file = Path(unidic_dir, "mecabrc") | |
| mecab_option = f"-d {unidic_dir} -r {mecabrc_file}" | |
| self.tagger = GenericTagger(mecab_option) | |
| def __call__(self, text: str) -> list[tuple[str, str, tuple[int, int]]]: | |
| outputs = [] | |
| end = 0 | |
| for node in self.tagger(text): | |
| word = node.surface.strip() | |
| pos = node.feature[0] | |
| start = text.index(word, end) | |
| end = start + len(word) | |
| outputs.append((word, pos, (start, end))) | |
| return outputs | |
| mecab_tokenizer = MecabTokenizer() | |
| def get_texts_from_file(file_path): | |
| texts = [] | |
| with open(file_path) as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| texts.append(line) | |
| return texts | |
| def get_noun_spans_from_text(text: str) -> list[tuple[int, int]]: | |
| last_pos = None | |
| noun_spans = [] | |
| for word, pos, (start, end) in mecab_tokenizer(text): | |
| if pos == "名詞": | |
| if len(noun_spans) > 0 and last_pos == "名詞": | |
| noun_spans[-1] = (noun_spans[-1][0], end) | |
| else: | |
| noun_spans.append((start, end)) | |
| last_pos = pos | |
| return noun_spans | |
| def get_topk_entities_from_texts( | |
| texts: list[str], k: int = 5 | |
| ) -> tuple[list[list[str]], list[list[str]], list[list[list[str]]]]: | |
| topk_normal_entities: list[list[str]] = [] | |
| topk_category_entities: list[list[str]] = [] | |
| topk_span_entities: list[list[list[str]]] = [] | |
| for text in texts: | |
| noun_spans = get_noun_spans_from_text(text) | |
| tokenized_examples = tokenizer(text, entity_spans=noun_spans, return_tensors="pt") | |
| model_outputs = model(**tokenized_examples) | |
| model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf") | |
| _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k) | |
| topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()]) | |
| _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k) | |
| topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()]) | |
| _, topk_span_entity_ids = model_outputs.entity_logits[0, :, :500000].topk(k) | |
| topk_span_entities.append([[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]) | |
| return topk_normal_entities, topk_category_entities, topk_span_entities | |
| def get_selected_entity(evt: gr.SelectData): | |
| return evt.value[0] | |
| def get_similar_entities(query_entity: str, k: int = 10) -> list[str]: | |
| query_entity_id = tokenizer.entity_vocab[query_entity] | |
| if query_entity_id < num_normal_entities: | |
| topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T | |
| topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:] | |
| topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()] | |
| else: | |
| query_entity_id -= num_normal_entities | |
| topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T | |
| topk_entity_scores[ignore_category_entity_ids] = float("-inf") | |
| topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:] | |
| topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()] | |
| return topk_entities | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## テキスト(直接入力またはファイルアップロード)") | |
| texts = gr.State([]) | |
| topk_normal_entities = gr.State([]) | |
| topk_category_entities = gr.State([]) | |
| topk_span_entities = gr.State([]) | |
| selected_entity = gr.State() | |
| similar_entities = gr.State([]) | |
| text_input = gr.Textbox(label="Input Text") | |
| texts_file = gr.File(label="Input Texts") | |
| text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts) | |
| texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts) | |
| texts.change( | |
| fn=get_topk_entities_from_texts, | |
| inputs=texts, | |
| outputs=[topk_normal_entities, topk_category_entities, topk_span_entities], | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("## 出力エンティティ") | |
| def render_topk_entities(texts, topk_normal_entities, topk_category_entities, topk_span_entities): | |
| for text, normal_entities, category_entities, span_entities in zip( | |
| texts, topk_normal_entities, topk_category_entities, topk_span_entities | |
| ): | |
| gr.HighlightedText( | |
| value=[(word, pos if pos == "名詞" else None) for word, pos, _ in mecab_tokenizer(text)], | |
| color_map={"名詞": "green"}, | |
| show_legend=True, | |
| combine_adjacent=True, | |
| adjacent_separator=" ", | |
| label="Text", | |
| ) | |
| # gr.Textbox(text, label="Text") | |
| gr.Dataset( | |
| label="Topic Entities", components=["text"], samples=[[entity] for entity in normal_entities] | |
| ).select(fn=get_selected_entity, outputs=selected_entity) | |
| gr.Dataset( | |
| label="Topic Categories", components=["text"], samples=[[entity] for entity in category_entities] | |
| ).select(fn=get_selected_entity, outputs=selected_entity) | |
| noun_spans = get_noun_spans_from_text(text) | |
| nouns = [text[start:end] for start, end in noun_spans] | |
| for noun, entities in zip(nouns, span_entities): | |
| gr.Dataset( | |
| label=f"Span Entities for {noun}", components=["text"], samples=[[entity] for entity in entities] | |
| ).select(fn=get_selected_entity, outputs=selected_entity) | |
| gr.Markdown("---") | |
| gr.Markdown("## 選択されたエンティティの類似エンティティ") | |
| selected_entity.change(fn=get_similar_entities, inputs=selected_entity, outputs=similar_entities) | |
| def render_similar_entities(selected_entity, similar_entities): | |
| gr.Textbox(selected_entity, label="Selected Entity") | |
| gr.Dataset(label="Similar Entities", components=["text"], samples=[[entity] for entity in similar_entities]) | |
| demo.launch() | |