Upload 5 files
Browse files- .gitattributes +1 -0
- JMTB_1_rescore_float.csv +3 -0
- modernbert_run_test.ipynb +947 -0
- test_check.png +0 -0
- train_jmtb_test_v6 (コピー).ipynb +853 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
JMTB_1_rescore_float.csv filter=lfs diff=lfs merge=lfs -text
|
JMTB_1_rescore_float.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc332dc75886e5c8679ad098739972d741770591223fe0ada0e74078494487ca
|
3 |
+
size 47064523
|
modernbert_run_test.ipynb
ADDED
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"ename": "ValueError",
|
10 |
+
"evalue": "The config parameter `problem_type` was not understood: received single_label_regression but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid.",
|
11 |
+
"output_type": "error",
|
12 |
+
"traceback": [
|
13 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
14 |
+
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
15 |
+
"Cell \u001b[0;32mIn[3], line 18\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m# もし学習時のクラスがカスタムクラス ModernBertForScoring なら\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# model = ModernBertForScoring.from_pretrained(MODEL_DIR)\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 15\u001b[0m \n\u001b[1;32m 16\u001b[0m \u001b[38;5;66;03m# 例:カスタムクラス ModernBertForScoring の場合\u001b[39;00m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtrain_jmtb_v6\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ModernBertForScoring\n\u001b[0;32m---> 18\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mModernBertForScoring\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mMODEL_DIR\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 20\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m AutoTokenizer\u001b[38;5;241m.\u001b[39mfrom_pretrained(MODEL_DIR)\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# GPU利用する場合\u001b[39;00m\n",
|
16 |
+
"File \u001b[0;32m~/anaconda3/envs/vllmtest/lib/python3.12/site-packages/transformers/modeling_utils.py:3571\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3569\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 3570\u001b[0m config_path \u001b[38;5;241m=\u001b[39m config \u001b[38;5;28;01mif\u001b[39;00m config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> 3571\u001b[0m config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3572\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3573\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3574\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_unused_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 3575\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3576\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3577\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3578\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3579\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3580\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3581\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3582\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_auto\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_auto_class\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3583\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_pipeline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_pipeline\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3584\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3585\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3587\u001b[0m \u001b[38;5;66;03m# In case one passes a config to `from_pretrained` + \"attn_implementation\"\u001b[39;00m\n\u001b[1;32m 3588\u001b[0m \u001b[38;5;66;03m# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3592\u001b[0m \u001b[38;5;66;03m# we pop attn_implementation from the kwargs but this handles the case where users\u001b[39;00m\n\u001b[1;32m 3593\u001b[0m \u001b[38;5;66;03m# passes manually the config to `from_pretrained`.\u001b[39;00m\n\u001b[1;32m 3594\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config)\n",
|
17 |
+
"File \u001b[0;32m~/anaconda3/envs/vllmtest/lib/python3.12/site-packages/transformers/configuration_utils.py:569\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type:\n\u001b[1;32m 564\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 565\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are using a model of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to instantiate a model of type \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 566\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 567\u001b[0m )\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
18 |
+
"File \u001b[0;32m~/anaconda3/envs/vllmtest/lib/python3.12/site-packages/transformers/configuration_utils.py:740\u001b[0m, in \u001b[0;36mPretrainedConfig.from_dict\u001b[0;34m(cls, config_dict, **kwargs)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;66;03m# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.\u001b[39;00m\n\u001b[1;32m 738\u001b[0m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 740\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(config, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpruned_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 743\u001b[0m config\u001b[38;5;241m.\u001b[39mpruned_heads \u001b[38;5;241m=\u001b[39m {\u001b[38;5;28mint\u001b[39m(key): value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m config\u001b[38;5;241m.\u001b[39mpruned_heads\u001b[38;5;241m.\u001b[39mitems()}\n",
|
19 |
+
"File \u001b[0;32m~/anaconda3/envs/vllmtest/lib/python3.12/site-packages/transformers/models/modernbert/configuration_modernbert.py:173\u001b[0m, in \u001b[0;36mModernBertConfig.__init__\u001b[0;34m(self, vocab_size, hidden_size, intermediate_size, num_hidden_layers, num_attention_heads, hidden_activation, max_position_embeddings, initializer_range, initializer_cutoff_factor, norm_eps, norm_bias, pad_token_id, eos_token_id, bos_token_id, cls_token_id, sep_token_id, global_rope_theta, attention_bias, attention_dropout, global_attn_every_n_layers, local_attention, local_rope_theta, embedding_dropout, mlp_bias, mlp_dropout, decoder_bias, classifier_pooling, classifier_dropout, classifier_bias, classifier_activation, deterministic_flash_attn, sparse_prediction, sparse_pred_ignore_index, reference_compile, repad_logits_with_grad, **kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 136\u001b[0m vocab_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m50368\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 172\u001b[0m ):\n\u001b[0;32m--> 173\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mpad_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpad_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43mbos_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbos_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 176\u001b[0m \u001b[43m \u001b[49m\u001b[43meos_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43meos_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 177\u001b[0m \u001b[43m \u001b[49m\u001b[43mcls_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcls_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 178\u001b[0m \u001b[43m \u001b[49m\u001b[43msep_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msep_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 179\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 180\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 181\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvocab_size \u001b[38;5;241m=\u001b[39m vocab_size\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_position_embeddings \u001b[38;5;241m=\u001b[39m max_position_embeddings\n",
|
20 |
+
"File \u001b[0;32m~/anaconda3/envs/vllmtest/lib/python3.12/site-packages/transformers/configuration_utils.py:286\u001b[0m, in \u001b[0;36mPretrainedConfig.__init__\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 284\u001b[0m allowed_problem_types \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mregression\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msingle_label_classification\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmulti_label_classification\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 285\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mproblem_type \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mproblem_type \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m allowed_problem_types:\n\u001b[0;32m--> 286\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 287\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe config parameter `problem_type` was not understood: received \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mproblem_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 288\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbut only \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mregression\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msingle_label_classification\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m and \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmulti_label_classification\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m are valid.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 289\u001b[0m )\n\u001b[1;32m 291\u001b[0m \u001b[38;5;66;03m# TPU arguments\u001b[39;00m\n\u001b[1;32m 292\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mxla_device\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
21 |
+
"\u001b[0;31mValueError\u001b[0m: The config parameter `problem_type` was not understood: received single_label_regression but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
|
22 |
+
]
|
23 |
+
}
|
24 |
+
],
|
25 |
+
"source": [
|
26 |
+
"import torch\n",
|
27 |
+
"from transformers import AutoTokenizer\n",
|
28 |
+
"\n",
|
29 |
+
"# カスタムクラスが必要な場合はそちらを import\n",
|
30 |
+
"# from your_module import ModernBertForScoring\n",
|
31 |
+
"\n",
|
32 |
+
"MODEL_DIR = \"./modernbert_jamt_finetune_ckpt_49\" # 実際のパスに置き換えてください\n",
|
33 |
+
"\n",
|
34 |
+
"# もし学習時のクラスがカスタムクラス ModernBertForScoring なら\n",
|
35 |
+
"# model = ModernBertForScoring.from_pretrained(MODEL_DIR)\n",
|
36 |
+
"\n",
|
37 |
+
"# もし学習時に ModernBertForSequenceClassification などを使ったなら(config.jsonを修正済み)\n",
|
38 |
+
"# from transformers import AutoModelForSequenceClassification\n",
|
39 |
+
"# model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)\n",
|
40 |
+
"\n",
|
41 |
+
"# 例:カスタムクラス ModernBertForScoring の場合\n",
|
42 |
+
"from train_jmtb_v6 import ModernBertForScoring\n",
|
43 |
+
"model = ModernBertForScoring.from_pretrained(MODEL_DIR)\n",
|
44 |
+
"\n",
|
45 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)\n",
|
46 |
+
"\n",
|
47 |
+
"# GPU利用する場合\n",
|
48 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
49 |
+
"model.to(device)\n",
|
50 |
+
"model.eval()\n"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": null,
|
56 |
+
"metadata": {},
|
57 |
+
"outputs": [
|
58 |
+
{
|
59 |
+
"name": "stdout",
|
60 |
+
"output_type": "stream",
|
61 |
+
"text": [
|
62 |
+
"Predicted score: 0.3452\n"
|
63 |
+
]
|
64 |
+
}
|
65 |
+
],
|
66 |
+
"source": [
|
67 |
+
"def predict_score(text: str, model, tokenizer, device):\n",
|
68 |
+
" \"\"\"\n",
|
69 |
+
" 1つのテキストに対し、学習済みモデルで 0.0~1.0 の推定スコアを返す\n",
|
70 |
+
" (ModernBertForScoring で Sigmoidがかかっている想定)\n",
|
71 |
+
" \"\"\"\n",
|
72 |
+
" # トークナイズ\n",
|
73 |
+
" inputs = tokenizer(\n",
|
74 |
+
" text,\n",
|
75 |
+
" return_tensors=\"pt\",\n",
|
76 |
+
" truncation=True,\n",
|
77 |
+
" max_length=512\n",
|
78 |
+
" )\n",
|
79 |
+
" # GPUへ移動\n",
|
80 |
+
" inputs = {k: v.to(device) for k, v in inputs.items()}\n",
|
81 |
+
"\n",
|
82 |
+
" # 推論\n",
|
83 |
+
" with torch.no_grad():\n",
|
84 |
+
" outputs = model(**inputs)\n",
|
85 |
+
" # ModernBertForScoring なら outputs.logits が [batch_size,1]\n",
|
86 |
+
" score = outputs.logits.squeeze().item() # floatに変換\n",
|
87 |
+
"\n",
|
88 |
+
" return score\n",
|
89 |
+
"\n",
|
90 |
+
"# ------------------------\n",
|
91 |
+
"# 推論テスト\n",
|
92 |
+
"# ------------------------\n",
|
93 |
+
"example_text = \"これはテスト入力です。BERTに対するテストを行います。\"\n",
|
94 |
+
"pred_score = predict_score(example_text, model, tokenizer, device)\n",
|
95 |
+
"print(f\"Predicted score: {pred_score:.4f}\")\n"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"execution_count": null,
|
101 |
+
"metadata": {},
|
102 |
+
"outputs": [],
|
103 |
+
"source": [
|
104 |
+
"import pickle"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": null,
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"# 学習時に保存したデータセットpickle (floatラベル)\n",
|
114 |
+
"with open(r\"/media/kurogane/kioxia1/dataset/sss/pixiv/modernbert_jamt_finetune_ckpt_49/dataset_dict_float.pkl\", \"rb\") as file:\n",
|
115 |
+
" dataset_dict = pickle.load(file)\n",
|
116 |
+
"\n",
|
117 |
+
"# テストセットだけ取り出す (train/validation も必要なら適宜呼び出す)\n",
|
118 |
+
"test_dataset = dataset_dict[\"test\"]\n"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": null,
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [
|
126 |
+
{
|
127 |
+
"data": {
|
128 |
+
"text/plain": [
|
129 |
+
"Dataset({\n",
|
130 |
+
" features: ['input_text', 'label'],\n",
|
131 |
+
" num_rows: 648\n",
|
132 |
+
"})"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
"execution_count": 12,
|
136 |
+
"metadata": {},
|
137 |
+
"output_type": "execute_result"
|
138 |
+
}
|
139 |
+
],
|
140 |
+
"source": [
|
141 |
+
"test_dataset"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": null,
|
147 |
+
"metadata": {},
|
148 |
+
"outputs": [],
|
149 |
+
"source": [
|
150 |
+
"from tqdm import tqdm"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": null,
|
156 |
+
"metadata": {},
|
157 |
+
"outputs": [
|
158 |
+
{
|
159 |
+
"name": "stderr",
|
160 |
+
"output_type": "stream",
|
161 |
+
"text": [
|
162 |
+
"100%|██████████| 648/648 [00:04<00:00, 133.80it/s]\n"
|
163 |
+
]
|
164 |
+
}
|
165 |
+
],
|
166 |
+
"source": [
|
167 |
+
"l_estimate_scores = []\n",
|
168 |
+
"for i_dataset in tqdm(test_dataset):\n",
|
169 |
+
" # print(i_dataset)\n",
|
170 |
+
" f_estimate_score = predict_score(i_dataset['input_text'], model, tokenizer, device)\n",
|
171 |
+
" l_estimate_scores.append([f_estimate_score, i_dataset[\"label\"]])"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "code",
|
176 |
+
"execution_count": null,
|
177 |
+
"metadata": {},
|
178 |
+
"outputs": [
|
179 |
+
{
|
180 |
+
"data": {
|
181 |
+
"text/plain": [
|
182 |
+
"[[0.9064586758613586, 0.9],\n",
|
183 |
+
" [0.8449122309684753, 0.9],\n",
|
184 |
+
" [0.929304838180542, 0.8],\n",
|
185 |
+
" [0.806448757648468, 0.9],\n",
|
186 |
+
" [0.6466315984725952, 0.6],\n",
|
187 |
+
" [0.9507829546928406, 1.0],\n",
|
188 |
+
" [0.8817955851554871, 1.0],\n",
|
189 |
+
" [0.9700656533241272, 0.9],\n",
|
190 |
+
" [0.8967279195785522, 1.0],\n",
|
191 |
+
" [0.6367055177688599, 0.6],\n",
|
192 |
+
" [0.7786707282066345, 0.9],\n",
|
193 |
+
" [0.8398993611335754, 0.9],\n",
|
194 |
+
" [0.8425216674804688, 0.9],\n",
|
195 |
+
" [0.8936999440193176, 1.0],\n",
|
196 |
+
" [0.8087979555130005, 0.9],\n",
|
197 |
+
" [0.8145586848258972, 0.8],\n",
|
198 |
+
" [0.6279298663139343, 0.4],\n",
|
199 |
+
" [0.4987145960330963, 0.8],\n",
|
200 |
+
" [0.696692705154419, 0.7],\n",
|
201 |
+
" [0.7956013083457947, 0.9],\n",
|
202 |
+
" [0.8720244765281677, 0.9],\n",
|
203 |
+
" [0.8167892694473267, 0.9],\n",
|
204 |
+
" [0.8600430488586426, 0.9],\n",
|
205 |
+
" [0.8366582989692688, 0.7],\n",
|
206 |
+
" [0.8482577800750732, 0.7],\n",
|
207 |
+
" [0.10592726618051529, 0.2],\n",
|
208 |
+
" [0.3639181852340698, 0.2],\n",
|
209 |
+
" [0.45103591680526733, 0.7],\n",
|
210 |
+
" [0.8230735659599304, 0.7],\n",
|
211 |
+
" [0.7876871824264526, 0.8],\n",
|
212 |
+
" [0.8766051530838013, 0.9],\n",
|
213 |
+
" [0.8099154233932495, 0.7],\n",
|
214 |
+
" [0.6839173436164856, 0.8],\n",
|
215 |
+
" [0.8837357759475708, 0.9],\n",
|
216 |
+
" [0.5957882404327393, 0.6],\n",
|
217 |
+
" [0.405498206615448, 0.6],\n",
|
218 |
+
" [0.8267595767974854, 0.9],\n",
|
219 |
+
" [0.9590301513671875, 1.0],\n",
|
220 |
+
" [0.7926787734031677, 0.7],\n",
|
221 |
+
" [0.5048006176948547, 0.2],\n",
|
222 |
+
" [0.872920036315918, 0.9],\n",
|
223 |
+
" [0.4801338016986847, 0.4],\n",
|
224 |
+
" [0.9707834720611572, 1.0],\n",
|
225 |
+
" [0.954249918460846, 0.9],\n",
|
226 |
+
" [0.6119499206542969, 0.8],\n",
|
227 |
+
" [0.804256796836853, 0.3],\n",
|
228 |
+
" [0.9629430174827576, 1.0],\n",
|
229 |
+
" [0.8675076365470886, 0.9],\n",
|
230 |
+
" [0.4841710329055786, 0.2],\n",
|
231 |
+
" [0.7352050542831421, 0.9],\n",
|
232 |
+
" [0.7698368430137634, 0.9],\n",
|
233 |
+
" [0.42692598700523376, 0.2],\n",
|
234 |
+
" [0.7776671051979065, 0.6],\n",
|
235 |
+
" [0.9430829882621765, 1.0],\n",
|
236 |
+
" [0.780847430229187, 0.9],\n",
|
237 |
+
" [0.9405631422996521, 1.0],\n",
|
238 |
+
" [0.254617303609848, 0.4],\n",
|
239 |
+
" [0.8624202013015747, 0.9],\n",
|
240 |
+
" [0.9356356263160706, 1.0],\n",
|
241 |
+
" [0.7910308241844177, 0.9],\n",
|
242 |
+
" [0.2423963099718094, 0.2],\n",
|
243 |
+
" [0.9045445919036865, 0.9],\n",
|
244 |
+
" [0.3448959290981293, 0.4],\n",
|
245 |
+
" [0.8975270390510559, 0.9],\n",
|
246 |
+
" [0.7061187028884888, 0.9],\n",
|
247 |
+
" [0.8589214086532593, 0.9],\n",
|
248 |
+
" [0.7566481232643127, 0.9],\n",
|
249 |
+
" [0.9401050209999084, 1.0],\n",
|
250 |
+
" [0.887227475643158, 0.9],\n",
|
251 |
+
" [0.7086873650550842, 0.9],\n",
|
252 |
+
" [0.9077960252761841, 0.9],\n",
|
253 |
+
" [0.9026407599449158, 1.0],\n",
|
254 |
+
" [0.935111939907074, 0.9],\n",
|
255 |
+
" [0.5277835130691528, 0.4],\n",
|
256 |
+
" [0.7517065405845642, 0.7],\n",
|
257 |
+
" [0.6940519213676453, 0.5],\n",
|
258 |
+
" [0.9113664031028748, 1.0],\n",
|
259 |
+
" [0.4126318097114563, 0.3],\n",
|
260 |
+
" [0.5240322947502136, 0.6],\n",
|
261 |
+
" [0.8750995397567749, 0.9],\n",
|
262 |
+
" [0.9469568729400635, 0.8],\n",
|
263 |
+
" [0.7899268269538879, 0.9],\n",
|
264 |
+
" [0.857871413230896, 0.2],\n",
|
265 |
+
" [0.7683762907981873, 0.7],\n",
|
266 |
+
" [0.8666701912879944, 0.9],\n",
|
267 |
+
" [0.902720034122467, 0.9],\n",
|
268 |
+
" [0.9435014128684998, 1.0],\n",
|
269 |
+
" [0.6808632612228394, 0.8],\n",
|
270 |
+
" [0.9126145839691162, 0.9],\n",
|
271 |
+
" [0.8799282908439636, 0.9],\n",
|
272 |
+
" [0.6882354021072388, 0.7],\n",
|
273 |
+
" [0.8309448957443237, 0.9],\n",
|
274 |
+
" [0.8704410195350647, 0.9],\n",
|
275 |
+
" [0.8138535022735596, 0.9],\n",
|
276 |
+
" [0.6686734557151794, 0.3],\n",
|
277 |
+
" [0.8925440907478333, 1.0],\n",
|
278 |
+
" [0.7934283018112183, 1.0],\n",
|
279 |
+
" [0.9107365012168884, 0.9],\n",
|
280 |
+
" [0.9745094180107117, 1.0],\n",
|
281 |
+
" [0.8090866804122925, 0.9],\n",
|
282 |
+
" [0.9362606406211853, 0.9],\n",
|
283 |
+
" [0.6617568135261536, 0.3],\n",
|
284 |
+
" [0.6281882524490356, 0.9],\n",
|
285 |
+
" [0.6575912833213806, 0.6],\n",
|
286 |
+
" [0.7039993405342102, 0.6],\n",
|
287 |
+
" [0.8477407097816467, 0.9],\n",
|
288 |
+
" [0.8910886645317078, 0.9],\n",
|
289 |
+
" [0.7563804388046265, 0.9],\n",
|
290 |
+
" [0.8112492561340332, 0.9],\n",
|
291 |
+
" [0.7291156053543091, 0.9],\n",
|
292 |
+
" [0.5929954051971436, 0.5],\n",
|
293 |
+
" [0.5142516493797302, 0.5],\n",
|
294 |
+
" [0.6867972016334534, 0.6],\n",
|
295 |
+
" [0.8761500120162964, 0.9],\n",
|
296 |
+
" [0.8619706034660339, 0.9],\n",
|
297 |
+
" [0.897497832775116, 0.9],\n",
|
298 |
+
" [0.8493452668190002, 0.8],\n",
|
299 |
+
" [0.8616324663162231, 0.9],\n",
|
300 |
+
" [0.6340180039405823, 0.4],\n",
|
301 |
+
" [0.7829850912094116, 0.9],\n",
|
302 |
+
" [0.6297580599784851, 0.4],\n",
|
303 |
+
" [0.8162065744400024, 0.9],\n",
|
304 |
+
" [0.7388235330581665, 0.4],\n",
|
305 |
+
" [0.7455839514732361, 0.9],\n",
|
306 |
+
" [0.8802245855331421, 0.6],\n",
|
307 |
+
" [0.7003363966941833, 0.9],\n",
|
308 |
+
" [0.5237756371498108, 0.2],\n",
|
309 |
+
" [0.8556636571884155, 0.6],\n",
|
310 |
+
" [0.851711094379425, 0.8],\n",
|
311 |
+
" [0.8817101716995239, 0.9],\n",
|
312 |
+
" [0.8661450743675232, 0.9],\n",
|
313 |
+
" [0.8317744135856628, 0.8],\n",
|
314 |
+
" [0.3223874866962433, 0.3],\n",
|
315 |
+
" [0.916279137134552, 0.9],\n",
|
316 |
+
" [0.8346007466316223, 0.5],\n",
|
317 |
+
" [0.8453168272972107, 0.9],\n",
|
318 |
+
" [0.37649181485176086, 0.3],\n",
|
319 |
+
" [0.6854564547538757, 0.8],\n",
|
320 |
+
" [0.7912370562553406, 0.9],\n",
|
321 |
+
" [0.38355275988578796, 0.9],\n",
|
322 |
+
" [0.7108463048934937, 0.7],\n",
|
323 |
+
" [0.8513278365135193, 0.9],\n",
|
324 |
+
" [0.9008965492248535, 0.9],\n",
|
325 |
+
" [0.1853475570678711, 0.2],\n",
|
326 |
+
" [0.5783144235610962, 0.5],\n",
|
327 |
+
" [0.7818315029144287, 0.9],\n",
|
328 |
+
" [0.7993879914283752, 0.7],\n",
|
329 |
+
" [0.7314165830612183, 0.8],\n",
|
330 |
+
" [0.9500613808631897, 1.0],\n",
|
331 |
+
" [0.8998763561248779, 1.0],\n",
|
332 |
+
" [0.38737180829048157, 0.4],\n",
|
333 |
+
" [0.8452264666557312, 0.9],\n",
|
334 |
+
" [0.25194141268730164, 0.3],\n",
|
335 |
+
" [0.9476278424263, 0.9],\n",
|
336 |
+
" [0.4460093379020691, 0.3],\n",
|
337 |
+
" [0.8978778719902039, 0.9],\n",
|
338 |
+
" [0.8573941588401794, 0.9],\n",
|
339 |
+
" [0.3037511706352234, 0.3],\n",
|
340 |
+
" [0.7195190787315369, 0.9],\n",
|
341 |
+
" [0.6808164119720459, 0.8],\n",
|
342 |
+
" [0.7646284103393555, 0.9],\n",
|
343 |
+
" [0.9012228846549988, 0.9],\n",
|
344 |
+
" [0.5082786083221436, 0.8],\n",
|
345 |
+
" [0.9199990034103394, 0.9],\n",
|
346 |
+
" [0.7429797053337097, 0.6],\n",
|
347 |
+
" [0.7855229377746582, 0.9],\n",
|
348 |
+
" [0.7403103709220886, 0.9],\n",
|
349 |
+
" [0.856158971786499, 0.9],\n",
|
350 |
+
" [0.7221283316612244, 0.9],\n",
|
351 |
+
" [0.8180127739906311, 0.9],\n",
|
352 |
+
" [0.8110374212265015, 0.6],\n",
|
353 |
+
" [0.8805463314056396, 0.9],\n",
|
354 |
+
" [0.8187531232833862, 0.8],\n",
|
355 |
+
" [0.6386672258377075, 0.6],\n",
|
356 |
+
" [0.9463333487510681, 1.0],\n",
|
357 |
+
" [0.8654801845550537, 0.9],\n",
|
358 |
+
" [0.9553059935569763, 0.9],\n",
|
359 |
+
" [0.7202808260917664, 0.4],\n",
|
360 |
+
" [0.596796452999115, 0.6],\n",
|
361 |
+
" [0.599234938621521, 0.2],\n",
|
362 |
+
" [0.8640603423118591, 0.9],\n",
|
363 |
+
" [0.8499320149421692, 0.7],\n",
|
364 |
+
" [0.8750359416007996, 0.9],\n",
|
365 |
+
" [0.922467827796936, 1.0],\n",
|
366 |
+
" [0.8759791851043701, 1.0],\n",
|
367 |
+
" [0.43951845169067383, 0.3],\n",
|
368 |
+
" [0.9501491189002991, 0.9],\n",
|
369 |
+
" [0.7858310341835022, 0.9],\n",
|
370 |
+
" [0.9279288053512573, 1.0],\n",
|
371 |
+
" [0.8105558753013611, 0.9],\n",
|
372 |
+
" [0.7309414148330688, 0.7],\n",
|
373 |
+
" [0.4521546959877014, 0.3],\n",
|
374 |
+
" [0.8569731116294861, 0.9],\n",
|
375 |
+
" [0.7542720437049866, 0.9],\n",
|
376 |
+
" [0.9578987956047058, 0.9],\n",
|
377 |
+
" [0.9457001090049744, 0.9],\n",
|
378 |
+
" [0.8531457781791687, 0.9],\n",
|
379 |
+
" [0.8666984438896179, 0.9],\n",
|
380 |
+
" [0.48565420508384705, 0.4],\n",
|
381 |
+
" [0.8775691390037537, 0.9],\n",
|
382 |
+
" [0.6819878220558167, 0.4],\n",
|
383 |
+
" [0.9245203137397766, 1.0],\n",
|
384 |
+
" [0.8452584147453308, 1.0],\n",
|
385 |
+
" [0.8809332251548767, 0.9],\n",
|
386 |
+
" [0.7760116457939148, 0.7],\n",
|
387 |
+
" [0.8173214197158813, 0.9],\n",
|
388 |
+
" [0.7378541827201843, 0.9],\n",
|
389 |
+
" [0.5877021551132202, 0.4],\n",
|
390 |
+
" [0.5508979558944702, 0.8],\n",
|
391 |
+
" [0.3678698241710663, 0.5],\n",
|
392 |
+
" [0.30494531989097595, 0.3],\n",
|
393 |
+
" [0.6908549070358276, 0.6],\n",
|
394 |
+
" [0.5437881946563721, 0.6],\n",
|
395 |
+
" [0.8356095552444458, 0.9],\n",
|
396 |
+
" [0.31034883856773376, 0.2],\n",
|
397 |
+
" [0.8924189805984497, 0.8],\n",
|
398 |
+
" [0.6236647963523865, 0.3],\n",
|
399 |
+
" [0.6277945637702942, 0.4],\n",
|
400 |
+
" [0.6978229880332947, 0.1],\n",
|
401 |
+
" [0.8123990893363953, 0.9],\n",
|
402 |
+
" [0.4208259880542755, 0.7],\n",
|
403 |
+
" [0.8291409611701965, 0.9],\n",
|
404 |
+
" [0.8832250237464905, 0.9],\n",
|
405 |
+
" [0.6538210511207581, 0.8],\n",
|
406 |
+
" [0.896472692489624, 0.9],\n",
|
407 |
+
" [0.6764245629310608, 0.4],\n",
|
408 |
+
" [0.8327236175537109, 0.9],\n",
|
409 |
+
" [0.8454877138137817, 0.9],\n",
|
410 |
+
" [0.8654239773750305, 0.8],\n",
|
411 |
+
" [0.6745596528053284, 0.6],\n",
|
412 |
+
" [0.7898547649383545, 0.8],\n",
|
413 |
+
" [0.6550565361976624, 0.4],\n",
|
414 |
+
" [0.6239812970161438, 0.8],\n",
|
415 |
+
" [0.9469243884086609, 0.9],\n",
|
416 |
+
" [0.9485745429992676, 0.9],\n",
|
417 |
+
" [0.6684531569480896, 0.6],\n",
|
418 |
+
" [0.9079251289367676, 0.9],\n",
|
419 |
+
" [0.7882359027862549, 0.6],\n",
|
420 |
+
" [0.7799747586250305, 0.9],\n",
|
421 |
+
" [0.7874063849449158, 0.9],\n",
|
422 |
+
" [0.8244850039482117, 0.9],\n",
|
423 |
+
" [0.6317123174667358, 0.6],\n",
|
424 |
+
" [0.8460860252380371, 0.9],\n",
|
425 |
+
" [0.8276510834693909, 0.9],\n",
|
426 |
+
" [0.38163939118385315, 0.9],\n",
|
427 |
+
" [0.9736513495445251, 1.0],\n",
|
428 |
+
" [0.8883947730064392, 0.9],\n",
|
429 |
+
" [0.7605443596839905, 0.8],\n",
|
430 |
+
" [0.19729329645633698, 0.2],\n",
|
431 |
+
" [0.88736891746521, 1.0],\n",
|
432 |
+
" [0.862339198589325, 0.9],\n",
|
433 |
+
" [0.7687414884567261, 0.9],\n",
|
434 |
+
" [0.7632433176040649, 0.6],\n",
|
435 |
+
" [0.20476382970809937, 0.2],\n",
|
436 |
+
" [0.31666404008865356, 0.9],\n",
|
437 |
+
" [0.8854409456253052, 0.9],\n",
|
438 |
+
" [0.28262194991111755, 0.4],\n",
|
439 |
+
" [0.8240434527397156, 0.1],\n",
|
440 |
+
" [0.8445137143135071, 0.9],\n",
|
441 |
+
" [0.5455150604248047, 0.3],\n",
|
442 |
+
" [0.9618996977806091, 1.0],\n",
|
443 |
+
" [0.8494833111763, 0.8],\n",
|
444 |
+
" [0.4823213815689087, 0.4],\n",
|
445 |
+
" [0.7849555611610413, 0.6],\n",
|
446 |
+
" [0.6141435503959656, 0.7],\n",
|
447 |
+
" [0.7253923416137695, 0.9],\n",
|
448 |
+
" [0.7148001790046692, 0.9],\n",
|
449 |
+
" [0.929366946220398, 0.9],\n",
|
450 |
+
" [0.3592156171798706, 0.7],\n",
|
451 |
+
" [0.3085547983646393, 0.2],\n",
|
452 |
+
" [0.770656943321228, 0.9],\n",
|
453 |
+
" [0.8839257955551147, 0.9],\n",
|
454 |
+
" [0.8835964202880859, 0.9],\n",
|
455 |
+
" [0.3086932301521301, 0.3],\n",
|
456 |
+
" [0.644216775894165, 0.4],\n",
|
457 |
+
" [0.7603057622909546, 0.8],\n",
|
458 |
+
" [0.47372305393218994, 0.3],\n",
|
459 |
+
" [0.8266362547874451, 0.9],\n",
|
460 |
+
" [0.8671748638153076, 0.9],\n",
|
461 |
+
" [0.7935934662818909, 0.9],\n",
|
462 |
+
" [0.338331937789917, 0.4],\n",
|
463 |
+
" [0.5553470253944397, 0.4],\n",
|
464 |
+
" [0.7325969934463501, 0.8],\n",
|
465 |
+
" [0.9176349639892578, 0.9],\n",
|
466 |
+
" [0.5863208770751953, 0.5],\n",
|
467 |
+
" [0.8673837780952454, 0.9],\n",
|
468 |
+
" [0.8770381808280945, 0.9],\n",
|
469 |
+
" [0.6373818516731262, 0.8],\n",
|
470 |
+
" [0.6105970144271851, 0.6],\n",
|
471 |
+
" [0.9128532409667969, 0.9],\n",
|
472 |
+
" [0.6021369099617004, 0.5],\n",
|
473 |
+
" [0.6904911994934082, 0.8],\n",
|
474 |
+
" [0.8588377833366394, 0.9],\n",
|
475 |
+
" [0.4375683069229126, 0.6],\n",
|
476 |
+
" [0.8753483891487122, 0.9],\n",
|
477 |
+
" [0.8913830518722534, 1.0],\n",
|
478 |
+
" [0.7222169637680054, 0.7],\n",
|
479 |
+
" [0.7359307408332825, 0.6],\n",
|
480 |
+
" [0.8244432806968689, 0.9],\n",
|
481 |
+
" [0.6900085210800171, 0.8],\n",
|
482 |
+
" [0.2715027630329132, 0.1],\n",
|
483 |
+
" [0.6896530389785767, 0.6],\n",
|
484 |
+
" [0.6765221953392029, 0.6],\n",
|
485 |
+
" [0.3277732729911804, 0.4],\n",
|
486 |
+
" [0.4515093266963959, 0.4],\n",
|
487 |
+
" [0.8928239941596985, 0.9],\n",
|
488 |
+
" [0.5652311444282532, 0.7],\n",
|
489 |
+
" [0.4977130591869354, 0.2],\n",
|
490 |
+
" [0.74165278673172, 0.9],\n",
|
491 |
+
" [0.48645636439323425, 0.7],\n",
|
492 |
+
" [0.8301733732223511, 0.9],\n",
|
493 |
+
" [0.46485790610313416, 0.6],\n",
|
494 |
+
" [0.9069660902023315, 1.0],\n",
|
495 |
+
" [0.6526173949241638, 0.5],\n",
|
496 |
+
" [0.22337760031223297, 0.1],\n",
|
497 |
+
" [0.8109521865844727, 0.9],\n",
|
498 |
+
" [0.2853657305240631, 0.2],\n",
|
499 |
+
" [0.8568928241729736, 0.9],\n",
|
500 |
+
" [0.5527607202529907, 0.6],\n",
|
501 |
+
" [0.8812926411628723, 0.9],\n",
|
502 |
+
" [0.7154238224029541, 0.6],\n",
|
503 |
+
" [0.9051880836486816, 0.9],\n",
|
504 |
+
" [0.5803526043891907, 0.8],\n",
|
505 |
+
" [0.7091109156608582, 0.8],\n",
|
506 |
+
" [0.5601979494094849, 0.7],\n",
|
507 |
+
" [0.787548840045929, 0.7],\n",
|
508 |
+
" [0.7948053479194641, 0.8],\n",
|
509 |
+
" [0.9312030076980591, 0.9],\n",
|
510 |
+
" [0.8789415955543518, 1.0],\n",
|
511 |
+
" [0.9068158864974976, 1.0],\n",
|
512 |
+
" [0.8658299446105957, 0.9],\n",
|
513 |
+
" [0.9198936820030212, 1.0],\n",
|
514 |
+
" [0.6551686525344849, 0.9],\n",
|
515 |
+
" [0.6174558401107788, 0.3],\n",
|
516 |
+
" [0.8762447237968445, 0.8],\n",
|
517 |
+
" [0.8365645408630371, 0.9],\n",
|
518 |
+
" [0.1843896359205246, 0.1],\n",
|
519 |
+
" [0.583404541015625, 0.9],\n",
|
520 |
+
" [0.8519049882888794, 0.8],\n",
|
521 |
+
" [0.6710367798805237, 0.8],\n",
|
522 |
+
" [0.4004596769809723, 0.3],\n",
|
523 |
+
" [0.9558364748954773, 0.9],\n",
|
524 |
+
" [0.8146979212760925, 0.8],\n",
|
525 |
+
" [0.9368678331375122, 0.9],\n",
|
526 |
+
" [0.9128404259681702, 0.9],\n",
|
527 |
+
" [0.8924294114112854, 0.9],\n",
|
528 |
+
" [0.8706570863723755, 0.9],\n",
|
529 |
+
" [0.36182519793510437, 0.2],\n",
|
530 |
+
" [0.8756670951843262, 0.9],\n",
|
531 |
+
" [0.5055785179138184, 0.3],\n",
|
532 |
+
" [0.7487927079200745, 0.9],\n",
|
533 |
+
" [0.9558643102645874, 1.0],\n",
|
534 |
+
" [0.5944591760635376, 0.6],\n",
|
535 |
+
" [0.6496614813804626, 0.6],\n",
|
536 |
+
" [0.891505241394043, 0.9],\n",
|
537 |
+
" [0.6592487096786499, 0.6],\n",
|
538 |
+
" [0.7970435619354248, 1.0],\n",
|
539 |
+
" [0.7491934299468994, 0.4],\n",
|
540 |
+
" [0.5845210552215576, 0.9],\n",
|
541 |
+
" [0.7628567814826965, 0.9],\n",
|
542 |
+
" [0.40675088763237, 0.5],\n",
|
543 |
+
" [0.627162754535675, 0.9],\n",
|
544 |
+
" [0.5906123518943787, 0.8],\n",
|
545 |
+
" [0.6509009003639221, 0.6],\n",
|
546 |
+
" [0.9104874134063721, 1.0],\n",
|
547 |
+
" [0.8778848648071289, 0.8],\n",
|
548 |
+
" [0.7858289480209351, 0.7],\n",
|
549 |
+
" [0.9646210670471191, 1.0],\n",
|
550 |
+
" [0.49148932099342346, 0.5],\n",
|
551 |
+
" [0.5657476186752319, 0.5],\n",
|
552 |
+
" [0.7989112138748169, 0.5],\n",
|
553 |
+
" [0.896877646446228, 0.9],\n",
|
554 |
+
" [0.8994553089141846, 0.9],\n",
|
555 |
+
" [0.8644108176231384, 0.9],\n",
|
556 |
+
" [0.5436504483222961, 0.3],\n",
|
557 |
+
" [0.38367825746536255, 0.2],\n",
|
558 |
+
" [0.35513395071029663, 0.3],\n",
|
559 |
+
" [0.9275620579719543, 0.9],\n",
|
560 |
+
" [0.854905903339386, 0.8],\n",
|
561 |
+
" [0.5229591727256775, 0.8],\n",
|
562 |
+
" [0.8073667287826538, 0.9],\n",
|
563 |
+
" [0.7266579866409302, 0.6],\n",
|
564 |
+
" [0.23632675409317017, 0.1],\n",
|
565 |
+
" [0.552478551864624, 0.9],\n",
|
566 |
+
" [0.8053351640701294, 0.9],\n",
|
567 |
+
" [0.850672721862793, 0.9],\n",
|
568 |
+
" [0.9100931286811829, 0.9],\n",
|
569 |
+
" [0.8568122982978821, 0.9],\n",
|
570 |
+
" [0.6421248912811279, 0.9],\n",
|
571 |
+
" [0.5956704020500183, 0.3],\n",
|
572 |
+
" [0.3317554295063019, 0.3],\n",
|
573 |
+
" [0.927498996257782, 0.9],\n",
|
574 |
+
" [0.8942874073982239, 1.0],\n",
|
575 |
+
" [0.9104828238487244, 1.0],\n",
|
576 |
+
" [0.37761199474334717, 0.2],\n",
|
577 |
+
" [0.7857874631881714, 0.7],\n",
|
578 |
+
" [0.8570524454116821, 0.3],\n",
|
579 |
+
" [0.8882994651794434, 0.9],\n",
|
580 |
+
" [0.9283419251441956, 0.9],\n",
|
581 |
+
" [0.8294586539268494, 0.9],\n",
|
582 |
+
" [0.3736439049243927, 0.5],\n",
|
583 |
+
" [0.6581687331199646, 0.7],\n",
|
584 |
+
" [0.8052690029144287, 0.9],\n",
|
585 |
+
" [0.8928396701812744, 0.9],\n",
|
586 |
+
" [0.6559609174728394, 0.8],\n",
|
587 |
+
" [0.870569109916687, 0.8],\n",
|
588 |
+
" [0.3797019422054291, 0.4],\n",
|
589 |
+
" [0.8790174126625061, 0.6],\n",
|
590 |
+
" [0.573027491569519, 0.3],\n",
|
591 |
+
" [0.8363456726074219, 0.9],\n",
|
592 |
+
" [0.6144676804542542, 0.7],\n",
|
593 |
+
" [0.8835358023643494, 0.9],\n",
|
594 |
+
" [0.7157717943191528, 0.9],\n",
|
595 |
+
" [0.7214363217353821, 0.1],\n",
|
596 |
+
" [0.7688565850257874, 0.8],\n",
|
597 |
+
" [0.6583333015441895, 0.5],\n",
|
598 |
+
" [0.7756986021995544, 1.0],\n",
|
599 |
+
" [0.18134945631027222, 0.1],\n",
|
600 |
+
" [0.3336744010448456, 0.3],\n",
|
601 |
+
" [0.7706341743469238, 0.9],\n",
|
602 |
+
" [0.734782874584198, 0.9],\n",
|
603 |
+
" [0.9471049308776855, 0.9],\n",
|
604 |
+
" [0.6686676144599915, 0.9],\n",
|
605 |
+
" [0.872651994228363, 0.9],\n",
|
606 |
+
" [0.6990708708763123, 0.6],\n",
|
607 |
+
" [0.4532737135887146, 0.4],\n",
|
608 |
+
" [0.5959187150001526, 0.4],\n",
|
609 |
+
" [0.9041457176208496, 0.9],\n",
|
610 |
+
" [0.9407055377960205, 0.9],\n",
|
611 |
+
" [0.9118932485580444, 0.9],\n",
|
612 |
+
" [0.5548721551895142, 0.2],\n",
|
613 |
+
" [0.9288930892944336, 0.9],\n",
|
614 |
+
" [0.48166006803512573, 0.3],\n",
|
615 |
+
" [0.8659979701042175, 0.9],\n",
|
616 |
+
" [0.7878676652908325, 0.8],\n",
|
617 |
+
" [0.9107018709182739, 1.0],\n",
|
618 |
+
" [0.8593129515647888, 0.8],\n",
|
619 |
+
" [0.6023291945457458, 0.4],\n",
|
620 |
+
" [0.8151740431785583, 0.6],\n",
|
621 |
+
" [0.9689931869506836, 1.0],\n",
|
622 |
+
" [0.32890671491622925, 0.4],\n",
|
623 |
+
" [0.25132861733436584, 0.2],\n",
|
624 |
+
" [0.8355442881584167, 0.9],\n",
|
625 |
+
" [0.5196486711502075, 0.3],\n",
|
626 |
+
" [0.5570502877235413, 0.9],\n",
|
627 |
+
" [0.1721695214509964, 0.2],\n",
|
628 |
+
" [0.9613489508628845, 0.9],\n",
|
629 |
+
" [0.8573195934295654, 0.9],\n",
|
630 |
+
" [0.7359338402748108, 0.9],\n",
|
631 |
+
" [0.6692647933959961, 0.2],\n",
|
632 |
+
" [0.7032365798950195, 0.5],\n",
|
633 |
+
" [0.7604613304138184, 0.8],\n",
|
634 |
+
" [0.3672597110271454, 0.6],\n",
|
635 |
+
" [0.8801596760749817, 1.0],\n",
|
636 |
+
" [0.8825082182884216, 0.9],\n",
|
637 |
+
" [0.9223929643630981, 0.9],\n",
|
638 |
+
" [0.2902374267578125, 0.2],\n",
|
639 |
+
" [0.6596561074256897, 0.9],\n",
|
640 |
+
" [0.5656098127365112, 0.6],\n",
|
641 |
+
" [0.5993040800094604, 0.3],\n",
|
642 |
+
" [0.326254278421402, 0.4],\n",
|
643 |
+
" [0.7709024548530579, 0.9],\n",
|
644 |
+
" [0.56084805727005, 0.8],\n",
|
645 |
+
" [0.8905344009399414, 0.9],\n",
|
646 |
+
" [0.6955621838569641, 0.7],\n",
|
647 |
+
" [0.6527406573295593, 0.6],\n",
|
648 |
+
" [0.8516212105751038, 0.9],\n",
|
649 |
+
" [0.6509021520614624, 0.4],\n",
|
650 |
+
" [0.849424421787262, 0.6],\n",
|
651 |
+
" [0.7331821322441101, 0.9],\n",
|
652 |
+
" [0.3355243504047394, 0.3],\n",
|
653 |
+
" [0.9231580495834351, 1.0],\n",
|
654 |
+
" [0.6610034108161926, 0.7],\n",
|
655 |
+
" [0.18266692757606506, 0.3],\n",
|
656 |
+
" [0.7420765161514282, 0.9],\n",
|
657 |
+
" [0.8482139110565186, 0.7],\n",
|
658 |
+
" [0.7824781537055969, 0.7],\n",
|
659 |
+
" [0.7808020710945129, 0.9],\n",
|
660 |
+
" [0.5461168885231018, 0.3],\n",
|
661 |
+
" [0.7221590280532837, 0.9],\n",
|
662 |
+
" [0.6944252848625183, 0.8],\n",
|
663 |
+
" [0.7390590310096741, 0.9],\n",
|
664 |
+
" [0.42653611302375793, 0.4],\n",
|
665 |
+
" [0.8902081847190857, 1.0],\n",
|
666 |
+
" [0.9005285501480103, 0.9],\n",
|
667 |
+
" [0.9097873568534851, 0.9],\n",
|
668 |
+
" [0.21993842720985413, 0.2],\n",
|
669 |
+
" [0.421650767326355, 0.6],\n",
|
670 |
+
" [0.8794265389442444, 0.9],\n",
|
671 |
+
" [0.9476831555366516, 0.9],\n",
|
672 |
+
" [0.7081140875816345, 0.7],\n",
|
673 |
+
" [0.8922443985939026, 0.9],\n",
|
674 |
+
" [0.3995065987110138, 0.9],\n",
|
675 |
+
" [0.8627650737762451, 0.9],\n",
|
676 |
+
" [0.9611384868621826, 0.9],\n",
|
677 |
+
" [0.8823320865631104, 1.0],\n",
|
678 |
+
" [0.7572306990623474, 0.9],\n",
|
679 |
+
" [0.8725445866584778, 0.9],\n",
|
680 |
+
" [0.9449562430381775, 0.9],\n",
|
681 |
+
" [0.7717506885528564, 0.9],\n",
|
682 |
+
" [0.15474310517311096, 0.1],\n",
|
683 |
+
" [0.9121803045272827, 0.9],\n",
|
684 |
+
" [0.8079653382301331, 0.9],\n",
|
685 |
+
" [0.835227370262146, 0.9],\n",
|
686 |
+
" [0.8131006956100464, 0.9],\n",
|
687 |
+
" [0.8179431557655334, 0.7],\n",
|
688 |
+
" [0.9555372595787048, 0.9],\n",
|
689 |
+
" [0.8693034648895264, 0.9],\n",
|
690 |
+
" [0.8599344491958618, 0.9],\n",
|
691 |
+
" [0.7984340190887451, 1.0],\n",
|
692 |
+
" [0.7487307190895081, 0.8],\n",
|
693 |
+
" [0.9249837398529053, 0.9],\n",
|
694 |
+
" [0.7589347958564758, 0.9],\n",
|
695 |
+
" [0.3615719974040985, 0.2],\n",
|
696 |
+
" [0.3107086420059204, 0.4],\n",
|
697 |
+
" [0.7213683128356934, 0.7],\n",
|
698 |
+
" [0.8425479531288147, 0.9],\n",
|
699 |
+
" [0.7714840769767761, 0.4],\n",
|
700 |
+
" [0.8188011646270752, 0.8],\n",
|
701 |
+
" [0.7286155819892883, 0.6],\n",
|
702 |
+
" [0.4045146107673645, 0.2],\n",
|
703 |
+
" [0.6047667264938354, 0.4],\n",
|
704 |
+
" [0.6913502812385559, 0.4],\n",
|
705 |
+
" [0.6467778086662292, 0.5],\n",
|
706 |
+
" [0.3444978594779968, 0.2],\n",
|
707 |
+
" [0.8719446659088135, 0.9],\n",
|
708 |
+
" [0.7965179085731506, 0.9],\n",
|
709 |
+
" [0.7913227081298828, 0.9],\n",
|
710 |
+
" [0.9410687685012817, 1.0],\n",
|
711 |
+
" [0.44169095158576965, 0.3],\n",
|
712 |
+
" [0.8851080536842346, 0.9],\n",
|
713 |
+
" [0.8913792371749878, 0.9],\n",
|
714 |
+
" [0.8524451851844788, 0.9],\n",
|
715 |
+
" [0.8013086915016174, 0.9],\n",
|
716 |
+
" [0.8113997578620911, 0.8],\n",
|
717 |
+
" [0.8060635328292847, 0.8],\n",
|
718 |
+
" [0.32350343465805054, 0.3],\n",
|
719 |
+
" [0.8984023332595825, 0.9],\n",
|
720 |
+
" [0.587394654750824, 0.6],\n",
|
721 |
+
" [0.5370021462440491, 0.4],\n",
|
722 |
+
" [0.942569375038147, 0.9],\n",
|
723 |
+
" [0.8009219169616699, 0.7],\n",
|
724 |
+
" [0.896619975566864, 0.9],\n",
|
725 |
+
" [0.8658144474029541, 0.9],\n",
|
726 |
+
" [0.7016218900680542, 0.4],\n",
|
727 |
+
" [0.8604831099510193, 0.8],\n",
|
728 |
+
" [0.8699275851249695, 0.9],\n",
|
729 |
+
" [0.45879390835762024, 0.4],\n",
|
730 |
+
" [0.5645806193351746, 0.4],\n",
|
731 |
+
" [0.9470452666282654, 0.9],\n",
|
732 |
+
" [0.870012640953064, 0.9],\n",
|
733 |
+
" [0.8051838278770447, 0.9],\n",
|
734 |
+
" [0.8563840389251709, 0.9],\n",
|
735 |
+
" [0.9484373927116394, 1.0],\n",
|
736 |
+
" [0.9129024147987366, 0.9],\n",
|
737 |
+
" [0.9109873175621033, 0.9],\n",
|
738 |
+
" [0.7702903747558594, 0.9],\n",
|
739 |
+
" [0.23435641825199127, 0.3],\n",
|
740 |
+
" [0.773851215839386, 0.9],\n",
|
741 |
+
" [0.8853207230567932, 0.9],\n",
|
742 |
+
" [0.19917032122612, 0.2],\n",
|
743 |
+
" [0.9677890539169312, 0.9],\n",
|
744 |
+
" [0.636743426322937, 0.4],\n",
|
745 |
+
" [0.6735020875930786, 0.9],\n",
|
746 |
+
" [0.16387033462524414, 0.2],\n",
|
747 |
+
" [0.8149173855781555, 0.8],\n",
|
748 |
+
" [0.7951643466949463, 0.9],\n",
|
749 |
+
" [0.8490980267524719, 0.9],\n",
|
750 |
+
" [0.8448660373687744, 0.9],\n",
|
751 |
+
" [0.7506303787231445, 0.3],\n",
|
752 |
+
" [0.8545787930488586, 0.8],\n",
|
753 |
+
" [0.532014012336731, 0.9],\n",
|
754 |
+
" [0.9296807646751404, 0.1],\n",
|
755 |
+
" [0.8303433656692505, 0.9],\n",
|
756 |
+
" [0.9027467370033264, 0.8],\n",
|
757 |
+
" [0.7876701354980469, 0.9],\n",
|
758 |
+
" [0.4417012333869934, 0.5],\n",
|
759 |
+
" [0.6421471238136292, 0.4],\n",
|
760 |
+
" [0.8930367827415466, 0.9],\n",
|
761 |
+
" [0.7860907912254333, 0.4],\n",
|
762 |
+
" [0.6721752285957336, 0.6],\n",
|
763 |
+
" [0.8049106597900391, 0.9],\n",
|
764 |
+
" [0.8319634199142456, 0.9],\n",
|
765 |
+
" [0.9733098149299622, 0.9],\n",
|
766 |
+
" [0.7663295269012451, 0.9],\n",
|
767 |
+
" [0.9548745155334473, 1.0],\n",
|
768 |
+
" [0.28572025895118713, 0.2],\n",
|
769 |
+
" [0.29264578223228455, 0.2],\n",
|
770 |
+
" [0.08858926594257355, 0.1],\n",
|
771 |
+
" [0.9195079207420349, 0.9],\n",
|
772 |
+
" [0.9017946124076843, 0.9],\n",
|
773 |
+
" [0.8725588321685791, 1.0],\n",
|
774 |
+
" [0.5177860856056213, 0.6],\n",
|
775 |
+
" [0.6396905183792114, 0.8],\n",
|
776 |
+
" [0.8232091069221497, 0.6],\n",
|
777 |
+
" [0.4722677171230316, 0.3],\n",
|
778 |
+
" [0.3547070622444153, 0.6],\n",
|
779 |
+
" [0.5229023098945618, 0.6],\n",
|
780 |
+
" [0.980872392654419, 1.0],\n",
|
781 |
+
" [0.9095045924186707, 1.0],\n",
|
782 |
+
" [0.8521897196769714, 0.9],\n",
|
783 |
+
" [0.9635071158409119, 1.0],\n",
|
784 |
+
" [0.892997145652771, 0.9],\n",
|
785 |
+
" [0.4399847388267517, 0.4],\n",
|
786 |
+
" [0.840275764465332, 0.9],\n",
|
787 |
+
" [0.28466078639030457, 0.8],\n",
|
788 |
+
" [0.9222121834754944, 1.0],\n",
|
789 |
+
" [0.8009138107299805, 0.9],\n",
|
790 |
+
" [0.4688073396682739, 0.3],\n",
|
791 |
+
" [0.788908839225769, 0.5],\n",
|
792 |
+
" [0.4609881043434143, 0.3],\n",
|
793 |
+
" [0.2563250660896301, 0.2],\n",
|
794 |
+
" [0.863552451133728, 0.9],\n",
|
795 |
+
" [0.9009376764297485, 1.0],\n",
|
796 |
+
" [0.8950297236442566, 0.9],\n",
|
797 |
+
" [0.7619693279266357, 0.8],\n",
|
798 |
+
" [0.9539045691490173, 0.9],\n",
|
799 |
+
" [0.857700526714325, 0.9],\n",
|
800 |
+
" [0.917656660079956, 0.9],\n",
|
801 |
+
" [0.4197356402873993, 0.1],\n",
|
802 |
+
" [0.8468145728111267, 0.9],\n",
|
803 |
+
" [0.8413441777229309, 0.6],\n",
|
804 |
+
" [0.8770924806594849, 0.9],\n",
|
805 |
+
" [0.7613767385482788, 0.7],\n",
|
806 |
+
" [0.5931036472320557, 0.7],\n",
|
807 |
+
" [0.7604084014892578, 0.7],\n",
|
808 |
+
" [0.9281649589538574, 1.0],\n",
|
809 |
+
" [0.38664042949676514, 0.3],\n",
|
810 |
+
" [0.9006865620613098, 1.0],\n",
|
811 |
+
" [0.8754125833511353, 0.9],\n",
|
812 |
+
" [0.8797391057014465, 0.9],\n",
|
813 |
+
" [0.7036916613578796, 0.8],\n",
|
814 |
+
" [0.9311502575874329, 0.9],\n",
|
815 |
+
" [0.6805518269538879, 0.8],\n",
|
816 |
+
" [0.7984088063240051, 0.9],\n",
|
817 |
+
" [0.8592762351036072, 0.9],\n",
|
818 |
+
" [0.7293879389762878, 0.8],\n",
|
819 |
+
" [0.7824617624282837, 0.9],\n",
|
820 |
+
" [0.866423487663269, 0.9],\n",
|
821 |
+
" [0.6669572591781616, 0.7],\n",
|
822 |
+
" [0.8584144711494446, 0.9],\n",
|
823 |
+
" [0.1908380538225174, 0.2],\n",
|
824 |
+
" [0.7461979389190674, 0.6],\n",
|
825 |
+
" [0.8193972706794739, 0.9],\n",
|
826 |
+
" [0.7538160085678101, 0.9],\n",
|
827 |
+
" [0.45426538586616516, 0.2],\n",
|
828 |
+
" [0.4462665617465973, 0.4],\n",
|
829 |
+
" [0.8647792935371399, 1.0]]"
|
830 |
+
]
|
831 |
+
},
|
832 |
+
"execution_count": 15,
|
833 |
+
"metadata": {},
|
834 |
+
"output_type": "execute_result"
|
835 |
+
}
|
836 |
+
],
|
837 |
+
"source": [
|
838 |
+
"l_estimate_scores"
|
839 |
+
]
|
840 |
+
},
|
841 |
+
{
|
842 |
+
"cell_type": "code",
|
843 |
+
"execution_count": 2,
|
844 |
+
"metadata": {},
|
845 |
+
"outputs": [
|
846 |
+
{
|
847 |
+
"ename": "NameError",
|
848 |
+
"evalue": "name 'l_estimate_scores' is not defined",
|
849 |
+
"output_type": "error",
|
850 |
+
"traceback": [
|
851 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
852 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
853 |
+
"Cell \u001b[0;32mIn[2], line 7\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m mean_squared_error, mean_absolute_error, r2_score\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m# 予測値(predicted)、実際値(actual)に分割\u001b[39;00m\n\u001b[0;32m----> 7\u001b[0m predicted \u001b[38;5;241m=\u001b[39m [x[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m \u001b[43ml_estimate_scores\u001b[49m]\n\u001b[1;32m 8\u001b[0m actual \u001b[38;5;241m=\u001b[39m [x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m l_estimate_scores]\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# --- 評価指標の計算 ---\u001b[39;00m\n",
|
854 |
+
"\u001b[0;31mNameError\u001b[0m: name 'l_estimate_scores' is not defined"
|
855 |
+
]
|
856 |
+
}
|
857 |
+
],
|
858 |
+
"source": [
|
859 |
+
"import numpy as np\n",
|
860 |
+
"import matplotlib.pyplot as plt\n",
|
861 |
+
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
|
862 |
+
"\n",
|
863 |
+
"\n",
|
864 |
+
"# 予測値(predicted)、実際値(actual)に分割\n",
|
865 |
+
"predicted = [x[0] for x in l_estimate_scores]\n",
|
866 |
+
"actual = [x[1] for x in l_estimate_scores]\n",
|
867 |
+
"\n",
|
868 |
+
"# --- 評価指標の計算 ---\n",
|
869 |
+
"mse = mean_squared_error(actual, predicted)\n",
|
870 |
+
"rmse = np.sqrt(mse)\n",
|
871 |
+
"mae = mean_absolute_error(actual, predicted)\n",
|
872 |
+
"r2 = r2_score(actual, predicted)\n",
|
873 |
+
"\n",
|
874 |
+
"print(\"MSE :\", mse)\n",
|
875 |
+
"print(\"RMSE:\", rmse)\n",
|
876 |
+
"print(\"MAE :\", mae)\n",
|
877 |
+
"print(\"R^2 :\", r2)\n",
|
878 |
+
"\n",
|
879 |
+
"# --- 散布図 (Predicted vs Actual) ---\n",
|
880 |
+
"plt.figure(figsize=(5, 5))\n",
|
881 |
+
"plt.scatter(actual, predicted, color='blue', label='Data Points')\n",
|
882 |
+
"# y = x の目安線\n",
|
883 |
+
"plt.plot([0, 1], [0, 1], 'r--', label='Ideal line (y=x)')\n",
|
884 |
+
"plt.xlabel('Actual')\n",
|
885 |
+
"plt.ylabel('Predicted')\n",
|
886 |
+
"plt.title('Predicted vs Actual')\n",
|
887 |
+
"plt.legend()\n",
|
888 |
+
"plt.show()\n",
|
889 |
+
"\n",
|
890 |
+
"# --- 残差プロット (Residual plot) ---\n",
|
891 |
+
"residuals = [p - a for p, a in zip(predicted, actual)]\n",
|
892 |
+
"\n",
|
893 |
+
"plt.figure(figsize=(5, 5))\n",
|
894 |
+
"plt.scatter(actual, residuals, color='green')\n",
|
895 |
+
"plt.axhline(0, color='red', linestyle='--') # 残差が0となるライン\n",
|
896 |
+
"plt.xlabel('Actual')\n",
|
897 |
+
"plt.ylabel('Residual (Predicted - Actual)')\n",
|
898 |
+
"plt.title('Residual Plot')\n",
|
899 |
+
"plt.show()\n",
|
900 |
+
"\n",
|
901 |
+
"# --- サンプルごとのバー比較 ---\n",
|
902 |
+
"indices = range(len(actual))\n",
|
903 |
+
"bar_width = 0.4\n",
|
904 |
+
"\n",
|
905 |
+
"plt.figure(figsize=(8, 5))\n",
|
906 |
+
"plt.bar(indices, actual, width=bar_width, label='Actual', alpha=0.7)\n",
|
907 |
+
"plt.bar([i + bar_width for i in indices], predicted, width=bar_width, label='Predicted', alpha=0.7)\n",
|
908 |
+
"\n",
|
909 |
+
"plt.xlabel('Sample Index')\n",
|
910 |
+
"plt.ylabel('Score')\n",
|
911 |
+
"plt.title('Actual vs Predicted')\n",
|
912 |
+
"plt.xticks([i + bar_width/2 for i in indices], indices) # 棒の中央にインデックスを合わせる\n",
|
913 |
+
"plt.legend()\n",
|
914 |
+
"plt.tight_layout()\n",
|
915 |
+
"plt.show()\n"
|
916 |
+
]
|
917 |
+
},
|
918 |
+
{
|
919 |
+
"cell_type": "code",
|
920 |
+
"execution_count": null,
|
921 |
+
"metadata": {},
|
922 |
+
"outputs": [],
|
923 |
+
"source": []
|
924 |
+
}
|
925 |
+
],
|
926 |
+
"metadata": {
|
927 |
+
"kernelspec": {
|
928 |
+
"display_name": "vllmtest",
|
929 |
+
"language": "python",
|
930 |
+
"name": "python3"
|
931 |
+
},
|
932 |
+
"language_info": {
|
933 |
+
"codemirror_mode": {
|
934 |
+
"name": "ipython",
|
935 |
+
"version": 3
|
936 |
+
},
|
937 |
+
"file_extension": ".py",
|
938 |
+
"mimetype": "text/x-python",
|
939 |
+
"name": "python",
|
940 |
+
"nbconvert_exporter": "python",
|
941 |
+
"pygments_lexer": "ipython3",
|
942 |
+
"version": "3.12.4"
|
943 |
+
}
|
944 |
+
},
|
945 |
+
"nbformat": 4,
|
946 |
+
"nbformat_minor": 2
|
947 |
+
}
|
test_check.png
ADDED
![]() |
train_jmtb_test_v6 (コピー).ipynb
ADDED
@@ -0,0 +1,853 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"#!/usr/bin/env python\n",
|
10 |
+
"# -*- coding: utf-8 -*-\n",
|
11 |
+
"\n",
|
12 |
+
"\"\"\"\n",
|
13 |
+
"Sample script to finetune ModernBERT-Ja-130M on Japanese MT-bench (0~10 discrete scores).\n",
|
14 |
+
"\"\"\"\n",
|
15 |
+
"\n",
|
16 |
+
"import os\n",
|
17 |
+
"import gc\n",
|
18 |
+
"import re\n",
|
19 |
+
"import glob\n",
|
20 |
+
"import json\n",
|
21 |
+
"import random\n",
|
22 |
+
"import pickle\n",
|
23 |
+
"import numpy as np\n",
|
24 |
+
"import pandas as pd\n",
|
25 |
+
"from typing import Dict, Any\n",
|
26 |
+
"\n",
|
27 |
+
"from matplotlib import pyplot as plt\n",
|
28 |
+
"\n",
|
29 |
+
"import torch\n",
|
30 |
+
"from torch import nn\n",
|
31 |
+
"from datasets import Dataset, DatasetDict\n",
|
32 |
+
"\n",
|
33 |
+
"\n",
|
34 |
+
"\n",
|
35 |
+
"from transformers import (\n",
|
36 |
+
" AutoTokenizer,\n",
|
37 |
+
" AutoConfig,\n",
|
38 |
+
" ModernBertForSequenceClassification,\n",
|
39 |
+
" DataCollatorWithPadding,\n",
|
40 |
+
" Trainer,\n",
|
41 |
+
" TrainingArguments,\n",
|
42 |
+
")\n",
|
43 |
+
"\n",
|
44 |
+
"from transformers.modeling_outputs import SequenceClassifierOutput\n",
|
45 |
+
"\n",
|
46 |
+
"\n",
|
47 |
+
"\n",
|
48 |
+
"\n",
|
49 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
50 |
+
"\n",
|
51 |
+
"# -------------------------------------------------------\n",
|
52 |
+
"# 1. パラメータ設定(適宜変更)\n",
|
53 |
+
"# -------------------------------------------------------\n",
|
54 |
+
"CSV_FILE_PATH = r\"/media/kurogane/kioxia1/dataset/sss/pixiv/JMTB_1_rescore_float.csv\" # Japanese MT benchのCSVファイルパスを指定\n",
|
55 |
+
"MODEL_NAME = \"sbintuitions/modernbert-ja-130m\" # ModernBERT-Ja-130M\n",
|
56 |
+
"NUM_LABELS = 1 # 0~10の11クラス分類とする\n",
|
57 |
+
"SEED = 42\n",
|
58 |
+
"\n",
|
59 |
+
"\n",
|
60 |
+
"BASE_PROMPT = \"AIアシスタントがユーザーの質問に対して提供した回答の質を、公平な立場で評価してください。評価の際は、回答の有用性、関連性、正確性、深さ、創造性、詳細さを考慮してください。評価の前に短い説明を提供し、できるだけ客観的に評価してください。期待される言語は日本語です。日本語以外の言語での回答は、特に要求されない限り減点対象となります。全く日本語を使用しない場合、最低評価となります。ただし、Pythonのスクリプトや計算結果のみを提供する場合、日本語は必須ではありません。評価を0から1.0の範囲で小数点第一位までの数値で示し、floatで記載してください。例:\\\"0.5\\\"。\"\n",
|
61 |
+
"\n",
|
62 |
+
"# -------------------------------------------------------\n",
|
63 |
+
"# 2. CSV読み込み & データ前処理\n",
|
64 |
+
"# -------------------------------------------------------\n",
|
65 |
+
"def load_jmtb_data(csv_path: str) -> pd.DataFrame:\n",
|
66 |
+
" \"\"\"\n",
|
67 |
+
" CSVを読み込んでDataFrameを返す。\n",
|
68 |
+
" CSVの列名例:\n",
|
69 |
+
" ['model_name', 'question_id', 'category', 'question', 'answer', 'judge', 'user_prompt',\n",
|
70 |
+
" 'judgment', 'score', 'turn', 'tstamp', 'sub_category']\n",
|
71 |
+
" \"\"\"\n",
|
72 |
+
" df = pd.read_csv(csv_path)\n",
|
73 |
+
" return df\n",
|
74 |
+
"\n",
|
75 |
+
"\n",
|
76 |
+
"def build_input_text(row: pd.Series, df: pd.DataFrame) -> str:\n",
|
77 |
+
" \"\"\"\n",
|
78 |
+
" turn=1 の場合は「ターン1のみ」のテキストを構築。\n",
|
79 |
+
" turn=2 の場合は「ターン1のQ&A + ターン2のQ&A」を一つに連結したテキストを構築。\n",
|
80 |
+
" \"\"\"\n",
|
81 |
+
" turn = row[\"turn\"]\n",
|
82 |
+
" if turn == 1:\n",
|
83 |
+
" # シングルターン\n",
|
84 |
+
" # text = (\n",
|
85 |
+
" # f\"{BASE_PROMPT}\\n\\n\",\n",
|
86 |
+
" # f\"{row['question']}\",\n",
|
87 |
+
" # f\"{row['answer']}\",\n",
|
88 |
+
" # )\n",
|
89 |
+
" text = f\"<cls>{BASE_PROMPT}<sep>{row['question']}<sep>{row['answer']}<sep>\"\n",
|
90 |
+
" else:\n",
|
91 |
+
" # 2ターン目のscore行なので、同じquestion_idのturn=1を探す\n",
|
92 |
+
" qid = row[\"question_id\"]\n",
|
93 |
+
" # 同じquestion_id & turn=1の行を検索\n",
|
94 |
+
" df_turn1 = df[(df[\"question_id\"] == qid) & (df[\"turn\"] == 1)]\n",
|
95 |
+
" if len(df_turn1) > 0:\n",
|
96 |
+
" # 1行だけのはずだが、複数ある場合はiloc[0]\n",
|
97 |
+
" r1 = df_turn1.iloc[0]\n",
|
98 |
+
" # text = (\n",
|
99 |
+
" # f\"{BASE_PROMPT}\\n\\n\",\n",
|
100 |
+
" # f\"{r1['question']}\\n\\n{row['question']}\",\n",
|
101 |
+
" # f\"{row['answer']}\",\n",
|
102 |
+
" # )\n",
|
103 |
+
" text = f\"<cls>{BASE_PROMPT}<sep>{r1['question']}\\n\\n{row['question']}<sep>{row['answer']}<sep>\"\n",
|
104 |
+
" else:\n",
|
105 |
+
" # turn=1が見当たらない不備データの場合 -> 仕方ないのでターン2だけ\n",
|
106 |
+
" # text = (\n",
|
107 |
+
" # f\"{BASE_PROMPT}\\n\\n\",\n",
|
108 |
+
" # f\"{row['question']}\",\n",
|
109 |
+
" # f\"{row['answer']}\",\n",
|
110 |
+
" # )\n",
|
111 |
+
" text = f\"<cls>{BASE_PROMPT}<sep>{row['question']}<sep>{row['answer']}<sep>\"\n",
|
112 |
+
" return text\n",
|
113 |
+
"\n",
|
114 |
+
"\n",
|
115 |
+
"def create_dataset_from_df(df: pd.DataFrame) -> Dataset:\n",
|
116 |
+
" \"\"\"\n",
|
117 |
+
" pandas DataFrame から [input_text, label] を作り、Hugging Face Datasets の Dataset を返す。\n",
|
118 |
+
" - label は score (0~10) をそのまま格納。\n",
|
119 |
+
" \"\"\"\n",
|
120 |
+
" # 新しい列 input_text と label を作成\n",
|
121 |
+
" # 参照しやすいようにデータフレームをコピー\n",
|
122 |
+
" df2 = df.copy()\n",
|
123 |
+
"\n",
|
124 |
+
" # テキスト列を作成\n",
|
125 |
+
" df2[\"input_text\"] = df2.apply(lambda row: build_input_text(row, df2), axis=1)\n",
|
126 |
+
" # スコアを整数化(既にintなら不要)\n",
|
127 |
+
" df2[\"label\"] = df2[\"score\"].astype(float)\n",
|
128 |
+
"\n",
|
129 |
+
" # 必要な列のみ残す\n",
|
130 |
+
" used_cols = [\"input_text\", \"label\"]\n",
|
131 |
+
" df2 = df2[used_cols]\n",
|
132 |
+
"\n",
|
133 |
+
" # Pandas -> Huggingface Dataset\n",
|
134 |
+
" dataset = Dataset.from_pandas(df2, preserve_index=False)\n",
|
135 |
+
" return dataset\n",
|
136 |
+
"\n",
|
137 |
+
"\n",
|
138 |
+
"# -------------------------------------------------------\n",
|
139 |
+
"# 3. データセットの分割: train/valid/test\n",
|
140 |
+
"# -------------------------------------------------------\n",
|
141 |
+
"def split_dataset(\n",
|
142 |
+
" dataset: Dataset,\n",
|
143 |
+
" split_ratio=(0.8, 0.1, 0.1),\n",
|
144 |
+
" seed=SEED\n",
|
145 |
+
") -> DatasetDict:\n",
|
146 |
+
" \"\"\"\n",
|
147 |
+
" Dataset を train/dev/test に分割 (ランダム).\n",
|
148 |
+
" デフォルトは 8:1:1\n",
|
149 |
+
" \"\"\"\n",
|
150 |
+
" train_ratio, valid_ratio, test_ratio = split_ratio\n",
|
151 |
+
" # assert sum(split_ratio) == 1.0\n",
|
152 |
+
" n_samples = len(dataset)\n",
|
153 |
+
"\n",
|
154 |
+
" # まず shuffle\n",
|
155 |
+
" dataset = dataset.shuffle(seed=seed)\n",
|
156 |
+
"\n",
|
157 |
+
" train_end = int(n_samples * train_ratio)\n",
|
158 |
+
" valid_end = int(n_samples * (train_ratio + valid_ratio))\n",
|
159 |
+
"\n",
|
160 |
+
" train_dataset = dataset.select(range(0, train_end))\n",
|
161 |
+
" valid_dataset = dataset.select(range(train_end, valid_end))\n",
|
162 |
+
" test_dataset = dataset.select(range(valid_end, n_samples))\n",
|
163 |
+
"\n",
|
164 |
+
" return DatasetDict({\n",
|
165 |
+
" \"train\": train_dataset,\n",
|
166 |
+
" \"validation\": valid_dataset,\n",
|
167 |
+
" \"test\": test_dataset\n",
|
168 |
+
" })\n",
|
169 |
+
"\n",
|
170 |
+
"\n",
|
171 |
+
"# -------------------------------------------------------\n",
|
172 |
+
"# 4. トークナイズ関数\n",
|
173 |
+
"# -------------------------------------------------------\n",
|
174 |
+
"def tokenize_function(examples, tokenizer, max_length=None):\n",
|
175 |
+
" \"\"\"\n",
|
176 |
+
" 文章をトークナイズ。max_lengthは適宜設定(Noneの場合は基本無制限、FlashAttention2でpadding無視)\n",
|
177 |
+
" \"\"\"\n",
|
178 |
+
" return tokenizer(\n",
|
179 |
+
" examples[\"input_text\"],\n",
|
180 |
+
" truncation=(max_length is not None),\n",
|
181 |
+
" max_length=max_length,\n",
|
182 |
+
" )\n",
|
183 |
+
"\n",
|
184 |
+
"\n",
|
185 |
+
"# -------------------------------------------------------\n",
|
186 |
+
"# 5. 評価指標: 分類タスク (単純にAccuracyを例示)\n",
|
187 |
+
"# 必要に応じて MAE, F1, MSE などを追加実装してください。\n",
|
188 |
+
"# -------------------------------------------------------\n",
|
189 |
+
"def compute_metrics_regression(eval_pred):\n",
|
190 |
+
" logits, labels = eval_pred\n",
|
191 |
+
" # logits: shape (batch_size, 1)\n",
|
192 |
+
" predictions = logits.reshape(-1)\n",
|
193 |
+
" mae = mean_absolute_error(labels, predictions)\n",
|
194 |
+
" mse = mean_squared_error(labels, predictions)\n",
|
195 |
+
" return {\n",
|
196 |
+
" \"mae\": mae,\n",
|
197 |
+
" \"mse\": mse\n",
|
198 |
+
" }\n",
|
199 |
+
"\n",
|
200 |
+
"class ModernBertForScoring(ModernBertForSequenceClassification):\n",
|
201 |
+
" \"\"\"\n",
|
202 |
+
" ModernBertForSequenceClassificationを継承し、\n",
|
203 |
+
" 出力層にシグモイドをかけて 0~1 の範囲にマッピングするカスタムクラス。\n",
|
204 |
+
" \"\"\"\n",
|
205 |
+
"\n",
|
206 |
+
" def __init__(self, config):\n",
|
207 |
+
" super().__init__(config)\n",
|
208 |
+
" # num_labels=1 + 回帰タスク想定なので、classification_head は linear + activation とする\n",
|
209 |
+
" # 既存の self.classifier を再利用しつつ、最後にシグモイドを追加するイメージ\n",
|
210 |
+
" self.sigmoid = nn.Sigmoid()\n",
|
211 |
+
" # もし self.classifier が 1 出力以外になっている場合は要調整\n",
|
212 |
+
" # (ModernBertForSequenceClassification の場合は config.num_labels に応じた Linear が作られる想定)\n",
|
213 |
+
"\n",
|
214 |
+
" def forward(\n",
|
215 |
+
" self,\n",
|
216 |
+
" input_ids=None,\n",
|
217 |
+
" attention_mask=None,\n",
|
218 |
+
" token_type_ids=None,\n",
|
219 |
+
" labels=None,\n",
|
220 |
+
" **kwargs,\n",
|
221 |
+
" ):\n",
|
222 |
+
" # 親クラス(ModernBertForSequenceClassification)の forward を実行\n",
|
223 |
+
" # ただし 親クラスは [loss, logits] を返す実装なので、それを受け取り再加工する\n",
|
224 |
+
" outputs = super().forward(\n",
|
225 |
+
" input_ids=input_ids,\n",
|
226 |
+
" attention_mask=attention_mask,\n",
|
227 |
+
" token_type_ids=token_type_ids,\n",
|
228 |
+
" labels=None, # ここでは一旦親の loss 計算を無効化し、自前でやる\n",
|
229 |
+
" **kwargs,\n",
|
230 |
+
" )\n",
|
231 |
+
"\n",
|
232 |
+
" # 親から返される logits は shape = (batch_size, num_labels=1) のはず\n",
|
233 |
+
" logits = outputs.logits # => [B,1]\n",
|
234 |
+
"\n",
|
235 |
+
" # ここでシグモイドをかけて 0~1 に収まるようにする\n",
|
236 |
+
" preds = self.sigmoid(logits) # => [B,1], range(0,1)\n",
|
237 |
+
"\n",
|
238 |
+
" loss = None\n",
|
239 |
+
" if labels is not None:\n",
|
240 |
+
" labels = labels.view(-1, 1).float()\n",
|
241 |
+
" loss_fct = nn.MSELoss()\n",
|
242 |
+
" loss = loss_fct(preds, labels)\n",
|
243 |
+
" \n",
|
244 |
+
" # hidden_states / attentions が None の場合も型的に問題なく格納できる\n",
|
245 |
+
" return SequenceClassifierOutput(\n",
|
246 |
+
" loss=loss,\n",
|
247 |
+
" logits=preds, # シグモイド後の出力 (shape=[B,1])\n",
|
248 |
+
" hidden_states=outputs.hidden_states,\n",
|
249 |
+
" attentions=outputs.attentions,\n",
|
250 |
+
" )\n",
|
251 |
+
"\n",
|
252 |
+
"\n",
|
253 |
+
"from sklearn.metrics import mean_absolute_error, mean_squared_error\n",
|
254 |
+
"\n",
|
255 |
+
"def compute_metrics_regression(eval_pred):\n",
|
256 |
+
" logits, labels = eval_pred\n",
|
257 |
+
" # logits: [batch_size, 1], labels: [batch_size]\n",
|
258 |
+
" preds = logits.reshape(-1)\n",
|
259 |
+
" mse = mean_squared_error(labels, preds)\n",
|
260 |
+
" mae = mean_absolute_error(labels, preds)\n",
|
261 |
+
" return {\n",
|
262 |
+
" \"mse\": mse,\n",
|
263 |
+
" \"mae\": mae\n",
|
264 |
+
" }\n",
|
265 |
+
"\n",
|
266 |
+
"\n"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "code",
|
271 |
+
"execution_count": 3,
|
272 |
+
"metadata": {},
|
273 |
+
"outputs": [
|
274 |
+
{
|
275 |
+
"name": "stdout",
|
276 |
+
"output_type": "stream",
|
277 |
+
"text": [
|
278 |
+
"[Info] Loading CSV from: /media/kurogane/kioxia1/dataset/sss/pixiv/JMTB_1_rescore_float.csv\n",
|
279 |
+
"[Info] CSV loaded: 6480 rows.\n",
|
280 |
+
"[Info] Built dataset with columns: ['input_text', 'label']\n",
|
281 |
+
"DatasetDict({\n",
|
282 |
+
" train: Dataset({\n",
|
283 |
+
" features: ['input_text', 'label'],\n",
|
284 |
+
" num_rows: 5184\n",
|
285 |
+
" })\n",
|
286 |
+
" validation: Dataset({\n",
|
287 |
+
" features: ['input_text', 'label'],\n",
|
288 |
+
" num_rows: 648\n",
|
289 |
+
" })\n",
|
290 |
+
" test: Dataset({\n",
|
291 |
+
" features: ['input_text', 'label'],\n",
|
292 |
+
" num_rows: 648\n",
|
293 |
+
" })\n",
|
294 |
+
"})\n",
|
295 |
+
"[Info] Loading tokenizer for sbintuitions/modernbert-ja-130m\n"
|
296 |
+
]
|
297 |
+
},
|
298 |
+
{
|
299 |
+
"data": {
|
300 |
+
"application/vnd.jupyter.widget-view+json": {
|
301 |
+
"model_id": "851e20ba48b845b995288997e95a58c5",
|
302 |
+
"version_major": 2,
|
303 |
+
"version_minor": 0
|
304 |
+
},
|
305 |
+
"text/plain": [
|
306 |
+
"Map: 0%| | 0/5184 [00:00<?, ? examples/s]"
|
307 |
+
]
|
308 |
+
},
|
309 |
+
"metadata": {},
|
310 |
+
"output_type": "display_data"
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"data": {
|
314 |
+
"application/vnd.jupyter.widget-view+json": {
|
315 |
+
"model_id": "8cbfe3c0e65744e182794a40f80ca1bf",
|
316 |
+
"version_major": 2,
|
317 |
+
"version_minor": 0
|
318 |
+
},
|
319 |
+
"text/plain": [
|
320 |
+
"Map: 0%| | 0/648 [00:00<?, ? examples/s]"
|
321 |
+
]
|
322 |
+
},
|
323 |
+
"metadata": {},
|
324 |
+
"output_type": "display_data"
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"data": {
|
328 |
+
"application/vnd.jupyter.widget-view+json": {
|
329 |
+
"model_id": "13f95d5559f64425aa8e1ffdecc72ac0",
|
330 |
+
"version_major": 2,
|
331 |
+
"version_minor": 0
|
332 |
+
},
|
333 |
+
"text/plain": [
|
334 |
+
"Map: 0%| | 0/648 [00:00<?, ? examples/s]"
|
335 |
+
]
|
336 |
+
},
|
337 |
+
"metadata": {},
|
338 |
+
"output_type": "display_data"
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"name": "stderr",
|
342 |
+
"output_type": "stream",
|
343 |
+
"text": [
|
344 |
+
"Some weights of ModernBertForScoring were not initialized from the model checkpoint at sbintuitions/modernbert-ja-130m and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
345 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
346 |
+
"/home/kurogane/anaconda3/envs/vllmtest/lib/python3.12/site-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
|
347 |
+
" warnings.warn(\n",
|
348 |
+
"/tmp/ipykernel_122664/3661479277.py:98: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
|
349 |
+
" trainer = Trainer(\n"
|
350 |
+
]
|
351 |
+
},
|
352 |
+
{
|
353 |
+
"name": "stdout",
|
354 |
+
"output_type": "stream",
|
355 |
+
"text": [
|
356 |
+
"[Info] Starting training ...\n"
|
357 |
+
]
|
358 |
+
},
|
359 |
+
{
|
360 |
+
"data": {
|
361 |
+
"text/html": [
|
362 |
+
"\n",
|
363 |
+
" <div>\n",
|
364 |
+
" \n",
|
365 |
+
" <progress value='2592' max='2592' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
366 |
+
" [2592/2592 18:44, Epoch 32/32]\n",
|
367 |
+
" </div>\n",
|
368 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
369 |
+
" <thead>\n",
|
370 |
+
" <tr style=\"text-align: left;\">\n",
|
371 |
+
" <th>Epoch</th>\n",
|
372 |
+
" <th>Training Loss</th>\n",
|
373 |
+
" <th>Validation Loss</th>\n",
|
374 |
+
" <th>Mse</th>\n",
|
375 |
+
" <th>Mae</th>\n",
|
376 |
+
" </tr>\n",
|
377 |
+
" </thead>\n",
|
378 |
+
" <tbody>\n",
|
379 |
+
" <tr>\n",
|
380 |
+
" <td>1</td>\n",
|
381 |
+
" <td>0.121000</td>\n",
|
382 |
+
" <td>0.050717</td>\n",
|
383 |
+
" <td>0.050717</td>\n",
|
384 |
+
" <td>0.172666</td>\n",
|
385 |
+
" </tr>\n",
|
386 |
+
" <tr>\n",
|
387 |
+
" <td>2</td>\n",
|
388 |
+
" <td>0.089100</td>\n",
|
389 |
+
" <td>0.041392</td>\n",
|
390 |
+
" <td>0.041392</td>\n",
|
391 |
+
" <td>0.158076</td>\n",
|
392 |
+
" </tr>\n",
|
393 |
+
" <tr>\n",
|
394 |
+
" <td>3</td>\n",
|
395 |
+
" <td>0.078900</td>\n",
|
396 |
+
" <td>0.029573</td>\n",
|
397 |
+
" <td>0.029573</td>\n",
|
398 |
+
" <td>0.121381</td>\n",
|
399 |
+
" </tr>\n",
|
400 |
+
" <tr>\n",
|
401 |
+
" <td>4</td>\n",
|
402 |
+
" <td>0.104900</td>\n",
|
403 |
+
" <td>0.064899</td>\n",
|
404 |
+
" <td>0.064899</td>\n",
|
405 |
+
" <td>0.209115</td>\n",
|
406 |
+
" </tr>\n",
|
407 |
+
" <tr>\n",
|
408 |
+
" <td>5</td>\n",
|
409 |
+
" <td>0.050500</td>\n",
|
410 |
+
" <td>0.029966</td>\n",
|
411 |
+
" <td>0.029966</td>\n",
|
412 |
+
" <td>0.131566</td>\n",
|
413 |
+
" </tr>\n",
|
414 |
+
" <tr>\n",
|
415 |
+
" <td>6</td>\n",
|
416 |
+
" <td>0.024500</td>\n",
|
417 |
+
" <td>0.068739</td>\n",
|
418 |
+
" <td>0.068739</td>\n",
|
419 |
+
" <td>0.215073</td>\n",
|
420 |
+
" </tr>\n",
|
421 |
+
" <tr>\n",
|
422 |
+
" <td>7</td>\n",
|
423 |
+
" <td>0.017600</td>\n",
|
424 |
+
" <td>0.032628</td>\n",
|
425 |
+
" <td>0.032628</td>\n",
|
426 |
+
" <td>0.140590</td>\n",
|
427 |
+
" </tr>\n",
|
428 |
+
" <tr>\n",
|
429 |
+
" <td>8</td>\n",
|
430 |
+
" <td>0.011500</td>\n",
|
431 |
+
" <td>0.024080</td>\n",
|
432 |
+
" <td>0.024080</td>\n",
|
433 |
+
" <td>0.107284</td>\n",
|
434 |
+
" </tr>\n",
|
435 |
+
" <tr>\n",
|
436 |
+
" <td>9</td>\n",
|
437 |
+
" <td>0.009600</td>\n",
|
438 |
+
" <td>0.023550</td>\n",
|
439 |
+
" <td>0.023550</td>\n",
|
440 |
+
" <td>0.106661</td>\n",
|
441 |
+
" </tr>\n",
|
442 |
+
" <tr>\n",
|
443 |
+
" <td>10</td>\n",
|
444 |
+
" <td>0.008900</td>\n",
|
445 |
+
" <td>0.019672</td>\n",
|
446 |
+
" <td>0.019672</td>\n",
|
447 |
+
" <td>0.098421</td>\n",
|
448 |
+
" </tr>\n",
|
449 |
+
" <tr>\n",
|
450 |
+
" <td>11</td>\n",
|
451 |
+
" <td>0.007900</td>\n",
|
452 |
+
" <td>0.020809</td>\n",
|
453 |
+
" <td>0.020809</td>\n",
|
454 |
+
" <td>0.108778</td>\n",
|
455 |
+
" </tr>\n",
|
456 |
+
" <tr>\n",
|
457 |
+
" <td>12</td>\n",
|
458 |
+
" <td>0.005000</td>\n",
|
459 |
+
" <td>0.018793</td>\n",
|
460 |
+
" <td>0.018793</td>\n",
|
461 |
+
" <td>0.098439</td>\n",
|
462 |
+
" </tr>\n",
|
463 |
+
" <tr>\n",
|
464 |
+
" <td>13</td>\n",
|
465 |
+
" <td>0.003600</td>\n",
|
466 |
+
" <td>0.017699</td>\n",
|
467 |
+
" <td>0.017699</td>\n",
|
468 |
+
" <td>0.098569</td>\n",
|
469 |
+
" </tr>\n",
|
470 |
+
" <tr>\n",
|
471 |
+
" <td>14</td>\n",
|
472 |
+
" <td>0.002900</td>\n",
|
473 |
+
" <td>0.020224</td>\n",
|
474 |
+
" <td>0.020224</td>\n",
|
475 |
+
" <td>0.100133</td>\n",
|
476 |
+
" </tr>\n",
|
477 |
+
" <tr>\n",
|
478 |
+
" <td>15</td>\n",
|
479 |
+
" <td>0.003400</td>\n",
|
480 |
+
" <td>0.017207</td>\n",
|
481 |
+
" <td>0.017207</td>\n",
|
482 |
+
" <td>0.096104</td>\n",
|
483 |
+
" </tr>\n",
|
484 |
+
" <tr>\n",
|
485 |
+
" <td>16</td>\n",
|
486 |
+
" <td>0.001200</td>\n",
|
487 |
+
" <td>0.017720</td>\n",
|
488 |
+
" <td>0.017720</td>\n",
|
489 |
+
" <td>0.095289</td>\n",
|
490 |
+
" </tr>\n",
|
491 |
+
" <tr>\n",
|
492 |
+
" <td>17</td>\n",
|
493 |
+
" <td>0.001500</td>\n",
|
494 |
+
" <td>0.017983</td>\n",
|
495 |
+
" <td>0.017983</td>\n",
|
496 |
+
" <td>0.096090</td>\n",
|
497 |
+
" </tr>\n",
|
498 |
+
" <tr>\n",
|
499 |
+
" <td>18</td>\n",
|
500 |
+
" <td>0.000800</td>\n",
|
501 |
+
" <td>0.017709</td>\n",
|
502 |
+
" <td>0.017709</td>\n",
|
503 |
+
" <td>0.095045</td>\n",
|
504 |
+
" </tr>\n",
|
505 |
+
" <tr>\n",
|
506 |
+
" <td>19</td>\n",
|
507 |
+
" <td>0.000900</td>\n",
|
508 |
+
" <td>0.017456</td>\n",
|
509 |
+
" <td>0.017456</td>\n",
|
510 |
+
" <td>0.094618</td>\n",
|
511 |
+
" </tr>\n",
|
512 |
+
" <tr>\n",
|
513 |
+
" <td>20</td>\n",
|
514 |
+
" <td>0.000300</td>\n",
|
515 |
+
" <td>0.017487</td>\n",
|
516 |
+
" <td>0.017487</td>\n",
|
517 |
+
" <td>0.095387</td>\n",
|
518 |
+
" </tr>\n",
|
519 |
+
" <tr>\n",
|
520 |
+
" <td>21</td>\n",
|
521 |
+
" <td>0.000200</td>\n",
|
522 |
+
" <td>0.017418</td>\n",
|
523 |
+
" <td>0.017418</td>\n",
|
524 |
+
" <td>0.094866</td>\n",
|
525 |
+
" </tr>\n",
|
526 |
+
" <tr>\n",
|
527 |
+
" <td>22</td>\n",
|
528 |
+
" <td>0.000100</td>\n",
|
529 |
+
" <td>0.017375</td>\n",
|
530 |
+
" <td>0.017375</td>\n",
|
531 |
+
" <td>0.095027</td>\n",
|
532 |
+
" </tr>\n",
|
533 |
+
" <tr>\n",
|
534 |
+
" <td>23</td>\n",
|
535 |
+
" <td>0.000100</td>\n",
|
536 |
+
" <td>0.017170</td>\n",
|
537 |
+
" <td>0.017170</td>\n",
|
538 |
+
" <td>0.095647</td>\n",
|
539 |
+
" </tr>\n",
|
540 |
+
" <tr>\n",
|
541 |
+
" <td>24</td>\n",
|
542 |
+
" <td>0.000100</td>\n",
|
543 |
+
" <td>0.017344</td>\n",
|
544 |
+
" <td>0.017344</td>\n",
|
545 |
+
" <td>0.095632</td>\n",
|
546 |
+
" </tr>\n",
|
547 |
+
" <tr>\n",
|
548 |
+
" <td>25</td>\n",
|
549 |
+
" <td>0.000000</td>\n",
|
550 |
+
" <td>0.017127</td>\n",
|
551 |
+
" <td>0.017127</td>\n",
|
552 |
+
" <td>0.095365</td>\n",
|
553 |
+
" </tr>\n",
|
554 |
+
" <tr>\n",
|
555 |
+
" <td>26</td>\n",
|
556 |
+
" <td>0.000000</td>\n",
|
557 |
+
" <td>0.017153</td>\n",
|
558 |
+
" <td>0.017153</td>\n",
|
559 |
+
" <td>0.095548</td>\n",
|
560 |
+
" </tr>\n",
|
561 |
+
" <tr>\n",
|
562 |
+
" <td>27</td>\n",
|
563 |
+
" <td>0.000000</td>\n",
|
564 |
+
" <td>0.017262</td>\n",
|
565 |
+
" <td>0.017262</td>\n",
|
566 |
+
" <td>0.095495</td>\n",
|
567 |
+
" </tr>\n",
|
568 |
+
" <tr>\n",
|
569 |
+
" <td>28</td>\n",
|
570 |
+
" <td>0.000000</td>\n",
|
571 |
+
" <td>0.017204</td>\n",
|
572 |
+
" <td>0.017204</td>\n",
|
573 |
+
" <td>0.095659</td>\n",
|
574 |
+
" </tr>\n",
|
575 |
+
" <tr>\n",
|
576 |
+
" <td>29</td>\n",
|
577 |
+
" <td>0.000000</td>\n",
|
578 |
+
" <td>0.017318</td>\n",
|
579 |
+
" <td>0.017318</td>\n",
|
580 |
+
" <td>0.095501</td>\n",
|
581 |
+
" </tr>\n",
|
582 |
+
" <tr>\n",
|
583 |
+
" <td>30</td>\n",
|
584 |
+
" <td>0.000000</td>\n",
|
585 |
+
" <td>0.017286</td>\n",
|
586 |
+
" <td>0.017286</td>\n",
|
587 |
+
" <td>0.095896</td>\n",
|
588 |
+
" </tr>\n",
|
589 |
+
" <tr>\n",
|
590 |
+
" <td>31</td>\n",
|
591 |
+
" <td>0.000000</td>\n",
|
592 |
+
" <td>0.017363</td>\n",
|
593 |
+
" <td>0.017363</td>\n",
|
594 |
+
" <td>0.095974</td>\n",
|
595 |
+
" </tr>\n",
|
596 |
+
" <tr>\n",
|
597 |
+
" <td>32</td>\n",
|
598 |
+
" <td>0.000000</td>\n",
|
599 |
+
" <td>0.017250</td>\n",
|
600 |
+
" <td>0.017250</td>\n",
|
601 |
+
" <td>0.095597</td>\n",
|
602 |
+
" </tr>\n",
|
603 |
+
" </tbody>\n",
|
604 |
+
"</table><p>"
|
605 |
+
],
|
606 |
+
"text/plain": [
|
607 |
+
"<IPython.core.display.HTML object>"
|
608 |
+
]
|
609 |
+
},
|
610 |
+
"metadata": {},
|
611 |
+
"output_type": "display_data"
|
612 |
+
},
|
613 |
+
{
|
614 |
+
"name": "stdout",
|
615 |
+
"output_type": "stream",
|
616 |
+
"text": [
|
617 |
+
"[Info] Evaluating on test set ...\n"
|
618 |
+
]
|
619 |
+
},
|
620 |
+
{
|
621 |
+
"data": {
|
622 |
+
"text/html": [
|
623 |
+
"\n",
|
624 |
+
" <div>\n",
|
625 |
+
" \n",
|
626 |
+
" <progress value='81' max='81' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
627 |
+
" [81/81 00:01]\n",
|
628 |
+
" </div>\n",
|
629 |
+
" "
|
630 |
+
],
|
631 |
+
"text/plain": [
|
632 |
+
"<IPython.core.display.HTML object>"
|
633 |
+
]
|
634 |
+
},
|
635 |
+
"metadata": {},
|
636 |
+
"output_type": "display_data"
|
637 |
+
},
|
638 |
+
{
|
639 |
+
"name": "stdout",
|
640 |
+
"output_type": "stream",
|
641 |
+
"text": [
|
642 |
+
"Test set metrics: {'eval_loss': 0.022432124242186546, 'eval_mse': 0.022432127967476845, 'eval_mae': 0.10348472744226456, 'eval_runtime': 1.4185, 'eval_samples_per_second': 456.805, 'eval_steps_per_second': 57.101, 'epoch': 32.0}\n",
|
643 |
+
"[Info] Done. Saving final model ...\n",
|
644 |
+
"[Info] Finished.\n"
|
645 |
+
]
|
646 |
+
},
|
647 |
+
{
|
648 |
+
"data": {
|
649 |
+
"image/png": "",
|
650 |
+
"text/plain": [
|
651 |
+
"<Figure size 640x480 with 2 Axes>"
|
652 |
+
]
|
653 |
+
},
|
654 |
+
"metadata": {},
|
655 |
+
"output_type": "display_data"
|
656 |
+
}
|
657 |
+
],
|
658 |
+
"source": [
|
659 |
+
"# -------------------------------------------------------\n",
|
660 |
+
"# 6. 実行メイン\n",
|
661 |
+
"# -------------------------------------------------------\n",
|
662 |
+
"\n",
|
663 |
+
"6\n",
|
664 |
+
"# 学習関連ハイパーパラメータ\n",
|
665 |
+
"TRAIN_EPOCHS = 32\n",
|
666 |
+
"TRAIN_BATCH_SIZE = 64\n",
|
667 |
+
"EVAL_BATCH_SIZE = 8\n",
|
668 |
+
"LEARNING_RATE = 4e-5\n",
|
669 |
+
"\n",
|
670 |
+
"\n",
|
671 |
+
"SAVE_DIR = \"./modernbert_jamt_finetune_ckpt_{:0=2}\".format(len(glob.glob(\"./modernbert_jamt_finetune_ckpt_*\")))\n",
|
672 |
+
"SPLIT_RATIO = (0.8, 0.1, 0.1) # train:valid:test = 8:1:1\n",
|
673 |
+
"\n",
|
674 |
+
"\n",
|
675 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
676 |
+
"random.seed(SEED)\n",
|
677 |
+
"np.random.seed(SEED)\n",
|
678 |
+
"torch.manual_seed(SEED)\n",
|
679 |
+
"\n",
|
680 |
+
"print(f\"[Info] Loading CSV from: {CSV_FILE_PATH}\")\n",
|
681 |
+
"df = load_jmtb_data(CSV_FILE_PATH)\n",
|
682 |
+
"print(f\"[Info] CSV loaded: {len(df)} rows.\")\n",
|
683 |
+
"\n",
|
684 |
+
"# Dataset化\n",
|
685 |
+
"dataset_all = create_dataset_from_df(df)\n",
|
686 |
+
"print(\"[Info] Built dataset with columns:\", dataset_all.column_names)\n",
|
687 |
+
"\n",
|
688 |
+
"# train/dev/test split\n",
|
689 |
+
"dataset_dict = split_dataset(dataset_all, split_ratio=SPLIT_RATIO, seed=SEED)\n",
|
690 |
+
"\n",
|
691 |
+
"# 変数をpickle形式で保存する\n",
|
692 |
+
"with open(\"./dataset_dict_float.pkl\", \"wb\") as file:\n",
|
693 |
+
" pickle.dump(dataset_dict, file)\n",
|
694 |
+
"# # pickle形式で保存された変数を読み込む\n",
|
695 |
+
"# with open(\"./dataset_dict_float.pkl\", \"rb\") as file:\n",
|
696 |
+
"# dataset_dict = pickle.load(file)\n",
|
697 |
+
"\n",
|
698 |
+
"\n",
|
699 |
+
"# dataset_dict = DatasetDict.load_from_disk(\"./jmtb_dataset_splits\")\n",
|
700 |
+
"\n",
|
701 |
+
"print(dataset_dict)\n",
|
702 |
+
"\n",
|
703 |
+
"# トークナイザ準備\n",
|
704 |
+
"print(f\"[Info] Loading tokenizer for {MODEL_NAME}\")\n",
|
705 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
706 |
+
"\n",
|
707 |
+
"def tokenize_fn(examples):\n",
|
708 |
+
" return tokenize_function(examples, tokenizer, max_length=None)\n",
|
709 |
+
"\n",
|
710 |
+
"dataset_dict = dataset_dict.map(tokenize_fn, batched=True)\n",
|
711 |
+
"\n",
|
712 |
+
"# モデルConfigとモデル本体\n",
|
713 |
+
"# num_labels=11クラス分類 (score=0..10)\n",
|
714 |
+
"config = AutoConfig.from_pretrained(\n",
|
715 |
+
" MODEL_NAME,\n",
|
716 |
+
" num_labels=1,\n",
|
717 |
+
" problem_type=\"single_label_regression\"\n",
|
718 |
+
")\n",
|
719 |
+
"# 注意: AutoConfig で problem_type 指定しても、上書きするの��親クラスの forward.\n",
|
720 |
+
"# ここでは主に「情報として入れておく」ため\n",
|
721 |
+
"\n",
|
722 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
723 |
+
"\n",
|
724 |
+
"model = ModernBertForScoring.from_pretrained(\n",
|
725 |
+
" MODEL_NAME,\n",
|
726 |
+
" config=config\n",
|
727 |
+
")\n",
|
728 |
+
"\n",
|
729 |
+
"\n",
|
730 |
+
"\n",
|
731 |
+
"# 学習データと評価データへ正しく入力されるようにcollator準備\n",
|
732 |
+
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
|
733 |
+
"\n",
|
734 |
+
"# Trainer用の引数設定\n",
|
735 |
+
"training_args = TrainingArguments(\n",
|
736 |
+
" output_dir=SAVE_DIR,\n",
|
737 |
+
" num_train_epochs=TRAIN_EPOCHS,\n",
|
738 |
+
" learning_rate=LEARNING_RATE,\n",
|
739 |
+
" per_device_train_batch_size=TRAIN_BATCH_SIZE,\n",
|
740 |
+
" per_device_eval_batch_size=EVAL_BATCH_SIZE,\n",
|
741 |
+
" evaluation_strategy=\"epoch\",\n",
|
742 |
+
" save_strategy=\"epoch\",\n",
|
743 |
+
" logging_strategy=\"epoch\",\n",
|
744 |
+
" load_best_model_at_end=True,\n",
|
745 |
+
" bf16=True, # Ampere以降のGPUでMixed Precision(BF16)学習\n",
|
746 |
+
" bf16_full_eval=True,\n",
|
747 |
+
" report_to=\"none\", # レポート先をOFF(W&Bなど使わない場合)\n",
|
748 |
+
" seed=SEED,\n",
|
749 |
+
" warmup_ratio=0.1,\n",
|
750 |
+
" lr_scheduler_type=\"cosine\",\n",
|
751 |
+
" weight_decay=0.01,\n",
|
752 |
+
" # logging_dir=SAVE_DIR,\n",
|
753 |
+
")\n",
|
754 |
+
"\n",
|
755 |
+
"# Trainer生成\n",
|
756 |
+
"trainer = Trainer(\n",
|
757 |
+
" model=model,\n",
|
758 |
+
" args=training_args,\n",
|
759 |
+
" train_dataset=dataset_dict[\"train\"],\n",
|
760 |
+
" eval_dataset=dataset_dict[\"validation\"],\n",
|
761 |
+
" tokenizer=tokenizer,\n",
|
762 |
+
" data_collator=data_collator,\n",
|
763 |
+
" compute_metrics=compute_metrics_regression, #compute_metrics_classification,\n",
|
764 |
+
")\n",
|
765 |
+
"\n",
|
766 |
+
"print(\"[Info] Starting training ...\")\n",
|
767 |
+
"trainer.train()\n",
|
768 |
+
"\n",
|
769 |
+
"# 学習完了後、テストセットで評価\n",
|
770 |
+
"print(\"[Info] Evaluating on test set ...\")\n",
|
771 |
+
"metrics_test = trainer.evaluate(dataset_dict[\"test\"])\n",
|
772 |
+
"print(\"Test set metrics:\", metrics_test)\n",
|
773 |
+
"\n",
|
774 |
+
"# 終了処理\n",
|
775 |
+
"print(\"[Info] Done. Saving final model ...\")\n",
|
776 |
+
"trainer.save_model(SAVE_DIR)\n",
|
777 |
+
"print(\"[Info] Finished.\")\n",
|
778 |
+
"\n",
|
779 |
+
"\n",
|
780 |
+
"\n",
|
781 |
+
"# ロスなどの結果を別途保存\n",
|
782 |
+
"dir_checkpoints = glob.glob(os.path.join(SAVE_DIR, \"checkpoint-*\", \"trainer_state.json\"))\n",
|
783 |
+
"def atoi(text):\n",
|
784 |
+
" return int(text) if text.isdigit() else text\n",
|
785 |
+
"\n",
|
786 |
+
"def natural_keys(text):\n",
|
787 |
+
" return [ atoi(c) for c in re.split(r'(\\d+)', text) ]\n",
|
788 |
+
"\n",
|
789 |
+
"dir_checkpoints = sorted(dir_checkpoints,key=natural_keys)\n",
|
790 |
+
"\n",
|
791 |
+
"l_data_eval_mae = []\n",
|
792 |
+
"l_data_eval_mse = []\n",
|
793 |
+
"l_data_eval_loss = []\n",
|
794 |
+
"l_data_loss = []\n",
|
795 |
+
"for i_checkpoint in dir_checkpoints:\n",
|
796 |
+
" with open(i_checkpoint, \"r\", encoding=\"utf-8\") as reader:\n",
|
797 |
+
" data_check = json.load(reader)\n",
|
798 |
+
" l_data_eval_mae.append(data_check[\"log_history\"][-1][\"eval_mae\"])\n",
|
799 |
+
" l_data_eval_mse.append(data_check[\"log_history\"][-1][\"eval_mse\"])\n",
|
800 |
+
" l_data_eval_loss.append(data_check[\"log_history\"][-1][\"eval_loss\"])\n",
|
801 |
+
" l_data_loss.append(data_check[\"log_history\"][-2][\"loss\"])\n",
|
802 |
+
"\n",
|
803 |
+
"d_logs = {\n",
|
804 |
+
" \"eval_mae\": l_data_eval_mae,\n",
|
805 |
+
" \"eval_mse\": l_data_eval_mse,\n",
|
806 |
+
" \"eval_loss\": l_data_eval_loss,\n",
|
807 |
+
" \"loss\": l_data_loss,\n",
|
808 |
+
"}\n",
|
809 |
+
"\n",
|
810 |
+
"with open(os.path.join(SAVE_DIR, \"log_epochs.json\"), \"w\", encoding=\"utf-8\") as writer:\n",
|
811 |
+
" json.dump(d_logs, writer, indent=4, ensure_ascii=False)\n",
|
812 |
+
"\n",
|
813 |
+
"# 可視化\n",
|
814 |
+
"fig, ax = plt.subplots(ncols=2)\n",
|
815 |
+
"\n",
|
816 |
+
"ax[0].plot(l_data_eval_mae, label=\"eval_mae\")\n",
|
817 |
+
"ax[0].plot(l_data_eval_mse, label=\"eval_mse\")\n",
|
818 |
+
"ax[1].plot(l_data_eval_loss, label=\"eval_loss\")\n",
|
819 |
+
"ax[1].plot(l_data_loss, label=\"loss\")\n",
|
820 |
+
"\n",
|
821 |
+
"ax[0].set_xlabel(\"epochs\")\n",
|
822 |
+
"ax[1].set_xlabel(\"epochs\")\n",
|
823 |
+
"\n",
|
824 |
+
"ax[0].legend()\n",
|
825 |
+
"ax[1].legend()\n",
|
826 |
+
"\n",
|
827 |
+
"plt.savefig(os.path.join(SAVE_DIR, \"log_epochs.png\"))\n",
|
828 |
+
"plt.show()"
|
829 |
+
]
|
830 |
+
}
|
831 |
+
],
|
832 |
+
"metadata": {
|
833 |
+
"kernelspec": {
|
834 |
+
"display_name": "vllmtest",
|
835 |
+
"language": "python",
|
836 |
+
"name": "python3"
|
837 |
+
},
|
838 |
+
"language_info": {
|
839 |
+
"codemirror_mode": {
|
840 |
+
"name": "ipython",
|
841 |
+
"version": 3
|
842 |
+
},
|
843 |
+
"file_extension": ".py",
|
844 |
+
"mimetype": "text/x-python",
|
845 |
+
"name": "python",
|
846 |
+
"nbconvert_exporter": "python",
|
847 |
+
"pygments_lexer": "ipython3",
|
848 |
+
"version": "3.12.4"
|
849 |
+
}
|
850 |
+
},
|
851 |
+
"nbformat": 4,
|
852 |
+
"nbformat_minor": 2
|
853 |
+
}
|