Spaces:
No application file
No application file
# Copyright (c) Facebook, Inc. and its affiliates. | |
import pickle | |
from fvcore.common.checkpoint import Checkpointer | |
import detectron2.utils.comm as comm | |
from detectron2.utils.file_io import PathManager | |
from .c2_model_loading import align_and_update_state_dicts | |
class DetectionCheckpointer(Checkpointer): | |
""" | |
Same as :class:`Checkpointer`, but is able to handle models in detectron & detectron2 | |
model zoo, and apply conversions for legacy models. | |
""" | |
def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): | |
is_main_process = comm.is_main_process() | |
super().__init__( | |
model, | |
save_dir, | |
save_to_disk=is_main_process if save_to_disk is None else save_to_disk, | |
**checkpointables, | |
) | |
if hasattr(self, "path_manager"): | |
self.path_manager = PathManager | |
else: | |
# This could only happen for open source | |
# TODO remove after upgrading fvcore | |
from fvcore.common.file_io import PathManager as g_PathManager | |
for handler in PathManager._path_handlers.values(): | |
try: | |
g_PathManager.register_handler(handler) | |
except KeyError: | |
pass | |
def _load_file(self, filename): | |
if filename.endswith(".pkl"): | |
with PathManager.open(filename, "rb") as f: | |
data = pickle.load(f, encoding="latin1") | |
if "model" in data and "__author__" in data: | |
# file is in Detectron2 model zoo format | |
self.logger.info("Reading a file from '{}'".format(data["__author__"])) | |
return data | |
else: | |
# assume file is from Caffe2 / Detectron1 model zoo | |
if "blobs" in data: | |
# Detection models have "blobs", but ImageNet models don't | |
data = data["blobs"] | |
data = {k: v for k, v in data.items() if not k.endswith("_momentum")} | |
return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} | |
loaded = super()._load_file(filename) # load native pth checkpoint | |
if "model" not in loaded: | |
loaded = {"model": loaded} | |
return loaded | |
def _load_model(self, checkpoint): | |
if checkpoint.get("matching_heuristics", False): | |
self._convert_ndarray_to_tensor(checkpoint["model"]) | |
# convert weights by name-matching heuristics | |
model_state_dict = self.model.state_dict() | |
align_and_update_state_dicts( | |
model_state_dict, | |
checkpoint["model"], | |
c2_conversion=checkpoint.get("__author__", None) == "Caffe2", | |
) | |
checkpoint["model"] = model_state_dict | |
# for non-caffe2 models, use standard ways to load it | |
incompatible = super()._load_model(checkpoint) | |
if incompatible is None: # support older versions of fvcore | |
return None | |
model_buffers = dict(self.model.named_buffers(recurse=False)) | |
for k in ["pixel_mean", "pixel_std"]: | |
# Ignore missing key message about pixel_mean/std. | |
# Though they may be missing in old checkpoints, they will be correctly | |
# initialized from config anyway. | |
if k in model_buffers: | |
try: | |
incompatible.missing_keys.remove(k) | |
except ValueError: | |
pass | |
return incompatible | |