|
import { defaultGenerationConfig } from "$lib/components/inference-playground/generation-config-settings.js"; |
|
|
|
|
|
import { showQuotaModal } from "$lib/components/quota-modal.svelte"; |
|
import { createInit } from "$lib/spells/create-init.svelte.js"; |
|
import { |
|
PipelineTag, |
|
type Conversation, |
|
type ConversationMessage, |
|
type DefaultProject, |
|
type Model, |
|
type Project, |
|
type Session, |
|
} from "$lib/types.js"; |
|
import { safeParse } from "$lib/utils/json.js"; |
|
import typia from "typia"; |
|
import { models } from "./models.svelte"; |
|
import { checkpoints } from "./checkpoints.svelte"; |
|
import { handleNonStreamingResponse, handleStreamingResponse } from "$lib/components/inference-playground/utils.js"; |
|
import { AbortManager } from "$lib/spells/abort-manager.svelte"; |
|
import { addToast } from "$lib/components/toaster.svelte.js"; |
|
import { token } from "./token.svelte"; |
|
|
|
const LOCAL_STORAGE_KEY = "hf_inference_playground_session"; |
|
|
|
interface GenerationStatistics { |
|
latency: number; |
|
generatedTokensCount: number; |
|
} |
|
|
|
const startMessageUser: ConversationMessage = { role: "user", content: "" }; |
|
const systemMessage: ConversationMessage = { |
|
role: "system", |
|
content: "", |
|
}; |
|
|
|
export const emptyModel: Model = { |
|
_id: "", |
|
inferenceProviderMapping: [], |
|
pipeline_tag: PipelineTag.TextGeneration, |
|
trendingScore: 0, |
|
tags: ["text-generation"], |
|
id: "", |
|
config: { |
|
architectures: [] as string[], |
|
model_type: "", |
|
tokenizer_config: {}, |
|
}, |
|
}; |
|
|
|
function getDefaults() { |
|
const defaultModel = models.trending[0] ?? models.remote[0] ?? emptyModel; |
|
|
|
const defaultConversation: Conversation = { |
|
model: defaultModel, |
|
config: { ...defaultGenerationConfig }, |
|
messages: [{ ...startMessageUser }], |
|
systemMessage, |
|
streaming: true, |
|
}; |
|
|
|
const defaultProject: DefaultProject = { |
|
name: "Default", |
|
id: "default", |
|
conversations: [defaultConversation], |
|
}; |
|
|
|
return { defaultProject, defaultConversation }; |
|
} |
|
|
|
class SessionState { |
|
#value = $state<Session>({} as Session); |
|
|
|
generationStats = $state([{ latency: 0, generatedTokensCount: 0 }] as |
|
| [GenerationStatistics] |
|
| [GenerationStatistics, GenerationStatistics]); |
|
generating = $state(false); |
|
|
|
#abortManager = new AbortManager(); |
|
|
|
|
|
init = createInit(() => { |
|
const { defaultConversation, defaultProject } = getDefaults(); |
|
|
|
|
|
let savedSession: Session = { |
|
projects: [defaultProject], |
|
activeProjectId: defaultProject.id, |
|
}; |
|
|
|
const savedData = localStorage.getItem(LOCAL_STORAGE_KEY); |
|
if (savedData) { |
|
const parsed = safeParse(savedData); |
|
const res = typia.validate<Session>(parsed); |
|
if (res.success) { |
|
savedSession = parsed; |
|
} else { |
|
localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(savedSession)); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
const dp = savedSession.projects.find(p => p.id === "default"); |
|
if (typia.is<DefaultProject>(dp)) { |
|
|
|
const searchParams = new URLSearchParams(window.location.search); |
|
const searchProviders = searchParams.getAll("provider"); |
|
const searchModelIds = searchParams.getAll("modelId"); |
|
const modelsFromSearch = searchModelIds.map(id => models.remote.find(model => model.id === id)).filter(Boolean); |
|
if (modelsFromSearch.length > 0) { |
|
savedSession.activeProjectId = "default"; |
|
|
|
let min = Math.min(dp.conversations.length, modelsFromSearch.length, searchProviders.length); |
|
min = Math.max(1, min); |
|
const convos = dp.conversations.slice(0, min); |
|
if (typia.is<Project["conversations"]>(convos)) dp.conversations = convos; |
|
|
|
for (let i = 0; i < min; i++) { |
|
const conversation = dp.conversations[i] ?? defaultConversation; |
|
dp.conversations[i] = { |
|
...conversation, |
|
model: modelsFromSearch[i] ?? conversation.model, |
|
provider: searchProviders[i] ?? conversation.provider, |
|
}; |
|
} |
|
} |
|
} |
|
|
|
this.$ = savedSession; |
|
session.generationStats = session.project.conversations.map(_ => ({ latency: 0, generatedTokensCount: 0 })) as |
|
| [GenerationStatistics] |
|
| [GenerationStatistics, GenerationStatistics]; |
|
this.#abortManager.init(); |
|
}); |
|
|
|
constructor() { |
|
$effect.root(() => { |
|
$effect(() => { |
|
if (!this.init.called) return; |
|
const v = $state.snapshot(this.#value); |
|
try { |
|
localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(v)); |
|
} catch (e) { |
|
console.error("Failed to save session to localStorage:", e); |
|
} |
|
}); |
|
}); |
|
} |
|
|
|
get $() { |
|
return this.#value; |
|
} |
|
|
|
set $(v: Session) { |
|
this.#value = v; |
|
} |
|
|
|
#setAnySession(s: unknown) { |
|
if (typia.is<Session>(s)) this.$ = s; |
|
} |
|
|
|
saveProject = (args: { name: string; moveCheckpoints?: boolean }) => { |
|
const defaultProject = this.$.projects.find(p => p.id === "default"); |
|
if (!defaultProject) return; |
|
|
|
const project: Project = { |
|
...defaultProject, |
|
name: args.name, |
|
id: crypto.randomUUID(), |
|
}; |
|
|
|
if (args.moveCheckpoints) { |
|
checkpoints.migrate(defaultProject.id, project.id); |
|
} |
|
|
|
defaultProject.conversations = [getDefaults().defaultConversation]; |
|
|
|
this.addProject(project); |
|
}; |
|
|
|
addProject = (project: Project) => { |
|
this.$ = { ...this.$, projects: [...this.$.projects, project], activeProjectId: project.id }; |
|
}; |
|
|
|
deleteProject = (id: string) => { |
|
|
|
if (id === "default") return; |
|
|
|
const projects = this.$.projects.filter(p => p.id !== id); |
|
if (projects.length === 0) { |
|
const { defaultProject } = getDefaults(); |
|
this.#setAnySession({ ...this.$, projects: [defaultProject], activeProjectId: defaultProject.id }); |
|
} |
|
|
|
const currProject = projects.find(p => p.id === this.$.activeProjectId); |
|
this.#setAnySession({ ...this.$, projects, activeProjectId: currProject?.id ?? projects[0]?.id }); |
|
checkpoints.clear(id); |
|
}; |
|
|
|
updateProject = (id: string, data: Partial<Project>) => { |
|
const projects = this.$.projects.map(p => (p.id === id ? { ...p, ...data } : p)); |
|
this.#setAnySession({ ...this.$, projects }); |
|
}; |
|
|
|
get project() { |
|
return this.$.projects.find(p => p.id === this.$.activeProjectId) ?? this.$.projects[0]; |
|
} |
|
|
|
set project(np: Project) { |
|
const projects = this.$.projects.map(p => (p.id === np.id ? np : p)); |
|
this.#setAnySession({ ...this.$, projects }); |
|
} |
|
|
|
async #runInference(conversation: Conversation) { |
|
const idx = session.project.conversations.indexOf(conversation); |
|
|
|
const startTime = performance.now(); |
|
|
|
if (conversation.streaming) { |
|
let addedMessage = false; |
|
const streamingMessage = $state({ role: "assistant", content: "" }); |
|
|
|
await handleStreamingResponse( |
|
conversation, |
|
content => { |
|
if (!streamingMessage) return; |
|
streamingMessage.content = content; |
|
if (!addedMessage) { |
|
conversation.messages = [...conversation.messages, streamingMessage]; |
|
addedMessage = true; |
|
} |
|
}, |
|
this.#abortManager.createController() |
|
); |
|
} else { |
|
const { message: newMessage, completion_tokens: newTokensCount } = await handleNonStreamingResponse(conversation); |
|
conversation.messages = [...conversation.messages, newMessage]; |
|
const c = session.generationStats[idx]; |
|
if (c) c.generatedTokensCount += newTokensCount; |
|
} |
|
|
|
const endTime = performance.now(); |
|
const c = session.generationStats[idx]; |
|
if (c) c.latency = Math.round(endTime - startTime); |
|
} |
|
|
|
async run(conv: "left" | "right" | "both" | Conversation = "both") { |
|
if (!token.value) { |
|
token.showModal = true; |
|
return; |
|
} |
|
|
|
const conversations = (() => { |
|
if (typeof conv === "string") { |
|
return session.project.conversations.filter((_, idx) => { |
|
return conv === "both" || (conv === "left" ? idx === 0 : idx === 1); |
|
}); |
|
} |
|
return [conv]; |
|
})(); |
|
|
|
for (let idx = 0; idx < conversations.length; idx++) { |
|
const conversation = conversations[idx]; |
|
if (!conversation || conversation.messages.at(-1)?.role !== "assistant") continue; |
|
|
|
let prefix = ""; |
|
if (session.project.conversations.length === 2) { |
|
prefix = `Error on ${idx === 0 ? "left" : "right"} conversation. `; |
|
} |
|
return addToast({ |
|
title: "Failed to run inference", |
|
description: `${prefix}Messages must alternate between user/assistant roles.`, |
|
variant: "error", |
|
}); |
|
} |
|
|
|
(document.activeElement as HTMLElement).blur(); |
|
session.generating = true; |
|
|
|
try { |
|
const promises = conversations.map(c => this.#runInference(c)); |
|
await Promise.all(promises); |
|
} catch (error) { |
|
for (const conversation of conversations) { |
|
if (conversation.messages.at(-1)?.role === "assistant" && !conversation.messages.at(-1)?.content?.trim()) { |
|
conversation.messages.pop(); |
|
conversation.messages = [...conversation.messages]; |
|
} |
|
|
|
session.$ = session.$; |
|
} |
|
|
|
if (error instanceof Error) { |
|
const msg = error.message; |
|
if (msg.toLowerCase().includes("montly") || msg.toLowerCase().includes("pro")) { |
|
showQuotaModal(); |
|
} |
|
|
|
if (error.message.includes("token seems invalid")) { |
|
token.reset(); |
|
} |
|
|
|
if (error.name !== "AbortError") { |
|
addToast({ title: "Error", description: error.message, variant: "error" }); |
|
} |
|
} else { |
|
addToast({ title: "Error", description: "An unknown error occurred", variant: "error" }); |
|
} |
|
} finally { |
|
session.generating = false; |
|
this.#abortManager.clear(); |
|
} |
|
} |
|
|
|
stopGenerating = () => { |
|
this.#abortManager.abortAll(); |
|
session.generating = false; |
|
}; |
|
|
|
runOrStop = (c?: Parameters<typeof this.run>[0]) => { |
|
if (session.generating) { |
|
this.stopGenerating(); |
|
} else { |
|
this.run(c); |
|
} |
|
}; |
|
} |
|
|
|
export const session = new SessionState(); |
|
|