import pandas as pd COLOR_MAP = { "yellow": "background-color: #FFFFCC", # Reasoning models "green": "background-color: #E3FBE9", # Linear attention hybrid "blue": "background-color: #E6F4FF" # SSM hybrid models } def style_zero_context(df): """ Similar approach to style_long_context: 1) color rows based on model name 2) numeric formatting """ import pandas as pd # Example color dict, tweak as needed: color_mapping = { "minimax-text-01": COLOR_MAP["green"], "jamba-1.5-large": COLOR_MAP["blue"], "deepseek-r1": COLOR_MAP["yellow"], "o1-mini": COLOR_MAP["yellow"], "qwq-32b-preview": COLOR_MAP["yellow"], # Add any other special-cased models here # "o1-mini": COLOR_MAP["yellow"], etc. } styler = df.style.apply( lambda row: [color_mapping.get(row["Model"], "")]*len(row), axis=1 ) # # Attach custom tooltips (optional) # tooltips = pd.DataFrame("", index=df.index, columns=df.columns) # if "1st<50% op" in df.columns: # tooltips["1st<50% op"] = "First operation number with accuracy <50%" # if "1st<10% op" in df.columns: # tooltips["1st<10% op"] = "First operation number with accuracy <10%" # if "Avg. Acc op≤30" in df.columns: # tooltips["Avg. Acc op≤30"] = "Average accuracy of first 30 operations" # styler = styler.set_tooltips(tooltips) # Apply numeric formatting styler = styler.format({ "Symbolic": "{:,.2f}", # Format as number with thousands separator and 1 decimal place "Medium": "{:,.2f}", # Format as number with thousands separator and 2 decimal places "Hard": "{:,.2f}", # Format as number with thousands separator and 2 decimal places "1st<50% op": "{:,.0f}", # Format as plain integer (no decimal places) "1st<10% op": "{:,.0f}", # Format as plain integer (no decimal places) "Avg. Acc op≤30": "{:.4f}", # Format with 4 decimal places "Average↑": "{:,.2f}" # Format as number with thousands separator and 2 decimal places }) return styler # Add styling for model types def style_long_context(df): color_mapping = { "minimax-text-01": COLOR_MAP["green"], "jamba-1.5-large": COLOR_MAP["blue"] } return df.style.apply( lambda row: [color_mapping.get(row["Model"], "")]*len(row), axis=1 ).format({ "8K": "{:,.2f}", "16K": "{:,.2f}", "32K": "{:,.2f}", "Average↑": "{:,.2f}" })