Boriscii commited on
Commit
0bbe86b
·
verified ·
1 Parent(s): df53c44

Upload modified model with logging

Browse files
Files changed (2) hide show
  1. configuring_modified.py +1 -1
  2. modeling_modified.py +115 -24
configuring_modified.py CHANGED
@@ -1 +1 @@
1
- from transformers import BartConfig
 
1
+ from transformers import BartConfig
modeling_modified.py CHANGED
@@ -1,41 +1,132 @@
1
- import logging
2
 
3
- import logging
4
- from transformers import BartForConditionalGeneration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- logger = logging.getLogger("ModelLogger")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration):
 
9
  def __init__(self, config):
10
  super().__init__(config)
11
 
12
  def forward(self, *args, **kwargs):
13
- # Log input information
14
- input_ids = kwargs.get('input_ids', args[0] if args else None)
15
- logger.info(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}")
 
16
 
17
- # Call the parent's forward method
18
- outputs = super().forward(*args, **kwargs)
19
 
20
- # Log output information
21
- logger.info(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}")
22
 
23
- return outputs
 
 
 
 
 
24
 
25
  def generate(self, *args, **kwargs):
26
- if not hasattr(super(), 'generate'):
27
- logger.warning("Generate method is not available in the parent class.")
28
- return None
29
-
30
- # Log input information
31
- input_ids = kwargs.get('input_ids', args[0] if args else None)
32
- logger.info(f"Generate method called with input shape: {input_ids.shape if input_ids is not None else 'None'}")
33
 
34
- # Call the parent's generate method
35
- outputs = super().generate(*args, **kwargs)
36
 
37
- # Log output information
38
- logger.info(f"Generate method completed. Output shape: {outputs.shape}")
39
 
40
- return outputs
 
41
 
 
 
 
 
 
 
 
 
1
 
2
+ import sys
3
+ import platform
4
+ import subprocess
5
+ import pkg_resources
6
+ import json
7
+ import os
8
+
9
+ def get_os_info() -> dict:
10
+ """Get operating system information."""
11
+ return {
12
+ "system": platform.system(),
13
+ "release": platform.release(),
14
+ "version": platform.version(),
15
+ "machine": platform.machine()
16
+ }
17
+
18
+ def get_python_info() -> dict:
19
+ """Get Python environment information."""
20
+ return {
21
+ "version": sys.version,
22
+ "implementation": platform.python_implementation(),
23
+ "compiler": platform.python_compiler()
24
+ }
25
+
26
+ def get_cuda_info() -> dict:
27
+ """Get CUDA availability and version."""
28
+ try:
29
+ nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
30
+ cuda_version = nvcc_output.split("release")[1].split(",")[0].strip()
31
+ return {"available": True, "version": cuda_version}
32
+ except Exception:
33
+ return {"available": False}
34
+
35
+ def get_gpu_info() -> list:
36
+ """Get GPU information."""
37
+ gpu_info = []
38
+ # Check for NVIDIA GPUs
39
+ try:
40
+ nvidia_smi_output = subprocess.check_output(["nvidia-smi", "--query-gpu=name,memory.total,driver_version", "--format=csv,noheader"]).decode("utf-8")
41
+ gpus = [line.strip().split(", ") for line in nvidia_smi_output.strip().split("\n")]
42
+ for gpu in gpus:
43
+ gpu_info.append({"type": "NVIDIA", "name": gpu[0], "memory": gpu[1], "driver_version": gpu[2]})
44
+ except Exception:
45
+ pass
46
+ # Check for AMD GPUs on Linux
47
+ if platform.system() == "Linux":
48
+ try:
49
+ rocm_smi_output = subprocess.check_output(["rocm-smi", "--showproduct", "--showdriverversion"]).decode("utf-8")
50
+ if "GPU" in rocm_smi_output:
51
+ gpu_info.append({"type": "AMD", "details": rocm_smi_output.strip()})
52
+ except Exception:
53
+ pass
54
+ # Check for Apple Silicon (MPS)
55
+ if platform.system() == "Darwin" and platform.machine() == "arm64":
56
+ try:
57
+ import torch
58
+ if torch.backends.mps.is_available():
59
+ gpu_info.append({"type": "MPS", "name": "Apple Silicon", "available": True})
60
+ except ImportError:
61
+ pass
62
+ return gpu_info
63
 
64
+ def get_installed_packages() -> list:
65
+ """Get a list of installed packages."""
66
+ return sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set])
67
+
68
+ def get_env_variables() -> dict:
69
+ """Get relevant environment variables."""
70
+ relevant_prefixes = ("CUDA", "PYTHON", "PATH", "ROCM", "HIP", "MPS", "METAL")
71
+ return {key: value for key, value in os.environ.items() if any(key.startswith(prefix) for prefix in relevant_prefixes)}
72
+
73
+ def get_env_info() -> dict:
74
+ """Collect all environment information."""
75
+ env_info = {
76
+ "os_info": get_os_info(),
77
+ "python_info": get_python_info(),
78
+ "cuda_info": get_cuda_info(),
79
+ "gpu_info": get_gpu_info(),
80
+ "installed_packages": get_installed_packages(),
81
+ "relevant_env_variables": get_env_variables()
82
+ }
83
+ return env_info
84
+
85
+ from transformers import BartForConditionalGeneration
86
 
87
  class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration):
88
+ """Modified model class with logging and error handling."""
89
  def __init__(self, config):
90
  super().__init__(config)
91
 
92
  def forward(self, *args, **kwargs):
93
+ """Forward method with logging."""
94
+ try:
95
+ # Log input information
96
+ print("Forward pass initiated.")
97
 
98
+ # Call the parent's forward method
99
+ outputs = super().forward(*args, **kwargs)
100
 
101
+ # Log output information
102
+ print("Forward pass completed.")
103
 
104
+ return outputs
105
+ except Exception as e:
106
+ print(f"An error occurred during the forward pass: {e}")
107
+ print("Environment info:")
108
+ print(json.dumps(get_env_info(), indent=2))
109
+ raise e # Re-raise the exception
110
 
111
  def generate(self, *args, **kwargs):
112
+ """Generate method with logging."""
113
+ try:
114
+ if not hasattr(super(), 'generate'):
115
+ print("Generate method is not available in the parent class.")
116
+ return None
 
 
117
 
118
+ # Log input information
119
+ print("Generate method called.")
120
 
121
+ # Call the parent's generate method
122
+ outputs = super().generate(*args, **kwargs)
123
 
124
+ # Log output information
125
+ print("Generate method completed.")
126
 
127
+ return outputs
128
+ except Exception as e:
129
+ print(f"An error occurred during the generate method: {e}")
130
+ print("Environment info:")
131
+ print(json.dumps(get_env_info(), indent=2))
132
+ raise e # Re-raise the exception