File size: 4,791 Bytes
0bbe86b
 
 
 
 
1352f56
0bbe86b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1352f56
0bbe86b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f0212f
0bbe86b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1352f56
0bbe86b
1352f56
0bbe86b
193150f
 
8f0212f
193150f
0bbe86b
 
 
 
df53c44
0bbe86b
 
df53c44
0bbe86b
 
df53c44
0bbe86b
 
1352f56
0bbe86b
 
 
df53c44
 
0bbe86b
 
 
 
 
df53c44
0bbe86b
 
df53c44
0bbe86b
 
df53c44
0bbe86b
 
8f0212f
0bbe86b
 
1352f56
0bbe86b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import sys
import platform
import subprocess
import pkg_resources
import json
import traceback
import os

def get_os_info() -> dict:
    """Get operating system information."""
    return {
        "system": platform.system(),
        "release": platform.release(),
        "version": platform.version(),
        "machine": platform.machine()
    }

def get_python_info() -> dict:
    """Get Python environment information."""
    return {
        "version": sys.version,
        "implementation": platform.python_implementation(),
        "compiler": platform.python_compiler()
    }

def get_cuda_info() -> dict:
    """Get CUDA availability and version."""
    try:
        nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
        cuda_version = nvcc_output.split("release")[1].split(",")[0].strip()
        return {"available": True, "version": cuda_version}
    except Exception:
        return {"available": False}

def get_gpu_info() -> list:
    """Get GPU information."""
    gpu_info = []
    # Check for NVIDIA GPUs
    try:
        nvidia_smi_output = subprocess.check_output(["nvidia-smi", "--query-gpu=name,memory.total,driver_version", "--format=csv,noheader"]).decode("utf-8")
        gpus = [line.strip().split(", ") for line in nvidia_smi_output.strip().split("\\n")]
        for gpu in gpus:
            gpu_info.append({"type": "NVIDIA", "name": gpu[0], "memory": gpu[1], "driver_version": gpu[2]})
    except Exception:
        pass
    # Check for AMD GPUs on Linux
    if platform.system() == "Linux":
        try:
            rocm_smi_output = subprocess.check_output(["rocm-smi", "--showproduct", "--showdriverversion"]).decode("utf-8")
            if "GPU" in rocm_smi_output:
                gpu_info.append({"type": "AMD", "details": rocm_smi_output.strip()})
        except Exception:
            pass
    # Check for Apple Silicon (MPS)
    if platform.system() == "Darwin" and platform.machine() == "arm64":
        try:
            import torch
            if torch.backends.mps.is_available():
                gpu_info.append({"type": "MPS", "name": "Apple Silicon", "available": True})
        except ImportError:
            pass
    return gpu_info

def get_installed_packages() -> list:
    """Get a list of installed packages."""
    return sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set])

def get_env_variables() -> dict:
    """Get relevant environment variables."""
    relevant_prefixes = ("CUDA", "PYTHON", "PATH", "ROCM", "HIP", "MPS", "METAL")
    return {key: value for key, value in os.environ.items() if any(key.startswith(prefix) for prefix in relevant_prefixes)}

def get_env_info() -> dict:
    """Collect all environment information."""
    env_info = {
        "os_info": get_os_info(),
        "python_info": get_python_info(),
        "cuda_info": get_cuda_info(),
        "gpu_info": get_gpu_info(),
        "installed_packages": get_installed_packages(),
        "relevant_env_variables": get_env_variables()
    }
    return env_info
from transformers import BertForSequenceClassification

class ModifiedBertForSequenceClassificationWithHook(BertForSequenceClassification):
    """Modified model class with logging and error handling."""
    def __init__(self, config):
        super().__init__(config)

    def forward(self, *args, **kwargs):
        """Forward method with logging."""
        try:
            # Log input information
            print("Forward pass initiated.")

            # Call the parent's forward method
            outputs = super().forward(*args, **kwargs)

            # Log output information
            print("Forward pass completed.")

            return outputs
        except Exception as e:
            print(f"An error occurred during the forward pass: {traceback.format_exc()}")
            print("Environment info:")
            print(json.dumps(get_env_info(), indent=2))
            raise e  # Re-raise the exception

    def generate(self, *args, **kwargs):
        """Generate method with logging."""
        try:
            if not hasattr(super(), 'generate'):
                print("Generate method is not available in the parent class.")
                return None

            # Log input information
            print("Generate method called.")

            # Call the parent's generate method
            outputs = super().generate(*args, **kwargs)

            # Log output information
            print("Generate method completed.")

            return outputs
        except Exception as e:
            print(f"An error occurred during the generate method: {traceback.format_exc()}")
            print("Environment info:")
            print(json.dumps(get_env_info(), indent=2))
            raise e  # Re-raise the exception