Spaces:
Running
Running
File size: 3,450 Bytes
d3cd5c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from typing import List, Tuple
import torch
from einops import rearrange
from PIL import Image
from torch.nn import functional as F
from torchvision.transforms.v2 import InterpolationMode
from torchvision.transforms.v2.functional import normalize
from torchvision.transforms.v2.functional import resize as tv_resize
from torchvision.transforms.v2.functional import to_dtype, to_image
from .layers import attn, layer_norm, linear, mlp
from .weights import VisionModel, load_from_safetensors
def im_resize(
image: Image.Image,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
) -> Image.Image:
"""
The 'resize' function from torchvision has bad type signatures.
it accepts both PIL images and torch tensors, but the type signature
only allows tensors.
"""
return tv_resize(
image, # type: ignore
size,
InterpolationMode.BICUBIC,
)
def create_patches(
image: Image.Image, image_patch_size=378
) -> Tuple[List[Image.Image], Tuple[int, int]]:
"""
Split the given image into a variable number of patches depending upon its
resolution.
"""
# Start off with the global patch.
patches = [im_resize(image, [image_patch_size, image_patch_size])]
# Find the closest resolution template.
res_templates = [(1, 2), (2, 1), (2, 2)]
im_width, im_height = image.size
max_dim = max(im_width, im_height)
if max_dim < image_patch_size * 1.4:
# If the image is already small, we just do a single patch that is a
# duplicate of the global patch. This creates a small amount of
# redundant computation now, but it is simpler and future-proofs us
# if/when we condition the vision encoder on the patch type.
res_template = (1, 1)
patches.append(patches[0])
else:
aspect_ratio = im_width / im_height
res_template = min(
res_templates, key=lambda size: abs((size[1] / size[0]) - aspect_ratio)
)
# TODO: Actually implement patching... just going to put in the global
# patch for now to make progress on other aspects.
patches.append(patches[0])
return patches, res_template
def encode_image(image: Image.Image, weights: VisionModel) -> torch.Tensor:
patches, res_template = create_patches(image.convert("RGB"))
patches = torch.stack(
[
normalize(
to_dtype(to_image(patch), torch.float16, scale=True),
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
)
for patch in patches
]
)
outputs = vision_encoder(patches, weights)
# TODO: Merge sub-image patch outputs properly... for now we'll just assume
# that the global patch is repeated.
assert outputs.shape[0] == 2, "Expected single image patch."
outputs = torch.cat([outputs[0], outputs[1]], dim=-1)
return mlp(outputs, weights.proj_mlp)
def vision_encoder(input_BCHW: torch.Tensor, w: VisionModel):
x = rearrange(
input_BCHW,
"b c (h p1) (w p2) -> b (h w) (c p1 p2)",
p1=w.patch_size,
p2=w.patch_size,
) # B3HW -> B(HxW)(3xP1xP2), aka BTC
x = linear(x, w.patch_emb)
x = x + w.pos_emb
for block in w.blocks:
x = x + attn(layer_norm(x, block.ln1), block.attn)
x = x + mlp(layer_norm(x, block.ln2), block.mlp)
x = layer_norm(x, w.post_ln)
return x
|