Spaces:
Running
Running
Update autorun_lora_gradio.py
Browse files- autorun_lora_gradio.py +188 -70
autorun_lora_gradio.py
CHANGED
@@ -1,59 +1,193 @@
|
|
1 |
import os
|
2 |
import uuid
|
3 |
-
import
|
|
|
|
|
|
|
4 |
from pathlib import Path
|
5 |
-
from
|
6 |
-
from
|
|
|
|
|
7 |
|
8 |
-
#
|
9 |
REPO_ID = "rahul7star/ohamlab"
|
10 |
FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81"
|
11 |
CONCEPT_SENTENCE = "ohamlab style"
|
12 |
LORA_NAME = "ohami_filter_autorun"
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def auto_run_lora_from_repo():
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
more_advanced_options = """\
|
57 |
training:
|
58 |
seed: 42
|
59 |
precision: bf16
|
@@ -62,30 +196,14 @@ augmentation:
|
|
62 |
flip: true
|
63 |
color_jitter: true
|
64 |
"""
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
lora_name=LORA_NAME,
|
69 |
-
concept_sentence=CONCEPT_SENTENCE,
|
70 |
-
steps=steps,
|
71 |
-
lr=lr,
|
72 |
-
rank=rank,
|
73 |
-
model_to_train=model_to_train,
|
74 |
-
low_vram=low_vram,
|
75 |
-
dataset_folder=dataset_path,
|
76 |
-
sample_1=sample_1,
|
77 |
-
sample_2=sample_2,
|
78 |
-
sample_3=sample_3,
|
79 |
-
use_more_advanced_options=use_more_advanced_options,
|
80 |
-
more_advanced_options=more_advanced_options
|
81 |
-
)
|
82 |
-
|
83 |
-
# Gradio UI
|
84 |
-
with gr.Blocks(title="LoRA Autorun from HF Repo") as demo:
|
85 |
-
gr.Markdown("# 🚀 Auto Run LoRA from Hugging Face Repo")
|
86 |
-
output = gr.Textbox(label="Training Status", lines=3)
|
87 |
-
run_button = gr.Button("Run Training from HF Repo")
|
88 |
-
run_button.click(fn=auto_run_lora_from_repo, outputs=output)
|
89 |
|
|
|
90 |
if __name__ == "__main__":
|
91 |
-
|
|
|
|
1 |
import os
|
2 |
import uuid
|
3 |
+
import yaml
|
4 |
+
import json
|
5 |
+
import shutil
|
6 |
+
import torch
|
7 |
from pathlib import Path
|
8 |
+
from PIL import Image
|
9 |
+
from fastapi import FastAPI
|
10 |
+
from fastapi.responses import JSONResponse
|
11 |
+
from huggingface_hub import hf_hub_download, whoami
|
12 |
|
13 |
+
# ========== CONFIGURATION ==========
|
14 |
REPO_ID = "rahul7star/ohamlab"
|
15 |
FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81"
|
16 |
CONCEPT_SENTENCE = "ohamlab style"
|
17 |
LORA_NAME = "ohami_filter_autorun"
|
18 |
|
19 |
+
# ========== FASTAPI APP ==========
|
20 |
+
app = FastAPI()
|
21 |
+
|
22 |
+
# ========== HELPERS ==========
|
23 |
+
def create_dataset(images, *captions):
|
24 |
+
destination_folder = f"datasets_{uuid.uuid4()}"
|
25 |
+
os.makedirs(destination_folder, exist_ok=True)
|
26 |
+
|
27 |
+
jsonl_file_path = os.path.join(destination_folder, "metadata.jsonl")
|
28 |
+
with open(jsonl_file_path, "a") as jsonl_file:
|
29 |
+
for index, image in enumerate(images):
|
30 |
+
new_image_path = shutil.copy(str(image), destination_folder)
|
31 |
+
caption = captions[index]
|
32 |
+
file_name = os.path.basename(new_image_path)
|
33 |
+
data = {"file_name": file_name, "prompt": caption}
|
34 |
+
jsonl_file.write(json.dumps(data) + "\n")
|
35 |
+
|
36 |
+
return destination_folder
|
37 |
+
|
38 |
+
def recursive_update(d, u):
|
39 |
+
for k, v in u.items():
|
40 |
+
if isinstance(v, dict) and v:
|
41 |
+
d[k] = recursive_update(d.get(k, {}), v)
|
42 |
+
else:
|
43 |
+
d[k] = v
|
44 |
+
return d
|
45 |
+
|
46 |
+
def start_training(
|
47 |
+
lora_name,
|
48 |
+
concept_sentence,
|
49 |
+
steps,
|
50 |
+
lr,
|
51 |
+
rank,
|
52 |
+
model_to_train,
|
53 |
+
low_vram,
|
54 |
+
dataset_folder,
|
55 |
+
sample_1,
|
56 |
+
sample_2,
|
57 |
+
sample_3,
|
58 |
+
use_more_advanced_options,
|
59 |
+
more_advanced_options,
|
60 |
+
):
|
61 |
+
try:
|
62 |
+
user = whoami()
|
63 |
+
username = user.get("name", "anonymous")
|
64 |
+
push_to_hub = True
|
65 |
+
except:
|
66 |
+
username = "anonymous"
|
67 |
+
push_to_hub = False
|
68 |
+
|
69 |
+
slugged_lora_name = lora_name.replace(" ", "_").lower()
|
70 |
+
|
71 |
+
# Load base config
|
72 |
+
config = {
|
73 |
+
"config": {
|
74 |
+
"name": slugged_lora_name,
|
75 |
+
"process": [
|
76 |
+
{
|
77 |
+
"model": {
|
78 |
+
"low_vram": low_vram,
|
79 |
+
"is_flux": True,
|
80 |
+
"quantize": True,
|
81 |
+
"name_or_path": "black-forest-labs/FLUX.1-dev"
|
82 |
+
},
|
83 |
+
"network": {
|
84 |
+
"linear": rank,
|
85 |
+
"linear_alpha": rank,
|
86 |
+
"type": "lora"
|
87 |
+
},
|
88 |
+
"train": {
|
89 |
+
"steps": steps,
|
90 |
+
"lr": lr,
|
91 |
+
"skip_first_sample": True,
|
92 |
+
"batch_size": 1,
|
93 |
+
"dtype": "bf16",
|
94 |
+
"gradient_accumulation_steps": 1,
|
95 |
+
"gradient_checkpointing": True,
|
96 |
+
"noise_scheduler": "flowmatch",
|
97 |
+
"optimizer": "adamw8bit",
|
98 |
+
"ema_config": {
|
99 |
+
"use_ema": True,
|
100 |
+
"ema_decay": 0.99
|
101 |
+
}
|
102 |
+
},
|
103 |
+
"datasets": [
|
104 |
+
{"folder_path": dataset_folder}
|
105 |
+
],
|
106 |
+
"save": {
|
107 |
+
"dtype": "float16",
|
108 |
+
"save_every": 10000,
|
109 |
+
"push_to_hub": push_to_hub,
|
110 |
+
"hf_repo_id": f"{username}/{slugged_lora_name}",
|
111 |
+
"hf_private": True,
|
112 |
+
"max_step_saves_to_keep": 4
|
113 |
+
},
|
114 |
+
"sample": {
|
115 |
+
"guidance_scale": 3.5,
|
116 |
+
"sample_every": steps,
|
117 |
+
"sample_steps": 28,
|
118 |
+
"width": 1024,
|
119 |
+
"height": 1024,
|
120 |
+
"walk_seed": True,
|
121 |
+
"seed": 42,
|
122 |
+
"sampler": "flowmatch",
|
123 |
+
"prompts": [p for p in [sample_1, sample_2, sample_3] if p]
|
124 |
+
},
|
125 |
+
"trigger_word": concept_sentence
|
126 |
+
}
|
127 |
+
]
|
128 |
+
}
|
129 |
+
}
|
130 |
+
|
131 |
+
# Apply advanced YAML overrides if any
|
132 |
+
if use_more_advanced_options and more_advanced_options:
|
133 |
+
advanced_config = yaml.safe_load(more_advanced_options)
|
134 |
+
config["config"]["process"][0] = recursive_update(config["config"]["process"][0], advanced_config)
|
135 |
+
|
136 |
+
# Save YAML config
|
137 |
+
os.makedirs("tmp_configs", exist_ok=True)
|
138 |
+
config_path = f"tmp_configs/{uuid.uuid4()}_{slugged_lora_name}.yaml"
|
139 |
+
with open(config_path, "w") as f:
|
140 |
+
yaml.dump(config, f)
|
141 |
+
|
142 |
+
# Simulate training
|
143 |
+
print(f"[INFO] Starting training with config: {config_path}")
|
144 |
+
print(json.dumps(config, indent=2))
|
145 |
+
return f"Training started successfully with config: {config_path}"
|
146 |
+
|
147 |
+
# ========== MAIN ENDPOINT ==========
|
148 |
+
@app.post("/train-from-hf")
|
149 |
def auto_run_lora_from_repo():
|
150 |
+
try:
|
151 |
+
local_dir = Path(f"/tmp/{LORA_NAME}-{uuid.uuid4()}")
|
152 |
+
os.makedirs(local_dir, exist_ok=True)
|
153 |
+
|
154 |
+
hf_hub_download(
|
155 |
+
repo_id=REPO_ID,
|
156 |
+
repo_type="dataset",
|
157 |
+
subfolder=FOLDER_IN_REPO,
|
158 |
+
local_dir=local_dir,
|
159 |
+
local_dir_use_symlinks=False,
|
160 |
+
force_download=False,
|
161 |
+
etag_timeout=10,
|
162 |
+
allow_patterns=["*.jpg", "*.png", "*.jpeg"],
|
163 |
+
)
|
164 |
+
|
165 |
+
image_dir = local_dir / FOLDER_IN_REPO
|
166 |
+
image_paths = list(image_dir.rglob("*.jpg")) + list(image_dir.rglob("*.jpeg")) + list(image_dir.rglob("*.png"))
|
167 |
+
|
168 |
+
if not image_paths:
|
169 |
+
return JSONResponse(status_code=400, content={"error": "No images found in the HF repo folder."})
|
170 |
+
|
171 |
+
captions = [
|
172 |
+
f"Autogenerated caption for {img.stem} in the {CONCEPT_SENTENCE} [trigger]" for img in image_paths
|
173 |
+
]
|
174 |
+
|
175 |
+
dataset_path = create_dataset(image_paths, *captions)
|
176 |
+
|
177 |
+
result = start_training(
|
178 |
+
lora_name=LORA_NAME,
|
179 |
+
concept_sentence=CONCEPT_SENTENCE,
|
180 |
+
steps=1000,
|
181 |
+
lr=4e-4,
|
182 |
+
rank=16,
|
183 |
+
model_to_train="dev",
|
184 |
+
low_vram=True,
|
185 |
+
dataset_folder=dataset_path,
|
186 |
+
sample_1=f"A stylized portrait using {CONCEPT_SENTENCE}",
|
187 |
+
sample_2=f"A cat in the {CONCEPT_SENTENCE}",
|
188 |
+
sample_3=f"A selfie processed in {CONCEPT_SENTENCE}",
|
189 |
+
use_more_advanced_options=True,
|
190 |
+
more_advanced_options="""
|
|
|
191 |
training:
|
192 |
seed: 42
|
193 |
precision: bf16
|
|
|
196 |
flip: true
|
197 |
color_jitter: true
|
198 |
"""
|
199 |
+
)
|
200 |
+
|
201 |
+
return {"message": result}
|
202 |
|
203 |
+
except Exception as e:
|
204 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
+
# ========== FASTAPI RUNNER ==========
|
207 |
if __name__ == "__main__":
|
208 |
+
import uvicorn
|
209 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|