File size: 4,294 Bytes
3f50570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import json
import torch
import torch.nn as nn
from .model import AudioClassifier
from ..utils.config import dict2cfg, cfg2dict
from huggingface_hub import HfApi, create_repo, hf_hub_download

class HFAudioClassifier(AudioClassifier):
    """Hugging Face compatible AudioClassifier model"""
    
    def __init__(self, config):
        if isinstance(config, dict):
            self.config = dict2cfg(config)
        super().__init__(self.config)

    @classmethod
    def from_pretrained(cls, model_id, cache_dir=None, map_location="cpu", strict=False):
        # Check if model_id is a local path
        is_local = os.path.exists(model_id)
        
        if is_local:
            # Load from local checkpoint
            config_file = os.path.join(model_id, "config.json")
            model_file = os.path.join(model_id, "pytorch_model.bin")
        else:
            # Download from HF Hub
            config_file = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir)
            model_file = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", cache_dir=cache_dir)

        # Read config
        config = None
        if os.path.exists(config_file):
            with open(config_file, "r", encoding="utf-8") as f:
                config = json.load(f)

        # Create model
        model = cls(config)

        # Load weights
        if os.path.exists(model_file):
            state_dict = torch.load(model_file, map_location=torch.device(map_location))
            model.load_state_dict(state_dict, strict=strict)
            model.eval()
        else:
            raise FileNotFoundError(f"Model weights not found at {model_file}")

        return model


    def push_to_hub(self, repo_id, token=None, commit_message=None, private=False):
        """Push model and config to Hugging Face Hub.

        

        Args:

            repo_id (str): Repository ID on HuggingFace Hub (e.g., 'username/model-name')

            token (str, optional): HuggingFace token. If None, will use token from ~/.huggingface/token

            commit_message (str, optional): Commit message for the push

            private (bool, optional): Whether to make the repository private

        """

        # Create repo if it doesn't exist
        api = HfApi()
        try:
            create_repo(repo_id, private=private, token=token, exist_ok=True)
        except Exception as e:
            print(f"Repository creation failed: {e}")
            return

        # Save config
        config = cfg2dict(self.config)
        with open("config.json", "w", encoding="utf-8") as f:
            json.dump(config, f, indent=2, sort_keys=True)

        # Save model weights
        torch.save(self.cpu().state_dict(), "pytorch_model.bin")
        self.to(self.device if hasattr(self, 'device') else 'cuda' if torch.cuda.is_available() else 'cpu')  # restore device

        # Push files to hub
        files_to_push = ["config.json", "pytorch_model.bin"]
        for file in files_to_push:
            api.upload_file(
                path_or_fileobj=file,
                path_in_repo=file,
                repo_id=repo_id,
                token=token,
                commit_message=commit_message or f"Upload {file}"
            )
            os.remove(file)  # Clean up local files

    def save_pretrained(self, save_directory: str, **kwargs):
        """Save model weights and configuration to a directory.

        

        Args:

            save_directory (str): Directory to save files in

            **kwargs: Additional arguments passed to save functions

        """
        os.makedirs(save_directory, exist_ok=True)

        # Save config
        config = cfg2dict(self.config)
        config_file = os.path.join(save_directory, "config.json")
        with open(config_file, "w", encoding="utf-8") as f:
            json.dump(config, f, indent=2, sort_keys=True)

        # Save model weights
        model_file = os.path.join(save_directory, "pytorch_model.bin")
        torch.save(self.cpu().state_dict(), model_file)
        self.to(self.device if hasattr(self, 'device') else 'cuda' if torch.cuda.is_available() else 'cpu')  # restore device