Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import re | |
import unicodedata | |
from pathlib import Path | |
import gradio as gr | |
import torch | |
import unidic_lite | |
from fugashi import GenericTagger | |
from transformers import AutoModelForPreTraining, AutoTokenizer | |
repo_id = "studio-ousia/luxe" | |
revision = "ja-v0.3" | |
ignore_category_patterns = [ | |
r"\d+年", | |
r"楽曲 [ぁ-ん]", | |
r"漫画作品 [ぁ-ん]", | |
r"アニメ作品 [ぁ-ん]", | |
r"アニメ作品 [ぁ-ん]", | |
r"の一覧", | |
r"各国の", | |
r"各年の", | |
] | |
model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True) | |
num_normal_entities = len(tokenizer.entity_vocab) - model.config.num_category_entities | |
num_category_entities = model.config.num_category_entities | |
id2normal_entity = { | |
entity_id: entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id < num_normal_entities | |
} | |
id2category_entity = { | |
entity_id - num_normal_entities: entity | |
for entity, entity_id in tokenizer.entity_vocab.items() | |
if entity_id >= num_normal_entities | |
} | |
ignore_category_entity_ids = [ | |
entity_id - num_normal_entities | |
for entity, entity_id in tokenizer.entity_vocab.items() | |
if entity_id >= num_normal_entities and any(re.search(pattern, entity) for pattern in ignore_category_patterns) | |
] | |
entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight | |
normal_entity_embeddings = entity_embeddings[:num_normal_entities] | |
category_entity_embeddings = entity_embeddings[num_normal_entities:] | |
class MecabTokenizer: | |
def __init__(self): | |
unidic_dir = unidic_lite.DICDIR | |
mecabrc_file = Path(unidic_dir, "mecabrc") | |
mecab_option = f"-d {unidic_dir} -r {mecabrc_file}" | |
self.tagger = GenericTagger(mecab_option) | |
def __call__(self, text: str) -> list[tuple[str, str, tuple[int, int]]]: | |
outputs = [] | |
end = 0 | |
for node in self.tagger(text): | |
word = node.surface.strip() | |
pos = node.feature[0] | |
start = text.index(word, end) | |
end = start + len(word) | |
outputs.append((word, pos, (start, end))) | |
return outputs | |
mecab_tokenizer = MecabTokenizer() | |
def normalize_text(text: str) -> str: | |
return unicodedata.normalize("NFKC", text) | |
def get_texts_from_file(file_path): | |
texts = [] | |
with open(file_path) as f: | |
for line in f: | |
line = line.strip() | |
if line: | |
texts.append(normalize_text(line)) | |
return texts | |
def get_noun_spans_from_text(text: str) -> list[tuple[int, int]]: | |
last_pos = None | |
noun_spans = [] | |
for word, pos, (start, end) in mecab_tokenizer(text): | |
if pos == "名詞": | |
if len(noun_spans) > 0 and last_pos == "名詞": | |
noun_spans[-1] = (noun_spans[-1][0], end) | |
else: | |
noun_spans.append((start, end)) | |
last_pos = pos | |
return noun_spans | |
def get_token_spans(text: str) -> list[tuple[int, int]]: | |
token_spans = [] | |
end = 0 | |
for token in tokenizer.tokenize(text): | |
token = token.removeprefix("##") | |
start = text.index(token, end) | |
end = start + len(token) | |
token_spans.append((start, end)) | |
return [(0, 0)] + token_spans + [(end, end)] # count for "[CLS]" and "[SEP]" | |
def get_predicted_entity_spans( | |
ner_logits: torch.Tensor, token_spans: list[tuple[int, int]], entity_span_sensitivity: float = 1.0 | |
) -> list[tuple[int, int]]: | |
length = ner_logits.size(-1) | |
assert ner_logits.size() == (length, length) # not batched | |
ner_probs = torch.sigmoid(ner_logits).triu() | |
probs_sorted, sort_idxs = ner_probs.flatten().sort(descending=True) | |
predicted_entity_spans = [] | |
for p, i in zip(probs_sorted, sort_idxs.tolist()): | |
if p < 10.0 ** (-1.0 * entity_span_sensitivity): | |
break | |
start_idx = i // length | |
end_idx = i % length | |
start = token_spans[start_idx][0] | |
end = token_spans[end_idx][1] | |
for ex_start, ex_end in predicted_entity_spans: | |
if not (start < end <= ex_start or ex_end <= start < end): | |
break | |
else: | |
predicted_entity_spans.append((start, end)) | |
return sorted(predicted_entity_spans) | |
def get_topk_entities_from_texts( | |
texts: list[str], k: int = 5, entity_span_sensitivity: float = 1.0 | |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]: | |
batch_entity_spans: list[list[tuple[int, int]]] = [] | |
topk_normal_entities: list[list[str]] = [] | |
topk_category_entities: list[list[str]] = [] | |
topk_span_entities: list[list[list[str]]] = [] | |
for text in texts: | |
tokenized_examples = tokenizer(text, return_tensors="pt") | |
model_outputs = model(**tokenized_examples) | |
token_spans = get_token_spans(text) | |
entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity) | |
batch_entity_spans.append(entity_spans) | |
tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, return_tensors="pt") | |
model_outputs = model(**tokenized_examples) | |
model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf") | |
_, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k) | |
topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()]) | |
_, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k) | |
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()]) | |
if model_outputs.entity_logits is not None: | |
_, topk_span_entity_ids = model_outputs.entity_logits[0, :, :500000].topk(k) | |
topk_span_entities.append( | |
[[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()] | |
) | |
else: | |
topk_span_entities.append([]) | |
return batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities | |
def get_selected_entity(evt: gr.SelectData): | |
return evt.value[0] | |
def get_similar_entities(query_entity: str, k: int = 10) -> list[str]: | |
query_entity_id = tokenizer.entity_vocab[query_entity] | |
if query_entity_id < num_normal_entities: | |
topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T | |
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:] | |
topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()] | |
else: | |
query_entity_id -= num_normal_entities | |
topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T | |
topk_entity_scores[ignore_category_entity_ids] = float("-inf") | |
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:] | |
topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()] | |
return topk_entities | |
with gr.Blocks() as demo: | |
gr.Markdown("## テキスト(直接入力またはファイルアップロード)") | |
texts = gr.State([]) | |
topk = gr.State(5) | |
entity_span_sensitivity = gr.State(1.0) | |
batch_entity_spans = gr.State([]) | |
topk_normal_entities = gr.State([]) | |
topk_category_entities = gr.State([]) | |
topk_span_entities = gr.State([]) | |
selected_entity = gr.State() | |
similar_entities = gr.State([]) | |
text_input = gr.Textbox(label="Input Text") | |
text_input.change(fn=lambda text: [normalize_text(text)], inputs=text_input, outputs=texts) | |
texts_file = gr.File(label="Input Texts") | |
texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts) | |
topk_input = gr.Number(5, label="Top K", interactive=True) | |
topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk) | |
entity_span_sensitivity_input = gr.Slider( | |
minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="Entity Span Sensitivity", interactive=True | |
) | |
entity_span_sensitivity_input.change( | |
fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity | |
) | |
texts.change( | |
fn=get_topk_entities_from_texts, | |
inputs=[texts, topk, entity_span_sensitivity], | |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities], | |
) | |
topk.change( | |
fn=get_topk_entities_from_texts, | |
inputs=[texts, topk, entity_span_sensitivity], | |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities], | |
) | |
entity_span_sensitivity.change( | |
fn=get_topk_entities_from_texts, | |
inputs=[texts, topk, entity_span_sensitivity], | |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities], | |
) | |
topk_input.change(inputs=topk_input, outputs=topk) | |
gr.Markdown("---") | |
gr.Markdown("## 出力エンティティ") | |
def render_topk_entities( | |
texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities | |
): | |
for text, entity_spans, normal_entities, category_entities, span_entities in zip( | |
texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities | |
): | |
highlighted_text_value = [] | |
cur = 0 | |
for start, end in entity_spans: | |
if cur < start: | |
highlighted_text_value.append((text[cur:start], None)) | |
highlighted_text_value.append((text[start:end], "Entity")) | |
cur = end | |
if cur < len(text): | |
highlighted_text_value.append((text[cur:], None)) | |
gr.HighlightedText( | |
value=highlighted_text_value, color_map={"Entity": "green"}, combine_adjacent=False, label="Text" | |
) | |
# gr.Textbox(text, label="Text") | |
gr.Dataset( | |
label="Topic Entities", components=["text"], samples=[[entity] for entity in normal_entities] | |
).select(fn=get_selected_entity, outputs=selected_entity) | |
gr.Dataset( | |
label="Topic Categories", components=["text"], samples=[[entity] for entity in category_entities] | |
).select(fn=get_selected_entity, outputs=selected_entity) | |
span_texts = [text[start:end] for start, end in entity_spans] | |
for span_text, entities in zip(span_texts, span_entities): | |
gr.Dataset( | |
label=f"Span Entities for {span_text}", | |
components=["text"], | |
samples=[[entity] for entity in entities], | |
).select(fn=get_selected_entity, outputs=selected_entity) | |
# gr.Markdown("---") | |
# gr.Markdown("## 選択されたエンティティの類似エンティティ") | |
# selected_entity.change(fn=get_similar_entities, inputs=selected_entity, outputs=similar_entities) | |
# @gr.render(inputs=[selected_entity, similar_entities]) | |
# def render_similar_entities(selected_entity, similar_entities): | |
# gr.Textbox(selected_entity, label="Selected Entity") | |
# gr.Dataset(label="Similar Entities", components=["text"], samples=[[entity] for entity in similar_entities]) | |
demo.launch() | |