import pandas as pd
import json
import os

from src.collect_data import fetch_version_metadata, fetch_registry_data
from assets.text_content import LANG_MAPPING
PRICING_PATH = os.path.join('assets', 'pricing.json')

# Convert parameters to float, handling both B and T suffixes
def convert_parameters(param):
    if pd.isna(param) or param == '':
        return None
    param = str(param)
    if 'T' in param:
        return float(param.replace('T', '')) * 1000
    return float(param.replace('B', ''))

# Clean price strings by removing '$' and handling empty strings
def clean_price(price):
    if pd.isna(price) or price == '':
        return None
    return float(price.replace('$', ''))

# Handle language mapping for both string and list inputs
def map_languages(languages):
    if isinstance(languages, float) and pd.isna(languages):
        return None
    # If it's already a list
    if isinstance(languages, list):
        return ', '.join([LANG_MAPPING.get(str(lang), str(lang)) for lang in languages])
    # If it's a string
    if isinstance(languages, str):
        return ', '.join([LANG_MAPPING.get(lang.strip(), lang.strip()) for lang in languages.split(',')])
    # If it's an array or any other type
    try:
        return ', '.join([str(lang) for lang in languages])
    except:
        return str(languages)
    
# Extract multimodality fields
def get_multimodality_field(model_data, field):
    try:
        return model_data.get('model_config', {}).get('multimodality', {}).get(field, False)
    except:
        return False


def merge_data():

    mm_latency_df, mm_result_df, text_latency_df, text_result_df = fetch_version_metadata()
    registry_data = fetch_registry_data()
    with open(PRICING_PATH, 'r') as f:
        pricing_data = json.load(f)

    # Ensure the unnamed column is renamed to 'model'
    mm_result_df.rename(columns={'Unnamed: 0': 'model', '-, clemscore': 'clemscore'}, inplace=True)
    text_result_df.rename(columns={'Unnamed: 0': 'model', '-, clemscore': 'clemscore'}, inplace=True)
    mm_result_df['model'] = mm_result_df['model'].str.split('-t0.0--').str[0]
    text_result_df['model'] = text_result_df['model'].str.split('-t0.0--').str[0] # Bug in get_latency.py, split by -t0.0 instead of -t (gpt-3.5-turbo/gpt-4-turbo breaks) 

    # Merge datasets to compute average values
    avg_latency_df = pd.concat([mm_latency_df, text_latency_df], axis=0).groupby('model')['latency'].mean().reset_index()
    avg_clemscore_df = pd.concat([mm_result_df, text_result_df], axis=0).groupby('model')['clemscore'].mean().reset_index()

    # Merge latency, clemscore, registry, and pricing data
    lat_clem_df = pd.merge(avg_latency_df, avg_clemscore_df, on='model', how='outer')

    # Convert registry_data to DataFrame for easier merging
    registry_df = pd.DataFrame(registry_data)
    
    # Extract license info
    registry_df['license_name'] = registry_df['license'].apply(lambda x: x['name'])
    registry_df['license_url'] = registry_df['license'].apply(lambda x: x['url'])

    # Add individual multimodality columns
    registry_df['single_image'] = registry_df.apply(lambda x: get_multimodality_field(x, 'single_image'), axis=1)
    registry_df['multiple_images'] = registry_df.apply(lambda x: get_multimodality_field(x, 'multiple_images'), axis=1)
    registry_df['audio'] = registry_df.apply(lambda x: get_multimodality_field(x, 'audio'), axis=1)
    registry_df['video'] = registry_df.apply(lambda x: get_multimodality_field(x, 'video'), axis=1)

    # Update columns list to include new multimodality fields
    registry_df = registry_df[[
        'model_name', 'parameters', 'release_date', 'open_weight',
        'languages', 'context_size', 'license_name', 'license_url',
        'single_image', 'multiple_images', 'audio', 'video'
    ]]
    
    # Merge with previous data
    merged_df = pd.merge(
        lat_clem_df,
        registry_df,
        left_on='model',
        right_on='model_name',
        how='inner'
    )
    
    # Update column renaming
    merged_df = merged_df.rename(columns={
        'model': 'Model Name',
        'latency': 'Latency (s)',
        'clemscore': 'Clemscore',
        'parameters': 'Parameters (B)',
        'release_date': 'Release Date',
        'open_weight': 'Open Weight',
        'languages': 'Languages',
        'context_size': 'Context Size (k)',
        'license_name': 'License Name',
        'license_url': 'License URL',
        'single_image': 'Single Image',
        'multiple_images': 'Multiple Images',
        'audio': 'Audio',
        'video': 'Video'
    })
    
    # Convert pricing_data list to DataFrame
    pricing_df = pd.DataFrame(pricing_data)
    pricing_df['input'] = pricing_df['input'].apply(clean_price)
    pricing_df['output'] = pricing_df['output'].apply(clean_price)
    
    # Merge pricing data with the existing dataframe
    merged_df = pd.merge(
        merged_df,
        pricing_df,
        left_on='Model Name',
        right_on='model_id',
        how='left'
    )
    
    # Drop duplicate model column and rename price columns
    merged_df = merged_df.drop('model_id', axis=1)
    merged_df = merged_df.rename(columns={
        'input': 'Input $/1M tokens',
        'output': 'Output $/1M tokens'
    })
    
    # Fill NaN values with 0.0 for pricing columns
    merged_df['Input $/1M tokens'] = merged_df['Input $/1M tokens'].fillna(0.0)
    merged_df['Output $/1M tokens'] = merged_df['Output $/1M tokens'].fillna(0.0)
    
    # Convert parameters and set to None for commercial models
    merged_df['Parameters (B)'] = merged_df.apply(
        lambda row: None if not row['Open Weight'] else convert_parameters(row['Parameters (B)']), 
        axis=1
    )

    merged_df['License'] = merged_df.apply(lambda row: f'<a href="{row["License URL"]}" style="color: blue;">{row["License Name"]}</a>', axis=1)
    merged_df['Temp Date'] = merged_df['Release Date']

    merged_df['Languages'] = merged_df['Languages'].apply(map_languages)

    # Sort by Clemscore in descending order
    merged_df = merged_df.sort_values(by='Clemscore', ascending=False)
    
    # Drop model_name column
    merged_df.drop(columns=['model_name'], inplace=True)
    
    return merged_df

if __name__=='__main__':
    merged_df = merge_data()
    # # Save to CSV
    output_path = os.path.join('assets', 'merged_data.csv')
    merged_df.to_csv(output_path, index=False)