amaye15 commited on
Commit
3f75720
·
1 Parent(s): b74cd7b

Feat - Model, Requirements & Handler

Browse files
Files changed (7) hide show
  1. .gitignore +2 -0
  2. config.json +26 -0
  3. download.py +25 -0
  4. handler.py +131 -0
  5. model.safetensors +3 -0
  6. preprocessor_config.json +27 -0
  7. requirements.txt +27 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *env*
2
+ *.DS_Store*
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "apple/aimv2-large-patch14-native",
3
+ "architectures": [
4
+ "AIMv2Model"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "apple/aimv2-large-patch14-native--configuration_aimv2.AIMv2Config",
9
+ "AutoModel": "apple/aimv2-large-patch14-native--modeling_aimv2.AIMv2Model",
10
+ "FlaxAutoModel": "apple/aimv2-large-patch14-native--modeling_flax_aimv2.FlaxAIMv2Model"
11
+ },
12
+ "hidden_size": 1024,
13
+ "intermediate_size": 2816,
14
+ "model_type": "aimv2",
15
+ "num_attention_heads": 8,
16
+ "num_channels": 3,
17
+ "num_hidden_layers": 24,
18
+ "num_queries": 256,
19
+ "patch_size": 14,
20
+ "projection_dropout": 0.0,
21
+ "qkv_bias": false,
22
+ "rms_norm_eps": 1e-05,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.48.1",
25
+ "use_bias": false
26
+ }
download.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import requests
2
+
3
+ # from PIL import Image
4
+ from transformers import AutoImageProcessor, AutoModel
5
+
6
+ # url = "http://images.cocodataset.org/val2017/000000039769.jpg"
7
+ # image = Image.open(requests.get(url, stream=True).raw)
8
+
9
+ PATH = "."
10
+ MODEL_NAME = "apple/aimv2-large-patch14-native"
11
+
12
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
13
+
14
+
15
+ processor.save_pretrained(PATH)
16
+
17
+ model = AutoModel.from_pretrained(
18
+ MODEL_NAME,
19
+ trust_remote_code=True,
20
+ )
21
+
22
+ model.save_pretrained(PATH)
23
+
24
+ # inputs = processor(images=image, return_tensors="pt")
25
+ # outputs = model(**inputs)
handler.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Any, List
3
+ from PIL import Image
4
+ import base64
5
+ from io import BytesIO
6
+ import logging
7
+ from transformers import AutoImageProcessor, AutoModel
8
+ import os
9
+ from dataclasses import dataclass
10
+
11
+
12
+ # Define a dataclass for the results
13
+ @dataclass
14
+ class ImageEncodingResult:
15
+ image_encoded: List[List[float]] # Full encoded embeddings
16
+ image_encoded_average: List[float] # Average of the embeddings
17
+
18
+
19
+ class EndpointHandler:
20
+ """
21
+ A handler class for processing images and generating embeddings using a pre-trained model.
22
+ Attributes:
23
+ processor: The pre-trained image processor.
24
+ model: The pre-trained model for generating embeddings.
25
+ device: The device (CPU or CUDA) used to run model inference.
26
+ """
27
+
28
+ def __init__(self):
29
+ """
30
+ Initializes the EndpointHandler with the model and processor from the current directory.
31
+ """
32
+ # Initialize logging
33
+ logging.basicConfig(level=logging.INFO)
34
+ self.logger = logging.getLogger(__name__)
35
+
36
+ # Determine the device (CPU or CUDA)
37
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ self.logger.info(f"Using device: {self.device}")
39
+
40
+ # Load the model and processor from the current directory
41
+ self.logger.info("Loading model and processor from the current directory.")
42
+ try:
43
+ self.processor = AutoImageProcessor.from_pretrained(os.getcwd())
44
+ self.model = AutoModel.from_pretrained(
45
+ os.getcwd(), trust_remote_code=True
46
+ ).to(self.device)
47
+ self.logger.info("Model and processor loaded successfully.")
48
+ except Exception as e:
49
+ self.logger.error(f"Failed to load model or processor: {e}")
50
+ raise
51
+
52
+ def _resize_image_if_large(
53
+ self, image: Image.Image, max_size: int = 1080
54
+ ) -> Image.Image:
55
+ """
56
+ Resizes an image if its dimensions exceed the specified maximum size.
57
+ Args:
58
+ image (Image.Image): Input image.
59
+ max_size (int): Maximum size for the image dimensions.
60
+ Returns:
61
+ Image.Image: Resized image.
62
+ """
63
+ width, height = image.size
64
+ if width > max_size or height > max_size:
65
+ scale = max_size / max(width, height)
66
+ new_width = int(width * scale)
67
+ new_height = int(height * scale)
68
+ image = image.resize((new_width, new_height), resample=Image.BILINEAR)
69
+ return image
70
+
71
+ def _encode_image(self, image: Image.Image) -> ImageEncodingResult:
72
+ """
73
+ Encodes an image into embeddings using the model.
74
+ Args:
75
+ image (Image.Image): Input image.
76
+ Returns:
77
+ ImageEncodingResult: Dataclass containing the encoded embeddings and their average.
78
+ """
79
+ try:
80
+ # Resize the image if necessary
81
+ image = self._resize_image_if_large(image)
82
+
83
+ # Process the image and generate embeddings
84
+ inputs = self.processor(image, return_tensors="pt").to(self.device)
85
+ with torch.inference_mode():
86
+ outputs = self.model(**inputs)
87
+ last_hidden_state = outputs.last_hidden_state
88
+ image_encoded = last_hidden_state.squeeze().tolist()
89
+ image_encoded_average = last_hidden_state.mean(dim=1).squeeze().tolist()
90
+
91
+ return ImageEncodingResult(
92
+ image_encoded=image_encoded,
93
+ image_encoded_average=image_encoded_average,
94
+ )
95
+ except Exception as e:
96
+ self.logger.error(f"Error encoding image: {e}")
97
+ raise
98
+
99
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
100
+ """
101
+ Processes input data containing base64-encoded images and generates embeddings.
102
+ Args:
103
+ data (Dict[str, Any]): Dictionary containing input images.
104
+ Returns:
105
+ Dict[str, Any]: Dictionary containing encoded embeddings or error messages.
106
+ """
107
+ images_data = data.get("images", [])
108
+
109
+ if not images_data:
110
+ return {"error": "No image data provided."}
111
+
112
+ results = []
113
+ for img_data in images_data:
114
+ if isinstance(img_data, str):
115
+ try:
116
+ # Decode the base64-encoded image
117
+ image_bytes = base64.b64decode(img_data)
118
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
119
+
120
+ # Encode the image
121
+ encoded_image = self._encode_image(image)
122
+ results.append(encoded_image)
123
+ except Exception as e:
124
+ self.logger.error(f"Invalid image data: {e}")
125
+ return {"error": f"Invalid image data: {e}"}
126
+ else:
127
+ self.logger.error("Images should be base64-encoded strings.")
128
+ return {"error": "Images should be base64-encoded strings."}
129
+
130
+ # Convert the results to a dictionary for JSON serialization
131
+ return {"results": [result.__dict__ for result in results]}
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cdc4c4ea6f2a477edebb482cc36ba021409a313eabdf3e6be62eb722771e7d1
3
+ size 1235760720
preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": false,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": false,
11
+ "image_mean": [
12
+ 0.48145466,
13
+ 0.4578275,
14
+ 0.40821073
15
+ ],
16
+ "image_processor_type": "CLIPImageProcessor",
17
+ "image_std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ],
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 224
26
+ }
27
+ }
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.3.0
2
+ certifi==2024.12.14
3
+ charset-normalizer==3.4.1
4
+ filelock==3.17.0
5
+ fsspec==2024.12.0
6
+ huggingface-hub==0.27.1
7
+ idna==3.10
8
+ Jinja2==3.1.5
9
+ MarkupSafe==3.0.2
10
+ mpmath==1.3.0
11
+ networkx==3.4.2
12
+ numpy==2.2.2
13
+ packaging==24.2
14
+ pillow==11.1.0
15
+ psutil==6.1.1
16
+ PyYAML==6.0.2
17
+ regex==2024.11.6
18
+ requests==2.32.3
19
+ safetensors==0.5.2
20
+ setuptools==75.8.0
21
+ sympy==1.13.1
22
+ tokenizers==0.21.0
23
+ torch==2.5.1
24
+ tqdm==4.67.1
25
+ transformers==4.48.1
26
+ typing_extensions==4.12.2
27
+ urllib3==2.3.0