KevinHuSh
commited on
Commit
·
3245107
1
Parent(s):
d4fd138
use minio to store uploaded files; build dialog server; (#16)
Browse files* format code
* use minio to store uploaded files; build dialog server;
- python/llm/__init__.py +1 -0
- python/llm/chat_model.py +34 -0
- python/llm/embedding_model.py +3 -2
- python/nlp/huchunk.py +4 -2
- python/nlp/search.py +221 -0
- python/parser/docx_parser.py +2 -1
- python/parser/excel_parser.py +3 -1
- python/parser/pdf_parser.py +2 -1
- python/svr/dialog_svr.py +164 -0
- python/svr/parse_user_docs.py +13 -12
- src/api/doc_info.rs +49 -34
- src/api/tag.rs +0 -58
- src/main.rs +25 -24
python/llm/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1 |
from .embedding_model import HuEmbedding
|
|
|
|
1 |
from .embedding_model import HuEmbedding
|
2 |
+
from .chat_model import GptTurbo
|
python/llm/chat_model.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
import openapi
|
3 |
+
import os
|
4 |
+
|
5 |
+
class Base(ABC):
|
6 |
+
def chat(self, system, history, gen_conf):
|
7 |
+
raise NotImplementedError("Please implement encode method!")
|
8 |
+
|
9 |
+
|
10 |
+
class GptTurbo(Base):
|
11 |
+
def __init__(self):
|
12 |
+
openapi.api_key = os.environ["OPENAPI_KEY"]
|
13 |
+
|
14 |
+
def chat(self, system, history, gen_conf):
|
15 |
+
history.insert(0, {"role": "system", "content": system})
|
16 |
+
res = openapi.ChatCompletion.create(model="gpt-3.5-turbo",
|
17 |
+
messages=history,
|
18 |
+
**gen_conf)
|
19 |
+
return res.choices[0].message.content.strip()
|
20 |
+
|
21 |
+
|
22 |
+
class QWen(Base):
|
23 |
+
def chat(self, system, history, gen_conf):
|
24 |
+
from http import HTTPStatus
|
25 |
+
from dashscope import Generation
|
26 |
+
from dashscope.api_entities.dashscope_response import Role
|
27 |
+
response = Generation.call(
|
28 |
+
Generation.Models.qwen_turbo,
|
29 |
+
messages=messages,
|
30 |
+
result_format='message'
|
31 |
+
)
|
32 |
+
if response.status_code == HTTPStatus.OK:
|
33 |
+
return response.output.choices[0]['message']['content']
|
34 |
+
return response.message
|
python/llm/embedding_model.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from abc import ABC
|
2 |
from FlagEmbedding import FlagModel
|
3 |
import torch
|
|
|
4 |
|
5 |
class Base(ABC):
|
6 |
def encode(self, texts: list, batch_size=32):
|
@@ -27,5 +28,5 @@ class HuEmbedding(Base):
|
|
27 |
def encode(self, texts: list, batch_size=32):
|
28 |
res = []
|
29 |
for i in range(0, len(texts), batch_size):
|
30 |
-
res.extend(self.encode(texts[i:i+batch_size]))
|
31 |
-
return res
|
|
|
1 |
from abc import ABC
|
2 |
from FlagEmbedding import FlagModel
|
3 |
import torch
|
4 |
+
import numpy as np
|
5 |
|
6 |
class Base(ABC):
|
7 |
def encode(self, texts: list, batch_size=32):
|
|
|
28 |
def encode(self, texts: list, batch_size=32):
|
29 |
res = []
|
30 |
for i in range(0, len(texts), batch_size):
|
31 |
+
res.extend(self.model.encode(texts[i:i+batch_size]).tolist())
|
32 |
+
return np.array(res)
|
python/nlp/huchunk.py
CHANGED
@@ -372,7 +372,7 @@ class PptChunker(HuChunker):
|
|
372 |
|
373 |
def __call__(self, fnm):
|
374 |
from pptx import Presentation
|
375 |
-
ppt = Presentation(fnm)
|
376 |
flds = self.Fields()
|
377 |
flds.text_chunks = []
|
378 |
for slide in ppt.slides:
|
@@ -396,7 +396,9 @@ class TextChunker(HuChunker):
|
|
396 |
@staticmethod
|
397 |
def is_binary_file(file_path):
|
398 |
mime = magic.Magic(mime=True)
|
399 |
-
|
|
|
|
|
400 |
if 'text' in file_type:
|
401 |
return False
|
402 |
else:
|
|
|
372 |
|
373 |
def __call__(self, fnm):
|
374 |
from pptx import Presentation
|
375 |
+
ppt = Presentation(fnm) if isinstance(fnm, str) else Presentation(BytesIO(fnm))
|
376 |
flds = self.Fields()
|
377 |
flds.text_chunks = []
|
378 |
for slide in ppt.slides:
|
|
|
396 |
@staticmethod
|
397 |
def is_binary_file(file_path):
|
398 |
mime = magic.Magic(mime=True)
|
399 |
+
if isinstance(file_path, str):
|
400 |
+
file_type = mime.from_file(file_path)
|
401 |
+
else:file_type = mime.from_buffer(file_path)
|
402 |
if 'text' in file_type:
|
403 |
return False
|
404 |
else:
|
python/nlp/search.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from elasticsearch_dsl import Q,Search,A
|
3 |
+
from typing import List, Optional, Tuple,Dict, Union
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from util import setup_logging, rmSpace
|
6 |
+
from nlp import huqie, query
|
7 |
+
from datetime import datetime
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
9 |
+
import numpy as np
|
10 |
+
from copy import deepcopy
|
11 |
+
|
12 |
+
class Dealer:
|
13 |
+
def __init__(self, es, emb_mdl):
|
14 |
+
self.qryr = query.EsQueryer(es)
|
15 |
+
self.qryr.flds = ["title_tks^10", "title_sm_tks^5", "content_ltks^2", "content_sm_ltks"]
|
16 |
+
self.es = es
|
17 |
+
self.emb_mdl = emb_mdl
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class SearchResult:
|
21 |
+
total:int
|
22 |
+
ids: List[str]
|
23 |
+
query_vector: List[float] = None
|
24 |
+
field: Optional[Dict] = None
|
25 |
+
highlight: Optional[Dict] = None
|
26 |
+
aggregation: Union[List, Dict, None] = None
|
27 |
+
keywords: Optional[List[str]] = None
|
28 |
+
group_docs: List[List] = None
|
29 |
+
|
30 |
+
def _vector(self, txt, sim=0.8, topk=10):
|
31 |
+
return {
|
32 |
+
"field": "q_vec",
|
33 |
+
"k": topk,
|
34 |
+
"similarity": sim,
|
35 |
+
"num_candidates": 1000,
|
36 |
+
"query_vector": self.emb_mdl.encode_queries(txt)
|
37 |
+
}
|
38 |
+
|
39 |
+
def search(self, req, idxnm, tks_num=3):
|
40 |
+
keywords = []
|
41 |
+
qst = req.get("question", "")
|
42 |
+
|
43 |
+
bqry,keywords = self.qryr.question(qst)
|
44 |
+
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
45 |
+
bqry.filter.append(Q("exists", field="q_tks"))
|
46 |
+
bqry.boost = 0.05
|
47 |
+
print(bqry)
|
48 |
+
|
49 |
+
s = Search()
|
50 |
+
pg = int(req.get("page", 1))-1
|
51 |
+
ps = int(req.get("size", 1000))
|
52 |
+
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
|
53 |
+
"image_id", "doc_id", "q_vec"])
|
54 |
+
|
55 |
+
s = s.query(bqry)[pg*ps:(pg+1)*ps]
|
56 |
+
s = s.highlight("content_ltks")
|
57 |
+
s = s.highlight("title_ltks")
|
58 |
+
if not qst: s = s.sort({"create_time":{"order":"desc", "unmapped_type":"date"}})
|
59 |
+
|
60 |
+
s = s.highlight_options(
|
61 |
+
fragment_size = 120,
|
62 |
+
number_of_fragments=5,
|
63 |
+
boundary_scanner_locale="zh-CN",
|
64 |
+
boundary_scanner="SENTENCE",
|
65 |
+
boundary_chars=",./;:\\!(),。?:!……()——、"
|
66 |
+
)
|
67 |
+
s = s.to_dict()
|
68 |
+
q_vec = []
|
69 |
+
if req.get("vector"):
|
70 |
+
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
|
71 |
+
s["knn"]["filter"] = bqry.to_dict()
|
72 |
+
del s["highlight"]
|
73 |
+
q_vec = s["knn"]["query_vector"]
|
74 |
+
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
|
75 |
+
print("TOTAL: ", self.es.getTotal(res))
|
76 |
+
if self.es.getTotal(res) == 0 and "knn" in s:
|
77 |
+
bqry,_ = self.qryr.question(qst, min_match="10%")
|
78 |
+
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
79 |
+
s["query"] = bqry.to_dict()
|
80 |
+
s["knn"]["filter"] = bqry.to_dict()
|
81 |
+
s["knn"]["similarity"] = 0.7
|
82 |
+
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
|
83 |
+
|
84 |
+
kwds = set([])
|
85 |
+
for k in keywords:
|
86 |
+
kwds.add(k)
|
87 |
+
for kk in huqie.qieqie(k).split(" "):
|
88 |
+
if len(kk) < 2:continue
|
89 |
+
if kk in kwds:continue
|
90 |
+
kwds.add(kk)
|
91 |
+
|
92 |
+
aggs = self.getAggregation(res, "docnm_kwd")
|
93 |
+
|
94 |
+
return self.SearchResult(
|
95 |
+
total = self.es.getTotal(res),
|
96 |
+
ids = self.es.getDocIds(res),
|
97 |
+
query_vector = q_vec,
|
98 |
+
aggregation = aggs,
|
99 |
+
highlight = self.getHighlight(res),
|
100 |
+
field = self.getFields(res, ["docnm_kwd", "content_ltks",
|
101 |
+
"kb_id","image_id", "doc_id", "q_vec"]),
|
102 |
+
keywords = list(kwds)
|
103 |
+
)
|
104 |
+
|
105 |
+
def getAggregation(self, res, g):
|
106 |
+
if not "aggregations" in res or "aggs_"+g not in res["aggregations"]:return
|
107 |
+
bkts = res["aggregations"]["aggs_"+g]["buckets"]
|
108 |
+
return [(b["key"], b["doc_count"]) for b in bkts]
|
109 |
+
|
110 |
+
def getHighlight(self, res):
|
111 |
+
def rmspace(line):
|
112 |
+
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
|
113 |
+
r = []
|
114 |
+
for t in line.split(" "):
|
115 |
+
if not t:continue
|
116 |
+
if len(r)>0 and len(t)>0 and r[-1][-1] in eng and t[0] in eng:r.append(" ")
|
117 |
+
r.append(t)
|
118 |
+
r = "".join(r)
|
119 |
+
return r
|
120 |
+
|
121 |
+
ans = {}
|
122 |
+
for d in res["hits"]["hits"]:
|
123 |
+
hlts = d.get("highlight")
|
124 |
+
if not hlts:continue
|
125 |
+
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
|
126 |
+
return ans
|
127 |
+
|
128 |
+
def getFields(self, sres, flds):
|
129 |
+
res = {}
|
130 |
+
if not flds:return {}
|
131 |
+
for d in self.es.getSource(sres):
|
132 |
+
m = {n:d.get(n) for n in flds if d.get(n) is not None}
|
133 |
+
for n,v in m.items():
|
134 |
+
if type(v) == type([]):
|
135 |
+
m[n] = "\t".join([str(vv) for vv in v])
|
136 |
+
continue
|
137 |
+
if type(v) != type(""):m[n] = str(m[n])
|
138 |
+
m[n] = rmSpace(m[n])
|
139 |
+
|
140 |
+
if m:res[d["id"]] = m
|
141 |
+
return res
|
142 |
+
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
def trans2floats(txt):
|
146 |
+
return [float(t) for t in txt.split("\t")]
|
147 |
+
|
148 |
+
|
149 |
+
def insert_citations(self, ans, top_idx, sres, vfield = "q_vec", cfield="content_ltks"):
|
150 |
+
|
151 |
+
ins_embd = [Dealer.trans2floats(sres.field[sres.ids[i]][vfield]) for i in top_idx]
|
152 |
+
ins_tw =[sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
|
153 |
+
s = 0
|
154 |
+
e = 0
|
155 |
+
res = ""
|
156 |
+
def citeit():
|
157 |
+
nonlocal s, e, ans, res
|
158 |
+
if not ins_embd:return
|
159 |
+
embd = self.emb_mdl.encode(ans[s: e])
|
160 |
+
sim = self.qryr.hybrid_similarity(embd,
|
161 |
+
ins_embd,
|
162 |
+
huqie.qie(ans[s:e]).split(" "),
|
163 |
+
ins_tw)
|
164 |
+
print(ans[s: e], sim)
|
165 |
+
mx = np.max(sim)*0.99
|
166 |
+
if mx < 0.55:return
|
167 |
+
cita = list(set([top_idx[i] for i in range(len(ins_embd)) if sim[i] >mx]))[:4]
|
168 |
+
for i in cita: res += f"@?{i}?@"
|
169 |
+
|
170 |
+
return cita
|
171 |
+
|
172 |
+
punct = set(";。?!!")
|
173 |
+
if not self.qryr.isChinese(ans):
|
174 |
+
punct.add("?")
|
175 |
+
punct.add(".")
|
176 |
+
while e < len(ans):
|
177 |
+
if e - s < 12 or ans[e] not in punct:
|
178 |
+
e += 1
|
179 |
+
continue
|
180 |
+
if ans[e] == "." and e+1<len(ans) and re.match(r"[0-9]", ans[e+1]):
|
181 |
+
e += 1
|
182 |
+
continue
|
183 |
+
if ans[e] == "." and e-2>=0 and ans[e-2] == "\n":
|
184 |
+
e += 1
|
185 |
+
continue
|
186 |
+
res += ans[s: e]
|
187 |
+
citeit()
|
188 |
+
res += ans[e]
|
189 |
+
e += 1
|
190 |
+
s = e
|
191 |
+
|
192 |
+
if s< len(ans):
|
193 |
+
res += ans[s:]
|
194 |
+
citeit()
|
195 |
+
|
196 |
+
return res
|
197 |
+
|
198 |
+
|
199 |
+
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, vfield="q_vec", cfield="content_ltks"):
|
200 |
+
ins_embd = [Dealer.trans2floats(sres.field[i]["q_vec"]) for i in sres.ids]
|
201 |
+
if not ins_embd: return []
|
202 |
+
ins_tw =[sres.field[i][cfield].split(" ") for i in sres.ids]
|
203 |
+
#return CosineSimilarity([sres.query_vector], ins_embd)[0]
|
204 |
+
sim = self.qryr.hybrid_similarity(sres.query_vector,
|
205 |
+
ins_embd,
|
206 |
+
huqie.qie(query).split(" "),
|
207 |
+
ins_tw, tkweight, vtweight)
|
208 |
+
return sim
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
if __name__ == "__main__":
|
213 |
+
from util import es_conn
|
214 |
+
SE = Dealer(es_conn.HuEs("infiniflow"))
|
215 |
+
qs = [
|
216 |
+
"胡凯",
|
217 |
+
""
|
218 |
+
]
|
219 |
+
for q in qs:
|
220 |
+
print(">>>>>>>>>>>>>>>>>>>>", q)
|
221 |
+
print(SE.search({"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))
|
python/parser/docx_parser.py
CHANGED
@@ -3,6 +3,7 @@ import re
|
|
3 |
import pandas as pd
|
4 |
from collections import Counter
|
5 |
from nlp import huqie
|
|
|
6 |
|
7 |
|
8 |
class HuDocxParser:
|
@@ -97,7 +98,7 @@ class HuDocxParser:
|
|
97 |
return ["\n".join(lines)]
|
98 |
|
99 |
def __call__(self, fnm):
|
100 |
-
self.doc = Document(fnm)
|
101 |
secs = [(p.text, p.style.name) for p in self.doc.paragraphs]
|
102 |
tbls = [self.__extract_table_content(tb) for tb in self.doc.tables]
|
103 |
return secs, tbls
|
|
|
3 |
import pandas as pd
|
4 |
from collections import Counter
|
5 |
from nlp import huqie
|
6 |
+
from io import BytesIO
|
7 |
|
8 |
|
9 |
class HuDocxParser:
|
|
|
98 |
return ["\n".join(lines)]
|
99 |
|
100 |
def __call__(self, fnm):
|
101 |
+
self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm))
|
102 |
secs = [(p.text, p.style.name) for p in self.doc.paragraphs]
|
103 |
tbls = [self.__extract_table_content(tb) for tb in self.doc.tables]
|
104 |
return secs, tbls
|
python/parser/excel_parser.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
from openpyxl import load_workbook
|
2 |
import sys
|
|
|
3 |
|
4 |
|
5 |
class HuExcelParser:
|
6 |
def __call__(self, fnm):
|
7 |
-
wb = load_workbook(fnm)
|
|
|
8 |
res = []
|
9 |
for sheetname in wb.sheetnames:
|
10 |
ws = wb[sheetname]
|
|
|
1 |
from openpyxl import load_workbook
|
2 |
import sys
|
3 |
+
from io import BytesIO
|
4 |
|
5 |
|
6 |
class HuExcelParser:
|
7 |
def __call__(self, fnm):
|
8 |
+
if isinstance(fnm, str):wb = load_workbook(fnm)
|
9 |
+
else: wb = load_workbook(BytesIO(fnm))
|
10 |
res = []
|
11 |
for sheetname in wb.sheetnames:
|
12 |
ws = wb[sheetname]
|
python/parser/pdf_parser.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import xgboost as xgb
|
|
|
2 |
import torch
|
3 |
import re
|
4 |
import pdfplumber
|
@@ -1525,7 +1526,7 @@ class HuParser:
|
|
1525 |
return "\n\n".join(res)
|
1526 |
|
1527 |
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
1528 |
-
self.pdf = pdfplumber.open(fnm)
|
1529 |
self.lefted_chars = []
|
1530 |
self.mean_height = []
|
1531 |
self.mean_width = []
|
|
|
1 |
import xgboost as xgb
|
2 |
+
from io import BytesIO
|
3 |
import torch
|
4 |
import re
|
5 |
import pdfplumber
|
|
|
1526 |
return "\n\n".join(res)
|
1527 |
|
1528 |
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
1529 |
+
self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm))
|
1530 |
self.lefted_chars = []
|
1531 |
self.mean_height = []
|
1532 |
self.mean_width = []
|
python/svr/dialog_svr.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#-*- coding:utf-8 -*-
|
2 |
+
import sys, os, re,inspect,json,traceback,logging,argparse, copy
|
3 |
+
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
|
4 |
+
from tornado.web import RequestHandler,Application
|
5 |
+
from tornado.ioloop import IOLoop
|
6 |
+
from tornado.httpserver import HTTPServer
|
7 |
+
from tornado.options import define,options
|
8 |
+
from util import es_conn, setup_logging
|
9 |
+
from svr import sec_search as search
|
10 |
+
from svr.rpc_proxy import RPCProxy
|
11 |
+
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
12 |
+
from nlp import huqie
|
13 |
+
from nlp import query as Query
|
14 |
+
from llm import HuEmbedding, GptTurbo
|
15 |
+
import numpy as np
|
16 |
+
from io import BytesIO
|
17 |
+
from util import config
|
18 |
+
from timeit import default_timer as timer
|
19 |
+
from collections import OrderedDict
|
20 |
+
|
21 |
+
SE = None
|
22 |
+
CFIELD="content_ltks"
|
23 |
+
EMBEDDING = HuEmbedding()
|
24 |
+
LLM = GptTurbo()
|
25 |
+
|
26 |
+
def get_QA_pairs(hists):
|
27 |
+
pa = []
|
28 |
+
for h in hists:
|
29 |
+
for k in ["user", "assistant"]:
|
30 |
+
if h.get(k):
|
31 |
+
pa.append({
|
32 |
+
"content": h[k],
|
33 |
+
"role": k,
|
34 |
+
})
|
35 |
+
|
36 |
+
for p in pa[:-1]: assert len(p) == 2, p
|
37 |
+
return pa
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
def get_instruction(sres, top_i, max_len=8096 fld="content_ltks"):
|
42 |
+
max_len //= len(top_i)
|
43 |
+
# add instruction to prompt
|
44 |
+
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
|
45 |
+
if len(instructions)>2:
|
46 |
+
# Said that LLM is sensitive to the first and the last one, so
|
47 |
+
# rearrange the order of references
|
48 |
+
instructions.append(copy.deepcopy(instructions[1]))
|
49 |
+
instructions.pop(1)
|
50 |
+
|
51 |
+
def token_num(txt):
|
52 |
+
c = 0
|
53 |
+
for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt):
|
54 |
+
if re.match(r"[a-zA-Z-]+$", tk):
|
55 |
+
c += 1
|
56 |
+
continue
|
57 |
+
c += len(tk)
|
58 |
+
return c
|
59 |
+
|
60 |
+
_inst = ""
|
61 |
+
for ins in instructions:
|
62 |
+
if token_num(_inst) > 4096:
|
63 |
+
_inst += "\n知识库:" + instructions[-1][:max_len]
|
64 |
+
break
|
65 |
+
_inst += "\n知识库:" + ins[:max_len]
|
66 |
+
return _inst
|
67 |
+
|
68 |
+
|
69 |
+
def prompt_and_answer(history, inst):
|
70 |
+
hist = get_QA_pairs(history)
|
71 |
+
chks = []
|
72 |
+
for s in re.split(r"[::;;。\n\r]+", inst):
|
73 |
+
if s: chks.append(s)
|
74 |
+
chks = len(set(chks))/(0.1+len(chks))
|
75 |
+
print("Duplication portion:", chks)
|
76 |
+
|
77 |
+
system = """
|
78 |
+
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
|
79 |
+
以下是知识库:
|
80 |
+
%s
|
81 |
+
以上是知识库。
|
82 |
+
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
|
83 |
+
|
84 |
+
print("【PROMPT】:", system)
|
85 |
+
start = timer()
|
86 |
+
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
|
87 |
+
print("GENERATE: ", timer()-start)
|
88 |
+
print("===>>", response)
|
89 |
+
return response
|
90 |
+
|
91 |
+
|
92 |
+
class Handler(RequestHandler):
|
93 |
+
def post(self):
|
94 |
+
global SE,MUST_TK_NUM
|
95 |
+
param = json.loads(self.request.body.decode('utf-8'))
|
96 |
+
try:
|
97 |
+
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
|
98 |
+
res = SE.search({
|
99 |
+
"question": question,
|
100 |
+
"kb_ids": param.get("kb_ids", []),
|
101 |
+
"size": param.get("topn", 15)
|
102 |
+
})
|
103 |
+
|
104 |
+
sim = SE.rerank(res, question)
|
105 |
+
rk_idx = np.argsort(sim*-1)
|
106 |
+
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
|
107 |
+
inst = get_instruction(res, topidx)
|
108 |
+
|
109 |
+
ans, topidx = prompt_and_answer(param["history"], inst)
|
110 |
+
ans = SE.insert_citations(ans, topidx, res)
|
111 |
+
|
112 |
+
refer = OrderedDict()
|
113 |
+
docnms = {}
|
114 |
+
for i in rk_idx:
|
115 |
+
did = res.field[res.ids[i]]["doc_id"])
|
116 |
+
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"])
|
117 |
+
if did not in refer: refer[did] = []
|
118 |
+
refer[did].append({
|
119 |
+
"chunk_id": res.ids[i],
|
120 |
+
"content": res.field[res.ids[i]]["content_ltks"]),
|
121 |
+
"image": ""
|
122 |
+
})
|
123 |
+
|
124 |
+
print("::::::::::::::", ans)
|
125 |
+
self.write(json.dumps({
|
126 |
+
"code":0,
|
127 |
+
"msg":"success",
|
128 |
+
"data":{
|
129 |
+
"uid": param["uid"],
|
130 |
+
"dialog_id": param["dialog_id"],
|
131 |
+
"assistant": ans
|
132 |
+
"refer": [{
|
133 |
+
"did": did,
|
134 |
+
"doc_name": docnms[did],
|
135 |
+
"chunks": chunks
|
136 |
+
} for did, chunks in refer.items()]
|
137 |
+
}
|
138 |
+
}))
|
139 |
+
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
|
140 |
+
|
141 |
+
except Exception as e:
|
142 |
+
logging.error("Request 500: "+str(e))
|
143 |
+
self.write(json.dumps({
|
144 |
+
"code":500,
|
145 |
+
"msg":str(e),
|
146 |
+
"data":{}
|
147 |
+
}))
|
148 |
+
print(traceback.format_exc())
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
parser = argparse.ArgumentParser()
|
153 |
+
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
|
154 |
+
ARGS = parser.parse_args()
|
155 |
+
|
156 |
+
SE = search.ResearchReportSearch(es_conn.HuEs("infiniflow"), EMBEDDING)
|
157 |
+
|
158 |
+
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
|
159 |
+
http_server = HTTPServer(app)
|
160 |
+
http_server.bind(ARGS.port)
|
161 |
+
http_server.start(3)
|
162 |
+
|
163 |
+
IOLoop.current().start()
|
164 |
+
|
python/svr/parse_user_docs.py
CHANGED
@@ -34,18 +34,14 @@ DOC = DocxChunker(DocxParser())
|
|
34 |
EXC = ExcelChunker(ExcelParser())
|
35 |
PPT = PptChunker()
|
36 |
|
37 |
-
|
38 |
-
logging.warning(f"The files are stored in {UPLOAD_LOCATION}, please check it!")
|
39 |
-
|
40 |
-
|
41 |
-
def chuck_doc(name):
|
42 |
suff = os.path.split(name)[-1].lower().split(".")[-1]
|
43 |
-
if suff.find("pdf") >= 0: return PDF(
|
44 |
-
if suff.find("doc") >= 0: return DOC(
|
45 |
-
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(
|
46 |
-
if suff.find("ppt") >= 0: return PPT(
|
47 |
|
48 |
-
return TextChunker()(
|
49 |
|
50 |
|
51 |
def collect(comm, mod, tm):
|
@@ -115,7 +111,7 @@ def build(row):
|
|
115 |
random.seed(time.time())
|
116 |
set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
|
117 |
try:
|
118 |
-
obj = chuck_doc(
|
119 |
except Exception as e:
|
120 |
if re.search("(No such file|not found)", str(e)):
|
121 |
set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"])
|
@@ -133,9 +129,11 @@ def build(row):
|
|
133 |
doc = {
|
134 |
"doc_id": row["did"],
|
135 |
"kb_id": [str(row["kb_id"])],
|
|
|
136 |
"title_tks": huqie.qie(os.path.split(row["location"])[-1]),
|
137 |
"updated_at": str(row["updated_at"]).replace("T", " ")[:19]
|
138 |
}
|
|
|
139 |
output_buffer = BytesIO()
|
140 |
docs = []
|
141 |
md5 = hashlib.md5()
|
@@ -144,11 +142,14 @@ def build(row):
|
|
144 |
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
145 |
d["_id"] = md5.hexdigest()
|
146 |
d["content_ltks"] = huqie.qie(txt)
|
|
|
147 |
if not img:
|
148 |
docs.append(d)
|
149 |
continue
|
150 |
img.save(output_buffer, format='JPEG')
|
151 |
-
|
|
|
|
|
152 |
docs.append(d)
|
153 |
|
154 |
for arr, img in obj.table_chunks:
|
|
|
34 |
EXC = ExcelChunker(ExcelParser())
|
35 |
PPT = PptChunker()
|
36 |
|
37 |
+
def chuck_doc(name, binary):
|
|
|
|
|
|
|
|
|
38 |
suff = os.path.split(name)[-1].lower().split(".")[-1]
|
39 |
+
if suff.find("pdf") >= 0: return PDF(binary)
|
40 |
+
if suff.find("doc") >= 0: return DOC(binary)
|
41 |
+
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
|
42 |
+
if suff.find("ppt") >= 0: return PPT(binary)
|
43 |
|
44 |
+
return TextChunker()(binary)
|
45 |
|
46 |
|
47 |
def collect(comm, mod, tm):
|
|
|
111 |
random.seed(time.time())
|
112 |
set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
|
113 |
try:
|
114 |
+
obj = chuck_doc(row["doc_name"], MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
|
115 |
except Exception as e:
|
116 |
if re.search("(No such file|not found)", str(e)):
|
117 |
set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"])
|
|
|
129 |
doc = {
|
130 |
"doc_id": row["did"],
|
131 |
"kb_id": [str(row["kb_id"])],
|
132 |
+
"docnm_kwd": os.path.split(row["location"])[-1],
|
133 |
"title_tks": huqie.qie(os.path.split(row["location"])[-1]),
|
134 |
"updated_at": str(row["updated_at"]).replace("T", " ")[:19]
|
135 |
}
|
136 |
+
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
137 |
output_buffer = BytesIO()
|
138 |
docs = []
|
139 |
md5 = hashlib.md5()
|
|
|
142 |
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
143 |
d["_id"] = md5.hexdigest()
|
144 |
d["content_ltks"] = huqie.qie(txt)
|
145 |
+
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
146 |
if not img:
|
147 |
docs.append(d)
|
148 |
continue
|
149 |
img.save(output_buffer, format='JPEG')
|
150 |
+
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
|
151 |
+
output_buffer.getvalue())
|
152 |
+
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
|
153 |
docs.append(d)
|
154 |
|
155 |
for arr, img in obj.table_chunks:
|
src/api/doc_info.rs
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
use std::collections::HashMap;
|
2 |
-
use std::io::
|
3 |
-
use actix_multipart_extract::{File, Multipart, MultipartForm};
|
4 |
-
use actix_web::{
|
5 |
-
use chrono::{Utc, FixedOffset};
|
6 |
-
use minio::s3::args::{BucketExistsArgs, MakeBucketArgs,
|
7 |
use sea_orm::DbConn;
|
8 |
use crate::api::JsonResponse;
|
9 |
use crate::AppState;
|
@@ -12,9 +12,6 @@ use crate::errors::AppError;
|
|
12 |
use crate::service::doc_info::{ Mutation, Query };
|
13 |
use serde::Deserialize;
|
14 |
|
15 |
-
const BUCKET_NAME: &'static str = "docgpt-upload";
|
16 |
-
|
17 |
-
|
18 |
fn now() -> chrono::DateTime<FixedOffset> {
|
19 |
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
20 |
}
|
@@ -74,53 +71,71 @@ async fn upload(
|
|
74 |
) -> Result<HttpResponse, AppError> {
|
75 |
let uid = payload.uid;
|
76 |
let file_name = payload.file_field.name.as_str();
|
77 |
-
async fn add_number_to_filename(
|
|
|
|
|
|
|
|
|
|
|
78 |
let mut i = 0;
|
79 |
let mut new_file_name = file_name.to_string();
|
80 |
let arr: Vec<&str> = file_name.split(".").collect();
|
81 |
-
let suffix = String::from(arr[arr.len()-1]);
|
82 |
-
let preffix = arr[..arr.len()-1].join(".");
|
83 |
-
let mut docs = Query::find_doc_infos_by_name(
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
85 |
i += 1;
|
86 |
new_file_name = format!("{}_{}.{}", preffix, i, suffix);
|
87 |
-
docs = Query::find_doc_infos_by_name(
|
|
|
|
|
|
|
|
|
|
|
88 |
}
|
89 |
new_file_name
|
90 |
}
|
91 |
let fnm = add_number_to_filename(file_name, &data.conn, uid, payload.did).await;
|
92 |
|
93 |
-
let
|
|
|
94 |
let buckets_exists = s3_client
|
95 |
-
.bucket_exists(&BucketExistsArgs::new(
|
96 |
-
.
|
97 |
if !buckets_exists {
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
101 |
}
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
)
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
114 |
let doc = Mutation::create_doc_info(&data.conn, Model {
|
115 |
-
did:Default::default(),
|
116 |
-
uid:
|
117 |
doc_name: fnm,
|
118 |
size: payload.file_field.bytes.len() as i64,
|
119 |
location,
|
120 |
r#type: "doc".to_string(),
|
121 |
created_at: now(),
|
122 |
updated_at: now(),
|
123 |
-
is_deleted:Default::default(),
|
124 |
}).await?;
|
125 |
|
126 |
let _ = Mutation::place_doc(&data.conn, payload.did, doc.did.unwrap()).await?;
|
|
|
1 |
use std::collections::HashMap;
|
2 |
+
use std::io::BufReader;
|
3 |
+
use actix_multipart_extract::{ File, Multipart, MultipartForm };
|
4 |
+
use actix_web::{ HttpResponse, post, web };
|
5 |
+
use chrono::{ Utc, FixedOffset };
|
6 |
+
use minio::s3::args::{ BucketExistsArgs, MakeBucketArgs, PutObjectArgs };
|
7 |
use sea_orm::DbConn;
|
8 |
use crate::api::JsonResponse;
|
9 |
use crate::AppState;
|
|
|
12 |
use crate::service::doc_info::{ Mutation, Query };
|
13 |
use serde::Deserialize;
|
14 |
|
|
|
|
|
|
|
15 |
fn now() -> chrono::DateTime<FixedOffset> {
|
16 |
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
17 |
}
|
|
|
71 |
) -> Result<HttpResponse, AppError> {
|
72 |
let uid = payload.uid;
|
73 |
let file_name = payload.file_field.name.as_str();
|
74 |
+
async fn add_number_to_filename(
|
75 |
+
file_name: &str,
|
76 |
+
conn: &DbConn,
|
77 |
+
uid: i64,
|
78 |
+
parent_id: i64
|
79 |
+
) -> String {
|
80 |
let mut i = 0;
|
81 |
let mut new_file_name = file_name.to_string();
|
82 |
let arr: Vec<&str> = file_name.split(".").collect();
|
83 |
+
let suffix = String::from(arr[arr.len() - 1]);
|
84 |
+
let preffix = arr[..arr.len() - 1].join(".");
|
85 |
+
let mut docs = Query::find_doc_infos_by_name(
|
86 |
+
conn,
|
87 |
+
uid,
|
88 |
+
&new_file_name,
|
89 |
+
Some(parent_id)
|
90 |
+
).await.unwrap();
|
91 |
+
while docs.len() > 0 {
|
92 |
i += 1;
|
93 |
new_file_name = format!("{}_{}.{}", preffix, i, suffix);
|
94 |
+
docs = Query::find_doc_infos_by_name(
|
95 |
+
conn,
|
96 |
+
uid,
|
97 |
+
&new_file_name,
|
98 |
+
Some(parent_id)
|
99 |
+
).await.unwrap();
|
100 |
}
|
101 |
new_file_name
|
102 |
}
|
103 |
let fnm = add_number_to_filename(file_name, &data.conn, uid, payload.did).await;
|
104 |
|
105 |
+
let bucket_name = format!("{}-upload", payload.uid);
|
106 |
+
let s3_client: &minio::s3::client::Client = &data.s3_client;
|
107 |
let buckets_exists = s3_client
|
108 |
+
.bucket_exists(&BucketExistsArgs::new(&bucket_name).unwrap()).await
|
109 |
+
.unwrap();
|
110 |
if !buckets_exists {
|
111 |
+
print!("Create bucket: {}", bucket_name.clone());
|
112 |
+
s3_client.make_bucket(&MakeBucketArgs::new(&bucket_name).unwrap()).await.unwrap();
|
113 |
+
} else {
|
114 |
+
print!("Existing bucket: {}", bucket_name.clone());
|
115 |
}
|
116 |
|
117 |
+
let location = format!("/{}/{}", payload.did, fnm);
|
118 |
+
print!("===>{}", location.clone());
|
119 |
+
s3_client.put_object(
|
120 |
+
&mut PutObjectArgs::new(
|
121 |
+
&bucket_name,
|
122 |
+
&location,
|
123 |
+
&mut BufReader::new(payload.file_field.bytes.as_slice()),
|
124 |
+
Some(payload.file_field.bytes.len()),
|
125 |
+
None
|
126 |
+
)?
|
127 |
+
).await?;
|
128 |
+
|
129 |
let doc = Mutation::create_doc_info(&data.conn, Model {
|
130 |
+
did: Default::default(),
|
131 |
+
uid: uid,
|
132 |
doc_name: fnm,
|
133 |
size: payload.file_field.bytes.len() as i64,
|
134 |
location,
|
135 |
r#type: "doc".to_string(),
|
136 |
created_at: now(),
|
137 |
updated_at: now(),
|
138 |
+
is_deleted: Default::default(),
|
139 |
}).await?;
|
140 |
|
141 |
let _ = Mutation::place_doc(&data.conn, payload.did, doc.did.unwrap()).await?;
|
src/api/tag.rs
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
use std::collections::HashMap;
|
2 |
-
use actix_web::{get, HttpResponse, post, web};
|
3 |
-
use actix_web::http::Error;
|
4 |
-
use crate::api::JsonResponse;
|
5 |
-
use crate::AppState;
|
6 |
-
use crate::entity::tag_info;
|
7 |
-
use crate::service::tag_info::{Mutation, Query};
|
8 |
-
|
9 |
-
#[post("/v1.0/create_tag")]
|
10 |
-
async fn create(model: web::Json<tag_info::Model>, data: web::Data<AppState>) -> Result<HttpResponse, Error> {
|
11 |
-
let model = Mutation::create_tag(&data.conn, model.into_inner()).await.unwrap();
|
12 |
-
|
13 |
-
let mut result = HashMap::new();
|
14 |
-
result.insert("tid", model.tid.unwrap());
|
15 |
-
|
16 |
-
let json_response = JsonResponse {
|
17 |
-
code: 200,
|
18 |
-
err: "".to_owned(),
|
19 |
-
data: result,
|
20 |
-
};
|
21 |
-
|
22 |
-
Ok(HttpResponse::Ok()
|
23 |
-
.content_type("application/json")
|
24 |
-
.body(serde_json::to_string(&json_response).unwrap()))
|
25 |
-
}
|
26 |
-
|
27 |
-
#[post("/v1.0/delete_tag")]
|
28 |
-
async fn delete(model: web::Json<tag_info::Model>, data: web::Data<AppState>) -> Result<HttpResponse, Error> {
|
29 |
-
let _ = Mutation::delete_tag(&data.conn, model.tid).await.unwrap();
|
30 |
-
|
31 |
-
let json_response = JsonResponse {
|
32 |
-
code: 200,
|
33 |
-
err: "".to_owned(),
|
34 |
-
data: (),
|
35 |
-
};
|
36 |
-
|
37 |
-
Ok(HttpResponse::Ok()
|
38 |
-
.content_type("application/json")
|
39 |
-
.body(serde_json::to_string(&json_response).unwrap()))
|
40 |
-
}
|
41 |
-
|
42 |
-
#[get("/v1.0/tags")]
|
43 |
-
async fn list(data: web::Data<AppState>) -> Result<HttpResponse, Error> {
|
44 |
-
let tags = Query::find_tag_infos(&data.conn).await.unwrap();
|
45 |
-
|
46 |
-
let mut result = HashMap::new();
|
47 |
-
result.insert("tags", tags);
|
48 |
-
|
49 |
-
let json_response = JsonResponse {
|
50 |
-
code: 200,
|
51 |
-
err: "".to_owned(),
|
52 |
-
data: result,
|
53 |
-
};
|
54 |
-
|
55 |
-
Ok(HttpResponse::Ok()
|
56 |
-
.content_type("application/json")
|
57 |
-
.body(serde_json::to_string(&json_response).unwrap()))
|
58 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/main.rs
CHANGED
@@ -5,9 +5,9 @@ mod errors;
|
|
5 |
|
6 |
use std::env;
|
7 |
use actix_files::Files;
|
8 |
-
use actix_identity::{CookieIdentityPolicy, IdentityService, RequestIdentity};
|
9 |
use actix_session::CookieSession;
|
10 |
-
use actix_web::{web, App, HttpServer, middleware, Error};
|
11 |
use actix_web::cookie::time::Duration;
|
12 |
use actix_web::dev::ServiceRequest;
|
13 |
use actix_web::error::ErrorUnauthorized;
|
@@ -16,9 +16,9 @@ use listenfd::ListenFd;
|
|
16 |
use minio::s3::client::Client;
|
17 |
use minio::s3::creds::StaticProvider;
|
18 |
use minio::s3::http::BaseUrl;
|
19 |
-
use sea_orm::{Database, DatabaseConnection};
|
20 |
-
use migration::{Migrator, MigratorTrait};
|
21 |
-
use crate::errors::{AppError, UserError};
|
22 |
|
23 |
#[derive(Debug, Clone)]
|
24 |
struct AppState {
|
@@ -28,10 +28,10 @@ struct AppState {
|
|
28 |
|
29 |
pub(crate) async fn validator(
|
30 |
req: ServiceRequest,
|
31 |
-
credentials: BearerAuth
|
32 |
) -> Result<ServiceRequest, Error> {
|
33 |
if let Some(token) = req.get_identity() {
|
34 |
-
println!("{}, {}",credentials.token(), token);
|
35 |
(credentials.token() == token)
|
36 |
.then(|| req)
|
37 |
.ok_or(ErrorUnauthorized(UserError::InvalidToken))
|
@@ -52,26 +52,25 @@ async fn main() -> Result<(), AppError> {
|
|
52 |
let port = env::var("PORT").expect("PORT is not set in .env file");
|
53 |
let server_url = format!("{host}:{port}");
|
54 |
|
55 |
-
let s3_base_url = env::var("
|
56 |
-
let s3_access_key = env::var("
|
57 |
-
let s3_secret_key = env::var("
|
|
|
|
|
|
|
58 |
|
59 |
// establish connection to database and apply migrations
|
60 |
// -> create post table if not exists
|
61 |
let conn = Database::connect(&db_url).await.unwrap();
|
62 |
Migrator::up(&conn, None).await.unwrap();
|
63 |
|
64 |
-
let static_provider = StaticProvider::new(
|
65 |
-
s3_access_key.as_str(),
|
66 |
-
s3_secret_key.as_str(),
|
67 |
-
None,
|
68 |
-
);
|
69 |
|
70 |
let s3_client = Client::new(
|
71 |
s3_base_url.parse::<BaseUrl>()?,
|
72 |
Some(Box::new(static_provider)),
|
73 |
None,
|
74 |
-
|
75 |
)?;
|
76 |
|
77 |
let state = AppState { conn, s3_client };
|
@@ -82,18 +81,20 @@ async fn main() -> Result<(), AppError> {
|
|
82 |
App::new()
|
83 |
.service(Files::new("/static", "./static"))
|
84 |
.app_data(web::Data::new(state.clone()))
|
85 |
-
.wrap(
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
91 |
.wrap(
|
92 |
CookieSession::signed(&[0; 32])
|
93 |
.name("session-cookie")
|
94 |
.secure(false)
|
95 |
// WARNING(alex): This uses the `time` crate, not `std::time`!
|
96 |
-
.expires_in_time(Duration::seconds(60))
|
97 |
)
|
98 |
.wrap(middleware::Logger::default())
|
99 |
.configure(init)
|
@@ -137,4 +138,4 @@ fn init(cfg: &mut web::ServiceConfig) {
|
|
137 |
cfg.service(api::user_info::login);
|
138 |
cfg.service(api::user_info::register);
|
139 |
cfg.service(api::user_info::setting);
|
140 |
-
}
|
|
|
5 |
|
6 |
use std::env;
|
7 |
use actix_files::Files;
|
8 |
+
use actix_identity::{ CookieIdentityPolicy, IdentityService, RequestIdentity };
|
9 |
use actix_session::CookieSession;
|
10 |
+
use actix_web::{ web, App, HttpServer, middleware, Error };
|
11 |
use actix_web::cookie::time::Duration;
|
12 |
use actix_web::dev::ServiceRequest;
|
13 |
use actix_web::error::ErrorUnauthorized;
|
|
|
16 |
use minio::s3::client::Client;
|
17 |
use minio::s3::creds::StaticProvider;
|
18 |
use minio::s3::http::BaseUrl;
|
19 |
+
use sea_orm::{ Database, DatabaseConnection };
|
20 |
+
use migration::{ Migrator, MigratorTrait };
|
21 |
+
use crate::errors::{ AppError, UserError };
|
22 |
|
23 |
#[derive(Debug, Clone)]
|
24 |
struct AppState {
|
|
|
28 |
|
29 |
pub(crate) async fn validator(
|
30 |
req: ServiceRequest,
|
31 |
+
credentials: BearerAuth
|
32 |
) -> Result<ServiceRequest, Error> {
|
33 |
if let Some(token) = req.get_identity() {
|
34 |
+
println!("{}, {}", credentials.token(), token);
|
35 |
(credentials.token() == token)
|
36 |
.then(|| req)
|
37 |
.ok_or(ErrorUnauthorized(UserError::InvalidToken))
|
|
|
52 |
let port = env::var("PORT").expect("PORT is not set in .env file");
|
53 |
let server_url = format!("{host}:{port}");
|
54 |
|
55 |
+
let mut s3_base_url = env::var("MINIO_HOST").expect("MINIO_HOST is not set in .env file");
|
56 |
+
let s3_access_key = env::var("MINIO_USR").expect("MINIO_USR is not set in .env file");
|
57 |
+
let s3_secret_key = env::var("MINIO_PWD").expect("MINIO_PWD is not set in .env file");
|
58 |
+
if s3_base_url.find("http") != Some(0) {
|
59 |
+
s3_base_url = format!("http://{}", s3_base_url);
|
60 |
+
}
|
61 |
|
62 |
// establish connection to database and apply migrations
|
63 |
// -> create post table if not exists
|
64 |
let conn = Database::connect(&db_url).await.unwrap();
|
65 |
Migrator::up(&conn, None).await.unwrap();
|
66 |
|
67 |
+
let static_provider = StaticProvider::new(s3_access_key.as_str(), s3_secret_key.as_str(), None);
|
|
|
|
|
|
|
|
|
68 |
|
69 |
let s3_client = Client::new(
|
70 |
s3_base_url.parse::<BaseUrl>()?,
|
71 |
Some(Box::new(static_provider)),
|
72 |
None,
|
73 |
+
Some(true)
|
74 |
)?;
|
75 |
|
76 |
let state = AppState { conn, s3_client };
|
|
|
81 |
App::new()
|
82 |
.service(Files::new("/static", "./static"))
|
83 |
.app_data(web::Data::new(state.clone()))
|
84 |
+
.wrap(
|
85 |
+
IdentityService::new(
|
86 |
+
CookieIdentityPolicy::new(&[0; 32])
|
87 |
+
.name("auth-cookie")
|
88 |
+
.login_deadline(Duration::seconds(120))
|
89 |
+
.secure(false)
|
90 |
+
)
|
91 |
+
)
|
92 |
.wrap(
|
93 |
CookieSession::signed(&[0; 32])
|
94 |
.name("session-cookie")
|
95 |
.secure(false)
|
96 |
// WARNING(alex): This uses the `time` crate, not `std::time`!
|
97 |
+
.expires_in_time(Duration::seconds(60))
|
98 |
)
|
99 |
.wrap(middleware::Logger::default())
|
100 |
.configure(init)
|
|
|
138 |
cfg.service(api::user_info::login);
|
139 |
cfg.service(api::user_info::register);
|
140 |
cfg.service(api::user_info::setting);
|
141 |
+
}
|