File size: 4,798 Bytes
249b27c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import json, re, sys, os, hashlib, copy, glob, util, time, random
from util.es_conn import HuEs, Postgres
from util import rmSpace, findMaxDt
from FlagEmbedding import FlagModel
from nlp import huchunk, huqie
import base64, hashlib
from io import BytesIO
from elasticsearch_dsl import Q
from parser import (
PdfParser,
DocxParser,
ExcelParser
)
from nlp.huchunk import (
PdfChunker,
DocxChunker,
ExcelChunker,
PptChunker,
TextChunker
)
ES = HuEs("infiniflow")
BATCH_SIZE = 64
PG = Postgres("infiniflow", "docgpt")
PDF = PdfChunker(PdfParser())
DOC = DocxChunker(DocxParser())
EXC = ExcelChunker(ExcelParser())
PPT = PptChunker()
def chuck_doc(name):
name = os.path.split(name)[-1].lower().split(".")[-1]
if name.find("pdf") >= 0: return PDF(name)
if name.find("doc") >= 0: return DOC(name)
if name.find("xlsx") >= 0: return EXC(name)
if name.find("ppt") >= 0: return PDF(name)
if name.find("pdf") >= 0: return PPT(name)
if re.match(r"(txt|csv)", name): return TextChunker(name)
def collect(comm, mod, tm):
sql = f"""
select
did,
uid,
doc_name,
location,
updated_at
from docinfo
where
updated_at >= '{tm}'
and kb_progress = 0
and type = 'doc'
and MOD(uid, {comm}) = {mod}
order by updated_at asc
limit 1000
"""
df = PG.select(sql)
df = df.fillna("")
mtm = str(df["updated_at"].max())[:19]
print("TOTAL:", len(df), "To: ", mtm)
return df, mtm
def set_progress(did, prog, msg):
sql = f"""
update docinfo set kb_progress={prog}, kb_progress_msg='{msg}' where did={did}
"""
PG.update(sql)
def build(row):
if row["size"] > 256000000:
set_progress(row["did"], -1, "File size exceeds( <= 256Mb )")
return []
doc = {
"doc_id": row["did"],
"title_tks": huqie.qie(os.path.split(row["location"])[-1]),
"updated_at": row["updated_at"]
}
random.seed(time.time())
set_progress(row["did"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
obj = chuck_doc(row["location"])
if not obj:
set_progress(row["did"], -1, "Unsuported file type.")
return []
set_progress(row["did"], random.randint(20, 60)/100.)
output_buffer = BytesIO()
docs = []
md5 = hashlib.md5()
for txt, img in obj.text_chunks:
d = copy.deepcopy(doc)
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["content_ltks"] = huqie.qie(txt)
d["docnm_kwd"] = rmSpace(d["docnm_tks"])
if not img:
docs.append(d)
continue
img.save(output_buffer, format='JPEG')
d["img_bin"] = base64.b64encode(output_buffer.getvalue())
docs.append(d)
for arr, img in obj.table_chunks:
for i, txt in enumerate(arr):
d = copy.deepcopy(doc)
d["content_ltks"] = huqie.qie(txt)
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
if not img:
docs.append(d)
continue
img.save(output_buffer, format='JPEG')
d["img_bin"] = base64.b64encode(output_buffer.getvalue())
docs.append(d)
set_progress(row["did"], random.randint(60, 70)/100., "Finished slicing. Start to embedding the content.")
return docs
def index_name(uid):return f"docgpt_{uid}"
def init_kb(row):
idxnm = index_name(row["uid"])
if ES.indexExist(idxnm): return
return ES.createIdx(idxnm, json.load(open("res/mapping.json", "r")))
model = None
def embedding(docs):
global model
tts = model.encode([rmSpace(d["title_tks"]) for d in docs])
cnts = model.encode([rmSpace(d["content_ltks"]) for d in docs])
vects = 0.1 * tts + 0.9 * cnts
assert len(vects) == len(docs)
for i,d in enumerate(docs):d["q_vec"] = vects[i].tolist()
for d in docs:
set_progress(d["doc_id"], random.randint(70, 95)/100.,
"Finished embedding! Start to build index!")
def main(comm, mod):
tm_fnm = f"res/{comm}-{mod}.tm"
tmf = open(tm_fnm, "a+")
tm = findMaxDt(tm_fnm)
rows, tm = collect(comm, mod, tm)
for r in rows:
if r["is_deleted"]:
ES.deleteByQuery(Q("term", dock_id=r["did"]), index_name(r["uid"]))
continue
cks = build(r)
## TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
embedding(cks)
if cks: init_kb(r)
ES.bulk(cks, index_name(r["uid"]))
tmf.write(str(r["updated_at"]) + "\n")
tmf.close()
if __name__ == "__main__":
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
main(comm, rank)
|