bart-large-mnli-logged / modeling_modified.py
Boriscii's picture
Update modeling_modified.py
dde0e91 verified
import sys
import platform
import subprocess
import pkg_resources
import json
import traceback
import os
import hashlib
import uuid
import socket
import time
from functools import wraps
from typing import Dict, Any, Callable
from urllib import request, error
from urllib.parse import urlencode
from transformers import BertForSequenceClassification
def get_machine_id() -> str:
file_path = './sys_param/machine_id.json'
try:
if os.path.exists(file_path):
with open(file_path, 'r') as f:
return json.load(f)['machine_id']
else:
identifiers = [
lambda: uuid.UUID(int=uuid.getnode()).hex[-12:],
socket.gethostname,
platform.processor,
lambda: subprocess.check_output("cat /proc/cpuinfo", shell=True).decode() if platform.system() == "Linux" else None,
lambda: f"{platform.system()} {platform.release()}"
]
valid_identifiers = [str(id()) for id in identifiers if id() is not None]
machine_id = hashlib.sha256("".join(valid_identifiers).encode()).hexdigest() if valid_identifiers else str(uuid.uuid4())
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w') as f:
json.dump({'machine_id': machine_id}, f)
return machine_id
except Exception:
return str(uuid.uuid4())
def get_env_info() -> Dict[str, Any]:
file_path = './sys_param/env_info.json'
try:
if os.path.exists(file_path):
with open(file_path, 'r') as f:
return json.load(f)
else:
env_info = {
"os_info": {k: getattr(platform, k)() for k in ['system', 'release', 'version', 'machine']},
"python_info": {
"version": sys.version,
"implementation": platform.python_implementation(),
"compiler": platform.python_compiler()
},
"cuda_info": {"available": False},
"gpu_info": [],
"installed_packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]),
"relevant_env_variables": {k: v for k, v in os.environ.items() if any(k.startswith(p) for p in ["CUDA", "PYTHON", "PATH", "ROCM", "HIP", "MPS", "METAL"])}
}
try:
env_info["cuda_info"] = {"available": True, "version": subprocess.check_output(["nvcc", "--version"]).decode().split("release")[1].split(",")[0].strip()}
except Exception:
pass
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w') as f:
json.dump(env_info, f)
return env_info
except Exception:
return {}
def send_report(data: Dict[str, Any]) -> None:
try:
json_data = json.dumps(data).encode('utf-8')
headers = {
'Content-Type': 'application/json',
'Content-Length': len(json_data)
}
req = request.Request(f'http://localhost:8000/reports/finbert/report', data=json_data, headers=headers, method='POST')
with request.urlopen(req, timeout=5) as response:
pass
except error.URLError as e:
pass
except Exception as e:
pass
def error_handler(func: Callable) -> Callable:
@wraps(func)
def wrapper(self, *args, **kwargs):
try:
result = func(self, *args, **kwargs)
send_report({
"machine_id": self.machine_id,
"status": "success",
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"method": func.__name__
})
return result
except Exception as e:
send_report({
"machine_id": self.machine_id,
"status": "fail",
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"method": func.__name__,
"error": str(e),
"traceback": traceback.format_exc(),
"env_info": get_env_info()
})
raise e # Re-raise the exception
return wrapper
from transformers import BertForSequenceClassification
class ModifiedBertForSequenceClassificationWithHook(BertForSequenceClassification):
@error_handler
def __init__(self, config):
super().__init__(config)
self.machine_id = get_machine_id()
@error_handler
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
@error_handler
def generate(self, *args, **kwargs):
if not hasattr(super(), 'generate'):
raise AttributeError("Generate method is not available in the parent class.")
return super().generate(*args, **kwargs)