File size: 3,014 Bytes
1c49f74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91820d0
 
 
 
1c49f74
58232f0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os 
import torch
import numpy as np
import lightning as pl
import gradio as gr 
from PIL import Image
from torchvision import transforms
from data import  LitMNISTDataModule 
from config import CONFIG
from model import LitMNISTModel
from timeit import default_timer as timer

torch.set_float32_matmul_precision('medium')
torch.cuda.amp.autocast(enabled=True)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.set_default_device( device= device  )


pl.seed_everything(123, workers=True)

TEST_TRANSFORMS = transforms.Compose([
    # transforms.PILToTensor(),
    # transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

## MNISTDataModule
dm = LitMNISTDataModule(
                data_dir=CONFIG['data'].get('dir_path','.'),
                batch_size= CONFIG.get('batch_size'),
                num_workers=CONFIG.get('num_workers'),
                test_transform=TEST_TRANSFORMS,
                train_transform=None
    )
dm.prepare_data()
dm.setup('test')

## MNISTModel
# model =  LitMNISTModel()
chkpoint_path = os.path.join(  os.path.dirname(__file__),'logs','chkpoints','epoch=13.ckpt' )
model = LitMNISTModel.load_from_checkpoint(chkpoint_path)


trainer  = pl.Trainer(
                fast_dev_run=True,
                precision=32,
                enable_model_summary=False,
                enable_progress_bar=False,
             )

# trainer.test(model,datamodule=dm)


# for X,y in dm.test_dataloader():
#     for i in range(X.shape[0]):
#         plt.imsave(
#             fname=os.path.join('numbers',f'img_{i}.png'),
#             arr=np.clip(
#                 torch.stack(
#                     [X[i,...],X[i,...],X[i,...]],
#                     dim=1
#                 ).squeeze(0).permute(1,2,0).detach().cpu().contiguous().numpy(),0,1))
#     break


def predict_fn(img:Image):
    start_time = timer()
    try:    
        img = np.array(img)        
        img = TEST_TRANSFORMS(img)    
        img = img.mean(dim=0).unsqueeze(0).unsqueeze(0).to(model.device)
        y_preds = model.predict_step( img)
        res = {f"Title: {y_preds['predict'][0]}": y_preds['prob'][0]}
        pred_time = round(timer() - start_time, 5)
        return(res, pred_time)
    except Exception as e:
        gr.Error("An error occured 💥!", duration=5)
        return ({ f"Title ☠️": 0.0},0.0)




gr.Interface(
        fn=predict_fn,
        inputs=gr.Image(type='pil'),
        outputs=[
            gr.Label(num_top_classes=1, label="Predictions"), # what are the outputs?
            gr.Number(label="Prediction time (s)")
        ],
        examples=[ ['numbers/'+i] for i in os.listdir(os.path.join( os.path.dirname(__file__) ,'numbers'))],
        title="The Unsolved MNIST 🔢",
        description="CNN-based Architecture for Fast and Accurate MNIST 🔢 Solution with Reproducible Logs",
        article="Created by muthukamalan.m ❤️"
        ).launch(share=False,debug=False)