muryshev's picture
init
b24d496
raw
history blame
3.84 kB
import argparse
from typing import List, Dict
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from common.constants import COLUMN_DOC_NAME
from common.constants import COLUMN_EMBEDDING
from common.constants import COLUMN_EMBEDDING_FULL
from common.constants import COLUMN_LABELS_STR
from common.constants import COLUMN_NAMES
from common.constants import COLUMN_TABLE_NAME
from common.constants import COLUMN_TEXT
from common.constants import DEVICE
from common.constants import DO_NORMALIZATION
from common.constants import COLUMN_TYPE_DOC_MAP
from components.embedding_extraction import EmbeddingExtractor
def get_label(unique_names: List) -> Dict[str, int]:
"""
Генерирует метки исходя из количества уникальных названий файлов.
Args:
unique_names: Список уникальных наименований файлов.
Returns:
Возвращает словарь ключ - имя файла, значение - метка.
"""
dict_ = {}
for ind, name in enumerate(unique_names):
dict_[name] = ind
return dict_
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_file',
type=Path,
default='../../data/csv/карта_проводок_clear.csv',
help='path to csv file.')
args = parser.parse_args()
df = pd.read_csv(args.input_file)
df = df.fillna('')
unique_table_name = df['Название таблицы'].unique()
class_name = get_label(unique_table_name)
global_model_path = 'intfloat/multilingual-e5-base'
model = EmbeddingExtractor(global_model_path, DEVICE)
new_df = pd.DataFrame(columns=[COLUMN_DOC_NAME,
COLUMN_TABLE_NAME,
COLUMN_TEXT,
COLUMN_NAMES,
COLUMN_LABELS_STR,
COLUMN_TYPE_DOC_MAP,
COLUMN_EMBEDDING,
COLUMN_EMBEDDING_FULL
])
for ind, row in tqdm(df.iterrows(), total=len(df)):
cleaned_text = row['Хозяйственные операции'].split('\n')
doc_name = row['Название файла']
table_name = row['Название таблицы']
try:
column_names = [i.replace('\t', '').strip() for i in row['Columns'].split('\n')]
except AttributeError:
column_names = []
type_docs = row['TypeDocs']
if not type_docs:
type_docs = '1C'
for text in cleaned_text:
if text != '':
query_tokens = model.query_tokenization('passage: ' + text)
query_embeds = model.query_embed_extraction(query_tokens.to(DEVICE), DO_NORMALIZATION)[0]
query_tokens_full = model.query_tokenization(f'passage: {doc_name} {table_name} {text}')
query_embeds_full = model.query_embed_extraction(query_tokens.to(DEVICE), DO_NORMALIZATION)[0]
new_df.loc[len(new_df.index)] = [doc_name,
table_name,
text,
column_names,
class_name[table_name],
type_docs,
query_embeds,
query_embeds_full,
]
new_df.to_pickle(f'{args.input_file.parent}/{args.input_file.name[:-4]}.pkl')