import streamlit as st
from joblib import load
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from tensorflow.keras.models import load_model
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import time
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tok = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
model_checkpoint = 'cointegrated/rubert-tiny-toxicity'
toxicity_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
toxicity_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
clf = load('my_model_filename.pkl')
vectorizer = load('tfidf_vectorizer.pkl')
scaler = load('scaler.joblib')
tukinazor = load('tokenizer.pkl')
rnn_model = load_model('path_to_my_model.h5')
bert_model = BertForSequenceClassification.from_pretrained('my_bert_model')
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model = bert_model.to(device)
model_finetuned = GPT2LMHeadModel.from_pretrained('GPT_2')
model_finetuned.eval()

labels = ["не токсичный", "оскорбляющий", "непристойный", "угрожающий", "опасный"]
def text2toxicity(text, aggregate=True):
    """ Calculate toxicity of a text (if aggregate=True) or a vector of toxicity aspects (if aggregate=False)"""
    with torch.no_grad():
        inputs = toxicity_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(toxicity_model.device)
        proba = torch.sigmoid(toxicity_model(**inputs).logits).cpu().numpy()
    
    if isinstance(text, str):
        proba = proba[0]
    
    if aggregate:
        return 1 - proba.T[0] * (1 - proba.T[-1])
    else:
        result = {}
        for label, prob in zip(labels, proba):
            result[label] = prob
        return result

        
def predict_text(text):
    sequences = tukinazor.texts_to_sequences([text])
    padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=200, padding='post', truncating='post')
    predictions = rnn_model.predict(padded_sequences)
    predicted_class = tf.argmax(predictions, axis=-1).numpy()[0]
    return predicted_class


def generate_text(model, prompt, max_length=150, temperature=1.0):
    input_ids = tok.encode(prompt, return_tensors='pt')
    output = sber.generate(
        input_ids=input_ids,
        max_length=max_length + len(input_ids[0]),
        temperature=temperature,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )
    generated_text = tok.decode(output[0], skip_special_tokens=True)
    return generated_text


def page_reviews_classification():
    st.title("Модель классификации отзывов")
    st.image("ramsey.jpg", caption="finally some good food", use_column_width=True)
    
    user_input = st.text_area("Введите текст отзыва:")

    if st.button("Классифицировать"):
        start_time = time.time()
        user_input_vec = vectorizer.transform([user_input])
        sentence_vector_scaled = scaler.transform(user_input_vec)
        prediction = clf.predict(
            sentence_vector_scaled)
        elapsed_time = time.time() - start_time
        st.write(f"Прогнозируемый класс: {prediction[0]}")
        st.write(f"Время вычисления: {elapsed_time:.2f} сек.")

    user_input_rnn = st.text_area("Введите текст отзыва для RNN модели:")

    if st.button("Классифицировать с RNN"):
        start_time = time.time()
        prediction_rnn = predict_text(user_input_rnn)
        elapsed_time = time.time() - start_time
        st.write(f"Прогнозируемый класс с RNN: {prediction_rnn}")
        st.write(f"Время вычисления: {elapsed_time:.2f} сек.")

    user_input_bert = st.text_area("Введите текст отзыва для BERT:")

    if st.button("Классифицировать (BERT)"):
        start_time = time.time()
        encoding = tokenizer.encode_plus(
            user_input_bert,
            add_special_tokens=True,
            max_length=200,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)

        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(outputs.logits, dim=1)
            elapsed_time = time.time() - start_time
            st.write(f"Прогнозируемый класс (BERT): {predictions.item() + 1}")
            st.write(f"Время вычисления: {elapsed_time:.2f} сек.")

def page_toxicity_analysis():
    user_input_toxicity = st.text_area("Введите текст для оценки токсичности:")
    if st.button("Оценить токсичность"):
        start_time = time.time()
        probs = text2toxicity(user_input_toxicity, aggregate=False)
        elapsed_time = time.time() - start_time
        for label, prob in probs.items():
            st.write(f"Вероятность того что комментарий {label}: {prob:.4f}")


def page_gpt_generation():
    st.title("Генерация текста с помощью GPT-модели")
    
    user_prompt = st.text_area("Введите ваш текст:")
    sequence_length = st.slider("Длина последовательности:", min_value=10, max_value=1000, value=150, step=10)
    num_generations = st.slider("Число генераций:", min_value=1, max_value=10, value=1)
    temperature = st.slider("Температура:", min_value=0.1, max_value=3.0, value=1.0, step=0.1)

    if st.button("Генерировать"):
        for _ in range(num_generations):
            generated_text = generate_text(model_finetuned, user_prompt, sequence_length, temperature)
            st.text(generated_text)

def main():
    page_selection = st.sidebar.selectbox("Выберите страницу:", ["Классификация отзывов", "Анализ токсичности"])
    
    if page_selection == "Классификация отзывов":
        page_reviews_classification()
    elif page_selection == "Анализ токсичности":
        page_toxicity_analysis()
    elif page_selection == "Генерация текста Noize MC":
        page_gpt_generation()

if __name__ == "__main__":
    main()