Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Change MAX_TEXT_FILE_LINES to 10, clean up entity names, modify some UI components
Browse files
app.py
CHANGED
@@ -2,22 +2,21 @@ import csv
|
|
2 |
import re
|
3 |
import unicodedata
|
4 |
from collections import defaultdict
|
5 |
-
from
|
6 |
|
7 |
import gradio as gr
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
-
import unidic_lite
|
12 |
from bm25s.hf import BM25HF, TokenizerHF
|
13 |
-
from fugashi import GenericTagger
|
14 |
from transformers import AutoModelForPreTraining, AutoTokenizer
|
15 |
|
16 |
|
17 |
ALIAS_SEP = "|"
|
|
|
18 |
ENTITY_SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[MASK]", "[MASK2]"]
|
19 |
MAX_TEXT_LENGTH = 800
|
20 |
-
MAX_TEXT_FILE_LINES =
|
21 |
MAX_ENTITY_FILE_LINES = 1000
|
22 |
|
23 |
repo_id = "studio-ousia/luxe"
|
@@ -37,32 +36,21 @@ ignore_category_patterns = [
|
|
37 |
]
|
38 |
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
46 |
|
47 |
-
|
48 |
-
outputs = []
|
49 |
|
50 |
-
|
51 |
-
for node in self.tagger(text):
|
52 |
-
word = node.surface.strip()
|
53 |
-
pos = node.feature[0]
|
54 |
-
start = text.index(word, end)
|
55 |
-
end = start + len(word)
|
56 |
-
outputs.append((word, pos, (start, end)))
|
57 |
-
|
58 |
-
return outputs
|
59 |
-
|
60 |
-
|
61 |
-
mecab_tokenizer = MecabTokenizer()
|
62 |
|
63 |
|
64 |
def normalize_text(text: str) -> str:
|
65 |
-
return unicodedata.normalize("NFKC", text)
|
66 |
|
67 |
|
68 |
def get_texts_from_file(file_path: str | None):
|
@@ -73,36 +61,20 @@ def get_texts_from_file(file_path: str | None):
|
|
73 |
reader = csv.DictReader(f, fieldnames=["text"])
|
74 |
for i, row in enumerate(reader):
|
75 |
if i >= MAX_TEXT_FILE_LINES:
|
76 |
-
gr.Info(f"{MAX_TEXT_FILE_LINES}行目までのデータを読み込みました。")
|
77 |
break
|
78 |
|
79 |
-
text =
|
80 |
-
if text != "":
|
81 |
texts.append(text[:MAX_TEXT_LENGTH])
|
82 |
except Exception as e:
|
83 |
-
gr.Warning("ファイルを正しく読み込めませんでした。")
|
84 |
print(e)
|
85 |
texts = []
|
86 |
|
87 |
return texts
|
88 |
|
89 |
|
90 |
-
def get_noun_spans_from_text(text: str) -> list[tuple[int, int]]:
|
91 |
-
last_pos = None
|
92 |
-
noun_spans = []
|
93 |
-
|
94 |
-
for word, pos, (start, end) in mecab_tokenizer(text):
|
95 |
-
if pos == "名詞":
|
96 |
-
if len(noun_spans) > 0 and last_pos == "名詞":
|
97 |
-
noun_spans[-1] = (noun_spans[-1][0], end)
|
98 |
-
else:
|
99 |
-
noun_spans.append((start, end))
|
100 |
-
|
101 |
-
last_pos = pos
|
102 |
-
|
103 |
-
return noun_spans
|
104 |
-
|
105 |
-
|
106 |
def get_token_spans(tokenizer, text: str) -> list[tuple[int, int]]:
|
107 |
token_spans = []
|
108 |
end = 0
|
@@ -147,12 +119,17 @@ def get_predicted_entity_spans(
|
|
147 |
|
148 |
def get_topk_entities_from_texts(
|
149 |
models,
|
150 |
-
texts: list[str],
|
151 |
k: int = 5,
|
152 |
entity_span_sensitivity: float = 1.0,
|
153 |
nayose_coef: float = 1.0,
|
154 |
entity_replaced_counts: bool = False,
|
155 |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
|
|
|
|
|
|
|
|
|
|
|
156 |
model, tokenizer, bm25_tokenizer, bm25_retriever = models
|
157 |
|
158 |
batch_entity_spans: list[list[tuple[int, int]]] = []
|
@@ -177,7 +154,12 @@ def get_topk_entities_from_texts(
|
|
177 |
and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
|
178 |
]
|
179 |
|
|
|
|
|
|
|
180 |
for text in texts:
|
|
|
|
|
181 |
tokenized_examples = tokenizer(text, return_tensors="pt")
|
182 |
model_outputs = model(**tokenized_examples)
|
183 |
token_spans = get_token_spans(tokenizer, text)
|
@@ -188,14 +170,14 @@ def get_topk_entities_from_texts(
|
|
188 |
model_outputs = model(**tokenized_examples)
|
189 |
|
190 |
if model_outputs.topic_entity_logits is not None:
|
191 |
-
_, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(
|
192 |
topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
|
193 |
else:
|
194 |
topk_normal_entities.append([])
|
195 |
|
196 |
if model_outputs.topic_category_logits is not None:
|
197 |
model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
|
198 |
-
_, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(
|
199 |
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
|
200 |
else:
|
201 |
topk_category_entities.append([])
|
@@ -211,7 +193,7 @@ def get_topk_entities_from_texts(
|
|
211 |
)
|
212 |
span_entity_logits += nayose_coef * nayose_scores
|
213 |
|
214 |
-
_, topk_span_entity_ids = span_entity_logits.topk(
|
215 |
topk_span_entities.append(
|
216 |
[[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]
|
217 |
)
|
@@ -221,51 +203,6 @@ def get_topk_entities_from_texts(
|
|
221 |
return texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
|
222 |
|
223 |
|
224 |
-
def get_selected_entity(evt: gr.SelectData):
|
225 |
-
return evt.value[0]
|
226 |
-
|
227 |
-
|
228 |
-
def get_similar_entities(models, query_entity: str, k: int = 10) -> list[str]:
|
229 |
-
model, tokenizer, _, _ = models
|
230 |
-
|
231 |
-
query_entity_id = tokenizer.entity_vocab[query_entity]
|
232 |
-
|
233 |
-
id2normal_entity = {
|
234 |
-
entity_id: entity
|
235 |
-
for entity, entity_id in tokenizer.entity_vocab.items()
|
236 |
-
if entity_id < model.config.num_normal_entities
|
237 |
-
}
|
238 |
-
id2category_entity = {
|
239 |
-
entity_id - model.config.num_normal_entities: entity
|
240 |
-
for entity, entity_id in tokenizer.entity_vocab.items()
|
241 |
-
if entity_id >= model.config.num_normal_entities
|
242 |
-
}
|
243 |
-
ignore_category_entity_ids = [
|
244 |
-
entity_id - model.config.num_normal_entities
|
245 |
-
for entity, entity_id in tokenizer.entity_vocab.items()
|
246 |
-
if entity_id >= model.config.num_normal_entities
|
247 |
-
and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
|
248 |
-
]
|
249 |
-
entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
|
250 |
-
normal_entity_embeddings = entity_embeddings[: model.config.num_normal_entities]
|
251 |
-
category_entity_embeddings = entity_embeddings[model.config.num_normal_entities :]
|
252 |
-
|
253 |
-
if query_entity_id < model.config.num_normal_entities:
|
254 |
-
topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T
|
255 |
-
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
|
256 |
-
topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
|
257 |
-
else:
|
258 |
-
query_entity_id -= model.config.num_normal_entities
|
259 |
-
topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
|
260 |
-
|
261 |
-
topk_entity_scores[ignore_category_entity_ids] = float("-inf")
|
262 |
-
|
263 |
-
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
|
264 |
-
topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
|
265 |
-
|
266 |
-
return topk_entities
|
267 |
-
|
268 |
-
|
269 |
def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]]:
|
270 |
new_entity_text_pairs = []
|
271 |
if file_path is not None:
|
@@ -274,7 +211,7 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
|
|
274 |
reader = csv.DictReader(f, fieldnames=["entity", "text"])
|
275 |
for i, row in enumerate(reader):
|
276 |
if i >= MAX_ENTITY_FILE_LINES:
|
277 |
-
gr.Info(f"{MAX_ENTITY_FILE_LINES}行目までのデータを読み込みました。")
|
278 |
break
|
279 |
|
280 |
entity = normalize_text(row["entity"]).strip()
|
@@ -282,7 +219,7 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
|
|
282 |
if entity != "" and text != "":
|
283 |
new_entity_text_pairs.append([entity, text])
|
284 |
except Exception as e:
|
285 |
-
gr.Warning("ファイルを正しく読み込めませんでした。")
|
286 |
print(e)
|
287 |
new_entity_text_pairs = []
|
288 |
|
@@ -290,90 +227,109 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
|
|
290 |
|
291 |
|
292 |
def replace_entities(
|
293 |
-
models,
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
new_entity_counts: list[int] | None = None,
|
298 |
-
new_padding_idx: int = 0,
|
299 |
-
) -> True:
|
300 |
-
model, tokenizer, bm25_tokenizer, bm25_retriever = models
|
301 |
-
|
302 |
-
gr.Info("トークナイザのエンティティの語彙を置き換えています...", duration=5)
|
303 |
-
new_entity_tokens = ENTITY_SPECIAL_TOKENS + [entity for entity, _ in new_entity_text_pairs]
|
304 |
-
|
305 |
-
new_entity_vocab = {}
|
306 |
-
for entity in new_entity_tokens:
|
307 |
-
if entity not in new_entity_vocab:
|
308 |
-
new_entity_vocab[entity] = len(new_entity_vocab)
|
309 |
|
310 |
-
|
311 |
|
312 |
-
tokenizer
|
313 |
-
tokenizer.entity_pad_token_id = tokenizer.entity_vocab["[PAD]"]
|
314 |
-
tokenizer.entity_unk_token_id = tokenizer.entity_vocab["[UNK]"]
|
315 |
-
tokenizer.entity_mask_token_id = tokenizer.entity_vocab["[MASK]"]
|
316 |
-
tokenizer.entity_mask2_token_id = tokenizer.entity_vocab["[MASK2]"]
|
317 |
-
|
318 |
-
gr.Info("モデルのエンティティの埋め込みを置き換えています...", duration=5)
|
319 |
-
new_entity_embeddings_dict = defaultdict(list)
|
320 |
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
for entity, text in new_entity_text_pairs:
|
328 |
-
entity_id = tokenizer.entity_vocab[entity]
|
329 |
tokenized_inputs = tokenizer(text[:MAX_TEXT_LENGTH], return_tensors="pt")
|
330 |
model_outputs = model(**tokenized_inputs)
|
331 |
-
|
332 |
-
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
-
|
|
|
335 |
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
|
|
|
|
344 |
|
345 |
-
|
|
|
|
|
|
|
|
|
346 |
|
347 |
-
|
348 |
-
raise ValueError("All items in new_entity_counts must be greater than zero")
|
349 |
|
350 |
if model.config.normalize_entity_embeddings:
|
351 |
-
|
352 |
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
|
|
|
|
357 |
device=model.luke.entity_embeddings.entity_embeddings.weight.device,
|
358 |
dtype=model.luke.entity_embeddings.entity_embeddings.weight.dtype,
|
359 |
)
|
360 |
-
|
361 |
-
model.luke.entity_embeddings.entity_embeddings =
|
362 |
|
363 |
-
|
364 |
-
model.entity_predictions.decoder =
|
365 |
-
model.entity_predictions.bias = nn.Parameter(torch.zeros(
|
366 |
model.tie_weights()
|
367 |
|
368 |
-
if
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
-
model.config.entity_vocab_size =
|
372 |
-
model.config.num_normal_entities =
|
373 |
-
model.config.num_category_entities =
|
374 |
-
model.config.entity_counts
|
|
|
375 |
|
376 |
-
gr.Info("
|
377 |
|
378 |
return entity_replaced_counts + 1
|
379 |
|
@@ -385,14 +341,15 @@ with gr.Blocks() as demo:
|
|
385 |
bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
|
386 |
bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
|
387 |
|
|
|
|
|
388 |
# Hint: gr.State に callable を渡すと、それが state の初期値を設定するための関数とみなされて
|
389 |
# __call__ が引数なしで実行されてしまうため、gr.State の引数に model や tokenizer を単体で渡すとエラーになってしまう。
|
390 |
# ここでは、モデル一式のタプル(callable でない)を渡すことで、そのようなエラーを回避している。
|
391 |
# cf. https://www.gradio.app/docs/gradio/state#param-state-value
|
392 |
models = gr.State((model, tokenizer, bm25_tokenizer, bm25_retriever))
|
393 |
|
394 |
-
|
395 |
-
output_texts = gr.State([])
|
396 |
|
397 |
entity_replaced_counts = gr.State(0)
|
398 |
|
@@ -400,26 +357,59 @@ with gr.Blocks() as demo:
|
|
400 |
entity_span_sensitivity = gr.State(1.0)
|
401 |
nayose_coef = gr.State(1.0)
|
402 |
|
|
|
403 |
batch_entity_spans = gr.State([])
|
404 |
topk_normal_entities = gr.State([])
|
405 |
topk_category_entities = gr.State([])
|
406 |
topk_span_entities = gr.State([])
|
407 |
|
408 |
-
|
409 |
-
|
|
|
|
|
|
|
410 |
|
411 |
-
|
|
|
|
|
|
|
412 |
|
413 |
gr.Markdown("## 入力テキスト")
|
414 |
|
415 |
with gr.Tab(label="直接入力"):
|
416 |
text_input = gr.Textbox(label=f"入力テキスト(最大{MAX_TEXT_LENGTH}文字)", max_length=MAX_TEXT_LENGTH)
|
|
|
417 |
with gr.Tab(label="ファイルアップロード"):
|
418 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
419 |
texts_file = gr.File(label="入力テキストファイル")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
with gr.Accordion(label="ハイパーパラメータ", open=False):
|
422 |
-
topk_input = gr.Number(5, label="
|
423 |
entity_span_sensitivity_input = gr.Slider(
|
424 |
minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True
|
425 |
)
|
@@ -427,25 +417,23 @@ with gr.Blocks() as demo:
|
|
427 |
minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
|
428 |
)
|
429 |
|
430 |
-
text_input.change(fn=lambda text: [normalize_text(text)], inputs=text_input, outputs=input_texts)
|
431 |
-
texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=input_texts)
|
432 |
topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk)
|
433 |
entity_span_sensitivity_input.change(
|
434 |
fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
|
435 |
)
|
436 |
nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef)
|
437 |
|
438 |
-
with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
|
439 |
gr.Markdown(
|
440 |
-
"""LUXE
|
441 |
-
|
442 |
line_breaks=True,
|
443 |
)
|
444 |
gr.Markdown(
|
445 |
-
f"「エンティティ」と「エンティティの説明文」の2列からなるCSVファイル(最大{MAX_ENTITY_FILE_LINES}行)をアップロードできます。"
|
446 |
)
|
447 |
-
new_entity_text_pairs_file = gr.File(label="エンティティと説明文のCSVファイル", height="128px")
|
448 |
-
gr.Markdown("CSV
|
449 |
new_entity_text_pairs_input = gr.Dataframe(
|
450 |
# value=sample_new_entity_text_pairs,
|
451 |
headers=["entity", "text"],
|
@@ -454,41 +442,28 @@ with gr.Blocks() as demo:
|
|
454 |
label="エンティティと説明文",
|
455 |
interactive=True,
|
456 |
)
|
|
|
457 |
replace_entity_button = gr.Button(value="エンティティ語彙を置き換える")
|
458 |
-
gr.Markdown("LUXEのモデルのエンティティ語彙は、デモページの再読み込み時にリセットされます。")
|
459 |
|
460 |
new_entity_text_pairs_file.change(
|
461 |
fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
|
462 |
)
|
463 |
replace_entity_button.click(
|
464 |
fn=replace_entities,
|
465 |
-
inputs=[models, new_entity_text_pairs_input, entity_replaced_counts],
|
466 |
outputs=entity_replaced_counts,
|
467 |
)
|
468 |
|
469 |
-
submit_button = gr.Button(value="予測実行", variant="huggingface")
|
470 |
-
submit_button.click(
|
471 |
-
fn=get_topk_entities_from_texts,
|
472 |
-
inputs=[models, input_texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
473 |
-
outputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
474 |
-
)
|
475 |
-
text_input.submit(
|
476 |
-
fn=get_topk_entities_from_texts,
|
477 |
-
inputs=[models, input_texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
478 |
-
outputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
479 |
-
)
|
480 |
-
|
481 |
gr.Markdown("---")
|
482 |
-
gr.Markdown("##
|
483 |
|
484 |
-
@gr.render(
|
485 |
-
inputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities]
|
486 |
-
)
|
487 |
def render_topk_entities(
|
488 |
-
|
489 |
):
|
490 |
for text, entity_spans, normal_entities, category_entities, span_entities in zip(
|
491 |
-
|
492 |
):
|
493 |
highlighted_text_value = []
|
494 |
cur = 0
|
@@ -503,7 +478,10 @@ with gr.Blocks() as demo:
|
|
503 |
highlighted_text_value.append((text[cur:], None))
|
504 |
|
505 |
gr.HighlightedText(
|
506 |
-
value=highlighted_text_value,
|
|
|
|
|
|
|
507 |
)
|
508 |
|
509 |
# gr.Textbox(text, label="Text")
|
@@ -512,31 +490,22 @@ with gr.Blocks() as demo:
|
|
512 |
label="テキスト全体に関連するエンティティ",
|
513 |
components=["text"],
|
514 |
samples=[[entity] for entity in normal_entities],
|
515 |
-
)
|
516 |
if category_entities:
|
517 |
gr.Dataset(
|
518 |
label="テキスト全体に関連するカテゴリ",
|
519 |
components=["text"],
|
520 |
samples=[[entity] for entity in category_entities],
|
521 |
-
)
|
522 |
-
|
523 |
-
span_texts = [text[start:end] for start, end in entity_spans]
|
524 |
-
for span_text, entities in zip(span_texts, span_entities):
|
525 |
-
gr.Dataset(
|
526 |
-
label=f"「{span_text}」に対応するエンティティ",
|
527 |
-
components=["text"],
|
528 |
-
samples=[[entity] for entity in entities],
|
529 |
-
).select(fn=get_selected_entity, outputs=selected_entity)
|
530 |
-
|
531 |
-
# gr.Markdown("---")
|
532 |
-
# gr.Markdown("## 選択されたエンティティの類似エンティティ")
|
533 |
-
|
534 |
-
# selected_entity.change(fn=get_similar_entities, inputs=[models, selected_entity], outputs=similar_entities)
|
535 |
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
|
|
|
|
|
|
|
|
540 |
|
541 |
|
542 |
demo.launch()
|
|
|
2 |
import re
|
3 |
import unicodedata
|
4 |
from collections import defaultdict
|
5 |
+
from itertools import chain
|
6 |
|
7 |
import gradio as gr
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
|
|
11 |
from bm25s.hf import BM25HF, TokenizerHF
|
|
|
12 |
from transformers import AutoModelForPreTraining, AutoTokenizer
|
13 |
|
14 |
|
15 |
ALIAS_SEP = "|"
|
16 |
+
CATEGORY_ENTITY_PREFIX = "Category:"
|
17 |
ENTITY_SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[MASK]", "[MASK2]"]
|
18 |
MAX_TEXT_LENGTH = 800
|
19 |
+
MAX_TEXT_FILE_LINES = 10
|
20 |
MAX_ENTITY_FILE_LINES = 1000
|
21 |
|
22 |
repo_id = "studio-ousia/luxe"
|
|
|
36 |
]
|
37 |
|
38 |
|
39 |
+
def clean_default_entity_vocab(tokenizer):
|
40 |
+
entity_vocab = {}
|
41 |
+
for entity, entity_id in tokenizer.entity_vocab.items():
|
42 |
+
if entity.startswith("ja:"):
|
43 |
+
entity = entity.removeprefix("ja:")
|
44 |
+
elif entity.startswith("Category:ja:"):
|
45 |
+
entity = "Category:" + entity.removeprefix("Category:ja:")
|
46 |
|
47 |
+
entity_vocab[entity] = entity_id
|
|
|
48 |
|
49 |
+
tokenizer.entity_vocab = entity_vocab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
def normalize_text(text: str) -> str:
|
53 |
+
return unicodedata.normalize("NFKC", text).strip()
|
54 |
|
55 |
|
56 |
def get_texts_from_file(file_path: str | None):
|
|
|
61 |
reader = csv.DictReader(f, fieldnames=["text"])
|
62 |
for i, row in enumerate(reader):
|
63 |
if i >= MAX_TEXT_FILE_LINES:
|
64 |
+
gr.Info(f"{MAX_TEXT_FILE_LINES}行目までのデータを読み込みました。", duration=5)
|
65 |
break
|
66 |
|
67 |
+
text = row["text"]
|
68 |
+
if text.strip() != "":
|
69 |
texts.append(text[:MAX_TEXT_LENGTH])
|
70 |
except Exception as e:
|
71 |
+
gr.Warning("ファイルを正しく読み込めませんでした。", duration=5)
|
72 |
print(e)
|
73 |
texts = []
|
74 |
|
75 |
return texts
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
def get_token_spans(tokenizer, text: str) -> list[tuple[int, int]]:
|
79 |
token_spans = []
|
80 |
end = 0
|
|
|
119 |
|
120 |
def get_topk_entities_from_texts(
|
121 |
models,
|
122 |
+
texts: str | list[str],
|
123 |
k: int = 5,
|
124 |
entity_span_sensitivity: float = 1.0,
|
125 |
nayose_coef: float = 1.0,
|
126 |
entity_replaced_counts: bool = False,
|
127 |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
|
128 |
+
gr.Info("LUXEによる予測を実行しています。", duration=5)
|
129 |
+
|
130 |
+
if isinstance(texts, str):
|
131 |
+
texts = [texts]
|
132 |
+
|
133 |
model, tokenizer, bm25_tokenizer, bm25_retriever = models
|
134 |
|
135 |
batch_entity_spans: list[list[tuple[int, int]]] = []
|
|
|
154 |
and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
|
155 |
]
|
156 |
|
157 |
+
entity_k = min(k, len(id2normal_entity))
|
158 |
+
category_k = min(k, len(id2category_entity))
|
159 |
+
|
160 |
for text in texts:
|
161 |
+
text = normalize_text(text).strip()
|
162 |
+
|
163 |
tokenized_examples = tokenizer(text, return_tensors="pt")
|
164 |
model_outputs = model(**tokenized_examples)
|
165 |
token_spans = get_token_spans(tokenizer, text)
|
|
|
170 |
model_outputs = model(**tokenized_examples)
|
171 |
|
172 |
if model_outputs.topic_entity_logits is not None:
|
173 |
+
_, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(entity_k)
|
174 |
topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
|
175 |
else:
|
176 |
topk_normal_entities.append([])
|
177 |
|
178 |
if model_outputs.topic_category_logits is not None:
|
179 |
model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
|
180 |
+
_, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(category_k)
|
181 |
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
|
182 |
else:
|
183 |
topk_category_entities.append([])
|
|
|
193 |
)
|
194 |
span_entity_logits += nayose_coef * nayose_scores
|
195 |
|
196 |
+
_, topk_span_entity_ids = span_entity_logits.topk(entity_k)
|
197 |
topk_span_entities.append(
|
198 |
[[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]
|
199 |
)
|
|
|
203 |
return texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
|
204 |
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]]:
|
207 |
new_entity_text_pairs = []
|
208 |
if file_path is not None:
|
|
|
211 |
reader = csv.DictReader(f, fieldnames=["entity", "text"])
|
212 |
for i, row in enumerate(reader):
|
213 |
if i >= MAX_ENTITY_FILE_LINES:
|
214 |
+
gr.Info(f"{MAX_ENTITY_FILE_LINES}行目までのデータを読み込みました。", duration=5)
|
215 |
break
|
216 |
|
217 |
entity = normalize_text(row["entity"]).strip()
|
|
|
219 |
if entity != "" and text != "":
|
220 |
new_entity_text_pairs.append([entity, text])
|
221 |
except Exception as e:
|
222 |
+
gr.Warning("ファイルを正しく読み込めませんでした。", duration=5)
|
223 |
print(e)
|
224 |
new_entity_text_pairs = []
|
225 |
|
|
|
227 |
|
228 |
|
229 |
def replace_entities(
|
230 |
+
models, new_entity_text_pairs: list[tuple[str, str]], entity_replaced_counts: int, preserve_default_entities: bool
|
231 |
+
) -> int:
|
232 |
+
if len(new_entity_text_pairs) == 0:
|
233 |
+
return entity_replaced_counts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
+
gr.Info("LUXEのモデルとトークナイザのエンティティ語彙を更新しています。完了までお待ちください。", duration=5)
|
236 |
|
237 |
+
model, tokenizer, bm25_tokenizer, bm25_retriever = models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
|
239 |
+
normal_entity_embeddings = defaultdict(list) # entity -> list of embeddings
|
240 |
+
category_entity_embeddings = defaultdict(list) # entity -> list of embeddings
|
241 |
+
normal_entity_counts = {} # entity -> count (int)
|
242 |
+
category_entity_counts = {} # entity -> count (int)
|
243 |
+
|
244 |
+
for entity, entity_id in sorted(tokenizer.entity_vocab.items(), key=lambda x: x[1]):
|
245 |
+
if entity in ENTITY_SPECIAL_TOKENS or preserve_default_entities:
|
246 |
+
entity_embedding = model.luke.entity_embeddings.entity_embeddings.weight.data[entity_id]
|
247 |
+
if entity.startswith(CATEGORY_ENTITY_PREFIX):
|
248 |
+
category_entity_embeddings[entity].append(entity_embedding)
|
249 |
+
if model.config.entity_counts is not None:
|
250 |
+
category_entity_counts[entity] = model.config.entity_counts[entity_id]
|
251 |
+
else:
|
252 |
+
category_entity_counts[entity] = 1
|
253 |
+
else:
|
254 |
+
normal_entity_embeddings[entity].append(entity_embedding)
|
255 |
+
if model.config.entity_counts is not None:
|
256 |
+
normal_entity_counts[entity] = model.config.entity_counts[entity_id]
|
257 |
+
else:
|
258 |
+
normal_entity_counts[entity] = 1
|
259 |
|
260 |
for entity, text in new_entity_text_pairs:
|
|
|
261 |
tokenized_inputs = tokenizer(text[:MAX_TEXT_LENGTH], return_tensors="pt")
|
262 |
model_outputs = model(**tokenized_inputs)
|
263 |
+
entity_embedding = model.entity_predictions.transform(model_outputs.last_hidden_state[:, 0])[0]
|
264 |
+
if entity.startswith(CATEGORY_ENTITY_PREFIX):
|
265 |
+
category_entity_embeddings[entity].append(entity_embedding)
|
266 |
+
category_entity_counts.setdefault(entity, 1)
|
267 |
+
else:
|
268 |
+
normal_entity_embeddings[entity].append(entity_embedding)
|
269 |
+
normal_entity_counts.setdefault(entity, 1)
|
270 |
|
271 |
+
num_normal_entities = len(normal_entity_embeddings)
|
272 |
+
num_category_entities = len(category_entity_embeddings)
|
273 |
|
274 |
+
entity_embeddings = {
|
275 |
+
entity: sum(embeddings) / len(embeddings)
|
276 |
+
for entity, embeddings in chain(normal_entity_embeddings.items(), category_entity_embeddings.items())
|
277 |
+
}
|
278 |
+
entity_vocab = {entity: entity_id for entity_id, entity in enumerate(entity_embeddings.keys())}
|
279 |
+
|
280 |
+
entity_counts = [
|
281 |
+
category_entity_counts[entity] if entity.startswith(CATEGORY_ENTITY_PREFIX) else normal_entity_counts[entity]
|
282 |
+
for entity in entity_vocab.keys()
|
283 |
+
]
|
284 |
|
285 |
+
tokenizer.entity_vocab = entity_vocab
|
286 |
+
tokenizer.entity_pad_token_id = entity_vocab["[PAD]"]
|
287 |
+
tokenizer.entity_unk_token_id = entity_vocab["[UNK]"]
|
288 |
+
tokenizer.entity_mask_token_id = entity_vocab["[MASK]"]
|
289 |
+
tokenizer.entity_mask2_token_id = entity_vocab["[MASK2]"]
|
290 |
|
291 |
+
entity_embeddings_tensor = torch.vstack(list(entity_embeddings.values()))
|
|
|
292 |
|
293 |
if model.config.normalize_entity_embeddings:
|
294 |
+
entity_embeddings_tensor = F.normalize(entity_embeddings_tensor)
|
295 |
|
296 |
+
entity_vocab_size, entity_emb_size = entity_embeddings_tensor.size()
|
297 |
+
|
298 |
+
entity_embeddings_module = nn.Embedding(
|
299 |
+
entity_vocab_size,
|
300 |
+
entity_emb_size,
|
301 |
+
padding_idx=tokenizer.entity_pad_token_id,
|
302 |
device=model.luke.entity_embeddings.entity_embeddings.weight.device,
|
303 |
dtype=model.luke.entity_embeddings.entity_embeddings.weight.dtype,
|
304 |
)
|
305 |
+
entity_embeddings_module.weight.data = entity_embeddings_tensor.data
|
306 |
+
model.luke.entity_embeddings.entity_embeddings = entity_embeddings_module
|
307 |
|
308 |
+
entity_decoder_module = nn.Linear(entity_emb_size, entity_vocab_size, bias=False)
|
309 |
+
model.entity_predictions.decoder = entity_decoder_module
|
310 |
+
model.entity_predictions.bias = nn.Parameter(torch.zeros(entity_vocab_size))
|
311 |
model.tie_weights()
|
312 |
|
313 |
+
if model.config.entity_counts is not None:
|
314 |
+
total_normal_entity_count = sum(entity_counts[:num_normal_entities])
|
315 |
+
total_category_entity_count = sum(entity_counts[num_normal_entities:])
|
316 |
+
|
317 |
+
entity_counts_tensor = torch.tensor(entity_counts, dtype=model.dtype, device=model.device)
|
318 |
+
total_entity_counts = torch.tensor(
|
319 |
+
[total_normal_entity_count] * num_normal_entities + [total_category_entity_count] * num_category_entities,
|
320 |
+
dtype=model.dtype,
|
321 |
+
device=model.device,
|
322 |
+
)
|
323 |
+
entity_log_probs = torch.log(entity_counts_tensor / total_entity_counts)
|
324 |
+
model.entity_log_probs = entity_log_probs
|
325 |
|
326 |
+
model.config.entity_vocab_size = entity_vocab_size
|
327 |
+
model.config.num_normal_entities = num_normal_entities
|
328 |
+
model.config.num_category_entities = num_category_entities
|
329 |
+
if model.config.entity_counts is not None:
|
330 |
+
model.config.entity_counts = entity_counts
|
331 |
|
332 |
+
gr.Info("LUXEのモデルとトークナイザのエンティティ語彙の更新が完了しました。", duration=5)
|
333 |
|
334 |
return entity_replaced_counts + 1
|
335 |
|
|
|
341 |
bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
|
342 |
bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
|
343 |
|
344 |
+
clean_default_entity_vocab(tokenizer)
|
345 |
+
|
346 |
# Hint: gr.State に callable を渡すと、それが state の初期値を設定するための関数とみなされて
|
347 |
# __call__ が引数なしで実行されてしまうため、gr.State の引数に model や tokenizer を単体で渡すとエラーになってしまう。
|
348 |
# ここでは、モデル一式のタプル(callable でない)を渡すことで、そのようなエラーを回避している。
|
349 |
# cf. https://www.gradio.app/docs/gradio/state#param-state-value
|
350 |
models = gr.State((model, tokenizer, bm25_tokenizer, bm25_retriever))
|
351 |
|
352 |
+
texts_input = gr.State([])
|
|
|
353 |
|
354 |
entity_replaced_counts = gr.State(0)
|
355 |
|
|
|
357 |
entity_span_sensitivity = gr.State(1.0)
|
358 |
nayose_coef = gr.State(1.0)
|
359 |
|
360 |
+
texts = gr.State([])
|
361 |
batch_entity_spans = gr.State([])
|
362 |
topk_normal_entities = gr.State([])
|
363 |
topk_category_entities = gr.State([])
|
364 |
topk_span_entities = gr.State([])
|
365 |
|
366 |
+
gr.Markdown("# 📝 LUXE Demo (β版)")
|
367 |
+
|
368 |
+
gr.Markdown(
|
369 |
+
"""Studio Ousia で開発中の次世代知識強化言語モデル **LUXE** の動作デモです。
|
370 |
+
入力されたテキストに対して、テキスト中に出現するエンティティ(事物)と、テキスト全体の主題となるエンティティおよびカテゴリを予測します。
|
371 |
|
372 |
+
デフォルトのLUXEは、エンティティおよびカテゴリとして、それぞれ日本語 Wikipedia における被リンク数上位50万件および10万件の項目を使用しています。
|
373 |
+
予測対象のエンティティを任意のものに置き換えて推論を行うことも可能です(下記「LUXE のエンティティ語彙を置き換える」を参照してください)。""",
|
374 |
+
line_breaks=True,
|
375 |
+
)
|
376 |
|
377 |
gr.Markdown("## 入力テキスト")
|
378 |
|
379 |
with gr.Tab(label="直接入力"):
|
380 |
text_input = gr.Textbox(label=f"入力テキスト(最大{MAX_TEXT_LENGTH}文字)", max_length=MAX_TEXT_LENGTH)
|
381 |
+
text_submit_button = gr.Button(value="予測実行", variant="huggingface")
|
382 |
with gr.Tab(label="ファイルアップロード"):
|
383 |
+
gr.Markdown(
|
384 |
+
f"""1行1事例のテキストファイル(最大{MAX_TEXT_FILE_LINES}行)をアップロードできます。
|
385 |
+
アップロードされたテキストのそれぞれに対して推論が実行されます。""",
|
386 |
+
line_breaks=True,
|
387 |
+
)
|
388 |
texts_file = gr.File(label="入力テキストファイル")
|
389 |
+
texts_submit_button = gr.Button(value="予測実行", variant="huggingface")
|
390 |
+
|
391 |
+
text_input.submit(
|
392 |
+
fn=get_topk_entities_from_texts,
|
393 |
+
inputs=[models, text_input, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
394 |
+
outputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
395 |
+
)
|
396 |
+
text_submit_button.click(
|
397 |
+
fn=get_topk_entities_from_texts,
|
398 |
+
inputs=[models, text_input, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
399 |
+
outputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
400 |
+
)
|
401 |
+
|
402 |
+
texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts_input)
|
403 |
+
texts_submit_button.click(
|
404 |
+
fn=get_topk_entities_from_texts,
|
405 |
+
inputs=[models, texts_input, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
406 |
+
outputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
407 |
+
)
|
408 |
+
|
409 |
+
gr.Markdown("---")
|
410 |
|
411 |
with gr.Accordion(label="ハイパーパラメータ", open=False):
|
412 |
+
topk_input = gr.Number(5, label="予測するエンティティの件数 (Top K)", interactive=True)
|
413 |
entity_span_sensitivity_input = gr.Slider(
|
414 |
minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True
|
415 |
)
|
|
|
417 |
minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
|
418 |
)
|
419 |
|
|
|
|
|
420 |
topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk)
|
421 |
entity_span_sensitivity_input.change(
|
422 |
fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
|
423 |
)
|
424 |
nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef)
|
425 |
|
426 |
+
with gr.Accordion(label="LUXE のエンティティ語彙を置き換える", open=False):
|
427 |
gr.Markdown(
|
428 |
+
"""LUXE のモデルとトークナイザのエンティティ語彙を任意のエンティティ集合に置き換えます。
|
429 |
+
エンティティとともに与えられるエンティティの説明文から、エンティティの埋め込みが計算され、LUXE の推論に利用されます。""",
|
430 |
line_breaks=True,
|
431 |
)
|
432 |
gr.Markdown(
|
433 |
+
f"「エンティティ」と「エンティティの説明文」の2列からなる CSV ファイル(最大{MAX_ENTITY_FILE_LINES}行)をアップロードできます。"
|
434 |
)
|
435 |
+
new_entity_text_pairs_file = gr.File(label="エンティティと説明文の CSV ファイル", height="128px")
|
436 |
+
gr.Markdown("CSV ファイルから読み込まれた項目が以下の表に表示されます。表の内容を直接編集することも可能です。")
|
437 |
new_entity_text_pairs_input = gr.Dataframe(
|
438 |
# value=sample_new_entity_text_pairs,
|
439 |
headers=["entity", "text"],
|
|
|
442 |
label="エンティティと説明文",
|
443 |
interactive=True,
|
444 |
)
|
445 |
+
preserve_default_entities_checkbox = gr.Checkbox(label="既存のエンティティを保持する", value=True)
|
446 |
replace_entity_button = gr.Button(value="エンティティ語彙を置き換える")
|
447 |
+
gr.Markdown("LUXE のモデルのエンティティ語彙は、デモページの再読み込み時にリセットされます。")
|
448 |
|
449 |
new_entity_text_pairs_file.change(
|
450 |
fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
|
451 |
)
|
452 |
replace_entity_button.click(
|
453 |
fn=replace_entities,
|
454 |
+
inputs=[models, new_entity_text_pairs_input, entity_replaced_counts, preserve_default_entities_checkbox],
|
455 |
outputs=entity_replaced_counts,
|
456 |
)
|
457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
gr.Markdown("---")
|
459 |
+
gr.Markdown("## 予測されたエンティティとカテゴリ")
|
460 |
|
461 |
+
@gr.render(inputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities])
|
|
|
|
|
462 |
def render_topk_entities(
|
463 |
+
texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
|
464 |
):
|
465 |
for text, entity_spans, normal_entities, category_entities, span_entities in zip(
|
466 |
+
texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
|
467 |
):
|
468 |
highlighted_text_value = []
|
469 |
cur = 0
|
|
|
478 |
highlighted_text_value.append((text[cur:], None))
|
479 |
|
480 |
gr.HighlightedText(
|
481 |
+
value=highlighted_text_value,
|
482 |
+
color_map={"Entity": "green"},
|
483 |
+
combine_adjacent=False,
|
484 |
+
label="予測されたエンティティのスパン",
|
485 |
)
|
486 |
|
487 |
# gr.Textbox(text, label="Text")
|
|
|
490 |
label="テキスト全体に関連するエンティティ",
|
491 |
components=["text"],
|
492 |
samples=[[entity] for entity in normal_entities],
|
493 |
+
)
|
494 |
if category_entities:
|
495 |
gr.Dataset(
|
496 |
label="テキスト全体に関連するカテゴリ",
|
497 |
components=["text"],
|
498 |
samples=[[entity] for entity in category_entities],
|
499 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
|
501 |
+
with gr.Accordion(label="テキスト中のスパンに対応するエンティティ", open=len(texts) == 1):
|
502 |
+
span_texts = [text[start:end] for start, end in entity_spans]
|
503 |
+
for span_text, entities in zip(span_texts, span_entities):
|
504 |
+
gr.Dataset(
|
505 |
+
label=f"「{span_text}」に対応するエンティティ",
|
506 |
+
components=["text"],
|
507 |
+
samples=[[entity] for entity in entities],
|
508 |
+
)
|
509 |
|
510 |
|
511 |
demo.launch()
|