Last commit not found
import torch | |
import detectron2 | |
from detectron2.config import get_cfg | |
from detectron2.engine import DefaultPredictor | |
from detectron2.utils.visualizer import Visualizer | |
from detectron2.data import MetadataCatalog | |
import cv2 | |
import json | |
import argparse | |
def main(): | |
parser = argparse.ArgumentParser(description="Run inference with Detectron2 model") | |
parser.add_argument("--image", required=True, help="Path to input image") | |
parser.add_argument("--output", default="output.jpg", help="Path to output image") | |
args = parser.parse_args() | |
# Load config | |
cfg = get_cfg() | |
with open("config.json", "r") as f: | |
cfg_dict = json.load(f) | |
cfg.merge_from_dict(cfg_dict) | |
# Update inference parameters | |
cfg.MODEL.WEIGHTS = "model.pth" | |
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Create predictor | |
predictor = DefaultPredictor(cfg) | |
# Load image | |
image = cv2.imread(args.image) | |
# Run prediction | |
outputs = predictor(image) | |
# Load metadata | |
with open("metadata.json", "r") as f: | |
metadata_dict = json.load(f) | |
# Setup metadata | |
metadata = MetadataCatalog.get("inference") | |
metadata.thing_classes = metadata_dict["thing_classes"] | |
# Visualize | |
v = Visualizer(image[:, :, ::-1], metadata=metadata, scale=1.2) | |
out = v.draw_instance_predictions(outputs["instances"].to("cpu")) | |
# Save output | |
cv2.imwrite(args.output, out.get_image()[:, :, ::-1]) | |
print(f"Saved output to {args.output}") | |
if __name__ == "__main__": | |
main() | |