import gradio as gr from gradio_imageslider import ImageSlider from PIL import Image import numpy as np from aura_sr import AuraSR import torch import os import time import platform import argparse # Global variable to control batch processing cancellation. stop_batch_flag = False def open_folder(): open_folder_path = os.path.abspath("outputs") if platform.system() == "Windows": os.startfile(open_folder_path) elif platform.system() == "Linux": os.system(f'xdg-open "{open_folder_path}"') def get_placeholder_image(): """ Creates a placeholder image (if not already present) and returns its file path. This placeholder is a blank (white) image that will be used for progress updates. """ placeholder_path = "placeholder.png" if not os.path.exists(placeholder_path): placeholder = Image.new("RGB", (256, 256), (255, 255, 255)) placeholder.save(placeholder_path) return placeholder_path # Force CPU usage torch.set_default_tensor_type(torch.FloatTensor) # Override torch.load to always use CPU original_load = torch.load torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, map_location=torch.device('cpu')) # Initialize the AuraSR model aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2") # Restore original torch.load torch.load = original_load def process_single_image(input_image_path, reduce_seams): if input_image_path is None: raise gr.Error("Please provide an image to upscale.") # Send an initial progress update. # Instead of (None, None), we use the placeholder image file paths. placeholder = get_placeholder_image() yield [(placeholder, placeholder), "Starting upscaling..."] # Load the image. pil_image = Image.open(input_image_path) # Upscale using the chosen method. start_time = time.time() if reduce_seams: print("using reduce seams") upscaled_image = aura_sr.upscale_4x_overlapped(pil_image) else: upscaled_image = aura_sr.upscale_4x(pil_image) processing_time = time.time() - start_time print(f"Processing time: {processing_time:.2f} seconds") # Save the upscaled image. output_folder = "outputs" os.makedirs(output_folder, exist_ok=True) input_filename = os.path.basename(input_image_path) output_filename = os.path.splitext(input_filename)[0] output_path = os.path.join(output_folder, output_filename + ".png") counter = 1 while os.path.exists(output_path): output_path = os.path.join(output_folder, f"{output_filename}_{counter:04d}.png") counter += 1 upscaled_image.save(output_path) # Send the final progress update along with the before/after slider images. yield [(input_image_path, output_path), f"Upscaling complete in {processing_time:.2f} seconds"] def process_batch(input_folder, output_folder=None, reduce_seams=False): global stop_batch_flag # Reset the stop flag for each new batch process. stop_batch_flag = False if not input_folder: raise gr.Error("Please provide an input folder path.") if not output_folder: output_folder = "outputs" os.makedirs(output_folder, exist_ok=True) input_files = [f for f in os.listdir(input_folder) if f.lower().endswith( ('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] total_files = len(input_files) processed_files = 0 results = [] yield [results, "Starting batch processing..."] for filename in input_files: # Check if the stop flag has been set. if stop_batch_flag: yield [results, "Batch processing cancelled by user."] return input_path = os.path.join(input_folder, filename) pil_image = Image.open(input_path) start_time = time.time() if reduce_seams: upscaled_image = aura_sr.upscale_4x_overlapped(pil_image) else: upscaled_image = aura_sr.upscale_4x(pil_image) processing_time = time.time() - start_time output_filename = os.path.splitext(filename)[0] + ".png" output_path = os.path.join(output_folder, output_filename) counter = 1 while os.path.exists(output_path): output_path = os.path.join(output_folder, f"{os.path.splitext(filename)[0]}_{counter:04d}.png") counter += 1 upscaled_image.save(output_path) processed_files += 1 results.append(output_path) yield [results, f"Processed {processed_files}/{total_files}: {filename} in {processing_time:.2f} seconds"] yield [results, f"Batch processing complete. {processed_files} images processed."] def stop_batch_process(): global stop_batch_flag stop_batch_flag = True return "Stop button clicked. Cancelling batch processing..." title = """