Muthukamalan commited on
Commit
1c49f74
·
1 Parent(s): d3a8b6d

added inference model

Browse files
Files changed (1) hide show
  1. app.py +92 -5
app.py CHANGED
@@ -1,7 +1,94 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import lightning as pl
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from data import LitMNISTDataModule
9
+ from config import CONFIG
10
+ from model import LitMNISTModel
11
+ from timeit import default_timer as timer
12
+
13
+ torch.set_float32_matmul_precision('medium')
14
+ torch.cuda.amp.autocast(enabled=True)
15
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
16
+ torch.set_default_device( device= device )
17
+
18
+
19
+ pl.seed_everything(123, workers=True)
20
+
21
+ TEST_TRANSFORMS = transforms.Compose([
22
+ # transforms.PILToTensor(),
23
+ # transforms.Resize((28, 28)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize((0.1307,), (0.3081,))
26
+ ])
27
+
28
+ ## MNISTDataModule
29
+ dm = LitMNISTDataModule(
30
+ data_dir=CONFIG['data'].get('dir_path','.'),
31
+ batch_size= CONFIG.get('batch_size'),
32
+ num_workers=CONFIG.get('num_workers'),
33
+ test_transform=TEST_TRANSFORMS,
34
+ train_transform=None
35
+ )
36
+ dm.prepare_data()
37
+ dm.setup('test')
38
+
39
+ ## MNISTModel
40
+ # model = LitMNISTModel()
41
+ chkpoint_path = os.path.join( os.path.dirname(__file__),'logs','chkpoints','epoch=13.ckpt' )
42
+ model = LitMNISTModel.load_from_checkpoint(chkpoint_path)
43
+
44
+
45
+ trainer = pl.Trainer(
46
+ fast_dev_run=True,
47
+ precision=32,
48
+ enable_model_summary=False,
49
+ enable_progress_bar=False,
50
+ )
51
+
52
+ # trainer.test(model,datamodule=dm)
53
+
54
+
55
+ # for X,y in dm.test_dataloader():
56
+ # for i in range(X.shape[0]):
57
+ # plt.imsave(
58
+ # fname=os.path.join('numbers',f'img_{i}.png'),
59
+ # arr=np.clip(
60
+ # torch.stack(
61
+ # [X[i,...],X[i,...],X[i,...]],
62
+ # dim=1
63
+ # ).squeeze(0).permute(1,2,0).detach().cpu().contiguous().numpy(),0,1))
64
+ # break
65
+
66
+
67
+ def predict_fn(img:Image):
68
+ start_time = timer()
69
+ try:
70
+ img = np.array(img)
71
+ img = TEST_TRANSFORMS(img)
72
+ img = img.mean(dim=0).unsqueeze(0).unsqueeze(0).to(model.device)
73
+ y_preds = model.predict_step( img)
74
+ res = {f"Title: {y_preds['predict'][0]}": y_preds['prob'][0]}
75
+ pred_time = round(timer() - start_time, 5)
76
+ return(res, pred_time)
77
+ except Exception as e:
78
+ gr.Error("An error occured 💥!", duration=5)
79
+ return ({ f"Title ☠️": 0.0},0.0)
80
+
81
+
82
+
83
+
84
+ gr.Interface(
85
+ fn=predict_fn,
86
+ inputs=gr.Image(type='pil'),
87
+ outputs=[
88
+ gr.Label(num_top_classes=1, label="Predictions"), # what are the outputs?
89
+ gr.Number(label="Prediction time (s)")
90
+ ],
91
+ examples=[ ['numbers/'+i] for i in os.listdir(os.path.join( os.path.dirname(__file__) ,'numbers'))]
92
+ ).launch(share=False,debug=False)
93
 
 
 
94