In [1]:
__import__('pysqlite3')
import sys
import os
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
os.environ['ALLOW_RESET'] = 'True'

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
import numpy as np
from tqdm import tqdm

import chromadb

### Подготавливаем базу данных

In [12]:
client = chromadb.PersistentClient(path='db')
client.reset()

collection = client.create_collection(
    name="administrative_codex",
    metadata={"hnsw:space": "cosine"}
)

### Открываем и предобрабатываем КоАП

In [3]:
with open('docs/КоАП РФ.txt', encoding='utf-8') as r:
    raw_text = r.read().split('\n\n')

### Делим документ по частям статей, исключаем лишнее

In [9]:
paragraphs = []
index = 0

while index != len(raw_text):
    if raw_text[index].startswith('Статья'):
        article = ' '.join(raw_text[index].strip().split()[:2])
        article_points = raw_text[index + 1].split('\n')

        cur_point = ''
        for i in range(len(article_points)):
            cur_point_part = article_points[i].strip()
            
            if 'КонсультантПлюс' in article_points[i] + article_points[i - 1]:
                continue
            elif cur_point_part.split()[0].strip().replace('.', '').isnumeric() or cur_point_part.startswith('Примечание. '):
                if cur_point:
                    if cur_point.startswith('Примечание. '):
                        paragraphs.append([cur_point, article, 'Примечание.'])
                    elif cur_point[0].isnumeric():
                        paragraphs.append([' '.join(cur_point.split()[1:]), article, f'Часть {cur_point.split()[0]}'])
                    else:
                        paragraphs.append([cur_point, article, ''])
                        
                cur_point = cur_point_part
            elif cur_point_part[0] != '(' and cur_point_part[-1] != ')' and 'утратил силу' not in cur_point_part[:20].lower():
                cur_point += ' ' + cur_point_part
            
        index += 2
    else:
        index += 1

### Получаем эмбеддинги из извлеченных фрагментов и сохраняем их в базу данных

In [5]:
checkpoint = 'sentence-transformers/LaBSE'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModel.from_pretrained(checkpoint, device_map='cuda:0')

In [6]:
def encode(docs):
    if type(docs) == str:
        docs = [docs]

    encoded_input = tokenizer(
        docs,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    )
    
    with torch.no_grad():
        model_output = model(**encoded_input.to('cuda'))
        
    embeddings = model_output.pooler_output
    embeddings = torch.nn.functional.normalize(embeddings)
    return embeddings.detach().cpu().tolist()

In [10]:
BATCH_SIZE = 128
loader = DataLoader(paragraphs, batch_size=BATCH_SIZE)

In [13]:
for i, docs in enumerate(tqdm(loader)):
    embeddings = encode(docs[0])
    collection.add(
        documents=docs[0],
        metadatas=[{'doc': 'КоАП РФ', 'article': a, 'point': p} for a, p in zip(docs[1], docs[2])],
        embeddings=embeddings,
        ids=[f'id{i * BATCH_SIZE + j}' for j in range(len(docs[0]))],
    )

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:11<00:00,  1.48it/s]
