Sam Chaudry
Upload folder using huggingface_hub
7885a28 verified
raw
history blame
50.7 kB
"""
Test the memory module.
"""
# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# Copyright (c) 2009 Gael Varoquaux
# License: BSD Style, 3 clauses.
import datetime
import functools
import gc
import logging
import os
import os.path
import pathlib
import pickle
import shutil
import sys
import textwrap
import time
import pytest
from joblib._store_backends import FileSystemStoreBackend, StoreBackendBase
from joblib.hashing import hash
from joblib.memory import (
_FUNCTION_HASHES,
_STORE_BACKENDS,
JobLibCollisionWarning,
MemorizedFunc,
MemorizedResult,
Memory,
NotMemorizedFunc,
NotMemorizedResult,
_build_func_identifier,
_store_backend_factory,
expires_after,
register_store_backend,
)
from joblib.parallel import Parallel, delayed
from joblib.test.common import np, with_multiprocessing, with_numpy
from joblib.testing import parametrize, raises, warns
###############################################################################
# Module-level variables for the tests
def f(x, y=1):
"""A module-level function for testing purposes."""
return x**2 + y
###############################################################################
# Helper function for the tests
def check_identity_lazy(func, accumulator, location):
"""Given a function and an accumulator (a list that grows every
time the function is called), check that the function can be
decorated by memory to be a lazy identity.
"""
# Call each function with several arguments, and check that it is
# evaluated only once per argument.
memory = Memory(location=location, verbose=0)
func = memory.cache(func)
for i in range(3):
for _ in range(2):
assert func(i) == i
assert len(accumulator) == i + 1
def corrupt_single_cache_item(memory):
(single_cache_item,) = memory.store_backend.get_items()
output_filename = os.path.join(single_cache_item.path, "output.pkl")
with open(output_filename, "w") as f:
f.write("garbage")
def monkeypatch_cached_func_warn(func, monkeypatch_fixture):
# Need monkeypatch because pytest does not
# capture stdlib logging output (see
# https://github.com/pytest-dev/pytest/issues/2079)
recorded = []
def append_to_record(item):
recorded.append(item)
monkeypatch_fixture.setattr(func, "warn", append_to_record)
return recorded
###############################################################################
# Tests
def test_memory_integration(tmpdir):
"""Simple test of memory lazy evaluation."""
accumulator = list()
# Rmk: this function has the same name than a module-level function,
# thus it serves as a test to see that both are identified
# as different.
def f(arg):
accumulator.append(1)
return arg
check_identity_lazy(f, accumulator, tmpdir.strpath)
# Now test clearing
for compress in (False, True):
for mmap_mode in ("r", None):
memory = Memory(
location=tmpdir.strpath,
verbose=10,
mmap_mode=mmap_mode,
compress=compress,
)
# First clear the cache directory, to check that our code can
# handle that
# NOTE: this line would raise an exception, as the database file is
# still open; we ignore the error since we want to test what
# happens if the directory disappears
shutil.rmtree(tmpdir.strpath, ignore_errors=True)
g = memory.cache(f)
g(1)
g.clear(warn=False)
current_accumulator = len(accumulator)
out = g(1)
assert len(accumulator) == current_accumulator + 1
# Also, check that Memory.eval works similarly
assert memory.eval(f, 1) == out
assert len(accumulator) == current_accumulator + 1
# Now do a smoke test with a function defined in __main__, as the name
# mangling rules are more complex
f.__module__ = "__main__"
memory = Memory(location=tmpdir.strpath, verbose=0)
memory.cache(f)(1)
@parametrize("call_before_reducing", [True, False])
def test_parallel_call_cached_function_defined_in_jupyter(tmpdir, call_before_reducing):
# Calling an interactively defined memory.cache()'d function inside a
# Parallel call used to clear the existing cache related to the said
# function (https://github.com/joblib/joblib/issues/1035)
# This tests checks that this is no longer the case.
# TODO: test that the cache related to the function cache persists across
# ipython sessions (provided that no code change were made to the
# function's source)?
# The first part of the test makes the necessary low-level calls to emulate
# the definition of a function in an jupyter notebook cell. Joblib has
# some custom code to treat functions defined specifically in jupyter
# notebooks/ipython session -- we want to test this code, which requires
# the emulation to be rigorous.
for session_no in [0, 1]:
ipython_cell_source = """
def f(x):
return x
"""
ipython_cell_id = "<ipython-input-{}-000000000000>".format(session_no)
my_locals = {}
exec(
compile(
textwrap.dedent(ipython_cell_source),
filename=ipython_cell_id,
mode="exec",
),
# TODO when Python 3.11 is the minimum supported version, use
# locals=my_locals instead of passing globals and locals in the
# next two lines as positional arguments
None,
my_locals,
)
f = my_locals["f"]
f.__module__ = "__main__"
# Preliminary sanity checks, and tests checking that joblib properly
# identified f as an interactive function defined in a jupyter notebook
assert f(1) == 1
assert f.__code__.co_filename == ipython_cell_id
memory = Memory(location=tmpdir.strpath, verbose=0)
cached_f = memory.cache(f)
assert len(os.listdir(tmpdir / "joblib")) == 1
f_cache_relative_directory = os.listdir(tmpdir / "joblib")[0]
assert "ipython-input" in f_cache_relative_directory
f_cache_directory = tmpdir / "joblib" / f_cache_relative_directory
if session_no == 0:
# The cache should be empty as cached_f has not been called yet.
assert os.listdir(f_cache_directory) == ["f"]
assert os.listdir(f_cache_directory / "f") == []
if call_before_reducing:
cached_f(3)
# Two files were just created, func_code.py, and a folder
# containing the information (inputs hash/ouptput) of
# cached_f(3)
assert len(os.listdir(f_cache_directory / "f")) == 2
# Now, testing #1035: when calling a cached function, joblib
# used to dynamically inspect the underlying function to
# extract its source code (to verify it matches the source code
# of the function as last inspected by joblib) -- however,
# source code introspection fails for dynamic functions sent to
# child processes - which would eventually make joblib clear
# the cache associated to f
Parallel(n_jobs=2)(delayed(cached_f)(i) for i in [1, 2])
else:
# Submit the function to the joblib child processes, although
# the function has never been called in the parent yet. This
# triggers a specific code branch inside
# MemorizedFunc.__reduce__.
Parallel(n_jobs=2)(delayed(cached_f)(i) for i in [1, 2])
# Ensure the child process has time to close the file.
# Wait up to 5 seconds for slow CI runs
for _ in range(25):
if len(os.listdir(f_cache_directory / "f")) == 3:
break
time.sleep(0.2) # pragma: no cover
assert len(os.listdir(f_cache_directory / "f")) == 3
cached_f(3)
# Making sure f's cache does not get cleared after the parallel
# calls, and contains ALL cached functions calls (f(1), f(2), f(3))
# and 'func_code.py'
assert len(os.listdir(f_cache_directory / "f")) == 4
else:
# For the second session, there should be an already existing cache
assert len(os.listdir(f_cache_directory / "f")) == 4
cached_f(3)
# The previous cache should not be invalidated after calling the
# function in a new session
assert len(os.listdir(f_cache_directory / "f")) == 4
def test_no_memory():
"""Test memory with location=None: no memoize"""
accumulator = list()
def ff(arg):
accumulator.append(1)
return arg
memory = Memory(location=None, verbose=0)
gg = memory.cache(ff)
for _ in range(4):
current_accumulator = len(accumulator)
gg(1)
assert len(accumulator) == current_accumulator + 1
def test_memory_kwarg(tmpdir):
"Test memory with a function with keyword arguments."
accumulator = list()
def g(arg1=None, arg2=1):
accumulator.append(1)
return arg1
check_identity_lazy(g, accumulator, tmpdir.strpath)
memory = Memory(location=tmpdir.strpath, verbose=0)
g = memory.cache(g)
# Smoke test with an explicit keyword argument:
assert g(arg1=30, arg2=2) == 30
def test_memory_lambda(tmpdir):
"Test memory with a function with a lambda."
accumulator = list()
def helper(x):
"""A helper function to define l as a lambda."""
accumulator.append(1)
return x
check_identity_lazy(lambda x: helper(x), accumulator, tmpdir.strpath)
def test_memory_name_collision(tmpdir):
"Check that name collisions with functions will raise warnings"
memory = Memory(location=tmpdir.strpath, verbose=0)
@memory.cache
def name_collision(x):
"""A first function called name_collision"""
return x
a = name_collision
@memory.cache
def name_collision(x):
"""A second function called name_collision"""
return x
b = name_collision
with warns(JobLibCollisionWarning) as warninfo:
a(1)
b(1)
assert len(warninfo) == 1
assert "collision" in str(warninfo[0].message)
def test_memory_warning_lambda_collisions(tmpdir):
# Check that multiple use of lambda will raise collisions
memory = Memory(location=tmpdir.strpath, verbose=0)
a = memory.cache(lambda x: x)
b = memory.cache(lambda x: x + 1)
with warns(JobLibCollisionWarning) as warninfo:
assert a(0) == 0
assert b(1) == 2
assert a(1) == 1
# In recent Python versions, we can retrieve the code of lambdas,
# thus nothing is raised
assert len(warninfo) == 4
def test_memory_warning_collision_detection(tmpdir):
# Check that collisions impossible to detect will raise appropriate
# warnings.
memory = Memory(location=tmpdir.strpath, verbose=0)
a1 = eval("lambda x: x")
a1 = memory.cache(a1)
b1 = eval("lambda x: x+1")
b1 = memory.cache(b1)
with warns(JobLibCollisionWarning) as warninfo:
a1(1)
b1(1)
a1(0)
assert len(warninfo) == 2
assert "cannot detect" in str(warninfo[0].message).lower()
def test_memory_partial(tmpdir):
"Test memory with functools.partial."
accumulator = list()
def func(x, y):
"""A helper function to define l as a lambda."""
accumulator.append(1)
return y
import functools
function = functools.partial(func, 1)
check_identity_lazy(function, accumulator, tmpdir.strpath)
def test_memory_eval(tmpdir):
"Smoke test memory with a function with a function defined in an eval."
memory = Memory(location=tmpdir.strpath, verbose=0)
m = eval("lambda x: x")
mm = memory.cache(m)
assert mm(1) == 1
def count_and_append(x=[]):
"""A function with a side effect in its arguments.
Return the length of its argument and append one element.
"""
len_x = len(x)
x.append(None)
return len_x
def test_argument_change(tmpdir):
"""Check that if a function has a side effect in its arguments, it
should use the hash of changing arguments.
"""
memory = Memory(location=tmpdir.strpath, verbose=0)
func = memory.cache(count_and_append)
# call the function for the first time, is should cache it with
# argument x=[]
assert func() == 0
# the second time the argument is x=[None], which is not cached
# yet, so the functions should be called a second time
assert func() == 1
@with_numpy
@parametrize("mmap_mode", [None, "r"])
def test_memory_numpy(tmpdir, mmap_mode):
"Test memory with a function with numpy arrays."
accumulator = list()
def n(arg=None):
accumulator.append(1)
return arg
memory = Memory(location=tmpdir.strpath, mmap_mode=mmap_mode, verbose=0)
cached_n = memory.cache(n)
rnd = np.random.RandomState(0)
for i in range(3):
a = rnd.random_sample((10, 10))
for _ in range(3):
assert np.all(cached_n(a) == a)
assert len(accumulator) == i + 1
@with_numpy
def test_memory_numpy_check_mmap_mode(tmpdir, monkeypatch):
"""Check that mmap_mode is respected even at the first call"""
memory = Memory(location=tmpdir.strpath, mmap_mode="r", verbose=0)
@memory.cache()
def twice(a):
return a * 2
a = np.ones(3)
b = twice(a)
c = twice(a)
assert isinstance(c, np.memmap)
assert c.mode == "r"
assert isinstance(b, np.memmap)
assert b.mode == "r"
# Corrupts the file, Deleting b and c mmaps
# is necessary to be able edit the file
del b
del c
gc.collect()
corrupt_single_cache_item(memory)
# Make sure that corrupting the file causes recomputation and that
# a warning is issued.
recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch)
d = twice(a)
assert len(recorded_warnings) == 1
exception_msg = "Exception while loading results"
assert exception_msg in recorded_warnings[0]
# Asserts that the recomputation returns a mmap
assert isinstance(d, np.memmap)
assert d.mode == "r"
def test_memory_exception(tmpdir):
"""Smoketest the exception handling of Memory."""
memory = Memory(location=tmpdir.strpath, verbose=0)
class MyException(Exception):
pass
@memory.cache
def h(exc=0):
if exc:
raise MyException
# Call once, to initialise the cache
h()
for _ in range(3):
# Call 3 times, to be sure that the Exception is always raised
with raises(MyException):
h(1)
def test_memory_ignore(tmpdir):
"Test the ignore feature of memory"
memory = Memory(location=tmpdir.strpath, verbose=0)
accumulator = list()
@memory.cache(ignore=["y"])
def z(x, y=1):
accumulator.append(1)
assert z.ignore == ["y"]
z(0, y=1)
assert len(accumulator) == 1
z(0, y=1)
assert len(accumulator) == 1
z(0, y=2)
assert len(accumulator) == 1
def test_memory_ignore_decorated(tmpdir):
"Test the ignore feature of memory on a decorated function"
memory = Memory(location=tmpdir.strpath, verbose=0)
accumulator = list()
def decorate(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
return f(*args, **kwargs)
return wrapped
@memory.cache(ignore=["y"])
@decorate
def z(x, y=1):
accumulator.append(1)
assert z.ignore == ["y"]
z(0, y=1)
assert len(accumulator) == 1
z(0, y=1)
assert len(accumulator) == 1
z(0, y=2)
assert len(accumulator) == 1
def test_memory_args_as_kwargs(tmpdir):
"""Non-regression test against 0.12.0 changes.
https://github.com/joblib/joblib/pull/751
"""
memory = Memory(location=tmpdir.strpath, verbose=0)
@memory.cache
def plus_one(a):
return a + 1
# It's possible to call a positional arg as a kwarg.
assert plus_one(1) == 2
assert plus_one(a=1) == 2
# However, a positional argument that joblib hadn't seen
# before would cause a failure if it was passed as a kwarg.
assert plus_one(a=2) == 3
@parametrize("ignore, verbose, mmap_mode", [(["x"], 100, "r"), ([], 10, None)])
def test_partial_decoration(tmpdir, ignore, verbose, mmap_mode):
"Check cache may be called with kwargs before decorating"
memory = Memory(location=tmpdir.strpath, verbose=0)
@memory.cache(ignore=ignore, verbose=verbose, mmap_mode=mmap_mode)
def z(x):
pass
assert z.ignore == ignore
assert z._verbose == verbose
assert z.mmap_mode == mmap_mode
def test_func_dir(tmpdir):
# Test the creation of the memory cache directory for the function.
memory = Memory(location=tmpdir.strpath, verbose=0)
path = __name__.split(".")
path.append("f")
path = tmpdir.join("joblib", *path).strpath
g = memory.cache(f)
# Test that the function directory is created on demand
func_id = _build_func_identifier(f)
location = os.path.join(g.store_backend.location, func_id)
assert location == path
assert os.path.exists(path)
assert memory.location == os.path.dirname(g.store_backend.location)
# Test that the code is stored.
# For the following test to be robust to previous execution, we clear
# the in-memory store
_FUNCTION_HASHES.clear()
assert not g._check_previous_func_code()
assert os.path.exists(os.path.join(path, "func_code.py"))
assert g._check_previous_func_code()
# Test the robustness to failure of loading previous results.
args_id = g._get_args_id(1)
output_dir = os.path.join(g.store_backend.location, g.func_id, args_id)
a = g(1)
assert os.path.exists(output_dir)
os.remove(os.path.join(output_dir, "output.pkl"))
assert a == g(1)
def test_persistence(tmpdir):
# Test the memorized functions can be pickled and restored.
memory = Memory(location=tmpdir.strpath, verbose=0)
g = memory.cache(f)
output = g(1)
h = pickle.loads(pickle.dumps(g))
args_id = h._get_args_id(1)
output_dir = os.path.join(h.store_backend.location, h.func_id, args_id)
assert os.path.exists(output_dir)
assert output == h.store_backend.load_item([h.func_id, args_id])
memory2 = pickle.loads(pickle.dumps(memory))
assert memory.store_backend.location == memory2.store_backend.location
# Smoke test that pickling a memory with location=None works
memory = Memory(location=None, verbose=0)
pickle.loads(pickle.dumps(memory))
g = memory.cache(f)
gp = pickle.loads(pickle.dumps(g))
gp(1)
@pytest.mark.parametrize("consider_cache_valid", [True, False])
def test_check_call_in_cache(tmpdir, consider_cache_valid):
for func in (
MemorizedFunc(
f, tmpdir.strpath, cache_validation_callback=lambda _: consider_cache_valid
),
Memory(location=tmpdir.strpath, verbose=0).cache(
f, cache_validation_callback=lambda _: consider_cache_valid
),
):
result = func.check_call_in_cache(2)
assert isinstance(result, bool)
assert not result
assert func(2) == 5
result = func.check_call_in_cache(2)
assert isinstance(result, bool)
assert result == consider_cache_valid
func.clear()
func = NotMemorizedFunc(f)
assert not func.check_call_in_cache(2)
def test_call_and_shelve(tmpdir):
# Test MemorizedFunc outputting a reference to cache.
for func, Result in zip(
(
MemorizedFunc(f, tmpdir.strpath),
NotMemorizedFunc(f),
Memory(location=tmpdir.strpath, verbose=0).cache(f),
Memory(location=None).cache(f),
),
(MemorizedResult, NotMemorizedResult, MemorizedResult, NotMemorizedResult),
):
assert func(2) == 5
result = func.call_and_shelve(2)
assert isinstance(result, Result)
assert result.get() == 5
result.clear()
with raises(KeyError):
result.get()
result.clear() # Do nothing if there is no cache.
def test_call_and_shelve_lazily_load_stored_result(tmpdir):
"""Check call_and_shelve only load stored data if needed."""
test_access_time_file = tmpdir.join("test_access")
test_access_time_file.write("test_access")
test_access_time = os.stat(test_access_time_file.strpath).st_atime
# check file system access time stats resolution is lower than test wait
# timings.
time.sleep(0.5)
assert test_access_time_file.read() == "test_access"
if test_access_time == os.stat(test_access_time_file.strpath).st_atime:
# Skip this test when access time cannot be retrieved with enough
# precision from the file system (e.g. NTFS on windows).
pytest.skip("filesystem does not support fine-grained access time attribute")
memory = Memory(location=tmpdir.strpath, verbose=0)
func = memory.cache(f)
args_id = func._get_args_id(2)
result_path = os.path.join(
memory.store_backend.location, func.func_id, args_id, "output.pkl"
)
assert func(2) == 5
first_access_time = os.stat(result_path).st_atime
time.sleep(1)
# Should not access the stored data
result = func.call_and_shelve(2)
assert isinstance(result, MemorizedResult)
assert os.stat(result_path).st_atime == first_access_time
time.sleep(1)
# Read the stored data => last access time is greater than first_access
assert result.get() == 5
assert os.stat(result_path).st_atime > first_access_time
def test_memorized_pickling(tmpdir):
for func in (MemorizedFunc(f, tmpdir.strpath), NotMemorizedFunc(f)):
filename = tmpdir.join("pickling_test.dat").strpath
result = func.call_and_shelve(2)
with open(filename, "wb") as fp:
pickle.dump(result, fp)
with open(filename, "rb") as fp:
result2 = pickle.load(fp)
assert result2.get() == result.get()
os.remove(filename)
def test_memorized_repr(tmpdir):
func = MemorizedFunc(f, tmpdir.strpath)
result = func.call_and_shelve(2)
func2 = MemorizedFunc(f, tmpdir.strpath)
result2 = func2.call_and_shelve(2)
assert result.get() == result2.get()
assert repr(func) == repr(func2)
# Smoke test with NotMemorizedFunc
func = NotMemorizedFunc(f)
repr(func)
repr(func.call_and_shelve(2))
# Smoke test for message output (increase code coverage)
func = MemorizedFunc(f, tmpdir.strpath, verbose=11, timestamp=time.time())
result = func.call_and_shelve(11)
result.get()
func = MemorizedFunc(f, tmpdir.strpath, verbose=11)
result = func.call_and_shelve(11)
result.get()
func = MemorizedFunc(f, tmpdir.strpath, verbose=5, timestamp=time.time())
result = func.call_and_shelve(11)
result.get()
func = MemorizedFunc(f, tmpdir.strpath, verbose=5)
result = func.call_and_shelve(11)
result.get()
def test_memory_file_modification(capsys, tmpdir, monkeypatch):
# Test that modifying a Python file after loading it does not lead to
# Recomputation
dir_name = tmpdir.mkdir("tmp_import").strpath
filename = os.path.join(dir_name, "tmp_joblib_.py")
content = "def f(x):\n print(x)\n return x\n"
with open(filename, "w") as module_file:
module_file.write(content)
# Load the module:
monkeypatch.syspath_prepend(dir_name)
import tmp_joblib_ as tmp
memory = Memory(location=tmpdir.strpath, verbose=0)
f = memory.cache(tmp.f)
# First call f a few times
f(1)
f(2)
f(1)
# Now modify the module where f is stored without modifying f
with open(filename, "w") as module_file:
module_file.write("\n\n" + content)
# And call f a couple more times
f(1)
f(1)
# Flush the .pyc files
shutil.rmtree(dir_name)
os.mkdir(dir_name)
# Now modify the module where f is stored, modifying f
content = 'def f(x):\n print("x=%s" % x)\n return x\n'
with open(filename, "w") as module_file:
module_file.write(content)
# And call f more times prior to reloading: the cache should not be
# invalidated at this point as the active function definition has not
# changed in memory yet.
f(1)
f(1)
# Now reload
sys.stdout.write("Reloading\n")
sys.modules.pop("tmp_joblib_")
import tmp_joblib_ as tmp
f = memory.cache(tmp.f)
# And call f more times
f(1)
f(1)
out, err = capsys.readouterr()
assert out == "1\n2\nReloading\nx=1\n"
def _function_to_cache(a, b):
# Just a place holder function to be mutated by tests
pass
def _sum(a, b):
return a + b
def _product(a, b):
return a * b
def test_memory_in_memory_function_code_change(tmpdir):
_function_to_cache.__code__ = _sum.__code__
memory = Memory(location=tmpdir.strpath, verbose=0)
f = memory.cache(_function_to_cache)
assert f(1, 2) == 3
assert f(1, 2) == 3
with warns(JobLibCollisionWarning):
# Check that inline function modification triggers a cache invalidation
_function_to_cache.__code__ = _product.__code__
assert f(1, 2) == 2
assert f(1, 2) == 2
def test_clear_memory_with_none_location():
memory = Memory(location=None)
memory.clear()
def func_with_kwonly_args(a, b, *, kw1="kw1", kw2="kw2"):
return a, b, kw1, kw2
def func_with_signature(a: int, b: float) -> float:
return a + b
def test_memory_func_with_kwonly_args(tmpdir):
memory = Memory(location=tmpdir.strpath, verbose=0)
func_cached = memory.cache(func_with_kwonly_args)
assert func_cached(1, 2, kw1=3) == (1, 2, 3, "kw2")
# Making sure that providing a keyword-only argument by
# position raises an exception
with raises(ValueError) as excinfo:
func_cached(1, 2, 3, kw2=4)
excinfo.match("Keyword-only parameter 'kw1' was passed as positional parameter")
# Keyword-only parameter passed by position with cached call
# should still raise ValueError
func_cached(1, 2, kw1=3, kw2=4)
with raises(ValueError) as excinfo:
func_cached(1, 2, 3, kw2=4)
excinfo.match("Keyword-only parameter 'kw1' was passed as positional parameter")
# Test 'ignore' parameter
func_cached = memory.cache(func_with_kwonly_args, ignore=["kw2"])
assert func_cached(1, 2, kw1=3, kw2=4) == (1, 2, 3, 4)
assert func_cached(1, 2, kw1=3, kw2="ignored") == (1, 2, 3, 4)
def test_memory_func_with_signature(tmpdir):
memory = Memory(location=tmpdir.strpath, verbose=0)
func_cached = memory.cache(func_with_signature)
assert func_cached(1, 2.0) == 3.0
def _setup_toy_cache(tmpdir, num_inputs=10):
memory = Memory(location=tmpdir.strpath, verbose=0)
@memory.cache()
def get_1000_bytes(arg):
return "a" * 1000
inputs = list(range(num_inputs))
for arg in inputs:
get_1000_bytes(arg)
func_id = _build_func_identifier(get_1000_bytes)
hash_dirnames = [get_1000_bytes._get_args_id(arg) for arg in inputs]
full_hashdirs = [
os.path.join(get_1000_bytes.store_backend.location, func_id, dirname)
for dirname in hash_dirnames
]
return memory, full_hashdirs, get_1000_bytes
def test__get_items(tmpdir):
memory, expected_hash_dirs, _ = _setup_toy_cache(tmpdir)
items = memory.store_backend.get_items()
hash_dirs = [ci.path for ci in items]
assert set(hash_dirs) == set(expected_hash_dirs)
def get_files_size(directory):
full_paths = [os.path.join(directory, fn) for fn in os.listdir(directory)]
return sum(os.path.getsize(fp) for fp in full_paths)
expected_hash_cache_sizes = [get_files_size(hash_dir) for hash_dir in hash_dirs]
hash_cache_sizes = [ci.size for ci in items]
assert hash_cache_sizes == expected_hash_cache_sizes
output_filenames = [os.path.join(hash_dir, "output.pkl") for hash_dir in hash_dirs]
expected_last_accesses = [
datetime.datetime.fromtimestamp(os.path.getatime(fn)) for fn in output_filenames
]
last_accesses = [ci.last_access for ci in items]
assert last_accesses == expected_last_accesses
def test__get_items_to_delete(tmpdir):
# test empty cache
memory, _, _ = _setup_toy_cache(tmpdir, num_inputs=0)
items_to_delete = memory.store_backend._get_items_to_delete("1K")
assert items_to_delete == []
memory, expected_hash_cachedirs, _ = _setup_toy_cache(tmpdir)
items = memory.store_backend.get_items()
# bytes_limit set to keep only one cache item (each hash cache
# folder is about 1000 bytes + metadata)
items_to_delete = memory.store_backend._get_items_to_delete("2K")
nb_hashes = len(expected_hash_cachedirs)
assert set.issubset(set(items_to_delete), set(items))
assert len(items_to_delete) == nb_hashes - 1
# Sanity check bytes_limit=2048 is the same as bytes_limit='2K'
items_to_delete_2048b = memory.store_backend._get_items_to_delete(2048)
assert sorted(items_to_delete) == sorted(items_to_delete_2048b)
# bytes_limit greater than the size of the cache
items_to_delete_empty = memory.store_backend._get_items_to_delete("1M")
assert items_to_delete_empty == []
# All the cache items need to be deleted
bytes_limit_too_small = 500
items_to_delete_500b = memory.store_backend._get_items_to_delete(
bytes_limit_too_small
)
assert set(items_to_delete_500b), set(items)
# Test LRU property: surviving cache items should all have a more
# recent last_access that the ones that have been deleted
items_to_delete_6000b = memory.store_backend._get_items_to_delete(6000)
surviving_items = set(items).difference(items_to_delete_6000b)
assert max(ci.last_access for ci in items_to_delete_6000b) <= min(
ci.last_access for ci in surviving_items
)
def test_memory_reduce_size_bytes_limit(tmpdir):
memory, _, _ = _setup_toy_cache(tmpdir)
ref_cache_items = memory.store_backend.get_items()
# By default memory.bytes_limit is None and reduce_size is a noop
memory.reduce_size()
cache_items = memory.store_backend.get_items()
assert sorted(ref_cache_items) == sorted(cache_items)
# No cache items deleted if bytes_limit greater than the size of
# the cache
memory.reduce_size(bytes_limit="1M")
cache_items = memory.store_backend.get_items()
assert sorted(ref_cache_items) == sorted(cache_items)
# bytes_limit is set so that only two cache items are kept
memory.reduce_size(bytes_limit="3K")
cache_items = memory.store_backend.get_items()
assert set.issubset(set(cache_items), set(ref_cache_items))
assert len(cache_items) == 2
# bytes_limit set so that no cache item is kept
bytes_limit_too_small = 500
memory.reduce_size(bytes_limit=bytes_limit_too_small)
cache_items = memory.store_backend.get_items()
assert cache_items == []
def test_memory_reduce_size_items_limit(tmpdir):
memory, _, _ = _setup_toy_cache(tmpdir)
ref_cache_items = memory.store_backend.get_items()
# By default reduce_size is a noop
memory.reduce_size()
cache_items = memory.store_backend.get_items()
assert sorted(ref_cache_items) == sorted(cache_items)
# No cache items deleted if items_limit greater than the size of
# the cache
memory.reduce_size(items_limit=10)
cache_items = memory.store_backend.get_items()
assert sorted(ref_cache_items) == sorted(cache_items)
# items_limit is set so that only two cache items are kept
memory.reduce_size(items_limit=2)
cache_items = memory.store_backend.get_items()
assert set.issubset(set(cache_items), set(ref_cache_items))
assert len(cache_items) == 2
# item_limit set so that no cache item is kept
memory.reduce_size(items_limit=0)
cache_items = memory.store_backend.get_items()
assert cache_items == []
def test_memory_reduce_size_age_limit(tmpdir):
import datetime
import time
memory, _, put_cache = _setup_toy_cache(tmpdir)
ref_cache_items = memory.store_backend.get_items()
# By default reduce_size is a noop
memory.reduce_size()
cache_items = memory.store_backend.get_items()
assert sorted(ref_cache_items) == sorted(cache_items)
# No cache items deleted if age_limit big.
memory.reduce_size(age_limit=datetime.timedelta(days=1))
cache_items = memory.store_backend.get_items()
assert sorted(ref_cache_items) == sorted(cache_items)
# age_limit is set so that only two cache items are kept
time.sleep(1)
put_cache(-1)
put_cache(-2)
memory.reduce_size(age_limit=datetime.timedelta(seconds=1))
cache_items = memory.store_backend.get_items()
assert not set.issubset(set(cache_items), set(ref_cache_items))
assert len(cache_items) == 2
# ensure age_limit is forced to be positive
with pytest.raises(ValueError, match="has to be a positive"):
memory.reduce_size(age_limit=datetime.timedelta(seconds=-1))
# age_limit set so that no cache item is kept
time.sleep(0.001) # make sure the age is different
memory.reduce_size(age_limit=datetime.timedelta(seconds=0))
cache_items = memory.store_backend.get_items()
assert cache_items == []
def test_memory_clear(tmpdir):
memory, _, g = _setup_toy_cache(tmpdir)
memory.clear()
assert os.listdir(memory.store_backend.location) == []
# Check that the cache for functions hash is also reset.
assert not g._check_previous_func_code(stacklevel=4)
def fast_func_with_complex_output():
complex_obj = ["a" * 1000] * 1000
return complex_obj
def fast_func_with_conditional_complex_output(complex_output=True):
complex_obj = {str(i): i for i in range(int(1e5))}
return complex_obj if complex_output else "simple output"
@with_multiprocessing
def test_cached_function_race_condition_when_persisting_output(tmpdir, capfd):
# Test race condition where multiple processes are writing into
# the same output.pkl. See
# https://github.com/joblib/joblib/issues/490 for more details.
memory = Memory(location=tmpdir.strpath)
func_cached = memory.cache(fast_func_with_complex_output)
Parallel(n_jobs=2)(delayed(func_cached)() for i in range(3))
stdout, stderr = capfd.readouterr()
# Checking both stdout and stderr (ongoing PR #434 may change
# logging destination) to make sure there is no exception while
# loading the results
exception_msg = "Exception while loading results"
assert exception_msg not in stdout
assert exception_msg not in stderr
@with_multiprocessing
def test_cached_function_race_condition_when_persisting_output_2(tmpdir, capfd):
# Test race condition in first attempt at solving
# https://github.com/joblib/joblib/issues/490. The race condition
# was due to the delay between seeing the cache directory created
# (interpreted as the result being cached) and the output.pkl being
# pickled.
memory = Memory(location=tmpdir.strpath)
func_cached = memory.cache(fast_func_with_conditional_complex_output)
Parallel(n_jobs=2)(
delayed(func_cached)(True if i % 2 == 0 else False) for i in range(3)
)
stdout, stderr = capfd.readouterr()
# Checking both stdout and stderr (ongoing PR #434 may change
# logging destination) to make sure there is no exception while
# loading the results
exception_msg = "Exception while loading results"
assert exception_msg not in stdout
assert exception_msg not in stderr
def test_memory_recomputes_after_an_error_while_loading_results(tmpdir, monkeypatch):
memory = Memory(location=tmpdir.strpath)
def func(arg):
# This makes sure that the timestamp returned by two calls of
# func are different. This is needed on Windows where
# time.time resolution may not be accurate enough
time.sleep(0.01)
return arg, time.time()
cached_func = memory.cache(func)
input_arg = "arg"
arg, timestamp = cached_func(input_arg)
# Make sure the function is correctly cached
assert arg == input_arg
# Corrupting output.pkl to make sure that an error happens when
# loading the cached result
corrupt_single_cache_item(memory)
# Make sure that corrupting the file causes recomputation and that
# a warning is issued.
recorded_warnings = monkeypatch_cached_func_warn(cached_func, monkeypatch)
recomputed_arg, recomputed_timestamp = cached_func(arg)
assert len(recorded_warnings) == 1
exception_msg = "Exception while loading results"
assert exception_msg in recorded_warnings[0]
assert recomputed_arg == arg
assert recomputed_timestamp > timestamp
# Corrupting output.pkl to make sure that an error happens when
# loading the cached result
corrupt_single_cache_item(memory)
reference = cached_func.call_and_shelve(arg)
try:
reference.get()
raise AssertionError(
"It normally not possible to load a corrupted MemorizedResult"
)
except KeyError as e:
message = "is corrupted"
assert message in str(e.args)
class IncompleteStoreBackend(StoreBackendBase):
"""This backend cannot be instantiated and should raise a TypeError."""
pass
class DummyStoreBackend(StoreBackendBase):
"""A dummy store backend that does nothing."""
def _open_item(self, *args, **kwargs):
"""Open an item on store."""
"Does nothing"
def _item_exists(self, location):
"""Check if an item location exists."""
"Does nothing"
def _move_item(self, src, dst):
"""Move an item from src to dst in store."""
"Does nothing"
def create_location(self, location):
"""Create location on store."""
"Does nothing"
def exists(self, obj):
"""Check if an object exists in the store"""
return False
def clear_location(self, obj):
"""Clear object on store"""
"Does nothing"
def get_items(self):
"""Returns the whole list of items available in cache."""
return []
def configure(self, location, *args, **kwargs):
"""Configure the store"""
"Does nothing"
@parametrize("invalid_prefix", [None, dict(), list()])
def test_register_invalid_store_backends_key(invalid_prefix):
# verify the right exceptions are raised when passing a wrong backend key.
with raises(ValueError) as excinfo:
register_store_backend(invalid_prefix, None)
excinfo.match(r"Store backend name should be a string*")
def test_register_invalid_store_backends_object():
# verify the right exceptions are raised when passing a wrong backend
# object.
with raises(ValueError) as excinfo:
register_store_backend("fs", None)
excinfo.match(r"Store backend should inherit StoreBackendBase*")
def test_memory_default_store_backend():
# test an unknown backend falls back into a FileSystemStoreBackend
with raises(TypeError) as excinfo:
Memory(location="/tmp/joblib", backend="unknown")
excinfo.match(r"Unknown location*")
def test_warning_on_unknown_location_type():
class NonSupportedLocationClass:
pass
unsupported_location = NonSupportedLocationClass()
with warns(UserWarning) as warninfo:
_store_backend_factory("local", location=unsupported_location)
expected_mesage = (
"Instantiating a backend using a "
"NonSupportedLocationClass as a location is not "
"supported by joblib"
)
assert expected_mesage in str(warninfo[0].message)
def test_instanciate_incomplete_store_backend():
# Verify that registering an external incomplete store backend raises an
# exception when one tries to instantiate it.
backend_name = "isb"
register_store_backend(backend_name, IncompleteStoreBackend)
assert (backend_name, IncompleteStoreBackend) in _STORE_BACKENDS.items()
with raises(TypeError) as excinfo:
_store_backend_factory(backend_name, "fake_location")
excinfo.match(
r"Can't instantiate abstract class IncompleteStoreBackend "
"(without an implementation for|with) abstract methods*"
)
def test_dummy_store_backend():
# Verify that registering an external store backend works.
backend_name = "dsb"
register_store_backend(backend_name, DummyStoreBackend)
assert (backend_name, DummyStoreBackend) in _STORE_BACKENDS.items()
backend_obj = _store_backend_factory(backend_name, "dummy_location")
assert isinstance(backend_obj, DummyStoreBackend)
def test_instanciate_store_backend_with_pathlib_path():
# Instantiate a FileSystemStoreBackend using a pathlib.Path object
path = pathlib.Path("some_folder")
backend_obj = _store_backend_factory("local", path)
try:
assert backend_obj.location == "some_folder"
finally: # remove cache folder after test
shutil.rmtree("some_folder", ignore_errors=True)
def test_filesystem_store_backend_repr(tmpdir):
# Verify string representation of a filesystem store backend.
repr_pattern = 'FileSystemStoreBackend(location="{location}")'
backend = FileSystemStoreBackend()
assert backend.location is None
repr(backend) # Should not raise an exception
assert str(backend) == repr_pattern.format(location=None)
# backend location is passed explicitly via the configure method (called
# by the internal _store_backend_factory function)
backend.configure(tmpdir.strpath)
assert str(backend) == repr_pattern.format(location=tmpdir.strpath)
repr(backend) # Should not raise an exception
def test_memory_objects_repr(tmpdir):
# Verify printable reprs of MemorizedResult, MemorizedFunc and Memory.
def my_func(a, b):
return a + b
memory = Memory(location=tmpdir.strpath, verbose=0)
memorized_func = memory.cache(my_func)
memorized_func_repr = "MemorizedFunc(func={func}, location={location})"
assert str(memorized_func) == memorized_func_repr.format(
func=my_func, location=memory.store_backend.location
)
memorized_result = memorized_func.call_and_shelve(42, 42)
memorized_result_repr = (
'MemorizedResult(location="{location}", func="{func}", args_id="{args_id}")'
)
assert str(memorized_result) == memorized_result_repr.format(
location=memory.store_backend.location,
func=memorized_result.func_id,
args_id=memorized_result.args_id,
)
assert str(memory) == "Memory(location={location})".format(
location=memory.store_backend.location
)
def test_memorized_result_pickle(tmpdir):
# Verify a MemoryResult object can be pickled/depickled. Non regression
# test introduced following issue
# https://github.com/joblib/joblib/issues/747
memory = Memory(location=tmpdir.strpath)
@memory.cache
def g(x):
return x**2
memorized_result = g.call_and_shelve(4)
memorized_result_pickle = pickle.dumps(memorized_result)
memorized_result_loads = pickle.loads(memorized_result_pickle)
assert (
memorized_result.store_backend.location
== memorized_result_loads.store_backend.location
)
assert memorized_result.func == memorized_result_loads.func
assert memorized_result.args_id == memorized_result_loads.args_id
assert str(memorized_result) == str(memorized_result_loads)
def compare(left, right, ignored_attrs=None):
if ignored_attrs is None:
ignored_attrs = []
left_vars = vars(left)
right_vars = vars(right)
assert set(left_vars.keys()) == set(right_vars.keys())
for attr in left_vars.keys():
if attr in ignored_attrs:
continue
assert left_vars[attr] == right_vars[attr]
@pytest.mark.parametrize(
"memory_kwargs",
[
{"compress": 3, "verbose": 2},
{"mmap_mode": "r", "verbose": 5, "backend_options": {"parameter": "unused"}},
],
)
def test_memory_pickle_dump_load(tmpdir, memory_kwargs):
memory = Memory(location=tmpdir.strpath, **memory_kwargs)
memory_reloaded = pickle.loads(pickle.dumps(memory))
# Compare Memory instance before and after pickle roundtrip
compare(memory.store_backend, memory_reloaded.store_backend)
compare(
memory,
memory_reloaded,
ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]),
)
assert hash(memory) == hash(memory_reloaded)
func_cached = memory.cache(f)
func_cached_reloaded = pickle.loads(pickle.dumps(func_cached))
# Compare MemorizedFunc instance before/after pickle roundtrip
compare(func_cached.store_backend, func_cached_reloaded.store_backend)
compare(
func_cached,
func_cached_reloaded,
ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]),
)
assert hash(func_cached) == hash(func_cached_reloaded)
# Compare MemorizedResult instance before/after pickle roundtrip
memorized_result = func_cached.call_and_shelve(1)
memorized_result_reloaded = pickle.loads(pickle.dumps(memorized_result))
compare(memorized_result.store_backend, memorized_result_reloaded.store_backend)
compare(
memorized_result,
memorized_result_reloaded,
ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]),
)
assert hash(memorized_result) == hash(memorized_result_reloaded)
def test_info_log(tmpdir, caplog):
caplog.set_level(logging.INFO)
x = 3
memory = Memory(location=tmpdir.strpath, verbose=20)
@memory.cache
def f(x):
return x**2
_ = f(x)
assert "Querying" in caplog.text
caplog.clear()
memory = Memory(location=tmpdir.strpath, verbose=0)
@memory.cache
def f(x):
return x**2
_ = f(x)
assert "Querying" not in caplog.text
caplog.clear()
class TestCacheValidationCallback:
"Tests on parameter `cache_validation_callback`"
def foo(self, x, d, delay=None):
d["run"] = True
if delay is not None:
time.sleep(delay)
return x * 2
def test_invalid_cache_validation_callback(self, memory):
"Test invalid values for `cache_validation_callback"
match = "cache_validation_callback needs to be callable. Got True."
with pytest.raises(ValueError, match=match):
memory.cache(cache_validation_callback=True)
@pytest.mark.parametrize("consider_cache_valid", [True, False])
def test_constant_cache_validation_callback(self, memory, consider_cache_valid):
"Test expiry of old results"
f = memory.cache(
self.foo,
cache_validation_callback=lambda _: consider_cache_valid,
ignore=["d"],
)
d1, d2 = {"run": False}, {"run": False}
assert f(2, d1) == 4
assert f(2, d2) == 4
assert d1["run"]
assert d2["run"] != consider_cache_valid
def test_memory_only_cache_long_run(self, memory):
"Test cache validity based on run duration."
def cache_validation_callback(metadata):
duration = metadata["duration"]
if duration > 0.1:
return True
f = memory.cache(
self.foo, cache_validation_callback=cache_validation_callback, ignore=["d"]
)
# Short run are not cached
d1, d2 = {"run": False}, {"run": False}
assert f(2, d1, delay=0) == 4
assert f(2, d2, delay=0) == 4
assert d1["run"]
assert d2["run"]
# Longer run are cached
d1, d2 = {"run": False}, {"run": False}
assert f(2, d1, delay=0.2) == 4
assert f(2, d2, delay=0.2) == 4
assert d1["run"]
assert not d2["run"]
def test_memory_expires_after(self, memory):
"Test expiry of old cached results"
f = memory.cache(
self.foo, cache_validation_callback=expires_after(seconds=0.3), ignore=["d"]
)
d1, d2, d3 = {"run": False}, {"run": False}, {"run": False}
assert f(2, d1) == 4
assert f(2, d2) == 4
time.sleep(0.5)
assert f(2, d3) == 4
assert d1["run"]
assert not d2["run"]
assert d3["run"]
class TestMemorizedFunc:
"Tests for the MemorizedFunc and NotMemorizedFunc classes"
@staticmethod
def f(x, counter):
counter[x] = counter.get(x, 0) + 1
return counter[x]
def test_call_method_memorized(self, memory):
"Test calling the function"
f = memory.cache(self.f, ignore=["counter"])
counter = {}
assert f(2, counter) == 1
assert f(2, counter) == 1
x, meta = f.call(2, counter)
assert x == 2, "f has not been called properly"
assert isinstance(meta, dict), (
"Metadata are not returned by MemorizedFunc.call."
)
def test_call_method_not_memorized(self, memory):
"Test calling the function"
f = NotMemorizedFunc(self.f)
counter = {}
assert f(2, counter) == 1
assert f(2, counter) == 2
x, meta = f.call(2, counter)
assert x == 3, "f has not been called properly"
assert isinstance(meta, dict), (
"Metadata are not returned by MemorizedFunc.call."
)
@with_numpy
@parametrize(
"location",
[
"test_cache_dir",
pathlib.Path("test_cache_dir"),
pathlib.Path("test_cache_dir").resolve(),
],
)
def test_memory_creates_gitignore(location):
"""Test that using the memory object automatically creates a `.gitignore` file
within the new cache directory."""
mem = Memory(location)
arr = np.asarray([[1, 2, 3], [4, 5, 6]])
costly_operation = mem.cache(np.square)
costly_operation(arr)
location = pathlib.Path(location)
try:
path_to_gitignore_file = os.path.join(location, ".gitignore")
gitignore_file_content = "# Created by joblib automatically.\n*\n"
with open(path_to_gitignore_file) as f:
assert gitignore_file_content == f.read()
finally: # remove cache folder after test
shutil.rmtree(location, ignore_errors=True)