import os import torch import numpy as np import lightning as pl import gradio as gr from PIL import Image from torchvision import transforms from data import LitMNISTDataModule from config import CONFIG from model import LitMNISTModel from timeit import default_timer as timer torch.set_float32_matmul_precision('medium') torch.cuda.amp.autocast(enabled=True) device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') torch.set_default_device( device= device ) pl.seed_everything(123, workers=True) TEST_TRANSFORMS = transforms.Compose([ # transforms.PILToTensor(), # transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) ## MNISTDataModule dm = LitMNISTDataModule( data_dir=CONFIG['data'].get('dir_path','.'), batch_size= CONFIG.get('batch_size'), num_workers=CONFIG.get('num_workers'), test_transform=TEST_TRANSFORMS, train_transform=None ) dm.prepare_data() dm.setup('test') ## MNISTModel # model = LitMNISTModel() chkpoint_path = os.path.join( os.path.dirname(__file__),'logs','chkpoints','epoch=13.ckpt' ) model = LitMNISTModel.load_from_checkpoint(chkpoint_path) trainer = pl.Trainer( fast_dev_run=True, precision=32, enable_model_summary=False, enable_progress_bar=False, ) # trainer.test(model,datamodule=dm) # for X,y in dm.test_dataloader(): # for i in range(X.shape[0]): # plt.imsave( # fname=os.path.join('numbers',f'img_{i}.png'), # arr=np.clip( # torch.stack( # [X[i,...],X[i,...],X[i,...]], # dim=1 # ).squeeze(0).permute(1,2,0).detach().cpu().contiguous().numpy(),0,1)) # break def predict_fn(img:Image): start_time = timer() try: img = np.array(img) img = TEST_TRANSFORMS(img) img = img.mean(dim=0).unsqueeze(0).unsqueeze(0).to(model.device) y_preds = model.predict_step( img) res = {f"Title: {y_preds['predict'][0]}": y_preds['prob'][0]} pred_time = round(timer() - start_time, 5) return(res, pred_time) except Exception as e: gr.Error("An error occured 💥!", duration=5) return ({ f"Title ☠️": 0.0},0.0) gr.Interface( fn=predict_fn, inputs=gr.Image(type='pil'), outputs=[ gr.Label(num_top_classes=1, label="Predictions"), # what are the outputs? gr.Number(label="Prediction time (s)") ], examples=[ ['numbers/'+i] for i in os.listdir(os.path.join( os.path.dirname(__file__) ,'numbers'))], title="The Unsolved MNIST 🔢", description="CNN-based Architecture for Fast and Accurate MNIST 🔢 Solution with Reproducible Logs", article="Created by muthukamalan.m ❤️" ).launch(share=False,debug=False)