jbilcke-hf's picture
jbilcke-hf HF staff
Update demo.py
84a1d8e verified
from huggingface_hub import InferenceClient
import base64
import os
import re
from pathlib import Path
import time
def save_video(base64_video: str, output_path: str):
"""Save base64 encoded video to a file"""
# Handle data URI format if present
if base64_video.startswith('data:video/mp4;base64,'):
base64_video = base64_video.split('base64,')[1]
video_bytes = base64.b64decode(base64_video)
with open(output_path, "wb") as f:
f.write(video_bytes)
print(f"Video saved to: {output_path}")
def generate_video(
prompt: str,
endpoint_url: str,
token: str = None,
resolution: str = "1280x720",
video_length: int = 129,
num_inference_steps: int = 30,
seed: int = -1,
guidance_scale: float = 1.0,
flow_shift: float = 7.0,
embedded_guidance_scale: float = 6.0,
enable_riflex: bool = True,
tea_cache: float = 0.0
) -> str:
"""Generate a video using the custom inference endpoint.
Args:
prompt: Text prompt describing the video
endpoint_url: Full URL to the inference endpoint
token: HuggingFace API token for authentication
resolution: Video resolution (default: "1280x720")
video_length: Number of frames (default: 129)
num_inference_steps: Number of inference steps (default: 30)
seed: Random seed, -1 for random (default: -1)
guidance_scale: Guidance scale value (default: 1.0)
flow_shift: Flow shift value (default: 7.0)
embedded_guidance_scale: Embedded guidance scale (default: 6.0)
enable_riflex: Enable RIFLEx positional embedding for long videos (default: True)
tea_cache: TeaCache acceleration threshold, 0.0 to disable, 0.1 for 1.6x speedup, 0.15 for 2.1x speedup (default: 0.0)
Returns:
Path to the saved video file
"""
# Initialize client
client = InferenceClient(model=endpoint_url, token=token)
print(f"Generating video with prompt: \"{prompt}\"")
print(f"Resolution: {resolution}, Length: {video_length} frames")
print(f"Steps: {num_inference_steps}, Seed: {'random' if seed == -1 else seed}")
# Sanitize filename from prompt
safe_prompt = re.sub(r'[^\w\s-]', '', prompt)[:50].strip().replace(' ', '_')
# Prepare payload
payload = {
"inputs": prompt,
"resolution": resolution,
"video_length": video_length,
"num_inference_steps": num_inference_steps,
"seed": seed,
"guidance_scale": guidance_scale,
"flow_shift": flow_shift,
"embedded_guidance_scale": embedded_guidance_scale,
"enable_riflex": enable_riflex,
"tea_cache": tea_cache
}
# Make request
start_time = time.time()
print("Sending request to endpoint...")
try:
response = client.post(json=payload)
# Check if the response is a string (data URI) or JSON
if response.headers.get('content-type') == 'application/json':
result = response.json()
video_data = result.get("video_base64", result)
else:
# The response might be directly the data URI
video_data = response.text
generation_time = time.time() - start_time
print(f"Video generated in {generation_time:.2f} seconds")
# Save video
timestamp = int(time.time())
output_path = f"{safe_prompt}_{timestamp}.mp4"
# If the response is a data URI, extract the base64 part
if isinstance(video_data, str) and video_data.startswith('data:video/mp4;base64,'):
save_video(video_data, output_path)
elif isinstance(video_data, str):
save_video(video_data, output_path)
else:
# Assume it's a dictionary with a base64 key
save_video(video_data.get("video_base64", ""), output_path)
return output_path
except Exception as e:
print(f"Error generating video: {e}")
raise
if __name__ == "__main__":
hf_api_token = os.environ.get('HF_API_TOKEN', '')
endpoint_url = os.environ.get('ENDPOINT_URL', '')
if not endpoint_url:
print("Please set the ENDPOINT_URL environment variable")
exit(1)
video_path = generate_video(
endpoint_url=endpoint_url,
token=hf_api_token,
prompt="A cat walks on the grass, realistic style.",
# Video configuration
resolution="1280x720", # Standard HD resolution
video_length=97, # About 4 seconds at 24fps
# Generation parameters
num_inference_steps=22, # Default for standard model
seed=-1, # Random seed
# Advanced parameters
guidance_scale=1.0,
embedded_guidance_scale=6.0,
flow_shift=7.0,
# Optimizations
enable_riflex=True, # Better for videos longer than 4 seconds
tea_cache=0.0 # Set to 0.1 or 0.15 for faster generation with slight quality loss
)