File size: 2,599 Bytes
c30b770 4cf98d3 c30b770 10656cf 10c2fec 10656cf 10c2fec c30b770 10c2fec c30b770 10656cf c30b770 6ea28ef c30b770 4cf98d3 c30b770 10c2fec c30b770 4cf98d3 10c2fec 4cf98d3 c30b770 10656cf c30b770 10656cf c30b770 10c2fec c30b770 10656cf 4cf98d3 10656cf c30b770 10656cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
import gc
import shutil
import logging
from pathlib import Path
from huggingface_hub import WebhooksServer, WebhookPayload
from datasets import Dataset, load_dataset, disable_caching
from fastapi import BackgroundTasks, Response, status
from huggingface_hub.utils import build_hf_headers, get_session
disable_caching()
# Set up the logger
logger = logging.getLogger("basic_logger")
logger.setLevel(logging.INFO)
# Set up the console handler with a simple format
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Environment variables
DS_NAME = "amaye15/object-segmentation"
DATA_DIR = "data"
TARGET_REPO = "amaye15/object-segmentation-processed"
WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET")
def get_data():
"""
Generator function to stream data from the dataset.
"""
ds = load_dataset(
DS_NAME,
cache_dir=os.path.join(os.getcwd(), DATA_DIR),
streaming=True,
download_mode="force_redownload",
)
for row in ds["train"]:
yield row
gc.collect()
def process_and_push_data():
"""
Function to process and push new data to the target repository.
"""
p = os.path.join(os.getcwd(), DATA_DIR)
if os.path.exists(p):
shutil.rmtree(p)
os.mkdir(p)
ds_processed = Dataset.from_generator(get_data)
ds_processed.push_to_hub(TARGET_REPO)
logger.info("Data processed and pushed to the hub.")
gc.collect()
# Initialize the WebhooksServer with Gradio interface (if needed)
app = WebhooksServer(webhook_secret=WEBHOOK_SECRET)
@app.add_webhook("/dataset_repo")
async def handle_repository_changes(
payload: WebhookPayload, task_queue: BackgroundTasks
):
"""
Webhook endpoint that triggers data processing when the dataset is updated.
"""
logger.info(
f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}"
)
task_queue.add_task(_process_webhook)
return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED)
def _process_webhook():
logger.info(f"Loading new dataset...")
# dataset = load_dataset(DS_NAME)
logger.info(f"Loaded new dataset")
logger.info(f"Processing and updating dataset...")
process_and_push_data()
logger.info(f"Processing and updating dataset completed!")
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", show_error=True, server_port=7860)
|