Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import gradio as gr | |
from transformers import AutoModelForPreTraining, AutoTokenizer | |
repo_id = "studio-ousia/luxe" | |
revision = "ja-v0.1" | |
model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True) | |
num_normal_entities = len(tokenizer.entity_vocab) - model.config.num_category_entities | |
num_category_entities = model.config.num_category_entities | |
id2normal_entity = { | |
entity_id: entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id < num_normal_entities | |
} | |
id2category_entity = { | |
entity_id - num_normal_entities: entity | |
for entity, entity_id in tokenizer.entity_vocab.items() | |
if entity_id >= num_normal_entities | |
} | |
entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight | |
normal_entity_embeddings = entity_embeddings[:num_normal_entities] | |
category_entity_embeddings = entity_embeddings[num_normal_entities:] | |
def get_texts_from_file(file_path): | |
texts = [] | |
with open(file_path) as f: | |
for line in f: | |
line = line.strip() | |
if line: | |
texts.append(line) | |
return texts | |
def get_topk_entities_from_texts(texts: list[str], k: int = 5) -> tuple[list[list[str]], list[list[str]]]: | |
topk_normal_entities = [] | |
topk_category_entities = [] | |
for text in texts: | |
tokenized_examples = tokenizer(text, return_tensors="pt") | |
model_outputs = model(**tokenized_examples) | |
_, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k) | |
topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()]) | |
_, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k) | |
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()]) | |
return topk_normal_entities, topk_category_entities | |
def get_selected_entity(evt: gr.SelectData): | |
return evt.value[0] | |
def get_similar_entities(query_entity: str, k: int = 10) -> list[str]: | |
query_entity_id = tokenizer.entity_vocab[query_entity] | |
if query_entity_id < num_normal_entities: | |
topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T | |
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:] | |
topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()] | |
else: | |
query_entity_id -= num_normal_entities | |
topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T | |
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:] | |
topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()] | |
return topk_entities | |
with gr.Blocks() as demo: | |
gr.Markdown("## テキスト(直接入力またはファイルアップロード)") | |
texts = gr.State([]) | |
topk_normal_entities = gr.State([]) | |
topk_category_entities = gr.State([]) | |
selected_entity = gr.State() | |
similar_entities = gr.State([]) | |
text_input = gr.Textbox(label="Input Text") | |
texts_file = gr.File(label="Input texts") | |
text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts) | |
texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts) | |
texts.change(fn=get_topk_entities_from_texts, inputs=texts, outputs=[topk_normal_entities, topk_category_entities]) | |
gr.Markdown("---") | |
gr.Markdown("## 出力エンティティ") | |
def render_topk_entities(texts, topk_normal_entities, topk_category_entities): | |
for text, normal_entities, category_entities in zip(texts, topk_normal_entities, topk_category_entities): | |
gr.Textbox(text, label="Text") | |
entities = gr.Dataset( | |
label="Entities", | |
components=["text"], | |
samples=[[entity] for entity in normal_entities + category_entities], | |
) | |
entities.select(fn=get_selected_entity, outputs=selected_entity) | |
gr.Markdown("---") | |
gr.Markdown("## 選択されたエンティティの類似エンティティ") | |
selected_entity.change(fn=get_similar_entities, inputs=selected_entity, outputs=similar_entities) | |
def render_similar_entities(selected_entity, similar_entities): | |
gr.Textbox(selected_entity, label="Selected Entity") | |
gr.Dataset(label="Similar Entities", components=["text"], samples=[[entity] for entity in similar_entities]) | |
demo.launch() | |