File size: 3,450 Bytes
d3cd5c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import List, Tuple

import torch
from einops import rearrange
from PIL import Image
from torch.nn import functional as F
from torchvision.transforms.v2 import InterpolationMode
from torchvision.transforms.v2.functional import normalize
from torchvision.transforms.v2.functional import resize as tv_resize
from torchvision.transforms.v2.functional import to_dtype, to_image

from .layers import attn, layer_norm, linear, mlp
from .weights import VisionModel, load_from_safetensors


def im_resize(
    image: Image.Image,
    size: List[int],
    interpolation: InterpolationMode = InterpolationMode.BICUBIC,
) -> Image.Image:
    """
    The 'resize' function from torchvision has bad type signatures.
    it accepts both PIL images and torch tensors,  but the type signature
    only allows tensors.
    """
    return tv_resize(
        image,  # type: ignore
        size,
        InterpolationMode.BICUBIC,
    )


def create_patches(
    image: Image.Image, image_patch_size=378
) -> Tuple[List[Image.Image], Tuple[int, int]]:
    """
    Split the given image into a variable number of patches depending upon its
    resolution.
    """
    # Start off with the global patch.
    patches = [im_resize(image, [image_patch_size, image_patch_size])]

    # Find the closest resolution template.
    res_templates = [(1, 2), (2, 1), (2, 2)]
    im_width, im_height = image.size
    max_dim = max(im_width, im_height)
    if max_dim < image_patch_size * 1.4:
        # If the image is already small, we just do a single patch that is a
        # duplicate of the global patch. This creates a small amount of
        # redundant computation now, but it is simpler and future-proofs us
        # if/when we condition the vision encoder on the patch type.
        res_template = (1, 1)
        patches.append(patches[0])
    else:
        aspect_ratio = im_width / im_height
        res_template = min(
            res_templates, key=lambda size: abs((size[1] / size[0]) - aspect_ratio)
        )
        # TODO: Actually implement patching... just going to put in the global
        # patch for now to make progress on other aspects.
        patches.append(patches[0])

    return patches, res_template


def encode_image(image: Image.Image, weights: VisionModel) -> torch.Tensor:
    patches, res_template = create_patches(image.convert("RGB"))
    patches = torch.stack(
        [
            normalize(
                to_dtype(to_image(patch), torch.float16, scale=True),
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5],
            )
            for patch in patches
        ]
    )

    outputs = vision_encoder(patches, weights)

    # TODO: Merge sub-image patch outputs properly... for now we'll just assume
    # that the global patch is repeated.
    assert outputs.shape[0] == 2, "Expected single image patch."
    outputs = torch.cat([outputs[0], outputs[1]], dim=-1)

    return mlp(outputs, weights.proj_mlp)


def vision_encoder(input_BCHW: torch.Tensor, w: VisionModel):
    x = rearrange(
        input_BCHW,
        "b c (h p1) (w p2) -> b (h w) (c p1 p2)",
        p1=w.patch_size,
        p2=w.patch_size,
    )  # B3HW -> B(HxW)(3xP1xP2), aka BTC

    x = linear(x, w.patch_emb)
    x = x + w.pos_emb
    for block in w.blocks:
        x = x + attn(layer_norm(x, block.ln1), block.attn)
        x = x + mlp(layer_norm(x, block.ln2), block.mlp)
    x = layer_norm(x, w.post_ln)

    return x