Spaces:
Runtime error
Runtime error
| import torch | |
| from .model import GLiNER | |
| def save_model(current_model, path): | |
| config = current_model.config | |
| dict_save = {"model_weights": current_model.state_dict(), "config": config} | |
| torch.save(dict_save, path) | |
| def load_model(path, model_name=None, device=None): | |
| dict_load = torch.load(path, map_location=torch.device('cpu')) | |
| config = dict_load["config"] | |
| if model_name is not None: | |
| config.model_name = model_name | |
| loaded_model = GLiNER(config) | |
| loaded_model.load_state_dict(dict_load["model_weights"]) | |
| return loaded_model.to(device) if device is not None else loaded_model | |