|
#include <cuda_runtime.h> |
|
#include <torch/torch.h> |
|
#include <vector> |
|
#include <stdio.h> |
|
|
|
#define SPATIAL_MERGE_SIZE 2 |
|
#define MAX_THREADS_PER_BLOCK 256 |
|
|
|
|
|
|
|
|
|
__global__ void create_image_positions_kernel( |
|
const int *image_grid_thw, |
|
const int *segment_offsets, |
|
const int *vision_segment_lengths_cumsum, |
|
int *image_positions) |
|
{ |
|
int segment_idx = blockIdx.x; |
|
|
|
|
|
int t = image_grid_thw[segment_idx * 3]; |
|
int h = image_grid_thw[segment_idx * 3 + 1] / SPATIAL_MERGE_SIZE; |
|
int w = image_grid_thw[segment_idx * 3 + 2] / SPATIAL_MERGE_SIZE; |
|
int total_length = t * h * w; |
|
|
|
|
|
int pos_offset = segment_offsets[segment_idx]; |
|
|
|
int offset_add = vision_segment_lengths_cumsum[segment_idx]; |
|
|
|
|
|
for (int pos_idx = threadIdx.x; pos_idx < total_length; pos_idx += blockDim.x) |
|
{ |
|
|
|
int t_idx = pos_idx / (h * w); |
|
int h_idx = (pos_idx / w) % h; |
|
int w_idx = pos_idx % w; |
|
|
|
int out_index = (pos_offset + pos_idx) * 3; |
|
image_positions[out_index] = t_idx + offset_add; |
|
image_positions[out_index + 1] = h_idx + offset_add; |
|
image_positions[out_index + 2] = w_idx + offset_add; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
void get_position_ids( |
|
torch::Tensor &out, |
|
torch::Tensor &input_ids, |
|
torch::Tensor &image_grid_thw) |
|
{ |
|
TORCH_CHECK(input_ids.device().is_cuda(), "input_ids must be a CUDA tensor"); |
|
TORCH_CHECK(image_grid_thw.device().is_cuda(), "image_grid_thw must be a CUDA tensor"); |
|
TORCH_CHECK(out.device().is_cuda(), "out must be a CUDA tensor"); |
|
|
|
const int input_len = input_ids.size(0); |
|
auto options_int = torch::TensorOptions().device(input_ids.device()).dtype(torch::kInt); |
|
auto options_long = torch::TensorOptions().device(input_ids.device()).dtype(torch::kLong); |
|
|
|
const int VISION_START_TOKEN_ID = 151652; |
|
const int VISION_END_TOKEN_ID = 151653; |
|
|
|
|
|
auto vision_starts_mask = input_ids == VISION_START_TOKEN_ID; |
|
auto vision_ends_mask = input_ids == VISION_END_TOKEN_ID; |
|
|
|
auto starts = torch::where(vision_starts_mask)[0].to(torch::kInt); |
|
auto ends = torch::where(vision_ends_mask)[0].to(torch::kInt); |
|
|
|
int actual_segments = starts.size(0); |
|
auto prev_end = torch::cat({torch::zeros({1}, options_long), ends.slice(0, 0, actual_segments - 1)}); |
|
|
|
|
|
auto text_lengths_between_vision = starts - prev_end + 1; |
|
auto zeros = torch::zeros({1}, options_long); |
|
auto widths = image_grid_thw.slice(0, 0, actual_segments).select(1, 2); |
|
auto divided_widths = widths / SPATIAL_MERGE_SIZE; |
|
auto vision_widths_max = torch::cat({zeros, divided_widths.slice(0, 0, actual_segments - 1)}); |
|
|
|
auto vision_segment_lengths = text_lengths_between_vision + vision_widths_max; |
|
auto vision_segment_lengths_cumsum = vision_segment_lengths.cumsum(0); |
|
auto text_segment_lengths = vision_segment_lengths_cumsum - text_lengths_between_vision; |
|
|
|
|
|
std::vector<int> segment_offsets_vec(actual_segments); |
|
int total_image_positions = 0; |
|
|
|
auto image_grid_cpu = image_grid_thw.to(torch::kCPU); |
|
auto image_grid_accessor = image_grid_cpu.accessor<int, 2>(); |
|
for (int i = 0; i < actual_segments; i++) |
|
{ |
|
int t = image_grid_accessor[i][0]; |
|
int h = image_grid_accessor[i][1] / SPATIAL_MERGE_SIZE; |
|
int w = image_grid_accessor[i][2] / SPATIAL_MERGE_SIZE; |
|
segment_offsets_vec[i] = total_image_positions; |
|
total_image_positions += t * h * w; |
|
} |
|
|
|
|
|
auto segment_offsets_tensor = torch::tensor(segment_offsets_vec, options_int); |
|
|
|
|
|
auto vision_segment_lengths_cumsum_int = vision_segment_lengths_cumsum.to(torch::kInt); |
|
|
|
|
|
|
|
auto image_positions_tensor = torch::empty({total_image_positions, 3}, options_int); |
|
|
|
|
|
int threads = MAX_THREADS_PER_BLOCK; |
|
int blocks = actual_segments; |
|
create_image_positions_kernel<<<blocks, threads>>>( |
|
image_grid_thw.data_ptr<int>(), |
|
segment_offsets_tensor.data_ptr<int>(), |
|
vision_segment_lengths_cumsum_int.data_ptr<int>(), |
|
image_positions_tensor.data_ptr<int>()); |
|
cudaDeviceSynchronize(); |
|
cudaError_t error = cudaGetLastError(); |
|
TORCH_CHECK(error == cudaSuccess, "CUDA error: ", cudaGetErrorString(error)); |
|
|
|
|
|
|
|
std::vector<torch::Tensor> text_positions_list; |
|
for (int i = 0; i < actual_segments; i++) |
|
{ |
|
int seq_len = text_lengths_between_vision[i].item<int>(); |
|
auto text_range = torch::zeros({3, seq_len}, options_long) + text_segment_lengths[i]; |
|
text_positions_list.push_back(text_range); |
|
} |
|
|
|
|
|
std::vector<torch::Tensor> full_positions_list; |
|
|
|
for (int i = 0; i < actual_segments; i++) |
|
{ |
|
|
|
full_positions_list.push_back(text_positions_list[i]); |
|
|
|
int start = segment_offsets_vec[i]; |
|
int seg_length = 0; |
|
if (i == actual_segments - 1) |
|
seg_length = total_image_positions - segment_offsets_vec[i]; |
|
else |
|
seg_length = segment_offsets_vec[i + 1] - segment_offsets_vec[i]; |
|
|
|
|
|
torch::Tensor image_segment = image_positions_tensor.slice(0, start, start + seg_length).t(); |
|
full_positions_list.push_back(image_segment); |
|
} |
|
|
|
int full_text_len = input_len - ends[actual_segments - 1].item<int>(); |
|
if (full_text_len > 0) |
|
{ |
|
int max_s = full_positions_list.back().max().item<int>() + 1; |
|
auto extra_text = torch::arange(full_text_len, options_long).view({1, -1}).expand({3, -1}) + max_s; |
|
full_positions_list.push_back(extra_text); |
|
} |
|
|
|
|
|
auto full_positions_concatenated = torch::cat(full_positions_list, 1); |
|
auto full_positions_concatenated_transposed = full_positions_concatenated.t(); |
|
|
|
|
|
out.copy_(full_positions_concatenated_transposed); |
|
} |
|
|