|
"""This test for the LFW require medium-size data downloading and processing |
|
|
|
If the data has not been already downloaded by running the examples, |
|
the tests won't run (skipped). |
|
|
|
If the test are run, the first execution will be long (typically a bit |
|
more than a couple of minutes) but as the dataset loader is leveraging |
|
joblib, successive runs will be fast (less than 200ms). |
|
""" |
|
|
|
import random |
|
from functools import partial |
|
|
|
import numpy as np |
|
import pytest |
|
|
|
from sklearn.datasets import fetch_lfw_pairs, fetch_lfw_people |
|
from sklearn.datasets.tests.test_common import check_return_X_y |
|
from sklearn.utils._testing import assert_array_equal |
|
|
|
FAKE_NAMES = [ |
|
"Abdelatif_Smith", |
|
"Abhati_Kepler", |
|
"Camara_Alvaro", |
|
"Chen_Dupont", |
|
"John_Lee", |
|
"Lin_Bauman", |
|
"Onur_Lopez", |
|
] |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
def mock_empty_data_home(tmp_path_factory): |
|
data_dir = tmp_path_factory.mktemp("scikit_learn_empty_test") |
|
|
|
yield data_dir |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
def mock_data_home(tmp_path_factory): |
|
"""Test fixture run once and common to all tests of this module""" |
|
Image = pytest.importorskip("PIL.Image") |
|
|
|
data_dir = tmp_path_factory.mktemp("scikit_learn_lfw_test") |
|
lfw_home = data_dir / "lfw_home" |
|
lfw_home.mkdir(parents=True, exist_ok=True) |
|
|
|
random_state = random.Random(42) |
|
np_rng = np.random.RandomState(42) |
|
|
|
|
|
counts = {} |
|
for name in FAKE_NAMES: |
|
folder_name = lfw_home / "lfw_funneled" / name |
|
folder_name.mkdir(parents=True, exist_ok=True) |
|
|
|
n_faces = np_rng.randint(1, 5) |
|
counts[name] = n_faces |
|
for i in range(n_faces): |
|
file_path = folder_name / (name + "_%04d.jpg" % i) |
|
uniface = np_rng.randint(0, 255, size=(250, 250, 3)) |
|
img = Image.fromarray(uniface.astype(np.uint8)) |
|
img.save(file_path) |
|
|
|
|
|
(lfw_home / "lfw_funneled" / ".test.swp").write_bytes( |
|
b"Text file to be ignored by the dataset loader." |
|
) |
|
|
|
|
|
with open(lfw_home / "pairsDevTrain.txt", "wb") as f: |
|
f.write(b"10\n") |
|
more_than_two = [name for name, count in counts.items() if count >= 2] |
|
for i in range(5): |
|
name = random_state.choice(more_than_two) |
|
first, second = random_state.sample(range(counts[name]), 2) |
|
f.write(("%s\t%d\t%d\n" % (name, first, second)).encode()) |
|
|
|
for i in range(5): |
|
first_name, second_name = random_state.sample(FAKE_NAMES, 2) |
|
first_index = np_rng.choice(np.arange(counts[first_name])) |
|
second_index = np_rng.choice(np.arange(counts[second_name])) |
|
f.write( |
|
( |
|
"%s\t%d\t%s\t%d\n" |
|
% (first_name, first_index, second_name, second_index) |
|
).encode() |
|
) |
|
|
|
(lfw_home / "pairsDevTest.txt").write_bytes( |
|
b"Fake place holder that won't be tested" |
|
) |
|
(lfw_home / "pairs.txt").write_bytes(b"Fake place holder that won't be tested") |
|
|
|
yield data_dir |
|
|
|
|
|
def test_load_empty_lfw_people(mock_empty_data_home): |
|
with pytest.raises(OSError): |
|
fetch_lfw_people(data_home=mock_empty_data_home, download_if_missing=False) |
|
|
|
|
|
def test_load_fake_lfw_people(mock_data_home): |
|
lfw_people = fetch_lfw_people( |
|
data_home=mock_data_home, min_faces_per_person=3, download_if_missing=False |
|
) |
|
|
|
|
|
|
|
assert lfw_people.images.shape == (10, 62, 47) |
|
assert lfw_people.data.shape == (10, 2914) |
|
|
|
|
|
assert_array_equal(lfw_people.target, [2, 0, 1, 0, 2, 0, 2, 1, 1, 2]) |
|
|
|
|
|
expected_classes = ["Abdelatif Smith", "Abhati Kepler", "Onur Lopez"] |
|
assert_array_equal(lfw_people.target_names, expected_classes) |
|
|
|
|
|
|
|
lfw_people = fetch_lfw_people( |
|
data_home=mock_data_home, |
|
resize=None, |
|
slice_=None, |
|
color=True, |
|
download_if_missing=False, |
|
) |
|
assert lfw_people.images.shape == (17, 250, 250, 3) |
|
assert lfw_people.DESCR.startswith(".. _labeled_faces_in_the_wild_dataset:") |
|
|
|
|
|
assert_array_equal( |
|
lfw_people.target, [0, 0, 1, 6, 5, 6, 3, 6, 0, 3, 6, 1, 2, 4, 5, 1, 2] |
|
) |
|
assert_array_equal( |
|
lfw_people.target_names, |
|
[ |
|
"Abdelatif Smith", |
|
"Abhati Kepler", |
|
"Camara Alvaro", |
|
"Chen Dupont", |
|
"John Lee", |
|
"Lin Bauman", |
|
"Onur Lopez", |
|
], |
|
) |
|
|
|
|
|
fetch_func = partial( |
|
fetch_lfw_people, |
|
data_home=mock_data_home, |
|
resize=None, |
|
slice_=None, |
|
color=True, |
|
download_if_missing=False, |
|
) |
|
check_return_X_y(lfw_people, fetch_func) |
|
|
|
|
|
def test_load_fake_lfw_people_too_restrictive(mock_data_home): |
|
with pytest.raises(ValueError): |
|
fetch_lfw_people( |
|
data_home=mock_data_home, |
|
min_faces_per_person=100, |
|
download_if_missing=False, |
|
) |
|
|
|
|
|
def test_load_empty_lfw_pairs(mock_empty_data_home): |
|
with pytest.raises(OSError): |
|
fetch_lfw_pairs(data_home=mock_empty_data_home, download_if_missing=False) |
|
|
|
|
|
def test_load_fake_lfw_pairs(mock_data_home): |
|
lfw_pairs_train = fetch_lfw_pairs( |
|
data_home=mock_data_home, download_if_missing=False |
|
) |
|
|
|
|
|
|
|
assert lfw_pairs_train.pairs.shape == (10, 2, 62, 47) |
|
|
|
|
|
assert_array_equal(lfw_pairs_train.target, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) |
|
|
|
|
|
expected_classes = ["Different persons", "Same person"] |
|
assert_array_equal(lfw_pairs_train.target_names, expected_classes) |
|
|
|
|
|
|
|
lfw_pairs_train = fetch_lfw_pairs( |
|
data_home=mock_data_home, |
|
resize=None, |
|
slice_=None, |
|
color=True, |
|
download_if_missing=False, |
|
) |
|
assert lfw_pairs_train.pairs.shape == (10, 2, 250, 250, 3) |
|
|
|
|
|
assert_array_equal(lfw_pairs_train.target, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) |
|
assert_array_equal(lfw_pairs_train.target_names, expected_classes) |
|
|
|
assert lfw_pairs_train.DESCR.startswith(".. _labeled_faces_in_the_wild_dataset:") |
|
|
|
|
|
def test_fetch_lfw_people_internal_cropping(mock_data_home): |
|
"""Check that we properly crop the images. |
|
|
|
Non-regression test for: |
|
https://github.com/scikit-learn/scikit-learn/issues/24942 |
|
""" |
|
|
|
|
|
|
|
slice_ = (slice(70, 195), slice(78, 172)) |
|
lfw = fetch_lfw_people( |
|
data_home=mock_data_home, |
|
min_faces_per_person=3, |
|
download_if_missing=False, |
|
resize=None, |
|
slice_=slice_, |
|
) |
|
assert lfw.images[0].shape == ( |
|
slice_[0].stop - slice_[0].start, |
|
slice_[1].stop - slice_[1].start, |
|
) |
|
|