import pytest import os import shutil import tempfile import zipfile import json from pathlib import Path from typing import List, Dict, Any from unittest.mock import patch pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") pa_ipc = pytest.importorskip("pyarrow.ipc") from fastapi.testclient import TestClient import main # Import main to reload and access config # --- Test Fixtures --- (Keep client fixture as before) @pytest.fixture(scope="module") def client(): with patch.dict(os.environ, {"DUCKDB_API_DB_PATH": ":memory:"}): import importlib importlib.reload(main) main.TEMP_EXPORT_DIR.mkdir(exist_ok=True) print(f"TestClient using temp export dir: {main.TEMP_EXPORT_DIR}") with TestClient(main.app) as c: yield c print(f"Cleaning up test export dir: {main.TEMP_EXPORT_DIR}") for item in main.TEMP_EXPORT_DIR.iterdir(): try: if item.is_file(): os.remove(item) elif item.is_dir(): shutil.rmtree(item) except Exception as e: print(f"Error cleaning up {item}: {e}") # --- Test Classes --- class TestHealth: # (Keep as before) def test_health_check(self, client: TestClient): response = client.get("/health") assert response.status_code == 200 assert response.json() == {"status": "ok", "message": None} class TestExecution: # (Keep as before) def test_execute_create(self, client: TestClient): response = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER, name VARCHAR);"}) assert response.status_code == 200 assert response.json() == {"status": "success", "message": None} response_fail = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER);"}) assert response_fail.status_code == 400 def test_execute_insert(self, client: TestClient): client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"}) response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (1, 'Alice')"}) assert response.status_code == 200 query_response = client.post("/query/fetchall", json={"sql": "SELECT COUNT(*) FROM test_table"}) assert query_response.status_code == 200 assert query_response.json() == [[1]] def test_execute_insert_params(self, client: TestClient): client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"}) response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (?, ?)", "parameters": [2, "Bob"]}) assert response.status_code == 200 query_response = client.post("/query/fetchall", json={"sql": "SELECT * FROM test_table WHERE id = 2"}) assert query_response.status_code == 200 assert query_response.json() == [[2, "Bob"]] def test_execute_invalid_sql(self, client: TestClient): response = client.post("/execute", json={"sql": "INVALID SQL STATEMENT"}) assert response.status_code == 400 assert "Parser Error" in response.json()["detail"] class TestQuerying: # (Keep as before) @pytest.fixture(scope="class", autouse=True) def setup_data(self, client: TestClient): client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE query_test(id INTEGER, val VARCHAR)"}) client.post("/execute", json={"sql": "INSERT INTO query_test VALUES (1, 'one'), (2, 'two'), (3, 'three')"}) def test_query_fetchall(self, client: TestClient): response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test ORDER BY id"}) assert response.status_code == 200 assert response.json() == [[1, 'one'], [2, 'two'], [3, 'three']] def test_query_fetchall_params(self, client: TestClient): response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > ? ORDER BY id", "parameters": [1]}) assert response.status_code == 200 assert response.json() == [[2, 'two'], [3, 'three']] def test_query_fetchall_empty(self, client: TestClient): response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > 100"}) assert response.status_code == 200 assert response.json() == [] def test_query_dataframe(self, client: TestClient): response = client.post("/query/dataframe", json={"sql": "SELECT * FROM query_test ORDER BY id"}) assert response.status_code == 200 data = response.json() assert data["columns"] == ["id", "val"] assert data["records"] == [ {"id": 1, "val": "one"}, {"id": 2, "val": "two"}, {"id": 3, "val": "three"} ] def test_query_dataframe_invalid_sql(self, client: TestClient): response = client.post("/query/dataframe", json={"sql": "SELECT non_existent FROM query_test"}) assert response.status_code == 400 assert "Binder Error" in response.json()["detail"] class TestStreaming: # (Keep as before) @pytest.fixture(scope="class", autouse=True) def setup_data(self, client: TestClient): client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE stream_test AS SELECT range AS id, range % 5 AS category FROM range(10)"}) def test_stream_arrow(self, client: TestClient): response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test"}) assert response.status_code == 200 assert response.headers["content-type"] == "application/vnd.apache.arrow.stream" if not response.content: pytest.fail("Arrow stream response content is empty") try: reader = pa_ipc.open_stream(response.content) table = reader.read_all() except pa.ArrowInvalid as e: pytest.fail(f"Failed to read Arrow stream: {e}") assert table.num_rows == 10 assert table.column_names == ["id", "category"] assert table.column('id').to_pylist() == list(range(10)) def test_stream_arrow_empty(self, client: TestClient): response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test WHERE id < 0"}) assert response.status_code == 200 assert response.headers["content-type"] == "application/vnd.apache.arrow.stream" try: reader = pa_ipc.open_stream(response.content) table = reader.read_all() assert table.num_rows == 0 except pa.ArrowInvalid as e: print(f"Received ArrowInvalid for empty stream, which is acceptable: {e}") assert response.content == b'' def test_stream_jsonl(self, client: TestClient): response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test ORDER BY id"}) assert response.status_code == 200 assert response.headers["content-type"] == "application/jsonl" lines = response.text.strip().split('\n') records = [json.loads(line) for line in lines if line] assert len(records) == 10 assert records[0] == {"id": 0, "category": 0} assert records[9] == {"id": 9, "category": 4} def test_stream_jsonl_empty(self, client: TestClient): response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test WHERE id < 0"}) assert response.status_code == 200 assert response.headers["content-type"] == "application/jsonl" assert response.text.strip() == "" class TestExportDownload: # (Keep setup_data as before) @pytest.fixture(scope="class", autouse=True) def setup_data(self, client: TestClient): client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE export_table(id INTEGER, name VARCHAR, price DECIMAL(5,2))"}) client.post("/execute", json={"sql": "INSERT INTO export_table VALUES (1, 'Apple', 0.50), (2, 'Banana', 0.30), (3, 'Orange', 0.75)"}) @pytest.mark.parametrize( "endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn", [ ("csv", "text/csv", ".csv", lambda c: b"id,name,price\n1,Apple,0.50\n" in c), ("parquet", "application/vnd.apache.parquet", ".parquet", lambda c: c.startswith(b"PAR1")), # --- MODIFIED JSON/JSONL Lambdas --- ("json", "application/json", ".json", lambda c: c.strip().startswith(b'[') and c.strip().endswith(b']')), ("jsonl", "application/jsonl", ".jsonl", lambda c: b'"id":1' in c and b'"name":"Apple"' in c and b'\n' in c), # --- END MODIFICATION --- ] ) def test_export_data(self, client: TestClient, endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn, tmp_path): endpoint = f"/export/data/{endpoint_suffix}" payload = {"source": "export_table"} if endpoint_suffix == 'csv': payload['options'] = {'HEADER': True} response = client.post(endpoint, json=payload) assert response.status_code == 200, f"Request to {endpoint} failed: {response.text}" assert response.headers["content-type"].startswith(expected_content_type) assert "content-disposition" in response.headers assert f'filename="export_export_table{expected_filename_ext}"' in response.headers["content-disposition"] downloaded_path = tmp_path / f"downloaded{expected_filename_ext}" with open(downloaded_path, "wb") as f: f.write(response.content) assert downloaded_path.exists() assert validation_fn(response.content), f"Validation failed for {endpoint_suffix}" # Test with a query source payload = {"source": "SELECT id, name FROM export_table WHERE price > 0.40 ORDER BY id"} response = client.post(endpoint, json=payload) assert response.status_code == 200 assert f'filename="export_query{expected_filename_ext}"' in response.headers["content-disposition"] assert len(response.content) > 0 # --- Keep test_export_database as before --- def test_export_database(self, client: TestClient, tmp_path): client.post("/execute", json={"sql": "CREATE TABLE IF NOT EXISTS another_table(x int)"}) response = client.post("/export/database") assert response.status_code == 200 assert response.headers["content-type"] == "application/zip" assert "content-disposition" in response.headers assert response.headers["content-disposition"].startswith("attachment; filename=") assert 'filename="in_memory_db_export.zip"' in response.headers["content-disposition"] zip_path = tmp_path / "db_export.zip" with open(zip_path, "wb") as f: f.write(response.content) assert zip_path.exists() with zipfile.ZipFile(zip_path, 'r') as z: print(f"Zip contents: {z.namelist()}") assert "schema.sql" in z.namelist() assert "load.sql" in z.namelist() assert any(name.startswith("export_table") for name in z.namelist()) assert any(name.startswith("another_table") for name in z.namelist()) class TestExtensions: # (Keep as before) def test_install_extension_fail(self, client: TestClient): response = client.post("/extensions/install", json={"extension_name": "nonexistent_dummy_ext"}) assert response.status_code >= 400 assert "Error during install" in response.json()["detail"] or "Failed to download" in response.json()["detail"] def test_load_extension_fail(self, client: TestClient): response = client.post("/extensions/load", json={"extension_name": "nonexistent_dummy_ext"}) assert response.status_code == 400 # --- MODIFIED Assertion --- assert "Error loading extension" in response.json()["detail"] # --- END MODIFICATION --- assert "not found" in response.json()["detail"].lower() @pytest.mark.skip(reason="Requires httpfs extension to be available for install/load") def test_install_and_load_httpfs(self, client: TestClient): install_response = client.post("/extensions/install", json={"extension_name": "httpfs"}) assert install_response.status_code == 200 assert install_response.json()["status"] == "success" load_response = client.post("/extensions/load", json={"extension_name": "httpfs"}) assert load_response.status_code == 200 assert load_response.json()["status"] == "success"