ttv / app.py
Spanicin's picture
Update app.py
553e650 verified
raw
history blame
5.17 kB
import torch
from diffusers import AnimateDiffPipeline, LCMScheduler, MotionAdapter
from diffusers.utils import export_to_video
from flask import Flask, request, jsonify
from flask_cors import CORS
import base64
import tempfile
import os
import threading
import traceback
import cv2
import numpy as np
app = Flask(__name__)
CORS(app)
pipe = None
app.config['temp_response'] = None
app.config['generation_thread'] = None
def download_pipeline():
global pipe
try:
print('Downloading the model weights')
# Download and initialize the animation pipeline
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=torch.float16)
pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, torch_dtype=torch.float16)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
return True
except Exception as e:
print(f"Error downloading pipeline: {e}")
return False
def export_video(frames_list, temp_video_path):
# Convert PIL images to numpy arrays
frames_np = [np.array(frame) for frame in frames_list]
# Determine the dimensions (height, width) of the frames
height, width, _ = frames_np[0].shape
# Create a VideoWriter object to write the frames to a video file
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4 format
out = cv2.VideoWriter(temp_video_path, fourcc, 12, (width, height)) # Adjust fps as needed
# Write each frame to the video file
for frame_np in frames_np:
# Convert RGB image (numpy array) to BGR (OpenCV format) for writing
frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
# Release the VideoWriter object
out.release()
def generate_and_export_animation(prompt):
global pipe
# Ensure the animation pipeline is initialized
if pipe is None:
if not download_pipeline():
return None, "Failed to initialize animation pipeline"
try:
# Generate animation frames
print('Generating Video frames')
output = pipe(
prompt=prompt,
negative_prompt="bad quality, worse quality, low resolution, blur",
num_frames=16,
guidance_scale=2.0,
num_inference_steps=6
)
print('Video frames generated')
# Export frames to a temporary video file
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
temp_video_path = temp_file.name
print('temp_video_path', temp_video_path)
# export_to_video(output.frames[0], temp_video_path)
export_video(output.frames[0], temp_video_path)
with open(temp_video_path, 'rb') as video_file:
video_binary = video_file.read()
video_base64 = base64.b64encode(video_binary).decode('utf-8')
os.remove(temp_video_path)
response_data = {'video_base64': '','status':None}
response_data['video_base64'] = video_base64
print('response_data',response_data)
return response_data
except Exception as e:
print(f"Error generating animation: {e}")
# return None, "Failed to generate animation"
traceback.print_exc() # Print exception details to console
return jsonify({"error": f"Failed to generate animation: {str(e)}"}), 500
def background(prompt):
with app.app_context():
temp_response = generate_and_export_animation(prompt)
# json_content = temp_response.get_json()
app.config['temp_response'] = temp_response
@app.route('/run', methods=['POST'])
def handle_animation_request():
prompt = request.form.get('prompt')
if prompt:
generation_thread = threading.Thread(target=background, args=(prompt,))
app.config['generation_thread'] = generation_thread
generation_thread.start()
response_data = {"message": "Video generation started", "process_id": generation_thread.ident}
return jsonify(response_data)
else:
return jsonify({"message": "Please provide a valid text prompt."}), 400
@app.route('/status', methods=['GET'])
def check_animation_status():
process_id = request.args.get('process_id',None)
if process_id:
generation_thread = app.config.get('generation_thread')
if generation_thread and generation_thread.is_alive():
return jsonify({"status": "in_progress"}), 200
elif app.config.get('temp_response'):
print('final',app.config.get('temp_response'))
# app.config['temp_response']['status'] = 'completed'
final_response = app.config['temp_response']
final_response['status'] = 'completed'
return jsonify(final_response)
if __name__ == '__main__':
app.run(debug=True)