Spaces:
Runtime error
๋ฌธ์ ์ง์ ์๋ต(Document Question Answering) [[document_question_answering]]
[[open-in-colab]]
๋ฌธ์ ์๊ฐ์ ์ง์ ์๋ต(Document Visual Question Answering)์ด๋ผ๊ณ ๋ ํ๋ ๋ฌธ์ ์ง์ ์๋ต(Document Question Answering)์ ๋ฌธ์ ์ด๋ฏธ์ง์ ๋ํ ์ง๋ฌธ์ ๋ต๋ณ์ ์ฃผ๋ ํ์คํฌ์ ๋๋ค. ์ด ํ์คํฌ๋ฅผ ์ง์ํ๋ ๋ชจ๋ธ์ ์ ๋ ฅ์ ์ผ๋ฐ์ ์ผ๋ก ์ด๋ฏธ์ง์ ์ง๋ฌธ์ ์กฐํฉ์ด๊ณ , ์ถ๋ ฅ์ ์์ฐ์ด๋ก ๋ ๋ต๋ณ์ ๋๋ค. ์ด๋ฌํ ๋ชจ๋ธ์ ํ ์คํธ, ๋จ์ด์ ์์น(๋ฐ์ด๋ฉ ๋ฐ์ค), ์ด๋ฏธ์ง ๋ฑ ๋ค์ํ ๋ชจ๋ฌ๋ฆฌํฐ๋ฅผ ํ์ฉํฉ๋๋ค.
์ด ๊ฐ์ด๋๋ ๋ค์ ๋ด์ฉ์ ์ค๋ช ํฉ๋๋ค:
- DocVQA dataset์ ์ฌ์ฉํด LayoutLMv2 ๋ฏธ์ธ ์กฐ์ ํ๊ธฐ
- ์ถ๋ก ์ ์ํด ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ธฐ
์ด ํํ ๋ฆฌ์ผ์์ ์ค๋ช ํ๋ ํ์คํฌ๋ ๋ค์๊ณผ ๊ฐ์ ๋ชจ๋ธ ์ํคํ ์ฒ์์ ์ง์๋ฉ๋๋ค:
LayoutLM, LayoutLMv2, LayoutLMv3
LayoutLMv2๋ ํ ํฐ์ ๋ง์ง๋ง ์๋์ธต ์์ ์ง์ ์๋ต ํค๋๋ฅผ ์ถ๊ฐํด ๋ต๋ณ์ ์์ ํ ํฐ๊ณผ ๋ ํ ํฐ์ ์์น๋ฅผ ์์ธกํจ์ผ๋ก์จ ๋ฌธ์ ์ง์ ์๋ต ํ์คํฌ๋ฅผ ํด๊ฒฐํฉ๋๋ค. ์ฆ, ๋ฌธ๋งฅ์ด ์ฃผ์ด์ก์ ๋ ์ง๋ฌธ์ ๋ตํ๋ ์ ๋ณด๋ฅผ ์ถ์ถํ๋ ์ถ์ถํ ์ง์ ์๋ต(Extractive question answering)์ผ๋ก ๋ฌธ์ ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค. ๋ฌธ๋งฅ์ OCR ์์ง์ ์ถ๋ ฅ์์ ๊ฐ์ ธ์ค๋ฉฐ, ์ฌ๊ธฐ์๋ Google์ Tesseract๋ฅผ ์ฌ์ฉํฉ๋๋ค.
์์ํ๊ธฐ ์ ์ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ๋ชจ๋ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์. LayoutLMv2๋ detectron2, torchvision ๋ฐ ํ ์๋ํธ๋ฅผ ํ์๋ก ํฉ๋๋ค.
pip install -q transformers datasets
pip install 'git+https://github.com/facebookresearch/detectron2.git'
pip install torchvision
sudo apt install tesseract-ocr
pip install -q pytesseract
ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค์ ๋ชจ๋ ์ค์นํ ํ ๋ฐํ์์ ๋ค์ ์์ํฉ๋๋ค.
์ปค๋ฎค๋ํฐ์ ๋น์ ์ ๋ชจ๋ธ์ ๊ณต์ ํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํด์ ๋ชจ๋ธ์ ๐ค Hub์ ์ ๋ก๋ํ์ธ์. ํ๋กฌํํธ๊ฐ ์คํ๋๋ฉด, ๋ก๊ทธ์ธ์ ์ํด ํ ํฐ์ ์ ๋ ฅํ์ธ์:
>>> from huggingface_hub import notebook_login
>>> notebook_login()
๋ช ๊ฐ์ง ์ ์ญ ๋ณ์๋ฅผ ์ ์ํด ๋ณด๊ฒ ์ต๋๋ค.
>>> model_checkpoint = "microsoft/layoutlmv2-base-uncased"
>>> batch_size = 4
๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ [[load-the-data]]
์ด ๊ฐ์ด๋์์๋ ๐ค Hub์์ ์ฐพ์ ์ ์๋ ์ ์ฒ๋ฆฌ๋ DocVQA์ ์์ ์ํ์ ์ฌ์ฉํฉ๋๋ค. DocVQA์ ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ฌ์ฉํ๊ณ ์ถ๋ค๋ฉด, DocVQA homepage์ ๊ฐ์ ํ ๋ค์ด๋ก๋ ํ ์ ์์ต๋๋ค. ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ค์ด๋ก๋ ํ๋ค๋ฉด, ์ด ๊ฐ์ด๋๋ฅผ ๊ณ์ ์งํํ๊ธฐ ์ํด ๐ค dataset์ ํ์ผ์ ๊ฐ์ ธ์ค๋ ๋ฐฉ๋ฒ์ ํ์ธํ์ธ์.
>>> from datasets import load_dataset
>>> dataset = load_dataset("nielsr/docvqa_1200_examples")
>>> dataset
DatasetDict({
train: Dataset({
features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],
num_rows: 1000
})
test: Dataset({
features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],
num_rows: 200
})
})
๋ณด์๋ค์ํผ, ๋ฐ์ดํฐ ์ธํธ๋ ์ด๋ฏธ ํ๋ จ ์ธํธ์ ํ ์คํธ ์ธํธ๋ก ๋๋์ด์ ธ ์์ต๋๋ค. ๋ฌด์์๋ก ์์ ๋ฅผ ์ดํด๋ณด๋ฉด์ ํน์ฑ์ ํ์ธํด๋ณด์ธ์.
>>> dataset["train"].features
๊ฐ ํ๋๊ฐ ๋ํ๋ด๋ ๋ด์ฉ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
id
: ์์ ์ idimage
: ๋ฌธ์ ์ด๋ฏธ์ง๋ฅผ ํฌํจํ๋ PIL.Image.Image ๊ฐ์ฒดquery
: ์ง๋ฌธ ๋ฌธ์์ด - ์ฌ๋ฌ ์ธ์ด์ ์์ฐ์ด๋ก ๋ ์ง๋ฌธanswers
: ์ฌ๋์ด ์ฃผ์์ ๋จ ์ ๋ต ๋ฆฌ์คํธwords
andbounding_boxes
: OCR์ ๊ฒฐ๊ณผ๊ฐ๋ค์ด๋ฉฐ ์ด ๊ฐ์ด๋์์๋ ์ฌ์ฉํ์ง ์์ ์์ answer
: ๋ค๋ฅธ ๋ชจ๋ธ๊ณผ ์ผ์นํ๋ ๋ต๋ณ์ด๋ฉฐ ์ด ๊ฐ์ด๋์์๋ ์ฌ์ฉํ์ง ์์ ์์
์์ด๋ก ๋ ์ง๋ฌธ๋ง ๋จ๊ธฐ๊ณ ๋ค๋ฅธ ๋ชจ๋ธ์ ๋ํ ์์ธก์ ํฌํจํ๋ answer
ํน์ฑ์ ์ญ์ ํ๊ฒ ์ต๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ฃผ์ ์์ฑ์๊ฐ ์ ๊ณตํ ๋ฐ์ดํฐ ์ธํธ์์ ์ฒซ ๋ฒ์งธ ๋ต๋ณ์ ๊ฐ์ ธ์ต๋๋ค. ๋๋ ๋ฌด์์๋ก ์ํ์ ์ถ์ถํ ์๋ ์์ต๋๋ค.
>>> updated_dataset = dataset.map(lambda example: {"question": example["query"]["en"]}, remove_columns=["query"])
>>> updated_dataset = updated_dataset.map(
... lambda example: {"answer": example["answers"][0]}, remove_columns=["answer", "answers"]
... )
์ด ๊ฐ์ด๋์์ ์ฌ์ฉํ๋ LayoutLMv2 ์ฒดํฌํฌ์ธํธ๋ max_position_embeddings = 512
๋ก ํ๋ จ๋์์ต๋๋ค(์ด ์ ๋ณด๋ ์ฒดํฌํฌ์ธํธ์ config.json
ํ์ผ์์ ํ์ธํ ์ ์์ต๋๋ค).
๋ฐ๋ก ์์ ๋ฅผ ์๋ผ๋ผ ์๋ ์์ง๋ง, ๊ธด ๋ฌธ์์ ๋์ ๋ต๋ณ์ด ์์ด ์๋ฆฌ๋ ์ํฉ์ ํผํ๊ธฐ ์ํด ์ฌ๊ธฐ์๋ ์๋ฒ ๋ฉ์ด 512๋ณด๋ค ๊ธธ์ด์ง ๊ฐ๋ฅ์ฑ์ด ์๋ ๋ช ๊ฐ์ง ์์ ๋ฅผ ์ ๊ฑฐํ๊ฒ ์ต๋๋ค.
๋ฐ์ดํฐ ์ธํธ์ ์๋ ๋๋ถ๋ถ์ ๋ฌธ์๊ฐ ๊ธด ๊ฒฝ์ฐ ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ ์ ์์ต๋๋ค - ์์ธํ ๋ด์ฉ์ ํ์ธํ๊ณ ์ถ์ผ๋ฉด ์ด ๋
ธํธ๋ถ์ ํ์ธํ์ธ์.
>>> updated_dataset = updated_dataset.filter(lambda x: len(x["words"]) + len(x["question"].split()) < 512)
์ด ์์ ์์ ์ด ๋ฐ์ดํฐ ์ธํธ์ OCR ํน์ฑ๋ ์ ๊ฑฐํด ๋ณด๊ฒ ์ต๋๋ค. OCR ํน์ฑ์ ๋ค๋ฅธ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๊ธฐ ์ํ ๊ฒ์ผ๋ก, ์ด ๊ฐ์ด๋์์ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ ์
๋ ฅ ์๊ตฌ ์ฌํญ๊ณผ ์ผ์นํ์ง ์๊ธฐ ๋๋ฌธ์ ์ด ํน์ฑ์ ์ฌ์ฉํ๊ธฐ ์ํด์๋ ์ผ๋ถ ์ฒ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค.
๋์ , ์๋ณธ ๋ฐ์ดํฐ์ [LayoutLMv2Processor
]๋ฅผ ์ฌ์ฉํ์ฌ OCR ๋ฐ ํ ํฐํ๋ฅผ ๋ชจ๋ ์ํํ ์ ์์ต๋๋ค.
์ด๋ ๊ฒ ํ๋ฉด ๋ชจ๋ธ์ด ์๊ตฌํ๋ ์
๋ ฅ์ ์ป์ ์ ์์ต๋๋ค.
์ด๋ฏธ์ง๋ฅผ ์๋์ผ๋ก ์ฒ๋ฆฌํ๋ ค๋ฉด, LayoutLMv2
model documentation์์ ๋ชจ๋ธ์ด ์๊ตฌํ๋ ์
๋ ฅ ํฌ๋งท์ ํ์ธํด๋ณด์ธ์.
>>> updated_dataset = updated_dataset.remove_columns("words")
>>> updated_dataset = updated_dataset.remove_columns("bounding_boxes")
๋ง์ง๋ง์ผ๋ก, ๋ฐ์ดํฐ ํ์์ ์๋ฃํ๊ธฐ ์ํด ์ด๋ฏธ์ง ์์๋ฅผ ์ดํด๋ด ์๋ค.
>>> updated_dataset["train"][11]["image"]

๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ [[preprocess-the-data]]
๋ฌธ์ ์ง์ ์๋ต ํ์คํฌ๋ ๋ฉํฐ๋ชจ๋ฌ ํ์คํฌ์ด๋ฉฐ, ๊ฐ ๋ชจ๋ฌ๋ฆฌํฐ์ ์
๋ ฅ์ด ๋ชจ๋ธ์ ์๊ตฌ์ ๋ง๊ฒ ์ ์ฒ๋ฆฌ ๋์๋์ง ํ์ธํด์ผ ํฉ๋๋ค.
์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ์ ์๋ ์ด๋ฏธ์ง ํ๋ก์ธ์์ ํ
์คํธ ๋ฐ์ดํฐ๋ฅผ ์ธ์ฝ๋ฉํ ์ ์๋ ํ ํฌ๋์ด์ ๋ฅผ ๊ฒฐํฉํ [LayoutLMv2Processor
]๋ฅผ ๊ฐ์ ธ์ค๋ ๊ฒ๋ถํฐ ์์ํด ๋ณด๊ฒ ์ต๋๋ค.
>>> from transformers import AutoProcessor
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
๋ฌธ์ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ [[preprocessing-document-images]]
๋จผ์ , ํ๋ก์ธ์์ image_processor
๋ฅผ ์ฌ์ฉํด ๋ชจ๋ธ์ ๋ํ ๋ฌธ์ ์ด๋ฏธ์ง๋ฅผ ์ค๋นํด ๋ณด๊ฒ ์ต๋๋ค.
๊ธฐ๋ณธ๊ฐ์ผ๋ก, ์ด๋ฏธ์ง ํ๋ก์ธ์๋ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ 224x224๋ก ์กฐ์ ํ๊ณ ์์ ์ฑ๋์ ์์๊ฐ ์ฌ๋ฐ๋ฅธ์ง ํ์ธํ ํ ๋จ์ด์ ์ ๊ทํ๋ ๋ฐ์ด๋ฉ ๋ฐ์ค๋ฅผ ์ป๊ธฐ ์ํด ํ
์๋ํธ๋ฅผ ์ฌ์ฉํด OCR๋ฅผ ์ ์ฉํฉ๋๋ค.
์ด ํํ ๋ฆฌ์ผ์์ ์ฐ๋ฆฌ๊ฐ ํ์ํ ๊ฒ๊ณผ ๊ธฐ๋ณธ๊ฐ์ ์์ ํ ๋์ผํฉ๋๋ค. ์ด๋ฏธ์ง ๋ฐฐ์น์ ๊ธฐ๋ณธ ์ด๋ฏธ์ง ์ฒ๋ฆฌ๋ฅผ ์ ์ฉํ๊ณ OCR์ ๊ฒฐ๊ณผ๋ฅผ ๋ณํํ๋ ํจ์๋ฅผ ์์ฑํฉ๋๋ค.
>>> image_processor = processor.image_processor
>>> def get_ocr_words_and_boxes(examples):
... images = [image.convert("RGB") for image in examples["image"]]
... encoded_inputs = image_processor(images)
... examples["image"] = encoded_inputs.pixel_values
... examples["words"] = encoded_inputs.words
... examples["boxes"] = encoded_inputs.boxes
... return examples
์ด ์ ์ฒ๋ฆฌ๋ฅผ ๋ฐ์ดํฐ ์ธํธ ์ ์ฒด์ ๋น ๋ฅด๊ฒ ์ ์ฉํ๋ ค๋ฉด [~datasets.Dataset.map
]๋ฅผ ์ฌ์ฉํ์ธ์.
>>> dataset_with_ocr = updated_dataset.map(get_ocr_words_and_boxes, batched=True, batch_size=2)
ํ ์คํธ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ [[preprocessing-text-data]]
์ด๋ฏธ์ง์ OCR์ ์ ์ฉํ์ผ๋ฉด ๋ฐ์ดํฐ ์ธํธ์ ํ
์คํธ ๋ถ๋ถ์ ๋ชจ๋ธ์ ๋ง๊ฒ ์ธ์ฝ๋ฉํด์ผ ํฉ๋๋ค.
์ด ์ธ์ฝ๋ฉ์๋ ์ด์ ๋จ๊ณ์์ ๊ฐ์ ธ์จ ๋จ์ด์ ๋ฐ์ค๋ฅผ ํ ํฐ ์์ค์ input_ids
, attention_mask
, token_type_ids
๋ฐ bbox
๋ก ๋ณํํ๋ ์์
์ด ํฌํจ๋ฉ๋๋ค.
ํ
์คํธ๋ฅผ ์ ์ฒ๋ฆฌํ๋ ค๋ฉด ํ๋ก์ธ์์ tokenizer
๊ฐ ํ์ํฉ๋๋ค.
>>> tokenizer = processor.tokenizer
์์์ ์ธ๊ธํ ์ ์ฒ๋ฆฌ ์ธ์๋ ๋ชจ๋ธ์ ์ํด ๋ ์ด๋ธ์ ์ถ๊ฐํด์ผ ํฉ๋๋ค. ๐ค Transformers์ xxxForQuestionAnswering
๋ชจ๋ธ์ ๊ฒฝ์ฐ, ๋ ์ด๋ธ์ start_positions
์ end_positions
๋ก ๊ตฌ์ฑ๋๋ฉฐ ์ด๋ค ํ ํฐ์ด ๋ต๋ณ์ ์์๊ณผ ๋์ ์๋์ง๋ฅผ ๋ํ๋
๋๋ค.
๋ ์ด๋ธ ์ถ๊ฐ๋ฅผ ์ํด์, ๋จผ์ ๋ ํฐ ๋ฆฌ์คํธ(๋จ์ด ๋ฆฌ์คํธ)์์ ํ์ ๋ฆฌ์คํธ(๋จ์ด๋ก ๋ถํ ๋ ๋ต๋ณ)์ ์ฐพ์ ์ ์๋ ํฌํผ ํจ์๋ฅผ ์ ์ํฉ๋๋ค.
์ด ํจ์๋ words_list
์ answer_list
, ์ด๋ ๊ฒ ๋ ๋ฆฌ์คํธ๋ฅผ ์
๋ ฅ์ผ๋ก ๋ฐ์ต๋๋ค.
๊ทธ๋ฐ ๋ค์ words_list
๋ฅผ ๋ฐ๋ณตํ์ฌ words_list
์ ํ์ฌ ๋จ์ด(words_list[i])๊ฐ answer_list
์ ์ฒซ ๋ฒ์งธ ๋จ์ด(answer_list[0])์ ๊ฐ์์ง,
ํ์ฌ ๋จ์ด์์ ์์ํด answer_list
์ ๊ฐ์ ๊ธธ์ด๋งํผ์ words_list
์ ํ์ ๋ฆฌ์คํธ๊ฐ answer_list
์ ์ผ์นํ๋์ง ํ์ธํฉ๋๋ค.
์ด ์กฐ๊ฑด์ด ์ฐธ์ด๋ผ๋ฉด ์ผ์นํ๋ ํญ๋ชฉ์ ๋ฐ๊ฒฌํ์์ ์๋ฏธํ๋ฉฐ, ํจ์๋ ์ผ์น ํญ๋ชฉ, ์์ ์ธ๋ฑ์ค(idx) ๋ฐ ์ข
๋ฃ ์ธ๋ฑ์ค(idx + len(answer_list) - 1)๋ฅผ ๊ธฐ๋กํฉ๋๋ค. ์ผ์นํ๋ ํญ๋ชฉ์ด ๋ ๊ฐ ์ด์ ๋ฐ๊ฒฌ๋๋ฉด ํจ์๋ ์ฒซ ๋ฒ์งธ ํญ๋ชฉ๋ง ๋ฐํํฉ๋๋ค. ์ผ์นํ๋ ํญ๋ชฉ์ด ์๋ค๋ฉด ํจ์๋ (None
, 0, 0)์ ๋ฐํํฉ๋๋ค.
>>> def subfinder(words_list, answer_list):
... matches = []
... start_indices = []
... end_indices = []
... for idx, i in enumerate(range(len(words_list))):
... if words_list[i] == answer_list[0] and words_list[i : i + len(answer_list)] == answer_list:
... matches.append(answer_list)
... start_indices.append(idx)
... end_indices.append(idx + len(answer_list) - 1)
... if matches:
... return matches[0], start_indices[0], end_indices[0]
... else:
... return None, 0, 0
์ด ํจ์๊ฐ ์ด๋ป๊ฒ ์ ๋ต์ ์์น๋ฅผ ์ฐพ๋์ง ์ค๋ช ํ๊ธฐ ์ํด ๋ค์ ์์ ์์ ํจ์๋ฅผ ์ฌ์ฉํด ๋ณด๊ฒ ์ต๋๋ค:
>>> example = dataset_with_ocr["train"][1]
>>> words = [word.lower() for word in example["words"]]
>>> match, word_idx_start, word_idx_end = subfinder(words, example["answer"].lower().split())
>>> print("Question: ", example["question"])
>>> print("Words:", words)
>>> print("Answer: ", example["answer"])
>>> print("start_index", word_idx_start)
>>> print("end_index", word_idx_end)
Question: Who is in cc in this letter?
Words: ['wie', 'baw', 'brown', '&', 'williamson', 'tobacco', 'corporation', 'research', '&', 'development', 'internal', 'correspondence', 'to:', 'r.', 'h.', 'honeycutt', 'ce:', 't.f.', 'riehl', 'from:', '.', 'c.j.', 'cook', 'date:', 'may', '8,', '1995', 'subject:', 'review', 'of', 'existing', 'brainstorming', 'ideas/483', 'the', 'major', 'function', 'of', 'the', 'product', 'innovation', 'graup', 'is', 'to', 'develop', 'marketable', 'nove!', 'products', 'that', 'would', 'be', 'profitable', 'to', 'manufacture', 'and', 'sell.', 'novel', 'is', 'defined', 'as:', 'of', 'a', 'new', 'kind,', 'or', 'different', 'from', 'anything', 'seen', 'or', 'known', 'before.', 'innovation', 'is', 'defined', 'as:', 'something', 'new', 'or', 'different', 'introduced;', 'act', 'of', 'innovating;', 'introduction', 'of', 'new', 'things', 'or', 'methods.', 'the', 'products', 'may', 'incorporate', 'the', 'latest', 'technologies,', 'materials', 'and', 'know-how', 'available', 'to', 'give', 'then', 'a', 'unique', 'taste', 'or', 'look.', 'the', 'first', 'task', 'of', 'the', 'product', 'innovation', 'group', 'was', 'to', 'assemble,', 'review', 'and', 'categorize', 'a', 'list', 'of', 'existing', 'brainstorming', 'ideas.', 'ideas', 'were', 'grouped', 'into', 'two', 'major', 'categories', 'labeled', 'appearance', 'and', 'taste/aroma.', 'these', 'categories', 'are', 'used', 'for', 'novel', 'products', 'that', 'may', 'differ', 'from', 'a', 'visual', 'and/or', 'taste/aroma', 'point', 'of', 'view', 'compared', 'to', 'canventional', 'cigarettes.', 'other', 'categories', 'include', 'a', 'combination', 'of', 'the', 'above,', 'filters,', 'packaging', 'and', 'brand', 'extensions.', 'appearance', 'this', 'category', 'is', 'used', 'for', 'novel', 'cigarette', 'constructions', 'that', 'yield', 'visually', 'different', 'products', 'with', 'minimal', 'changes', 'in', 'smoke', 'chemistry', 'two', 'cigarettes', 'in', 'cne.', 'emulti-plug', 'te', 'build', 'yaur', 'awn', 'cigarette.', 'eswitchable', 'menthol', 'or', 'non', 'menthol', 'cigarette.', '*cigarettes', 'with', 'interspaced', 'perforations', 'to', 'enable', 'smoker', 'to', 'separate', 'unburned', 'section', 'for', 'future', 'smoking.', 'ยซshort', 'cigarette,', 'tobacco', 'section', '30', 'mm.', 'ยซextremely', 'fast', 'buming', 'cigarette.', 'ยซnovel', 'cigarette', 'constructions', 'that', 'permit', 'a', 'significant', 'reduction', 'iretobacco', 'weight', 'while', 'maintaining', 'smoking', 'mechanics', 'and', 'visual', 'characteristics.', 'higher', 'basis', 'weight', 'paper:', 'potential', 'reduction', 'in', 'tobacco', 'weight.', 'ยซmore', 'rigid', 'tobacco', 'column;', 'stiffing', 'agent', 'for', 'tobacco;', 'e.g.', 'starch', '*colored', 'tow', 'and', 'cigarette', 'papers;', 'seasonal', 'promotions,', 'e.g.', 'pastel', 'colored', 'cigarettes', 'for', 'easter', 'or', 'in', 'an', 'ebony', 'and', 'ivory', 'brand', 'containing', 'a', 'mixture', 'of', 'all', 'black', '(black', 'paper', 'and', 'tow)', 'and', 'ail', 'white', 'cigarettes.', '499150498']
Answer: T.F. Riehl
start_index 17
end_index 18
ํํธ, ์ ์์ ๊ฐ ์ธ์ฝ๋ฉ๋๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํ์๋ฉ๋๋ค:
>>> encoding = tokenizer(example["question"], example["words"], example["boxes"])
>>> tokenizer.decode(encoding["input_ids"])
[CLS] who is in cc in this letter? [SEP] wie baw brown & williamson tobacco corporation research & development ...
์ด์ ์ธ์ฝ๋ฉ๋ ์ ๋ ฅ์์ ์ ๋ต์ ์์น๋ฅผ ์ฐพ์์ผ ํฉ๋๋ค.
token_type_ids
๋ ์ด๋ค ํ ํฐ์ด ์ง๋ฌธ์ ์ํ๋์ง, ๊ทธ๋ฆฌ๊ณ ์ด๋ค ํ ํฐ์ด ๋ฌธ์์ ๋จ์ด์ ํฌํจ๋๋์ง๋ฅผ ์๋ ค์ค๋๋ค.tokenizer.cls_token_id
์ ๋ ฅ์ ์์ ๋ถ๋ถ์ ์๋ ํน์ ํ ํฐ์ ์ฐพ๋ ๋ฐ ๋์์ ์ค๋๋ค.word_ids
๋ ์๋ณธwords
์์ ์ฐพ์ ๋ต๋ณ์ ์ ์ฒด ์ธ์ฝ๋ฉ๋ ์ ๋ ฅ์ ๋์ผํ ๋ต๊ณผ ์ผ์น์ํค๊ณ ์ธ์ฝ๋ฉ๋ ์ ๋ ฅ์์ ๋ต๋ณ์ ์์/๋ ์์น๋ฅผ ๊ฒฐ์ ํฉ๋๋ค.
์ ๋ด์ฉ๋ค์ ์ผ๋์ ๋๊ณ ๋ฐ์ดํฐ ์ธํธ ์์ ์ ๋ฐฐ์น๋ฅผ ์ธ์ฝ๋ฉํ๋ ํจ์๋ฅผ ๋ง๋ค์ด ๋ณด๊ฒ ์ต๋๋ค:
>>> def encode_dataset(examples, max_length=512):
... questions = examples["question"]
... words = examples["words"]
... boxes = examples["boxes"]
... answers = examples["answer"]
... # ์์ ๋ฐฐ์น๋ฅผ ์ธ์ฝ๋ฉํ๊ณ start_positions์ end_positions๋ฅผ ์ด๊ธฐํํฉ๋๋ค
... encoding = tokenizer(questions, words, boxes, max_length=max_length, padding="max_length", truncation=True)
... start_positions = []
... end_positions = []
... # ๋ฐฐ์น์ ์์ ๋ฅผ ๋ฐ๋ณตํฉ๋๋ค
... for i in range(len(questions)):
... cls_index = encoding["input_ids"][i].index(tokenizer.cls_token_id)
... # ์์ ์ words์์ ๋ต๋ณ์ ์์น๋ฅผ ์ฐพ์ต๋๋ค
... words_example = [word.lower() for word in words[i]]
... answer = answers[i]
... match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split())
... if match:
... # ์ผ์นํ๋ ํญ๋ชฉ์ ๋ฐ๊ฒฌํ๋ฉด, `token_type_ids`๋ฅผ ์ฌ์ฉํด ์ธ์ฝ๋ฉ์์ ๋จ์ด๊ฐ ์์ํ๋ ์์น๋ฅผ ์ฐพ์ต๋๋ค
... token_type_ids = encoding["token_type_ids"][i]
... token_start_index = 0
... while token_type_ids[token_start_index] != 1:
... token_start_index += 1
... token_end_index = len(encoding["input_ids"][i]) - 1
... while token_type_ids[token_end_index] != 1:
... token_end_index -= 1
... word_ids = encoding.word_ids(i)[token_start_index : token_end_index + 1]
... start_position = cls_index
... end_position = cls_index
... # words์ ๋ต๋ณ ์์น์ ์ผ์นํ ๋๊น์ง word_ids๋ฅผ ๋ฐ๋ณตํ๊ณ `token_start_index`๋ฅผ ๋๋ฆฝ๋๋ค
... # ์ผ์นํ๋ฉด `token_start_index`๋ฅผ ์ธ์ฝ๋ฉ์์ ๋ต๋ณ์ `start_position`์ผ๋ก ์ ์ฅํฉ๋๋ค
... for id in word_ids:
... if id == word_idx_start:
... start_position = token_start_index
... else:
... token_start_index += 1
... # ๋น์ทํ๊ฒ, ๋์์ ์์ํด `word_ids`๋ฅผ ๋ฐ๋ณตํ๋ฉฐ ๋ต๋ณ์ `end_position`์ ์ฐพ์ต๋๋ค
... for id in word_ids[::-1]:
... if id == word_idx_end:
... end_position = token_end_index
... else:
... token_end_index -= 1
... start_positions.append(start_position)
... end_positions.append(end_position)
... else:
... start_positions.append(cls_index)
... end_positions.append(cls_index)
... encoding["image"] = examples["image"]
... encoding["start_positions"] = start_positions
... encoding["end_positions"] = end_positions
... return encoding
์ด์ ์ด ์ ์ฒ๋ฆฌ ํจ์๊ฐ ์์ผ๋ ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ธ์ฝ๋ฉํ ์ ์์ต๋๋ค:
>>> encoded_train_dataset = dataset_with_ocr["train"].map(
... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["train"].column_names
... )
>>> encoded_test_dataset = dataset_with_ocr["test"].map(
... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["test"].column_names
... )
์ธ์ฝ๋ฉ๋ ๋ฐ์ดํฐ ์ธํธ์ ํน์ฑ์ด ์ด๋ป๊ฒ ์๊ฒผ๋์ง ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค:
>>> encoded_train_dataset.features
{'image': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='uint8', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None),
'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
'bbox': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
'start_positions': Value(dtype='int64', id=None),
'end_positions': Value(dtype='int64', id=None)}
ํ๊ฐ [[evaluation]]
๋ฌธ์ ์ง์ ์๋ต์ ํ๊ฐํ๋ ค๋ฉด ์๋นํ ์์ ํ์ฒ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค. ์๊ฐ์ด ๋๋ฌด ๋ง์ด ๊ฑธ๋ฆฌ์ง ์๋๋ก ์ด ๊ฐ์ด๋์์๋ ํ๊ฐ ๋จ๊ณ๋ฅผ ์๋ตํฉ๋๋ค.
[Trainer
]๊ฐ ํ๋ จ ๊ณผ์ ์์ ํ๊ฐ ์์ค(evaluation loss)์ ๊ณ์ ๊ณ์ฐํ๊ธฐ ๋๋ฌธ์ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๋๋ต์ ์ผ๋ก ์ ์ ์์ต๋๋ค.
์ถ์ถ์ (Extractive) ์ง์ ์๋ต์ ๋ณดํต F1/exact match ๋ฐฉ๋ฒ์ ์ฌ์ฉํด ํ๊ฐ๋ฉ๋๋ค.
์ง์ ๊ตฌํํด๋ณด๊ณ ์ถ์ผ์๋ค๋ฉด, Hugging Face course์ Question Answering chapter์ ์ฐธ๊ณ ํ์ธ์.
ํ๋ จ [[train]]
์ถํํฉ๋๋ค! ์ด ๊ฐ์ด๋์ ๊ฐ์ฅ ์ด๋ ค์ด ๋ถ๋ถ์ ์ฑ๊ณต์ ์ผ๋ก ์ฒ๋ฆฌํ์ผ๋ ์ด์ ๋๋ง์ ๋ชจ๋ธ์ ํ๋ จํ ์ค๋น๊ฐ ๋์์ต๋๋ค. ํ๋ จ์ ๋ค์๊ณผ ๊ฐ์ ๋จ๊ณ๋ก ์ด๋ฃจ์ด์ ธ ์์ต๋๋ค:
- ์ ์ฒ๋ฆฌ์์์ ๋์ผํ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด [
AutoModelForDocumentQuestionAnswering
]์ผ๋ก ๋ชจ๋ธ์ ๊ฐ์ ธ์ต๋๋ค. - [
TrainingArguments
]๋ก ํ๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ํฉ๋๋ค. - ์์ ๋ฅผ ๋ฐฐ์น ์ฒ๋ฆฌํ๋ ํจ์๋ฅผ ์ ์ํฉ๋๋ค. ์ฌ๊ธฐ์๋ [
DefaultDataCollator
]๊ฐ ์ ๋นํฉ๋๋ค. - ๋ชจ๋ธ, ๋ฐ์ดํฐ ์ธํธ, ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ(Data collator)์ ํจ๊ป [
Trainer
]์ ํ๋ จ ์ธ์๋ค์ ์ ๋ฌํฉ๋๋ค. - [
~Trainer.train
]์ ํธ์ถํด์ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํฉ๋๋ค.
>>> from transformers import AutoModelForDocumentQuestionAnswering
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_checkpoint)
[TrainingArguments
]์์ output_dir
์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ ์ฅํ ์์น๋ฅผ ์ง์ ํ๊ณ , ์ ์ ํ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ค์ ํฉ๋๋ค.
๋ชจ๋ธ์ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ๋ ค๋ฉด push_to_hub
๋ฅผ True
๋ก ์ค์ ํ์ธ์ (๋ชจ๋ธ์ ์
๋ก๋ํ๋ ค๋ฉด Hugging Face์ ๋ก๊ทธ์ธํด์ผ ํฉ๋๋ค).
์ด ๊ฒฝ์ฐ output_dir
์ ๋ชจ๋ธ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ํธ์ํ ๋ ํฌ์งํ ๋ฆฌ์ ์ด๋ฆ์ด ๋ฉ๋๋ค.
>>> from transformers import TrainingArguments
>>> # ๋ณธ์ธ์ ๋ ํฌ์งํ ๋ฆฌ ID๋ก ๋ฐ๊พธ์ธ์
>>> repo_id = "MariaK/layoutlmv2-base-uncased_finetuned_docvqa"
>>> training_args = TrainingArguments(
... output_dir=repo_id,
... per_device_train_batch_size=4,
... num_train_epochs=20,
... save_steps=200,
... logging_steps=50,
... evaluation_strategy="steps",
... learning_rate=5e-5,
... save_total_limit=2,
... remove_unused_columns=False,
... push_to_hub=True,
... )
๊ฐ๋จํ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๋ฅผ ์ ์ํ์ฌ ์์ ๋ฅผ ํจ๊ป ๋ฐฐ์นํฉ๋๋ค.
>>> from transformers import DefaultDataCollator
>>> data_collator = DefaultDataCollator()
๋ง์ง๋ง์ผ๋ก, ๋ชจ๋ ๊ฒ์ ํ ๊ณณ์ ๋ชจ์ [~Trainer.train
]์ ํธ์ถํฉ๋๋ค:
>>> from transformers import Trainer
>>> trainer = Trainer(
... model=model,
... args=training_args,
... data_collator=data_collator,
... train_dataset=encoded_train_dataset,
... eval_dataset=encoded_test_dataset,
... tokenizer=processor,
... )
>>> trainer.train()
์ต์ข
๋ชจ๋ธ์ ๐ค Hub์ ์ถ๊ฐํ๋ ค๋ฉด, ๋ชจ๋ธ ์นด๋๋ฅผ ์์ฑํ๊ณ push_to_hub
๋ฅผ ํธ์ถํฉ๋๋ค:
>>> trainer.create_model_card()
>>> trainer.push_to_hub()
์ถ๋ก [[inference]]
์ด์ LayoutLMv2 ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๊ณ ๐ค Hub์ ์
๋ก๋ํ์ผ๋ ์ถ๋ก ์๋ ์ฌ์ฉํ ์ ์์ต๋๋ค.
์ถ๋ก ์ ์ํด ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ฌ์ฉํด ๋ณด๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ [Pipeline
]์ ์ฌ์ฉํ๋ ๊ฒ ์
๋๋ค.
์๋ฅผ ๋ค์ด ๋ณด๊ฒ ์ต๋๋ค:
>>> example = dataset["test"][2]
>>> question = example["query"]["en"]
>>> image = example["image"]
>>> print(question)
>>> print(example["answers"])
'Who is โpresidingโ TRRF GENERAL SESSION (PART 1)?'
['TRRF Vice President', 'lee a. waller']
๊ทธ ๋ค์, ๋ชจ๋ธ๋ก ๋ฌธ์ ์ง์ ์๋ต์ ํ๊ธฐ ์ํด ํ์ดํ๋ผ์ธ์ ์ธ์คํด์คํํ๊ณ ์ด๋ฏธ์ง + ์ง๋ฌธ ์กฐํฉ์ ์ ๋ฌํฉ๋๋ค.
>>> from transformers import pipeline
>>> qa_pipeline = pipeline("document-question-answering", model="MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> qa_pipeline(image, question)
[{'score': 0.9949808120727539,
'answer': 'Lee A. Waller',
'start': 55,
'end': 57}]
์ํ๋ค๋ฉด ํ์ดํ๋ผ์ธ์ ๊ฒฐ๊ณผ๋ฅผ ์๋์ผ๋ก ๋ณต์ ํ ์๋ ์์ต๋๋ค:
- ์ด๋ฏธ์ง์ ์ง๋ฌธ์ ๊ฐ์ ธ์ ๋ชจ๋ธ์ ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํด ๋ชจ๋ธ์ ๋ง๊ฒ ์ค๋นํฉ๋๋ค.
- ๋ชจ๋ธ์ ํตํด ๊ฒฐ๊ณผ ๋๋ ์ ์ฒ๋ฆฌ๋ฅผ ์ ๋ฌํฉ๋๋ค.
- ๋ชจ๋ธ์ ์ด๋ค ํ ํฐ์ด ๋ต๋ณ์ ์์์ ์๋์ง, ์ด๋ค ํ ํฐ์ด ๋ต๋ณ์ด ๋์ ์๋์ง๋ฅผ ๋ํ๋ด๋
start_logits
์end_logits
๋ฅผ ๋ฐํํฉ๋๋ค. ๋ ๋ค (batch_size, sequence_length) ํํ๋ฅผ ๊ฐ์ต๋๋ค. start_logits
์end_logits
์ ๋ง์ง๋ง ์ฐจ์์ ์ต๋๋ก ๋ง๋๋ ๊ฐ์ ์ฐพ์ ์์start_idx
์end_idx
๋ฅผ ์ป์ต๋๋ค.- ํ ํฌ๋์ด์ ๋ก ๋ต๋ณ์ ๋์ฝ๋ฉํฉ๋๋ค.
>>> import torch
>>> from transformers import AutoProcessor
>>> from transformers import AutoModelForDocumentQuestionAnswering
>>> processor = AutoProcessor.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> with torch.no_grad():
... encoding = processor(image.convert("RGB"), question, return_tensors="pt")
... outputs = model(**encoding)
... start_logits = outputs.start_logits
... end_logits = outputs.end_logits
... predicted_start_idx = start_logits.argmax(-1).item()
... predicted_end_idx = end_logits.argmax(-1).item()
>>> processor.tokenizer.decode(encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1])
'lee a. waller'