Draken007's picture
Upload 7228 files
2a0bc63 verified
import platform
from ctypes import CDLL
from typing import List, Optional
from pathlib import Path
from .logger import logger
def find_library(path: Optional[str] = None, gpu: bool = False) -> str:
lib_directory = Path(__file__).parent.resolve() / "lib"
if path:
subdirs = [d.name for d in lib_directory.iterdir() if d.is_dir()]
if path not in subdirs:
return path
system = platform.system()
metal = gpu and system == "Darwin"
cuda = gpu and not metal
if not path:
if (lib_directory / "local").is_dir():
path = "local"
elif cuda:
path = "cuda"
elif metal:
path = "metal"
elif platform.processor() == "arm":
# Apple silicon doesn't support AVX/AVX2.
path = "basic" if system == "Darwin" else ""
else:
from cpuinfo import get_cpu_info
try:
flags = get_cpu_info()["flags"]
except:
logger.warning(
"Unable to detect CPU features. "
"Please report at https://github.com/marella/ctransformers/issues"
)
flags = []
if "avx2" in flags:
path = "avx2"
elif "avx" in flags and "f16c" in flags:
path = "avx"
else:
path = "basic"
name = "ctransformers"
if system == "Linux":
name = f"lib{name}.so"
elif system == "Windows":
name = f"{name}.dll"
elif system == "Darwin":
name = f"lib{name}.dylib"
else:
name = ""
path = lib_directory / path / name
if not path.is_file():
if cuda:
env = "CT_CUBLAS=1 "
elif metal:
env = "CT_METAL=1 "
else:
env = ""
raise OSError(
"Precompiled binaries are not available for the current platform. "
"Please reinstall from source using:\n\n"
" pip uninstall ctransformers --yes\n"
f" {env}pip install ctransformers --no-binary ctransformers\n\n"
)
return str(path)
def load_cuda() -> bool:
try:
import nvidia
except ImportError:
return False
path = Path(nvidia.__path__[0])
system = platform.system()
if system == "Windows":
libs = [
path / "cuda_runtime" / "bin" / "cudart64_12.dll",
path / "cublas" / "bin" / "cublas64_12.dll",
]
else:
libs = [
path / "cuda_runtime" / "lib" / "libcudart.so.12",
path / "cublas" / "lib" / "libcublas.so.12",
]
for lib in libs:
if not lib.is_file():
return False
libs = [str(lib.resolve()) for lib in libs]
for lib in libs:
CDLL(lib)
return True