File size: 4,158 Bytes
e52d1ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import time
import torch
import pytest
import get_position_ids # noqa: E402
from reference import DummyModel

# Each configuration includes:
#   - name: A label for the test case.
#   - input_ids: A list of token IDs (with vision start (151652) and vision end (151653) tokens embedded).
#   - grid: A list of [t, h, w] values (one per vision segment).
#
# The cases below include:
#   1. one_segment: a single vision segment.
#   2. two_segments: two vision segments with extra text tokens afterward.
#   3. three_segments: three vision segments.
VISION_CONFIGS = [
    {
        "name": "one_segment",
        "input_ids": (
            [10] * 5 +            # 5 text tokens before vision segment
            [151652, 151653] +    # vision tokens for segment 1
            [20] * 5              # 5 extra text tokens after vision segment
        ),
        "grid": [[2, 4, 6]]       # one vision segment grid
    },
    {
        "name": "two_segments",
        "input_ids": (
            [100] * 5 +           # 5 text tokens for segment 1
            [151652, 151653] +    # vision tokens for segment 1
            [101] * 5 +           # 5 text tokens for segment 2
            [151652, 151653] +    # vision tokens for segment 2
            [102] * 5             # 5 extra text tokens after last vision segment
        ),
        "grid": [
            [2, 4, 6],          # vision segment 1 grid
            [3, 4, 6]           # vision segment 2 grid
        ],
    },
    {
        "name": "three_segments",
        "input_ids": (
            [11] * 5 +            # Segment 1: 5 text tokens
            [151652, 151653] +    # vision tokens for segment 1
            [12] * 6 +            # Segment 2: 6 text tokens
            [151652, 151653] +    # vision tokens for segment 2
            [13] * 7 +            # Segment 3: 7 text tokens
            [151652, 151653] +    # vision tokens for segment 3
            [14] * 8              # 8 extra text tokens after the last vision segment
        ),
        "grid": [
            [2, 4, 6],          # vision segment 1 grid
            [3, 6, 6],          # vision segment 2 grid
            [4, 4, 8]           # vision segment 3 grid
        ],
    },
]

CUDA_DEVICES = ["cuda"]     # List of CUDA devices; you can add more if needed.
SEEDS = [42]                # Seeds for reproducibility.
DTYPES = [torch.int32]      # In our test the tokens and grid are created with int32.


@pytest.mark.parametrize("vision_config", 
                         VISION_CONFIGS, 
                         ids=[cfg["name"] for cfg in VISION_CONFIGS])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_get_position_ids(vision_config, seed, device):
    torch.manual_seed(seed)
    input_ids = torch.tensor(vision_config["input_ids"], dtype=torch.int32, device=device)
    image_grid_thw = torch.tensor(vision_config["grid"], dtype=torch.int32, device=device)

    # Create a DummyModel instance from the reference implementation.
    dummy_model = DummyModel()

    # reference implementation
    torch.cuda.synchronize()
    start_ref = time.perf_counter()
    pos_ids_ref = dummy_model.get_position_ids(input_ids, image_grid_thw)
    torch.cuda.synchronize()
    end_ref = time.perf_counter()
    ref_time = (end_ref - start_ref) * 1000  # ms
    print(f"\nVision config {vision_config['name']} - Reference time: {ref_time:.2f} ms")
    # Convert reference output to int32 for comparison (since its returned as a float tensor).
    pos_ids_ref = pos_ids_ref.to(dtype=torch.int32)

    # kernel implementation
    torch.cuda.synchronize()
    start_ext = time.perf_counter()
    out = torch.empty(pos_ids_ref.shape, dtype=torch.int32, device=device)
    get_position_ids.get_position_ids(out, input_ids, image_grid_thw)
    torch.cuda.synchronize()
    end_ext = time.perf_counter()
    ext_time = (end_ext - start_ext) * 1000  # ms
    print(f"Vision config {vision_config['name']} - Extension time: {ext_time:.2f} ms\n")
    ext_out = out.clone()

    # verify the results
    torch.testing.assert_close(ext_out.cpu(), pos_ids_ref.cpu())