{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "d44219ba-b2cf-45c8-98fa-69a43e27f5b0", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelCategoryClassPrecisionRecallF1 ScoreAccuracyBalanced Acc.Support
0GPT-4omain-eventcollision0.8245610.0561190.105087NaNNaN1675
1GPT-4omain-eventnear-collision0.6701030.0074730.014781NaNNaN8698
2GPT-4omain-eventOverall (main-event)0.7473320.0317960.0599340.0153280.03179610373
3GPT-4olocationalley0.3387100.0917030.144330NaNNaN229
4GPT-4olocationhighway0.8145290.8778460.845003NaNNaN1801
\n", "
" ], "text/plain": [ " Model Category Class Precision Recall F1 Score \\\n", "0 GPT-4o main-event collision 0.824561 0.056119 0.105087 \n", "1 GPT-4o main-event near-collision 0.670103 0.007473 0.014781 \n", "2 GPT-4o main-event Overall (main-event) 0.747332 0.031796 0.059934 \n", "3 GPT-4o location alley 0.338710 0.091703 0.144330 \n", "4 GPT-4o location highway 0.814529 0.877846 0.845003 \n", "\n", " Accuracy Balanced Acc. Support \n", "0 NaN NaN 1675 \n", "1 NaN NaN 8698 \n", "2 0.015328 0.031796 10373 \n", "3 NaN NaN 229 \n", "4 NaN NaN 1801 " ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "import pandas as pd\n", "from comparison import ModelEvaluator, ModelComparison\n", "\n", "def load_data(directory='results', labels_filename='Labels.csv'):\n", " labels_path = os.path.join(directory, labels_filename)\n", " df_labels = pd.read_csv(labels_path)\n", "\n", " evaluators = []\n", " for filename in os.listdir(directory):\n", " if filename.endswith('.csv') and filename != labels_filename:\n", " model_name = os.path.splitext(filename)[0]\n", " df_model = pd.read_csv(os.path.join(directory, filename))\n", " evaluator = ModelEvaluator(df_labels, df_model, model_name)\n", " evaluators.append(evaluator)\n", "\n", " model_comparison = ModelComparison(evaluators)\n", " return model_comparison\n", "\n", "model_comparison = load_data()\n", "\n", "combined_df = model_comparison.combined_df\n", "\n", "\n", "combined_df.head()" ] }, { "cell_type": "code", "execution_count": 2, "id": "eaa0047b-aa72-4c4e-8089-472432b62123", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelCategoryClassPrecisionRecallF1 ScoreAccuracyBalanced Acc.Support
0GPT-4omain-eventcollision0.8245610.0561190.105087NaNNaN1675
1GPT-4omain-eventnear-collision0.6701030.0074730.014781NaNNaN8698
2GPT-4omain-eventOverall (main-event)0.7473320.0317960.0599340.0153280.03179610373
3GPT-4olocationalley0.3387100.0917030.144330NaNNaN229
4GPT-4olocationhighway0.8145290.8778460.845003NaNNaN1801
5GPT-4olocationmain-road0.9450360.9443430.944689NaNNaN8175
6GPT-4olocationparking-lot0.7014930.8392860.764228NaNNaN168
7GPT-4olocationOverall (location)0.6999420.6882940.6745620.9122720.68829410373
8GPT-4ozonerural0.5116280.0574410.103286NaNNaN383
9GPT-4ozonesuburbs0.4950000.2145950.299395NaNNaN1384
10GPT-4ozoneunknown0.2500000.0021690.004301NaNNaN461
11GPT-4ozoneurban0.8160600.9744630.888255NaNNaN8145
12GPT-4ozoneOverall (zone)0.5181720.3121670.3238090.7960090.31216710373
13GPT-4olight-conditionsdaylight0.9720070.9680890.970044NaNNaN5954
14GPT-4olight-conditionsnight0.9661590.9772670.971681NaNNaN3827
15GPT-4olight-conditionstwilight0.6013990.5810810.591065NaNNaN592
16GPT-4olight-conditionsOverall (light-conditions)0.8465210.8421460.8442630.9493880.84214610373
17GPT-4oweather-conditionsclear-sky0.8580410.9781250.914156NaNNaN7817
18GPT-4oweather-conditionscloudy0.7781200.2764090.407916NaNNaN1827
19GPT-4oweather-conditionsrain0.7982780.8902610.841764NaNNaN729
20GPT-4oweather-conditionsOverall (weather-conditions)0.8114800.7149320.7212790.8483560.71493210373
21GPT-4ovehicles-densityhigh0.8123750.0690530.127287NaNNaN5894
22GPT-4ovehicles-densitylow0.4996610.3156540.386894NaNNaN2338
23GPT-4ovehicles-densitymedium0.2200120.8626810.350607NaNNaN2141
24GPT-4ovehicles-densityOverall (vehicles-density)0.5106830.4157960.2882630.2884410.41579610373
\n", "
" ], "text/plain": [ " Model Category Class Precision \\\n", "0 GPT-4o main-event collision 0.824561 \n", "1 GPT-4o main-event near-collision 0.670103 \n", "2 GPT-4o main-event Overall (main-event) 0.747332 \n", "3 GPT-4o location alley 0.338710 \n", "4 GPT-4o location highway 0.814529 \n", "5 GPT-4o location main-road 0.945036 \n", "6 GPT-4o location parking-lot 0.701493 \n", "7 GPT-4o location Overall (location) 0.699942 \n", "8 GPT-4o zone rural 0.511628 \n", "9 GPT-4o zone suburbs 0.495000 \n", "10 GPT-4o zone unknown 0.250000 \n", "11 GPT-4o zone urban 0.816060 \n", "12 GPT-4o zone Overall (zone) 0.518172 \n", "13 GPT-4o light-conditions daylight 0.972007 \n", "14 GPT-4o light-conditions night 0.966159 \n", "15 GPT-4o light-conditions twilight 0.601399 \n", "16 GPT-4o light-conditions Overall (light-conditions) 0.846521 \n", "17 GPT-4o weather-conditions clear-sky 0.858041 \n", "18 GPT-4o weather-conditions cloudy 0.778120 \n", "19 GPT-4o weather-conditions rain 0.798278 \n", "20 GPT-4o weather-conditions Overall (weather-conditions) 0.811480 \n", "21 GPT-4o vehicles-density high 0.812375 \n", "22 GPT-4o vehicles-density low 0.499661 \n", "23 GPT-4o vehicles-density medium 0.220012 \n", "24 GPT-4o vehicles-density Overall (vehicles-density) 0.510683 \n", "\n", " Recall F1 Score Accuracy Balanced Acc. Support \n", "0 0.056119 0.105087 NaN NaN 1675 \n", "1 0.007473 0.014781 NaN NaN 8698 \n", "2 0.031796 0.059934 0.015328 0.031796 10373 \n", "3 0.091703 0.144330 NaN NaN 229 \n", "4 0.877846 0.845003 NaN NaN 1801 \n", "5 0.944343 0.944689 NaN NaN 8175 \n", "6 0.839286 0.764228 NaN NaN 168 \n", "7 0.688294 0.674562 0.912272 0.688294 10373 \n", "8 0.057441 0.103286 NaN NaN 383 \n", "9 0.214595 0.299395 NaN NaN 1384 \n", "10 0.002169 0.004301 NaN NaN 461 \n", "11 0.974463 0.888255 NaN NaN 8145 \n", "12 0.312167 0.323809 0.796009 0.312167 10373 \n", "13 0.968089 0.970044 NaN NaN 5954 \n", "14 0.977267 0.971681 NaN NaN 3827 \n", "15 0.581081 0.591065 NaN NaN 592 \n", "16 0.842146 0.844263 0.949388 0.842146 10373 \n", "17 0.978125 0.914156 NaN NaN 7817 \n", "18 0.276409 0.407916 NaN NaN 1827 \n", "19 0.890261 0.841764 NaN NaN 729 \n", "20 0.714932 0.721279 0.848356 0.714932 10373 \n", "21 0.069053 0.127287 NaN NaN 5894 \n", "22 0.315654 0.386894 NaN NaN 2338 \n", "23 0.862681 0.350607 NaN NaN 2141 \n", "24 0.415796 0.288263 0.288441 0.415796 10373 " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_comparison.evaluators[0].metrics_df" ] }, { "cell_type": "code", "execution_count": 3, "id": "a0777e36-cfe4-44ac-85fd-8bf297cc3152", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelCategoryClassPrecisionRecallF1 ScoreAccuracyBalanced Acc.Support
0GPT-4omain-eventcollision0.8245610.0561190.105087NaNNaN1675
1GPT-4omain-eventnear-collision0.6701030.0074730.014781NaNNaN8698
2GPT-4omain-eventOverall (main-event)0.7473320.0317960.0599340.0153280.03179610373
3GPT-4olocationalley0.3387100.0917030.144330NaNNaN229
4GPT-4olocationhighway0.8145290.8778460.845003NaNNaN1801
5GPT-4olocationmain-road0.9450360.9443430.944689NaNNaN8175
6GPT-4olocationparking-lot0.7014930.8392860.764228NaNNaN168
7GPT-4olocationOverall (location)0.6999420.6882940.6745620.9122720.68829410373
8GPT-4ozonerural0.5116280.0574410.103286NaNNaN383
9GPT-4ozonesuburbs0.4950000.2145950.299395NaNNaN1384
10GPT-4ozoneunknown0.2500000.0021690.004301NaNNaN461
11GPT-4ozoneurban0.8160600.9744630.888255NaNNaN8145
12GPT-4ozoneOverall (zone)0.5181720.3121670.3238090.7960090.31216710373
13GPT-4olight-conditionsdaylight0.9720070.9680890.970044NaNNaN5954
14GPT-4olight-conditionsnight0.9661590.9772670.971681NaNNaN3827
15GPT-4olight-conditionstwilight0.6013990.5810810.591065NaNNaN592
16GPT-4olight-conditionsOverall (light-conditions)0.8465210.8421460.8442630.9493880.84214610373
17GPT-4oweather-conditionsclear-sky0.8580410.9781250.914156NaNNaN7817
18GPT-4oweather-conditionscloudy0.7781200.2764090.407916NaNNaN1827
19GPT-4oweather-conditionsrain0.7982780.8902610.841764NaNNaN729
20GPT-4oweather-conditionsOverall (weather-conditions)0.8114800.7149320.7212790.8483560.71493210373
21GPT-4ovehicles-densityhigh0.8123750.0690530.127287NaNNaN5894
22GPT-4ovehicles-densitylow0.4996610.3156540.386894NaNNaN2338
23GPT-4ovehicles-densitymedium0.2200120.8626810.350607NaNNaN2141
24GPT-4ovehicles-densityOverall (vehicles-density)0.5106830.4157960.2882630.2884410.41579610373
\n", "
" ], "text/plain": [ " Model Category Class Precision \\\n", "0 GPT-4o main-event collision 0.824561 \n", "1 GPT-4o main-event near-collision 0.670103 \n", "2 GPT-4o main-event Overall (main-event) 0.747332 \n", "3 GPT-4o location alley 0.338710 \n", "4 GPT-4o location highway 0.814529 \n", "5 GPT-4o location main-road 0.945036 \n", "6 GPT-4o location parking-lot 0.701493 \n", "7 GPT-4o location Overall (location) 0.699942 \n", "8 GPT-4o zone rural 0.511628 \n", "9 GPT-4o zone suburbs 0.495000 \n", "10 GPT-4o zone unknown 0.250000 \n", "11 GPT-4o zone urban 0.816060 \n", "12 GPT-4o zone Overall (zone) 0.518172 \n", "13 GPT-4o light-conditions daylight 0.972007 \n", "14 GPT-4o light-conditions night 0.966159 \n", "15 GPT-4o light-conditions twilight 0.601399 \n", "16 GPT-4o light-conditions Overall (light-conditions) 0.846521 \n", "17 GPT-4o weather-conditions clear-sky 0.858041 \n", "18 GPT-4o weather-conditions cloudy 0.778120 \n", "19 GPT-4o weather-conditions rain 0.798278 \n", "20 GPT-4o weather-conditions Overall (weather-conditions) 0.811480 \n", "21 GPT-4o vehicles-density high 0.812375 \n", "22 GPT-4o vehicles-density low 0.499661 \n", "23 GPT-4o vehicles-density medium 0.220012 \n", "24 GPT-4o vehicles-density Overall (vehicles-density) 0.510683 \n", "\n", " Recall F1 Score Accuracy Balanced Acc. Support \n", "0 0.056119 0.105087 NaN NaN 1675 \n", "1 0.007473 0.014781 NaN NaN 8698 \n", "2 0.031796 0.059934 0.015328 0.031796 10373 \n", "3 0.091703 0.144330 NaN NaN 229 \n", "4 0.877846 0.845003 NaN NaN 1801 \n", "5 0.944343 0.944689 NaN NaN 8175 \n", "6 0.839286 0.764228 NaN NaN 168 \n", "7 0.688294 0.674562 0.912272 0.688294 10373 \n", "8 0.057441 0.103286 NaN NaN 383 \n", "9 0.214595 0.299395 NaN NaN 1384 \n", "10 0.002169 0.004301 NaN NaN 461 \n", "11 0.974463 0.888255 NaN NaN 8145 \n", "12 0.312167 0.323809 0.796009 0.312167 10373 \n", "13 0.968089 0.970044 NaN NaN 5954 \n", "14 0.977267 0.971681 NaN NaN 3827 \n", "15 0.581081 0.591065 NaN NaN 592 \n", "16 0.842146 0.844263 0.949388 0.842146 10373 \n", "17 0.978125 0.914156 NaN NaN 7817 \n", "18 0.276409 0.407916 NaN NaN 1827 \n", "19 0.890261 0.841764 NaN NaN 729 \n", "20 0.714932 0.721279 0.848356 0.714932 10373 \n", "21 0.069053 0.127287 NaN NaN 5894 \n", "22 0.315654 0.386894 NaN NaN 2338 \n", "23 0.862681 0.350607 NaN NaN 2141 \n", "24 0.415796 0.288263 0.288441 0.415796 10373 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "combined_df.loc[combined_df['Model']=='GPT-4o']" ] }, { "cell_type": "code", "execution_count": 4, "id": "9b59f972-27c1-439e-a974-cb642a54febd", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "categories = combined_df[\"Category\"].unique()\n", "models = combined_df[\"Model\"].unique()\n", "metrics = ['F1 Score', 'Precision', 'Recall']\n", "\n", "selected_category = categories[4]\n", "selected_model = models[0]\n", "\n", "model_data = combined_df[(combined_df[\"Category\"] == selected_category) & (combined_df[\"Model\"] == selected_model)]\n", "\n", "sns.set(style=\"whitegrid\")\n", "\n", "plt.figure(figsize=(12, 6))\n", "df_melted = model_data.melt(id_vars=[\"Class\"], value_vars=metrics, var_name=\"Metric\", value_name=\"Score\")\n", "sns.barplot(data=df_melted, x=\"Class\", y=\"Score\", hue=\"Metric\", palette=\"viridis\")\n", "plt.xticks(rotation=45)\n", "plt.title(f\"Performance Metrics by Class for {selected_model} ({selected_category})\")\n", "plt.ylabel(\"Score\")\n", "plt.ylim(0, 1)\n", "plt.legend(title=\"Metric\")\n", "plt.show()\n", "\n", "performance_table = model_data.pivot(index=\"Class\", columns=\"Model\", values=metrics).round(4)" ] }, { "cell_type": "code", "execution_count": 5, "id": "8f97a9ad-4de3-49ea-883c-ebe8a5c55f2c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
F1 ScorePrecisionRecall
ModelGPT-4oGPT-4oGPT-4o
Class
Overall (weather-conditions)0.72130.81150.7149
clear-sky0.91420.85800.9781
cloudy0.40790.77810.2764
rain0.84180.79830.8903
\n", "
" ], "text/plain": [ " F1 Score Precision Recall\n", "Model GPT-4o GPT-4o GPT-4o\n", "Class \n", "Overall (weather-conditions) 0.7213 0.8115 0.7149\n", "clear-sky 0.9142 0.8580 0.9781\n", "cloudy 0.4079 0.7781 0.2764\n", "rain 0.8418 0.7983 0.8903" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "performance_table" ] }, { "cell_type": "code", "execution_count": 6, "id": "d640dd28-4acd-49ea-9b30-add33f4717a4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Classes in 'weather-conditions': ['clear-sky' 'cloudy' 'rain' 'Overall (weather-conditions)']\n" ] } ], "source": [ "all_classes_in_category = combined_df[combined_df[\"Category\"] == \"weather-conditions\"][\"Class\"].unique()\n", "print(\"Classes in 'weather-conditions':\", all_classes_in_category)\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "3364ab40-dc38-4cbd-8667-7ab13afbae85", "metadata": {}, "outputs": [], "source": [ "df_labels = pd.read_csv('results/Labels.csv')\n", "df_predictions = pd.read_csv('results/GPT-4o.csv')" ] }, { "cell_type": "code", "execution_count": 7, "id": "b475f167-3b27-44d0-9ffc-b16dd4232bad", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ClassF1 Score
0clear-sky0.910311
1cloudy0.404647
2fog0.000000
3rain0.834827
4snow0.000000
5unknown0.000000
\n", "
" ], "text/plain": [ " Class F1 Score\n", "0 clear-sky 0.910311\n", "1 cloudy 0.404647\n", "2 fog 0.000000\n", "3 rain 0.834827\n", "4 snow 0.000000\n", "5 unknown 0.000000" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.metrics import f1_score\n", "\n", "merged_df = pd.merge(df_labels, df_predictions, on='id', suffixes=('_true', '_pred'))\n", "y_true = merged_df[\"weather-conditions_true\"].astype(str)\n", "y_pred = merged_df[\"weather-conditions_pred\"].astype(str)\n", "\n", "labels = sorted(set(y_true) | set(y_pred))\n", "\n", "class_f1_scores = f1_score(y_true, y_pred, labels=labels, average=None, zero_division=0)\n", "\n", "f1_per_class = pd.DataFrame({\"Class\": labels, \"F1 Score\": class_f1_scores})\n", "f1_per_class" ] }, { "cell_type": "code", "execution_count": 8, "id": "7dc70d74-0169-4062-8595-d2c62a6aa3d4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.3582976364012411" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "overall_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)\n", "overall_f1" ] }, { "cell_type": "code", "execution_count": 9, "id": "8e5e359b-98de-4cb9-a737-f0f3ef3001ca", "metadata": {}, "outputs": [], "source": [ "merged_df = pd.merge(df_labels, df_predictions, on='id', suffixes=('_true', '_pred'))" ] }, { "cell_type": "code", "execution_count": 10, "id": "706aa9fa-82ea-4ad8-b3d9-133b635a1f53", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idmain-event_truelocation_truezone_truelight-conditions_trueweather-conditions_truevehicles-density_truemain-event_predlocation_predzone_predlight-conditions_predweather-conditions_predvehicles-density_pred
0clgr0y79557jk076o681egs6pnear-collisionmain-roadurbannightclear-skyhighnormal-drivingmain-roadurbannightclear-skymedium
1clgr0y79657lc076o46dt0xjbnear-collisionmain-roadurbantwilightclear-skyhighnormal-drivingmain-roadurbantwilightclear-skyhigh
2clgr0y79657n0076og7w6grt8near-collisionmain-roadurbandaylightcloudymediumnormal-drivingmain-roadurbandaylightrainmedium
3clgr0y7973glw0791gwwz2vd4near-collisionmain-roadurbandaylightclear-skymediumnormal-drivinghighwayurbandaylightclear-skymedium
4clgr0y7973gm407912vru07ydnear-collisionmain-roadurbannightclear-skylownormal-drivingmain-roadurbannightclear-skymedium
..........................................
10491cljex08nl2m93077n9t2jc06nnear-collisionmain-roadurbannightclear-skyhighnormal-drivingmain-roadurbannightclear-skymedium
10492cljex08nn19z507403nkf20fkcollisionhighwayunknowndaylightclear-skyhighnormal-drivinghighwayurbandaylightclear-skymedium
10493cljex08sa1eg507409e539opynear-collisionmain-roadurbandaylightclear-skyhighnormal-drivingmain-roadurbandaylightclear-skymedium
10494cljex09rl26px075u2woa9ziznear-collisionhighwayruraldaylightclear-skymediumnormal-drivinghighwaysuburbsdaylightclear-skymedium
10495cljex09rr276l075u01ec2k67near-collisionmain-roadsuburbstwilightclear-skymediumnormal-drivingmain-roadurbannightclear-skymedium
\n", "

10496 rows × 13 columns

\n", "
" ], "text/plain": [ " id main-event_true location_true zone_true \\\n", "0 clgr0y79557jk076o681egs6p near-collision main-road urban \n", "1 clgr0y79657lc076o46dt0xjb near-collision main-road urban \n", "2 clgr0y79657n0076og7w6grt8 near-collision main-road urban \n", "3 clgr0y7973glw0791gwwz2vd4 near-collision main-road urban \n", "4 clgr0y7973gm407912vru07yd near-collision main-road urban \n", "... ... ... ... ... \n", "10491 cljex08nl2m93077n9t2jc06n near-collision main-road urban \n", "10492 cljex08nn19z507403nkf20fk collision highway unknown \n", "10493 cljex08sa1eg507409e539opy near-collision main-road urban \n", "10494 cljex09rl26px075u2woa9ziz near-collision highway rural \n", "10495 cljex09rr276l075u01ec2k67 near-collision main-road suburbs \n", "\n", " light-conditions_true weather-conditions_true vehicles-density_true \\\n", "0 night clear-sky high \n", "1 twilight clear-sky high \n", "2 daylight cloudy medium \n", "3 daylight clear-sky medium \n", "4 night clear-sky low \n", "... ... ... ... \n", "10491 night clear-sky high \n", "10492 daylight clear-sky high \n", "10493 daylight clear-sky high \n", "10494 daylight clear-sky medium \n", "10495 twilight clear-sky medium \n", "\n", " main-event_pred location_pred zone_pred light-conditions_pred \\\n", "0 normal-driving main-road urban night \n", "1 normal-driving main-road urban twilight \n", "2 normal-driving main-road urban daylight \n", "3 normal-driving highway urban daylight \n", "4 normal-driving main-road urban night \n", "... ... ... ... ... \n", "10491 normal-driving main-road urban night \n", "10492 normal-driving highway urban daylight \n", "10493 normal-driving main-road urban daylight \n", "10494 normal-driving highway suburbs daylight \n", "10495 normal-driving main-road urban night \n", "\n", " weather-conditions_pred vehicles-density_pred \n", "0 clear-sky medium \n", "1 clear-sky high \n", "2 rain medium \n", "3 clear-sky medium \n", "4 clear-sky medium \n", "... ... ... \n", "10491 clear-sky medium \n", "10492 clear-sky medium \n", "10493 clear-sky medium \n", "10494 clear-sky medium \n", "10495 clear-sky medium \n", "\n", "[10496 rows x 13 columns]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "merged_df" ] }, { "cell_type": "code", "execution_count": 11, "id": "19481708-1dd4-4dff-8643-d3ba5bf9ae43", "metadata": {}, "outputs": [], "source": [ "for category in ['main-event', 'location', 'zone', 'light-conditions', 'weather-conditions', 'vehicles-density']:\n", " valid_values = df_labels[f\"{category}\"].unique().astype(str)\n", " merged_df = merged_df[merged_df[f\"{category}_pred\"].astype(str).isin(valid_values)]" ] }, { "cell_type": "code", "execution_count": 12, "id": "c55af0fe-1e02-45cb-a9ba-4d2fbd437355", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idmain-event_truelocation_truezone_truelight-conditions_trueweather-conditions_truevehicles-density_truemain-event_predlocation_predzone_predlight-conditions_predweather-conditions_predvehicles-density_pred
79clgr0v0823dag075s64swej3mnear-collisionmain-roadurbannightrainlownear-collisionmain-roadurbannightclear-skymedium
84cljewyyws223z075j0gmr2a68near-collisionmain-roadurbantwilightclear-skyhighnear-collisionmain-roadurbannightclear-skymedium
111clgr0yhnq0wxa078v64tf1faxcollisionmain-roadurbandaylightclear-skylownear-collisionmain-roadurbandaylightclear-skylow
218clgr0y77i5d960796etl94khdnear-collisionmain-roadurbannightclear-skyhighnear-collisionmain-roadurbannightclear-skymedium
264clgr0vkmb1zr80791fpno41uhnear-collisionmain-roadurbandaylightclear-skyhighnear-collisionmain-roadurbandaylightclear-skymedium
..........................................
10236clgr0y77u5e060796b73g9yzccollisionmain-roadurbannightclear-skyhighnear-collisionmain-roadurbannightclear-skymedium
10249clgr0y7a84wo8075s40bm20wbcollisionhighwaysuburbsdaylightrainmediumcollisionhighwayurbandaylightrainlow
10282clgr0wqyj0m1v078o5s9edrchcollisionmain-roadurbandaylightclear-skymediumcollisionmain-roadurbandaylightclear-skymedium
10301clgr0xb9p1jfm0775dvg8h13jcollisionmain-roadurbannightclear-skymediumcollisionmain-roadurbannightclear-skymedium
10359clgr0yhoi4mqx076l7msodxutnear-collisionmain-roadurbannightclear-skyhighnear-collisionmain-roadurbannightclear-skymedium
\n", "

211 rows × 13 columns

\n", "
" ], "text/plain": [ " id main-event_true location_true zone_true \\\n", "79 clgr0v0823dag075s64swej3m near-collision main-road urban \n", "84 cljewyyws223z075j0gmr2a68 near-collision main-road urban \n", "111 clgr0yhnq0wxa078v64tf1fax collision main-road urban \n", "218 clgr0y77i5d960796etl94khd near-collision main-road urban \n", "264 clgr0vkmb1zr80791fpno41uh near-collision main-road urban \n", "... ... ... ... ... \n", "10236 clgr0y77u5e060796b73g9yzc collision main-road urban \n", "10249 clgr0y7a84wo8075s40bm20wb collision highway suburbs \n", "10282 clgr0wqyj0m1v078o5s9edrch collision main-road urban \n", "10301 clgr0xb9p1jfm0775dvg8h13j collision main-road urban \n", "10359 clgr0yhoi4mqx076l7msodxut near-collision main-road urban \n", "\n", " light-conditions_true weather-conditions_true vehicles-density_true \\\n", "79 night rain low \n", "84 twilight clear-sky high \n", "111 daylight clear-sky low \n", "218 night clear-sky high \n", "264 daylight clear-sky high \n", "... ... ... ... \n", "10236 night clear-sky high \n", "10249 daylight rain medium \n", "10282 daylight clear-sky medium \n", "10301 night clear-sky medium \n", "10359 night clear-sky high \n", "\n", " main-event_pred location_pred zone_pred light-conditions_pred \\\n", "79 near-collision main-road urban night \n", "84 near-collision main-road urban night \n", "111 near-collision main-road urban daylight \n", "218 near-collision main-road urban night \n", "264 near-collision main-road urban daylight \n", "... ... ... ... ... \n", "10236 near-collision main-road urban night \n", "10249 collision highway urban daylight \n", "10282 collision main-road urban daylight \n", "10301 collision main-road urban night \n", "10359 near-collision main-road urban night \n", "\n", " weather-conditions_pred vehicles-density_pred \n", "79 clear-sky medium \n", "84 clear-sky medium \n", "111 clear-sky low \n", "218 clear-sky medium \n", "264 clear-sky medium \n", "... ... ... \n", "10236 clear-sky medium \n", "10249 rain low \n", "10282 clear-sky medium \n", "10301 clear-sky medium \n", "10359 clear-sky medium \n", "\n", "[211 rows x 13 columns]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "merged_df" ] } ], "metadata": { "kernelspec": { "display_name": "BabyDriver", "language": "python", "name": "babydriver" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }