Spaces:
Runtime error
Runtime error
ngxquang
commited on
Commit
·
2e5d83b
1
Parent(s):
cfe3897
feat: update clip api for deployment
Browse files- .env +11 -0
- .env.example +11 -0
- .gitattributes +2 -0
- Dockerfile +37 -0
- data/config/keyframes_groups_L01_to_L20.json +3 -0
- data/faiss-index/index_clip_L01_to_L20.faiss +3 -0
- requirements.txt +16 -0
- src/__init__.py +0 -0
- src/config.py +28 -0
- src/itr/__init__.py +0 -0
- src/itr/dtb_cursor.py +28 -0
- src/itr/router.py +49 -0
- src/itr/vlm_model.py +30 -0
- src/main.py +62 -0
.env
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PROJECT INFORMATION
|
2 |
+
HOST=0.0.0.0
|
3 |
+
PORT=7860
|
4 |
+
CORS_HEADERS=["*"]
|
5 |
+
CORS_ORIGINS=["http://localhost"]
|
6 |
+
|
7 |
+
MODEL_NAME="ViT-B/32"
|
8 |
+
DEVICE="cpu" # ["cuda", "cpu"]
|
9 |
+
|
10 |
+
INDEX_FILE_PATH="data/faiss-index/index_clip_L01_to_L20.faiss"
|
11 |
+
KEYFRAMES_GROUPS_JSON_PATH="data/config/keyframes_groups_L01_to_L20.json"
|
.env.example
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PROJECT INFORMATION
|
2 |
+
HOST=0.0.0.0
|
3 |
+
PORT=8000
|
4 |
+
CORS_HEADERS=["*"]
|
5 |
+
CORS_ORIGINS=["http://localhost"]
|
6 |
+
|
7 |
+
MODEL_NAME="ViT-B/32"
|
8 |
+
DEVICE="cpu" # ["cuda", "cpu"]
|
9 |
+
|
10 |
+
INDEX_FILE_PATH="data/faiss-index/index.faiss"
|
11 |
+
KEYFRAMES_GROUPS_JSON_PATH="data/config/keyframes_groups.json"
|
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.faiss filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8-slim
|
2 |
+
|
3 |
+
RUN apt-get update && \
|
4 |
+
apt-get install git gsutil -y && \
|
5 |
+
apt clean && \
|
6 |
+
rm -rf /var/cache/apt/*
|
7 |
+
|
8 |
+
WORKDIR /code
|
9 |
+
|
10 |
+
COPY requirements.txt /code/requirements.txt
|
11 |
+
|
12 |
+
# PYTHONDONTWRITEBYTECODE=1: Disables the creation of .pyc files (compiled bytecode)
|
13 |
+
# PYTHONUNBUFFERED=1: Disables buffering of the standard output stream
|
14 |
+
# PYTHONIOENCODING: specifies the encoding to be used for the standard input, output, and error streams
|
15 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
16 |
+
PYTHONUNBUFFERED=1 \
|
17 |
+
PYTHONIOENCODING=utf-8
|
18 |
+
|
19 |
+
RUN pip install -U pip && \
|
20 |
+
pip install --no-cache-dir -r /code/requirements.txt
|
21 |
+
|
22 |
+
RUN useradd -m -u 1000 user
|
23 |
+
|
24 |
+
USER user
|
25 |
+
|
26 |
+
ENV HOME=/home/user \
|
27 |
+
PATH=/home/user/.local/bin:$PATH
|
28 |
+
|
29 |
+
WORKDIR $HOME/app
|
30 |
+
|
31 |
+
COPY --chown=user . $HOME/app
|
32 |
+
|
33 |
+
# Download index
|
34 |
+
# RUN mkdir ./data/faiss-index/ && \
|
35 |
+
# gsutil -m cp "gs://thangtd1/faiss-index/index_clip_L01_to_L20.faiss" ./data/faiss-index/
|
36 |
+
|
37 |
+
CMD python ./src/main.py
|
data/config/keyframes_groups_L01_to_L20.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8d2f52efda68fed4a80512ecfe30a90e65663da396b61b8de3db11433cd65f3
|
3 |
+
size 17780893
|
data/faiss-index/index_clip_L01_to_L20.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c294320f4b8cb934f57f199500477324dd57a2e6445db375f02937e5a2fcf19
|
3 |
+
size 413999149
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.103.1
|
2 |
+
uvicorn==0.23.2
|
3 |
+
pydantic-settings==2.0.3
|
4 |
+
|
5 |
+
|
6 |
+
# Models
|
7 |
+
torch==1.7.1
|
8 |
+
torchvision==0.8.2
|
9 |
+
ftfy==6.1.1
|
10 |
+
regex
|
11 |
+
tqdm==4.66.1
|
12 |
+
git+https://github.com/openai/CLIP.git@main
|
13 |
+
|
14 |
+
# Vector Database
|
15 |
+
faiss-cpu
|
16 |
+
|
src/__init__.py
ADDED
File without changes
|
src/config.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from pydantic_settings import BaseSettings
|
4 |
+
|
5 |
+
FILE = Path(__file__)
|
6 |
+
ROOT = FILE.parent.parent
|
7 |
+
|
8 |
+
|
9 |
+
class Settings(BaseSettings):
|
10 |
+
# API SETTINGS
|
11 |
+
HOST: str
|
12 |
+
PORT: int
|
13 |
+
CORS_ORIGINS: list
|
14 |
+
CORS_HEADERS: list
|
15 |
+
|
16 |
+
# MODEL SETTINGS
|
17 |
+
MODEL_NAME: str = "ViT-B/32"
|
18 |
+
DEVICE: str = "cpu"
|
19 |
+
|
20 |
+
# FAISS DATABASE SETTINGS
|
21 |
+
INDEX_FILE_PATH: str
|
22 |
+
KEYFRAMES_GROUPS_JSON_PATH: str
|
23 |
+
|
24 |
+
class Config:
|
25 |
+
env_file = ROOT / ".env"
|
26 |
+
|
27 |
+
|
28 |
+
settings = Settings()
|
src/itr/__init__.py
ADDED
File without changes
|
src/itr/dtb_cursor.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from functools import lru_cache
|
3 |
+
|
4 |
+
import faiss
|
5 |
+
|
6 |
+
|
7 |
+
class DatabaseCursor:
|
8 |
+
def __init__(self, index_file_path: str, keyframes_groups_json_path: str):
|
9 |
+
self._load_index(index_file_path)
|
10 |
+
self._load_keyframes_groups_info(keyframes_groups_json_path)
|
11 |
+
|
12 |
+
@lru_cache(maxsize=1)
|
13 |
+
def _load_index(self, index_file_path):
|
14 |
+
self.index = faiss.read_index(index_file_path)
|
15 |
+
|
16 |
+
@lru_cache(maxsize=1)
|
17 |
+
def _load_keyframes_groups_info(self, keyframes_groups_json_path: str):
|
18 |
+
with open(keyframes_groups_json_path) as file:
|
19 |
+
self.keyframes_group_info = json.loads(file.read())
|
20 |
+
|
21 |
+
def kNN_search(self, query_vector: str, topk: int = 10):
|
22 |
+
results = []
|
23 |
+
distances, ids = self.index.search(query_vector, topk)
|
24 |
+
for i in range(len(ids[0])):
|
25 |
+
frame_detail = self.keyframes_group_info[ids[0][i]]
|
26 |
+
frame_detail["distance"] = str(distances[0][i])
|
27 |
+
results.append(frame_detail)
|
28 |
+
return results
|
src/itr/router.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, File, status
|
2 |
+
from fastapi.responses import JSONResponse
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
from .dtb_cursor import DatabaseCursor
|
6 |
+
from .vlm_model import VisionLanguageModel
|
7 |
+
|
8 |
+
|
9 |
+
class Item(BaseModel):
|
10 |
+
query_text: str
|
11 |
+
topk: int
|
12 |
+
|
13 |
+
|
14 |
+
router = APIRouter()
|
15 |
+
|
16 |
+
|
17 |
+
vectordb_cursor = None
|
18 |
+
vlm_model = None
|
19 |
+
|
20 |
+
|
21 |
+
def init_vectordb(**kargs):
|
22 |
+
# Singleton pattern
|
23 |
+
global vectordb_cursor
|
24 |
+
if vectordb_cursor is None:
|
25 |
+
vectordb_cursor = DatabaseCursor(**kargs)
|
26 |
+
|
27 |
+
|
28 |
+
def init_model(**kargs):
|
29 |
+
# Singleton
|
30 |
+
global vlm_model
|
31 |
+
if vlm_model is None:
|
32 |
+
vlm_model = VisionLanguageModel(**kargs)
|
33 |
+
|
34 |
+
|
35 |
+
@router.post("/retrieval/image-text")
|
36 |
+
async def retrieve(item: Item) -> JSONResponse:
|
37 |
+
try:
|
38 |
+
query_vector = vlm_model.get_embedding(input=item.query_text)
|
39 |
+
search_results = vectordb_cursor.kNN_search(query_vector, item.topk)
|
40 |
+
except Exception:
|
41 |
+
return JSONResponse(
|
42 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
43 |
+
content={"message": "Search error"},
|
44 |
+
)
|
45 |
+
|
46 |
+
return JSONResponse(
|
47 |
+
status_code=status.HTTP_200_OK,
|
48 |
+
content={"message": "success", "details": search_results},
|
49 |
+
)
|
src/itr/vlm_model.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import clip
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
class VisionLanguageModel:
|
9 |
+
def __init__(self, model_name: str = "ViT-B/32", device: str = "cuda"):
|
10 |
+
self._load_model(model_name, device)
|
11 |
+
self.device = device
|
12 |
+
|
13 |
+
@lru_cache(maxsize=1)
|
14 |
+
def _load_model(self, model_name, device: str = "cpu"):
|
15 |
+
self.model, self.processor = clip.load(model_name, device=device)
|
16 |
+
|
17 |
+
def get_embedding(self, input: Union[str, Image.Image]):
|
18 |
+
if isinstance(input, str):
|
19 |
+
tokens = clip.tokenize(input).to(self.device)
|
20 |
+
vector = self.model.encode_text(tokens)
|
21 |
+
vector /= vector.norm(dim=-1, keepdim=True)
|
22 |
+
vector = vector.cpu().detach().numpy().astype("float32")
|
23 |
+
return vector
|
24 |
+
elif isinstance(input, Image.Image):
|
25 |
+
image_input = self.preprocess(input).unsqueeze(0).to(self.device)
|
26 |
+
vector = self.model.encode_image(image_input)
|
27 |
+
vector /= vector.norm(dim=-1, keepdim=True)
|
28 |
+
return vector
|
29 |
+
else:
|
30 |
+
raise Exception("Invalid input type")
|
src/main.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from config import settings
|
3 |
+
from fastapi import FastAPI, Request, status
|
4 |
+
from fastapi.exceptions import RequestValidationError
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
+
from fastapi.responses import JSONResponse, RedirectResponse
|
7 |
+
from itr.router import init_model, init_vectordb
|
8 |
+
from itr.router import router as router
|
9 |
+
|
10 |
+
app = FastAPI(title="Text-to-image Retrieval API")
|
11 |
+
|
12 |
+
|
13 |
+
app.add_middleware(
|
14 |
+
CORSMiddleware,
|
15 |
+
allow_origins=settings.CORS_ORIGINS,
|
16 |
+
allow_headers=settings.CORS_HEADERS,
|
17 |
+
allow_credentials=True,
|
18 |
+
allow_methods=["*"],
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
@app.exception_handler(RequestValidationError)
|
23 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
24 |
+
# Get the original 'detail' list of errors
|
25 |
+
details = exc.errors()
|
26 |
+
error_details = []
|
27 |
+
|
28 |
+
for error in details:
|
29 |
+
error_details.append({"error": f"{error['msg']} {str(error['loc'])}"})
|
30 |
+
return JSONResponse(content={"message": error_details})
|
31 |
+
|
32 |
+
|
33 |
+
@app.on_event("startup")
|
34 |
+
async def startup_event():
|
35 |
+
init_vectordb(
|
36 |
+
index_file_path=settings.INDEX_FILE_PATH,
|
37 |
+
keyframes_groups_json_path=settings.KEYFRAMES_GROUPS_JSON_PATH,
|
38 |
+
)
|
39 |
+
device = (
|
40 |
+
"cuda" if settings.DEVICE == "cuda" and torch.cuda.is_available() else "cpu"
|
41 |
+
)
|
42 |
+
init_model(model_name=settings.MODEL_NAME, device=device)
|
43 |
+
|
44 |
+
|
45 |
+
@app.get("/", include_in_schema=False)
|
46 |
+
async def root() -> None:
|
47 |
+
return RedirectResponse("/docs")
|
48 |
+
|
49 |
+
|
50 |
+
@app.get("/health", status_code=status.HTTP_200_OK, tags=["health"])
|
51 |
+
async def perform_healthcheck() -> None:
|
52 |
+
return JSONResponse(content={"message": "success"})
|
53 |
+
|
54 |
+
|
55 |
+
app.include_router(router)
|
56 |
+
|
57 |
+
|
58 |
+
# Start API
|
59 |
+
# if __name__ == "__main__":
|
60 |
+
# import uvicorn
|
61 |
+
|
62 |
+
# uvicorn.run("main:app", host=settings.HOST, port=settings.PORT, reload=True)
|