riccorl's picture
first commit
626eca0
import logging
from pathlib import Path
from typing import List, Optional, Union
from relik.common.utils import is_package_available
from relik.inference.annotator import Relik
if not is_package_available("fastapi"):
raise ImportError(
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
)
from fastapi import FastAPI, HTTPException
if not is_package_available("ray"):
raise ImportError(
"Ray is not installed. Please install Ray with `pip install relik[serve]`."
)
from ray import serve
from relik.common.log import get_logger
from relik.inference.serve.backend.utils import (
RayParameterManager,
ServerParameterManager,
)
from relik.retriever.data.utils import batch_generator
logger = get_logger(__name__, level=logging.INFO)
VERSION = {} # type: ignore
with open(
Path(__file__).parent.parent.parent.parent / "version.py", "r"
) as version_file:
exec(version_file.read(), VERSION)
# Env variables for server
SERVER_MANAGER = ServerParameterManager()
RAY_MANAGER = RayParameterManager()
app = FastAPI(
title="ReLiK",
version=VERSION["VERSION"],
description="ReLiK REST API",
)
@serve.deployment(
ray_actor_options={
"num_gpus": RAY_MANAGER.num_gpus
if (
SERVER_MANAGER.retriver_device == "cuda"
or SERVER_MANAGER.reader_device == "cuda"
)
else 0
},
autoscaling_config={
"min_replicas": RAY_MANAGER.min_replicas,
"max_replicas": RAY_MANAGER.max_replicas,
},
)
@serve.ingress(app)
class RelikServer:
def __init__(
self,
question_encoder: str,
document_index: str,
passage_encoder: Optional[str] = None,
reader_encoder: Optional[str] = None,
top_k: int = 100,
retriver_device: str = "cpu",
reader_device: str = "cpu",
index_device: Optional[str] = None,
precision: int = 32,
index_precision: Optional[int] = None,
use_faiss: bool = False,
window_batch_size: int = 32,
window_size: int = 32,
window_stride: int = 16,
split_on_spaces: bool = False,
):
# parameters
self.question_encoder = question_encoder
self.passage_encoder = passage_encoder
self.reader_encoder = reader_encoder
self.document_index = document_index
self.top_k = top_k
self.retriver_device = retriver_device
self.index_device = index_device or retriver_device
self.reader_device = reader_device
self.precision = precision
self.index_precision = index_precision or precision
self.use_faiss = use_faiss
self.window_batch_size = window_batch_size
self.window_size = window_size
self.window_stride = window_stride
self.split_on_spaces = split_on_spaces
# log stuff for debugging
logger.info("Initializing RelikServer with parameters:")
logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
logger.info(f"READER_ENCODER: {self.reader_encoder}")
logger.info(f"DOCUMENT_INDEX: {self.document_index}")
logger.info(f"TOP_K: {self.top_k}")
logger.info(f"RETRIEVER_DEVICE: {self.retriver_device}")
logger.info(f"READER_DEVICE: {self.reader_device}")
logger.info(f"INDEX_DEVICE: {self.index_device}")
logger.info(f"PRECISION: {self.precision}")
logger.info(f"INDEX_PRECISION: {self.index_precision}")
logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
self.relik = Relik(
question_encoder=self.question_encoder,
passage_encoder=self.passage_encoder,
document_index=self.document_index,
reader=self.reader_encoder,
retriever_device=self.retriver_device,
document_index_device=self.index_device,
reader_device=self.reader_device,
retriever_precision=self.precision,
document_index_precision=self.index_precision,
reader_precision=self.precision,
)
# @serve.batch()
async def handle_batch(self, documents: List[str]) -> List:
return self.relik(
documents,
top_k=self.top_k,
window_size=self.window_size,
window_stride=self.window_stride,
batch_size=self.window_batch_size,
)
@app.post("/api/entities")
async def entities_endpoint(
self,
documents: Union[str, List[str]],
):
try:
# normalize input
if isinstance(documents, str):
documents = [documents]
if document_topics is not None:
if isinstance(document_topics, str):
document_topics = [document_topics]
assert len(documents) == len(document_topics)
# get predictions for the retriever
return await self.handle_batch(documents, document_topics)
except Exception as e:
# log the entire stack trace
logger.exception(e)
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
@app.post("/api/gerbil")
async def gerbil_endpoint(self, documents: Union[str, List[str]]):
try:
# normalize input
if isinstance(documents, str):
documents = [documents]
# output list
windows_passages = []
# split documents into windows
document_windows = [
window
for doc_id, document in enumerate(documents)
for window in self.window_manager(
self.tokenizer,
document,
window_size=self.window_size,
stride=self.window_stride,
doc_id=doc_id,
)
]
# get text and topic from document windows and create new list
model_inputs = [
(window.text, window.doc_topic) for window in document_windows
]
# batch generator
for batch in batch_generator(
model_inputs, batch_size=self.window_batch_size
):
text, text_pair = zip(*batch)
batch_predictions = await self.handle_batch_retriever(text, text_pair)
windows_passages.extend(
[
[p.label for p in predictions]
for predictions in batch_predictions
]
)
# add passage to document windows
for window, passages in zip(document_windows, windows_passages):
# clean up passages (remove everything after first <def> tag if present)
passages = [c.split(" <def>", 1)[0] for c in passages]
window.window_candidates = passages
# return document windows
return document_windows
except Exception as e:
# log the entire stack trace
logger.exception(e)
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
server = RelikServer.bind(**vars(SERVER_MANAGER))