|
import type { Config } from "../types"; |
|
import { |
|
CONFIG_ERROR_MSG, |
|
CONFIG_URL, |
|
INVALID_CREDENTIALS_MSG, |
|
LOGIN_URL, |
|
MISSING_CREDENTIALS_MSG, |
|
SPACE_METADATA_ERROR_MSG, |
|
UNAUTHORIZED_MSG |
|
} from "../constants"; |
|
import { Client } from ".."; |
|
import { join_urls, process_endpoint } from "./api_info"; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export function resolve_root( |
|
base_url: string, |
|
root_path: string, |
|
prioritize_base: boolean |
|
): string { |
|
if (root_path.startsWith("http://") || root_path.startsWith("https://")) { |
|
return prioritize_base ? base_url : root_path; |
|
} |
|
return base_url + root_path; |
|
} |
|
|
|
export async function get_jwt( |
|
space: string, |
|
token: `hf_${string}`, |
|
cookies?: string | null |
|
): Promise<string | false> { |
|
try { |
|
const r = await fetch(`https://huggingface.co/api/spaces/${space}/jwt`, { |
|
headers: { |
|
Authorization: `Bearer ${token}`, |
|
...(cookies ? { Cookie: cookies } : {}) |
|
} |
|
}); |
|
|
|
const jwt = (await r.json()).token; |
|
|
|
return jwt || false; |
|
} catch (e) { |
|
return false; |
|
} |
|
} |
|
|
|
export function map_names_to_ids( |
|
fns: Config["dependencies"] |
|
): Record<string, number> { |
|
let apis: Record<string, number> = {}; |
|
|
|
fns.forEach(({ api_name, id }) => { |
|
if (api_name) apis[api_name] = id; |
|
}); |
|
return apis; |
|
} |
|
|
|
export async function resolve_config( |
|
this: Client, |
|
endpoint: string |
|
): Promise<Config | undefined> { |
|
const headers: Record<string, string> = this.options.hf_token |
|
? { Authorization: `Bearer ${this.options.hf_token}` } |
|
: {}; |
|
|
|
headers["Content-Type"] = "application/json"; |
|
|
|
if ( |
|
typeof window !== "undefined" && |
|
window.gradio_config && |
|
location.origin !== "http://localhost:9876" && |
|
!window.gradio_config.dev_mode |
|
) { |
|
const path = window.gradio_config.root; |
|
const config = window.gradio_config; |
|
let config_root = resolve_root(endpoint, config.root, false); |
|
config.root = config_root; |
|
return { ...config, path } as Config; |
|
} else if (endpoint) { |
|
const config_url = join_urls(endpoint, CONFIG_URL); |
|
|
|
const response = await this.fetch(config_url, { |
|
headers, |
|
credentials: "include" |
|
}); |
|
|
|
if (response?.status === 401 && !this.options.auth) { |
|
throw new Error(MISSING_CREDENTIALS_MSG); |
|
} else if (response?.status === 401 && this.options.auth) { |
|
throw new Error(INVALID_CREDENTIALS_MSG); |
|
} |
|
if (response?.status === 200) { |
|
let config = await response.json(); |
|
config.path = config.path ?? ""; |
|
config.root = endpoint; |
|
config.dependencies?.forEach((dep: any, i: number) => { |
|
if (dep.id === undefined) { |
|
dep.id = i; |
|
} |
|
}); |
|
return config; |
|
} else if (response?.status === 401) { |
|
throw new Error(UNAUTHORIZED_MSG); |
|
} |
|
throw new Error(CONFIG_ERROR_MSG); |
|
} |
|
|
|
throw new Error(CONFIG_ERROR_MSG); |
|
} |
|
|
|
export async function resolve_cookies(this: Client): Promise<void> { |
|
const { http_protocol, host } = await process_endpoint( |
|
this.app_reference, |
|
this.options.hf_token |
|
); |
|
|
|
try { |
|
if (this.options.auth) { |
|
const cookie_header = await get_cookie_header( |
|
http_protocol, |
|
host, |
|
this.options.auth, |
|
this.fetch, |
|
this.options.hf_token |
|
); |
|
|
|
if (cookie_header) this.set_cookies(cookie_header); |
|
} |
|
} catch (e: unknown) { |
|
throw Error((e as Error).message); |
|
} |
|
} |
|
|
|
|
|
export async function get_cookie_header( |
|
http_protocol: string, |
|
host: string, |
|
auth: [string, string], |
|
_fetch: typeof fetch, |
|
hf_token?: `hf_${string}` |
|
): Promise<string | null> { |
|
const formData = new FormData(); |
|
formData.append("username", auth?.[0]); |
|
formData.append("password", auth?.[1]); |
|
|
|
let headers: { Authorization?: string } = {}; |
|
|
|
if (hf_token) { |
|
headers.Authorization = `Bearer ${hf_token}`; |
|
} |
|
|
|
const res = await _fetch(`${http_protocol}//${host}/${LOGIN_URL}`, { |
|
headers, |
|
method: "POST", |
|
body: formData, |
|
credentials: "include" |
|
}); |
|
|
|
if (res.status === 200) { |
|
return res.headers.get("set-cookie"); |
|
} else if (res.status === 401) { |
|
throw new Error(INVALID_CREDENTIALS_MSG); |
|
} else { |
|
throw new Error(SPACE_METADATA_ERROR_MSG); |
|
} |
|
} |
|
|
|
export function determine_protocol(endpoint: string): { |
|
ws_protocol: "ws" | "wss"; |
|
http_protocol: "http:" | "https:"; |
|
host: string; |
|
} { |
|
if (endpoint.startsWith("http")) { |
|
const { protocol, host, pathname } = new URL(endpoint); |
|
|
|
if (host.endsWith("hf.space")) { |
|
return { |
|
ws_protocol: "wss", |
|
host: host, |
|
http_protocol: protocol as "http:" | "https:" |
|
}; |
|
} |
|
return { |
|
ws_protocol: protocol === "https:" ? "wss" : "ws", |
|
http_protocol: protocol as "http:" | "https:", |
|
host: host + (pathname !== "/" ? pathname : "") |
|
}; |
|
} else if (endpoint.startsWith("file:")) { |
|
|
|
|
|
return { |
|
ws_protocol: "ws", |
|
http_protocol: "http:", |
|
host: "lite.local" |
|
}; |
|
} |
|
|
|
|
|
|
|
return { |
|
ws_protocol: "wss", |
|
http_protocol: "https:", |
|
host: new URL(endpoint).host |
|
}; |
|
} |
|
|
|
export const parse_and_set_cookies = (cookie_header: string): string[] => { |
|
let cookies: string[] = []; |
|
const parts = cookie_header.split(/,(?=\s*[^\s=;]+=[^\s=;]+)/); |
|
parts.forEach((cookie) => { |
|
const [cookie_name, cookie_value] = cookie.split(";")[0].split("="); |
|
if (cookie_name && cookie_value) { |
|
cookies.push(`${cookie_name.trim()}=${cookie_value.trim()}`); |
|
} |
|
}); |
|
return cookies; |
|
}; |
|
|