from fastapi import FastAPI, Form, Request, BackgroundTasks |
from fastapi.responses import HTMLResponse |
from fastapi.templating import Jinja2Templates |
from uuid import uuid4 |
import time |
import asyncio |
from build_map import load_dataset_and_metadata, upload_dataset_to_atlas |
from models import WebhookPayload, WebhookPayloadRepo, WebhookPayloadEvent |
app = FastAPI() |
tasks = {} |
templates = Jinja2Templates(directory="templates") |
def upload_atlas_task(task_id, dataset_name): |
dataset_dict = load_dataset_and_metadata(dataset_name) |
map_url = upload_dataset_to_atlas(dataset_dict, project_name="atlas-space-test") |
tasks[task_id]['status'] = 'done' |
tasks[task_id]['url'] = map_url |
@app.on_event("startup") |
async def startup_event(): |
asyncio.create_task(cleanup_tasks()) |
async def cleanup_tasks(): |
while True: |
current_time = time.time() |
tasks_to_delete = [] |
for task_id, task in tasks.items(): |
if task['status'] == 'done' and current_time - task.get('finish_time', current_time) > 1800: |
tasks_to_delete.append(task_id) |
for task_id in tasks_to_delete: |
del tasks[task_id] |
await asyncio.sleep(1800) |
@app.get("/", response_class=HTMLResponse) |
async def read_form(request: Request): |
return templates.TemplateResponse("form.html", {"request": request}) |
@app.post("/submit_form") |
async def form_post(background_tasks: BackgroundTasks, dataset_name: str = Form(...)): |
task_id = str(uuid4()) |
tasks[task_id] = {'status': 'running'} |
background_tasks.add_task(upload_atlas_task, task_id, dataset_name) |
return {'task_id': task_id} |
@app.get("/status/{task_id}") |
async def read_task(task_id: str): |
if task_id not in tasks: |
return {'status': 'not found'} |
else: |
return tasks[task_id] |
@app.post("/webhook") |
async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload): |
if not ( |
payload.event.action == "update" |
and payload.event.scope.startswith("repo.content") |
and payload.repo.type == "dataset" |
): |
return {"processed": False} |
else: |
task_id = str(uuid4()) |
tasks[task_id] = {'status': 'running'} |
background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name) |
return {'task_id': task_id} |