rahul7star commited on
Commit
65f8095
·
verified ·
1 Parent(s): b287482

Update autorun_lora_gradio.py

Browse files
Files changed (1) hide show
  1. autorun_lora_gradio.py +188 -70
autorun_lora_gradio.py CHANGED
@@ -1,59 +1,193 @@
1
  import os
2
  import uuid
3
- import gradio as gr
 
 
 
4
  from pathlib import Path
5
- from huggingface_hub import hf_hub_download
6
- from flux_train_ui import create_dataset, start_training # <-- update this import as needed
 
 
7
 
8
- # Constants
9
  REPO_ID = "rahul7star/ohamlab"
10
  FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81"
11
  CONCEPT_SENTENCE = "ohamlab style"
12
  LORA_NAME = "ohami_filter_autorun"
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def auto_run_lora_from_repo():
15
- local_dir = Path(f"/tmp/{LORA_NAME}-{uuid.uuid4()}")
16
- os.makedirs(local_dir, exist_ok=True)
17
-
18
- # Download at least one file to force HF to pull full folder
19
- hf_hub_download(
20
- repo_id=REPO_ID,
21
- repo_type="dataset",
22
- subfolder=FOLDER_IN_REPO,
23
- local_dir=local_dir,
24
- local_dir_use_symlinks=False,
25
- force_download=False,
26
- etag_timeout=10,
27
- allow_patterns=["*.jpg", "*.png", "*.jpeg"],
28
- )
29
-
30
- image_dir = local_dir / FOLDER_IN_REPO
31
- image_paths = list(image_dir.rglob("*.jpg")) + list(image_dir.rglob("*.jpeg")) + list(image_dir.rglob("*.png"))
32
-
33
- if not image_paths:
34
- raise gr.Error("No images found in the Hugging Face repo folder.")
35
-
36
- # Captions
37
- captions = [
38
- f"Generated image caption for {img.stem} in the {CONCEPT_SENTENCE} [trigger]" for img in image_paths
39
- ]
40
-
41
- # Create dataset
42
- dataset_path = create_dataset(image_paths, *captions)
43
-
44
- # Static prompts
45
- sample_1 = f"A stylized portrait using {CONCEPT_SENTENCE}"
46
- sample_2 = f"A cat in the {CONCEPT_SENTENCE}"
47
- sample_3 = f"A selfie processed in {CONCEPT_SENTENCE}"
48
-
49
- # Config
50
- steps = 1000
51
- lr = 4e-4
52
- rank = 16
53
- model_to_train = "dev"
54
- low_vram = True
55
- use_more_advanced_options = True
56
- more_advanced_options = """\
57
  training:
58
  seed: 42
59
  precision: bf16
@@ -62,30 +196,14 @@ augmentation:
62
  flip: true
63
  color_jitter: true
64
  """
 
 
 
65
 
66
- # Train
67
- return start_training(
68
- lora_name=LORA_NAME,
69
- concept_sentence=CONCEPT_SENTENCE,
70
- steps=steps,
71
- lr=lr,
72
- rank=rank,
73
- model_to_train=model_to_train,
74
- low_vram=low_vram,
75
- dataset_folder=dataset_path,
76
- sample_1=sample_1,
77
- sample_2=sample_2,
78
- sample_3=sample_3,
79
- use_more_advanced_options=use_more_advanced_options,
80
- more_advanced_options=more_advanced_options
81
- )
82
-
83
- # Gradio UI
84
- with gr.Blocks(title="LoRA Autorun from HF Repo") as demo:
85
- gr.Markdown("# 🚀 Auto Run LoRA from Hugging Face Repo")
86
- output = gr.Textbox(label="Training Status", lines=3)
87
- run_button = gr.Button("Run Training from HF Repo")
88
- run_button.click(fn=auto_run_lora_from_repo, outputs=output)
89
 
 
90
  if __name__ == "__main__":
91
- demo.launch(share=True)
 
 
1
  import os
2
  import uuid
3
+ import yaml
4
+ import json
5
+ import shutil
6
+ import torch
7
  from pathlib import Path
8
+ from PIL import Image
9
+ from fastapi import FastAPI
10
+ from fastapi.responses import JSONResponse
11
+ from huggingface_hub import hf_hub_download, whoami
12
 
13
+ # ========== CONFIGURATION ==========
14
  REPO_ID = "rahul7star/ohamlab"
15
  FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81"
16
  CONCEPT_SENTENCE = "ohamlab style"
17
  LORA_NAME = "ohami_filter_autorun"
18
 
19
+ # ========== FASTAPI APP ==========
20
+ app = FastAPI()
21
+
22
+ # ========== HELPERS ==========
23
+ def create_dataset(images, *captions):
24
+ destination_folder = f"datasets_{uuid.uuid4()}"
25
+ os.makedirs(destination_folder, exist_ok=True)
26
+
27
+ jsonl_file_path = os.path.join(destination_folder, "metadata.jsonl")
28
+ with open(jsonl_file_path, "a") as jsonl_file:
29
+ for index, image in enumerate(images):
30
+ new_image_path = shutil.copy(str(image), destination_folder)
31
+ caption = captions[index]
32
+ file_name = os.path.basename(new_image_path)
33
+ data = {"file_name": file_name, "prompt": caption}
34
+ jsonl_file.write(json.dumps(data) + "\n")
35
+
36
+ return destination_folder
37
+
38
+ def recursive_update(d, u):
39
+ for k, v in u.items():
40
+ if isinstance(v, dict) and v:
41
+ d[k] = recursive_update(d.get(k, {}), v)
42
+ else:
43
+ d[k] = v
44
+ return d
45
+
46
+ def start_training(
47
+ lora_name,
48
+ concept_sentence,
49
+ steps,
50
+ lr,
51
+ rank,
52
+ model_to_train,
53
+ low_vram,
54
+ dataset_folder,
55
+ sample_1,
56
+ sample_2,
57
+ sample_3,
58
+ use_more_advanced_options,
59
+ more_advanced_options,
60
+ ):
61
+ try:
62
+ user = whoami()
63
+ username = user.get("name", "anonymous")
64
+ push_to_hub = True
65
+ except:
66
+ username = "anonymous"
67
+ push_to_hub = False
68
+
69
+ slugged_lora_name = lora_name.replace(" ", "_").lower()
70
+
71
+ # Load base config
72
+ config = {
73
+ "config": {
74
+ "name": slugged_lora_name,
75
+ "process": [
76
+ {
77
+ "model": {
78
+ "low_vram": low_vram,
79
+ "is_flux": True,
80
+ "quantize": True,
81
+ "name_or_path": "black-forest-labs/FLUX.1-dev"
82
+ },
83
+ "network": {
84
+ "linear": rank,
85
+ "linear_alpha": rank,
86
+ "type": "lora"
87
+ },
88
+ "train": {
89
+ "steps": steps,
90
+ "lr": lr,
91
+ "skip_first_sample": True,
92
+ "batch_size": 1,
93
+ "dtype": "bf16",
94
+ "gradient_accumulation_steps": 1,
95
+ "gradient_checkpointing": True,
96
+ "noise_scheduler": "flowmatch",
97
+ "optimizer": "adamw8bit",
98
+ "ema_config": {
99
+ "use_ema": True,
100
+ "ema_decay": 0.99
101
+ }
102
+ },
103
+ "datasets": [
104
+ {"folder_path": dataset_folder}
105
+ ],
106
+ "save": {
107
+ "dtype": "float16",
108
+ "save_every": 10000,
109
+ "push_to_hub": push_to_hub,
110
+ "hf_repo_id": f"{username}/{slugged_lora_name}",
111
+ "hf_private": True,
112
+ "max_step_saves_to_keep": 4
113
+ },
114
+ "sample": {
115
+ "guidance_scale": 3.5,
116
+ "sample_every": steps,
117
+ "sample_steps": 28,
118
+ "width": 1024,
119
+ "height": 1024,
120
+ "walk_seed": True,
121
+ "seed": 42,
122
+ "sampler": "flowmatch",
123
+ "prompts": [p for p in [sample_1, sample_2, sample_3] if p]
124
+ },
125
+ "trigger_word": concept_sentence
126
+ }
127
+ ]
128
+ }
129
+ }
130
+
131
+ # Apply advanced YAML overrides if any
132
+ if use_more_advanced_options and more_advanced_options:
133
+ advanced_config = yaml.safe_load(more_advanced_options)
134
+ config["config"]["process"][0] = recursive_update(config["config"]["process"][0], advanced_config)
135
+
136
+ # Save YAML config
137
+ os.makedirs("tmp_configs", exist_ok=True)
138
+ config_path = f"tmp_configs/{uuid.uuid4()}_{slugged_lora_name}.yaml"
139
+ with open(config_path, "w") as f:
140
+ yaml.dump(config, f)
141
+
142
+ # Simulate training
143
+ print(f"[INFO] Starting training with config: {config_path}")
144
+ print(json.dumps(config, indent=2))
145
+ return f"Training started successfully with config: {config_path}"
146
+
147
+ # ========== MAIN ENDPOINT ==========
148
+ @app.post("/train-from-hf")
149
  def auto_run_lora_from_repo():
150
+ try:
151
+ local_dir = Path(f"/tmp/{LORA_NAME}-{uuid.uuid4()}")
152
+ os.makedirs(local_dir, exist_ok=True)
153
+
154
+ hf_hub_download(
155
+ repo_id=REPO_ID,
156
+ repo_type="dataset",
157
+ subfolder=FOLDER_IN_REPO,
158
+ local_dir=local_dir,
159
+ local_dir_use_symlinks=False,
160
+ force_download=False,
161
+ etag_timeout=10,
162
+ allow_patterns=["*.jpg", "*.png", "*.jpeg"],
163
+ )
164
+
165
+ image_dir = local_dir / FOLDER_IN_REPO
166
+ image_paths = list(image_dir.rglob("*.jpg")) + list(image_dir.rglob("*.jpeg")) + list(image_dir.rglob("*.png"))
167
+
168
+ if not image_paths:
169
+ return JSONResponse(status_code=400, content={"error": "No images found in the HF repo folder."})
170
+
171
+ captions = [
172
+ f"Autogenerated caption for {img.stem} in the {CONCEPT_SENTENCE} [trigger]" for img in image_paths
173
+ ]
174
+
175
+ dataset_path = create_dataset(image_paths, *captions)
176
+
177
+ result = start_training(
178
+ lora_name=LORA_NAME,
179
+ concept_sentence=CONCEPT_SENTENCE,
180
+ steps=1000,
181
+ lr=4e-4,
182
+ rank=16,
183
+ model_to_train="dev",
184
+ low_vram=True,
185
+ dataset_folder=dataset_path,
186
+ sample_1=f"A stylized portrait using {CONCEPT_SENTENCE}",
187
+ sample_2=f"A cat in the {CONCEPT_SENTENCE}",
188
+ sample_3=f"A selfie processed in {CONCEPT_SENTENCE}",
189
+ use_more_advanced_options=True,
190
+ more_advanced_options="""
 
191
  training:
192
  seed: 42
193
  precision: bf16
 
196
  flip: true
197
  color_jitter: true
198
  """
199
+ )
200
+
201
+ return {"message": result}
202
 
203
+ except Exception as e:
204
+ return JSONResponse(status_code=500, content={"error": str(e)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ # ========== FASTAPI RUNNER ==========
207
  if __name__ == "__main__":
208
+ import uvicorn
209
+ uvicorn.run(app, host="0.0.0.0", port=8000)