Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import tarfile | |
| from dataclasses import dataclass | |
| from io import BytesIO | |
| from typing import Any, Dict, Generator, List, Optional, TypedDict, Union | |
| import numpy as np | |
| import onnx | |
| import onnxruntime as ort | |
| from PIL import Image | |
| from tokenizers import Tokenizer | |
| from .preprocess import create_patches | |
| class EncodedImage: | |
| pos: int | |
| kv_caches: List[np.ndarray] | |
| SamplingSettings = TypedDict( | |
| "SamplingSettings", | |
| {"max_tokens": int}, | |
| total=False, | |
| ) | |
| CaptionOutput = TypedDict( | |
| "CaptionOutput", {"caption": Union[str, Generator[str, None, None]]} | |
| ) | |
| QueryOutput = TypedDict( | |
| "QueryOutput", {"answer": Union[str, Generator[str, None, None]]} | |
| ) | |
| DEFAULT_MAX_TOKENS = 1024 | |
| MIN_SUPPORTED_VERSION = 1 | |
| MAX_SUPPORT_VERSION = 1 | |
| class Region: | |
| pass | |
| class VL: | |
| def __init__(self, model_path: Optional[str], ort_settings: Dict[str, Any] = {}): | |
| """ | |
| Initialize the Moondream VL (Vision Language) model. | |
| Args: | |
| model_path (str): The path to the model file. | |
| Returns: | |
| None | |
| """ | |
| if model_path is None or not os.path.isfile(model_path): | |
| raise ValueError("Model path is invalid or file does not exist.") | |
| if not tarfile.is_tarfile(model_path): | |
| raise ValueError( | |
| "Model format not recognized. You may need to upgrade the moondream" | |
| " package." | |
| ) | |
| self.text_decoders = [] | |
| with tarfile.open(model_path, "r:*") as tar: | |
| for member in tar.getmembers(): | |
| name = member.name.split("/")[-1] | |
| f = tar.extractfile(member) | |
| if f is not None: | |
| contents = f.read() | |
| else: | |
| continue | |
| if name == "vision_encoder.onnx": | |
| self.vision_encoder = ort.InferenceSession(contents, **ort_settings) | |
| elif name == "vision_projection.onnx": | |
| self.vision_projection = ort.InferenceSession( | |
| contents, **ort_settings | |
| ) | |
| elif name == "text_encoder.onnx": | |
| self.text_encoder = ort.InferenceSession(contents, **ort_settings) | |
| elif "text_decoder" in name and name.endswith(".onnx"): | |
| self.text_decoders.append( | |
| ort.InferenceSession(contents, **ort_settings) | |
| ) | |
| elif name == "tokenizer.json": | |
| self.tokenizer = Tokenizer.from_buffer(contents) | |
| elif name == "initial_kv_caches.npy": | |
| self.initial_kv_caches = [x for x in np.load(BytesIO(contents))] | |
| elif name == "config.json": | |
| self.config = json.loads(contents) | |
| assert self.vision_encoder is not None | |
| assert self.vision_projection is not None | |
| assert self.text_encoder is not None | |
| assert len(self.text_decoders) > 0 | |
| assert self.tokenizer is not None | |
| assert self.initial_kv_caches is not None | |
| assert self.config is not None | |
| if type(self.config) != dict or "model_version" not in self.config: | |
| raise ValueError("Model format not recognized.") | |
| if ( | |
| self.config["model_version"] < MIN_SUPPORTED_VERSION | |
| or self.config["model_version"] > MAX_SUPPORT_VERSION | |
| ): | |
| raise ValueError( | |
| "Model version not supported. You may need to upgrade the moondream" | |
| " package." | |
| ) | |
| self.special_tokens = self.config["special_tokens"] | |
| self.templates = self.config["templates"] | |
| def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage: | |
| """ | |
| Preprocess the image by running it through the model. | |
| This method is useful if the user wants to make multiple queries with the same image. | |
| The output is not guaranteed to be backward-compatible across version updates, | |
| and should not be persisted out of band. | |
| Args: | |
| image (Image.Image): The input image to be encoded. | |
| Returns: | |
| The encoded representation of the image. | |
| """ | |
| if type(image) == EncodedImage: | |
| return image | |
| image_patches = create_patches(image) # type: ignore | |
| patch_emb = self.vision_encoder.run(None, {"input": image_patches})[0] | |
| patch_emb = np.concatenate([patch_emb[0], patch_emb[1]], axis=-1) | |
| patch_emb = np.expand_dims(patch_emb, axis=0) | |
| (inputs_embeds,) = self.vision_projection.run(None, {"input": patch_emb}) | |
| kv_caches = self.initial_kv_caches | |
| pos = inputs_embeds.shape[-2] + kv_caches[0].shape[-2] | |
| for i, decoder in enumerate(self.text_decoders): | |
| inputs_embeds, kv_cache_update = decoder.run( | |
| None, | |
| { | |
| "inputs_embeds": inputs_embeds, | |
| "kv_cache": kv_caches[i], | |
| }, | |
| ) | |
| kv_caches[i] = np.concatenate([kv_caches[i], kv_cache_update], axis=-2) | |
| return EncodedImage(pos=pos, kv_caches=kv_caches) | |
| def _generate( | |
| self, hidden: np.ndarray, encoded_image: EncodedImage, max_tokens: int | |
| ) -> Generator[str, None, None]: | |
| kv_caches = { | |
| i: np.zeros( | |
| ( | |
| *self.initial_kv_caches[0].shape[:-2], | |
| 2048, | |
| self.initial_kv_caches[0].shape[-1], | |
| ), | |
| dtype=np.float16, | |
| ) | |
| for i in range(len(self.text_decoders)) | |
| } | |
| for i, kv_cache in kv_caches.items(): | |
| kv_cache[:, :, :, :, : encoded_image.pos, :] = encoded_image.kv_caches[i] | |
| pos = encoded_image.pos | |
| generated_tokens = 0 | |
| while generated_tokens < max_tokens: | |
| # Track the original T dimension of the input hidden states, so we can | |
| # bind the kv cache update accordingly. We can't check it just-in-time | |
| # because the final 'hidden' output is actually the model's logits. | |
| og_t = hidden.shape[-2] | |
| for i, decoder in enumerate(self.text_decoders): | |
| hidden, kv_cache_update = decoder.run( | |
| None, | |
| { | |
| "inputs_embeds": hidden, | |
| "kv_cache": kv_caches[i][:, :, :, :, :pos, :], | |
| }, | |
| ) | |
| kv_caches[i][:, :, :, :, pos : pos + og_t, :] = kv_cache_update | |
| next_token = np.argmax(hidden, axis=-1)[0] | |
| if next_token == self.special_tokens["eos"]: | |
| break | |
| yield self.tokenizer.decode([next_token]) | |
| generated_tokens += 1 | |
| pos += og_t | |
| (hidden,) = self.text_encoder.run(None, {"input_ids": [[next_token]]}) | |
| def caption( | |
| self, | |
| image: Union[Image.Image, EncodedImage], | |
| length: str = "normal", | |
| stream: bool = False, | |
| settings: Optional[SamplingSettings] = None, | |
| ) -> CaptionOutput: | |
| """ | |
| Generate a caption for the input image. | |
| Args: | |
| image (Union[Image.Image, EncodedImage]): The input image to be captioned. | |
| settings (Optional[SamplingSettings]): Optional settings for the caption generation. | |
| If not provided, default settings will be used. | |
| Returns: | |
| str: The caption for the input image. | |
| """ | |
| if "caption" not in self.templates: | |
| raise ValueError("Model does not support captioning.") | |
| if length not in self.templates["caption"]: | |
| raise ValueError(f"Model does not support caption length '{length}'.") | |
| (input_embeds,) = self.text_encoder.run( | |
| None, {"input_ids": [self.templates["caption"][length]]} | |
| ) | |
| if settings is None: | |
| settings = {} | |
| max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS) | |
| encoded_image = self.encode_image(image) | |
| def generator(): | |
| for t in self._generate(input_embeds, encoded_image, max_tokens): | |
| yield t | |
| if stream: | |
| return {"caption": generator()} | |
| else: | |
| out = "" | |
| for t in generator(): | |
| out += t | |
| return {"caption": out} | |
| def query( | |
| self, | |
| image: Union[Image.Image, EncodedImage], | |
| question: str, | |
| stream: bool = False, | |
| settings: Optional[SamplingSettings] = None, | |
| ) -> QueryOutput: | |
| """ | |
| Generate an answer to the input question about the input image. | |
| Args: | |
| image (Union[Image.Image, EncodedImage]): The input image to be queried. | |
| question (str): The question to be answered. | |
| Returns: | |
| str: The answer to the input question about the input image. | |
| """ | |
| if "query" not in self.templates: | |
| raise ValueError("Model does not support querying.") | |
| question_toks = ( | |
| self.templates["query"]["prefix"] | |
| + self.tokenizer.encode(question).ids | |
| + self.templates["query"]["suffix"] | |
| ) | |
| (input_embeds,) = self.text_encoder.run(None, {"input_ids": [question_toks]}) | |
| if settings is None: | |
| settings = {} | |
| max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS) | |
| encoded_image = self.encode_image(image) | |
| def generator(): | |
| for t in self._generate(input_embeds, encoded_image, max_tokens): | |
| yield t | |
| if stream: | |
| return {"answer": generator()} | |
| else: | |
| out = "" | |
| for t in generator(): | |
| out += t | |
| return {"answer": out} | |
| def detect( | |
| self, image: Union[Image.Image, EncodedImage], object: str | |
| ) -> List[Region]: | |
| """ | |
| Detect and localize the specified object in the input image. | |
| Args: | |
| image (Union[Image.Image, EncodedImage]): The input image to be analyzed. | |
| object (str): The object to be detected in the image. | |
| Returns: | |
| List[Region]: A list of Region objects representing the detected instances of the specified object. | |
| """ | |
| return [] | |