File size: 4,798 Bytes
249b27c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import json, re, sys, os, hashlib, copy, glob, util, time, random
from util.es_conn import HuEs, Postgres
from util import rmSpace, findMaxDt
from FlagEmbedding import FlagModel
from nlp import huchunk, huqie
import base64, hashlib
from io import BytesIO
from elasticsearch_dsl import Q
from parser import (
    PdfParser,
    DocxParser,
    ExcelParser
)
from nlp.huchunk import (
    PdfChunker,
    DocxChunker,
    ExcelChunker,
    PptChunker,
    TextChunker
)

ES = HuEs("infiniflow")
BATCH_SIZE = 64
PG = Postgres("infiniflow", "docgpt")

PDF = PdfChunker(PdfParser())
DOC = DocxChunker(DocxParser())
EXC = ExcelChunker(ExcelParser())
PPT = PptChunker()


def chuck_doc(name):
    name = os.path.split(name)[-1].lower().split(".")[-1]
    if name.find("pdf") >= 0: return PDF(name)
    if name.find("doc") >= 0: return DOC(name)
    if name.find("xlsx") >= 0: return EXC(name)
    if name.find("ppt") >= 0: return PDF(name)
    if name.find("pdf") >= 0: return PPT(name)
    
    if re.match(r"(txt|csv)", name): return TextChunker(name)


def collect(comm, mod, tm):
    sql = f"""
    select 
    did,
    uid,
    doc_name,
    location,
    updated_at
    from docinfo
    where 
    updated_at >= '{tm}' 
    and kb_progress = 0
    and type = 'doc'
    and MOD(uid, {comm}) = {mod}
    order by updated_at asc
    limit 1000
    """
    df = PG.select(sql)
    df = df.fillna("")
    mtm = str(df["updated_at"].max())[:19]
    print("TOTAL:", len(df), "To: ", mtm)
    return df, mtm


def set_progress(did, prog, msg):
    sql = f"""
    update docinfo set kb_progress={prog}, kb_progress_msg='{msg}' where did={did}
    """
    PG.update(sql)


def build(row):
    if row["size"] > 256000000:
        set_progress(row["did"], -1, "File size exceeds( <= 256Mb )")
        return  []
    doc = {
        "doc_id": row["did"],
        "title_tks": huqie.qie(os.path.split(row["location"])[-1]),
        "updated_at": row["updated_at"]
    }
    random.seed(time.time())
    set_progress(row["did"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
    obj = chuck_doc(row["location"])
    if not obj: 
        set_progress(row["did"], -1, "Unsuported file type.")
        return  []

    set_progress(row["did"], random.randint(20, 60)/100.)

    output_buffer = BytesIO()
    docs = []
    md5 = hashlib.md5()
    for txt, img in obj.text_chunks:
        d = copy.deepcopy(doc)
        md5.update((txt + str(d["doc_id"])).encode("utf-8"))
        d["_id"] = md5.hexdigest()
        d["content_ltks"] = huqie.qie(txt)
        d["docnm_kwd"] = rmSpace(d["docnm_tks"])
        if not img:
            docs.append(d)
            continue
        img.save(output_buffer, format='JPEG')
        d["img_bin"] = base64.b64encode(output_buffer.getvalue())
        docs.append(d)

    for arr, img in obj.table_chunks:
        for i, txt in enumerate(arr):
            d = copy.deepcopy(doc)
            d["content_ltks"] = huqie.qie(txt)
            md5.update((txt + str(d["doc_id"])).encode("utf-8"))
            d["_id"] = md5.hexdigest()
            if not img:
                docs.append(d)
                continue
            img.save(output_buffer, format='JPEG')
            d["img_bin"] = base64.b64encode(output_buffer.getvalue())
            docs.append(d)
    set_progress(row["did"], random.randint(60, 70)/100., "Finished slicing. Start to embedding the content.")

    return docs


def index_name(uid):return f"docgpt_{uid}"

def init_kb(row):
    idxnm = index_name(row["uid"])
    if ES.indexExist(idxnm): return
    return ES.createIdx(idxnm, json.load(open("res/mapping.json", "r")))


model = None
def embedding(docs):
    global model
    tts = model.encode([rmSpace(d["title_tks"]) for d in docs])
    cnts = model.encode([rmSpace(d["content_ltks"]) for d in docs])
    vects = 0.1 * tts + 0.9 * cnts
    assert len(vects) == len(docs)
    for i,d in enumerate(docs):d["q_vec"] = vects[i].tolist()
    for d in docs:
        set_progress(d["doc_id"], random.randint(70, 95)/100., 
                     "Finished embedding! Start to build index!")


def main(comm, mod):
    tm_fnm = f"res/{comm}-{mod}.tm"
    tmf = open(tm_fnm, "a+")
    tm = findMaxDt(tm_fnm)
    rows, tm = collect(comm, mod, tm)
    for r in rows:
        if r["is_deleted"]: 
            ES.deleteByQuery(Q("term", dock_id=r["did"]), index_name(r["uid"]))
            continue

        cks = build(r)
        ## TODO: exception handler
        ## set_progress(r["did"], -1, "ERROR: ")
        embedding(cks)
        if cks: init_kb(r)
        ES.bulk(cks, index_name(r["uid"]))
        tmf.write(str(r["updated_at"]) + "\n")
    tmf.close()


if __name__ == "__main__":
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    main(comm, rank)