Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import multiprocessing | |
import queue | |
import threading | |
from pathlib import Path | |
import numpy as np | |
import PIL.Image | |
import torch | |
def safe_stop_image_writer(func): | |
def wrapper(*args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
except Exception as e: | |
dataset = kwargs.get("dataset") | |
image_writer = getattr(dataset, "image_writer", None) if dataset else None | |
if image_writer is not None: | |
print("Waiting for image writer to terminate...") | |
image_writer.stop() | |
raise e | |
return wrapper | |
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image: | |
# TODO(aliberts): handle 1 channel and 4 for depth images | |
if image_array.ndim != 3: | |
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.") | |
if image_array.shape[0] == 3: | |
# Transpose from pytorch convention (C, H, W) to (H, W, C) | |
image_array = image_array.transpose(1, 2, 0) | |
elif image_array.shape[-1] != 3: | |
raise NotImplementedError( | |
f"The image has {image_array.shape[-1]} channels, but 3 is required for now." | |
) | |
if image_array.dtype != np.uint8: | |
if range_check: | |
max_ = image_array.max().item() | |
min_ = image_array.min().item() | |
if max_ > 1.0 or min_ < 0.0: | |
raise ValueError( | |
"The image data type is float, which requires values in the range [0.0, 1.0]. " | |
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or " | |
"provide a uint8 image with values in the range [0, 255]." | |
) | |
image_array = (image_array * 255).astype(np.uint8) | |
return PIL.Image.fromarray(image_array) | |
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): | |
try: | |
if isinstance(image, np.ndarray): | |
img = image_array_to_pil_image(image) | |
elif isinstance(image, PIL.Image.Image): | |
img = image | |
else: | |
raise TypeError(f"Unsupported image type: {type(image)}") | |
img.save(fpath) | |
except Exception as e: | |
print(f"Error writing image {fpath}: {e}") | |
def worker_thread_loop(queue: queue.Queue): | |
while True: | |
item = queue.get() | |
if item is None: | |
queue.task_done() | |
break | |
image_array, fpath = item | |
write_image(image_array, fpath) | |
queue.task_done() | |
def worker_process(queue: queue.Queue, num_threads: int): | |
threads = [] | |
for _ in range(num_threads): | |
t = threading.Thread(target=worker_thread_loop, args=(queue,)) | |
t.daemon = True | |
t.start() | |
threads.append(t) | |
for t in threads: | |
t.join() | |
class AsyncImageWriter: | |
""" | |
This class abstract away the initialisation of processes or/and threads to | |
save images on disk asynchrounously, which is critical to control a robot and record data | |
at a high frame rate. | |
When `num_processes=0`, it creates a threads pool of size `num_threads`. | |
When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts | |
their own threads pool of size `num_threads`. | |
The optimal number of processes and threads depends on your computer capabilities. | |
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower | |
the number of threads. If it is still not stable, try to use 1 subprocess, or more. | |
""" | |
def __init__(self, num_processes: int = 0, num_threads: int = 1): | |
self.num_processes = num_processes | |
self.num_threads = num_threads | |
self.queue = None | |
self.threads = [] | |
self.processes = [] | |
self._stopped = False | |
if num_threads <= 0 and num_processes <= 0: | |
raise ValueError("Number of threads and processes must be greater than zero.") | |
if self.num_processes == 0: | |
# Use threading | |
self.queue = queue.Queue() | |
for _ in range(self.num_threads): | |
t = threading.Thread(target=worker_thread_loop, args=(self.queue,)) | |
t.daemon = True | |
t.start() | |
self.threads.append(t) | |
else: | |
# Use multiprocessing | |
self.queue = multiprocessing.JoinableQueue() | |
for _ in range(self.num_processes): | |
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads)) | |
p.daemon = True | |
p.start() | |
self.processes.append(p) | |
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path): | |
if isinstance(image, torch.Tensor): | |
# Convert tensor to numpy array to minimize main process time | |
image = image.cpu().numpy() | |
self.queue.put((image, fpath)) | |
def wait_until_done(self): | |
self.queue.join() | |
def stop(self): | |
if self._stopped: | |
return | |
if self.num_processes == 0: | |
for _ in self.threads: | |
self.queue.put(None) | |
for t in self.threads: | |
t.join() | |
else: | |
num_nones = self.num_processes * self.num_threads | |
for _ in range(num_nones): | |
self.queue.put(None) | |
for p in self.processes: | |
p.join() | |
if p.is_alive(): | |
p.terminate() | |
self.queue.close() | |
self.queue.join_thread() | |
self._stopped = True | |