File size: 2,662 Bytes
2a0635e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import json
import boto3
import torch
import argparse
import time
from omegaconf import OmegaConf

from inference import inference_process  # Ensure inference.py is in the same directory or update the import path

def download_from_s3(s3_path, local_path):
    s3 = boto3.client('s3')
    bucket, key = s3_path.replace("s3://", "").split("/", 1)
    s3.download_file(bucket, key, local_path)

def upload_to_s3(local_path, s3_path):
    s3 = boto3.client('s3')
    bucket, key = s3_path.replace("s3://", "").split("/", 1)
    s3.upload_file(local_path, bucket, key)

def model_fn(model_dir):
    # config_path = os.path.join(model_dir, 'config.json')
    
    # # Create a placeholder config.json if it does not exist
    # if not os.path.exists(config_path):
    #     print(f"config.json not found in {model_dir}. Creating a placeholder config.json.")
    #     config_content = {
    #         "placeholder": "This is a placeholder config.json"
    #     }
    #     with open(config_path, 'w') as config_file:
    #         json.dump(config_content, config_file)

    return model_dir

def input_fn(request_body, content_type='application/json'):
    if content_type == 'application/json':
        input_data = json.loads(request_body)

        # Download source_image and driving_audio from S3 if necessary
        source_image_path = input_data['source_image']
        driving_audio_path = input_data['driving_audio']
        
        local_source_image = "/opt/ml/input/data/source_image.jpg"
        local_driving_audio = "/opt/ml/input/data/driving_audio.wav"
        
        if source_image_path.startswith("s3://"):
            download_from_s3(source_image_path, local_source_image)
            input_data['source_image'] = local_source_image
        if driving_audio_path.startswith("s3://"):
            download_from_s3(driving_audio_path, local_driving_audio)
            input_data['driving_audio'] = local_driving_audio

        args = argparse.Namespace(**input_data.get('config', {}))
        s3_output = input_data.get('output', None)

        return args, s3_output
    else:
        raise ValueError(f"Unsupported content type: {content_type}")

def predict_fn(input_data, model):
    args, s3_output = input_data
    
    # Call the inference process
    inference_process(args)

    return '.cache/output.mp4', s3_output

def output_fn(prediction, content_type='application/json'):
    local_output, s3_output = prediction

    # Wait for the output file to be created and upload it to S3
    while not os.path.exists(local_output):
        time.sleep(1)

    return json.dumps({'status': 'completed', 's3_output': s3_output})