Spaces:
Runtime error
Runtime error
from typing import List, Optional | |
from inference.core.entities.types import DatasetID, WorkspaceID | |
from inference.core.roboflow_api import ( | |
get_roboflow_labeling_batches, | |
get_roboflow_labeling_jobs, | |
) | |
def image_can_be_submitted_to_batch( | |
batch_name: str, | |
workspace_id: WorkspaceID, | |
dataset_id: DatasetID, | |
max_batch_images: Optional[int], | |
api_key: str, | |
) -> bool: | |
"""Check if an image can be submitted to a batch. | |
Args: | |
batch_name: Name of the batch. | |
workspace_id: ID of the workspace. | |
dataset_id: ID of the dataset. | |
max_batch_images: Maximum number of images allowed in the batch. | |
api_key: API key to use for the request. | |
Returns: | |
True if the image can be submitted to the batch, False otherwise. | |
""" | |
if max_batch_images is None: | |
return True | |
labeling_batches = get_roboflow_labeling_batches( | |
api_key=api_key, | |
workspace_id=workspace_id, | |
dataset_id=dataset_id, | |
) | |
matching_labeling_batch = get_matching_labeling_batch( | |
all_labeling_batches=labeling_batches["batches"], | |
batch_name=batch_name, | |
) | |
if matching_labeling_batch is None: | |
return max_batch_images > 0 | |
batch_images_under_labeling = 0 | |
if matching_labeling_batch["numJobs"] > 0: | |
labeling_jobs = get_roboflow_labeling_jobs( | |
api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id | |
) | |
batch_images_under_labeling = get_images_in_labeling_jobs_of_specific_batch( | |
all_labeling_jobs=labeling_jobs["jobs"], | |
batch_id=matching_labeling_batch["id"], | |
) | |
total_batch_images = matching_labeling_batch["images"] + batch_images_under_labeling | |
return max_batch_images > total_batch_images | |
def get_matching_labeling_batch( | |
all_labeling_batches: List[dict], | |
batch_name: str, | |
) -> Optional[dict]: | |
"""Get the matching labeling batch. | |
Args: | |
all_labeling_batches: All labeling batches. | |
batch_name: Name of the batch. | |
Returns: | |
The matching labeling batch if found, None otherwise. | |
""" | |
matching_batch = None | |
for labeling_batch in all_labeling_batches: | |
if labeling_batch["name"] == batch_name: | |
matching_batch = labeling_batch | |
break | |
return matching_batch | |
def get_images_in_labeling_jobs_of_specific_batch( | |
all_labeling_jobs: List[dict], | |
batch_id: str, | |
) -> int: | |
"""Get the number of images in labeling jobs of a specific batch. | |
Args: | |
all_labeling_jobs: All labeling jobs. | |
batch_id: ID of the batch. | |
Returns: | |
The number of images in labeling jobs of the batch. | |
""" | |
matching_jobs = [] | |
for labeling_job in all_labeling_jobs: | |
if batch_id in labeling_job["sourceBatch"]: | |
matching_jobs.append(labeling_job) | |
return sum(job["numImages"] for job in matching_jobs) | |