SeeForMe-Live / sample.py
akshit-g's picture
add : files
d3cd5c1
raw
history blame
2.75 kB
import argparse
from queue import Queue
from threading import Thread
import torch
from PIL import Image
from transformers import AutoTokenizer, TextIteratorStreamer
from moondream.hf import LATEST_REVISION, Moondream, detect_device
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, required=True)
parser.add_argument("--prompt", type=str, required=False)
parser.add_argument("--caption", action="store_true")
parser.add_argument("--cpu", action="store_true")
args = parser.parse_args()
if args.cpu:
device = torch.device("cpu")
dtype = torch.float32
else:
device, dtype = detect_device()
if device != torch.device("cpu"):
print("Using device:", device)
print("If you run into issues, pass the `--cpu` flag to this script.")
print()
image_path = args.image
prompt = args.prompt
model_id = "vikhyatk/moondream2"
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)
moondream = Moondream.from_pretrained(
model_id,
revision=LATEST_REVISION,
torch_dtype=dtype,
).to(device=device)
moondream.eval()
image = Image.open(image_path)
if args.caption:
print(moondream.caption(images=[image], tokenizer=tokenizer)[0])
else:
image_embeds = moondream.encode_image(image)
if prompt is None:
chat_history = ""
while True:
question = input("> ")
result_queue = Queue()
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
# Separate direct arguments from keyword arguments
thread_args = (image_embeds, question, tokenizer, chat_history)
thread_kwargs = {"streamer": streamer, "result_queue": result_queue}
thread = Thread(
target=moondream.answer_question,
args=thread_args,
kwargs=thread_kwargs,
)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
if not new_text.endswith("<") and not new_text.endswith("END"):
print(buffer, end="", flush=True)
buffer = ""
print(buffer)
thread.join()
answer = result_queue.get()
chat_history += f"Question: {question}\n\nAnswer: {answer}\n\n"
else:
print(">", prompt)
answer = moondream.answer_question(image_embeds, prompt, tokenizer)
print(answer)