KevinHuSh commited on
Commit
3245107
·
1 Parent(s): d4fd138

use minio to store uploaded files; build dialog server; (#16)

Browse files

* format code

* use minio to store uploaded files; build dialog server;

python/llm/__init__.py CHANGED
@@ -1 +1,2 @@
1
  from .embedding_model import HuEmbedding
 
 
1
  from .embedding_model import HuEmbedding
2
+ from .chat_model import GptTurbo
python/llm/chat_model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ import openapi
3
+ import os
4
+
5
+ class Base(ABC):
6
+ def chat(self, system, history, gen_conf):
7
+ raise NotImplementedError("Please implement encode method!")
8
+
9
+
10
+ class GptTurbo(Base):
11
+ def __init__(self):
12
+ openapi.api_key = os.environ["OPENAPI_KEY"]
13
+
14
+ def chat(self, system, history, gen_conf):
15
+ history.insert(0, {"role": "system", "content": system})
16
+ res = openapi.ChatCompletion.create(model="gpt-3.5-turbo",
17
+ messages=history,
18
+ **gen_conf)
19
+ return res.choices[0].message.content.strip()
20
+
21
+
22
+ class QWen(Base):
23
+ def chat(self, system, history, gen_conf):
24
+ from http import HTTPStatus
25
+ from dashscope import Generation
26
+ from dashscope.api_entities.dashscope_response import Role
27
+ response = Generation.call(
28
+ Generation.Models.qwen_turbo,
29
+ messages=messages,
30
+ result_format='message'
31
+ )
32
+ if response.status_code == HTTPStatus.OK:
33
+ return response.output.choices[0]['message']['content']
34
+ return response.message
python/llm/embedding_model.py CHANGED
@@ -1,6 +1,7 @@
1
  from abc import ABC
2
  from FlagEmbedding import FlagModel
3
  import torch
 
4
 
5
  class Base(ABC):
6
  def encode(self, texts: list, batch_size=32):
@@ -27,5 +28,5 @@ class HuEmbedding(Base):
27
  def encode(self, texts: list, batch_size=32):
28
  res = []
29
  for i in range(0, len(texts), batch_size):
30
- res.extend(self.encode(texts[i:i+batch_size]))
31
- return res
 
1
  from abc import ABC
2
  from FlagEmbedding import FlagModel
3
  import torch
4
+ import numpy as np
5
 
6
  class Base(ABC):
7
  def encode(self, texts: list, batch_size=32):
 
28
  def encode(self, texts: list, batch_size=32):
29
  res = []
30
  for i in range(0, len(texts), batch_size):
31
+ res.extend(self.model.encode(texts[i:i+batch_size]).tolist())
32
+ return np.array(res)
python/nlp/huchunk.py CHANGED
@@ -372,7 +372,7 @@ class PptChunker(HuChunker):
372
 
373
  def __call__(self, fnm):
374
  from pptx import Presentation
375
- ppt = Presentation(fnm)
376
  flds = self.Fields()
377
  flds.text_chunks = []
378
  for slide in ppt.slides:
@@ -396,7 +396,9 @@ class TextChunker(HuChunker):
396
  @staticmethod
397
  def is_binary_file(file_path):
398
  mime = magic.Magic(mime=True)
399
- file_type = mime.from_file(file_path)
 
 
400
  if 'text' in file_type:
401
  return False
402
  else:
 
372
 
373
  def __call__(self, fnm):
374
  from pptx import Presentation
375
+ ppt = Presentation(fnm) if isinstance(fnm, str) else Presentation(BytesIO(fnm))
376
  flds = self.Fields()
377
  flds.text_chunks = []
378
  for slide in ppt.slides:
 
396
  @staticmethod
397
  def is_binary_file(file_path):
398
  mime = magic.Magic(mime=True)
399
+ if isinstance(file_path, str):
400
+ file_type = mime.from_file(file_path)
401
+ else:file_type = mime.from_buffer(file_path)
402
  if 'text' in file_type:
403
  return False
404
  else:
python/nlp/search.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from elasticsearch_dsl import Q,Search,A
3
+ from typing import List, Optional, Tuple,Dict, Union
4
+ from dataclasses import dataclass
5
+ from util import setup_logging, rmSpace
6
+ from nlp import huqie, query
7
+ from datetime import datetime
8
+ from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
9
+ import numpy as np
10
+ from copy import deepcopy
11
+
12
+ class Dealer:
13
+ def __init__(self, es, emb_mdl):
14
+ self.qryr = query.EsQueryer(es)
15
+ self.qryr.flds = ["title_tks^10", "title_sm_tks^5", "content_ltks^2", "content_sm_ltks"]
16
+ self.es = es
17
+ self.emb_mdl = emb_mdl
18
+
19
+ @dataclass
20
+ class SearchResult:
21
+ total:int
22
+ ids: List[str]
23
+ query_vector: List[float] = None
24
+ field: Optional[Dict] = None
25
+ highlight: Optional[Dict] = None
26
+ aggregation: Union[List, Dict, None] = None
27
+ keywords: Optional[List[str]] = None
28
+ group_docs: List[List] = None
29
+
30
+ def _vector(self, txt, sim=0.8, topk=10):
31
+ return {
32
+ "field": "q_vec",
33
+ "k": topk,
34
+ "similarity": sim,
35
+ "num_candidates": 1000,
36
+ "query_vector": self.emb_mdl.encode_queries(txt)
37
+ }
38
+
39
+ def search(self, req, idxnm, tks_num=3):
40
+ keywords = []
41
+ qst = req.get("question", "")
42
+
43
+ bqry,keywords = self.qryr.question(qst)
44
+ if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
45
+ bqry.filter.append(Q("exists", field="q_tks"))
46
+ bqry.boost = 0.05
47
+ print(bqry)
48
+
49
+ s = Search()
50
+ pg = int(req.get("page", 1))-1
51
+ ps = int(req.get("size", 1000))
52
+ src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
53
+ "image_id", "doc_id", "q_vec"])
54
+
55
+ s = s.query(bqry)[pg*ps:(pg+1)*ps]
56
+ s = s.highlight("content_ltks")
57
+ s = s.highlight("title_ltks")
58
+ if not qst: s = s.sort({"create_time":{"order":"desc", "unmapped_type":"date"}})
59
+
60
+ s = s.highlight_options(
61
+ fragment_size = 120,
62
+ number_of_fragments=5,
63
+ boundary_scanner_locale="zh-CN",
64
+ boundary_scanner="SENTENCE",
65
+ boundary_chars=",./;:\\!(),。?:!……()——、"
66
+ )
67
+ s = s.to_dict()
68
+ q_vec = []
69
+ if req.get("vector"):
70
+ s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
71
+ s["knn"]["filter"] = bqry.to_dict()
72
+ del s["highlight"]
73
+ q_vec = s["knn"]["query_vector"]
74
+ res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
75
+ print("TOTAL: ", self.es.getTotal(res))
76
+ if self.es.getTotal(res) == 0 and "knn" in s:
77
+ bqry,_ = self.qryr.question(qst, min_match="10%")
78
+ if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
79
+ s["query"] = bqry.to_dict()
80
+ s["knn"]["filter"] = bqry.to_dict()
81
+ s["knn"]["similarity"] = 0.7
82
+ res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
83
+
84
+ kwds = set([])
85
+ for k in keywords:
86
+ kwds.add(k)
87
+ for kk in huqie.qieqie(k).split(" "):
88
+ if len(kk) < 2:continue
89
+ if kk in kwds:continue
90
+ kwds.add(kk)
91
+
92
+ aggs = self.getAggregation(res, "docnm_kwd")
93
+
94
+ return self.SearchResult(
95
+ total = self.es.getTotal(res),
96
+ ids = self.es.getDocIds(res),
97
+ query_vector = q_vec,
98
+ aggregation = aggs,
99
+ highlight = self.getHighlight(res),
100
+ field = self.getFields(res, ["docnm_kwd", "content_ltks",
101
+ "kb_id","image_id", "doc_id", "q_vec"]),
102
+ keywords = list(kwds)
103
+ )
104
+
105
+ def getAggregation(self, res, g):
106
+ if not "aggregations" in res or "aggs_"+g not in res["aggregations"]:return
107
+ bkts = res["aggregations"]["aggs_"+g]["buckets"]
108
+ return [(b["key"], b["doc_count"]) for b in bkts]
109
+
110
+ def getHighlight(self, res):
111
+ def rmspace(line):
112
+ eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
113
+ r = []
114
+ for t in line.split(" "):
115
+ if not t:continue
116
+ if len(r)>0 and len(t)>0 and r[-1][-1] in eng and t[0] in eng:r.append(" ")
117
+ r.append(t)
118
+ r = "".join(r)
119
+ return r
120
+
121
+ ans = {}
122
+ for d in res["hits"]["hits"]:
123
+ hlts = d.get("highlight")
124
+ if not hlts:continue
125
+ ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
126
+ return ans
127
+
128
+ def getFields(self, sres, flds):
129
+ res = {}
130
+ if not flds:return {}
131
+ for d in self.es.getSource(sres):
132
+ m = {n:d.get(n) for n in flds if d.get(n) is not None}
133
+ for n,v in m.items():
134
+ if type(v) == type([]):
135
+ m[n] = "\t".join([str(vv) for vv in v])
136
+ continue
137
+ if type(v) != type(""):m[n] = str(m[n])
138
+ m[n] = rmSpace(m[n])
139
+
140
+ if m:res[d["id"]] = m
141
+ return res
142
+
143
+
144
+ @staticmethod
145
+ def trans2floats(txt):
146
+ return [float(t) for t in txt.split("\t")]
147
+
148
+
149
+ def insert_citations(self, ans, top_idx, sres, vfield = "q_vec", cfield="content_ltks"):
150
+
151
+ ins_embd = [Dealer.trans2floats(sres.field[sres.ids[i]][vfield]) for i in top_idx]
152
+ ins_tw =[sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
153
+ s = 0
154
+ e = 0
155
+ res = ""
156
+ def citeit():
157
+ nonlocal s, e, ans, res
158
+ if not ins_embd:return
159
+ embd = self.emb_mdl.encode(ans[s: e])
160
+ sim = self.qryr.hybrid_similarity(embd,
161
+ ins_embd,
162
+ huqie.qie(ans[s:e]).split(" "),
163
+ ins_tw)
164
+ print(ans[s: e], sim)
165
+ mx = np.max(sim)*0.99
166
+ if mx < 0.55:return
167
+ cita = list(set([top_idx[i] for i in range(len(ins_embd)) if sim[i] >mx]))[:4]
168
+ for i in cita: res += f"@?{i}?@"
169
+
170
+ return cita
171
+
172
+ punct = set(";。?!!")
173
+ if not self.qryr.isChinese(ans):
174
+ punct.add("?")
175
+ punct.add(".")
176
+ while e < len(ans):
177
+ if e - s < 12 or ans[e] not in punct:
178
+ e += 1
179
+ continue
180
+ if ans[e] == "." and e+1<len(ans) and re.match(r"[0-9]", ans[e+1]):
181
+ e += 1
182
+ continue
183
+ if ans[e] == "." and e-2>=0 and ans[e-2] == "\n":
184
+ e += 1
185
+ continue
186
+ res += ans[s: e]
187
+ citeit()
188
+ res += ans[e]
189
+ e += 1
190
+ s = e
191
+
192
+ if s< len(ans):
193
+ res += ans[s:]
194
+ citeit()
195
+
196
+ return res
197
+
198
+
199
+ def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, vfield="q_vec", cfield="content_ltks"):
200
+ ins_embd = [Dealer.trans2floats(sres.field[i]["q_vec"]) for i in sres.ids]
201
+ if not ins_embd: return []
202
+ ins_tw =[sres.field[i][cfield].split(" ") for i in sres.ids]
203
+ #return CosineSimilarity([sres.query_vector], ins_embd)[0]
204
+ sim = self.qryr.hybrid_similarity(sres.query_vector,
205
+ ins_embd,
206
+ huqie.qie(query).split(" "),
207
+ ins_tw, tkweight, vtweight)
208
+ return sim
209
+
210
+
211
+
212
+ if __name__ == "__main__":
213
+ from util import es_conn
214
+ SE = Dealer(es_conn.HuEs("infiniflow"))
215
+ qs = [
216
+ "胡凯",
217
+ ""
218
+ ]
219
+ for q in qs:
220
+ print(">>>>>>>>>>>>>>>>>>>>", q)
221
+ print(SE.search({"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))
python/parser/docx_parser.py CHANGED
@@ -3,6 +3,7 @@ import re
3
  import pandas as pd
4
  from collections import Counter
5
  from nlp import huqie
 
6
 
7
 
8
  class HuDocxParser:
@@ -97,7 +98,7 @@ class HuDocxParser:
97
  return ["\n".join(lines)]
98
 
99
  def __call__(self, fnm):
100
- self.doc = Document(fnm)
101
  secs = [(p.text, p.style.name) for p in self.doc.paragraphs]
102
  tbls = [self.__extract_table_content(tb) for tb in self.doc.tables]
103
  return secs, tbls
 
3
  import pandas as pd
4
  from collections import Counter
5
  from nlp import huqie
6
+ from io import BytesIO
7
 
8
 
9
  class HuDocxParser:
 
98
  return ["\n".join(lines)]
99
 
100
  def __call__(self, fnm):
101
+ self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm))
102
  secs = [(p.text, p.style.name) for p in self.doc.paragraphs]
103
  tbls = [self.__extract_table_content(tb) for tb in self.doc.tables]
104
  return secs, tbls
python/parser/excel_parser.py CHANGED
@@ -1,10 +1,12 @@
1
  from openpyxl import load_workbook
2
  import sys
 
3
 
4
 
5
  class HuExcelParser:
6
  def __call__(self, fnm):
7
- wb = load_workbook(fnm)
 
8
  res = []
9
  for sheetname in wb.sheetnames:
10
  ws = wb[sheetname]
 
1
  from openpyxl import load_workbook
2
  import sys
3
+ from io import BytesIO
4
 
5
 
6
  class HuExcelParser:
7
  def __call__(self, fnm):
8
+ if isinstance(fnm, str):wb = load_workbook(fnm)
9
+ else: wb = load_workbook(BytesIO(fnm))
10
  res = []
11
  for sheetname in wb.sheetnames:
12
  ws = wb[sheetname]
python/parser/pdf_parser.py CHANGED
@@ -1,4 +1,5 @@
1
  import xgboost as xgb
 
2
  import torch
3
  import re
4
  import pdfplumber
@@ -1525,7 +1526,7 @@ class HuParser:
1525
  return "\n\n".join(res)
1526
 
1527
  def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
1528
- self.pdf = pdfplumber.open(fnm)
1529
  self.lefted_chars = []
1530
  self.mean_height = []
1531
  self.mean_width = []
 
1
  import xgboost as xgb
2
+ from io import BytesIO
3
  import torch
4
  import re
5
  import pdfplumber
 
1526
  return "\n\n".join(res)
1527
 
1528
  def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
1529
+ self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm))
1530
  self.lefted_chars = []
1531
  self.mean_height = []
1532
  self.mean_width = []
python/svr/dialog_svr.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding:utf-8 -*-
2
+ import sys, os, re,inspect,json,traceback,logging,argparse, copy
3
+ sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
4
+ from tornado.web import RequestHandler,Application
5
+ from tornado.ioloop import IOLoop
6
+ from tornado.httpserver import HTTPServer
7
+ from tornado.options import define,options
8
+ from util import es_conn, setup_logging
9
+ from svr import sec_search as search
10
+ from svr.rpc_proxy import RPCProxy
11
+ from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
12
+ from nlp import huqie
13
+ from nlp import query as Query
14
+ from llm import HuEmbedding, GptTurbo
15
+ import numpy as np
16
+ from io import BytesIO
17
+ from util import config
18
+ from timeit import default_timer as timer
19
+ from collections import OrderedDict
20
+
21
+ SE = None
22
+ CFIELD="content_ltks"
23
+ EMBEDDING = HuEmbedding()
24
+ LLM = GptTurbo()
25
+
26
+ def get_QA_pairs(hists):
27
+ pa = []
28
+ for h in hists:
29
+ for k in ["user", "assistant"]:
30
+ if h.get(k):
31
+ pa.append({
32
+ "content": h[k],
33
+ "role": k,
34
+ })
35
+
36
+ for p in pa[:-1]: assert len(p) == 2, p
37
+ return pa
38
+
39
+
40
+
41
+ def get_instruction(sres, top_i, max_len=8096 fld="content_ltks"):
42
+ max_len //= len(top_i)
43
+ # add instruction to prompt
44
+ instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
45
+ if len(instructions)>2:
46
+ # Said that LLM is sensitive to the first and the last one, so
47
+ # rearrange the order of references
48
+ instructions.append(copy.deepcopy(instructions[1]))
49
+ instructions.pop(1)
50
+
51
+ def token_num(txt):
52
+ c = 0
53
+ for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt):
54
+ if re.match(r"[a-zA-Z-]+$", tk):
55
+ c += 1
56
+ continue
57
+ c += len(tk)
58
+ return c
59
+
60
+ _inst = ""
61
+ for ins in instructions:
62
+ if token_num(_inst) > 4096:
63
+ _inst += "\n知识库:" + instructions[-1][:max_len]
64
+ break
65
+ _inst += "\n知识库:" + ins[:max_len]
66
+ return _inst
67
+
68
+
69
+ def prompt_and_answer(history, inst):
70
+ hist = get_QA_pairs(history)
71
+ chks = []
72
+ for s in re.split(r"[::;;。\n\r]+", inst):
73
+ if s: chks.append(s)
74
+ chks = len(set(chks))/(0.1+len(chks))
75
+ print("Duplication portion:", chks)
76
+
77
+ system = """
78
+ 你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
79
+ 以下是知识库:
80
+ %s
81
+ 以上是知识库。
82
+ """%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
83
+
84
+ print("【PROMPT】:", system)
85
+ start = timer()
86
+ response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
87
+ print("GENERATE: ", timer()-start)
88
+ print("===>>", response)
89
+ return response
90
+
91
+
92
+ class Handler(RequestHandler):
93
+ def post(self):
94
+ global SE,MUST_TK_NUM
95
+ param = json.loads(self.request.body.decode('utf-8'))
96
+ try:
97
+ question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
98
+ res = SE.search({
99
+ "question": question,
100
+ "kb_ids": param.get("kb_ids", []),
101
+ "size": param.get("topn", 15)
102
+ })
103
+
104
+ sim = SE.rerank(res, question)
105
+ rk_idx = np.argsort(sim*-1)
106
+ topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
107
+ inst = get_instruction(res, topidx)
108
+
109
+ ans, topidx = prompt_and_answer(param["history"], inst)
110
+ ans = SE.insert_citations(ans, topidx, res)
111
+
112
+ refer = OrderedDict()
113
+ docnms = {}
114
+ for i in rk_idx:
115
+ did = res.field[res.ids[i]]["doc_id"])
116
+ if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"])
117
+ if did not in refer: refer[did] = []
118
+ refer[did].append({
119
+ "chunk_id": res.ids[i],
120
+ "content": res.field[res.ids[i]]["content_ltks"]),
121
+ "image": ""
122
+ })
123
+
124
+ print("::::::::::::::", ans)
125
+ self.write(json.dumps({
126
+ "code":0,
127
+ "msg":"success",
128
+ "data":{
129
+ "uid": param["uid"],
130
+ "dialog_id": param["dialog_id"],
131
+ "assistant": ans
132
+ "refer": [{
133
+ "did": did,
134
+ "doc_name": docnms[did],
135
+ "chunks": chunks
136
+ } for did, chunks in refer.items()]
137
+ }
138
+ }))
139
+ logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
140
+
141
+ except Exception as e:
142
+ logging.error("Request 500: "+str(e))
143
+ self.write(json.dumps({
144
+ "code":500,
145
+ "msg":str(e),
146
+ "data":{}
147
+ }))
148
+ print(traceback.format_exc())
149
+
150
+
151
+ if __name__ == '__main__':
152
+ parser = argparse.ArgumentParser()
153
+ parser.add_argument("--port", default=4455, type=int, help="Port used for service")
154
+ ARGS = parser.parse_args()
155
+
156
+ SE = search.ResearchReportSearch(es_conn.HuEs("infiniflow"), EMBEDDING)
157
+
158
+ app = Application([(r'/v1/chat/completions', Handler)],debug=False)
159
+ http_server = HTTPServer(app)
160
+ http_server.bind(ARGS.port)
161
+ http_server.start(3)
162
+
163
+ IOLoop.current().start()
164
+
python/svr/parse_user_docs.py CHANGED
@@ -34,18 +34,14 @@ DOC = DocxChunker(DocxParser())
34
  EXC = ExcelChunker(ExcelParser())
35
  PPT = PptChunker()
36
 
37
- UPLOAD_LOCATION = os.environ.get("UPLOAD_LOCATION", "./")
38
- logging.warning(f"The files are stored in {UPLOAD_LOCATION}, please check it!")
39
-
40
-
41
- def chuck_doc(name):
42
  suff = os.path.split(name)[-1].lower().split(".")[-1]
43
- if suff.find("pdf") >= 0: return PDF(name)
44
- if suff.find("doc") >= 0: return DOC(name)
45
- if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(name)
46
- if suff.find("ppt") >= 0: return PPT(name)
47
 
48
- return TextChunker()(name)
49
 
50
 
51
  def collect(comm, mod, tm):
@@ -115,7 +111,7 @@ def build(row):
115
  random.seed(time.time())
116
  set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
117
  try:
118
- obj = chuck_doc(os.path.join(UPLOAD_LOCATION, row["location"]))
119
  except Exception as e:
120
  if re.search("(No such file|not found)", str(e)):
121
  set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"])
@@ -133,9 +129,11 @@ def build(row):
133
  doc = {
134
  "doc_id": row["did"],
135
  "kb_id": [str(row["kb_id"])],
 
136
  "title_tks": huqie.qie(os.path.split(row["location"])[-1]),
137
  "updated_at": str(row["updated_at"]).replace("T", " ")[:19]
138
  }
 
139
  output_buffer = BytesIO()
140
  docs = []
141
  md5 = hashlib.md5()
@@ -144,11 +142,14 @@ def build(row):
144
  md5.update((txt + str(d["doc_id"])).encode("utf-8"))
145
  d["_id"] = md5.hexdigest()
146
  d["content_ltks"] = huqie.qie(txt)
 
147
  if not img:
148
  docs.append(d)
149
  continue
150
  img.save(output_buffer, format='JPEG')
151
- d["img_bin"] = str(output_buffer.getvalue())
 
 
152
  docs.append(d)
153
 
154
  for arr, img in obj.table_chunks:
 
34
  EXC = ExcelChunker(ExcelParser())
35
  PPT = PptChunker()
36
 
37
+ def chuck_doc(name, binary):
 
 
 
 
38
  suff = os.path.split(name)[-1].lower().split(".")[-1]
39
+ if suff.find("pdf") >= 0: return PDF(binary)
40
+ if suff.find("doc") >= 0: return DOC(binary)
41
+ if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
42
+ if suff.find("ppt") >= 0: return PPT(binary)
43
 
44
+ return TextChunker()(binary)
45
 
46
 
47
  def collect(comm, mod, tm):
 
111
  random.seed(time.time())
112
  set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
113
  try:
114
+ obj = chuck_doc(row["doc_name"], MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
115
  except Exception as e:
116
  if re.search("(No such file|not found)", str(e)):
117
  set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"])
 
129
  doc = {
130
  "doc_id": row["did"],
131
  "kb_id": [str(row["kb_id"])],
132
+ "docnm_kwd": os.path.split(row["location"])[-1],
133
  "title_tks": huqie.qie(os.path.split(row["location"])[-1]),
134
  "updated_at": str(row["updated_at"]).replace("T", " ")[:19]
135
  }
136
+ doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
137
  output_buffer = BytesIO()
138
  docs = []
139
  md5 = hashlib.md5()
 
142
  md5.update((txt + str(d["doc_id"])).encode("utf-8"))
143
  d["_id"] = md5.hexdigest()
144
  d["content_ltks"] = huqie.qie(txt)
145
+ d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
146
  if not img:
147
  docs.append(d)
148
  continue
149
  img.save(output_buffer, format='JPEG')
150
+ MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
151
+ output_buffer.getvalue())
152
+ d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
153
  docs.append(d)
154
 
155
  for arr, img in obj.table_chunks:
src/api/doc_info.rs CHANGED
@@ -1,9 +1,9 @@
1
  use std::collections::HashMap;
2
- use std::io::Write;
3
- use actix_multipart_extract::{File, Multipart, MultipartForm};
4
- use actix_web::{get, HttpResponse, post, web};
5
- use chrono::{Utc, FixedOffset};
6
- use minio::s3::args::{BucketExistsArgs, MakeBucketArgs, UploadObjectArgs};
7
  use sea_orm::DbConn;
8
  use crate::api::JsonResponse;
9
  use crate::AppState;
@@ -12,9 +12,6 @@ use crate::errors::AppError;
12
  use crate::service::doc_info::{ Mutation, Query };
13
  use serde::Deserialize;
14
 
15
- const BUCKET_NAME: &'static str = "docgpt-upload";
16
-
17
-
18
  fn now() -> chrono::DateTime<FixedOffset> {
19
  Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
20
  }
@@ -74,53 +71,71 @@ async fn upload(
74
  ) -> Result<HttpResponse, AppError> {
75
  let uid = payload.uid;
76
  let file_name = payload.file_field.name.as_str();
77
- async fn add_number_to_filename(file_name: &str, conn:&DbConn, uid:i64, parent_id:i64) -> String {
 
 
 
 
 
78
  let mut i = 0;
79
  let mut new_file_name = file_name.to_string();
80
  let arr: Vec<&str> = file_name.split(".").collect();
81
- let suffix = String::from(arr[arr.len()-1]);
82
- let preffix = arr[..arr.len()-1].join(".");
83
- let mut docs = Query::find_doc_infos_by_name(conn, uid, &new_file_name, Some(parent_id)).await.unwrap();
84
- while docs.len()>0 {
 
 
 
 
 
85
  i += 1;
86
  new_file_name = format!("{}_{}.{}", preffix, i, suffix);
87
- docs = Query::find_doc_infos_by_name(conn, uid, &new_file_name, Some(parent_id)).await.unwrap();
 
 
 
 
 
88
  }
89
  new_file_name
90
  }
91
  let fnm = add_number_to_filename(file_name, &data.conn, uid, payload.did).await;
92
 
93
- let s3_client = &data.s3_client;
 
94
  let buckets_exists = s3_client
95
- .bucket_exists(&BucketExistsArgs::new(BUCKET_NAME)?)
96
- .await?;
97
  if !buckets_exists {
98
- s3_client
99
- .make_bucket(&MakeBucketArgs::new(BUCKET_NAME)?)
100
- .await?;
 
101
  }
102
 
103
- s3_client
104
- .upload_object(
105
- &mut UploadObjectArgs::new(
106
- BUCKET_NAME,
107
- fnm.as_str(),
108
- format!("/{}/{}-{}", payload.uid, payload.did, fnm).as_str()
109
- )?
110
- )
111
- .await?;
112
-
113
- let location = format!("/{}/{}", BUCKET_NAME, fnm);
 
114
  let doc = Mutation::create_doc_info(&data.conn, Model {
115
- did:Default::default(),
116
- uid: uid,
117
  doc_name: fnm,
118
  size: payload.file_field.bytes.len() as i64,
119
  location,
120
  r#type: "doc".to_string(),
121
  created_at: now(),
122
  updated_at: now(),
123
- is_deleted:Default::default(),
124
  }).await?;
125
 
126
  let _ = Mutation::place_doc(&data.conn, payload.did, doc.did.unwrap()).await?;
 
1
  use std::collections::HashMap;
2
+ use std::io::BufReader;
3
+ use actix_multipart_extract::{ File, Multipart, MultipartForm };
4
+ use actix_web::{ HttpResponse, post, web };
5
+ use chrono::{ Utc, FixedOffset };
6
+ use minio::s3::args::{ BucketExistsArgs, MakeBucketArgs, PutObjectArgs };
7
  use sea_orm::DbConn;
8
  use crate::api::JsonResponse;
9
  use crate::AppState;
 
12
  use crate::service::doc_info::{ Mutation, Query };
13
  use serde::Deserialize;
14
 
 
 
 
15
  fn now() -> chrono::DateTime<FixedOffset> {
16
  Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
17
  }
 
71
  ) -> Result<HttpResponse, AppError> {
72
  let uid = payload.uid;
73
  let file_name = payload.file_field.name.as_str();
74
+ async fn add_number_to_filename(
75
+ file_name: &str,
76
+ conn: &DbConn,
77
+ uid: i64,
78
+ parent_id: i64
79
+ ) -> String {
80
  let mut i = 0;
81
  let mut new_file_name = file_name.to_string();
82
  let arr: Vec<&str> = file_name.split(".").collect();
83
+ let suffix = String::from(arr[arr.len() - 1]);
84
+ let preffix = arr[..arr.len() - 1].join(".");
85
+ let mut docs = Query::find_doc_infos_by_name(
86
+ conn,
87
+ uid,
88
+ &new_file_name,
89
+ Some(parent_id)
90
+ ).await.unwrap();
91
+ while docs.len() > 0 {
92
  i += 1;
93
  new_file_name = format!("{}_{}.{}", preffix, i, suffix);
94
+ docs = Query::find_doc_infos_by_name(
95
+ conn,
96
+ uid,
97
+ &new_file_name,
98
+ Some(parent_id)
99
+ ).await.unwrap();
100
  }
101
  new_file_name
102
  }
103
  let fnm = add_number_to_filename(file_name, &data.conn, uid, payload.did).await;
104
 
105
+ let bucket_name = format!("{}-upload", payload.uid);
106
+ let s3_client: &minio::s3::client::Client = &data.s3_client;
107
  let buckets_exists = s3_client
108
+ .bucket_exists(&BucketExistsArgs::new(&bucket_name).unwrap()).await
109
+ .unwrap();
110
  if !buckets_exists {
111
+ print!("Create bucket: {}", bucket_name.clone());
112
+ s3_client.make_bucket(&MakeBucketArgs::new(&bucket_name).unwrap()).await.unwrap();
113
+ } else {
114
+ print!("Existing bucket: {}", bucket_name.clone());
115
  }
116
 
117
+ let location = format!("/{}/{}", payload.did, fnm);
118
+ print!("===>{}", location.clone());
119
+ s3_client.put_object(
120
+ &mut PutObjectArgs::new(
121
+ &bucket_name,
122
+ &location,
123
+ &mut BufReader::new(payload.file_field.bytes.as_slice()),
124
+ Some(payload.file_field.bytes.len()),
125
+ None
126
+ )?
127
+ ).await?;
128
+
129
  let doc = Mutation::create_doc_info(&data.conn, Model {
130
+ did: Default::default(),
131
+ uid: uid,
132
  doc_name: fnm,
133
  size: payload.file_field.bytes.len() as i64,
134
  location,
135
  r#type: "doc".to_string(),
136
  created_at: now(),
137
  updated_at: now(),
138
+ is_deleted: Default::default(),
139
  }).await?;
140
 
141
  let _ = Mutation::place_doc(&data.conn, payload.did, doc.did.unwrap()).await?;
src/api/tag.rs DELETED
@@ -1,58 +0,0 @@
1
- use std::collections::HashMap;
2
- use actix_web::{get, HttpResponse, post, web};
3
- use actix_web::http::Error;
4
- use crate::api::JsonResponse;
5
- use crate::AppState;
6
- use crate::entity::tag_info;
7
- use crate::service::tag_info::{Mutation, Query};
8
-
9
- #[post("/v1.0/create_tag")]
10
- async fn create(model: web::Json<tag_info::Model>, data: web::Data<AppState>) -> Result<HttpResponse, Error> {
11
- let model = Mutation::create_tag(&data.conn, model.into_inner()).await.unwrap();
12
-
13
- let mut result = HashMap::new();
14
- result.insert("tid", model.tid.unwrap());
15
-
16
- let json_response = JsonResponse {
17
- code: 200,
18
- err: "".to_owned(),
19
- data: result,
20
- };
21
-
22
- Ok(HttpResponse::Ok()
23
- .content_type("application/json")
24
- .body(serde_json::to_string(&json_response).unwrap()))
25
- }
26
-
27
- #[post("/v1.0/delete_tag")]
28
- async fn delete(model: web::Json<tag_info::Model>, data: web::Data<AppState>) -> Result<HttpResponse, Error> {
29
- let _ = Mutation::delete_tag(&data.conn, model.tid).await.unwrap();
30
-
31
- let json_response = JsonResponse {
32
- code: 200,
33
- err: "".to_owned(),
34
- data: (),
35
- };
36
-
37
- Ok(HttpResponse::Ok()
38
- .content_type("application/json")
39
- .body(serde_json::to_string(&json_response).unwrap()))
40
- }
41
-
42
- #[get("/v1.0/tags")]
43
- async fn list(data: web::Data<AppState>) -> Result<HttpResponse, Error> {
44
- let tags = Query::find_tag_infos(&data.conn).await.unwrap();
45
-
46
- let mut result = HashMap::new();
47
- result.insert("tags", tags);
48
-
49
- let json_response = JsonResponse {
50
- code: 200,
51
- err: "".to_owned(),
52
- data: result,
53
- };
54
-
55
- Ok(HttpResponse::Ok()
56
- .content_type("application/json")
57
- .body(serde_json::to_string(&json_response).unwrap()))
58
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/main.rs CHANGED
@@ -5,9 +5,9 @@ mod errors;
5
 
6
  use std::env;
7
  use actix_files::Files;
8
- use actix_identity::{CookieIdentityPolicy, IdentityService, RequestIdentity};
9
  use actix_session::CookieSession;
10
- use actix_web::{web, App, HttpServer, middleware, Error};
11
  use actix_web::cookie::time::Duration;
12
  use actix_web::dev::ServiceRequest;
13
  use actix_web::error::ErrorUnauthorized;
@@ -16,9 +16,9 @@ use listenfd::ListenFd;
16
  use minio::s3::client::Client;
17
  use minio::s3::creds::StaticProvider;
18
  use minio::s3::http::BaseUrl;
19
- use sea_orm::{Database, DatabaseConnection};
20
- use migration::{Migrator, MigratorTrait};
21
- use crate::errors::{AppError, UserError};
22
 
23
  #[derive(Debug, Clone)]
24
  struct AppState {
@@ -28,10 +28,10 @@ struct AppState {
28
 
29
  pub(crate) async fn validator(
30
  req: ServiceRequest,
31
- credentials: BearerAuth,
32
  ) -> Result<ServiceRequest, Error> {
33
  if let Some(token) = req.get_identity() {
34
- println!("{}, {}",credentials.token(), token);
35
  (credentials.token() == token)
36
  .then(|| req)
37
  .ok_or(ErrorUnauthorized(UserError::InvalidToken))
@@ -52,26 +52,25 @@ async fn main() -> Result<(), AppError> {
52
  let port = env::var("PORT").expect("PORT is not set in .env file");
53
  let server_url = format!("{host}:{port}");
54
 
55
- let s3_base_url = env::var("S3_BASE_URL").expect("S3_BASE_URL is not set in .env file");
56
- let s3_access_key = env::var("S3_ACCESS_KEY").expect("S3_ACCESS_KEY is not set in .env file");;
57
- let s3_secret_key = env::var("S3_SECRET_KEY").expect("S3_SECRET_KEY is not set in .env file");;
 
 
 
58
 
59
  // establish connection to database and apply migrations
60
  // -> create post table if not exists
61
  let conn = Database::connect(&db_url).await.unwrap();
62
  Migrator::up(&conn, None).await.unwrap();
63
 
64
- let static_provider = StaticProvider::new(
65
- s3_access_key.as_str(),
66
- s3_secret_key.as_str(),
67
- None,
68
- );
69
 
70
  let s3_client = Client::new(
71
  s3_base_url.parse::<BaseUrl>()?,
72
  Some(Box::new(static_provider)),
73
  None,
74
- None,
75
  )?;
76
 
77
  let state = AppState { conn, s3_client };
@@ -82,18 +81,20 @@ async fn main() -> Result<(), AppError> {
82
  App::new()
83
  .service(Files::new("/static", "./static"))
84
  .app_data(web::Data::new(state.clone()))
85
- .wrap(IdentityService::new(
86
- CookieIdentityPolicy::new(&[0; 32])
87
- .name("auth-cookie")
88
- .login_deadline(Duration::seconds(120))
89
- .secure(false),
90
- ))
 
 
91
  .wrap(
92
  CookieSession::signed(&[0; 32])
93
  .name("session-cookie")
94
  .secure(false)
95
  // WARNING(alex): This uses the `time` crate, not `std::time`!
96
- .expires_in_time(Duration::seconds(60)),
97
  )
98
  .wrap(middleware::Logger::default())
99
  .configure(init)
@@ -137,4 +138,4 @@ fn init(cfg: &mut web::ServiceConfig) {
137
  cfg.service(api::user_info::login);
138
  cfg.service(api::user_info::register);
139
  cfg.service(api::user_info::setting);
140
- }
 
5
 
6
  use std::env;
7
  use actix_files::Files;
8
+ use actix_identity::{ CookieIdentityPolicy, IdentityService, RequestIdentity };
9
  use actix_session::CookieSession;
10
+ use actix_web::{ web, App, HttpServer, middleware, Error };
11
  use actix_web::cookie::time::Duration;
12
  use actix_web::dev::ServiceRequest;
13
  use actix_web::error::ErrorUnauthorized;
 
16
  use minio::s3::client::Client;
17
  use minio::s3::creds::StaticProvider;
18
  use minio::s3::http::BaseUrl;
19
+ use sea_orm::{ Database, DatabaseConnection };
20
+ use migration::{ Migrator, MigratorTrait };
21
+ use crate::errors::{ AppError, UserError };
22
 
23
  #[derive(Debug, Clone)]
24
  struct AppState {
 
28
 
29
  pub(crate) async fn validator(
30
  req: ServiceRequest,
31
+ credentials: BearerAuth
32
  ) -> Result<ServiceRequest, Error> {
33
  if let Some(token) = req.get_identity() {
34
+ println!("{}, {}", credentials.token(), token);
35
  (credentials.token() == token)
36
  .then(|| req)
37
  .ok_or(ErrorUnauthorized(UserError::InvalidToken))
 
52
  let port = env::var("PORT").expect("PORT is not set in .env file");
53
  let server_url = format!("{host}:{port}");
54
 
55
+ let mut s3_base_url = env::var("MINIO_HOST").expect("MINIO_HOST is not set in .env file");
56
+ let s3_access_key = env::var("MINIO_USR").expect("MINIO_USR is not set in .env file");
57
+ let s3_secret_key = env::var("MINIO_PWD").expect("MINIO_PWD is not set in .env file");
58
+ if s3_base_url.find("http") != Some(0) {
59
+ s3_base_url = format!("http://{}", s3_base_url);
60
+ }
61
 
62
  // establish connection to database and apply migrations
63
  // -> create post table if not exists
64
  let conn = Database::connect(&db_url).await.unwrap();
65
  Migrator::up(&conn, None).await.unwrap();
66
 
67
+ let static_provider = StaticProvider::new(s3_access_key.as_str(), s3_secret_key.as_str(), None);
 
 
 
 
68
 
69
  let s3_client = Client::new(
70
  s3_base_url.parse::<BaseUrl>()?,
71
  Some(Box::new(static_provider)),
72
  None,
73
+ Some(true)
74
  )?;
75
 
76
  let state = AppState { conn, s3_client };
 
81
  App::new()
82
  .service(Files::new("/static", "./static"))
83
  .app_data(web::Data::new(state.clone()))
84
+ .wrap(
85
+ IdentityService::new(
86
+ CookieIdentityPolicy::new(&[0; 32])
87
+ .name("auth-cookie")
88
+ .login_deadline(Duration::seconds(120))
89
+ .secure(false)
90
+ )
91
+ )
92
  .wrap(
93
  CookieSession::signed(&[0; 32])
94
  .name("session-cookie")
95
  .secure(false)
96
  // WARNING(alex): This uses the `time` crate, not `std::time`!
97
+ .expires_in_time(Duration::seconds(60))
98
  )
99
  .wrap(middleware::Logger::default())
100
  .configure(init)
 
138
  cfg.service(api::user_info::login);
139
  cfg.service(api::user_info::register);
140
  cfg.service(api::user_info::setting);
141
+ }