|
from huggingface_hub import InferenceClient |
|
import base64 |
|
import os |
|
import re |
|
from pathlib import Path |
|
import time |
|
|
|
def save_video(base64_video: str, output_path: str): |
|
"""Save base64 encoded video to a file""" |
|
|
|
if base64_video.startswith('data:video/mp4;base64,'): |
|
base64_video = base64_video.split('base64,')[1] |
|
|
|
video_bytes = base64.b64decode(base64_video) |
|
with open(output_path, "wb") as f: |
|
f.write(video_bytes) |
|
print(f"Video saved to: {output_path}") |
|
|
|
def generate_video( |
|
prompt: str, |
|
endpoint_url: str, |
|
token: str = None, |
|
resolution: str = "1280x720", |
|
video_length: int = 129, |
|
num_inference_steps: int = 30, |
|
seed: int = -1, |
|
guidance_scale: float = 1.0, |
|
flow_shift: float = 7.0, |
|
embedded_guidance_scale: float = 6.0, |
|
enable_riflex: bool = True, |
|
tea_cache: float = 0.0 |
|
) -> str: |
|
"""Generate a video using the custom inference endpoint. |
|
|
|
Args: |
|
prompt: Text prompt describing the video |
|
endpoint_url: Full URL to the inference endpoint |
|
token: HuggingFace API token for authentication |
|
resolution: Video resolution (default: "1280x720") |
|
video_length: Number of frames (default: 129) |
|
num_inference_steps: Number of inference steps (default: 30) |
|
seed: Random seed, -1 for random (default: -1) |
|
guidance_scale: Guidance scale value (default: 1.0) |
|
flow_shift: Flow shift value (default: 7.0) |
|
embedded_guidance_scale: Embedded guidance scale (default: 6.0) |
|
enable_riflex: Enable RIFLEx positional embedding for long videos (default: True) |
|
tea_cache: TeaCache acceleration threshold, 0.0 to disable, 0.1 for 1.6x speedup, 0.15 for 2.1x speedup (default: 0.0) |
|
|
|
Returns: |
|
Path to the saved video file |
|
""" |
|
|
|
client = InferenceClient(model=endpoint_url, token=token) |
|
|
|
print(f"Generating video with prompt: \"{prompt}\"") |
|
print(f"Resolution: {resolution}, Length: {video_length} frames") |
|
print(f"Steps: {num_inference_steps}, Seed: {'random' if seed == -1 else seed}") |
|
|
|
|
|
safe_prompt = re.sub(r'[^\w\s-]', '', prompt)[:50].strip().replace(' ', '_') |
|
|
|
|
|
payload = { |
|
"inputs": prompt, |
|
"resolution": resolution, |
|
"video_length": video_length, |
|
"num_inference_steps": num_inference_steps, |
|
"seed": seed, |
|
"guidance_scale": guidance_scale, |
|
"flow_shift": flow_shift, |
|
"embedded_guidance_scale": embedded_guidance_scale, |
|
"enable_riflex": enable_riflex, |
|
"tea_cache": tea_cache |
|
} |
|
|
|
|
|
start_time = time.time() |
|
print("Sending request to endpoint...") |
|
|
|
try: |
|
response = client.post(json=payload) |
|
|
|
|
|
if response.headers.get('content-type') == 'application/json': |
|
result = response.json() |
|
video_data = result.get("video_base64", result) |
|
else: |
|
|
|
video_data = response.text |
|
|
|
generation_time = time.time() - start_time |
|
print(f"Video generated in {generation_time:.2f} seconds") |
|
|
|
|
|
timestamp = int(time.time()) |
|
output_path = f"{safe_prompt}_{timestamp}.mp4" |
|
|
|
|
|
if isinstance(video_data, str) and video_data.startswith('data:video/mp4;base64,'): |
|
save_video(video_data, output_path) |
|
elif isinstance(video_data, str): |
|
save_video(video_data, output_path) |
|
else: |
|
|
|
save_video(video_data.get("video_base64", ""), output_path) |
|
|
|
return output_path |
|
|
|
except Exception as e: |
|
print(f"Error generating video: {e}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
hf_api_token = os.environ.get('HF_API_TOKEN', '') |
|
endpoint_url = os.environ.get('ENDPOINT_URL', '') |
|
|
|
if not endpoint_url: |
|
print("Please set the ENDPOINT_URL environment variable") |
|
exit(1) |
|
|
|
video_path = generate_video( |
|
endpoint_url=endpoint_url, |
|
token=hf_api_token, |
|
prompt="A cat walks on the grass, realistic style.", |
|
|
|
|
|
resolution="1280x720", |
|
video_length=97, |
|
|
|
|
|
num_inference_steps=22, |
|
seed=-1, |
|
|
|
|
|
guidance_scale=1.0, |
|
embedded_guidance_scale=6.0, |
|
flow_shift=7.0, |
|
|
|
|
|
enable_riflex=True, |
|
tea_cache=0.0 |
|
) |