import os import json import boto3 import torch import argparse import time from omegaconf import OmegaConf from inference import inference_process # Ensure inference.py is in the same directory or update the import path def download_from_s3(s3_path, local_path): s3 = boto3.client('s3') bucket, key = s3_path.replace("s3://", "").split("/", 1) s3.download_file(bucket, key, local_path) def upload_to_s3(local_path, s3_path): s3 = boto3.client('s3') bucket, key = s3_path.replace("s3://", "").split("/", 1) s3.upload_file(local_path, bucket, key) def model_fn(model_dir): # config_path = os.path.join(model_dir, 'config.json') # # Create a placeholder config.json if it does not exist # if not os.path.exists(config_path): # print(f"config.json not found in {model_dir}. Creating a placeholder config.json.") # config_content = { # "placeholder": "This is a placeholder config.json" # } # with open(config_path, 'w') as config_file: # json.dump(config_content, config_file) return model_dir def input_fn(request_body, content_type='application/json'): if content_type == 'application/json': input_data = json.loads(request_body) # Download source_image and driving_audio from S3 if necessary source_image_path = input_data['source_image'] driving_audio_path = input_data['driving_audio'] local_source_image = "/opt/ml/input/data/source_image.jpg" local_driving_audio = "/opt/ml/input/data/driving_audio.wav" if source_image_path.startswith("s3://"): download_from_s3(source_image_path, local_source_image) input_data['source_image'] = local_source_image if driving_audio_path.startswith("s3://"): download_from_s3(driving_audio_path, local_driving_audio) input_data['driving_audio'] = local_driving_audio args = argparse.Namespace(**input_data.get('config', {})) s3_output = input_data.get('output', None) return args, s3_output else: raise ValueError(f"Unsupported content type: {content_type}") def predict_fn(input_data, model): args, s3_output = input_data # Call the inference process inference_process(args) return '.cache/output.mp4', s3_output def output_fn(prediction, content_type='application/json'): local_output, s3_output = prediction # Wait for the output file to be created and upload it to S3 while not os.path.exists(local_output): time.sleep(1) return json.dumps({'status': 'completed', 's3_output': s3_output})