import streamlit as st
import plotly.graph_objects as go
import numpy as np
import pandas as pd

# Hugging Face Colors
fillcolor = "#FFD21E"
line_color = "#FF9D00"

fill_color_list = [fillcolor, "#F05998", "#40BAF0"]
line_color_list = [line_color, "#5E233C", "#194A5E"]
# opacity of the plot
opacity = 0.75

# categories to show radar chart
categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"]
# Dataset columns
columns = ["index","model_name", "model_dtype", "ARC", "HellaSwag", "TruthfulQA",
            "Winogrande", "GSM8K","MMLU", "Average"]


#@st.cache_data
def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
    """
    plot the index-th row of the dataframe
    
    Arguments:
    dataframe: a pandas DataFrame 
    index: the index of the row we want to plot
    categories: the list of the metrics
    fillcolor: a string specifying the color to fill the area
    line_color: a string specifying the color of the lines in the graph
    """
    fig = go.Figure()
    data = dataframe.loc[index,categories].to_numpy()*100
    data  = data.astype(float)
    # rounding data
    data = data.round(decimals = 2)
    
    # add data to close the area of the radar chart
    data = np.append(data, data[0])
    categories_theta = categories.copy()
    categories_theta.append(categories[0])
    model_name = dataframe.loc[index,"model_name"] 
    #print("Printing data ", data, " for ", model_name)

    fig.add_trace(go.Scatterpolar(
          r=data,
          theta=categories_theta,
          fill='toself',
          fillcolor = fillcolor,
          opacity = opacity,
          line=dict(color = line_color),
          name= model_name
    ))
    fig.update_layout(
      polar=dict(
        radialaxis=dict(
          visible=True,
          range=[0, 100.]
        )),
      showlegend=False
    )

    return fig

#@st.cache_data
def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
    """
    plot the results of the model named model_name row of the dataframe
    
    Arguments:
    dataframe: a pandas DataFrame 
    model_name: a string stating the name of the model
    categories: the list of the metrics
    fillcolor: a string specifying the color to fill the area
    line_color: a string specifying the color of the lines in the graph
    """
    fig = go.Figure()
    data = dataframe[dataframe["model_name"] == model_name][categories].to_numpy()*100
    data  = data.astype(float)
    # rounding data
    data = data.round(decimals = 2)
    
    # add data to close the area of the radar chart
    data = np.append(data, data[0])
    categories_theta = categories.copy()
    categories_theta.append(categories[0])
    model_name = model_name 
    #print("Printing data ", data, " for ", model_name)

    fig.add_trace(go.Scatterpolar(
          r=data,
          theta=categories_theta,
          fill='toself',
          fillcolor = fillcolor,
          opacity = opacity,
          line=dict(color = line_color),
          name= model_name
    ))
    fig.update_layout(
      polar=dict(
        radialaxis=dict(
          visible=True,
          range=[0, 100.]
        )),
      showlegend=False
    )

    return fig


#@st.cache_data
def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list = categories, fillcolor_list: str = fill_color_list, line_color_list:str = line_color_list):
    """
    plot the results of the model selected by the checkbox
    
    Arguments:
    rows: an iterable whose elements are dicts with columns as their keys 
    columns: the list of the columns to use
    categories: the list of the metrics
    fillcolor: a string specifying the color to fill the area
    line_color: a string specifying the color of the lines in the graph
    """
    fig = go.Figure()
    dataset = pd.DataFrame(rows, columns=columns)
    data = dataset[categories].to_numpy()
    data  = data.astype(float)
    showLegend = False
    if len(rows) > 1:
        showLegend = True

    
    # add data to close the area of the radar chart
    data = np.append(data, data[:,0].reshape((-1,1)), axis=1)
    categories_theta = categories.copy()
    categories_theta.append(categories[0])
     
    opacity = 0.75
    for i in range(len(dataset)):
      colors = fillcolor_list[i]
      
      fig.add_trace(go.Scatterpolar(
            r=data[i,:],
            theta=categories_theta,
            fill='toself',
            fillcolor = colors,
            opacity = opacity,
            line=dict(color = line_color_list[i]),
            name= dataset.loc[i,"model_name"]
      ))
      fig.update_layout(
        polar=dict(
          radialaxis=dict(
            visible=True,
            range=[0, 100.]
          )),
        showlegend=showLegend
      )
      opacity -= .2

    return fig