|
import pytest |
|
import os |
|
import shutil |
|
import tempfile |
|
import zipfile |
|
import json |
|
from pathlib import Path |
|
from typing import List, Dict, Any |
|
from unittest.mock import patch |
|
|
|
pd = pytest.importorskip("pandas") |
|
pa = pytest.importorskip("pyarrow") |
|
pa_ipc = pytest.importorskip("pyarrow.ipc") |
|
|
|
from fastapi.testclient import TestClient |
|
import main |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
def client(): |
|
with patch.dict(os.environ, {"DUCKDB_API_DB_PATH": ":memory:"}): |
|
import importlib |
|
importlib.reload(main) |
|
main.TEMP_EXPORT_DIR.mkdir(exist_ok=True) |
|
print(f"TestClient using temp export dir: {main.TEMP_EXPORT_DIR}") |
|
with TestClient(main.app) as c: |
|
yield c |
|
print(f"Cleaning up test export dir: {main.TEMP_EXPORT_DIR}") |
|
for item in main.TEMP_EXPORT_DIR.iterdir(): |
|
try: |
|
if item.is_file(): |
|
os.remove(item) |
|
elif item.is_dir(): |
|
shutil.rmtree(item) |
|
except Exception as e: |
|
print(f"Error cleaning up {item}: {e}") |
|
|
|
|
|
|
|
class TestHealth: |
|
def test_health_check(self, client: TestClient): |
|
response = client.get("/health") |
|
assert response.status_code == 200 |
|
assert response.json() == {"status": "ok", "message": None} |
|
|
|
class TestExecution: |
|
def test_execute_create(self, client: TestClient): |
|
response = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER, name VARCHAR);"}) |
|
assert response.status_code == 200 |
|
assert response.json() == {"status": "success", "message": None} |
|
response_fail = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER);"}) |
|
assert response_fail.status_code == 400 |
|
|
|
def test_execute_insert(self, client: TestClient): |
|
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"}) |
|
response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (1, 'Alice')"}) |
|
assert response.status_code == 200 |
|
query_response = client.post("/query/fetchall", json={"sql": "SELECT COUNT(*) FROM test_table"}) |
|
assert query_response.status_code == 200 |
|
assert query_response.json() == [[1]] |
|
|
|
def test_execute_insert_params(self, client: TestClient): |
|
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"}) |
|
response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (?, ?)", "parameters": [2, "Bob"]}) |
|
assert response.status_code == 200 |
|
query_response = client.post("/query/fetchall", json={"sql": "SELECT * FROM test_table WHERE id = 2"}) |
|
assert query_response.status_code == 200 |
|
assert query_response.json() == [[2, "Bob"]] |
|
|
|
def test_execute_invalid_sql(self, client: TestClient): |
|
response = client.post("/execute", json={"sql": "INVALID SQL STATEMENT"}) |
|
assert response.status_code == 400 |
|
assert "Parser Error" in response.json()["detail"] |
|
|
|
class TestQuerying: |
|
@pytest.fixture(scope="class", autouse=True) |
|
def setup_data(self, client: TestClient): |
|
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE query_test(id INTEGER, val VARCHAR)"}) |
|
client.post("/execute", json={"sql": "INSERT INTO query_test VALUES (1, 'one'), (2, 'two'), (3, 'three')"}) |
|
|
|
def test_query_fetchall(self, client: TestClient): |
|
response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test ORDER BY id"}) |
|
assert response.status_code == 200 |
|
assert response.json() == [[1, 'one'], [2, 'two'], [3, 'three']] |
|
|
|
def test_query_fetchall_params(self, client: TestClient): |
|
response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > ? ORDER BY id", "parameters": [1]}) |
|
assert response.status_code == 200 |
|
assert response.json() == [[2, 'two'], [3, 'three']] |
|
|
|
def test_query_fetchall_empty(self, client: TestClient): |
|
response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > 100"}) |
|
assert response.status_code == 200 |
|
assert response.json() == [] |
|
|
|
def test_query_dataframe(self, client: TestClient): |
|
response = client.post("/query/dataframe", json={"sql": "SELECT * FROM query_test ORDER BY id"}) |
|
assert response.status_code == 200 |
|
data = response.json() |
|
assert data["columns"] == ["id", "val"] |
|
assert data["records"] == [ |
|
{"id": 1, "val": "one"}, |
|
{"id": 2, "val": "two"}, |
|
{"id": 3, "val": "three"} |
|
] |
|
|
|
def test_query_dataframe_invalid_sql(self, client: TestClient): |
|
response = client.post("/query/dataframe", json={"sql": "SELECT non_existent FROM query_test"}) |
|
assert response.status_code == 400 |
|
assert "Binder Error" in response.json()["detail"] |
|
|
|
class TestStreaming: |
|
@pytest.fixture(scope="class", autouse=True) |
|
def setup_data(self, client: TestClient): |
|
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE stream_test AS SELECT range AS id, range % 5 AS category FROM range(10)"}) |
|
|
|
def test_stream_arrow(self, client: TestClient): |
|
response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test"}) |
|
assert response.status_code == 200 |
|
assert response.headers["content-type"] == "application/vnd.apache.arrow.stream" |
|
if not response.content: |
|
pytest.fail("Arrow stream response content is empty") |
|
try: |
|
reader = pa_ipc.open_stream(response.content) |
|
table = reader.read_all() |
|
except pa.ArrowInvalid as e: |
|
pytest.fail(f"Failed to read Arrow stream: {e}") |
|
assert table.num_rows == 10 |
|
assert table.column_names == ["id", "category"] |
|
assert table.column('id').to_pylist() == list(range(10)) |
|
|
|
def test_stream_arrow_empty(self, client: TestClient): |
|
response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test WHERE id < 0"}) |
|
assert response.status_code == 200 |
|
assert response.headers["content-type"] == "application/vnd.apache.arrow.stream" |
|
try: |
|
reader = pa_ipc.open_stream(response.content) |
|
table = reader.read_all() |
|
assert table.num_rows == 0 |
|
except pa.ArrowInvalid as e: |
|
print(f"Received ArrowInvalid for empty stream, which is acceptable: {e}") |
|
assert response.content == b'' |
|
|
|
def test_stream_jsonl(self, client: TestClient): |
|
response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test ORDER BY id"}) |
|
assert response.status_code == 200 |
|
assert response.headers["content-type"] == "application/jsonl" |
|
lines = response.text.strip().split('\n') |
|
records = [json.loads(line) for line in lines if line] |
|
assert len(records) == 10 |
|
assert records[0] == {"id": 0, "category": 0} |
|
assert records[9] == {"id": 9, "category": 4} |
|
|
|
def test_stream_jsonl_empty(self, client: TestClient): |
|
response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test WHERE id < 0"}) |
|
assert response.status_code == 200 |
|
assert response.headers["content-type"] == "application/jsonl" |
|
assert response.text.strip() == "" |
|
|
|
class TestExportDownload: |
|
@pytest.fixture(scope="class", autouse=True) |
|
def setup_data(self, client: TestClient): |
|
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE export_table(id INTEGER, name VARCHAR, price DECIMAL(5,2))"}) |
|
client.post("/execute", json={"sql": "INSERT INTO export_table VALUES (1, 'Apple', 0.50), (2, 'Banana', 0.30), (3, 'Orange', 0.75)"}) |
|
|
|
@pytest.mark.parametrize( |
|
"endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn", |
|
[ |
|
("csv", "text/csv", ".csv", lambda c: b"id,name,price\n1,Apple,0.50\n" in c), |
|
("parquet", "application/vnd.apache.parquet", ".parquet", lambda c: c.startswith(b"PAR1")), |
|
|
|
("json", "application/json", ".json", lambda c: c.strip().startswith(b'[') and c.strip().endswith(b']')), |
|
("jsonl", "application/jsonl", ".jsonl", lambda c: b'"id":1' in c and b'"name":"Apple"' in c and b'\n' in c), |
|
|
|
] |
|
) |
|
def test_export_data(self, client: TestClient, endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn, tmp_path): |
|
endpoint = f"/export/data/{endpoint_suffix}" |
|
payload = {"source": "export_table"} |
|
if endpoint_suffix == 'csv': |
|
payload['options'] = {'HEADER': True} |
|
|
|
response = client.post(endpoint, json=payload) |
|
|
|
assert response.status_code == 200, f"Request to {endpoint} failed: {response.text}" |
|
assert response.headers["content-type"].startswith(expected_content_type) |
|
assert "content-disposition" in response.headers |
|
assert f'filename="export_export_table{expected_filename_ext}"' in response.headers["content-disposition"] |
|
|
|
downloaded_path = tmp_path / f"downloaded{expected_filename_ext}" |
|
with open(downloaded_path, "wb") as f: |
|
f.write(response.content) |
|
assert downloaded_path.exists() |
|
assert validation_fn(response.content), f"Validation failed for {endpoint_suffix}" |
|
|
|
|
|
payload = {"source": "SELECT id, name FROM export_table WHERE price > 0.40 ORDER BY id"} |
|
response = client.post(endpoint, json=payload) |
|
assert response.status_code == 200 |
|
assert f'filename="export_query{expected_filename_ext}"' in response.headers["content-disposition"] |
|
assert len(response.content) > 0 |
|
|
|
|
|
def test_export_database(self, client: TestClient, tmp_path): |
|
client.post("/execute", json={"sql": "CREATE TABLE IF NOT EXISTS another_table(x int)"}) |
|
response = client.post("/export/database") |
|
assert response.status_code == 200 |
|
assert response.headers["content-type"] == "application/zip" |
|
assert "content-disposition" in response.headers |
|
assert response.headers["content-disposition"].startswith("attachment; filename=") |
|
assert 'filename="in_memory_db_export.zip"' in response.headers["content-disposition"] |
|
zip_path = tmp_path / "db_export.zip" |
|
with open(zip_path, "wb") as f: |
|
f.write(response.content) |
|
assert zip_path.exists() |
|
with zipfile.ZipFile(zip_path, 'r') as z: |
|
print(f"Zip contents: {z.namelist()}") |
|
assert "schema.sql" in z.namelist() |
|
assert "load.sql" in z.namelist() |
|
assert any(name.startswith("export_table") for name in z.namelist()) |
|
assert any(name.startswith("another_table") for name in z.namelist()) |
|
|
|
|
|
class TestExtensions: |
|
def test_install_extension_fail(self, client: TestClient): |
|
response = client.post("/extensions/install", json={"extension_name": "nonexistent_dummy_ext"}) |
|
assert response.status_code >= 400 |
|
assert "Error during install" in response.json()["detail"] or "Failed to download" in response.json()["detail"] |
|
|
|
def test_load_extension_fail(self, client: TestClient): |
|
response = client.post("/extensions/load", json={"extension_name": "nonexistent_dummy_ext"}) |
|
assert response.status_code == 400 |
|
|
|
assert "Error loading extension" in response.json()["detail"] |
|
|
|
assert "not found" in response.json()["detail"].lower() |
|
|
|
@pytest.mark.skip(reason="Requires httpfs extension to be available for install/load") |
|
def test_install_and_load_httpfs(self, client: TestClient): |
|
install_response = client.post("/extensions/install", json={"extension_name": "httpfs"}) |
|
assert install_response.status_code == 200 |
|
assert install_response.json()["status"] == "success" |
|
|
|
load_response = client.post("/extensions/load", json={"extension_name": "httpfs"}) |
|
assert load_response.status_code == 200 |
|
assert load_response.json()["status"] == "success" |