songy / transformers /docs /source /ko /tasks /visual_question_answering.md
trishv's picture
Upload 2383 files
96e9536
|
raw
history blame
17.8 kB

์‹œ๊ฐ์  ์งˆ์˜์‘๋‹ต (Visual Question Answering)

[[open-in-colab]]

์‹œ๊ฐ์  ์งˆ์˜์‘๋‹ต(VQA)์€ ์ด๋ฏธ์ง€๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ฐœ๋ฐฉํ˜• ์งˆ๋ฌธ์— ๋Œ€์‘ํ•˜๋Š” ์ž‘์—…์ž…๋‹ˆ๋‹ค. ์ด ์ž‘์—…์„ ์ง€์›ํ•˜๋Š” ๋ชจ๋ธ์˜ ์ž…๋ ฅ์€ ๋Œ€๋ถ€๋ถ„ ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์˜ ์กฐํ•ฉ์ด๋ฉฐ, ์ถœ๋ ฅ์€ ์ž์—ฐ์–ด๋กœ ๋œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค.

VQA์˜ ์ฃผ์š” ์‚ฌ์šฉ ์‚ฌ๋ก€๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  • ์‹œ๊ฐ ์žฅ์• ์ธ์„ ์œ„ํ•œ ์ ‘๊ทผ์„ฑ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ๊ตฌ์ถ•ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ต์œก: ๊ฐ•์˜๋‚˜ ๊ต๊ณผ์„œ์— ๋‚˜์˜จ ์‹œ๊ฐ ์ž๋ฃŒ์— ๋Œ€ํ•œ ์งˆ๋ฌธ์— ๋‹ตํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ ์ฒดํ—˜ํ˜• ์ „์‹œ์™€ ์œ ์  ๋“ฑ์—์„œ๋„ VQA๋ฅผ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ณ ๊ฐ ์„œ๋น„์Šค ๋ฐ ์ „์ž์ƒ๊ฑฐ๋ž˜: VQA๋Š” ์‚ฌ์šฉ์ž๊ฐ€ ์ œํ’ˆ์— ๋Œ€ํ•ด ์งˆ๋ฌธํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•จ์œผ๋กœ์จ ์‚ฌ์šฉ์ž ๊ฒฝํ—˜์„ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ด๋ฏธ์ง€ ๊ฒ€์ƒ‰: VQA ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์›ํ•˜๋Š” ํŠน์„ฑ์„ ๊ฐ€์ง„ ์ด๋ฏธ์ง€๋ฅผ ๊ฒ€์ƒ‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ์‚ฌ์šฉ์ž๋Š” "๊ฐ•์•„์ง€๊ฐ€ ์žˆ์–ด?"๋ผ๊ณ  ๋ฌผ์–ด๋ด์„œ ์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€ ๋ฌถ์Œ์—์„œ ๊ฐ•์•„์ง€๊ฐ€ ์žˆ๋Š” ๋ชจ๋“  ์ด๋ฏธ์ง€๋ฅผ ๋ฐ›์•„๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ ํ•™์Šตํ•  ๋‚ด์šฉ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  • VQA ๋ชจ๋ธ ์ค‘ ํ•˜๋‚˜์ธ ViLT๋ฅผ Graphcore/vqa ๋ฐ์ดํ„ฐ์…‹ ์—์„œ ๋ฏธ์„ธ์กฐ์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•
  • ๋ฏธ์„ธ์กฐ์ •๋œ ViLT ๋ชจ๋ธ๋กœ ์ถ”๋ก ํ•˜๋Š” ๋ฐฉ๋ฒ•
  • BLIP-2 ๊ฐ™์€ ์ƒ์„ฑ ๋ชจ๋ธ๋กœ ์ œ๋กœ์ƒท VQA ์ถ”๋ก ์„ ์‹คํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•

ViLT ๋ฏธ์„ธ ์กฐ์ • [[finetuning-vilt]]

ViLT๋Š” Vision Transformer (ViT) ๋‚ด์— ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ์„ ํฌํ•จํ•˜์—ฌ ๋น„์ „/์ž์—ฐ์–ด ์‚ฌ์ „ํ›ˆ๋ จ(VLP; Vision-and-Language Pretraining)์„ ์œ„ํ•œ ๊ธฐ๋ณธ ๋””์ž์ธ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ViLT ๋ชจ๋ธ์€ ๋น„์ „ ํŠธ๋žœ์Šคํฌ๋จธ(ViT)์— ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ์„ ๋„ฃ์–ด ๋น„์ „/์–ธ์–ด ์‚ฌ์ „ํ›ˆ๋ จ(VLP; Vision-and-Language Pre-training)์„ ์œ„ํ•œ ๊ธฐ๋ณธ์ ์ธ ๋””์ž์ธ์„ ๊ฐ–์ท„์Šต๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ์—ฌ๋Ÿฌ ๋‹ค์šด์ŠคํŠธ๋ฆผ ์ž‘์—…์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. VQA ํƒœ์Šคํฌ์—์„œ๋Š” ([CLS] ํ† ํฐ์˜ ์ตœ์ข… ์€๋‹‰ ์ƒํƒœ ์œ„์— ์„ ํ˜• ๋ ˆ์ด์–ด์ธ) ๋ถ„๋ฅ˜ ํ—ค๋”๊ฐ€ ์žˆ์œผ๋ฉฐ ๋ฌด์ž‘์œ„๋กœ ์ดˆ๊ธฐํ™”๋ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์—ฌ๊ธฐ์—์„œ ์‹œ๊ฐ์  ์งˆ์˜์‘๋‹ต์€ ๋ถ„๋ฅ˜ ๋ฌธ์ œ๋กœ ์ทจ๊ธ‰๋ฉ๋‹ˆ๋‹ค.

์ตœ๊ทผ์˜ BLIP, BLIP-2, InstructBLIP์™€ ๊ฐ™์€ ๋ชจ๋ธ๋“ค์€ VQA๋ฅผ ์ƒ์„ฑํ˜• ์ž‘์—…์œผ๋กœ ๊ฐ„์ฃผํ•ฉ๋‹ˆ๋‹ค. ๊ฐ€์ด๋“œ์˜ ํ›„๋ฐ˜๋ถ€์—์„œ๋Š” ์ด๋Ÿฐ ๋ชจ๋ธ๋“ค์„ ์‚ฌ์šฉํ•˜์—ฌ ์ œ๋กœ์ƒท VQA ์ถ”๋ก ์„ ํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์„ค๋ช…ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

์‹œ์ž‘ํ•˜๊ธฐ ์ „ ํ•„์š”ํ•œ ๋ชจ๋“  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์„ค์น˜ํ–ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.

pip install -q transformers datasets

์ปค๋ฎค๋‹ˆํ‹ฐ์— ๋ชจ๋ธ์„ ๊ณต์œ ํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅ ๋“œ๋ฆฝ๋‹ˆ๋‹ค. Hugging Face ๊ณ„์ •์— ๋กœ๊ทธ์ธํ•˜์—ฌ ๐Ÿค— Hub์— ์—…๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฉ”์‹œ์ง€๊ฐ€ ๋‚˜ํƒ€๋‚˜๋ฉด ๋กœ๊ทธ์ธํ•  ํ† ํฐ์„ ์ž…๋ ฅํ•˜์„ธ์š”:

>>> from huggingface_hub import notebook_login

>>> notebook_login()

๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ „์—ญ ๋ณ€์ˆ˜๋กœ ์„ ์–ธํ•˜์„ธ์š”.

>>> model_checkpoint = "dandelin/vilt-b32-mlm"

๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ [[load-the-data]]

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” Graphcore/vqa ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ์ž‘์€ ์ƒ˜ํ”Œ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ „์ฒด ๋ฐ์ดํ„ฐ์„ธํŠธ๋Š” ๐Ÿค— Hub ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Graphcore/vqa ๋ฐ์ดํ„ฐ์„ธํŠธ ์˜ ๋Œ€์•ˆ์œผ๋กœ ๊ณต์‹ VQA ๋ฐ์ดํ„ฐ์„ธํŠธ ํŽ˜์ด์ง€ ์—์„œ ๋™์ผํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ง์ ‘ ๊ณต์ˆ˜ํ•œ ๋ฐ์ดํ„ฐ๋กœ ํŠœํ† ๋ฆฌ์–ผ์„ ๋”ฐ๋ฅด๊ณ  ์‹ถ๋‹ค๋ฉด ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์„ธํŠธ ๋งŒ๋“ค๊ธฐ ๋ผ๋Š” ๐Ÿค— Datasets ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์˜ ์ฒซ 200๊ฐœ ํ•ญ๋ชฉ์„ ๋ถˆ๋Ÿฌ์™€ ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ํŠน์„ฑ์„ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> from datasets import load_dataset

>>> dataset = load_dataset("Graphcore/vqa", split="validation[:200]")
>>> dataset
Dataset({
    features: ['question', 'question_type', 'question_id', 'image_id', 'answer_type', 'label'],
    num_rows: 200
})

์˜ˆ์ œ๋ฅผ ํ•˜๋‚˜ ๋ฝ‘์•„ ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ํŠน์„ฑ์„ ์ดํ•ดํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

>>> dataset[0]
{'question': 'Where is he looking?',
 'question_type': 'none of the above',
 'question_id': 262148000,
 'image_id': '/root/.cache/huggingface/datasets/downloads/extracted/ca733e0e000fb2d7a09fbcc94dbfe7b5a30750681d0e965f8e0a23b1c2f98c75/val2014/COCO_val2014_000000262148.jpg',
 'answer_type': 'other',
 'label': {'ids': ['at table', 'down', 'skateboard', 'table'],
  'weights': [0.30000001192092896,
   1.0,
   0.30000001192092896,
   0.30000001192092896]}}

๋ฐ์ดํ„ฐ์„ธํŠธ์—๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํŠน์„ฑ์ด ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค:

  • question: ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์งˆ๋ฌธ
  • image_id: ์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ ์ด๋ฏธ์ง€์˜ ๊ฒฝ๋กœ
  • label: ๋ฐ์ดํ„ฐ์˜ ๋ ˆ์ด๋ธ” (annotations)

๋‚˜๋จธ์ง€ ํŠน์„ฑ๋“ค์€ ํ•„์š”ํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ์‚ญ์ œํ•ด๋„ ๋ฉ๋‹ˆ๋‹ค:

>>> dataset = dataset.remove_columns(['question_type', 'question_id', 'answer_type'])

๋ณด์‹œ๋‹ค์‹œํ”ผ label ํŠน์„ฑ์€ ๊ฐ™์€ ์งˆ๋ฌธ๋งˆ๋‹ค ๋‹ต๋ณ€์ด ์—ฌ๋Ÿฌ ๊ฐœ ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ชจ๋‘ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ ๋ผ๋ฒจ๋Ÿฌ๋“ค๋กœ๋ถ€ํ„ฐ ์ˆ˜์ง‘๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์ธ๋ฐ์š”. ์งˆ๋ฌธ์˜ ๋‹ต๋ณ€์€ ์ฃผ๊ด€์ ์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ ์งˆ๋ฌธ์€ "๊ทธ๋Š” ์–ด๋””๋ฅผ ๋ณด๊ณ  ์žˆ๋‚˜์š”?" ์˜€์ง€๋งŒ, ์–ด๋–ค ์‚ฌ๋žŒ๋“ค์€ "์•„๋ž˜"๋กœ ๋ ˆ์ด๋ธ”์„ ๋‹ฌ์•˜๊ณ , ๋‹ค๋ฅธ ์‚ฌ๋žŒ๋“ค์€ "ํ…Œ์ด๋ธ”" ๋˜๋Š” "์Šค์ผ€์ดํŠธ๋ณด๋“œ" ๋“ฑ์œผ๋กœ ์ฃผ์„์„ ๋‹ฌ์•˜์Šต๋‹ˆ๋‹ค.

์•„๋ž˜์˜ ์ด๋ฏธ์ง€๋ฅผ ๋ณด๊ณ  ์–ด๋–ค ๋‹ต๋ณ€์„ ์„ ํƒํ•  ๊ฒƒ์ธ์ง€ ์ƒ๊ฐํ•ด ๋ณด์„ธ์š”:

>>> from PIL import Image

>>> image = Image.open(dataset[0]['image_id'])
>>> image
VQA Image Example

์งˆ๋ฌธ๊ณผ ๋‹ต๋ณ€์˜ ๋ชจํ˜ธ์„ฑ์œผ๋กœ ์ธํ•ด ์ด๋Ÿฌํ•œ ๋ฐ์ดํ„ฐ์„ธํŠธ๋Š” ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๋‹ต๋ณ€์ด ๊ฐ€๋Šฅํ•˜๋ฏ€๋กœ ๋‹ค์ค‘ ๋ ˆ์ด๋ธ” ๋ถ„๋ฅ˜ ๋ฌธ์ œ๋กœ ์ฒ˜๋ฆฌ๋ฉ๋‹ˆ๋‹ค. ๊ฒŒ๋‹ค๊ฐ€, ์›ํ•ซ(one-hot) ์ธ์ฝ”๋”ฉ ๋ฒกํ„ฐ๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ๋ณด๋‹ค๋Š” ๋ ˆ์ด๋ธ”์—์„œ ํŠน์ • ๋‹ต๋ณ€์ด ๋‚˜ํƒ€๋‚˜๋Š” ํšŸ์ˆ˜๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์†Œํ”„ํŠธ ์ธ์ฝ”๋”ฉ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

์œ„์˜ ์˜ˆ์‹œ์—์„œ "์•„๋ž˜"๋ผ๋Š” ๋‹ต๋ณ€์ด ๋‹ค๋ฅธ ๋‹ต๋ณ€๋ณด๋‹ค ํ›จ์”ฌ ๋” ์ž์ฃผ ์„ ํƒ๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์— ๋ฐ์ดํ„ฐ์„ธํŠธ์—์„œ weight๋ผ๊ณ  ๋ถˆ๋ฆฌ๋Š” ์ ์ˆ˜๋กœ 1.0์„ ๊ฐ€์ง€๋ฉฐ, ๋‚˜๋จธ์ง€ ๋‹ต๋ณ€๋“ค์€ 1.0 ๋ฏธ๋งŒ์˜ ์ ์ˆ˜๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค.

์ ์ ˆํ•œ ๋ถ„๋ฅ˜ ํ—ค๋”๋กœ ๋ชจ๋ธ์„ ๋‚˜์ค‘์— ์ธ์Šคํ„ด์Šคํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๋ ˆ์ด๋ธ”์„ ์ •์ˆ˜๋กœ ๋งคํ•‘ํ•œ ๋”•์…”๋„ˆ๋ฆฌ ํ•˜๋‚˜, ๋ฐ˜๋Œ€๋กœ ์ •์ˆ˜๋ฅผ ๋ ˆ์ด๋ธ”๋กœ ๋งคํ•‘ํ•œ ๋”•์…”๋„ˆ๋ฆฌ ํ•˜๋‚˜ ์ด 2๊ฐœ์˜ ๋”•์…”๋„ˆ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•˜์„ธ์š”:

>>> import itertools

>>> labels = [item['ids'] for item in dataset['label']]
>>> flattened_labels = list(itertools.chain(*labels))
>>> unique_labels = list(set(flattened_labels))

>>> label2id = {label: idx for idx, label in enumerate(unique_labels)}
>>> id2label = {idx: label for label, idx in label2id.items()} 

์ด์ œ ๋งคํ•‘์ด ์™„๋ฃŒ๋˜์—ˆ์œผ๋ฏ€๋กœ ๋ฌธ์ž์—ด ๋‹ต๋ณ€์„ ํ•ด๋‹น id๋กœ ๊ต์ฒดํ•˜๊ณ , ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ๋” ํŽธ๋ฆฌํ•œ ํ›„์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด ํŽธํ‰ํ™” ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> def replace_ids(inputs):
...   inputs["label"]["ids"] = [label2id[x] for x in inputs["label"]["ids"]]
...   return inputs


>>> dataset = dataset.map(replace_ids)
>>> flat_dataset = dataset.flatten()
>>> flat_dataset.features
{'question': Value(dtype='string', id=None),
 'image_id': Value(dtype='string', id=None),
 'label.ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'label.weights': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None)}

๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ [[preprocessing-data]]

๋‹ค์Œ ๋‹จ๊ณ„๋Š” ๋ชจ๋ธ์„ ์œ„ํ•ด ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ์ค€๋น„ํ•˜๊ธฐ ์œ„ํ•ด ViLT ํ”„๋กœ์„ธ์„œ๋ฅผ ๊ฐ€์ ธ์˜ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. [ViltProcessor]๋Š” BERT ํ† ํฌ๋‚˜์ด์ €์™€ ViLT ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋ฅผ ํŽธ๋ฆฌํ•˜๊ฒŒ ํ•˜๋‚˜์˜ ํ”„๋กœ์„ธ์„œ๋กœ ๋ฌถ์Šต๋‹ˆ๋‹ค:

>>> from transformers import ViltProcessor

>>> processor = ViltProcessor.from_pretrained(model_checkpoint)

๋ฐ์ดํ„ฐ๋ฅผ ์ „์ฒ˜๋ฆฌํ•˜๋ ค๋ฉด ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์„ [ViltProcessor]๋กœ ์ธ์ฝ”๋”ฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ํ”„๋กœ์„ธ์„œ๋Š” [BertTokenizerFast]๋กœ ํ…์ŠคํŠธ๋ฅผ ํ† ํฌ๋‚˜์ด์ฆˆํ•˜๊ณ  ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ์œ„ํ•ด input_ids, attention_mask ๋ฐ token_type_ids๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€๋Š” [ViltImageProcessor]๋กœ ์ด๋ฏธ์ง€๋ฅผ ํฌ๊ธฐ ์กฐ์ •ํ•˜๊ณ  ์ •๊ทœํ™”ํ•˜๋ฉฐ, pixel_values์™€ pixel_mask๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

์ด๋Ÿฐ ์ „์ฒ˜๋ฆฌ ๋‹จ๊ณ„๋Š” ๋ชจ๋‘ ๋‚ด๋ถ€์—์„œ ์ด๋ฃจ์–ด์ง€๋ฏ€๋กœ, processor๋ฅผ ํ˜ธ์ถœํ•˜๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์•„์ง ํƒ€๊ฒŸ ๋ ˆ์ด๋ธ”์ด ์™„์„ฑ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ํƒ€๊ฒŸ์˜ ํ‘œํ˜„์—์„œ ๊ฐ ์š”์†Œ๋Š” ๊ฐ€๋Šฅํ•œ ๋‹ต๋ณ€(๋ ˆ์ด๋ธ”)์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค. ์ •ํ™•ํ•œ ๋‹ต๋ณ€์˜ ์š”์†Œ๋Š” ํ•ด๋‹น ์ ์ˆ˜(weight)๋ฅผ ์œ ์ง€์‹œํ‚ค๊ณ  ๋‚˜๋จธ์ง€ ์š”์†Œ๋Š” 0์œผ๋กœ ์„ค์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์•„๋ž˜ ํ•จ์ˆ˜๊ฐ€ ์œ„์—์„œ ์„ค๋ช…ํ•œ๋Œ€๋กœ ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์— processor๋ฅผ ์ ์šฉํ•˜๊ณ  ๋ ˆ์ด๋ธ”์„ ํ˜•์‹์— ๋งž์ถฅ๋‹ˆ๋‹ค:

>>> import torch

>>> def preprocess_data(examples):
...     image_paths = examples['image_id']
...     images = [Image.open(image_path) for image_path in image_paths]
...     texts = examples['question']    

...     encoding = processor(images, texts, padding="max_length", truncation=True, return_tensors="pt")

...     for k, v in encoding.items():
...           encoding[k] = v.squeeze()
    
...     targets = []

...     for labels, scores in zip(examples['label.ids'], examples['label.weights']):
...         target = torch.zeros(len(id2label))

...         for label, score in zip(labels, scores):
...             target[label] = score
      
...         targets.append(target)

...     encoding["labels"] = targets
    
...     return encoding

์ „์ฒด ๋ฐ์ดํ„ฐ์„ธํŠธ์— ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•˜๋ ค๋ฉด ๐Ÿค— Datasets์˜ [~datasets.map] ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์‹ญ์‹œ์˜ค. batched=True๋ฅผ ์„ค์ •ํ•˜์—ฌ ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ์—ฌ๋Ÿฌ ์š”์†Œ๋ฅผ ํ•œ ๋ฒˆ์— ์ฒ˜๋ฆฌํ•จ์œผ๋กœ์จ map์„ ๋” ๋น ๋ฅด๊ฒŒ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์‹œ์ ์—์„œ ํ•„์š”ํ•˜์ง€ ์•Š์€ ์—ด์€ ์ œ๊ฑฐํ•˜์„ธ์š”.

>>> processed_dataset = flat_dataset.map(preprocess_data, batched=True, remove_columns=['question','question_type',  'question_id', 'image_id', 'answer_type', 'label.ids', 'label.weights'])
>>> processed_dataset
Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'],
    num_rows: 200
})

๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„๋กœ, [DefaultDataCollator]๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ œ๋กœ ์“ธ ๋ฐฐ์น˜๋ฅผ ์ƒ์„ฑํ•˜์„ธ์š”:

>>> from transformers import DefaultDataCollator

>>> data_collator = DefaultDataCollator()

๋ชจ๋ธ ํ›ˆ๋ จ [[train-the-model]]

์ด์ œ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•˜๊ธฐ ์œ„ํ•ด ์ค€๋น„๋˜์—ˆ์Šต๋‹ˆ๋‹ค! [ViltForQuestionAnswering]์œผ๋กœ ViLT๋ฅผ ๊ฐ€์ ธ์˜ฌ ์ฐจ๋ก€์ž…๋‹ˆ๋‹ค. ๋ ˆ์ด๋ธ”์˜ ์ˆ˜์™€ ๋ ˆ์ด๋ธ” ๋งคํ•‘์„ ์ง€์ •ํ•˜์„ธ์š”:

>>> from transformers import ViltForQuestionAnswering

>>> model = ViltForQuestionAnswering.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)

์ด ์‹œ์ ์—์„œ๋Š” ๋‹ค์Œ ์„ธ ๋‹จ๊ณ„๋งŒ ๋‚จ์•˜์Šต๋‹ˆ๋‹ค:

  1. [TrainingArguments]์—์„œ ํ›ˆ๋ จ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ •์˜ํ•˜์„ธ์š”:
>>> from transformers import TrainingArguments

>>> repo_id = "MariaK/vilt_finetuned_200"

>>> training_args = TrainingArguments(
...     output_dir=repo_id,
...     per_device_train_batch_size=4,
...     num_train_epochs=20,
...     save_steps=200,
...     logging_steps=50,
...     learning_rate=5e-5,
...     save_total_limit=2,
...     remove_unused_columns=False,
...     push_to_hub=True,
... )
  1. ๋ชจ๋ธ, ๋ฐ์ดํ„ฐ์„ธํŠธ, ํ”„๋กœ์„ธ์„œ, ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ์™€ ํ•จ๊ป˜ ํ›ˆ๋ จ ์ธ์ˆ˜๋ฅผ [Trainer]์— ์ „๋‹ฌํ•˜์„ธ์š”:
>>> from transformers import Trainer

>>> trainer = Trainer(
...     model=model,
...     args=training_args,
...     data_collator=data_collator,
...     train_dataset=processed_dataset,
...     tokenizer=processor,
... )
  1. [~Trainer.train]์„ ํ˜ธ์ถœํ•˜์—ฌ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜์„ธ์š”:
>>> trainer.train() 

ํ›ˆ๋ จ์ด ์™„๋ฃŒ๋˜๋ฉด, [~Trainer.push_to_hub] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๐Ÿค— Hub์— ๋ชจ๋ธ์„ ๊ณต์œ ํ•˜์„ธ์š”:

>>> trainer.push_to_hub()

์ถ”๋ก  [[inference]]

ViLT ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ณ  ๐Ÿค— Hub์— ์—…๋กœ๋“œํ–ˆ๋‹ค๋ฉด ์ถ”๋ก ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์ถ”๋ก ์— ์‚ฌ์šฉํ•ด๋ณด๋Š” ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์€ [Pipeline]์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

>>> from transformers import pipeline

>>> pipe = pipeline("visual-question-answering", model="MariaK/vilt_finetuned_200")

์ด ๊ฐ€์ด๋“œ์˜ ๋ชจ๋ธ์€ 200๊ฐœ์˜ ์˜ˆ์ œ์—์„œ๋งŒ ํ›ˆ๋ จ๋˜์—ˆ์œผ๋ฏ€๋กœ ๊ทธ๋‹ค์ง€ ๋งŽ์€ ๊ฒƒ์„ ๊ธฐ๋Œ€ํ•  ์ˆ˜๋Š” ์—†์Šต๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ์ฒซ ๋ฒˆ์งธ ์˜ˆ์ œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๋ก  ๊ฒฐ๊ณผ๋ฅผ ์„ค๋ช…ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> example = dataset[0]
>>> image = Image.open(example['image_id'])
>>> question = example['question']
>>> print(question)
>>> pipe(image, question, top_k=1)
"Where is he looking?"
[{'score': 0.5498199462890625, 'answer': 'down'}]

๋น„๋ก ํ™•์‹ ์€ ๋ณ„๋กœ ์—†์ง€๋งŒ, ๋ชจ๋ธ์€ ์‹ค์ œ๋กœ ๋ฌด์–ธ๊ฐ€๋ฅผ ๋ฐฐ์› ์Šต๋‹ˆ๋‹ค. ๋” ๋งŽ์€ ์˜ˆ์ œ์™€ ๋” ๊ธด ํ›ˆ๋ จ ๊ธฐ๊ฐ„์ด ์ฃผ์–ด์ง„๋‹ค๋ฉด ๋ถ„๋ช… ๋” ๋‚˜์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค!

์›ํ•œ๋‹ค๋ฉด ํŒŒ์ดํ”„๋ผ์ธ์˜ ๊ฒฐ๊ณผ๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋ณต์ œํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:

  1. ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์„ ๊ฐ€์ ธ์™€์„œ ํ”„๋กœ์„ธ์„œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์— ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค.
  2. ์ „์ฒ˜๋ฆฌ๋œ ๊ฒฐ๊ณผ๋ฅผ ๋ชจ๋ธ์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
  3. ๋กœ์ง“์—์„œ ๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ ์žˆ๋Š” ๋‹ต๋ณ€์˜ id๋ฅผ ๊ฐ€์ ธ์™€์„œ id2label์—์„œ ์‹ค์ œ ๋‹ต๋ณ€์„ ์ฐพ์Šต๋‹ˆ๋‹ค.
>>> processor = ViltProcessor.from_pretrained("MariaK/vilt_finetuned_200")

>>> image = Image.open(example['image_id'])
>>> question = example['question']

>>> # prepare inputs
>>> inputs = processor(image, question, return_tensors="pt")

>>> model = ViltForQuestionAnswering.from_pretrained("MariaK/vilt_finetuned_200")

>>> # forward pass
>>> with torch.no_grad():
...     outputs = model(**inputs)

>>> logits = outputs.logits
>>> idx = logits.argmax(-1).item()
>>> print("Predicted answer:", model.config.id2label[idx])
Predicted answer: down

์ œ๋กœ์ƒท VQA [[zeroshot-vqa]]

์ด์ „ ๋ชจ๋ธ์€ VQA๋ฅผ ๋ถ„๋ฅ˜ ๋ฌธ์ œ๋กœ ์ฒ˜๋ฆฌํ–ˆ์Šต๋‹ˆ๋‹ค. BLIP, BLIP-2 ๋ฐ InstructBLIP์™€ ๊ฐ™์€ ์ตœ๊ทผ์˜ ๋ชจ๋ธ์€ VQA๋ฅผ ์ƒ์„ฑ ์ž‘์—…์œผ๋กœ ์ ‘๊ทผํ•ฉ๋‹ˆ๋‹ค. BLIP-2๋ฅผ ์˜ˆ๋กœ ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ์‚ฌ์ „ํ›ˆ๋ จ๋œ ๋น„์ „ ์ธ์ฝ”๋”์™€ LLM์˜ ๋ชจ๋“  ์กฐํ•ฉ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ์ƒˆ๋กœ์šด ๋น„์ „-์ž์—ฐ์–ด ์‚ฌ์ „ ํ•™์Šต ํŒจ๋Ÿฌ๋‹ค์ž„์„ ๋„์ž…ํ–ˆ์Šต๋‹ˆ๋‹ค. (BLIP-2 ๋ธ”๋กœ๊ทธ ํฌ์ŠคํŠธ๋ฅผ ํ†ตํ•ด ๋” ์ž์„ธํžˆ ์•Œ์•„๋ณผ ์ˆ˜ ์žˆ์–ด์š”) ์ด๋ฅผ ํ†ตํ•ด ์‹œ๊ฐ์  ์งˆ์˜์‘๋‹ต์„ ํฌํ•จํ•œ ์—ฌ๋Ÿฌ ๋น„์ „-์ž์—ฐ์–ด ์ž‘์—…์—์„œ SOTA๋ฅผ ๋‹ฌ์„ฑํ•  ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.

์ด ๋ชจ๋ธ์„ ์–ด๋–ป๊ฒŒ VQA์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š”์ง€ ์„ค๋ช…ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋จผ์ € ๋ชจ๋ธ์„ ๊ฐ€์ ธ์™€ ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ GPU๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ ๋ชจ๋ธ์„ ๋ช…์‹œ์ ์œผ๋กœ GPU๋กœ ์ „์†กํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด์ „์—๋Š” ํ›ˆ๋ จํ•  ๋•Œ ์“ฐ์ง€ ์•Š์€ ์ด์œ ๋Š” [Trainer]๊ฐ€ ์ด ๋ถ€๋ถ„์„ ์ž๋™์œผ๋กœ ์ฒ˜๋ฆฌํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค:

>>> from transformers import AutoProcessor, Blip2ForConditionalGeneration
>>> import torch

>>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model.to(device)

๋ชจ๋ธ์€ ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์œผ๋ฏ€๋กœ, VQA ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ์ฒซ ๋ฒˆ์งธ ์˜ˆ์ œ์—์„œ์™€ ๋™์ผํ•œ ์ด๋ฏธ์ง€/์งˆ๋ฌธ ์Œ์„ ์‚ฌ์šฉํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> example = dataset[0]
>>> image = Image.open(example['image_id'])
>>> question = example['question']

BLIP-2๋ฅผ ์‹œ๊ฐ์  ์งˆ์˜์‘๋‹ต ์ž‘์—…์— ์‚ฌ์šฉํ•˜๋ ค๋ฉด ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ๊ฐ€ Question: {} Answer: ํ˜•์‹์„ ๋”ฐ๋ผ์•ผ ํ•ฉ๋‹ˆ๋‹ค.

>>> prompt = f"Question: {question} Answer:" 

์ด์ œ ๋ชจ๋ธ์˜ ํ”„๋กœ์„ธ์„œ๋กœ ์ด๋ฏธ์ง€/ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ „์ฒ˜๋ฆฌํ•˜๊ณ , ์ฒ˜๋ฆฌ๋œ ์ž…๋ ฅ์„ ๋ชจ๋ธ์„ ํ†ตํ•ด ์ „๋‹ฌํ•˜๊ณ , ์ถœ๋ ฅ์„ ๋””์ฝ”๋“œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

>>> inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

>>> generated_ids = model.generate(**inputs, max_new_tokens=10)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
>>> print(generated_text)
"He is looking at the crowd" 

๋ณด์‹œ๋‹ค์‹œํ”ผ ๋ชจ๋ธ์€ ๊ตฐ์ค‘์„ ์ธ์‹ํ•˜๊ณ , ์–ผ๊ตด์˜ ๋ฐฉํ–ฅ(์•„๋ž˜์ชฝ์„ ๋ณด๊ณ  ์žˆ์Œ)์„ ์ธ์‹ํ–ˆ์ง€๋งŒ, ๊ตฐ์ค‘์ด ์Šค์ผ€์ดํ„ฐ ๋’ค์— ์žˆ๋‹ค๋Š” ์‚ฌ์‹ค์„ ๋†“์ณค์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์‚ฌ๋žŒ์ด ์ง์ ‘ ๋ผ๋ฒจ๋งํ•œ ๋ฐ์ดํ„ฐ์…‹์„ ์–ป์„ ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ์—, ์ด ์ ‘๊ทผ๋ฒ•์€ ๋น ๋ฅด๊ฒŒ ์œ ์šฉํ•œ ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.