File size: 2,599 Bytes
c30b770
4cf98d3
c30b770
 
10656cf
 
10c2fec
10656cf
 
10c2fec
 
c30b770
 
 
 
 
10c2fec
c30b770
 
 
 
 
 
10656cf
c30b770
 
6ea28ef
 
c30b770
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf98d3
c30b770
 
 
 
10c2fec
c30b770
 
 
 
 
 
 
 
 
4cf98d3
10c2fec
4cf98d3
c30b770
 
10656cf
 
c30b770
 
10656cf
 
 
 
c30b770
10c2fec
c30b770
10656cf
 
 
 
 
 
 
 
 
4cf98d3
10656cf
 
 
 
 
c30b770
 
10656cf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import gc
import shutil
import logging
from pathlib import Path
from huggingface_hub import WebhooksServer, WebhookPayload
from datasets import Dataset, load_dataset, disable_caching
from fastapi import BackgroundTasks, Response, status
from huggingface_hub.utils import build_hf_headers, get_session

disable_caching()

# Set up the logger
logger = logging.getLogger("basic_logger")
logger.setLevel(logging.INFO)

# Set up the console handler with a simple format
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# Environment variables
DS_NAME = "amaye15/object-segmentation"
DATA_DIR = "data"
TARGET_REPO = "amaye15/object-segmentation-processed"
WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET")


def get_data():
    """
    Generator function to stream data from the dataset.
    """
    ds = load_dataset(
        DS_NAME,
        cache_dir=os.path.join(os.getcwd(), DATA_DIR),
        streaming=True,
        download_mode="force_redownload",
    )
    for row in ds["train"]:
        yield row
    gc.collect()


def process_and_push_data():
    """
    Function to process and push new data to the target repository.
    """
    p = os.path.join(os.getcwd(), DATA_DIR)

    if os.path.exists(p):
        shutil.rmtree(p)

    os.mkdir(p)

    ds_processed = Dataset.from_generator(get_data)
    ds_processed.push_to_hub(TARGET_REPO)
    logger.info("Data processed and pushed to the hub.")
    gc.collect()


# Initialize the WebhooksServer with Gradio interface (if needed)
app = WebhooksServer(webhook_secret=WEBHOOK_SECRET)


@app.add_webhook("/dataset_repo")
async def handle_repository_changes(
    payload: WebhookPayload, task_queue: BackgroundTasks
):
    """
    Webhook endpoint that triggers data processing when the dataset is updated.
    """
    logger.info(
        f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}"
    )
    task_queue.add_task(_process_webhook)
    return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED)


def _process_webhook():
    logger.info(f"Loading new dataset...")
    # dataset = load_dataset(DS_NAME)
    logger.info(f"Loaded new dataset")

    logger.info(f"Processing and updating dataset...")
    process_and_push_data()
    logger.info(f"Processing and updating dataset completed!")


if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", show_error=True, server_port=7860)