From bbb1a6e5789d3e001f766e2143e091860e44a05e Mon Sep 17 00:00:00 2001
From: Cohee <18619528+Cohee1207@users.noreply.github.com>
Date: Fri, 28 Jun 2024 18:17:27 +0300
Subject: [PATCH] Add huggingface inference as text completion source
---
public/img/huggingface.svg | 40 ++++++++++++++++++++++
public/index.html | 21 ++++++++++--
public/script.js | 6 ++++
public/scripts/secrets.js | 2 ++
public/scripts/textgen-settings.js | 26 +++++++++++++-
src/additional-headers.js | 14 ++++++++
src/constants.js | 1 +
src/endpoints/backends/text-completions.js | 19 +++++++---
src/endpoints/secrets.js | 1 +
9 files changed, 122 insertions(+), 8 deletions(-)
create mode 100644 public/img/huggingface.svg
diff --git a/public/img/huggingface.svg b/public/img/huggingface.svg
new file mode 100644
index 000000000..3acb82a27
--- /dev/null
+++ b/public/img/huggingface.svg
@@ -0,0 +1,40 @@
+
+
diff --git a/public/index.html b/public/index.html
index 2f9bee085..99a53c5fe 100644
--- a/public/index.html
+++ b/public/index.html
@@ -1477,7 +1477,7 @@
-
+
Seed
@@ -2029,6 +2029,7 @@
+
@@ -2211,6 +2212,22 @@
+
+
HuggingFace Token
+
+
+
+
+
+ For privacy reasons, your API key will be hidden after you reload the page.
+
+
+
Endpoint URL
+ Example: https://****.endpoints.huggingface.cloud
+
+
+
-
+
diff --git a/public/script.js b/public/script.js
index ec738648c..ab9543751 100644
--- a/public/script.js
+++ b/public/script.js
@@ -8311,6 +8311,11 @@ const CONNECT_API_MAP = {
button: '#api_button_textgenerationwebui',
type: textgen_types.OPENROUTER,
},
+ 'huggingface': {
+ selected: 'textgenerationwebui',
+ button: '#api_button_textgenerationwebui',
+ type: textgen_types.HUGGINGFACE,
+ },
};
async function selectContextCallback(_, name) {
@@ -9471,6 +9476,7 @@ jQuery(async function () {
{ id: 'api_key_openrouter-tg', secret: SECRET_KEYS.OPENROUTER },
{ id: 'api_key_koboldcpp', secret: SECRET_KEYS.KOBOLDCPP },
{ id: 'api_key_llamacpp', secret: SECRET_KEYS.LLAMACPP },
+ { id: 'api_key_huggingface', secret: SECRET_KEYS.HUGGINGFACE },
];
for (const key of keys) {
diff --git a/public/scripts/secrets.js b/public/scripts/secrets.js
index 8989cdb50..4410ed1ee 100644
--- a/public/scripts/secrets.js
+++ b/public/scripts/secrets.js
@@ -29,6 +29,7 @@ export const SECRET_KEYS = {
GROQ: 'api_key_groq',
AZURE_TTS: 'api_key_azure_tts',
ZEROONEAI: 'api_key_01ai',
+ HUGGINGFACE: 'api_key_huggingface',
};
const INPUT_MAP = {
@@ -58,6 +59,7 @@ const INPUT_MAP = {
[SECRET_KEYS.PERPLEXITY]: '#api_key_perplexity',
[SECRET_KEYS.GROQ]: '#api_key_groq',
[SECRET_KEYS.ZEROONEAI]: '#api_key_01ai',
+ [SECRET_KEYS.HUGGINGFACE]: '#api_key_huggingface',
};
async function clearSecret() {
diff --git a/public/scripts/textgen-settings.js b/public/scripts/textgen-settings.js
index 6623f1d9e..eed2d0ef8 100644
--- a/public/scripts/textgen-settings.js
+++ b/public/scripts/textgen-settings.js
@@ -38,9 +38,24 @@ export const textgen_types = {
INFERMATICAI: 'infermaticai',
DREAMGEN: 'dreamgen',
OPENROUTER: 'openrouter',
+ HUGGINGFACE: 'huggingface',
};
-const { MANCER, VLLM, APHRODITE, TABBY, TOGETHERAI, OOBA, OLLAMA, LLAMACPP, INFERMATICAI, DREAMGEN, OPENROUTER, KOBOLDCPP } = textgen_types;
+const {
+ MANCER,
+ VLLM,
+ APHRODITE,
+ TABBY,
+ TOGETHERAI,
+ OOBA,
+ OLLAMA,
+ LLAMACPP,
+ INFERMATICAI,
+ DREAMGEN,
+ OPENROUTER,
+ KOBOLDCPP,
+ HUGGINGFACE,
+} = textgen_types;
const LLAMACPP_DEFAULT_ORDER = [
'top_k',
@@ -84,6 +99,7 @@ const SERVER_INPUTS = {
[textgen_types.KOBOLDCPP]: '#koboldcpp_api_url_text',
[textgen_types.LLAMACPP]: '#llamacpp_api_url_text',
[textgen_types.OLLAMA]: '#ollama_api_url_text',
+ [textgen_types.HUGGINGFACE]: '#huggingface_api_url_text',
};
const KOBOLDCPP_ORDER = [6, 0, 1, 3, 4, 2, 5];
@@ -1009,6 +1025,8 @@ export function getTextGenModel() {
throw new Error('No Ollama model selected');
}
return settings.ollama_model;
+ case HUGGINGFACE:
+ return 'tgi';
default:
return undefined;
}
@@ -1146,6 +1164,12 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
params.grammar = settings.grammar_string;
}
+ if (settings.type === HUGGINGFACE) {
+ params.top_p = Math.min(Math.max(Number(params.top_p), 0.0), 0.999);
+ params.stop = Array.isArray(params.stop) ? params.stop.slice(0, 4) : [];
+ nonAphroditeParams.seed = settings.seed >= 0 ? settings.seed : undefined;
+ }
+
if (settings.type === MANCER) {
params.n = canMultiSwipe ? settings.n : 1;
params.epsilon_cutoff /= 1000;
diff --git a/src/additional-headers.js b/src/additional-headers.js
index b8a44b390..8a188caa6 100644
--- a/src/additional-headers.js
+++ b/src/additional-headers.js
@@ -147,6 +147,19 @@ function getKoboldCppHeaders(directories) {
}) : {};
}
+/**
+ * Gets the headers for the HuggingFace API.
+ * @param {import('./users').UserDirectoryList} directories
+ * @returns {object} Headers for the request
+ */
+function getHuggingFaceHeaders(directories) {
+ const apiKey = readSecret(directories, SECRET_KEYS.HUGGINGFACE);
+
+ return apiKey ? ({
+ 'Authorization': `Bearer ${apiKey}`,
+ }) : {};
+}
+
function getOverrideHeaders(urlHost) {
const requestOverrides = getConfigValue('requestOverrides', []);
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
@@ -187,6 +200,7 @@ function setAdditionalHeadersByType(requestHeaders, type, server, directories) {
[TEXTGEN_TYPES.OPENROUTER]: getOpenRouterHeaders,
[TEXTGEN_TYPES.KOBOLDCPP]: getKoboldCppHeaders,
[TEXTGEN_TYPES.LLAMACPP]: getLlamaCppHeaders,
+ [TEXTGEN_TYPES.HUGGINGFACE]: getHuggingFaceHeaders,
};
const getHeaders = headerGetters[type];
diff --git a/src/constants.js b/src/constants.js
index 842c35d46..01ba9e32e 100644
--- a/src/constants.js
+++ b/src/constants.js
@@ -216,6 +216,7 @@ const TEXTGEN_TYPES = {
INFERMATICAI: 'infermaticai',
DREAMGEN: 'dreamgen',
OPENROUTER: 'openrouter',
+ HUGGINGFACE: 'huggingface',
};
const INFERMATICAI_KEYS = [
diff --git a/src/endpoints/backends/text-completions.js b/src/endpoints/backends/text-completions.js
index a6c55acbd..8ca61ecf3 100644
--- a/src/endpoints/backends/text-completions.js
+++ b/src/endpoints/backends/text-completions.js
@@ -95,13 +95,14 @@ router.post('/status', jsonParser, async function (request, response) {
setAdditionalHeaders(request, args, baseUrl);
+ const apiType = request.body.api_type;
let url = baseUrl;
let result = '';
if (request.body.legacy_api) {
url += '/v1/model';
} else {
- switch (request.body.api_type) {
+ switch (apiType) {
case TEXTGEN_TYPES.OOBA:
case TEXTGEN_TYPES.VLLM:
case TEXTGEN_TYPES.APHRODITE:
@@ -126,6 +127,9 @@ router.post('/status', jsonParser, async function (request, response) {
case TEXTGEN_TYPES.OLLAMA:
url += '/api/tags';
break;
+ case TEXTGEN_TYPES.HUGGINGFACE:
+ url += '/info';
+ break;
}
}
@@ -144,14 +148,18 @@ router.post('/status', jsonParser, async function (request, response) {
}
// Rewrap to OAI-like response
- if (request.body.api_type === TEXTGEN_TYPES.TOGETHERAI && Array.isArray(data)) {
+ if (apiType === TEXTGEN_TYPES.TOGETHERAI && Array.isArray(data)) {
data = { data: data.map(x => ({ id: x.name, ...x })) };
}
- if (request.body.api_type === TEXTGEN_TYPES.OLLAMA && Array.isArray(data.models)) {
+ if (apiType === TEXTGEN_TYPES.OLLAMA && Array.isArray(data.models)) {
data = { data: data.models.map(x => ({ id: x.name, ...x })) };
}
+ if (apiType === TEXTGEN_TYPES.HUGGINGFACE) {
+ data = { data: [] };
+ }
+
if (!Array.isArray(data.data)) {
console.log('Models response is not an array.');
return response.status(400);
@@ -163,7 +171,7 @@ router.post('/status', jsonParser, async function (request, response) {
// Set result to the first model ID
result = modelIds[0] || 'Valid';
- if (request.body.api_type === TEXTGEN_TYPES.OOBA) {
+ if (apiType === TEXTGEN_TYPES.OOBA) {
try {
const modelInfoUrl = baseUrl + '/v1/internal/model/info';
const modelInfoReply = await fetch(modelInfoUrl, args);
@@ -178,7 +186,7 @@ router.post('/status', jsonParser, async function (request, response) {
} catch (error) {
console.error(`Failed to get Ooba model info: ${error}`);
}
- } else if (request.body.api_type === TEXTGEN_TYPES.TABBY) {
+ } else if (apiType === TEXTGEN_TYPES.TABBY) {
try {
const modelInfoUrl = baseUrl + '/v1/model';
const modelInfoReply = await fetch(modelInfoUrl, args);
@@ -241,6 +249,7 @@ router.post('/generate', jsonParser, async function (request, response) {
case TEXTGEN_TYPES.KOBOLDCPP:
case TEXTGEN_TYPES.TOGETHERAI:
case TEXTGEN_TYPES.INFERMATICAI:
+ case TEXTGEN_TYPES.HUGGINGFACE:
url += '/v1/completions';
break;
case TEXTGEN_TYPES.DREAMGEN:
diff --git a/src/endpoints/secrets.js b/src/endpoints/secrets.js
index 669dd86b4..532fb1a32 100644
--- a/src/endpoints/secrets.js
+++ b/src/endpoints/secrets.js
@@ -41,6 +41,7 @@ const SECRET_KEYS = {
GROQ: 'api_key_groq',
AZURE_TTS: 'api_key_azure_tts',
ZEROONEAI: 'api_key_01ai',
+ HUGGINGFACE: 'api_key_huggingface',
};
// These are the keys that are safe to expose, even if allowKeysExposure is false