Hisab Cloud
Upload folder using huggingface_hub
45e92bd verified
import copy
import json
import logging
import math
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import numpy as np
import torch
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
raw_data,
transform,
tokenizer,
slice_config,
llm_type="minicpm",
patch_size=14,
query_nums=64,
batch_vision=False,
):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
self.tokenizer = tokenizer
self.transform = transform
self.slice_config = slice_config
self.llm_type = llm_type
self.patch_size = patch_size
self.query_nums=query_nums
self.batch_vision = batch_vision
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image = Image.open(self.raw_data[i]["image"]).convert("RGB")
ret = preprocess(
image,
self.raw_data[i]["conversations"],
self.tokenizer,
self.transform,
query_nums=self.query_nums,
slice_config=self.slice_config,
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
)
ret = dict(
input_ids=ret["input_ids"],
position_ids=ret["position_ids"],
labels=ret["target"],
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
return ret
def data_collator(examples, padding_value=0, max_length=2048):
def trim_and_pad(seq, batch_first, padding_value):
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
input_ids = trim_and_pad(
[example["input_ids"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
position_ids = trim_and_pad(
[example["position_ids"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
targets = trim_and_pad(
[example["labels"] for example in examples],
batch_first=True,
padding_value=-100,
)
attention_mask = trim_and_pad(
[example["attention_mask"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
pixel_values = [example["pixel_values"] for example in examples]
image_bound = [example["image_bound"] for example in examples]
tgt_sizes = [example["tgt_sizes"] for example in examples]
return {
"input_ids": input_ids,
"position_ids": position_ids,
"labels": targets,
"attention_mask": attention_mask,
"image_bound": image_bound,
"tgt_sizes": tgt_sizes,
"pixel_values": pixel_values,
}
def conversation_to_ids(conversation, tokenizer, llm_type=None):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
{'role': 'assistant', 'content': 'This is a cat.'}]
"""
if llm_type == "llama3":
input_ids, context, raw_msg = conversation_to_ids_llama3(
conversation, tokenizer
)
else:
input_ids, context, raw_msg = conversation_to_ids_minicpm(
conversation, tokenizer
)
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
# build target
target = torch.full_like(ids, -100, dtype=torch.int32)
for i in range(1, len(ids)):
if context[i] == 0:
target[i - 1] = ids[i]
if context[i] == 1 and context[i - 1] == 0:
if hasattr(tokenizer, "eot_id"):
target[i - 1] = tokenizer.eot_id
else:
target[i - 1] = tokenizer.eos_id
# build image bound
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
)
else:
image_bound = []
position_ids = torch.arange(ids.size(0)).long()
return {
"input_ids": ids,
"target": target,
"image_bound": image_bound,
"raw_msg": raw_msg,
"position_ids": position_ids
}
def conversation_to_ids_minicpm(conversation, tokenizer):
raw_msg = ""
input_ids = []
context = []
for idx, msg in enumerate(conversation):
role = msg["role"]
message = msg["content"]
assert role in ["user", "assistant"]
if role == "user":
prefix = "<用户>"
else:
prefix = "<AI>"
# append eos
if idx == len(conversation) - 1:
message = message + tokenizer.eos_token
prefix_ids = tokenizer.encode(prefix)[1:] # remove bos
message_ids = tokenizer.encode(message)[1:]
input_ids.append(prefix_ids)
input_ids.append(message_ids)
context.append(np.ones((len(prefix_ids),), dtype=np.int8))
if role == "assistant":
context.append(np.zeros((len(message_ids),), dtype=np.int8))
else:
context.append(np.ones((len(message_ids),), dtype=np.int8))
raw_msg += prefix + message
return input_ids, context, raw_msg
def conversation_to_ids_llama3(conversation, tokenizer):
raw_msg = ""
input_ids = []
context = []
raw_msg = tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=False
)
input_ids = tokenizer.apply_chat_template(
conversation, tokenize=True, add_generation_prompt=False
)
input_ids = np.array(input_ids)
start_header_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|start_header_id|>")
)[0]
assistant_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("assistant")
)[0]
end_header_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|end_header_id|>")
)[0]
eot_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|eot_id|>"))[0]
context = np.ones_like(input_ids, dtype=np.int8)
for assistant_idx in assistant_idxs:
if assistant_idx in set((start_header_idxs + end_header_idxs) / 2):
st = assistant_idx + 3 # assistant<|end_header_id|>\n\n
for eot_idx in eot_idxs:
if eot_idx > st:
context[st: eot_idx + 1] = 0
break
input_ids = np.hstack(input_ids)
context = np.hstack(context)
return input_ids, context, raw_msg
def preprocess(
image,
conversation,
tokenizer,
transform,
query_nums=64,
slice_config=None,
llm_type=None,
patch_size=14,
batch_vision=False,
):
"""
single image preprocess, the image will be placed at the top of the conversation
"""
conversation = copy.deepcopy(conversation)
assert len(conversation) > 1, "conversation length must large than 2"
assert conversation[0]["role"] == "user", "the first role must be user"
if slice_config is not None:
assert isinstance(slice_config, Dict)
assert "patch_size" in slice_config
assert "max_slice_nums" in slice_config
assert "scale_resolution" in slice_config
default_image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
)
if slice_config:
images = []
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
slice_config["scale_resolution"],
slice_config["patch_size"],
)
images.append(source_image)
image_placeholder = default_image_placeholder
if len(patches) > 0:
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums)
images = [transform(i) for i in images]
else:
images = [transform(image)]
image_placeholder = default_image_placeholder
if "<image>" in conversation[0]["content"]:
conversation[0]["content"] = conversation[0]["content"].replace(
"<image>", image_placeholder
)
else:
conversation[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type)
if batch_vision:
tgt_sizes = []
reshape_images = []
for image in images:
H, W = image.shape[1:]
reshape_image = reshape_by_patch(image, patch_size)
reshape_images.append(reshape_image)
tgt_sizes.append([H // patch_size, W // patch_size])
if tgt_sizes:
tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32)
input_dict["pixel_values"] = reshape_images
input_dict["tgt_sizes"] = tgt_sizes
else:
input_dict["pixel_values"] = images
input_dict["tgt_sizes"] = []
return input_dict
def slice_image(
image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
):
original_size = image.size
original_width, original_height = original_size
log_ratio = math.log(original_width / original_height)
ratio = original_width * original_height / \
(scale_resolution * scale_resolution)
multiple = min(math.ceil(ratio), max_slice_nums)
source_image = None
best_grid = None
patches = []
if multiple <= 1 or never_split:
# dont need to slice, upsample
best_size = find_best_resize(
original_size, scale_resolution, patch_size, allow_upscale=True
)
source_image = image.resize(best_size, Image.Resampling.BICUBIC)
else:
candidate_split_grids_nums = []
for i in [multiple - 1, multiple, multiple + 1]:
if i == 1 or i > max_slice_nums:
continue
candidate_split_grids_nums.append(i)
# source image, down-sampling and ensure divided by patch_size
best_resize = find_best_resize(
original_size, scale_resolution, patch_size)
source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
candidate_grids = []
# find best grid
for split_grids_nums in candidate_split_grids_nums:
m = 1
while m <= split_grids_nums:
if split_grids_nums % m == 0:
candidate_grids.append([m, split_grids_nums // m])
m += 1
best_grid = [1, 1]
min_error = float("inf")
for grid in candidate_grids:
error = abs(log_ratio - math.log(grid[0] / grid[1]))
if error < min_error:
best_grid = grid
min_error = error
refine_size = get_refine_size(
original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
)
refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)
patches = split_to_patches(refine_image, best_grid)
return source_image, patches, best_grid
def ensure_divide(length, patch_size):
return max(round(length / patch_size) * patch_size, patch_size)
def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False):
width, height = original_size
if (width * height > scale_resolution * scale_resolution) or allow_upscale:
r = width / height
height = int(scale_resolution / math.sqrt(r))
width = int(height * r)
best_width = ensure_divide(width, patch_size)
best_height = ensure_divide(height, patch_size)
return (best_width, best_height)
def get_refine_size(
original_size, grid, scale_resolution, patch_size, allow_upscale=False
):
width, height = original_size
grid_x, grid_y = grid
refine_width = ensure_divide(width, grid_x)
refine_height = ensure_divide(height, grid_y)
grid_width = refine_width / grid_x
grid_height = refine_height / grid_y
best_grid_size = find_best_resize(
(grid_width, grid_height),
scale_resolution,
patch_size,
allow_upscale=allow_upscale,
)
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
return refine_size
def split_to_patches(image, grid):
patches = []
width, height = image.size
grid_x = int(width / grid[0])
grid_y = int(height / grid[1])
for i in range(0, height, grid_y):
images = []
for j in range(0, width, grid_x):
box = (j, i, j + grid_x, i + grid_y)
patch = image.crop(box)
images.append(patch)
patches.append(images)
return patches
def get_grid_placeholder(tokenizer, grid, query_num):
image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
)
cols = grid[0]
rows = grid[1]
slices = []
for i in range(rows):
lines = []
for j in range(cols):
lines.append(image_placeholder)
slices.append("".join(lines))
slice_placeholder = tokenizer.slice_start + \
"\n".join(slices) + tokenizer.slice_end
return slice_placeholder
def reshape_by_patch(image_tensor, patch_size):
"""
:param image_tensor: shape [3, H, W]
:param patch_size:
:return: [3, patch_size, HW/patch_size]
"""
patches = torch.nn.functional.unfold(
image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size)
)
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
patches = patches.permute(0, 1, 3, 2).reshape(
image_tensor.size(0), patch_size, -1)
return patches