vton-002 / vton.py
noumanjavaid's picture
Update vton.py
8d252a1 verified
import replicate
import requests
import os
import base64
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
def virtual_try_on(garm_img_data, human_img_data, garment_des):
"""
Perform virtual try-on using the IDM-VTON model on Replicate.
Args:
garm_img_data (bytes): Binary data of the garment image
human_img_data (bytes): Binary data of the human image
garment_des (str): Description of the garment
Returns:
str: Path to the saved output image
"""
# Convert binary data to data URLs
def bytes_to_data_url(img_data):
encoded = base64.b64encode(img_data).decode('utf-8')
return f"data:image/png;base64,{encoded}"
# Prepare input for the model
input_data = {
"garm_img": bytes_to_data_url(garm_img_data),
"human_img": bytes_to_data_url(human_img_data),
"garment_des": garment_des
}
print("Sending request to IDM-VTON model...")
# Run the model on Replicate
output = replicate.run(
"cuuupid/idm-vton:c871bb9b046607b680449ecbae55fd8c6d945e0a1948644bf2361b3d021d3ff4",
input=input_data
)
# Validate and process the model output
if not output:
raise ValueError("Empty response received from the model")
# Save the output image
output_path = "output.jpg"
# Handle FileOutput type from Replicate
if hasattr(output, 'save'):
# If output is a FileOutput object, save it directly
output.save(output_path)
else:
# Handle URL-based response (fallback to existing logic)
if isinstance(output, str):
output_url = output
elif isinstance(output, list) and output:
output_url = output[0]
else:
# Instead of raising an error, try to get the content directly
try:
with open(output_path, "wb") as file:
file.write(output.read())
return output_path
except Exception as e:
raise ValueError(f"Unable to process model output: {str(e)}")
# Validate URL scheme
if not output_url.startswith(('http://', 'https://')) and not output_url.startswith('http'):
output_url = 'https://' + output_url.lstrip('/')
# Download the image from the URL
response = requests.get(output_url)
response.raise_for_status() # Raise an exception for bad status codes
with open(output_path, "wb") as file:
file.write(response.content)
print(f"Virtual try-on complete! Output saved to {output_path}")
return output_path
# Example usage
if __name__ == "__main__":
print("Please use this module by importing it in your application.")
print("Example usage:")
print("from vton import virtual_try_on")
print("output_path = virtual_try_on(garment_image_path, person_image_path, garment_description)")
print("\nNote: Both garment_image_path and person_image_path should be paths to local image files.")
# Check if REPLICATE_API_TOKEN is set
if "REPLICATE_API_TOKEN" not in os.environ:
print("Warning: REPLICATE_API_TOKEN environment variable is not set.")
print("Please set it using: export REPLICATE_API_TOKEN='your_token_here'")
try:
output_path = virtual_try_on(garm_img, human_img, garment_des)
print(f"Success! Try-on image saved to {output_path}")
except Exception as e:
print(f"Error occurred: {e}")