from io import BytesIO from multiprocessing.connection import Listener from os import chmod, remove from os.path import abspath, exists from pathlib import Path from PIL.JpegImagePlugin import JpegImageFile from pipelines.models import TextToImageRequest import torch from pipeline import load_pipeline, infer SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock") def main(): print(f"pipeline loading") pipeline = load_pipeline() generator = torch.Generator(pipeline.device) print(f"loaded the pipeline, creating socket at '{SOCKET}'") if exists(SOCKET): remove(SOCKET) with Listener(SOCKET) as listener: chmod(SOCKET, 0o777) print(f"waiting...") with listener.accept() as connection: print(f"accepted") while True: try: request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8")) except EOFError: print(f"Inference socket exiting") return image = infer(request, pipeline, generator.manual_seed(request.seed)) data = BytesIO() image.save(data, format=JpegImageFile.format) packet = data.getvalue() connection.send_bytes(packet) if __name__ == '__main__': main()