|
--- |
|
datasets: |
|
- flaviagiammarino/vqa-rad |
|
base_model: |
|
- vikhyatk/moondream2 |
|
tags: |
|
- med |
|
- vqa |
|
- vqarad |
|
- finetune |
|
- vision |
|
- VLM |
|
--- |
|
# MoonDream2 Fine-Tuning on Med VQA RAD Dataset |
|
|
|
## Description |
|
This project fine-tunes the **MoonDream2** model on the **Med VQA RAD dataset** to improve medical visual question answering (VQA) capabilities. The fine-tuning process optimizes performance by adjusting hyperparameters using **Optuna** and tracks training progress with **Weights & Biases (W&B)**. |
|
|
|
## Training Environment |
|
- **Hardware**: NVIDIA GPU (CUDA enabled) |
|
- **Frameworks**: PyTorch, Hugging Face Transformers |
|
- **Optimizer**: Adam8bit (from bitsandbytes) |
|
- **Batch Processing**: DataLoader (Torch) |
|
- **Hyperparameter Tuning**: Optuna |
|
- **Logging**: Weights & Biases (W&B) |
|
- **Device**: CUDA-enabled GPU |
|
|
|
## Dataset |
|
- **Name**: Med VQA RAD |
|
- **Content**: Medical visual question-answering dataset with radiology images and associated Q&A pairs. |
|
- **Preprocessing**: Images are processed through **MoonDream2's vision encoder**. |
|
- **Tokenization**: Text is tokenized with **Hugging Face's tokenizer**. |
|
|
|
## Training Parameters |
|
- **Model**: vikhyatk/MoonDream2 |
|
- **Number of Image Tokens**: 729 |
|
- **Learning Rate (LR)**: Tuned via Optuna (log-uniform search between **1e-6** and **1e-4**) |
|
- **Batch Size**: 3 |
|
- **Gradient Accumulation Steps**: 8 / Batch Size |
|
- **Optimizer**: Adam8bit (betas=(0.9, 0.95), eps=1e-6) |
|
- **Loss Function**: Cross-entropy loss computed on token-level outputs |
|
- **Scheduler**: Cosine Annealing with warm-up (10% of total steps) |
|
- **Epochs**: Tuned via Optuna (default: 1-2 epochs) |
|
- **Validation Strategy**: Loss-based evaluation on validation set |
|
|
|
## Training Process |
|
1. **Collate Function**: |
|
- Prepares image embeddings using **MoonDream2’s vision encoder**. |
|
- Converts question-answer pairs into tokenized sequences. |
|
- Pads sequences to ensure uniform input length. |
|
2. **Loss Computation**: |
|
- Generates text embeddings. |
|
- Concatenates image and text embeddings. |
|
- Computes loss using **MoonDream2’s causal language model**. |
|
3. **Learning Rate Scheduling**: |
|
- Starts at **0.1 × LR** and gradually increases. |
|
- Uses cosine decay after warm-up. |
|
4. **Hyperparameter Optimization**: |
|
- Optuna optimizes learning rate and epoch count. |
|
- Trials are pruned if performance is suboptimal. |
|
5. **Logging & Monitoring**: |
|
- W&B logs loss, learning rate, and training progress. |
|
|
|
## Results |
|
- **Best Hyperparameters**: Selected via Optuna trials. |
|
- **Final Validation Loss**: Computed and logged. |
|
- **Model Performance**: Evaluated using token-wise accuracy and qualitative assessment. |
|
|
|
## References |
|
- [MoonDream2 on Hugging Face](https://huggingface.co/vikhyatk/moondream2) |
|
- [Med VQA RAD Dataset](https://github.com/med-vqa) |
|
- [Optuna Documentation](https://optuna.org/) |