From b2fa6a0afb2bc859bb2f116232f18035cafebb6d Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 1 May 2026 00:09:24 +0300 Subject: [PATCH] Add rate limit to basic auth middleware (#5504) * feat: add rate limiting to basic auth flow * fix: round up retry-after duration * feat: enhance point consume logic * fix: move unauthorized webpage reading inside response function * refactor: move getIpAddress to express-common * fix: check for rate limit before checking creds * fix: use correct rate limit pattern in /recover-step2 * feat: handle CF forwarded IP header in rate limit, whitelist and access logger * feat: add individual config toggles for forwarded headers * feat: enhance IP address retrieval to include forwarded IP for access logging * chore: clean-up diff * fix: don't consume points for missing credentials * feat: log rate limited method and URL Co-authored-by: Copilot * feat: make rate limiter points configurable Co-authored-by: Copilot * feat: implement retry-after header for rate limiting responses Co-authored-by: Copilot --------- Co-authored-by: Copilot --- default/config.yaml | 21 ++++++-- src/endpoints/users-public.js | 33 +++++++----- src/express-common.js | 60 ++++++++++++++++++--- src/middleware/accessLogWriter.js | 4 +- src/middleware/basicAuth.js | 86 +++++++++++++++++++++---------- src/middleware/whitelist.js | 31 ++--------- 6 files changed, 153 insertions(+), 82 deletions(-) diff --git a/default/config.yaml b/default/config.yaml index e78aeba62..e61f783ff 100644 --- a/default/config.yaml +++ b/default/config.yaml @@ -58,7 +58,7 @@ ssl: # -- SECURITY CONFIGURATION -- # Toggle whitelist mode whitelistMode: true -# Whitelist will also verify IP in X-Forwarded-For / X-Real-IP headers +# When enabled, whitelist will also verify IP in headers enabled in `forwardedHeaders` section. enableForwardedWhitelist: true # Whitelist of allowed IP addresses whitelist: @@ -189,9 +189,24 @@ logging: minLogLevel: 0 # -- RATE LIMITING CONFIGURATION -- rateLimiting: - # Use X-Real-IP header instead of socket IP for rate limiting - # Only enable this if you are using a properly configured reverse proxy (like Nginx/traefik/Caddy) + # Use any of the enabled headers in the `forwardedHeaders` section to identify the client IP for rate limiting. + # If disabled, only the socket IP will be used, which may not work correctly if you are behind a reverse proxy. preferRealIpHeader: false + # Set the maximum number of allowed failed basic authentication attempts before rate limiting is applied. Set to 0 to disable rate limiting for basic auth. + basicAuthMaxAttempts: 5 + # Set the maximum number of allowed failed account login attempts before rate limiting is applied. Set to 0 to disable rate limiting for account logins. + accountsLoginMaxAttempts: 5 + # Set the maximum number of allowed failed account recovery attempts before rate limiting is applied. Set to 0 to disable rate limiting for account recovery. + accountsRecoverMaxAttempts: 5 +# Set to true to enable support for real IPs in certain request headers for features like IP whitelisting, rate limiting and access logging. +# Only change if you are sure that you use a correctly configured reverse proxy, otherwise this may lead to IP spoofing. +forwardedHeaders: + # X-Real-IP header (common with Nginx and Caddy) + xRealIp: true + # X-Forwarded-For header (common with many proxies, but may contain multiple IPs - only the first one will be used) + xForwardedFor: true + # CF-Connecting-IP header (used by Cloudflare Tunnels) + cfConnectingIp: false ## BACKUP CONFIGURATION backups: diff --git a/src/endpoints/users-public.js b/src/endpoints/users-public.js index ec677d740..205a3aafc 100644 --- a/src/endpoints/users-public.js +++ b/src/endpoints/users-public.js @@ -3,23 +3,23 @@ import crypto from 'node:crypto'; import storage from 'node-persist'; import express from 'express'; import { RateLimiterMemory, RateLimiterRes } from 'rate-limiter-flexible'; -import { getIpFromRequest, getRealIpFromHeader } from '../express-common.js'; +import { getIpAddress, retryAfter } from '../express-common.js'; import { color, Cache, getConfigValue } from '../util.js'; import { KEY_PREFIX, getUserAvatar, toKey, getPasswordHash, getPasswordSalt } from '../users.js'; const DISCREET_LOGIN = getConfigValue('enableDiscreetLogin', false, 'boolean'); const PREFER_REAL_IP_HEADER = getConfigValue('rateLimiting.preferRealIpHeader', false, 'boolean'); +const LOGIN_POINTS = getConfigValue('rateLimiting.accountsLoginMaxAttempts', 5, 'number'); +const RECOVER_POINTS = getConfigValue('rateLimiting.accountsRecoverMaxAttempts', 5, 'number'); const MFA_CACHE = new Cache(5 * 60 * 1000); -const getIpAddress = (request) => PREFER_REAL_IP_HEADER ? getRealIpFromHeader(request) : getIpFromRequest(request); - export const router = express.Router(); const loginLimiter = new RateLimiterMemory({ - points: 5, + points: LOGIN_POINTS > 0 ? LOGIN_POINTS : Number.MAX_SAFE_INTEGER, duration: 60, }); const recoverLimiter = new RateLimiterMemory({ - points: 5, + points: RECOVER_POINTS > 0 ? RECOVER_POINTS : Number.MAX_SAFE_INTEGER, duration: 300, }); @@ -63,7 +63,7 @@ router.post('/login', async (request, response) => { return response.status(400).json({ error: 'Missing required fields' }); } - const ip = getIpAddress(request); + const ip = getIpAddress(request, PREFER_REAL_IP_HEADER); await loginLimiter.consume(ip); /** @type {import('../users.js').User} */ @@ -95,8 +95,8 @@ router.post('/login', async (request, response) => { return response.json({ handle: user.handle }); } catch (error) { if (error instanceof RateLimiterRes) { - console.error('Login failed: Rate limited from', getIpAddress(request)); - return response.status(429).send({ error: 'Too many attempts. Try again later or recover your password.' }); + console.error('Login failed: Rate limited from', getIpAddress(request, PREFER_REAL_IP_HEADER)); + return retryAfter(response, error).status(429).send({ error: 'Too many attempts. Try again later or recover your password.' }); } console.error('Login failed:', error); @@ -111,7 +111,7 @@ router.post('/recover-step1', async (request, response) => { return response.status(400).json({ error: 'Missing required fields' }); } - const ip = getIpAddress(request); + const ip = getIpAddress(request, PREFER_REAL_IP_HEADER); await recoverLimiter.consume(ip); /** @type {import('../users.js').User} */ @@ -135,8 +135,8 @@ router.post('/recover-step1', async (request, response) => { return response.sendStatus(204); } catch (error) { if (error instanceof RateLimiterRes) { - console.error('Recover step 1 failed: Rate limited from', getIpAddress(request)); - return response.status(429).send({ error: 'Too many attempts. Try again later or contact your admin.' }); + console.error('Recover step 1 failed: Rate limited from', getIpAddress(request, PREFER_REAL_IP_HEADER)); + return retryAfter(response, error).status(429).send({ error: 'Too many attempts. Try again later or contact your admin.' }); } console.error('Recover step 1 failed:', error); @@ -153,7 +153,12 @@ router.post('/recover-step2', async (request, response) => { /** @type {import('../users.js').User} */ const user = await storage.getItem(toKey(request.body.handle)); - const ip = getIpAddress(request); + const ip = getIpAddress(request, PREFER_REAL_IP_HEADER); + const rateLimit = await recoverLimiter.get(ip); + + if (rateLimit !== null && rateLimit.consumedPoints > recoverLimiter.points) { + throw rateLimit; + } if (!user) { console.error('Recover step 2 failed: User', request.body.handle, 'not found'); @@ -189,8 +194,8 @@ router.post('/recover-step2', async (request, response) => { return response.sendStatus(204); } catch (error) { if (error instanceof RateLimiterRes) { - console.error('Recover step 2 failed: Rate limited from', getIpAddress(request)); - return response.status(429).send({ error: 'Too many attempts. Try again later or contact your admin.' }); + console.error('Recover step 2 failed: Rate limited from', getIpAddress(request, PREFER_REAL_IP_HEADER)); + return retryAfter(response, error).status(429).send({ error: 'Too many attempts. Try again later or contact your admin.' }); } console.error('Recover step 2 failed:', error); diff --git a/src/express-common.js b/src/express-common.js index df74cc7b8..3a7050f30 100644 --- a/src/express-common.js +++ b/src/express-common.js @@ -1,5 +1,7 @@ import ipaddr from 'ipaddr.js'; import ipMatching from 'ip-matching'; +import { RateLimiterRes } from 'rate-limiter-flexible'; +import { getConfigValue } from './util.js'; const noopMiddleware = (_req, _res, next) => next(); /** @deprecated Do not use. A global middleware is provided at the application level. */ @@ -29,17 +31,46 @@ export function getIpFromRequest(req) { } /** - * Gets the IP address of the client when behind reverse proxy using x-real-ip header, falls back to socket remote address. - * This function should be used when the application is running behind a reverse proxy (e.g., Nginx, traefik, Caddy...). - * @param {import('express').Request} req Request object - * @returns {string} IP address of the client + * Get the client IP address from the request headers. + * @param {import('express').Request} req Express request object + * @returns {string|undefined} The client IP address */ -export function getRealIpFromHeader(req) { - if (req.headers['x-real-ip']) { +export function getRealOrForwardedIp(req) { + const xRealIpEnabled = !!getConfigValue('forwardedHeaders.xRealIp', true, 'boolean'); + const cfConnectingIpEnabled = !!getConfigValue('forwardedHeaders.cfConnectingIp', false, 'boolean'); + const xForwardedForEnabled = !!getConfigValue('forwardedHeaders.xForwardedFor', true, 'boolean'); + + // Check if X-Real-IP is available + if (req.headers['x-real-ip'] && xRealIpEnabled) { return req.headers['x-real-ip'].toString(); } - return getIpFromRequest(req); + // Check for CF-Connecting-IP (Cloudflare) if available + if (req.headers['cf-connecting-ip'] && cfConnectingIpEnabled) { + return req.headers['cf-connecting-ip'].toString(); + } + + // Check for X-Forwarded-For and parse if available + if (req.headers['x-forwarded-for'] && xForwardedForEnabled) { + const ipList = req.headers['x-forwarded-for'].toString().split(',').map(ip => ip.trim()); + return ipList[0]; + } + + // If none of the headers are available, return undefined + return undefined; +} + +/** + * Gets the IP address of the client, optionally including the real/forwarded IP from headers. + * Most common use cases: key for rate limiter, logging, etc. where you want to have the real client IP if behind a reverse proxy. + * @param {import('express').Request} request Request object + * @param {boolean} includeHeaderIp Whether to include the real/forwarded IP from headers + * @returns {string} IP address of the client (will include "forwarded" info if includeHeaderIp is true and headers are present) + */ +export function getIpAddress(request, includeHeaderIp) { + const socketIp = getIpFromRequest(request); + const forwardedIp = includeHeaderIp && getRealOrForwardedIp(request); + return forwardedIp ? `${socketIp} (forwarded: ${forwardedIp})` : socketIp; } /** @@ -79,3 +110,18 @@ export function filterValidIpPatterns(entries, formatLog) { return validEntries; } + +/** + * Sets the Retry-After header on the response based on the rate limit information. + * @param {import('express').Response} response Express response object + * @param {RateLimiterRes} rateLimit The rate limit information from rate-limiter-flexible + * @returns {import('express').Response} The response object with the Retry-After header set if applicable + */ +export function retryAfter(response, rateLimit) { + if (response.headersSent || !(rateLimit instanceof RateLimiterRes)) { + return response; + } + const retryAfter = Math.ceil(rateLimit.msBeforeNext / 1000); + response.set('Retry-After', retryAfter.toString()); + return response; +} diff --git a/src/middleware/accessLogWriter.js b/src/middleware/accessLogWriter.js index 9451dabc1..2bfef1515 100644 --- a/src/middleware/accessLogWriter.js +++ b/src/middleware/accessLogWriter.js @@ -1,6 +1,6 @@ import path from 'node:path'; import fs from 'node:fs'; -import { getRealIpFromHeader } from '../express-common.js'; +import { getIpAddress } from '../express-common.js'; import { color, getConfigValue } from '../util.js'; const enableAccessLog = getConfigValue('logging.enableAccessLog', true, 'boolean'); @@ -32,7 +32,7 @@ export function migrateAccessLog() { */ export default function accessLoggerMiddleware() { return function (req, res, next) { - const clientIp = getRealIpFromHeader(req); + const clientIp = getIpAddress(req, true); const userAgent = req.headers['user-agent']; if (!knownIPs.has(clientIp)) { diff --git a/src/middleware/basicAuth.js b/src/middleware/basicAuth.js index d0a65960f..e7c8c44de 100644 --- a/src/middleware/basicAuth.js +++ b/src/middleware/basicAuth.js @@ -5,53 +5,83 @@ import { Buffer } from 'node:buffer'; import path from 'node:path'; import storage from 'node-persist'; +import { RateLimiterMemory, RateLimiterRes } from 'rate-limiter-flexible'; import { getAllUserHandles, toKey, getPasswordHash } from '../users.js'; import { getConfigValue, safeReadFileSync } from '../util.js'; +import { getIpAddress, retryAfter } from '../express-common.js'; -const PER_USER_BASIC_AUTH = getConfigValue('perUserBasicAuth', false, 'boolean'); -const ENABLE_ACCOUNTS = getConfigValue('enableUserAccounts', false, 'boolean'); +const PER_USER_BASIC_AUTH = !!getConfigValue('perUserBasicAuth', false, 'boolean'); +const ENABLE_ACCOUNTS = !!getConfigValue('enableUserAccounts', false, 'boolean'); +const PREFER_REAL_IP_HEADER = !!getConfigValue('rateLimiting.preferRealIpHeader', false, 'boolean'); +const BASIC_AUTH_ATTEMPTS = getConfigValue('rateLimiting.basicAuthMaxAttempts', 5, 'number'); + +const basicAuthLimiter = new RateLimiterMemory({ + points: BASIC_AUTH_ATTEMPTS > 0 ? BASIC_AUTH_ATTEMPTS : Number.MAX_SAFE_INTEGER, + duration: 60, +}); const basicAuthMiddleware = async function (request, response, callback) { - const unauthorizedWebpage = safeReadFileSync(path.join(globalThis.DATA_ROOT, '_errors', 'unauthorized.html')) ?? ''; const unauthorizedResponse = (res) => { + const unauthorizedWebpage = safeReadFileSync(path.join(globalThis.DATA_ROOT, '_errors', 'unauthorized.html')) ?? ''; res.set('WWW-Authenticate', 'Basic realm="SillyTavern", charset="UTF-8"'); return res.status(401).send(unauthorizedWebpage); }; - const basicAuthUserName = getConfigValue('basicAuthUser.username'); - const basicAuthUserPassword = getConfigValue('basicAuthUser.password'); - const authHeader = request.headers.authorization; + try { + const ip = getIpAddress(request, PREFER_REAL_IP_HEADER); - if (!authHeader) { - return unauthorizedResponse(response); - } + const basicAuthUserName = getConfigValue('basicAuthUser.username'); + const basicAuthUserPassword = getConfigValue('basicAuthUser.password'); + const authHeader = request.headers.authorization; - const [scheme, credentials] = authHeader.split(' '); + if (!authHeader) { + return unauthorizedResponse(response); + } - if (scheme !== 'Basic' || !credentials) { - return unauthorizedResponse(response); - } + const [scheme, credentials] = authHeader.split(' '); - const usePerUserAuth = PER_USER_BASIC_AUTH && ENABLE_ACCOUNTS; - const [username, ...passwordParts] = Buffer.from(credentials, 'base64') - .toString('utf8') - .split(':'); - const password = passwordParts.join(':'); + if (scheme !== 'Basic' || !credentials) { + return unauthorizedResponse(response); + } - if (!usePerUserAuth && username === basicAuthUserName && password === basicAuthUserPassword) { - return callback(); - } else if (usePerUserAuth) { - const userHandles = await getAllUserHandles(); - for (const userHandle of userHandles) { - if (username === userHandle) { - const user = await storage.getItem(toKey(userHandle)); - if (user && user.enabled && (user.password && user.password === getPasswordHash(password, user.salt))) { - return callback(); + const rateLimit = await basicAuthLimiter.get(ip); + + if (rateLimit !== null && rateLimit.consumedPoints > basicAuthLimiter.points) { + throw rateLimit; + } + + const usePerUserAuth = PER_USER_BASIC_AUTH && ENABLE_ACCOUNTS; + const [username, ...passwordParts] = Buffer.from(credentials, 'base64') + .toString('utf8') + .split(':'); + const password = passwordParts.join(':'); + + if (!usePerUserAuth && username === basicAuthUserName && password === basicAuthUserPassword) { + await basicAuthLimiter.delete(ip); + return callback(); + } else if (usePerUserAuth) { + const userHandles = await getAllUserHandles(); + for (const userHandle of userHandles) { + if (username === userHandle) { + const user = await storage.getItem(toKey(userHandle)); + if (user && user.enabled && (user.password && user.password === getPasswordHash(password, user.salt))) { + await basicAuthLimiter.delete(ip); + return callback(); + } } } } + + await basicAuthLimiter.consume(ip); + return unauthorizedResponse(response); + } catch (error) { + if (error instanceof RateLimiterRes) { + console.error('Basic auth failed: Rate limited from', getIpAddress(request, PREFER_REAL_IP_HEADER), request.method, request.originalUrl); + return retryAfter(response, error).sendStatus(429); + } + console.error('Basic auth error:', error); + return response.sendStatus(500); } - return unauthorizedResponse(response); }; export default basicAuthMiddleware; diff --git a/src/middleware/whitelist.js b/src/middleware/whitelist.js index 98a9aa473..cdf6527c4 100644 --- a/src/middleware/whitelist.js +++ b/src/middleware/whitelist.js @@ -6,7 +6,7 @@ import Handlebars from 'handlebars'; import ipMatching from 'ip-matching'; import isDocker from 'is-docker'; -import { filterValidIpPatterns, getIpFromRequest } from '../express-common.js'; +import { filterValidIpPatterns, getIpFromRequest, getRealOrForwardedIp } from '../express-common.js'; import { color, getConfigValue, safeReadFileSync } from '../util.js'; const whitelistPath = path.join(process.cwd(), './whitelist.txt'); @@ -28,31 +28,6 @@ if (fs.existsSync(whitelistPath)) { whitelist = filterValidIpPatterns(whitelist, (entry, message) => `${color.red('Warning')}: Ignoring invalid whitelist entry ${color.yellow(entry)} - ${message}`); -/** - * Get the client IP address from the request headers. - * @param {import('express').Request} req Express request object - * @returns {string|undefined} The client IP address - */ -function getForwardedIp(req) { - if (!enableForwardedWhitelist) { - return undefined; - } - - // Check if X-Real-IP is available - if (req.headers['x-real-ip']) { - return req.headers['x-real-ip'].toString(); - } - - // Check for X-Forwarded-For and parse if available - if (req.headers['x-forwarded-for']) { - const ipList = req.headers['x-forwarded-for'].toString().split(',').map(ip => ip.trim()); - return ipList[0]; - } - - // If none of the headers are available, return undefined - return undefined; -} - /** * Resolves the IP addresses of Docker hostnames and adds them to the whitelist. * @returns {Promise} Promise that resolves when the Docker hostnames are resolved @@ -92,7 +67,7 @@ export default async function getWhitelistMiddleware() { return function (req, res, next) { const clientIp = getIpFromRequest(req); - const forwardedIp = getForwardedIp(req); + const forwardedIp = enableForwardedWhitelist && getRealOrForwardedIp(req); const userAgent = req.headers['user-agent']; /** @@ -107,7 +82,7 @@ export default async function getWhitelistMiddleware() { //clientIp = req.connection.remoteAddress.split(':').pop(); if (!isIPInWhitelist(whitelist, clientIp) - || forwardedIp && !isIPInWhitelist(whitelist, forwardedIp) + || (forwardedIp && !isIPInWhitelist(whitelist, forwardedIp)) ) { // Log the connection attempt with real IP address const ipDetails = forwardedIp