Chuckame commited on
Commit
181e20c
·
verified ·
1 Parent(s): 380a3cb

Create infer-rotation.py

Browse files
Files changed (1) hide show
  1. infer-rotation.py +57 -0
infer-rotation.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ from transformers import ViTImageProcessor, TFViTModel
5
+ import keras
6
+ import argparse
7
+
8
+ VIT_WEIGHTS_PATH = "model-vit-ang-loss.h5"
9
+ BASE_MODEL = "google/vit-base-patch16-224"
10
+ IMAGE_SIZE = 224
11
+
12
+ class Inference:
13
+ def __init__(self):
14
+ self.vit_model = self._load_vit_model()
15
+ self.image_preprocessor = self._load_image_preprocessor()
16
+
17
+ def predict_rotation(self, image_path):
18
+ X = self._preprocess(image_path)
19
+ y = self.vit_model.predict(X)[0][0]
20
+ return y
21
+
22
+ def _preprocess(self, image_path):
23
+ img = Image.open(image_path)
24
+ img = img.resize((IMAGE_SIZE, IMAGE_SIZE))
25
+ img = np.array(img)
26
+
27
+ X_vit = self.image_preprocessor.preprocess(images=[img], return_tensors="pt")["pixel_values"]
28
+ return np.array(X_vit)
29
+
30
+ def _load_image_preprocessor(self):
31
+ print("Loading Image Preprocessor")
32
+ return ViTImageProcessor.from_pretrained(BASE_MODEL)
33
+
34
+ def _load_vit_model(self):
35
+ print("Loading Model")
36
+ vit_base = TFViTModel.from_pretrained(BASE_MODEL)
37
+
38
+ img_input = keras.layers.Input(shape=(3,IMAGE_SIZE, IMAGE_SIZE))
39
+ x = vit_base.vit(img_input)
40
+ y = keras.layers.Dense(1, activation="linear")(x[-1])
41
+
42
+ model = keras.Model(inputs=img_input, outputs=y)
43
+ print(model.summary())
44
+
45
+ print("Loading Weights")
46
+ model.load_weights(VIT_WEIGHTS_PATH)
47
+
48
+ return model
49
+
50
+ if __name__=="__main__":
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("--image-path", type=str, required=True)
53
+ args = parser.parse_args()
54
+
55
+ model = Inference()
56
+ expected_angle = model.predict_rotation(args.image_path)
57
+ print(f"Predicted angle is about '{expected_angle}' degrees")