Spaces:
Paused
Paused
import type { Status } from "../types"; | |
import { | |
HOST_URL, | |
INVALID_URL_MSG, | |
QUEUE_FULL_MSG, | |
SPACE_METADATA_ERROR_MSG | |
} from "../constants"; | |
import type { ApiData, ApiInfo, Config, JsApiData } from "../types"; | |
import { determine_protocol } from "./init_helpers"; | |
export const RE_SPACE_NAME = /^[a-zA-Z0-9_\-\.]+\/[a-zA-Z0-9_\-\.]+$/; | |
export const RE_SPACE_DOMAIN = /.*hf\.space\/{0,1}$/; | |
export async function process_endpoint( | |
app_reference: string, | |
hf_token?: `hf_${string}` | |
): Promise<{ | |
space_id: string | false; | |
host: string; | |
ws_protocol: "ws" | "wss"; | |
http_protocol: "http:" | "https:"; | |
}> { | |
const headers: { Authorization?: string } = {}; | |
if (hf_token) { | |
headers.Authorization = `Bearer ${hf_token}`; | |
} | |
const _app_reference = app_reference.trim().replace(/\/$/, ""); | |
if (RE_SPACE_NAME.test(_app_reference)) { | |
// app_reference is a HF space name | |
try { | |
const res = await fetch( | |
`https://huggingface.co/api/spaces/${_app_reference}/${HOST_URL}`, | |
{ headers } | |
); | |
const _host = (await res.json()).host; | |
return { | |
space_id: app_reference, | |
...determine_protocol(_host) | |
}; | |
} catch (e) { | |
throw new Error(SPACE_METADATA_ERROR_MSG); | |
} | |
} | |
if (RE_SPACE_DOMAIN.test(_app_reference)) { | |
// app_reference is a direct HF space domain | |
const { ws_protocol, http_protocol, host } = | |
determine_protocol(_app_reference); | |
return { | |
space_id: host.replace(".hf.space", ""), | |
ws_protocol, | |
http_protocol, | |
host | |
}; | |
} | |
return { | |
space_id: false, | |
...determine_protocol(_app_reference) | |
}; | |
} | |
export const join_urls = (...urls: string[]): string => { | |
try { | |
return urls.reduce((base_url: string, part: string) => { | |
base_url = base_url.replace(/\/+$/, ""); | |
part = part.replace(/^\/+/, ""); | |
return new URL(part, base_url + "/").toString(); | |
}); | |
} catch (e) { | |
throw new Error(INVALID_URL_MSG); | |
} | |
}; | |
export function transform_api_info( | |
api_info: ApiInfo<ApiData>, | |
config: Config, | |
api_map: Record<string, number> | |
): ApiInfo<JsApiData> { | |
const transformed_info: ApiInfo<JsApiData> = { | |
named_endpoints: {}, | |
unnamed_endpoints: {} | |
}; | |
Object.keys(api_info).forEach((category) => { | |
if (category === "named_endpoints" || category === "unnamed_endpoints") { | |
transformed_info[category] = {}; | |
Object.entries(api_info[category]).forEach( | |
([endpoint, { parameters, returns }]) => { | |
const dependencyIndex = | |
config.dependencies.find( | |
(dep) => | |
dep.api_name === endpoint || | |
dep.api_name === endpoint.replace("/", "") | |
)?.id || | |
api_map[endpoint.replace("/", "")] || | |
-1; | |
const dependencyTypes = | |
dependencyIndex !== -1 | |
? config.dependencies.find((dep) => dep.id == dependencyIndex) | |
?.types | |
: { continuous: false, generator: false, cancel: false }; | |
if ( | |
dependencyIndex !== -1 && | |
config.dependencies.find((dep) => dep.id == dependencyIndex)?.inputs | |
?.length !== parameters.length | |
) { | |
const components = config.dependencies | |
.find((dep) => dep.id == dependencyIndex)! | |
.inputs.map( | |
(input) => config.components.find((c) => c.id === input)?.type | |
); | |
try { | |
components.forEach((comp, idx) => { | |
if (comp === "state") { | |
const new_param = { | |
component: "state", | |
example: null, | |
parameter_default: null, | |
parameter_has_default: true, | |
parameter_name: null, | |
hidden: true | |
}; | |
// @ts-ignore | |
parameters.splice(idx, 0, new_param); | |
} | |
}); | |
} catch (e) { | |
console.error(e); | |
} | |
} | |
const transform_type = ( | |
data: ApiData, | |
component: string, | |
serializer: string, | |
signature_type: "return" | "parameter" | |
): JsApiData => ({ | |
...data, | |
description: get_description(data?.type, serializer), | |
type: | |
get_type(data?.type, component, serializer, signature_type) || "" | |
}); | |
transformed_info[category][endpoint] = { | |
parameters: parameters.map((p: ApiData) => | |
transform_type(p, p?.component, p?.serializer, "parameter") | |
), | |
returns: returns.map((r: ApiData) => | |
transform_type(r, r?.component, r?.serializer, "return") | |
), | |
type: dependencyTypes | |
}; | |
} | |
); | |
} | |
}); | |
return transformed_info; | |
} | |
export function get_type( | |
type: { type: any; description: string }, | |
component: string, | |
serializer: string, | |
signature_type: "return" | "parameter" | |
): string | undefined { | |
switch (type?.type) { | |
case "string": | |
return "string"; | |
case "boolean": | |
return "boolean"; | |
case "number": | |
return "number"; | |
} | |
if ( | |
serializer === "JSONSerializable" || | |
serializer === "StringSerializable" | |
) { | |
return "any"; | |
} else if (serializer === "ListStringSerializable") { | |
return "string[]"; | |
} else if (component === "Image") { | |
return signature_type === "parameter" ? "Blob | File | Buffer" : "string"; | |
} else if (serializer === "FileSerializable") { | |
if (type?.type === "array") { | |
return signature_type === "parameter" | |
? "(Blob | File | Buffer)[]" | |
: `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}[]`; | |
} | |
return signature_type === "parameter" | |
? "Blob | File | Buffer" | |
: `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}`; | |
} else if (serializer === "GallerySerializable") { | |
return signature_type === "parameter" | |
? "[(Blob | File | Buffer), (string | null)][]" | |
: `[{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}, (string | null))][]`; | |
} | |
} | |
export function get_description( | |
type: { type: any; description: string }, | |
serializer: string | |
): string { | |
if (serializer === "GallerySerializable") { | |
return "array of [file, label] tuples"; | |
} else if (serializer === "ListStringSerializable") { | |
return "array of strings"; | |
} else if (serializer === "FileSerializable") { | |
return "array of files or single file"; | |
} | |
return type?.description; | |
} | |
/* eslint-disable complexity */ | |
export function handle_message( | |
data: any, | |
last_status: Status["stage"] | |
): { | |
type: | |
| "hash" | |
| "data" | |
| "update" | |
| "complete" | |
| "generating" | |
| "log" | |
| "none" | |
| "heartbeat" | |
| "unexpected_error"; | |
data?: any; | |
status?: Status; | |
} { | |
const queue = true; | |
switch (data.msg) { | |
case "send_data": | |
return { type: "data" }; | |
case "send_hash": | |
return { type: "hash" }; | |
case "queue_full": | |
return { | |
type: "update", | |
status: { | |
queue, | |
message: QUEUE_FULL_MSG, | |
stage: "error", | |
code: data.code, | |
success: data.success | |
} | |
}; | |
case "heartbeat": | |
return { | |
type: "heartbeat" | |
}; | |
case "unexpected_error": | |
return { | |
type: "unexpected_error", | |
status: { | |
queue, | |
message: data.message, | |
stage: "error", | |
success: false | |
} | |
}; | |
case "estimation": | |
return { | |
type: "update", | |
status: { | |
queue, | |
stage: last_status || "pending", | |
code: data.code, | |
size: data.queue_size, | |
position: data.rank, | |
eta: data.rank_eta, | |
success: data.success | |
} | |
}; | |
case "progress": | |
return { | |
type: "update", | |
status: { | |
queue, | |
stage: "pending", | |
code: data.code, | |
progress_data: data.progress_data, | |
success: data.success | |
} | |
}; | |
case "log": | |
return { type: "log", data: data }; | |
case "process_generating": | |
return { | |
type: "generating", | |
status: { | |
queue, | |
message: !data.success ? data.output.error : null, | |
stage: data.success ? "generating" : "error", | |
code: data.code, | |
progress_data: data.progress_data, | |
eta: data.average_duration | |
}, | |
data: data.success ? data.output : null | |
}; | |
case "process_completed": | |
if ("error" in data.output) { | |
return { | |
type: "update", | |
status: { | |
queue, | |
message: data.output.error as string, | |
stage: "error", | |
code: data.code, | |
success: data.success | |
} | |
}; | |
} | |
return { | |
type: "complete", | |
status: { | |
queue, | |
message: !data.success ? data.output.error : undefined, | |
stage: data.success ? "complete" : "error", | |
code: data.code, | |
progress_data: data.progress_data, | |
changed_state_ids: data.success | |
? data.output.changed_state_ids | |
: undefined | |
}, | |
data: data.success ? data.output : null | |
}; | |
case "process_starts": | |
return { | |
type: "update", | |
status: { | |
queue, | |
stage: "pending", | |
code: data.code, | |
size: data.rank, | |
position: 0, | |
success: data.success, | |
eta: data.eta | |
} | |
}; | |
} | |
return { type: "none", status: { stage: "error", queue } }; | |
} | |
/* eslint-enable complexity */ | |
/** | |
* Maps the provided `data` to the parameters defined by the `/info` endpoint response. | |
* This allows us to support both positional and keyword arguments passed to the client | |
* and ensures that all parameters are either directly provided or have default values assigned. | |
* | |
* @param {unknown[] | Record<string, unknown>} data - The input data for the function, | |
* which can be either an array of values for positional arguments or an object | |
* with key-value pairs for keyword arguments. | |
* @param {JsApiData[]} parameters - Array of parameter descriptions retrieved from the | |
* `/info` endpoint. | |
* | |
* @returns {unknown[]} - Returns an array of resolved data where each element corresponds | |
* to the expected parameter from the API. The `parameter_default` value is used where | |
* a value is not provided for a parameter, and optional parameters without defaults are | |
* set to `undefined`. | |
* | |
* @throws {Error} - Throws an error: | |
* - If more arguments are provided than are defined in the parameters. | |
* * - If no parameter value is provided for a required parameter and no default value is defined. | |
* - If an argument is provided that does not match any defined parameter. | |
*/ | |
export const map_data_to_params = ( | |
data: unknown[] | Record<string, unknown>, | |
api_info: ApiInfo<JsApiData | ApiData> | |
): unknown[] => { | |
const parameters = Object.values(api_info.named_endpoints).flatMap( | |
(values) => values.parameters | |
); | |
if (Array.isArray(data)) { | |
if (data.length > parameters.length) { | |
console.warn("Too many arguments provided for the endpoint."); | |
} | |
return data; | |
} | |
const resolved_data: unknown[] = []; | |
const provided_keys = Object.keys(data); | |
parameters.forEach((param, index) => { | |
if (data.hasOwnProperty(param.parameter_name)) { | |
resolved_data[index] = data[param.parameter_name]; | |
} else if (param.parameter_has_default) { | |
resolved_data[index] = param.parameter_default; | |
} else { | |
throw new Error( | |
`No value provided for required parameter: ${param.parameter_name}` | |
); | |
} | |
}); | |
provided_keys.forEach((key) => { | |
if (!parameters.some((param) => param.parameter_name === key)) { | |
throw new Error( | |
`Parameter \`${key}\` is not a valid keyword argument. Please refer to the API for usage.` | |
); | |
} | |
}); | |
resolved_data.forEach((value, idx) => { | |
if (value === undefined && !parameters[idx].parameter_has_default) { | |
throw new Error( | |
`No value provided for required parameter: ${parameters[idx].parameter_name}` | |
); | |
} | |
}); | |
return resolved_data; | |
}; | |