DuckDB-UI / test_api.py
amaye15's picture
Deploy
c26b6eb
raw
history blame
12.7 kB
import pytest
import os
import shutil
import tempfile
import zipfile
import json
from pathlib import Path
from typing import List, Dict, Any
from unittest.mock import patch
pd = pytest.importorskip("pandas")
pa = pytest.importorskip("pyarrow")
pa_ipc = pytest.importorskip("pyarrow.ipc")
from fastapi.testclient import TestClient
import main # Import main to reload and access config
# --- Test Fixtures --- (Keep client fixture as before)
@pytest.fixture(scope="module")
def client():
with patch.dict(os.environ, {"DUCKDB_API_DB_PATH": ":memory:"}):
import importlib
importlib.reload(main)
main.TEMP_EXPORT_DIR.mkdir(exist_ok=True)
print(f"TestClient using temp export dir: {main.TEMP_EXPORT_DIR}")
with TestClient(main.app) as c:
yield c
print(f"Cleaning up test export dir: {main.TEMP_EXPORT_DIR}")
for item in main.TEMP_EXPORT_DIR.iterdir():
try:
if item.is_file():
os.remove(item)
elif item.is_dir():
shutil.rmtree(item)
except Exception as e:
print(f"Error cleaning up {item}: {e}")
# --- Test Classes ---
class TestHealth: # (Keep as before)
def test_health_check(self, client: TestClient):
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok", "message": None}
class TestExecution: # (Keep as before)
def test_execute_create(self, client: TestClient):
response = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER, name VARCHAR);"})
assert response.status_code == 200
assert response.json() == {"status": "success", "message": None}
response_fail = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER);"})
assert response_fail.status_code == 400
def test_execute_insert(self, client: TestClient):
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"})
response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (1, 'Alice')"})
assert response.status_code == 200
query_response = client.post("/query/fetchall", json={"sql": "SELECT COUNT(*) FROM test_table"})
assert query_response.status_code == 200
assert query_response.json() == [[1]]
def test_execute_insert_params(self, client: TestClient):
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"})
response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (?, ?)", "parameters": [2, "Bob"]})
assert response.status_code == 200
query_response = client.post("/query/fetchall", json={"sql": "SELECT * FROM test_table WHERE id = 2"})
assert query_response.status_code == 200
assert query_response.json() == [[2, "Bob"]]
def test_execute_invalid_sql(self, client: TestClient):
response = client.post("/execute", json={"sql": "INVALID SQL STATEMENT"})
assert response.status_code == 400
assert "Parser Error" in response.json()["detail"]
class TestQuerying: # (Keep as before)
@pytest.fixture(scope="class", autouse=True)
def setup_data(self, client: TestClient):
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE query_test(id INTEGER, val VARCHAR)"})
client.post("/execute", json={"sql": "INSERT INTO query_test VALUES (1, 'one'), (2, 'two'), (3, 'three')"})
def test_query_fetchall(self, client: TestClient):
response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test ORDER BY id"})
assert response.status_code == 200
assert response.json() == [[1, 'one'], [2, 'two'], [3, 'three']]
def test_query_fetchall_params(self, client: TestClient):
response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > ? ORDER BY id", "parameters": [1]})
assert response.status_code == 200
assert response.json() == [[2, 'two'], [3, 'three']]
def test_query_fetchall_empty(self, client: TestClient):
response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > 100"})
assert response.status_code == 200
assert response.json() == []
def test_query_dataframe(self, client: TestClient):
response = client.post("/query/dataframe", json={"sql": "SELECT * FROM query_test ORDER BY id"})
assert response.status_code == 200
data = response.json()
assert data["columns"] == ["id", "val"]
assert data["records"] == [
{"id": 1, "val": "one"},
{"id": 2, "val": "two"},
{"id": 3, "val": "three"}
]
def test_query_dataframe_invalid_sql(self, client: TestClient):
response = client.post("/query/dataframe", json={"sql": "SELECT non_existent FROM query_test"})
assert response.status_code == 400
assert "Binder Error" in response.json()["detail"]
class TestStreaming: # (Keep as before)
@pytest.fixture(scope="class", autouse=True)
def setup_data(self, client: TestClient):
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE stream_test AS SELECT range AS id, range % 5 AS category FROM range(10)"})
def test_stream_arrow(self, client: TestClient):
response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test"})
assert response.status_code == 200
assert response.headers["content-type"] == "application/vnd.apache.arrow.stream"
if not response.content:
pytest.fail("Arrow stream response content is empty")
try:
reader = pa_ipc.open_stream(response.content)
table = reader.read_all()
except pa.ArrowInvalid as e:
pytest.fail(f"Failed to read Arrow stream: {e}")
assert table.num_rows == 10
assert table.column_names == ["id", "category"]
assert table.column('id').to_pylist() == list(range(10))
def test_stream_arrow_empty(self, client: TestClient):
response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test WHERE id < 0"})
assert response.status_code == 200
assert response.headers["content-type"] == "application/vnd.apache.arrow.stream"
try:
reader = pa_ipc.open_stream(response.content)
table = reader.read_all()
assert table.num_rows == 0
except pa.ArrowInvalid as e:
print(f"Received ArrowInvalid for empty stream, which is acceptable: {e}")
assert response.content == b''
def test_stream_jsonl(self, client: TestClient):
response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test ORDER BY id"})
assert response.status_code == 200
assert response.headers["content-type"] == "application/jsonl"
lines = response.text.strip().split('\n')
records = [json.loads(line) for line in lines if line]
assert len(records) == 10
assert records[0] == {"id": 0, "category": 0}
assert records[9] == {"id": 9, "category": 4}
def test_stream_jsonl_empty(self, client: TestClient):
response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test WHERE id < 0"})
assert response.status_code == 200
assert response.headers["content-type"] == "application/jsonl"
assert response.text.strip() == ""
class TestExportDownload: # (Keep setup_data as before)
@pytest.fixture(scope="class", autouse=True)
def setup_data(self, client: TestClient):
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE export_table(id INTEGER, name VARCHAR, price DECIMAL(5,2))"})
client.post("/execute", json={"sql": "INSERT INTO export_table VALUES (1, 'Apple', 0.50), (2, 'Banana', 0.30), (3, 'Orange', 0.75)"})
@pytest.mark.parametrize(
"endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn",
[
("csv", "text/csv", ".csv", lambda c: b"id,name,price\n1,Apple,0.50\n" in c),
("parquet", "application/vnd.apache.parquet", ".parquet", lambda c: c.startswith(b"PAR1")),
# --- MODIFIED JSON/JSONL Lambdas ---
("json", "application/json", ".json", lambda c: c.strip().startswith(b'[') and c.strip().endswith(b']')),
("jsonl", "application/jsonl", ".jsonl", lambda c: b'"id":1' in c and b'"name":"Apple"' in c and b'\n' in c),
# --- END MODIFICATION ---
]
)
def test_export_data(self, client: TestClient, endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn, tmp_path):
endpoint = f"/export/data/{endpoint_suffix}"
payload = {"source": "export_table"}
if endpoint_suffix == 'csv':
payload['options'] = {'HEADER': True}
response = client.post(endpoint, json=payload)
assert response.status_code == 200, f"Request to {endpoint} failed: {response.text}"
assert response.headers["content-type"].startswith(expected_content_type)
assert "content-disposition" in response.headers
assert f'filename="export_export_table{expected_filename_ext}"' in response.headers["content-disposition"]
downloaded_path = tmp_path / f"downloaded{expected_filename_ext}"
with open(downloaded_path, "wb") as f:
f.write(response.content)
assert downloaded_path.exists()
assert validation_fn(response.content), f"Validation failed for {endpoint_suffix}"
# Test with a query source
payload = {"source": "SELECT id, name FROM export_table WHERE price > 0.40 ORDER BY id"}
response = client.post(endpoint, json=payload)
assert response.status_code == 200
assert f'filename="export_query{expected_filename_ext}"' in response.headers["content-disposition"]
assert len(response.content) > 0
# --- Keep test_export_database as before ---
def test_export_database(self, client: TestClient, tmp_path):
client.post("/execute", json={"sql": "CREATE TABLE IF NOT EXISTS another_table(x int)"})
response = client.post("/export/database")
assert response.status_code == 200
assert response.headers["content-type"] == "application/zip"
assert "content-disposition" in response.headers
assert response.headers["content-disposition"].startswith("attachment; filename=")
assert 'filename="in_memory_db_export.zip"' in response.headers["content-disposition"]
zip_path = tmp_path / "db_export.zip"
with open(zip_path, "wb") as f:
f.write(response.content)
assert zip_path.exists()
with zipfile.ZipFile(zip_path, 'r') as z:
print(f"Zip contents: {z.namelist()}")
assert "schema.sql" in z.namelist()
assert "load.sql" in z.namelist()
assert any(name.startswith("export_table") for name in z.namelist())
assert any(name.startswith("another_table") for name in z.namelist())
class TestExtensions: # (Keep as before)
def test_install_extension_fail(self, client: TestClient):
response = client.post("/extensions/install", json={"extension_name": "nonexistent_dummy_ext"})
assert response.status_code >= 400
assert "Error during install" in response.json()["detail"] or "Failed to download" in response.json()["detail"]
def test_load_extension_fail(self, client: TestClient):
response = client.post("/extensions/load", json={"extension_name": "nonexistent_dummy_ext"})
assert response.status_code == 400
# --- MODIFIED Assertion ---
assert "Error loading extension" in response.json()["detail"]
# --- END MODIFICATION ---
assert "not found" in response.json()["detail"].lower()
@pytest.mark.skip(reason="Requires httpfs extension to be available for install/load")
def test_install_and_load_httpfs(self, client: TestClient):
install_response = client.post("/extensions/install", json={"extension_name": "httpfs"})
assert install_response.status_code == 200
assert install_response.json()["status"] == "success"
load_response = client.post("/extensions/load", json={"extension_name": "httpfs"})
assert load_response.status_code == 200
assert load_response.json()["status"] == "success"