hallo / scripts /sagemaker.py
Yohai Rosen
test
2a0635e
import os
import json
import boto3
import torch
import argparse
import time
from omegaconf import OmegaConf
from inference import inference_process # Ensure inference.py is in the same directory or update the import path
def download_from_s3(s3_path, local_path):
s3 = boto3.client('s3')
bucket, key = s3_path.replace("s3://", "").split("/", 1)
s3.download_file(bucket, key, local_path)
def upload_to_s3(local_path, s3_path):
s3 = boto3.client('s3')
bucket, key = s3_path.replace("s3://", "").split("/", 1)
s3.upload_file(local_path, bucket, key)
def model_fn(model_dir):
# config_path = os.path.join(model_dir, 'config.json')
# # Create a placeholder config.json if it does not exist
# if not os.path.exists(config_path):
# print(f"config.json not found in {model_dir}. Creating a placeholder config.json.")
# config_content = {
# "placeholder": "This is a placeholder config.json"
# }
# with open(config_path, 'w') as config_file:
# json.dump(config_content, config_file)
return model_dir
def input_fn(request_body, content_type='application/json'):
if content_type == 'application/json':
input_data = json.loads(request_body)
# Download source_image and driving_audio from S3 if necessary
source_image_path = input_data['source_image']
driving_audio_path = input_data['driving_audio']
local_source_image = "/opt/ml/input/data/source_image.jpg"
local_driving_audio = "/opt/ml/input/data/driving_audio.wav"
if source_image_path.startswith("s3://"):
download_from_s3(source_image_path, local_source_image)
input_data['source_image'] = local_source_image
if driving_audio_path.startswith("s3://"):
download_from_s3(driving_audio_path, local_driving_audio)
input_data['driving_audio'] = local_driving_audio
args = argparse.Namespace(**input_data.get('config', {}))
s3_output = input_data.get('output', None)
return args, s3_output
else:
raise ValueError(f"Unsupported content type: {content_type}")
def predict_fn(input_data, model):
args, s3_output = input_data
# Call the inference process
inference_process(args)
return '.cache/output.mp4', s3_output
def output_fn(prediction, content_type='application/json'):
local_output, s3_output = prediction
# Wait for the output file to be created and upload it to S3
while not os.path.exists(local_output):
time.sleep(1)
return json.dumps({'status': 'completed', 's3_output': s3_output})