Spaces:
Running
Running
import math | |
from contextlib import contextmanager | |
from dataclasses import dataclass | |
from typing import List, Callable | |
import safetensors | |
import torch | |
from .layers import AttentionWeights, LayerNormWeights, LinearWeights, MLPWeights | |
class VisionBlock: | |
ln1: LayerNormWeights | |
attn: AttentionWeights | |
ln2: LayerNormWeights | |
mlp: MLPWeights | |
class VisionModel: | |
patch_size: int | |
patch_emb: LinearWeights | |
pos_emb: torch.Tensor | |
blocks: List[VisionBlock] | |
post_ln: LayerNormWeights | |
proj_mlp: MLPWeights | |
class TextBlock: | |
ln: LayerNormWeights | |
attn: AttentionWeights | |
mlp: MLPWeights | |
class TextModel: | |
wte: torch.Tensor | |
blocks: List[TextBlock] | |
post_ln: LayerNormWeights | |
lm_head: LinearWeights | |
class MoondreamModel: | |
vision: VisionModel | |
text: TextModel | |
def safetensors_open(safetensors_file: str): | |
""" | |
Simplify interfacing with safetensors files. Eliminates the need to ignore | |
type errors when using the `safe_open` function. | |
""" | |
with safetensors.safe_open( | |
safetensors_file, framework="pt" | |
) as st: # pyright: ignore | |
def get_tensor(name: str) -> torch.Tensor: | |
return st.get_tensor(name) | |
yield get_tensor | |
def load_model( | |
get_tensor: Callable[[str], torch.Tensor], | |
vision_blocks: int = 27, | |
text_blocks: int = 24, | |
vision_n_heads: int = 16, | |
text_n_heads: int = 32, | |
) -> MoondreamModel: | |
## Vision encoder | |
prefix = "vision_encoder.encoder.model.visual.patch_embed.linear" | |
patch_emb = LinearWeights( | |
weight=get_tensor(f"{prefix}.weight"), bias=get_tensor(f"{prefix}.bias") | |
) | |
patch_size = int(math.sqrt(patch_emb.weight.shape[1] // 3)) | |
pos_emb = get_tensor("vision_encoder.encoder.model.visual.pos_embed") | |
post_ln = LayerNormWeights( | |
weight=get_tensor("vision_encoder.encoder.model.visual.norm.weight"), | |
bias=get_tensor("vision_encoder.encoder.model.visual.norm.bias"), | |
) | |
blocks = [] | |
for i in range(vision_blocks): | |
prefix = f"vision_encoder.encoder.model.visual.blocks.{i}" | |
blocks.append( | |
VisionBlock( | |
ln1=LayerNormWeights( | |
weight=get_tensor(f"{prefix}.norm1.weight"), | |
bias=get_tensor(f"{prefix}.norm1.bias"), | |
), | |
attn=AttentionWeights( | |
qkv=LinearWeights( | |
weight=get_tensor(f"{prefix}.attn.qkv.weight"), | |
bias=get_tensor(f"{prefix}.attn.qkv.bias"), | |
), | |
proj=LinearWeights( | |
weight=get_tensor(f"{prefix}.attn.proj.weight"), | |
bias=get_tensor(f"{prefix}.attn.proj.bias"), | |
), | |
n_heads=vision_n_heads, | |
), | |
ln2=LayerNormWeights( | |
weight=get_tensor(f"{prefix}.norm2.weight"), | |
bias=get_tensor(f"{prefix}.norm2.bias"), | |
), | |
mlp=MLPWeights( | |
fc1=LinearWeights( | |
weight=get_tensor(f"{prefix}.mlp.fc1.weight"), | |
bias=get_tensor(f"{prefix}.mlp.fc1.bias"), | |
), | |
fc2=LinearWeights( | |
weight=get_tensor(f"{prefix}.mlp.fc2.weight"), | |
bias=get_tensor(f"{prefix}.mlp.fc2.bias"), | |
), | |
), | |
) | |
) | |
proj_mlp = MLPWeights( | |
fc1=LinearWeights( | |
weight=get_tensor("vision_encoder.projection.mlp.fc1.weight"), | |
bias=get_tensor("vision_encoder.projection.mlp.fc1.bias"), | |
), | |
fc2=LinearWeights( | |
weight=get_tensor("vision_encoder.projection.mlp.fc2.weight"), | |
bias=get_tensor("vision_encoder.projection.mlp.fc2.bias"), | |
), | |
act="gelu_approx", | |
) | |
vision = VisionModel( | |
patch_size=patch_size, | |
patch_emb=patch_emb, | |
pos_emb=pos_emb, | |
blocks=blocks, | |
post_ln=post_ln, | |
proj_mlp=proj_mlp, | |
) | |
## Text decoder model | |
wte = get_tensor("text_model.transformer.embd.wte.weight") | |
post_ln = LayerNormWeights( | |
weight=get_tensor("text_model.lm_head.ln.weight"), | |
bias=get_tensor("text_model.lm_head.ln.bias"), | |
) | |
lm_head = LinearWeights( | |
weight=get_tensor("text_model.lm_head.linear.weight"), | |
bias=get_tensor("text_model.lm_head.linear.bias"), | |
) | |
blocks = [] | |
for i in range(text_blocks): | |
prefix = f"text_model.transformer.h.{i}" | |
blocks.append( | |
TextBlock( | |
ln=LayerNormWeights( | |
weight=get_tensor(f"{prefix}.ln.weight"), | |
bias=get_tensor(f"{prefix}.ln.bias"), | |
), | |
attn=AttentionWeights( | |
qkv=LinearWeights( | |
weight=get_tensor(f"{prefix}.mixer.Wqkv.weight"), | |
bias=get_tensor(f"{prefix}.mixer.Wqkv.bias"), | |
), | |
proj=LinearWeights( | |
weight=get_tensor(f"{prefix}.mixer.out_proj.weight"), | |
bias=get_tensor(f"{prefix}.mixer.out_proj.bias"), | |
), | |
n_heads=text_n_heads, | |
), | |
mlp=MLPWeights( | |
fc1=LinearWeights( | |
weight=get_tensor(f"{prefix}.mlp.fc1.weight"), | |
bias=get_tensor(f"{prefix}.mlp.fc1.bias"), | |
), | |
fc2=LinearWeights( | |
weight=get_tensor(f"{prefix}.mlp.fc2.weight"), | |
bias=get_tensor(f"{prefix}.mlp.fc2.bias"), | |
), | |
act="gelu_approx", | |
), | |
) | |
) | |
text = TextModel(wte=wte, blocks=blocks, post_ln=post_ln, lm_head=lm_head) | |
return MoondreamModel(vision=vision, text=text) | |
def load_from_safetensors( | |
safetensors_file: str, | |
vision_blocks: int = 27, | |
text_blocks: int = 24, | |
**kwargs, | |
) -> MoondreamModel: | |
with safetensors_open(safetensors_file) as get_tensor: | |
return load_model(get_tensor, vision_blocks, text_blocks, **kwargs) | |
def load_from_pt( | |
pt_file: str, | |
vision_blocks: int = 27, | |
text_blocks: int = 24, | |
**kwargs, | |
) -> MoondreamModel: | |
device = str(torch.empty(0).device) | |
tensors = torch.load(pt_file, map_location=device, weights_only=True) | |
tensors = { | |
k.replace("._orig_mod", ""): v.to(dtype=torch.float16) | |
for k, v in tensors.items() | |
} | |
return load_model(lambda x: tensors[x], vision_blocks, text_blocks, **kwargs) | |
if __name__ == "__main__": | |
weights = load_from_safetensors("model.safetensors") | |
print(weights) | |