File size: 7,450 Bytes
3e1d9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import sys
import copy
import warnings
import logging
from typing import Dict, Any, List
import PIL.Image
import torch
from PIL import Image
from transformers import LlamaTokenizer
from ..root import (
FUNCTIONS,
IMAGE_PLACEHOLDER,
BaseImageProcessFunc,
BaseConvProcessFunc,
BaseTextProcessFunc,
)
from ...conversation import SeparatorStyle, Conversation
IGNORE_INDEX = -100
DEFAULT_IMAGE_TOKEN = IMAGE_PLACEHOLDER
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout), ],
)
@FUNCTIONS.register_module()
class ShikraConvProcess(BaseConvProcessFunc):
def __call__(self, raw_conv: List[Dict[str, Any]], preprocessor: Dict[str, Any], conv_template: Conversation) -> List[Dict[str, Any]]:
conv_processor_cfg = preprocessor['conv']
image_token_len = conv_processor_cfg['image_token_len']
sep_image_conv_front = conv_processor_cfg.get('sep_image_conv_front', False)
use_im_start_end = conv_processor_cfg.get('use_im_start_end', False)
# assert DEFAULT_IMAGE_PATCH_TOKEN in preprocessor['text'].get_vocab()
# if use_im_start_end:
# assert DEFAULT_IM_START_TOKEN in preprocessor['text'].get_vocab()
# assert DEFAULT_IM_END_TOKEN in preprocessor['text'].get_vocab()
if sep_image_conv_front:
raw_conv[0]['value'] = raw_conv[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
raw_conv[0]['value'] = DEFAULT_IMAGE_TOKEN + conv_template.sep + conv_template.roles[0] + ": " + raw_conv[0]['value']
for sentence in raw_conv:
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
if use_im_start_end:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
return raw_conv
@FUNCTIONS.register_module()
class ShikraTextProcess(BaseTextProcessFunc):
def __call__(self, conv: Conversation, preprocessor: Dict[str, Any], mode: str, **tokenize_kwargs) -> Dict[str, Any]:
tokenizer = preprocessor['text']
assert isinstance(tokenizer, LlamaTokenizer), "only work for LlamaTokenizer"
_truncation_size = tokenize_kwargs.pop('truncation_size', None)
_kwargs = {'return_tensors': 'pt'}
_kwargs.update(tokenize_kwargs)
if conv.sep_style == SeparatorStyle.ADD_COLON_TWO:
if mode in ['train']:
ret = self.tk_conv_colon_two_train(conv, tokenizer, **_kwargs)
else:
ret = self.tk_conv_colon_two_eval(conv, tokenizer, **_kwargs)
else:
raise ValueError(f"unrecognized conv_style: {conv.sep_style}.\n the conv is {conv}")
if _truncation_size is None:
return ret
if len(ret['input_ids']) <= _truncation_size:
return ret
origin_len = len(ret['input_ids'])
ids_to_remove_num = origin_len - _truncation_size
# truncation. should carefully not truncate <img_token>
ids_should_not_remove = list(map(
tokenizer.convert_tokens_to_ids,
(DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN)
))
back_no_image = all(ids not in ids_should_not_remove for ids in ret['input_ids'][_truncation_size:])
if back_no_image:
tgt_ids = list(range(_truncation_size))
else:
ids_to_remove = set()
for idx in range(origin_len - 1, -1, -1):
if ret['input_ids'][idx] not in ids_should_not_remove:
ids_to_remove.add(idx)
if len(ids_to_remove) >= ids_to_remove_num:
break
tgt_ids = [_ for _ in range(origin_len) if _ not in ids_to_remove]
logger.warning(f"truncate sample size from {origin_len} to {len(tgt_ids)}.")
assert len(tgt_ids) == _truncation_size, f"{len(tgt_ids)}, {_truncation_size}, {ret['input_ids'].tolist()}"
truncated_ret = {k: v[tgt_ids] for k, v in ret.items()}
return truncated_ret
# noinspection PyMethodMayBeStatic
def tk_conv_colon_two_train(self, conv, tokenizer, **kwargs):
conversation = conv.get_prompt()
input_ids = tokenizer([conversation, ], **kwargs).input_ids[0]
target = copy.deepcopy(input_ids)
assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # <s> <space>
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
warnings.warn(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. (ignored):\n{conversation}")
return dict(
input_ids=input_ids,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
labels=target,
)
# noinspection PyMethodMayBeStatic
def tk_conv_colon_two_eval(self, conv, tokenizer, **kwargs):
assert len(conv.messages) >= 2
# target = conv.messages[-1][-1]
target = conv.get_prompt()
conv.messages[-1][-1] = ""
conversation = conv.get_prompt()
input_ids = tokenizer([conversation, ], **kwargs).input_ids[0]
target = tokenizer([target, ], add_special_tokens=False, **kwargs).input_ids[0]
target[target == tokenizer.pad_token_id] = IGNORE_INDEX
return dict(
input_ids=input_ids,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
labels=target,
)
@FUNCTIONS.register_module()
class ShikraImageProcessor(BaseImageProcessFunc):
def __call__(self, image: Image.Image, preprocessor: Dict[str, Any]) -> Dict[str, Any]:
image_processor = preprocessor['image']
if isinstance(image, (list, tuple)):
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
assert False, 'Shikra not support MultiImage'
elif isinstance(image, PIL.Image.Image):
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
if hasattr(image_processor, 'crop_size'):
crop_size = image_processor.crop_size
height, width = crop_size['height'], crop_size['width']
else:
raise ValueError("got empty image. and don't know how to pad")
image = torch.zeros(3, height, width)
return {'image': image}
|