## LIBRARIES ###
## Data
import numpy as np
import pandas as pd
import torch
import json
from tqdm import tqdm
from math import floor
from datasets import load_dataset
from collections import defaultdict
from transformers import AutoTokenizer
pd.options.display.float_format = '${:,.2f}'.format

# Analysis
# from gensim.models.doc2vec import Doc2Vec
# from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import nltk
from nltk.cluster import KMeansClusterer
import scipy.spatial.distance as sdist
from scipy.spatial import distance_matrix
# nltk.download('punkt') #make sure that punkt is downloaded

# App & Visualization
import streamlit as st
import altair as alt
import plotly.graph_objects as go
from streamlit_vega_lite import altair_component



# utils
from random import sample
from error_analysis import utils as ut


def down_samp(embedding):
    """Down sample a data frame for altiar visualization """
    # total number of positive and negative sentiments in the class
    #embedding = embedding.groupby('slice').apply(lambda x: x.sample(frac=0.3))
    total_size = embedding.groupby(['slice','label'], as_index=False).count()

    user_data = 0
    # if 'Your Sentences' in str(total_size['slice']):
    #     tmp = embedding.groupby(['slice'], as_index=False).count()
    #     val = int(tmp[tmp['slice'] == "Your Sentences"]['source'])
    #     user_data = val

    max_sample = total_size.groupby('slice').max()['content']

    # # down sample to meeting altair's max values
    # # but keep the proportional representation of groups
    down_samp = 1/(sum(max_sample.astype(float))/(1000-user_data))

    max_samp = max_sample.apply(lambda x: floor(x*down_samp)).astype(int).to_dict()
    max_samp['Your Sentences'] = user_data

    # # sample down for each group in the data frame
    embedding = embedding.groupby('slice').apply(lambda x: x.sample(n=max_samp.get(x.name))).reset_index(drop=True)

    # # order the embedding
    return(embedding)


def data_comparison(df):
    selection = alt.selection_multi(fields=['cluster:N','label:O'])
    color = alt.condition(alt.datum.slice == 'high-loss', alt.Color('cluster:N', scale = alt.Scale(domain=df.cluster.unique().tolist())), alt.value("lightgray"))
    opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))

    # basic chart
    scatter = alt.Chart(df).mark_point(size=100, filled=True).encode(
        x=alt.X('x:Q', axis=None),
        y=alt.Y('y:Q', axis=None),
        color=color,
        shape=alt.Shape('label:O', scale=alt.Scale(range=['circle', 'diamond'])),
        tooltip=['cluster:N','slice:N','content:N','label:O','pred:O'],
        opacity=opacity
    ).properties(
        width=1000,
        height=800
    ).interactive()

    legend = alt.Chart(df).mark_point(size=100, filled=True).encode(
        x=alt.X("label:O"),
        y=alt.Y('cluster:N', axis=alt.Axis(orient='right'), title=""),
        shape=alt.Shape('label:O', scale=alt.Scale(
        range=['circle', 'diamond']), legend=None),
        color=color,
    ).add_selection(
        selection
    )
    layered = scatter | legend
    layered = layered.configure_axis(
        grid=False
    ).configure_view(
        strokeOpacity=0
    )
    return layered

def quant_panel(embedding_df):
    """ Quantitative Panel Layout"""
    all_metrics = {}
    st.warning("**Error slice visualization**")
    with st.expander("How to read this chart:"):
        st.markdown("* Each **point** is an input example.")
        st.markdown("* Gray points have low-loss and the colored have high-loss. High-loss instances are clustered using **kmeans** and each color represents a cluster.")
        st.markdown("* The **shape** of each point reflects the label category --  positive (diamond) or negative sentiment (circle).")
    st.altair_chart(data_comparison(down_samp(embedding_df)), use_container_width=True)


def frequent_tokens(data, tokenizer, loss_quantile=0.95, top_k=200, smoothing=0.005):
    unique_tokens = []
    tokens = []
    for row in tqdm(data['content']):
        tokenized = tokenizer(row,padding=True, return_tensors='pt')
        tokens.append(tokenized['input_ids'].flatten())
        unique_tokens.append(torch.unique(tokenized['input_ids']))
    losses = data['loss'].astype(float)
    high_loss = losses.quantile(loss_quantile)
    loss_weights = (losses > high_loss)
    loss_weights = loss_weights / loss_weights.sum()
    token_frequencies = defaultdict(float)
    token_frequencies_error = defaultdict(float)

    weights_uniform = np.full_like(loss_weights, 1 / len(loss_weights))

    num_examples = len(data)
    for i in tqdm(range(num_examples)):
        for token in unique_tokens[i]:
            token_frequencies[token.item()] += weights_uniform[i]
            token_frequencies_error[token.item()] += loss_weights[i]

    token_lrs = {k: (smoothing+token_frequencies_error[k]) / (smoothing+token_frequencies[k]) for k in token_frequencies}
    tokens_sorted = list(map(lambda x: x[0], sorted(token_lrs.items(), key=lambda x: x[1])[::-1]))

    top_tokens = []
    for i, (token) in enumerate(tokens_sorted[:top_k]):
        top_tokens.append(['%10s' % (tokenizer.decode(token)), '%.4f' % (token_frequencies[token]), '%.4f' % (
            token_frequencies_error[token]), '%4.2f' % (token_lrs[token])])
    return pd.DataFrame(top_tokens, columns=['Token', 'Freq', 'Freq error slice', 'lrs'])


@st.cache(ttl=600)
def get_data(spotlight, emb):
    preds = spotlight.outputs.numpy()
    losses = spotlight.losses.numpy()
    embeddings = pd.DataFrame(emb, columns=['x', 'y'])
    num_examples = len(losses)
    # dataset_labels = [dataset[i]['label'] for i in range(num_examples)]
    return pd.concat([pd.DataFrame(np.transpose(np.vstack([dataset[:num_examples]['content'], 
                    dataset[:num_examples]['label'], preds, losses])), columns=['content', 'label', 'pred', 'loss']), embeddings], axis=1)

@st.cache(ttl=600)
def clustering(data,num_clusters):
    X = np.array(data['embedding'].tolist())
    kclusterer = KMeansClusterer(
        num_clusters, distance=nltk.cluster.util.cosine_distance,
        repeats=25,avoid_empty_clusters=True)
    assigned_clusters = kclusterer.cluster(X, assign_clusters=True)
    data['cluster'] = pd.Series(assigned_clusters, index=data.index).astype('int')
    data['centroid'] = data['cluster'].apply(lambda x: kclusterer.means()[x])
        

    return data, assigned_clusters

@st.cache(ttl=600)
def kmeans(df, num_clusters=3):
    data_hl = df.loc[df['slice'] == 'high-loss']
    data_kmeans,clusters = clustering(data_hl,num_clusters)
    merged = pd.merge(df, data_kmeans, left_index=True, right_index=True, how='outer', suffixes=('', '_y'))
    merged.drop(merged.filter(regex='_y$').columns.tolist(),axis=1,inplace=True)
    merged['cluster'] = merged['cluster'].fillna(num_clusters).astype('int')
    return merged

@st.cache(ttl=600)
def distance_from_centroid(row):
    return sdist.norm(row['embedding'] - row['centroid'].tolist())

@st.cache(ttl=600)
def topic_distribution(weights, smoothing=0.01):
    topic_frequencies = defaultdict(float)
    topic_frequencies_spotlight = defaultdict(float)
    weights_uniform = np.full_like(weights, 1 / len(weights))
    num_examples = len(weights)
    for i in range(num_examples):
        example = dataset[i]
        category = example['title']
        topic_frequencies[category] += weights_uniform[i]
        topic_frequencies_spotlight[category] += weights[i]

    topic_ratios = {c: (smoothing + topic_frequencies_spotlight[c]) / (
        smoothing + topic_frequencies[c]) for c in topic_frequencies}

    categories_sorted = map(lambda x: x[0], sorted(
        topic_ratios.items(), key=lambda x: x[1], reverse=True))

    topic_distr = []
    for category in categories_sorted:
        topic_distr.append(['%.3f' % topic_frequencies[category], '%.3f' %
                           topic_frequencies_spotlight[category], '%.2f' % topic_ratios[category], '%s' % category])

    return pd.DataFrame(topic_distr, columns=['Overall frequency', 'Error frequency', 'Ratio', 'Category'])
    # for category in categories_sorted:
    #    return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category)

def populate_session(dataset,model):
    data_df = pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'.parquet')
    if model == 'albert-base-v2-yelp-polarity':
        tokenizer = AutoTokenizer.from_pretrained('textattack/'+model)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model)
    if "user_data" not in st.session_state:
        st.session_state["user_data"] = data_df
    if "selected_slice" not in st.session_state:
        st.session_state["selected_slice"] = None
 


if __name__ == "__main__":
    ### STREAMLIT APP CONGFIG ###
    st.set_page_config(layout="wide", page_title="Interactive Error Analysis")

    ut.init_style()

    lcol, rcol = st.columns([2, 2])
    # ******* loading the mode and the data
    #st.sidebar.mardown("<h4>Interactive Error Analysis</h4>", unsafe_allow_html=True)

    dataset = st.sidebar.selectbox(
        "Dataset",
        ["amazon_polarity", "yelp_polarity"],
        index = 1
    )

    model = st.sidebar.selectbox(
        "Model",
        ["distilbert-base-uncased-finetuned-sst-2-english",
            "albert-base-v2-yelp-polarity"],
    )

    ### LOAD DATA AND SESSION VARIABLES ###
    ##uncomment the next next line to run dynamically and not from file
    #populate_session(dataset, model)
    data_df = pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'.parquet')
    loss_quantile = st.sidebar.slider(
        "Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
    )
    data_df['loss'] = data_df['loss'].astype(float)
    losses = data_df['loss']
    high_loss = losses.quantile(loss_quantile)
    data_df['slice'] = 'high-loss'
    data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss') 
    
    with rcol:
        with st.spinner(text='loading...'):
            st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
            #uncomment the next two lines to run dynamically and not from file
            #commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
            commontokens = pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'_commontokens.parquet')
            with st.expander("How to read the table:"):
                st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
            st.write(commontokens)

    run_kmeans = st.sidebar.radio("Cluster error slice?", ('True', 'False'), index=0)

    num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)

    if run_kmeans == 'True':
        merged = kmeans(data_df,num_clusters=num_clusters)
    with lcol:
        st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
        dataframe=pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'_error-slices.parquet')
        #uncomment the next next line to run dynamically and not from file
        # dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
        #     by=['loss'], ascending=False)
        # table_html = dataframe.to_html(
        #     columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
        # table_html = table_html.replace("<th>", '<th align="left">')  # left-align the headers
        with st.expander("How to read the table:"):
            st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
            st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
            st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
        st.write(dataframe,width=900, height=300)

    quant_panel(merged)