|
|
|
import duckdb |
|
import pandas as pd |
|
import pyarrow as pa |
|
import pyarrow.ipc |
|
from pathlib import Path |
|
import tempfile |
|
import os |
|
import shutil |
|
from typing import Optional, List, Dict, Any, Union, Iterator, Generator, Tuple |
|
|
|
from fastapi import FastAPI, HTTPException, Body, Query, BackgroundTasks, Depends |
|
from fastapi.responses import StreamingResponse, FileResponse |
|
from pydantic import BaseModel, Field |
|
|
|
from database_api import DatabaseAPI, DatabaseAPIError, QueryError |
|
|
|
|
|
DUCKDB_API_DB_PATH = os.getenv("DUCKDB_API_DB_PATH", "api_database.db") |
|
DUCKDB_API_READ_ONLY = os.getenv("DUCKDB_API_READ_ONLY", False) |
|
DUCKDB_API_CONFIG = {} |
|
TEMP_EXPORT_DIR = Path(tempfile.gettempdir()) / "duckdb_api_exports" |
|
TEMP_EXPORT_DIR.mkdir(exist_ok=True) |
|
print(f"Using temporary directory for exports: {TEMP_EXPORT_DIR}") |
|
|
|
|
|
class StatusResponse(BaseModel): |
|
status: str |
|
message: Optional[str] = None |
|
|
|
class ExecuteRequest(BaseModel): |
|
sql: str |
|
parameters: Optional[List[Any]] = None |
|
|
|
class QueryRequest(BaseModel): |
|
sql: str |
|
parameters: Optional[List[Any]] = None |
|
|
|
class DataFrameResponse(BaseModel): |
|
columns: List[str] |
|
records: List[Dict[str, Any]] |
|
|
|
class InstallRequest(BaseModel): |
|
extension_name: str |
|
force_install: bool = False |
|
|
|
class LoadRequest(BaseModel): |
|
extension_name: str |
|
|
|
class ExportDataRequest(BaseModel): |
|
source: str = Field(..., description="Table name or SQL SELECT query to export") |
|
options: Optional[Dict[str, Any]] = Field(None, description="Format-specific export options") |
|
|
|
|
|
app = FastAPI( |
|
title="DuckDB API Wrapper", |
|
description="Exposes DuckDB functionalities via a RESTful API.", |
|
version="0.2.1" |
|
) |
|
|
|
|
|
db_api_instance: Optional[DatabaseAPI] = None |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
global db_api_instance |
|
print("Starting up DuckDB API...") |
|
try: |
|
db_api_instance = DatabaseAPI(db_path=DUCKDB_API_DB_PATH, read_only=DUCKDB_API_READ_ONLY, config=DUCKDB_API_CONFIG) |
|
except DatabaseAPIError as e: |
|
print(f"FATAL: Could not initialize DatabaseAPI on startup: {e}") |
|
db_api_instance = None |
|
|
|
@app.on_event("shutdown") |
|
def shutdown_event(): |
|
print("Shutting down DuckDB API...") |
|
if db_api_instance: |
|
db_api_instance.close() |
|
|
|
|
|
def get_db_api() -> DatabaseAPI: |
|
if db_api_instance is None: |
|
raise HTTPException(status_code=503, detail="Database service is unavailable (failed to initialize).") |
|
try: |
|
db_api_instance._ensure_connection() |
|
return db_api_instance |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=503, detail=f"Database service error: {e}") |
|
|
|
|
|
|
|
|
|
@app.post("/execute", response_model=StatusResponse, tags=["CRUD"]) |
|
async def execute_statement(request: ExecuteRequest, api: DatabaseAPI = Depends(get_db_api)): |
|
try: |
|
api.execute_sql(request.sql, request.parameters) |
|
return {"status": "success", "message": None} |
|
except QueryError as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/query/fetchall", response_model=List[tuple], tags=["Querying"]) |
|
async def query_fetchall_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)): |
|
try: |
|
return api.query_fetchall(request.sql, request.parameters) |
|
except QueryError as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/query/dataframe", response_model=DataFrameResponse, tags=["Querying"]) |
|
async def query_dataframe_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)): |
|
try: |
|
df = api.query_df(request.sql, request.parameters) |
|
df_serializable = df.replace({pd.NA: None, pd.NaT: None, float('nan'): None}) |
|
return {"columns": df_serializable.columns.tolist(), "records": df_serializable.to_dict(orient='records')} |
|
except (QueryError, ImportError) as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
async def _stream_arrow_ipc(record_batch_iterator: Iterator[pa.RecordBatch]) -> Generator[bytes, None, None]: |
|
"""Helper generator to stream Arrow IPC Stream format.""" |
|
writer = None |
|
sink = pa.BufferOutputStream() |
|
try: |
|
first_batch = next(record_batch_iterator) |
|
writer = pa.ipc.new_stream(sink, first_batch.schema) |
|
writer.write_batch(first_batch) |
|
|
|
|
|
for batch in record_batch_iterator: |
|
|
|
writer.write_batch(batch) |
|
|
|
except StopIteration: |
|
|
|
if writer is None: |
|
print("Warning: Arrow stream iterator was empty.") |
|
|
|
|
|
return |
|
|
|
except Exception as e: |
|
print(f"Error during Arrow streaming generator: {e}") |
|
|
|
finally: |
|
if writer: |
|
try: |
|
print("Closing Arrow IPC Stream Writer...") |
|
writer.close() |
|
print("Writer closed.") |
|
except Exception as close_e: |
|
print(f"Error closing Arrow writer: {close_e}") |
|
if sink: |
|
try: |
|
buffer = sink.getvalue() |
|
if buffer: |
|
print(f"Yielding final Arrow buffer (size: {len(buffer.to_pybytes())})...") |
|
yield buffer.to_pybytes() |
|
else: |
|
print("Arrow sink buffer was empty after closing writer.") |
|
sink.close() |
|
except Exception as close_e: |
|
print(f"Error closing or getting value from Arrow sink: {close_e}") |
|
|
|
|
|
|
|
@app.post("/query/stream/arrow", tags=["Streaming"]) |
|
async def query_stream_arrow_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)): |
|
"""Executes a SQL query and streams results as Arrow IPC Stream format.""" |
|
try: |
|
iterator = api.stream_query_arrow(request.sql, request.parameters) |
|
return StreamingResponse( |
|
_stream_arrow_ipc(iterator), |
|
media_type="application/vnd.apache.arrow.stream" |
|
) |
|
except (QueryError, ImportError) as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
async def _stream_jsonl(dataframe_iterator: Iterator[pd.DataFrame]) -> Generator[bytes, None, None]: |
|
try: |
|
for df_chunk in dataframe_iterator: |
|
df_serializable = df_chunk.replace({pd.NA: None, pd.NaT: None, float('nan'): None}) |
|
jsonl_string = df_serializable.to_json(orient='records', lines=True, date_format='iso') |
|
if jsonl_string: |
|
|
|
if not jsonl_string.endswith('\n'): |
|
jsonl_string += '\n' |
|
yield jsonl_string.encode('utf-8') |
|
except Exception as e: |
|
print(f"Error during JSONL streaming generator: {e}") |
|
|
|
@app.post("/query/stream/jsonl", tags=["Streaming"]) |
|
async def query_stream_jsonl_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)): |
|
"""Executes a SQL query and streams results as JSON Lines (JSONL).""" |
|
try: |
|
iterator = api.stream_query_df(request.sql, request.parameters) |
|
return StreamingResponse(_stream_jsonl(iterator), media_type="application/jsonl") |
|
except (QueryError, ImportError) as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
def _cleanup_temp_file(path: Union[str, Path]): |
|
try: |
|
if Path(path).is_file(): |
|
os.remove(path) |
|
print(f"Cleaned up temporary file: {path}") |
|
except OSError as e: |
|
print(f"Error cleaning up temporary file {path}: {e}") |
|
|
|
async def _create_temp_export( |
|
api: DatabaseAPI, |
|
source: str, |
|
export_format: str, |
|
options: Optional[Dict[str, Any]] = None, |
|
suffix: str = ".tmp" |
|
) -> Path: |
|
fd, temp_path_str = tempfile.mkstemp(suffix=suffix, dir=TEMP_EXPORT_DIR) |
|
os.close(fd) |
|
temp_file_path = Path(temp_path_str) |
|
|
|
try: |
|
print(f"Exporting to temporary file: {temp_file_path}") |
|
if export_format == 'csv': |
|
api.export_data_to_csv(source, temp_file_path, options) |
|
elif export_format == 'parquet': |
|
api.export_data_to_parquet(source, temp_file_path, options) |
|
elif export_format == 'json': |
|
api.export_data_to_json(source, temp_file_path, array_format=True, options=options) |
|
elif export_format == 'jsonl': |
|
api.export_data_to_jsonl(source, temp_file_path, options=options) |
|
else: |
|
raise ValueError(f"Unsupported export format: {export_format}") |
|
return temp_file_path |
|
except Exception as e: |
|
_cleanup_temp_file(temp_file_path) |
|
raise e |
|
|
|
@app.post("/export/data/csv", response_class=FileResponse, tags=["Export / Download"]) |
|
async def export_csv_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)): |
|
try: |
|
temp_file_path = await _create_temp_export(api, request.source, 'csv', request.options, suffix=".csv") |
|
background_tasks.add_task(_cleanup_temp_file, temp_file_path) |
|
filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.csv" |
|
return FileResponse(temp_file_path, media_type='text/csv', filename=filename) |
|
except (QueryError, ValueError) as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Unexpected error during CSV export: {e}") |
|
|
|
@app.post("/export/data/parquet", response_class=FileResponse, tags=["Export / Download"]) |
|
async def export_parquet_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)): |
|
try: |
|
temp_file_path = await _create_temp_export(api, request.source, 'parquet', request.options, suffix=".parquet") |
|
background_tasks.add_task(_cleanup_temp_file, temp_file_path) |
|
filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.parquet" |
|
return FileResponse(temp_file_path, media_type='application/vnd.apache.parquet', filename=filename) |
|
except (QueryError, ValueError) as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Unexpected error during Parquet export: {e}") |
|
|
|
@app.post("/export/data/json", response_class=FileResponse, tags=["Export / Download"]) |
|
async def export_json_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)): |
|
try: |
|
temp_file_path = await _create_temp_export(api, request.source, 'json', request.options, suffix=".json") |
|
background_tasks.add_task(_cleanup_temp_file, temp_file_path) |
|
filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.json" |
|
return FileResponse(temp_file_path, media_type='application/json', filename=filename) |
|
except (QueryError, ValueError) as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Unexpected error during JSON export: {e}") |
|
|
|
@app.post("/export/data/jsonl", response_class=FileResponse, tags=["Export / Download"]) |
|
async def export_jsonl_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)): |
|
try: |
|
temp_file_path = await _create_temp_export(api, request.source, 'jsonl', request.options, suffix=".jsonl") |
|
background_tasks.add_task(_cleanup_temp_file, temp_file_path) |
|
filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.jsonl" |
|
return FileResponse(temp_file_path, media_type='application/jsonl', filename=filename) |
|
except (QueryError, ValueError) as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Unexpected error during JSONL export: {e}") |
|
|
|
@app.post("/export/database", response_class=FileResponse, tags=["Export / Download"]) |
|
async def export_database_endpoint(background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)): |
|
export_target_dir = Path(tempfile.mkdtemp(dir=TEMP_EXPORT_DIR)) |
|
fd, zip_path_str = tempfile.mkstemp(suffix=".zip", dir=TEMP_EXPORT_DIR) |
|
os.close(fd) |
|
zip_file_path = Path(zip_path_str) |
|
try: |
|
print(f"Exporting database to temporary directory: {export_target_dir}") |
|
api.export_database(export_target_dir) |
|
print(f"Creating zip archive at: {zip_file_path}") |
|
shutil.make_archive(str(zip_file_path.with_suffix('')), 'zip', str(export_target_dir)) |
|
print(f"Zip archive created: {zip_file_path}") |
|
background_tasks.add_task(shutil.rmtree, export_target_dir, ignore_errors=True) |
|
background_tasks.add_task(_cleanup_temp_file, zip_file_path) |
|
db_name = Path(api._db_path).stem if api._db_path != ':memory:' else 'in_memory_db' |
|
return FileResponse(zip_file_path, media_type='application/zip', filename=f"{db_name}_export.zip") |
|
except (QueryError, ValueError, OSError, DatabaseAPIError) as e: |
|
print(f"Error during database export: {e}") |
|
shutil.rmtree(export_target_dir, ignore_errors=True) |
|
_cleanup_temp_file(zip_file_path) |
|
if isinstance(e, DatabaseAPIError): |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
else: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except Exception as e: |
|
print(f"Unexpected error during database export: {e}") |
|
shutil.rmtree(export_target_dir, ignore_errors=True) |
|
_cleanup_temp_file(zip_file_path) |
|
raise HTTPException(status_code=500, detail=f"Unexpected error during database export: {e}") |
|
|
|
|
|
|
|
@app.post("/extensions/install", response_model=StatusResponse, tags=["Extensions"]) |
|
async def install_extension_endpoint(request: InstallRequest, api: DatabaseAPI = Depends(get_db_api)): |
|
try: |
|
api.install_extension(request.extension_name, request.force_install) |
|
return {"status": "success", "message": f"Extension '{request.extension_name}' installed."} |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
except (duckdb.IOException, duckdb.CatalogException, duckdb.InvalidInputException) as e: |
|
raise HTTPException(status_code=400, detail=f"DuckDB Error during install: {e}") |
|
except duckdb.Error as e: |
|
raise HTTPException(status_code=500, detail=f"Unexpected DuckDB Error during install: {e}") |
|
|
|
|
|
@app.post("/extensions/load", response_model=StatusResponse, tags=["Extensions"]) |
|
async def load_extension_endpoint(request: LoadRequest, api: DatabaseAPI = Depends(get_db_api)): |
|
"""Loads an installed DuckDB extension.""" |
|
try: |
|
api.load_extension(request.extension_name) |
|
return {"status": "success", "message": f"Extension '{request.extension_name}' loaded."} |
|
|
|
except QueryError as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except DatabaseAPIError as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
except (duckdb.IOException, duckdb.CatalogException) as e: |
|
raise HTTPException(status_code=400, detail=f"DuckDB Error during load: {e}") |
|
except duckdb.Error as e: |
|
raise HTTPException(status_code=500, detail=f"Unexpected DuckDB Error during load: {e}") |
|
|
|
|
|
|
|
@app.get("/health", response_model=StatusResponse, tags=["Health"]) |
|
async def health_check(): |
|
"""Basic health check.""" |
|
try: |
|
_ = get_db_api() |
|
return {"status": "ok", "message": None} |
|
except HTTPException as e: |
|
raise e |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Health check failed unexpectedly: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
print(f"Starting DuckDB API server...") |
|
print(f"Database file configured at: {DUCKDB_API_DB_PATH}") |
|
print(f"Read-only mode: {DUCKDB_API_READ_ONLY}") |
|
print(f"Temporary export directory: {TEMP_EXPORT_DIR}") |
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |