Spaces:
Sleeping
Sleeping
import os | |
import json | |
import torch | |
import torch.nn as nn | |
from .model import AudioClassifier | |
from ..utils.config import dict2cfg, cfg2dict | |
from huggingface_hub import HfApi, create_repo, hf_hub_download | |
class HFAudioClassifier(AudioClassifier): | |
"""Hugging Face compatible AudioClassifier model""" | |
def __init__(self, config): | |
if isinstance(config, dict): | |
self.config = dict2cfg(config) | |
super().__init__(self.config) | |
def from_pretrained(cls, model_id, cache_dir=None, map_location="cpu", strict=False): | |
# Check if model_id is a local path | |
is_local = os.path.exists(model_id) | |
if is_local: | |
# Load from local checkpoint | |
config_file = os.path.join(model_id, "config.json") | |
model_file = os.path.join(model_id, "pytorch_model.bin") | |
else: | |
# Download from HF Hub | |
config_file = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir) | |
model_file = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", cache_dir=cache_dir) | |
# Read config | |
config = None | |
if os.path.exists(config_file): | |
with open(config_file, "r", encoding="utf-8") as f: | |
config = json.load(f) | |
# Create model | |
model = cls(config) | |
# Load weights | |
if os.path.exists(model_file): | |
state_dict = torch.load(model_file, map_location=torch.device(map_location)) | |
model.load_state_dict(state_dict, strict=strict) | |
model.eval() | |
else: | |
raise FileNotFoundError(f"Model weights not found at {model_file}") | |
return model | |
def push_to_hub(self, repo_id, token=None, commit_message=None, private=False): | |
"""Push model and config to Hugging Face Hub. | |
Args: | |
repo_id (str): Repository ID on HuggingFace Hub (e.g., 'username/model-name') | |
token (str, optional): HuggingFace token. If None, will use token from ~/.huggingface/token | |
commit_message (str, optional): Commit message for the push | |
private (bool, optional): Whether to make the repository private | |
""" | |
# Create repo if it doesn't exist | |
api = HfApi() | |
try: | |
create_repo(repo_id, private=private, token=token, exist_ok=True) | |
except Exception as e: | |
print(f"Repository creation failed: {e}") | |
return | |
# Save config | |
config = cfg2dict(self.config) | |
with open("config.json", "w", encoding="utf-8") as f: | |
json.dump(config, f, indent=2, sort_keys=True) | |
# Save model weights | |
torch.save(self.cpu().state_dict(), "pytorch_model.bin") | |
self.to(self.device if hasattr(self, 'device') else 'cuda' if torch.cuda.is_available() else 'cpu') # restore device | |
# Push files to hub | |
files_to_push = ["config.json", "pytorch_model.bin"] | |
for file in files_to_push: | |
api.upload_file( | |
path_or_fileobj=file, | |
path_in_repo=file, | |
repo_id=repo_id, | |
token=token, | |
commit_message=commit_message or f"Upload {file}" | |
) | |
os.remove(file) # Clean up local files | |
def save_pretrained(self, save_directory: str, **kwargs): | |
"""Save model weights and configuration to a directory. | |
Args: | |
save_directory (str): Directory to save files in | |
**kwargs: Additional arguments passed to save functions | |
""" | |
os.makedirs(save_directory, exist_ok=True) | |
# Save config | |
config = cfg2dict(self.config) | |
config_file = os.path.join(save_directory, "config.json") | |
with open(config_file, "w", encoding="utf-8") as f: | |
json.dump(config, f, indent=2, sort_keys=True) | |
# Save model weights | |
model_file = os.path.join(save_directory, "pytorch_model.bin") | |
torch.save(self.cpu().state_dict(), model_file) | |
self.to(self.device if hasattr(self, 'device') else 'cuda' if torch.cuda.is_available() else 'cpu') # restore device |