diff --git a/api/app/clients/ChatGPTClient.js b/api/app/clients/ChatGPTClient.js index 58483bb7f7..c1ae54fdf0 100644 --- a/api/app/clients/ChatGPTClient.js +++ b/api/app/clients/ChatGPTClient.js @@ -166,6 +166,12 @@ class ChatGPTClient extends BaseClient { console.debug(modelOptions); console.debug(); } + + if (this.azure || this.options.azure) { + // Azure does not accept `model` in the body, so we need to remove it. + delete modelOptions.model; + } + const opts = { method: 'POST', headers: { diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 57ffd58346..5894329fc1 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -104,7 +104,7 @@ class OpenAIClient extends BaseClient { const { model } = this.modelOptions; - this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt-'); + this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt'); this.isChatGptModel = this.isChatCompletion; if ( model.includes('text-davinci') || diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index 6e9b383de7..889499fbc2 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -1,19 +1,33 @@ const { initializeFakeClient } = require('./FakeClient'); jest.mock('../../../lib/db/connectDb'); -jest.mock('../../../models', () => { - return function () { - return { - save: jest.fn(), - deleteConvos: jest.fn(), - getConvo: jest.fn(), - getMessages: jest.fn(), - saveMessage: jest.fn(), - updateMessage: jest.fn(), - saveConvo: jest.fn(), - }; - }; -}); +jest.mock('~/models', () => ({ + User: jest.fn(), + Key: jest.fn(), + Session: jest.fn(), + Balance: jest.fn(), + Transaction: jest.fn(), + getMessages: jest.fn().mockResolvedValue([]), + saveMessage: jest.fn(), + updateMessage: jest.fn(), + deleteMessagesSince: jest.fn(), + deleteMessages: jest.fn(), + getConvoTitle: jest.fn(), + getConvo: jest.fn(), + saveConvo: jest.fn(), + deleteConvos: jest.fn(), + getPreset: jest.fn(), + getPresets: jest.fn(), + savePreset: jest.fn(), + deletePresets: jest.fn(), + findFileById: jest.fn(), + createFile: jest.fn(), + updateFile: jest.fn(), + deleteFile: jest.fn(), + deleteFiles: jest.fn(), + getFiles: jest.fn(), + updateFileUsage: jest.fn(), +})); jest.mock('langchain/chat_models/openai', () => { return { diff --git a/api/app/clients/specs/FakeClient.js b/api/app/clients/specs/FakeClient.js index cc6a54d3db..a5915adcf2 100644 --- a/api/app/clients/specs/FakeClient.js +++ b/api/app/clients/specs/FakeClient.js @@ -42,7 +42,6 @@ class FakeClient extends BaseClient { this.maxContextTokens = getModelMaxTokens(this.modelOptions.model) ?? 4097; } - getCompletion() {} buildMessages() {} getTokenCount(str) { return str.length; @@ -86,6 +85,19 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => { return 'Mock response text'; }); + // eslint-disable-next-line no-unused-vars + TestClient.getCompletion = jest.fn().mockImplementation(async (..._args) => { + return { + choices: [ + { + message: { + content: 'Mock response text', + }, + }, + ], + }; + }); + TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => { const orderedMessages = TestClient.constructor.getMessagesForConversation({ messages, diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 3fbf75f574..a4f06631f2 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -1,8 +1,46 @@ require('dotenv').config(); +const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); +const { genAzureChatCompletion } = require('~/utils/azureUtils'); const OpenAIClient = require('../OpenAIClient'); - jest.mock('meilisearch'); +jest.mock('~/lib/db/connectDb'); +jest.mock('~/models', () => ({ + User: jest.fn(), + Key: jest.fn(), + Session: jest.fn(), + Balance: jest.fn(), + Transaction: jest.fn(), + getMessages: jest.fn().mockResolvedValue([]), + saveMessage: jest.fn(), + updateMessage: jest.fn(), + deleteMessagesSince: jest.fn(), + deleteMessages: jest.fn(), + getConvoTitle: jest.fn(), + getConvo: jest.fn(), + saveConvo: jest.fn(), + deleteConvos: jest.fn(), + getPreset: jest.fn(), + getPresets: jest.fn(), + savePreset: jest.fn(), + deletePresets: jest.fn(), + findFileById: jest.fn(), + createFile: jest.fn(), + updateFile: jest.fn(), + deleteFile: jest.fn(), + deleteFiles: jest.fn(), + getFiles: jest.fn(), + updateFileUsage: jest.fn(), +})); + +jest.mock('langchain/chat_models/openai', () => { + return { + ChatOpenAI: jest.fn().mockImplementation(() => { + return {}; + }), + }; +}); + describe('OpenAIClient', () => { let client, client2; const model = 'gpt-4'; @@ -12,6 +50,21 @@ describe('OpenAIClient', () => { { role: 'assistant', sender: 'Assistant', text: 'Hi', messageId: '2' }, ]; + const defaultOptions = { + // debug: true, + openaiApiKey: 'new-api-key', + modelOptions: { + model, + temperature: 0.7, + }, + }; + + const defaultAzureOptions = { + azureOpenAIApiInstanceName: 'your-instance-name', + azureOpenAIApiDeploymentName: 'your-deployment-name', + azureOpenAIApiVersion: '2020-07-01-preview', + }; + beforeAll(() => { jest.spyOn(console, 'warn').mockImplementation(() => {}); }); @@ -21,14 +74,7 @@ describe('OpenAIClient', () => { }); beforeEach(() => { - const options = { - // debug: true, - openaiApiKey: 'new-api-key', - modelOptions: { - model, - temperature: 0.7, - }, - }; + const options = { ...defaultOptions }; client = new OpenAIClient('test-api-key', options); client2 = new OpenAIClient('test-api-key', options); client.summarizeMessages = jest.fn().mockResolvedValue({ @@ -40,6 +86,7 @@ describe('OpenAIClient', () => { .fn() .mockResolvedValue({ prompt: messages.map((m) => m.text).join('\n') }); client.constructor.freeAndResetAllEncoders(); + client.getMessages = jest.fn().mockResolvedValue([]); }); describe('setOptions', () => { @@ -408,4 +455,46 @@ describe('OpenAIClient', () => { }); }); }); + + describe('sendMessage/getCompletion', () => { + afterEach(() => { + delete process.env.AZURE_OPENAI_DEFAULT_MODEL; + delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME; + }); + + it('[Azure OpenAI] should call getCompletion and fetchEventSource with correct args', async () => { + // Set a default model + process.env.AZURE_OPENAI_DEFAULT_MODEL = 'gpt4-turbo'; + + const onProgress = jest.fn().mockImplementation(() => ({})); + client.azure = defaultAzureOptions; + const getCompletion = jest.spyOn(client, 'getCompletion'); + await client.sendMessage('Hi mom!', { + replaceOptions: true, + ...defaultOptions, + onProgress, + azure: defaultAzureOptions, + }); + + expect(getCompletion).toHaveBeenCalled(); + expect(getCompletion.mock.calls.length).toBe(1); + expect(getCompletion.mock.calls[0][0][0].role).toBe('user'); + expect(getCompletion.mock.calls[0][0][0].content).toBe('Hi mom!'); + + expect(fetchEventSource).toHaveBeenCalled(); + expect(fetchEventSource.mock.calls.length).toBe(1); + + // Check if the first argument (url) is correct + const expectedURL = genAzureChatCompletion(defaultAzureOptions); + const firstCallArgs = fetchEventSource.mock.calls[0]; + + expect(firstCallArgs[0]).toBe(expectedURL); + // Should not have model in the deployment name + expect(firstCallArgs[0]).not.toContain('gpt4-turbo'); + + // Should not include the model in request body + const requestBody = JSON.parse(firstCallArgs[1].body); + expect(requestBody).not.toHaveProperty('model'); + }); + }); }); diff --git a/api/config/parsers.js b/api/config/parsers.js index e9b3c99448..f94c38da75 100644 --- a/api/config/parsers.js +++ b/api/config/parsers.js @@ -3,7 +3,7 @@ const winston = require('winston'); const traverse = require('traverse'); const { klona } = require('klona/full'); -const sensitiveKeys = [/^sk-\w+$/, /Bearer \w+/]; +const sensitiveKeys = [/^sk-\w+$/, /Bearer \w+/, /api-key: \w+/]; /** * Determines if a given key string is sensitive. diff --git a/api/jest.config.js b/api/jest.config.js index f060c0c977..ec44bd7f56 100644 --- a/api/jest.config.js +++ b/api/jest.config.js @@ -7,6 +7,7 @@ module.exports = { './test/jestSetup.js', './test/__mocks__/KeyvMongo.js', './test/__mocks__/logger.js', + './test/__mocks__/fetchEventSource.js', ], moduleNameMapper: { '~/(.*)': '/$1', diff --git a/api/server/services/Endpoints/openAI/addTitle.js b/api/server/services/Endpoints/openAI/addTitle.js index 3197374a68..f630638643 100644 --- a/api/server/services/Endpoints/openAI/addTitle.js +++ b/api/server/services/Endpoints/openAI/addTitle.js @@ -7,8 +7,8 @@ const addTitle = async (req, { text, response, client }) => { return; } - // If the request was aborted, don't generate the title. - if (client.abortController.signal.aborted) { + // If the request was aborted and is not azure, don't generate the title. + if (!client.azure && client.abortController.signal.aborted) { return; } diff --git a/api/test/__mocks__/fetchEventSource.js b/api/test/__mocks__/fetchEventSource.js new file mode 100644 index 0000000000..8f6d3cc575 --- /dev/null +++ b/api/test/__mocks__/fetchEventSource.js @@ -0,0 +1,27 @@ +jest.mock('@waylaidwanderer/fetch-event-source', () => ({ + fetchEventSource: jest + .fn() + .mockImplementation((url, { onopen, onmessage, onclose, onerror, error }) => { + // Simulating the onopen event + onopen && onopen({ status: 200 }); + + // Simulating a few onmessage events + onmessage && + onmessage({ data: JSON.stringify({ message: 'First message' }), event: 'message' }); + onmessage && + onmessage({ data: JSON.stringify({ message: 'Second message' }), event: 'message' }); + onmessage && + onmessage({ data: JSON.stringify({ message: 'Third message' }), event: 'message' }); + + // Simulate the onclose event + onclose && onclose(); + + if (error) { + // Simulate the onerror event + onerror && onerror({ status: 500 }); + } + + // Return a Promise that resolves to simulate async behavior + return Promise.resolve(); + }), +})); diff --git a/api/utils/azureUtils.js b/api/utils/azureUtils.js index 3c4a891bec..a735a6b4f7 100644 --- a/api/utils/azureUtils.js +++ b/api/utils/azureUtils.js @@ -6,7 +6,7 @@ * @property {string} azureOpenAIApiVersion - The Azure OpenAI API version. */ -const { isEnabled } = require('../server/utils'); +const { isEnabled } = require('~/server/utils'); /** * Sanitizes the model name to be used in the URL by removing or replacing disallowed characters. diff --git a/client/src/components/Chat/Menus/EndpointsMenu.tsx b/client/src/components/Chat/Menus/EndpointsMenu.tsx index 5a3a61b713..a4b5ed439a 100644 --- a/client/src/components/Chat/Menus/EndpointsMenu.tsx +++ b/client/src/components/Chat/Menus/EndpointsMenu.tsx @@ -1,5 +1,5 @@ -import { Content, Portal, Root } from '@radix-ui/react-popover'; import { alternateName } from 'librechat-data-provider'; +import { Content, Portal, Root } from '@radix-ui/react-popover'; import { useGetEndpointsQuery } from 'librechat-data-provider/react-query'; import type { FC } from 'react'; import EndpointItems from './Endpoints/MenuItems'; @@ -14,9 +14,14 @@ const EndpointsMenu: FC = () => { const { conversation } = useChatContext(); const selected = conversation?.endpoint ?? ''; + + if (!selected) { + console.warn('No endpoint selected'); + return null; + } return ( - +
dataService.login(payload), { onMutate: () => { queryClient.removeQueries(); + localStorage.removeItem('lastConversationSetup'); + localStorage.removeItem('lastSelectedModel'); + localStorage.removeItem('lastSelectedTools'); + localStorage.removeItem('filesToDelete'); + localStorage.removeItem('lastAssistant'); }, }); }; @@ -375,11 +380,6 @@ export const useRefreshTokenMutation = (): UseMutationResult< return useMutation(() => request.refreshToken(), { onMutate: () => { queryClient.removeQueries(); - localStorage.removeItem('lastConversationSetup'); - localStorage.removeItem('lastSelectedModel'); - localStorage.removeItem('lastSelectedTools'); - localStorage.removeItem('filesToDelete'); - localStorage.removeItem('lastAssistant'); }, }); };