import sys import torch import os import random import base64 import msgpack from io import BytesIO import numpy as np from transformers import AutoTokenizer from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2 from llava.model.builder import load_pretrained_model from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor from llava.model import LlavaMistralForCausalLM from transformers import CLIPImageProcessor from PIL import Image import logging import time from concurrent.futures import ThreadPoolExecutor, as_completed import threading def select_frames(input_frames, num_segments = 10): indices = np.linspace(start=0, stop=len(input_frames)-1, num=num_segments).astype(int) frames = [input_frames[ind] for ind in indices] return frames def load_model(model_path, device_map): kwargs = {"device_map": device_map} kwargs['torch_dtype'] = torch.float16 #difference with cpu handler but it needs float16 to ensure no memory issue tokenizer = AutoTokenizer.from_pretrained(model_path) model = LlavaMistralForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model(device_map=device_map) return model, tokenizer class EndpointHandler: def __init__(self): model_path = './masp_094_v2' disable_torch_init() model_path = os.path.expanduser(model_path) #print(model_path) model_name = get_model_name_from_path(model_path) model, tokenizer = load_model(model_path, device_map={"":0}) image_processor = Blip2ImageTrainProcessor( image_size=model.config.img_size, is_training=False) """ import os from PIL import Image input_dir = './v12044gd0000clg1n4fog65p7pag5n6g/video' image_paths = os.listdir(input_dir) images = [Image.open(os.path.join(input_dir, item)) for item in image_paths] num_segments = 10 images = images[:num_segments] import torch device = torch.device('cuda:0') image_processor = Blip2ImageTrainProcessor( image_size=224, is_training=False) images_tensor = [image_processor.preprocess(image).cpu().to(device) for image in images] """ self.tokenizer = tokenizer self.device = torch.device('cuda:0') #another difference here self.model = model.to(self.device) self.image_processor = image_processor self.conv_mode = 'v1' def inference_frames_batch(self, batch_image_lists, batch_prompts, batch_temperatures): start_time = time.perf_counter() # Start timer batch_size = len(batch_image_lists) # Process images and prompts for each item in the batch images_tensors_list = [] input_ids_list = [] for images, prompt in zip(batch_image_lists, batch_prompts): # Select frames (ensure consistent number of frames) if len(images) > 10: images = select_frames(images) if len(images) < 10: images += [images[-1]] * (10 - len(images)) # Pad to 10 frames # Process images images_tensor = process_images_v2(images, self.image_processor, self.model.config) images_tensor = images_tensor.half().to(self.device) # Ensure correct dtype and device images_tensors_list.append(images_tensor) # Prepare the prompt if len(images) == 1: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt else: qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + prompt # Build conversation and tokenize conv = conv_templates[self.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt_text = conv.get_prompt() input_ids = tokenizer_image_token(prompt_text, self.tokenizer, MM_TOKEN_INDEX, return_tensors='pt').squeeze(0) input_ids_list.append(input_ids) # Pad input IDs to the same length input_ids_padded = torch.nn.utils.rnn.pad_sequence( input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id ).to(self.device) # No need to stack images_tensors_list into a tensor # Each item in images_tensors_list is a tensor of shape (num_frames, C, H, W) # Prepare stopping criteria conv = conv_templates[self.conv_mode].copy() stop_str = conv.sep if conv.sep2 is None else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids_padded) # Use the first temperature for simplicity temperature = batch_temperatures[0] # Perform model inference with torch.inference_mode(): output_ids = self.model.generate( input_ids_padded, images=images_tensors_list, temperature=temperature, do_sample=True, top_p=None, num_beams=1, no_repeat_ngram_size=3, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria], ) # Decode outputs outputs = [] for output_id in output_ids: output = self.tokenizer.decode(output_id, skip_special_tokens=True).strip() output = output.rstrip(stop_str).strip() outputs.append(output) end_time = time.perf_counter() # End timer latency = end_time - start_time print(f"Latency for this batch inference: {latency:.4f} seconds") return outputs def __call__(self, request): # Unpack the images and prompts packed_data_list = request['images'] # List of packed image data prompt_list = request.get('prompt', [''.encode()] * len(packed_data_list)) temperature_list = request.get('temperature', ['0.01'.encode()] * len(packed_data_list)) # Initialize lists to collect images, prompts, and temperatures all_image_lists = [] # List of lists of images all_prompts = [] all_temperatures = [] for packed_data, prompt_encoded, temperature_encoded in zip(packed_data_list, prompt_list, temperature_list): # Unpack the images unpacked_data = msgpack.unpackb(packed_data, raw=False) image_list = [Image.open(BytesIO(byte_data)).convert('RGB') for byte_data in unpacked_data] all_image_lists.append(image_list) # Decode the prompt prompt = prompt_encoded.decode() if prompt == '': if len(image_list) == 1: prompt = "Please describe this image in detail." else: prompt = "Describe the following video in detail." all_prompts.append(prompt) # Decode the temperature temperature = float(temperature_encoded.decode()) all_temperatures.append(temperature) # Now process all_image_lists and all_prompts in batch with torch.no_grad(): outputs = self.inference_frames_batch(all_image_lists, all_prompts, all_temperatures) return {'output': outputs} def benchmark_qps_batched(handler, batched_request, num_batches=10): start_time = time.perf_counter() completed_samples = 0 for _ in range(num_batches): handler(batched_request) completed_samples += len(batched_request['images']) end_time = time.perf_counter() total_time = end_time - start_time qps = completed_samples / total_time print(f"Processed {completed_samples} samples in {total_time:.2f} seconds. QPS: {qps:.2f}") if __name__ == "__main__": # 7347652962333773061 video_dir = './v12044gd0000cl5c6rfog65i2eoqcqig' #video_dir = '/mnt/bn/data-tns-algo-masp/kaili.zhao/data/masp_data/train/human_annotation/video_frames_2fps/7347652962333773061' frames = [(int(os.path.splitext(item)[0]), os.path.join(video_dir, item)) for item in os.listdir(video_dir)] frames = [item[1] for item in sorted(frames, key=lambda x: x[0])] out_frames = [Image.open(frame).convert('RGB') for frame in frames] # out_frames = select_frames(frames) # Number of samples to include in the batch batch_size = 4 # Adjust based on GPU memory # Prepare batched data batched_packed_data = [] batched_prompts = [] batched_temperatures = [] for _ in range(batch_size): # Convert images to byte format byte_images = [] for img in out_frames: byte_io = BytesIO() img.save(byte_io, format='JPEG') byte_images.append(byte_io.getvalue()) # Pack the byte data with msgpack packed_data = msgpack.packb(byte_images) batched_packed_data.append(packed_data) # Add prompt and temperature for each sample batched_prompts.append(''.encode()) # Or specific prompts batched_temperatures.append('0.01'.encode()) # Create the batched request batched_request = { 'images': batched_packed_data, 'prompt': batched_prompts, 'temperature': batched_temperatures, } handler = EndpointHandler() # Measure latency for the batched request #print("\nMeasuring latency for batched request...") response = handler(batched_request) print(response)#['output']) # Benchmark QPS with batched requests # print("\nBenchmarking QPS with batched requests...") # num_batches = 10 # Number of batched requests # benchmark_qps_batched(handler, batched_request, num_batches=num_batches)