LPX55's picture
Update app.py
cb15469 verified
raw
history blame
3.96 kB
import gradio as gr
import numpy as np
from io import BytesIO
from PIL import Image, ImageOps
import zipfile
import os
import atexit
import shutil
import cv2
import imageio
import torchvision.transforms.functional as TF
# Create a persistent directory to store generated files
GENERATED_FILES_DIR = "generated_files"
if not os.path.exists(GENERATED_FILES_DIR):
os.makedirs(GENERATED_FILES_DIR)
def cleanup_generated_files():
if os.path.exists(GENERATED_FILES_DIR):
shutil.rmtree(GENERATED_FILES_DIR)
# Register the cleanup function to run when the script exits
atexit.register(cleanup_generated_files)
def split_image_grid(image, grid_cols, grid_rows):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
width, height = image.width, image.height
cell_width = width // grid_cols
cell_height = height // grid_rows
frames = []
for i in range(grid_rows):
for j in range(grid_cols):
left = j * cell_width
upper = i * cell_height
right = left + cell_width
lower = upper + cell_height
frame = image.crop((left, upper, right, lower))
frames.append(np.array(frame))
return frames
def interpolate_frames(frames, factor=2):
interpolated_frames = []
for i in range(len(frames) - 1):
frame1 = frames[i]
frame2 = frames[i + 1]
interpolated_frames.append(frame1)
for j in range(1, factor):
t = j / factor
frame_t = cv2.addWeighted(frame1, 1 - t, frame2, t, 0)
interpolated_frames.append(frame_t)
interpolated_frames.append(frames[-1])
return interpolated_frames
def enhance_gif(images):
enhanced_images = []
for img in images:
img = ImageOps.autocontrast(Image.fromarray(img))
img = img.convert("RGB") # Ensure the image is in RGB mode
enhanced_images.append(np.array(img))
return enhanced_images
def create_gif_imageio(images, duration=50, loop=0):
gif_path = os.path.join(GENERATED_FILES_DIR, "output_enhanced.gif")
images_pil = [Image.fromarray(img) for img in images]
imageio.mimsave(gif_path, images_pil, duration=duration, loop=loop)
return gif_path
def process_image(image, grid_cols_input, grid_rows_input):
frames = split_image_grid(image, grid_cols_input, grid_rows_input)
zip_file = zip_images(frames)
return zip_file
def process_image_to_gif(image, grid_cols_input, grid_rows_input):
frames = split_image_grid(image, grid_cols_input, grid_rows_input)
interpolated_frames = interpolate_frames(frames, factor=2)
enhanced_frames = enhance_gif(interpolated_frames)
gif_file = create_gif_imageio(enhanced_frames, duration=50, loop=0)
return gif_file
def zip_images(images):
zip_path = os.path.join(GENERATED_FILES_DIR, "output.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
for idx, img in enumerate(images):
img_buffer = BytesIO()
img = Image.fromarray(img)
img.save(img_buffer, format='PNG')
img_buffer.seek(0)
zipf.writestr(f'image_{idx}.png', img_buffer.getvalue())
return zip_path
with gr.Blocks() as demo:
with gr.Row():
image_input = gr.Image(label="Input Image", type="pil")
grid_cols_input = gr.Slider(1, 10, value=2, step=1, label="Grid Columns")
grid_rows_input = gr.Slider(1, 10, value=2, step=1, label="Grid Rows")
with gr.Row():
zip_button = gr.Button("Create Zip File")
gif_button = gr.Button("Create GIF")
with gr.Row():
zip_output = gr.File(label="Download Zip File")
gif_output = gr.File(label="Download GIF")
zip_button.click(process_image, inputs=[image_input, grid_cols_input, grid_rows_input], outputs=zip_output)
gif_button.click(process_image_to_gif, inputs=[image_input, grid_cols_input, grid_rows_input], outputs=gif_output)
demo.launch(show_error=True)