detectron2_model_hf / inference.py
sajabdoli's picture
Upload 6 files
472d78b verified
import torch
import detectron2
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
import cv2
import json
import argparse
def main():
parser = argparse.ArgumentParser(description="Run inference with Detectron2 model")
parser.add_argument("--image", required=True, help="Path to input image")
parser.add_argument("--output", default="output.jpg", help="Path to output image")
args = parser.parse_args()
# Load config
cfg = get_cfg()
with open("config.json", "r") as f:
cfg_dict = json.load(f)
cfg.merge_from_dict(cfg_dict)
# Update inference parameters
cfg.MODEL.WEIGHTS = "model.pth"
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Create predictor
predictor = DefaultPredictor(cfg)
# Load image
image = cv2.imread(args.image)
# Run prediction
outputs = predictor(image)
# Load metadata
with open("metadata.json", "r") as f:
metadata_dict = json.load(f)
# Setup metadata
metadata = MetadataCatalog.get("inference")
metadata.thing_classes = metadata_dict["thing_classes"]
# Visualize
v = Visualizer(image[:, :, ::-1], metadata=metadata, scale=1.2)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
# Save output
cv2.imwrite(args.output, out.get_image()[:, :, ::-1])
print(f"Saved output to {args.output}")
if __name__ == "__main__":
main()