dmibor commited on
Commit
31211be
·
1 Parent(s): fa01465

some batch doc optimization, added S1000D bike example usage

Browse files
Files changed (2) hide show
  1. app.py +70 -21
  2. search_core.py +89 -37
app.py CHANGED
@@ -15,35 +15,84 @@ st.set_page_config(
15
 
16
  searchStarted= False
17
  qaStarted= False
18
- # Sidebar
 
 
 
 
 
 
 
 
 
19
  with st.sidebar:
20
- tab1, tab2 = st.sidebar.tabs(["Поиск по публикации", "Вопросы-ответы"])
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  with tab1:
23
- st.header("Поиск по публикации")
24
- search_input = st.text_input(label='Введите запрос:', value='аккумуляторная батарея')
25
- searchStarted = st.button('Искать')
 
 
 
 
 
26
 
27
  with tab2:
28
- st.header("Вопросы-ответы")
29
- qa_input = st.text_input(label='Введите вопрос:', value='Какой ресурс до первого ремонта?')
30
- qaStarted = st.button('Узнать ответ')
 
 
 
 
 
 
 
31
 
32
  if searchStarted==True:
33
- st.header("Результаты поиска")
34
- search_result= search_query_all(search_input)
35
- df = pd.DataFrame(pd.json_normalize(search_result))
36
- df.columns=['Параграф модуля данных', 'Код МД']
37
- st.table(df)
 
 
 
 
 
 
 
38
  if qaStarted==True:
39
- st.header("Ответ")
40
- mode_string = 'strict'
41
- model_string = '1'
42
- answer= answer_question(qa_input, mode_string, model_string)
43
- df = pd.DataFrame(pd.json_normalize(answer))
44
- df.columns=['Уверенность', 'Ответ', 'Код МД']
45
- st.table(df)
46
-
 
 
 
 
 
 
 
 
47
 
48
 
49
 
 
15
 
16
  searchStarted= False
17
  qaStarted= False
18
+
19
+ LANGUAGE= 'language'
20
+ if LANGUAGE not in st.session_state:
21
+ st.session_state[LANGUAGE]= 'Русский'
22
+
23
+ if "visibility" not in st.session_state:
24
+ st.session_state.visibility = "visible"
25
+ st.session_state.disabled = False
26
+ st.session_state.horizontal = True
27
+
28
  with st.sidebar:
29
+ st.session_state[LANGUAGE]= st.sidebar.radio(
30
+ "Язык/Language",
31
+ ["Русский", "English"],
32
+ key="Русский",
33
+ label_visibility=st.session_state.visibility,
34
+ disabled=st.session_state.disabled,
35
+ horizontal=st.session_state.horizontal,
36
+ )
37
+ if st.session_state[LANGUAGE]== 'Русский':
38
+ st.sidebar.subheader('Демо-публикация: "Урал 44202-80М", 74 модуля данных, русский язык')
39
+ tab1, tab2 = st.sidebar.tabs(["Поиск по публикации", "Вопросы-ответы"])
40
+ else:
41
+ st.sidebar.subheader('Publication asset: "S1000D release 5.0 bike example", 101 data module, english language')
42
+ tab1, tab2 = st.sidebar.tabs(["Indexed search", "Question answering"])
43
 
44
  with tab1:
45
+ if st.session_state[LANGUAGE]== 'Русский':
46
+ st.header("Поиск по публикации")
47
+ search_input = st.text_input(label='Введите запрос:', value='аккумуляторная батарея')
48
+ searchStarted = st.button('Искать')
49
+ else:
50
+ st.header("Publication content search")
51
+ search_input = st.text_input(label='Enter query:', value='bicycle wheel')
52
+ searchStarted = st.button('Search')
53
 
54
  with tab2:
55
+ if st.session_state[LANGUAGE]== 'Русский':
56
+ st.header("Вопросы-ответы")
57
+ qa_input = st.text_input(label='Введите вопрос:', value='Какой ресурс до первого ремонта?')
58
+ #qa_input = st.text_input(label='Введите вопрос:', value='Что входит в состав системы предпускового подогрева?')
59
+ #qa_input = st.text_input(label='Введите вопрос:', value='Для чего нужен нагреватель с нагнетателем воздуха?')
60
+ qaStarted = st.button('Узнать ответ')
61
+ else:
62
+ st.header("Question answering")
63
+ qa_input = st.text_input(label='Enter question:', value='How many brake pads on the bicycle?')
64
+ qaStarted = st.button('Find out')
65
 
66
  if searchStarted==True:
67
+ if st.session_state[LANGUAGE]== 'Русский':
68
+ st.header("Результаты поиска")
69
+ search_result= search_query_all(search_input, language="ru")
70
+ df = pd.DataFrame(pd.json_normalize(search_result))
71
+ df.columns=['Параграф модуля данных', 'Код МД']
72
+ st.table(df)
73
+ else:
74
+ st.header("Search results")
75
+ search_result= search_query_all(search_input, language="en")
76
+ df = pd.DataFrame(pd.json_normalize(search_result))
77
+ df.columns=['Data module paragraph', 'Data module code']
78
+ st.table(df)
79
  if qaStarted==True:
80
+ if st.session_state[LANGUAGE]== 'Русский':
81
+ st.header("Ответ")
82
+ mode_string = 'strict'
83
+ model_string = '1'
84
+ answer= answer_question(qa_input, mode_string, model_string, language="ru")
85
+ df = pd.DataFrame(pd.json_normalize(answer))
86
+ df.columns=['Уверенность', 'Ответ', 'Код МД']
87
+ st.table(df)
88
+ else:
89
+ st.header("Answer")
90
+ mode_string = 'strict'
91
+ model_string = '1'
92
+ answer= answer_question(qa_input, mode_string, model_string, language="en")
93
+ df = pd.DataFrame(pd.json_normalize(answer))
94
+ df.columns=['Score', 'Answer', 'Data module code']
95
+ st.table(df)
96
 
97
 
98
 
search_core.py CHANGED
@@ -32,6 +32,7 @@ PARSE_PATHS=['//dmodule/content[last()]/procedure[last()]/preliminaryRqmts[last(
32
 
33
  PERSCENTAGE_IN_RATIO=0.5
34
  THRESHOLD=0.1
 
35
 
36
  global nlp, tokenizer_search, tokenizer_qa, device
37
  global search_df, qa_df, SEARCH_DATA
@@ -44,9 +45,15 @@ PUBLICATION_PATH=PUBLICATION_DEMO_RU_PATH
44
  TOKENIZER_SEARCH_FILENAME='tokenizer_search.pickle'
45
  TOKENIZER_QA_FILENAME='tokenizer_qa.pickle'
46
  INDEX_FOLDER= PUBLICATION_PATH+ os.sep+ "index"
 
 
47
  #print('INDEX_FOLDER:', INDEX_FOLDER)
48
  TOKENIZER_SEARCH_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_SEARCH_FILENAME
 
 
49
  TOKENIZER_QA_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_QA_FILENAME
 
 
50
  #print('TOKENIZER_SEARCH_PATH:', TOKENIZER_SEARCH_PATH)
51
  PUBLICATION_LANGUAGE="ru"
52
 
@@ -334,6 +341,7 @@ def convert2list(string):
334
  def load_index_data():
335
  global nlp, tokenizer_search, search_df, index_data_loaded
336
  print('load_index_data!')
 
337
  #spacy
338
  disabled_pipes = [ "parser", "ner"]
339
  if PUBLICATION_LANGUAGE=="ru":
@@ -344,11 +352,18 @@ def load_index_data():
344
  stemmer= Stemmer.Stemmer('en')#english
345
  #print('spacy loaded:', nlp)
346
  #tokenizer
347
- with open(TOKENIZER_SEARCH_PATH, 'rb') as handle:
348
- tokenizer_search = pickle.load(handle)
 
 
 
 
349
  #print('tokenizer loaded:', tokenizer)
350
  #index
351
- search_index_path= INDEX_FOLDER+os.sep+'search_index.csv'
 
 
 
352
  search_df= pd.read_csv(search_index_path, sep=';')
353
  print('index file loaded:', search_df.info())
354
  search_df['tokens']= search_df['tokens'].apply(convert2list)
@@ -368,11 +383,18 @@ def load_index_data_qa():
368
  stemmer= Stemmer.Stemmer('en')#english
369
  print('spacy loaded:', nlp)
370
  #tokenizer
371
- with open(TOKENIZER_QA_PATH, 'rb') as handle:
372
- tokenizer_qa = pickle.load(handle)
 
 
 
 
373
  #print('tokenizer loaded:', tokenizer_qa)
374
  #index
375
- qa_index_path= INDEX_FOLDER+os.sep+'qa_index.csv'
 
 
 
376
  qa_df= pd.read_csv(qa_index_path, sep=';')
377
  #print('index qa file loaded:', qa_df.info())
378
  qa_df['tokens']= qa_df['tokens'].apply(convert2list)
@@ -445,17 +467,20 @@ def search_query_any(query, df=None, tokenizer=None):
445
  result.append({'text': text, 'DMC':dmc})
446
  return result
447
 
448
- def search_query_all(query, df=None, tokenizer=None):
449
- global SEARCH_DATA, search_df, index_data_loaded
450
  print('search_query_all!')
451
  print(f'query: {query}')
 
 
 
452
  SEARCH_DATA= df
453
  if df is None:
454
- if index_data_loaded==False:
455
  load_index_data()
456
  SEARCH_DATA=search_df
457
  print('SEARCH_DATA:', SEARCH_DATA.head())
458
-
459
  print('nlp loaded or not:', nlp)
460
 
461
  doc = nlp(clear_text(query))
@@ -536,8 +561,8 @@ def initialize_qa_model(model):
536
  else:#model==2 (базовая)
537
  qa_model= pipeline("question-answering", "timpal0l/mdeberta-v3-base-squad2", device=device)
538
  print('initialized model number 2!')
539
- if qa_index_data_loaded==False:
540
- load_index_data_qa()
541
  #print('len(qa_df)', len(qa_df))
542
  qa_df= concat_by_DMC(qa_df)
543
  #qa_df.to_csv('concat_index.csv', sep=';', index=False)
@@ -582,33 +607,60 @@ def get_best_and_longest_result(model_results, threshold, mode):
582
  #print('longest_answer:' , longest_answer)
583
  return best_result, longest_result
584
 
585
- def find_answer(question, context, threshold, max_answer_len=1000, top_k=20, verbose=True, mode='strict'):
586
  print('find_answer!')
587
  print('mode:', mode)
588
  found_answer=False
589
  #print('qa_model', qa_model)
590
- model_results= qa_model(question = question, context = context, max_answer_len=max_answer_len, top_k=top_k)
591
- #print(model_result)
592
- best_result, longest_result= get_best_and_longest_result(model_results, threshold, mode)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  #print('longest_result', longest_result)
594
  if best_result['score']>=threshold:
595
  longest_answer= longest_result['answer']
596
  answer_cleaned= re.sub(r"[\W\d_]+$", '', longest_answer).strip()
597
  if verbose==True:
598
  prob_value= round(model_result['score'], 2)
599
- print(f'Ответ (уверенность= {prob_value}): {answer_cleaned}')
600
  longest_result['answer']= answer_cleaned
601
  found_answer=True
602
  if found_answer==False and verbose==True:
603
- print('Ответ не найден!')
604
  model_result= best_result
605
  model_result['answer']= longest_result['answer']
606
  return model_result
607
 
608
- def answer_question(question, mode, model=1):
609
- global qa_model_initialized, qa_model_num, tokenizer_qa
610
  print('answer_question!')
611
- if qa_model_initialized==False or model!= qa_model_num:
 
 
 
612
  initialize_qa_model(model)
613
  print(f'question: {question}')
614
  print(f'mode: {mode}')
@@ -620,21 +672,21 @@ def answer_question(question, mode, model=1):
620
  if len(filtered_index)<1:
621
  filtered_index= search_query_any(question, qa_df, tokenizer_qa)
622
  threshold= THRESHOLD
623
- #print('filtered_index любое слово:', len(filtered_index))
624
-
625
- found_answer=False
626
- best_answer=""
627
- best_score=0
628
- best_DMC=""
629
-
630
- regex = re.compile(r'\([^)]*\)')
631
- for indx in filtered_index:
632
- result= find_answer(question, indx['text'], threshold=threshold, max_answer_len=1000, top_k=20, verbose=False, mode=mode)
633
- if result['score']>best_score:
634
- best_answer= result['answer']
635
- best_score= result['score']
636
- best_DMC= indx['DMC']
637
- best_DMC= re.sub(regex, '', best_DMC)
638
- result= [{'score': best_score, 'answer': best_answer, 'DMC': best_DMC}]
639
  return result
640
 
 
32
 
33
  PERSCENTAGE_IN_RATIO=0.5
34
  THRESHOLD=0.1
35
+ BATCH_SIZE=8
36
 
37
  global nlp, tokenizer_search, tokenizer_qa, device
38
  global search_df, qa_df, SEARCH_DATA
 
45
  TOKENIZER_SEARCH_FILENAME='tokenizer_search.pickle'
46
  TOKENIZER_QA_FILENAME='tokenizer_qa.pickle'
47
  INDEX_FOLDER= PUBLICATION_PATH+ os.sep+ "index"
48
+ INDEX_FOLDER_RU= PUBLICATION_DEMO_RU_PATH+ os.sep+ "index"
49
+ INDEX_FOLDER_EN= PUBLICATION_DEMO_EN_PATH+ os.sep+ "index"
50
  #print('INDEX_FOLDER:', INDEX_FOLDER)
51
  TOKENIZER_SEARCH_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_SEARCH_FILENAME
52
+ TOKENIZER_SEARCH_PATH_RU= INDEX_FOLDER_RU+ os.sep+ TOKENIZER_SEARCH_FILENAME
53
+ TOKENIZER_SEARCH_PATH_EN= INDEX_FOLDER_EN+ os.sep+ TOKENIZER_SEARCH_FILENAME
54
  TOKENIZER_QA_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_QA_FILENAME
55
+ TOKENIZER_QA_PATH_RU= INDEX_FOLDER_RU+ os.sep+ TOKENIZER_QA_FILENAME
56
+ TOKENIZER_QA_PATH_EN= INDEX_FOLDER_EN+ os.sep+ TOKENIZER_QA_FILENAME
57
  #print('TOKENIZER_SEARCH_PATH:', TOKENIZER_SEARCH_PATH)
58
  PUBLICATION_LANGUAGE="ru"
59
 
 
341
  def load_index_data():
342
  global nlp, tokenizer_search, search_df, index_data_loaded
343
  print('load_index_data!')
344
+ print('PUBLICATION_LANGUAGE:', PUBLICATION_LANGUAGE)
345
  #spacy
346
  disabled_pipes = [ "parser", "ner"]
347
  if PUBLICATION_LANGUAGE=="ru":
 
352
  stemmer= Stemmer.Stemmer('en')#english
353
  #print('spacy loaded:', nlp)
354
  #tokenizer
355
+ if PUBLICATION_LANGUAGE=="ru":
356
+ with open(TOKENIZER_SEARCH_PATH_RU, 'rb') as handle:
357
+ tokenizer_search = pickle.load(handle)
358
+ else:
359
+ with open(TOKENIZER_SEARCH_PATH_EN, 'rb') as handle:
360
+ tokenizer_search = pickle.load(handle)
361
  #print('tokenizer loaded:', tokenizer)
362
  #index
363
+ if PUBLICATION_LANGUAGE=="ru":
364
+ search_index_path= INDEX_FOLDER_RU+os.sep+'search_index.csv'
365
+ else:
366
+ search_index_path= INDEX_FOLDER_EN+os.sep+'search_index.csv'
367
  search_df= pd.read_csv(search_index_path, sep=';')
368
  print('index file loaded:', search_df.info())
369
  search_df['tokens']= search_df['tokens'].apply(convert2list)
 
383
  stemmer= Stemmer.Stemmer('en')#english
384
  print('spacy loaded:', nlp)
385
  #tokenizer
386
+ if PUBLICATION_LANGUAGE=="ru":
387
+ with open(TOKENIZER_QA_PATH_RU, 'rb') as handle:
388
+ tokenizer_qa = pickle.load(handle)
389
+ else:
390
+ with open(TOKENIZER_QA_PATH_EN, 'rb') as handle:
391
+ tokenizer_qa = pickle.load(handle)
392
  #print('tokenizer loaded:', tokenizer_qa)
393
  #index
394
+ if PUBLICATION_LANGUAGE=="ru":
395
+ qa_index_path= INDEX_FOLDER_RU+os.sep+'qa_index.csv'
396
+ else:
397
+ qa_index_path= INDEX_FOLDER_EN+os.sep+'qa_index.csv'
398
  qa_df= pd.read_csv(qa_index_path, sep=';')
399
  #print('index qa file loaded:', qa_df.info())
400
  qa_df['tokens']= qa_df['tokens'].apply(convert2list)
 
467
  result.append({'text': text, 'DMC':dmc})
468
  return result
469
 
470
+ def search_query_all(query, df=None, tokenizer=None, language="ru"):
471
+ global SEARCH_DATA, search_df, index_data_loaded, PUBLICATION_LANGUAGE
472
  print('search_query_all!')
473
  print(f'query: {query}')
474
+ old_publication_language= PUBLICATION_LANGUAGE
475
+ PUBLICATION_LANGUAGE= language
476
+ print('PUBLICATION_LANGUAGE:', PUBLICATION_LANGUAGE)
477
  SEARCH_DATA= df
478
  if df is None:
479
+ if index_data_loaded==False or language!=old_publication_language:
480
  load_index_data()
481
  SEARCH_DATA=search_df
482
  print('SEARCH_DATA:', SEARCH_DATA.head())
483
+
484
  print('nlp loaded or not:', nlp)
485
 
486
  doc = nlp(clear_text(query))
 
561
  else:#model==2 (базовая)
562
  qa_model= pipeline("question-answering", "timpal0l/mdeberta-v3-base-squad2", device=device)
563
  print('initialized model number 2!')
564
+ #if qa_index_data_loaded==False:
565
+ load_index_data_qa()
566
  #print('len(qa_df)', len(qa_df))
567
  qa_df= concat_by_DMC(qa_df)
568
  #qa_df.to_csv('concat_index.csv', sep=';', index=False)
 
607
  #print('longest_answer:' , longest_answer)
608
  return best_result, longest_result
609
 
610
+ def find_answer(inputs, threshold, max_answer_len=1000, top_k=20, verbose=True, mode='strict'):
611
  print('find_answer!')
612
  print('mode:', mode)
613
  found_answer=False
614
  #print('qa_model', qa_model)
615
+ model_results= qa_model([{"question": q["question"], "context": q["context"]} for q in inputs], batch_size=BATCH_SIZE, max_answer_len=max_answer_len, top_k=top_k)
616
+ #print('model_results type:', type(model_results))
617
+ if isinstance(model_results, dict):
618
+ tmp= model_results
619
+ model_results= list()
620
+ model_results.append(tmp)
621
+ #print('model_results:', model_results)
622
+ # Добавляем индексы обратно в результаты
623
+ best_score=0
624
+ best_result=None
625
+ longest_result=None
626
+ for i, result in enumerate(model_results):#для каждого документа (модуля данных) свой список результатов
627
+ dmc_value= inputs[i]["DMC"]
628
+ #print('dmc_value:', dmc_value)
629
+ if isinstance(result, dict):
630
+ tmp= result
631
+ result= list()
632
+ result.append(tmp)
633
+ for r in result:#это список результатов для одного модуля данных
634
+ #print('r:', r)
635
+ r["DMC"] = dmc_value
636
+ #print(model_results)
637
+ best_doc_result, longest_doc_result= get_best_and_longest_result(result, threshold, mode)
638
+ if best_doc_result["score"]>best_score:
639
+ best_score= best_doc_result["score"]
640
+ best_result= best_doc_result
641
+ longest_result= longest_doc_result
642
  #print('longest_result', longest_result)
643
  if best_result['score']>=threshold:
644
  longest_answer= longest_result['answer']
645
  answer_cleaned= re.sub(r"[\W\d_]+$", '', longest_answer).strip()
646
  if verbose==True:
647
  prob_value= round(model_result['score'], 2)
648
+ print(f'Answer (score= {prob_value}): {answer_cleaned}')
649
  longest_result['answer']= answer_cleaned
650
  found_answer=True
651
  if found_answer==False and verbose==True:
652
+ print('Answer not found!')
653
  model_result= best_result
654
  model_result['answer']= longest_result['answer']
655
  return model_result
656
 
657
+ def answer_question(question, mode, model=1, language="ru"):
658
+ global qa_model_initialized, qa_model_num, tokenizer_qa, PUBLICATION_LANGUAGE
659
  print('answer_question!')
660
+ old_publication_language= PUBLICATION_LANGUAGE
661
+ PUBLICATION_LANGUAGE= language
662
+ print('PUBLICATION_LANGUAGE:', PUBLICATION_LANGUAGE)
663
+ if qa_model_initialized==False or model!= qa_model_num or old_publication_language!= language:
664
  initialize_qa_model(model)
665
  print(f'question: {question}')
666
  print(f'mode: {mode}')
 
672
  if len(filtered_index)<1:
673
  filtered_index= search_query_any(question, qa_df, tokenizer_qa)
674
  threshold= THRESHOLD
675
+ #print('filtered_index:', filtered_index)
676
+
677
+ inputs = [{"question": question, "context": indx["text"], "DMC": indx["DMC"]} for indx in filtered_index]
678
+ #print('qa model inputs', inputs)
679
+ top_k=1
680
+ if mode!="strict":
681
+ top_k=len(filtered_index)
682
+ result= find_answer(inputs, threshold=threshold, max_answer_len=1000, top_k=top_k, verbose=False, mode=mode)
683
+
684
+ if result!= None:
685
+ best_answer= result['answer']
686
+ best_score= result['score']
687
+ best_DMC= result['DMC']
688
+ regex = re.compile(r'\([^)]*\)')
689
+ best_DMC= re.sub(regex, '', best_DMC)
690
+ result= [{'score': best_score, 'answer': best_answer, 'DMC': best_DMC}]
691
  return result
692