Spaces:
Sleeping
Sleeping
some batch doc optimization, added S1000D bike example usage
Browse files- app.py +70 -21
- search_core.py +89 -37
app.py
CHANGED
|
@@ -15,35 +15,84 @@ st.set_page_config(
|
|
| 15 |
|
| 16 |
searchStarted= False
|
| 17 |
qaStarted= False
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
with st.sidebar:
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
with tab1:
|
| 23 |
-
st.
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
with tab2:
|
| 28 |
-
st.
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
if searchStarted==True:
|
| 33 |
-
st.
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
if qaStarted==True:
|
| 39 |
-
st.
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
|
|
|
|
| 15 |
|
| 16 |
searchStarted= False
|
| 17 |
qaStarted= False
|
| 18 |
+
|
| 19 |
+
LANGUAGE= 'language'
|
| 20 |
+
if LANGUAGE not in st.session_state:
|
| 21 |
+
st.session_state[LANGUAGE]= 'Русский'
|
| 22 |
+
|
| 23 |
+
if "visibility" not in st.session_state:
|
| 24 |
+
st.session_state.visibility = "visible"
|
| 25 |
+
st.session_state.disabled = False
|
| 26 |
+
st.session_state.horizontal = True
|
| 27 |
+
|
| 28 |
with st.sidebar:
|
| 29 |
+
st.session_state[LANGUAGE]= st.sidebar.radio(
|
| 30 |
+
"Язык/Language",
|
| 31 |
+
["Русский", "English"],
|
| 32 |
+
key="Русский",
|
| 33 |
+
label_visibility=st.session_state.visibility,
|
| 34 |
+
disabled=st.session_state.disabled,
|
| 35 |
+
horizontal=st.session_state.horizontal,
|
| 36 |
+
)
|
| 37 |
+
if st.session_state[LANGUAGE]== 'Русский':
|
| 38 |
+
st.sidebar.subheader('Демо-публикация: "Урал 44202-80М", 74 модуля данных, русский язык')
|
| 39 |
+
tab1, tab2 = st.sidebar.tabs(["Поиск по публикации", "Вопросы-ответы"])
|
| 40 |
+
else:
|
| 41 |
+
st.sidebar.subheader('Publication asset: "S1000D release 5.0 bike example", 101 data module, english language')
|
| 42 |
+
tab1, tab2 = st.sidebar.tabs(["Indexed search", "Question answering"])
|
| 43 |
|
| 44 |
with tab1:
|
| 45 |
+
if st.session_state[LANGUAGE]== 'Русский':
|
| 46 |
+
st.header("Поиск по публикации")
|
| 47 |
+
search_input = st.text_input(label='Введите запрос:', value='аккумуляторная батарея')
|
| 48 |
+
searchStarted = st.button('Искать')
|
| 49 |
+
else:
|
| 50 |
+
st.header("Publication content search")
|
| 51 |
+
search_input = st.text_input(label='Enter query:', value='bicycle wheel')
|
| 52 |
+
searchStarted = st.button('Search')
|
| 53 |
|
| 54 |
with tab2:
|
| 55 |
+
if st.session_state[LANGUAGE]== 'Русский':
|
| 56 |
+
st.header("Вопросы-ответы")
|
| 57 |
+
qa_input = st.text_input(label='Введите вопрос:', value='Какой ресурс до первого ремонта?')
|
| 58 |
+
#qa_input = st.text_input(label='Введите вопрос:', value='Что входит в состав системы предпускового подогрева?')
|
| 59 |
+
#qa_input = st.text_input(label='Введите вопрос:', value='Для чего нужен нагреватель с нагнетателем воздуха?')
|
| 60 |
+
qaStarted = st.button('Узнать ответ')
|
| 61 |
+
else:
|
| 62 |
+
st.header("Question answering")
|
| 63 |
+
qa_input = st.text_input(label='Enter question:', value='How many brake pads on the bicycle?')
|
| 64 |
+
qaStarted = st.button('Find out')
|
| 65 |
|
| 66 |
if searchStarted==True:
|
| 67 |
+
if st.session_state[LANGUAGE]== 'Русский':
|
| 68 |
+
st.header("Результаты поиска")
|
| 69 |
+
search_result= search_query_all(search_input, language="ru")
|
| 70 |
+
df = pd.DataFrame(pd.json_normalize(search_result))
|
| 71 |
+
df.columns=['Параграф модуля данных', 'Код МД']
|
| 72 |
+
st.table(df)
|
| 73 |
+
else:
|
| 74 |
+
st.header("Search results")
|
| 75 |
+
search_result= search_query_all(search_input, language="en")
|
| 76 |
+
df = pd.DataFrame(pd.json_normalize(search_result))
|
| 77 |
+
df.columns=['Data module paragraph', 'Data module code']
|
| 78 |
+
st.table(df)
|
| 79 |
if qaStarted==True:
|
| 80 |
+
if st.session_state[LANGUAGE]== 'Русский':
|
| 81 |
+
st.header("Ответ")
|
| 82 |
+
mode_string = 'strict'
|
| 83 |
+
model_string = '1'
|
| 84 |
+
answer= answer_question(qa_input, mode_string, model_string, language="ru")
|
| 85 |
+
df = pd.DataFrame(pd.json_normalize(answer))
|
| 86 |
+
df.columns=['Уверенность', 'Ответ', 'Код МД']
|
| 87 |
+
st.table(df)
|
| 88 |
+
else:
|
| 89 |
+
st.header("Answer")
|
| 90 |
+
mode_string = 'strict'
|
| 91 |
+
model_string = '1'
|
| 92 |
+
answer= answer_question(qa_input, mode_string, model_string, language="en")
|
| 93 |
+
df = pd.DataFrame(pd.json_normalize(answer))
|
| 94 |
+
df.columns=['Score', 'Answer', 'Data module code']
|
| 95 |
+
st.table(df)
|
| 96 |
|
| 97 |
|
| 98 |
|
search_core.py
CHANGED
|
@@ -32,6 +32,7 @@ PARSE_PATHS=['//dmodule/content[last()]/procedure[last()]/preliminaryRqmts[last(
|
|
| 32 |
|
| 33 |
PERSCENTAGE_IN_RATIO=0.5
|
| 34 |
THRESHOLD=0.1
|
|
|
|
| 35 |
|
| 36 |
global nlp, tokenizer_search, tokenizer_qa, device
|
| 37 |
global search_df, qa_df, SEARCH_DATA
|
|
@@ -44,9 +45,15 @@ PUBLICATION_PATH=PUBLICATION_DEMO_RU_PATH
|
|
| 44 |
TOKENIZER_SEARCH_FILENAME='tokenizer_search.pickle'
|
| 45 |
TOKENIZER_QA_FILENAME='tokenizer_qa.pickle'
|
| 46 |
INDEX_FOLDER= PUBLICATION_PATH+ os.sep+ "index"
|
|
|
|
|
|
|
| 47 |
#print('INDEX_FOLDER:', INDEX_FOLDER)
|
| 48 |
TOKENIZER_SEARCH_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_SEARCH_FILENAME
|
|
|
|
|
|
|
| 49 |
TOKENIZER_QA_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_QA_FILENAME
|
|
|
|
|
|
|
| 50 |
#print('TOKENIZER_SEARCH_PATH:', TOKENIZER_SEARCH_PATH)
|
| 51 |
PUBLICATION_LANGUAGE="ru"
|
| 52 |
|
|
@@ -334,6 +341,7 @@ def convert2list(string):
|
|
| 334 |
def load_index_data():
|
| 335 |
global nlp, tokenizer_search, search_df, index_data_loaded
|
| 336 |
print('load_index_data!')
|
|
|
|
| 337 |
#spacy
|
| 338 |
disabled_pipes = [ "parser", "ner"]
|
| 339 |
if PUBLICATION_LANGUAGE=="ru":
|
|
@@ -344,11 +352,18 @@ def load_index_data():
|
|
| 344 |
stemmer= Stemmer.Stemmer('en')#english
|
| 345 |
#print('spacy loaded:', nlp)
|
| 346 |
#tokenizer
|
| 347 |
-
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
#print('tokenizer loaded:', tokenizer)
|
| 350 |
#index
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
| 352 |
search_df= pd.read_csv(search_index_path, sep=';')
|
| 353 |
print('index file loaded:', search_df.info())
|
| 354 |
search_df['tokens']= search_df['tokens'].apply(convert2list)
|
|
@@ -368,11 +383,18 @@ def load_index_data_qa():
|
|
| 368 |
stemmer= Stemmer.Stemmer('en')#english
|
| 369 |
print('spacy loaded:', nlp)
|
| 370 |
#tokenizer
|
| 371 |
-
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
#print('tokenizer loaded:', tokenizer_qa)
|
| 374 |
#index
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
| 376 |
qa_df= pd.read_csv(qa_index_path, sep=';')
|
| 377 |
#print('index qa file loaded:', qa_df.info())
|
| 378 |
qa_df['tokens']= qa_df['tokens'].apply(convert2list)
|
|
@@ -445,17 +467,20 @@ def search_query_any(query, df=None, tokenizer=None):
|
|
| 445 |
result.append({'text': text, 'DMC':dmc})
|
| 446 |
return result
|
| 447 |
|
| 448 |
-
def search_query_all(query, df=None, tokenizer=None):
|
| 449 |
-
global SEARCH_DATA, search_df, index_data_loaded
|
| 450 |
print('search_query_all!')
|
| 451 |
print(f'query: {query}')
|
|
|
|
|
|
|
|
|
|
| 452 |
SEARCH_DATA= df
|
| 453 |
if df is None:
|
| 454 |
-
if index_data_loaded==False:
|
| 455 |
load_index_data()
|
| 456 |
SEARCH_DATA=search_df
|
| 457 |
print('SEARCH_DATA:', SEARCH_DATA.head())
|
| 458 |
-
|
| 459 |
print('nlp loaded or not:', nlp)
|
| 460 |
|
| 461 |
doc = nlp(clear_text(query))
|
|
@@ -536,8 +561,8 @@ def initialize_qa_model(model):
|
|
| 536 |
else:#model==2 (базовая)
|
| 537 |
qa_model= pipeline("question-answering", "timpal0l/mdeberta-v3-base-squad2", device=device)
|
| 538 |
print('initialized model number 2!')
|
| 539 |
-
if qa_index_data_loaded==False:
|
| 540 |
-
|
| 541 |
#print('len(qa_df)', len(qa_df))
|
| 542 |
qa_df= concat_by_DMC(qa_df)
|
| 543 |
#qa_df.to_csv('concat_index.csv', sep=';', index=False)
|
|
@@ -582,33 +607,60 @@ def get_best_and_longest_result(model_results, threshold, mode):
|
|
| 582 |
#print('longest_answer:' , longest_answer)
|
| 583 |
return best_result, longest_result
|
| 584 |
|
| 585 |
-
def find_answer(
|
| 586 |
print('find_answer!')
|
| 587 |
print('mode:', mode)
|
| 588 |
found_answer=False
|
| 589 |
#print('qa_model', qa_model)
|
| 590 |
-
model_results= qa_model(question
|
| 591 |
-
#print(
|
| 592 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
#print('longest_result', longest_result)
|
| 594 |
if best_result['score']>=threshold:
|
| 595 |
longest_answer= longest_result['answer']
|
| 596 |
answer_cleaned= re.sub(r"[\W\d_]+$", '', longest_answer).strip()
|
| 597 |
if verbose==True:
|
| 598 |
prob_value= round(model_result['score'], 2)
|
| 599 |
-
print(f'
|
| 600 |
longest_result['answer']= answer_cleaned
|
| 601 |
found_answer=True
|
| 602 |
if found_answer==False and verbose==True:
|
| 603 |
-
print('
|
| 604 |
model_result= best_result
|
| 605 |
model_result['answer']= longest_result['answer']
|
| 606 |
return model_result
|
| 607 |
|
| 608 |
-
def answer_question(question, mode, model=1):
|
| 609 |
-
global qa_model_initialized, qa_model_num, tokenizer_qa
|
| 610 |
print('answer_question!')
|
| 611 |
-
|
|
|
|
|
|
|
|
|
|
| 612 |
initialize_qa_model(model)
|
| 613 |
print(f'question: {question}')
|
| 614 |
print(f'mode: {mode}')
|
|
@@ -620,21 +672,21 @@ def answer_question(question, mode, model=1):
|
|
| 620 |
if len(filtered_index)<1:
|
| 621 |
filtered_index= search_query_any(question, qa_df, tokenizer_qa)
|
| 622 |
threshold= THRESHOLD
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
return result
|
| 640 |
|
|
|
|
| 32 |
|
| 33 |
PERSCENTAGE_IN_RATIO=0.5
|
| 34 |
THRESHOLD=0.1
|
| 35 |
+
BATCH_SIZE=8
|
| 36 |
|
| 37 |
global nlp, tokenizer_search, tokenizer_qa, device
|
| 38 |
global search_df, qa_df, SEARCH_DATA
|
|
|
|
| 45 |
TOKENIZER_SEARCH_FILENAME='tokenizer_search.pickle'
|
| 46 |
TOKENIZER_QA_FILENAME='tokenizer_qa.pickle'
|
| 47 |
INDEX_FOLDER= PUBLICATION_PATH+ os.sep+ "index"
|
| 48 |
+
INDEX_FOLDER_RU= PUBLICATION_DEMO_RU_PATH+ os.sep+ "index"
|
| 49 |
+
INDEX_FOLDER_EN= PUBLICATION_DEMO_EN_PATH+ os.sep+ "index"
|
| 50 |
#print('INDEX_FOLDER:', INDEX_FOLDER)
|
| 51 |
TOKENIZER_SEARCH_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_SEARCH_FILENAME
|
| 52 |
+
TOKENIZER_SEARCH_PATH_RU= INDEX_FOLDER_RU+ os.sep+ TOKENIZER_SEARCH_FILENAME
|
| 53 |
+
TOKENIZER_SEARCH_PATH_EN= INDEX_FOLDER_EN+ os.sep+ TOKENIZER_SEARCH_FILENAME
|
| 54 |
TOKENIZER_QA_PATH= INDEX_FOLDER+ os.sep+ TOKENIZER_QA_FILENAME
|
| 55 |
+
TOKENIZER_QA_PATH_RU= INDEX_FOLDER_RU+ os.sep+ TOKENIZER_QA_FILENAME
|
| 56 |
+
TOKENIZER_QA_PATH_EN= INDEX_FOLDER_EN+ os.sep+ TOKENIZER_QA_FILENAME
|
| 57 |
#print('TOKENIZER_SEARCH_PATH:', TOKENIZER_SEARCH_PATH)
|
| 58 |
PUBLICATION_LANGUAGE="ru"
|
| 59 |
|
|
|
|
| 341 |
def load_index_data():
|
| 342 |
global nlp, tokenizer_search, search_df, index_data_loaded
|
| 343 |
print('load_index_data!')
|
| 344 |
+
print('PUBLICATION_LANGUAGE:', PUBLICATION_LANGUAGE)
|
| 345 |
#spacy
|
| 346 |
disabled_pipes = [ "parser", "ner"]
|
| 347 |
if PUBLICATION_LANGUAGE=="ru":
|
|
|
|
| 352 |
stemmer= Stemmer.Stemmer('en')#english
|
| 353 |
#print('spacy loaded:', nlp)
|
| 354 |
#tokenizer
|
| 355 |
+
if PUBLICATION_LANGUAGE=="ru":
|
| 356 |
+
with open(TOKENIZER_SEARCH_PATH_RU, 'rb') as handle:
|
| 357 |
+
tokenizer_search = pickle.load(handle)
|
| 358 |
+
else:
|
| 359 |
+
with open(TOKENIZER_SEARCH_PATH_EN, 'rb') as handle:
|
| 360 |
+
tokenizer_search = pickle.load(handle)
|
| 361 |
#print('tokenizer loaded:', tokenizer)
|
| 362 |
#index
|
| 363 |
+
if PUBLICATION_LANGUAGE=="ru":
|
| 364 |
+
search_index_path= INDEX_FOLDER_RU+os.sep+'search_index.csv'
|
| 365 |
+
else:
|
| 366 |
+
search_index_path= INDEX_FOLDER_EN+os.sep+'search_index.csv'
|
| 367 |
search_df= pd.read_csv(search_index_path, sep=';')
|
| 368 |
print('index file loaded:', search_df.info())
|
| 369 |
search_df['tokens']= search_df['tokens'].apply(convert2list)
|
|
|
|
| 383 |
stemmer= Stemmer.Stemmer('en')#english
|
| 384 |
print('spacy loaded:', nlp)
|
| 385 |
#tokenizer
|
| 386 |
+
if PUBLICATION_LANGUAGE=="ru":
|
| 387 |
+
with open(TOKENIZER_QA_PATH_RU, 'rb') as handle:
|
| 388 |
+
tokenizer_qa = pickle.load(handle)
|
| 389 |
+
else:
|
| 390 |
+
with open(TOKENIZER_QA_PATH_EN, 'rb') as handle:
|
| 391 |
+
tokenizer_qa = pickle.load(handle)
|
| 392 |
#print('tokenizer loaded:', tokenizer_qa)
|
| 393 |
#index
|
| 394 |
+
if PUBLICATION_LANGUAGE=="ru":
|
| 395 |
+
qa_index_path= INDEX_FOLDER_RU+os.sep+'qa_index.csv'
|
| 396 |
+
else:
|
| 397 |
+
qa_index_path= INDEX_FOLDER_EN+os.sep+'qa_index.csv'
|
| 398 |
qa_df= pd.read_csv(qa_index_path, sep=';')
|
| 399 |
#print('index qa file loaded:', qa_df.info())
|
| 400 |
qa_df['tokens']= qa_df['tokens'].apply(convert2list)
|
|
|
|
| 467 |
result.append({'text': text, 'DMC':dmc})
|
| 468 |
return result
|
| 469 |
|
| 470 |
+
def search_query_all(query, df=None, tokenizer=None, language="ru"):
|
| 471 |
+
global SEARCH_DATA, search_df, index_data_loaded, PUBLICATION_LANGUAGE
|
| 472 |
print('search_query_all!')
|
| 473 |
print(f'query: {query}')
|
| 474 |
+
old_publication_language= PUBLICATION_LANGUAGE
|
| 475 |
+
PUBLICATION_LANGUAGE= language
|
| 476 |
+
print('PUBLICATION_LANGUAGE:', PUBLICATION_LANGUAGE)
|
| 477 |
SEARCH_DATA= df
|
| 478 |
if df is None:
|
| 479 |
+
if index_data_loaded==False or language!=old_publication_language:
|
| 480 |
load_index_data()
|
| 481 |
SEARCH_DATA=search_df
|
| 482 |
print('SEARCH_DATA:', SEARCH_DATA.head())
|
| 483 |
+
|
| 484 |
print('nlp loaded or not:', nlp)
|
| 485 |
|
| 486 |
doc = nlp(clear_text(query))
|
|
|
|
| 561 |
else:#model==2 (базовая)
|
| 562 |
qa_model= pipeline("question-answering", "timpal0l/mdeberta-v3-base-squad2", device=device)
|
| 563 |
print('initialized model number 2!')
|
| 564 |
+
#if qa_index_data_loaded==False:
|
| 565 |
+
load_index_data_qa()
|
| 566 |
#print('len(qa_df)', len(qa_df))
|
| 567 |
qa_df= concat_by_DMC(qa_df)
|
| 568 |
#qa_df.to_csv('concat_index.csv', sep=';', index=False)
|
|
|
|
| 607 |
#print('longest_answer:' , longest_answer)
|
| 608 |
return best_result, longest_result
|
| 609 |
|
| 610 |
+
def find_answer(inputs, threshold, max_answer_len=1000, top_k=20, verbose=True, mode='strict'):
|
| 611 |
print('find_answer!')
|
| 612 |
print('mode:', mode)
|
| 613 |
found_answer=False
|
| 614 |
#print('qa_model', qa_model)
|
| 615 |
+
model_results= qa_model([{"question": q["question"], "context": q["context"]} for q in inputs], batch_size=BATCH_SIZE, max_answer_len=max_answer_len, top_k=top_k)
|
| 616 |
+
#print('model_results type:', type(model_results))
|
| 617 |
+
if isinstance(model_results, dict):
|
| 618 |
+
tmp= model_results
|
| 619 |
+
model_results= list()
|
| 620 |
+
model_results.append(tmp)
|
| 621 |
+
#print('model_results:', model_results)
|
| 622 |
+
# Добавляем индексы обратно в результаты
|
| 623 |
+
best_score=0
|
| 624 |
+
best_result=None
|
| 625 |
+
longest_result=None
|
| 626 |
+
for i, result in enumerate(model_results):#для каждого документа (модуля данных) свой список результатов
|
| 627 |
+
dmc_value= inputs[i]["DMC"]
|
| 628 |
+
#print('dmc_value:', dmc_value)
|
| 629 |
+
if isinstance(result, dict):
|
| 630 |
+
tmp= result
|
| 631 |
+
result= list()
|
| 632 |
+
result.append(tmp)
|
| 633 |
+
for r in result:#это список результатов для одного модуля данных
|
| 634 |
+
#print('r:', r)
|
| 635 |
+
r["DMC"] = dmc_value
|
| 636 |
+
#print(model_results)
|
| 637 |
+
best_doc_result, longest_doc_result= get_best_and_longest_result(result, threshold, mode)
|
| 638 |
+
if best_doc_result["score"]>best_score:
|
| 639 |
+
best_score= best_doc_result["score"]
|
| 640 |
+
best_result= best_doc_result
|
| 641 |
+
longest_result= longest_doc_result
|
| 642 |
#print('longest_result', longest_result)
|
| 643 |
if best_result['score']>=threshold:
|
| 644 |
longest_answer= longest_result['answer']
|
| 645 |
answer_cleaned= re.sub(r"[\W\d_]+$", '', longest_answer).strip()
|
| 646 |
if verbose==True:
|
| 647 |
prob_value= round(model_result['score'], 2)
|
| 648 |
+
print(f'Answer (score= {prob_value}): {answer_cleaned}')
|
| 649 |
longest_result['answer']= answer_cleaned
|
| 650 |
found_answer=True
|
| 651 |
if found_answer==False and verbose==True:
|
| 652 |
+
print('Answer not found!')
|
| 653 |
model_result= best_result
|
| 654 |
model_result['answer']= longest_result['answer']
|
| 655 |
return model_result
|
| 656 |
|
| 657 |
+
def answer_question(question, mode, model=1, language="ru"):
|
| 658 |
+
global qa_model_initialized, qa_model_num, tokenizer_qa, PUBLICATION_LANGUAGE
|
| 659 |
print('answer_question!')
|
| 660 |
+
old_publication_language= PUBLICATION_LANGUAGE
|
| 661 |
+
PUBLICATION_LANGUAGE= language
|
| 662 |
+
print('PUBLICATION_LANGUAGE:', PUBLICATION_LANGUAGE)
|
| 663 |
+
if qa_model_initialized==False or model!= qa_model_num or old_publication_language!= language:
|
| 664 |
initialize_qa_model(model)
|
| 665 |
print(f'question: {question}')
|
| 666 |
print(f'mode: {mode}')
|
|
|
|
| 672 |
if len(filtered_index)<1:
|
| 673 |
filtered_index= search_query_any(question, qa_df, tokenizer_qa)
|
| 674 |
threshold= THRESHOLD
|
| 675 |
+
#print('filtered_index:', filtered_index)
|
| 676 |
+
|
| 677 |
+
inputs = [{"question": question, "context": indx["text"], "DMC": indx["DMC"]} for indx in filtered_index]
|
| 678 |
+
#print('qa model inputs', inputs)
|
| 679 |
+
top_k=1
|
| 680 |
+
if mode!="strict":
|
| 681 |
+
top_k=len(filtered_index)
|
| 682 |
+
result= find_answer(inputs, threshold=threshold, max_answer_len=1000, top_k=top_k, verbose=False, mode=mode)
|
| 683 |
+
|
| 684 |
+
if result!= None:
|
| 685 |
+
best_answer= result['answer']
|
| 686 |
+
best_score= result['score']
|
| 687 |
+
best_DMC= result['DMC']
|
| 688 |
+
regex = re.compile(r'\([^)]*\)')
|
| 689 |
+
best_DMC= re.sub(regex, '', best_DMC)
|
| 690 |
+
result= [{'score': best_score, 'answer': best_answer, 'DMC': best_DMC}]
|
| 691 |
return result
|
| 692 |
|