Spaces:
Running
Running
✨ Save messages in backend (#31)
Browse files- .eslintrc.cjs +3 -0
- src/lib/buildPrompt.ts +25 -0
- src/lib/utils/streamToAsyncIterable.ts +15 -0
- src/lib/utils/sum.ts +3 -0
- src/routes/+page.svelte +1 -1
- src/routes/api/conversation/+server.ts +0 -19
- src/routes/conversation/[id]/+page.svelte +2 -20
- src/routes/conversation/[id]/+server.ts +110 -0
.eslintrc.cjs
CHANGED
|
@@ -12,6 +12,9 @@ module.exports = {
|
|
| 12 |
sourceType: 'module',
|
| 13 |
ecmaVersion: 2020
|
| 14 |
},
|
|
|
|
|
|
|
|
|
|
| 15 |
env: {
|
| 16 |
browser: true,
|
| 17 |
es2017: true,
|
|
|
|
| 12 |
sourceType: 'module',
|
| 13 |
ecmaVersion: 2020
|
| 14 |
},
|
| 15 |
+
rules: {
|
| 16 |
+
'no-shadow': ['error']
|
| 17 |
+
},
|
| 18 |
env: {
|
| 19 |
browser: true,
|
| 20 |
es2017: true,
|
src/lib/buildPrompt.ts
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import {
|
| 2 |
+
PUBLIC_ASSISTANT_MESSAGE_TOKEN,
|
| 3 |
+
PUBLIC_SEP_TOKEN,
|
| 4 |
+
PUBLIC_USER_MESSAGE_TOKEN
|
| 5 |
+
} from '$env/static/public';
|
| 6 |
+
import type { Message } from './types/Message';
|
| 7 |
+
|
| 8 |
+
/**
|
| 9 |
+
* Convert [{user: "assistant", content: "hi"}, {user: "user", content: "hello"}] to:
|
| 10 |
+
*
|
| 11 |
+
* <|assistant|>hi<|endoftext|><|prompter|>hello<|endoftext|><|assistant|>
|
| 12 |
+
*/
|
| 13 |
+
export function buildPrompt(messages: Message[]): string {
|
| 14 |
+
return (
|
| 15 |
+
messages
|
| 16 |
+
.map(
|
| 17 |
+
(m) =>
|
| 18 |
+
(m.from === 'user'
|
| 19 |
+
? PUBLIC_USER_MESSAGE_TOKEN + m.content
|
| 20 |
+
: PUBLIC_ASSISTANT_MESSAGE_TOKEN + m.content) +
|
| 21 |
+
(m.content.endsWith(PUBLIC_SEP_TOKEN) ? '' : PUBLIC_SEP_TOKEN)
|
| 22 |
+
)
|
| 23 |
+
.join('') + PUBLIC_ASSISTANT_MESSAGE_TOKEN
|
| 24 |
+
);
|
| 25 |
+
}
|
src/lib/utils/streamToAsyncIterable.ts
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Statements/for-await...of#iterating_over_async_generators
|
| 2 |
+
export async function* streamToAsyncIterable(
|
| 3 |
+
stream: ReadableStream<Uint8Array>
|
| 4 |
+
): AsyncIterableIterator<Uint8Array> {
|
| 5 |
+
const reader = stream.getReader();
|
| 6 |
+
try {
|
| 7 |
+
while (true) {
|
| 8 |
+
const { done, value } = await reader.read();
|
| 9 |
+
if (done) return;
|
| 10 |
+
yield value;
|
| 11 |
+
}
|
| 12 |
+
} finally {
|
| 13 |
+
reader.releaseLock();
|
| 14 |
+
}
|
| 15 |
+
}
|
src/lib/utils/sum.ts
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export function sum(nums: number[]): number {
|
| 2 |
+
return nums.reduce((a, b) => a + b, 0);
|
| 3 |
+
}
|
src/routes/+page.svelte
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
<script lang="ts">
|
| 2 |
-
import { goto
|
| 3 |
import ChatWindow from '$lib/components/chat/ChatWindow.svelte';
|
| 4 |
import { pendingMessage } from '$lib/stores/pendingMessage';
|
| 5 |
|
|
|
|
| 1 |
<script lang="ts">
|
| 2 |
+
import { goto } from '$app/navigation';
|
| 3 |
import ChatWindow from '$lib/components/chat/ChatWindow.svelte';
|
| 4 |
import { pendingMessage } from '$lib/stores/pendingMessage';
|
| 5 |
|
src/routes/api/conversation/+server.ts
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
import { HF_TOKEN } from '$env/static/private';
|
| 2 |
-
import { PUBLIC_MODEL_ENDPOINT } from '$env/static/public';
|
| 3 |
-
|
| 4 |
-
export async function POST({ request, fetch }) {
|
| 5 |
-
const resp = await fetch(PUBLIC_MODEL_ENDPOINT, {
|
| 6 |
-
headers: {
|
| 7 |
-
'Content-Type': request.headers.get('Content-Type') ?? 'application/json',
|
| 8 |
-
Authorization: `Basic ${HF_TOKEN}`
|
| 9 |
-
},
|
| 10 |
-
method: 'POST',
|
| 11 |
-
body: await request.text()
|
| 12 |
-
});
|
| 13 |
-
|
| 14 |
-
return new Response(resp.body, {
|
| 15 |
-
headers: Object.fromEntries(resp.headers.entries()),
|
| 16 |
-
status: resp.status,
|
| 17 |
-
statusText: resp.statusText
|
| 18 |
-
});
|
| 19 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/routes/conversation/[id]/+page.svelte
CHANGED
|
@@ -4,23 +4,14 @@
|
|
| 4 |
import { onMount } from 'svelte';
|
| 5 |
import type { PageData } from './$types';
|
| 6 |
import { page } from '$app/stores';
|
| 7 |
-
import {
|
| 8 |
-
PUBLIC_ASSISTANT_MESSAGE_TOKEN,
|
| 9 |
-
PUBLIC_SEP_TOKEN,
|
| 10 |
-
PUBLIC_USER_MESSAGE_TOKEN
|
| 11 |
-
} from '$env/static/public';
|
| 12 |
import { HfInference } from '@huggingface/inference';
|
| 13 |
|
| 14 |
export let data: PageData;
|
| 15 |
|
| 16 |
$: messages = data.messages;
|
| 17 |
|
| 18 |
-
const userToken = PUBLIC_USER_MESSAGE_TOKEN;
|
| 19 |
-
const assistantToken = PUBLIC_ASSISTANT_MESSAGE_TOKEN;
|
| 20 |
-
const sepToken = PUBLIC_SEP_TOKEN;
|
| 21 |
-
|
| 22 |
const hf = new HfInference();
|
| 23 |
-
const model = hf.endpoint(
|
| 24 |
|
| 25 |
let loading = false;
|
| 26 |
|
|
@@ -76,16 +67,7 @@
|
|
| 76 |
|
| 77 |
messages = [...messages, { from: 'user', content: message }];
|
| 78 |
|
| 79 |
-
|
| 80 |
-
messages
|
| 81 |
-
.map(
|
| 82 |
-
(m) =>
|
| 83 |
-
(m.from === 'user' ? userToken + m.content : assistantToken + m.content) +
|
| 84 |
-
(m.content.endsWith(sepToken) ? '' : sepToken)
|
| 85 |
-
)
|
| 86 |
-
.join('') + assistantToken;
|
| 87 |
-
|
| 88 |
-
await getTextGenerationStream(inputs);
|
| 89 |
} finally {
|
| 90 |
loading = false;
|
| 91 |
}
|
|
|
|
| 4 |
import { onMount } from 'svelte';
|
| 5 |
import type { PageData } from './$types';
|
| 6 |
import { page } from '$app/stores';
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import { HfInference } from '@huggingface/inference';
|
| 8 |
|
| 9 |
export let data: PageData;
|
| 10 |
|
| 11 |
$: messages = data.messages;
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
const hf = new HfInference();
|
| 14 |
+
const model = hf.endpoint($page.url.href);
|
| 15 |
|
| 16 |
let loading = false;
|
| 17 |
|
|
|
|
| 67 |
|
| 68 |
messages = [...messages, { from: 'user', content: message }];
|
| 69 |
|
| 70 |
+
await getTextGenerationStream(message);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
} finally {
|
| 72 |
loading = false;
|
| 73 |
}
|
src/routes/conversation/[id]/+server.ts
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { HF_TOKEN } from '$env/static/private';
|
| 2 |
+
import { PUBLIC_MODEL_ENDPOINT } from '$env/static/public';
|
| 3 |
+
import { buildPrompt } from '$lib/buildPrompt.js';
|
| 4 |
+
import { collections } from '$lib/server/database.js';
|
| 5 |
+
import type { Message } from '$lib/types/Message.js';
|
| 6 |
+
import { streamToAsyncIterable } from '$lib/utils/streamToAsyncIterable';
|
| 7 |
+
import { sum } from '$lib/utils/sum';
|
| 8 |
+
import { error } from '@sveltejs/kit';
|
| 9 |
+
import { ObjectId } from 'mongodb';
|
| 10 |
+
|
| 11 |
+
export async function POST({ request, fetch, locals, params }) {
|
| 12 |
+
// todo: add validation on params.id
|
| 13 |
+
const convId = new ObjectId(params.id);
|
| 14 |
+
|
| 15 |
+
const conv = await collections.conversations.findOne({
|
| 16 |
+
_id: convId,
|
| 17 |
+
sessionId: locals.sessionId
|
| 18 |
+
});
|
| 19 |
+
|
| 20 |
+
if (!conv) {
|
| 21 |
+
throw error(404, 'Conversation not found');
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// Todo: validate prompt with zod? or aktype
|
| 25 |
+
const json = await request.json();
|
| 26 |
+
|
| 27 |
+
const messages = [...conv.messages, { from: 'user', content: json.inputs }] satisfies Message[];
|
| 28 |
+
|
| 29 |
+
json.inputs = buildPrompt(messages);
|
| 30 |
+
|
| 31 |
+
const resp = await fetch(PUBLIC_MODEL_ENDPOINT, {
|
| 32 |
+
headers: {
|
| 33 |
+
'Content-Type': request.headers.get('Content-Type') ?? 'application/json',
|
| 34 |
+
Authorization: `Basic ${HF_TOKEN}`
|
| 35 |
+
},
|
| 36 |
+
method: 'POST',
|
| 37 |
+
body: JSON.stringify(json)
|
| 38 |
+
});
|
| 39 |
+
|
| 40 |
+
const [stream1, stream2] = resp.body!.tee();
|
| 41 |
+
|
| 42 |
+
async function saveMessage() {
|
| 43 |
+
const generated_text = await parseGeneratedText(stream2);
|
| 44 |
+
|
| 45 |
+
messages.push({ from: 'assistant', content: generated_text });
|
| 46 |
+
|
| 47 |
+
console.log('updating conversation', convId, messages);
|
| 48 |
+
|
| 49 |
+
await collections.conversations.updateOne(
|
| 50 |
+
{
|
| 51 |
+
_id: convId
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
$set: {
|
| 55 |
+
messages,
|
| 56 |
+
updatedAt: new Date()
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
saveMessage().catch(console.error);
|
| 63 |
+
|
| 64 |
+
// Todo: maybe we should wait for the message to be saved before ending the response - in case of errors
|
| 65 |
+
return new Response(stream1, {
|
| 66 |
+
headers: Object.fromEntries(resp.headers.entries()),
|
| 67 |
+
status: resp.status,
|
| 68 |
+
statusText: resp.statusText
|
| 69 |
+
});
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
async function parseGeneratedText(stream: ReadableStream): Promise<string> {
|
| 73 |
+
const inputs: Uint8Array[] = [];
|
| 74 |
+
for await (const input of streamToAsyncIterable(stream)) {
|
| 75 |
+
inputs.push(input);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
// Merge inputs into a single Uint8Array
|
| 79 |
+
const completeInput = new Uint8Array(sum(inputs.map((input) => input.length)));
|
| 80 |
+
let offset = 0;
|
| 81 |
+
for (const input of inputs) {
|
| 82 |
+
completeInput.set(input, offset);
|
| 83 |
+
offset += input.length;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// Get last line starting with "data:" and parse it as JSON to get the generated text
|
| 87 |
+
const message = new TextDecoder().decode(completeInput);
|
| 88 |
+
|
| 89 |
+
let lastIndex = message.lastIndexOf('\ndata:');
|
| 90 |
+
if (lastIndex === -1) {
|
| 91 |
+
lastIndex = message.indexOf('data');
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
if (lastIndex === -1) {
|
| 95 |
+
console.error('Could not parse in last message');
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
let lastMessage = message.slice(lastIndex).trim().slice('data:'.length);
|
| 99 |
+
if (lastMessage.includes('\n')) {
|
| 100 |
+
lastMessage = lastMessage.slice(0, lastMessage.indexOf('\n'));
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
const res = JSON.parse(lastMessage).generated_text;
|
| 104 |
+
|
| 105 |
+
if (typeof res !== 'string') {
|
| 106 |
+
throw new Error('Could not parse generated text');
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
return res;
|
| 110 |
+
}
|