|
import type { Config } from "../types";
|
|
import {
|
|
CONFIG_ERROR_MSG,
|
|
CONFIG_URL,
|
|
INVALID_CREDENTIALS_MSG,
|
|
LOGIN_URL,
|
|
MISSING_CREDENTIALS_MSG,
|
|
SPACE_METADATA_ERROR_MSG,
|
|
UNAUTHORIZED_MSG
|
|
} from "../constants";
|
|
import { Client } from "..";
|
|
import { join_urls, process_endpoint } from "./api_info";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export function resolve_root(
|
|
base_url: string,
|
|
root_path: string,
|
|
prioritize_base: boolean
|
|
): string {
|
|
if (root_path.startsWith("http://") || root_path.startsWith("https://")) {
|
|
return prioritize_base ? base_url : root_path;
|
|
}
|
|
return base_url + root_path;
|
|
}
|
|
|
|
export async function get_jwt(
|
|
space: string,
|
|
token: `hf_${string}`,
|
|
cookies?: string | null
|
|
): Promise<string | false> {
|
|
try {
|
|
const r = await fetch(`https://huggingface.co/api/spaces/${space}/jwt`, {
|
|
headers: {
|
|
Authorization: `Bearer ${token}`,
|
|
...(cookies ? { Cookie: cookies } : {})
|
|
}
|
|
});
|
|
|
|
const jwt = (await r.json()).token;
|
|
|
|
return jwt || false;
|
|
} catch (e) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
export function map_names_to_ids(
|
|
fns: Config["dependencies"]
|
|
): Record<string, number> {
|
|
let apis: Record<string, number> = {};
|
|
|
|
fns.forEach(({ api_name, id }) => {
|
|
if (api_name) apis[api_name] = id;
|
|
});
|
|
return apis;
|
|
}
|
|
|
|
export async function resolve_config(
|
|
this: Client,
|
|
endpoint: string
|
|
): Promise<Config | undefined> {
|
|
const headers: Record<string, string> = this.options.hf_token
|
|
? { Authorization: `Bearer ${this.options.hf_token}` }
|
|
: {};
|
|
|
|
headers["Content-Type"] = "application/json";
|
|
|
|
if (
|
|
typeof window !== "undefined" &&
|
|
window.gradio_config &&
|
|
location.origin !== "http://localhost:9876" &&
|
|
!window.gradio_config.dev_mode
|
|
) {
|
|
const path = window.gradio_config.root;
|
|
const config = window.gradio_config;
|
|
let config_root = resolve_root(endpoint, config.root, false);
|
|
config.root = config_root;
|
|
return { ...config, path } as Config;
|
|
} else if (endpoint) {
|
|
const config_url = join_urls(endpoint, CONFIG_URL);
|
|
const response = await this.fetch(config_url, {
|
|
headers,
|
|
credentials: "include"
|
|
});
|
|
|
|
if (response?.status === 401 && !this.options.auth) {
|
|
throw new Error(MISSING_CREDENTIALS_MSG);
|
|
} else if (response?.status === 401 && this.options.auth) {
|
|
throw new Error(INVALID_CREDENTIALS_MSG);
|
|
}
|
|
if (response?.status === 200) {
|
|
let config = await response.json();
|
|
config.path = config.path ?? "";
|
|
config.root = endpoint;
|
|
config.dependencies?.forEach((dep: any, i: number) => {
|
|
if (dep.id === undefined) {
|
|
dep.id = i;
|
|
}
|
|
});
|
|
return config;
|
|
} else if (response?.status === 401) {
|
|
throw new Error(UNAUTHORIZED_MSG);
|
|
}
|
|
throw new Error(CONFIG_ERROR_MSG);
|
|
}
|
|
|
|
throw new Error(CONFIG_ERROR_MSG);
|
|
}
|
|
|
|
export async function resolve_cookies(this: Client): Promise<void> {
|
|
const { http_protocol, host } = await process_endpoint(
|
|
this.app_reference,
|
|
this.options.hf_token
|
|
);
|
|
|
|
try {
|
|
if (this.options.auth) {
|
|
const cookie_header = await get_cookie_header(
|
|
http_protocol,
|
|
host,
|
|
this.options.auth,
|
|
this.fetch,
|
|
this.options.hf_token
|
|
);
|
|
|
|
if (cookie_header) this.set_cookies(cookie_header);
|
|
}
|
|
} catch (e: unknown) {
|
|
throw Error((e as Error).message);
|
|
}
|
|
}
|
|
|
|
|
|
export async function get_cookie_header(
|
|
http_protocol: string,
|
|
host: string,
|
|
auth: [string, string],
|
|
_fetch: typeof fetch,
|
|
hf_token?: `hf_${string}`
|
|
): Promise<string | null> {
|
|
const formData = new FormData();
|
|
formData.append("username", auth?.[0]);
|
|
formData.append("password", auth?.[1]);
|
|
|
|
let headers: { Authorization?: string } = {};
|
|
|
|
if (hf_token) {
|
|
headers.Authorization = `Bearer ${hf_token}`;
|
|
}
|
|
|
|
const res = await _fetch(`${http_protocol}//${host}/${LOGIN_URL}`, {
|
|
headers,
|
|
method: "POST",
|
|
body: formData,
|
|
credentials: "include"
|
|
});
|
|
|
|
if (res.status === 200) {
|
|
return res.headers.get("set-cookie");
|
|
} else if (res.status === 401) {
|
|
throw new Error(INVALID_CREDENTIALS_MSG);
|
|
} else {
|
|
throw new Error(SPACE_METADATA_ERROR_MSG);
|
|
}
|
|
}
|
|
|
|
export function determine_protocol(endpoint: string): {
|
|
ws_protocol: "ws" | "wss";
|
|
http_protocol: "http:" | "https:";
|
|
host: string;
|
|
} {
|
|
if (endpoint.startsWith("http")) {
|
|
const { protocol, host, pathname } = new URL(endpoint);
|
|
|
|
if (host.endsWith("hf.space")) {
|
|
return {
|
|
ws_protocol: "wss",
|
|
host: host,
|
|
http_protocol: protocol as "http:" | "https:"
|
|
};
|
|
}
|
|
return {
|
|
ws_protocol: protocol === "https:" ? "wss" : "ws",
|
|
http_protocol: protocol as "http:" | "https:",
|
|
host: host + (pathname !== "/" ? pathname : "")
|
|
};
|
|
} else if (endpoint.startsWith("file:")) {
|
|
|
|
|
|
return {
|
|
ws_protocol: "ws",
|
|
http_protocol: "http:",
|
|
host: "lite.local"
|
|
};
|
|
}
|
|
|
|
|
|
return {
|
|
ws_protocol: "wss",
|
|
http_protocol: "https:",
|
|
host: endpoint
|
|
};
|
|
}
|
|
|
|
export const parse_and_set_cookies = (cookie_header: string): string[] => {
|
|
let cookies: string[] = [];
|
|
const parts = cookie_header.split(/,(?=\s*[^\s=;]+=[^\s=;]+)/);
|
|
parts.forEach((cookie) => {
|
|
const [cookie_name, cookie_value] = cookie.split(";")[0].split("=");
|
|
if (cookie_name && cookie_value) {
|
|
cookies.push(`${cookie_name.trim()}=${cookie_value.trim()}`);
|
|
}
|
|
});
|
|
return cookies;
|
|
};
|
|
|