Spaces:
Running
Running
import os | |
import uuid | |
import yaml | |
import json | |
import shutil | |
import torch | |
from pathlib import Path | |
from PIL import Image | |
from fastapi import FastAPI | |
from fastapi.responses import JSONResponse | |
from huggingface_hub import hf_hub_download, whoami | |
# ========== CONFIGURATION ========== | |
REPO_ID = "rahul7star/ohamlab" | |
FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81" | |
CONCEPT_SENTENCE = "ohamlab style" | |
LORA_NAME = "ohami_filter_autorun" | |
# ========== FASTAPI APP ========== | |
app = FastAPI() | |
# ========== HELPERS ========== | |
def create_dataset(images, *captions): | |
destination_folder = f"datasets_{uuid.uuid4()}" | |
os.makedirs(destination_folder, exist_ok=True) | |
jsonl_file_path = os.path.join(destination_folder, "metadata.jsonl") | |
with open(jsonl_file_path, "a") as jsonl_file: | |
for index, image in enumerate(images): | |
new_image_path = shutil.copy(str(image), destination_folder) | |
caption = captions[index] | |
file_name = os.path.basename(new_image_path) | |
data = {"file_name": file_name, "prompt": caption} | |
jsonl_file.write(json.dumps(data) + "\n") | |
return destination_folder | |
def recursive_update(d, u): | |
for k, v in u.items(): | |
if isinstance(v, dict) and v: | |
d[k] = recursive_update(d.get(k, {}), v) | |
else: | |
d[k] = v | |
return d | |
def start_training( | |
lora_name, | |
concept_sentence, | |
steps, | |
lr, | |
rank, | |
model_to_train, | |
low_vram, | |
dataset_folder, | |
sample_1, | |
sample_2, | |
sample_3, | |
use_more_advanced_options, | |
more_advanced_options, | |
): | |
try: | |
user = whoami() | |
username = user.get("name", "anonymous") | |
push_to_hub = True | |
except: | |
username = "anonymous" | |
push_to_hub = False | |
slugged_lora_name = lora_name.replace(" ", "_").lower() | |
# Load base config | |
config = { | |
"config": { | |
"name": slugged_lora_name, | |
"process": [ | |
{ | |
"model": { | |
"low_vram": low_vram, | |
"is_flux": True, | |
"quantize": True, | |
"name_or_path": "black-forest-labs/FLUX.1-dev" | |
}, | |
"network": { | |
"linear": rank, | |
"linear_alpha": rank, | |
"type": "lora" | |
}, | |
"train": { | |
"steps": steps, | |
"lr": lr, | |
"skip_first_sample": True, | |
"batch_size": 1, | |
"dtype": "bf16", | |
"gradient_accumulation_steps": 1, | |
"gradient_checkpointing": True, | |
"noise_scheduler": "flowmatch", | |
"optimizer": "adamw8bit", | |
"ema_config": { | |
"use_ema": True, | |
"ema_decay": 0.99 | |
} | |
}, | |
"datasets": [ | |
{"folder_path": dataset_folder} | |
], | |
"save": { | |
"dtype": "float16", | |
"save_every": 10000, | |
"push_to_hub": push_to_hub, | |
"hf_repo_id": f"{username}/{slugged_lora_name}", | |
"hf_private": True, | |
"max_step_saves_to_keep": 4 | |
}, | |
"sample": { | |
"guidance_scale": 3.5, | |
"sample_every": steps, | |
"sample_steps": 28, | |
"width": 1024, | |
"height": 1024, | |
"walk_seed": True, | |
"seed": 42, | |
"sampler": "flowmatch", | |
"prompts": [p for p in [sample_1, sample_2, sample_3] if p] | |
}, | |
"trigger_word": concept_sentence | |
} | |
] | |
} | |
} | |
# Apply advanced YAML overrides if any | |
if use_more_advanced_options and more_advanced_options: | |
advanced_config = yaml.safe_load(more_advanced_options) | |
config["config"]["process"][0] = recursive_update(config["config"]["process"][0], advanced_config) | |
# Save YAML config | |
os.makedirs("tmp_configs", exist_ok=True) | |
config_path = f"tmp_configs/{uuid.uuid4()}_{slugged_lora_name}.yaml" | |
with open(config_path, "w") as f: | |
yaml.dump(config, f) | |
# Simulate training | |
print(f"[INFO] Starting training with config: {config_path}") | |
print(json.dumps(config, indent=2)) | |
return f"Training started successfully with config: {config_path}" | |
# ========== MAIN ENDPOINT ========== | |
def auto_run_lora_from_repo(): | |
try: | |
local_dir = Path(f"/tmp/{LORA_NAME}-{uuid.uuid4()}") | |
os.makedirs(local_dir, exist_ok=True) | |
hf_hub_download( | |
repo_id=REPO_ID, | |
repo_type="dataset", | |
subfolder=FOLDER_IN_REPO, | |
local_dir=local_dir, | |
local_dir_use_symlinks=False, | |
force_download=False, | |
etag_timeout=10, | |
allow_patterns=["*.jpg", "*.png", "*.jpeg"], | |
) | |
image_dir = local_dir / FOLDER_IN_REPO | |
image_paths = list(image_dir.rglob("*.jpg")) + list(image_dir.rglob("*.jpeg")) + list(image_dir.rglob("*.png")) | |
if not image_paths: | |
return JSONResponse(status_code=400, content={"error": "No images found in the HF repo folder."}) | |
captions = [ | |
f"Autogenerated caption for {img.stem} in the {CONCEPT_SENTENCE} [trigger]" for img in image_paths | |
] | |
dataset_path = create_dataset(image_paths, *captions) | |
result = start_training( | |
lora_name=LORA_NAME, | |
concept_sentence=CONCEPT_SENTENCE, | |
steps=1000, | |
lr=4e-4, | |
rank=16, | |
model_to_train="dev", | |
low_vram=True, | |
dataset_folder=dataset_path, | |
sample_1=f"A stylized portrait using {CONCEPT_SENTENCE}", | |
sample_2=f"A cat in the {CONCEPT_SENTENCE}", | |
sample_3=f"A selfie processed in {CONCEPT_SENTENCE}", | |
use_more_advanced_options=True, | |
more_advanced_options=""" | |
training: | |
seed: 42 | |
precision: bf16 | |
batch_size: 2 | |
augmentation: | |
flip: true | |
color_jitter: true | |
""" | |
) | |
return {"message": result} | |
except Exception as e: | |
return JSONResponse(status_code=500, content={"error": str(e)}) | |
# ========== FASTAPI RUNNER ========== | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |