File size: 2,826 Bytes
2f2be15
 
 
 
 
 
 
 
 
 
 
 
 
452676d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f2be15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
---
datasets:
- flaviagiammarino/vqa-rad
base_model:
- vikhyatk/moondream2
tags:
- med
- vqa
- vqarad
- finetune
- vision
- VLM
---
# MoonDream2 Fine-Tuning on Med VQA RAD Dataset

## Description
This project fine-tunes the **MoonDream2** model on the **Med VQA RAD dataset** to improve medical visual question answering (VQA) capabilities. The fine-tuning process optimizes performance by adjusting hyperparameters using **Optuna** and tracks training progress with **Weights & Biases (W&B)**.

## Training Environment
- **Hardware**: NVIDIA GPU (CUDA enabled)
- **Frameworks**: PyTorch, Hugging Face Transformers
- **Optimizer**: Adam8bit (from bitsandbytes)
- **Batch Processing**: DataLoader (Torch)
- **Hyperparameter Tuning**: Optuna
- **Logging**: Weights & Biases (W&B)
- **Device**: CUDA-enabled GPU

## Dataset
- **Name**: Med VQA RAD
- **Content**: Medical visual question-answering dataset with radiology images and associated Q&A pairs.
- **Preprocessing**: Images are processed through **MoonDream2's vision encoder**.
- **Tokenization**: Text is tokenized with **Hugging Face's tokenizer**.

## Training Parameters
- **Model**: vikhyatk/MoonDream2
- **Number of Image Tokens**: 729
- **Learning Rate (LR)**: Tuned via Optuna (log-uniform search between **1e-6** and **1e-4**)
- **Batch Size**: 3
- **Gradient Accumulation Steps**: 8 / Batch Size
- **Optimizer**: Adam8bit (betas=(0.9, 0.95), eps=1e-6)
- **Loss Function**: Cross-entropy loss computed on token-level outputs
- **Scheduler**: Cosine Annealing with warm-up (10% of total steps)
- **Epochs**: Tuned via Optuna (default: 1-2 epochs)
- **Validation Strategy**: Loss-based evaluation on validation set

## Training Process
1. **Collate Function**:
   - Prepares image embeddings using **MoonDream2’s vision encoder**.
   - Converts question-answer pairs into tokenized sequences.
   - Pads sequences to ensure uniform input length.
2. **Loss Computation**:
   - Generates text embeddings.
   - Concatenates image and text embeddings.
   - Computes loss using **MoonDream2’s causal language model**.
3. **Learning Rate Scheduling**:
   - Starts at **0.1 × LR** and gradually increases.
   - Uses cosine decay after warm-up.
4. **Hyperparameter Optimization**:
   - Optuna optimizes learning rate and epoch count.
   - Trials are pruned if performance is suboptimal.
5. **Logging & Monitoring**:
   - W&B logs loss, learning rate, and training progress.

## Results
- **Best Hyperparameters**: Selected via Optuna trials.
- **Final Validation Loss**: Computed and logged.
- **Model Performance**: Evaluated using token-wise accuracy and qualitative assessment.

## References
- [MoonDream2 on Hugging Face](https://huggingface.co/vikhyatk/moondream2)
- [Med VQA RAD Dataset](https://github.com/med-vqa)
- [Optuna Documentation](https://optuna.org/)