Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import type { ModelData, WidgetExampleAttribute } from "@huggingface/tasks"; | |
import { parseJSON } from "../../../utils/ViewUtils.js"; | |
import { ComputeType, type ModelLoadInfo, type TableData } from "./types.js"; | |
import { LoadState } from "./types.js"; | |
import { isLoggedIn } from "../stores.js"; | |
import { get } from "svelte/store"; | |
const KEYS_TEXT: WidgetExampleAttribute[] = ["text", "context", "candidate_labels"]; | |
const KEYS_TABLE: WidgetExampleAttribute[] = ["table", "structured_data"]; | |
type QueryParamVal = string | null | boolean | (string | number)[][]; | |
export function getQueryParamVal(key: WidgetExampleAttribute): QueryParamVal { | |
const searchParams = new URL(window.location.href).searchParams; | |
const value = searchParams.get(key); | |
if (KEYS_TEXT.includes(key)) { | |
return value; | |
} else if (KEYS_TABLE.includes(key)) { | |
const table = convertDataToTable((parseJSON(value) as TableData) ?? {}); | |
return table; | |
} else if (key === "multi_class") { | |
return value === "true"; | |
} | |
return value; | |
} | |
// Update current url search params, keeping existing keys intact. | |
export function updateUrl(obj: Partial<Record<WidgetExampleAttribute, string | undefined>>): void { | |
if (!window) { | |
return; | |
} | |
const sp = new URL(window.location.href).searchParams; | |
for (const [k, v] of Object.entries(obj)) { | |
if (v === undefined) { | |
sp.delete(k); | |
} else { | |
sp.set(k, v); | |
} | |
} | |
const path = `${window.location.pathname}?${sp.toString()}`; | |
window.history.replaceState(null, "", path); | |
} | |
// Run through our own proxy to bypass CORS: | |
function proxify(url: string): string { | |
return url.startsWith(`http://localhost`) || new URL(url).host === window.location.host | |
? url | |
: `https://widgets.hf.co/proxy?url=${url}`; | |
} | |
// Get BLOB from a given URL after proxifying the URL | |
export async function getBlobFromUrl(url: string): Promise<Blob> { | |
const proxiedUrl = proxify(url); | |
const res = await fetch(proxiedUrl); | |
const blob = await res.blob(); | |
return blob; | |
} | |
interface Success<T> { | |
computeTime: string; | |
output: T; | |
outputJson: string; | |
response: Response; | |
status: "success"; | |
} | |
interface LoadingModel { | |
error: string; | |
estimatedTime: number; | |
status: "loading-model"; | |
} | |
interface Error { | |
error: string; | |
status: "error"; | |
} | |
interface CacheNotFound { | |
status: "cache not found"; | |
} | |
type Result<T> = Success<T> | LoadingModel | Error | CacheNotFound; | |
export async function callInferenceApi<T>( | |
url: string, | |
repoId: string, | |
requestBody: Record<string, unknown>, | |
apiToken = "", | |
outputParsingFn: (x: unknown) => T, | |
waitForModel = false, // If true, the server will only respond once the model has been loaded on Inference API (serverless) | |
includeCredentials = false, | |
isOnLoadCall = false, // If true, the server will try to answer from cache and not do anything if not | |
useCache = true | |
): Promise<Result<T>> { | |
const contentType = | |
"file" in requestBody && requestBody["file"] && requestBody["file"] instanceof Blob && requestBody["file"].type | |
? requestBody["file"]["type"] | |
: "application/json"; | |
const headers = new Headers(); | |
headers.set("Content-Type", contentType); | |
if (apiToken) { | |
headers.set("Authorization", `Bearer ${apiToken}`); | |
} | |
if (waitForModel) { | |
headers.set("X-Wait-For-Model", "true"); | |
} | |
if (useCache === false && get(isLoggedIn)) { | |
headers.set("X-Use-Cache", "false"); | |
} | |
if (isOnLoadCall || !get(isLoggedIn)) { | |
headers.set("X-Load-Model", "0"); | |
} | |
// `File` is a subtype of `Blob`: therefore, checking for instanceof `Blob` also checks for instanceof `File` | |
const reqBody: Blob | string = | |
"file" in requestBody && requestBody["file"] instanceof Blob ? requestBody.file : JSON.stringify(requestBody); | |
const response = await fetch(`${url}/models/${repoId}`, { | |
method: "POST", | |
body: reqBody, | |
headers, | |
credentials: includeCredentials ? "include" : "same-origin", | |
}); | |
if (response.ok) { | |
// Success | |
const computeTime = response.headers.has("x-compute-time") | |
? `${response.headers.get("x-compute-time")} s` | |
: `cached`; | |
const isMediaContent = (response.headers.get("content-type")?.search(/^(?:audio|image)/i) ?? -1) !== -1; | |
const body = !isMediaContent ? await response.json() : await response.blob(); | |
try { | |
const output = outputParsingFn(body); | |
const outputJson = !isMediaContent ? JSON.stringify(body, null, 2) : ""; | |
return { computeTime, output, outputJson, response, status: "success" }; | |
} catch (e) { | |
if (isOnLoadCall && body.error === "not loaded yet") { | |
return { status: "cache not found" }; | |
} | |
// Invalid output | |
const error = `API Implementation Error: ${String(e).replace(/^Error: /, "")}`; | |
return { error, status: "error" }; | |
} | |
} else { | |
// Error | |
const bodyText = await response.text(); | |
const body = parseJSON<Record<string, unknown>>(bodyText) ?? {}; | |
if ( | |
body["error"] && | |
response.status === 503 && | |
body["estimated_time"] !== null && | |
body["estimated_time"] !== undefined | |
) { | |
// Model needs loading | |
return { error: String(body["error"]), estimatedTime: +body["estimated_time"], status: "loading-model" }; | |
} else { | |
// Other errors | |
const { status, statusText } = response; | |
return { | |
error: String(body["error"]) || String(body["traceback"]) || `${status} ${statusText}`, | |
status: "error", | |
}; | |
} | |
} | |
} | |
export async function getModelLoadInfo( | |
url: string, | |
repoId: string, | |
includeCredentials = false | |
): Promise<ModelLoadInfo> { | |
const response = await fetch(`${url}/status/${repoId}`, { | |
credentials: includeCredentials ? "include" : "same-origin", | |
}); | |
const output: { | |
state: LoadState; | |
compute_type: ComputeType | Record<ComputeType, { [key in ComputeType]?: string } & { count: number }>; | |
loaded: boolean; | |
error: Error; | |
} = await response.json(); | |
if (response.ok && typeof output === "object" && output.loaded !== undefined) { | |
// eslint-disable-next-line @typescript-eslint/naming-convention | |
const compute_type = | |
typeof output.compute_type === "string" | |
? output.compute_type | |
: output.compute_type["gpu"] | |
? ComputeType.GPU | |
: ComputeType.CPU; | |
return { compute_type, state: output.state }; | |
} else { | |
console.warn(response.status, output.error); | |
return { state: LoadState.Error }; | |
} | |
} | |
// Extend requestBody with user supplied parameters for Inference API (serverless) | |
export function addInferenceParameters(requestBody: Record<string, unknown>, model: ModelData): void { | |
const inference = model?.cardData?.inference; | |
if (typeof inference === "object") { | |
const inferenceParameters = inference?.parameters; | |
if (inferenceParameters) { | |
if (requestBody.parameters) { | |
requestBody.parameters = { ...requestBody.parameters, ...inferenceParameters }; | |
} else { | |
requestBody.parameters = inferenceParameters; | |
} | |
} | |
} | |
} | |
/* | |
* Converts table from [[Header0, Header1, Header2], [Column0Val0, Column1Val0, Column2Val0], ...] | |
* to {Header0: [ColumnVal0, ...], Header1: [Column1Val0, ...], Header2: [Column2Val0, ...]} | |
*/ | |
export function convertTableToData(table: (string | number)[][]): TableData { | |
return Object.fromEntries( | |
table[0].map((cell, x) => { | |
return [ | |
cell, | |
table | |
.slice(1) | |
.flat() | |
.filter((_, i) => i % table[0].length === x) | |
.map((v) => String(v)), // some models can only handle strings (no numbers) | |
]; | |
}) | |
); | |
} | |
/** | |
* Converts data from {Header0: [ColumnVal0, ...], Header1: [Column1Val0, ...], Header2: [Column2Val0, ...]} | |
* to [[Header0, Header1, Header2], [Column0Val0, Column1Val0, Column2Val0], ...] | |
*/ | |
export function convertDataToTable(data: TableData): (string | number)[][] { | |
const dataArray = Object.entries(data); // [header, cell[]][] | |
const nbCols = dataArray.length; | |
const nbRows = (dataArray[0]?.[1]?.length ?? 0) + 1; | |
return Array(nbRows) | |
.fill("") | |
.map((_, y) => | |
Array(nbCols) | |
.fill("") | |
.map((__, x) => (y === 0 ? dataArray[x][0] : dataArray[x][1][y - 1])) | |
); | |
} | |