ngxquang commited on
Commit
2e5d83b
·
1 Parent(s): cfe3897

feat: update clip api for deployment

Browse files
.env ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PROJECT INFORMATION
2
+ HOST=0.0.0.0
3
+ PORT=7860
4
+ CORS_HEADERS=["*"]
5
+ CORS_ORIGINS=["http://localhost"]
6
+
7
+ MODEL_NAME="ViT-B/32"
8
+ DEVICE="cpu" # ["cuda", "cpu"]
9
+
10
+ INDEX_FILE_PATH="data/faiss-index/index_clip_L01_to_L20.faiss"
11
+ KEYFRAMES_GROUPS_JSON_PATH="data/config/keyframes_groups_L01_to_L20.json"
.env.example ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PROJECT INFORMATION
2
+ HOST=0.0.0.0
3
+ PORT=8000
4
+ CORS_HEADERS=["*"]
5
+ CORS_ORIGINS=["http://localhost"]
6
+
7
+ MODEL_NAME="ViT-B/32"
8
+ DEVICE="cpu" # ["cuda", "cpu"]
9
+
10
+ INDEX_FILE_PATH="data/faiss-index/index.faiss"
11
+ KEYFRAMES_GROUPS_JSON_PATH="data/config/keyframes_groups.json"
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.faiss filter=lfs diff=lfs merge=lfs -text
37
+ *.json filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8-slim
2
+
3
+ RUN apt-get update && \
4
+ apt-get install git gsutil -y && \
5
+ apt clean && \
6
+ rm -rf /var/cache/apt/*
7
+
8
+ WORKDIR /code
9
+
10
+ COPY requirements.txt /code/requirements.txt
11
+
12
+ # PYTHONDONTWRITEBYTECODE=1: Disables the creation of .pyc files (compiled bytecode)
13
+ # PYTHONUNBUFFERED=1: Disables buffering of the standard output stream
14
+ # PYTHONIOENCODING: specifies the encoding to be used for the standard input, output, and error streams
15
+ ENV PYTHONDONTWRITEBYTECODE=1 \
16
+ PYTHONUNBUFFERED=1 \
17
+ PYTHONIOENCODING=utf-8
18
+
19
+ RUN pip install -U pip && \
20
+ pip install --no-cache-dir -r /code/requirements.txt
21
+
22
+ RUN useradd -m -u 1000 user
23
+
24
+ USER user
25
+
26
+ ENV HOME=/home/user \
27
+ PATH=/home/user/.local/bin:$PATH
28
+
29
+ WORKDIR $HOME/app
30
+
31
+ COPY --chown=user . $HOME/app
32
+
33
+ # Download index
34
+ # RUN mkdir ./data/faiss-index/ && \
35
+ # gsutil -m cp "gs://thangtd1/faiss-index/index_clip_L01_to_L20.faiss" ./data/faiss-index/
36
+
37
+ CMD python ./src/main.py
data/config/keyframes_groups_L01_to_L20.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8d2f52efda68fed4a80512ecfe30a90e65663da396b61b8de3db11433cd65f3
3
+ size 17780893
data/faiss-index/index_clip_L01_to_L20.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c294320f4b8cb934f57f199500477324dd57a2e6445db375f02937e5a2fcf19
3
+ size 413999149
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.103.1
2
+ uvicorn==0.23.2
3
+ pydantic-settings==2.0.3
4
+
5
+
6
+ # Models
7
+ torch==1.7.1
8
+ torchvision==0.8.2
9
+ ftfy==6.1.1
10
+ regex
11
+ tqdm==4.66.1
12
+ git+https://github.com/openai/CLIP.git@main
13
+
14
+ # Vector Database
15
+ faiss-cpu
16
+
src/__init__.py ADDED
File without changes
src/config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from pydantic_settings import BaseSettings
4
+
5
+ FILE = Path(__file__)
6
+ ROOT = FILE.parent.parent
7
+
8
+
9
+ class Settings(BaseSettings):
10
+ # API SETTINGS
11
+ HOST: str
12
+ PORT: int
13
+ CORS_ORIGINS: list
14
+ CORS_HEADERS: list
15
+
16
+ # MODEL SETTINGS
17
+ MODEL_NAME: str = "ViT-B/32"
18
+ DEVICE: str = "cpu"
19
+
20
+ # FAISS DATABASE SETTINGS
21
+ INDEX_FILE_PATH: str
22
+ KEYFRAMES_GROUPS_JSON_PATH: str
23
+
24
+ class Config:
25
+ env_file = ROOT / ".env"
26
+
27
+
28
+ settings = Settings()
src/itr/__init__.py ADDED
File without changes
src/itr/dtb_cursor.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import lru_cache
3
+
4
+ import faiss
5
+
6
+
7
+ class DatabaseCursor:
8
+ def __init__(self, index_file_path: str, keyframes_groups_json_path: str):
9
+ self._load_index(index_file_path)
10
+ self._load_keyframes_groups_info(keyframes_groups_json_path)
11
+
12
+ @lru_cache(maxsize=1)
13
+ def _load_index(self, index_file_path):
14
+ self.index = faiss.read_index(index_file_path)
15
+
16
+ @lru_cache(maxsize=1)
17
+ def _load_keyframes_groups_info(self, keyframes_groups_json_path: str):
18
+ with open(keyframes_groups_json_path) as file:
19
+ self.keyframes_group_info = json.loads(file.read())
20
+
21
+ def kNN_search(self, query_vector: str, topk: int = 10):
22
+ results = []
23
+ distances, ids = self.index.search(query_vector, topk)
24
+ for i in range(len(ids[0])):
25
+ frame_detail = self.keyframes_group_info[ids[0][i]]
26
+ frame_detail["distance"] = str(distances[0][i])
27
+ results.append(frame_detail)
28
+ return results
src/itr/router.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, File, status
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+
5
+ from .dtb_cursor import DatabaseCursor
6
+ from .vlm_model import VisionLanguageModel
7
+
8
+
9
+ class Item(BaseModel):
10
+ query_text: str
11
+ topk: int
12
+
13
+
14
+ router = APIRouter()
15
+
16
+
17
+ vectordb_cursor = None
18
+ vlm_model = None
19
+
20
+
21
+ def init_vectordb(**kargs):
22
+ # Singleton pattern
23
+ global vectordb_cursor
24
+ if vectordb_cursor is None:
25
+ vectordb_cursor = DatabaseCursor(**kargs)
26
+
27
+
28
+ def init_model(**kargs):
29
+ # Singleton
30
+ global vlm_model
31
+ if vlm_model is None:
32
+ vlm_model = VisionLanguageModel(**kargs)
33
+
34
+
35
+ @router.post("/retrieval/image-text")
36
+ async def retrieve(item: Item) -> JSONResponse:
37
+ try:
38
+ query_vector = vlm_model.get_embedding(input=item.query_text)
39
+ search_results = vectordb_cursor.kNN_search(query_vector, item.topk)
40
+ except Exception:
41
+ return JSONResponse(
42
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
43
+ content={"message": "Search error"},
44
+ )
45
+
46
+ return JSONResponse(
47
+ status_code=status.HTTP_200_OK,
48
+ content={"message": "success", "details": search_results},
49
+ )
src/itr/vlm_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Union
3
+
4
+ import clip
5
+ from PIL import Image
6
+
7
+
8
+ class VisionLanguageModel:
9
+ def __init__(self, model_name: str = "ViT-B/32", device: str = "cuda"):
10
+ self._load_model(model_name, device)
11
+ self.device = device
12
+
13
+ @lru_cache(maxsize=1)
14
+ def _load_model(self, model_name, device: str = "cpu"):
15
+ self.model, self.processor = clip.load(model_name, device=device)
16
+
17
+ def get_embedding(self, input: Union[str, Image.Image]):
18
+ if isinstance(input, str):
19
+ tokens = clip.tokenize(input).to(self.device)
20
+ vector = self.model.encode_text(tokens)
21
+ vector /= vector.norm(dim=-1, keepdim=True)
22
+ vector = vector.cpu().detach().numpy().astype("float32")
23
+ return vector
24
+ elif isinstance(input, Image.Image):
25
+ image_input = self.preprocess(input).unsqueeze(0).to(self.device)
26
+ vector = self.model.encode_image(image_input)
27
+ vector /= vector.norm(dim=-1, keepdim=True)
28
+ return vector
29
+ else:
30
+ raise Exception("Invalid input type")
src/main.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from config import settings
3
+ from fastapi import FastAPI, Request, status
4
+ from fastapi.exceptions import RequestValidationError
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.responses import JSONResponse, RedirectResponse
7
+ from itr.router import init_model, init_vectordb
8
+ from itr.router import router as router
9
+
10
+ app = FastAPI(title="Text-to-image Retrieval API")
11
+
12
+
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=settings.CORS_ORIGINS,
16
+ allow_headers=settings.CORS_HEADERS,
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ )
20
+
21
+
22
+ @app.exception_handler(RequestValidationError)
23
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
24
+ # Get the original 'detail' list of errors
25
+ details = exc.errors()
26
+ error_details = []
27
+
28
+ for error in details:
29
+ error_details.append({"error": f"{error['msg']} {str(error['loc'])}"})
30
+ return JSONResponse(content={"message": error_details})
31
+
32
+
33
+ @app.on_event("startup")
34
+ async def startup_event():
35
+ init_vectordb(
36
+ index_file_path=settings.INDEX_FILE_PATH,
37
+ keyframes_groups_json_path=settings.KEYFRAMES_GROUPS_JSON_PATH,
38
+ )
39
+ device = (
40
+ "cuda" if settings.DEVICE == "cuda" and torch.cuda.is_available() else "cpu"
41
+ )
42
+ init_model(model_name=settings.MODEL_NAME, device=device)
43
+
44
+
45
+ @app.get("/", include_in_schema=False)
46
+ async def root() -> None:
47
+ return RedirectResponse("/docs")
48
+
49
+
50
+ @app.get("/health", status_code=status.HTTP_200_OK, tags=["health"])
51
+ async def perform_healthcheck() -> None:
52
+ return JSONResponse(content={"message": "success"})
53
+
54
+
55
+ app.include_router(router)
56
+
57
+
58
+ # Start API
59
+ # if __name__ == "__main__":
60
+ # import uvicorn
61
+
62
+ # uvicorn.run("main:app", host=settings.HOST, port=settings.PORT, reload=True)