File size: 3,619 Bytes
d3cd5c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import json
import os

import torch
from PIL import Image
from transformers import AutoTokenizer

from .rope import precompute_freqs_cis
from .text import lm_head, text_decoder, text_encoder
from .vision import encode_image
from .weights import load_from_pt, load_from_safetensors

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--image", "-i", type=str, required=True)
    parser.add_argument("--prompt", "-p", type=str, required=True)
    parser.add_argument("--model", "-m", type=str, required=True)
    parser.add_argument("--config", "-c", type=str, default="{}")
    parser.add_argument("--max-tokens", "-t", type=int, default=200)
    parser.add_argument("--sampler", "-s", type=str, default="greedy")
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    # Load config.
    config = json.loads(args.config)
    text_n_heads = config.get("text_n_heads", 32)

    # Load model.
    model_path = args.model
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model not found at {model_path}")
    if model_path.endswith(".pt"):
        model = load_from_pt(model_path, **config)
    elif model_path.endswith(".safetensors"):
        model = load_from_safetensors(model_path, **config)
    else:
        raise ValueError(f"Invalid model format: {model_path}")

    # Encode image.
    image_path = args.image
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found at {image_path}")
    image = Image.open(image_path)
    image = image.resize((378, 378))
    image_tensor = encode_image(image, model.vision)

    # Encode text, and create inputs_embeds.
    tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2")
    prompt = f"\n\nQuestion: {args.prompt}\n\nAnswer:"
    input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
    input_ids = torch.cat([torch.tensor([[tokenizer.eos_token_id]]), input_ids], dim=1)
    inputs_embeds = text_encoder(input_ids, model.text)
    inputs_embeds = torch.cat(
        [
            inputs_embeds[:, 0:1, :],
            image_tensor.unsqueeze(0),
            inputs_embeds[:, 1:, :],
        ],
        dim=1,
    )

    kv_cache = torch.empty(24, 2, 1, text_n_heads, 2048, 64, dtype=torch.float16)
    freqs_cis = precompute_freqs_cis(32, 2048)
    pos = 0

    for _ in range(args.max_tokens):
        with torch.no_grad():
            hidden, kv_cache_update = text_decoder(
                inputs_embeds, model.text, kv_cache[:, :, :, :, :pos, :], freqs_cis
            )
            logits = lm_head(hidden, model.text)
            kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
                kv_cache_update
            )
            pos += kv_cache_update.size(-2)

            if args.sampler == "multinomial":
                next_token = torch.multinomial(
                    torch.softmax(logits, dim=-1), num_samples=1
                ).squeeze(0)
            elif args.sampler == "greedy":
                next_token = torch.argmax(logits, dim=-1)
            else:
                raise ValueError(f"Invalid sampler: {args.sampler}")

            if next_token == tokenizer.eos_token_id:
                print()
                break

            input_ids = next_token.unsqueeze(0)
            inputs_embeds = text_encoder(input_ids, model.text)

            output_text = tokenizer.batch_decode(input_ids)[0]
            print(output_text, end="", flush=True)