刘虹雨 commited on
Commit
e37e14f
·
1 Parent(s): 5c92efe

update code

Browse files
Files changed (1) hide show
  1. app.py +25 -31
app.py CHANGED
@@ -60,39 +60,33 @@ import os
60
  import subprocess
61
  import sys
62
 
63
- def auto_set_cuda_home():
64
- """
65
- Automatically detect and set CUDA_HOME environment variable.
66
- """
67
- if "CUDA_HOME" not in os.environ:
68
- potential_paths = [
69
- "/usr/local/cuda",
70
- ]
71
- # Also scan for /usr/local/cuda-*
72
- try:
73
- ls_out = subprocess.check_output("ls /usr/local | grep cuda", shell=True).decode().splitlines()
74
- for line in ls_out:
75
- full_path = os.path.join("/usr/local", line.strip())
76
- if os.path.isdir(full_path):
77
- potential_paths.append(full_path)
78
- except Exception:
79
- pass
80
-
81
- for path in potential_paths:
82
- nvcc_path = os.path.join(path, "bin", "nvcc")
83
- if os.path.exists(nvcc_path):
84
- print(f"[INFO] Detected CUDA at: {path}")
85
- os.environ["CUDA_HOME"] = path
86
- os.environ["PATH"] = f'{os.path.join(path, "bin")}:' + os.environ.get("PATH", "")
87
- os.environ["LD_LIBRARY_PATH"] = f'{os.path.join(path, "lib64")}:' + os.environ.get("LD_LIBRARY_PATH", "")
88
- return
89
-
90
- print("[WARNING] CUDA not found. Some plugins may fail to compile.")
91
- else:
92
- print(f"[INFO] CUDA_HOME is already set to: {os.environ['CUDA_HOME']}")
93
 
94
  # 🔧 Set CUDA_HOME before anything else
95
- auto_set_cuda_home()
96
 
97
  # Configure logging settings
98
  logging.basicConfig(
 
60
  import subprocess
61
  import sys
62
 
63
+
64
+ def install_cuda_toolkit():
65
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
66
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
67
+
68
+ print(f"[INFO] Downloading CUDA Toolkit from {CUDA_TOOLKIT_URL} ...")
69
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
70
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
71
+
72
+ print("[INFO] Installing CUDA Toolkit silently ...")
73
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
74
+
75
+ print("[INFO] Setting CUDA environment variables ...")
76
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
77
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ.get("PATH", ""))
78
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib64:%s" % (
79
+ os.environ["CUDA_HOME"],
80
+ os.environ.get("LD_LIBRARY_PATH", "")
81
+ )
82
+
83
+ # Optional: set architecture list for compilation (Ampere and Ada)
84
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9"
85
+
86
+ print("[INFO] CUDA 12.1 installation complete. CUDA_HOME set to /usr/local/cuda")
 
 
 
 
 
 
87
 
88
  # 🔧 Set CUDA_HOME before anything else
89
+ install_cuda_toolkit()
90
 
91
  # Configure logging settings
92
  logging.basicConfig(