|
import gradio as gr |
|
from gradio_client import Client, handle_file |
|
import os |
|
import logging |
|
import json |
|
from datetime import datetime |
|
import tempfile |
|
import numpy as np |
|
from PIL import Image |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
api_client = Client("http://211.233.58.202:7960/") |
|
|
|
|
|
GALLERY_DIR = "gallery" |
|
GALLERY_JSON = "gallery.json" |
|
|
|
|
|
os.makedirs(GALLERY_DIR, exist_ok=True) |
|
|
|
def save_to_gallery(video_path, prompt): |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
new_video_path = os.path.join(GALLERY_DIR, f"{timestamp}.mp4") |
|
|
|
|
|
with open(video_path, "rb") as src, open(new_video_path, "wb") as dst: |
|
dst.write(src.read()) |
|
|
|
|
|
gallery_info = { |
|
"video": new_video_path, |
|
"prompt": prompt, |
|
"timestamp": timestamp |
|
} |
|
|
|
if os.path.exists(GALLERY_JSON): |
|
with open(GALLERY_JSON, "r") as f: |
|
gallery = json.load(f) |
|
else: |
|
gallery = [] |
|
|
|
gallery.append(gallery_info) |
|
|
|
with open(GALLERY_JSON, "w") as f: |
|
json.dump(gallery, f, indent=2) |
|
|
|
return new_video_path |
|
|
|
def load_gallery(): |
|
if os.path.exists(GALLERY_JSON): |
|
with open(GALLERY_JSON, "r") as f: |
|
gallery = json.load(f) |
|
return [(item["video"], item["prompt"]) for item in reversed(gallery)] |
|
return [] |
|
|
|
|
|
def respond(image, prompt, steps, cfg_scale, eta, fs, seed, video_length): |
|
logging.info(f"Received prompt: {prompt}, steps: {steps}, cfg_scale: {cfg_scale}, " |
|
f"eta: {eta}, fs: {fs}, seed: {seed}, video_length: {video_length}") |
|
|
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: |
|
if isinstance(image, np.ndarray): |
|
img = Image.fromarray(image.astype('uint8'), 'RGB') |
|
img.save(temp_file.name) |
|
else: |
|
temp_file.write(image) |
|
temp_file_path = temp_file.name |
|
|
|
|
|
result = api_client.predict( |
|
image=temp_file_path, |
|
prompt=prompt, |
|
steps=steps, |
|
cfg_scale=cfg_scale, |
|
eta=eta, |
|
fs=fs, |
|
seed=seed, |
|
video_length=video_length, |
|
api_name="/infer" |
|
) |
|
logging.info("API response received: %s", result) |
|
|
|
|
|
os.unlink(temp_file_path) |
|
|
|
|
|
if isinstance(result, str) and result.endswith('.mp4'): |
|
saved_video_path = save_to_gallery(result, prompt) |
|
return saved_video_path |
|
else: |
|
raise ValueError("Unexpected API response format") |
|
except Exception as e: |
|
logging.error("Error during API request: %s", str(e)) |
|
return "Failed to generate video due to an error." |
|
|
|
css = """ |
|
footer { |
|
visibility: hidden; |
|
} |
|
""" |
|
|
|
|
|
|
|
examples = [ |
|
["A glamorous young woman with long, wavy blonde hair and smokey eye makeup, posing in a luxury hotel room. Sheโs wearing a sparkly gold cocktail dress and holding up a white card with 'openfree.ai' written on it in elegant calligraphy. Soft, warm lighting creates a luxurious atmosphere. ", "q1.webp"], |
|
["A fantasy map of a fictional world, with detailed terrain and cities.", "q19.webp"] |
|
] |
|
|
|
def use_prompt(prompt): |
|
return prompt |
|
|
|
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo: |
|
with gr.Tab("Generate"): |
|
with gr.Row(): |
|
input_image = gr.Image(label="Upload an image") |
|
input_text = gr.Textbox(label="Enter your prompt for video generation") |
|
output_video = gr.Video(label="Generated Video") |
|
|
|
with gr.Row(): |
|
steps = gr.Slider(minimum=1, maximum=100, step=1, label="Steps", value=30) |
|
cfg_scale = gr.Slider(minimum=1, maximum=15, step=0.1, label="CFG Scale", value=3.5) |
|
eta = gr.Slider(minimum=0, maximum=1, step=0.1, label="ETA", value=1) |
|
fs = gr.Slider(minimum=1, maximum=30, step=1, label="FPS", value=8) |
|
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="Seed", value=123) |
|
video_length = gr.Slider(minimum=1, maximum=10, step=1, label="Video Length (seconds)", value=2) |
|
|
|
with gr.Row(): |
|
for prompt, image_file in examples: |
|
with gr.Column(): |
|
gr.Image(image_file, label=prompt[:50] + "...") |
|
gr.Button("Use this prompt").click( |
|
fn=use_prompt, |
|
inputs=[], |
|
outputs=input_text, |
|
api_name=False |
|
).then( |
|
lambda x=prompt: x, |
|
inputs=[], |
|
outputs=input_text |
|
) |
|
|
|
with gr.Tab("Gallery"): |
|
gallery = gr.Gallery( |
|
label="Generated Videos", |
|
show_label=False, |
|
elem_id="gallery", |
|
columns=[5], |
|
rows=[3], |
|
object_fit="contain", |
|
height="auto" |
|
) |
|
refresh_btn = gr.Button("Refresh Gallery") |
|
|
|
def update_gallery(): |
|
return load_gallery() |
|
|
|
refresh_btn.click(fn=update_gallery, inputs=None, outputs=gallery) |
|
demo.load(fn=update_gallery, inputs=None, outputs=gallery) |
|
|
|
input_text.submit( |
|
fn=respond, |
|
inputs=[input_image, input_text, steps, cfg_scale, eta, fs, seed, video_length], |
|
outputs=output_video |
|
).then( |
|
fn=update_gallery, |
|
inputs=None, |
|
outputs=gallery |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|
|
|
|
|