Spaces:
Runtime error
Runtime error
| import base64 | |
| import torch | |
| from PIL import Image | |
| from io import BytesIO | |
| from transformers import StoppingCriteria | |
| from serve.constants import IMAGE_TOKEN_INDEX | |
| def load_image_from_base64(image): | |
| return Image.open(BytesIO(base64.b64decode(image))) | |
| def expand2square(pil_img, background_color): | |
| width, height = pil_img.size | |
| if width == height: | |
| return pil_img | |
| elif width > height: | |
| result = Image.new(pil_img.mode, (width, width), background_color) | |
| result.paste(pil_img, (0, (width - height) // 2)) | |
| return result | |
| else: | |
| result = Image.new(pil_img.mode, (height, height), background_color) | |
| result.paste(pil_img, ((height - width) // 2, 0)) | |
| return result | |
| def process_images(images, image_processor, model_cfg): | |
| image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) | |
| new_images = [] | |
| if image_aspect_ratio == 'pad': | |
| for image in images: | |
| image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean)) | |
| image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] | |
| new_images.append(image) | |
| else: | |
| return image_processor(images, return_tensors='pt')['pixel_values'] | |
| if all(x.shape == new_images[0].shape for x in new_images): | |
| new_images = torch.stack(new_images, dim=0) | |
| return new_images | |
| def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): | |
| prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')] | |
| def insert_separator(X, sep): | |
| return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] | |
| input_ids = [] | |
| offset = 0 | |
| if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: | |
| offset = 1 | |
| input_ids.append(prompt_chunks[0][0]) | |
| for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): | |
| input_ids.extend(x[offset:]) | |
| if return_tensors is not None: | |
| if return_tensors == 'pt': | |
| return torch.tensor(input_ids, dtype=torch.long) | |
| raise ValueError(f'Unsupported tensor type: {return_tensors}') | |
| return input_ids | |
| def get_model_name_from_path(model_path): | |
| model_path = model_path.strip("/") | |
| model_paths = model_path.split("/") | |
| if model_paths[-1].startswith('checkpoint-'): | |
| return model_paths[-2] + "_" + model_paths[-1] | |
| else: | |
| return model_paths[-1] | |
| class KeywordsStoppingCriteria(StoppingCriteria): | |
| def __init__(self, keywords, tokenizer, input_ids): | |
| self.keywords = keywords | |
| self.keyword_ids = [] | |
| self.max_keyword_len = 0 | |
| for keyword in keywords: | |
| cur_keyword_ids = tokenizer(keyword).input_ids | |
| if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: | |
| cur_keyword_ids = cur_keyword_ids[1:] | |
| if len(cur_keyword_ids) > self.max_keyword_len: | |
| self.max_keyword_len = len(cur_keyword_ids) | |
| self.keyword_ids.append(torch.tensor(cur_keyword_ids)) | |
| self.tokenizer = tokenizer | |
| self.start_len = input_ids.shape[1] | |
| def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) | |
| self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] | |
| for keyword_id in self.keyword_ids: | |
| truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] | |
| if torch.equal(truncated_output_ids, keyword_id): | |
| return True | |
| outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] | |
| for keyword in self.keywords: | |
| if keyword in outputs: | |
| return True | |
| return False | |
| def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| outputs = [] | |
| for i in range(output_ids.shape[0]): | |
| outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) | |
| return all(outputs) |