File size: 3,577 Bytes
2792168
 
 
6cd828a
8d252a1
 
 
 
2792168
714eb56
2792168
 
 
 
714eb56
 
2792168
 
 
 
 
714eb56
 
2792168
 
 
 
 
714eb56
 
2792168
 
 
 
 
 
 
 
 
 
8d252a1
 
 
21380b4
2792168
 
8d252a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2792168
 
 
 
 
 
6cd828a
 
 
 
 
2792168
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import replicate
import requests
import os
import base64
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

def virtual_try_on(garm_img_data, human_img_data, garment_des):
    """
    Perform virtual try-on using the IDM-VTON model on Replicate.
    
    Args:
        garm_img_data (bytes): Binary data of the garment image
        human_img_data (bytes): Binary data of the human image
        garment_des (str): Description of the garment
        
    Returns:
        str: Path to the saved output image
    """
    # Convert binary data to data URLs
    def bytes_to_data_url(img_data):
        encoded = base64.b64encode(img_data).decode('utf-8')
        return f"data:image/png;base64,{encoded}"
    
    # Prepare input for the model
    input_data = {
        "garm_img": bytes_to_data_url(garm_img_data),
        "human_img": bytes_to_data_url(human_img_data),
        "garment_des": garment_des
    }
    
    print("Sending request to IDM-VTON model...")
    # Run the model on Replicate
    output = replicate.run(
        "cuuupid/idm-vton:c871bb9b046607b680449ecbae55fd8c6d945e0a1948644bf2361b3d021d3ff4",
        input=input_data
    )
    
    # Validate and process the model output
    if not output:
        raise ValueError("Empty response received from the model")
    
    # Save the output image
    output_path = "output.jpg"
    
    # Handle FileOutput type from Replicate
    if hasattr(output, 'save'):
        # If output is a FileOutput object, save it directly
        output.save(output_path)
    else:
        # Handle URL-based response (fallback to existing logic)
        if isinstance(output, str):
            output_url = output
        elif isinstance(output, list) and output:
            output_url = output[0]
        else:
            # Instead of raising an error, try to get the content directly
            try:
                with open(output_path, "wb") as file:
                    file.write(output.read())
                return output_path
            except Exception as e:
                raise ValueError(f"Unable to process model output: {str(e)}")
            
        # Validate URL scheme
        if not output_url.startswith(('http://', 'https://')) and not output_url.startswith('http'):
            output_url = 'https://' + output_url.lstrip('/')
        # Download the image from the URL
        response = requests.get(output_url)
        response.raise_for_status()  # Raise an exception for bad status codes
        
        with open(output_path, "wb") as file:
            file.write(response.content)
    
    print(f"Virtual try-on complete! Output saved to {output_path}")
    return output_path

# Example usage
if __name__ == "__main__":
    print("Please use this module by importing it in your application.")
    print("Example usage:")
    print("from vton import virtual_try_on")
    print("output_path = virtual_try_on(garment_image_path, person_image_path, garment_description)")
    print("\nNote: Both garment_image_path and person_image_path should be paths to local image files.")
    
    # Check if REPLICATE_API_TOKEN is set
    if "REPLICATE_API_TOKEN" not in os.environ:
        print("Warning: REPLICATE_API_TOKEN environment variable is not set.")
        print("Please set it using: export REPLICATE_API_TOKEN='your_token_here'")
    
    try:
        output_path = virtual_try_on(garm_img, human_img, garment_des)
        print(f"Success! Try-on image saved to {output_path}")
    except Exception as e:
        print(f"Error occurred: {e}")