Spaces:
Runtime error
Runtime error
import asyncio | |
import random | |
import time | |
class Job: | |
def __init__(self, id, data): | |
self.id = id | |
self.data = data | |
class Node: | |
# def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None): | |
def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ): | |
self.worker_id = worker_id | |
self.input_queue = input_queue | |
self.output_queue = output_queue | |
self.buffer = {} | |
self.job_sync = job_sync | |
self.sequential_node = sequential_node | |
self.next_i = 0 | |
self._jobs_dequeued = 0 | |
self._jobs_processed = 0 | |
# throw an error if job_sync is not None and sequential_node is False | |
if self.job_sync is not None and self.sequential_node == False: | |
raise ValueError('job_sync is not None and sequential_node is False') | |
async def run(self): | |
while True: | |
job: Job = await self.input_queue.get() | |
self._jobs_dequeued += 1 | |
if self.sequential_node == False: | |
await self.process_job(job) | |
else: | |
# ensure that jobs are processed in order | |
self.buffer[job.id] = job | |
while self.next_i in self.buffer: | |
job = self.buffer.pop(self.next_i) | |
await self.process_job(job) | |
self.next_i += 1 | |
if self.output_queue is not None: | |
await self.output_queue.put(job) | |
if self.job_sync is not None: | |
self.job_sync.append(job) | |
self._jobs_processed += 1 | |
async def process_job(self, job: Job): | |
raise NotImplementedError | |
class Node1(Node): | |
async def process_job(self, job: Job): | |
job.data += f' (processed by node 1, worker {self.worker_id})' | |
class Node2(Node): | |
async def process_job(self, job: Job): | |
sleep_duration = 0.8 + 0.4 * random.random() | |
await asyncio.sleep(sleep_duration) | |
job.data += f' (processed by node 2, worker {self.worker_id})' | |
class Node3(Node): | |
async def process_job(self, job: Job): | |
job.data += f' (processed by node 3, worker {self.worker_id})' | |
print(f'{job.id} - {job.data}') | |
async def main(): | |
node1_queue = asyncio.Queue() | |
node2_queue = asyncio.Queue() | |
node3_queue = asyncio.Queue() | |
num_jobs = 100 | |
joe_source = [Job(i, "") for i in range(num_jobs)] | |
job_sync = [] | |
# create the workers | |
num_workers = 5 | |
node1_workers = [Node1(i + 1, node1_queue, node2_queue) for i in range(1)] | |
node2_workers = [Node2(i + 1, node2_queue, node3_queue) for i in range(num_workers)] | |
node3_workers = [Node3(i + 1, node3_queue, job_sync=job_sync, sequential_node=True) for i in range(1)] | |
# create tasks for the workers | |
tasks1 = [asyncio.create_task(worker.run()) for worker in node1_workers] | |
tasks2 = [asyncio.create_task(worker.run()) for worker in node2_workers] | |
tasks3 = [asyncio.create_task(worker.run()) for worker in node3_workers] | |
for job in joe_source: | |
await node1_queue.put(job) | |
# await input_queue.put(joe_source[0]) | |
try: | |
while len(job_sync) < num_jobs: | |
# print(f"Waiting for jobs to finish... Job sync size: {len(job_sync)}, node1_queue size: {node1_queue.qsize()}, node2_queue size: {node2_queue.qsize()}, node3_queue size: {node3_queue.qsize()}") | |
await asyncio.sleep(0.1) | |
except asyncio.CancelledError: | |
print("Pipeline cancelled") | |
for task in tasks1: | |
task.cancel() | |
for task in tasks2: | |
task.cancel() | |
for task in tasks3: | |
task.cancel() | |
await asyncio.gather(*tasks1, *tasks2, *tasks3, return_exceptions=True) | |
start_time = time.time() | |
try: | |
asyncio.run(main()) | |
except KeyboardInterrupt: | |
print("Pipeline interrupted by user") | |
end_time = time.time() | |
print(f"Pipeline processed in {end_time - start_time} seconds.") | |