diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index a39974e16..8e506ec65 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -372,7 +372,7 @@ async function sendClaudeRequest(request, response) { if (request.body.stream) { // Pipe remote SSE stream to Express response - forwardFetchResponse(generateResponse, response); + await forwardFetchResponse(generateResponse, response); } else { if (!generateResponse.ok) { const generateResponseText = await generateResponse.text(); @@ -682,7 +682,7 @@ async function sendMakerSuiteRequest(request, response) { if (stream) { try { // Pipe remote SSE stream to Express response - forwardFetchResponse(generateResponse, response); + await forwardFetchResponse(generateResponse, response); } catch (error) { console.error('Error forwarding streaming response:', error); if (!response.headersSent) { @@ -793,7 +793,7 @@ async function sendAI21Request(request, response) { try { const generateResponse = await fetch(API_AI21 + '/chat/completions', options); if (request.body.stream) { - forwardFetchResponse(generateResponse, response); + await forwardFetchResponse(generateResponse, response); } else { if (!generateResponse.ok) { const errorText = await generateResponse.text(); @@ -883,7 +883,7 @@ async function sendMistralAIRequest(request, response) { const generateResponse = await fetch(apiUrl + '/chat/completions', config); if (request.body.stream) { - forwardFetchResponse(generateResponse, response); + await forwardFetchResponse(generateResponse, response); } else { if (!generateResponse.ok) { const errorText = await generateResponse.text(); @@ -982,7 +982,7 @@ async function sendCohereRequest(request, response) { if (request.body.stream) { const stream = await fetch(apiUrl, config); - forwardFetchResponse(stream, response); + await forwardFetchResponse(stream, response); } else { const generateResponse = await fetch(apiUrl, config); if (!generateResponse.ok) { @@ -1093,7 +1093,7 @@ async function sendDeepSeekRequest(request, response) { const generateResponse = await fetch(apiUrl + '/chat/completions', config); if (request.body.stream) { - forwardFetchResponse(generateResponse, response); + await forwardFetchResponse(generateResponse, response); } else { if (!generateResponse.ok) { const errorText = await generateResponse.text(); @@ -1199,7 +1199,7 @@ async function sendXaiRequest(request, response) { const generateResponse = await fetch(apiUrl + '/chat/completions', config); if (request.body.stream) { - forwardFetchResponse(generateResponse, response); + await forwardFetchResponse(generateResponse, response); } else { if (!generateResponse.ok) { const errorText = await generateResponse.text(); @@ -1304,7 +1304,7 @@ async function sendAimlapiRequest(request, response) { const generateResponse = await fetch(apiUrl + '/chat/completions', config); if (request.body.stream) { - forwardFetchResponse(generateResponse, response); + await forwardFetchResponse(generateResponse, response); } else { if (!generateResponse.ok) { const errorText = await generateResponse.text(); @@ -1416,7 +1416,7 @@ async function sendElectronHubRequest(request, response) { const generateResponse = await fetch(apiUrl + '/chat/completions', config); if (request.body.stream) { - forwardFetchResponse(generateResponse, response); + await forwardFetchResponse(generateResponse, response); } else { if (!generateResponse.ok) { const errorText = await generateResponse.text(); @@ -1517,7 +1517,7 @@ async function sendChutesRequest(request, response) { const generateResponse = await fetch(apiUrl + '/chat/completions', config); if (request.body.stream) { - forwardFetchResponse(generateResponse, response); + await forwardFetchResponse(generateResponse, response); } else { if (!generateResponse.ok) { const errorText = await generateResponse.text(); @@ -1612,7 +1612,7 @@ async function sendAzureOpenAIRequest(request, response) { const fetchResponse = await fetch(endpointUrl, config); if (request.body.stream) { - return forwardFetchResponse(fetchResponse, response); + return await forwardFetchResponse(fetchResponse, response); } if (fetchResponse.ok) { @@ -2411,7 +2411,7 @@ router.post('/generate', async function (request, response) { if (request.body.stream) { console.info('Streaming request in progress'); - return forwardFetchResponse(fetchResponse, response); + return await forwardFetchResponse(fetchResponse, response); } if (fetchResponse.ok) { diff --git a/src/endpoints/backends/kobold.js b/src/endpoints/backends/kobold.js index 5968a7cf8..895263199 100644 --- a/src/endpoints/backends/kobold.js +++ b/src/endpoints/backends/kobold.js @@ -99,7 +99,7 @@ router.post('/generate', async function (request, response_generate) { if (request.body.streaming) { // Pipe remote SSE stream to Express response - forwardFetchResponse(response, response_generate); + await forwardFetchResponse(response, response_generate); return; } else { if (!response.ok) { diff --git a/src/endpoints/backends/text-completions.js b/src/endpoints/backends/text-completions.js index 5821af87b..c92fc3a14 100644 --- a/src/endpoints/backends/text-completions.js +++ b/src/endpoints/backends/text-completions.js @@ -404,7 +404,7 @@ router.post('/generate', async function (request, response) { } else if (request.body.stream) { const completionsStream = await fetch(url, args); // Pipe remote SSE stream to Express response - forwardFetchResponse(completionsStream, response); + await forwardFetchResponse(completionsStream, response); } else { const completionsReply = await fetch(url, args); diff --git a/src/endpoints/novelai.js b/src/endpoints/novelai.js index 831eb3ae3..39f65e751 100644 --- a/src/endpoints/novelai.js +++ b/src/endpoints/novelai.js @@ -270,7 +270,7 @@ router.post('/generate', async function (req, res) { if (req.body.streaming) { // Pipe remote SSE stream to Express response - forwardFetchResponse(response, res); + await forwardFetchResponse(response, res); } else { if (!response.ok) { const text = await response.text(); diff --git a/src/endpoints/speech.js b/src/endpoints/speech.js index af7027e41..0140532cd 100644 --- a/src/endpoints/speech.js +++ b/src/endpoints/speech.js @@ -264,7 +264,7 @@ elevenlabs.post('/synthesize', async (req, res) => { } res.set('Content-Type', 'audio/mpeg'); - forwardFetchResponse(response, res); + await forwardFetchResponse(response, res); } catch (error) { console.error(error); return res.sendStatus(500); @@ -328,7 +328,7 @@ elevenlabs.post('/history-audio', async (req, res) => { } res.set('Content-Type', 'audio/mpeg'); - forwardFetchResponse(response, res); + await forwardFetchResponse(response, res); } catch (error) { console.error(error); return res.sendStatus(500); diff --git a/src/middleware/corsProxy.js b/src/middleware/corsProxy.js index 3b9f1fa9b..3235ae92f 100644 --- a/src/middleware/corsProxy.js +++ b/src/middleware/corsProxy.js @@ -35,7 +35,7 @@ export default async function corsProxyMiddleware(req, res) { }); // Copy over relevant response params to the proxy response - forwardFetchResponse(response, res); + await forwardFetchResponse(response, res); } catch (error) { res.status(500).send('Error occurred while trying to proxy to: ' + url + ' ' + error); } diff --git a/src/util.js b/src/util.js index 21a1ffc6c..ea3e50b21 100644 --- a/src/util.js +++ b/src/util.js @@ -704,15 +704,12 @@ export function getImages(directoryPath, sortBy = 'name', type = MEDIA_REQUEST_T * Pipe a fetch() response to an Express.js Response, including status code. * @param {import('node-fetch').Response} from The Fetch API response to pipe from. * @param {import('express').Response} to The Express response to pipe to. + * @returns {Promise} */ -export function forwardFetchResponse(from, to) { +export async function forwardFetchResponse(from, to) { let statusCode = from.status; let statusText = from.statusText; - if (!from.ok) { - console.warn(`Streaming request failed with status ${statusCode} ${statusText}`); - } - // Avoid sending 401 responses as they reset the client Basic auth. // This can produce an interesting artifact as "400 Unauthorized", but it's not out of spec. // https://www.rfc-editor.org/rfc/rfc9110.html#name-overview-of-status-codes @@ -725,6 +722,21 @@ export function forwardFetchResponse(from, to) { to.statusCode = statusCode; to.statusMessage = statusText; + if (!from.ok) { + try { + const rawErrorText = await from.text(); + const detail = rawErrorText || 'Unknown error occurred'; + + console.warn(`Streaming request failed with status ${from.status} ${statusText}: ${detail}`); + to.end(rawErrorText, 'utf-8'); + } catch { + console.warn(`Streaming request failed with status ${from.status} ${statusText}: Unknown error occurred`); + to.end(); + } + + return; + } + if (from.body && to.socket) { from.body.pipe(to); diff --git a/tests/util.test.js b/tests/util.test.js index 8ff86ffb9..47af1248b 100644 --- a/tests/util.test.js +++ b/tests/util.test.js @@ -1,6 +1,31 @@ -import { describe, test, expect } from '@jest/globals'; +import { afterEach, describe, test, expect, jest } from '@jest/globals'; +import { once } from 'node:events'; +import { PassThrough } from 'node:stream'; +import { Response } from 'node-fetch'; import { CHAT_COMPLETION_SOURCES } from '../src/constants'; -import { flattenSchema } from '../src/util'; +import { flattenSchema, forwardFetchResponse } from '../src/util'; + +function createMockExpressResponse() { + const response = new PassThrough(); + response.statusCode = 200; + response.statusMessage = ''; + + return response; +} + +async function collectResponseBody(response) { + const chunks = []; + + response.on('data', chunk => chunks.push(Buffer.from(chunk))); + + await once(response, 'finish'); + + return Buffer.concat(chunks).toString('utf8'); +} + +afterEach(() => { + jest.restoreAllMocks(); +}); describe('flattenSchema', () => { test('should return the schema if it is not an object', () => { @@ -105,3 +130,37 @@ describe('flattenSchema', () => { expect(flattenSchema(schema, 'some-other-api')).toEqual(expected); }); }); + +describe('forwardFetchResponse', () => { + test('should log JSON error bodies and return the original body for non-2xx streaming responses', async () => { + const warnSpy = jest.spyOn(console, 'warn').mockImplementation(() => undefined); + const body = JSON.stringify({ error: { message: 'Forbidden by upstream policy' }, detail: 'policy_denied' }); + const response = createMockExpressResponse(); + const bodyPromise = collectResponseBody(response); + + await forwardFetchResponse(new Response(body, { + status: 403, + statusText: 'Forbidden', + }), response); + + expect(await bodyPromise).toBe(body); + expect(response.statusCode).toBe(403); + expect(warnSpy).toHaveBeenCalledWith(`Streaming request failed with status 403 Forbidden: ${body}`); + }); + + test('should log plain text error bodies and return the original body for non-2xx streaming responses', async () => { + const warnSpy = jest.spyOn(console, 'warn').mockImplementation(() => undefined); + const body = 'Plain text upstream failure'; + const response = createMockExpressResponse(); + const bodyPromise = collectResponseBody(response); + + await forwardFetchResponse(new Response(body, { + status: 502, + statusText: 'Bad Gateway', + }), response); + + expect(await bodyPromise).toBe(body); + expect(response.statusCode).toBe(502); + expect(warnSpy).toHaveBeenCalledWith(`Streaming request failed with status 502 Bad Gateway: ${body}`); + }); +});