In [1]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import lightning as L
import kornia as K
import numpy as np
import random
import sys

PROJECT_ROOT = os.path.abspath(os.path.normpath("/home/venom/repo/xray-exp/"))
sys.path.append(PROJECT_ROOT)

  from .autonotebook import tqdm as notebook_tqdm
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


In [2]:
from models.model_loader import create_model
from scripts.trainer import XrayReg

In [None]:
class XrayInferenceDataset(Dataset):

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.file_names = os.listdir(root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        img_path = os.path.join(self.root_dir, file_name)
        img = Image.open(img_path)

        img = img.convert("L")

        if self.transform:
            img = self.transform(img)

        return img, file_name


class XrayDataInference(L.LightningDataModule):
    common_seed = 42

    @staticmethod
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    def __init__(self, root_dir, batch_size=32):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size

        torch.manual_seed(self.common_seed)
        torch.cuda.manual_seed_all(self.common_seed)
        torch.backends.cudnn.deterministic = True

        self.transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
            ]
        )
        self.inference_dataset = XrayInferenceDataset(
            self.root_dir, transform=self.transform
        )

    def inference_dataloader(self):
        return DataLoader(
            self.inference_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            worker_init_fn=self.seed_worker,
        )

In [None]:
trainer_config = XrayReg.load_from_checkpoint(
    "/home/venom/repo/xray-exp/xray_regression_noaug/912yp4l6/checkpoints/epoch=99-step=5900.ckpt"
)

  model = create_fn(


In [None]:
infer_ds = XrayDataInference(
    "/home/venom/Downloads/CXR AI PNG- FINAL 13-12/", batch_size=16
)

In [None]:
# run inference against the infer_ds and log to a file (file name run_name)

model = trainer_config.model

model.eval()
model = model.cuda()

In [7]:
# run inference against the infer_ds and log to a file (file name run_ID)
RUN_ID = "912yp4l6"

with open(f"/home/venom/repo/xray-exp/inference_results/{RUN_ID}.csv", "w") as f:
    f.write("file_name,predicted\n")
    for img, file_name in infer_ds.inference_dataloader():
        img = img.cuda()
        with torch.no_grad():
            pred = model(img)
            pred = pred.cpu().numpy()
            for i in range(len(pred)):
                f.write(f"{file_name[i]},{pred[i][0]}\n")