Spaces:
Runtime error
Runtime error
Merge pull request #1 from soutrik71/feat/litserve_gpu_gradio
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dvc/.gitignore +3 -0
- .dvc/config +8 -0
- .dvcignore +3 -0
- .flake8 +27 -0
- .gitattributes +2 -0
- .github/workflows/cd.yaml +65 -0
- .github/workflows/ci.yaml +171 -0
- .github/workflows/hf_deploy.yaml +61 -0
- .github/workflows/test_deploy.yml +62 -0
- .gitignore +32 -0
- .gradio/certificate.pem +31 -0
- .project-root +0 -0
- Dockerfile +76 -0
- app.py +115 -0
- basic_setup.md +419 -0
- client.py +18 -0
- configs/callbacks/default.yaml +24 -0
- configs/callbacks/early_stopping.yaml +15 -0
- configs/callbacks/model_checkpoint.yaml +17 -0
- configs/callbacks/rich_model_summary.yaml +4 -0
- configs/callbacks/rich_progress_bar.yaml +4 -0
- configs/data/catdog.yaml +9 -0
- configs/experiment/catdog_experiment.yaml +62 -0
- configs/experiment/catdog_experiment_resnet.yaml +59 -0
- configs/hydra/default.yaml +19 -0
- configs/infer.yaml +52 -0
- configs/logger/aim.yaml +6 -0
- configs/logger/csv.yaml +7 -0
- configs/logger/default.yaml +5 -0
- configs/logger/mlflow.yaml +9 -0
- configs/logger/tensorboard.yaml +10 -0
- configs/model/catdog_classifier.yaml +22 -0
- configs/model/catdog_classifier_resnet.yaml +13 -0
- configs/paths/catdog.yaml +27 -0
- configs/train.yaml +47 -0
- configs/trainer/default.yaml +20 -0
- data.dvc +6 -0
- docker-compose-old.yaml +74 -0
- docker-compose.yaml +90 -0
- docker_compose_exec.sh +58 -0
- dvc.lock +31 -0
- dvc.yaml +28 -0
- ec2_runner_setup.md +357 -0
- image.jpg +0 -0
- main.py +5 -0
- notebooks/datamodule_lightning.ipynb +301 -0
- notebooks/training_lightning_tests.ipynb +1011 -0
- poetry.lock +0 -0
- pyproject.toml +94 -0
- requirements.txt +28 -0
.dvc/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
/config.local
|
2 |
+
/tmp
|
3 |
+
/cache
|
.dvc/config
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[core]
|
2 |
+
autostage = true
|
3 |
+
remote = aws_remote
|
4 |
+
['remote "local_remote"']
|
5 |
+
url = /tmp/dvclocalstore
|
6 |
+
['remote "aws_remote"']
|
7 |
+
url = s3://deep-bucket-s3/data
|
8 |
+
region = ap-south-1
|
.dvcignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Add patterns of files dvc should ignore, which could improve
|
2 |
+
# the performance. Learn more at
|
3 |
+
# https://dvc.org/doc/user-guide/dvcignore
|
.flake8
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
max-line-length = 120
|
3 |
+
|
4 |
+
# Exclude the virtual environment, notebooks folder, tests folder, and other unnecessary directories
|
5 |
+
exclude =
|
6 |
+
.venv,
|
7 |
+
__pycache__,
|
8 |
+
.git,
|
9 |
+
build,
|
10 |
+
dist,
|
11 |
+
notebooks,
|
12 |
+
tests,
|
13 |
+
.ipynb_checkpoints,
|
14 |
+
.mypy_cache,
|
15 |
+
.pytest_cache,
|
16 |
+
pytorch_project
|
17 |
+
|
18 |
+
ignore =
|
19 |
+
E203,
|
20 |
+
W503,
|
21 |
+
E501,
|
22 |
+
E402,
|
23 |
+
F401,
|
24 |
+
E401
|
25 |
+
|
26 |
+
max-complexity = 10
|
27 |
+
show-source = True
|
.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
2 |
+
checkpoints/*.ckpt filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/cd.yaml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Deploy PyTorch Training to ECR with Docker Compose
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
- feat/pytorch-catdog-setup
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
deploy:
|
11 |
+
runs-on: self-hosted
|
12 |
+
|
13 |
+
steps:
|
14 |
+
- name: Checkout repository
|
15 |
+
uses: actions/checkout@v4
|
16 |
+
|
17 |
+
- name: Set up Docker Buildx
|
18 |
+
uses: docker/setup-buildx-action@v3
|
19 |
+
|
20 |
+
- name: Configure AWS credentials
|
21 |
+
uses: aws-actions/configure-aws-credentials@v4
|
22 |
+
with:
|
23 |
+
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
24 |
+
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
25 |
+
aws-region: ${{ secrets.AWS_REGION }}
|
26 |
+
|
27 |
+
- name: Log in to Amazon ECR
|
28 |
+
id: login-ecr
|
29 |
+
uses: aws-actions/amazon-ecr-login@v2
|
30 |
+
|
31 |
+
- name: Create .env file
|
32 |
+
run: |
|
33 |
+
echo "AWS_ACCESS_KEY_ID=${{ secrets.AWS_ACCESS_KEY_ID }}" >> .env
|
34 |
+
echo "AWS_SECRET_ACCESS_KEY=${{ secrets.AWS_SECRET_ACCESS_KEY }}" >> .env
|
35 |
+
echo "AWS_REGION=${{ secrets.AWS_REGION }}" >> .env
|
36 |
+
|
37 |
+
- name: Run Docker Compose for train service
|
38 |
+
run: |
|
39 |
+
docker-compose stop
|
40 |
+
docker-compose build
|
41 |
+
docker-compose up -d train
|
42 |
+
docker-compose up -d eval
|
43 |
+
docker-compose up -d server
|
44 |
+
docker-compose up -d client
|
45 |
+
docker-compose remove
|
46 |
+
|
47 |
+
- name: Build, tag, and push Docker image to Amazon ECR
|
48 |
+
env:
|
49 |
+
REGISTRY: ${{ steps.login-ecr.outputs.registry }}
|
50 |
+
REPOSITORY: soutrik71/pytorch_catdog
|
51 |
+
IMAGE_TAG: ${{ github.sha }}
|
52 |
+
run: |
|
53 |
+
docker build -t $REGISTRY/$REPOSITORY:$IMAGE_TAG .
|
54 |
+
docker push $REGISTRY/$REPOSITORY:$IMAGE_TAG
|
55 |
+
docker tag $REGISTRY/$REPOSITORY:$IMAGE_TAG $REGISTRY/$REPOSITORY:latest
|
56 |
+
docker push $REGISTRY/$REPOSITORY:latest
|
57 |
+
|
58 |
+
- name: Pull Docker image from ECR and verify
|
59 |
+
env:
|
60 |
+
REGISTRY: ${{ steps.login-ecr.outputs.registry }}
|
61 |
+
REPOSITORY: soutrik71/pytorch_catdog
|
62 |
+
IMAGE_TAG: ${{ github.sha }}
|
63 |
+
run: |
|
64 |
+
docker pull $REGISTRY/$REPOSITORY:$IMAGE_TAG
|
65 |
+
docker images | grep "$REGISTRY/$REPOSITORY"
|
.github/workflows/ci.yaml
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: CI Pipeline
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
# - feat/pytorch-catdog-setup
|
8 |
+
pull_request:
|
9 |
+
branches:
|
10 |
+
- main
|
11 |
+
workflow_dispatch:
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
python_basic_test:
|
15 |
+
name: Test current codebase and setup Python environment
|
16 |
+
runs-on: self-hosted
|
17 |
+
|
18 |
+
strategy:
|
19 |
+
matrix:
|
20 |
+
python-version: [3.10.15]
|
21 |
+
|
22 |
+
env:
|
23 |
+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
24 |
+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
25 |
+
AWS_REGION: ${{ secrets.AWS_REGION }}
|
26 |
+
|
27 |
+
steps:
|
28 |
+
- name: Configure AWS credentials
|
29 |
+
uses: aws-actions/configure-aws-credentials@v4
|
30 |
+
with:
|
31 |
+
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
32 |
+
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
33 |
+
aws-region: ${{ secrets.AWS_REGION }}
|
34 |
+
|
35 |
+
- name: Print branch name
|
36 |
+
run: echo "Branch name is ${{ github.ref_name }}"
|
37 |
+
|
38 |
+
- name: Checkout code
|
39 |
+
uses: actions/checkout@v3
|
40 |
+
|
41 |
+
- name: Set up Python ${{ matrix.python-version }}
|
42 |
+
uses: actions/setup-python@v4
|
43 |
+
with:
|
44 |
+
python-version: ${{ matrix.python-version }}
|
45 |
+
|
46 |
+
- name: Install Poetry
|
47 |
+
run: |
|
48 |
+
python -m pip install --upgrade pip
|
49 |
+
pip install poetry
|
50 |
+
poetry config virtualenvs.in-project true
|
51 |
+
|
52 |
+
- name: Cache Poetry dependencies
|
53 |
+
uses: actions/cache@v3
|
54 |
+
with:
|
55 |
+
path: |
|
56 |
+
.venv
|
57 |
+
~/.cache/pypoetry
|
58 |
+
key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }}
|
59 |
+
restore-keys: |
|
60 |
+
${{ runner.os }}-poetry-
|
61 |
+
|
62 |
+
- name: Install dependencies
|
63 |
+
run: poetry install --no-root --no-interaction
|
64 |
+
|
65 |
+
- name: Check Poetry environment
|
66 |
+
run: poetry env info
|
67 |
+
|
68 |
+
- name: Create .env file
|
69 |
+
run: |
|
70 |
+
echo "AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}" >> .env
|
71 |
+
echo "AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}" >> .env
|
72 |
+
echo "AWS_REGION=${AWS_REGION}" >> .env
|
73 |
+
echo ".env file created"
|
74 |
+
|
75 |
+
- name: Run lint checks
|
76 |
+
run: poetry run flake8 . --exclude=.venv,tests,notebooks
|
77 |
+
|
78 |
+
- name: black
|
79 |
+
run: poetry run black . --exclude="(\.venv|tests|notebooks)"
|
80 |
+
|
81 |
+
pytorch_code_test:
|
82 |
+
name: Test PyTorch code
|
83 |
+
runs-on: self-hosted
|
84 |
+
|
85 |
+
env:
|
86 |
+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
87 |
+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
88 |
+
AWS_REGION: ${{ secrets.AWS_REGION }}
|
89 |
+
|
90 |
+
needs: python_basic_test
|
91 |
+
|
92 |
+
strategy:
|
93 |
+
matrix:
|
94 |
+
python-version: [3.10.15]
|
95 |
+
|
96 |
+
steps:
|
97 |
+
- name: Configure AWS credentials
|
98 |
+
uses: aws-actions/configure-aws-credentials@v4
|
99 |
+
with:
|
100 |
+
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
101 |
+
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
102 |
+
aws-region: ${{ secrets.AWS_REGION }}
|
103 |
+
|
104 |
+
- name: Checkout code
|
105 |
+
uses: actions/checkout@v3
|
106 |
+
|
107 |
+
- name: Set up Python ${{ matrix.python-version }}
|
108 |
+
uses: actions/setup-python@v4
|
109 |
+
with:
|
110 |
+
python-version: ${{ matrix.python-version }}
|
111 |
+
|
112 |
+
- name: Install Poetry
|
113 |
+
run: |
|
114 |
+
python -m pip install --upgrade pip
|
115 |
+
pip install poetry
|
116 |
+
poetry config virtualenvs.in-project true
|
117 |
+
|
118 |
+
- name: Cache Poetry dependencies
|
119 |
+
uses: actions/cache@v3
|
120 |
+
with:
|
121 |
+
path: |
|
122 |
+
.venv
|
123 |
+
~/.cache/pypoetry
|
124 |
+
key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }}
|
125 |
+
restore-keys: |
|
126 |
+
${{ runner.os }}-poetry-
|
127 |
+
|
128 |
+
- name: Install dependencies
|
129 |
+
run: poetry install --no-root --no-interaction
|
130 |
+
|
131 |
+
- name: Check Poetry environment
|
132 |
+
run: poetry env info
|
133 |
+
|
134 |
+
- name: Get data from DVC
|
135 |
+
run: |
|
136 |
+
poetry run dvc pull || echo "No data to pull from DVC"
|
137 |
+
|
138 |
+
- name: Run Train code
|
139 |
+
run: |
|
140 |
+
echo "Training the model"
|
141 |
+
poetry run python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=train ++train=True ++test=False || exit 1
|
142 |
+
poetry run python -m src.create_artifacts
|
143 |
+
|
144 |
+
- name: Run Test code
|
145 |
+
run: |
|
146 |
+
echo "Testing the model"
|
147 |
+
poetry run python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=test ++train=False ++test=True || exit 1
|
148 |
+
|
149 |
+
- name: upload model checkpoints
|
150 |
+
uses: actions/upload-artifact@v4
|
151 |
+
with:
|
152 |
+
name: model-checkpoints
|
153 |
+
path: ./checkpoints/
|
154 |
+
|
155 |
+
- name: upload logs
|
156 |
+
uses: actions/upload-artifact@v4
|
157 |
+
with:
|
158 |
+
name: logs
|
159 |
+
path: ./logs/
|
160 |
+
|
161 |
+
- name: upload configs
|
162 |
+
uses: actions/upload-artifact@v4
|
163 |
+
with:
|
164 |
+
name: configs
|
165 |
+
path: ./configs/
|
166 |
+
|
167 |
+
- name: upload artifacts
|
168 |
+
uses: actions/upload-artifact@v4
|
169 |
+
with:
|
170 |
+
name: artifacts
|
171 |
+
path: ./artifacts/
|
.github/workflows/hf_deploy.yaml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face Hub
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
- feat/litserve_gpu_gradio
|
8 |
+
jobs:
|
9 |
+
sync-to-hub:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
steps:
|
12 |
+
- uses: actions/checkout@v4
|
13 |
+
with:
|
14 |
+
fetch-depth: 0
|
15 |
+
lfs: true
|
16 |
+
|
17 |
+
- name: Install Git LFS
|
18 |
+
run: |
|
19 |
+
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
|
20 |
+
sudo apt-get install git-lfs
|
21 |
+
git lfs install
|
22 |
+
git lfs pull
|
23 |
+
|
24 |
+
- name: Add remote
|
25 |
+
run: |
|
26 |
+
git remote add space https://$USER:[email protected]/spaces/$USER/$SPACE
|
27 |
+
env:
|
28 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
29 |
+
USER: soutrik
|
30 |
+
SPACE: gradio_demo_CatDogClassifier
|
31 |
+
|
32 |
+
- name: Add README.md
|
33 |
+
run: |
|
34 |
+
cat <<EOF > README.md
|
35 |
+
---
|
36 |
+
title: My Gradio App CatDog Classifier
|
37 |
+
emoji: 🚀
|
38 |
+
colorFrom: blue
|
39 |
+
colorTo: green
|
40 |
+
sdk: gradio
|
41 |
+
sdk_version: "5.7.1"
|
42 |
+
app_file: app.py
|
43 |
+
pinned: false
|
44 |
+
---
|
45 |
+
EOF
|
46 |
+
|
47 |
+
- name: Configure Git identity
|
48 |
+
run: |
|
49 |
+
git config user.name "soutrik"
|
50 |
+
git config user.email "[email protected]"
|
51 |
+
|
52 |
+
- name: Push to hub
|
53 |
+
run: |
|
54 |
+
git add README.md
|
55 |
+
git commit -m "Add README.md"
|
56 |
+
git push --force https://$USER:[email protected]/spaces/$USER/$SPACE main
|
57 |
+
env:
|
58 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
59 |
+
USER: soutrik
|
60 |
+
SPACE: gradio_demo_CatDogClassifier
|
61 |
+
|
.github/workflows/test_deploy.yml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Deploy to ECR and Run Docker Compose with AWS Actions for GitHub and Docker Buildx
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
- feat/framework-setup
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
deploy:
|
11 |
+
runs-on: self-hosted
|
12 |
+
|
13 |
+
steps:
|
14 |
+
- name: Checkout repository
|
15 |
+
uses: actions/checkout@v4
|
16 |
+
|
17 |
+
- name: Set up Docker Buildx
|
18 |
+
uses: docker/setup-buildx-action@v3
|
19 |
+
|
20 |
+
- name: Configure AWS credentials
|
21 |
+
uses: aws-actions/configure-aws-credentials@v4
|
22 |
+
with:
|
23 |
+
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
24 |
+
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
25 |
+
aws-region: ${{ secrets.AWS_REGION }}
|
26 |
+
|
27 |
+
- name: Log in to Amazon ECR
|
28 |
+
id: login-ecr
|
29 |
+
uses: aws-actions/amazon-ecr-login@v2
|
30 |
+
|
31 |
+
- name: Build, tag, and push docker image to Amazon ECR
|
32 |
+
env:
|
33 |
+
POSTGRES_DB: ${{ secrets.POSTGRES_DB }}
|
34 |
+
POSTGRES_USER: ${{ secrets.POSTGRES_USER }}
|
35 |
+
POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }}
|
36 |
+
REDIS_PORT: ${{ secrets.REDIS_PORT }}
|
37 |
+
REDIS_HOST: ${{ secrets.REDIS_HOST }}
|
38 |
+
FLOWER_BASIC_AUTH: ${{ secrets.FLOWER_BASIC_AUTH }}
|
39 |
+
REDIS_URL: ${{ secrets.REDIS_URL }}
|
40 |
+
DATABASE_URL: ${{ secrets.DATABASE_URL }}
|
41 |
+
BROKER_URL: ${{ secrets.BROKER_URL }}
|
42 |
+
REGISTRY: ${{ steps.login-ecr.outputs.registry }}
|
43 |
+
REPOSITORY: soutrik71/test
|
44 |
+
IMAGE_TAG: ${{ github.sha }}
|
45 |
+
run: |
|
46 |
+
docker build -t $REGISTRY/$REPOSITORY:$IMAGE_TAG .
|
47 |
+
docker push $REGISTRY/$REPOSITORY:$IMAGE_TAG
|
48 |
+
|
49 |
+
|
50 |
+
- name: Run Docker Compose
|
51 |
+
env:
|
52 |
+
POSTGRES_DB: ${{ secrets.POSTGRES_DB }}
|
53 |
+
POSTGRES_USER: ${{ secrets.POSTGRES_USER }}
|
54 |
+
POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }}
|
55 |
+
REDIS_PORT: ${{ secrets.REDIS_PORT }}
|
56 |
+
REDIS_HOST: ${{ secrets.REDIS_HOST }}
|
57 |
+
FLOWER_BASIC_AUTH: ${{ secrets.FLOWER_BASIC_AUTH }}
|
58 |
+
REDIS_URL: ${{ secrets.REDIS_URL }}
|
59 |
+
DATABASE_URL: ${{ secrets.DATABASE_URL }}
|
60 |
+
BROKER_URL: ${{ secrets.BROKER_URL }}
|
61 |
+
run: |
|
62 |
+
docker-compose up -d --build app
|
.gitignore
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aws/
|
2 |
+
*.zip
|
3 |
+
*.tar.gz
|
4 |
+
*.tar.bz2
|
5 |
+
.env
|
6 |
+
*.pyc
|
7 |
+
*.cpython-*.*
|
8 |
+
src/__pycache__/
|
9 |
+
src/*.egg-info/
|
10 |
+
src/dist/
|
11 |
+
src/build/
|
12 |
+
src/.eggs/
|
13 |
+
src/.pytest_cache/
|
14 |
+
src/.mypy_cache/
|
15 |
+
src/.tox/
|
16 |
+
src/.coverage
|
17 |
+
src/.vscode/
|
18 |
+
src/.vscode-test/
|
19 |
+
app/core/__pycache__/
|
20 |
+
src/__pycache__/test_infra.cpython-310.pyc
|
21 |
+
app/core/__pycache__/config.cpython-310.pyc
|
22 |
+
data/
|
23 |
+
!configs/data/
|
24 |
+
checkpoints/
|
25 |
+
logs/
|
26 |
+
/data
|
27 |
+
artifacts/
|
28 |
+
artifacts/*
|
29 |
+
*png
|
30 |
+
*jpg
|
31 |
+
*jpeg
|
32 |
+
artifacts/image_prediction.png
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
.project-root
ADDED
File without changes
|
Dockerfile
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stage 1: Base image with CUDA 12.2, cuDNN 9, and minimal runtime for PyTorch
|
2 |
+
FROM nvidia/cuda:12.2.0-runtime-ubuntu20.04 as base
|
3 |
+
|
4 |
+
LABEL maintainer="Soutrik [email protected]" \
|
5 |
+
description="Base Docker image for running a Python app with Poetry and GPU support."
|
6 |
+
|
7 |
+
# Install necessary system dependencies, including Python 3.10
|
8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
9 |
+
software-properties-common && \
|
10 |
+
add-apt-repository ppa:deadsnakes/ppa && \
|
11 |
+
apt-get update && apt-get install -y --no-install-recommends \
|
12 |
+
python3.10 \
|
13 |
+
python3.10-venv \
|
14 |
+
python3.10-dev \
|
15 |
+
python3-pip \
|
16 |
+
curl \
|
17 |
+
git \
|
18 |
+
build-essential && \
|
19 |
+
apt-get clean && rm -rf /var/lib/apt/lists/*
|
20 |
+
|
21 |
+
# Set Python 3.10 as the default
|
22 |
+
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 && \
|
23 |
+
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 && \
|
24 |
+
python --version
|
25 |
+
|
26 |
+
# Install Poetry
|
27 |
+
RUN curl -sSL https://install.python-poetry.org | python3 - && \
|
28 |
+
ln -s /root/.local/bin/poetry /usr/local/bin/poetry
|
29 |
+
|
30 |
+
# Configure Poetry environment
|
31 |
+
ENV POETRY_NO_INTERACTION=1 \
|
32 |
+
POETRY_VIRTUALENVS_IN_PROJECT=1 \
|
33 |
+
POETRY_CACHE_DIR=/tmp/poetry_cache
|
34 |
+
|
35 |
+
# Set the working directory to /app
|
36 |
+
WORKDIR /app
|
37 |
+
|
38 |
+
# Copy pyproject.toml and poetry.lock to install dependencies
|
39 |
+
COPY pyproject.toml poetry.lock /app/
|
40 |
+
|
41 |
+
# Install Python dependencies without building the app itself
|
42 |
+
RUN --mount=type=cache,target=/tmp/poetry_cache poetry install --only main --no-root
|
43 |
+
|
44 |
+
# Stage 2: Build stage for the application
|
45 |
+
FROM base as builder
|
46 |
+
|
47 |
+
# Copy application source code and necessary files
|
48 |
+
COPY src /app/src
|
49 |
+
COPY configs /app/configs
|
50 |
+
COPY .project-root /app/.project-root
|
51 |
+
COPY main.py /app/main.py
|
52 |
+
|
53 |
+
# Stage 3: Final runtime stage
|
54 |
+
FROM base as runner
|
55 |
+
|
56 |
+
# Copy application source code and dependencies from the builder stage
|
57 |
+
COPY --from=builder /app/src /app/src
|
58 |
+
COPY --from=builder /app/configs /app/configs
|
59 |
+
COPY --from=builder /app/.project-root /app/.project-root
|
60 |
+
COPY --from=builder /app/main.py /app/main.py
|
61 |
+
COPY --from=builder /app/.venv /app/.venv
|
62 |
+
|
63 |
+
# Copy client files
|
64 |
+
COPY run_client.sh /app/run_client.sh
|
65 |
+
|
66 |
+
# Set the working directory to /app
|
67 |
+
WORKDIR /app
|
68 |
+
|
69 |
+
# Add virtual environment to PATH
|
70 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
71 |
+
|
72 |
+
# Install PyTorch with CUDA 12.2 support (adjusted for compatibility)
|
73 |
+
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu122
|
74 |
+
|
75 |
+
# Default command to run the application
|
76 |
+
CMD ["python", "-m", "main"]
|
app.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from PIL import Image
|
5 |
+
from pathlib import Path
|
6 |
+
from torchvision import transforms
|
7 |
+
from src.models.catdog_model_resnet import ResnetClassifier
|
8 |
+
from src.utils.aws_s3_services import S3Handler
|
9 |
+
from src.utils.logging_utils import setup_logger
|
10 |
+
from loguru import logger
|
11 |
+
import rootutils
|
12 |
+
|
13 |
+
# Load environment variables and configure logger
|
14 |
+
setup_logger(Path("./logs") / "gradio_app.log")
|
15 |
+
# Setup root directory
|
16 |
+
root = rootutils.setup_root(__file__, indicator=".project-root")
|
17 |
+
|
18 |
+
|
19 |
+
class ImageClassifier:
|
20 |
+
def __init__(self, cfg):
|
21 |
+
self.cfg = cfg
|
22 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
+
self.classes = cfg.labels
|
24 |
+
|
25 |
+
# Download and load model from S3
|
26 |
+
logger.info("Downloading model from S3...")
|
27 |
+
s3_handler = S3Handler(bucket_name="deep-bucket-s3")
|
28 |
+
s3_handler.download_folder("checkpoints", "checkpoints")
|
29 |
+
|
30 |
+
logger.info("Loading model checkpoint...")
|
31 |
+
self.model = ResnetClassifier.load_from_checkpoint(
|
32 |
+
checkpoint_path=cfg.ckpt_path
|
33 |
+
)
|
34 |
+
self.model = self.model.to(self.device)
|
35 |
+
self.model.eval()
|
36 |
+
|
37 |
+
# Image transform
|
38 |
+
self.transform = transforms.Compose(
|
39 |
+
[
|
40 |
+
transforms.Resize((cfg.data.image_size, cfg.data.image_size)),
|
41 |
+
transforms.ToTensor(),
|
42 |
+
transforms.Normalize(
|
43 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
44 |
+
),
|
45 |
+
]
|
46 |
+
)
|
47 |
+
|
48 |
+
def predict(self, image):
|
49 |
+
if image is None:
|
50 |
+
return "No image provided.", None
|
51 |
+
|
52 |
+
# Preprocess the image
|
53 |
+
logger.info("Processing input image...")
|
54 |
+
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
55 |
+
|
56 |
+
# Inference
|
57 |
+
with torch.no_grad():
|
58 |
+
output = self.model(img_tensor)
|
59 |
+
probabilities = F.softmax(output, dim=1)
|
60 |
+
predicted_class_idx = torch.argmax(probabilities, dim=1).item()
|
61 |
+
confidence = probabilities[0][predicted_class_idx].item()
|
62 |
+
|
63 |
+
predicted_label = self.classes[predicted_class_idx]
|
64 |
+
logger.info(f"Prediction: {predicted_label} (Confidence: {confidence:.2f})")
|
65 |
+
return predicted_label, confidence
|
66 |
+
|
67 |
+
|
68 |
+
def create_gradio_app(cfg):
|
69 |
+
classifier = ImageClassifier(cfg)
|
70 |
+
|
71 |
+
def classify_image(image):
|
72 |
+
"""Gradio interface function."""
|
73 |
+
predicted_label, confidence = classifier.predict(image)
|
74 |
+
if predicted_label:
|
75 |
+
return f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
|
76 |
+
return "Error during prediction."
|
77 |
+
|
78 |
+
# Create Gradio interface
|
79 |
+
with gr.Blocks() as demo:
|
80 |
+
gr.Markdown(
|
81 |
+
"""
|
82 |
+
# Cat vs Dog Classifier
|
83 |
+
Upload an image of a cat or a dog to classify it with confidence.
|
84 |
+
"""
|
85 |
+
)
|
86 |
+
|
87 |
+
with gr.Row():
|
88 |
+
with gr.Column():
|
89 |
+
input_image = gr.Image(
|
90 |
+
label="Input Image", type="pil", image_mode="RGB"
|
91 |
+
)
|
92 |
+
predict_button = gr.Button("Classify")
|
93 |
+
with gr.Column():
|
94 |
+
output_text = gr.Textbox(label="Prediction")
|
95 |
+
|
96 |
+
# Define interaction
|
97 |
+
predict_button.click(
|
98 |
+
fn=classify_image, inputs=[input_image], outputs=[output_text]
|
99 |
+
)
|
100 |
+
|
101 |
+
return demo
|
102 |
+
|
103 |
+
|
104 |
+
# Hydra config wrapper for launching Gradio app
|
105 |
+
if __name__ == "__main__":
|
106 |
+
import hydra
|
107 |
+
from omegaconf import DictConfig
|
108 |
+
|
109 |
+
@hydra.main(config_path="configs", config_name="infer", version_base="1.3")
|
110 |
+
def main(cfg: DictConfig):
|
111 |
+
logger.info("Launching Gradio App...")
|
112 |
+
demo = create_gradio_app(cfg)
|
113 |
+
demo.launch(share=True, server_name="0.0.0.0", server_port=7860)
|
114 |
+
|
115 |
+
main()
|
basic_setup.md
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## __POETRY SETUP__
|
2 |
+
|
3 |
+
```bash
|
4 |
+
# Install poetry
|
5 |
+
conda create -n poetry_env python=3.10 -y
|
6 |
+
conda activate poetry_env
|
7 |
+
pip install poetry
|
8 |
+
poetry env info
|
9 |
+
poetry new pytorch_project
|
10 |
+
cd pytorch_project/
|
11 |
+
# fill up the pyproject.toml file without pytorch and torchvision
|
12 |
+
poetry install
|
13 |
+
|
14 |
+
# Add dependencies to the project for pytorch and torchvision
|
15 |
+
poetry source add --priority explicit pytorch_cpu https://download.pytorch.org/whl/cpu
|
16 |
+
poetry add --source pytorch_cpu torch torchvision
|
17 |
+
poetry lock
|
18 |
+
poetry show
|
19 |
+
poetry install --no-root
|
20 |
+
|
21 |
+
# Add dependencies to the project
|
22 |
+
poetry add matplotlib
|
23 |
+
poetry add hydra-core
|
24 |
+
poetry add omegaconf
|
25 |
+
poetry add hydra_colorlog
|
26 |
+
poetry add --dev black #
|
27 |
+
poetry lock
|
28 |
+
poetry show
|
29 |
+
|
30 |
+
Type Purpose Installation Command
|
31 |
+
Normal Dependency Required for the app to run in production. poetry add <package>
|
32 |
+
Development Dependency Needed only during development (e.g., testing, linting). poetry add --dev <package>
|
33 |
+
# Add dependencies to the project with specific version
|
34 |
+
poetry add <package_name>@<version>
|
35 |
+
```
|
36 |
+
|
37 |
+
## __MULTISTAGEDOCKER SETUP__
|
38 |
+
|
39 |
+
#### Step-by-Step Guide to Creating Dockerfile and docker-compose.yml for a New Code Repo
|
40 |
+
|
41 |
+
If you're new to the project and need to set up Docker and Docker Compose to run the training and inference steps, follow these steps.
|
42 |
+
|
43 |
+
---
|
44 |
+
|
45 |
+
### 1. Setting Up the Dockerfile
|
46 |
+
|
47 |
+
A Dockerfile is a set of instructions that Docker uses to create an image. In this case, we'll use a __multi-stage build__ to make the final image lightweight while managing dependencies with `Poetry`.
|
48 |
+
|
49 |
+
#### Step-by-Step Process for Creating the Dockerfile
|
50 |
+
|
51 |
+
1. __Choose a Base Image__:
|
52 |
+
- We need to choose a Python image that matches the project's required version (e.g., Python 3.10.14).
|
53 |
+
- Use the lightweight __`slim`__ version to minimize image size.
|
54 |
+
|
55 |
+
```Dockerfile
|
56 |
+
FROM python:3.10.14-slim as builder
|
57 |
+
```
|
58 |
+
|
59 |
+
2. __Install Dependencies in the Build Stage__:
|
60 |
+
- We'll use __Poetry__ for dependency management. Install it using `pip`.
|
61 |
+
- Next, copy the `pyproject.toml` and `poetry.lock` files to the `/app` directory to install dependencies.
|
62 |
+
|
63 |
+
```Dockerfile
|
64 |
+
RUN pip3 install poetry==1.7.1
|
65 |
+
WORKDIR /app
|
66 |
+
COPY pytorch_project/pyproject.toml pytorch_project/poetry.lock /app/
|
67 |
+
```
|
68 |
+
|
69 |
+
3. __Configure Poetry__:
|
70 |
+
- Configure Poetry to install the dependencies in a virtual environment inside the project directory (not globally). This keeps everything contained and avoids conflicts with the system environment.
|
71 |
+
|
72 |
+
```Dockerfile
|
73 |
+
ENV POETRY_NO_INTERACTION=1 \
|
74 |
+
POETRY_VIRTUALENVS_IN_PROJECT=1 \
|
75 |
+
POETRY_VIRTUALENVS_CREATE=true \
|
76 |
+
POETRY_CACHE_DIR=/tmp/poetry_cache
|
77 |
+
```
|
78 |
+
|
79 |
+
4. __Install Dependencies__:
|
80 |
+
- Use `poetry install --no-root` to install only the dependencies and not the package itself. This is because you typically don't need to install the actual project code at this stage.
|
81 |
+
|
82 |
+
```Dockerfile
|
83 |
+
RUN --mount=type=cache,target=/tmp/poetry_cache poetry install --only main --no-root
|
84 |
+
```
|
85 |
+
|
86 |
+
5. __Build the Runtime Stage__:
|
87 |
+
- Now, set up the final runtime image. This stage will only include the required application code and the virtual environment created in the first stage.
|
88 |
+
- The final image will use the same Python base image but remain small by avoiding the re-installation of dependencies.
|
89 |
+
|
90 |
+
```Dockerfile
|
91 |
+
FROM python:3.10.14-slim as runner
|
92 |
+
WORKDIR /app
|
93 |
+
COPY src /app/src
|
94 |
+
COPY --from=builder /app/.venv /app/.venv
|
95 |
+
```
|
96 |
+
|
97 |
+
6. __Set Up the Path to Use the Virtual Environment__:
|
98 |
+
- Update the `PATH` environment variable to use the Python binaries from the virtual environment.
|
99 |
+
|
100 |
+
```Dockerfile
|
101 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
102 |
+
```
|
103 |
+
|
104 |
+
7. __Set a Default Command__:
|
105 |
+
- Finally, set the command that will be executed by default when the container is run. You can change or override this later in the Docker Compose file.
|
106 |
+
|
107 |
+
```Dockerfile
|
108 |
+
CMD ["python", "-m", "src.train"]
|
109 |
+
```
|
110 |
+
|
111 |
+
### Final Dockerfile
|
112 |
+
|
113 |
+
```Dockerfile
|
114 |
+
# Stage 1: Build environment with Poetry and dependencies
|
115 |
+
FROM python:3.10.14-slim as builder
|
116 |
+
RUN pip3 install poetry==1.7.1
|
117 |
+
WORKDIR /app
|
118 |
+
COPY pytorch_project/pyproject.toml pytorch_project/poetry.lock /app/
|
119 |
+
ENV POETRY_NO_INTERACTION=1 \
|
120 |
+
POETRY_VIRTUALENVS_IN_PROJECT=1 \
|
121 |
+
POETRY_VIRTUALENVS_CREATE=true \
|
122 |
+
POETRY_CACHE_DIR=/tmp/poetry_cache
|
123 |
+
RUN --mount=type=cache,target=/tmp/poetry_cache poetry install --only main --no-root
|
124 |
+
|
125 |
+
# Stage 2: Runtime environment
|
126 |
+
FROM python:3.10.14-slim as runner
|
127 |
+
WORKDIR /app
|
128 |
+
COPY src /app/src
|
129 |
+
COPY --from=builder /app/.venv /app/.venv
|
130 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
131 |
+
CMD ["python", "-m", "src.train"]
|
132 |
+
```
|
133 |
+
|
134 |
+
---
|
135 |
+
|
136 |
+
### 2. Setting Up the docker-compose.yml File
|
137 |
+
|
138 |
+
The `docker-compose.yml` file is used to define and run multiple Docker containers as services. In this case, we need two services: one for __training__ and one for __inference__.
|
139 |
+
|
140 |
+
### Step-by-Step Process for Creating docker-compose.yml
|
141 |
+
|
142 |
+
1. __Define the Version__:
|
143 |
+
- Docker Compose uses a versioning system. Use version `3.8`, which is widely supported and offers features such as networking and volume support.
|
144 |
+
|
145 |
+
```yaml
|
146 |
+
version: '3.8'
|
147 |
+
```
|
148 |
+
|
149 |
+
2. __Set Up the `train` Service__:
|
150 |
+
- The `train` service is responsible for running the training script. It builds the Docker image, runs the training command, and uses volumes to store the data, checkpoints, and artifacts.
|
151 |
+
|
152 |
+
```yaml
|
153 |
+
services:
|
154 |
+
train:
|
155 |
+
build:
|
156 |
+
context: .
|
157 |
+
command: python -m src.train
|
158 |
+
volumes:
|
159 |
+
- data:/app/data
|
160 |
+
- checkpoints:/app/checkpoints
|
161 |
+
- artifacts:/app/artifacts
|
162 |
+
shm_size: '2g' # Increase shared memory to prevent DataLoader issues
|
163 |
+
networks:
|
164 |
+
- default
|
165 |
+
env_file:
|
166 |
+
- .env # Load environment variables
|
167 |
+
```
|
168 |
+
|
169 |
+
3. __Set Up the `inference` Service__:
|
170 |
+
- The `inference` service runs after the training has completed. It waits for a file (e.g., `train_done.flag`) to be created by the training process and then runs the inference script.
|
171 |
+
|
172 |
+
```yaml
|
173 |
+
inference:
|
174 |
+
build:
|
175 |
+
context: .
|
176 |
+
command: /bin/bash -c "while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done; python -m src.infer"
|
177 |
+
volumes:
|
178 |
+
- checkpoints:/app/checkpoints
|
179 |
+
- artifacts:/app/artifacts
|
180 |
+
shm_size: '2g'
|
181 |
+
networks:
|
182 |
+
- default
|
183 |
+
depends_on:
|
184 |
+
- train
|
185 |
+
env_file:
|
186 |
+
- .env
|
187 |
+
```
|
188 |
+
|
189 |
+
4. __Define Shared Volumes__:
|
190 |
+
- Volumes allow services to share data. Here, we define three shared volumes:
|
191 |
+
- `data`: Stores the input data.
|
192 |
+
- `checkpoints`: Stores the model checkpoints and the flag indicating training is complete.
|
193 |
+
- `artifacts`: Stores the final model outputs or artifacts.
|
194 |
+
|
195 |
+
```yaml
|
196 |
+
volumes:
|
197 |
+
data:
|
198 |
+
checkpoints:
|
199 |
+
artifacts:
|
200 |
+
```
|
201 |
+
|
202 |
+
5. __Set Up Networking__:
|
203 |
+
- Use the default network to allow the services to communicate.
|
204 |
+
|
205 |
+
```yaml
|
206 |
+
networks:
|
207 |
+
default:
|
208 |
+
```
|
209 |
+
|
210 |
+
### Final docker-compose.yml
|
211 |
+
|
212 |
+
```yaml
|
213 |
+
version: '3.8'
|
214 |
+
|
215 |
+
services:
|
216 |
+
train:
|
217 |
+
build:
|
218 |
+
context: .
|
219 |
+
command: python -m src.train
|
220 |
+
volumes:
|
221 |
+
- data:/app/data
|
222 |
+
- checkpoints:/app/checkpoints
|
223 |
+
- artifacts:/app/artifacts
|
224 |
+
shm_size: '2g'
|
225 |
+
networks:
|
226 |
+
- default
|
227 |
+
env_file:
|
228 |
+
- .env
|
229 |
+
|
230 |
+
inference:
|
231 |
+
build:
|
232 |
+
context: .
|
233 |
+
command: /bin/bash -c "while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done; python -m src.infer"
|
234 |
+
volumes:
|
235 |
+
- checkpoints:/app/checkpoints
|
236 |
+
- artifacts:/app/artifacts
|
237 |
+
shm_size: '2g'
|
238 |
+
networks:
|
239 |
+
- default
|
240 |
+
depends_on:
|
241 |
+
- train
|
242 |
+
env_file:
|
243 |
+
- .env
|
244 |
+
|
245 |
+
volumes:
|
246 |
+
data:
|
247 |
+
checkpoints:
|
248 |
+
artifacts:
|
249 |
+
|
250 |
+
networks:
|
251 |
+
default:
|
252 |
+
```
|
253 |
+
|
254 |
+
---
|
255 |
+
|
256 |
+
### Summary
|
257 |
+
|
258 |
+
1. __Dockerfile__:
|
259 |
+
- A multi-stage Dockerfile is used to create a lightweight image where the dependencies are installed with Poetry and the application code is run using a virtual environment.
|
260 |
+
- It ensures that all dependencies are isolated in a virtual environment, and the final container only includes what is necessary for the runtime.
|
261 |
+
|
262 |
+
2. __docker-compose.yml__:
|
263 |
+
- The `docker-compose.yml` file defines two services:
|
264 |
+
- __train__: Runs the training script and stores checkpoints.
|
265 |
+
- __inference__: Waits for the training to finish and runs inference based on the saved model.
|
266 |
+
- Shared volumes ensure that the services can access data, checkpoints, and artifacts.
|
267 |
+
- `shm_size` is increased to prevent issues with DataLoader in PyTorch when using multiple workers.
|
268 |
+
|
269 |
+
This setup allows for easy management of multiple services using Docker Compose, ensuring reproducibility and simplicity.
|
270 |
+
|
271 |
+
## __References__
|
272 |
+
|
273 |
+
- <https://stackoverflow.com/questions/53835198/integrating-python-poetry-with-docker>
|
274 |
+
- <https://github.com/fralik/poetry-with-private-repos/blob/master/Dockerfile>
|
275 |
+
- <https://medium.com/@albertazzir/blazing-fast-python-docker-builds-with-poetry-a78a66f5aed0>
|
276 |
+
- <https://www.martinrichards.me/post/python_poetry_docker/>
|
277 |
+
- <https://gist.github.com/soof-golan/6ebb97a792ccd87816c0bda1e6e8b8c2>
|
278 |
+
|
279 |
+
8. ## __DVC SETUP__
|
280 |
+
|
281 |
+
First, install dvc using the following command
|
282 |
+
|
283 |
+
```bash
|
284 |
+
dvc init
|
285 |
+
dvc version
|
286 |
+
dvc init -f
|
287 |
+
dvc config core.autostage true
|
288 |
+
dvc add data
|
289 |
+
dvc remote add -d myremote /tmp/dvcstore
|
290 |
+
dvc push
|
291 |
+
```
|
292 |
+
|
293 |
+
Add some more file in the data directory and run the following commands
|
294 |
+
|
295 |
+
```bash
|
296 |
+
dvc add data
|
297 |
+
dvc push
|
298 |
+
dvc pull
|
299 |
+
```
|
300 |
+
|
301 |
+
Next go back to 1 commit and run the following command
|
302 |
+
|
303 |
+
```bash
|
304 |
+
git checkout HEAD~1
|
305 |
+
dvc checkout
|
306 |
+
# you will get one file less
|
307 |
+
```
|
308 |
+
|
309 |
+
Next go back to the latest commit and run the following command
|
310 |
+
|
311 |
+
```bash
|
312 |
+
git checkout -
|
313 |
+
dvc checkout
|
314 |
+
dv pull
|
315 |
+
dvc commit
|
316 |
+
```
|
317 |
+
|
318 |
+
Next run the following command to add google drive as a remote
|
319 |
+
|
320 |
+
```bash
|
321 |
+
dvc remote add --default gdrive gdrive://1w2e3r4t5y6u7i8o9p0
|
322 |
+
dvc remote modify gdrive gdrive_acknowledge_abuse true
|
323 |
+
dvc remote modify gdrive gdrive_client_id <>
|
324 |
+
dvc remote modify gdrive gdrive_client_secret <>
|
325 |
+
# does not work when used from VM and port forwarding to local machine
|
326 |
+
```
|
327 |
+
|
328 |
+
Next run the following command to add azure-blob as a remote
|
329 |
+
|
330 |
+
```bash
|
331 |
+
dvc remote remove azblob
|
332 |
+
dvc remote add --default azblob azure://mycontainer/myfolder
|
333 |
+
dvc remote modify --local azblob connection_string "<>"
|
334 |
+
dvc remote modify azblob allow_anonymous_login true
|
335 |
+
dvc push -r azblob
|
336 |
+
# this works when used and requires no explicit login
|
337 |
+
```
|
338 |
+
|
339 |
+
Next we will add S3 as a remote
|
340 |
+
|
341 |
+
```bash
|
342 |
+
dvc remote add --default aws_remote s3://deep-bucket-s3/data
|
343 |
+
dvc remote modify --local aws_remote access_key_id <>
|
344 |
+
dvc remote modify --local aws_remote secret_access_key <>
|
345 |
+
dvc remote modify --local aws_remote region ap-south-1
|
346 |
+
dvc remote modify aws_remote region ap-south-1
|
347 |
+
dvc push -r aws_remote -v
|
348 |
+
```
|
349 |
+
|
350 |
+
9. ## __HYDRA SETUP__
|
351 |
+
|
352 |
+
```bash
|
353 |
+
# Install hydra
|
354 |
+
pip install hydra-core hydra_colorlog omegaconf
|
355 |
+
# Fillup the configs folder with the files as per the project
|
356 |
+
# Run the following command to run the hydra experiment
|
357 |
+
# for train
|
358 |
+
python -m src.hydra_test experiment=catdog_experiment ++task_name=train ++train=True ++test=False
|
359 |
+
# for eval
|
360 |
+
python -m src.hydra_test experiment=catdog_experiment ++task_name=eval ++train=False ++test=True
|
361 |
+
# for both
|
362 |
+
python -m src.hydra_test experiment=catdog_experiment task_name=train train=True test=True # + means adding new key value pair to the existing config and ++ means overriding the existing key value pair
|
363 |
+
```
|
364 |
+
|
365 |
+
10. ## __LOCAL SETUP__
|
366 |
+
|
367 |
+
```bash
|
368 |
+
python -m src.train experiment=catdog_experiment ++task_name=train ++train=True ++test=False
|
369 |
+
python -m src.train experiment=catdog_experiment ++task_name=eval ++train=False ++test=True
|
370 |
+
python -m src.infer experiment=catdog_experiment
|
371 |
+
```
|
372 |
+
|
373 |
+
11. ## _DVC_PIPELINE_SETUP_
|
374 |
+
|
375 |
+
```bash
|
376 |
+
dvc repro
|
377 |
+
```
|
378 |
+
12. ## _DVC Experiments_
|
379 |
+
- To run the dvc experiments keep different experiment_<>.yaml files in the configs folder under experiment folder
|
380 |
+
- Make sure to override the default values in the experiment_<>.yaml file for each parameter that you want to change
|
381 |
+
|
382 |
+
13. ## _HYDRA Experiments_
|
383 |
+
- make sure to declare te config file in yaml format in the configs folder hparam
|
384 |
+
- have hparam null in train and eval config file
|
385 |
+
- run the following command to run the hydra experiment
|
386 |
+
```bash
|
387 |
+
python -m src.train --multirun experiment=catdog_experiment_convnext ++task_name=train ++train=True ++test=False hparam=catdog_classifier_covnext
|
388 |
+
python -m src.create_artifacts
|
389 |
+
```
|
390 |
+
|
391 |
+
14. ## __Latest Execution Command__
|
392 |
+
|
393 |
+
```bash
|
394 |
+
python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=train ++train=True ++test=False
|
395 |
+
python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=test ++train=False ++test=True
|
396 |
+
python -m src.infer experiment=catdog_experiment
|
397 |
+
```
|
398 |
+
|
399 |
+
15. ## __GPU Setup__
|
400 |
+
```bash
|
401 |
+
docker build -t my-gpu-app .
|
402 |
+
docker run --gpus all my-gpu-app
|
403 |
+
docker exec -it <container_id> /bin/bash
|
404 |
+
# pytorch/pytorch:2.2.2-cuda12.1-cudnn8-runtime supports cuda 12.1 and python 3.10.14
|
405 |
+
```
|
406 |
+
```bash
|
407 |
+
# for docker compose what we need to is follow similar to the following
|
408 |
+
services:
|
409 |
+
test:
|
410 |
+
image: nvidia/cuda:12.3.1-base-ubuntu20.04
|
411 |
+
command: nvidia-smi
|
412 |
+
deploy:
|
413 |
+
resources:
|
414 |
+
reservations:
|
415 |
+
devices:
|
416 |
+
- driver: nvidia
|
417 |
+
count: 1
|
418 |
+
capabilities: [gpu]
|
419 |
+
```
|
client.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Copyright The Lightning AI team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import requests
|
16 |
+
|
17 |
+
response = requests.post("http://127.0.0.1:8080/predict", json={"input": 4.0})
|
18 |
+
print(f"Status: {response.status_code}\nResponse:\n {response.text}")
|
configs/callbacks/default.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- model_checkpoint
|
3 |
+
- early_stopping
|
4 |
+
- rich_model_summary
|
5 |
+
- rich_progress_bar
|
6 |
+
- _self_
|
7 |
+
|
8 |
+
model_checkpoint:
|
9 |
+
dirpath: ${paths.ckpt_dir}
|
10 |
+
monitor: "val_loss"
|
11 |
+
mode: "min"
|
12 |
+
save_last: False
|
13 |
+
auto_insert_metric_name: False
|
14 |
+
|
15 |
+
early_stopping:
|
16 |
+
monitor: "val_loss"
|
17 |
+
patience: 3
|
18 |
+
mode: "min"
|
19 |
+
|
20 |
+
rich_model_summary:
|
21 |
+
max_depth: -1
|
22 |
+
|
23 |
+
rich_progress_bar:
|
24 |
+
refresh_rate: 1
|
configs/callbacks/early_stopping.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
2 |
+
|
3 |
+
early_stopping:
|
4 |
+
_target_: lightning.pytorch.callbacks.EarlyStopping
|
5 |
+
monitor: val_loss # quantity to be monitored, must be specified !!!
|
6 |
+
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
7 |
+
patience: 3 # number of checks with no improvement after which training will be stopped
|
8 |
+
verbose: False # verbosity mode
|
9 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
10 |
+
strict: True # whether to crash the training if monitor is not found in the validation metrics
|
11 |
+
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
|
12 |
+
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
|
13 |
+
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
|
14 |
+
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
|
15 |
+
# log_rank_zero_only: False # this keyword argument isn't available in stable version
|
configs/callbacks/model_checkpoint.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
2 |
+
|
3 |
+
model_checkpoint:
|
4 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
5 |
+
dirpath: null # directory to save the model file
|
6 |
+
filename: best-checkpoint # checkpoint filename
|
7 |
+
monitor: val_loss # name of the logged metric which determines when model is improving
|
8 |
+
verbose: False # verbosity mode
|
9 |
+
save_last: False # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
10 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
11 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
12 |
+
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
|
13 |
+
save_weights_only: False # if True, then only the model’s weights will be saved
|
14 |
+
every_n_train_steps: null # number of training steps between checkpoints
|
15 |
+
train_time_interval: null # checkpoints are monitored at the specified time interval
|
16 |
+
every_n_epochs: null # number of epochs between checkpoints
|
17 |
+
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
configs/callbacks/rich_model_summary.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
|
2 |
+
rich_model_summary:
|
3 |
+
_target_: lightning.pytorch.callbacks.RichModelSummary
|
4 |
+
max_depth: 1
|
configs/callbacks/rich_progress_bar.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichProgressBar.html
|
2 |
+
rich_progress_bar:
|
3 |
+
_target_: lightning.pytorch.callbacks.RichProgressBar
|
4 |
+
refresh_rate: 1
|
configs/data/catdog.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.datamodules.catdog_datamodule.CatDogImageDataModule
|
2 |
+
root_dir: ${paths.data_dir}
|
3 |
+
data_dir: "cats_and_dogs_filtered"
|
4 |
+
url: ${paths.data_url}
|
5 |
+
num_workers: 4
|
6 |
+
batch_size: 32
|
7 |
+
train_val_split: [0.8, 0.2]
|
8 |
+
pin_memory: False
|
9 |
+
image_size: 224
|
configs/experiment/catdog_experiment.yaml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# to execute this experiment run:
|
4 |
+
# python train.py experiment=catdog_ex
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /paths: catdog
|
8 |
+
- override /data: catdog
|
9 |
+
- override /model: catdog_classifier
|
10 |
+
- override /callbacks: default
|
11 |
+
- override /logger: default
|
12 |
+
- override /trainer: default
|
13 |
+
|
14 |
+
# all parameters below will be merged with parameters from default configurations set above
|
15 |
+
# this allows you to overwrite only specified parameters
|
16 |
+
|
17 |
+
seed: 42
|
18 |
+
name: "catdog_experiment"
|
19 |
+
|
20 |
+
data:
|
21 |
+
data_dir: "cats_and_dogs_filtered"
|
22 |
+
batch_size: 64
|
23 |
+
num_workers: 8
|
24 |
+
pin_memory: True
|
25 |
+
image_size: 224
|
26 |
+
|
27 |
+
model:
|
28 |
+
lr: 5e-5
|
29 |
+
weight_decay: 1e-5
|
30 |
+
factor: 0.5
|
31 |
+
patience: 5
|
32 |
+
min_lr: 1e-6
|
33 |
+
num_classes: 2
|
34 |
+
patch_size: 16
|
35 |
+
embed_dim: 256
|
36 |
+
depth: 4
|
37 |
+
num_heads: 4
|
38 |
+
mlp_ratio: 4
|
39 |
+
|
40 |
+
trainer:
|
41 |
+
min_epochs: 1
|
42 |
+
max_epochs: 5
|
43 |
+
|
44 |
+
callbacks:
|
45 |
+
model_checkpoint:
|
46 |
+
monitor: "val_acc"
|
47 |
+
mode: "max"
|
48 |
+
save_top_k: 1
|
49 |
+
save_last: True
|
50 |
+
verbose: True
|
51 |
+
|
52 |
+
early_stopping:
|
53 |
+
monitor: "val_acc"
|
54 |
+
patience: 10
|
55 |
+
mode: "max"
|
56 |
+
verbose: True
|
57 |
+
|
58 |
+
rich_model_summary:
|
59 |
+
max_depth: 1
|
60 |
+
|
61 |
+
rich_progress_bar:
|
62 |
+
refresh_rate: 1
|
configs/experiment/catdog_experiment_resnet.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# to execute this experiment run:
|
4 |
+
# python train.py experiment=catdog_ex
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /paths: catdog
|
8 |
+
- override /data: catdog
|
9 |
+
- override /model: catdog_classifier_resnet
|
10 |
+
- override /callbacks: default
|
11 |
+
- override /logger: default
|
12 |
+
- override /trainer: default
|
13 |
+
|
14 |
+
# all parameters below will be merged with parameters from default configurations set above
|
15 |
+
# this allows you to overwrite only specified parameters
|
16 |
+
|
17 |
+
seed: 42
|
18 |
+
name: "catdog_experiment_resnet"
|
19 |
+
|
20 |
+
# Logger-specific configurations
|
21 |
+
logger:
|
22 |
+
aim:
|
23 |
+
experiment: ${name}
|
24 |
+
mlflow:
|
25 |
+
experiment_name: ${name}
|
26 |
+
tags:
|
27 |
+
model_type: "timm_classify"
|
28 |
+
|
29 |
+
data:
|
30 |
+
batch_size: 64
|
31 |
+
num_workers: 8
|
32 |
+
pin_memory: True
|
33 |
+
image_size: 160
|
34 |
+
|
35 |
+
model:
|
36 |
+
base_model: efficientnet_b0
|
37 |
+
pretrained: True
|
38 |
+
lr: 1e-3
|
39 |
+
weight_decay: 1e-5
|
40 |
+
factor: 0.1
|
41 |
+
patience: 5
|
42 |
+
min_lr: 1e-6
|
43 |
+
num_classes: 2
|
44 |
+
|
45 |
+
trainer:
|
46 |
+
min_epochs: 1
|
47 |
+
max_epochs: 5
|
48 |
+
|
49 |
+
callbacks:
|
50 |
+
model_checkpoint:
|
51 |
+
monitor: "val_acc"
|
52 |
+
mode: "max"
|
53 |
+
save_top_k: 1
|
54 |
+
save_last: True
|
55 |
+
|
56 |
+
early_stopping:
|
57 |
+
monitor: "val_acc"
|
58 |
+
patience: 3
|
59 |
+
mode: "max"
|
configs/hydra/default.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://hydra.cc/docs/configure_hydra/intro/
|
2 |
+
|
3 |
+
# enable color logging
|
4 |
+
defaults:
|
5 |
+
- override hydra_logging: colorlog
|
6 |
+
- override job_logging: colorlog
|
7 |
+
|
8 |
+
# output directory, generated dynamically on each run
|
9 |
+
run:
|
10 |
+
dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
11 |
+
sweep:
|
12 |
+
dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
13 |
+
subdir: ${hydra.job.num}
|
14 |
+
|
15 |
+
job_logging:
|
16 |
+
handlers:
|
17 |
+
file:
|
18 |
+
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
|
19 |
+
filename: ${hydra.runtime.output_dir}/${task_name}.log
|
configs/infer.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# specify here default configuration
|
4 |
+
# order of defaults determines the order in which configs override each other
|
5 |
+
defaults:
|
6 |
+
- _self_
|
7 |
+
- data: catdog
|
8 |
+
- model: catdog_classifier
|
9 |
+
- callbacks: default
|
10 |
+
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
11 |
+
- trainer: default
|
12 |
+
- paths: dogbreed
|
13 |
+
- hydra: default
|
14 |
+
# experiment configs allow for version control of specific hyperparameters
|
15 |
+
# e.g. best hyperparameters for given model and datamodule
|
16 |
+
- experiment: catdog_experiment
|
17 |
+
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
18 |
+
- debug: null
|
19 |
+
|
20 |
+
# task name, determines output directory path
|
21 |
+
task_name: "infer"
|
22 |
+
|
23 |
+
# tags to help you identify your experiments
|
24 |
+
# you can overwrite this in experiment configs
|
25 |
+
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
|
26 |
+
tags: ["dev"]
|
27 |
+
|
28 |
+
# set False to skip model training
|
29 |
+
train: False
|
30 |
+
|
31 |
+
# evaluate on test set, using best model weights achieved during training
|
32 |
+
# lightning chooses best weights based on the metric specified in checkpoint callback
|
33 |
+
test: False
|
34 |
+
|
35 |
+
# simply provide checkpoint path to resume training
|
36 |
+
ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt
|
37 |
+
|
38 |
+
# seed for random number generators in pytorch, numpy and python.random
|
39 |
+
seed: 42
|
40 |
+
|
41 |
+
# name of the experiment
|
42 |
+
name: "catdog_experiment"
|
43 |
+
|
44 |
+
server:
|
45 |
+
port: 8080
|
46 |
+
max_batch_size: 8
|
47 |
+
batch_timeout: 0.01
|
48 |
+
accelerator: "auto"
|
49 |
+
devices: "auto"
|
50 |
+
workers_per_device: 2
|
51 |
+
|
52 |
+
labels: ["cat", "dog"]
|
configs/logger/aim.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aim:
|
2 |
+
_target_: aim.pytorch_lightning.AimLogger
|
3 |
+
experiment: ${name}
|
4 |
+
train_metric_prefix: train_
|
5 |
+
test_metric_prefix: test_
|
6 |
+
val_metric_prefix: val_
|
configs/logger/csv.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# csv logger built in lightning
|
2 |
+
|
3 |
+
csv:
|
4 |
+
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
|
5 |
+
save_dir: "${paths.output_dir}"
|
6 |
+
name: "csv/"
|
7 |
+
prefix: ""
|
configs/logger/default.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train with many loggers at once
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- csv
|
5 |
+
- tensorboard
|
configs/logger/mlflow.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MLflow logger configuration
|
2 |
+
|
3 |
+
mlflow:
|
4 |
+
_target_: lightning.pytorch.loggers.MLFlowLogger
|
5 |
+
experiment_name: ${name}
|
6 |
+
tracking_uri: file:${paths.log_dir}/mlruns
|
7 |
+
save_dir: ${paths.log_dir}/mlruns
|
8 |
+
log_model: False
|
9 |
+
prefix: ""
|
configs/logger/tensorboard.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://www.tensorflow.org/tensorboard/
|
2 |
+
|
3 |
+
tensorboard:
|
4 |
+
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
5 |
+
save_dir: "${paths.output_dir}/tensorboard/"
|
6 |
+
name: null
|
7 |
+
log_graph: False
|
8 |
+
default_hp_metric: True
|
9 |
+
prefix: ""
|
10 |
+
# version: ""
|
configs/model/catdog_classifier.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# model class
|
3 |
+
_target_: src.models.catdog_model.ViTTinyClassifier
|
4 |
+
|
5 |
+
# model params
|
6 |
+
img_size: ${data.image_size}
|
7 |
+
patch_size: 16
|
8 |
+
num_classes: 2
|
9 |
+
embed_dim: 128
|
10 |
+
depth: 6
|
11 |
+
num_heads: 4
|
12 |
+
mlp_ratio: 4
|
13 |
+
pre_norm: False
|
14 |
+
|
15 |
+
# optimizer params
|
16 |
+
lr: 1e-3
|
17 |
+
weight_decay: 1e-5
|
18 |
+
|
19 |
+
# scheduler params
|
20 |
+
factor: 0.1
|
21 |
+
patience: 10
|
22 |
+
min_lr: 1e-6
|
configs/model/catdog_classifier_resnet.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.catdog_model_resnet.ResnetClassifier
|
2 |
+
|
3 |
+
# model params
|
4 |
+
base_model: efficientnet_b0
|
5 |
+
pretrained: True
|
6 |
+
num_classes: 2
|
7 |
+
# optimizer params
|
8 |
+
lr: 1e-3
|
9 |
+
weight_decay: 1e-5
|
10 |
+
# scheduler params
|
11 |
+
factor: 0.1
|
12 |
+
patience: 10
|
13 |
+
min_lr: 1e-6
|
configs/paths/catdog.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# path to root directory
|
2 |
+
# this requires PROJECT_ROOT environment variable to exist
|
3 |
+
# you can replace it with "." if you want the root to be the current working directory
|
4 |
+
root_dir: ${oc.env:PROJECT_ROOT}
|
5 |
+
|
6 |
+
# path to data directory
|
7 |
+
data_dir: ${paths.root_dir}/data/
|
8 |
+
|
9 |
+
# path to logging directory
|
10 |
+
log_dir: ${paths.root_dir}/logs/
|
11 |
+
|
12 |
+
# path to checkpoint directory
|
13 |
+
ckpt_dir: ${paths.root_dir}/checkpoints
|
14 |
+
|
15 |
+
# path to artifact directory
|
16 |
+
artifact_dir: ${paths.root_dir}/artifacts/
|
17 |
+
|
18 |
+
# download url for the dataset
|
19 |
+
data_url: "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip"
|
20 |
+
|
21 |
+
# path to output directory, created dynamically by hydra
|
22 |
+
# path generation pattern is specified in `configs/hydra/default.yaml`
|
23 |
+
# use it to store all files generated during the run, like ckpts and metrics
|
24 |
+
output_dir: ${hydra:runtime.output_dir}
|
25 |
+
|
26 |
+
# path to working directory
|
27 |
+
work_dir: ${hydra:runtime.cwd}
|
configs/train.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# specify here default configuration
|
4 |
+
# order of defaults determines the order in which configs override each other
|
5 |
+
defaults:
|
6 |
+
- _self_
|
7 |
+
- data: catdog
|
8 |
+
- model: catdog_classifier
|
9 |
+
- callbacks: default
|
10 |
+
- logger: default # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
11 |
+
- trainer: default
|
12 |
+
- paths: catdog
|
13 |
+
- hydra: default
|
14 |
+
|
15 |
+
- experiment: catdog_experiment
|
16 |
+
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
17 |
+
- debug: null
|
18 |
+
|
19 |
+
# task name, determines output directory path
|
20 |
+
task_name: "train"
|
21 |
+
|
22 |
+
# tags to help you identify your experiments
|
23 |
+
# you can overwrite this in experiment configs
|
24 |
+
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
|
25 |
+
tags: ["dev"]
|
26 |
+
|
27 |
+
# set False to skip model training
|
28 |
+
train: True
|
29 |
+
|
30 |
+
# evaluate on test set, using best model weights achieved during training
|
31 |
+
# lightning chooses best weights based on the metric specified in checkpoint callback
|
32 |
+
test: False
|
33 |
+
|
34 |
+
# simply provide checkpoint path to resume training
|
35 |
+
ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt
|
36 |
+
|
37 |
+
# seed for random number generators in pytorch, numpy and python.random
|
38 |
+
seed: 42
|
39 |
+
|
40 |
+
# name of the experiment
|
41 |
+
name: "catdog_experiment"
|
42 |
+
|
43 |
+
# optimization metric
|
44 |
+
optimization_metric: "val_acc"
|
45 |
+
|
46 |
+
# optuna hyperparameter optimization
|
47 |
+
n_trials: 2
|
configs/trainer/default.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
default_root_dir: ${paths.output_dir}
|
3 |
+
min_epochs: 1
|
4 |
+
max_epochs: 6
|
5 |
+
|
6 |
+
accelerator: auto
|
7 |
+
devices: auto
|
8 |
+
|
9 |
+
# mixed precision for extra speed-up
|
10 |
+
# precision: 16
|
11 |
+
|
12 |
+
# set True to to ensure deterministic results makes training slower but gives more reproducibility than just setting seeds
|
13 |
+
deterministic: True
|
14 |
+
|
15 |
+
# Log every N steps in training and validation
|
16 |
+
log_every_n_steps: 10
|
17 |
+
fast_dev_run: False
|
18 |
+
|
19 |
+
gradient_clip_val: 1.0
|
20 |
+
gradient_clip_algorithm: 'norm'
|
data.dvc
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
outs:
|
2 |
+
- md5: 1a2429ba45778743c46917f7e6b9b542.dir
|
3 |
+
size: 97446370
|
4 |
+
nfiles: 3002
|
5 |
+
hash: md5
|
6 |
+
path: data
|
docker-compose-old.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.8'
|
2 |
+
|
3 |
+
services:
|
4 |
+
train:
|
5 |
+
build:
|
6 |
+
context: .
|
7 |
+
command: |
|
8 |
+
python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=train ++train=True ++test=False && \
|
9 |
+
python -m src.create_artifacts && \
|
10 |
+
touch ./checkpoints/train_done.flag
|
11 |
+
volumes:
|
12 |
+
- ./data:/app/data
|
13 |
+
- ./checkpoints:/app/checkpoints
|
14 |
+
- ./artifacts:/app/artifacts
|
15 |
+
- ./logs:/app/logs
|
16 |
+
environment:
|
17 |
+
- PYTHONUNBUFFERED=1
|
18 |
+
- PYTHONPATH=/app
|
19 |
+
shm_size: '4g'
|
20 |
+
networks:
|
21 |
+
- default
|
22 |
+
env_file:
|
23 |
+
- .env
|
24 |
+
|
25 |
+
eval:
|
26 |
+
build:
|
27 |
+
context: .
|
28 |
+
command: |
|
29 |
+
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=test ++train=False ++test=True'
|
30 |
+
volumes:
|
31 |
+
- ./data:/app/data
|
32 |
+
- ./checkpoints:/app/checkpoints
|
33 |
+
- ./artifacts:/app/artifacts
|
34 |
+
- ./logs:/app/logs
|
35 |
+
environment:
|
36 |
+
- PYTHONUNBUFFERED=1
|
37 |
+
- PYTHONPATH=/app
|
38 |
+
shm_size: '4g'
|
39 |
+
networks:
|
40 |
+
- default
|
41 |
+
env_file:
|
42 |
+
- .env
|
43 |
+
depends_on:
|
44 |
+
- train
|
45 |
+
|
46 |
+
inference:
|
47 |
+
build:
|
48 |
+
context: .
|
49 |
+
command: |
|
50 |
+
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.infer experiment=catdog_experiment'
|
51 |
+
volumes:
|
52 |
+
- ./data:/app/data
|
53 |
+
- ./checkpoints:/app/checkpoints
|
54 |
+
- ./artifacts:/app/artifacts
|
55 |
+
- ./logs:/app/logs
|
56 |
+
environment:
|
57 |
+
- PYTHONUNBUFFERED=1
|
58 |
+
- PYTHONPATH=/app
|
59 |
+
shm_size: '4g'
|
60 |
+
networks:
|
61 |
+
- default
|
62 |
+
env_file:
|
63 |
+
- .env
|
64 |
+
depends_on:
|
65 |
+
- train
|
66 |
+
|
67 |
+
volumes:
|
68 |
+
data:
|
69 |
+
checkpoints:
|
70 |
+
artifacts:
|
71 |
+
logs:
|
72 |
+
|
73 |
+
networks:
|
74 |
+
default:
|
docker-compose.yaml
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
train:
|
3 |
+
build:
|
4 |
+
context: .
|
5 |
+
command: |
|
6 |
+
python -m src.train_optuna_callbacks experiment=catdog_experiment_resnet ++task_name=train ++train=True ++test=False && \
|
7 |
+
python -m src.create_artifacts && \
|
8 |
+
touch ./checkpoints/train_done.flag
|
9 |
+
volumes:
|
10 |
+
- ./data:/app/data
|
11 |
+
- ./checkpoints:/app/checkpoints
|
12 |
+
- ./artifacts:/app/artifacts
|
13 |
+
- ./logs:/app/logs
|
14 |
+
environment:
|
15 |
+
- PYTHONUNBUFFERED=1
|
16 |
+
- PYTHONPATH=/app
|
17 |
+
shm_size: '4g'
|
18 |
+
networks:
|
19 |
+
- default
|
20 |
+
env_file:
|
21 |
+
- .env
|
22 |
+
deploy:
|
23 |
+
resources:
|
24 |
+
reservations:
|
25 |
+
devices:
|
26 |
+
- driver: nvidia
|
27 |
+
count: 1
|
28 |
+
capabilities: [gpu]
|
29 |
+
|
30 |
+
eval:
|
31 |
+
build:
|
32 |
+
context: .
|
33 |
+
command: |
|
34 |
+
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_optuna_callbacks experiment=catdog_experiment_resnet ++task_name=test ++train=False ++test=True'
|
35 |
+
volumes:
|
36 |
+
- ./data:/app/data
|
37 |
+
- ./checkpoints:/app/checkpoints
|
38 |
+
- ./artifacts:/app/artifacts
|
39 |
+
- ./logs:/app/logs
|
40 |
+
environment:
|
41 |
+
- PYTHONUNBUFFERED=1
|
42 |
+
- PYTHONPATH=/app
|
43 |
+
shm_size: '4g'
|
44 |
+
networks:
|
45 |
+
- default
|
46 |
+
env_file:
|
47 |
+
- .env
|
48 |
+
deploy:
|
49 |
+
resources:
|
50 |
+
reservations:
|
51 |
+
devices:
|
52 |
+
- driver: nvidia
|
53 |
+
count: 1
|
54 |
+
capabilities: [gpu]
|
55 |
+
|
56 |
+
inference:
|
57 |
+
build:
|
58 |
+
context: .
|
59 |
+
command: |
|
60 |
+
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.infer experiment=catdog_experiment_resnet'
|
61 |
+
volumes:
|
62 |
+
- ./data:/app/data
|
63 |
+
- ./checkpoints:/app/checkpoints
|
64 |
+
- ./artifacts:/app/artifacts
|
65 |
+
- ./logs:/app/logs
|
66 |
+
environment:
|
67 |
+
- PYTHONUNBUFFERED=1
|
68 |
+
- PYTHONPATH=/app
|
69 |
+
shm_size: '4g'
|
70 |
+
networks:
|
71 |
+
- default
|
72 |
+
env_file:
|
73 |
+
- .env
|
74 |
+
deploy:
|
75 |
+
resources:
|
76 |
+
reservations:
|
77 |
+
devices:
|
78 |
+
- driver: nvidia
|
79 |
+
count: 1
|
80 |
+
capabilities: [gpu]
|
81 |
+
|
82 |
+
|
83 |
+
volumes:
|
84 |
+
data:
|
85 |
+
checkpoints:
|
86 |
+
artifacts:
|
87 |
+
logs:
|
88 |
+
|
89 |
+
networks:
|
90 |
+
default:
|
docker_compose_exec.sh
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Exit on any error
|
4 |
+
set -e
|
5 |
+
|
6 |
+
# Helper function to wait for a condition
|
7 |
+
wait_for_condition() {
|
8 |
+
local condition=$1
|
9 |
+
local description=$2
|
10 |
+
echo "Waiting for $description..."
|
11 |
+
while ! eval "$condition"; do
|
12 |
+
echo "$description not ready. Retrying in 5 seconds..."
|
13 |
+
sleep 5
|
14 |
+
done
|
15 |
+
echo "$description is ready!"
|
16 |
+
}
|
17 |
+
|
18 |
+
# Step 1: Stop and rebuild all containers
|
19 |
+
echo "Stopping all running services..."
|
20 |
+
docker-compose stop
|
21 |
+
|
22 |
+
echo "Building all services..."
|
23 |
+
docker-compose build
|
24 |
+
|
25 |
+
# Step 2: Start the train service
|
26 |
+
echo "Starting 'train' service..."
|
27 |
+
docker-compose up -d train
|
28 |
+
|
29 |
+
# Step 3: Wait for train to complete
|
30 |
+
wait_for_condition "[ -f ./checkpoints/train_done.flag ]" "'train' service to complete"
|
31 |
+
|
32 |
+
# Step 4: Start the eval service
|
33 |
+
echo "Starting 'eval' service..."
|
34 |
+
docker-compose up -d eval
|
35 |
+
|
36 |
+
# Step 5: Start the server service
|
37 |
+
echo "Starting 'server' service..."
|
38 |
+
docker-compose up -d server
|
39 |
+
|
40 |
+
# Step 6: Wait for the server to be healthy
|
41 |
+
wait_for_condition "curl -s http://localhost:8080/health" "'server' service to be ready"
|
42 |
+
|
43 |
+
# Step 7: Start the client service
|
44 |
+
echo "Starting 'client' service..."
|
45 |
+
docker-compose up -d client
|
46 |
+
|
47 |
+
# Step 8: Show all running services
|
48 |
+
echo "All services are up and running:"
|
49 |
+
docker-compose ps
|
50 |
+
|
51 |
+
# Step 9: Stop and remove all containers after completion
|
52 |
+
echo "Stopping all services..."
|
53 |
+
docker-compose stop
|
54 |
+
|
55 |
+
echo "Removing all stopped containers..."
|
56 |
+
docker-compose rm -f
|
57 |
+
|
58 |
+
echo "Workflow complete!"
|
dvc.lock
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
schema: '2.0'
|
2 |
+
stages:
|
3 |
+
train:
|
4 |
+
cmd: docker-compose run --rm train
|
5 |
+
deps:
|
6 |
+
- path: data
|
7 |
+
hash: md5
|
8 |
+
md5: a372d6faac374b9f988d530864d0d7d5.dir
|
9 |
+
size: 97446370
|
10 |
+
nfiles: 3002
|
11 |
+
- path: docker-compose.yaml
|
12 |
+
hash: md5
|
13 |
+
md5: 85a64185c917ce60ae28e32c20c70164
|
14 |
+
size: 1735
|
15 |
+
isexec: true
|
16 |
+
- path: src/train.py
|
17 |
+
hash: md5
|
18 |
+
md5: 86b3871600a12f311e71dc171a2a37b9
|
19 |
+
size: 5972
|
20 |
+
isexec: true
|
21 |
+
outs:
|
22 |
+
- path: checkpoints/best-checkpoint.ckpt
|
23 |
+
hash: md5
|
24 |
+
md5: 6b6dcaa677324992489edaa51fc8b24f
|
25 |
+
size: 3755038
|
26 |
+
isexec: true
|
27 |
+
- path: checkpoints/train_done.flag
|
28 |
+
hash: md5
|
29 |
+
md5: bfc5d6f6817daa48ad7ae164aa621dbf
|
30 |
+
size: 20
|
31 |
+
isexec: true
|
dvc.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
stages:
|
2 |
+
train:
|
3 |
+
cmd: docker-compose run --rm train
|
4 |
+
deps:
|
5 |
+
- docker-compose.yaml
|
6 |
+
- src/train_optuna_callbacks.py
|
7 |
+
- src/create_artifacts.py
|
8 |
+
- data
|
9 |
+
outs:
|
10 |
+
- checkpoints/best-checkpoint.ckpt
|
11 |
+
- checkpoints/train_done.flag
|
12 |
+
# eval:
|
13 |
+
# cmd: docker-compose run --rm eval
|
14 |
+
# deps:
|
15 |
+
# - docker-compose.yaml
|
16 |
+
# - src/train.py
|
17 |
+
# - checkpoints/best-checkpoint.ckpt
|
18 |
+
# - checkpoints/train_done.flag
|
19 |
+
|
20 |
+
# inference:
|
21 |
+
# cmd: docker-compose run --rm inference
|
22 |
+
# deps:
|
23 |
+
# - docker-compose.yaml
|
24 |
+
# - src/infer.py
|
25 |
+
# - checkpoints/best-checkpoint.ckpt
|
26 |
+
# - checkpoints/train_done.flag
|
27 |
+
# outs:
|
28 |
+
# - artifacts/image_prediction.png
|
ec2_runner_setup.md
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**Install docker and docker-compose on Ubuntu 22.04**
|
2 |
+
__PreRequisites__:
|
3 |
+
|
4 |
+
* Have an aws account with a user that has the necessary permissions
|
5 |
+
* Have the access key either on env variables or in the github actions secrets
|
6 |
+
* Have an ec2 runner instance running/created in the aws account
|
7 |
+
* Have a s3 bucket created in the aws account
|
8 |
+
* Have aws container registry created in the aws account
|
9 |
+
__Local VM setup__:
|
10 |
+
* Install aws configure and setup the access key and secret key and the right zone
|
11 |
+
```bash
|
12 |
+
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
|
13 |
+
unzip awscliv2.zip
|
14 |
+
sudo ./aws/install
|
15 |
+
aws configure
|
16 |
+
```
|
17 |
+
|
18 |
+
|
19 |
+
__Install docker__:
|
20 |
+
```bash
|
21 |
+
sudo apt update
|
22 |
+
sudo apt install -y apt-transport-https ca-certificates curl software-properties-common
|
23 |
+
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg
|
24 |
+
echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
|
25 |
+
sudo apt update
|
26 |
+
sudo apt install -y docker-ce
|
27 |
+
sudo systemctl start docker
|
28 |
+
sudo systemctl enable docker
|
29 |
+
sudo usermod -aG docker $USER
|
30 |
+
sudo systemctl restart docker
|
31 |
+
sudo reboot
|
32 |
+
docker --version
|
33 |
+
docker ps
|
34 |
+
```
|
35 |
+
__Install docker-compose__:
|
36 |
+
```bash
|
37 |
+
sudo rm /usr/local/bin/docker-compose
|
38 |
+
sudo curl -L "https://github.com/docker/compose/releases/download/v2.30.0/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
|
39 |
+
sudo chmod +x /usr/local/bin/docker-compose
|
40 |
+
docker-compose --version
|
41 |
+
```
|
42 |
+
|
43 |
+
__Github actions self-hosted runner__:
|
44 |
+
```bash
|
45 |
+
mkdir actions-runner && cd actions-runner
|
46 |
+
curl -o actions-runner-linux-x64-2.320.0.tar.gz -L https://github.com/actions/runner/releases/download/v2.320.0/actions-runner-linux-x64-2.320.0.tar.gz
|
47 |
+
echo "93ac1b7ce743ee85b5d386f5c1787385ef07b3d7c728ff66ce0d3813d5f46900 actions-runner-linux-x64-2.320.0.tar.gz" | shasum -a 256 -c
|
48 |
+
tar xzf ./actions-runner-linux-x64-2.320.0.tar.gz
|
49 |
+
./config.sh --url https://github.com/soutrik71/pytorch-template-aws --token <Latest>
|
50 |
+
# cd actions-runner/
|
51 |
+
./run.sh
|
52 |
+
./config.sh remove --token <> # To remove the runner
|
53 |
+
# https://github.com/soutrik71/pytorch-template-aws/settings/actions/runners/new?arch=x64&os=linux
|
54 |
+
```
|
55 |
+
__Activate aws cli__:
|
56 |
+
```bash
|
57 |
+
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
|
58 |
+
sudo apt install unzip
|
59 |
+
unzip awscliv2.zip
|
60 |
+
sudo ./aws/install
|
61 |
+
aws --version
|
62 |
+
aws configure
|
63 |
+
|
64 |
+
```
|
65 |
+
__S3 bucket operations__:
|
66 |
+
```bash
|
67 |
+
aws s3 cp data s3://deep-bucket-s3/data --recursive
|
68 |
+
aws s3 ls s3://deep-bucket-s3
|
69 |
+
aws s3 rm s3://deep-bucket-s3/data --recursive
|
70 |
+
```
|
71 |
+
|
72 |
+
__Cuda Update Setup__:
|
73 |
+
```bash
|
74 |
+
# if you already have nvidia drivers installed and you have a Tesla T4 GPU
|
75 |
+
sudo apt update
|
76 |
+
sudo apt upgrade
|
77 |
+
sudo reboot
|
78 |
+
|
79 |
+
sudo apt --fix-broken install
|
80 |
+
sudo apt install ubuntu-drivers-common
|
81 |
+
sudo apt autoremove
|
82 |
+
|
83 |
+
nvidia-smi
|
84 |
+
lsmod | grep nvidia
|
85 |
+
|
86 |
+
sudo apt install nvidia-cuda-toolkit
|
87 |
+
nvcc --version
|
88 |
+
|
89 |
+
ls /usr/local/ | grep cuda
|
90 |
+
ldconfig -p | grep cudnn
|
91 |
+
lspci | grep -i nvidia
|
92 |
+
|
93 |
+
Based on the provided details, here is the breakdown of the information about your GPU, CUDA, and environment setup:
|
94 |
+
|
95 |
+
---
|
96 |
+
|
97 |
+
### **1. GPU Details**
|
98 |
+
- **Model**: Tesla T4
|
99 |
+
- A popular NVIDIA GPU for deep learning and AI workloads.
|
100 |
+
- It belongs to the Turing architecture (TU104GL).
|
101 |
+
|
102 |
+
- **Memory**: 16 GB
|
103 |
+
- Only **2 MiB is currently in use**, indicating minimal GPU activity.
|
104 |
+
|
105 |
+
- **Temperature**: 25°C
|
106 |
+
- The GPU is operating at a low temperature, suggesting no heavy utilization currently.
|
107 |
+
|
108 |
+
- **Power Usage**: 11W / 70W
|
109 |
+
- The GPU is in idle or low-performance mode (P8).
|
110 |
+
|
111 |
+
- **MIG Mode**: Not enabled.
|
112 |
+
- MIG (Multi-Instance GPU) mode is specific to NVIDIA A100 and other GPUs, so it is not applicable here.
|
113 |
+
|
114 |
+
---
|
115 |
+
|
116 |
+
### **2. Driver and CUDA Version**
|
117 |
+
- **Driver Version**: 535.216.03
|
118 |
+
- Installed NVIDIA driver supports CUDA 12.x.
|
119 |
+
|
120 |
+
- **CUDA Runtime Version**: 12.2
|
121 |
+
- This is the active runtime version compatible with the driver.
|
122 |
+
|
123 |
+
---
|
124 |
+
|
125 |
+
### **3. CUDA Toolkit Versions**
|
126 |
+
From your `nvcc` and file system checks:
|
127 |
+
- **Default `nvcc` Version**: CUDA 10.1
|
128 |
+
- The system's default `nvcc` is pointing to an older CUDA 10.1 installation (`nvcc --version` output shows CUDA 10.1).
|
129 |
+
|
130 |
+
- **Installed CUDA Toolkits**:
|
131 |
+
- `cuda-12`
|
132 |
+
- `cuda-12.2`
|
133 |
+
- `cuda` (likely symlinked to `cuda-12.2`)
|
134 |
+
|
135 |
+
Multiple CUDA versions are installed. However, the runtime and drivers align with **CUDA 12.2**, while the default compiler (`nvcc`) is still from CUDA 10.1.
|
136 |
+
|
137 |
+
---
|
138 |
+
|
139 |
+
### **4. cuDNN Version**
|
140 |
+
From `cudnn_version.h` and `ldconfig`:
|
141 |
+
- **cuDNN Version**: 9.5.1
|
142 |
+
- This cuDNN version is compatible with **CUDA 12.x**.
|
143 |
+
- **cuDNN Runtime**: The libraries for cuDNN 9 are present under `/lib/x86_64-linux-gnu`.
|
144 |
+
|
145 |
+
---
|
146 |
+
|
147 |
+
### **5. NVIDIA Software Packages**
|
148 |
+
From `dpkg`:
|
149 |
+
- **NVIDIA Drivers**: Driver version 535 is installed.
|
150 |
+
- **CUDA Toolkit**: Multiple versions installed (`10.1`, `12`, `12.2`).
|
151 |
+
- **cuDNN**: Versions for CUDA 12 and CUDA 12.6 are installed (`cudnn9-cuda-12`, `cudnn9-cuda-12-6`).
|
152 |
+
|
153 |
+
---
|
154 |
+
|
155 |
+
### **6. Other Observations**
|
156 |
+
- **Graphics Settings Issue**:
|
157 |
+
- `nvidia-settings` failed due to the lack of a display server connection (`Connection refused`). Likely, this is a headless server without a GUI environment.
|
158 |
+
|
159 |
+
- **OpenGL Tools Missing**:
|
160 |
+
- `glxinfo` command is missing, indicating the `mesa-utils` package needs to be installed.
|
161 |
+
|
162 |
+
---
|
163 |
+
|
164 |
+
### **Summary of Setup**
|
165 |
+
- **GPU**: Tesla T4
|
166 |
+
- **Driver Version**: 535.216.03
|
167 |
+
- **CUDA Runtime Version**: 12.2
|
168 |
+
- **CUDA Toolkit Versions**: 10.1 (default `nvcc`), 12, 12.2
|
169 |
+
- **cuDNN Version**: 9.5.1 (compatible with CUDA 12.x)
|
170 |
+
- **Software Packages**: NVIDIA drivers, CUDA, cuDNN installed
|
171 |
+
```
|
172 |
+
|
173 |
+
__CUDA New Installation__:
|
174 |
+
```bash
|
175 |
+
# if you don't have nvidia drivers installed and you have a Tesla T4 GPU
|
176 |
+
lspci | grep -i nvidia # Check if the GPU is detected
|
177 |
+
To set up the T4 GPU from scratch, starting with no drivers or CUDA tools, and replicating the above configurations and drivers, follow these reverse-engineered steps:
|
178 |
+
|
179 |
+
---
|
180 |
+
|
181 |
+
### **1. Update System**
|
182 |
+
Ensure the system is updated:
|
183 |
+
```bash
|
184 |
+
sudo apt update && sudo apt upgrade -y
|
185 |
+
sudo reboot
|
186 |
+
```
|
187 |
+
|
188 |
+
---
|
189 |
+
|
190 |
+
### **2. Install NVIDIA Driver**
|
191 |
+
#### **a. Identify Required Driver**
|
192 |
+
The T4 GPU requires a compatible NVIDIA driver version. Based on your configurations, we will install **Driver 535**.
|
193 |
+
|
194 |
+
#### **b. Add NVIDIA Repository**
|
195 |
+
Add the official NVIDIA driver repository:
|
196 |
+
```bash
|
197 |
+
sudo apt install -y software-properties-common
|
198 |
+
sudo add-apt-repository -y ppa:graphics-drivers/ppa
|
199 |
+
sudo apt update
|
200 |
+
```
|
201 |
+
|
202 |
+
#### **c. Install Driver**
|
203 |
+
Install the driver for the T4 GPU:
|
204 |
+
```bash
|
205 |
+
sudo apt install -y nvidia-driver-535
|
206 |
+
```
|
207 |
+
|
208 |
+
#### **d. Verify Driver Installation**
|
209 |
+
Reboot the system and check the driver:
|
210 |
+
```bash
|
211 |
+
sudo reboot
|
212 |
+
nvidia-smi
|
213 |
+
```
|
214 |
+
This should display the GPU model and driver version.
|
215 |
+
|
216 |
+
---
|
217 |
+
|
218 |
+
### **3. Install CUDA Toolkit**
|
219 |
+
#### **a. Add CUDA Repository**
|
220 |
+
Download and install the CUDA 12.2 repository for Ubuntu 20.04:
|
221 |
+
```bash
|
222 |
+
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin
|
223 |
+
sudo mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600
|
224 |
+
wget https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda-repo-ubuntu2004-12-2-local_12.2.0-535.86.10-1_amd64.deb
|
225 |
+
sudo dpkg -i cuda-repo-ubuntu2004-12-2-local_12.2.0-535.86.10-1_amd64.deb
|
226 |
+
sudo cp /var/cuda-repo-ubuntu2004-12-2-local/cuda-*-keyring.gpg /usr/share/keyrings/
|
227 |
+
sudo apt update
|
228 |
+
```
|
229 |
+
|
230 |
+
#### **b. Install CUDA Toolkit**
|
231 |
+
Install CUDA 12.2:
|
232 |
+
```bash
|
233 |
+
sudo apt install -y cuda
|
234 |
+
```
|
235 |
+
|
236 |
+
#### **c. Set Up Environment Variables**
|
237 |
+
Add CUDA binaries to the PATH and library paths:
|
238 |
+
```bash
|
239 |
+
echo 'export PATH=/usr/local/cuda-12.2/bin:$PATH' >> ~/.bashrc
|
240 |
+
echo 'export LD_LIBRARY_PATH=/usr/local/cuda-12.2/lib64:$LD_LIBRARY_PATH' >> ~/.bashrc
|
241 |
+
source ~/.bashrc
|
242 |
+
```
|
243 |
+
|
244 |
+
#### **d. Verify CUDA Installation**
|
245 |
+
Check CUDA installation:
|
246 |
+
```bash
|
247 |
+
nvcc --version
|
248 |
+
nvidia-smi
|
249 |
+
```
|
250 |
+
|
251 |
+
---
|
252 |
+
|
253 |
+
### **4. Install cuDNN**
|
254 |
+
#### **a. Download cuDNN**
|
255 |
+
Download cuDNN 9.5.1 (compatible with CUDA 12.x) from the [NVIDIA cuDNN page](https://developer.nvidia.com/cudnn). You’ll need to log in and download the appropriate `.deb` files for Ubuntu 20.04.
|
256 |
+
|
257 |
+
#### **b. Install cuDNN**
|
258 |
+
Install the downloaded `.deb` files:
|
259 |
+
```bash
|
260 |
+
sudo dpkg -i libcudnn9*.deb
|
261 |
+
```
|
262 |
+
|
263 |
+
#### **c. Verify cuDNN**
|
264 |
+
Check the installed version:
|
265 |
+
```bash
|
266 |
+
cat /usr/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
|
267 |
+
```
|
268 |
+
|
269 |
+
---
|
270 |
+
|
271 |
+
### **5. Install NCCL and Other Libraries**
|
272 |
+
Install additional NVIDIA libraries (like NCCL) required for distributed deep learning:
|
273 |
+
```bash
|
274 |
+
sudo apt install -y libnccl2 libnccl-dev
|
275 |
+
```
|
276 |
+
|
277 |
+
---
|
278 |
+
|
279 |
+
### **6. Install PyTorch**
|
280 |
+
#### **a. Install Python Environment**
|
281 |
+
Install Python and `pip` if not already present:
|
282 |
+
```bash
|
283 |
+
sudo apt install -y python3 python3-pip
|
284 |
+
```
|
285 |
+
|
286 |
+
#### **b. Install PyTorch with CUDA 12.2**
|
287 |
+
Install PyTorch with the appropriate CUDA runtime:
|
288 |
+
```bash
|
289 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu122
|
290 |
+
```
|
291 |
+
|
292 |
+
#### **c. Test PyTorch**
|
293 |
+
Run a quick test:
|
294 |
+
```python
|
295 |
+
import torch
|
296 |
+
print(torch.cuda.is_available()) # Should return True
|
297 |
+
print(torch.cuda.get_device_name(0)) # Should return "Tesla T4"
|
298 |
+
```
|
299 |
+
|
300 |
+
---
|
301 |
+
|
302 |
+
### **7. Optional: Install Nsight Tools**
|
303 |
+
For debugging and profiling:
|
304 |
+
```bash
|
305 |
+
sudo apt install -y nsight-compute nsight-systems
|
306 |
+
```
|
307 |
+
|
308 |
+
---
|
309 |
+
|
310 |
+
### **8. Check for OpenGL**
|
311 |
+
If you need OpenGL utilities (like `glxinfo`):
|
312 |
+
```bash
|
313 |
+
sudo apt install -y mesa-utils
|
314 |
+
glxinfo | grep "OpenGL version"
|
315 |
+
```
|
316 |
+
|
317 |
+
---
|
318 |
+
|
319 |
+
### **9. Validate Entire Setup**
|
320 |
+
Run the NVIDIA sample tests to confirm the configuration:
|
321 |
+
```bash
|
322 |
+
cd /usr/local/cuda-12.2/samples/1_Utilities/deviceQuery
|
323 |
+
make
|
324 |
+
./deviceQuery
|
325 |
+
```
|
326 |
+
If successful, it should show details of the T4 GPU.
|
327 |
+
|
328 |
+
---
|
329 |
+
|
330 |
+
### **Summary of Installed Components**
|
331 |
+
- **GPU**: Tesla T4
|
332 |
+
- **Driver**: 535
|
333 |
+
- **CUDA Toolkit**: 12.2
|
334 |
+
- **cuDNN**: 9.5.1
|
335 |
+
- **PyTorch**: Installed with CUDA 12.2 support
|
336 |
+
|
337 |
+
This setup ensures your system is ready for deep learning workloads with the T4 GPU.
|
338 |
+
|
339 |
+
Install conda and create a new environment for the project
|
340 |
+
Install pytorch and torchvision in the new environment
|
341 |
+
Install other dependencies like numpy, pandas, matplotlib, etc.
|
342 |
+
Run the project code in the new environment
|
343 |
+
>>> import torch
|
344 |
+
>>> print(torch.cuda.is_available())
|
345 |
+
>>> print(torch.cuda.get_device_name(0))
|
346 |
+
>>> print(torch.version.cuda)
|
347 |
+
```
|
348 |
+
__CUDA Docker Setup__:
|
349 |
+
```bash
|
350 |
+
# If you are using docker and want to run a container with CUDA support
|
351 |
+
sudo apt install -y nvidia-container-toolkit
|
352 |
+
nvidia-ctk --version
|
353 |
+
sudo systemctl restart docker
|
354 |
+
sudo systemctl status docker
|
355 |
+
docker run --rm --gpus all nvidia/cuda:12.2.0-base-ubuntu20.04 nvidia-smi
|
356 |
+
docker run --rm --gpus all nvidia/cuda:12.2.0-base-ubuntu20.04 nvcc --version
|
357 |
+
```
|
image.jpg
ADDED
![]() |
main.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
print(torch.cuda.is_available()) # Should return True if our GPU is enabled
|
4 |
+
print(torch.cuda.get_device_name(0)) # Should return "Tesla T4" if our GPU is enabled
|
5 |
+
print(torch.version.cuda) # Should return "12.4" if our GPU is enabled
|
notebooks/datamodule_lightning.ipynb
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"vscode": {
|
7 |
+
"languageId": "plaintext"
|
8 |
+
}
|
9 |
+
},
|
10 |
+
"source": [
|
11 |
+
"In this notebook, we will be discussing about the pytorch lightning datamodule library with images in a folder strutcture with folders as class labels. We will be using the cats and dogs dataset from kaggle. The dataset can be downloaded from [here](https://www.kaggle.com/c/dogs-vs-cats/data). The dataset contains 25000 images of cats and dogs. We will be using 20000 images for training and 5000 images for validation. The images are in a folder structure with folders as class labels."
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 1,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [
|
19 |
+
{
|
20 |
+
"data": {
|
21 |
+
"application/javascript": "IPython.notebook.set_autosave_interval(300000)"
|
22 |
+
},
|
23 |
+
"metadata": {},
|
24 |
+
"output_type": "display_data"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"name": "stdout",
|
28 |
+
"output_type": "stream",
|
29 |
+
"text": [
|
30 |
+
"Autosaving every 300 seconds\n"
|
31 |
+
]
|
32 |
+
}
|
33 |
+
],
|
34 |
+
"source": [
|
35 |
+
"%autosave 300\n",
|
36 |
+
"%load_ext autoreload\n",
|
37 |
+
"%autoreload 2\n",
|
38 |
+
"%reload_ext autoreload\n",
|
39 |
+
"%config Completer.use_jedi = False"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": 2,
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [
|
47 |
+
{
|
48 |
+
"name": "stdout",
|
49 |
+
"output_type": "stream",
|
50 |
+
"text": [
|
51 |
+
"/mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws\n"
|
52 |
+
]
|
53 |
+
}
|
54 |
+
],
|
55 |
+
"source": [
|
56 |
+
"import os\n",
|
57 |
+
"\n",
|
58 |
+
"os.chdir(\"..\")\n",
|
59 |
+
"print(os.getcwd())"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": 3,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [
|
67 |
+
{
|
68 |
+
"name": "stderr",
|
69 |
+
"output_type": "stream",
|
70 |
+
"text": [
|
71 |
+
"/anaconda/envs/emlo_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
72 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
73 |
+
]
|
74 |
+
}
|
75 |
+
],
|
76 |
+
"source": [
|
77 |
+
"from pathlib import Path\n",
|
78 |
+
"from typing import Union, Tuple, Optional, List\n",
|
79 |
+
"import os\n",
|
80 |
+
"import lightning as L\n",
|
81 |
+
"from torch.utils.data import DataLoader, random_split\n",
|
82 |
+
"from torchvision import transforms\n",
|
83 |
+
"from torchvision.datasets import ImageFolder\n",
|
84 |
+
"from torchvision.datasets.utils import download_and_extract_archive\n",
|
85 |
+
"from loguru import logger"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": 32,
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"class CatDogImageDataModule(L.LightningDataModule):\n",
|
95 |
+
" \"\"\"DataModule for Cat and Dog Image Classification using ImageFolder.\"\"\"\n",
|
96 |
+
"\n",
|
97 |
+
" def __init__(\n",
|
98 |
+
" self,\n",
|
99 |
+
" data_root: Union[str, Path] = \"data\",\n",
|
100 |
+
" data_dir: Union[str, Path] = \"cats_and_dogs_filtered\",\n",
|
101 |
+
" batch_size: int = 32,\n",
|
102 |
+
" num_workers: int = 4,\n",
|
103 |
+
" train_val_split: List[float] = [0.8, 0.2],\n",
|
104 |
+
" pin_memory: bool = False,\n",
|
105 |
+
" image_size: int = 224,\n",
|
106 |
+
" url: str = \"https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\",\n",
|
107 |
+
" ):\n",
|
108 |
+
" super().__init__()\n",
|
109 |
+
" self.data_root = Path(data_root)\n",
|
110 |
+
" self.data_dir = data_dir\n",
|
111 |
+
" self.batch_size = batch_size\n",
|
112 |
+
" self.num_workers = num_workers\n",
|
113 |
+
" self.train_val_split = train_val_split\n",
|
114 |
+
" self.pin_memory = pin_memory\n",
|
115 |
+
" self.image_size = image_size\n",
|
116 |
+
" self.url = url\n",
|
117 |
+
"\n",
|
118 |
+
" # Initialize variables for datasets\n",
|
119 |
+
" self.train_dataset = None\n",
|
120 |
+
" self.val_dataset = None\n",
|
121 |
+
" self.test_dataset = None\n",
|
122 |
+
"\n",
|
123 |
+
" def prepare_data(self):\n",
|
124 |
+
" \"\"\"Download the dataset if it doesn't exist.\"\"\"\n",
|
125 |
+
" self.dataset_path = self.data_root / self.data_dir\n",
|
126 |
+
" if not self.dataset_path.exists():\n",
|
127 |
+
" logger.info(\"Downloading and extracting dataset.\")\n",
|
128 |
+
" download_and_extract_archive(\n",
|
129 |
+
" url=self.url, download_root=self.data_root, remove_finished=True\n",
|
130 |
+
" )\n",
|
131 |
+
" logger.info(\"Download completed.\")\n",
|
132 |
+
"\n",
|
133 |
+
" def setup(self, stage: Optional[str] = None):\n",
|
134 |
+
" \"\"\"Set up the train, validation, and test datasets.\"\"\"\n",
|
135 |
+
"\n",
|
136 |
+
" train_transform = transforms.Compose(\n",
|
137 |
+
" [\n",
|
138 |
+
" transforms.Resize((self.image_size, self.image_size)),\n",
|
139 |
+
" transforms.RandomHorizontalFlip(0.1),\n",
|
140 |
+
" transforms.RandomRotation(10),\n",
|
141 |
+
" transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),\n",
|
142 |
+
" transforms.RandomAutocontrast(0.1),\n",
|
143 |
+
" transforms.RandomAdjustSharpness(2, 0.1),\n",
|
144 |
+
" transforms.ToTensor(),\n",
|
145 |
+
" transforms.Normalize(\n",
|
146 |
+
" mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
147 |
+
" ),\n",
|
148 |
+
" ]\n",
|
149 |
+
" )\n",
|
150 |
+
"\n",
|
151 |
+
" test_transform = transforms.Compose(\n",
|
152 |
+
" [\n",
|
153 |
+
" transforms.Resize((self.image_size, self.image_size)),\n",
|
154 |
+
" transforms.ToTensor(),\n",
|
155 |
+
" transforms.Normalize(\n",
|
156 |
+
" mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
157 |
+
" ),\n",
|
158 |
+
" ]\n",
|
159 |
+
" )\n",
|
160 |
+
"\n",
|
161 |
+
" train_path = self.dataset_path / \"train\"\n",
|
162 |
+
" test_path = self.dataset_path / \"test\"\n",
|
163 |
+
"\n",
|
164 |
+
" self.prepare_data()\n",
|
165 |
+
"\n",
|
166 |
+
" if stage == \"fit\" or stage is None:\n",
|
167 |
+
" full_train_dataset = ImageFolder(root=train_path, transform=train_transform)\n",
|
168 |
+
" self.class_names = full_train_dataset.classes\n",
|
169 |
+
" train_size = int(self.train_val_split[0] * len(full_train_dataset))\n",
|
170 |
+
" val_size = len(full_train_dataset) - train_size\n",
|
171 |
+
" self.train_dataset, self.val_dataset = random_split(\n",
|
172 |
+
" full_train_dataset, [train_size, val_size]\n",
|
173 |
+
" )\n",
|
174 |
+
" logger.info(\n",
|
175 |
+
" f\"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images.\"\n",
|
176 |
+
" )\n",
|
177 |
+
"\n",
|
178 |
+
" if stage == \"test\" or stage is None:\n",
|
179 |
+
" self.test_dataset = ImageFolder(root=test_path, transform=test_transform)\n",
|
180 |
+
" logger.info(f\"Test dataset size: {len(self.test_dataset)} images.\")\n",
|
181 |
+
"\n",
|
182 |
+
" def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader:\n",
|
183 |
+
" \"\"\"Helper function to create a DataLoader.\"\"\"\n",
|
184 |
+
" return DataLoader(\n",
|
185 |
+
" dataset=dataset,\n",
|
186 |
+
" batch_size=self.batch_size,\n",
|
187 |
+
" num_workers=self.num_workers,\n",
|
188 |
+
" pin_memory=self.pin_memory,\n",
|
189 |
+
" shuffle=shuffle,\n",
|
190 |
+
" )\n",
|
191 |
+
"\n",
|
192 |
+
" def train_dataloader(self) -> DataLoader:\n",
|
193 |
+
" return self._create_dataloader(self.train_dataset, shuffle=True)\n",
|
194 |
+
"\n",
|
195 |
+
" def val_dataloader(self) -> DataLoader:\n",
|
196 |
+
" return self._create_dataloader(self.val_dataset)\n",
|
197 |
+
"\n",
|
198 |
+
" def test_dataloader(self) -> DataLoader:\n",
|
199 |
+
" return self._create_dataloader(self.test_dataset)\n",
|
200 |
+
"\n",
|
201 |
+
" def get_class_names(self) -> List[str]:\n",
|
202 |
+
" return self.class_names"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"cell_type": "code",
|
207 |
+
"execution_count": 33,
|
208 |
+
"metadata": {},
|
209 |
+
"outputs": [],
|
210 |
+
"source": [
|
211 |
+
"datamodule = CatDogImageDataModule(\n",
|
212 |
+
" data_root=\"data\",\n",
|
213 |
+
" data_dir=\"cats_and_dogs_filtered\",\n",
|
214 |
+
" batch_size=32,\n",
|
215 |
+
" num_workers=4,\n",
|
216 |
+
" train_val_split=[0.8, 0.2],\n",
|
217 |
+
" pin_memory=True,\n",
|
218 |
+
" image_size=224,\n",
|
219 |
+
" url=\"https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\",\n",
|
220 |
+
")"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"cell_type": "code",
|
225 |
+
"execution_count": 35,
|
226 |
+
"metadata": {},
|
227 |
+
"outputs": [
|
228 |
+
{
|
229 |
+
"name": "stderr",
|
230 |
+
"output_type": "stream",
|
231 |
+
"text": [
|
232 |
+
"\u001b[32m2024-11-10 05:37:17.840\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m81\u001b[0m - \u001b[1mTrain/Validation split: 2241 train, 561 validation images.\u001b[0m\n"
|
233 |
+
]
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"name": "stderr",
|
237 |
+
"output_type": "stream",
|
238 |
+
"text": [
|
239 |
+
"\u001b[32m2024-11-10 05:37:17.910\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m87\u001b[0m - \u001b[1mTest dataset size: 198 images.\u001b[0m\n"
|
240 |
+
]
|
241 |
+
}
|
242 |
+
],
|
243 |
+
"source": [
|
244 |
+
"datamodule.prepare_data()\n",
|
245 |
+
"datamodule.setup()\n",
|
246 |
+
"class_names = datamodule.get_class_names()\n",
|
247 |
+
"train_dataloader = datamodule.train_dataloader()\n",
|
248 |
+
"val_dataloader= datamodule.val_dataloader()\n",
|
249 |
+
"test_dataloader= datamodule.test_dataloader()"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "code",
|
254 |
+
"execution_count": 36,
|
255 |
+
"metadata": {},
|
256 |
+
"outputs": [
|
257 |
+
{
|
258 |
+
"data": {
|
259 |
+
"text/plain": [
|
260 |
+
"['cats', 'dogs']"
|
261 |
+
]
|
262 |
+
},
|
263 |
+
"execution_count": 36,
|
264 |
+
"metadata": {},
|
265 |
+
"output_type": "execute_result"
|
266 |
+
}
|
267 |
+
],
|
268 |
+
"source": [
|
269 |
+
"class_names"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "code",
|
274 |
+
"execution_count": null,
|
275 |
+
"metadata": {},
|
276 |
+
"outputs": [],
|
277 |
+
"source": []
|
278 |
+
}
|
279 |
+
],
|
280 |
+
"metadata": {
|
281 |
+
"kernelspec": {
|
282 |
+
"display_name": "emlo_env",
|
283 |
+
"language": "python",
|
284 |
+
"name": "python3"
|
285 |
+
},
|
286 |
+
"language_info": {
|
287 |
+
"codemirror_mode": {
|
288 |
+
"name": "ipython",
|
289 |
+
"version": 3
|
290 |
+
},
|
291 |
+
"file_extension": ".py",
|
292 |
+
"mimetype": "text/x-python",
|
293 |
+
"name": "python",
|
294 |
+
"nbconvert_exporter": "python",
|
295 |
+
"pygments_lexer": "ipython3",
|
296 |
+
"version": "3.10.15"
|
297 |
+
}
|
298 |
+
},
|
299 |
+
"nbformat": 4,
|
300 |
+
"nbformat_minor": 2
|
301 |
+
}
|
notebooks/training_lightning_tests.ipynb
ADDED
@@ -0,0 +1,1011 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"data": {
|
10 |
+
"application/javascript": "IPython.notebook.set_autosave_interval(300000)"
|
11 |
+
},
|
12 |
+
"metadata": {},
|
13 |
+
"output_type": "display_data"
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"name": "stdout",
|
17 |
+
"output_type": "stream",
|
18 |
+
"text": [
|
19 |
+
"Autosaving every 300 seconds\n"
|
20 |
+
]
|
21 |
+
}
|
22 |
+
],
|
23 |
+
"source": [
|
24 |
+
"%autosave 300\n",
|
25 |
+
"%load_ext autoreload\n",
|
26 |
+
"%autoreload 2\n",
|
27 |
+
"%reload_ext autoreload\n",
|
28 |
+
"%config Completer.use_jedi = False"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 2,
|
34 |
+
"metadata": {},
|
35 |
+
"outputs": [
|
36 |
+
{
|
37 |
+
"name": "stdout",
|
38 |
+
"output_type": "stream",
|
39 |
+
"text": [
|
40 |
+
"/mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws\n"
|
41 |
+
]
|
42 |
+
}
|
43 |
+
],
|
44 |
+
"source": [
|
45 |
+
"\n",
|
46 |
+
"import os\n",
|
47 |
+
"\n",
|
48 |
+
"os.chdir(\"..\")\n",
|
49 |
+
"print(os.getcwd())"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": 3,
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [
|
57 |
+
{
|
58 |
+
"name": "stderr",
|
59 |
+
"output_type": "stream",
|
60 |
+
"text": [
|
61 |
+
"/anaconda/envs/emlo_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
62 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
63 |
+
]
|
64 |
+
}
|
65 |
+
],
|
66 |
+
"source": [
|
67 |
+
"import os\n",
|
68 |
+
"import shutil\n",
|
69 |
+
"from pathlib import Path\n",
|
70 |
+
"import torch\n",
|
71 |
+
"import lightning as L\n",
|
72 |
+
"from lightning.pytorch.loggers import Logger\n",
|
73 |
+
"from typing import List\n",
|
74 |
+
"from src.datamodules.catdog_datamodule import CatDogImageDataModule\n",
|
75 |
+
"from src.utils.logging_utils import setup_logger, task_wrapper\n",
|
76 |
+
"from loguru import logger\n",
|
77 |
+
"from dotenv import load_dotenv, find_dotenv\n",
|
78 |
+
"import rootutils\n",
|
79 |
+
"import hydra\n",
|
80 |
+
"from omegaconf import DictConfig, OmegaConf\n",
|
81 |
+
"from lightning.pytorch.callbacks import (\n",
|
82 |
+
" ModelCheckpoint,\n",
|
83 |
+
" EarlyStopping,\n",
|
84 |
+
" RichModelSummary,\n",
|
85 |
+
" RichProgressBar,\n",
|
86 |
+
")\n",
|
87 |
+
"from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": 4,
|
93 |
+
"metadata": {},
|
94 |
+
"outputs": [
|
95 |
+
{
|
96 |
+
"name": "stderr",
|
97 |
+
"output_type": "stream",
|
98 |
+
"text": [
|
99 |
+
"\u001b[32m2024-11-08 18:25:17.572\u001b[0m | \u001b[31m\u001b[1mERROR \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m9\u001b[0m - \u001b[31m\u001b[1mname '__file__' is not defined\u001b[0m\n"
|
100 |
+
]
|
101 |
+
}
|
102 |
+
],
|
103 |
+
"source": [
|
104 |
+
"# Load environment variables\n",
|
105 |
+
"load_dotenv(find_dotenv(\".env\"))\n",
|
106 |
+
"\n",
|
107 |
+
"# Setup root directory\n",
|
108 |
+
"try:\n",
|
109 |
+
" root = rootutils.setup_root(__file__, indicator=\".project-root\")\n",
|
110 |
+
"\n",
|
111 |
+
"except Exception as e:\n",
|
112 |
+
" logger.error(e)\n",
|
113 |
+
" root = Path(os.getcwd())\n",
|
114 |
+
" os.environ[\"PROJECT_ROOT\"] = str(root)"
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "code",
|
119 |
+
"execution_count": 5,
|
120 |
+
"metadata": {},
|
121 |
+
"outputs": [],
|
122 |
+
"source": [
|
123 |
+
"def load_checkpoint_if_available(ckpt_path: str) -> str:\n",
|
124 |
+
" \"\"\"Check if the specified checkpoint exists and return the valid checkpoint path.\"\"\"\n",
|
125 |
+
" if ckpt_path and Path(ckpt_path).exists():\n",
|
126 |
+
" logger.info(f\"Checkpoint found: {ckpt_path}\")\n",
|
127 |
+
" return ckpt_path\n",
|
128 |
+
" else:\n",
|
129 |
+
" logger.warning(\n",
|
130 |
+
" f\"No checkpoint found at {ckpt_path}. Using current model weights.\"\n",
|
131 |
+
" )\n",
|
132 |
+
" return None\n",
|
133 |
+
"\n",
|
134 |
+
"\n",
|
135 |
+
"def clear_checkpoint_directory(ckpt_dir: str):\n",
|
136 |
+
" \"\"\"Clear all contents of the checkpoint directory without deleting the directory itself.\"\"\"\n",
|
137 |
+
" ckpt_dir_path = Path(ckpt_dir)\n",
|
138 |
+
" if ckpt_dir_path.exists() and ckpt_dir_path.is_dir():\n",
|
139 |
+
" logger.info(f\"Clearing checkpoint directory: {ckpt_dir}\")\n",
|
140 |
+
" # Iterate over all files and directories in the checkpoint directory and remove them\n",
|
141 |
+
" for item in ckpt_dir_path.iterdir():\n",
|
142 |
+
" try:\n",
|
143 |
+
" if item.is_file() or item.is_symlink():\n",
|
144 |
+
" item.unlink() # Remove file or symlink\n",
|
145 |
+
" elif item.is_dir():\n",
|
146 |
+
" shutil.rmtree(item) # Remove directory\n",
|
147 |
+
" except Exception as e:\n",
|
148 |
+
" logger.error(f\"Failed to delete {item}: {e}\")\n",
|
149 |
+
" logger.info(f\"Checkpoint directory cleared: {ckpt_dir}\")\n",
|
150 |
+
" else:\n",
|
151 |
+
" logger.info(\n",
|
152 |
+
" f\"Checkpoint directory does not exist. Creating directory: {ckpt_dir}\"\n",
|
153 |
+
" )\n",
|
154 |
+
" os.makedirs(ckpt_dir_path, exist_ok=True)\n",
|
155 |
+
"\n",
|
156 |
+
"\n",
|
157 |
+
"@task_wrapper\n",
|
158 |
+
"def train_module(\n",
|
159 |
+
" cfg: DictConfig,\n",
|
160 |
+
" data_module: L.LightningDataModule,\n",
|
161 |
+
" model: L.LightningModule,\n",
|
162 |
+
" trainer: L.Trainer,\n",
|
163 |
+
"):\n",
|
164 |
+
" \"\"\"Train the model using the provided Trainer and DataModule.\"\"\"\n",
|
165 |
+
" logger.info(\"Training the model\")\n",
|
166 |
+
" trainer.fit(model, data_module)\n",
|
167 |
+
" train_metrics = trainer.callback_metrics\n",
|
168 |
+
" try:\n",
|
169 |
+
" logger.info(\n",
|
170 |
+
" f\"Training completed with the following metrics- train_acc: {train_metrics['train_acc'].item()} and val_acc: {train_metrics['val_acc'].item()}\"\n",
|
171 |
+
" )\n",
|
172 |
+
" except KeyError:\n",
|
173 |
+
" logger.info(f\"Training completed with the following metrics:{train_metrics}\")\n",
|
174 |
+
"\n",
|
175 |
+
" return train_metrics\n",
|
176 |
+
"\n",
|
177 |
+
"\n",
|
178 |
+
"@task_wrapper\n",
|
179 |
+
"def run_test_module(\n",
|
180 |
+
" cfg: DictConfig,\n",
|
181 |
+
" datamodule: L.LightningDataModule,\n",
|
182 |
+
" model: L.LightningModule,\n",
|
183 |
+
" trainer: L.Trainer,\n",
|
184 |
+
"):\n",
|
185 |
+
" \"\"\"Test the model using the best checkpoint or the current model weights.\"\"\"\n",
|
186 |
+
" logger.info(\"Testing the model\")\n",
|
187 |
+
" datamodule.setup(stage=\"test\")\n",
|
188 |
+
"\n",
|
189 |
+
" ckpt_path = load_checkpoint_if_available(cfg.ckpt_path)\n",
|
190 |
+
"\n",
|
191 |
+
" # If no checkpoint is available, Lightning will use current model weights\n",
|
192 |
+
" test_metrics = trainer.test(model, datamodule, ckpt_path=ckpt_path)\n",
|
193 |
+
" logger.info(f\"Test metrics:\\n{test_metrics}\")\n",
|
194 |
+
"\n",
|
195 |
+
" return test_metrics[0] if test_metrics else {}"
|
196 |
+
]
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"cell_type": "code",
|
200 |
+
"execution_count": 6,
|
201 |
+
"metadata": {},
|
202 |
+
"outputs": [
|
203 |
+
{
|
204 |
+
"name": "stderr",
|
205 |
+
"output_type": "stream",
|
206 |
+
"text": [
|
207 |
+
"/tmp/ipykernel_487789/541470590.py:8: UserWarning: \n",
|
208 |
+
"The version_base parameter is not specified.\n",
|
209 |
+
"Please specify a compatability version level, or None.\n",
|
210 |
+
"Will assume defaults for version 1.1\n",
|
211 |
+
" with hydra.initialize(config_path=\"../configs\"):\n"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"name": "stdout",
|
216 |
+
"output_type": "stream",
|
217 |
+
"text": [
|
218 |
+
"Full Configuration:\n",
|
219 |
+
"task_name: train\n",
|
220 |
+
"tags:\n",
|
221 |
+
"- dev\n",
|
222 |
+
"train: true\n",
|
223 |
+
"test: false\n",
|
224 |
+
"ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt\n",
|
225 |
+
"seed: 42\n",
|
226 |
+
"name: catdog_experiment\n",
|
227 |
+
"data:\n",
|
228 |
+
" _target_: src.datamodules.catdog_datamodule.CatDogImageDataModule\n",
|
229 |
+
" data_dir: ${paths.data_dir}\n",
|
230 |
+
" url: ${paths.data_url}\n",
|
231 |
+
" num_workers: 8\n",
|
232 |
+
" batch_size: 64\n",
|
233 |
+
" train_val_split:\n",
|
234 |
+
" - 0.8\n",
|
235 |
+
" - 0.2\n",
|
236 |
+
" pin_memory: true\n",
|
237 |
+
" image_size: 160\n",
|
238 |
+
"model:\n",
|
239 |
+
" _target_: src.models.catdog_model.ViTTinyClassifier\n",
|
240 |
+
" img_size: 160\n",
|
241 |
+
" patch_size: 16\n",
|
242 |
+
" num_classes: 2\n",
|
243 |
+
" embed_dim: 64\n",
|
244 |
+
" depth: 6\n",
|
245 |
+
" num_heads: 2\n",
|
246 |
+
" mlp_ratio: 3\n",
|
247 |
+
" pre_norm: false\n",
|
248 |
+
" lr: 0.001\n",
|
249 |
+
" weight_decay: 1.0e-05\n",
|
250 |
+
" factor: 0.1\n",
|
251 |
+
" patience: 10\n",
|
252 |
+
" min_lr: 1.0e-06\n",
|
253 |
+
"callbacks:\n",
|
254 |
+
" model_checkpoint:\n",
|
255 |
+
" dirpath: ${paths.ckpt_dir}\n",
|
256 |
+
" filename: best-checkpoint\n",
|
257 |
+
" monitor: val_acc\n",
|
258 |
+
" verbose: true\n",
|
259 |
+
" save_last: true\n",
|
260 |
+
" save_top_k: 1\n",
|
261 |
+
" mode: max\n",
|
262 |
+
" auto_insert_metric_name: false\n",
|
263 |
+
" save_weights_only: false\n",
|
264 |
+
" every_n_train_steps: null\n",
|
265 |
+
" train_time_interval: null\n",
|
266 |
+
" every_n_epochs: null\n",
|
267 |
+
" save_on_train_epoch_end: null\n",
|
268 |
+
" early_stopping:\n",
|
269 |
+
" monitor: val_acc\n",
|
270 |
+
" min_delta: 0.0\n",
|
271 |
+
" patience: 10\n",
|
272 |
+
" verbose: true\n",
|
273 |
+
" mode: max\n",
|
274 |
+
" strict: true\n",
|
275 |
+
" check_finite: true\n",
|
276 |
+
" stopping_threshold: null\n",
|
277 |
+
" divergence_threshold: null\n",
|
278 |
+
" check_on_train_epoch_end: null\n",
|
279 |
+
" rich_model_summary:\n",
|
280 |
+
" max_depth: 1\n",
|
281 |
+
" rich_progress_bar:\n",
|
282 |
+
" refresh_rate: 1\n",
|
283 |
+
"logger:\n",
|
284 |
+
" csv:\n",
|
285 |
+
" save_dir: ${paths.output_dir}\n",
|
286 |
+
" name: csv/\n",
|
287 |
+
" prefix: ''\n",
|
288 |
+
" tensorboard:\n",
|
289 |
+
" save_dir: ${paths.output_dir}/tensorboard/\n",
|
290 |
+
" name: null\n",
|
291 |
+
" log_graph: false\n",
|
292 |
+
" default_hp_metric: true\n",
|
293 |
+
" prefix: ''\n",
|
294 |
+
"trainer:\n",
|
295 |
+
" _target_: lightning.Trainer\n",
|
296 |
+
" default_root_dir: ${paths.output_dir}\n",
|
297 |
+
" min_epochs: 1\n",
|
298 |
+
" max_epochs: 6\n",
|
299 |
+
" accelerator: auto\n",
|
300 |
+
" devices: auto\n",
|
301 |
+
" deterministic: true\n",
|
302 |
+
" log_every_n_steps: 10\n",
|
303 |
+
" fast_dev_run: false\n",
|
304 |
+
"paths:\n",
|
305 |
+
" root_dir: ${oc.env:PROJECT_ROOT}\n",
|
306 |
+
" data_dir: ${paths.root_dir}/data/\n",
|
307 |
+
" log_dir: ${paths.root_dir}/logs/\n",
|
308 |
+
" ckpt_dir: ${paths.root_dir}/checkpoints\n",
|
309 |
+
" artifact_dir: ${paths.root_dir}/artifacts/\n",
|
310 |
+
" data_url: https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\n",
|
311 |
+
" output_dir: ${hydra:runtime.output_dir}\n",
|
312 |
+
" work_dir: ${hydra:runtime.cwd}\n",
|
313 |
+
"\n"
|
314 |
+
]
|
315 |
+
}
|
316 |
+
],
|
317 |
+
"source": [
|
318 |
+
"import hydra\n",
|
319 |
+
"from omegaconf import DictConfig, OmegaConf\n",
|
320 |
+
"\n",
|
321 |
+
"\n",
|
322 |
+
"# Function to load the configuration as an object without using the @hydra.main decorator\n",
|
323 |
+
"def load_config() -> DictConfig:\n",
|
324 |
+
" # Initialize the configuration context (e.g., \"../configs\" directory)\n",
|
325 |
+
" with hydra.initialize(config_path=\"../configs\"):\n",
|
326 |
+
" # Compose the configuration object with a specific config name (e.g., \"train\")\n",
|
327 |
+
" cfg = hydra.compose(config_name=\"train\")\n",
|
328 |
+
" return cfg\n",
|
329 |
+
"\n",
|
330 |
+
"\n",
|
331 |
+
"# Load the configuration\n",
|
332 |
+
"cfg = load_config()\n",
|
333 |
+
"\n",
|
334 |
+
"# Print the entire configuration for reference\n",
|
335 |
+
"print(\"Full Configuration:\")\n",
|
336 |
+
"print(OmegaConf.to_yaml(cfg))"
|
337 |
+
]
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"cell_type": "code",
|
341 |
+
"execution_count": 7,
|
342 |
+
"metadata": {},
|
343 |
+
"outputs": [
|
344 |
+
{
|
345 |
+
"name": "stderr",
|
346 |
+
"output_type": "stream",
|
347 |
+
"text": [
|
348 |
+
"\u001b[32m2024-11-08 18:25:23\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m8\u001b[0m - \u001b[1mWhole Config:\n",
|
349 |
+
"task_name: train\n",
|
350 |
+
"tags:\n",
|
351 |
+
"- dev\n",
|
352 |
+
"train: true\n",
|
353 |
+
"test: false\n",
|
354 |
+
"ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt\n",
|
355 |
+
"seed: 42\n",
|
356 |
+
"name: catdog_experiment\n",
|
357 |
+
"data:\n",
|
358 |
+
" _target_: src.datamodules.catdog_datamodule.CatDogImageDataModule\n",
|
359 |
+
" data_dir: ${paths.data_dir}\n",
|
360 |
+
" url: ${paths.data_url}\n",
|
361 |
+
" num_workers: 8\n",
|
362 |
+
" batch_size: 64\n",
|
363 |
+
" train_val_split:\n",
|
364 |
+
" - 0.8\n",
|
365 |
+
" - 0.2\n",
|
366 |
+
" pin_memory: true\n",
|
367 |
+
" image_size: 160\n",
|
368 |
+
"model:\n",
|
369 |
+
" _target_: src.models.catdog_model.ViTTinyClassifier\n",
|
370 |
+
" img_size: 160\n",
|
371 |
+
" patch_size: 16\n",
|
372 |
+
" num_classes: 2\n",
|
373 |
+
" embed_dim: 64\n",
|
374 |
+
" depth: 6\n",
|
375 |
+
" num_heads: 2\n",
|
376 |
+
" mlp_ratio: 3\n",
|
377 |
+
" pre_norm: false\n",
|
378 |
+
" lr: 0.001\n",
|
379 |
+
" weight_decay: 1.0e-05\n",
|
380 |
+
" factor: 0.1\n",
|
381 |
+
" patience: 10\n",
|
382 |
+
" min_lr: 1.0e-06\n",
|
383 |
+
"callbacks:\n",
|
384 |
+
" model_checkpoint:\n",
|
385 |
+
" dirpath: ${paths.ckpt_dir}\n",
|
386 |
+
" filename: best-checkpoint\n",
|
387 |
+
" monitor: val_acc\n",
|
388 |
+
" verbose: true\n",
|
389 |
+
" save_last: true\n",
|
390 |
+
" save_top_k: 1\n",
|
391 |
+
" mode: max\n",
|
392 |
+
" auto_insert_metric_name: false\n",
|
393 |
+
" save_weights_only: false\n",
|
394 |
+
" every_n_train_steps: null\n",
|
395 |
+
" train_time_interval: null\n",
|
396 |
+
" every_n_epochs: null\n",
|
397 |
+
" save_on_train_epoch_end: null\n",
|
398 |
+
" early_stopping:\n",
|
399 |
+
" monitor: val_acc\n",
|
400 |
+
" min_delta: 0.0\n",
|
401 |
+
" patience: 10\n",
|
402 |
+
" verbose: true\n",
|
403 |
+
" mode: max\n",
|
404 |
+
" strict: true\n",
|
405 |
+
" check_finite: true\n",
|
406 |
+
" stopping_threshold: null\n",
|
407 |
+
" divergence_threshold: null\n",
|
408 |
+
" check_on_train_epoch_end: null\n",
|
409 |
+
" rich_model_summary:\n",
|
410 |
+
" max_depth: 1\n",
|
411 |
+
" rich_progress_bar:\n",
|
412 |
+
" refresh_rate: 1\n",
|
413 |
+
"logger:\n",
|
414 |
+
" csv:\n",
|
415 |
+
" save_dir: ${paths.output_dir}\n",
|
416 |
+
" name: csv/\n",
|
417 |
+
" prefix: ''\n",
|
418 |
+
" tensorboard:\n",
|
419 |
+
" save_dir: ${paths.output_dir}/tensorboard/\n",
|
420 |
+
" name: null\n",
|
421 |
+
" log_graph: false\n",
|
422 |
+
" default_hp_metric: true\n",
|
423 |
+
" prefix: ''\n",
|
424 |
+
"trainer:\n",
|
425 |
+
" _target_: lightning.Trainer\n",
|
426 |
+
" default_root_dir: ${paths.output_dir}\n",
|
427 |
+
" min_epochs: 1\n",
|
428 |
+
" max_epochs: 6\n",
|
429 |
+
" accelerator: auto\n",
|
430 |
+
" devices: auto\n",
|
431 |
+
" deterministic: true\n",
|
432 |
+
" log_every_n_steps: 10\n",
|
433 |
+
" fast_dev_run: false\n",
|
434 |
+
"paths:\n",
|
435 |
+
" root_dir: ${oc.env:PROJECT_ROOT}\n",
|
436 |
+
" data_dir: ${paths.root_dir}/data/\n",
|
437 |
+
" log_dir: ${paths.root_dir}/logs/\n",
|
438 |
+
" ckpt_dir: ${paths.root_dir}/checkpoints\n",
|
439 |
+
" artifact_dir: ${paths.root_dir}/artifacts/\n",
|
440 |
+
" data_url: https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\n",
|
441 |
+
" output_dir: ${hydra:runtime.output_dir}\n",
|
442 |
+
" work_dir: ${hydra:runtime.cwd}\n",
|
443 |
+
"\u001b[0m\n"
|
444 |
+
]
|
445 |
+
}
|
446 |
+
],
|
447 |
+
"source": [
|
448 |
+
"# Initialize logger\n",
|
449 |
+
"if cfg.task_name == \"train\":\n",
|
450 |
+
" log_path = Path(cfg.paths.log_dir) / \"train.log\"\n",
|
451 |
+
"else:\n",
|
452 |
+
" log_path = Path(cfg.paths.log_dir) / \"eval.log\"\n",
|
453 |
+
"setup_logger(log_path)\n",
|
454 |
+
"\n",
|
455 |
+
"logger.info(f\"Whole Config:\\n{OmegaConf.to_yaml(cfg)}\")"
|
456 |
+
]
|
457 |
+
},
|
458 |
+
{
|
459 |
+
"cell_type": "code",
|
460 |
+
"execution_count": 8,
|
461 |
+
"metadata": {},
|
462 |
+
"outputs": [
|
463 |
+
{
|
464 |
+
"name": "stderr",
|
465 |
+
"output_type": "stream",
|
466 |
+
"text": [
|
467 |
+
"\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m3\u001b[0m - \u001b[1mRoot directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws\u001b[0m\n",
|
468 |
+
"\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m5\u001b[0m - \u001b[1mCurrent working directory: ['.dvc', '.dvcignore', '.env', '.git', '.github', '.gitignore', '.project-root', 'aws', 'basic_setup.md', 'configs', 'data', 'data.dvc', 'docker-compose.yaml', 'Dockerfile', 'ec2_runner_setup.md', 'logs', 'main.py', 'notebooks', 'poetry.lock', 'pyproject.toml', 'README.md', 'setup_aws_ci.md', 'src', 'tests', 'todo.md']\u001b[0m\n",
|
469 |
+
"\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m8\u001b[0m - \u001b[1mCheckpoint directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/checkpoints\u001b[0m\n",
|
470 |
+
"\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m12\u001b[0m - \u001b[1mData directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/data/\u001b[0m\n",
|
471 |
+
"\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m16\u001b[0m - \u001b[1mLog directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/logs/\u001b[0m\n",
|
472 |
+
"\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m20\u001b[0m - \u001b[1mArtifact directory: /mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws/artifacts/\u001b[0m\n",
|
473 |
+
"\u001b[32m2024-11-08 18:25:25\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m28\u001b[0m - \u001b[1mExperiment name: catdog_experiment\u001b[0m\n"
|
474 |
+
]
|
475 |
+
}
|
476 |
+
],
|
477 |
+
"source": [
|
478 |
+
"# the path to the checkpoint directory\n",
|
479 |
+
"root_dir = cfg.paths.root_dir\n",
|
480 |
+
"logger.info(f\"Root directory: {root_dir}\")\n",
|
481 |
+
"\n",
|
482 |
+
"logger.info(f\"Current working directory: {os.listdir(root_dir)}\")\n",
|
483 |
+
"\n",
|
484 |
+
"ckpt_dir = cfg.paths.ckpt_dir\n",
|
485 |
+
"logger.info(f\"Checkpoint directory: {ckpt_dir}\")\n",
|
486 |
+
"\n",
|
487 |
+
"# the path to the data directory\n",
|
488 |
+
"data_dir = cfg.paths.data_dir\n",
|
489 |
+
"logger.info(f\"Data directory: {data_dir}\")\n",
|
490 |
+
"\n",
|
491 |
+
"# the path to the log directory\n",
|
492 |
+
"log_dir = cfg.paths.log_dir\n",
|
493 |
+
"logger.info(f\"Log directory: {log_dir}\")\n",
|
494 |
+
"\n",
|
495 |
+
"# the path to the artifact directory\n",
|
496 |
+
"artifact_dir = cfg.paths.artifact_dir\n",
|
497 |
+
"logger.info(f\"Artifact directory: {artifact_dir}\")\n",
|
498 |
+
"\n",
|
499 |
+
"# output directory\n",
|
500 |
+
"# output_dir = cfg.paths.output_dir\n",
|
501 |
+
"# logger.info(f\"Output directory: {output_dir}\")\n",
|
502 |
+
"\n",
|
503 |
+
"# name of the experiment\n",
|
504 |
+
"experiment_name = cfg.name\n",
|
505 |
+
"logger.info(f\"Experiment name: {experiment_name}\")\n"
|
506 |
+
]
|
507 |
+
},
|
508 |
+
{
|
509 |
+
"cell_type": "code",
|
510 |
+
"execution_count": 9,
|
511 |
+
"metadata": {},
|
512 |
+
"outputs": [
|
513 |
+
{
|
514 |
+
"name": "stderr",
|
515 |
+
"output_type": "stream",
|
516 |
+
"text": [
|
517 |
+
"\u001b[32m2024-11-08 18:25:28\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m2\u001b[0m - \u001b[1mInstantiating datamodule <src.datamodules.catdog_datamodule.CatDogImageDataModule>\u001b[0m\n"
|
518 |
+
]
|
519 |
+
}
|
520 |
+
],
|
521 |
+
"source": [
|
522 |
+
"# Initialize DataModule\n",
|
523 |
+
"logger.info(f\"Instantiating datamodule <{cfg.data._target_}>\")\n",
|
524 |
+
"datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)"
|
525 |
+
]
|
526 |
+
},
|
527 |
+
{
|
528 |
+
"cell_type": "code",
|
529 |
+
"execution_count": 10,
|
530 |
+
"metadata": {},
|
531 |
+
"outputs": [
|
532 |
+
{
|
533 |
+
"name": "stderr",
|
534 |
+
"output_type": "stream",
|
535 |
+
"text": [
|
536 |
+
"\u001b[32m2024-11-08 18:25:28\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m2\u001b[0m - \u001b[1mNo GPU available\u001b[0m\n",
|
537 |
+
"Seed set to 42\n"
|
538 |
+
]
|
539 |
+
},
|
540 |
+
{
|
541 |
+
"data": {
|
542 |
+
"text/plain": [
|
543 |
+
"42"
|
544 |
+
]
|
545 |
+
},
|
546 |
+
"execution_count": 10,
|
547 |
+
"metadata": {},
|
548 |
+
"output_type": "execute_result"
|
549 |
+
}
|
550 |
+
],
|
551 |
+
"source": [
|
552 |
+
"# Check for GPU availability\n",
|
553 |
+
"logger.info(\"GPU available\" if torch.cuda.is_available() else \"No GPU available\")\n",
|
554 |
+
"\n",
|
555 |
+
"# Set seed for reproducibility\n",
|
556 |
+
"L.seed_everything(cfg.seed, workers=True)"
|
557 |
+
]
|
558 |
+
},
|
559 |
+
{
|
560 |
+
"cell_type": "code",
|
561 |
+
"execution_count": 11,
|
562 |
+
"metadata": {},
|
563 |
+
"outputs": [
|
564 |
+
{
|
565 |
+
"name": "stderr",
|
566 |
+
"output_type": "stream",
|
567 |
+
"text": [
|
568 |
+
"\u001b[32m2024-11-08 18:25:29\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m2\u001b[0m - \u001b[1mInstantiating model <src.models.catdog_model.ViTTinyClassifier>\u001b[0m\n"
|
569 |
+
]
|
570 |
+
}
|
571 |
+
],
|
572 |
+
"source": [
|
573 |
+
"# Initialize model\n",
|
574 |
+
"logger.info(f\"Instantiating model <{cfg.model._target_}>\")\n",
|
575 |
+
"model: L.LightningModule = hydra.utils.instantiate(cfg.model)"
|
576 |
+
]
|
577 |
+
},
|
578 |
+
{
|
579 |
+
"cell_type": "code",
|
580 |
+
"execution_count": 12,
|
581 |
+
"metadata": {},
|
582 |
+
"outputs": [
|
583 |
+
{
|
584 |
+
"name": "stderr",
|
585 |
+
"output_type": "stream",
|
586 |
+
"text": [
|
587 |
+
"\u001b[32m2024-11-08 18:25:30\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m1\u001b[0m - \u001b[1mModel summary:\n",
|
588 |
+
"ViTTinyClassifier(\n",
|
589 |
+
" (model): VisionTransformer(\n",
|
590 |
+
" (patch_embed): PatchEmbed(\n",
|
591 |
+
" (proj): Conv2d(3, 64, kernel_size=(16, 16), stride=(16, 16))\n",
|
592 |
+
" (norm): Identity()\n",
|
593 |
+
" )\n",
|
594 |
+
" (pos_drop): Dropout(p=0.0, inplace=False)\n",
|
595 |
+
" (patch_drop): Identity()\n",
|
596 |
+
" (norm_pre): Identity()\n",
|
597 |
+
" (blocks): Sequential(\n",
|
598 |
+
" (0): Block(\n",
|
599 |
+
" (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
600 |
+
" (attn): Attention(\n",
|
601 |
+
" (qkv): Linear(in_features=64, out_features=192, bias=False)\n",
|
602 |
+
" (q_norm): Identity()\n",
|
603 |
+
" (k_norm): Identity()\n",
|
604 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
605 |
+
" (proj): Linear(in_features=64, out_features=64, bias=True)\n",
|
606 |
+
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
|
607 |
+
" )\n",
|
608 |
+
" (ls1): Identity()\n",
|
609 |
+
" (drop_path1): Identity()\n",
|
610 |
+
" (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
611 |
+
" (mlp): Mlp(\n",
|
612 |
+
" (fc1): Linear(in_features=64, out_features=192, bias=True)\n",
|
613 |
+
" (act): GELU(approximate='none')\n",
|
614 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
615 |
+
" (norm): Identity()\n",
|
616 |
+
" (fc2): Linear(in_features=192, out_features=64, bias=True)\n",
|
617 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
618 |
+
" )\n",
|
619 |
+
" (ls2): Identity()\n",
|
620 |
+
" (drop_path2): Identity()\n",
|
621 |
+
" )\n",
|
622 |
+
" (1): Block(\n",
|
623 |
+
" (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
624 |
+
" (attn): Attention(\n",
|
625 |
+
" (qkv): Linear(in_features=64, out_features=192, bias=False)\n",
|
626 |
+
" (q_norm): Identity()\n",
|
627 |
+
" (k_norm): Identity()\n",
|
628 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
629 |
+
" (proj): Linear(in_features=64, out_features=64, bias=True)\n",
|
630 |
+
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
|
631 |
+
" )\n",
|
632 |
+
" (ls1): Identity()\n",
|
633 |
+
" (drop_path1): Identity()\n",
|
634 |
+
" (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
635 |
+
" (mlp): Mlp(\n",
|
636 |
+
" (fc1): Linear(in_features=64, out_features=192, bias=True)\n",
|
637 |
+
" (act): GELU(approximate='none')\n",
|
638 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
639 |
+
" (norm): Identity()\n",
|
640 |
+
" (fc2): Linear(in_features=192, out_features=64, bias=True)\n",
|
641 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
642 |
+
" )\n",
|
643 |
+
" (ls2): Identity()\n",
|
644 |
+
" (drop_path2): Identity()\n",
|
645 |
+
" )\n",
|
646 |
+
" (2): Block(\n",
|
647 |
+
" (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
648 |
+
" (attn): Attention(\n",
|
649 |
+
" (qkv): Linear(in_features=64, out_features=192, bias=False)\n",
|
650 |
+
" (q_norm): Identity()\n",
|
651 |
+
" (k_norm): Identity()\n",
|
652 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
653 |
+
" (proj): Linear(in_features=64, out_features=64, bias=True)\n",
|
654 |
+
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
|
655 |
+
" )\n",
|
656 |
+
" (ls1): Identity()\n",
|
657 |
+
" (drop_path1): Identity()\n",
|
658 |
+
" (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
659 |
+
" (mlp): Mlp(\n",
|
660 |
+
" (fc1): Linear(in_features=64, out_features=192, bias=True)\n",
|
661 |
+
" (act): GELU(approximate='none')\n",
|
662 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
663 |
+
" (norm): Identity()\n",
|
664 |
+
" (fc2): Linear(in_features=192, out_features=64, bias=True)\n",
|
665 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
666 |
+
" )\n",
|
667 |
+
" (ls2): Identity()\n",
|
668 |
+
" (drop_path2): Identity()\n",
|
669 |
+
" )\n",
|
670 |
+
" (3): Block(\n",
|
671 |
+
" (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
672 |
+
" (attn): Attention(\n",
|
673 |
+
" (qkv): Linear(in_features=64, out_features=192, bias=False)\n",
|
674 |
+
" (q_norm): Identity()\n",
|
675 |
+
" (k_norm): Identity()\n",
|
676 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
677 |
+
" (proj): Linear(in_features=64, out_features=64, bias=True)\n",
|
678 |
+
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
|
679 |
+
" )\n",
|
680 |
+
" (ls1): Identity()\n",
|
681 |
+
" (drop_path1): Identity()\n",
|
682 |
+
" (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
683 |
+
" (mlp): Mlp(\n",
|
684 |
+
" (fc1): Linear(in_features=64, out_features=192, bias=True)\n",
|
685 |
+
" (act): GELU(approximate='none')\n",
|
686 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
687 |
+
" (norm): Identity()\n",
|
688 |
+
" (fc2): Linear(in_features=192, out_features=64, bias=True)\n",
|
689 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
690 |
+
" )\n",
|
691 |
+
" (ls2): Identity()\n",
|
692 |
+
" (drop_path2): Identity()\n",
|
693 |
+
" )\n",
|
694 |
+
" (4): Block(\n",
|
695 |
+
" (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
696 |
+
" (attn): Attention(\n",
|
697 |
+
" (qkv): Linear(in_features=64, out_features=192, bias=False)\n",
|
698 |
+
" (q_norm): Identity()\n",
|
699 |
+
" (k_norm): Identity()\n",
|
700 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
701 |
+
" (proj): Linear(in_features=64, out_features=64, bias=True)\n",
|
702 |
+
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
|
703 |
+
" )\n",
|
704 |
+
" (ls1): Identity()\n",
|
705 |
+
" (drop_path1): Identity()\n",
|
706 |
+
" (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
707 |
+
" (mlp): Mlp(\n",
|
708 |
+
" (fc1): Linear(in_features=64, out_features=192, bias=True)\n",
|
709 |
+
" (act): GELU(approximate='none')\n",
|
710 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
711 |
+
" (norm): Identity()\n",
|
712 |
+
" (fc2): Linear(in_features=192, out_features=64, bias=True)\n",
|
713 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
714 |
+
" )\n",
|
715 |
+
" (ls2): Identity()\n",
|
716 |
+
" (drop_path2): Identity()\n",
|
717 |
+
" )\n",
|
718 |
+
" (5): Block(\n",
|
719 |
+
" (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
720 |
+
" (attn): Attention(\n",
|
721 |
+
" (qkv): Linear(in_features=64, out_features=192, bias=False)\n",
|
722 |
+
" (q_norm): Identity()\n",
|
723 |
+
" (k_norm): Identity()\n",
|
724 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
725 |
+
" (proj): Linear(in_features=64, out_features=64, bias=True)\n",
|
726 |
+
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
|
727 |
+
" )\n",
|
728 |
+
" (ls1): Identity()\n",
|
729 |
+
" (drop_path1): Identity()\n",
|
730 |
+
" (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
731 |
+
" (mlp): Mlp(\n",
|
732 |
+
" (fc1): Linear(in_features=64, out_features=192, bias=True)\n",
|
733 |
+
" (act): GELU(approximate='none')\n",
|
734 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
735 |
+
" (norm): Identity()\n",
|
736 |
+
" (fc2): Linear(in_features=192, out_features=64, bias=True)\n",
|
737 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
738 |
+
" )\n",
|
739 |
+
" (ls2): Identity()\n",
|
740 |
+
" (drop_path2): Identity()\n",
|
741 |
+
" )\n",
|
742 |
+
" )\n",
|
743 |
+
" (norm): LayerNorm((64,), eps=1e-06, elementwise_affine=True)\n",
|
744 |
+
" (fc_norm): Identity()\n",
|
745 |
+
" (head_drop): Dropout(p=0.0, inplace=False)\n",
|
746 |
+
" (head): Linear(in_features=64, out_features=2, bias=True)\n",
|
747 |
+
" )\n",
|
748 |
+
" (train_metrics): ModuleDict(\n",
|
749 |
+
" (accuracy): MulticlassAccuracy()\n",
|
750 |
+
" (precision): MulticlassPrecision()\n",
|
751 |
+
" (recall): MulticlassRecall()\n",
|
752 |
+
" (f1): MulticlassF1Score()\n",
|
753 |
+
" )\n",
|
754 |
+
" (val_metrics): ModuleDict(\n",
|
755 |
+
" (accuracy): MulticlassAccuracy()\n",
|
756 |
+
" (precision): MulticlassPrecision()\n",
|
757 |
+
" (recall): MulticlassRecall()\n",
|
758 |
+
" (f1): MulticlassF1Score()\n",
|
759 |
+
" )\n",
|
760 |
+
" (test_metrics): ModuleDict(\n",
|
761 |
+
" (accuracy): MulticlassAccuracy()\n",
|
762 |
+
" (precision): MulticlassPrecision()\n",
|
763 |
+
" (recall): MulticlassRecall()\n",
|
764 |
+
" (f1): MulticlassF1Score()\n",
|
765 |
+
" )\n",
|
766 |
+
" (criterion): CrossEntropyLoss()\n",
|
767 |
+
")\u001b[0m\n"
|
768 |
+
]
|
769 |
+
}
|
770 |
+
],
|
771 |
+
"source": [
|
772 |
+
"logger.info(f\"Model summary:\\n{model}\")"
|
773 |
+
]
|
774 |
+
},
|
775 |
+
{
|
776 |
+
"cell_type": "code",
|
777 |
+
"execution_count": 13,
|
778 |
+
"metadata": {},
|
779 |
+
"outputs": [],
|
780 |
+
"source": [
|
781 |
+
"def initialize_callbacks(cfg: DictConfig) -> List[L.Callback]:\n",
|
782 |
+
" \"\"\"Initialize the callbacks based on the configuration.\"\"\"\n",
|
783 |
+
" if not cfg:\n",
|
784 |
+
" logger.warning(\"No callback configs found! Skipping..\")\n",
|
785 |
+
" return callbacks\n",
|
786 |
+
"\n",
|
787 |
+
" if not isinstance(cfg, DictConfig):\n",
|
788 |
+
" raise TypeError(\"Callbacks config must be a DictConfig!\")\n",
|
789 |
+
" callbacks = []\n",
|
790 |
+
"\n",
|
791 |
+
" # Initialize the model checkpoint callback\n",
|
792 |
+
" model_checkpoint = ModelCheckpoint(**cfg.callbacks.model_checkpoint)\n",
|
793 |
+
" callbacks.append(model_checkpoint)\n",
|
794 |
+
"\n",
|
795 |
+
" # Initialize the early stopping callback\n",
|
796 |
+
" early_stopping = EarlyStopping(**cfg.callbacks.early_stopping)\n",
|
797 |
+
" callbacks.append(early_stopping)\n",
|
798 |
+
"\n",
|
799 |
+
" # Initialize the rich model summary callback\n",
|
800 |
+
" model_summary = RichModelSummary(**cfg.callbacks.rich_model_summary)\n",
|
801 |
+
" callbacks.append(model_summary)\n",
|
802 |
+
"\n",
|
803 |
+
" # Initialize the rich progress bar callback\n",
|
804 |
+
" progress_bar = RichProgressBar(**cfg.callbacks.rich_progress_bar)\n",
|
805 |
+
" callbacks.append(progress_bar)\n",
|
806 |
+
"\n",
|
807 |
+
" return callbacks\n",
|
808 |
+
"\n",
|
809 |
+
"\n",
|
810 |
+
"def initialize_logger(cfg: DictConfig) -> Logger:\n",
|
811 |
+
" \"\"\"Initialize the logger based on the configuration.\"\"\"\n",
|
812 |
+
" if not cfg:\n",
|
813 |
+
" logger.warning(\"No logger configs found! Skipping..\")\n",
|
814 |
+
" return None\n",
|
815 |
+
"\n",
|
816 |
+
" if not isinstance(cfg, DictConfig):\n",
|
817 |
+
" raise TypeError(\"Logger config must be a DictConfig!\")\n",
|
818 |
+
"\n",
|
819 |
+
" loggers = []\n",
|
820 |
+
"\n",
|
821 |
+
" # Initialize the TensorBoard logger\n",
|
822 |
+
" tensorboard_logger = TensorBoardLogger(**cfg.loggers.tensorboard)\n",
|
823 |
+
" loggers.append(tensorboard_logger)\n",
|
824 |
+
"\n",
|
825 |
+
" # Initialize the CSV logger\n",
|
826 |
+
" csv_logger = CSVLogger(**cfg.loggers.csv)\n",
|
827 |
+
" loggers.append(csv_logger)\n",
|
828 |
+
"\n",
|
829 |
+
" return loggers"
|
830 |
+
]
|
831 |
+
},
|
832 |
+
{
|
833 |
+
"cell_type": "code",
|
834 |
+
"execution_count": 1,
|
835 |
+
"metadata": {},
|
836 |
+
"outputs": [
|
837 |
+
{
|
838 |
+
"name": "stderr",
|
839 |
+
"output_type": "stream",
|
840 |
+
"text": [
|
841 |
+
"/anaconda/envs/emlo_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
842 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
843 |
+
]
|
844 |
+
},
|
845 |
+
{
|
846 |
+
"name": "stdout",
|
847 |
+
"output_type": "stream",
|
848 |
+
"text": [
|
849 |
+
"['bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', 'beitv2_large_patch16_224', 'botnet26t_256', 'botnet50ts_256', 'caformer_b36', 'caformer_m36', 'caformer_s18', 'caformer_s36', 'cait_m36_384', 'cait_m48_448', 'cait_s24_224', 'cait_s24_384', 'cait_s36_384', 'cait_xs24_384', 'cait_xxs24_224', 'cait_xxs24_384', 'cait_xxs36_224', 'cait_xxs36_384', 'coat_lite_medium', 'coat_lite_medium_384', 'coat_lite_mini', 'coat_lite_small', 'coat_lite_tiny', 'coat_mini', 'coat_small', 'coat_tiny', 'coatnet_0_224', 'coatnet_0_rw_224', 'coatnet_1_224', 'coatnet_1_rw_224', 'coatnet_2_224', 'coatnet_2_rw_224', 'coatnet_3_224', 'coatnet_3_rw_224', 'coatnet_4_224', 'coatnet_5_224', 'coatnet_bn_0_rw_224', 'coatnet_nano_cc_224', 'coatnet_nano_rw_224', 'coatnet_pico_rw_224', 'coatnet_rmlp_0_rw_224', 'coatnet_rmlp_1_rw2_224', 'coatnet_rmlp_1_rw_224', 'coatnet_rmlp_2_rw_224', 'coatnet_rmlp_2_rw_384', 'coatnet_rmlp_3_rw_224', 'coatnet_rmlp_nano_rw_224', 'coatnext_nano_rw_224', 'convformer_b36', 'convformer_m36', 'convformer_s18', 'convformer_s36', 'convit_base', 'convit_small', 'convit_tiny', 'convmixer_768_32', 'convmixer_1024_20_ks9_p14', 'convmixer_1536_20', 'convnext_atto', 'convnext_atto_ols', 'convnext_base', 'convnext_femto', 'convnext_femto_ols', 'convnext_large', 'convnext_large_mlp', 'convnext_nano', 'convnext_nano_ols', 'convnext_pico', 'convnext_pico_ols', 'convnext_small', 'convnext_tiny', 'convnext_tiny_hnf', 'convnext_xlarge', 'convnext_xxlarge', 'convnextv2_atto', 'convnextv2_base', 'convnextv2_femto', 'convnextv2_huge', 'convnextv2_large', 'convnextv2_nano', 'convnextv2_pico', 'convnextv2_small', 'convnextv2_tiny', 'crossvit_9_240', 'crossvit_9_dagger_240', 'crossvit_15_240', 'crossvit_15_dagger_240', 'crossvit_15_dagger_408', 'crossvit_18_240', 'crossvit_18_dagger_240', 'crossvit_18_dagger_408', 'crossvit_base_240', 'crossvit_small_240', 'crossvit_tiny_240', 'cs3darknet_focus_l', 'cs3darknet_focus_m', 'cs3darknet_focus_s', 'cs3darknet_focus_x', 'cs3darknet_l', 'cs3darknet_m', 'cs3darknet_s', 'cs3darknet_x', 'cs3edgenet_x', 'cs3se_edgenet_x', 'cs3sedarknet_l', 'cs3sedarknet_x', 'cs3sedarknet_xdw', 'cspdarknet53', 'cspresnet50', 'cspresnet50d', 'cspresnet50w', 'cspresnext50', 'darknet17', 'darknet21', 'darknet53', 'darknetaa53', 'davit_base', 'davit_base_fl', 'davit_giant', 'davit_huge', 'davit_huge_fl', 'davit_large', 'davit_small', 'davit_tiny', 'deit3_base_patch16_224', 'deit3_base_patch16_384', 'deit3_huge_patch14_224', 'deit3_large_patch16_224', 'deit3_large_patch16_384', 'deit3_medium_patch16_224', 'deit3_small_patch16_224', 'deit3_small_patch16_384', 'deit_base_distilled_patch16_224', 'deit_base_distilled_patch16_384', 'deit_base_patch16_224', 'deit_base_patch16_384', 'deit_small_distilled_patch16_224', 'deit_small_patch16_224', 'deit_tiny_distilled_patch16_224', 'deit_tiny_patch16_224', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'densenet264d', 'densenetblur121d', 'dla34', 'dla46_c', 'dla46x_c', 'dla60', 'dla60_res2net', 'dla60_res2next', 'dla60x', 'dla60x_c', 'dla102', 'dla102x', 'dla102x2', 'dla169', 'dm_nfnet_f0', 'dm_nfnet_f1', 'dm_nfnet_f2', 'dm_nfnet_f3', 'dm_nfnet_f4', 'dm_nfnet_f5', 'dm_nfnet_f6', 'dpn48b', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn107', 'dpn131', 'eca_botnext26ts_256', 'eca_halonext26ts', 'eca_nfnet_l0', 'eca_nfnet_l1', 'eca_nfnet_l2', 'eca_nfnet_l3', 'eca_resnet33ts', 'eca_resnext26ts', 'eca_vovnet39b', 'ecaresnet26t', 'ecaresnet50d', 'ecaresnet50d_pruned', 'ecaresnet50t', 'ecaresnet101d', 'ecaresnet101d_pruned', 'ecaresnet200d', 'ecaresnet269d', 'ecaresnetlight', 'ecaresnext26t_32x4d', 'ecaresnext50t_32x4d', 'edgenext_base', 'edgenext_small', 'edgenext_small_rw', 'edgenext_x_small', 'edgenext_xx_small', 'efficientformer_l1', 'efficientformer_l3', 'efficientformer_l7', 'efficientformerv2_l', 'efficientformerv2_s0', 'efficientformerv2_s1', 'efficientformerv2_s2', 'efficientnet_b0', 'efficientnet_b0_g8_gn', 'efficientnet_b0_g16_evos', 'efficientnet_b0_gn', 'efficientnet_b1', 'efficientnet_b1_pruned', 'efficientnet_b2', 'efficientnet_b2_pruned', 'efficientnet_b3', 'efficientnet_b3_g8_gn', 'efficientnet_b3_gn', 'efficientnet_b3_pruned', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', 'efficientnet_blur_b0', 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', 'efficientnet_el', 'efficientnet_el_pruned', 'efficientnet_em', 'efficientnet_es', 'efficientnet_es_pruned', 'efficientnet_h_b5', 'efficientnet_l2', 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', 'efficientnet_x_b3', 'efficientnet_x_b5', 'efficientnetv2_l', 'efficientnetv2_m', 'efficientnetv2_rw_m', 'efficientnetv2_rw_s', 'efficientnetv2_rw_t', 'efficientnetv2_s', 'efficientnetv2_xl', 'efficientvit_b0', 'efficientvit_b1', 'efficientvit_b2', 'efficientvit_b3', 'efficientvit_l1', 'efficientvit_l2', 'efficientvit_l3', 'efficientvit_m0', 'efficientvit_m1', 'efficientvit_m2', 'efficientvit_m3', 'efficientvit_m4', 'efficientvit_m5', 'ese_vovnet19b_dw', 'ese_vovnet19b_slim', 'ese_vovnet19b_slim_dw', 'ese_vovnet39b', 'ese_vovnet39b_evos', 'ese_vovnet57b', 'ese_vovnet99b', 'eva02_base_patch14_224', 'eva02_base_patch14_448', 'eva02_base_patch16_clip_224', 'eva02_enormous_patch14_clip_224', 'eva02_large_patch14_224', 'eva02_large_patch14_448', 'eva02_large_patch14_clip_224', 'eva02_large_patch14_clip_336', 'eva02_small_patch14_224', 'eva02_small_patch14_336', 'eva02_tiny_patch14_224', 'eva02_tiny_patch14_336', 'eva_giant_patch14_224', 'eva_giant_patch14_336', 'eva_giant_patch14_560', 'eva_giant_patch14_clip_224', 'eva_large_patch14_196', 'eva_large_patch14_336', 'fastvit_ma36', 'fastvit_mci0', 'fastvit_mci1', 'fastvit_mci2', 'fastvit_s12', 'fastvit_sa12', 'fastvit_sa24', 'fastvit_sa36', 'fastvit_t8', 'fastvit_t12', 'fbnetc_100', 'fbnetv3_b', 'fbnetv3_d', 'fbnetv3_g', 'flexivit_base', 'flexivit_large', 'flexivit_small', 'focalnet_base_lrf', 'focalnet_base_srf', 'focalnet_huge_fl3', 'focalnet_huge_fl4', 'focalnet_large_fl3', 'focalnet_large_fl4', 'focalnet_small_lrf', 'focalnet_small_srf', 'focalnet_tiny_lrf', 'focalnet_tiny_srf', 'focalnet_xlarge_fl3', 'focalnet_xlarge_fl4', 'gc_efficientnetv2_rw_t', 'gcresnet33ts', 'gcresnet50t', 'gcresnext26ts', 'gcresnext50ts', 'gcvit_base', 'gcvit_small', 'gcvit_tiny', 'gcvit_xtiny', 'gcvit_xxtiny', 'gernet_l', 'gernet_m', 'gernet_s', 'ghostnet_050', 'ghostnet_100', 'ghostnet_130', 'ghostnetv2_100', 'ghostnetv2_130', 'ghostnetv2_160', 'gmixer_12_224', 'gmixer_24_224', 'gmlp_b16_224', 'gmlp_s16_224', 'gmlp_ti16_224', 'halo2botnet50ts_256', 'halonet26t', 'halonet50ts', 'halonet_h1', 'haloregnetz_b', 'hardcorenas_a', 'hardcorenas_b', 'hardcorenas_c', 'hardcorenas_d', 'hardcorenas_e', 'hardcorenas_f', 'hgnet_base', 'hgnet_small', 'hgnet_tiny', 'hgnetv2_b0', 'hgnetv2_b1', 'hgnetv2_b2', 'hgnetv2_b3', 'hgnetv2_b4', 'hgnetv2_b5', 'hgnetv2_b6', 'hiera_base_224', 'hiera_base_abswin_256', 'hiera_base_plus_224', 'hiera_huge_224', 'hiera_large_224', 'hiera_small_224', 'hiera_small_abswin_256', 'hiera_tiny_224', 'hieradet_small', 'hrnet_w18', 'hrnet_w18_small', 'hrnet_w18_small_v2', 'hrnet_w18_ssld', 'hrnet_w30', 'hrnet_w32', 'hrnet_w40', 'hrnet_w44', 'hrnet_w48', 'hrnet_w48_ssld', 'hrnet_w64', 'inception_next_base', 'inception_next_small', 'inception_next_tiny', 'inception_resnet_v2', 'inception_v3', 'inception_v4', 'lambda_resnet26rpt_256', 'lambda_resnet26t', 'lambda_resnet50ts', 'lamhalobotnet50ts_256', 'lcnet_035', 'lcnet_050', 'lcnet_075', 'lcnet_100', 'lcnet_150', 'legacy_senet154', 'legacy_seresnet18', 'legacy_seresnet34', 'legacy_seresnet50', 'legacy_seresnet101', 'legacy_seresnet152', 'legacy_seresnext26_32x4d', 'legacy_seresnext50_32x4d', 'legacy_seresnext101_32x4d', 'legacy_xception', 'levit_128', 'levit_128s', 'levit_192', 'levit_256', 'levit_256d', 'levit_384', 'levit_384_s8', 'levit_512', 'levit_512_s8', 'levit_512d', 'levit_conv_128', 'levit_conv_128s', 'levit_conv_192', 'levit_conv_256', 'levit_conv_256d', 'levit_conv_384', 'levit_conv_384_s8', 'levit_conv_512', 'levit_conv_512_s8', 'levit_conv_512d', 'maxvit_base_tf_224', 'maxvit_base_tf_384', 'maxvit_base_tf_512', 'maxvit_large_tf_224', 'maxvit_large_tf_384', 'maxvit_large_tf_512', 'maxvit_nano_rw_256', 'maxvit_pico_rw_256', 'maxvit_rmlp_base_rw_224', 'maxvit_rmlp_base_rw_384', 'maxvit_rmlp_nano_rw_256', 'maxvit_rmlp_pico_rw_256', 'maxvit_rmlp_small_rw_224', 'maxvit_rmlp_small_rw_256', 'maxvit_rmlp_tiny_rw_256', 'maxvit_small_tf_224', 'maxvit_small_tf_384', 'maxvit_small_tf_512', 'maxvit_tiny_pm_256', 'maxvit_tiny_rw_224', 'maxvit_tiny_rw_256', 'maxvit_tiny_tf_224', 'maxvit_tiny_tf_384', 'maxvit_tiny_tf_512', 'maxvit_xlarge_tf_224', 'maxvit_xlarge_tf_384', 'maxvit_xlarge_tf_512', 'maxxvit_rmlp_nano_rw_256', 'maxxvit_rmlp_small_rw_256', 'maxxvit_rmlp_tiny_rw_256', 'maxxvitv2_nano_rw_256', 'maxxvitv2_rmlp_base_rw_224', 'maxxvitv2_rmlp_base_rw_384', 'maxxvitv2_rmlp_large_rw_224', 'mixer_b16_224', 'mixer_b32_224', 'mixer_l16_224', 'mixer_l32_224', 'mixer_s16_224', 'mixer_s32_224', 'mixnet_l', 'mixnet_m', 'mixnet_s', 'mixnet_xl', 'mixnet_xxl', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140', 'mnasnet_small', 'mobilenet_edgetpu_100', 'mobilenet_edgetpu_v2_l', 'mobilenet_edgetpu_v2_m', 'mobilenet_edgetpu_v2_s', 'mobilenet_edgetpu_v2_xs', 'mobilenetv1_100', 'mobilenetv1_100h', 'mobilenetv1_125', 'mobilenetv2_035', 'mobilenetv2_050', 'mobilenetv2_075', 'mobilenetv2_100', 'mobilenetv2_110d', 'mobilenetv2_120d', 'mobilenetv2_140', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_150d', 'mobilenetv3_rw', 'mobilenetv3_small_050', 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv4_conv_aa_large', 'mobilenetv4_conv_aa_medium', 'mobilenetv4_conv_blur_medium', 'mobilenetv4_conv_large', 'mobilenetv4_conv_medium', 'mobilenetv4_conv_small', 'mobilenetv4_hybrid_large', 'mobilenetv4_hybrid_large_075', 'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_medium_075', 'mobileone_s0', 'mobileone_s1', 'mobileone_s2', 'mobileone_s3', 'mobileone_s4', 'mobilevit_s', 'mobilevit_xs', 'mobilevit_xxs', 'mobilevitv2_050', 'mobilevitv2_075', 'mobilevitv2_100', 'mobilevitv2_125', 'mobilevitv2_150', 'mobilevitv2_175', 'mobilevitv2_200', 'mvitv2_base', 'mvitv2_base_cls', 'mvitv2_huge_cls', 'mvitv2_large', 'mvitv2_large_cls', 'mvitv2_small', 'mvitv2_small_cls', 'mvitv2_tiny', 'nasnetalarge', 'nest_base', 'nest_base_jx', 'nest_small', 'nest_small_jx', 'nest_tiny', 'nest_tiny_jx', 'nextvit_base', 'nextvit_large', 'nextvit_small', 'nf_ecaresnet26', 'nf_ecaresnet50', 'nf_ecaresnet101', 'nf_regnet_b0', 'nf_regnet_b1', 'nf_regnet_b2', 'nf_regnet_b3', 'nf_regnet_b4', 'nf_regnet_b5', 'nf_resnet26', 'nf_resnet50', 'nf_resnet101', 'nf_seresnet26', 'nf_seresnet50', 'nf_seresnet101', 'nfnet_f0', 'nfnet_f1', 'nfnet_f2', 'nfnet_f3', 'nfnet_f4', 'nfnet_f5', 'nfnet_f6', 'nfnet_f7', 'nfnet_l0', 'pit_b_224', 'pit_b_distilled_224', 'pit_s_224', 'pit_s_distilled_224', 'pit_ti_224', 'pit_ti_distilled_224', 'pit_xs_224', 'pit_xs_distilled_224', 'pnasnet5large', 'poolformer_m36', 'poolformer_m48', 'poolformer_s12', 'poolformer_s24', 'poolformer_s36', 'poolformerv2_m36', 'poolformerv2_m48', 'poolformerv2_s12', 'poolformerv2_s24', 'poolformerv2_s36', 'pvt_v2_b0', 'pvt_v2_b1', 'pvt_v2_b2', 'pvt_v2_b2_li', 'pvt_v2_b3', 'pvt_v2_b4', 'pvt_v2_b5', 'rdnet_base', 'rdnet_large', 'rdnet_small', 'rdnet_tiny', 'regnetv_040', 'regnetv_064', 'regnetx_002', 'regnetx_004', 'regnetx_004_tv', 'regnetx_006', 'regnetx_008', 'regnetx_016', 'regnetx_032', 'regnetx_040', 'regnetx_064', 'regnetx_080', 'regnetx_120', 'regnetx_160', 'regnetx_320', 'regnety_002', 'regnety_004', 'regnety_006', 'regnety_008', 'regnety_008_tv', 'regnety_016', 'regnety_032', 'regnety_040', 'regnety_040_sgn', 'regnety_064', 'regnety_080', 'regnety_080_tv', 'regnety_120', 'regnety_160', 'regnety_320', 'regnety_640', 'regnety_1280', 'regnety_2560', 'regnetz_005', 'regnetz_040', 'regnetz_040_h', 'regnetz_b16', 'regnetz_b16_evos', 'regnetz_c16', 'regnetz_c16_evos', 'regnetz_d8', 'regnetz_d8_evos', 'regnetz_d32', 'regnetz_e8', 'repghostnet_050', 'repghostnet_058', 'repghostnet_080', 'repghostnet_100', 'repghostnet_111', 'repghostnet_130', 'repghostnet_150', 'repghostnet_200', 'repvgg_a0', 'repvgg_a1', 'repvgg_a2', 'repvgg_b0', 'repvgg_b1', 'repvgg_b1g4', 'repvgg_b2', 'repvgg_b2g4', 'repvgg_b3', 'repvgg_b3g4', 'repvgg_d2se', 'repvit_m0_9', 'repvit_m1', 'repvit_m1_0', 'repvit_m1_1', 'repvit_m1_5', 'repvit_m2', 'repvit_m2_3', 'repvit_m3', 'res2net50_14w_8s', 'res2net50_26w_4s', 'res2net50_26w_6s', 'res2net50_26w_8s', 'res2net50_48w_2s', 'res2net50d', 'res2net101_26w_4s', 'res2net101d', 'res2next50', 'resmlp_12_224', 'resmlp_24_224', 'resmlp_36_224', 'resmlp_big_24_224', 'resnest14d', 'resnest26d', 'resnest50d', 'resnest50d_1s4x24d', 'resnest50d_4s2x40d', 'resnest101e', 'resnest200e', 'resnest269e', 'resnet10t', 'resnet14t', 'resnet18', 'resnet18d', 'resnet26', 'resnet26d', 'resnet26t', 'resnet32ts', 'resnet33ts', 'resnet34', 'resnet34d', 'resnet50', 'resnet50_clip', 'resnet50_clip_gap', 'resnet50_gn', 'resnet50_mlp', 'resnet50c', 'resnet50d', 'resnet50s', 'resnet50t', 'resnet50x4_clip', 'resnet50x4_clip_gap', 'resnet50x16_clip', 'resnet50x16_clip_gap', 'resnet50x64_clip', 'resnet50x64_clip_gap', 'resnet51q', 'resnet61q', 'resnet101', 'resnet101_clip', 'resnet101_clip_gap', 'resnet101c', 'resnet101d', 'resnet101s', 'resnet152', 'resnet152c', 'resnet152d', 'resnet152s', 'resnet200', 'resnet200d', 'resnetaa34d', 'resnetaa50', 'resnetaa50d', 'resnetaa101d', 'resnetblur18', 'resnetblur50', 'resnetblur50d', 'resnetblur101d', 'resnetrs50', 'resnetrs101', 'resnetrs152', 'resnetrs200', 'resnetrs270', 'resnetrs350', 'resnetrs420', 'resnetv2_50', 'resnetv2_50d', 'resnetv2_50d_evos', 'resnetv2_50d_frn', 'resnetv2_50d_gn', 'resnetv2_50t', 'resnetv2_50x1_bit', 'resnetv2_50x3_bit', 'resnetv2_101', 'resnetv2_101d', 'resnetv2_101x1_bit', 'resnetv2_101x3_bit', 'resnetv2_152', 'resnetv2_152d', 'resnetv2_152x2_bit', 'resnetv2_152x4_bit', 'resnext26ts', 'resnext50_32x4d', 'resnext50d_32x4d', 'resnext101_32x4d', 'resnext101_32x8d', 'resnext101_32x16d', 'resnext101_32x32d', 'resnext101_64x4d', 'rexnet_100', 'rexnet_130', 'rexnet_150', 'rexnet_200', 'rexnet_300', 'rexnetr_100', 'rexnetr_130', 'rexnetr_150', 'rexnetr_200', 'rexnetr_300', 'sam2_hiera_base_plus', 'sam2_hiera_large', 'sam2_hiera_small', 'sam2_hiera_tiny', 'samvit_base_patch16', 'samvit_base_patch16_224', 'samvit_huge_patch16', 'samvit_large_patch16', 'sebotnet33ts_256', 'sedarknet21', 'sehalonet33ts', 'selecsls42', 'selecsls42b', 'selecsls60', 'selecsls60b', 'selecsls84', 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'semnasnet_140', 'senet154', 'sequencer2d_l', 'sequencer2d_m', 'sequencer2d_s', 'seresnet18', 'seresnet33ts', 'seresnet34', 'seresnet50', 'seresnet50t', 'seresnet101', 'seresnet152', 'seresnet152d', 'seresnet200d', 'seresnet269d', 'seresnetaa50d', 'seresnext26d_32x4d', 'seresnext26t_32x4d', 'seresnext26ts', 'seresnext50_32x4d', 'seresnext101_32x4d', 'seresnext101_32x8d', 'seresnext101_64x4d', 'seresnext101d_32x8d', 'seresnextaa101d_32x8d', 'seresnextaa201d_32x8d', 'skresnet18', 'skresnet34', 'skresnet50', 'skresnet50d', 'skresnext50_32x4d', 'spnasnet_100', 'swin_base_patch4_window7_224', 'swin_base_patch4_window12_384', 'swin_large_patch4_window7_224', 'swin_large_patch4_window12_384', 'swin_s3_base_224', 'swin_s3_small_224', 'swin_s3_tiny_224', 'swin_small_patch4_window7_224', 'swin_tiny_patch4_window7_224', 'swinv2_base_window8_256', 'swinv2_base_window12_192', 'swinv2_base_window12to16_192to256', 'swinv2_base_window12to24_192to384', 'swinv2_base_window16_256', 'swinv2_cr_base_224', 'swinv2_cr_base_384', 'swinv2_cr_base_ns_224', 'swinv2_cr_giant_224', 'swinv2_cr_giant_384', 'swinv2_cr_huge_224', 'swinv2_cr_huge_384', 'swinv2_cr_large_224', 'swinv2_cr_large_384', 'swinv2_cr_small_224', 'swinv2_cr_small_384', 'swinv2_cr_small_ns_224', 'swinv2_cr_small_ns_256', 'swinv2_cr_tiny_224', 'swinv2_cr_tiny_384', 'swinv2_cr_tiny_ns_224', 'swinv2_large_window12_192', 'swinv2_large_window12to16_192to256', 'swinv2_large_window12to24_192to384', 'swinv2_small_window8_256', 'swinv2_small_window16_256', 'swinv2_tiny_window8_256', 'swinv2_tiny_window16_256', 'test_byobnet', 'test_efficientnet', 'test_vit', 'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', 'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8', 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', 'tf_efficientnet_el', 'tf_efficientnet_em', 'tf_efficientnet_es', 'tf_efficientnet_l2', 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', 'tf_efficientnet_lite4', 'tf_efficientnetv2_b0', 'tf_efficientnetv2_b1', 'tf_efficientnetv2_b2', 'tf_efficientnetv2_b3', 'tf_efficientnetv2_l', 'tf_efficientnetv2_m', 'tf_efficientnetv2_s', 'tf_efficientnetv2_xl', 'tf_mixnet_l', 'tf_mixnet_m', 'tf_mixnet_s', 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100', 'tiny_vit_5m_224', 'tiny_vit_11m_224', 'tiny_vit_21m_224', 'tiny_vit_21m_384', 'tiny_vit_21m_512', 'tinynet_a', 'tinynet_b', 'tinynet_c', 'tinynet_d', 'tinynet_e', 'tnt_b_patch16_224', 'tnt_s_patch16_224', 'tresnet_l', 'tresnet_m', 'tresnet_v2_l', 'tresnet_xl', 'twins_pcpvt_base', 'twins_pcpvt_large', 'twins_pcpvt_small', 'twins_svt_base', 'twins_svt_large', 'twins_svt_small', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'visformer_small', 'visformer_tiny', 'vit_base_mci_224', 'vit_base_patch8_224', 'vit_base_patch14_dinov2', 'vit_base_patch14_reg4_dinov2', 'vit_base_patch16_18x2_224', 'vit_base_patch16_224', 'vit_base_patch16_224_miil', 'vit_base_patch16_384', 'vit_base_patch16_clip_224', 'vit_base_patch16_clip_384', 'vit_base_patch16_clip_quickgelu_224', 'vit_base_patch16_gap_224', 'vit_base_patch16_plus_240', 'vit_base_patch16_reg4_gap_256', 'vit_base_patch16_rope_reg1_gap_256', 'vit_base_patch16_rpn_224', 'vit_base_patch16_siglip_224', 'vit_base_patch16_siglip_256', 'vit_base_patch16_siglip_384', 'vit_base_patch16_siglip_512', 'vit_base_patch16_siglip_gap_224', 'vit_base_patch16_siglip_gap_256', 'vit_base_patch16_siglip_gap_384', 'vit_base_patch16_siglip_gap_512', 'vit_base_patch16_xp_224', 'vit_base_patch32_224', 'vit_base_patch32_384', 'vit_base_patch32_clip_224', 'vit_base_patch32_clip_256', 'vit_base_patch32_clip_384', 'vit_base_patch32_clip_448', 'vit_base_patch32_clip_quickgelu_224', 'vit_base_patch32_plus_256', 'vit_base_r26_s32_224', 'vit_base_r50_s16_224', 'vit_base_r50_s16_384', 'vit_base_resnet26d_224', 'vit_base_resnet50d_224', 'vit_betwixt_patch16_gap_256', 'vit_betwixt_patch16_reg1_gap_256', 'vit_betwixt_patch16_reg4_gap_256', 'vit_betwixt_patch16_reg4_gap_384', 'vit_betwixt_patch16_rope_reg4_gap_256', 'vit_betwixt_patch32_clip_224', 'vit_giant_patch14_224', 'vit_giant_patch14_clip_224', 'vit_giant_patch14_dinov2', 'vit_giant_patch14_reg4_dinov2', 'vit_giant_patch16_gap_224', 'vit_gigantic_patch14_224', 'vit_gigantic_patch14_clip_224', 'vit_huge_patch14_224', 'vit_huge_patch14_clip_224', 'vit_huge_patch14_clip_336', 'vit_huge_patch14_clip_378', 'vit_huge_patch14_clip_quickgelu_224', 'vit_huge_patch14_clip_quickgelu_378', 'vit_huge_patch14_gap_224', 'vit_huge_patch14_xp_224', 'vit_huge_patch16_gap_448', 'vit_large_patch14_224', 'vit_large_patch14_clip_224', 'vit_large_patch14_clip_336', 'vit_large_patch14_clip_quickgelu_224', 'vit_large_patch14_clip_quickgelu_336', 'vit_large_patch14_dinov2', 'vit_large_patch14_reg4_dinov2', 'vit_large_patch14_xp_224', 'vit_large_patch16_224', 'vit_large_patch16_384', 'vit_large_patch16_siglip_256', 'vit_large_patch16_siglip_384', 'vit_large_patch16_siglip_gap_256', 'vit_large_patch16_siglip_gap_384', 'vit_large_patch32_224', 'vit_large_patch32_384', 'vit_large_r50_s32_224', 'vit_large_r50_s32_384', 'vit_little_patch16_reg1_gap_256', 'vit_little_patch16_reg4_gap_256', 'vit_medium_patch16_clip_224', 'vit_medium_patch16_gap_240', 'vit_medium_patch16_gap_256', 'vit_medium_patch16_gap_384', 'vit_medium_patch16_reg1_gap_256', 'vit_medium_patch16_reg4_gap_256', 'vit_medium_patch16_rope_reg1_gap_256', 'vit_medium_patch32_clip_224', 'vit_mediumd_patch16_reg4_gap_256', 'vit_mediumd_patch16_reg4_gap_384', 'vit_mediumd_patch16_rope_reg1_gap_256', 'vit_pwee_patch16_reg1_gap_256', 'vit_relpos_base_patch16_224', 'vit_relpos_base_patch16_cls_224', 'vit_relpos_base_patch16_clsgap_224', 'vit_relpos_base_patch16_plus_240', 'vit_relpos_base_patch16_rpn_224', 'vit_relpos_base_patch32_plus_rpn_256', 'vit_relpos_medium_patch16_224', 'vit_relpos_medium_patch16_cls_224', 'vit_relpos_medium_patch16_rpn_224', 'vit_relpos_small_patch16_224', 'vit_relpos_small_patch16_rpn_224', 'vit_small_patch8_224', 'vit_small_patch14_dinov2', 'vit_small_patch14_reg4_dinov2', 'vit_small_patch16_18x2_224', 'vit_small_patch16_36x1_224', 'vit_small_patch16_224', 'vit_small_patch16_384', 'vit_small_patch32_224', 'vit_small_patch32_384', 'vit_small_r26_s32_224', 'vit_small_r26_s32_384', 'vit_small_resnet26d_224', 'vit_small_resnet50d_s16_224', 'vit_so150m_patch16_reg4_gap_256', 'vit_so150m_patch16_reg4_map_256', 'vit_so400m_patch14_siglip_224', 'vit_so400m_patch14_siglip_384', 'vit_so400m_patch14_siglip_gap_224', 'vit_so400m_patch14_siglip_gap_384', 'vit_so400m_patch14_siglip_gap_448', 'vit_so400m_patch14_siglip_gap_896', 'vit_srelpos_medium_patch16_224', 'vit_srelpos_small_patch16_224', 'vit_tiny_patch16_224', 'vit_tiny_patch16_384', 'vit_tiny_r_s16_p8_224', 'vit_tiny_r_s16_p8_384', 'vit_wee_patch16_reg1_gap_256', 'vit_xsmall_patch16_clip_224', 'vitamin_base_224', 'vitamin_large2_224', 'vitamin_large2_256', 'vitamin_large2_336', 'vitamin_large2_384', 'vitamin_large_224', 'vitamin_large_256', 'vitamin_large_336', 'vitamin_large_384', 'vitamin_small_224', 'vitamin_xlarge_256', 'vitamin_xlarge_336', 'vitamin_xlarge_384', 'volo_d1_224', 'volo_d1_384', 'volo_d2_224', 'volo_d2_384', 'volo_d3_224', 'volo_d3_448', 'volo_d4_224', 'volo_d4_448', 'volo_d5_224', 'volo_d5_448', 'volo_d5_512', 'vovnet39a', 'vovnet57a', 'wide_resnet50_2', 'wide_resnet101_2', 'xception41', 'xception41p', 'xception65', 'xception65p', 'xception71', 'xcit_large_24_p8_224', 'xcit_large_24_p8_384', 'xcit_large_24_p16_224', 'xcit_large_24_p16_384', 'xcit_medium_24_p8_224', 'xcit_medium_24_p8_384', 'xcit_medium_24_p16_224', 'xcit_medium_24_p16_384', 'xcit_nano_12_p8_224', 'xcit_nano_12_p8_384', 'xcit_nano_12_p16_224', 'xcit_nano_12_p16_384', 'xcit_small_12_p8_224', 'xcit_small_12_p8_384', 'xcit_small_12_p16_224', 'xcit_small_12_p16_384', 'xcit_small_24_p8_224', 'xcit_small_24_p8_384', 'xcit_small_24_p16_224', 'xcit_small_24_p16_384', 'xcit_tiny_12_p8_224', 'xcit_tiny_12_p8_384', 'xcit_tiny_12_p16_224', 'xcit_tiny_12_p16_384', 'xcit_tiny_24_p8_224', 'xcit_tiny_24_p8_384', 'xcit_tiny_24_p16_224', 'xcit_tiny_24_p16_384']\n"
|
850 |
+
]
|
851 |
+
}
|
852 |
+
],
|
853 |
+
"source": [
|
854 |
+
"import timm\n",
|
855 |
+
"print(timm.list_models())"
|
856 |
+
]
|
857 |
+
},
|
858 |
+
{
|
859 |
+
"cell_type": "markdown",
|
860 |
+
"metadata": {},
|
861 |
+
"source": [
|
862 |
+
"##### testing the litserve model"
|
863 |
+
]
|
864 |
+
},
|
865 |
+
{
|
866 |
+
"cell_type": "code",
|
867 |
+
"execution_count": 2,
|
868 |
+
"metadata": {},
|
869 |
+
"outputs": [],
|
870 |
+
"source": [
|
871 |
+
"import requests\n",
|
872 |
+
"from urllib.request import urlopen\n",
|
873 |
+
"import base64"
|
874 |
+
]
|
875 |
+
},
|
876 |
+
{
|
877 |
+
"cell_type": "code",
|
878 |
+
"execution_count": 33,
|
879 |
+
"metadata": {},
|
880 |
+
"outputs": [
|
881 |
+
{
|
882 |
+
"name": "stdout",
|
883 |
+
"output_type": "stream",
|
884 |
+
"text": [
|
885 |
+
"<class 'bytes'>\n"
|
886 |
+
]
|
887 |
+
}
|
888 |
+
],
|
889 |
+
"source": [
|
890 |
+
"url = \"https://media.istockphoto.com/id/541844008/photo/portland-grand-floral-parade-2016.jpg?s=2048x2048&w=is&k=20&c=ZuvR6oDv5WxwL5dhXKAbevysEXhXV47shJdpzkqen5Y=\"\n",
|
891 |
+
"img_data = urlopen(url).read()\n",
|
892 |
+
"print(type(img_data))"
|
893 |
+
]
|
894 |
+
},
|
895 |
+
{
|
896 |
+
"cell_type": "code",
|
897 |
+
"execution_count": 34,
|
898 |
+
"metadata": {},
|
899 |
+
"outputs": [
|
900 |
+
{
|
901 |
+
"name": "stdout",
|
902 |
+
"output_type": "stream",
|
903 |
+
"text": [
|
904 |
+
"<class 'str'>\n"
|
905 |
+
]
|
906 |
+
}
|
907 |
+
],
|
908 |
+
"source": [
|
909 |
+
"# Convert to base64 string\n",
|
910 |
+
"img_bytes = base64.b64encode(img_data).decode('utf-8')\n",
|
911 |
+
"print(type(img_bytes))"
|
912 |
+
]
|
913 |
+
},
|
914 |
+
{
|
915 |
+
"cell_type": "code",
|
916 |
+
"execution_count": 35,
|
917 |
+
"metadata": {},
|
918 |
+
"outputs": [],
|
919 |
+
"source": [
|
920 |
+
"response = requests.post(\n",
|
921 |
+
" \"http://localhost:8080/predict\", json={\"image\": img_bytes} # image is the key\n",
|
922 |
+
")"
|
923 |
+
]
|
924 |
+
},
|
925 |
+
{
|
926 |
+
"cell_type": "code",
|
927 |
+
"execution_count": 36,
|
928 |
+
"metadata": {},
|
929 |
+
"outputs": [
|
930 |
+
{
|
931 |
+
"name": "stdout",
|
932 |
+
"output_type": "stream",
|
933 |
+
"text": [
|
934 |
+
"\\nTop 5 Predictions:\n",
|
935 |
+
"mountain_bike, all-terrain_bike, off-roader: 82.13%\n",
|
936 |
+
"maillot: 5.09%\n",
|
937 |
+
"crash_helmet: 1.84%\n",
|
938 |
+
"bicycle-built-for-two, tandem_bicycle, tandem: 1.83%\n",
|
939 |
+
"alp: 0.69%\n"
|
940 |
+
]
|
941 |
+
}
|
942 |
+
],
|
943 |
+
"source": [
|
944 |
+
"if response.status_code == 200:\n",
|
945 |
+
" predictions = response.json()[\"predictions\"]\n",
|
946 |
+
" print(\"\\\\nTop 5 Predictions:\")\n",
|
947 |
+
" for pred in predictions:\n",
|
948 |
+
" print(f\"{pred['label']}: {pred['probability']:.2%}\")\n",
|
949 |
+
"else:\n",
|
950 |
+
" print(f\"Error: {response.status_code}\")\n",
|
951 |
+
" print(response.text)"
|
952 |
+
]
|
953 |
+
},
|
954 |
+
{
|
955 |
+
"cell_type": "code",
|
956 |
+
"execution_count": null,
|
957 |
+
"metadata": {},
|
958 |
+
"outputs": [],
|
959 |
+
"source": []
|
960 |
+
},
|
961 |
+
{
|
962 |
+
"cell_type": "code",
|
963 |
+
"execution_count": null,
|
964 |
+
"metadata": {},
|
965 |
+
"outputs": [],
|
966 |
+
"source": []
|
967 |
+
},
|
968 |
+
{
|
969 |
+
"cell_type": "code",
|
970 |
+
"execution_count": null,
|
971 |
+
"metadata": {},
|
972 |
+
"outputs": [],
|
973 |
+
"source": []
|
974 |
+
},
|
975 |
+
{
|
976 |
+
"cell_type": "code",
|
977 |
+
"execution_count": null,
|
978 |
+
"metadata": {},
|
979 |
+
"outputs": [],
|
980 |
+
"source": []
|
981 |
+
},
|
982 |
+
{
|
983 |
+
"cell_type": "markdown",
|
984 |
+
"metadata": {},
|
985 |
+
"source": [
|
986 |
+
"########################################## End of the script ##########################################"
|
987 |
+
]
|
988 |
+
}
|
989 |
+
],
|
990 |
+
"metadata": {
|
991 |
+
"kernelspec": {
|
992 |
+
"display_name": "emlo_env",
|
993 |
+
"language": "python",
|
994 |
+
"name": "python3"
|
995 |
+
},
|
996 |
+
"language_info": {
|
997 |
+
"codemirror_mode": {
|
998 |
+
"name": "ipython",
|
999 |
+
"version": 3
|
1000 |
+
},
|
1001 |
+
"file_extension": ".py",
|
1002 |
+
"mimetype": "text/x-python",
|
1003 |
+
"name": "python",
|
1004 |
+
"nbconvert_exporter": "python",
|
1005 |
+
"pygments_lexer": "ipython3",
|
1006 |
+
"version": "3.10.15"
|
1007 |
+
}
|
1008 |
+
},
|
1009 |
+
"nbformat": 4,
|
1010 |
+
"nbformat_minor": 2
|
1011 |
+
}
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "pytorch_fastapi_project"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Consolidated PyTorch and FastAPI project for AWS deployment and GHA testing"
|
5 |
+
authors = ["soutrik71 <[email protected]>"]
|
6 |
+
license = "Apache-2.0"
|
7 |
+
readme = "README.md"
|
8 |
+
|
9 |
+
[tool.poetry.dependencies]
|
10 |
+
python = ">=3.10,<3.11"
|
11 |
+
black = "24.8.0"
|
12 |
+
coverage = ">=7.6.1"
|
13 |
+
hydra-colorlog = "1.2.0"
|
14 |
+
hydra-core = "1.3.2"
|
15 |
+
lightning = {version = "2.4.0", extras = ["extra"]}
|
16 |
+
loguru = "0.7.2"
|
17 |
+
pytest = "^8.3.3"
|
18 |
+
rich = "13.8.1"
|
19 |
+
rootutils = "1.0.7"
|
20 |
+
tensorboard = "2.17.1"
|
21 |
+
timm = "1.0.9"
|
22 |
+
pandas = "^2.2.3"
|
23 |
+
numpy = "^1.26.0"
|
24 |
+
ruff = "*"
|
25 |
+
torch = {version = "^2.4.1", source = "pytorch_cuda"}
|
26 |
+
torchvision = {version = "^0.19.1", source = "pytorch_cuda"}
|
27 |
+
torchaudio = {version = "^2.4.1", source = "pytorch_cuda"}
|
28 |
+
seaborn = "^0.13.2"
|
29 |
+
pydantic = "^2.9.2"
|
30 |
+
kaggle = "^1.6.17"
|
31 |
+
pytest-cov = "^5.0.0"
|
32 |
+
pytest-mock = "^3.14.0"
|
33 |
+
flake8 = "^7.1.1"
|
34 |
+
dvc-gdrive = "^3.0.1"
|
35 |
+
dvc-azure = "^3.1.0"
|
36 |
+
transformers = "^4.45.2"
|
37 |
+
fastapi = "^0.115.4"
|
38 |
+
pydantic-settings = "^2.6.1"
|
39 |
+
uvicorn = "^0.32.0"
|
40 |
+
tenacity = "^9.0.0"
|
41 |
+
gunicorn = "^23.0.0"
|
42 |
+
aim = "^3.25.0"
|
43 |
+
mlflow = "^2.17.1"
|
44 |
+
hydra-optuna-sweeper = "^1.2.0"
|
45 |
+
dvc = "^3.56.0"
|
46 |
+
platformdirs = "3.10"
|
47 |
+
fastapi-utils = "^0.7.0"
|
48 |
+
httpx = "^0.27.2"
|
49 |
+
typing-inspect = "^0.9.0"
|
50 |
+
requests = "^2.32.3"
|
51 |
+
fastapi-restful = {extras = ["all"], version = "^0.6.0"}
|
52 |
+
aioredis = "^2.0.1"
|
53 |
+
psycopg2-binary = "^2.9.10"
|
54 |
+
asyncpg = "^0.30.0"
|
55 |
+
confluent-kafka = "^2.6.0"
|
56 |
+
aiokafka = "^0.12.0"
|
57 |
+
azure-servicebus = "^7.12.3"
|
58 |
+
aiohttp = "^3.10.10"
|
59 |
+
aiofiles = "*"
|
60 |
+
aiologger = "^0.7.0"
|
61 |
+
pyyaml = "^6.0.2"
|
62 |
+
sqlalchemy-utils = "^0.41.2"
|
63 |
+
sqlalchemy = "^2.0.36"
|
64 |
+
alembic = "^1.13.3"
|
65 |
+
fastapi-limiter = "^0.1.6"
|
66 |
+
redis = "5.0.8"
|
67 |
+
redisearch = "2.0.0"
|
68 |
+
python-multipart = "*"
|
69 |
+
python-dotenv = "^1.0.1"
|
70 |
+
celery = "^5.4.0"
|
71 |
+
fastapi-cache2 = "^0.2.2"
|
72 |
+
aiocache = "^0.12.3"
|
73 |
+
dvc-s3 = "^3.2.0"
|
74 |
+
litserve = "^0.2.4"
|
75 |
+
gpustat = "^1.1.1"
|
76 |
+
nvitop = "^1.3.2"
|
77 |
+
gradio = "5.7.1"
|
78 |
+
gradio-client = "^1.5.0"
|
79 |
+
accelerate = "^1.1.1"
|
80 |
+
cryptography = "^44.0.0"
|
81 |
+
boto3 = "*"
|
82 |
+
pyopenssl = "^24.3.0"
|
83 |
+
|
84 |
+
[tool.poetry.dev-dependencies]
|
85 |
+
pytest-asyncio = "^0.20.3"
|
86 |
+
|
87 |
+
[[tool.poetry.source]]
|
88 |
+
name = "pytorch_cuda"
|
89 |
+
url = "https://download.pytorch.org/whl/cu124"
|
90 |
+
priority = "explicit"
|
91 |
+
|
92 |
+
[build-system]
|
93 |
+
requires = ["poetry-core"]
|
94 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.4.1
|
2 |
+
torchvision==0.19.1
|
3 |
+
hydra-colorlog==1.2.0
|
4 |
+
hydra-core==1.3.2
|
5 |
+
lightning[extra]==2.4.0
|
6 |
+
loguru==0.7.2
|
7 |
+
rich==13.8.1
|
8 |
+
rootutils==1.0.7
|
9 |
+
tensorboard==2.17.1
|
10 |
+
timm==1.0.9
|
11 |
+
pandas>=2.2.3
|
12 |
+
numpy>=1.26.0
|
13 |
+
transformers>=4.45.2
|
14 |
+
aim>=3.25.0
|
15 |
+
mlflow>=2.17.1
|
16 |
+
hydra-optuna-sweeper>=1.2.0
|
17 |
+
aiologger>=0.7.0
|
18 |
+
pyyaml>=6.0.2
|
19 |
+
dvc-s3>=3.2.0
|
20 |
+
litserve>=0.2.4
|
21 |
+
gpustat>=1.1.1
|
22 |
+
nvitop>=1.3.2
|
23 |
+
gradio==5.7.1
|
24 |
+
gradio-client>=1.5.0
|
25 |
+
accelerate>=1.1.1
|
26 |
+
cryptography>=44.0.0
|
27 |
+
boto3
|
28 |
+
pyopenssl>=24.3.0
|