Spaces:
Sleeping
Sleeping
"""Start a server with NLP functionality.""" | |
import bottle | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers import util | |
from functools import lru_cache | |
import typing | |
import argparse | |
parser = argparse.ArgumentParser(description="Start an NLP server.") | |
parser.add_argument( | |
"--port", | |
type=int, | |
help="Server port", | |
default=7860 | |
) | |
parser.add_argument( | |
"--model", | |
type=str, | |
help="Transformer model ID", | |
default="all-mpnet-base-v2" | |
) | |
parser.add_argument( | |
"--embed_cache_size", | |
type=int, | |
help="Cache size for sentence embeddings", | |
default=2048, | |
) | |
args = parser.parse_args() | |
model = SentenceTransformer(args.model) | |
def method_not_allowed(res): | |
"""Adds headers to allow cross-origin requests to all OPTION requests. | |
Essentially this allows requests from external domains to be processed.""" | |
if bottle.request.method == 'OPTIONS': | |
new_res = bottle.HTTPResponse() | |
new_res.set_header('Access-Control-Allow-Origin', '*') | |
new_res.set_header('Access-Control-Allow-Headers', 'content-type') | |
return new_res | |
res.headers['Allow'] += ', OPTIONS' | |
return bottle.request.app.default_error_handler(res) | |
def enable_cors(): | |
"""Sets the CORS header to `*` in all responses. This signals the clients | |
that the response can be read by any domain.""" | |
bottle.response.set_header('Access-Control-Allow-Origin', '*') | |
bottle.response.set_header('Access-Control-Allow-Headers', 'content-type') | |
def no_batch_embed(sentence: str) -> typing.List[float]: | |
"""Returns a list with the numbers of the vector into which the | |
model embedded the string.""" | |
return model.encode(sentence).tolist() | |
def embedding(): | |
"""Returns `{'embeddings': v}` where `v` is a list of vectors with the | |
embeddings of each document in `documents`.""" | |
documents = bottle.request.json["documents"] | |
embeddings = [no_batch_embed(document) for document in documents] | |
return {"embeddings": embeddings} | |
def semantic_search(): | |
"""Returns `{'similarities': v}` where `v` is a list of numbers with the | |
similarities of `query` to each document in `documents`.""" | |
query = bottle.request.json["query"] | |
documents = bottle.request.json["documents"] | |
query_embedding = no_batch_embed(query) | |
document_embeddings = [no_batch_embed(document) for document in documents] | |
scores = util.dot_score(query_embedding, document_embeddings).squeeze() | |
return {"similarities": [float(s) for s in scores]} | |
bottle.run(host="0.0.0.0", port=args.port, server="cheroot") | |