xiaozaa commited on
Commit
1beac4e
·
0 Parent(s):

fist commit

Browse files
.gitignore ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Distribution / packaging
7
+ dist/
8
+ build/
9
+ *.egg-info/
10
+
11
+ # Virtual environments
12
+ venv/
13
+ env/
14
+ .env/
15
+ .venv/
16
+
17
+ # IDE specific files
18
+ .idea/
19
+ .vscode/
20
+ *.swp
21
+ *.swo
22
+
23
+ # Unit test / coverage reports
24
+ htmlcov/
25
+ .tox/
26
+ .coverage
27
+ .coverage.*
28
+ coverage.xml
29
+ *.cover
30
+
31
+ # Jupyter Notebook
32
+ .ipynb_checkpoints
33
+
34
+ # Local development settings
35
+ .env
36
+ .env.local
37
+
38
+ # Logs
39
+ *.log
40
+
41
+ # Database files
42
+ *.db
43
+ *.sqlite3
44
+
45
+ # OS generated files
46
+ .DS_Store
47
+ .DS_Store?
48
+ ._*
49
+ .Spotlight-V100
50
+ .Trashes
51
+ ehthumbs.db
52
+ Thumbs.db
README.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # catvton-flux
2
+
3
+ An advanced virtual try-on solution that combines the power of [CATVTON](https://arxiv.org/abs/2407.15886) (Contrastive Appearance and Topology Virtual Try-On) with Flux fill inpainting model for realistic and accurate clothing transfer.
4
+
5
+ ## Showcase
6
+ | Original | Result |
7
+ |----------|--------|
8
+ | ![Original](example/person/1.jpg) | ![Result](example/result/1.png) |
9
+ | ![Original](example/person/00008_00.jpg) | ![Result](example/result/2.png) |
10
+ | ![Original](example/person/00008_00.jpg) | ![Result](example/result/3.png) |
11
+
12
+ ## Model Weights
13
+ The model weights are trained on the [VITON-HD](https://github.com/shadow2496/VITON-HD) dataset.
14
+ 🤗 [catvton-flux-alpha](https://huggingface.co/xiaozaa/catvton-flux-alpha)
15
+
16
+ ## Prerequisites
17
+ ```bash
18
+ bash
19
+ conda create -n flux python=3.10
20
+ conda activate flux
21
+ pip install -r requirements.txt
22
+ ```
23
+
24
+ ## Usage
25
+
26
+ ```bash
27
+ python tryon_inference.py \
28
+ --image ./example/person/00008_00.jpg \
29
+ --mask ./example/person/00008_00_mask.png \
30
+ --garment ./example/garment/00034_00.jpg \
31
+ --seed 42
32
+ ```
33
+
34
+ ## TODO:
35
+ - [ ] Release the FID score
36
+ - [ ] Add gradio demo
37
+ - [ ] Release updated weights with better performance
38
+
39
+ ## Citation
40
+
41
+ ```bibtex
42
+ @misc{jiang2024catvton,
43
+ title={CATVTON: A Contrastive Approach for Virtual Try-On Network},
44
+ author={Chao Jiang and Xujie Zhang}
45
+ }
46
+ ```
47
+
48
+ ## License
49
+ - The code is licensed under the MIT License.
50
+ - The model weights have the same license as Flux.1 Fill and VITON-HD.
example/garment/00034_00.jpg ADDED
example/garment/00035_00.jpg ADDED
example/garment/04564_00.jpg ADDED
example/person/00008_00.jpg ADDED
example/person/00008_00_mask.png ADDED
example/person/1.jpg ADDED
example/person/1_mask.png ADDED
example/result/1.png ADDED
example/result/2.png ADDED
example/result/3.png ADDED
requirements.txt ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.30.1
2
+ aiohappyeyeballs==2.3.5
3
+ aiohttp==3.10.3
4
+ aiosignal==1.3.1
5
+ annotated-types==0.7.0
6
+ antlr4-python3-runtime==4.9.3
7
+ attrs==24.2.0
8
+ certifi==2024.7.4
9
+ charset-normalizer==3.3.2
10
+ click==8.1.7
11
+ coloredlogs==15.0.1
12
+ contourpy==1.2.1
13
+ cycler==0.12.1
14
+ datasets==2.21.0
15
+ deepspeed==0.14.4
16
+ dill==0.3.8
17
+ docker-pycreds==0.4.0
18
+ einops==0.8.0
19
+ filelock==3.15.4
20
+ flatbuffers==24.3.25
21
+ fonttools==4.53.1
22
+ frozenlist==1.4.1
23
+ fsspec==2024.6.1
24
+ gitdb==4.0.11
25
+ GitPython==3.1.43
26
+ hjson==3.1.0
27
+ huggingface-hub==0.24.5
28
+ humanfriendly==10.0
29
+ idna==3.7
30
+ importlib_metadata==8.2.0
31
+ Jinja2==3.1.4
32
+ kiwisolver==1.4.5
33
+ MarkupSafe==2.1.5
34
+ matplotlib==3.9.2
35
+ mpmath==1.3.0
36
+ multidict==6.0.5
37
+ multiprocess==0.70.16
38
+ networkx==3.3
39
+ ninja==1.11.1.1
40
+ numpy==1.26.4
41
+ nvidia-cublas-cu12==12.1.3.1
42
+ nvidia-cuda-cupti-cu12==12.1.105
43
+ nvidia-cuda-nvrtc-cu12==12.1.105
44
+ nvidia-cuda-runtime-cu12==12.1.105
45
+ nvidia-cudnn-cu12==9.1.0.70
46
+ nvidia-cufft-cu12==11.0.2.54
47
+ nvidia-curand-cu12==10.3.2.106
48
+ nvidia-cusolver-cu12==11.4.5.107
49
+ nvidia-cusparse-cu12==12.1.0.106
50
+ nvidia-ml-py==12.555.43
51
+ nvidia-nccl-cu12==2.20.5
52
+ nvidia-nvjitlink-cu12==12.6.20
53
+ nvidia-nvtx-cu12==12.1.105
54
+ omegaconf==2.3.0
55
+ onnxruntime-gpu==1.18.1
56
+ opencv-python==4.10.0.84
57
+ optimum-quanto==0.2.4
58
+ packaging==24.1
59
+ pandas==2.2.2
60
+ pillow==10.4.0
61
+ platformdirs==4.2.2
62
+ protobuf==5.27.3
63
+ psutil==6.0.0
64
+ py-cpuinfo==9.0.0
65
+ pyarrow==17.0.0
66
+ pydantic==2.8.2
67
+ pydantic_core==2.20.1
68
+ pyparsing==3.1.2
69
+ python-dateutil==2.9.0.post0
70
+ pytz==2024.1
71
+ PyYAML==6.0.2
72
+ regex==2024.7.24
73
+ requests==2.32.3
74
+ safetensors==0.4.4
75
+ sentencepiece==0.2.0
76
+ sentry-sdk==2.13.0
77
+ setproctitle==1.3.3
78
+ six==1.16.0
79
+ smmap==5.0.1
80
+ sympy==1.13.2
81
+ timm==1.0.8
82
+ tokenizers==0.19.1
83
+ torch==2.4.0
84
+ torchvision==0.19.0
85
+ tqdm==4.66.5
86
+ transformers==4.43.3
87
+ triton==3.0.0
88
+ typing_extensions==4.12.2
89
+ tzdata==2024.1
90
+ urllib3==2.2.2
91
+ wandb==0.17.6
92
+ xxhash==3.4.1
93
+ yarl==1.9.4
94
+ zipp==3.20.0
95
+ peft==0.13.2
96
+ bitsandbytes==0.44.1
97
+ prodigyopt
98
+ git+https://github.com/huggingface/diffusers.git
tryon_inference.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from diffusers.utils import load_image, check_min_version
4
+ from diffusers import FluxPriorReduxPipeline, FluxFillPipeline
5
+ from diffusers import FluxTransformer2DModel
6
+ import numpy as np
7
+ from torchvision import transforms
8
+
9
+ def run_inference(
10
+ image_path,
11
+ mask_path,
12
+ garment_path,
13
+ output_garment_path=None,
14
+ output_tryon_path='flux_inpaint_tryon.png',
15
+ size=(576, 768),
16
+ num_steps=50,
17
+ guidance_scale=30,
18
+ seed=42,
19
+ pipe=None
20
+ ):
21
+ # Build pipeline
22
+ if pipe is None:
23
+ transformer = FluxTransformer2DModel.from_pretrained(
24
+ "xiaozaa/catvton-flux-alpha",
25
+ torch_dtype=torch.bfloat16
26
+ )
27
+ pipe = FluxFillPipeline.from_pretrained(
28
+ "black-forest-labs/FLUX.1-dev",
29
+ transformer=transformer,
30
+ torch_dtype=torch.bfloat16
31
+ ).to("cuda")
32
+ else:
33
+ pipe.to("cuda")
34
+
35
+ pipe.transformer.to(torch.bfloat16)
36
+
37
+ # Add transform
38
+ transform = transforms.Compose([
39
+ transforms.ToTensor(),
40
+ transforms.Normalize([0.5], [0.5]) # For RGB images
41
+ ])
42
+ mask_transform = transforms.Compose([
43
+ transforms.ToTensor()
44
+ ])
45
+
46
+ # Load and process images
47
+ print("image_path", image_path)
48
+ image = load_image(image_path).convert("RGB").resize(size)
49
+ mask = load_image(mask_path).convert("RGB").resize(size)
50
+ garment = load_image(garment_path).convert("RGB").resize(size)
51
+
52
+ # Transform images using the new preprocessing
53
+ image_tensor = transform(image)
54
+ mask_tensor = mask_transform(mask)[:1] # Take only first channel
55
+ garment_tensor = transform(garment)
56
+
57
+ # Create concatenated images
58
+ inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
59
+ garment_mask = torch.zeros_like(mask_tensor)
60
+ extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
61
+
62
+ prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
63
+ f"[IMAGE1] Detailed product shot of a clothing" \
64
+ f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."
65
+
66
+ generator = torch.Generator(device="cuda").manual_seed(seed)
67
+
68
+ result = pipe(
69
+ height=size[1],
70
+ width=size[0] * 2,
71
+ image=inpaint_image,
72
+ mask_image=extended_mask,
73
+ num_inference_steps=num_steps,
74
+ generator=generator,
75
+ max_sequence_length=512,
76
+ guidance_scale=guidance_scale,
77
+ prompt=prompt,
78
+ ).images[0]
79
+
80
+ # Split and save results
81
+ width = size[0]
82
+ garment_result = result.crop((0, 0, width, size[1]))
83
+ tryon_result = result.crop((width, 0, width * 2, size[1]))
84
+
85
+ if output_garment_path is not None:
86
+ garment_result.save(output_garment_path)
87
+ tryon_result.save(output_tryon_path)
88
+ return garment_result, tryon_result
89
+
90
+ def main():
91
+ parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference')
92
+ parser.add_argument('--image', required=True, help='Path to the model image')
93
+ parser.add_argument('--mask', required=True, help='Path to the agnostic mask')
94
+ parser.add_argument('--garment', required=True, help='Path to the garment image')
95
+ parser.add_argument('--output-garment', default='flux_inpaint_garment.png', help='Output path for garment result')
96
+ parser.add_argument('--output-tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result')
97
+ parser.add_argument('--steps', type=int, default=50, help='Number of inference steps')
98
+ parser.add_argument('--guidance-scale', type=float, default=30, help='Guidance scale')
99
+ parser.add_argument('--seed', type=int, default=0, help='Random seed')
100
+
101
+ args = parser.parse_args()
102
+
103
+ check_min_version("0.30.2")
104
+
105
+ garment_result, tryon_result = run_inference(
106
+ image_path=args.image,
107
+ mask_path=args.mask,
108
+ garment_path=args.garment,
109
+ output_garment_path=args.output_garment,
110
+ output_tryon_path=args.output_tryon,
111
+ num_steps=args.steps,
112
+ guidance_scale=args.guidance_scale,
113
+ seed=args.seed
114
+ )
115
+ print("Successfully saved garment and try-on images")
116
+
117
+ if __name__ == "__main__":
118
+ main()