inference-playground / src /lib /state /session.svelte.ts
Thomas G. Lopes
Enhancements (#79)
f36471e unverified
import { defaultGenerationConfig } from "$lib/components/inference-playground/generation-config-settings.js";
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore - Svelte imports are broken in TS files
import { showQuotaModal } from "$lib/components/quota-modal.svelte";
import { createInit } from "$lib/spells/create-init.svelte.js";
import {
PipelineTag,
type Conversation,
type ConversationMessage,
type DefaultProject,
type Model,
type Project,
type Session,
} from "$lib/types.js";
import { safeParse } from "$lib/utils/json.js";
import typia from "typia";
import { models } from "./models.svelte";
import { checkpoints } from "./checkpoints.svelte";
import { handleNonStreamingResponse, handleStreamingResponse } from "$lib/components/inference-playground/utils.js";
import { AbortManager } from "$lib/spells/abort-manager.svelte";
import { addToast } from "$lib/components/toaster.svelte.js";
import { token } from "./token.svelte";
const LOCAL_STORAGE_KEY = "hf_inference_playground_session";
interface GenerationStatistics {
latency: number;
generatedTokensCount: number;
}
const startMessageUser: ConversationMessage = { role: "user", content: "" };
const systemMessage: ConversationMessage = {
role: "system",
content: "",
};
export const emptyModel: Model = {
_id: "",
inferenceProviderMapping: [],
pipeline_tag: PipelineTag.TextGeneration,
trendingScore: 0,
tags: ["text-generation"],
id: "",
config: {
architectures: [] as string[],
model_type: "",
tokenizer_config: {},
},
};
function getDefaults() {
const defaultModel = models.trending[0] ?? models.remote[0] ?? emptyModel;
const defaultConversation: Conversation = {
model: defaultModel,
config: { ...defaultGenerationConfig },
messages: [{ ...startMessageUser }],
systemMessage,
streaming: true,
};
const defaultProject: DefaultProject = {
name: "Default",
id: "default",
conversations: [defaultConversation],
};
return { defaultProject, defaultConversation };
}
class SessionState {
#value = $state<Session>({} as Session);
generationStats = $state([{ latency: 0, generatedTokensCount: 0 }] as
| [GenerationStatistics]
| [GenerationStatistics, GenerationStatistics]);
generating = $state(false);
#abortManager = new AbortManager();
// Call once in layout
init = createInit(() => {
const { defaultConversation, defaultProject } = getDefaults();
// Get saved session from localStorage if available
let savedSession: Session = {
projects: [defaultProject],
activeProjectId: defaultProject.id,
};
const savedData = localStorage.getItem(LOCAL_STORAGE_KEY);
if (savedData) {
const parsed = safeParse(savedData);
const res = typia.validate<Session>(parsed);
if (res.success) {
savedSession = parsed;
} else {
localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(savedSession));
}
}
// Merge query params with savedSession's default project
// Query params models and providers take precedence over savedSession's.
// In any case, we try to merge the two, and the amount of conversations
// is the maximum between the two.
const dp = savedSession.projects.find(p => p.id === "default");
if (typia.is<DefaultProject>(dp)) {
// Parse URL query parameters
const searchParams = new URLSearchParams(window.location.search);
const searchProviders = searchParams.getAll("provider");
const searchModelIds = searchParams.getAll("modelId");
const modelsFromSearch = searchModelIds.map(id => models.remote.find(model => model.id === id)).filter(Boolean);
if (modelsFromSearch.length > 0) {
savedSession.activeProjectId = "default";
let min = Math.min(dp.conversations.length, modelsFromSearch.length, searchProviders.length);
min = Math.max(1, min);
const convos = dp.conversations.slice(0, min);
if (typia.is<Project["conversations"]>(convos)) dp.conversations = convos;
for (let i = 0; i < min; i++) {
const conversation = dp.conversations[i] ?? defaultConversation;
dp.conversations[i] = {
...conversation,
model: modelsFromSearch[i] ?? conversation.model,
provider: searchProviders[i] ?? conversation.provider,
};
}
}
}
this.$ = savedSession;
session.generationStats = session.project.conversations.map(_ => ({ latency: 0, generatedTokensCount: 0 })) as
| [GenerationStatistics]
| [GenerationStatistics, GenerationStatistics];
this.#abortManager.init();
});
constructor() {
$effect.root(() => {
$effect(() => {
if (!this.init.called) return;
const v = $state.snapshot(this.#value);
try {
localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(v));
} catch (e) {
console.error("Failed to save session to localStorage:", e);
}
});
});
}
get $() {
return this.#value;
}
set $(v: Session) {
this.#value = v;
}
#setAnySession(s: unknown) {
if (typia.is<Session>(s)) this.$ = s;
}
saveProject = (args: { name: string; moveCheckpoints?: boolean }) => {
const defaultProject = this.$.projects.find(p => p.id === "default");
if (!defaultProject) return;
const project: Project = {
...defaultProject,
name: args.name,
id: crypto.randomUUID(),
};
if (args.moveCheckpoints) {
checkpoints.migrate(defaultProject.id, project.id);
}
defaultProject.conversations = [getDefaults().defaultConversation];
this.addProject(project);
};
addProject = (project: Project) => {
this.$ = { ...this.$, projects: [...this.$.projects, project], activeProjectId: project.id };
};
deleteProject = (id: string) => {
// Can't delete default project!
if (id === "default") return;
const projects = this.$.projects.filter(p => p.id !== id);
if (projects.length === 0) {
const { defaultProject } = getDefaults();
this.#setAnySession({ ...this.$, projects: [defaultProject], activeProjectId: defaultProject.id });
}
const currProject = projects.find(p => p.id === this.$.activeProjectId);
this.#setAnySession({ ...this.$, projects, activeProjectId: currProject?.id ?? projects[0]?.id });
checkpoints.clear(id);
};
updateProject = (id: string, data: Partial<Project>) => {
const projects = this.$.projects.map(p => (p.id === id ? { ...p, ...data } : p));
this.#setAnySession({ ...this.$, projects });
};
get project() {
return this.$.projects.find(p => p.id === this.$.activeProjectId) ?? this.$.projects[0];
}
set project(np: Project) {
const projects = this.$.projects.map(p => (p.id === np.id ? np : p));
this.#setAnySession({ ...this.$, projects });
}
async #runInference(conversation: Conversation) {
const idx = session.project.conversations.indexOf(conversation);
const startTime = performance.now();
if (conversation.streaming) {
let addedMessage = false;
const streamingMessage = $state({ role: "assistant", content: "" });
await handleStreamingResponse(
conversation,
content => {
if (!streamingMessage) return;
streamingMessage.content = content;
if (!addedMessage) {
conversation.messages = [...conversation.messages, streamingMessage];
addedMessage = true;
}
},
this.#abortManager.createController()
);
} else {
const { message: newMessage, completion_tokens: newTokensCount } = await handleNonStreamingResponse(conversation);
conversation.messages = [...conversation.messages, newMessage];
const c = session.generationStats[idx];
if (c) c.generatedTokensCount += newTokensCount;
}
const endTime = performance.now();
const c = session.generationStats[idx];
if (c) c.latency = Math.round(endTime - startTime);
}
async run(conv: "left" | "right" | "both" | Conversation = "both") {
if (!token.value) {
token.showModal = true;
return;
}
const conversations = (() => {
if (typeof conv === "string") {
return session.project.conversations.filter((_, idx) => {
return conv === "both" || (conv === "left" ? idx === 0 : idx === 1);
});
}
return [conv];
})();
for (let idx = 0; idx < conversations.length; idx++) {
const conversation = conversations[idx];
if (!conversation || conversation.messages.at(-1)?.role !== "assistant") continue;
let prefix = "";
if (session.project.conversations.length === 2) {
prefix = `Error on ${idx === 0 ? "left" : "right"} conversation. `;
}
return addToast({
title: "Failed to run inference",
description: `${prefix}Messages must alternate between user/assistant roles.`,
variant: "error",
});
}
(document.activeElement as HTMLElement).blur();
session.generating = true;
try {
const promises = conversations.map(c => this.#runInference(c));
await Promise.all(promises);
} catch (error) {
for (const conversation of conversations) {
if (conversation.messages.at(-1)?.role === "assistant" && !conversation.messages.at(-1)?.content?.trim()) {
conversation.messages.pop();
conversation.messages = [...conversation.messages];
}
// eslint-disable-next-line no-self-assign
session.$ = session.$;
}
if (error instanceof Error) {
const msg = error.message;
if (msg.toLowerCase().includes("montly") || msg.toLowerCase().includes("pro")) {
showQuotaModal();
}
if (error.message.includes("token seems invalid")) {
token.reset();
}
if (error.name !== "AbortError") {
addToast({ title: "Error", description: error.message, variant: "error" });
}
} else {
addToast({ title: "Error", description: "An unknown error occurred", variant: "error" });
}
} finally {
session.generating = false;
this.#abortManager.clear();
}
}
stopGenerating = () => {
this.#abortManager.abortAll();
session.generating = false;
};
runOrStop = (c?: Parameters<typeof this.run>[0]) => {
if (session.generating) {
this.stopGenerating();
} else {
this.run(c);
}
};
}
export const session = new SessionState();