akshit-g's picture
add : files
d3cd5c1
raw
history blame
6.9 kB
import math
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List, Callable
import safetensors
import torch
from .layers import AttentionWeights, LayerNormWeights, LinearWeights, MLPWeights
@dataclass
class VisionBlock:
ln1: LayerNormWeights
attn: AttentionWeights
ln2: LayerNormWeights
mlp: MLPWeights
@dataclass
class VisionModel:
patch_size: int
patch_emb: LinearWeights
pos_emb: torch.Tensor
blocks: List[VisionBlock]
post_ln: LayerNormWeights
proj_mlp: MLPWeights
@dataclass
class TextBlock:
ln: LayerNormWeights
attn: AttentionWeights
mlp: MLPWeights
@dataclass
class TextModel:
wte: torch.Tensor
blocks: List[TextBlock]
post_ln: LayerNormWeights
lm_head: LinearWeights
@dataclass
class MoondreamModel:
vision: VisionModel
text: TextModel
@contextmanager
def safetensors_open(safetensors_file: str):
"""
Simplify interfacing with safetensors files. Eliminates the need to ignore
type errors when using the `safe_open` function.
"""
with safetensors.safe_open(
safetensors_file, framework="pt"
) as st: # pyright: ignore
def get_tensor(name: str) -> torch.Tensor:
return st.get_tensor(name)
yield get_tensor
def load_model(
get_tensor: Callable[[str], torch.Tensor],
vision_blocks: int = 27,
text_blocks: int = 24,
vision_n_heads: int = 16,
text_n_heads: int = 32,
) -> MoondreamModel:
## Vision encoder
prefix = "vision_encoder.encoder.model.visual.patch_embed.linear"
patch_emb = LinearWeights(
weight=get_tensor(f"{prefix}.weight"), bias=get_tensor(f"{prefix}.bias")
)
patch_size = int(math.sqrt(patch_emb.weight.shape[1] // 3))
pos_emb = get_tensor("vision_encoder.encoder.model.visual.pos_embed")
post_ln = LayerNormWeights(
weight=get_tensor("vision_encoder.encoder.model.visual.norm.weight"),
bias=get_tensor("vision_encoder.encoder.model.visual.norm.bias"),
)
blocks = []
for i in range(vision_blocks):
prefix = f"vision_encoder.encoder.model.visual.blocks.{i}"
blocks.append(
VisionBlock(
ln1=LayerNormWeights(
weight=get_tensor(f"{prefix}.norm1.weight"),
bias=get_tensor(f"{prefix}.norm1.bias"),
),
attn=AttentionWeights(
qkv=LinearWeights(
weight=get_tensor(f"{prefix}.attn.qkv.weight"),
bias=get_tensor(f"{prefix}.attn.qkv.bias"),
),
proj=LinearWeights(
weight=get_tensor(f"{prefix}.attn.proj.weight"),
bias=get_tensor(f"{prefix}.attn.proj.bias"),
),
n_heads=vision_n_heads,
),
ln2=LayerNormWeights(
weight=get_tensor(f"{prefix}.norm2.weight"),
bias=get_tensor(f"{prefix}.norm2.bias"),
),
mlp=MLPWeights(
fc1=LinearWeights(
weight=get_tensor(f"{prefix}.mlp.fc1.weight"),
bias=get_tensor(f"{prefix}.mlp.fc1.bias"),
),
fc2=LinearWeights(
weight=get_tensor(f"{prefix}.mlp.fc2.weight"),
bias=get_tensor(f"{prefix}.mlp.fc2.bias"),
),
),
)
)
proj_mlp = MLPWeights(
fc1=LinearWeights(
weight=get_tensor("vision_encoder.projection.mlp.fc1.weight"),
bias=get_tensor("vision_encoder.projection.mlp.fc1.bias"),
),
fc2=LinearWeights(
weight=get_tensor("vision_encoder.projection.mlp.fc2.weight"),
bias=get_tensor("vision_encoder.projection.mlp.fc2.bias"),
),
act="gelu_approx",
)
vision = VisionModel(
patch_size=patch_size,
patch_emb=patch_emb,
pos_emb=pos_emb,
blocks=blocks,
post_ln=post_ln,
proj_mlp=proj_mlp,
)
## Text decoder model
wte = get_tensor("text_model.transformer.embd.wte.weight")
post_ln = LayerNormWeights(
weight=get_tensor("text_model.lm_head.ln.weight"),
bias=get_tensor("text_model.lm_head.ln.bias"),
)
lm_head = LinearWeights(
weight=get_tensor("text_model.lm_head.linear.weight"),
bias=get_tensor("text_model.lm_head.linear.bias"),
)
blocks = []
for i in range(text_blocks):
prefix = f"text_model.transformer.h.{i}"
blocks.append(
TextBlock(
ln=LayerNormWeights(
weight=get_tensor(f"{prefix}.ln.weight"),
bias=get_tensor(f"{prefix}.ln.bias"),
),
attn=AttentionWeights(
qkv=LinearWeights(
weight=get_tensor(f"{prefix}.mixer.Wqkv.weight"),
bias=get_tensor(f"{prefix}.mixer.Wqkv.bias"),
),
proj=LinearWeights(
weight=get_tensor(f"{prefix}.mixer.out_proj.weight"),
bias=get_tensor(f"{prefix}.mixer.out_proj.bias"),
),
n_heads=text_n_heads,
),
mlp=MLPWeights(
fc1=LinearWeights(
weight=get_tensor(f"{prefix}.mlp.fc1.weight"),
bias=get_tensor(f"{prefix}.mlp.fc1.bias"),
),
fc2=LinearWeights(
weight=get_tensor(f"{prefix}.mlp.fc2.weight"),
bias=get_tensor(f"{prefix}.mlp.fc2.bias"),
),
act="gelu_approx",
),
)
)
text = TextModel(wte=wte, blocks=blocks, post_ln=post_ln, lm_head=lm_head)
return MoondreamModel(vision=vision, text=text)
def load_from_safetensors(
safetensors_file: str,
vision_blocks: int = 27,
text_blocks: int = 24,
**kwargs,
) -> MoondreamModel:
with safetensors_open(safetensors_file) as get_tensor:
return load_model(get_tensor, vision_blocks, text_blocks, **kwargs)
def load_from_pt(
pt_file: str,
vision_blocks: int = 27,
text_blocks: int = 24,
**kwargs,
) -> MoondreamModel:
device = str(torch.empty(0).device)
tensors = torch.load(pt_file, map_location=device, weights_only=True)
tensors = {
k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
for k, v in tensors.items()
}
return load_model(lambda x: tensors[x], vision_blocks, text_blocks, **kwargs)
if __name__ == "__main__":
weights = load_from_safetensors("model.safetensors")
print(weights)