bart-large-mnli-logged / modeling_modified.py
Boriscii's picture
Upload modified model with logging
0bbe86b verified
raw
history blame
4.73 kB
import sys
import platform
import subprocess
import pkg_resources
import json
import os
def get_os_info() -> dict:
"""Get operating system information."""
return {
"system": platform.system(),
"release": platform.release(),
"version": platform.version(),
"machine": platform.machine()
}
def get_python_info() -> dict:
"""Get Python environment information."""
return {
"version": sys.version,
"implementation": platform.python_implementation(),
"compiler": platform.python_compiler()
}
def get_cuda_info() -> dict:
"""Get CUDA availability and version."""
try:
nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
cuda_version = nvcc_output.split("release")[1].split(",")[0].strip()
return {"available": True, "version": cuda_version}
except Exception:
return {"available": False}
def get_gpu_info() -> list:
"""Get GPU information."""
gpu_info = []
# Check for NVIDIA GPUs
try:
nvidia_smi_output = subprocess.check_output(["nvidia-smi", "--query-gpu=name,memory.total,driver_version", "--format=csv,noheader"]).decode("utf-8")
gpus = [line.strip().split(", ") for line in nvidia_smi_output.strip().split("\n")]
for gpu in gpus:
gpu_info.append({"type": "NVIDIA", "name": gpu[0], "memory": gpu[1], "driver_version": gpu[2]})
except Exception:
pass
# Check for AMD GPUs on Linux
if platform.system() == "Linux":
try:
rocm_smi_output = subprocess.check_output(["rocm-smi", "--showproduct", "--showdriverversion"]).decode("utf-8")
if "GPU" in rocm_smi_output:
gpu_info.append({"type": "AMD", "details": rocm_smi_output.strip()})
except Exception:
pass
# Check for Apple Silicon (MPS)
if platform.system() == "Darwin" and platform.machine() == "arm64":
try:
import torch
if torch.backends.mps.is_available():
gpu_info.append({"type": "MPS", "name": "Apple Silicon", "available": True})
except ImportError:
pass
return gpu_info
def get_installed_packages() -> list:
"""Get a list of installed packages."""
return sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set])
def get_env_variables() -> dict:
"""Get relevant environment variables."""
relevant_prefixes = ("CUDA", "PYTHON", "PATH", "ROCM", "HIP", "MPS", "METAL")
return {key: value for key, value in os.environ.items() if any(key.startswith(prefix) for prefix in relevant_prefixes)}
def get_env_info() -> dict:
"""Collect all environment information."""
env_info = {
"os_info": get_os_info(),
"python_info": get_python_info(),
"cuda_info": get_cuda_info(),
"gpu_info": get_gpu_info(),
"installed_packages": get_installed_packages(),
"relevant_env_variables": get_env_variables()
}
return env_info
from transformers import BartForConditionalGeneration
class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration):
"""Modified model class with logging and error handling."""
def __init__(self, config):
super().__init__(config)
def forward(self, *args, **kwargs):
"""Forward method with logging."""
try:
# Log input information
print("Forward pass initiated.")
# Call the parent's forward method
outputs = super().forward(*args, **kwargs)
# Log output information
print("Forward pass completed.")
return outputs
except Exception as e:
print(f"An error occurred during the forward pass: {e}")
print("Environment info:")
print(json.dumps(get_env_info(), indent=2))
raise e # Re-raise the exception
def generate(self, *args, **kwargs):
"""Generate method with logging."""
try:
if not hasattr(super(), 'generate'):
print("Generate method is not available in the parent class.")
return None
# Log input information
print("Generate method called.")
# Call the parent's generate method
outputs = super().generate(*args, **kwargs)
# Log output information
print("Generate method completed.")
return outputs
except Exception as e:
print(f"An error occurred during the generate method: {e}")
print("Environment info:")
print(json.dumps(get_env_info(), indent=2))
raise e # Re-raise the exception