AWeirdDev commited on
Commit
de8d704
·
verified ·
1 Parent(s): 1ec82cd

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +81 -0
main.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Start a server with NLP functionality."""
2
+ import bottle
3
+ from sentence_transformers import SentenceTransformer
4
+ from sentence_transformers import util
5
+ from functools import lru_cache
6
+ import typing
7
+ import argparse
8
+
9
+ parser = argparse.ArgumentParser(description="Start an NLP server.")
10
+ parser.add_argument(
11
+ "--port",
12
+ type=int,
13
+ help="Server port",
14
+ default=8080
15
+ )
16
+ parser.add_argument(
17
+ "--model",
18
+ type=str,
19
+ help="Transformer model ID",
20
+ default="all-mpnet-base-v2"
21
+ )
22
+ parser.add_argument(
23
+ "--embed_cache_size",
24
+ type=int,
25
+ help="Cache size for sentence embeddings",
26
+ default=2048,
27
+ )
28
+ args = parser.parse_args()
29
+
30
+ model = SentenceTransformer(args.model)
31
+
32
+
33
+ @bottle.error(405)
34
+ def method_not_allowed(res):
35
+ """Adds headers to allow cross-origin requests to all OPTION requests.
36
+ Essentially this allows requests from external domains to be processed."""
37
+ if bottle.request.method == 'OPTIONS':
38
+ new_res = bottle.HTTPResponse()
39
+ new_res.set_header('Access-Control-Allow-Origin', '*')
40
+ new_res.set_header('Access-Control-Allow-Headers', 'content-type')
41
+ return new_res
42
+ res.headers['Allow'] += ', OPTIONS'
43
+ return bottle.request.app.default_error_handler(res)
44
+
45
+
46
+ @bottle.hook('after_request')
47
+ def enable_cors():
48
+ """Sets the CORS header to `*` in all responses. This signals the clients
49
+ that the response can be read by any domain."""
50
+ bottle.response.set_header('Access-Control-Allow-Origin', '*')
51
+ bottle.response.set_header('Access-Control-Allow-Headers', 'content-type')
52
+
53
+
54
+ @lru_cache(maxsize=args.embed_cache_size)
55
+ def no_batch_embed(sentence: str) -> typing.List[float]:
56
+ """Returns a list with the numbers of the vector into which the
57
+ model embedded the string."""
58
+ return model.encode(sentence).tolist()
59
+
60
+
61
+ @bottle.post('/embedding')
62
+ def embedding():
63
+ """Returns `{'embeddings': v}` where `v` is a list of vectors with the
64
+ embeddings of each document in `documents`."""
65
+ documents = bottle.request.json["documents"]
66
+ embeddings = [no_batch_embed(document) for document in documents]
67
+ return {"embeddings": embeddings}
68
+
69
+
70
+ @bottle.post('/semantic_search')
71
+ def semantic_search():
72
+ """Returns `{'similarities': v}` where `v` is a list of numbers with the
73
+ similarities of `query` to each document in `documents`."""
74
+ query = bottle.request.json["query"]
75
+ documents = bottle.request.json["documents"]
76
+ query_embedding = no_batch_embed(query)
77
+ document_embeddings = [no_batch_embed(document) for document in documents]
78
+ scores = util.dot_score(query_embedding, document_embeddings).squeeze()
79
+ return {"similarities": [float(s) for s in scores]}
80
+
81
+ bottle.run(port=args.port, server="cheroot")