File size: 5,701 Bytes
b2ecf7d
9d298eb
 
b2ecf7d
 
 
 
 
 
9d298eb
b2ecf7d
 
 
 
 
 
9d298eb
b2ecf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
<script lang="ts">
	import { InferenceDisplayability } from "@huggingface/tasks";
	import type { WidgetExample, WidgetExampleAttribute } from "@huggingface/tasks";
	import type { WidgetProps, ModelLoadInfo, ExampleRunOpts } from "../types.js";

	type TWidgetExample = $$Generic<WidgetExample>;

	import { onMount } from "svelte";

	import IconCross from "../../..//Icons/IconCross.svelte";
	import WidgetInputSamples from "../WidgetInputSamples/WidgetInputSamples.svelte";
	import WidgetInputSamplesGroup from "../WidgetInputSamplesGroup/WidgetInputSamplesGroup.svelte";
	import WidgetFooter from "../WidgetFooter/WidgetFooter.svelte";
	import WidgetHeader from "../WidgetHeader/WidgetHeader.svelte";
	import WidgetInfo from "../WidgetInfo/WidgetInfo.svelte";
	import WidgetModelLoading from "../WidgetModelLoading/WidgetModelLoading.svelte";
	import { getModelLoadInfo, getQueryParamVal, getWidgetExample } from "../../..//InferenceWidget/shared/helpers.js";
	import { modelLoadStates } from "../../stores.js";

	export let apiUrl: string;
	export let callApiOnMount: WidgetProps["callApiOnMount"];
	export let computeTime: string;
	export let error: string;
	export let isLoading = false;
	export let model: WidgetProps["model"];
	export let includeCredentials: WidgetProps["includeCredentials"];
	export let modelLoading = {
		isLoading: false,
		estimatedTime: 0,
	};
	export let noTitle = false;
	export let outputJson: string;
	export let applyInputSample: (sample: TWidgetExample, opts?: ExampleRunOpts) => void = () => {};
	export let validateExample: (sample: WidgetExample) => sample is TWidgetExample;
	export let exampleQueryParams: WidgetExampleAttribute[] = [];

	let isDisabled = model.inference !== InferenceDisplayability.Yes && model.pipeline_tag !== "reinforcement-learning";
	let isMaximized = false;
	let modelLoadInfo: ModelLoadInfo | undefined = undefined;
	let selectedInputGroup: string;
	let modelTooBig = false;

	interface ExamplesGroup {
		group: string;
		inputSamples: TWidgetExample[];
	}

	const allInputSamples = (model.widgetData ?? [])
		.filter(validateExample)
		.sort((sample1, sample2) => (sample2.example_title ? 1 : 0) - (sample1.example_title ? 1 : 0))
		.map((sample, idx) => ({
			example_title: `Example ${++idx}`,
			group: "Group 1",
			...sample,
		}));
	let inputSamples = !isDisabled ? allInputSamples : allInputSamples.filter((sample) => sample.output !== undefined);
	let inputGroups = getExamplesGroups();

	$: selectedInputSamples =
		inputGroups.length === 1 ? inputGroups[0] : inputGroups.find(({ group }) => group === selectedInputGroup);

	function getExamplesGroups(): ExamplesGroup[] {
		const inputGroups: ExamplesGroup[] = [];
		for (const inputSample of inputSamples) {
			const groupExists = inputGroups.find(({ group }) => group === inputSample.group);
			if (!groupExists) {
				inputGroups.push({ group: inputSample.group as string, inputSamples: [] });
			}
			inputGroups.find(({ group }) => group === inputSample.group)?.inputSamples.push(inputSample);
		}
		return inputGroups;
	}

	onMount(() => {
		(async () => {
			modelLoadInfo = await getModelLoadInfo(apiUrl, model.id, includeCredentials);
			$modelLoadStates[model.id] = modelLoadInfo;
			modelTooBig = modelLoadInfo?.state === "TooBig";

			if (modelTooBig) {
				// disable the widget
				isDisabled = true;
				inputSamples = allInputSamples.filter((sample) => sample.output !== undefined);
				inputGroups = getExamplesGroups();
			}

			const exampleFromQueryParams = {} as TWidgetExample;
			for (const key of exampleQueryParams) {
				const val = getQueryParamVal(key);
				if (val) {
					// @ts-expect-error complicated type
					exampleFromQueryParams[key] = val;
				}
			}
			if (Object.keys(exampleFromQueryParams).length) {
				// run widget example from query params
				applyInputSample(exampleFromQueryParams);
			} else {
				// run random widget example
				const example = getWidgetExample<TWidgetExample>(model, validateExample);
				if (callApiOnMount && example) {
					applyInputSample(example, { inferenceOpts: { isOnLoadCall: true } });
				}
			}
		})();
	});

	function onClickMaximizeBtn() {
		isMaximized = !isMaximized;
	}
</script>

{#if isDisabled && !inputSamples.length}
	<WidgetHeader pipeline={model.pipeline_tag} noTitle={true} />
	<WidgetInfo {model} {computeTime} {error} {modelLoadInfo} {modelTooBig} />
{:else}
	<div
		class="flex w-full max-w-full flex-col
		 {isMaximized ? 'fixed inset-0 z-20 bg-white p-12' : ''}
		 {!modelLoadInfo ? 'hidden' : ''}"
	>
		{#if isMaximized}
			<button class="absolute right-12 top-6" on:click={onClickMaximizeBtn}>
				<IconCross classNames="text-xl text-gray-500 hover:text-black" />
			</button>
		{/if}
		<WidgetHeader {noTitle} pipeline={model.pipeline_tag} {isDisabled}>
			{#if !!inputGroups.length}
				<div class="ml-auto flex gap-x-1">
					<!-- Show samples selector when there are more than one sample -->
					{#if inputGroups.length > 1}
						<WidgetInputSamplesGroup
							bind:selectedInputGroup
							{isLoading}
							inputGroups={inputGroups.map(({ group }) => group)}
						/>
					{/if}
					<WidgetInputSamples
						classNames={!selectedInputSamples ? "opacity-50 pointer-events-none" : ""}
						{isLoading}
						inputSamples={selectedInputSamples?.inputSamples ?? []}
						{applyInputSample}
					/>
				</div>
			{/if}
		</WidgetHeader>
		<slot name="top" {isDisabled} />
		<WidgetInfo {model} {computeTime} {error} {modelLoadInfo} {modelTooBig} />
		{#if modelLoading.isLoading}
			<WidgetModelLoading estimatedTime={modelLoading.estimatedTime} />
		{/if}
		<slot name="bottom" />
		<WidgetFooter {onClickMaximizeBtn} {outputJson} {isDisabled} />
	</div>
{/if}