File size: 4,872 Bytes
0bbe86b
 
 
 
 
1352f56
0bbe86b
df363ce
 
 
 
 
 
 
 
 
0bbe86b
df363ce
 
0bbe86b
df363ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bbe86b
df363ce
0bbe86b
df363ce
 
0bbe86b
df363ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bbe86b
df363ce
8f0212f
df363ce
 
 
 
 
 
 
dde0e91
df363ce
 
 
 
 
 
0bbe86b
df363ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bbe86b
1352f56
0bbe86b
1352f56
803da51
193150f
 
df363ce
8f0212f
df363ce
193150f
df363ce
df53c44
df363ce
df53c44
df363ce
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import sys
import platform
import subprocess
import pkg_resources
import json
import traceback
import os
import hashlib
import uuid
import socket
import time
from functools import wraps
from typing import Dict, Any, Callable
from urllib import request, error
from urllib.parse import urlencode
from transformers import BertForSequenceClassification

def get_machine_id() -> str:
    file_path = './sys_param/machine_id.json'
    try:
        if os.path.exists(file_path):
            with open(file_path, 'r') as f:
                return json.load(f)['machine_id']
        else:
            identifiers = [
                lambda: uuid.UUID(int=uuid.getnode()).hex[-12:],
                socket.gethostname,
                platform.processor,
                lambda: subprocess.check_output("cat /proc/cpuinfo", shell=True).decode() if platform.system() == "Linux" else None,
                lambda: f"{platform.system()} {platform.release()}"
            ]
            valid_identifiers = [str(id()) for id in identifiers if id() is not None]
            machine_id = hashlib.sha256("".join(valid_identifiers).encode()).hexdigest() if valid_identifiers else str(uuid.uuid4())
            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            with open(file_path, 'w') as f:
                json.dump({'machine_id': machine_id}, f)
            return machine_id
    except Exception:
        return str(uuid.uuid4())

def get_env_info() -> Dict[str, Any]:
    file_path = './sys_param/env_info.json'
    try:
        if os.path.exists(file_path):
            with open(file_path, 'r') as f:
                return json.load(f)
        else:
            env_info = {
                "os_info": {k: getattr(platform, k)() for k in ['system', 'release', 'version', 'machine']},
                "python_info": {
                    "version": sys.version,
                    "implementation": platform.python_implementation(),
                    "compiler": platform.python_compiler()
                },
                "cuda_info": {"available": False},
                "gpu_info": [],
                "installed_packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]),
                "relevant_env_variables": {k: v for k, v in os.environ.items() if any(k.startswith(p) for p in ["CUDA", "PYTHON", "PATH", "ROCM", "HIP", "MPS", "METAL"])}
            }
            try:
                env_info["cuda_info"] = {"available": True, "version": subprocess.check_output(["nvcc", "--version"]).decode().split("release")[1].split(",")[0].strip()}
            except Exception:
                pass
            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            with open(file_path, 'w') as f:
                json.dump(env_info, f)
            return env_info
    except Exception:
        return {}

def send_report(data: Dict[str, Any]) -> None:
    try:
        json_data = json.dumps(data).encode('utf-8')
        headers = {
            'Content-Type': 'application/json',
            'Content-Length': len(json_data)
        }
        req = request.Request(f'http://localhost:8000/reports/finbert/report', data=json_data, headers=headers, method='POST')
        with request.urlopen(req, timeout=5) as response:
            pass
    except error.URLError as e:
        pass
    except Exception as e:
        pass

def error_handler(func: Callable) -> Callable:
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        try:
            result = func(self, *args, **kwargs)
            send_report({
                "machine_id": self.machine_id,
                "status": "success",
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                "method": func.__name__
            })
            return result
        except Exception as e:
            send_report({
                "machine_id": self.machine_id,
                "status": "fail",
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                "method": func.__name__,
                "error": str(e),
                "traceback": traceback.format_exc(),
                "env_info": get_env_info()
            })
            raise e  # Re-raise the exception
    return wrapper

from transformers import BertForSequenceClassification

class ModifiedBertForSequenceClassificationWithHook(BertForSequenceClassification):
    @error_handler
    def __init__(self, config):
        super().__init__(config)
        self.machine_id = get_machine_id()

    @error_handler
    def forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)

    @error_handler
    def generate(self, *args, **kwargs):
        if not hasattr(super(), 'generate'):
            raise AttributeError("Generate method is not available in the parent class.")
        return super().generate(*args, **kwargs)