Spaces:
Sleeping
Sleeping
some batch doc optimization, added S1000D bike example usage
Browse files- app.py +70 -21
- search_core.py +89 -37
app.py
CHANGED
@@ -15,35 +15,84 @@ st.set_page_config(
|
|
15 |
|
16 |
searchStarted= False
|
17 |
qaStarted= False
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
with st.sidebar:
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
with tab1:
|
23 |
-
st.
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
with tab2:
|
28 |
-
st.
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
if searchStarted==True:
|
33 |
-
st.
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
if qaStarted==True:
|
39 |
-
st.
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
|
|
|
15 |
|
16 |
searchStarted= False
|
17 |
qaStarted= False
|
18 |
+
|
19 |
+
LANGUAGE= 'language'
|
20 |
+
if LANGUAGE not in st.session_state:
|
21 |
+
st.session_state[LANGUAGE]= 'Русский'
|
22 |
+
|
23 |
+
if "visibility" not in st.session_state:
|
24 |
+
st.session_state.visibility = "visible"
|
25 |
+
st.session_state.disabled = False
|
26 |
+
st.session_state.horizontal = True
|
27 |
+
|
28 |
with st.sidebar:
|
29 |
+
st.session_state[LANGUAGE]= st.sidebar.radio(
|
30 |
+
"Язык/Language",
|
31 |
+
["Русский", "English"],
|
32 |
+
key="Русский",
|
33 |
+
label_visibility=st.session_state.visibility,
|
34 |
+
disabled=st.session_state.disabled,
|
35 |
+
horizontal=st.session_state.horizontal,
|
36 |
+
)
|
37 |
+
if st.session_state[LANGUAGE]== 'Русский':
|
38 |
+
st.sidebar.subheader('Демо-публикация: "Урал 44202-80М", 74 модуля данных, русский язык')
|
39 |
+
tab1, tab2 = st.sidebar.tabs(["Поиск по публикации", "Вопросы-ответы"])
|
40 |
+
else:
|
41 |
+
st.sidebar.subheader('Publication asset: "S1000D release 5.0 bike example", 101 data module, english language')
|
42 |
+
tab1, tab2 = st.sidebar.tabs(["Indexed search", "Question answering"])
|
43 |
|
44 |
with tab1:
|
45 |
+
if st.session_state[LANGUAGE]== 'Русский':
|
46 |
+
st.header("Поиск по публикации")
|
47 |
+
search_input = st.text_input(label='Введите запрос:', value='аккумуляторная батарея')
|
48 |
+
searchStarted = st.button('Искать')
|
49 |
+
else:
|
50 |
+
st.header("Publication content search")
|
51 |
+
search_input = st.text_input(label='Enter query:', value='bicycle wheel')
|
52 |
+
searchStarted = st.button('Search')
|
53 |
|
54 |
with tab2:
|
55 |
+
if st.session_state[LANGUAGE]== 'Русский':
|
56 |
+
st.header("Вопросы-ответы")
|
57 |
+
qa_input = st.text_input(label='Введите вопрос:', value='Какой ресурс до первого ремонта?')
|
58 |
+
#qa_input = st.text_input(label='Введите вопрос:', value='Что входит в состав системы предпускового подогрева?')
|
59 |
+
#qa_input = st.text_input(label='Введите вопрос:', value='Для чего нужен нагреватель с нагнетателем воздуха?')
|
60 |
+
qaStarted = st.button('Узнать ответ')
|
61 |
+
else:
|
62 |
+
st.header("Question answering")
|
63 |
+
qa_input = st.text_input(label='Enter question:', value='How many brake pads on the bicycle?')
|
64 |
+
qaStarted = st.button('Find out')
|
65 |
|
66 |
if searchStarted==True:
|
67 |
+
if st.session_state[LANGUAGE]== 'Русский':
|
68 |
+
st.header("Результаты поиска")
|
69 |
+
search_result= search_query_all(search_input, language="ru")
|
70 |
+
df = pd.DataFrame(pd.json_normalize(search_result))
|
71 |
+
df.columns=['Параграф модуля данных', 'Код МД']
|
72 |
+
st.table(df)
|
73 |
+
else:
|
74 |
+
st.header("Search results")
|
75 |
+
search_result= search_query_all(search_input, language="en")
|
76 |
+
df = pd.DataFrame(pd.json_normalize(search_result))
|
77 |
+
df.columns=['Data module paragraph', 'Data module code']
|
78 |
+
st.table(df)
|
79 |
if qaStarted==True:
|
80 |
+
if st.session_state[LANGUAGE]== 'Русский':
|
81 |
+
st.header("Ответ")
|
82 |
+
mode_string = 'strict'
|
83 |
+
model_string = '1'
|
84 |
+
answer= answer_question(qa_input, mode_string, model_string, language="ru")
|
85 |
+
df = pd.DataFrame(pd.json_normalize(answer))
|
86 |
+
df.columns=['Уверенность', 'Ответ', 'Код МД']
|
87 |
+
st.table(df)
|
88 |
+
else:
|
89 |
+
st.header("Answer")
|
90 |
+
mode_string = 'strict'
|
91 |
+
model_string = '1'
|
92 |
+
answer= answer_question(qa_input, mode_string, model_string, language="en")
|
93 |
+
df = pd.DataFrame(pd.json_normalize(answer))
|
94 |
+
df.columns=['Score', 'Answer', 'Data module code']
|
95 |
+
st.table(df)
|
96 |
|
97 |
|
98 |
|
search_core.py
CHANGED
@@ -32,6 +32,7 @@ PARSE_PATHS=['//dmodule/content[last()]/procedure[last()]/preliminaryRqmts[last(
|
|
32 |
|
33 |
PERSCENTAGE_IN_RATIO=0.5
|
34 |
THRESHOLD=0.1
|
|
|
35 |
|
36 |
global nlp, tokenizer_search, tokenizer_qa, device
|
37 |
global search_df, qa_df, SEARCH_DATA
|
@@ -44,9 +45,15 @@ PUBLICATION_PATH=PUBLICATION_DEMO_RU_PATH
|
|
44 |
TOKENIZER_SEARCH_FILENAME='tokenizer_search.pickle'
|
45 |
TOKENIZER_QA_FILENAME='tokenizer_qa.pickle'
|
46 |
INDEX_FOLDER= PUBLICATION_PATH+ os.sep+ "index"
|
|
|
|
|
47 |
#print('INDEX_FOLDER:', INDEX_FOLDER)
|
48 |
TOKENIZER_SEARCH_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_SEARCH_FILENAME
|
|
|
|
|
49 |
TOKENIZER_QA_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_QA_FILENAME
|
|
|
|
|
50 |
#print('TOKENIZER_SEARCH_PATH:', TOKENIZER_SEARCH_PATH)
|
51 |
PUBLICATION_LANGUAGE="ru"
|
52 |
|
@@ -334,6 +341,7 @@ def convert2list(string):
|
|
334 |
def load_index_data():
|
335 |
global nlp, tokenizer_search, search_df, index_data_loaded
|
336 |
print('load_index_data!')
|
|
|
337 |
#spacy
|
338 |
disabled_pipes = [ "parser", "ner"]
|
339 |
if PUBLICATION_LANGUAGE=="ru":
|
@@ -344,11 +352,18 @@ def load_index_data():
|
|
344 |
stemmer= Stemmer.Stemmer('en')#english
|
345 |
#print('spacy loaded:', nlp)
|
346 |
#tokenizer
|
347 |
-
|
348 |
-
|
|
|
|
|
|
|
|
|
349 |
#print('tokenizer loaded:', tokenizer)
|
350 |
#index
|
351 |
-
|
|
|
|
|
|
|
352 |
search_df= pd.read_csv(search_index_path, sep=';')
|
353 |
print('index file loaded:', search_df.info())
|
354 |
search_df['tokens']= search_df['tokens'].apply(convert2list)
|
@@ -368,11 +383,18 @@ def load_index_data_qa():
|
|
368 |
stemmer= Stemmer.Stemmer('en')#english
|
369 |
print('spacy loaded:', nlp)
|
370 |
#tokenizer
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
|
|
373 |
#print('tokenizer loaded:', tokenizer_qa)
|
374 |
#index
|
375 |
-
|
|
|
|
|
|
|
376 |
qa_df= pd.read_csv(qa_index_path, sep=';')
|
377 |
#print('index qa file loaded:', qa_df.info())
|
378 |
qa_df['tokens']= qa_df['tokens'].apply(convert2list)
|
@@ -445,17 +467,20 @@ def search_query_any(query, df=None, tokenizer=None):
|
|
445 |
result.append({'text': text, 'DMC':dmc})
|
446 |
return result
|
447 |
|
448 |
-
def search_query_all(query, df=None, tokenizer=None):
|
449 |
-
global SEARCH_DATA, search_df, index_data_loaded
|
450 |
print('search_query_all!')
|
451 |
print(f'query: {query}')
|
|
|
|
|
|
|
452 |
SEARCH_DATA= df
|
453 |
if df is None:
|
454 |
-
if index_data_loaded==False:
|
455 |
load_index_data()
|
456 |
SEARCH_DATA=search_df
|
457 |
print('SEARCH_DATA:', SEARCH_DATA.head())
|
458 |
-
|
459 |
print('nlp loaded or not:', nlp)
|
460 |
|
461 |
doc = nlp(clear_text(query))
|
@@ -536,8 +561,8 @@ def initialize_qa_model(model):
|
|
536 |
else:#model==2 (базовая)
|
537 |
qa_model= pipeline("question-answering", "timpal0l/mdeberta-v3-base-squad2", device=device)
|
538 |
print('initialized model number 2!')
|
539 |
-
if qa_index_data_loaded==False:
|
540 |
-
|
541 |
#print('len(qa_df)', len(qa_df))
|
542 |
qa_df= concat_by_DMC(qa_df)
|
543 |
#qa_df.to_csv('concat_index.csv', sep=';', index=False)
|
@@ -582,33 +607,60 @@ def get_best_and_longest_result(model_results, threshold, mode):
|
|
582 |
#print('longest_answer:' , longest_answer)
|
583 |
return best_result, longest_result
|
584 |
|
585 |
-
def find_answer(
|
586 |
print('find_answer!')
|
587 |
print('mode:', mode)
|
588 |
found_answer=False
|
589 |
#print('qa_model', qa_model)
|
590 |
-
model_results= qa_model(question
|
591 |
-
#print(
|
592 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
#print('longest_result', longest_result)
|
594 |
if best_result['score']>=threshold:
|
595 |
longest_answer= longest_result['answer']
|
596 |
answer_cleaned= re.sub(r"[\W\d_]+$", '', longest_answer).strip()
|
597 |
if verbose==True:
|
598 |
prob_value= round(model_result['score'], 2)
|
599 |
-
print(f'
|
600 |
longest_result['answer']= answer_cleaned
|
601 |
found_answer=True
|
602 |
if found_answer==False and verbose==True:
|
603 |
-
print('
|
604 |
model_result= best_result
|
605 |
model_result['answer']= longest_result['answer']
|
606 |
return model_result
|
607 |
|
608 |
-
def answer_question(question, mode, model=1):
|
609 |
-
global qa_model_initialized, qa_model_num, tokenizer_qa
|
610 |
print('answer_question!')
|
611 |
-
|
|
|
|
|
|
|
612 |
initialize_qa_model(model)
|
613 |
print(f'question: {question}')
|
614 |
print(f'mode: {mode}')
|
@@ -620,21 +672,21 @@ def answer_question(question, mode, model=1):
|
|
620 |
if len(filtered_index)<1:
|
621 |
filtered_index= search_query_any(question, qa_df, tokenizer_qa)
|
622 |
threshold= THRESHOLD
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
return result
|
640 |
|
|
|
32 |
|
33 |
PERSCENTAGE_IN_RATIO=0.5
|
34 |
THRESHOLD=0.1
|
35 |
+
BATCH_SIZE=8
|
36 |
|
37 |
global nlp, tokenizer_search, tokenizer_qa, device
|
38 |
global search_df, qa_df, SEARCH_DATA
|
|
|
45 |
TOKENIZER_SEARCH_FILENAME='tokenizer_search.pickle'
|
46 |
TOKENIZER_QA_FILENAME='tokenizer_qa.pickle'
|
47 |
INDEX_FOLDER= PUBLICATION_PATH+ os.sep+ "index"
|
48 |
+
INDEX_FOLDER_RU= PUBLICATION_DEMO_RU_PATH+ os.sep+ "index"
|
49 |
+
INDEX_FOLDER_EN= PUBLICATION_DEMO_EN_PATH+ os.sep+ "index"
|
50 |
#print('INDEX_FOLDER:', INDEX_FOLDER)
|
51 |
TOKENIZER_SEARCH_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_SEARCH_FILENAME
|
52 |
+
TOKENIZER_SEARCH_PATH_RU= INDEX_FOLDER_RU+ os.sep+ TOKENIZER_SEARCH_FILENAME
|
53 |
+
TOKENIZER_SEARCH_PATH_EN= INDEX_FOLDER_EN+ os.sep+ TOKENIZER_SEARCH_FILENAME
|
54 |
TOKENIZER_QA_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_QA_FILENAME
|
55 |
+
TOKENIZER_QA_PATH_RU= INDEX_FOLDER_RU+ os.sep+ TOKENIZER_QA_FILENAME
|
56 |
+
TOKENIZER_QA_PATH_EN= INDEX_FOLDER_EN+ os.sep+ TOKENIZER_QA_FILENAME
|
57 |
#print('TOKENIZER_SEARCH_PATH:', TOKENIZER_SEARCH_PATH)
|
58 |
PUBLICATION_LANGUAGE="ru"
|
59 |
|
|
|
341 |
def load_index_data():
|
342 |
global nlp, tokenizer_search, search_df, index_data_loaded
|
343 |
print('load_index_data!')
|
344 |
+
print('PUBLICATION_LANGUAGE:', PUBLICATION_LANGUAGE)
|
345 |
#spacy
|
346 |
disabled_pipes = [ "parser", "ner"]
|
347 |
if PUBLICATION_LANGUAGE=="ru":
|
|
|
352 |
stemmer= Stemmer.Stemmer('en')#english
|
353 |
#print('spacy loaded:', nlp)
|
354 |
#tokenizer
|
355 |
+
if PUBLICATION_LANGUAGE=="ru":
|
356 |
+
with open(TOKENIZER_SEARCH_PATH_RU, 'rb') as handle:
|
357 |
+
tokenizer_search = pickle.load(handle)
|
358 |
+
else:
|
359 |
+
with open(TOKENIZER_SEARCH_PATH_EN, 'rb') as handle:
|
360 |
+
tokenizer_search = pickle.load(handle)
|
361 |
#print('tokenizer loaded:', tokenizer)
|
362 |
#index
|
363 |
+
if PUBLICATION_LANGUAGE=="ru":
|
364 |
+
search_index_path= INDEX_FOLDER_RU+os.sep+'search_index.csv'
|
365 |
+
else:
|
366 |
+
search_index_path= INDEX_FOLDER_EN+os.sep+'search_index.csv'
|
367 |
search_df= pd.read_csv(search_index_path, sep=';')
|
368 |
print('index file loaded:', search_df.info())
|
369 |
search_df['tokens']= search_df['tokens'].apply(convert2list)
|
|
|
383 |
stemmer= Stemmer.Stemmer('en')#english
|
384 |
print('spacy loaded:', nlp)
|
385 |
#tokenizer
|
386 |
+
if PUBLICATION_LANGUAGE=="ru":
|
387 |
+
with open(TOKENIZER_QA_PATH_RU, 'rb') as handle:
|
388 |
+
tokenizer_qa = pickle.load(handle)
|
389 |
+
else:
|
390 |
+
with open(TOKENIZER_QA_PATH_EN, 'rb') as handle:
|
391 |
+
tokenizer_qa = pickle.load(handle)
|
392 |
#print('tokenizer loaded:', tokenizer_qa)
|
393 |
#index
|
394 |
+
if PUBLICATION_LANGUAGE=="ru":
|
395 |
+
qa_index_path= INDEX_FOLDER_RU+os.sep+'qa_index.csv'
|
396 |
+
else:
|
397 |
+
qa_index_path= INDEX_FOLDER_EN+os.sep+'qa_index.csv'
|
398 |
qa_df= pd.read_csv(qa_index_path, sep=';')
|
399 |
#print('index qa file loaded:', qa_df.info())
|
400 |
qa_df['tokens']= qa_df['tokens'].apply(convert2list)
|
|
|
467 |
result.append({'text': text, 'DMC':dmc})
|
468 |
return result
|
469 |
|
470 |
+
def search_query_all(query, df=None, tokenizer=None, language="ru"):
|
471 |
+
global SEARCH_DATA, search_df, index_data_loaded, PUBLICATION_LANGUAGE
|
472 |
print('search_query_all!')
|
473 |
print(f'query: {query}')
|
474 |
+
old_publication_language= PUBLICATION_LANGUAGE
|
475 |
+
PUBLICATION_LANGUAGE= language
|
476 |
+
print('PUBLICATION_LANGUAGE:', PUBLICATION_LANGUAGE)
|
477 |
SEARCH_DATA= df
|
478 |
if df is None:
|
479 |
+
if index_data_loaded==False or language!=old_publication_language:
|
480 |
load_index_data()
|
481 |
SEARCH_DATA=search_df
|
482 |
print('SEARCH_DATA:', SEARCH_DATA.head())
|
483 |
+
|
484 |
print('nlp loaded or not:', nlp)
|
485 |
|
486 |
doc = nlp(clear_text(query))
|
|
|
561 |
else:#model==2 (базовая)
|
562 |
qa_model= pipeline("question-answering", "timpal0l/mdeberta-v3-base-squad2", device=device)
|
563 |
print('initialized model number 2!')
|
564 |
+
#if qa_index_data_loaded==False:
|
565 |
+
load_index_data_qa()
|
566 |
#print('len(qa_df)', len(qa_df))
|
567 |
qa_df= concat_by_DMC(qa_df)
|
568 |
#qa_df.to_csv('concat_index.csv', sep=';', index=False)
|
|
|
607 |
#print('longest_answer:' , longest_answer)
|
608 |
return best_result, longest_result
|
609 |
|
610 |
+
def find_answer(inputs, threshold, max_answer_len=1000, top_k=20, verbose=True, mode='strict'):
|
611 |
print('find_answer!')
|
612 |
print('mode:', mode)
|
613 |
found_answer=False
|
614 |
#print('qa_model', qa_model)
|
615 |
+
model_results= qa_model([{"question": q["question"], "context": q["context"]} for q in inputs], batch_size=BATCH_SIZE, max_answer_len=max_answer_len, top_k=top_k)
|
616 |
+
#print('model_results type:', type(model_results))
|
617 |
+
if isinstance(model_results, dict):
|
618 |
+
tmp= model_results
|
619 |
+
model_results= list()
|
620 |
+
model_results.append(tmp)
|
621 |
+
#print('model_results:', model_results)
|
622 |
+
# Добавляем индексы обратно в результаты
|
623 |
+
best_score=0
|
624 |
+
best_result=None
|
625 |
+
longest_result=None
|
626 |
+
for i, result in enumerate(model_results):#для каждого документа (модуля данных) свой список результатов
|
627 |
+
dmc_value= inputs[i]["DMC"]
|
628 |
+
#print('dmc_value:', dmc_value)
|
629 |
+
if isinstance(result, dict):
|
630 |
+
tmp= result
|
631 |
+
result= list()
|
632 |
+
result.append(tmp)
|
633 |
+
for r in result:#это список результатов для одного модуля данных
|
634 |
+
#print('r:', r)
|
635 |
+
r["DMC"] = dmc_value
|
636 |
+
#print(model_results)
|
637 |
+
best_doc_result, longest_doc_result= get_best_and_longest_result(result, threshold, mode)
|
638 |
+
if best_doc_result["score"]>best_score:
|
639 |
+
best_score= best_doc_result["score"]
|
640 |
+
best_result= best_doc_result
|
641 |
+
longest_result= longest_doc_result
|
642 |
#print('longest_result', longest_result)
|
643 |
if best_result['score']>=threshold:
|
644 |
longest_answer= longest_result['answer']
|
645 |
answer_cleaned= re.sub(r"[\W\d_]+$", '', longest_answer).strip()
|
646 |
if verbose==True:
|
647 |
prob_value= round(model_result['score'], 2)
|
648 |
+
print(f'Answer (score= {prob_value}): {answer_cleaned}')
|
649 |
longest_result['answer']= answer_cleaned
|
650 |
found_answer=True
|
651 |
if found_answer==False and verbose==True:
|
652 |
+
print('Answer not found!')
|
653 |
model_result= best_result
|
654 |
model_result['answer']= longest_result['answer']
|
655 |
return model_result
|
656 |
|
657 |
+
def answer_question(question, mode, model=1, language="ru"):
|
658 |
+
global qa_model_initialized, qa_model_num, tokenizer_qa, PUBLICATION_LANGUAGE
|
659 |
print('answer_question!')
|
660 |
+
old_publication_language= PUBLICATION_LANGUAGE
|
661 |
+
PUBLICATION_LANGUAGE= language
|
662 |
+
print('PUBLICATION_LANGUAGE:', PUBLICATION_LANGUAGE)
|
663 |
+
if qa_model_initialized==False or model!= qa_model_num or old_publication_language!= language:
|
664 |
initialize_qa_model(model)
|
665 |
print(f'question: {question}')
|
666 |
print(f'mode: {mode}')
|
|
|
672 |
if len(filtered_index)<1:
|
673 |
filtered_index= search_query_any(question, qa_df, tokenizer_qa)
|
674 |
threshold= THRESHOLD
|
675 |
+
#print('filtered_index:', filtered_index)
|
676 |
+
|
677 |
+
inputs = [{"question": question, "context": indx["text"], "DMC": indx["DMC"]} for indx in filtered_index]
|
678 |
+
#print('qa model inputs', inputs)
|
679 |
+
top_k=1
|
680 |
+
if mode!="strict":
|
681 |
+
top_k=len(filtered_index)
|
682 |
+
result= find_answer(inputs, threshold=threshold, max_answer_len=1000, top_k=top_k, verbose=False, mode=mode)
|
683 |
+
|
684 |
+
if result!= None:
|
685 |
+
best_answer= result['answer']
|
686 |
+
best_score= result['score']
|
687 |
+
best_DMC= result['DMC']
|
688 |
+
regex = re.compile(r'\([^)]*\)')
|
689 |
+
best_DMC= re.sub(regex, '', best_DMC)
|
690 |
+
result= [{'score': best_score, 'answer': best_answer, 'DMC': best_DMC}]
|
691 |
return result
|
692 |
|