File size: 4,716 Bytes
dde7d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from transformers import AutoModelForPreTraining, AutoTokenizer


repo_id = "studio-ousia/luxe"
revision = "ja-v0.1"

model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)

num_normal_entities = len(tokenizer.entity_vocab) - model.config.num_category_entities
num_category_entities = model.config.num_category_entities

id2normal_entity = {
    entity_id: entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id < num_normal_entities
}

id2category_entity = {
    entity_id - num_normal_entities: entity
    for entity, entity_id in tokenizer.entity_vocab.items()
    if entity_id >= num_normal_entities
}

entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
normal_entity_embeddings = entity_embeddings[:num_normal_entities]
category_entity_embeddings = entity_embeddings[num_normal_entities:]


def get_texts_from_file(file_path):
    texts = []
    with open(file_path) as f:
        for line in f:
            line = line.strip()
            if line:
                texts.append(line)

    return texts


def get_topk_entities_from_texts(texts: list[str], k: int = 5) -> tuple[list[list[str]], list[list[str]]]:
    topk_normal_entities = []
    topk_category_entities = []

    for text in texts:
        tokenized_examples = tokenizer(text, return_tensors="pt")
        model_outputs = model(**tokenized_examples)

        _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
        topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])

        _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
        topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])

    return topk_normal_entities, topk_category_entities


def get_selected_entity(evt: gr.SelectData):
    return evt.value[0]


def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
    query_entity_id = tokenizer.entity_vocab[query_entity]

    if query_entity_id < num_normal_entities:
        topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T
        topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
        topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
    else:
        query_entity_id -= num_normal_entities
        topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
        topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
        topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()]

    return topk_entities


with gr.Blocks() as demo:
    gr.Markdown("## テキスト(直接入力またはファイルアップロード)")

    texts = gr.State([])
    topk_normal_entities = gr.State([])
    topk_category_entities = gr.State([])
    selected_entity = gr.State()
    similar_entities = gr.State([])

    text_input = gr.Textbox(label="Input Text")
    texts_file = gr.File(label="Input texts")

    text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts)
    texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
    texts.change(fn=get_topk_entities_from_texts, inputs=texts, outputs=[topk_normal_entities, topk_category_entities])

    gr.Markdown("---")
    gr.Markdown("## 出力エンティティ")

    @gr.render(inputs=[texts, topk_normal_entities, topk_category_entities])
    def render_topk_entities(texts, topk_normal_entities, topk_category_entities):
        for text, normal_entities, category_entities in zip(texts, topk_normal_entities, topk_category_entities):
            gr.Textbox(text, label="Text")
            entities = gr.Dataset(
                label="Entities",
                components=["text"],
                samples=[[entity] for entity in normal_entities + category_entities],
            )
            entities.select(fn=get_selected_entity, outputs=selected_entity)

        gr.Markdown("---")
        gr.Markdown("## 選択されたエンティティの類似エンティティ")

    selected_entity.change(fn=get_similar_entities, inputs=selected_entity, outputs=similar_entities)

    @gr.render(inputs=[selected_entity, similar_entities])
    def render_similar_entities(selected_entity, similar_entities):
        gr.Textbox(selected_entity, label="Selected Entity")
        gr.Dataset(label="Similar Entities", components=["text"], samples=[[entity] for entity in similar_entities])


demo.launch()