dmibor commited on
Commit
61cffc1
·
1 Parent(s): cd39c4a

основной код поиска и вопросов/ответов

Browse files
Files changed (3) hide show
  1. __pycache__/search_core.cpython-312.pyc +0 -0
  2. app.py +43 -71
  3. search_core.py +632 -0
__pycache__/search_core.cpython-312.pyc ADDED
Binary file (28.2 kB). View file
 
app.py CHANGED
@@ -4,87 +4,59 @@ import torch
4
  import numpy as np
5
  import os
6
  import glob
7
- from pathlib import Path
8
- from PIL import Image
9
- import chromadb
10
- import boto3
11
- import botocore
12
- from io import BytesIO
13
 
 
 
 
 
 
 
 
 
 
14
  st.set_page_config(
15
- page_title="Поиск изображений Google Open Images по текстовому запросу",
16
  page_icon="🤖",
17
  layout="wide",
18
  initial_sidebar_state="expanded"
19
  )
20
-
21
- BUCKET_NAME = 'open-images-dataset'
22
- DOWNLOAD_FOLDER='ds_download'
23
- SPLIT='validation'
24
-
25
- device = "cuda" if torch.cuda.is_available() else "cpu"
26
- model= None
27
- preprocess = None
28
- try:
29
- model, preprocess = clip.load("ViT-B/32", device=device)
30
- except:
31
- st.write("Exception loading model")
32
-
33
- collection_path = "image_embeddings_collection.chroma"
34
- chroma_client = chromadb.PersistentClient(path=collection_path)
35
- image_collection= None
36
- try:
37
- image_collection = chroma_client.get_or_create_collection("image" , metadata={"hnsw:space": "cosine"})
38
- except:
39
- st.write("Exception loading collection")
40
-
41
- #if os.path.isdir(DOWNLOAD_FOLDER)==False:
42
- #os.mkdir(DOWNLOAD_FOLDER)
43
-
44
- num_embeddings = image_collection.count()
45
-
46
  # Main page heading
47
- st.title("Поиск изображений Google Open Images по текстовому запросу")
48
 
 
 
49
  # Sidebar
50
- st.sidebar.header("Настройки поиска")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
- st.sidebar.write(f"Количество изображений в БД: {num_embeddings}")
54
- text_input = st.sidebar.text_input(label='Введите запрос:', value='kite in the sky')
55
- search_files_cnt = int(st.sidebar.slider(label="Количество изображений", min_value=1, max_value=10, value=2))
56
- searchStarted = st.sidebar.button('Искать')
57
 
58
- col1, col2 = st.columns(2)
59
- if searchStarted==True:
60
- text_embedding = clip.tokenize(text_input).to(device)
61
- text_features = model.encode_text(text_embedding).detach().cpu().numpy()
62
- result = image_collection.query(text_features, n_results=search_files_cnt)
63
- bucket = boto3.resource('s3',
64
- config=botocore.config.Config(
65
- signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME)
66
- cnt=0
67
- for i in result['metadatas'][0]:
68
- try:
69
- filename= Path(i['name'])
70
- image_id= filename.with_suffix('')
71
- filepath= i['path']
72
- #down_file_path= os.path.join(DOWNLOAD_FOLDER, f'{image_id}.jpg')
73
- #bucket.download_file(f'{SPLIT}/{image_id}.jpg', down_file_path)
74
- #img = Image.open(down_file_path)
75
- object_key= f'{SPLIT}/{image_id}.jpg'
76
- image_data = BytesIO()
77
- bucket.download_fileobj(object_key, image_data)
78
- image_data.seek(0)
79
- img = Image.open(image_data)
80
- col_ref= col1
81
- if ((cnt+1) % 2) == 0:
82
- col_ref= col2
83
- with col_ref:
84
- st.write('image_id:', image_id)
85
- st.write('distance:', result['distances'][0][cnt])
86
- st.image(img, use_column_width=True)
87
- except botocore.exceptions.ClientError as exception:
88
- st.write(str(exception))
89
- cnt=cnt+1
90
 
 
4
  import numpy as np
5
  import os
6
  import glob
7
+ from search_core import make_search_index, make_search_index_qa, search_query_all, answer_question
8
+ import pandas as pd
9
+ import json
 
 
 
10
 
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model= None
13
+ #try:
14
+ #model, preprocess = clip.load("ViT-B/32", device=device)
15
+ #except:
16
+ #st.write("Exception loading model")
17
+
18
+
19
+ # Setting page layout
20
  st.set_page_config(
21
+ page_title="Поиск по публикации/вопросы-ответы",
22
  page_icon="🤖",
23
  layout="wide",
24
  initial_sidebar_state="expanded"
25
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Main page heading
27
+ #st.title("Поиск изображений Google Open Images по текстовому запросу")
28
 
29
+ searchStarted= False
30
+ qaStarted= False
31
  # Sidebar
32
+ with st.sidebar:
33
+ tab1, tab2 = st.sidebar.tabs(["Поиск по публикации", "Вопросы-ответы"])
34
+
35
+ with tab1:
36
+ st.header("Поиск по публикации")
37
+ search_input = st.text_input(label='Введите запрос:', value='аккумуляторная батарея')
38
+ searchStarted = st.button('Искать')
39
+
40
+ with tab2:
41
+ st.header("Вопросы-ответы")
42
+ qa_input = st.text_input(label='Введите вопрос:', value='Какой ресурс до первого ремонта?')
43
+ qaStarted = st.button('Узнать ответ')
44
+
45
+ if searchStarted==True:
46
+ st.header("Результаты поиска")
47
+ search_result= search_query_all(search_input)
48
+ df = pd.DataFrame(pd.json_normalize(search_result))
49
+ df.columns=['Параграф модуля данных', 'Код МД']
50
+ st.table(df)
51
+ if qaStarted==True:
52
+ st.header("Ответ")
53
+ mode_string = 'strict'
54
+ model_string = '1'
55
+ answer= answer_question(qa_input, mode_string, model_string)
56
+ df = pd.DataFrame(pd.json_normalize(answer))
57
+ df.columns=['Уверенность', 'Ответ', 'Код МД']
58
+ st.table(df)
59
 
60
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
search_core.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lxml.etree as ET
3
+ import pandas as pd
4
+
5
+ import numpy as np
6
+ from tensorflow.keras.preprocessing.text import Tokenizer
7
+ import spacy
8
+ from tqdm import tqdm
9
+ import configparser
10
+ import pickle
11
+ import re
12
+ from transformers import pipeline
13
+ import torch
14
+ from tqdm import tqdm
15
+
16
+ import Stemmer
17
+ #stemmer= Stemmer.Stemmer('ru')#russian
18
+ stemmer= Stemmer.Stemmer('en')#english
19
+ import json
20
+
21
+ #exclude_tags=['graphic', 'figure']
22
+ #include_tags=['note', 'notePara', 'para']
23
+
24
+ exclude_tags=['graphic']
25
+ include_tags=['note', 'notePara', 'para', 'title', 'warningAndCautionPara', 'techName', 'infoName']
26
+ add_colon_tags=['title', 'techName']
27
+ make_lower_parent_tags=['listItemDefinition']
28
+ PARSE_PATHS=['//dmodule/content[last()]/procedure[last()]/preliminaryRqmts[last()]',
29
+ '//dmodule/content[last()]/procedure[last()]/mainProcedure[last()]',
30
+ '//dmodule/content[last()]/description[last()]',
31
+ '//dmodule/content[last()]/crew[last()]/crewRefCard[last()]/crewDrill[last()]',
32
+ '//dmodule/identAndStatusSection[last()]/dmAddress[last()]/dmAddressItems[last()]/dmTitle[last()]']
33
+
34
+ PERSCENTAGE_IN_RATIO=0.5
35
+ THRESHOLD=0.1
36
+
37
+ global nlp, tokenizer_search, tokenizer_qa, device
38
+ global search_df, qa_df, SEARCH_DATA
39
+ global index_data_loaded, qa_index_data_loaded, qa_model_initialized
40
+ global qa_model, qa_model_num
41
+
42
+ PUBLICATION_DEMO_RU_PATH="publications/Demo publication in Russian"
43
+ PUBLICATION_DEMO_EN_PATH="publications/Bike Data Set for Release number 5.0"
44
+ PUBLICATION_PATH=PUBLICATION_DEMO_RU_PATH
45
+ TOKENIZER_SEARCH_FILENAME='tokenizer_search.pickle'
46
+ TOKENIZER_QA_FILENAME='tokenizer_qa.pickle'
47
+ INDEX_FOLDER= PUBLICATION_PATH+ os.sep+ "index"
48
+ #print('INDEX_FOLDER:', INDEX_FOLDER)
49
+ TOKENIZER_SEARCH_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_SEARCH_FILENAME
50
+ TOKENIZER_QA_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_QA_FILENAME
51
+ #print('TOKENIZER_SEARCH_PATH:', TOKENIZER_SEARCH_PATH)
52
+ PUBLICATION_LANGUAGE="ru"
53
+
54
+ nlp=None
55
+ search_df=None
56
+ qa_df=None
57
+ index_data_loaded=False
58
+ qa_index_data_loaded=False
59
+ SEARCH_DATA= None
60
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61
+ qa_model_initialized=False
62
+
63
+
64
+ def get_xpath_one(tree, xpath):
65
+ res = tree.xpath(xpath)
66
+ if res:
67
+ return res[0]
68
+
69
+ def get_dmc(doc):
70
+ dmc=""
71
+ node= get_xpath_one(doc, '//dmCode')
72
+ dmc='DMC-'+'-'.join([node.get('modelIdentCode'), \
73
+ node.get('itemLocationCode'), \
74
+ node.get('systemCode'), \
75
+ node.get('subSystemCode')+node.get('subSubSystemCode'), \
76
+ node.get('assyCode'),\
77
+ node.get('disassyCode')+node.get('disassyCodeVariant'),\
78
+ node.get('infoCode')+node.get('infoCodeVariant'),\
79
+ node.get('systemDiffCode')])
80
+
81
+ #print('dmc: ', dmc)
82
+ return dmc
83
+
84
+ def is_float(string):
85
+ if string.replace(".", "").replace(",", "").replace("+", "").replace("-", "").isnumeric():
86
+ return True
87
+ else:
88
+ return False
89
+
90
+
91
+ def stringify_children(node, texts, pis, excludeDigits=True):
92
+ s = node.text
93
+ if (s != None) and (s.isspace()==False):
94
+ if excludeDigits:
95
+ if is_float(s)==False:
96
+ texts.add(s)
97
+ else:
98
+ texts.add(s)
99
+ for child in node:
100
+ if child.tag not in exclude_tags:
101
+ if child not in pis:
102
+ stringify_children(child, texts, pis)
103
+ return
104
+
105
+ def stringify_children_incl(node, texts, pis, make_lower=False):
106
+ ET.strip_tags(node, 'internalRef')
107
+ ET.strip_tags(node, 'emphasis')
108
+ s = node.text
109
+ if s and make_lower==True:
110
+ s= s.lower()
111
+ if s and node.tag in add_colon_tags:
112
+ s=s+':'
113
+ #print('s', s)
114
+ clear_s= clear_text(s)
115
+ if (s != None) and (s.isspace()==False) and (clear_s!='') and (clear_s):
116
+ print('s:', s)
117
+ print('clear_text(s):', clear_text(s))
118
+ texts.append(s)
119
+
120
+ for child in node:
121
+ #print('child.tag:', child.tag)
122
+ if (len(child.getchildren())>0) or (child.tag in include_tags):
123
+ if (child not in pis) and (child.tag not in exclude_tags):
124
+ make_lower=False
125
+ if node.tag in make_lower_parent_tags:
126
+ make_lower=True
127
+ stringify_children_incl(child, texts, pis, make_lower)
128
+ return
129
+
130
+ def clear_text(text):
131
+ #print('clear_text!')
132
+ clean_text = re.sub(r'(?:(?!\u0301)[\W\d_])+', ' ', str(text).lower())
133
+ return clean_text
134
+
135
+ def lemmatize_and_stemm(df_r):
136
+ global nlp
137
+ #print('lemmatize_and_stemm!')
138
+ disabled_pipes = [ "parser", "ner"]
139
+ if PUBLICATION_LANGUAGE=="ru":
140
+ nlp = spacy.load('ru_core_news_sm', disable=disabled_pipes)#english - en_core_web_sm
141
+ else:
142
+ nlp = spacy.load('en_core_web_sm', disable=disabled_pipes)#russian - ru_core_news_sm
143
+
144
+ lemm_texts = []
145
+ stem_texts=[]
146
+
147
+ for doc in tqdm(nlp.pipe(df_r['lemm_text'].values, disable = disabled_pipes), total=df_r.shape[0]):
148
+ lemm_text = " ".join([i.lemma_ for i in doc])
149
+ lemm_texts.append(lemm_text)
150
+ stem_text = " ".join([stemmer.stemWord(i.text) for i in doc])
151
+ stem_texts.append(stem_text)
152
+
153
+ df_r['lemm_text']= lemm_texts
154
+ df_r['stem_text']= stem_texts
155
+ df_r=df_r.drop_duplicates()
156
+ #print('lemmatization and stemming success!')
157
+ return
158
+
159
+ def tokenize_text(df_r, save_filename):
160
+ #global tokenizer_search
161
+ #print('tokenize_text!')
162
+
163
+ #try:
164
+ #with open('tokenizer.pickle', 'rb') as handle:
165
+ #tokenizer = pickle.load(handle)
166
+ #print('tokenizer loaded from file')
167
+ #except Exception as e:
168
+ tokenizer = Tokenizer(oov_token='<oov>')
169
+ print('tokenizer created')
170
+
171
+ texts= pd.concat([df_r['lemm_text'],df_r['stem_text']])
172
+ tokenizer.fit_on_texts(texts)
173
+ total_words = len(tokenizer.word_index) + 1
174
+ print("Total number of words: ", total_words)
175
+ with open(save_filename, 'wb') as handle:
176
+ pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)
177
+ return tokenizer
178
+
179
+ def make_final_index(df_r, tokenizer, index_filename='search_index.csv', includePlainText=True):
180
+ print('make_final_index!')
181
+ tokens=[]
182
+ labels=[]
183
+ dmcs=[]
184
+ texts=[]
185
+ for index, row in tqdm(df_r.iterrows()):
186
+ #print('row:', row)
187
+ text= row['text']
188
+ lemm_token= tokenizer.texts_to_sequences([row['lemm_text']])[0]
189
+ stem_token= tokenizer.texts_to_sequences([row['stem_text']])[0]
190
+ dmc= row['DMC']
191
+ #print(str(row['label_enc'])+':'+dmc)
192
+ tokens.append(lemm_token)
193
+ labels.append(row['label_enc'])
194
+ dmcs.append(dmc)
195
+ texts.append(text)
196
+ tokens.append(stem_token)
197
+ labels.append(row['label_enc'])
198
+ dmcs.append(dmc)
199
+ texts.append(text)
200
+ columns= ['tokens', 'labels', 'DMC']
201
+ data= {'tokens': tokens, 'labels': labels, 'DMC': dmcs}
202
+ if includePlainText==True:
203
+ columns= ['tokens', 'labels', 'text', 'DMC']
204
+ data= {'tokens': tokens, 'labels': labels, 'text': texts, 'DMC': dmcs}
205
+ s_df= pd.DataFrame(columns=columns, data= data)
206
+ s_df= s_df.loc[s_df.astype(str).drop_duplicates().index]
207
+ print('final index info:')
208
+ print(s_df.info())
209
+ s_df.to_csv(index_filename, sep=';', index=False)
210
+ #print(f'results saved to {index_filename}')
211
+ return s_df
212
+
213
+ def make_search_index(path):
214
+ global nlp, tokenizer_search, search_df, index_data_loaded
215
+ #print('make_search_index!')
216
+ directory= path.replace('"', '')
217
+ #print(f'path: {directory}')
218
+ df_r= pd.DataFrame(columns=['text'])
219
+
220
+ for file in os.listdir(directory):
221
+ filename = file#os.fsdecode(file)
222
+ if 'PMC' in filename:
223
+ continue
224
+ #print('filename: ', filename)
225
+ if filename.lower().endswith(".xml")==False:
226
+ continue
227
+ filepath= directory+ os.sep+ filename
228
+ print('filepath:', filepath)
229
+
230
+ doc= ET.parse(filepath)
231
+ dmc= get_dmc(doc)
232
+
233
+ file_texts=set()
234
+ pis = doc.xpath("//processing-instruction()")
235
+ for node in doc.xpath('//dmodule'):
236
+ stringify_children(node, file_texts, pis)
237
+
238
+ #print('file_texts:', file_texts)
239
+ df= pd.DataFrame(columns=['text'], data= file_texts)
240
+ df['DMC']= dmc
241
+ df_r= pd.concat([df_r, df], ignore_index=True)
242
+ print('parsing results:')
243
+ print(df_r.info())
244
+ #PARSING_INDEX_FILENAME='strings_with_DMC.csv'
245
+ #print(f'parsing results saved to: {PARSING_INDEX_FILENAME}')
246
+ #df_r.to_csv(PARSING_INDEX_FILENAME, index=False, sep = ';')
247
+
248
+ df_r['lemm_text']=df_r['text'].apply(clear_text)
249
+ lemmatize_and_stemm(df_r)
250
+ df_r= df_r.reset_index(drop=True)
251
+ df_r['label_enc']= df_r.index
252
+ tokenizer_search= tokenize_text(df_r, TOKENIZER_SEARCH_PATH)
253
+ #print('tokenizer before make_final_index:', tokenizer_search)
254
+ search_df= make_final_index(df_r, tokenizer_search)
255
+ index_data_loaded= True
256
+ return len(search_df)
257
+
258
+ def make_search_index_qa(path):
259
+ global nlp, tokenizer_qa, qa_df, qa_index_data_loaded
260
+ #print('make_search_index_qa!')
261
+ directory= path.replace('"', '')
262
+ #print(f'path: {directory}')
263
+ df_r= pd.DataFrame(columns=['text'])
264
+
265
+ for file in os.listdir(directory):
266
+ filename = file#os.fsdecode(file)
267
+ if 'PMC' in filename:
268
+ continue
269
+ #print('filename: ', filename)
270
+ if filename.lower().endswith(".xml")==False:
271
+ continue
272
+ filepath= directory+ os.sep+ filename
273
+ #print('filepath:', filepath)
274
+
275
+ doc= ET.parse(filepath)
276
+ dmc= get_dmc(doc)
277
+
278
+ paths= PARSE_PATHS
279
+
280
+ pis = doc.xpath("//processing-instruction()")
281
+ for pi in pis:
282
+ if pi.getparent()!=None:
283
+ ET.strip_tags(pi.getparent(), pi.tag)
284
+
285
+ cntr=1
286
+ for expr in paths:
287
+ try:
288
+ x_path_result = doc.xpath(expr)
289
+ except ET.XPathEvalError:
290
+ continue
291
+
292
+ if not x_path_result:
293
+ continue
294
+ file_texts=[]
295
+ dmc_with_chapter= f'{dmc}({cntr})'
296
+ for node in x_path_result:#doc.xpath(expr):
297
+ stringify_children_incl(node, file_texts, pis)
298
+ cntr=cntr+1
299
+ #print('file_texts:',file_texts)
300
+ #print('file_texts len:',len(file_texts))
301
+ if len(file_texts)==0:
302
+ continue
303
+ concat_texts=[' \n '.join(file_texts)]
304
+ #print('file_texts:', file_texts)
305
+
306
+ #df= pd.DataFrame(columns=['text'], data= file_texts)
307
+ df= pd.DataFrame(columns=['text'], data= concat_texts)
308
+ df['DMC']= dmc_with_chapter
309
+ df_r= pd.concat([df_r, df], ignore_index=True)
310
+ #print('parsing results:')
311
+ #print(df_r.info())
312
+ #PARSING_INDEX_FILENAME='strings_with_DMC.csv'
313
+ #print('parsing results saved to: {PARSING_INDEX_FILENAME}')
314
+ #df_r.to_csv(PARSING_INDEX_FILENAME, index=False, sep = ';')
315
+
316
+ df_r['lemm_text']=df_r['text'].apply(clear_text)
317
+ lemmatize_and_stemm(df_r)
318
+ df_r= df_r.reset_index(drop=True)
319
+ df_r['label_enc']= df_r.index
320
+ tokenizer_qa= tokenize_text(df_r, TOKENIZER_QA_PATH)
321
+ qa_df= make_final_index(df_r, tokenizer_qa, index_filename='qa_index.csv')
322
+ qa_index_data_loaded= True
323
+ return len(qa_df)
324
+
325
+ def convert2list(string):
326
+ x = json.loads(string)
327
+ lst=[]
328
+ for n in x:
329
+ #print(x)
330
+ lst.append(int(n))
331
+ return lst
332
+
333
+ def load_index_data():
334
+ global nlp, tokenizer_search, search_df, index_data_loaded
335
+ print('load_index_data!')
336
+ #spacy
337
+ disabled_pipes = [ "parser", "ner"]
338
+ if PUBLICATION_LANGUAGE=="ru":
339
+ nlp = spacy.load('ru_core_news_sm', disable=disabled_pipes)#english - en_core_web_sm
340
+ else:
341
+ nlp = spacy.load('en_core_web_sm', disable=disabled_pipes)#russian - ru_core_news_sm
342
+ #print('spacy loaded:', nlp)
343
+ #tokenizer
344
+ with open(TOKENIZER_SEARCH_PATH, 'rb') as handle:
345
+ tokenizer_search = pickle.load(handle)
346
+ #print('tokenizer loaded:', tokenizer)
347
+ #index
348
+ search_index_path= INDEX_FOLDER+os.sep+'search_index.csv'
349
+ search_df= pd.read_csv(search_index_path, sep=';')
350
+ print('index file loaded:', search_df.info())
351
+ search_df['tokens']= search_df['tokens'].apply(convert2list)
352
+ index_data_loaded= True
353
+ return nlp, tokenizer_search, search_df
354
+
355
+ def load_index_data_qa():
356
+ global nlp, tokenizer_qa, qa_df, qa_index_data_loaded
357
+ #print('load_index_data_qa!')
358
+ #spacy
359
+ disabled_pipes = [ "parser", "ner"]
360
+ if PUBLICATION_LANGUAGE=="ru":
361
+ nlp = spacy.load('ru_core_news_sm', disable=disabled_pipes)#english - en_core_web_sm
362
+ else:
363
+ nlp = spacy.load('en_core_web_sm', disable=disabled_pipes)#russian - ru_core_news_sm
364
+ print('spacy loaded:', nlp)
365
+ #tokenizer
366
+ with open(TOKENIZER_QA_PATH, 'rb') as handle:
367
+ tokenizer_qa = pickle.load(handle)
368
+ #print('tokenizer loaded:', tokenizer_qa)
369
+ #index
370
+ qa_index_path= INDEX_FOLDER+os.sep+'qa_index.csv'
371
+ qa_df= pd.read_csv(qa_index_path, sep=';')
372
+ #print('index qa file loaded:', qa_df.info())
373
+ qa_df['tokens']= qa_df['tokens'].apply(convert2list)
374
+ qa_index_data_loaded= True
375
+ return nlp, tokenizer_qa, qa_df
376
+
377
+ def customIsIn(x , tokens):
378
+ result= False
379
+ cnt_in=0
380
+ for val in x:
381
+ if val in tokens:
382
+ cnt_in+=1
383
+ PERSCENTAGE_IN= cnt_in/len(tokens)
384
+ if PERSCENTAGE_IN>=PERSCENTAGE_IN_RATIO:
385
+ return True
386
+ return result
387
+
388
+ def get_lemmed_stemmed_text(text):
389
+ global nlp
390
+ #print('nlp loaded or not:', nlp)
391
+ if PUBLICATION_LANGUAGE=="ru":
392
+ spacy_stopwords = spacy.lang.ru.stop_words.STOP_WORDS #russian
393
+ else:
394
+ spacy_stopwords = nlp.Defaults.stop_words #english
395
+ #print('spacy_stopwords:', spacy_stopwords)
396
+ doc = nlp(clear_text(text))
397
+ # Remove stop words
398
+ doc_cleared = [token for token in doc if not token.is_stop]
399
+ #print('doc_cleared:', doc_cleared)
400
+ lemm_text = " ".join([i.lemma_ for i in doc_cleared if not i.lemma_ in spacy_stopwords])
401
+ print(f'lemm_text: {lemm_text}')
402
+ stem_text = " ".join([stemmer.stemWord(i.text) for i in doc_cleared if not stemmer.stemWord(i.text) in spacy_stopwords])
403
+ print(f'stem_text: {stem_text}')
404
+ return lemm_text, stem_text
405
+
406
+ def search_query_any(query, df=None, tokenizer=None):
407
+ global SEARCH_DATA, search_df, index_data_loaded
408
+ print('search_query_any!')
409
+ print(f'query: {query}')
410
+ if index_data_loaded==False:
411
+ load_index_data()
412
+ SEARCH_DATA= df
413
+ if df is None:
414
+ if index_data_loaded==False:
415
+ load_index_data()
416
+ SEARCH_DATA=search_df
417
+ lemm_text, stem_text= get_lemmed_stemmed_text(query)
418
+ if tokenizer==None:
419
+ tokenizer= tokenizer_search
420
+ token_list = tokenizer.texts_to_sequences([lemm_text])[0]
421
+ #print(f'token_list: {token_list}')
422
+ token_list_stem = tokenizer.texts_to_sequences([stem_text])[0]
423
+ #print(f'token_list stem: {token_list_stem}')
424
+
425
+ mask1 = SEARCH_DATA.tokens.apply(lambda x: customIsIn(x, token_list))
426
+ indexes1= SEARCH_DATA[mask1]['labels'].unique()
427
+ mask2= SEARCH_DATA.tokens.apply(lambda x: customIsIn(x, token_list_stem))
428
+ indexes2= SEARCH_DATA[mask2]['labels'].unique()
429
+ indexes= np.concatenate((indexes1, indexes2), axis=None)
430
+ results_df= SEARCH_DATA[SEARCH_DATA['labels'].isin(indexes)].drop(['tokens', 'labels'], axis=1)
431
+ results_df= results_df.drop_duplicates()
432
+ result=[]
433
+ regex = re.compile(r'\([^)]*\)')
434
+ for index, row in results_df.iterrows():
435
+ text= row['text']
436
+ dmc= row['DMC']
437
+ dmc= re.sub(regex, '', dmc)
438
+ result.append({'text': text, 'DMC':dmc})
439
+ return result
440
+
441
+ def search_query_all(query, df=None, tokenizer=None):
442
+ global SEARCH_DATA, search_df, index_data_loaded
443
+ print('search_query_all!')
444
+ print(f'query: {query}')
445
+ SEARCH_DATA= df
446
+ if df is None:
447
+ if index_data_loaded==False:
448
+ load_index_data()
449
+ SEARCH_DATA=search_df
450
+ print('SEARCH_DATA:', SEARCH_DATA.head())
451
+
452
+ print('nlp loaded or not:', nlp)
453
+
454
+ doc = nlp(clear_text(query))
455
+ lemm_text, stem_text= get_lemmed_stemmed_text(query)
456
+ if tokenizer==None:
457
+ tokenizer= tokenizer_search
458
+ token_list = tokenizer.texts_to_sequences([lemm_text])[0]
459
+ print(f'token_list: {token_list}')
460
+ token_list_stem = tokenizer.texts_to_sequences([stem_text])[0]
461
+ print(f'token_list stem: {token_list_stem}')
462
+
463
+ mask1= SEARCH_DATA['tokens'].map(set(token_list).issubset)
464
+ mask2= SEARCH_DATA['tokens'].map(set(token_list_stem).issubset)
465
+ indexes1= SEARCH_DATA[mask1]['labels'].unique()
466
+ indexes2= SEARCH_DATA[mask2]['labels'].unique()
467
+ indexes= np.concatenate((indexes1, indexes2), axis=None)
468
+ results_df= SEARCH_DATA[SEARCH_DATA['labels'].isin(indexes)].drop(['tokens', 'labels'], axis=1)
469
+ results_df= results_df.drop_duplicates()
470
+ result=[]
471
+ regex = re.compile(r'\([^)]*\)')
472
+ for index, row in results_df.iterrows():
473
+ text= row['text']
474
+ dmc= row['DMC']
475
+ dmc= re.sub(regex, '', dmc)
476
+ result.append({'text': text, 'DMC':dmc})
477
+ return result
478
+
479
+ def concat_by_DMC(s_df):
480
+ #print('concat_by_DMC!')
481
+ #print(s_df.head())
482
+ #объединяем лемматизированную и стеммизированную часть датасета
483
+ concat_tokens=[]
484
+ for label in s_df['labels'].unique():
485
+ tokens_lists= s_df[s_df['labels']==label]['tokens'].to_list()
486
+ joined_lst=[]
487
+ for lst in tokens_lists:
488
+ joined_lst+= lst
489
+ concat_tokens.append(joined_lst)
490
+ #print(concat_tokens[:5])
491
+ df= s_df.drop('tokens', axis=1)
492
+ df= df.drop_duplicates()
493
+ df['tokens']=concat_tokens
494
+
495
+ #объединяем тексты и токены по DMC
496
+ concat_tokens=[]
497
+ DMCs=[]
498
+ texts=[]
499
+ for dmc_code in df['DMC'].unique():
500
+ DMCs.append(dmc_code)
501
+ #объединяем списки токенов для одного модуля данных (DMC)
502
+ tokens_lists= df[df['DMC']==dmc_code]['tokens'].to_list()
503
+ joined_token_lst=[]
504
+ for lst in tokens_lists:
505
+ joined_token_lst+= lst
506
+ concat_tokens.append(joined_token_lst)
507
+ #объединяем тексты
508
+ text_list= df[df['DMC']==dmc_code]['text'].to_list()
509
+ concat_text=' \n '.join(str(txt) for txt in text_list)
510
+ texts.append(concat_text)
511
+ #print('concat_tokens',len(concat_tokens))
512
+ #print('DMCs',len(DMCs))
513
+ #print('texts',len(texts))
514
+ df= pd.DataFrame(columns=['DMC'], data=DMCs)
515
+ df['text']= texts
516
+ df['tokens']= concat_tokens
517
+ df['labels']= df.index
518
+ #print(df.head())
519
+ return df
520
+
521
+
522
+ def initialize_qa_model(model):
523
+ global qa_df, qa_model, qa_model_num
524
+ qa_model_num= model
525
+ print('initialize_qa_model!')
526
+ if model==1 or str(model)=="1":
527
+ qa_model= pipeline("question-answering", "dmibor/ietm_search_and_qa", device=device)
528
+ print('initialized model number 1!')
529
+ else:#model==2 (базовая)
530
+ qa_model= pipeline("question-answering", "timpal0l/mdeberta-v3-base-squad2", device=device)
531
+ print('initialized model number 2!')
532
+ if qa_index_data_loaded==False:
533
+ load_index_data_qa()
534
+ #print('len(qa_df)', len(qa_df))
535
+ qa_df= concat_by_DMC(qa_df)
536
+ #qa_df.to_csv('concat_index.csv', sep=';', index=False)
537
+ #print('concat_by_DMC len(qa_df)', len(qa_df))
538
+ qa_model_initialized=True
539
+
540
+ def get_best_and_longest_result(model_results, threshold, mode):
541
+ print('get_best_and_longest_result!')
542
+ print('mode:', mode)
543
+ best_result=None
544
+ longest_result=None
545
+ if(type(model_results)!= list):
546
+ return best_result, longest_result
547
+ best_result= model_results[0]
548
+ best_result_answer= best_result['answer']
549
+ print('best_result_answer: ',best_result_answer)
550
+ best_answer_cleaned= (re.sub(r"[\W\d_]+$", "", best_result_answer)).strip()
551
+ print('best_answer_cleaned: ',best_answer_cleaned)
552
+ longest_answer=''
553
+ longest_answer_len= len(best_answer_cleaned)
554
+ longest_result= best_result
555
+ print("type(mode)", type(mode))
556
+ print("mode=='strict'", mode=='strict')
557
+ print("mode==\"strict\"", mode=="strict")
558
+ if mode=='strict':
559
+ return best_result, longest_result
560
+ if best_result['score']>=threshold:
561
+ print('best_result_answer: ',best_answer_cleaned)
562
+ print('best_result score:', best_result['score'])
563
+ for result in model_results:
564
+ answer= result['answer']
565
+ answer_cleaned= re.sub(r"[\W\d_]+$", "", answer).strip()
566
+ #print('answer_cleaned: ',answer_cleaned)
567
+ if best_answer_cleaned in answer_cleaned:
568
+ if len(answer_cleaned)>longest_answer_len:
569
+ print('new longest answer: ',answer_cleaned)
570
+ print('longest score:', result['score'])
571
+ print()
572
+ longest_answer= answer_cleaned
573
+ longest_answer_len= len(answer_cleaned)
574
+ longest_result= result
575
+ #print('longest_answer:' , longest_answer)
576
+ return best_result, longest_result
577
+
578
+ def find_answer(question, context, threshold, max_answer_len=1000, top_k=20, verbose=True, mode='strict'):
579
+ print('find_answer!')
580
+ print('mode:', mode)
581
+ found_answer=False
582
+ #print('qa_model', qa_model)
583
+ model_results= qa_model(question = question, context = context, max_answer_len=max_answer_len, top_k=top_k)
584
+ #print(model_result)
585
+ best_result, longest_result= get_best_and_longest_result(model_results, threshold, mode)
586
+ #print('longest_result', longest_result)
587
+ if best_result['score']>=threshold:
588
+ longest_answer= longest_result['answer']
589
+ answer_cleaned= re.sub(r"[\W\d_]+$", "", longest_answer).strip()
590
+ if verbose==True:
591
+ print(f'Ответ (уверенность= {round(model_result['score'], 2)}): {answer_cleaned}')
592
+ longest_result['answer']= answer_cleaned
593
+ found_answer=True
594
+ if found_answer==False and verbose==True:
595
+ print('Ответ не найден!')
596
+ model_result= best_result
597
+ model_result['answer']= longest_result['answer']
598
+ return model_result
599
+
600
+ def answer_question(question, mode, model=1):
601
+ global qa_model_initialized, qa_model_num, tokenizer_qa
602
+ print('answer_question!')
603
+ if qa_model_initialized==False or model!= qa_model_num:
604
+ initialize_qa_model(model)
605
+ print(f'question: {question}')
606
+ print(f'mode: {mode}')
607
+ print(f'model: {qa_model}')
608
+
609
+ filtered_index= search_query_all(question, qa_df, tokenizer_qa)
610
+ threshold= THRESHOLD
611
+ #print('filtered_index все слова:', len(filtered_index))
612
+ if len(filtered_index)<1:
613
+ filtered_index= search_query_any(question, qa_df, tokenizer_qa)
614
+ threshold= THRESHOLD
615
+ #print('filtered_index любое слово:', len(filtered_index))
616
+
617
+ found_answer=False
618
+ best_answer=""
619
+ best_score=0
620
+ best_DMC=""
621
+
622
+ regex = re.compile(r'\([^)]*\)')
623
+ for indx in filtered_index:
624
+ result= find_answer(question, indx['text'], threshold=threshold, max_answer_len=1000, top_k=20, verbose=False, mode=mode)
625
+ if result['score']>best_score:
626
+ best_answer= result['answer']
627
+ best_score= result['score']
628
+ best_DMC= indx['DMC']
629
+ best_DMC= re.sub(regex, '', best_DMC)
630
+ result= [{'score': best_score, 'answer': best_answer, 'DMC': best_DMC}]
631
+ return result
632
+