Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
import lightning as pl | |
import gradio as gr | |
from PIL import Image | |
from torchvision import transforms | |
from data import LitMNISTDataModule | |
from config import CONFIG | |
from model import LitMNISTModel | |
from timeit import default_timer as timer | |
torch.set_float32_matmul_precision('medium') | |
torch.cuda.amp.autocast(enabled=True) | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
torch.set_default_device( device= device ) | |
pl.seed_everything(123, workers=True) | |
TEST_TRANSFORMS = transforms.Compose([ | |
# transforms.PILToTensor(), | |
# transforms.Resize((28, 28)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
## MNISTDataModule | |
dm = LitMNISTDataModule( | |
data_dir=CONFIG['data'].get('dir_path','.'), | |
batch_size= CONFIG.get('batch_size'), | |
num_workers=CONFIG.get('num_workers'), | |
test_transform=TEST_TRANSFORMS, | |
train_transform=None | |
) | |
dm.prepare_data() | |
dm.setup('test') | |
## MNISTModel | |
# model = LitMNISTModel() | |
chkpoint_path = os.path.join( os.path.dirname(__file__),'logs','chkpoints','epoch=13.ckpt' ) | |
model = LitMNISTModel.load_from_checkpoint(chkpoint_path) | |
trainer = pl.Trainer( | |
fast_dev_run=True, | |
precision=32, | |
enable_model_summary=False, | |
enable_progress_bar=False, | |
) | |
# trainer.test(model,datamodule=dm) | |
# for X,y in dm.test_dataloader(): | |
# for i in range(X.shape[0]): | |
# plt.imsave( | |
# fname=os.path.join('numbers',f'img_{i}.png'), | |
# arr=np.clip( | |
# torch.stack( | |
# [X[i,...],X[i,...],X[i,...]], | |
# dim=1 | |
# ).squeeze(0).permute(1,2,0).detach().cpu().contiguous().numpy(),0,1)) | |
# break | |
def predict_fn(img:Image): | |
start_time = timer() | |
try: | |
img = np.array(img) | |
img = TEST_TRANSFORMS(img) | |
img = img.mean(dim=0).unsqueeze(0).unsqueeze(0).to(model.device) | |
y_preds = model.predict_step( img) | |
res = {f"Title: {y_preds['predict'][0]}": y_preds['prob'][0]} | |
pred_time = round(timer() - start_time, 5) | |
return(res, pred_time) | |
except Exception as e: | |
gr.Error("An error occured π₯!", duration=5) | |
return ({ f"Title β οΈ": 0.0},0.0) | |
gr.Interface( | |
fn=predict_fn, | |
inputs=gr.Image(type='pil'), | |
outputs=[ | |
gr.Label(num_top_classes=1, label="Predictions"), # what are the outputs? | |
gr.Number(label="Prediction time (s)") | |
], | |
examples=[ ['numbers/'+i] for i in os.listdir(os.path.join( os.path.dirname(__file__) ,'numbers'))], | |
title="The Unsolved MNIST π’", | |
description="CNN-based Architecture for Fast and Accurate MNIST π’ Solution with Reproducible Logs", | |
article="Created by muthukamalan.m β€οΈ" | |
).launch(share=False,debug=False) | |