Spaces:
Paused
Paused
/* eslint-disable complexity */ | |
import type { | |
Status, | |
Payload, | |
GradioEvent, | |
JsApiData, | |
EndpointInfo, | |
ApiInfo, | |
Config, | |
Dependency, | |
SubmitIterable | |
} from "../types"; | |
import { skip_queue, post_message, handle_payload } from "../helpers/data"; | |
import { resolve_root } from "../helpers/init_helpers"; | |
import { | |
handle_message, | |
map_data_to_params, | |
process_endpoint | |
} from "../helpers/api_info"; | |
import semiver from "semiver"; | |
import { BROKEN_CONNECTION_MSG, QUEUE_FULL_MSG } from "../constants"; | |
import { apply_diff_stream, close_stream } from "./stream"; | |
import { Client } from "../client"; | |
export function submit( | |
this: Client, | |
endpoint: string | number, | |
data: unknown[] | Record<string, unknown>, | |
event_data?: unknown, | |
trigger_id?: number | null, | |
all_events?: boolean | |
): SubmitIterable<GradioEvent> { | |
try { | |
const { hf_token } = this.options; | |
const { | |
fetch, | |
app_reference, | |
config, | |
session_hash, | |
api_info, | |
api_map, | |
stream_status, | |
pending_stream_messages, | |
pending_diff_streams, | |
event_callbacks, | |
unclosed_events, | |
post_data, | |
options | |
} = this; | |
const that = this; | |
if (!api_info) throw new Error("No API found"); | |
if (!config) throw new Error("Could not resolve app config"); | |
let { fn_index, endpoint_info, dependency } = get_endpoint_info( | |
api_info, | |
endpoint, | |
api_map, | |
config | |
); | |
let resolved_data = map_data_to_params(data, api_info); | |
let websocket: WebSocket; | |
let stream: EventSource | null; | |
let protocol = config.protocol ?? "ws"; | |
const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint; | |
let payload: Payload; | |
let event_id: string | null = null; | |
let complete: Status | undefined | false = false; | |
let last_status: Record<string, Status["stage"]> = {}; | |
let url_params = | |
typeof window !== "undefined" && typeof document !== "undefined" | |
? new URLSearchParams(window.location.search).toString() | |
: ""; | |
const events_to_publish = | |
options?.events?.reduce( | |
(acc, event) => { | |
acc[event] = true; | |
return acc; | |
}, | |
{} as Record<string, boolean> | |
) || {}; | |
// event subscription methods | |
function fire_event(event: GradioEvent): void { | |
if (all_events || events_to_publish[event.type]) { | |
push_event(event); | |
} | |
} | |
async function cancel(): Promise<void> { | |
const _status: Status = { | |
stage: "complete", | |
queue: false, | |
time: new Date() | |
}; | |
complete = _status; | |
fire_event({ | |
..._status, | |
type: "status", | |
endpoint: _endpoint, | |
fn_index: fn_index | |
}); | |
let reset_request = {}; | |
let cancel_request = {}; | |
if (protocol === "ws") { | |
if (websocket && websocket.readyState === 0) { | |
websocket.addEventListener("open", () => { | |
websocket.close(); | |
}); | |
} else { | |
websocket.close(); | |
} | |
reset_request = { fn_index, session_hash }; | |
} else { | |
close_stream(stream_status, that.abort_controller); | |
close(); | |
reset_request = { event_id }; | |
cancel_request = { event_id, session_hash, fn_index }; | |
} | |
try { | |
if (!config) { | |
throw new Error("Could not resolve app config"); | |
} | |
if ("event_id" in cancel_request) { | |
await fetch(`${config.root}/cancel`, { | |
headers: { "Content-Type": "application/json" }, | |
method: "POST", | |
body: JSON.stringify(cancel_request) | |
}); | |
} | |
await fetch(`${config.root}/reset`, { | |
headers: { "Content-Type": "application/json" }, | |
method: "POST", | |
body: JSON.stringify(reset_request) | |
}); | |
} catch (e) { | |
console.warn( | |
"The `/reset` endpoint could not be called. Subsequent endpoint results may be unreliable." | |
); | |
} | |
} | |
const resolve_heartbeat = async (config: Config): Promise<void> => { | |
await this._resolve_hearbeat(config); | |
}; | |
async function handle_render_config(render_config: any): Promise<void> { | |
if (!config) return; | |
let render_id: number = render_config.render_id; | |
config.components = [ | |
...config.components.filter((c) => c.props.rendered_in !== render_id), | |
...render_config.components | |
]; | |
config.dependencies = [ | |
...config.dependencies.filter((d) => d.rendered_in !== render_id), | |
...render_config.dependencies | |
]; | |
const any_state = config.components.some((c) => c.type === "state"); | |
const any_unload = config.dependencies.some((d) => | |
d.targets.some((t) => t[1] === "unload") | |
); | |
config.connect_heartbeat = any_state || any_unload; | |
await resolve_heartbeat(config); | |
fire_event({ | |
type: "render", | |
data: render_config, | |
endpoint: _endpoint, | |
fn_index | |
}); | |
} | |
this.handle_blob(config.root, resolved_data, endpoint_info).then( | |
async (_payload) => { | |
let input_data = handle_payload( | |
_payload, | |
dependency, | |
config.components, | |
"input", | |
true | |
); | |
payload = { | |
data: input_data || [], | |
event_data, | |
fn_index, | |
trigger_id | |
}; | |
if (skip_queue(fn_index, config)) { | |
fire_event({ | |
type: "status", | |
endpoint: _endpoint, | |
stage: "pending", | |
queue: false, | |
fn_index, | |
time: new Date() | |
}); | |
post_data( | |
`${config.root}/run${ | |
_endpoint.startsWith("/") ? _endpoint : `/${_endpoint}` | |
}${url_params ? "?" + url_params : ""}`, | |
{ | |
...payload, | |
session_hash | |
} | |
) | |
.then(([output, status_code]: any) => { | |
const data = output.data; | |
if (status_code == 200) { | |
fire_event({ | |
type: "data", | |
endpoint: _endpoint, | |
fn_index, | |
data: handle_payload( | |
data, | |
dependency, | |
config.components, | |
"output", | |
options.with_null_state | |
), | |
time: new Date(), | |
event_data, | |
trigger_id | |
}); | |
if (output.render_config) { | |
handle_render_config(output.render_config); | |
} | |
fire_event({ | |
type: "status", | |
endpoint: _endpoint, | |
fn_index, | |
stage: "complete", | |
eta: output.average_duration, | |
queue: false, | |
time: new Date() | |
}); | |
} else { | |
fire_event({ | |
type: "status", | |
stage: "error", | |
endpoint: _endpoint, | |
fn_index, | |
message: output.error, | |
queue: false, | |
time: new Date() | |
}); | |
} | |
}) | |
.catch((e) => { | |
fire_event({ | |
type: "status", | |
stage: "error", | |
message: e.message, | |
endpoint: _endpoint, | |
fn_index, | |
queue: false, | |
time: new Date() | |
}); | |
}); | |
} else if (protocol == "ws") { | |
const { ws_protocol, host } = await process_endpoint( | |
app_reference, | |
hf_token | |
); | |
fire_event({ | |
type: "status", | |
stage: "pending", | |
queue: true, | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date() | |
}); | |
let url = new URL( | |
`${ws_protocol}://${resolve_root( | |
host, | |
config.path as string, | |
true | |
)}/queue/join${url_params ? "?" + url_params : ""}` | |
); | |
if (this.jwt) { | |
url.searchParams.set("__sign", this.jwt); | |
} | |
websocket = new WebSocket(url); | |
websocket.onclose = (evt) => { | |
if (!evt.wasClean) { | |
fire_event({ | |
type: "status", | |
stage: "error", | |
broken: true, | |
message: BROKEN_CONNECTION_MSG, | |
queue: true, | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date() | |
}); | |
} | |
}; | |
websocket.onmessage = function (event) { | |
const _data = JSON.parse(event.data); | |
const { type, status, data } = handle_message( | |
_data, | |
last_status[fn_index] | |
); | |
if (type === "update" && status && !complete) { | |
// call 'status' listeners | |
fire_event({ | |
type: "status", | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date(), | |
...status | |
}); | |
if (status.stage === "error") { | |
websocket.close(); | |
} | |
} else if (type === "hash") { | |
websocket.send(JSON.stringify({ fn_index, session_hash })); | |
return; | |
} else if (type === "data") { | |
websocket.send(JSON.stringify({ ...payload, session_hash })); | |
} else if (type === "complete") { | |
complete = status; | |
} else if (type === "log") { | |
fire_event({ | |
type: "log", | |
log: data.log, | |
level: data.level, | |
endpoint: _endpoint, | |
fn_index | |
}); | |
} else if (type === "generating") { | |
fire_event({ | |
type: "status", | |
time: new Date(), | |
...status, | |
stage: status?.stage!, | |
queue: true, | |
endpoint: _endpoint, | |
fn_index | |
}); | |
} | |
if (data) { | |
fire_event({ | |
type: "data", | |
time: new Date(), | |
data: handle_payload( | |
data.data, | |
dependency, | |
config.components, | |
"output", | |
options.with_null_state | |
), | |
endpoint: _endpoint, | |
fn_index, | |
event_data, | |
trigger_id | |
}); | |
if (complete) { | |
fire_event({ | |
type: "status", | |
time: new Date(), | |
...complete, | |
stage: status?.stage!, | |
queue: true, | |
endpoint: _endpoint, | |
fn_index | |
}); | |
websocket.close(); | |
} | |
} | |
}; | |
// different ws contract for gradio versions older than 3.6.0 | |
//@ts-ignore | |
if (semiver(config.version || "2.0.0", "3.6") < 0) { | |
addEventListener("open", () => | |
websocket.send(JSON.stringify({ hash: session_hash })) | |
); | |
} | |
} else if (protocol == "sse") { | |
fire_event({ | |
type: "status", | |
stage: "pending", | |
queue: true, | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date() | |
}); | |
var params = new URLSearchParams({ | |
fn_index: fn_index.toString(), | |
session_hash: session_hash | |
}).toString(); | |
let url = new URL( | |
`${config.root}/queue/join?${ | |
url_params ? url_params + "&" : "" | |
}${params}` | |
); | |
if (this.jwt) { | |
url.searchParams.set("__sign", this.jwt); | |
} | |
stream = this.stream(url); | |
if (!stream) { | |
return Promise.reject( | |
new Error("Cannot connect to SSE endpoint: " + url.toString()) | |
); | |
} | |
stream.onmessage = async function (event: MessageEvent) { | |
const _data = JSON.parse(event.data); | |
const { type, status, data } = handle_message( | |
_data, | |
last_status[fn_index] | |
); | |
if (type === "update" && status && !complete) { | |
// call 'status' listeners | |
fire_event({ | |
type: "status", | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date(), | |
...status | |
}); | |
if (status.stage === "error") { | |
stream?.close(); | |
close(); | |
} | |
} else if (type === "data") { | |
event_id = _data.event_id as string; | |
let [_, status] = await post_data(`${config.root}/queue/data`, { | |
...payload, | |
session_hash, | |
event_id | |
}); | |
if (status !== 200) { | |
fire_event({ | |
type: "status", | |
stage: "error", | |
message: BROKEN_CONNECTION_MSG, | |
queue: true, | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date() | |
}); | |
stream?.close(); | |
close(); | |
} | |
} else if (type === "complete") { | |
complete = status; | |
} else if (type === "log") { | |
fire_event({ | |
type: "log", | |
log: data.log, | |
level: data.level, | |
endpoint: _endpoint, | |
fn_index | |
}); | |
} else if (type === "generating") { | |
fire_event({ | |
type: "status", | |
time: new Date(), | |
...status, | |
stage: status?.stage!, | |
queue: true, | |
endpoint: _endpoint, | |
fn_index | |
}); | |
} | |
if (data) { | |
fire_event({ | |
type: "data", | |
time: new Date(), | |
data: handle_payload( | |
data.data, | |
dependency, | |
config.components, | |
"output", | |
options.with_null_state | |
), | |
endpoint: _endpoint, | |
fn_index, | |
event_data, | |
trigger_id | |
}); | |
if (complete) { | |
fire_event({ | |
type: "status", | |
time: new Date(), | |
...complete, | |
stage: status?.stage!, | |
queue: true, | |
endpoint: _endpoint, | |
fn_index | |
}); | |
stream?.close(); | |
close(); | |
} | |
} | |
}; | |
} else if ( | |
protocol == "sse_v1" || | |
protocol == "sse_v2" || | |
protocol == "sse_v2.1" || | |
protocol == "sse_v3" | |
) { | |
// latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter. | |
// v3 only closes the stream when the backend sends the close stream message. | |
fire_event({ | |
type: "status", | |
stage: "pending", | |
queue: true, | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date() | |
}); | |
let hostname = ""; | |
if ( | |
typeof window !== "undefined" && | |
typeof document !== "undefined" | |
) { | |
hostname = window?.location?.hostname; | |
} | |
let hfhubdev = "dev.spaces.huggingface.tech"; | |
const origin = hostname.includes(".dev.") | |
? `https://moon-${hostname.split(".")[1]}.${hfhubdev}` | |
: `https://huggingface.co`; | |
const is_iframe = | |
typeof window !== "undefined" && | |
typeof document !== "undefined" && | |
window.parent != window; | |
const is_zerogpu_space = dependency.zerogpu && config.space_id; | |
const zerogpu_auth_promise = | |
is_iframe && is_zerogpu_space | |
? post_message<Headers>("zerogpu-headers", origin) | |
: Promise.resolve(null); | |
const post_data_promise = zerogpu_auth_promise.then((headers) => { | |
return post_data( | |
`${config.root}/queue/join?${url_params}`, | |
{ | |
...payload, | |
session_hash | |
}, | |
headers | |
); | |
}); | |
post_data_promise.then(async ([response, status]: any) => { | |
if (status === 503) { | |
fire_event({ | |
type: "status", | |
stage: "error", | |
message: QUEUE_FULL_MSG, | |
queue: true, | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date() | |
}); | |
} else if (status !== 200) { | |
fire_event({ | |
type: "status", | |
stage: "error", | |
message: BROKEN_CONNECTION_MSG, | |
queue: true, | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date() | |
}); | |
} else { | |
event_id = response.event_id as string; | |
let callback = async function (_data: object): Promise<void> { | |
try { | |
const { type, status, data } = handle_message( | |
_data, | |
last_status[fn_index] | |
); | |
if (type == "heartbeat") { | |
return; | |
} | |
if (type === "update" && status && !complete) { | |
// call 'status' listeners | |
fire_event({ | |
type: "status", | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date(), | |
...status | |
}); | |
} else if (type === "complete") { | |
complete = status; | |
} else if (type == "unexpected_error") { | |
console.error("Unexpected error", status?.message); | |
fire_event({ | |
type: "status", | |
stage: "error", | |
message: | |
status?.message || "An Unexpected Error Occurred!", | |
queue: true, | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date() | |
}); | |
} else if (type === "log") { | |
fire_event({ | |
type: "log", | |
log: data.log, | |
level: data.level, | |
endpoint: _endpoint, | |
fn_index | |
}); | |
return; | |
} else if (type === "generating") { | |
fire_event({ | |
type: "status", | |
time: new Date(), | |
...status, | |
stage: status?.stage!, | |
queue: true, | |
endpoint: _endpoint, | |
fn_index | |
}); | |
if ( | |
data && | |
["sse_v2", "sse_v2.1", "sse_v3"].includes(protocol) | |
) { | |
apply_diff_stream(pending_diff_streams, event_id!, data); | |
} | |
} | |
if (data) { | |
fire_event({ | |
type: "data", | |
time: new Date(), | |
data: handle_payload( | |
data.data, | |
dependency, | |
config.components, | |
"output", | |
options.with_null_state | |
), | |
endpoint: _endpoint, | |
fn_index | |
}); | |
if (data.render_config) { | |
await handle_render_config(data.render_config); | |
} | |
if (complete) { | |
fire_event({ | |
type: "status", | |
time: new Date(), | |
...complete, | |
stage: status?.stage!, | |
queue: true, | |
endpoint: _endpoint, | |
fn_index | |
}); | |
} | |
} | |
if ( | |
status?.stage === "complete" || | |
status?.stage === "error" | |
) { | |
if (event_callbacks[event_id!]) { | |
delete event_callbacks[event_id!]; | |
} | |
if (event_id! in pending_diff_streams) { | |
delete pending_diff_streams[event_id!]; | |
} | |
} | |
} catch (e) { | |
console.error("Unexpected client exception", e); | |
fire_event({ | |
type: "status", | |
stage: "error", | |
message: "An Unexpected Error Occurred!", | |
queue: true, | |
endpoint: _endpoint, | |
fn_index, | |
time: new Date() | |
}); | |
if (["sse_v2", "sse_v2.1", "sse_v3"].includes(protocol)) { | |
close_stream(stream_status, that.abort_controller); | |
stream_status.open = false; | |
close(); | |
} | |
} | |
}; | |
if (event_id in pending_stream_messages) { | |
pending_stream_messages[event_id].forEach((msg) => | |
callback(msg) | |
); | |
delete pending_stream_messages[event_id]; | |
} | |
// @ts-ignore | |
event_callbacks[event_id] = callback; | |
unclosed_events.add(event_id); | |
if (!stream_status.open) { | |
await this.open_stream(); | |
} | |
} | |
}); | |
} | |
} | |
); | |
let done = false; | |
const values: (IteratorResult<GradioEvent> | PromiseLike<never>)[] = []; | |
const resolvers: (( | |
value: IteratorResult<GradioEvent> | PromiseLike<never> | |
) => void)[] = []; | |
function close(): void { | |
done = true; | |
while (resolvers.length > 0) | |
(resolvers.shift() as (typeof resolvers)[0])({ | |
value: undefined, | |
done: true | |
}); | |
} | |
function push( | |
data: { value: GradioEvent; done: boolean } | PromiseLike<never> | |
): void { | |
if (done) return; | |
if (resolvers.length > 0) { | |
(resolvers.shift() as (typeof resolvers)[0])(data); | |
} else { | |
values.push(data); | |
} | |
} | |
function push_error(error: unknown): void { | |
push(thenable_reject(error)); | |
close(); | |
} | |
function push_event(event: GradioEvent): void { | |
push({ value: event, done: false }); | |
} | |
function next(): Promise<IteratorResult<GradioEvent, unknown>> { | |
if (values.length > 0) | |
return Promise.resolve(values.shift() as (typeof values)[0]); | |
if (done) return Promise.resolve({ value: undefined, done: true }); | |
return new Promise((resolve) => resolvers.push(resolve)); | |
} | |
const iterator = { | |
[Symbol.asyncIterator]: () => iterator, | |
next, | |
throw: async (value: unknown) => { | |
push_error(value); | |
return next(); | |
}, | |
return: async () => { | |
close(); | |
return next(); | |
}, | |
cancel | |
}; | |
return iterator; | |
} catch (error) { | |
console.error("Submit function encountered an error:", error); | |
throw error; | |
} | |
} | |
function thenable_reject<T>(error: T): PromiseLike<never> { | |
return { | |
then: ( | |
resolve: (value: never) => PromiseLike<never>, | |
reject: (error: T) => PromiseLike<never> | |
) => reject(error) | |
}; | |
} | |
function get_endpoint_info( | |
api_info: ApiInfo<JsApiData>, | |
endpoint: string | number, | |
api_map: Record<string, number>, | |
config: Config | |
): { | |
fn_index: number; | |
endpoint_info: EndpointInfo<JsApiData>; | |
dependency: Dependency; | |
} { | |
let fn_index: number; | |
let endpoint_info: EndpointInfo<JsApiData>; | |
let dependency: Dependency; | |
if (typeof endpoint === "number") { | |
fn_index = endpoint; | |
endpoint_info = api_info.unnamed_endpoints[fn_index]; | |
dependency = config.dependencies.find((dep) => dep.id == endpoint)!; | |
} else { | |
const trimmed_endpoint = endpoint.replace(/^\//, ""); | |
fn_index = api_map[trimmed_endpoint]; | |
endpoint_info = api_info.named_endpoints[endpoint.trim()]; | |
dependency = config.dependencies.find( | |
(dep) => dep.id == api_map[trimmed_endpoint] | |
)!; | |
} | |
if (typeof fn_index !== "number") { | |
throw new Error( | |
"There is no endpoint matching that name of fn_index matching that number." | |
); | |
} | |
return { fn_index, endpoint_info, dependency }; | |
} | |