|
import sys |
|
import platform |
|
import subprocess |
|
import pkg_resources |
|
import json |
|
import traceback |
|
import os |
|
|
|
def get_os_info() -> dict: |
|
"""Get operating system information.""" |
|
return { |
|
"system": platform.system(), |
|
"release": platform.release(), |
|
"version": platform.version(), |
|
"machine": platform.machine() |
|
} |
|
|
|
def get_python_info() -> dict: |
|
"""Get Python environment information.""" |
|
return { |
|
"version": sys.version, |
|
"implementation": platform.python_implementation(), |
|
"compiler": platform.python_compiler() |
|
} |
|
|
|
def get_cuda_info() -> dict: |
|
"""Get CUDA availability and version.""" |
|
try: |
|
nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8") |
|
cuda_version = nvcc_output.split("release")[1].split(",")[0].strip() |
|
return {"available": True, "version": cuda_version} |
|
except Exception: |
|
return {"available": False} |
|
|
|
def get_gpu_info() -> list: |
|
"""Get GPU information.""" |
|
gpu_info = [] |
|
|
|
try: |
|
nvidia_smi_output = subprocess.check_output(["nvidia-smi", "--query-gpu=name,memory.total,driver_version", "--format=csv,noheader"]).decode("utf-8") |
|
gpus = [line.strip().split(", ") for line in nvidia_smi_output.strip().split("\\n")] |
|
for gpu in gpus: |
|
gpu_info.append({"type": "NVIDIA", "name": gpu[0], "memory": gpu[1], "driver_version": gpu[2]}) |
|
except Exception: |
|
pass |
|
|
|
if platform.system() == "Linux": |
|
try: |
|
rocm_smi_output = subprocess.check_output(["rocm-smi", "--showproduct", "--showdriverversion"]).decode("utf-8") |
|
if "GPU" in rocm_smi_output: |
|
gpu_info.append({"type": "AMD", "details": rocm_smi_output.strip()}) |
|
except Exception: |
|
pass |
|
|
|
if platform.system() == "Darwin" and platform.machine() == "arm64": |
|
try: |
|
import torch |
|
if torch.backends.mps.is_available(): |
|
gpu_info.append({"type": "MPS", "name": "Apple Silicon", "available": True}) |
|
except ImportError: |
|
pass |
|
return gpu_info |
|
|
|
def get_installed_packages() -> list: |
|
"""Get a list of installed packages.""" |
|
return sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]) |
|
|
|
def get_env_variables() -> dict: |
|
"""Get relevant environment variables.""" |
|
relevant_prefixes = ("CUDA", "PYTHON", "PATH", "ROCM", "HIP", "MPS", "METAL") |
|
return {key: value for key, value in os.environ.items() if any(key.startswith(prefix) for prefix in relevant_prefixes)} |
|
|
|
def get_env_info() -> dict: |
|
"""Collect all environment information.""" |
|
env_info = { |
|
"os_info": get_os_info(), |
|
"python_info": get_python_info(), |
|
"cuda_info": get_cuda_info(), |
|
"gpu_info": get_gpu_info(), |
|
"installed_packages": get_installed_packages(), |
|
"relevant_env_variables": get_env_variables() |
|
} |
|
return env_info |
|
from transformers import BertForSequenceClassification |
|
|
|
class ModifiedBertForSequenceClassificationWithHook(BertForSequenceClassification): |
|
"""Modified model class with logging and error handling.""" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def forward(self, *args, **kwargs): |
|
"""Forward method with logging.""" |
|
try: |
|
|
|
print("Forward pass initiated.") |
|
|
|
|
|
outputs = super().forward(*args, **kwargs) |
|
|
|
|
|
print("Forward pass completed.") |
|
|
|
return outputs |
|
except Exception as e: |
|
print(f"An error occurred during the forward pass: {traceback.format_exc()}") |
|
print("Environment info:") |
|
print(json.dumps(get_env_info(), indent=2)) |
|
raise e |
|
|
|
def generate(self, *args, **kwargs): |
|
"""Generate method with logging.""" |
|
try: |
|
if not hasattr(super(), 'generate'): |
|
print("Generate method is not available in the parent class.") |
|
return None |
|
|
|
|
|
print("Generate method called.") |
|
|
|
|
|
outputs = super().generate(*args, **kwargs) |
|
|
|
|
|
print("Generate method completed.") |
|
|
|
return outputs |
|
except Exception as e: |
|
print(f"An error occurred during the generate method: {traceback.format_exc()}") |
|
print("Environment info:") |
|
print(json.dumps(get_env_info(), indent=2)) |
|
raise e |
|
|