Spaces:
Sleeping
Sleeping
File size: 3,014 Bytes
1c49f74 91820d0 1c49f74 58232f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import torch
import numpy as np
import lightning as pl
import gradio as gr
from PIL import Image
from torchvision import transforms
from data import LitMNISTDataModule
from config import CONFIG
from model import LitMNISTModel
from timeit import default_timer as timer
torch.set_float32_matmul_precision('medium')
torch.cuda.amp.autocast(enabled=True)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.set_default_device( device= device )
pl.seed_everything(123, workers=True)
TEST_TRANSFORMS = transforms.Compose([
# transforms.PILToTensor(),
# transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
## MNISTDataModule
dm = LitMNISTDataModule(
data_dir=CONFIG['data'].get('dir_path','.'),
batch_size= CONFIG.get('batch_size'),
num_workers=CONFIG.get('num_workers'),
test_transform=TEST_TRANSFORMS,
train_transform=None
)
dm.prepare_data()
dm.setup('test')
## MNISTModel
# model = LitMNISTModel()
chkpoint_path = os.path.join( os.path.dirname(__file__),'logs','chkpoints','epoch=13.ckpt' )
model = LitMNISTModel.load_from_checkpoint(chkpoint_path)
trainer = pl.Trainer(
fast_dev_run=True,
precision=32,
enable_model_summary=False,
enable_progress_bar=False,
)
# trainer.test(model,datamodule=dm)
# for X,y in dm.test_dataloader():
# for i in range(X.shape[0]):
# plt.imsave(
# fname=os.path.join('numbers',f'img_{i}.png'),
# arr=np.clip(
# torch.stack(
# [X[i,...],X[i,...],X[i,...]],
# dim=1
# ).squeeze(0).permute(1,2,0).detach().cpu().contiguous().numpy(),0,1))
# break
def predict_fn(img:Image):
start_time = timer()
try:
img = np.array(img)
img = TEST_TRANSFORMS(img)
img = img.mean(dim=0).unsqueeze(0).unsqueeze(0).to(model.device)
y_preds = model.predict_step( img)
res = {f"Title: {y_preds['predict'][0]}": y_preds['prob'][0]}
pred_time = round(timer() - start_time, 5)
return(res, pred_time)
except Exception as e:
gr.Error("An error occured 💥!", duration=5)
return ({ f"Title ☠️": 0.0},0.0)
gr.Interface(
fn=predict_fn,
inputs=gr.Image(type='pil'),
outputs=[
gr.Label(num_top_classes=1, label="Predictions"), # what are the outputs?
gr.Number(label="Prediction time (s)")
],
examples=[ ['numbers/'+i] for i in os.listdir(os.path.join( os.path.dirname(__file__) ,'numbers'))],
title="The Unsolved MNIST 🔢",
description="CNN-based Architecture for Fast and Accurate MNIST 🔢 Solution with Reproducible Logs",
article="Created by muthukamalan.m ❤️"
).launch(share=False,debug=False)
|