File size: 2,747 Bytes
d3cd5c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
from queue import Queue
from threading import Thread

import torch
from PIL import Image
from transformers import AutoTokenizer, TextIteratorStreamer

from moondream.hf import LATEST_REVISION, Moondream, detect_device

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--image", type=str, required=True)
    parser.add_argument("--prompt", type=str, required=False)
    parser.add_argument("--caption", action="store_true")
    parser.add_argument("--cpu", action="store_true")
    args = parser.parse_args()

    if args.cpu:
        device = torch.device("cpu")
        dtype = torch.float32
    else:
        device, dtype = detect_device()
        if device != torch.device("cpu"):
            print("Using device:", device)
            print("If you run into issues, pass the `--cpu` flag to this script.")
            print()

    image_path = args.image
    prompt = args.prompt

    model_id = "vikhyatk/moondream2"
    tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)
    moondream = Moondream.from_pretrained(
        model_id,
        revision=LATEST_REVISION,
        torch_dtype=dtype,
    ).to(device=device)
    moondream.eval()

    image = Image.open(image_path)

    if args.caption:
        print(moondream.caption(images=[image], tokenizer=tokenizer)[0])
    else:
        image_embeds = moondream.encode_image(image)

        if prompt is None:
            chat_history = ""

            while True:
                question = input("> ")

                result_queue = Queue()

                streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

                # Separate direct arguments from keyword arguments
                thread_args = (image_embeds, question, tokenizer, chat_history)
                thread_kwargs = {"streamer": streamer, "result_queue": result_queue}

                thread = Thread(
                    target=moondream.answer_question,
                    args=thread_args,
                    kwargs=thread_kwargs,
                )
                thread.start()

                buffer = ""
                for new_text in streamer:
                    buffer += new_text
                    if not new_text.endswith("<") and not new_text.endswith("END"):
                        print(buffer, end="", flush=True)
                        buffer = ""
                print(buffer)

                thread.join()

                answer = result_queue.get()
                chat_history += f"Question: {question}\n\nAnswer: {answer}\n\n"
        else:
            print(">", prompt)
            answer = moondream.answer_question(image_embeds, prompt, tokenizer)
            print(answer)