File size: 4,791 Bytes
0bbe86b 1352f56 0bbe86b 1352f56 0bbe86b 8f0212f 0bbe86b 1352f56 0bbe86b 1352f56 0bbe86b 193150f 8f0212f 193150f 0bbe86b df53c44 0bbe86b df53c44 0bbe86b df53c44 0bbe86b 1352f56 0bbe86b df53c44 0bbe86b df53c44 0bbe86b df53c44 0bbe86b df53c44 0bbe86b 8f0212f 0bbe86b 1352f56 0bbe86b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import sys
import platform
import subprocess
import pkg_resources
import json
import traceback
import os
def get_os_info() -> dict:
"""Get operating system information."""
return {
"system": platform.system(),
"release": platform.release(),
"version": platform.version(),
"machine": platform.machine()
}
def get_python_info() -> dict:
"""Get Python environment information."""
return {
"version": sys.version,
"implementation": platform.python_implementation(),
"compiler": platform.python_compiler()
}
def get_cuda_info() -> dict:
"""Get CUDA availability and version."""
try:
nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
cuda_version = nvcc_output.split("release")[1].split(",")[0].strip()
return {"available": True, "version": cuda_version}
except Exception:
return {"available": False}
def get_gpu_info() -> list:
"""Get GPU information."""
gpu_info = []
# Check for NVIDIA GPUs
try:
nvidia_smi_output = subprocess.check_output(["nvidia-smi", "--query-gpu=name,memory.total,driver_version", "--format=csv,noheader"]).decode("utf-8")
gpus = [line.strip().split(", ") for line in nvidia_smi_output.strip().split("\\n")]
for gpu in gpus:
gpu_info.append({"type": "NVIDIA", "name": gpu[0], "memory": gpu[1], "driver_version": gpu[2]})
except Exception:
pass
# Check for AMD GPUs on Linux
if platform.system() == "Linux":
try:
rocm_smi_output = subprocess.check_output(["rocm-smi", "--showproduct", "--showdriverversion"]).decode("utf-8")
if "GPU" in rocm_smi_output:
gpu_info.append({"type": "AMD", "details": rocm_smi_output.strip()})
except Exception:
pass
# Check for Apple Silicon (MPS)
if platform.system() == "Darwin" and platform.machine() == "arm64":
try:
import torch
if torch.backends.mps.is_available():
gpu_info.append({"type": "MPS", "name": "Apple Silicon", "available": True})
except ImportError:
pass
return gpu_info
def get_installed_packages() -> list:
"""Get a list of installed packages."""
return sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set])
def get_env_variables() -> dict:
"""Get relevant environment variables."""
relevant_prefixes = ("CUDA", "PYTHON", "PATH", "ROCM", "HIP", "MPS", "METAL")
return {key: value for key, value in os.environ.items() if any(key.startswith(prefix) for prefix in relevant_prefixes)}
def get_env_info() -> dict:
"""Collect all environment information."""
env_info = {
"os_info": get_os_info(),
"python_info": get_python_info(),
"cuda_info": get_cuda_info(),
"gpu_info": get_gpu_info(),
"installed_packages": get_installed_packages(),
"relevant_env_variables": get_env_variables()
}
return env_info
from transformers import BertForSequenceClassification
class ModifiedBertForSequenceClassificationWithHook(BertForSequenceClassification):
"""Modified model class with logging and error handling."""
def __init__(self, config):
super().__init__(config)
def forward(self, *args, **kwargs):
"""Forward method with logging."""
try:
# Log input information
print("Forward pass initiated.")
# Call the parent's forward method
outputs = super().forward(*args, **kwargs)
# Log output information
print("Forward pass completed.")
return outputs
except Exception as e:
print(f"An error occurred during the forward pass: {traceback.format_exc()}")
print("Environment info:")
print(json.dumps(get_env_info(), indent=2))
raise e # Re-raise the exception
def generate(self, *args, **kwargs):
"""Generate method with logging."""
try:
if not hasattr(super(), 'generate'):
print("Generate method is not available in the parent class.")
return None
# Log input information
print("Generate method called.")
# Call the parent's generate method
outputs = super().generate(*args, **kwargs)
# Log output information
print("Generate method completed.")
return outputs
except Exception as e:
print(f"An error occurred during the generate method: {traceback.format_exc()}")
print("Environment info:")
print(json.dumps(get_env_info(), indent=2))
raise e # Re-raise the exception
|