amirulhazym
commited on
Commit
·
516c7fe
1
Parent(s):
561035e
Feat(P3L1): Configure TrainingArguments and instantiate Trainer
Browse files- 01-FineTuning-QA.ipynb +147 -1
01-FineTuning-QA.ipynb
CHANGED
|
@@ -800,9 +800,155 @@
|
|
| 800 |
},
|
| 801 |
{
|
| 802 |
"cell_type": "code",
|
| 803 |
-
"execution_count":
|
| 804 |
"id": "1d26984c-395e-4b7b-a8cf-1c71fb61394b",
|
| 805 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
"outputs": [],
|
| 807 |
"source": []
|
| 808 |
}
|
|
|
|
| 800 |
},
|
| 801 |
{
|
| 802 |
"cell_type": "code",
|
| 803 |
+
"execution_count": 13,
|
| 804 |
"id": "1d26984c-395e-4b7b-a8cf-1c71fb61394b",
|
| 805 |
"metadata": {},
|
| 806 |
+
"outputs": [
|
| 807 |
+
{
|
| 808 |
+
"name": "stdout",
|
| 809 |
+
"output_type": "stream",
|
| 810 |
+
"text": [
|
| 811 |
+
"\n",
|
| 812 |
+
"Training Arguments defined:\n"
|
| 813 |
+
]
|
| 814 |
+
}
|
| 815 |
+
],
|
| 816 |
+
"source": [
|
| 817 |
+
"#Sub-Step 7.2: Define Training Arguments\n",
|
| 818 |
+
"\n",
|
| 819 |
+
"from transformers import TrainingArguments\n",
|
| 820 |
+
"\n",
|
| 821 |
+
"# Define the directory where model checkpoints will be saved during/after training\n",
|
| 822 |
+
"output_directory = \"malay-qa-model-finetuned\" # You can change this name if you like\n",
|
| 823 |
+
"\n",
|
| 824 |
+
"# Configure the training parameters\n",
|
| 825 |
+
"# These values are starting points; adjust based on performance and GPU memory.\n",
|
| 826 |
+
"training_args = TrainingArguments(\n",
|
| 827 |
+
" output_dir=output_directory, # Directory to save model checkpoints and logs\n",
|
| 828 |
+
"\n",
|
| 829 |
+
" # --- Logging & Saving ---\n",
|
| 830 |
+
" logging_strategy=\"steps\", # Log training loss at specified intervals\n",
|
| 831 |
+
" logging_steps=5, # Log training loss every 5 steps (adjust based on dataset size)\n",
|
| 832 |
+
" save_strategy=\"epoch\", # Save a checkpoint at the end of each epoch\n",
|
| 833 |
+
" # save_total_limit=2, # Optional: Only keep the latest 2 checkpoints\n",
|
| 834 |
+
"\n",
|
| 835 |
+
" # --- Evaluation ---\n",
|
| 836 |
+
" eval_strategy=\"epoch\", # Run evaluation at the end of each epoch\n",
|
| 837 |
+
"\n",
|
| 838 |
+
" # --- Training Hyperparameters ---\n",
|
| 839 |
+
" learning_rate=2e-5, # Starting learning rate (AdamW optimizer default: 5e-5) - 2e-5 is common for fine-tuning\n",
|
| 840 |
+
" num_train_epochs=1, # Number of times to iterate over the entire training dataset. START SMALL (1-3).\n",
|
| 841 |
+
" per_device_train_batch_size=4, # Number of training examples per batch per GPU/CPU. DECREASE if you get \"CUDA out of memory\" errors. Start small (4 or 8).\n",
|
| 842 |
+
" per_device_eval_batch_size=4, # Batch size for evaluation. DECREASE if memory errors during eval.\n",
|
| 843 |
+
" weight_decay=0.01, # Adds a small penalty to large weights to prevent overfitting\n",
|
| 844 |
+
"\n",
|
| 845 |
+
" # --- Other Settings ---\n",
|
| 846 |
+
" push_to_hub=False, # Set to True later if you want to upload model to Hugging Face Hub\n",
|
| 847 |
+
" report_to=\"none\", # Disable integrations like WandB/TensorBoard for simplicity now\n",
|
| 848 |
+
" # fp16=torch.cuda.is_available(), # Optional: Enable mixed-precision training if on GPU (can speed up training and save memory) - requires 'accelerate' library\n",
|
| 849 |
+
")\n",
|
| 850 |
+
"\n",
|
| 851 |
+
"print(\"\\nTraining Arguments defined:\")\n",
|
| 852 |
+
"# Display the arguments to verify\n",
|
| 853 |
+
"# Note: It might show many more default arguments as well\n",
|
| 854 |
+
"# print(training_args)"
|
| 855 |
+
]
|
| 856 |
+
},
|
| 857 |
+
{
|
| 858 |
+
"cell_type": "code",
|
| 859 |
+
"execution_count": 14,
|
| 860 |
+
"id": "11dc845d-4f57-4931-b15b-0c7d948fc770",
|
| 861 |
+
"metadata": {},
|
| 862 |
+
"outputs": [
|
| 863 |
+
{
|
| 864 |
+
"name": "stdout",
|
| 865 |
+
"output_type": "stream",
|
| 866 |
+
"text": [
|
| 867 |
+
"\n",
|
| 868 |
+
"Using DefaultDataCollator.\n"
|
| 869 |
+
]
|
| 870 |
+
}
|
| 871 |
+
],
|
| 872 |
+
"source": [
|
| 873 |
+
"#Sub-Step 7.3: Define Data Collator\n",
|
| 874 |
+
"\n",
|
| 875 |
+
"from transformers import DefaultDataCollator\n",
|
| 876 |
+
"\n",
|
| 877 |
+
"# DefaultDataCollator is suitable for standard token classification, sequence classification,\n",
|
| 878 |
+
"# and also QA tasks where inputs are already padded to max_length by the tokenizer.\n",
|
| 879 |
+
"# It converts lists of dictionaries into batches of tensors.\n",
|
| 880 |
+
"data_collator = DefaultDataCollator()\n",
|
| 881 |
+
"\n",
|
| 882 |
+
"print(\"\\nUsing DefaultDataCollator.\")"
|
| 883 |
+
]
|
| 884 |
+
},
|
| 885 |
+
{
|
| 886 |
+
"cell_type": "code",
|
| 887 |
+
"execution_count": 15,
|
| 888 |
+
"id": "33d37330-080e-4400-bd14-272d5d770780",
|
| 889 |
+
"metadata": {},
|
| 890 |
+
"outputs": [
|
| 891 |
+
{
|
| 892 |
+
"name": "stdout",
|
| 893 |
+
"output_type": "stream",
|
| 894 |
+
"text": [
|
| 895 |
+
"\n",
|
| 896 |
+
"Instantiating Trainer...\n",
|
| 897 |
+
"Trainer instantiated successfully.\n"
|
| 898 |
+
]
|
| 899 |
+
},
|
| 900 |
+
{
|
| 901 |
+
"name": "stderr",
|
| 902 |
+
"output_type": "stream",
|
| 903 |
+
"text": [
|
| 904 |
+
"C:\\Users\\amiru\\AppData\\Local\\Temp\\ipykernel_14448\\3333510307.py:21: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
|
| 905 |
+
" trainer = Trainer(\n"
|
| 906 |
+
]
|
| 907 |
+
}
|
| 908 |
+
],
|
| 909 |
+
"source": [
|
| 910 |
+
"#Sub-Step 7.4: Instantiate the Trainer\n",
|
| 911 |
+
"\n",
|
| 912 |
+
"from transformers import Trainer\n",
|
| 913 |
+
"import torch # Ensure torch is imported if checking CUDA\n",
|
| 914 |
+
"\n",
|
| 915 |
+
"# --- Verify required variables exist ---\n",
|
| 916 |
+
"required_vars = ['model', 'training_args', 'tokenized_datasets', 'tokenizer', 'data_collator']\n",
|
| 917 |
+
"for var_name in required_vars:\n",
|
| 918 |
+
" if not var_name in locals():\n",
|
| 919 |
+
" print(f\"ERROR: Required variable '{var_name}' not found. Please run previous steps.\")\n",
|
| 920 |
+
" raise NameError(f\"Variable '{var_name}' not defined.\")\n",
|
| 921 |
+
"\n",
|
| 922 |
+
"# Check dataset structure again (paranoid check)\n",
|
| 923 |
+
"if not isinstance(tokenized_datasets, DatasetDict) or 'train' not in tokenized_datasets or 'eval' not in tokenized_datasets:\n",
|
| 924 |
+
" print(\"ERROR: 'tokenized_datasets' is not a DatasetDict or lacks 'train'/'eval' splits.\")\n",
|
| 925 |
+
" raise TypeError(\"'tokenized_datasets' has incorrect type or structure.\")\n",
|
| 926 |
+
"\n",
|
| 927 |
+
"print(\"\\nInstantiating Trainer...\")\n",
|
| 928 |
+
"try:\n",
|
| 929 |
+
" # Create the Trainer instance\n",
|
| 930 |
+
" trainer = Trainer(\n",
|
| 931 |
+
" model=model, # The pre-trained QA model loaded in Step 6\n",
|
| 932 |
+
" args=training_args, # The configuration object defined above\n",
|
| 933 |
+
" train_dataset=tokenized_datasets[\"train\"], # The tokenized training data split\n",
|
| 934 |
+
" eval_dataset=tokenized_datasets[\"eval\"], # The tokenized evaluation data split\n",
|
| 935 |
+
" tokenizer=tokenizer, # The tokenizer (used for saving/padding consistency)\n",
|
| 936 |
+
" data_collator=data_collator, # How to create batches from dataset samples\n",
|
| 937 |
+
" # compute_metrics=compute_metrics, # We're skipping custom QA metrics (EM/F1) for Level 1 MVP\n",
|
| 938 |
+
" )\n",
|
| 939 |
+
" print(\"Trainer instantiated successfully.\")\n",
|
| 940 |
+
"\n",
|
| 941 |
+
"except Exception as e:\n",
|
| 942 |
+
" print(f\"ERROR: Failed to instantiate Trainer. Check input arguments and configurations.\")\n",
|
| 943 |
+
" print(f\"Error details: {e}\")\n",
|
| 944 |
+
" raise"
|
| 945 |
+
]
|
| 946 |
+
},
|
| 947 |
+
{
|
| 948 |
+
"cell_type": "code",
|
| 949 |
+
"execution_count": null,
|
| 950 |
+
"id": "078ecf6d-b2d0-47f4-b20b-d57ab02ca498",
|
| 951 |
+
"metadata": {},
|
| 952 |
"outputs": [],
|
| 953 |
"source": []
|
| 954 |
}
|