|
""" |
|
Test the memory module. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import datetime |
|
import functools |
|
import gc |
|
import logging |
|
import os |
|
import os.path |
|
import pathlib |
|
import pickle |
|
import shutil |
|
import sys |
|
import textwrap |
|
import time |
|
|
|
import pytest |
|
|
|
from joblib._store_backends import FileSystemStoreBackend, StoreBackendBase |
|
from joblib.hashing import hash |
|
from joblib.memory import ( |
|
_FUNCTION_HASHES, |
|
_STORE_BACKENDS, |
|
JobLibCollisionWarning, |
|
MemorizedFunc, |
|
MemorizedResult, |
|
Memory, |
|
NotMemorizedFunc, |
|
NotMemorizedResult, |
|
_build_func_identifier, |
|
_store_backend_factory, |
|
expires_after, |
|
register_store_backend, |
|
) |
|
from joblib.parallel import Parallel, delayed |
|
from joblib.test.common import np, with_multiprocessing, with_numpy |
|
from joblib.testing import parametrize, raises, warns |
|
|
|
|
|
|
|
|
|
def f(x, y=1): |
|
"""A module-level function for testing purposes.""" |
|
return x**2 + y |
|
|
|
|
|
|
|
|
|
def check_identity_lazy(func, accumulator, location): |
|
"""Given a function and an accumulator (a list that grows every |
|
time the function is called), check that the function can be |
|
decorated by memory to be a lazy identity. |
|
""" |
|
|
|
|
|
memory = Memory(location=location, verbose=0) |
|
func = memory.cache(func) |
|
for i in range(3): |
|
for _ in range(2): |
|
assert func(i) == i |
|
assert len(accumulator) == i + 1 |
|
|
|
|
|
def corrupt_single_cache_item(memory): |
|
(single_cache_item,) = memory.store_backend.get_items() |
|
output_filename = os.path.join(single_cache_item.path, "output.pkl") |
|
with open(output_filename, "w") as f: |
|
f.write("garbage") |
|
|
|
|
|
def monkeypatch_cached_func_warn(func, monkeypatch_fixture): |
|
|
|
|
|
|
|
|
|
recorded = [] |
|
|
|
def append_to_record(item): |
|
recorded.append(item) |
|
|
|
monkeypatch_fixture.setattr(func, "warn", append_to_record) |
|
return recorded |
|
|
|
|
|
|
|
|
|
def test_memory_integration(tmpdir): |
|
"""Simple test of memory lazy evaluation.""" |
|
accumulator = list() |
|
|
|
|
|
|
|
|
|
def f(arg): |
|
accumulator.append(1) |
|
return arg |
|
|
|
check_identity_lazy(f, accumulator, tmpdir.strpath) |
|
|
|
|
|
for compress in (False, True): |
|
for mmap_mode in ("r", None): |
|
memory = Memory( |
|
location=tmpdir.strpath, |
|
verbose=10, |
|
mmap_mode=mmap_mode, |
|
compress=compress, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
shutil.rmtree(tmpdir.strpath, ignore_errors=True) |
|
g = memory.cache(f) |
|
g(1) |
|
g.clear(warn=False) |
|
current_accumulator = len(accumulator) |
|
out = g(1) |
|
|
|
assert len(accumulator) == current_accumulator + 1 |
|
|
|
assert memory.eval(f, 1) == out |
|
assert len(accumulator) == current_accumulator + 1 |
|
|
|
|
|
|
|
f.__module__ = "__main__" |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
memory.cache(f)(1) |
|
|
|
|
|
@parametrize("call_before_reducing", [True, False]) |
|
def test_parallel_call_cached_function_defined_in_jupyter(tmpdir, call_before_reducing): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for session_no in [0, 1]: |
|
ipython_cell_source = """ |
|
def f(x): |
|
return x |
|
""" |
|
|
|
ipython_cell_id = "<ipython-input-{}-000000000000>".format(session_no) |
|
|
|
my_locals = {} |
|
exec( |
|
compile( |
|
textwrap.dedent(ipython_cell_source), |
|
filename=ipython_cell_id, |
|
mode="exec", |
|
), |
|
|
|
|
|
|
|
None, |
|
my_locals, |
|
) |
|
f = my_locals["f"] |
|
f.__module__ = "__main__" |
|
|
|
|
|
|
|
assert f(1) == 1 |
|
assert f.__code__.co_filename == ipython_cell_id |
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
cached_f = memory.cache(f) |
|
|
|
assert len(os.listdir(tmpdir / "joblib")) == 1 |
|
f_cache_relative_directory = os.listdir(tmpdir / "joblib")[0] |
|
assert "ipython-input" in f_cache_relative_directory |
|
|
|
f_cache_directory = tmpdir / "joblib" / f_cache_relative_directory |
|
|
|
if session_no == 0: |
|
|
|
assert os.listdir(f_cache_directory) == ["f"] |
|
assert os.listdir(f_cache_directory / "f") == [] |
|
|
|
if call_before_reducing: |
|
cached_f(3) |
|
|
|
|
|
|
|
assert len(os.listdir(f_cache_directory / "f")) == 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Parallel(n_jobs=2)(delayed(cached_f)(i) for i in [1, 2]) |
|
else: |
|
|
|
|
|
|
|
|
|
Parallel(n_jobs=2)(delayed(cached_f)(i) for i in [1, 2]) |
|
|
|
|
|
for _ in range(25): |
|
if len(os.listdir(f_cache_directory / "f")) == 3: |
|
break |
|
time.sleep(0.2) |
|
assert len(os.listdir(f_cache_directory / "f")) == 3 |
|
|
|
cached_f(3) |
|
|
|
|
|
|
|
|
|
assert len(os.listdir(f_cache_directory / "f")) == 4 |
|
else: |
|
|
|
assert len(os.listdir(f_cache_directory / "f")) == 4 |
|
|
|
cached_f(3) |
|
|
|
|
|
|
|
assert len(os.listdir(f_cache_directory / "f")) == 4 |
|
|
|
|
|
def test_no_memory(): |
|
"""Test memory with location=None: no memoize""" |
|
accumulator = list() |
|
|
|
def ff(arg): |
|
accumulator.append(1) |
|
return arg |
|
|
|
memory = Memory(location=None, verbose=0) |
|
gg = memory.cache(ff) |
|
for _ in range(4): |
|
current_accumulator = len(accumulator) |
|
gg(1) |
|
assert len(accumulator) == current_accumulator + 1 |
|
|
|
|
|
def test_memory_kwarg(tmpdir): |
|
"Test memory with a function with keyword arguments." |
|
accumulator = list() |
|
|
|
def g(arg1=None, arg2=1): |
|
accumulator.append(1) |
|
return arg1 |
|
|
|
check_identity_lazy(g, accumulator, tmpdir.strpath) |
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
g = memory.cache(g) |
|
|
|
assert g(arg1=30, arg2=2) == 30 |
|
|
|
|
|
def test_memory_lambda(tmpdir): |
|
"Test memory with a function with a lambda." |
|
accumulator = list() |
|
|
|
def helper(x): |
|
"""A helper function to define l as a lambda.""" |
|
accumulator.append(1) |
|
return x |
|
|
|
check_identity_lazy(lambda x: helper(x), accumulator, tmpdir.strpath) |
|
|
|
|
|
def test_memory_name_collision(tmpdir): |
|
"Check that name collisions with functions will raise warnings" |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
|
|
@memory.cache |
|
def name_collision(x): |
|
"""A first function called name_collision""" |
|
return x |
|
|
|
a = name_collision |
|
|
|
@memory.cache |
|
def name_collision(x): |
|
"""A second function called name_collision""" |
|
return x |
|
|
|
b = name_collision |
|
|
|
with warns(JobLibCollisionWarning) as warninfo: |
|
a(1) |
|
b(1) |
|
|
|
assert len(warninfo) == 1 |
|
assert "collision" in str(warninfo[0].message) |
|
|
|
|
|
def test_memory_warning_lambda_collisions(tmpdir): |
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
a = memory.cache(lambda x: x) |
|
b = memory.cache(lambda x: x + 1) |
|
|
|
with warns(JobLibCollisionWarning) as warninfo: |
|
assert a(0) == 0 |
|
assert b(1) == 2 |
|
assert a(1) == 1 |
|
|
|
|
|
|
|
assert len(warninfo) == 4 |
|
|
|
|
|
def test_memory_warning_collision_detection(tmpdir): |
|
|
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
a1 = eval("lambda x: x") |
|
a1 = memory.cache(a1) |
|
b1 = eval("lambda x: x+1") |
|
b1 = memory.cache(b1) |
|
|
|
with warns(JobLibCollisionWarning) as warninfo: |
|
a1(1) |
|
b1(1) |
|
a1(0) |
|
|
|
assert len(warninfo) == 2 |
|
assert "cannot detect" in str(warninfo[0].message).lower() |
|
|
|
|
|
def test_memory_partial(tmpdir): |
|
"Test memory with functools.partial." |
|
accumulator = list() |
|
|
|
def func(x, y): |
|
"""A helper function to define l as a lambda.""" |
|
accumulator.append(1) |
|
return y |
|
|
|
import functools |
|
|
|
function = functools.partial(func, 1) |
|
|
|
check_identity_lazy(function, accumulator, tmpdir.strpath) |
|
|
|
|
|
def test_memory_eval(tmpdir): |
|
"Smoke test memory with a function with a function defined in an eval." |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
|
|
m = eval("lambda x: x") |
|
mm = memory.cache(m) |
|
|
|
assert mm(1) == 1 |
|
|
|
|
|
def count_and_append(x=[]): |
|
"""A function with a side effect in its arguments. |
|
|
|
Return the length of its argument and append one element. |
|
""" |
|
len_x = len(x) |
|
x.append(None) |
|
return len_x |
|
|
|
|
|
def test_argument_change(tmpdir): |
|
"""Check that if a function has a side effect in its arguments, it |
|
should use the hash of changing arguments. |
|
""" |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
func = memory.cache(count_and_append) |
|
|
|
|
|
assert func() == 0 |
|
|
|
|
|
assert func() == 1 |
|
|
|
|
|
@with_numpy |
|
@parametrize("mmap_mode", [None, "r"]) |
|
def test_memory_numpy(tmpdir, mmap_mode): |
|
"Test memory with a function with numpy arrays." |
|
accumulator = list() |
|
|
|
def n(arg=None): |
|
accumulator.append(1) |
|
return arg |
|
|
|
memory = Memory(location=tmpdir.strpath, mmap_mode=mmap_mode, verbose=0) |
|
cached_n = memory.cache(n) |
|
|
|
rnd = np.random.RandomState(0) |
|
for i in range(3): |
|
a = rnd.random_sample((10, 10)) |
|
for _ in range(3): |
|
assert np.all(cached_n(a) == a) |
|
assert len(accumulator) == i + 1 |
|
|
|
|
|
@with_numpy |
|
def test_memory_numpy_check_mmap_mode(tmpdir, monkeypatch): |
|
"""Check that mmap_mode is respected even at the first call""" |
|
|
|
memory = Memory(location=tmpdir.strpath, mmap_mode="r", verbose=0) |
|
|
|
@memory.cache() |
|
def twice(a): |
|
return a * 2 |
|
|
|
a = np.ones(3) |
|
|
|
b = twice(a) |
|
c = twice(a) |
|
|
|
assert isinstance(c, np.memmap) |
|
assert c.mode == "r" |
|
|
|
assert isinstance(b, np.memmap) |
|
assert b.mode == "r" |
|
|
|
|
|
|
|
del b |
|
del c |
|
gc.collect() |
|
corrupt_single_cache_item(memory) |
|
|
|
|
|
|
|
recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch) |
|
d = twice(a) |
|
assert len(recorded_warnings) == 1 |
|
exception_msg = "Exception while loading results" |
|
assert exception_msg in recorded_warnings[0] |
|
|
|
assert isinstance(d, np.memmap) |
|
assert d.mode == "r" |
|
|
|
|
|
def test_memory_exception(tmpdir): |
|
"""Smoketest the exception handling of Memory.""" |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
|
|
class MyException(Exception): |
|
pass |
|
|
|
@memory.cache |
|
def h(exc=0): |
|
if exc: |
|
raise MyException |
|
|
|
|
|
h() |
|
|
|
for _ in range(3): |
|
|
|
with raises(MyException): |
|
h(1) |
|
|
|
|
|
def test_memory_ignore(tmpdir): |
|
"Test the ignore feature of memory" |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
accumulator = list() |
|
|
|
@memory.cache(ignore=["y"]) |
|
def z(x, y=1): |
|
accumulator.append(1) |
|
|
|
assert z.ignore == ["y"] |
|
|
|
z(0, y=1) |
|
assert len(accumulator) == 1 |
|
z(0, y=1) |
|
assert len(accumulator) == 1 |
|
z(0, y=2) |
|
assert len(accumulator) == 1 |
|
|
|
|
|
def test_memory_ignore_decorated(tmpdir): |
|
"Test the ignore feature of memory on a decorated function" |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
accumulator = list() |
|
|
|
def decorate(f): |
|
@functools.wraps(f) |
|
def wrapped(*args, **kwargs): |
|
return f(*args, **kwargs) |
|
|
|
return wrapped |
|
|
|
@memory.cache(ignore=["y"]) |
|
@decorate |
|
def z(x, y=1): |
|
accumulator.append(1) |
|
|
|
assert z.ignore == ["y"] |
|
|
|
z(0, y=1) |
|
assert len(accumulator) == 1 |
|
z(0, y=1) |
|
assert len(accumulator) == 1 |
|
z(0, y=2) |
|
assert len(accumulator) == 1 |
|
|
|
|
|
def test_memory_args_as_kwargs(tmpdir): |
|
"""Non-regression test against 0.12.0 changes. |
|
|
|
https://github.com/joblib/joblib/pull/751 |
|
""" |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
|
|
@memory.cache |
|
def plus_one(a): |
|
return a + 1 |
|
|
|
|
|
assert plus_one(1) == 2 |
|
assert plus_one(a=1) == 2 |
|
|
|
|
|
|
|
assert plus_one(a=2) == 3 |
|
|
|
|
|
@parametrize("ignore, verbose, mmap_mode", [(["x"], 100, "r"), ([], 10, None)]) |
|
def test_partial_decoration(tmpdir, ignore, verbose, mmap_mode): |
|
"Check cache may be called with kwargs before decorating" |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
|
|
@memory.cache(ignore=ignore, verbose=verbose, mmap_mode=mmap_mode) |
|
def z(x): |
|
pass |
|
|
|
assert z.ignore == ignore |
|
assert z._verbose == verbose |
|
assert z.mmap_mode == mmap_mode |
|
|
|
|
|
def test_func_dir(tmpdir): |
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
path = __name__.split(".") |
|
path.append("f") |
|
path = tmpdir.join("joblib", *path).strpath |
|
|
|
g = memory.cache(f) |
|
|
|
func_id = _build_func_identifier(f) |
|
location = os.path.join(g.store_backend.location, func_id) |
|
assert location == path |
|
assert os.path.exists(path) |
|
assert memory.location == os.path.dirname(g.store_backend.location) |
|
|
|
|
|
|
|
|
|
_FUNCTION_HASHES.clear() |
|
assert not g._check_previous_func_code() |
|
assert os.path.exists(os.path.join(path, "func_code.py")) |
|
assert g._check_previous_func_code() |
|
|
|
|
|
args_id = g._get_args_id(1) |
|
output_dir = os.path.join(g.store_backend.location, g.func_id, args_id) |
|
a = g(1) |
|
assert os.path.exists(output_dir) |
|
os.remove(os.path.join(output_dir, "output.pkl")) |
|
assert a == g(1) |
|
|
|
|
|
def test_persistence(tmpdir): |
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
g = memory.cache(f) |
|
output = g(1) |
|
|
|
h = pickle.loads(pickle.dumps(g)) |
|
|
|
args_id = h._get_args_id(1) |
|
output_dir = os.path.join(h.store_backend.location, h.func_id, args_id) |
|
assert os.path.exists(output_dir) |
|
assert output == h.store_backend.load_item([h.func_id, args_id]) |
|
memory2 = pickle.loads(pickle.dumps(memory)) |
|
assert memory.store_backend.location == memory2.store_backend.location |
|
|
|
|
|
memory = Memory(location=None, verbose=0) |
|
pickle.loads(pickle.dumps(memory)) |
|
g = memory.cache(f) |
|
gp = pickle.loads(pickle.dumps(g)) |
|
gp(1) |
|
|
|
|
|
@pytest.mark.parametrize("consider_cache_valid", [True, False]) |
|
def test_check_call_in_cache(tmpdir, consider_cache_valid): |
|
for func in ( |
|
MemorizedFunc( |
|
f, tmpdir.strpath, cache_validation_callback=lambda _: consider_cache_valid |
|
), |
|
Memory(location=tmpdir.strpath, verbose=0).cache( |
|
f, cache_validation_callback=lambda _: consider_cache_valid |
|
), |
|
): |
|
result = func.check_call_in_cache(2) |
|
assert isinstance(result, bool) |
|
assert not result |
|
assert func(2) == 5 |
|
result = func.check_call_in_cache(2) |
|
assert isinstance(result, bool) |
|
assert result == consider_cache_valid |
|
func.clear() |
|
|
|
func = NotMemorizedFunc(f) |
|
assert not func.check_call_in_cache(2) |
|
|
|
|
|
def test_call_and_shelve(tmpdir): |
|
|
|
|
|
for func, Result in zip( |
|
( |
|
MemorizedFunc(f, tmpdir.strpath), |
|
NotMemorizedFunc(f), |
|
Memory(location=tmpdir.strpath, verbose=0).cache(f), |
|
Memory(location=None).cache(f), |
|
), |
|
(MemorizedResult, NotMemorizedResult, MemorizedResult, NotMemorizedResult), |
|
): |
|
assert func(2) == 5 |
|
result = func.call_and_shelve(2) |
|
assert isinstance(result, Result) |
|
assert result.get() == 5 |
|
|
|
result.clear() |
|
with raises(KeyError): |
|
result.get() |
|
result.clear() |
|
|
|
|
|
def test_call_and_shelve_lazily_load_stored_result(tmpdir): |
|
"""Check call_and_shelve only load stored data if needed.""" |
|
test_access_time_file = tmpdir.join("test_access") |
|
test_access_time_file.write("test_access") |
|
test_access_time = os.stat(test_access_time_file.strpath).st_atime |
|
|
|
|
|
time.sleep(0.5) |
|
assert test_access_time_file.read() == "test_access" |
|
|
|
if test_access_time == os.stat(test_access_time_file.strpath).st_atime: |
|
|
|
|
|
pytest.skip("filesystem does not support fine-grained access time attribute") |
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
func = memory.cache(f) |
|
args_id = func._get_args_id(2) |
|
result_path = os.path.join( |
|
memory.store_backend.location, func.func_id, args_id, "output.pkl" |
|
) |
|
assert func(2) == 5 |
|
first_access_time = os.stat(result_path).st_atime |
|
time.sleep(1) |
|
|
|
|
|
result = func.call_and_shelve(2) |
|
assert isinstance(result, MemorizedResult) |
|
assert os.stat(result_path).st_atime == first_access_time |
|
time.sleep(1) |
|
|
|
|
|
assert result.get() == 5 |
|
assert os.stat(result_path).st_atime > first_access_time |
|
|
|
|
|
def test_memorized_pickling(tmpdir): |
|
for func in (MemorizedFunc(f, tmpdir.strpath), NotMemorizedFunc(f)): |
|
filename = tmpdir.join("pickling_test.dat").strpath |
|
result = func.call_and_shelve(2) |
|
with open(filename, "wb") as fp: |
|
pickle.dump(result, fp) |
|
with open(filename, "rb") as fp: |
|
result2 = pickle.load(fp) |
|
assert result2.get() == result.get() |
|
os.remove(filename) |
|
|
|
|
|
def test_memorized_repr(tmpdir): |
|
func = MemorizedFunc(f, tmpdir.strpath) |
|
result = func.call_and_shelve(2) |
|
|
|
func2 = MemorizedFunc(f, tmpdir.strpath) |
|
result2 = func2.call_and_shelve(2) |
|
assert result.get() == result2.get() |
|
assert repr(func) == repr(func2) |
|
|
|
|
|
func = NotMemorizedFunc(f) |
|
repr(func) |
|
repr(func.call_and_shelve(2)) |
|
|
|
|
|
func = MemorizedFunc(f, tmpdir.strpath, verbose=11, timestamp=time.time()) |
|
result = func.call_and_shelve(11) |
|
result.get() |
|
|
|
func = MemorizedFunc(f, tmpdir.strpath, verbose=11) |
|
result = func.call_and_shelve(11) |
|
result.get() |
|
|
|
func = MemorizedFunc(f, tmpdir.strpath, verbose=5, timestamp=time.time()) |
|
result = func.call_and_shelve(11) |
|
result.get() |
|
|
|
func = MemorizedFunc(f, tmpdir.strpath, verbose=5) |
|
result = func.call_and_shelve(11) |
|
result.get() |
|
|
|
|
|
def test_memory_file_modification(capsys, tmpdir, monkeypatch): |
|
|
|
|
|
dir_name = tmpdir.mkdir("tmp_import").strpath |
|
filename = os.path.join(dir_name, "tmp_joblib_.py") |
|
content = "def f(x):\n print(x)\n return x\n" |
|
with open(filename, "w") as module_file: |
|
module_file.write(content) |
|
|
|
|
|
monkeypatch.syspath_prepend(dir_name) |
|
import tmp_joblib_ as tmp |
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
f = memory.cache(tmp.f) |
|
|
|
f(1) |
|
f(2) |
|
f(1) |
|
|
|
|
|
with open(filename, "w") as module_file: |
|
module_file.write("\n\n" + content) |
|
|
|
|
|
f(1) |
|
f(1) |
|
|
|
|
|
shutil.rmtree(dir_name) |
|
os.mkdir(dir_name) |
|
|
|
content = 'def f(x):\n print("x=%s" % x)\n return x\n' |
|
with open(filename, "w") as module_file: |
|
module_file.write(content) |
|
|
|
|
|
|
|
|
|
f(1) |
|
f(1) |
|
|
|
|
|
sys.stdout.write("Reloading\n") |
|
sys.modules.pop("tmp_joblib_") |
|
import tmp_joblib_ as tmp |
|
|
|
f = memory.cache(tmp.f) |
|
|
|
|
|
f(1) |
|
f(1) |
|
|
|
out, err = capsys.readouterr() |
|
assert out == "1\n2\nReloading\nx=1\n" |
|
|
|
|
|
def _function_to_cache(a, b): |
|
|
|
pass |
|
|
|
|
|
def _sum(a, b): |
|
return a + b |
|
|
|
|
|
def _product(a, b): |
|
return a * b |
|
|
|
|
|
def test_memory_in_memory_function_code_change(tmpdir): |
|
_function_to_cache.__code__ = _sum.__code__ |
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
f = memory.cache(_function_to_cache) |
|
|
|
assert f(1, 2) == 3 |
|
assert f(1, 2) == 3 |
|
|
|
with warns(JobLibCollisionWarning): |
|
|
|
_function_to_cache.__code__ = _product.__code__ |
|
assert f(1, 2) == 2 |
|
assert f(1, 2) == 2 |
|
|
|
|
|
def test_clear_memory_with_none_location(): |
|
memory = Memory(location=None) |
|
memory.clear() |
|
|
|
|
|
def func_with_kwonly_args(a, b, *, kw1="kw1", kw2="kw2"): |
|
return a, b, kw1, kw2 |
|
|
|
|
|
def func_with_signature(a: int, b: float) -> float: |
|
return a + b |
|
|
|
|
|
def test_memory_func_with_kwonly_args(tmpdir): |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
func_cached = memory.cache(func_with_kwonly_args) |
|
|
|
assert func_cached(1, 2, kw1=3) == (1, 2, 3, "kw2") |
|
|
|
|
|
|
|
with raises(ValueError) as excinfo: |
|
func_cached(1, 2, 3, kw2=4) |
|
excinfo.match("Keyword-only parameter 'kw1' was passed as positional parameter") |
|
|
|
|
|
|
|
func_cached(1, 2, kw1=3, kw2=4) |
|
|
|
with raises(ValueError) as excinfo: |
|
func_cached(1, 2, 3, kw2=4) |
|
excinfo.match("Keyword-only parameter 'kw1' was passed as positional parameter") |
|
|
|
|
|
func_cached = memory.cache(func_with_kwonly_args, ignore=["kw2"]) |
|
assert func_cached(1, 2, kw1=3, kw2=4) == (1, 2, 3, 4) |
|
assert func_cached(1, 2, kw1=3, kw2="ignored") == (1, 2, 3, 4) |
|
|
|
|
|
def test_memory_func_with_signature(tmpdir): |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
func_cached = memory.cache(func_with_signature) |
|
|
|
assert func_cached(1, 2.0) == 3.0 |
|
|
|
|
|
def _setup_toy_cache(tmpdir, num_inputs=10): |
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
|
|
@memory.cache() |
|
def get_1000_bytes(arg): |
|
return "a" * 1000 |
|
|
|
inputs = list(range(num_inputs)) |
|
for arg in inputs: |
|
get_1000_bytes(arg) |
|
|
|
func_id = _build_func_identifier(get_1000_bytes) |
|
hash_dirnames = [get_1000_bytes._get_args_id(arg) for arg in inputs] |
|
|
|
full_hashdirs = [ |
|
os.path.join(get_1000_bytes.store_backend.location, func_id, dirname) |
|
for dirname in hash_dirnames |
|
] |
|
return memory, full_hashdirs, get_1000_bytes |
|
|
|
|
|
def test__get_items(tmpdir): |
|
memory, expected_hash_dirs, _ = _setup_toy_cache(tmpdir) |
|
items = memory.store_backend.get_items() |
|
hash_dirs = [ci.path for ci in items] |
|
assert set(hash_dirs) == set(expected_hash_dirs) |
|
|
|
def get_files_size(directory): |
|
full_paths = [os.path.join(directory, fn) for fn in os.listdir(directory)] |
|
return sum(os.path.getsize(fp) for fp in full_paths) |
|
|
|
expected_hash_cache_sizes = [get_files_size(hash_dir) for hash_dir in hash_dirs] |
|
hash_cache_sizes = [ci.size for ci in items] |
|
assert hash_cache_sizes == expected_hash_cache_sizes |
|
|
|
output_filenames = [os.path.join(hash_dir, "output.pkl") for hash_dir in hash_dirs] |
|
|
|
expected_last_accesses = [ |
|
datetime.datetime.fromtimestamp(os.path.getatime(fn)) for fn in output_filenames |
|
] |
|
last_accesses = [ci.last_access for ci in items] |
|
assert last_accesses == expected_last_accesses |
|
|
|
|
|
def test__get_items_to_delete(tmpdir): |
|
|
|
memory, _, _ = _setup_toy_cache(tmpdir, num_inputs=0) |
|
items_to_delete = memory.store_backend._get_items_to_delete("1K") |
|
assert items_to_delete == [] |
|
|
|
memory, expected_hash_cachedirs, _ = _setup_toy_cache(tmpdir) |
|
items = memory.store_backend.get_items() |
|
|
|
|
|
items_to_delete = memory.store_backend._get_items_to_delete("2K") |
|
nb_hashes = len(expected_hash_cachedirs) |
|
assert set.issubset(set(items_to_delete), set(items)) |
|
assert len(items_to_delete) == nb_hashes - 1 |
|
|
|
|
|
items_to_delete_2048b = memory.store_backend._get_items_to_delete(2048) |
|
assert sorted(items_to_delete) == sorted(items_to_delete_2048b) |
|
|
|
|
|
items_to_delete_empty = memory.store_backend._get_items_to_delete("1M") |
|
assert items_to_delete_empty == [] |
|
|
|
|
|
bytes_limit_too_small = 500 |
|
items_to_delete_500b = memory.store_backend._get_items_to_delete( |
|
bytes_limit_too_small |
|
) |
|
assert set(items_to_delete_500b), set(items) |
|
|
|
|
|
|
|
items_to_delete_6000b = memory.store_backend._get_items_to_delete(6000) |
|
surviving_items = set(items).difference(items_to_delete_6000b) |
|
|
|
assert max(ci.last_access for ci in items_to_delete_6000b) <= min( |
|
ci.last_access for ci in surviving_items |
|
) |
|
|
|
|
|
def test_memory_reduce_size_bytes_limit(tmpdir): |
|
memory, _, _ = _setup_toy_cache(tmpdir) |
|
ref_cache_items = memory.store_backend.get_items() |
|
|
|
|
|
memory.reduce_size() |
|
cache_items = memory.store_backend.get_items() |
|
assert sorted(ref_cache_items) == sorted(cache_items) |
|
|
|
|
|
|
|
memory.reduce_size(bytes_limit="1M") |
|
cache_items = memory.store_backend.get_items() |
|
assert sorted(ref_cache_items) == sorted(cache_items) |
|
|
|
|
|
memory.reduce_size(bytes_limit="3K") |
|
cache_items = memory.store_backend.get_items() |
|
assert set.issubset(set(cache_items), set(ref_cache_items)) |
|
assert len(cache_items) == 2 |
|
|
|
|
|
bytes_limit_too_small = 500 |
|
memory.reduce_size(bytes_limit=bytes_limit_too_small) |
|
cache_items = memory.store_backend.get_items() |
|
assert cache_items == [] |
|
|
|
|
|
def test_memory_reduce_size_items_limit(tmpdir): |
|
memory, _, _ = _setup_toy_cache(tmpdir) |
|
ref_cache_items = memory.store_backend.get_items() |
|
|
|
|
|
memory.reduce_size() |
|
cache_items = memory.store_backend.get_items() |
|
assert sorted(ref_cache_items) == sorted(cache_items) |
|
|
|
|
|
|
|
memory.reduce_size(items_limit=10) |
|
cache_items = memory.store_backend.get_items() |
|
assert sorted(ref_cache_items) == sorted(cache_items) |
|
|
|
|
|
memory.reduce_size(items_limit=2) |
|
cache_items = memory.store_backend.get_items() |
|
assert set.issubset(set(cache_items), set(ref_cache_items)) |
|
assert len(cache_items) == 2 |
|
|
|
|
|
memory.reduce_size(items_limit=0) |
|
cache_items = memory.store_backend.get_items() |
|
assert cache_items == [] |
|
|
|
|
|
def test_memory_reduce_size_age_limit(tmpdir): |
|
import datetime |
|
import time |
|
|
|
memory, _, put_cache = _setup_toy_cache(tmpdir) |
|
ref_cache_items = memory.store_backend.get_items() |
|
|
|
|
|
memory.reduce_size() |
|
cache_items = memory.store_backend.get_items() |
|
assert sorted(ref_cache_items) == sorted(cache_items) |
|
|
|
|
|
memory.reduce_size(age_limit=datetime.timedelta(days=1)) |
|
cache_items = memory.store_backend.get_items() |
|
assert sorted(ref_cache_items) == sorted(cache_items) |
|
|
|
|
|
time.sleep(1) |
|
put_cache(-1) |
|
put_cache(-2) |
|
memory.reduce_size(age_limit=datetime.timedelta(seconds=1)) |
|
cache_items = memory.store_backend.get_items() |
|
assert not set.issubset(set(cache_items), set(ref_cache_items)) |
|
assert len(cache_items) == 2 |
|
|
|
|
|
with pytest.raises(ValueError, match="has to be a positive"): |
|
memory.reduce_size(age_limit=datetime.timedelta(seconds=-1)) |
|
|
|
|
|
time.sleep(0.001) |
|
memory.reduce_size(age_limit=datetime.timedelta(seconds=0)) |
|
cache_items = memory.store_backend.get_items() |
|
assert cache_items == [] |
|
|
|
|
|
def test_memory_clear(tmpdir): |
|
memory, _, g = _setup_toy_cache(tmpdir) |
|
memory.clear() |
|
|
|
assert os.listdir(memory.store_backend.location) == [] |
|
|
|
|
|
assert not g._check_previous_func_code(stacklevel=4) |
|
|
|
|
|
def fast_func_with_complex_output(): |
|
complex_obj = ["a" * 1000] * 1000 |
|
return complex_obj |
|
|
|
|
|
def fast_func_with_conditional_complex_output(complex_output=True): |
|
complex_obj = {str(i): i for i in range(int(1e5))} |
|
return complex_obj if complex_output else "simple output" |
|
|
|
|
|
@with_multiprocessing |
|
def test_cached_function_race_condition_when_persisting_output(tmpdir, capfd): |
|
|
|
|
|
|
|
memory = Memory(location=tmpdir.strpath) |
|
func_cached = memory.cache(fast_func_with_complex_output) |
|
|
|
Parallel(n_jobs=2)(delayed(func_cached)() for i in range(3)) |
|
|
|
stdout, stderr = capfd.readouterr() |
|
|
|
|
|
|
|
|
|
exception_msg = "Exception while loading results" |
|
assert exception_msg not in stdout |
|
assert exception_msg not in stderr |
|
|
|
|
|
@with_multiprocessing |
|
def test_cached_function_race_condition_when_persisting_output_2(tmpdir, capfd): |
|
|
|
|
|
|
|
|
|
|
|
memory = Memory(location=tmpdir.strpath) |
|
func_cached = memory.cache(fast_func_with_conditional_complex_output) |
|
|
|
Parallel(n_jobs=2)( |
|
delayed(func_cached)(True if i % 2 == 0 else False) for i in range(3) |
|
) |
|
|
|
stdout, stderr = capfd.readouterr() |
|
|
|
|
|
|
|
|
|
exception_msg = "Exception while loading results" |
|
assert exception_msg not in stdout |
|
assert exception_msg not in stderr |
|
|
|
|
|
def test_memory_recomputes_after_an_error_while_loading_results(tmpdir, monkeypatch): |
|
memory = Memory(location=tmpdir.strpath) |
|
|
|
def func(arg): |
|
|
|
|
|
|
|
time.sleep(0.01) |
|
return arg, time.time() |
|
|
|
cached_func = memory.cache(func) |
|
input_arg = "arg" |
|
arg, timestamp = cached_func(input_arg) |
|
|
|
|
|
assert arg == input_arg |
|
|
|
|
|
|
|
corrupt_single_cache_item(memory) |
|
|
|
|
|
|
|
recorded_warnings = monkeypatch_cached_func_warn(cached_func, monkeypatch) |
|
recomputed_arg, recomputed_timestamp = cached_func(arg) |
|
assert len(recorded_warnings) == 1 |
|
exception_msg = "Exception while loading results" |
|
assert exception_msg in recorded_warnings[0] |
|
assert recomputed_arg == arg |
|
assert recomputed_timestamp > timestamp |
|
|
|
|
|
|
|
corrupt_single_cache_item(memory) |
|
reference = cached_func.call_and_shelve(arg) |
|
try: |
|
reference.get() |
|
raise AssertionError( |
|
"It normally not possible to load a corrupted MemorizedResult" |
|
) |
|
except KeyError as e: |
|
message = "is corrupted" |
|
assert message in str(e.args) |
|
|
|
|
|
class IncompleteStoreBackend(StoreBackendBase): |
|
"""This backend cannot be instantiated and should raise a TypeError.""" |
|
|
|
pass |
|
|
|
|
|
class DummyStoreBackend(StoreBackendBase): |
|
"""A dummy store backend that does nothing.""" |
|
|
|
def _open_item(self, *args, **kwargs): |
|
"""Open an item on store.""" |
|
"Does nothing" |
|
|
|
def _item_exists(self, location): |
|
"""Check if an item location exists.""" |
|
"Does nothing" |
|
|
|
def _move_item(self, src, dst): |
|
"""Move an item from src to dst in store.""" |
|
"Does nothing" |
|
|
|
def create_location(self, location): |
|
"""Create location on store.""" |
|
"Does nothing" |
|
|
|
def exists(self, obj): |
|
"""Check if an object exists in the store""" |
|
return False |
|
|
|
def clear_location(self, obj): |
|
"""Clear object on store""" |
|
"Does nothing" |
|
|
|
def get_items(self): |
|
"""Returns the whole list of items available in cache.""" |
|
return [] |
|
|
|
def configure(self, location, *args, **kwargs): |
|
"""Configure the store""" |
|
"Does nothing" |
|
|
|
|
|
@parametrize("invalid_prefix", [None, dict(), list()]) |
|
def test_register_invalid_store_backends_key(invalid_prefix): |
|
|
|
with raises(ValueError) as excinfo: |
|
register_store_backend(invalid_prefix, None) |
|
excinfo.match(r"Store backend name should be a string*") |
|
|
|
|
|
def test_register_invalid_store_backends_object(): |
|
|
|
|
|
with raises(ValueError) as excinfo: |
|
register_store_backend("fs", None) |
|
excinfo.match(r"Store backend should inherit StoreBackendBase*") |
|
|
|
|
|
def test_memory_default_store_backend(): |
|
|
|
with raises(TypeError) as excinfo: |
|
Memory(location="/tmp/joblib", backend="unknown") |
|
excinfo.match(r"Unknown location*") |
|
|
|
|
|
def test_warning_on_unknown_location_type(): |
|
class NonSupportedLocationClass: |
|
pass |
|
|
|
unsupported_location = NonSupportedLocationClass() |
|
|
|
with warns(UserWarning) as warninfo: |
|
_store_backend_factory("local", location=unsupported_location) |
|
|
|
expected_mesage = ( |
|
"Instantiating a backend using a " |
|
"NonSupportedLocationClass as a location is not " |
|
"supported by joblib" |
|
) |
|
assert expected_mesage in str(warninfo[0].message) |
|
|
|
|
|
def test_instanciate_incomplete_store_backend(): |
|
|
|
|
|
backend_name = "isb" |
|
register_store_backend(backend_name, IncompleteStoreBackend) |
|
assert (backend_name, IncompleteStoreBackend) in _STORE_BACKENDS.items() |
|
with raises(TypeError) as excinfo: |
|
_store_backend_factory(backend_name, "fake_location") |
|
excinfo.match( |
|
r"Can't instantiate abstract class IncompleteStoreBackend " |
|
"(without an implementation for|with) abstract methods*" |
|
) |
|
|
|
|
|
def test_dummy_store_backend(): |
|
|
|
|
|
backend_name = "dsb" |
|
register_store_backend(backend_name, DummyStoreBackend) |
|
assert (backend_name, DummyStoreBackend) in _STORE_BACKENDS.items() |
|
|
|
backend_obj = _store_backend_factory(backend_name, "dummy_location") |
|
assert isinstance(backend_obj, DummyStoreBackend) |
|
|
|
|
|
def test_instanciate_store_backend_with_pathlib_path(): |
|
|
|
path = pathlib.Path("some_folder") |
|
backend_obj = _store_backend_factory("local", path) |
|
try: |
|
assert backend_obj.location == "some_folder" |
|
finally: |
|
shutil.rmtree("some_folder", ignore_errors=True) |
|
|
|
|
|
def test_filesystem_store_backend_repr(tmpdir): |
|
|
|
|
|
repr_pattern = 'FileSystemStoreBackend(location="{location}")' |
|
backend = FileSystemStoreBackend() |
|
assert backend.location is None |
|
|
|
repr(backend) |
|
|
|
assert str(backend) == repr_pattern.format(location=None) |
|
|
|
|
|
|
|
backend.configure(tmpdir.strpath) |
|
|
|
assert str(backend) == repr_pattern.format(location=tmpdir.strpath) |
|
|
|
repr(backend) |
|
|
|
|
|
def test_memory_objects_repr(tmpdir): |
|
|
|
|
|
def my_func(a, b): |
|
return a + b |
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
memorized_func = memory.cache(my_func) |
|
|
|
memorized_func_repr = "MemorizedFunc(func={func}, location={location})" |
|
|
|
assert str(memorized_func) == memorized_func_repr.format( |
|
func=my_func, location=memory.store_backend.location |
|
) |
|
|
|
memorized_result = memorized_func.call_and_shelve(42, 42) |
|
|
|
memorized_result_repr = ( |
|
'MemorizedResult(location="{location}", func="{func}", args_id="{args_id}")' |
|
) |
|
|
|
assert str(memorized_result) == memorized_result_repr.format( |
|
location=memory.store_backend.location, |
|
func=memorized_result.func_id, |
|
args_id=memorized_result.args_id, |
|
) |
|
|
|
assert str(memory) == "Memory(location={location})".format( |
|
location=memory.store_backend.location |
|
) |
|
|
|
|
|
def test_memorized_result_pickle(tmpdir): |
|
|
|
|
|
|
|
|
|
memory = Memory(location=tmpdir.strpath) |
|
|
|
@memory.cache |
|
def g(x): |
|
return x**2 |
|
|
|
memorized_result = g.call_and_shelve(4) |
|
memorized_result_pickle = pickle.dumps(memorized_result) |
|
memorized_result_loads = pickle.loads(memorized_result_pickle) |
|
|
|
assert ( |
|
memorized_result.store_backend.location |
|
== memorized_result_loads.store_backend.location |
|
) |
|
assert memorized_result.func == memorized_result_loads.func |
|
assert memorized_result.args_id == memorized_result_loads.args_id |
|
assert str(memorized_result) == str(memorized_result_loads) |
|
|
|
|
|
def compare(left, right, ignored_attrs=None): |
|
if ignored_attrs is None: |
|
ignored_attrs = [] |
|
|
|
left_vars = vars(left) |
|
right_vars = vars(right) |
|
assert set(left_vars.keys()) == set(right_vars.keys()) |
|
for attr in left_vars.keys(): |
|
if attr in ignored_attrs: |
|
continue |
|
assert left_vars[attr] == right_vars[attr] |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"memory_kwargs", |
|
[ |
|
{"compress": 3, "verbose": 2}, |
|
{"mmap_mode": "r", "verbose": 5, "backend_options": {"parameter": "unused"}}, |
|
], |
|
) |
|
def test_memory_pickle_dump_load(tmpdir, memory_kwargs): |
|
memory = Memory(location=tmpdir.strpath, **memory_kwargs) |
|
|
|
memory_reloaded = pickle.loads(pickle.dumps(memory)) |
|
|
|
|
|
compare(memory.store_backend, memory_reloaded.store_backend) |
|
compare( |
|
memory, |
|
memory_reloaded, |
|
ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]), |
|
) |
|
assert hash(memory) == hash(memory_reloaded) |
|
|
|
func_cached = memory.cache(f) |
|
|
|
func_cached_reloaded = pickle.loads(pickle.dumps(func_cached)) |
|
|
|
|
|
compare(func_cached.store_backend, func_cached_reloaded.store_backend) |
|
compare( |
|
func_cached, |
|
func_cached_reloaded, |
|
ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]), |
|
) |
|
assert hash(func_cached) == hash(func_cached_reloaded) |
|
|
|
|
|
memorized_result = func_cached.call_and_shelve(1) |
|
memorized_result_reloaded = pickle.loads(pickle.dumps(memorized_result)) |
|
|
|
compare(memorized_result.store_backend, memorized_result_reloaded.store_backend) |
|
compare( |
|
memorized_result, |
|
memorized_result_reloaded, |
|
ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]), |
|
) |
|
assert hash(memorized_result) == hash(memorized_result_reloaded) |
|
|
|
|
|
def test_info_log(tmpdir, caplog): |
|
caplog.set_level(logging.INFO) |
|
x = 3 |
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=20) |
|
|
|
@memory.cache |
|
def f(x): |
|
return x**2 |
|
|
|
_ = f(x) |
|
assert "Querying" in caplog.text |
|
caplog.clear() |
|
|
|
memory = Memory(location=tmpdir.strpath, verbose=0) |
|
|
|
@memory.cache |
|
def f(x): |
|
return x**2 |
|
|
|
_ = f(x) |
|
assert "Querying" not in caplog.text |
|
caplog.clear() |
|
|
|
|
|
class TestCacheValidationCallback: |
|
"Tests on parameter `cache_validation_callback`" |
|
|
|
def foo(self, x, d, delay=None): |
|
d["run"] = True |
|
if delay is not None: |
|
time.sleep(delay) |
|
return x * 2 |
|
|
|
def test_invalid_cache_validation_callback(self, memory): |
|
"Test invalid values for `cache_validation_callback" |
|
match = "cache_validation_callback needs to be callable. Got True." |
|
with pytest.raises(ValueError, match=match): |
|
memory.cache(cache_validation_callback=True) |
|
|
|
@pytest.mark.parametrize("consider_cache_valid", [True, False]) |
|
def test_constant_cache_validation_callback(self, memory, consider_cache_valid): |
|
"Test expiry of old results" |
|
f = memory.cache( |
|
self.foo, |
|
cache_validation_callback=lambda _: consider_cache_valid, |
|
ignore=["d"], |
|
) |
|
|
|
d1, d2 = {"run": False}, {"run": False} |
|
assert f(2, d1) == 4 |
|
assert f(2, d2) == 4 |
|
|
|
assert d1["run"] |
|
assert d2["run"] != consider_cache_valid |
|
|
|
def test_memory_only_cache_long_run(self, memory): |
|
"Test cache validity based on run duration." |
|
|
|
def cache_validation_callback(metadata): |
|
duration = metadata["duration"] |
|
if duration > 0.1: |
|
return True |
|
|
|
f = memory.cache( |
|
self.foo, cache_validation_callback=cache_validation_callback, ignore=["d"] |
|
) |
|
|
|
|
|
d1, d2 = {"run": False}, {"run": False} |
|
assert f(2, d1, delay=0) == 4 |
|
assert f(2, d2, delay=0) == 4 |
|
assert d1["run"] |
|
assert d2["run"] |
|
|
|
|
|
d1, d2 = {"run": False}, {"run": False} |
|
assert f(2, d1, delay=0.2) == 4 |
|
assert f(2, d2, delay=0.2) == 4 |
|
assert d1["run"] |
|
assert not d2["run"] |
|
|
|
def test_memory_expires_after(self, memory): |
|
"Test expiry of old cached results" |
|
|
|
f = memory.cache( |
|
self.foo, cache_validation_callback=expires_after(seconds=0.3), ignore=["d"] |
|
) |
|
|
|
d1, d2, d3 = {"run": False}, {"run": False}, {"run": False} |
|
assert f(2, d1) == 4 |
|
assert f(2, d2) == 4 |
|
time.sleep(0.5) |
|
assert f(2, d3) == 4 |
|
|
|
assert d1["run"] |
|
assert not d2["run"] |
|
assert d3["run"] |
|
|
|
|
|
class TestMemorizedFunc: |
|
"Tests for the MemorizedFunc and NotMemorizedFunc classes" |
|
|
|
@staticmethod |
|
def f(x, counter): |
|
counter[x] = counter.get(x, 0) + 1 |
|
return counter[x] |
|
|
|
def test_call_method_memorized(self, memory): |
|
"Test calling the function" |
|
|
|
f = memory.cache(self.f, ignore=["counter"]) |
|
|
|
counter = {} |
|
assert f(2, counter) == 1 |
|
assert f(2, counter) == 1 |
|
|
|
x, meta = f.call(2, counter) |
|
assert x == 2, "f has not been called properly" |
|
assert isinstance(meta, dict), ( |
|
"Metadata are not returned by MemorizedFunc.call." |
|
) |
|
|
|
def test_call_method_not_memorized(self, memory): |
|
"Test calling the function" |
|
|
|
f = NotMemorizedFunc(self.f) |
|
|
|
counter = {} |
|
assert f(2, counter) == 1 |
|
assert f(2, counter) == 2 |
|
|
|
x, meta = f.call(2, counter) |
|
assert x == 3, "f has not been called properly" |
|
assert isinstance(meta, dict), ( |
|
"Metadata are not returned by MemorizedFunc.call." |
|
) |
|
|
|
|
|
@with_numpy |
|
@parametrize( |
|
"location", |
|
[ |
|
"test_cache_dir", |
|
pathlib.Path("test_cache_dir"), |
|
pathlib.Path("test_cache_dir").resolve(), |
|
], |
|
) |
|
def test_memory_creates_gitignore(location): |
|
"""Test that using the memory object automatically creates a `.gitignore` file |
|
within the new cache directory.""" |
|
|
|
mem = Memory(location) |
|
arr = np.asarray([[1, 2, 3], [4, 5, 6]]) |
|
costly_operation = mem.cache(np.square) |
|
costly_operation(arr) |
|
|
|
location = pathlib.Path(location) |
|
|
|
try: |
|
path_to_gitignore_file = os.path.join(location, ".gitignore") |
|
gitignore_file_content = "# Created by joblib automatically.\n*\n" |
|
with open(path_to_gitignore_file) as f: |
|
assert gitignore_file_content == f.read() |
|
|
|
finally: |
|
shutil.rmtree(location, ignore_errors=True) |
|
|