machineuser
Sync widgets demo
76d4920
raw
history blame
7.89 kB
import type { ModelData, WidgetExampleAttribute } from "@huggingface/tasks";
import { parseJSON } from "../../../utils/ViewUtils.js";
import { ComputeType, type ModelLoadInfo, type TableData } from "./types.js";
import { LoadState } from "./types.js";
import { isLoggedIn } from "../stores.js";
import { get } from "svelte/store";
const KEYS_TEXT: WidgetExampleAttribute[] = ["text", "context", "candidate_labels"];
const KEYS_TABLE: WidgetExampleAttribute[] = ["table", "structured_data"];
type QueryParamVal = string | null | boolean | (string | number)[][];
export function getQueryParamVal(key: WidgetExampleAttribute): QueryParamVal {
const searchParams = new URL(window.location.href).searchParams;
const value = searchParams.get(key);
if (KEYS_TEXT.includes(key)) {
return value;
} else if (KEYS_TABLE.includes(key)) {
const table = convertDataToTable((parseJSON(value) as TableData) ?? {});
return table;
} else if (key === "multi_class") {
return value === "true";
}
return value;
}
// Update current url search params, keeping existing keys intact.
export function updateUrl(obj: Partial<Record<WidgetExampleAttribute, string | undefined>>): void {
if (!window) {
return;
}
const sp = new URL(window.location.href).searchParams;
for (const [k, v] of Object.entries(obj)) {
if (v === undefined) {
sp.delete(k);
} else {
sp.set(k, v);
}
}
const path = `${window.location.pathname}?${sp.toString()}`;
window.history.replaceState(null, "", path);
}
// Run through our own proxy to bypass CORS:
function proxify(url: string): string {
return url.startsWith(`http://localhost`) || new URL(url).host === window.location.host
? url
: `https://widgets.hf.co/proxy?url=${url}`;
}
// Get BLOB from a given URL after proxifying the URL
export async function getBlobFromUrl(url: string): Promise<Blob> {
const proxiedUrl = proxify(url);
const res = await fetch(proxiedUrl);
const blob = await res.blob();
return blob;
}
interface Success<T> {
computeTime: string;
output: T;
outputJson: string;
response: Response;
status: "success";
}
interface LoadingModel {
error: string;
estimatedTime: number;
status: "loading-model";
}
interface Error {
error: string;
status: "error";
}
interface CacheNotFound {
status: "cache not found";
}
type Result<T> = Success<T> | LoadingModel | Error | CacheNotFound;
export async function callInferenceApi<T>(
url: string,
repoId: string,
requestBody: Record<string, unknown>,
apiToken = "",
outputParsingFn: (x: unknown) => T,
waitForModel = false, // If true, the server will only respond once the model has been loaded on Inference API (serverless)
includeCredentials = false,
isOnLoadCall = false, // If true, the server will try to answer from cache and not do anything if not
useCache = true
): Promise<Result<T>> {
const contentType =
"file" in requestBody && requestBody["file"] && requestBody["file"] instanceof Blob && requestBody["file"].type
? requestBody["file"]["type"]
: "application/json";
const headers = new Headers();
headers.set("Content-Type", contentType);
if (apiToken) {
headers.set("Authorization", `Bearer ${apiToken}`);
}
if (waitForModel) {
headers.set("X-Wait-For-Model", "true");
}
if (useCache === false && get(isLoggedIn)) {
headers.set("X-Use-Cache", "false");
}
if (isOnLoadCall || !get(isLoggedIn)) {
headers.set("X-Load-Model", "0");
}
// `File` is a subtype of `Blob`: therefore, checking for instanceof `Blob` also checks for instanceof `File`
const reqBody: Blob | string =
"file" in requestBody && requestBody["file"] instanceof Blob ? requestBody.file : JSON.stringify(requestBody);
const response = await fetch(`${url}/models/${repoId}`, {
method: "POST",
body: reqBody,
headers,
credentials: includeCredentials ? "include" : "same-origin",
});
if (response.ok) {
// Success
const computeTime = response.headers.has("x-compute-time")
? `${response.headers.get("x-compute-time")} s`
: `cached`;
const isMediaContent = (response.headers.get("content-type")?.search(/^(?:audio|image)/i) ?? -1) !== -1;
const body = !isMediaContent ? await response.json() : await response.blob();
try {
const output = outputParsingFn(body);
const outputJson = !isMediaContent ? JSON.stringify(body, null, 2) : "";
return { computeTime, output, outputJson, response, status: "success" };
} catch (e) {
if (isOnLoadCall && body.error === "not loaded yet") {
return { status: "cache not found" };
}
// Invalid output
const error = `API Implementation Error: ${String(e).replace(/^Error: /, "")}`;
return { error, status: "error" };
}
} else {
// Error
const bodyText = await response.text();
const body = parseJSON<Record<string, unknown>>(bodyText) ?? {};
if (
body["error"] &&
response.status === 503 &&
body["estimated_time"] !== null &&
body["estimated_time"] !== undefined
) {
// Model needs loading
return { error: String(body["error"]), estimatedTime: +body["estimated_time"], status: "loading-model" };
} else {
// Other errors
const { status, statusText } = response;
return {
error: String(body["error"]) || String(body["traceback"]) || `${status} ${statusText}`,
status: "error",
};
}
}
}
export async function getModelLoadInfo(
url: string,
repoId: string,
includeCredentials = false
): Promise<ModelLoadInfo> {
const response = await fetch(`${url}/status/${repoId}`, {
credentials: includeCredentials ? "include" : "same-origin",
});
const output: {
state: LoadState;
compute_type: ComputeType | Record<ComputeType, { [key in ComputeType]?: string } & { count: number }>;
loaded: boolean;
error: Error;
} = await response.json();
if (response.ok && typeof output === "object" && output.loaded !== undefined) {
// eslint-disable-next-line @typescript-eslint/naming-convention
const compute_type =
typeof output.compute_type === "string"
? output.compute_type
: output.compute_type["gpu"]
? ComputeType.GPU
: ComputeType.CPU;
return { compute_type, state: output.state };
} else {
console.warn(response.status, output.error);
return { state: LoadState.Error };
}
}
// Extend requestBody with user supplied parameters for Inference API (serverless)
export function addInferenceParameters(requestBody: Record<string, unknown>, model: ModelData): void {
const inference = model?.cardData?.inference;
if (typeof inference === "object") {
const inferenceParameters = inference?.parameters;
if (inferenceParameters) {
if (requestBody.parameters) {
requestBody.parameters = { ...requestBody.parameters, ...inferenceParameters };
} else {
requestBody.parameters = inferenceParameters;
}
}
}
}
/*
* Converts table from [[Header0, Header1, Header2], [Column0Val0, Column1Val0, Column2Val0], ...]
* to {Header0: [ColumnVal0, ...], Header1: [Column1Val0, ...], Header2: [Column2Val0, ...]}
*/
export function convertTableToData(table: (string | number)[][]): TableData {
return Object.fromEntries(
table[0].map((cell, x) => {
return [
cell,
table
.slice(1)
.flat()
.filter((_, i) => i % table[0].length === x)
.map((v) => String(v)), // some models can only handle strings (no numbers)
];
})
);
}
/**
* Converts data from {Header0: [ColumnVal0, ...], Header1: [Column1Val0, ...], Header2: [Column2Val0, ...]}
* to [[Header0, Header1, Header2], [Column0Val0, Column1Val0, Column2Val0], ...]
*/
export function convertDataToTable(data: TableData): (string | number)[][] {
const dataArray = Object.entries(data); // [header, cell[]][]
const nbCols = dataArray.length;
const nbRows = (dataArray[0]?.[1]?.length ?? 0) + 1;
return Array(nbRows)
.fill("")
.map((_, y) =>
Array(nbCols)
.fill("")
.map((__, x) => (y === 0 ? dataArray[x][0] : dataArray[x][1][y - 1]))
);
}