|
import { |
|
animation_duration, |
|
chat, |
|
cleanUpMessage, |
|
event_types, |
|
eventSource, |
|
Generate, |
|
getGeneratingApi, |
|
is_send_press, |
|
isStreamingEnabled, |
|
substituteParamsExtended, |
|
} from '../script.js'; |
|
import { debounce, delay, getStringHash } from './utils.js'; |
|
import { decodeTextTokens, getTokenizerBestMatch } from './tokenizers.js'; |
|
import { power_user } from './power-user.js'; |
|
import { callGenericPopup, POPUP_TYPE } from './popup.js'; |
|
import { t } from './i18n.js'; |
|
|
|
const TINTS = 4; |
|
const MAX_MESSAGE_LOGPROBS = 100; |
|
const REROLL_BUTTON = $('#logprobsReroll'); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const state = { |
|
selectedTokenLogprobs: null, |
|
messageLogprobs: new Map(), |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
function renderAlternativeTokensView() { |
|
const view = $('#logprobs_generation_output'); |
|
if (!view.is(':visible')) { |
|
return; |
|
} |
|
view.empty(); |
|
state.selectedTokenLogprobs = null; |
|
renderTopLogprobs(); |
|
|
|
const { messageLogprobs, continueFrom } = getActiveMessageLogprobData() || {}; |
|
const usingSmoothStreaming = isStreamingEnabled() && power_user.smooth_streaming; |
|
if (!messageLogprobs?.length || usingSmoothStreaming) { |
|
const emptyState = $('<div></div>'); |
|
const noTokensMsg = !power_user.request_token_probabilities |
|
? '<span>Enable <b>Request token probabilities</b> in the User Settings menu to use this feature.</span>' |
|
: usingSmoothStreaming |
|
? t`Token probabilities are not available when using Smooth Streaming.` |
|
: is_send_press |
|
? t`Generation in progress...` |
|
: t`No token probabilities available for the current message.`; |
|
emptyState.html(noTokensMsg); |
|
emptyState.addClass('logprobs_empty_state'); |
|
view.append(emptyState); |
|
return; |
|
} |
|
|
|
const prefix = continueFrom || ''; |
|
const tokenSpans = []; |
|
REROLL_BUTTON.toggle(!!prefix); |
|
|
|
if (prefix) { |
|
REROLL_BUTTON.off('click').on('click', () => onPrefixClicked(prefix.length)); |
|
|
|
let cumulativeOffset = 0; |
|
const words = prefix.split(/\s+/); |
|
const delimiters = prefix.match(/\s+/g) || []; |
|
|
|
words.forEach((word, i) => { |
|
const span = $('<span></span>'); |
|
span.text(`${word} `); |
|
|
|
span.addClass('logprobs_output_prefix'); |
|
span.attr('title', t`Reroll from this point`); |
|
|
|
let offset = cumulativeOffset; |
|
span.on('click', () => onPrefixClicked(offset)); |
|
addKeyboardProps(span); |
|
|
|
tokenSpans.push(span); |
|
tokenSpans.push(delimiters[i]?.includes('\n') |
|
? document.createElement('br') |
|
: document.createTextNode(delimiters[i] || ' '), |
|
); |
|
|
|
cumulativeOffset += word.length + (delimiters[i]?.length || 0); |
|
}); |
|
} |
|
|
|
messageLogprobs.forEach((tokenData, i) => { |
|
const { token } = tokenData; |
|
const span = $('<span></span>'); |
|
const text = toVisibleWhitespace(token); |
|
span.text(text); |
|
span.addClass('logprobs_output_token'); |
|
span.addClass('logprobs_tint_' + (i % TINTS)); |
|
span.on('click', () => onSelectedTokenChanged(tokenData, span)); |
|
addKeyboardProps(span); |
|
tokenSpans.push(...withVirtualWhitespace(token, span)); |
|
}); |
|
|
|
view.append(tokenSpans); |
|
|
|
|
|
if (prefix) { |
|
const element = view.find('.logprobs_output_token').first(); |
|
const scrollOffset = element.offset().top - element.parent().offset().top; |
|
element.parent().scrollTop(scrollOffset); |
|
} |
|
} |
|
|
|
function addKeyboardProps(element) { |
|
element.attr('role', 'button'); |
|
element.attr('tabindex', '0'); |
|
element.keydown(function (e) { |
|
if (e.key === 'Enter' || e.key === ' ') { |
|
element.click(); |
|
} |
|
}); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function renderTopLogprobs() { |
|
$('#logprobs_top_logprobs_hint').hide(); |
|
const view = $('.logprobs_candidate_list'); |
|
view.empty(); |
|
|
|
if (!state.selectedTokenLogprobs) { |
|
return; |
|
} |
|
|
|
const { token: selectedToken, topLogprobs } = state.selectedTokenLogprobs; |
|
|
|
let sum = 0; |
|
const nodes = []; |
|
const candidates = topLogprobs |
|
.sort(([, logA], [, logB]) => logB - logA) |
|
.map(([text, log]) => { |
|
if (log <= 0) { |
|
const probability = Math.exp(log); |
|
sum += probability; |
|
return [text, probability, log]; |
|
} else { |
|
return [text, log, null]; |
|
} |
|
}); |
|
candidates.push(['<others>', 1 - sum, 0]); |
|
|
|
let matched = false; |
|
for (const [token, probability, log] of candidates) { |
|
const container = $('<button class="flex-container flexFlowColumn logprobs_top_candidate"></button>'); |
|
const tokenNormalized = String(token).replace(/^[▁Ġ]/g, ' '); |
|
|
|
if (token === selectedToken || tokenNormalized === selectedToken) { |
|
matched = true; |
|
container.addClass('selected'); |
|
} |
|
|
|
const tokenText = $('<span></span>').text(`${toVisibleWhitespace(token.toString())}`); |
|
const percentText = $('<span></span>').text(`${(+probability * 100).toFixed(2)}%`); |
|
container.append(tokenText, percentText); |
|
if (log) { |
|
container.attr('title', `logarithm: ${log}`); |
|
} |
|
addKeyboardProps(container); |
|
if (token !== '<others>') { |
|
container.on('click', () => onAlternativeClicked(state.selectedTokenLogprobs, token.toString())); |
|
} else { |
|
container.prop('disabled', true); |
|
} |
|
nodes.push(container); |
|
} |
|
|
|
|
|
|
|
if (!matched) { |
|
nodes[nodes.length - 1].css('background-color', 'rgba(255, 0, 0, 0.1)'); |
|
} |
|
|
|
view.append(nodes); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function onSelectedTokenChanged(logprobs, span) { |
|
$('.logprobs_output_token.selected').removeClass('selected'); |
|
if (state.selectedTokenLogprobs === logprobs) { |
|
state.selectedTokenLogprobs = null; |
|
} else { |
|
state.selectedTokenLogprobs = logprobs; |
|
$(span).addClass('selected'); |
|
} |
|
renderTopLogprobs(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function onAlternativeClicked(tokenLogprobs, alternative) { |
|
if (!checkGenerateReady()) { |
|
return; |
|
} |
|
|
|
if (getGeneratingApi() === 'openai') { |
|
const title = t`Feature unavailable`; |
|
const message = t`Due to API limitations, rerolling a token is not supported with OpenAI. Try switching to a different API.`; |
|
const content = `<h3>${title}</h3><p>${message}</p>`; |
|
return callGenericPopup(content, POPUP_TYPE.TEXT); |
|
} |
|
|
|
const { messageLogprobs, continueFrom } = getActiveMessageLogprobData(); |
|
const replaceIndex = messageLogprobs.findIndex(x => x === tokenLogprobs); |
|
|
|
const tokens = messageLogprobs.slice(0, replaceIndex + 1).map(({ token }) => token); |
|
tokens[replaceIndex] = String(alternative).replace(/^[▁Ġ]/g, ' ').replace(/Ċ/g, '\n'); |
|
|
|
const prefix = continueFrom || ''; |
|
const prompt = prefix + tokens.join(''); |
|
addGeneration(prompt); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function onPrefixClicked(offset = undefined) { |
|
if (!checkGenerateReady()) { |
|
return; |
|
} |
|
|
|
const { continueFrom } = getActiveMessageLogprobData() || {}; |
|
const prefix = continueFrom ? continueFrom.substring(0, offset) : ''; |
|
addGeneration(prefix); |
|
} |
|
|
|
function checkGenerateReady() { |
|
if (is_send_press) { |
|
toastr.warning('Please wait for the current generation to complete.'); |
|
return false; |
|
} |
|
return true; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function addGeneration(prompt) { |
|
const messageId = chat.length - 1; |
|
if (prompt && prompt.length > 0) { |
|
createSwipe(messageId, prompt); |
|
$('.swipe_right:last').trigger('click'); |
|
void Generate('continue'); |
|
} else { |
|
$('.swipe_right:last').trigger('click'); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
function onToggleLogprobsPanel() { |
|
const logprobsViewer = $('#logprobsViewer'); |
|
|
|
|
|
if (logprobsViewer.css('display') === 'none') { |
|
logprobsViewer.addClass('resizing'); |
|
logprobsViewer.css('display', 'flex'); |
|
logprobsViewer.css('opacity', 0.0); |
|
renderAlternativeTokensView(); |
|
logprobsViewer.transition({ |
|
opacity: 1.0, |
|
duration: animation_duration, |
|
}, async function () { |
|
await delay(50); |
|
logprobsViewer.removeClass('resizing'); |
|
}); |
|
} else { |
|
logprobsViewer.addClass('resizing'); |
|
logprobsViewer.transition({ |
|
opacity: 0.0, |
|
duration: animation_duration, |
|
}, |
|
async function () { |
|
await delay(50); |
|
logprobsViewer.removeClass('resizing'); |
|
}); |
|
setTimeout(function () { |
|
logprobsViewer.hide(); |
|
}, animation_duration); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
function createSwipe(messageId, prompt) { |
|
|
|
|
|
let cleanedPrompt = cleanUpMessage({ |
|
getMessage: prompt, |
|
isImpersonate: false, |
|
isContinue: false, |
|
displayIncompleteSentences: true, |
|
}); |
|
|
|
const msg = chat[messageId]; |
|
|
|
const reasoningPrefix = substituteParamsExtended(power_user.reasoning.prefix); |
|
const reasoningSuffix = substituteParamsExtended(power_user.reasoning.suffix); |
|
const isReasoningAutoParsed = power_user.reasoning.auto_parse; |
|
const msgHasParsedReasoning = msg.extra?.reasoning?.length > 0; |
|
let shouldRerollReasoning = false; |
|
|
|
|
|
if (isReasoningAutoParsed && msgHasParsedReasoning) { |
|
console.debug('saw autoparse on with reasoning in message'); |
|
|
|
if (cleanedPrompt.includes(reasoningPrefix) && !cleanedPrompt.includes(reasoningSuffix)) { |
|
|
|
|
|
console.debug('..with start tag but no end tag... reroll reasoning'); |
|
shouldRerollReasoning = true; |
|
} |
|
|
|
let hasReasoningPrefix = cleanedPrompt.includes(reasoningPrefix); |
|
let hasReasoningSuffix = cleanedPrompt.includes(reasoningSuffix); |
|
|
|
|
|
|
|
|
|
if (hasReasoningPrefix && hasReasoningSuffix) { |
|
|
|
console.debug('...incl. end tag...rerolling response'); |
|
const endOfThink = cleanedPrompt.indexOf(reasoningSuffix) + reasoningSuffix.length; |
|
cleanedPrompt = cleanedPrompt.substring(endOfThink); |
|
} |
|
|
|
|
|
if (hasReasoningPrefix && !hasReasoningSuffix) { |
|
console.debug('..no end tag...rerolling reasoning, so removing prefix'); |
|
cleanedPrompt = cleanedPrompt.replace(reasoningPrefix, ''); |
|
} |
|
} |
|
|
|
console.debug('cleanedPrompt: ', cleanedPrompt); |
|
|
|
const newSwipeInfo = { |
|
send_date: msg.send_date, |
|
gen_started: msg.gen_started, |
|
gen_finished: msg.gen_finished, |
|
extra: { ...structuredClone(msg.extra), from_logprobs: new Date().getTime() }, |
|
}; |
|
|
|
msg.swipes = msg.swipes || []; |
|
msg.swipe_info = msg.swipe_info || []; |
|
|
|
|
|
|
|
|
|
|
|
if (shouldRerollReasoning) { |
|
|
|
newSwipeInfo.extra.reasoning = cleanedPrompt; |
|
|
|
msg.swipes.push(''); |
|
} else { |
|
|
|
msg.swipes.push(cleanedPrompt); |
|
} |
|
|
|
msg.swipe_info.push(newSwipeInfo); |
|
msg.swipe_id = Math.max(0, msg.swipes.length - 2); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function toVisibleWhitespace(input) { |
|
return input.replace(/ /g, '·').replace(/[▁Ġ]/g, '·').replace(/[Ċ\n]/g, '↵'); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function withVirtualWhitespace(text, span) { |
|
|
|
const result = [span]; |
|
if (text.match(/^\s/)) { |
|
result.unshift(document.createTextNode('\u200b')); |
|
} |
|
if (text.match(/\s$/)) { |
|
result.push($(document.createTextNode('\u200b'))); |
|
} |
|
if (text.match(/^[▁Ġ]/)) { |
|
result.unshift(document.createTextNode('\u200b')); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (text.match(/^\n(?:.|\n)+\n$/)) { |
|
result.unshift($('<br>')); |
|
result.push($('<br>')); |
|
} else if (text.match(/^\n/)) { |
|
result.unshift($('<br>')); |
|
} else if (text.match(/\n$/)) { |
|
result.push($('<br>')); |
|
} |
|
return result; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export function saveLogprobsForActiveMessage(logprobs, continueFrom) { |
|
if (!logprobs) { |
|
|
|
return; |
|
} |
|
|
|
|
|
if (getGeneratingApi() === 'novel') { |
|
convertTokenIdLogprobsToText(logprobs); |
|
} |
|
|
|
const msgId = chat.length - 1; |
|
|
|
const data = { |
|
created: new Date().getTime(), |
|
api: getGeneratingApi(), |
|
messageId: msgId, |
|
swipeId: chat[msgId].swipe_id, |
|
messageLogprobs: logprobs, |
|
continueFrom, |
|
hash: getMessageHash(chat[msgId]), |
|
}; |
|
|
|
state.messageLogprobs.set(data.hash, data); |
|
|
|
|
|
const oldLogprobs = Array.from(state.messageLogprobs.values()) |
|
.sort((a, b) => b.created - a.created) |
|
.slice(MAX_MESSAGE_LOGPROBS); |
|
for (const oldData of oldLogprobs) { |
|
state.messageLogprobs.delete(oldData.hash); |
|
} |
|
} |
|
|
|
function getMessageHash(message) { |
|
|
|
|
|
const hashParams = { |
|
name: message.name, |
|
mid: chat.indexOf(message), |
|
text: message.mes, |
|
}; |
|
return getStringHash(JSON.stringify(hashParams)); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
function getActiveMessageLogprobData() { |
|
if (chat.length === 0) { |
|
return null; |
|
} |
|
|
|
const hash = getMessageHash(chat[chat.length - 1]); |
|
return state.messageLogprobs.get(hash) || null; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function convertTokenIdLogprobsToText(input) { |
|
const api = getGeneratingApi(); |
|
if (api !== 'novel') { |
|
|
|
throw new Error('convertTokenIdLogprobsToText should only be called for NovelAI'); |
|
} |
|
|
|
const tokenizerId = getTokenizerBestMatch(api); |
|
|
|
|
|
const tokenIds = Array.from(new Set(input.flatMap(logprobs => |
|
logprobs.topLogprobs.map(([token]) => token).concat(logprobs.token), |
|
))); |
|
|
|
|
|
|
|
const { chunks } = decodeTextTokens(tokenizerId, tokenIds); |
|
const tokenIdText = new Map(tokenIds.map((id, i) => [id, chunks[i]])); |
|
|
|
|
|
input.forEach(logprobs => { |
|
logprobs.token = tokenIdText.get(logprobs.token); |
|
logprobs.topLogprobs = logprobs.topLogprobs.map(([token, logprob]) => |
|
[tokenIdText.get(token), logprob], |
|
); |
|
}); |
|
} |
|
|
|
export function initLogprobs() { |
|
REROLL_BUTTON.hide(); |
|
const debouncedRender = debounce(renderAlternativeTokensView); |
|
$('#logprobsViewerClose').on('click', onToggleLogprobsPanel); |
|
$('#option_toggle_logprobs').on('click', onToggleLogprobsPanel); |
|
eventSource.on(event_types.CHAT_CHANGED, debouncedRender); |
|
eventSource.on(event_types.CHARACTER_MESSAGE_RENDERED, debouncedRender); |
|
eventSource.on(event_types.IMPERSONATE_READY, debouncedRender); |
|
eventSource.on(event_types.MESSAGE_DELETED, debouncedRender); |
|
eventSource.on(event_types.MESSAGE_EDITED, debouncedRender); |
|
eventSource.on(event_types.MESSAGE_SWIPED, debouncedRender); |
|
} |
|
|