Spaces:
Sleeping
Sleeping
Commit
·
af3a445
1
Parent(s):
8da8c04
src file added
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- config/__init__.py +6 -0
- config/config.toml +24 -0
- data/__init__.py +1 -0
- data/data.py +51 -0
- hand_made_numbers/number_eight.png +0 -0
- hand_made_numbers/number_five.png +0 -0
- hand_made_numbers/number_four.png +0 -0
- hand_made_numbers/number_nine.png +0 -0
- hand_made_numbers/number_one.png +0 -0
- hand_made_numbers/number_seven.png +0 -0
- hand_made_numbers/number_six.png +0 -0
- hand_made_numbers/number_three.png +0 -0
- hand_made_numbers/number_two.png +0 -0
- hand_made_numbers/number_zero.png +0 -0
- numbers/img_0.png +0 -0
- numbers/img_1.png +0 -0
- numbers/img_10.png +0 -0
- numbers/img_100.png +0 -0
- numbers/img_101.png +0 -0
- numbers/img_102.png +0 -0
- numbers/img_103.png +0 -0
- numbers/img_104.png +0 -0
- numbers/img_105.png +0 -0
- numbers/img_106.png +0 -0
- numbers/img_107.png +0 -0
- numbers/img_108.png +0 -0
- numbers/img_109.png +0 -0
- numbers/img_11.png +0 -0
- numbers/img_110.png +0 -0
- numbers/img_111.png +0 -0
- numbers/img_112.png +0 -0
- numbers/img_113.png +0 -0
- numbers/img_114.png +0 -0
- numbers/img_115.png +0 -0
- numbers/img_116.png +0 -0
- numbers/img_117.png +0 -0
- numbers/img_118.png +0 -0
- numbers/img_119.png +0 -0
- numbers/img_12.png +0 -0
- numbers/img_120.png +0 -0
- numbers/img_121.png +0 -0
- numbers/img_122.png +0 -0
- numbers/img_123.png +0 -0
- numbers/img_124.png +0 -0
- numbers/img_125.png +0 -0
- numbers/img_126.png +0 -0
- numbers/img_127.png +0 -0
- numbers/img_13.png +0 -0
- numbers/img_14.png +0 -0
- numbers/img_15.png +0 -0
config/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import toml
|
2 |
+
import os
|
3 |
+
|
4 |
+
config_file:str = os.path.join(os.path.dirname(__file__),'config.toml')
|
5 |
+
|
6 |
+
CONFIG = toml.load(config_file)
|
config/config.toml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
batch_size = 128
|
2 |
+
shuffle = true
|
3 |
+
num_workers= 0
|
4 |
+
pin_memory = false
|
5 |
+
|
6 |
+
[optimizer]
|
7 |
+
lr = 0.1
|
8 |
+
momentum = 0.9
|
9 |
+
|
10 |
+
[scheduler]
|
11 |
+
step_size = 15
|
12 |
+
gamma = 0.1
|
13 |
+
|
14 |
+
|
15 |
+
[training]
|
16 |
+
num_epochs = 15
|
17 |
+
|
18 |
+
|
19 |
+
[data]
|
20 |
+
dir_path = 'C:\Users\muthu\GitHub\DATA 📁\'
|
21 |
+
|
22 |
+
[model]
|
23 |
+
dropout_rate = 0.01
|
24 |
+
bias = false
|
data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .data import LitMNISTDataModule
|
data/data.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable,List,Any
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from lightning import LightningDataModule
|
5 |
+
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS,EVAL_DATALOADERS
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader,random_split
|
8 |
+
from torchvision import transforms
|
9 |
+
from torchvision.datasets import MNIST
|
10 |
+
|
11 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
12 |
+
|
13 |
+
class LitMNISTDataModule(LightningDataModule):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
data_dir:Path = Path('.'),
|
17 |
+
batch_size:int = 32,
|
18 |
+
num_workers:int = 0,
|
19 |
+
test_transform:Callable = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(.1307,),std=(.3081,))]),
|
20 |
+
train_transform:Callable = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(.1307,),std=(.3081,))])
|
21 |
+
) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.data_dir:Path = data_dir
|
24 |
+
self.batch_size:int = batch_size
|
25 |
+
self.num_workers:int = num_workers
|
26 |
+
self.test_transform:Callable = test_transform
|
27 |
+
self.train_transform:Callable = train_transform
|
28 |
+
self.save_hyperparameters()
|
29 |
+
|
30 |
+
def prepare_data(self) -> None:
|
31 |
+
MNIST(self.data_dir,train=True,download=True)
|
32 |
+
MNIST(self.data_dir,train=False,download=True)
|
33 |
+
|
34 |
+
def setup(self, stage: str=None) -> None:
|
35 |
+
if stage=="fit" or stage is None:
|
36 |
+
_mnist_full = MNIST(self.data_dir,train=True,transform=self.train_transform)
|
37 |
+
self.mnist_train, self.mnist_val = random_split(_mnist_full,[.9,.1],generator=torch.Generator(device))
|
38 |
+
|
39 |
+
if stage=='test' or stage is None:
|
40 |
+
self.mnist_test = MNIST(self.data_dir,train=False, transform=self.test_transform)
|
41 |
+
|
42 |
+
|
43 |
+
def train_dataloader(self) -> TRAIN_DATALOADERS:
|
44 |
+
return DataLoader(self.mnist_train,batch_size=self.batch_size,num_workers=self.num_workers,collate_fn=None,shuffle=True,generator= torch.Generator(device) )
|
45 |
+
|
46 |
+
def val_dataloader(self) -> EVAL_DATALOADERS:
|
47 |
+
return DataLoader(self.mnist_val,batch_size=self.batch_size,num_workers=self.num_workers,collate_fn=None,shuffle=False,generator= torch.Generator(device))
|
48 |
+
|
49 |
+
def test_dataloader(self) -> EVAL_DATALOADERS:
|
50 |
+
return DataLoader(self.mnist_test,batch_size=self.batch_size,num_workers=self.num_workers,collate_fn=None,shuffle=False,generator= torch.Generator(device))
|
51 |
+
|
hand_made_numbers/number_eight.png
ADDED
![]() |
hand_made_numbers/number_five.png
ADDED
![]() |
hand_made_numbers/number_four.png
ADDED
![]() |
hand_made_numbers/number_nine.png
ADDED
![]() |
hand_made_numbers/number_one.png
ADDED
![]() |
hand_made_numbers/number_seven.png
ADDED
![]() |
hand_made_numbers/number_six.png
ADDED
![]() |
hand_made_numbers/number_three.png
ADDED
![]() |
hand_made_numbers/number_two.png
ADDED
![]() |
hand_made_numbers/number_zero.png
ADDED
![]() |
numbers/img_0.png
ADDED
![]() |
numbers/img_1.png
ADDED
![]() |
numbers/img_10.png
ADDED
![]() |
numbers/img_100.png
ADDED
![]() |
numbers/img_101.png
ADDED
![]() |
numbers/img_102.png
ADDED
![]() |
numbers/img_103.png
ADDED
![]() |
numbers/img_104.png
ADDED
![]() |
numbers/img_105.png
ADDED
![]() |
numbers/img_106.png
ADDED
![]() |
numbers/img_107.png
ADDED
![]() |
numbers/img_108.png
ADDED
![]() |
numbers/img_109.png
ADDED
![]() |
numbers/img_11.png
ADDED
![]() |
numbers/img_110.png
ADDED
![]() |
numbers/img_111.png
ADDED
![]() |
numbers/img_112.png
ADDED
![]() |
numbers/img_113.png
ADDED
![]() |
numbers/img_114.png
ADDED
![]() |
numbers/img_115.png
ADDED
![]() |
numbers/img_116.png
ADDED
![]() |
numbers/img_117.png
ADDED
![]() |
numbers/img_118.png
ADDED
![]() |
numbers/img_119.png
ADDED
![]() |
numbers/img_12.png
ADDED
![]() |
numbers/img_120.png
ADDED
![]() |
numbers/img_121.png
ADDED
![]() |
numbers/img_122.png
ADDED
![]() |
numbers/img_123.png
ADDED
![]() |
numbers/img_124.png
ADDED
![]() |
numbers/img_125.png
ADDED
![]() |
numbers/img_126.png
ADDED
![]() |
numbers/img_127.png
ADDED
![]() |
numbers/img_13.png
ADDED
![]() |
numbers/img_14.png
ADDED
![]() |
numbers/img_15.png
ADDED
![]() |