Spaces:
Sleeping
Sleeping
import logging | |
from pathlib import Path | |
from typing import List, Optional, Union | |
from relik.common.utils import is_package_available | |
from relik.inference.annotator import Relik | |
if not is_package_available("fastapi"): | |
raise ImportError( | |
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`." | |
) | |
from fastapi import FastAPI, HTTPException | |
if not is_package_available("ray"): | |
raise ImportError( | |
"Ray is not installed. Please install Ray with `pip install relik[serve]`." | |
) | |
from ray import serve | |
from relik.common.log import get_logger | |
from relik.inference.serve.backend.utils import ( | |
RayParameterManager, | |
ServerParameterManager, | |
) | |
from relik.retriever.data.utils import batch_generator | |
logger = get_logger(__name__, level=logging.INFO) | |
VERSION = {} # type: ignore | |
with open( | |
Path(__file__).parent.parent.parent.parent / "version.py", "r" | |
) as version_file: | |
exec(version_file.read(), VERSION) | |
# Env variables for server | |
SERVER_MANAGER = ServerParameterManager() | |
RAY_MANAGER = RayParameterManager() | |
app = FastAPI( | |
title="ReLiK", | |
version=VERSION["VERSION"], | |
description="ReLiK REST API", | |
) | |
class RelikServer: | |
def __init__( | |
self, | |
question_encoder: str, | |
document_index: str, | |
passage_encoder: Optional[str] = None, | |
reader_encoder: Optional[str] = None, | |
top_k: int = 100, | |
retriver_device: str = "cpu", | |
reader_device: str = "cpu", | |
index_device: Optional[str] = None, | |
precision: int = 32, | |
index_precision: Optional[int] = None, | |
use_faiss: bool = False, | |
window_batch_size: int = 32, | |
window_size: int = 32, | |
window_stride: int = 16, | |
split_on_spaces: bool = False, | |
): | |
# parameters | |
self.question_encoder = question_encoder | |
self.passage_encoder = passage_encoder | |
self.reader_encoder = reader_encoder | |
self.document_index = document_index | |
self.top_k = top_k | |
self.retriver_device = retriver_device | |
self.index_device = index_device or retriver_device | |
self.reader_device = reader_device | |
self.precision = precision | |
self.index_precision = index_precision or precision | |
self.use_faiss = use_faiss | |
self.window_batch_size = window_batch_size | |
self.window_size = window_size | |
self.window_stride = window_stride | |
self.split_on_spaces = split_on_spaces | |
# log stuff for debugging | |
logger.info("Initializing RelikServer with parameters:") | |
logger.info(f"QUESTION_ENCODER: {self.question_encoder}") | |
logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}") | |
logger.info(f"READER_ENCODER: {self.reader_encoder}") | |
logger.info(f"DOCUMENT_INDEX: {self.document_index}") | |
logger.info(f"TOP_K: {self.top_k}") | |
logger.info(f"RETRIEVER_DEVICE: {self.retriver_device}") | |
logger.info(f"READER_DEVICE: {self.reader_device}") | |
logger.info(f"INDEX_DEVICE: {self.index_device}") | |
logger.info(f"PRECISION: {self.precision}") | |
logger.info(f"INDEX_PRECISION: {self.index_precision}") | |
logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}") | |
logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}") | |
self.relik = Relik( | |
question_encoder=self.question_encoder, | |
passage_encoder=self.passage_encoder, | |
document_index=self.document_index, | |
reader=self.reader_encoder, | |
retriever_device=self.retriver_device, | |
document_index_device=self.index_device, | |
reader_device=self.reader_device, | |
retriever_precision=self.precision, | |
document_index_precision=self.index_precision, | |
reader_precision=self.precision, | |
) | |
# @serve.batch() | |
async def handle_batch(self, documents: List[str]) -> List: | |
return self.relik( | |
documents, | |
top_k=self.top_k, | |
window_size=self.window_size, | |
window_stride=self.window_stride, | |
batch_size=self.window_batch_size, | |
) | |
async def entities_endpoint( | |
self, | |
documents: Union[str, List[str]], | |
): | |
try: | |
# normalize input | |
if isinstance(documents, str): | |
documents = [documents] | |
if document_topics is not None: | |
if isinstance(document_topics, str): | |
document_topics = [document_topics] | |
assert len(documents) == len(document_topics) | |
# get predictions for the retriever | |
return await self.handle_batch(documents, document_topics) | |
except Exception as e: | |
# log the entire stack trace | |
logger.exception(e) | |
raise HTTPException(status_code=500, detail=f"Server Error: {e}") | |
async def gerbil_endpoint(self, documents: Union[str, List[str]]): | |
try: | |
# normalize input | |
if isinstance(documents, str): | |
documents = [documents] | |
# output list | |
windows_passages = [] | |
# split documents into windows | |
document_windows = [ | |
window | |
for doc_id, document in enumerate(documents) | |
for window in self.window_manager( | |
self.tokenizer, | |
document, | |
window_size=self.window_size, | |
stride=self.window_stride, | |
doc_id=doc_id, | |
) | |
] | |
# get text and topic from document windows and create new list | |
model_inputs = [ | |
(window.text, window.doc_topic) for window in document_windows | |
] | |
# batch generator | |
for batch in batch_generator( | |
model_inputs, batch_size=self.window_batch_size | |
): | |
text, text_pair = zip(*batch) | |
batch_predictions = await self.handle_batch_retriever(text, text_pair) | |
windows_passages.extend( | |
[ | |
[p.label for p in predictions] | |
for predictions in batch_predictions | |
] | |
) | |
# add passage to document windows | |
for window, passages in zip(document_windows, windows_passages): | |
# clean up passages (remove everything after first <def> tag if present) | |
passages = [c.split(" <def>", 1)[0] for c in passages] | |
window.window_candidates = passages | |
# return document windows | |
return document_windows | |
except Exception as e: | |
# log the entire stack trace | |
logger.exception(e) | |
raise HTTPException(status_code=500, detail=f"Server Error: {e}") | |
server = RelikServer.bind(**vars(SERVER_MANAGER)) | |