import re
import numpy as np
import pandas as pd
import pymorphy2
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

morph = pymorphy2.MorphAnalyzer()
tokenizer = AutoTokenizer.from_pretrained("ai-forever/ru-en-RoSBERTa")
model = AutoModel.from_pretrained("ai-forever/ru-en-RoSBERTa")

def cosine_similarity(embedding1, embedding2):
    embedding1 = np.array(embedding1)
    embedding2 = np.array(embedding2)
    
    dot_product = np.dot(embedding1, embedding2)
    norm_a = np.linalg.norm(embedding1)
    norm_b = np.linalg.norm(embedding2)
    
    return dot_product / (norm_a * norm_b)

def pool(hidden_state, mask, pooling_method="cls"):
    if pooling_method == "mean":
        s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
        d = mask.sum(axis=1, keepdim=True).float()
        return s / d
    elif pooling_method == "cls":
        return hidden_state[:, 0]

def text_to_embedding(text, tokenizer, model):
    # Токенизация текста
    tokenized_inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**tokenized_inputs)
    
    embeddings = pool(
        outputs.last_hidden_state, 
        tokenized_inputs["attention_mask"],
        pooling_method="cls" # or try "mean"
    )

    embeddings = F.normalize(embeddings, p=2, dim=1).numpy()
    
    return embeddings

def preprocess_text(text):
    lemmas = []  # Для хранения лемм
    for token in text.split():
        parsed = morph.parse(token)[0]  # Морфологический разбор токена
        
        # Лемматизация
        if parsed.normal_form and parsed.normal_form.strip():
            lemmas.append(parsed.normal_form)  # Добавляем лемму
            
    return " ".join(lemmas) if lemmas else ""

def product_extraction(text):
    lemmas = preprocess_text(text)
    if 'кредитный бизнес-' in lemmas:
        return 'кредитная бизнес-карта'
    elif 'выпустить бизнес-карта' in lemmas:
        return 'бизнес-карта'
    elif ('расчётный счёт' in lemmas) or ('открыть счёт' in lemmas):
        return 'расчетный счет'
    elif 'бизнес-карта' in lemmas:
        return 'бизнес-карта'
    elif 'бизнес-кешбэк' in lemmas:
         return 'cashback'
    elif 'перевод' in lemmas:
         return 'переводы'
    elif 'кредит' in lemmas:
        return 'кредит'
    elif 'эквайринг' in lemmas:
        return 'эквайринг'
    elif 'зарплатный проект' in lemmas:
        return 'зарплатный проект'
    elif 'вклад' in lemmas:
        return 'вклад'
    elif 'депозит' in lemmas:
        return 'депозит'
    return 'прочее'

def best_text_choice(texts, core_df, tokenizer, model, coef=1):
    '''
    Функция для выбора лучшего текста, и оценки его успешности
    '''
    scoring_list = []
    embeddings_df = core_df.copy()
    texts_df = pd.DataFrame(texts, columns=['texts'])
    texts_df['texts_lower'] = texts_df['texts'].apply(lambda x: x.lower())
    texts_df['texts_'] = 'search_query: ' + texts_df['texts_lower']
    texts_df['embeddings'] = texts_df['texts_'].apply(lambda x: text_to_embedding(x, tokenizer, model)[0])
    texts_df['product'] = texts_df['texts'].apply(product_extraction)
    best_text = ''
    score = 0
    for index, row in texts_df.iterrows():
        product = row['product']
        embeddings_df['similarity'] = embeddings_df['embedding'].apply(lambda x: cosine_similarity(x, row['embeddings']))
        embeddings_df['score'] = embeddings_df['value'] * embeddings_df['similarity']
        score_ = np.mean([(embeddings_df
                          .sort_values(by=['product_type', 'score'], ascending=[True, False])
                          .query('product_type == @product')['score'][:3].mean() * coef),
                         embeddings_df
                         .sort_values(by='similarity', ascending=False)
                         .query('product_type != @product')['score'][:3].mean()])
        scoring_list.append([row['texts'], 100*score_  / embeddings_df.query('product_type == @product')['value'].max()])
        if score_ > score:
            score = score_
            best_text = row['texts']
        
    # ratio = score / embeddings_df.query('product_type == @product')['value'].max()
    scoring_df = pd.DataFrame(scoring_list, columns=['text', 'score'])
    scoring_df = scoring_df.sort_values(by='score', ascending=False).reset_index(drop=True)
    scoring_df.index += 1
    return scoring_df.reset_index().rename(columns={'index': 'Место'})[['Место', 'text']]