From 9660aaa2c2d49a502af0176cf2cebb4d554dc390 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Sun, 27 Aug 2023 18:27:34 +0300 Subject: [PATCH] Add NovelAI hypebot plugin --- public/index.html | 2 +- public/lib/gpt-3-tokenizer/array-keyed-map.js | 210 -------------- public/lib/gpt-3-tokenizer/gpt3-tokenizer.js | 271 ------------------ .../lib/gpt-3-tokenizer/gpt3-tokenizer.js.map | 1 - public/scripts/extensions/hypebot/index.js | 195 +++++++++++++ .../scripts/extensions/hypebot/manifest.json | 11 + .../scripts/extensions/hypebot/settings.html | 18 ++ public/scripts/extensions/hypebot/style.css | 17 ++ .../scripts/extensions/quick-reply/style.css | 1 + public/scripts/tokenizers.js | 118 ++++++-- public/scripts/utils.js | 11 + public/style.css | 3 + server.js | 116 ++++++-- src/novelai.js | 92 ++++-- 14 files changed, 505 insertions(+), 561 deletions(-) delete mode 100644 public/lib/gpt-3-tokenizer/array-keyed-map.js delete mode 100644 public/lib/gpt-3-tokenizer/gpt3-tokenizer.js delete mode 100644 public/lib/gpt-3-tokenizer/gpt3-tokenizer.js.map create mode 100644 public/scripts/extensions/hypebot/index.js create mode 100644 public/scripts/extensions/hypebot/manifest.json create mode 100644 public/scripts/extensions/hypebot/settings.html create mode 100644 public/scripts/extensions/hypebot/style.css diff --git a/public/index.html b/public/index.html index 4dcfe7d45..ea91c7adc 100644 --- a/public/index.html +++ b/public/index.html @@ -2277,7 +2277,7 @@ + Enabled + + + + + + diff --git a/public/scripts/extensions/hypebot/style.css b/public/scripts/extensions/hypebot/style.css new file mode 100644 index 000000000..3a696f4df --- /dev/null +++ b/public/scripts/extensions/hypebot/style.css @@ -0,0 +1,17 @@ +#hypeBotBar { + width: 100%; + max-width: 100%; + padding: 0.5em; + white-space: normal; + font-size: calc(var(--mainFontSize) * 0.85); + order: 20; +} + +.hypebot_nokey { + text-align: center; + font-style: italic; +} + +.hypebot_name { + font-weight: 600; +} diff --git a/public/scripts/extensions/quick-reply/style.css b/public/scripts/extensions/quick-reply/style.css index 4f8f1a7b5..c60a60a21 100644 --- a/public/scripts/extensions/quick-reply/style.css +++ b/public/scripts/extensions/quick-reply/style.css @@ -12,6 +12,7 @@ display: none; max-width: 100%; overflow-x: auto; + order: 10; } #quickReplies { diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js index 0b5a34668..8c43b8049 100644 --- a/public/scripts/tokenizers.js +++ b/public/scripts/tokenizers.js @@ -1,7 +1,6 @@ import { characters, main_api, nai_settings, online_status, this_chid } from "../script.js"; import { power_user } from "./power-user.js"; import { encode } from "../lib/gpt-2-3-tokenizer/mod.js"; -import { GPT3BrowserTokenizer } from "../lib/gpt-3-tokenizer/gpt3-tokenizer.js"; import { chat_completion_sources, oai_settings } from "./openai.js"; import { groups, selected_group } from "./group-chats.js"; import { getStringHash } from "./utils.js"; @@ -12,7 +11,7 @@ const TOKENIZER_WARNING_KEY = 'tokenizationWarningShown'; export const tokenizers = { NONE: 0, - GPT3: 1, + GPT2: 1, CLASSIC: 2, LLAMA: 3, NERD: 4, @@ -22,7 +21,6 @@ export const tokenizers = { }; const objectStore = new localforage.createInstance({ name: "SillyTavern_ChatCompletions" }); -const gpt3 = new GPT3BrowserTokenizer({ type: 'gpt3' }); let tokenCache = {}; @@ -93,6 +91,35 @@ function getTokenizerBestMatch() { return tokenizers.NONE; } +/** + * Calls the underlying tokenizer model to the token count for a string. + * @param {number} type Tokenizer type. + * @param {string} str String to tokenize. + * @param {number} padding Number of padding tokens. + * @returns {number} Token count. + */ +function callTokenizer(type, str, padding) { + switch (type) { + case tokenizers.NONE: + return guesstimate(str) + padding; + case tokenizers.GPT2: + return countTokensRemote('/tokenize_gpt2', str, padding); + case tokenizers.CLASSIC: + return encode(str).length + padding; + case tokenizers.LLAMA: + return countTokensRemote('/tokenize_llama', str, padding); + case tokenizers.NERD: + return countTokensRemote('/tokenize_nerdstash', str, padding); + case tokenizers.NERD2: + return countTokensRemote('/tokenize_nerdstash_v2', str, padding); + case tokenizers.API: + return countTokensRemote('/tokenize_via_api', str, padding); + default: + console.warn("Unknown tokenizer type", type); + return callTokenizer(tokenizers.NONE, str, padding); + } +} + /** * Gets the token count for a string using the current model tokenizer. * @param {string} str String to tokenize @@ -100,33 +127,6 @@ function getTokenizerBestMatch() { * @returns {number} Token count. */ export function getTokenCount(str, padding = undefined) { - /** - * Calculates the token count for a string. - * @param {number} [type] Tokenizer type. - * @returns {number} Token count. - */ - function calculate(type) { - switch (type) { - case tokenizers.NONE: - return guesstimate(str) + padding; - case tokenizers.GPT3: - return gpt3.encode(str).bpe.length + padding; - case tokenizers.CLASSIC: - return encode(str).length + padding; - case tokenizers.LLAMA: - return countTokensRemote('/tokenize_llama', str, padding); - case tokenizers.NERD: - return countTokensRemote('/tokenize_nerdstash', str, padding); - case tokenizers.NERD2: - return countTokensRemote('/tokenize_nerdstash_v2', str, padding); - case tokenizers.API: - return countTokensRemote('/tokenize_via_api', str, padding); - default: - console.warn("Unknown tokenizer type", type); - return calculate(tokenizers.NONE); - } - } - if (typeof str !== 'string' || !str?.length) { return 0; } @@ -159,7 +159,7 @@ export function getTokenCount(str, padding = undefined) { return cacheObject[cacheKey]; } - const result = calculate(tokenizerType); + const result = callTokenizer(tokenizerType, str, padding); if (isNaN(result)) { console.warn("Token count calculation returned NaN"); @@ -350,6 +350,12 @@ function countTokensRemote(endpoint, str, padding) { return tokenCount + padding; } +/** + * Calls the underlying tokenizer model to encode a string to tokens. + * @param {string} endpoint API endpoint. + * @param {string} str String to tokenize. + * @returns {number[]} Array of token ids. + */ function getTextTokensRemote(endpoint, str) { let ids = []; jQuery.ajax({ @@ -366,8 +372,37 @@ function getTextTokensRemote(endpoint, str) { return ids; } +/** + * Calls the underlying tokenizer model to decode token ids to text. + * @param {string} endpoint API endpoint. + * @param {number[]} ids Array of token ids + */ +function decodeTextTokensRemote(endpoint, ids) { + let text = ''; + jQuery.ajax({ + async: false, + type: 'POST', + url: endpoint, + data: JSON.stringify({ ids: ids }), + dataType: "json", + contentType: "application/json", + success: function (data) { + text = data.text; + } + }); + return text; +} + +/** + * Encodes a string to tokens using the remote server API. + * @param {number} tokenizerType Tokenizer type. + * @param {string} str String to tokenize. + * @returns {number[]} Array of token ids. + */ export function getTextTokens(tokenizerType, str) { switch (tokenizerType) { + case tokenizers.GPT2: + return getTextTokensRemote('/tokenize_gpt2', str); case tokenizers.LLAMA: return getTextTokensRemote('/tokenize_llama', str); case tokenizers.NERD: @@ -380,6 +415,27 @@ export function getTextTokens(tokenizerType, str) { } } +/** + * Decodes token ids to text using the remote server API. + * @param {any} tokenizerType Tokenizer type. + * @param {number[]} ids Array of token ids + */ +export function decodeTextTokens(tokenizerType, ids) { + switch (tokenizerType) { + case tokenizers.GPT2: + return decodeTextTokensRemote('/decode_gpt2', ids); + case tokenizers.LLAMA: + return decodeTextTokensRemote('/decode_llama', ids); + case tokenizers.NERD: + return decodeTextTokensRemote('/decode_nerdstash', ids); + case tokenizers.NERD2: + return decodeTextTokensRemote('/decode_nerdstash_v2', ids); + default: + console.warn("Calling decodeTextTokens with unsupported tokenizer type", tokenizerType); + return ''; + } +} + jQuery(async () => { await loadTokenCache(); }); diff --git a/public/scripts/utils.js b/public/scripts/utils.js index 69d0d1030..aa74ff3dd 100644 --- a/public/scripts/utils.js +++ b/public/scripts/utils.js @@ -45,6 +45,17 @@ export function getSortableDelay() { return isMobile() ? 750 : 50; } +export async function bufferToBase64(buffer) { + // use a FileReader to generate a base64 data URI: + const base64url = await new Promise(resolve => { + const reader = new FileReader() + reader.onload = () => resolve(reader.result) + reader.readAsDataURL(new Blob([buffer])) + }); + // remove the `data:...;base64,` part from the start + return base64url.slice(base64url.indexOf(',') + 1); +} + /** * Rearranges an array in a random order. * @param {any[]} array The array to shuffle. diff --git a/public/style.css b/public/style.css index 3f09e4c76..3bc531046 100644 --- a/public/style.css +++ b/public/style.css @@ -531,6 +531,7 @@ hr { column-gap: 5px; font-size: var(--bottomFormIconSize); overflow: hidden; + order: 1003; } #send_but_sheld>div { @@ -581,6 +582,7 @@ hr { transition: 0.3s; display: flex; align-items: center; + order: 1001; } .font-family-reset { @@ -904,6 +906,7 @@ select { margin: 0; text-shadow: 0px 0px calc(var(--shadowWidth) * 1px) var(--SmartThemeShadowColor); flex: 1; + order: 1002; } .text_pole::placeholder { diff --git a/server.js b/server.js index e28d2acc1..241951c81 100644 --- a/server.js +++ b/server.js @@ -1894,8 +1894,7 @@ app.post("/generate_novelai", jsonParser, async function (request, response_gene const novelai = require('./src/novelai'); const isNewModel = (request.body.model.includes('clio') || request.body.model.includes('kayra')); - const isKrake = request.body.model.includes('krake'); - const badWordsList = (isNewModel ? novelai.badWordsList : (isKrake ? novelai.krakeBadWordsList : novelai.euterpeBadWordsList)).slice(); + const badWordsList = novelai.getBadWordsList(request.body.model); // Add customized bad words for Clio and Kayra if (isNewModel && Array.isArray(request.body.bad_words_ids)) { @@ -1907,7 +1906,7 @@ app.post("/generate_novelai", jsonParser, async function (request, response_gene } // Add default biases for dinkus and asterism - const logit_bias_exp = isNewModel ? novelai.logitBiasExp.slice() : null; + const logit_bias_exp = isNewModel ? novelai.logitBiasExp.slice() : []; if (Array.isArray(logit_bias_exp) && Array.isArray(request.body.logit_bias_exp)) { logit_bias_exp.push(...request.body.logit_bias_exp); @@ -1942,7 +1941,7 @@ app.post("/generate_novelai", jsonParser, async function (request, response_gene "logit_bias_exp": logit_bias_exp, "generate_until_sentence": request.body.generate_until_sentence, "use_cache": request.body.use_cache, - "use_string": true, + "use_string": request.body.use_string ?? true, "return_full_text": request.body.return_full_text, "prefix": request.body.prefix, "order": request.body.order @@ -3845,22 +3844,87 @@ function getPresetSettingsByAPI(apiId) { } } -function createTokenizationHandler(getTokenizerFn) { +function createSentencepieceEncodingHandler(getTokenizerFn) { return async function (request, response) { - if (!request.body) { - return response.sendStatus(400); - } + try { + if (!request.body) { + return response.sendStatus(400); + } - const text = request.body.text || ''; - const tokenizer = getTokenizerFn(); - const { ids, count } = await countSentencepieceTokens(tokenizer, text); - return response.send({ ids, count }); + const text = request.body.text || ''; + const tokenizer = getTokenizerFn(); + const { ids, count } = await countSentencepieceTokens(tokenizer, text); + return response.send({ ids, count }); + } catch (error) { + console.log(error); + return response.send({ ids: [], count: 0 }); + } }; } -app.post("/tokenize_llama", jsonParser, createTokenizationHandler(() => spp_llama)); -app.post("/tokenize_nerdstash", jsonParser, createTokenizationHandler(() => spp_nerd)); -app.post("/tokenize_nerdstash_v2", jsonParser, createTokenizationHandler(() => spp_nerd_v2)); +function createSentencepieceDecodingHandler(getTokenizerFn) { + return async function (request, response) { + try { + if (!request.body) { + return response.sendStatus(400); + } + + const ids = request.body.ids || []; + const tokenizer = getTokenizerFn(); + const text = await tokenizer.decodeIds(ids); + return response.send({ text }); + } catch (error) { + console.log(error); + return response.send({ text: '' }); + } + }; +} + +function createTiktokenEncodingHandler(modelId) { + return async function (request, response) { + try { + if (!request.body) { + return response.sendStatus(400); + } + + const text = request.body.text || ''; + const tokenizer = getTiktokenTokenizer(modelId); + const tokens = Object.values(tokenizer.encode(text)); + return response.send({ ids: tokens, count: tokens.length }); + } catch (error) { + console.log(error); + return response.send({ ids: [], count: 0 }); + } + } +} + +function createTiktokenDecodingHandler(modelId) { + return async function (request, response) { + try { + if (!request.body) { + return response.sendStatus(400); + } + + const ids = request.body.ids || []; + const tokenizer = getTiktokenTokenizer(modelId); + const textBytes = tokenizer.decode(new Uint32Array(ids)); + const text = new TextDecoder().decode(textBytes); + return response.send({ text }); + } catch (error) { + console.log(error); + return response.send({ text: '' }); + } + } +} + +app.post("/tokenize_llama", jsonParser, createSentencepieceEncodingHandler(() => spp_llama)); +app.post("/tokenize_nerdstash", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd)); +app.post("/tokenize_nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd_v2)); +app.post("/tokenize_gpt2", jsonParser, createTiktokenEncodingHandler('gpt2')); +app.post("/decode_llama", jsonParser, createSentencepieceDecodingHandler(() => spp_llama)); +app.post("/decode_nerdstash", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd)); +app.post("/decode_nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2)); +app.post("/decode_gpt2", jsonParser, createTiktokenDecodingHandler('gpt2')); app.post("/tokenize_via_api", jsonParser, async function (request, response) { if (!request.body) { return response.sendStatus(400); @@ -4350,17 +4414,17 @@ app.post('/libre_translate', jsonParser, async (request, response) => { console.log('Input text: ' + text); try { - const result = await fetch(url, { - method: "POST", - body: JSON.stringify({ - q: text, - source: "auto", - target: lang, - format: "text", - api_key: key - }), - headers: { "Content-Type": "application/json" } - }); + const result = await fetch(url, { + method: "POST", + body: JSON.stringify({ + q: text, + source: "auto", + target: lang, + format: "text", + api_key: key + }), + headers: { "Content-Type": "application/json" } + }); if (!result.ok) { return response.sendStatus(result.status); diff --git a/src/novelai.js b/src/novelai.js index 3a988029a..36dc6fa94 100644 --- a/src/novelai.js +++ b/src/novelai.js @@ -1,18 +1,18 @@ // Ban bracket generation, plus defaults const euterpeBadWordsList = [ - [8162], [17202], [8162], [17202], [8162], [17202], [8162], [17202], [8162], [17202], [46256, 224], [2343, 223, 224], - [46256, 224], [2343, 223, 224], [46256, 224], [2343, 223, 224], [46256, 224], [2343, 223, 224], [46256, 224], - [2343, 223, 224], [58], [60], [90], [92], [685], [1391], [1782], [2361], [3693], [4083], [4357], [4895], [5512], - [5974], [7131], [8183], [8351], [8762], [8964], [8973], [9063], [11208], [11709], [11907], [11919], [12878], [12962], - [13018], [13412], [14631], [14692], [14980], [15090], [15437], [16151], [16410], [16589], [17241], [17414], [17635], - [17816], [17912], [18083], [18161], [18477], [19629], [19779], [19953], [20520], [20598], [20662], [20740], [21476], - [21737], [22133], [22241], [22345], [22935], [23330], [23785], [23834], [23884], [25295], [25597], [25719], [25787], - [25915], [26076], [26358], [26398], [26894], [26933], [27007], [27422], [28013], [29164], [29225], [29342], [29565], - [29795], [30072], [30109], [30138], [30866], [31161], [31478], [32092], [32239], [32509], [33116], [33250], [33761], - [34171], [34758], [34949], [35944], [36338], [36463], [36563], [36786], [36796], [36937], [37250], [37913], [37981], - [38165], [38362], [38381], [38430], [38892], [39850], [39893], [41832], [41888], [42535], [42669], [42785], [42924], - [43839], [44438], [44587], [44926], [45144], [45297], [46110], [46570], [46581], [46956], [47175], [47182], [47527], - [47715], [48600], [48683], [48688], [48874], [48999], [49074], [49082], [49146], [49946], [10221], [4841], [1427], + [8162], [17202], [8162], [17202], [8162], [17202], [8162], [17202], [8162], [17202], [46256, 224], [2343, 223, 224], + [46256, 224], [2343, 223, 224], [46256, 224], [2343, 223, 224], [46256, 224], [2343, 223, 224], [46256, 224], + [2343, 223, 224], [58], [60], [90], [92], [685], [1391], [1782], [2361], [3693], [4083], [4357], [4895], [5512], + [5974], [7131], [8183], [8351], [8762], [8964], [8973], [9063], [11208], [11709], [11907], [11919], [12878], [12962], + [13018], [13412], [14631], [14692], [14980], [15090], [15437], [16151], [16410], [16589], [17241], [17414], [17635], + [17816], [17912], [18083], [18161], [18477], [19629], [19779], [19953], [20520], [20598], [20662], [20740], [21476], + [21737], [22133], [22241], [22345], [22935], [23330], [23785], [23834], [23884], [25295], [25597], [25719], [25787], + [25915], [26076], [26358], [26398], [26894], [26933], [27007], [27422], [28013], [29164], [29225], [29342], [29565], + [29795], [30072], [30109], [30138], [30866], [31161], [31478], [32092], [32239], [32509], [33116], [33250], [33761], + [34171], [34758], [34949], [35944], [36338], [36463], [36563], [36786], [36796], [36937], [37250], [37913], [37981], + [38165], [38362], [38381], [38430], [38892], [39850], [39893], [41832], [41888], [42535], [42669], [42785], [42924], + [43839], [44438], [44587], [44926], [45144], [45297], [46110], [46570], [46581], [46956], [47175], [47182], [47527], + [47715], [48600], [48683], [48688], [48874], [48999], [49074], [49082], [49146], [49946], [10221], [4841], [1427], [2602, 834], [29343], [37405], [35780], [2602], [50256], ] @@ -47,29 +47,79 @@ const krakeBadWordsList = [ // Ban bracket generation, plus defaults const badWordsList = [ - [23], [49209, 23], [23], [49209, 23], [23], [49209, 23], [23], [49209, 23], [23], [49209, 23], [21], [49209, 21], - [21], [49209, 21], [21], [49209, 21], [21], [49209, 21], [21], [49209, 21], [3], [49356], [1431], [31715], [34387], + [23], [49209, 23], [23], [49209, 23], [23], [49209, 23], [23], [49209, 23], [23], [49209, 23], [21], [49209, 21], + [21], [49209, 21], [21], [49209, 21], [21], [49209, 21], [21], [49209, 21], [3], [49356], [1431], [31715], [34387], [20765], [30702], [10691], [49333], [1266], [26523], [41471], [2936], [85, 85], [49332], [7286], [1115] ] +const hypeBotBadWordsList = [ + [58], [60], [90], [92], [685], [1391], [1782], [2361], [3693], [4083], [4357], [4895], + [5512], [5974], [7131], [8183], [8351], [8762], [8964], [8973], [9063], [11208], + [11709], [11907], [11919], [12878], [12962], [13018], [13412], [14631], [14692], + [14980], [15090], [15437], [16151], [16410], [16589], [17241], [17414], [17635], + [17816], [17912], [18083], [18161], [18477], [19629], [19779], [19953], [20520], + [20598], [20662], [20740], [21476], [21737], [22133], [22241], [22345], [22935], + [23330], [23785], [23834], [23884], [25295], [25597], [25719], [25787], [25915], + [26076], [26358], [26398], [26894], [26933], [27007], [27422], [28013], [29164], + [29225], [29342], [29565], [29795], [30072], [30109], [30138], [30866], [31161], + [31478], [32092], [32239], [32509], [33116], [33250], [33761], [34171], [34758], + [34949], [35944], [36338], [36463], [36563], [36786], [36796], [36937], [37250], + [37913], [37981], [38165], [38362], [38381], [38430], [38892], [39850], [39893], + [41832], [41888], [42535], [42669], [42785], [42924], [43839], [44438], [44587], + [44926], [45144], [45297], [46110], [46570], [46581], [46956], [47175], [47182], + [47527], [47715], [48600], [48683], [48688], [48874], [48999], [49074], [49082], + [49146], [49946], [10221], [4841], [1427], [2602, 834], [29343], [37405], [35780], [2602], [50256] +]; + // Used for phrase repetition penalty const repPenaltyAllowList = [ - [49256, 49264, 49231, 49230, 49287, 85, 49255, 49399, 49262, 336, 333, 432, 363, 468, 492, 745, 401, 426, 623, 794, - 1096, 2919, 2072, 7379, 1259, 2110, 620, 526, 487, 16562, 603, 805, 761, 2681, 942, 8917, 653, 3513, 506, 5301, - 562, 5010, 614, 10942, 539, 2976, 462, 5189, 567, 2032, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 588, - 803, 1040, 49209, 4, 5, 6, 7, 8, 9, 10, 11, 12] + [49256, 49264, 49231, 49230, 49287, 85, 49255, 49399, 49262, 336, 333, 432, 363, 468, 492, 745, 401, 426, 623, 794, + 1096, 2919, 2072, 7379, 1259, 2110, 620, 526, 487, 16562, 603, 805, 761, 2681, 942, 8917, 653, 3513, 506, 5301, + 562, 5010, 614, 10942, 539, 2976, 462, 5189, 567, 2032, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 588, + 803, 1040, 49209, 4, 5, 6, 7, 8, 9, 10, 11, 12] ] // Ban the dinkus and asterism const logitBiasExp = [ - { "sequence": [23], "bias": -0.08, "ensure_sequence_finish": false, "generate_once": false }, + { "sequence": [23], "bias": -0.08, "ensure_sequence_finish": false, "generate_once": false }, { "sequence": [21], "bias": -0.08, "ensure_sequence_finish": false, "generate_once": false } ] +const hypeBotLogitBiasExp = [ + { "sequence": [8162], "bias": -0.12, "ensure_sequence_finish": false, "generate_once": false}, + { "sequence": [46256, 224], "bias": -0.12, "ensure_sequence_finish": false, "generate_once": false } +]; + +function getBadWordsList(model) { + let list = [] + + if (model.includes('euterpe')) { + list = euterpeBadWordsList; + } + + if (model.includes('krake')) { + list = krakeBadWordsList; + } + + if (model.includes('hypebot')) { + list = hypeBotBadWordsList; + } + + if (model.includes('clio') || model.includes('kayra')) { + list = badWordsList; + } + + // Clone the list so we don't modify the original + return list.slice(); +} + module.exports = { euterpeBadWordsList, krakeBadWordsList, badWordsList, repPenaltyAllowList, - logitBiasExp + logitBiasExp, + hypeBotBadWordsList, + hypeBotLogitBiasExp, + getBadWordsList, };