diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index ea95f4959..f8d00213c 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -99,6 +99,7 @@ const sources = { google: 'google', zai: 'zai', openrouter: 'openrouter', + workersai: 'workersai', }; const comfyTypes = { standard: 'standard', @@ -1747,6 +1748,9 @@ async function loadSamplers() { case sources.openrouter: samplers = ['N/A']; break; + case sources.workersai: + samplers = ['N/A']; + break; } for (const sampler of samplers) { @@ -1997,6 +2001,9 @@ async function loadModels() { case sources.openrouter: models = await loadOpenRouterModels(); break; + case sources.workersai: + models = await loadWorkersAIImageModels(); + break; } if (extension_settings.sd.source === sources.electronhub) { @@ -2124,6 +2131,33 @@ async function loadXAIModels() { ]; } +async function loadWorkersAIImageModels() { + $('#sd_cf_workers_key').toggleClass('success', !!secret_state[SECRET_KEYS.WORKERS_AI]); + + if (!secret_state[SECRET_KEYS.WORKERS_AI]) { + return []; + } + + if (!oai_settings.workers_ai_account_id) { + toastr.warning('Workers AI account ID is required. Save it in the "API Connections" panel.', 'Image Generation'); + return []; + } + + const result = await fetch('/api/sd/workersai/models', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ + account_id: oai_settings.workers_ai_account_id, + }), + }); + + if (result.ok) { + return await result.json(); + } + + return []; +} + async function loadPollinationsModels() { $('#sd_pollinations_key').toggleClass('success', !!secret_state[SECRET_KEYS.POLLINATIONS]); @@ -2609,6 +2643,9 @@ async function loadSchedulers() { case sources.openrouter: schedulers = ['N/A']; break; + case sources.workersai: + schedulers = ['N/A']; + break; } for (const scheduler of schedulers) { @@ -2729,6 +2766,9 @@ async function loadVaes() { case sources.openrouter: vaes = ['N/A']; break; + case sources.workersai: + vaes = ['N/A']; + break; } for (const vae of vaes) { @@ -3432,6 +3472,9 @@ async function sendGenerationRequest(generationType, prompt, additionalNegativeP case sources.openrouter: result = await generateOpenRouterImage(prefixedPrompt, signal); break; + case sources.workersai: + result = await generateWorkersAIImage(prefixedPrompt, negativePrompt, signal); + break; } if (!result.data) { @@ -4748,6 +4791,33 @@ async function generateOpenRouterImage(prompt, signal) { throw new Error(text); } +async function generateWorkersAIImage(prompt, negativePrompt, signal) { + const result = await fetch('/api/sd/workersai/generate', { + method: 'POST', + headers: getRequestHeaders(), + signal: signal, + body: JSON.stringify({ + prompt: prompt, + negative_prompt: negativePrompt, + model: extension_settings.sd.model, + width: extension_settings.sd.width, + height: extension_settings.sd.height, + steps: extension_settings.sd.steps, + scale: extension_settings.sd.scale, + seed: extension_settings.sd.seed >= 0 ? extension_settings.sd.seed : undefined, + account_id: oai_settings.workers_ai_account_id, + }), + }); + + if (result.ok) { + const data = await result.json(); + return { format: data?.format, data: data?.image }; + } else { + const text = await result.text(); + throw new Error(text); + } +} + async function onComfyOpenWorkflowEditorClick() { let workflow = await (await fetch('/api/sd/comfy/workflow', { method: 'POST', @@ -5120,6 +5190,8 @@ function isValidState() { return secret_state[SECRET_KEYS.ZAI]; case sources.openrouter: return secret_state[SECRET_KEYS.OPENROUTER]; + case sources.workersai: + return !!oai_settings.workers_ai_account_id && secret_state[SECRET_KEYS.WORKERS_AI]; default: return false; } @@ -5879,6 +5951,9 @@ export async function init() { extension_settings.sd.google_duration = Number($(this).val()); saveSettingsDebounced(); }); + $('#sd_models_refresh').on('click', async () => { + await loadModels(); + }); $('#sd_electronhub_quality').on('change', function () { extension_settings.sd.electronhub_quality = String($(this).val()); saveSettingsDebounced(); @@ -5922,6 +5997,7 @@ export async function init() { [sources.aimlapi]: SECRET_KEYS.AIMLAPI, [sources.comfy]: SECRET_KEYS.COMFY_RUNPOD, [sources.pollinations]: SECRET_KEYS.POLLINATIONS, + [sources.workersai]: SECRET_KEYS.WORKERS_AI, }; const shouldReloadOptions = Object.entries(keySourceMap).some(([k, v]) => k === extension_settings.sd.source && v === key); if (!shouldReloadOptions) { diff --git a/public/scripts/extensions/stable-diffusion/settings.html b/public/scripts/extensions/stable-diffusion/settings.html index ba7e4d00e..a71c4083f 100644 --- a/public/scripts/extensions/stable-diffusion/settings.html +++ b/public/scripts/extensions/stable-diffusion/settings.html @@ -40,10 +40,11 @@ Minimal response prompt processing Source - + AI/ML API BFL (Black Forest Labs) Chutes + Cloudflare Workers AI ComfyUI DrawThings HTTP API Electron Hub @@ -123,7 +124,7 @@ Image Quality - + @@ -191,14 +192,14 @@ Image Style - + Vivid Natural Image Quality - + Auto Low Medium @@ -207,7 +208,7 @@ Image Quality - + Standard HD @@ -216,7 +217,7 @@ Duration - + Short (4 seconds) Medium (8 seconds) Long (12 seconds) @@ -226,7 +227,7 @@ Server Type - + Standard Server RunPod Serverless Endpoint @@ -318,7 +319,7 @@ Style Preset - + Anime 3D Model Analog Film @@ -375,6 +376,20 @@ + + Cloudflare Workers AI + + API Key + + + Click to set + + + + Hint: Account ID and API key are pulled from API connections. + + + @@ -394,7 +409,7 @@ Duration (Veo) - + Short (4 seconds) Medium (6 seconds) Long (8 seconds) @@ -405,37 +420,42 @@ - Model - + + Model + + + + + VAE - + Sampling method - + Scheduler - + Resolution - + Upscaler - + diff --git a/public/scripts/extensions/stable-diffusion/style.css b/public/scripts/extensions/stable-diffusion/style.css index e0053f740..c193ee4ab 100644 --- a/public/scripts/extensions/stable-diffusion/style.css +++ b/public/scripts/extensions/stable-diffusion/style.css @@ -1,7 +1,3 @@ -.sd_settings label:not(.checkbox_label) { - display: block; -} - #sd_dropdown { z-index: 30000; backdrop-filter: blur(var(--SmartThemeBlurStrength)); diff --git a/src/endpoints/stable-diffusion.js b/src/endpoints/stable-diffusion.js index 3418cd890..e6af932c9 100644 --- a/src/endpoints/stable-diffusion.js +++ b/src/endpoints/stable-diffusion.js @@ -5,7 +5,6 @@ import express from 'express'; import fetch from 'node-fetch'; import sanitize from 'sanitize-filename'; import { sync as writeFileAtomicSync } from 'write-file-atomic'; -import FormData from 'form-data'; import urlJoin from 'url-join'; import _ from 'lodash'; import mime from 'mime-types'; @@ -2031,6 +2030,145 @@ zai.post('/generate-video', async (request, response) => { } }); +const workersai = express.Router(); + +workersai.post('/models', async (request, response) => { + try { + const key = readSecret(request.user.directories, SECRET_KEYS.WORKERS_AI); + + if (!key) { + console.warn('Cloudflare Workers AI API key not found.'); + return response.sendStatus(400); + } + + const accountId = String(request.body.account_id || '').trim(); + if (!accountId) { + console.warn('Cloudflare Workers AI Account ID not found.'); + return response.sendStatus(400); + } + + const apiUrl = new URL(`https://api.cloudflare.com/client/v4/accounts/${encodeURIComponent(accountId)}/ai/models/search`); + apiUrl.searchParams.set('task', 'Text-to-Image'); + apiUrl.searchParams.set('per_page', '1000'); + const result = await fetch(apiUrl, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${key}`, + }, + }); + + if (!result.ok) { + console.warn('Cloudflare Workers AI returned an error.', result.statusText); + return response.sendStatus(500); + } + + /** @type {any} */ + const data = await result.json(); + + if (!data.success || !Array.isArray(data.result)) { + console.warn('Cloudflare Workers AI returned invalid data.'); + return response.sendStatus(500); + } + + const models = data.result.map(x => ({ value: x.name, text: x.name })); + return response.send(models); + } catch (error) { + console.error(error); + return response.sendStatus(500); + } +}); + +workersai.post('/generate', async (request, response) => { + try { + const key = readSecret(request.user.directories, SECRET_KEYS.WORKERS_AI); + + if (!key) { + console.warn('Cloudflare Workers AI API key not found.'); + return response.sendStatus(400); + } + + const accountId = String(request.body.account_id || '').trim(); + if (!accountId) { + console.warn('Cloudflare Workers AI Account ID not found.'); + return response.sendStatus(400); + } + + const model = String(request.body.model || '').trim(); + if (!model) { + console.warn('Cloudflare Workers AI model not specified.'); + return response.sendStatus(400); + } + + const apiUrl = `https://api.cloudflare.com/client/v4/accounts/${encodeURIComponent(accountId)}/ai/run/${model}`; + + const body = { + prompt: request.body.prompt, + negative_prompt: request.body.negative_prompt || undefined, + width: request.body.width ? Number(request.body.width) : undefined, + height: request.body.height ? Number(request.body.height) : undefined, + num_steps: request.body.steps ? Number(request.body.steps) : undefined, + guidance: request.body.scale ? Number(request.body.scale) : undefined, + seed: request.body.seed >= 0 ? Number(request.body.seed) : undefined, + }; + + // Remove undefined values + for (const prop of Object.keys(body)) { + if (body[prop] === undefined) { + delete body[prop]; + } + } + + console.debug('Cloudflare Workers AI request:', model, body); + + /** @type {import('node-fetch').RequestInit} */ + const apiRequest = { + method: 'POST', + headers: { + 'Authorization': `Bearer ${key}`, + }, + }; + + if (/flux-2/.test(model)) { + const formData = new FormData(); + for (const [key, value] of Object.entries(body)) { + formData.append(key, String(value)); + } + apiRequest.body = formData; + } else { + apiRequest.headers = { ...apiRequest.headers, 'Content-Type': 'application/json' }; + apiRequest.body = JSON.stringify(body); + } + + const result = await fetch(apiUrl, apiRequest); + if (!result.ok) { + const text = await result.text(); + console.warn('Cloudflare Workers AI returned an error.', result.status, result.statusText, text); + return response.status(500).send(text); + } + + const contentType = result.headers.get('content-type') || ''; + + // Partner models return JSON with base64 image + if (contentType.includes('application/json')) { + /** @type {any} */ + const data = await result.json(); + const image = data?.result?.image || data?.image; + if (!image) { + console.warn('Cloudflare Workers AI returned JSON without image data.'); + return response.sendStatus(500); + } + return response.send({ format: 'png', image: image }); + } + + // Non-partner models return raw binary image data + const buffer = await result.arrayBuffer(); + return response.send({ format: 'png', image: Buffer.from(buffer).toString('base64') }); + } catch (error) { + console.error(error); + return response.sendStatus(500); + } +}); + router.use('/comfy', comfy); router.use('/comfyrunpod', comfyRunPod); router.use('/together', together); @@ -2047,3 +2185,4 @@ router.use('/falai', falai); router.use('/xai', xai); router.use('/aimlapi', aimlapi); router.use('/zai', zai); +router.use('/workersai', workersai);