Update README.md
Browse files
README.md
CHANGED
@@ -59,23 +59,25 @@ We provide pretrained CHATS checkpoints on SDXL for easy download and evaluation
|
|
59 |
## 🛠️ Quick Start
|
60 |
|
61 |
```python
|
|
|
62 |
from pipeline import ChatsSDXLPipeline
|
63 |
|
64 |
# Load CHATS-SDXL pipeline
|
65 |
pipe = ChatsSDXLPipeline.from_pretrained(
|
66 |
"AIDC-AI/CHATS",
|
67 |
torch_dtype=torch.bfloat16
|
68 |
-
|
|
|
69 |
|
70 |
# Generate images
|
71 |
images = pipe(
|
72 |
-
|
73 |
num_inference_steps=50,
|
74 |
guidance_scale=5,
|
75 |
seed=0
|
76 |
)
|
77 |
|
78 |
-
#
|
79 |
for i, img in enumerate(images):
|
80 |
img.save(f"output_{i}.png")
|
81 |
```
|
@@ -107,8 +109,8 @@ accelerate launch --config_file=config/ac_ds_8gpu_zero0.yaml train.py \
|
|
107 |
### Args:
|
108 |
- config_file: This DeepSpeed parameter allows you to specify the configuration file. If you wish to adjust the number of GPUs used for training, simply change the value of **num_processes** in the ac_ds_xgpu_zero0.yaml file to reflect the desired GPU count.
|
109 |
- pretrained_model_name_or_path: name or patch of unet model to load
|
110 |
-
- pretrained_vae_model_name_or_path:
|
111 |
-
- max_train_steps: max steps to
|
112 |
- output: output dir
|
113 |
- dataset_name: the huggingface sufix of the selected dataset (e.g. OIP)
|
114 |
|
@@ -137,4 +139,4 @@ The project is released under Apache License 2.0 (http://www.apache.org/licenses
|
|
137 |
|
138 |
## 🚨 Disclaimer
|
139 |
|
140 |
-
We used compliance checking algorithms during the training process, to ensure the compliance of the trained model to the best of our ability. Due to complex data and the diversity of language model usage scenarios, we cannot guarantee that the model is completely free of copyright issues or improper content. If you believe anything infringes on your rights or generates improper content, please contact us, and we will promptly address the matter.
|
|
|
59 |
## 🛠️ Quick Start
|
60 |
|
61 |
```python
|
62 |
+
import torch
|
63 |
from pipeline import ChatsSDXLPipeline
|
64 |
|
65 |
# Load CHATS-SDXL pipeline
|
66 |
pipe = ChatsSDXLPipeline.from_pretrained(
|
67 |
"AIDC-AI/CHATS",
|
68 |
torch_dtype=torch.bfloat16
|
69 |
+
)
|
70 |
+
pipe.to("cuda")
|
71 |
|
72 |
# Generate images
|
73 |
images = pipe(
|
74 |
+
prompt=["A serene mountain lake at sunset"],
|
75 |
num_inference_steps=50,
|
76 |
guidance_scale=5,
|
77 |
seed=0
|
78 |
)
|
79 |
|
80 |
+
# Save outputs
|
81 |
for i, img in enumerate(images):
|
82 |
img.save(f"output_{i}.png")
|
83 |
```
|
|
|
109 |
### Args:
|
110 |
- config_file: This DeepSpeed parameter allows you to specify the configuration file. If you wish to adjust the number of GPUs used for training, simply change the value of **num_processes** in the ac_ds_xgpu_zero0.yaml file to reflect the desired GPU count.
|
111 |
- pretrained_model_name_or_path: name or patch of unet model to load
|
112 |
+
- pretrained_vae_model_name_or_path: name or patch of vae model to load
|
113 |
+
- max_train_steps: max steps to train
|
114 |
- output: output dir
|
115 |
- dataset_name: the huggingface sufix of the selected dataset (e.g. OIP)
|
116 |
|
|
|
139 |
|
140 |
## 🚨 Disclaimer
|
141 |
|
142 |
+
We used compliance checking algorithms during the training process, to ensure the compliance of the trained model to the best of our ability. Due to complex data and the diversity of language model usage scenarios, we cannot guarantee that the model is completely free of copyright issues or improper content. If you believe anything infringes on your rights or generates improper content, please contact us, and we will promptly address the matter.
|