Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from main import setup, execute_task | |
from arguments import parse_args | |
import os | |
import shutil | |
import glob | |
import time | |
import threading | |
import argparse | |
def list_iter_images(save_dir): | |
# Specify the image extensions you want to search for | |
image_extensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp'] # Add more if needed | |
# Create a list to store the image file paths | |
image_paths = [] | |
# Iterate through the specified image extensions and get the file paths | |
for ext in image_extensions: | |
# Use glob to find all image files with the given extension | |
image_paths.extend(glob.glob(os.path.join(save_dir, f'*.{ext}'))) | |
# Now image_paths contains the list of all image file paths | |
print(image_paths) | |
return image_paths | |
def clean_dir(save_dir): | |
# Check if the directory exists | |
if os.path.exists(save_dir): | |
# Check if the directory contains any files | |
if len(os.listdir(save_dir)) > 0: | |
# If it contains files, delete all files in the directory | |
for filename in os.listdir(save_dir): | |
file_path = os.path.join(save_dir, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) # Remove file or symbolic link | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) # Remove directory and its contents | |
except Exception as e: | |
print(f"Failed to delete {file_path}. Reason: {e}") | |
print(f"All files in {save_dir} have been deleted.") | |
else: | |
print(f"{save_dir} exists but is empty.") | |
else: | |
print(f"{save_dir} does not exist.") | |
def start_over(gallery_state): | |
if gallery_state is not None: | |
gallery_state = None | |
return gallery_state, None, None, gr.update(visible=False) | |
def setup_model(prompt, model, num_iterations, learning_rate, progress=gr.Progress(track_tqdm=True)): | |
"""Clear CUDA memory before starting the training.""" | |
torch.cuda.empty_cache() # Free up cached memory | |
# Set up arguments | |
args = parse_args() | |
args.task = "single" | |
args.prompt = prompt | |
args.model = model | |
args.n_iters = num_iterations | |
args.lr = learning_rate | |
args.cache_dir = "./HF_model_cache" | |
args.save_dir = "./outputs" | |
args.save_all_images = True | |
args, trainer, device, dtype, shape, enable_grad, settings = setup(args) | |
loaded_setup = [args, trainer, device, dtype, shape, enable_grad, settings] | |
return None, loaded_setup | |
def generate_image(setup_args, num_iterations): | |
args = setup_args[0] | |
trainer = setup_args[1] | |
device = setup_args[2] | |
dtype = setup_args[3] | |
shape = setup_args[4] | |
enable_grad = setup_args[5] | |
settings = setup_args[6] | |
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}" | |
clean_dir(save_dir) | |
try: | |
steps_completed = [] | |
result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None} | |
# Define progress_callback that updates steps_completed | |
def progress_callback(step): | |
steps_completed.append(step) | |
# Function to run main in a separate thread | |
def run_main(): | |
result_container["best_image"], result_container["total_init_rewards"], result_container["total_best_rewards"] = execute_task(args, trainer, device, dtype, shape, enable_grad, settings, progress_callback) | |
# Start main in a separate thread | |
main_thread = threading.Thread(target=run_main) | |
main_thread.start() | |
last_step_yielded = 0 | |
while main_thread.is_alive() or last_step_yielded < num_iterations: | |
# Check if new steps have been completed | |
if steps_completed and steps_completed[-1] > last_step_yielded: | |
last_step_yielded = steps_completed[-1] | |
png_number = last_step_yielded - 1 | |
# Get the image for this step | |
image_path = os.path.join(save_dir, f"{png_number}.png") | |
if os.path.exists(image_path): | |
yield (image_path, f"Iteration {last_step_yielded}/{num_iterations} - Image saved", None) | |
else: | |
yield (None, f"Iteration {last_step_yielded}/{num_iterations} - Image not found", None) | |
else: | |
# Small sleep to prevent busy waiting | |
time.sleep(0.1) | |
main_thread.join() | |
# After main is complete, yield the final image | |
final_image_path = os.path.join(save_dir, "best_image.png") | |
if os.path.exists(final_image_path): | |
iter_images = list_iter_images(save_dir) | |
yield (final_image_path, f"Final image saved at {final_image_path}", iter_images) | |
else: | |
yield (None, "Image generation completed, but no final image was found.", None) | |
except Exception as e: | |
yield (None, f"An error occurred: {str(e)}", None) | |
def show_gallery_output(gallery_state): | |
if gallery_state is not None: | |
return gr.update(value=gallery_state, visible=True) | |
else: | |
return gr.update(value=None, visible=False) | |
# Create Gradio interface | |
title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization" | |
description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed." | |
with gr.Blocks() as demo: | |
loaded_model_setup = gr.State() | |
gallery_state = gr.State() | |
with gr.Column(): | |
gr.Markdown(title) | |
gr.Markdown(description) | |
gr.HTML(""" | |
<div style="display:flex;column-gap:4px;"> | |
<a href='https://github.com/ExplainableML/ReNO'> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
</a> | |
<a href='https://arxiv.org/abs/2406.04312v1'> | |
<img src='https://img.shields.io/badge/Paper-Arxiv-red'> | |
</a> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt") | |
with gr.Row(): | |
chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd"], label="Model", value="sd-turbo") | |
model_status = gr.Textbox(label="model status", visible=False) | |
with gr.Row(): | |
n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations") | |
learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate") | |
submit_btn = gr.Button("Submit") | |
gr.Examples( | |
examples = [ | |
"A minimalist logo design of a reindeer, fully rendered. The reindeer features distinct, complete shapes using bold and flat colors. The design emphasizes simplicity and clarity, suitable for logo use with a sharp outline and white background.", | |
"A blue scooter is parked near a curb in front of a green vintage car", | |
"A impressionistic oil painting: a lone figure walking on a misty beach, a weathered lighthouse on a cliff, seagulls above crashing waves", | |
"A bird with 8 legs", | |
"An orange chair to the right of a black airplane", | |
"A pink elephant and a grey cow", | |
], | |
inputs = [prompt] | |
) | |
with gr.Column(): | |
output_image = gr.Image(type="filepath", label="Best Generated Image") | |
status = gr.Textbox(label="Status") | |
iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False) | |
submit_btn.click( | |
fn = start_over, | |
inputs =[gallery_state], | |
outputs = [gallery_state, output_image, status, iter_gallery] | |
).then( | |
fn = setup_model, | |
inputs = [prompt, chosen_model, n_iter, learning_rate], | |
outputs = [output_image, loaded_model_setup] | |
).then( | |
fn = generate_image, | |
inputs = [loaded_model_setup, n_iter], | |
outputs = [output_image, status, gallery_state] | |
).then( | |
fn = show_gallery_output, | |
inputs = [gallery_state], | |
outputs = iter_gallery | |
) | |
# Launch the app | |
demo.queue().launch(show_error=True, show_api=False) |