Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
"use server" | |
import Replicate from "replicate" | |
import { generateSeed } from "../../utils/misc/generateSeed.mts" | |
import { sleep } from "../../utils/misc/sleep.mts" | |
const replicateToken = `${process.env.VC_REPLICATE_API_TOKEN || ""}` | |
const replicateModel = `wyhsirius/lia` | |
const replicateModelVersion = "4ce4e4aff5bd28c6958b1e3e7628ea80718be56672d92ea8039039a3a152e67d" | |
if (!replicateToken) { | |
throw new Error(`you need to configure your VC_REPLICATE_API_TOKEN`) | |
} | |
const replicate = new Replicate({ auth: replicateToken }) | |
/** | |
* Generate a video with hotshot through Replicate | |
* | |
* Note that if nbFrames == 1, then it will generate a jpg | |
* | |
*/ | |
export async function generateVideoWithHotshotReplicate({ | |
}): Promise<string> { | |
if (!replicateModel) { | |
throw new Error(`you need to configure the replicateModel`) | |
} | |
if (!replicateModelVersion) { | |
throw new Error(`you need to configure the replicateModelVersion`) | |
} | |
const prediction = await replicate.predictions.create({ | |
version: replicateModelVersion, | |
input: { | |
img_source: , | |
negative_prompt: negativePrompt, | |
// this is not a URL but a model name | |
hf_lora_url: replicateLora?.length ? undefined : huggingFaceLora, | |
// this is a URL to the .tar (we can get it from the "trainings" page) | |
replicate_weights_url: huggingFaceLora?.length ? undefined : replicateLora, | |
width, | |
height, | |
// those are used to create an upsampling or downsampling | |
// original_width: width, | |
// original_height: height, | |
// target_width: width, | |
// target_height: height, | |
steps: nbSteps, | |
// note: right now it only makes sense to use either 1 (a jpg) | |
video_length: nbFrames, // nb frames | |
video_duration: videoDuration, // video duration in ms | |
seed: !isNaN(seed) && isFinite(seed) ? seed : generateSeed() | |
} | |
}) | |
// console.log("prediction:", prediction) | |
// Replicate requires at least 30 seconds of mandatory delay | |
await sleep(30000) | |
let res: Response | |
let pollingCount = 0 | |
do { | |
// Check every 5 seconds | |
await sleep(5000) | |
res = await fetch(`https://api.replicate.com/v1/predictions/${prediction.id}`, { | |
method: "GET", | |
headers: { | |
Authorization: `Token ${replicateToken}`, | |
}, | |
cache: 'no-store', | |
}) | |
if (res.status === 200) { | |
const response = (await res.json()) as any | |
const error = `${response?.error || ""}` | |
if (error) { | |
throw new Error(error) | |
} | |
} | |
pollingCount++ | |
// To prevent indefinite polling, we can stop after a certain number, here 30 (i.e. about 2 and half minutes) | |
if (pollingCount >= 30) { | |
throw new Error('Request time out.') | |
} | |
} while (true) | |
} |