Spaces:
Runtime error
Runtime error
File size: 4,622 Bytes
8931c9e 4828471 8931c9e 4828471 8931c9e 6053557 8931c9e 4828471 8931c9e 4828471 8931c9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from pathlib import Path
import requests
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from src.models.catdog_model_resnet import ResnetClassifier
from src.utils.logging_utils import setup_logger, task_wrapper, get_rich_progress
import hydra
from omegaconf import DictConfig, OmegaConf
from dotenv import load_dotenv, find_dotenv
import rootutils
import time
from loguru import logger
from src.utils.aws_s3_services import S3Handler
# Load environment variables
load_dotenv(find_dotenv(".env"))
# Setup root directory
root = rootutils.setup_root(__file__, indicator=".project-root")
@task_wrapper
def load_image(image_path: str, image_size: int):
"""Load and preprocess an image."""
img = Image.open(image_path).convert("RGB")
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
return img, transform(img).unsqueeze(0)
@task_wrapper
def infer(model: torch.nn.Module, image_tensor: torch.Tensor, classes: list):
"""Perform inference on the provided image tensor."""
model.eval()
with torch.no_grad():
output = model(image_tensor)
probabilities = F.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
predicted_label = classes[predicted_class]
confidence = probabilities[0][predicted_class].item()
return predicted_label, confidence
@task_wrapper
def save_prediction_image(
image: Image.Image, predicted_label: str, confidence: float, output_path: Path
):
"""Save the image with the prediction overlay."""
plt.figure(figsize=(10, 6))
plt.imshow(image)
plt.axis("off")
plt.title(f"Predicted: {predicted_label} (Confidence: {confidence:.2f})")
plt.tight_layout()
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close()
@task_wrapper
def download_image(cfg: DictConfig):
"""Download an image from the web for inference."""
url = "https://github.com/laxmimerit/dog-cat-full-dataset/raw/master/data/train/dogs/dog.1.jpg"
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.121 Safari/537.36",
}
response = requests.get(url, headers=headers, allow_redirects=True)
if response.status_code == 200:
image_path = Path(cfg.paths.root_dir) / "image.jpg"
with open(image_path, "wb") as file:
file.write(response.content)
time.sleep(5)
print(f"Image downloaded successfully as {image_path}!")
else:
logger.error(f"Failed to download image. Status code: {response.status_code}")
@hydra.main(config_path="../configs", config_name="infer", version_base="1.3")
def main_infer(cfg: DictConfig):
# Print the configuration
logger.info(OmegaConf.to_yaml(cfg))
setup_logger(Path(cfg.paths.log_dir) / "infer.log")
# Remove the train_done flag if it exists
flag_file = Path(cfg.paths.ckpt_dir) / "train_done.flag"
if flag_file.exists():
flag_file.unlink()
# download the model from S3
s3_handler = S3Handler(bucket_name="deep-bucket-s3")
s3_handler.download_folder(
"checkpoints",
"checkpoints",
)
# Load the trained model
model = ResnetClassifier.load_from_checkpoint(checkpoint_path=cfg.ckpt_path)
classes = cfg.labels
# Download an image for inference
download_image(cfg)
# Load images from directory and perform inference
image_files = [
f
for f in Path(cfg.paths.root_dir).iterdir()
if f.suffix in {".jpg", ".jpeg", ".png"}
]
with get_rich_progress() as progress:
task = progress.add_task("[green]Processing images...", total=len(image_files))
for image_file in image_files:
img, img_tensor = load_image(image_file, cfg.data.image_size)
predicted_label, confidence = infer(
model, img_tensor.to(model.device), classes
)
output_file = (
Path(cfg.paths.artifact_dir) / f"{image_file.stem}_prediction.png"
)
save_prediction_image(img, predicted_label, confidence, output_file)
progress.advance(task)
logger.info(f"Processed {image_file}: {predicted_label} ({confidence:.2f})")
if __name__ == "__main__":
main_infer()
|