Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 5,701 Bytes
b2ecf7d 9d298eb b2ecf7d 9d298eb b2ecf7d 9d298eb b2ecf7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
<script lang="ts">
import { InferenceDisplayability } from "@huggingface/tasks";
import type { WidgetExample, WidgetExampleAttribute } from "@huggingface/tasks";
import type { WidgetProps, ModelLoadInfo, ExampleRunOpts } from "../types.js";
type TWidgetExample = $$Generic<WidgetExample>;
import { onMount } from "svelte";
import IconCross from "../../..//Icons/IconCross.svelte";
import WidgetInputSamples from "../WidgetInputSamples/WidgetInputSamples.svelte";
import WidgetInputSamplesGroup from "../WidgetInputSamplesGroup/WidgetInputSamplesGroup.svelte";
import WidgetFooter from "../WidgetFooter/WidgetFooter.svelte";
import WidgetHeader from "../WidgetHeader/WidgetHeader.svelte";
import WidgetInfo from "../WidgetInfo/WidgetInfo.svelte";
import WidgetModelLoading from "../WidgetModelLoading/WidgetModelLoading.svelte";
import { getModelLoadInfo, getQueryParamVal, getWidgetExample } from "../../..//InferenceWidget/shared/helpers.js";
import { modelLoadStates } from "../../stores.js";
export let apiUrl: string;
export let callApiOnMount: WidgetProps["callApiOnMount"];
export let computeTime: string;
export let error: string;
export let isLoading = false;
export let model: WidgetProps["model"];
export let includeCredentials: WidgetProps["includeCredentials"];
export let modelLoading = {
isLoading: false,
estimatedTime: 0,
};
export let noTitle = false;
export let outputJson: string;
export let applyInputSample: (sample: TWidgetExample, opts?: ExampleRunOpts) => void = () => {};
export let validateExample: (sample: WidgetExample) => sample is TWidgetExample;
export let exampleQueryParams: WidgetExampleAttribute[] = [];
let isDisabled = model.inference !== InferenceDisplayability.Yes && model.pipeline_tag !== "reinforcement-learning";
let isMaximized = false;
let modelLoadInfo: ModelLoadInfo | undefined = undefined;
let selectedInputGroup: string;
let modelTooBig = false;
interface ExamplesGroup {
group: string;
inputSamples: TWidgetExample[];
}
const allInputSamples = (model.widgetData ?? [])
.filter(validateExample)
.sort((sample1, sample2) => (sample2.example_title ? 1 : 0) - (sample1.example_title ? 1 : 0))
.map((sample, idx) => ({
example_title: `Example ${++idx}`,
group: "Group 1",
...sample,
}));
let inputSamples = !isDisabled ? allInputSamples : allInputSamples.filter((sample) => sample.output !== undefined);
let inputGroups = getExamplesGroups();
$: selectedInputSamples =
inputGroups.length === 1 ? inputGroups[0] : inputGroups.find(({ group }) => group === selectedInputGroup);
function getExamplesGroups(): ExamplesGroup[] {
const inputGroups: ExamplesGroup[] = [];
for (const inputSample of inputSamples) {
const groupExists = inputGroups.find(({ group }) => group === inputSample.group);
if (!groupExists) {
inputGroups.push({ group: inputSample.group as string, inputSamples: [] });
}
inputGroups.find(({ group }) => group === inputSample.group)?.inputSamples.push(inputSample);
}
return inputGroups;
}
onMount(() => {
(async () => {
modelLoadInfo = await getModelLoadInfo(apiUrl, model.id, includeCredentials);
$modelLoadStates[model.id] = modelLoadInfo;
modelTooBig = modelLoadInfo?.state === "TooBig";
if (modelTooBig) {
// disable the widget
isDisabled = true;
inputSamples = allInputSamples.filter((sample) => sample.output !== undefined);
inputGroups = getExamplesGroups();
}
const exampleFromQueryParams = {} as TWidgetExample;
for (const key of exampleQueryParams) {
const val = getQueryParamVal(key);
if (val) {
// @ts-expect-error complicated type
exampleFromQueryParams[key] = val;
}
}
if (Object.keys(exampleFromQueryParams).length) {
// run widget example from query params
applyInputSample(exampleFromQueryParams);
} else {
// run random widget example
const example = getWidgetExample<TWidgetExample>(model, validateExample);
if (callApiOnMount && example) {
applyInputSample(example, { inferenceOpts: { isOnLoadCall: true } });
}
}
})();
});
function onClickMaximizeBtn() {
isMaximized = !isMaximized;
}
</script>
{#if isDisabled && !inputSamples.length}
<WidgetHeader pipeline={model.pipeline_tag} noTitle={true} />
<WidgetInfo {model} {computeTime} {error} {modelLoadInfo} {modelTooBig} />
{:else}
<div
class="flex w-full max-w-full flex-col
{isMaximized ? 'fixed inset-0 z-20 bg-white p-12' : ''}
{!modelLoadInfo ? 'hidden' : ''}"
>
{#if isMaximized}
<button class="absolute right-12 top-6" on:click={onClickMaximizeBtn}>
<IconCross classNames="text-xl text-gray-500 hover:text-black" />
</button>
{/if}
<WidgetHeader {noTitle} pipeline={model.pipeline_tag} {isDisabled}>
{#if !!inputGroups.length}
<div class="ml-auto flex gap-x-1">
<!-- Show samples selector when there are more than one sample -->
{#if inputGroups.length > 1}
<WidgetInputSamplesGroup
bind:selectedInputGroup
{isLoading}
inputGroups={inputGroups.map(({ group }) => group)}
/>
{/if}
<WidgetInputSamples
classNames={!selectedInputSamples ? "opacity-50 pointer-events-none" : ""}
{isLoading}
inputSamples={selectedInputSamples?.inputSamples ?? []}
{applyInputSample}
/>
</div>
{/if}
</WidgetHeader>
<slot name="top" {isDisabled} />
<WidgetInfo {model} {computeTime} {error} {modelLoadInfo} {modelTooBig} />
{#if modelLoading.isLoading}
<WidgetModelLoading estimatedTime={modelLoading.estimatedTime} />
{/if}
<slot name="bottom" />
<WidgetFooter {onClickMaximizeBtn} {outputJson} {isDisabled} />
</div>
{/if}
|