Spaces:
Runtime error
์๊ฐ์ ์ง์์๋ต (Visual Question Answering)
[[open-in-colab]]
์๊ฐ์ ์ง์์๋ต(VQA)์ ์ด๋ฏธ์ง๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฐ๋ฐฉํ ์ง๋ฌธ์ ๋์ํ๋ ์์ ์ ๋๋ค. ์ด ์์ ์ ์ง์ํ๋ ๋ชจ๋ธ์ ์ ๋ ฅ์ ๋๋ถ๋ถ ์ด๋ฏธ์ง์ ์ง๋ฌธ์ ์กฐํฉ์ด๋ฉฐ, ์ถ๋ ฅ์ ์์ฐ์ด๋ก ๋ ๋ต๋ณ์ ๋๋ค.
VQA์ ์ฃผ์ ์ฌ์ฉ ์ฌ๋ก๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ์๊ฐ ์ฅ์ ์ธ์ ์ํ ์ ๊ทผ์ฑ ์ ํ๋ฆฌ์ผ์ด์ ์ ๊ตฌ์ถํ ์ ์์ต๋๋ค.
- ๊ต์ก: ๊ฐ์๋ ๊ต๊ณผ์์ ๋์จ ์๊ฐ ์๋ฃ์ ๋ํ ์ง๋ฌธ์ ๋ตํ ์ ์์ต๋๋ค. ๋ํ ์ฒดํํ ์ ์์ ์ ์ ๋ฑ์์๋ VQA๋ฅผ ํ์ฉํ ์ ์์ต๋๋ค.
- ๊ณ ๊ฐ ์๋น์ค ๋ฐ ์ ์์๊ฑฐ๋: VQA๋ ์ฌ์ฉ์๊ฐ ์ ํ์ ๋ํด ์ง๋ฌธํ ์ ์๊ฒ ํจ์ผ๋ก์จ ์ฌ์ฉ์ ๊ฒฝํ์ ํฅ์์ํฌ ์ ์์ต๋๋ค.
- ์ด๋ฏธ์ง ๊ฒ์: VQA ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ํ๋ ํน์ฑ์ ๊ฐ์ง ์ด๋ฏธ์ง๋ฅผ ๊ฒ์ํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด ์ฌ์ฉ์๋ "๊ฐ์์ง๊ฐ ์์ด?"๋ผ๊ณ ๋ฌผ์ด๋ด์ ์ฃผ์ด์ง ์ด๋ฏธ์ง ๋ฌถ์์์ ๊ฐ์์ง๊ฐ ์๋ ๋ชจ๋ ์ด๋ฏธ์ง๋ฅผ ๋ฐ์๋ณผ ์ ์์ต๋๋ค.
์ด ๊ฐ์ด๋์์ ํ์ตํ ๋ด์ฉ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- VQA ๋ชจ๋ธ ์ค ํ๋์ธ ViLT๋ฅผ
Graphcore/vqa
๋ฐ์ดํฐ์ ์์ ๋ฏธ์ธ์กฐ์ ํ๋ ๋ฐฉ๋ฒ - ๋ฏธ์ธ์กฐ์ ๋ ViLT ๋ชจ๋ธ๋ก ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ
- BLIP-2 ๊ฐ์ ์์ฑ ๋ชจ๋ธ๋ก ์ ๋ก์ท VQA ์ถ๋ก ์ ์คํํ๋ ๋ฐฉ๋ฒ
ViLT ๋ฏธ์ธ ์กฐ์ [[finetuning-vilt]]
ViLT๋ Vision Transformer (ViT) ๋ด์ ํ
์คํธ ์๋ฒ ๋ฉ์ ํฌํจํ์ฌ ๋น์ /์์ฐ์ด ์ฌ์ ํ๋ จ(VLP; Vision-and-Language Pretraining)์ ์ํ ๊ธฐ๋ณธ ๋์์ธ์ ์ ๊ณตํฉ๋๋ค.
ViLT ๋ชจ๋ธ์ ๋น์ ํธ๋์คํฌ๋จธ(ViT)์ ํ
์คํธ ์๋ฒ ๋ฉ์ ๋ฃ์ด ๋น์ /์ธ์ด ์ฌ์ ํ๋ จ(VLP; Vision-and-Language Pre-training)์ ์ํ ๊ธฐ๋ณธ์ ์ธ ๋์์ธ์ ๊ฐ์ท์ต๋๋ค. ์ด ๋ชจ๋ธ์ ์ฌ๋ฌ ๋ค์ด์คํธ๋ฆผ ์์
์ ์ฌ์ฉํ ์ ์์ต๋๋ค. VQA ํ์คํฌ์์๋ ([CLS]
ํ ํฐ์ ์ต์ข
์๋ ์ํ ์์ ์ ํ ๋ ์ด์ด์ธ) ๋ถ๋ฅ ํค๋๊ฐ ์์ผ๋ฉฐ ๋ฌด์์๋ก ์ด๊ธฐํ๋ฉ๋๋ค.
๋ฐ๋ผ์ ์ฌ๊ธฐ์์ ์๊ฐ์ ์ง์์๋ต์ ๋ถ๋ฅ ๋ฌธ์ ๋ก ์ทจ๊ธ๋ฉ๋๋ค.
์ต๊ทผ์ BLIP, BLIP-2, InstructBLIP์ ๊ฐ์ ๋ชจ๋ธ๋ค์ VQA๋ฅผ ์์ฑํ ์์ ์ผ๋ก ๊ฐ์ฃผํฉ๋๋ค. ๊ฐ์ด๋์ ํ๋ฐ๋ถ์์๋ ์ด๋ฐ ๋ชจ๋ธ๋ค์ ์ฌ์ฉํ์ฌ ์ ๋ก์ท VQA ์ถ๋ก ์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ค๋ช ํ๊ฒ ์ต๋๋ค.
์์ํ๊ธฐ ์ ํ์ํ ๋ชจ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํ๋์ง ํ์ธํ์ธ์.
pip install -q transformers datasets
์ปค๋ฎค๋ํฐ์ ๋ชจ๋ธ์ ๊ณต์ ํ๋ ๊ฒ์ ๊ถ์ฅ ๋๋ฆฝ๋๋ค. Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ์ฌ ๐ค Hub์ ์ ๋ก๋ํ ์ ์์ต๋๋ค. ๋ฉ์์ง๊ฐ ๋ํ๋๋ฉด ๋ก๊ทธ์ธํ ํ ํฐ์ ์ ๋ ฅํ์ธ์:
>>> from huggingface_hub import notebook_login
>>> notebook_login()
๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ญ ๋ณ์๋ก ์ ์ธํ์ธ์.
>>> model_checkpoint = "dandelin/vilt-b32-mlm"
๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ [[load-the-data]]
์ด ๊ฐ์ด๋์์๋ Graphcore/vqa
๋ฐ์ดํฐ์ธํธ์ ์์ ์ํ์ ์ฌ์ฉํฉ๋๋ค. ์ ์ฒด ๋ฐ์ดํฐ์ธํธ๋ ๐ค Hub ์์ ํ์ธํ ์ ์์ต๋๋ค.
Graphcore/vqa
๋ฐ์ดํฐ์ธํธ ์ ๋์์ผ๋ก ๊ณต์ VQA ๋ฐ์ดํฐ์ธํธ ํ์ด์ง ์์ ๋์ผํ ๋ฐ์ดํฐ๋ฅผ ์๋์ผ๋ก ๋ค์ด๋ก๋ํ ์ ์์ต๋๋ค. ์ง์ ๊ณต์ํ ๋ฐ์ดํฐ๋ก ํํ ๋ฆฌ์ผ์ ๋ฐ๋ฅด๊ณ ์ถ๋ค๋ฉด ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ธํธ ๋ง๋ค๊ธฐ ๋ผ๋
๐ค Datasets ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ธ์.
๊ฒ์ฆ ๋ฐ์ดํฐ์ ์ฒซ 200๊ฐ ํญ๋ชฉ์ ๋ถ๋ฌ์ ๋ฐ์ดํฐ์ธํธ์ ํน์ฑ์ ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค:
>>> from datasets import load_dataset
>>> dataset = load_dataset("Graphcore/vqa", split="validation[:200]")
>>> dataset
Dataset({
features: ['question', 'question_type', 'question_id', 'image_id', 'answer_type', 'label'],
num_rows: 200
})
์์ ๋ฅผ ํ๋ ๋ฝ์ ๋ฐ์ดํฐ์ธํธ์ ํน์ฑ์ ์ดํดํด ๋ณด๊ฒ ์ต๋๋ค.
>>> dataset[0]
{'question': 'Where is he looking?',
'question_type': 'none of the above',
'question_id': 262148000,
'image_id': '/root/.cache/huggingface/datasets/downloads/extracted/ca733e0e000fb2d7a09fbcc94dbfe7b5a30750681d0e965f8e0a23b1c2f98c75/val2014/COCO_val2014_000000262148.jpg',
'answer_type': 'other',
'label': {'ids': ['at table', 'down', 'skateboard', 'table'],
'weights': [0.30000001192092896,
1.0,
0.30000001192092896,
0.30000001192092896]}}
๋ฐ์ดํฐ์ธํธ์๋ ๋ค์๊ณผ ๊ฐ์ ํน์ฑ์ด ํฌํจ๋์ด ์์ต๋๋ค:
question
: ์ด๋ฏธ์ง์ ๋ํ ์ง๋ฌธimage_id
: ์ง๋ฌธ๊ณผ ๊ด๋ จ๋ ์ด๋ฏธ์ง์ ๊ฒฝ๋กlabel
: ๋ฐ์ดํฐ์ ๋ ์ด๋ธ (annotations)
๋๋จธ์ง ํน์ฑ๋ค์ ํ์ํ์ง ์๊ธฐ ๋๋ฌธ์ ์ญ์ ํด๋ ๋ฉ๋๋ค:
>>> dataset = dataset.remove_columns(['question_type', 'question_id', 'answer_type'])
๋ณด์๋ค์ํผ label
ํน์ฑ์ ๊ฐ์ ์ง๋ฌธ๋ง๋ค ๋ต๋ณ์ด ์ฌ๋ฌ ๊ฐ ์์ ์ ์์ต๋๋ค. ๋ชจ๋ ๋ค๋ฅธ ๋ฐ์ดํฐ ๋ผ๋ฒจ๋ฌ๋ค๋ก๋ถํฐ ์์ง๋์๊ธฐ ๋๋ฌธ์ธ๋ฐ์. ์ง๋ฌธ์ ๋ต๋ณ์ ์ฃผ๊ด์ ์ผ ์ ์์ต๋๋ค. ์ด ๊ฒฝ์ฐ ์ง๋ฌธ์ "๊ทธ๋ ์ด๋๋ฅผ ๋ณด๊ณ ์๋์?" ์์ง๋ง, ์ด๋ค ์ฌ๋๋ค์ "์๋"๋ก ๋ ์ด๋ธ์ ๋ฌ์๊ณ , ๋ค๋ฅธ ์ฌ๋๋ค์ "ํ
์ด๋ธ" ๋๋ "์ค์ผ์ดํธ๋ณด๋" ๋ฑ์ผ๋ก ์ฃผ์์ ๋ฌ์์ต๋๋ค.
์๋์ ์ด๋ฏธ์ง๋ฅผ ๋ณด๊ณ ์ด๋ค ๋ต๋ณ์ ์ ํํ ๊ฒ์ธ์ง ์๊ฐํด ๋ณด์ธ์:
>>> from PIL import Image
>>> image = Image.open(dataset[0]['image_id'])
>>> image

์ง๋ฌธ๊ณผ ๋ต๋ณ์ ๋ชจํธ์ฑ์ผ๋ก ์ธํด ์ด๋ฌํ ๋ฐ์ดํฐ์ธํธ๋ ์ฌ๋ฌ ๊ฐ์ ๋ต๋ณ์ด ๊ฐ๋ฅํ๋ฏ๋ก ๋ค์ค ๋ ์ด๋ธ ๋ถ๋ฅ ๋ฌธ์ ๋ก ์ฒ๋ฆฌ๋ฉ๋๋ค. ๊ฒ๋ค๊ฐ, ์ํซ(one-hot) ์ธ์ฝ๋ฉ ๋ฒกํฐ๋ฅผ ์์ฑํ๊ธฐ๋ณด๋ค๋ ๋ ์ด๋ธ์์ ํน์ ๋ต๋ณ์ด ๋ํ๋๋ ํ์๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ํํธ ์ธ์ฝ๋ฉ์ ์์ฑํฉ๋๋ค.
์์ ์์์์ "์๋"๋ผ๋ ๋ต๋ณ์ด ๋ค๋ฅธ ๋ต๋ณ๋ณด๋ค ํจ์ฌ ๋ ์์ฃผ ์ ํ๋์๊ธฐ ๋๋ฌธ์ ๋ฐ์ดํฐ์ธํธ์์ weight
๋ผ๊ณ ๋ถ๋ฆฌ๋ ์ ์๋ก 1.0์ ๊ฐ์ง๋ฉฐ, ๋๋จธ์ง ๋ต๋ณ๋ค์ 1.0 ๋ฏธ๋ง์ ์ ์๋ฅผ ๊ฐ์ง๋๋ค.
์ ์ ํ ๋ถ๋ฅ ํค๋๋ก ๋ชจ๋ธ์ ๋์ค์ ์ธ์คํด์คํํ๊ธฐ ์ํด ๋ ์ด๋ธ์ ์ ์๋ก ๋งคํํ ๋์ ๋๋ฆฌ ํ๋, ๋ฐ๋๋ก ์ ์๋ฅผ ๋ ์ด๋ธ๋ก ๋งคํํ ๋์ ๋๋ฆฌ ํ๋ ์ด 2๊ฐ์ ๋์ ๋๋ฆฌ๋ฅผ ์์ฑํ์ธ์:
>>> import itertools
>>> labels = [item['ids'] for item in dataset['label']]
>>> flattened_labels = list(itertools.chain(*labels))
>>> unique_labels = list(set(flattened_labels))
>>> label2id = {label: idx for idx, label in enumerate(unique_labels)}
>>> id2label = {idx: label for label, idx in label2id.items()}
์ด์ ๋งคํ์ด ์๋ฃ๋์์ผ๋ฏ๋ก ๋ฌธ์์ด ๋ต๋ณ์ ํด๋น id๋ก ๊ต์ฒดํ๊ณ , ๋ฐ์ดํฐ์ธํธ์ ๋ ํธ๋ฆฌํ ํ์ฒ๋ฆฌ๋ฅผ ์ํด ํธํํ ํ ์ ์์ต๋๋ค.
>>> def replace_ids(inputs):
... inputs["label"]["ids"] = [label2id[x] for x in inputs["label"]["ids"]]
... return inputs
>>> dataset = dataset.map(replace_ids)
>>> flat_dataset = dataset.flatten()
>>> flat_dataset.features
{'question': Value(dtype='string', id=None),
'image_id': Value(dtype='string', id=None),
'label.ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
'label.weights': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None)}
๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ [[preprocessing-data]]
๋ค์ ๋จ๊ณ๋ ๋ชจ๋ธ์ ์ํด ์ด๋ฏธ์ง์ ํ
์คํธ ๋ฐ์ดํฐ๋ฅผ ์ค๋นํ๊ธฐ ์ํด ViLT ํ๋ก์ธ์๋ฅผ ๊ฐ์ ธ์ค๋ ๊ฒ์
๋๋ค.
[ViltProcessor
]๋ BERT ํ ํฌ๋์ด์ ์ ViLT ์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ํธ๋ฆฌํ๊ฒ ํ๋์ ํ๋ก์ธ์๋ก ๋ฌถ์ต๋๋ค:
>>> from transformers import ViltProcessor
>>> processor = ViltProcessor.from_pretrained(model_checkpoint)
๋ฐ์ดํฐ๋ฅผ ์ ์ฒ๋ฆฌํ๋ ค๋ฉด ์ด๋ฏธ์ง์ ์ง๋ฌธ์ [ViltProcessor
]๋ก ์ธ์ฝ๋ฉํด์ผ ํฉ๋๋ค. ํ๋ก์ธ์๋ [BertTokenizerFast
]๋ก ํ
์คํธ๋ฅผ ํ ํฌ๋์ด์ฆํ๊ณ ํ
์คํธ ๋ฐ์ดํฐ๋ฅผ ์ํด input_ids
, attention_mask
๋ฐ token_type_ids
๋ฅผ ์์ฑํฉ๋๋ค.
์ด๋ฏธ์ง๋ [ViltImageProcessor
]๋ก ์ด๋ฏธ์ง๋ฅผ ํฌ๊ธฐ ์กฐ์ ํ๊ณ ์ ๊ทํํ๋ฉฐ, pixel_values
์ pixel_mask
๋ฅผ ์์ฑํฉ๋๋ค.
์ด๋ฐ ์ ์ฒ๋ฆฌ ๋จ๊ณ๋ ๋ชจ๋ ๋ด๋ถ์์ ์ด๋ฃจ์ด์ง๋ฏ๋ก, processor
๋ฅผ ํธ์ถํ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค. ํ์ง๋ง ์์ง ํ๊ฒ ๋ ์ด๋ธ์ด ์์ฑ๋์ง ์์์ต๋๋ค. ํ๊ฒ์ ํํ์์ ๊ฐ ์์๋ ๊ฐ๋ฅํ ๋ต๋ณ(๋ ์ด๋ธ)์ ํด๋นํฉ๋๋ค. ์ ํํ ๋ต๋ณ์ ์์๋ ํด๋น ์ ์(weight)๋ฅผ ์ ์ง์ํค๊ณ ๋๋จธ์ง ์์๋ 0์ผ๋ก ์ค์ ํด์ผ ํฉ๋๋ค.
์๋ ํจ์๊ฐ ์์์ ์ค๋ช
ํ๋๋ก ์ด๋ฏธ์ง์ ์ง๋ฌธ์ processor
๋ฅผ ์ ์ฉํ๊ณ ๋ ์ด๋ธ์ ํ์์ ๋ง์ถฅ๋๋ค:
>>> import torch
>>> def preprocess_data(examples):
... image_paths = examples['image_id']
... images = [Image.open(image_path) for image_path in image_paths]
... texts = examples['question']
... encoding = processor(images, texts, padding="max_length", truncation=True, return_tensors="pt")
... for k, v in encoding.items():
... encoding[k] = v.squeeze()
... targets = []
... for labels, scores in zip(examples['label.ids'], examples['label.weights']):
... target = torch.zeros(len(id2label))
... for label, score in zip(labels, scores):
... target[label] = score
... targets.append(target)
... encoding["labels"] = targets
... return encoding
์ ์ฒด ๋ฐ์ดํฐ์ธํธ์ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ์ ์ฉํ๋ ค๋ฉด ๐ค Datasets์ [~datasets.map
] ํจ์๋ฅผ ์ฌ์ฉํ์ญ์์ค. batched=True
๋ฅผ ์ค์ ํ์ฌ ๋ฐ์ดํฐ์ธํธ์ ์ฌ๋ฌ ์์๋ฅผ ํ ๋ฒ์ ์ฒ๋ฆฌํจ์ผ๋ก์จ map
์ ๋ ๋น ๋ฅด๊ฒ ํ ์ ์์ต๋๋ค. ์ด ์์ ์์ ํ์ํ์ง ์์ ์ด์ ์ ๊ฑฐํ์ธ์.
>>> processed_dataset = flat_dataset.map(preprocess_data, batched=True, remove_columns=['question','question_type', 'question_id', 'image_id', 'answer_type', 'label.ids', 'label.weights'])
>>> processed_dataset
Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'],
num_rows: 200
})
๋ง์ง๋ง ๋จ๊ณ๋ก, [DefaultDataCollator
]๋ฅผ ์ฌ์ฉํ์ฌ ์์ ๋ก ์ธ ๋ฐฐ์น๋ฅผ ์์ฑํ์ธ์:
>>> from transformers import DefaultDataCollator
>>> data_collator = DefaultDataCollator()
๋ชจ๋ธ ํ๋ จ [[train-the-model]]
์ด์ ๋ชจ๋ธ์ ํ๋ จํ๊ธฐ ์ํด ์ค๋น๋์์ต๋๋ค! [ViltForQuestionAnswering
]์ผ๋ก ViLT๋ฅผ ๊ฐ์ ธ์ฌ ์ฐจ๋ก์
๋๋ค. ๋ ์ด๋ธ์ ์์ ๋ ์ด๋ธ ๋งคํ์ ์ง์ ํ์ธ์:
>>> from transformers import ViltForQuestionAnswering
>>> model = ViltForQuestionAnswering.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)
์ด ์์ ์์๋ ๋ค์ ์ธ ๋จ๊ณ๋ง ๋จ์์ต๋๋ค:
- [
TrainingArguments
]์์ ํ๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ํ์ธ์:
>>> from transformers import TrainingArguments
>>> repo_id = "MariaK/vilt_finetuned_200"
>>> training_args = TrainingArguments(
... output_dir=repo_id,
... per_device_train_batch_size=4,
... num_train_epochs=20,
... save_steps=200,
... logging_steps=50,
... learning_rate=5e-5,
... save_total_limit=2,
... remove_unused_columns=False,
... push_to_hub=True,
... )
- ๋ชจ๋ธ, ๋ฐ์ดํฐ์ธํธ, ํ๋ก์ธ์, ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ์ ํจ๊ป ํ๋ จ ์ธ์๋ฅผ [
Trainer
]์ ์ ๋ฌํ์ธ์:
>>> from transformers import Trainer
>>> trainer = Trainer(
... model=model,
... args=training_args,
... data_collator=data_collator,
... train_dataset=processed_dataset,
... tokenizer=processor,
... )
- [
~Trainer.train
]์ ํธ์ถํ์ฌ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ธ์:
>>> trainer.train()
ํ๋ จ์ด ์๋ฃ๋๋ฉด, [~Trainer.push_to_hub
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๐ค Hub์ ๋ชจ๋ธ์ ๊ณต์ ํ์ธ์:
>>> trainer.push_to_hub()
์ถ๋ก [[inference]]
ViLT ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๊ณ ๐ค Hub์ ์
๋ก๋ํ๋ค๋ฉด ์ถ๋ก ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ถ๋ก ์ ์ฌ์ฉํด๋ณด๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ [Pipeline
]์์ ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค.
>>> from transformers import pipeline
>>> pipe = pipeline("visual-question-answering", model="MariaK/vilt_finetuned_200")
์ด ๊ฐ์ด๋์ ๋ชจ๋ธ์ 200๊ฐ์ ์์ ์์๋ง ํ๋ จ๋์์ผ๋ฏ๋ก ๊ทธ๋ค์ง ๋ง์ ๊ฒ์ ๊ธฐ๋ํ ์๋ ์์ต๋๋ค. ๋ฐ์ดํฐ์ธํธ์ ์ฒซ ๋ฒ์งธ ์์ ๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ์ค๋ช ํด๋ณด๊ฒ ์ต๋๋ค:
>>> example = dataset[0]
>>> image = Image.open(example['image_id'])
>>> question = example['question']
>>> print(question)
>>> pipe(image, question, top_k=1)
"Where is he looking?"
[{'score': 0.5498199462890625, 'answer': 'down'}]
๋น๋ก ํ์ ์ ๋ณ๋ก ์์ง๋ง, ๋ชจ๋ธ์ ์ค์ ๋ก ๋ฌด์ธ๊ฐ๋ฅผ ๋ฐฐ์ ์ต๋๋ค. ๋ ๋ง์ ์์ ์ ๋ ๊ธด ํ๋ จ ๊ธฐ๊ฐ์ด ์ฃผ์ด์ง๋ค๋ฉด ๋ถ๋ช ๋ ๋์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์ ์์ ๊ฒ์ ๋๋ค!
์ํ๋ค๋ฉด ํ์ดํ๋ผ์ธ์ ๊ฒฐ๊ณผ๋ฅผ ์๋์ผ๋ก ๋ณต์ ํ ์๋ ์์ต๋๋ค:
- ์ด๋ฏธ์ง์ ์ง๋ฌธ์ ๊ฐ์ ธ์์ ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ค๋นํฉ๋๋ค.
- ์ ์ฒ๋ฆฌ๋ ๊ฒฐ๊ณผ๋ฅผ ๋ชจ๋ธ์ ์ ๋ฌํฉ๋๋ค.
- ๋ก์ง์์ ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ ์๋ ๋ต๋ณ์ id๋ฅผ ๊ฐ์ ธ์์
id2label
์์ ์ค์ ๋ต๋ณ์ ์ฐพ์ต๋๋ค.
>>> processor = ViltProcessor.from_pretrained("MariaK/vilt_finetuned_200")
>>> image = Image.open(example['image_id'])
>>> question = example['question']
>>> # prepare inputs
>>> inputs = processor(image, question, return_tensors="pt")
>>> model = ViltForQuestionAnswering.from_pretrained("MariaK/vilt_finetuned_200")
>>> # forward pass
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> logits = outputs.logits
>>> idx = logits.argmax(-1).item()
>>> print("Predicted answer:", model.config.id2label[idx])
Predicted answer: down
์ ๋ก์ท VQA [[zeroshot-vqa]]
์ด์ ๋ชจ๋ธ์ VQA๋ฅผ ๋ถ๋ฅ ๋ฌธ์ ๋ก ์ฒ๋ฆฌํ์ต๋๋ค. BLIP, BLIP-2 ๋ฐ InstructBLIP์ ๊ฐ์ ์ต๊ทผ์ ๋ชจ๋ธ์ VQA๋ฅผ ์์ฑ ์์ ์ผ๋ก ์ ๊ทผํฉ๋๋ค. BLIP-2๋ฅผ ์๋ก ๋ค์ด ๋ณด๊ฒ ์ต๋๋ค. ์ด ๋ชจ๋ธ์ ์ฌ์ ํ๋ จ๋ ๋น์ ์ธ์ฝ๋์ LLM์ ๋ชจ๋ ์กฐํฉ์ ์ฌ์ฉํ ์ ์๋ ์๋ก์ด ๋น์ -์์ฐ์ด ์ฌ์ ํ์ต ํจ๋ฌ๋ค์์ ๋์ ํ์ต๋๋ค. (BLIP-2 ๋ธ๋ก๊ทธ ํฌ์คํธ๋ฅผ ํตํด ๋ ์์ธํ ์์๋ณผ ์ ์์ด์) ์ด๋ฅผ ํตํด ์๊ฐ์ ์ง์์๋ต์ ํฌํจํ ์ฌ๋ฌ ๋น์ -์์ฐ์ด ์์ ์์ SOTA๋ฅผ ๋ฌ์ฑํ ์ ์์์ต๋๋ค.
์ด ๋ชจ๋ธ์ ์ด๋ป๊ฒ VQA์ ์ฌ์ฉํ ์ ์๋์ง ์ค๋ช
ํด ๋ณด๊ฒ ์ต๋๋ค. ๋จผ์ ๋ชจ๋ธ์ ๊ฐ์ ธ์ ๋ณด๊ฒ ์ต๋๋ค. ์ฌ๊ธฐ์ GPU๊ฐ ์ฌ์ฉ ๊ฐ๋ฅํ ๊ฒฝ์ฐ ๋ชจ๋ธ์ ๋ช
์์ ์ผ๋ก GPU๋ก ์ ์กํ ๊ฒ์
๋๋ค. ์ด์ ์๋ ํ๋ จํ ๋ ์ฐ์ง ์์ ์ด์ ๋ [Trainer
]๊ฐ ์ด ๋ถ๋ถ์ ์๋์ผ๋ก ์ฒ๋ฆฌํ๊ธฐ ๋๋ฌธ์
๋๋ค:
>>> from transformers import AutoProcessor, Blip2ForConditionalGeneration
>>> import torch
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model.to(device)
๋ชจ๋ธ์ ์ด๋ฏธ์ง์ ํ ์คํธ๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์ผ๋ฏ๋ก, VQA ๋ฐ์ดํฐ์ธํธ์ ์ฒซ ๋ฒ์งธ ์์ ์์์ ๋์ผํ ์ด๋ฏธ์ง/์ง๋ฌธ ์์ ์ฌ์ฉํด ๋ณด๊ฒ ์ต๋๋ค:
>>> example = dataset[0]
>>> image = Image.open(example['image_id'])
>>> question = example['question']
BLIP-2๋ฅผ ์๊ฐ์ ์ง์์๋ต ์์
์ ์ฌ์ฉํ๋ ค๋ฉด ํ
์คํธ ํ๋กฌํํธ๊ฐ Question: {} Answer:
ํ์์ ๋ฐ๋ผ์ผ ํฉ๋๋ค.
>>> prompt = f"Question: {question} Answer:"
์ด์ ๋ชจ๋ธ์ ํ๋ก์ธ์๋ก ์ด๋ฏธ์ง/ํ๋กฌํํธ๋ฅผ ์ ์ฒ๋ฆฌํ๊ณ , ์ฒ๋ฆฌ๋ ์ ๋ ฅ์ ๋ชจ๋ธ์ ํตํด ์ ๋ฌํ๊ณ , ์ถ๋ ฅ์ ๋์ฝ๋ํด์ผ ํฉ๋๋ค:
>>> inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)
>>> generated_ids = model.generate(**inputs, max_new_tokens=10)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
>>> print(generated_text)
"He is looking at the crowd"
๋ณด์๋ค์ํผ ๋ชจ๋ธ์ ๊ตฐ์ค์ ์ธ์ํ๊ณ , ์ผ๊ตด์ ๋ฐฉํฅ(์๋์ชฝ์ ๋ณด๊ณ ์์)์ ์ธ์ํ์ง๋ง, ๊ตฐ์ค์ด ์ค์ผ์ดํฐ ๋ค์ ์๋ค๋ ์ฌ์ค์ ๋์ณค์ต๋๋ค. ๊ทธ๋ฌ๋ ์ฌ๋์ด ์ง์ ๋ผ๋ฒจ๋งํ ๋ฐ์ดํฐ์ ์ ์ป์ ์ ์๋ ๊ฒฝ์ฐ์, ์ด ์ ๊ทผ๋ฒ์ ๋น ๋ฅด๊ฒ ์ ์ฉํ ๊ฒฐ๊ณผ๋ฅผ ์์ฑํ ์ ์์ต๋๋ค.