import os
import pprint as pp
from collections import OrderedDict, defaultdict

import json
import diff_viewer
import pandas as pd
import streamlit as st
from datasets import load_dataset, get_dataset_config_names

CHECK_DATASET_DIR_PATH_BEFORE_CLEAN_SELECT = st.secrets["CHECK_DATASET_DIR_PATH_BEFORE_CLEAN_SELECT"]
LOGS_DATASET_DIR_PATH_BEFORE_CLEAN_SELECT = st.secrets["LOGS_DATASET_DIR_PATH_BEFORE_CLEAN_SELECT"]
HF_API_TOKEN = st.secrets["HF_API_TOKEN"]
OPERATION_TYPES = [
    "Applied filter",
    "Applied deduplication function",
    "Applied map function",
]
MAX_LEN_DS_CHECKS = st.secrets["MAX_LEN_DS_CHECKS"]


def get_ds(config):
    ds = load_dataset(CHECK_DATASET_DIR_PATH_BEFORE_CLEAN_SELECT, config, use_auth_token=HF_API_TOKEN)
    return ds


def next_idx(idx: int):
    idx += 1
    return idx % len(st.session_state["ds"])


def previous_idx(idx: int):
    idx -= 1
    return idx % len(st.session_state["ds"])


def on_click_next():
    st.session_state["idx_1"] = next_idx(st.session_state["idx_1"])
    st.session_state["idx_2"] = next_idx(st.session_state["idx_2"])


def on_click_previous():
    st.session_state["idx_1"] = previous_idx(st.session_state["idx_1"])
    st.session_state["idx_2"] = previous_idx(st.session_state["idx_2"])


def on_ds_change(config):
    st.session_state["ds"] = get_ds(config)
    st.session_state["idx_1"] = 0
    st.session_state["idx_2"] = 1 if len(st.session_state["ds"]) > 1 else 0
    st.session_state["ds_check_config"] = config
    st.session_state["ds_max_docs"] = len(st.session_state["ds"])


def get_log_stats_df(raw_log):
    data = OrderedDict(
        {
            "Order": [],
            "Name": [],
            "Initial number of samples": [],
            "Final number of samples": [],
            "Initial size in bytes": [],
            "Final size in bytes": [],
        }
    )

    metric_dict = defaultdict(lambda: {})
    order = 0
    for line in raw_log.split("\n"):
        for metric_name in list(data.keys()) + OPERATION_TYPES:

            if metric_name == "Name" or metric_name == "Order":
                continue

            if metric_name not in line:
                continue

            if (
                metric_name == "Removed percentage"
                and "Removed percentage in bytes" in line
            ):
                continue

            if (
                metric_name == "Deduplicated percentage"
                and "Deduplicated percentage in bytes" in line
            ):
                continue

            value = line.split(metric_name)[1].split(" ")[1]

            if metric_name in OPERATION_TYPES:
                operation_name = value
                metric_dict[operation_name]["Order"] = order
                order += 1
                continue

            assert (
                metric_name not in metric_dict[operation_name]
            ), f"operation_name: {operation_name}\n\nvalue: {value}\n\nmetric_dict: {pp.pformat(metric_dict)} \n\nmetric_name: {metric_name} \n\nline: {line}"
            metric_dict[operation_name][metric_name] = value
    for name, data_dict in metric_dict.items():
        for metric_name in data.keys():
            if metric_name == "Name":
                data[metric_name].append(name)
                continue

            data[metric_name].append(data_dict[metric_name])
    df = pd.DataFrame(data)
    df.rename(
        {
            "Initial size in bytes": "Initial size (GB)",
            "Final size in bytes": "Final size (GB)",
        },
        axis=1,
        inplace=True,
    )
    df["% samples removed"] = (
        (
            df["Initial number of samples"].astype(float)
            - df["Final number of samples"].astype(float)
        )
        / df["Initial number of samples"].astype(float)
        * 100
    )
    df["Size (GB) % removed"] = (
        (df["Initial size (GB)"].astype(float) - df["Final size (GB)"].astype(float))
        / df["Initial size (GB)"].astype(float)
        * 100
    )
    return df


def get_logs_stats(raw_log):
    try:
        df = get_log_stats_df(raw_log)
        st.dataframe(df)
    except Exception as e:
        st.write(e)
        st.write("Subset of the logs:")
        subcontent = [
            line
            for line in raw_log.split("\n")
            if "INFO - __main__" in line
            and "Examples of" not in line
            and "Examples n°" not in line
        ]
        st.write(subcontent)


def meta_component(idx_key: str = "idx_1"):
    if "meta" not in st.session_state["ds"][st.session_state[idx_key]]:
        return

    with st.expander("See meta field of the example"):
        meta = st.session_state["ds"][st.session_state["idx_1"]]["meta"]
        st.write(meta)


def filter_page():
    index_example = st.number_input("Index of the chosen example", min_value=0, max_value=st.session_state["ds_max_docs"] -1, value=0, step=1)
    st.session_state["idx_1"] =  index_example
    st.session_state["idx_2"] = next_idx(index_example) 
    idx_1 = st.session_state["idx_1"]
    idx_2 = st.session_state["idx_2"] 
    text_1 = st.session_state["ds"][idx_1]["text"]
    text_2 = st.session_state["ds"][idx_2]["text"]

    st.markdown(
        f"<h1 style='text-align: center'>Some examples of filtered out texts</h1>",
        unsafe_allow_html=True,
    )
    # col_button_previous, _, col_button_next = st.columns(3)


    # col_button_next.button(
    #     "Go to next example",
    #     key=None,
    #     help=None,
    #     on_click=on_click_next,
    #     args=None,
    #     kwargs=None,
    # )
    # col_button_previous.button(
    #     "Go to previous example",
    #     key=None,
    #     help=None,
    #     on_click=on_click_previous,
    #     args=None,
    #     kwargs=None,
    # )
    col_1, col_2 = st.columns(2)
    with col_1:
        st.subheader(f"Example n°{idx_1}")
        meta_component(idx_key="idx_1")
        text_1_show = text_1.replace("\n", "<br>")
        st.markdown(f"<div>{text_1_show}</div>", unsafe_allow_html=True)

    with col_2:
        st.subheader(f"Example n°{idx_2}")
        meta_component(idx_key="idx_2")
        text_2_show = text_2.replace("\n", "<br>")
        st.markdown(f"<div>{text_2_show}</div>", unsafe_allow_html=True)


def dedup_or_cleaning_page():
    index_example = st.number_input("Index of the chosen example", min_value=0, max_value=st.session_state["ds_max_docs"] -1, value=0, step=1)
    st.session_state["idx_1"] =  index_example
    st.session_state["idx_2"] = next_idx(index_example) 

    # col_button_previous, col_title, col_button_next = st.columns(3)
    # col_title.markdown(
    #     f"<h1 style='text-align: center'>Example n°{st.session_state['idx_1']}</h1>",
    #     unsafe_allow_html=True,
    # )
    # col_button_next.button(
    #     "Go to next example",
    #     key=None,
    #     help=None,
    #     on_click=on_click_next,
    #     args=None,
    #     kwargs=None,
    # )
    # col_button_previous.button(
    #     "Go to previous example",
    #     key=None,
    #     help=None,
    #     on_click=on_click_previous,
    #     args=None,
    #     kwargs=None,
    # )

    text = st.session_state["ds"][st.session_state["idx_1"]]["text"]
    old_text = st.session_state["ds"][st.session_state["idx_1"]]["old_text"]
    st.markdown(
        f"<h2 style='text-align: center'>Changes applied</h1>", unsafe_allow_html=True
    )
    col_text_1, col_text_2 = st.columns(2)
    with col_text_1:
        st.subheader("Old text")
    with col_text_2:
        st.subheader("New text")
    diff_viewer.diff_viewer(old_text=old_text, new_text=text, lang="none")
    meta_component(idx_key="idx_1")

    with st.expander("See full old and new texts of the example"):
        text_show = text.replace("\n", "<br>")
        old_text_show = old_text.replace("\n", "<br>")

        col_1, col_2 = st.columns(2)
        with col_1:
            st.subheader("Old text")
            st.markdown(f"<div>{old_text_show}</div>", unsafe_allow_html=True)
        with col_2:
            st.subheader("New text")
            st.markdown(f"<div>{text_show}</div>", unsafe_allow_html=True)


# Streamlit page
st.set_page_config(page_title="Dataset explorer", page_icon=":hugging_face:", layout="wide")
st.write(
    "The purpose of this application is to sequentially view the changes made to a dataset."
)


# st.write(CHECK_DATASET_DIR_PATH_BEFORE_CLEAN_SELECT)
# ds_log = load_dataset(CHECK_DATASET_DIR_PATH_BEFORE_CLEAN_SELECT, 'clean_v1_dsname_lm_en_multi_un_2', use_auth_token=HF_API_TOKEN)
# st.write(ds_log)



col_option_clean, col_option_ds = st.columns(2)

with open("dataset_configs.json", "r") as f:
    CHECK_CONFIGS = json.load(f)
# CHECK_CONFIGS = get_dataset_config_names(CHECK_DATASET_DIR_PATH_BEFORE_CLEAN_SELECT, use_auth_token=HF_API_TOKEN)

CLEANING_VERSIONS = set()
dataset_names = set()
checks_names = set()
for check_config in CHECK_CONFIGS:
    cleaning_version, check_config = check_config.split("_dsname_")
    dataset_name, checks_name = check_config.split("_operation_")
    CLEANING_VERSIONS.add(cleaning_version)
    dataset_names.add(dataset_name)
    checks_names.add(checks_name)

# CLEANING_VERSIONS = sorted(list(os.listdir(DATASET_DIR_PATH_BEFORE_CLEAN_SELECT)), reverse=True)
option_clean = col_option_clean.selectbox(
    "Select the cleaning version", CLEANING_VERSIONS
)

# DATASET_DIR_PATH = os.path.join(DATASET_DIR_PATH_BEFORE_CLEAN_SELECT, option_clean)
# dataset_names = sorted(list(os.listdir(DATASET_DIR_PATH)))
option_ds = col_option_ds.selectbox("Select the dataset", dataset_names)

# checks_path = os.path.join(DATASET_DIR_PATH, option_ds, "checks")
# checks_names = sorted(list(os.listdir(checks_path)))

# log_path = os.path.join(DATASET_DIR_PATH, option_ds, "logs.txt")
ds_log = load_dataset(LOGS_DATASET_DIR_PATH_BEFORE_CLEAN_SELECT, f"{option_clean}_dsname_{option_ds}", use_auth_token=HF_API_TOKEN)
log = ds_log["train"][0]["log"]
get_logs_stats(raw_log=log)

option_check = st.selectbox("Select the operation applied to inspect", checks_names)

ds_check_config = f"{option_clean}_dsname_{option_ds}_operation_{option_check}"

if "ds" not in st.session_state or ds_check_config != st.session_state["ds_check_config"]:
    on_ds_change(ds_check_config)

if len(st.session_state["ds"]) == MAX_LEN_DS_CHECKS:
    st.warning(
        f"Note: only a subset of size {MAX_LEN_DS_CHECKS} of the modified / filtered examples can be shown in this application"
    )
with st.expander("See details of the available checks"):
    st.write(st.session_state["ds"])


_ = filter_page() if "_filter_" in option_check else dedup_or_cleaning_page()