ofai-it2v2 / app.py
ginipick's picture
Update app.py
4bd528b verified
raw
history blame
6.92 kB
import gradio as gr
from gradio_client import Client
import os
import logging
import json
from datetime import datetime
import tempfile
import numpy as np
from PIL import Image
import shutil
import httpx
import time
import base64
from gradio_client import Client, handle_file
import cv2
from moviepy.editor import VideoFileClip
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
# ํƒ€์ž„์•„์›ƒ ์„ค์ •์„ 30์ดˆ๋กœ ๋Š˜๋ฆผ
httpx_client = httpx.Client(timeout=30.0)
max_retries = 3
retry_delay = 5 # 5์ดˆ ๋Œ€๊ธฐ
for attempt in range(max_retries):
try:
api_client = Client("http://211.233.58.202:7960/")
api_client.httpx_client = httpx_client # httpx ํด๋ผ์ด์–ธํŠธ ์„ค์ •
break # ์„ฑ๊ณตํ•˜๋ฉด ๋ฃจํ”„ ์ข…๋ฃŒ
except httpx.ReadTimeout:
if attempt < max_retries - 1: # ๋งˆ์ง€๋ง‰ ์‹œ๋„๊ฐ€ ์•„๋‹ˆ๋ฉด
print(f"Connection timed out. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
print("Failed to connect after multiple attempts.")
raise # ๋ชจ๋“  ์‹œ๋„ ์‹คํŒจ ์‹œ ์˜ˆ์™ธ ๋ฐœ์ƒ
# ๊ฐค๋Ÿฌ๋ฆฌ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
GALLERY_DIR = "gallery"
GALLERY_JSON = "gallery.json"
# ๊ฐค๋Ÿฌ๋ฆฌ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
os.makedirs(GALLERY_DIR, exist_ok=True)
def save_to_gallery(video_path, prompt):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
new_video_path = os.path.join(GALLERY_DIR, f"{timestamp}.mp4")
thumbnail_path = os.path.join(GALLERY_DIR, f"{timestamp}_thumb.jpg")
# ๋น„๋””์˜ค ํŒŒ์ผ ๋ณต์‚ฌ
shutil.copy2(video_path, new_video_path)
# ์ธ๋„ค์ผ ์ƒ์„ฑ
video = VideoFileClip(new_video_path)
video.save_frame(thumbnail_path, t=0) # ์ฒซ ํ”„๋ ˆ์ž„์„ ์ธ๋„ค์ผ๋กœ ์ €์žฅ
# ๊ฐค๋Ÿฌ๋ฆฌ ์ •๋ณด ์ €์žฅ
gallery_info = {
"video": new_video_path,
"thumbnail": thumbnail_path,
"prompt": prompt,
"timestamp": timestamp
}
if os.path.exists(GALLERY_JSON):
with open(GALLERY_JSON, "r") as f:
gallery = json.load(f)
else:
gallery = []
gallery.append(gallery_info)
with open(GALLERY_JSON, "w") as f:
json.dump(gallery, f, indent=2)
return new_video_path, thumbnail_path
def load_gallery():
if os.path.exists(GALLERY_JSON):
with open(GALLERY_JSON, "r") as f:
gallery = json.load(f)
return [(item["thumbnail"], item["prompt"], item["video"]) for item in reversed(gallery)]
return []
def respond(image, prompt, steps, cfg_scale, eta, fs, seed, video_length):
logging.info(f"Received prompt: {prompt}, steps: {steps}, cfg_scale: {cfg_scale}, "
f"eta: {eta}, fs: {fs}, seed: {seed}, video_length: {video_length}")
try:
# ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ฒ˜๋ฆฌ
if image is not None:
image_file = handle_file(image)
else:
image_file = None
# ๋น„๋””์˜ค ์ƒ์„ฑ ์š”์ฒญ
result = api_client.predict(
image=image_file,
prompt=prompt,
steps=steps,
cfg_scale=cfg_scale,
eta=eta,
fs=fs,
seed=seed,
video_length=video_length,
api_name="/infer"
)
logging.info("API response received: %s", result)
# ๊ฒฐ๊ณผ ํ™•์ธ ๋ฐ ์ฒ˜๋ฆฌ
if isinstance(result, dict) and 'video' in result:
saved_video_path = save_to_gallery(result['video'], prompt)
return saved_video_path
else:
raise ValueError("Unexpected API response format")
except Exception as e:
logging.error("Error during API request: %s", str(e))
return "Failed to generate video due to an error."
css = """
footer {
visibility: hidden;
}
"""
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ์œ„ํ•œ ์˜ˆ์ œ ํ”„๋กฌํ”„ํŠธ
examples = [
["A glamorous young woman with long, wavy blonde hair and smokey eye makeup, posing in a luxury hotel room. Sheโ€™s wearing a sparkly gold cocktail dress and holding up a white card with 'openfree.ai' written on it in elegant calligraphy. Soft, warm lighting creates a luxurious atmosphere. ", "q1.webp"],
["A fantasy map of a fictional world, with detailed terrain and cities.", "q19.webp"]
]
def use_prompt_and_image(prompt, image):
return prompt, image
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
with gr.Tab("Generate"):
with gr.Row():
input_image = gr.Image(label="Upload an image", type="filepath")
input_text = gr.Textbox(label="Enter your prompt for video generation")
output_video = gr.Video(label="Generated Video")
with gr.Row():
steps = gr.Slider(minimum=1, maximum=100, step=1, label="Steps", value=30)
cfg_scale = gr.Slider(minimum=1, maximum=15, step=0.1, label="CFG Scale", value=3.5)
eta = gr.Slider(minimum=0, maximum=1, step=0.1, label="ETA", value=1)
fs = gr.Slider(minimum=1, maximum=30, step=1, label="FPS", value=8)
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="Seed", value=123)
video_length = gr.Slider(minimum=1, maximum=10, step=1, label="Video Length (seconds)", value=2)
with gr.Row():
for prompt, image_file in examples:
with gr.Column():
gr.Image(image_file, label=prompt[:50] + "...")
gr.Button("Use this example").click(
fn=use_prompt_and_image,
inputs=[],
outputs=[input_text, input_image],
api_name=False
).then(
lambda p=prompt, i=image_file: (p, i),
inputs=[],
outputs=[input_text, input_image]
)
with gr.Tab("Gallery"):
gallery = gr.Gallery(
label="Generated Videos",
show_label=False,
elem_id="gallery",
columns=[5],
rows=[3],
object_fit="contain",
height="auto"
)
selected_video = gr.Video(label="Selected Video")
refresh_btn = gr.Button("Refresh Gallery")
def update_gallery():
return load_gallery()
def show_video(evt: gr.SelectData):
return evt.value[2] # ์„ ํƒ๋œ ์ธ๋„ค์ผ์— ํ•ด๋‹นํ•˜๋Š” ๋น„๋””์˜ค ๊ฒฝ๋กœ ๋ฐ˜ํ™˜
refresh_btn.click(fn=update_gallery, inputs=None, outputs=gallery)
demo.load(fn=update_gallery, inputs=None, outputs=gallery)
gallery.select(show_video, None, selected_video)
input_text.submit(
fn=respond,
inputs=[input_image, input_text, steps, cfg_scale, eta, fs, seed, video_length],
outputs=output_video
).then(
fn=update_gallery,
inputs=None,
outputs=gallery
)
if __name__ == "__main__":
demo.launch()