from PIL import Image import numpy as np from transformers import ViTImageProcessor, TFViTModel import keras import argparse BASE_MODEL = "google/vit-base-patch16-224" IMAGE_SIZE = 224 class Inference: def __init__(self): self.vit_model = self._load_vit_model() self.image_preprocessor = self._load_image_preprocessor() def predict_rotation(self, image_path): X = self._preprocess(image_path) y = self.vit_model.predict(X)[0][0] return y def _preprocess(self, image_path): img = Image.open(image_path) img = img.resize((IMAGE_SIZE, IMAGE_SIZE)) img = np.array(img) X_vit = self.image_preprocessor.preprocess(images=[img], return_tensors="pt")["pixel_values"] return np.array(X_vit) def _load_image_preprocessor(self): print("Loading Image Preprocessor") return ViTImageProcessor.from_pretrained(BASE_MODEL) def _load_vit_model(self): print("Loading Model") vit_base = TFViTModel.from_pretrained(BASE_MODEL) img_input = keras.layers.Input(shape=(3,IMAGE_SIZE, IMAGE_SIZE)) x = vit_base.vit(img_input) y = keras.layers.Dense(1, activation="linear")(x[-1]) model = keras.Model(inputs=img_input, outputs=y) print(model.summary()) print("Loading Weights") model.load_weights("weights.h5") return model if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument("--image-path", type=str, required=True) args = parser.parse_args() model = Inference() expected_angle = model.predict_rotation(args.image_path) print(f"Predicted angle is about '{expected_angle}' degrees")