luxe-demo / app.py
singletongue's picture
Use ja-v0.2 model, ignore categories of some patterns
c5df237 verified
raw
history blame
8.02 kB
import re
from pathlib import Path
import gradio as gr
import unidic_lite
from fugashi import GenericTagger
from transformers import AutoModelForPreTraining, AutoTokenizer
repo_id = "studio-ousia/luxe"
revision = "ja-v0.2"
ignore_category_patterns = [
r"\d+年",
r"楽曲 [ぁ-ん]",
r"漫画作品 [ぁ-ん]",
r"アニメ作品 [ぁ-ん]",
r"アニメ作品 [ぁ-ん]",
r"の一覧",
r"各国の",
r"各年の",
]
model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
num_normal_entities = len(tokenizer.entity_vocab) - model.config.num_category_entities
num_category_entities = model.config.num_category_entities
id2normal_entity = {
entity_id: entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id < num_normal_entities
}
id2category_entity = {
entity_id - num_normal_entities: entity
for entity, entity_id in tokenizer.entity_vocab.items()
if entity_id >= num_normal_entities
}
ignore_category_entity_ids = [
entity_id - num_normal_entities
for entity, entity_id in tokenizer.entity_vocab.items()
if entity_id >= num_normal_entities and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
]
entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
normal_entity_embeddings = entity_embeddings[:num_normal_entities]
category_entity_embeddings = entity_embeddings[num_normal_entities:]
class MecabTokenizer:
def __init__(self):
unidic_dir = unidic_lite.DICDIR
mecabrc_file = Path(unidic_dir, "mecabrc")
mecab_option = f"-d {unidic_dir} -r {mecabrc_file}"
self.tagger = GenericTagger(mecab_option)
def __call__(self, text: str) -> list[tuple[str, str, tuple[int, int]]]:
outputs = []
end = 0
for node in self.tagger(text):
word = node.surface.strip()
pos = node.feature[0]
start = text.index(word, end)
end = start + len(word)
outputs.append((word, pos, (start, end)))
return outputs
mecab_tokenizer = MecabTokenizer()
def get_texts_from_file(file_path):
texts = []
with open(file_path) as f:
for line in f:
line = line.strip()
if line:
texts.append(line)
return texts
def get_noun_spans_from_text(text: str) -> list[tuple[int, int]]:
last_pos = None
noun_spans = []
for word, pos, (start, end) in mecab_tokenizer(text):
if pos == "名詞":
if len(noun_spans) > 0 and last_pos == "名詞":
noun_spans[-1] = (noun_spans[-1][0], end)
else:
noun_spans.append((start, end))
last_pos = pos
return noun_spans
def get_topk_entities_from_texts(
texts: list[str], k: int = 5
) -> tuple[list[list[str]], list[list[str]], list[list[list[str]]]]:
topk_normal_entities: list[list[str]] = []
topk_category_entities: list[list[str]] = []
topk_span_entities: list[list[list[str]]] = []
for text in texts:
noun_spans = get_noun_spans_from_text(text)
tokenized_examples = tokenizer(text, entity_spans=noun_spans, return_tensors="pt")
model_outputs = model(**tokenized_examples)
model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
_, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
_, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
_, topk_span_entity_ids = model_outputs.entity_logits[0, :, :500000].topk(k)
topk_span_entities.append([[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()])
return topk_normal_entities, topk_category_entities, topk_span_entities
def get_selected_entity(evt: gr.SelectData):
return evt.value[0]
def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
query_entity_id = tokenizer.entity_vocab[query_entity]
if query_entity_id < num_normal_entities:
topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
else:
query_entity_id -= num_normal_entities
topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
topk_entity_scores[ignore_category_entity_ids] = float("-inf")
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
return topk_entities
with gr.Blocks() as demo:
gr.Markdown("## テキスト(直接入力またはファイルアップロード)")
texts = gr.State([])
topk_normal_entities = gr.State([])
topk_category_entities = gr.State([])
topk_span_entities = gr.State([])
selected_entity = gr.State()
similar_entities = gr.State([])
text_input = gr.Textbox(label="Input Text")
texts_file = gr.File(label="Input Texts")
text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts)
texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
texts.change(
fn=get_topk_entities_from_texts,
inputs=texts,
outputs=[topk_normal_entities, topk_category_entities, topk_span_entities],
)
gr.Markdown("---")
gr.Markdown("## 出力エンティティ")
@gr.render(inputs=[texts, topk_normal_entities, topk_category_entities, topk_span_entities])
def render_topk_entities(texts, topk_normal_entities, topk_category_entities, topk_span_entities):
for text, normal_entities, category_entities, span_entities in zip(
texts, topk_normal_entities, topk_category_entities, topk_span_entities
):
gr.HighlightedText(
value=[(word, pos if pos == "名詞" else None) for word, pos, _ in mecab_tokenizer(text)],
color_map={"名詞": "green"},
show_legend=True,
combine_adjacent=True,
adjacent_separator=" ",
label="Text",
)
# gr.Textbox(text, label="Text")
gr.Dataset(
label="Topic Entities", components=["text"], samples=[[entity] for entity in normal_entities]
).select(fn=get_selected_entity, outputs=selected_entity)
gr.Dataset(
label="Topic Categories", components=["text"], samples=[[entity] for entity in category_entities]
).select(fn=get_selected_entity, outputs=selected_entity)
noun_spans = get_noun_spans_from_text(text)
nouns = [text[start:end] for start, end in noun_spans]
for noun, entities in zip(nouns, span_entities):
gr.Dataset(
label=f"Span Entities for {noun}", components=["text"], samples=[[entity] for entity in entities]
).select(fn=get_selected_entity, outputs=selected_entity)
gr.Markdown("---")
gr.Markdown("## 選択されたエンティティの類似エンティティ")
selected_entity.change(fn=get_similar_entities, inputs=selected_entity, outputs=similar_entities)
@gr.render(inputs=[selected_entity, similar_entities])
def render_similar_entities(selected_entity, similar_entities):
gr.Textbox(selected_entity, label="Selected Entity")
gr.Dataset(label="Similar Entities", components=["text"], samples=[[entity] for entity in similar_entities])
demo.launch()