From 8c22bb1d3de87c97ab8d43b7bd94f2c170672a3d Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 20 Apr 2024 15:02:56 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=A0=EF=B8=8F=20fix(Azure/Assistants):?= =?UTF-8?q?=20Handle=20Long=20Domain=20Names=20&=20Other=20Minor=20chores?= =?UTF-8?q?=20(#2475)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: replace violation cache accessors with enum * chore: fix test * chore(fileSchema): index timestamps * fix(ActionService): use encoding/caching strategy for handling assistant function character length limit * refactor(actions): async `domainParser` also resolve retrieved model (which is deployment name) to user-defined model * style(AssistantAction): add `whitespace-nowrap` for ellipsis * refactor(ActionService): if domain is less than or equal to encoded domain fixed length, return domain with replacement of separator * refactor(actions): use sessions/transactions for updating Assistant Action database records * chore: remove TTL from ENCODED_DOMAINS cache * refactor(domainParser): minor optimization and add tests * fix(spendTokens): use txData.user for token usage logging * refactor(actions): add helper function `withSession` for database operations with sessions/transactions * fix(PluginsClient): logger debug `message` field edge case --- api/app/clients/PluginsClient.js | 2 +- api/cache/banViolation.js | 7 +- api/cache/banViolation.spec.js | 3 +- api/cache/getLogStores.js | 12 +- api/models/Action.js | 23 +- api/models/Assistant.js | 11 +- api/models/schema/fileSchema.js | 2 + api/models/spendTokens.js | 2 +- api/server/middleware/checkBan.js | 15 +- api/server/routes/assistants/actions.js | 30 ++- api/server/services/ActionService.js | 53 ++++- api/server/services/ActionService.spec.js | 196 ++++++++++++++++++ api/server/services/ToolService.js | 19 +- api/server/utils/index.js | 2 + api/server/utils/mongoose.js | 25 +++ .../Chat/Messages/Content/ToolCall.tsx | 4 +- .../Messages/Content/MessageContent.tsx | 3 +- .../SidePanel/Builder/AssistantAction.tsx | 2 +- packages/data-provider/src/config.ts | 17 ++ 19 files changed, 365 insertions(+), 63 deletions(-) create mode 100644 api/server/services/ActionService.spec.js create mode 100644 api/server/utils/mongoose.js diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index 033c122664..7ba530b885 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -244,7 +244,7 @@ class PluginsClient extends OpenAIClient { this.setOptions(opts); return super.sendMessage(message, opts); } - logger.debug('[PluginsClient] sendMessage', { message, opts }); + logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts }); const { user, isEdited, diff --git a/api/cache/banViolation.js b/api/cache/banViolation.js index 3d67e57872..1d86007638 100644 --- a/api/cache/banViolation.js +++ b/api/cache/banViolation.js @@ -1,6 +1,7 @@ -const Session = require('~/models/Session'); -const getLogStores = require('./getLogStores'); +const { ViolationTypes } = require('librechat-data-provider'); const { isEnabled, math, removePorts } = require('~/server/utils'); +const getLogStores = require('./getLogStores'); +const Session = require('~/models/Session'); const { logger } = require('~/config'); const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {}; @@ -48,7 +49,7 @@ const banViolation = async (req, res, errorMessage) => { await Session.deleteAllUserSessions(user_id); res.clearCookie('refreshToken'); - const banLogs = getLogStores('ban'); + const banLogs = getLogStores(ViolationTypes.BAN); const duration = errorMessage.duration || banLogs.opts.ttl; if (duration <= 0) { diff --git a/api/cache/banViolation.spec.js b/api/cache/banViolation.spec.js index ba8e78a1ed..8fef16920f 100644 --- a/api/cache/banViolation.spec.js +++ b/api/cache/banViolation.spec.js @@ -6,6 +6,7 @@ jest.mock('../models/Session'); jest.mock('./getLogStores', () => { return jest.fn().mockImplementation(() => { const EventEmitter = require('events'); + const { CacheKeys } = require('librechat-data-provider'); const math = require('../server/utils/math'); const mockGet = jest.fn(); const mockSet = jest.fn(); @@ -33,7 +34,7 @@ jest.mock('./getLogStores', () => { } return new KeyvMongo('', { - namespace: 'bans', + namespace: CacheKeys.BANS, ttl: math(process.env.BAN_DURATION, 7200000), }); }); diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 786bb1f1f7..0d9b662e4e 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -6,6 +6,7 @@ const keyvRedis = require('./keyvRedis'); const keyvMongo = require('./keyvMongo'); const { BAN_DURATION, USE_REDIS } = process.env ?? {}; +const THIRTY_MINUTES = 1800000; const duration = math(BAN_DURATION, 7200000); @@ -24,8 +25,8 @@ const config = isEnabled(USE_REDIS) : new Keyv({ namespace: CacheKeys.CONFIG_STORE }); const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes - ? new Keyv({ store: keyvRedis, ttl: 1800000 }) - : new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: 1800000 }); + ? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES }) + : new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES }); const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes ? new Keyv({ store: keyvRedis, ttl: 120000 }) @@ -42,7 +43,12 @@ const abortKeys = isEnabled(USE_REDIS) const namespaces = { [CacheKeys.CONFIG_STORE]: config, pending_req, - ban: new Keyv({ store: keyvMongo, namespace: 'bans', ttl: duration }), + [ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }), + [CacheKeys.ENCODED_DOMAINS]: new Keyv({ + store: keyvMongo, + namespace: CacheKeys.ENCODED_DOMAINS, + ttl: 0, + }), general: new Keyv({ store: logFile, namespace: 'violations' }), concurrent: createViolationInstance('concurrent'), non_browser: createViolationInstance('non_browser'), diff --git a/api/models/Action.js b/api/models/Action.js index 5141569c10..9acac078b9 100644 --- a/api/models/Action.js +++ b/api/models/Action.js @@ -5,19 +5,18 @@ const Action = mongoose.model('action', actionSchema); /** * Update an action with new data without overwriting existing properties, - * or create a new action if it doesn't exist. + * or create a new action if it doesn't exist, within a transaction session if provided. * * @param {Object} searchParams - The search parameters to find the action to update. * @param {string} searchParams.action_id - The ID of the action to update. * @param {string} searchParams.user - The user ID of the action's author. * @param {Object} updateData - An object containing the properties to update. + * @param {mongoose.ClientSession} [session] - The transaction session to use. * @returns {Promise} The updated or newly created action document as a plain object. */ -const updateAction = async (searchParams, updateData) => { - return await Action.findOneAndUpdate(searchParams, updateData, { - new: true, - upsert: true, - }).lean(); +const updateAction = async (searchParams, updateData, session = null) => { + const options = { new: true, upsert: true, session }; + return await Action.findOneAndUpdate(searchParams, updateData, options).lean(); }; /** @@ -50,15 +49,17 @@ const getActions = async (searchParams, includeSensitive = false) => { }; /** - * Deletes an action by its ID. + * Deletes an action by params, within a transaction session if provided. * - * @param {Object} searchParams - The search parameters to find the action to update. - * @param {string} searchParams.action_id - The ID of the action to update. + * @param {Object} searchParams - The search parameters to find the action to delete. + * @param {string} searchParams.action_id - The ID of the action to delete. * @param {string} searchParams.user - The user ID of the action's author. + * @param {mongoose.ClientSession} [session] - The transaction session to use (optional). * @returns {Promise} A promise that resolves to the deleted action document as a plain object, or null if no document was found. */ -const deleteAction = async (searchParams) => { - return await Action.findOneAndDelete(searchParams).lean(); +const deleteAction = async (searchParams, session = null) => { + const options = session ? { session } : {}; + return await Action.findOneAndDelete(searchParams, options).lean(); }; module.exports = { diff --git a/api/models/Assistant.js b/api/models/Assistant.js index fa6192eee9..17e4077220 100644 --- a/api/models/Assistant.js +++ b/api/models/Assistant.js @@ -5,19 +5,18 @@ const Assistant = mongoose.model('assistant', assistantSchema); /** * Update an assistant with new data without overwriting existing properties, - * or create a new assistant if it doesn't exist. + * or create a new assistant if it doesn't exist, within a transaction session if provided. * * @param {Object} searchParams - The search parameters to find the assistant to update. * @param {string} searchParams.assistant_id - The ID of the assistant to update. * @param {string} searchParams.user - The user ID of the assistant's author. * @param {Object} updateData - An object containing the properties to update. + * @param {mongoose.ClientSession} [session] - The transaction session to use (optional). * @returns {Promise} The updated or newly created assistant document as a plain object. */ -const updateAssistant = async (searchParams, updateData) => { - return await Assistant.findOneAndUpdate(searchParams, updateData, { - new: true, - upsert: true, - }).lean(); +const updateAssistant = async (searchParams, updateData, session = null) => { + const options = { new: true, upsert: true, session }; + return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean(); }; /** diff --git a/api/models/schema/fileSchema.js b/api/models/schema/fileSchema.js index 93a8815e53..2075538b1d 100644 --- a/api/models/schema/fileSchema.js +++ b/api/models/schema/fileSchema.js @@ -99,4 +99,6 @@ const fileSchema = mongoose.Schema( }, ); +fileSchema.index({ createdAt: 1, updatedAt: 1 }); + module.exports = fileSchema; diff --git a/api/models/spendTokens.js b/api/models/spendTokens.js index e37aa41d0c..830cda2075 100644 --- a/api/models/spendTokens.js +++ b/api/models/spendTokens.js @@ -54,7 +54,7 @@ const spendTokens = async (txData, tokenUsage) => { prompt && completion && logger.debug('[spendTokens] Transaction data record against balance:', { - user: prompt.user, + user: txData.user, prompt: prompt.prompt, promptRate: prompt.rate, completion: completion.completion, diff --git a/api/server/middleware/checkBan.js b/api/server/middleware/checkBan.js index a7eab87bdf..aa322cd1c2 100644 --- a/api/server/middleware/checkBan.js +++ b/api/server/middleware/checkBan.js @@ -1,14 +1,15 @@ const Keyv = require('keyv'); const uap = require('ua-parser-js'); -const denyRequest = require('./denyRequest'); -const { getLogStores } = require('../../cache'); +const { ViolationTypes } = require('librechat-data-provider'); const { isEnabled, removePorts } = require('../utils'); -const keyvRedis = require('../../cache/keyvRedis'); -const User = require('../../models/User'); +const keyvRedis = require('~/cache/keyvRedis'); +const denyRequest = require('./denyRequest'); +const { getLogStores } = require('~/cache'); +const User = require('~/models/User'); const banCache = isEnabled(process.env.USE_REDIS) ? new Keyv({ store: keyvRedis }) - : new Keyv({ namespace: 'bans', ttl: 0 }); + : new Keyv({ namespace: ViolationTypes.BAN, ttl: 0 }); const message = 'Your account has been temporarily banned due to violations of our service.'; /** @@ -28,7 +29,7 @@ const banResponse = async (req, res) => { if (!ua.browser.name) { return res.status(403).json({ message }); } else if (baseUrl === '/api/ask' || baseUrl === '/api/edit') { - return await denyRequest(req, res, { type: 'ban' }); + return await denyRequest(req, res, { type: ViolationTypes.BAN }); } return res.status(403).json({ message }); @@ -87,7 +88,7 @@ const checkBan = async (req, res, next = () => {}) => { return await banResponse(req, res); } - const banLogs = getLogStores('ban'); + const banLogs = getLogStores(ViolationTypes.BAN); const duration = banLogs.opts.ttl; if (duration <= 0) { diff --git a/api/server/routes/assistants/actions.js b/api/server/routes/assistants/actions.js index 33db6ce803..711e224c6e 100644 --- a/api/server/routes/assistants/actions.js +++ b/api/server/routes/assistants/actions.js @@ -1,10 +1,11 @@ const { v4 } = require('uuid'); const express = require('express'); -const { actionDelimiter } = require('librechat-data-provider'); -const { initializeClient } = require('~/server/services/Endpoints/assistants'); const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); +const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider'); +const { initializeClient } = require('~/server/services/Endpoints/assistants'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); const { updateAssistant, getAssistant } = require('~/models/Assistant'); +const { withSession } = require('~/server/utils'); const { logger } = require('~/config'); const router = express.Router(); @@ -46,7 +47,7 @@ router.post('/:assistant_id', async (req, res) => { let { domain } = metadata; /* Azure doesn't support periods in function names */ - domain = domainParser(req, domain, true); + domain = await domainParser(req, domain, true); if (!domain) { return res.status(400).json({ message: 'No domain provided' }); @@ -110,7 +111,8 @@ router.post('/:assistant_id', async (req, res) => { const promises = []; promises.push( - updateAssistant( + withSession( + updateAssistant, { assistant_id }, { actions, @@ -119,7 +121,9 @@ router.post('/:assistant_id', async (req, res) => { ), ); promises.push(openai.beta.assistants.update(assistant_id, { tools })); - promises.push(updateAction({ action_id }, { metadata, assistant_id, user: req.user.id })); + promises.push( + withSession(updateAction, { action_id }, { metadata, assistant_id, user: req.user.id }), + ); /** @type {[AssistantDocument, Assistant, Action]} */ const resolved = await Promise.all(promises); @@ -129,6 +133,15 @@ router.post('/:assistant_id', async (req, res) => { delete resolved[2].metadata[field]; } } + + /* Map Azure OpenAI model to the assistant as defined by config */ + if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) { + resolved[1] = { + ...resolved[1], + model: req.body.model, + }; + } + res.json(resolved); } catch (error) { const message = 'Trouble updating the Assistant Action'; @@ -171,7 +184,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => { return true; }); - domain = domainParser(req, domain, true); + domain = await domainParser(req, domain, true); const updatedTools = tools.filter( (tool) => !(tool.function && tool.function.name.includes(domain)), @@ -179,7 +192,8 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => { const promises = []; promises.push( - updateAssistant( + withSession( + updateAssistant, { assistant_id }, { actions: updatedActions, @@ -188,7 +202,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => { ), ); promises.push(openai.beta.assistants.update(assistant_id, { tools: updatedTools })); - promises.push(deleteAction({ action_id })); + promises.push(withSession(deleteAction, { action_id })); await Promise.all(promises); res.status(200).json({ message: 'Action deleted successfully' }); diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 22770f1550..344a6570ba 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -1,20 +1,27 @@ -const { AuthTypeEnum, EModelEndpoint, actionDomainSeparator } = require('librechat-data-provider'); +const { + AuthTypeEnum, + EModelEndpoint, + actionDomainSeparator, + CacheKeys, + Constants, +} = require('librechat-data-provider'); const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); const { getActions } = require('~/models/Action'); +const { getLogStores } = require('~/cache'); const { logger } = require('~/config'); /** - * Parses the domain for an action. + * Encodes or decodes a domain name to/from base64, or replacing periods with a custom separator. * - * Azure OpenAI Assistants API doesn't support periods in function - * names due to `[a-zA-Z0-9_-]*` Regex Validation. + * Necessary because Azure OpenAI Assistants API doesn't support periods in function + * names due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum. * - * @param {Express.Request} req - Express Request object - * @param {string} domain - The domain for the actoin - * @param {boolean} inverse - If true, replaces periods with `actionDomainSeparator` - * @returns {string} The parsed domain + * @param {Express.Request} req - The Express Request object. + * @param {string} domain - The domain name to encode/decode. + * @param {boolean} inverse - False to decode from base64, true to encode to base64. + * @returns {Promise} Encoded or decoded domain string. */ -function domainParser(req, domain, inverse = false) { +async function domainParser(req, domain, inverse = false) { if (!domain) { return; } @@ -23,11 +30,35 @@ function domainParser(req, domain, inverse = false) { return domain; } - if (inverse) { + const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS); + const cachedDomain = await domainsCache.get(domain); + if (inverse && cachedDomain) { + return domain; + } + + if (inverse && domain.length <= Constants.ENCODED_DOMAIN_LENGTH) { return domain.replace(/\./g, actionDomainSeparator); } - return domain.replace(actionDomainSeparator, '.'); + if (inverse) { + const modifiedDomain = Buffer.from(domain).toString('base64'); + const key = modifiedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH); + await domainsCache.set(key, modifiedDomain); + return key; + } + + const replaceSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); + + if (!cachedDomain) { + return domain.replace(replaceSeparatorRegex, '.'); + } + + try { + return Buffer.from(cachedDomain, 'base64').toString('utf-8'); + } catch (error) { + logger.error(`Failed to parse domain (possibly not base64): ${domain}`, error); + return domain; + } } /** diff --git a/api/server/services/ActionService.spec.js b/api/server/services/ActionService.spec.js new file mode 100644 index 0000000000..57f9988961 --- /dev/null +++ b/api/server/services/ActionService.spec.js @@ -0,0 +1,196 @@ +const { Constants, EModelEndpoint, actionDomainSeparator } = require('librechat-data-provider'); +const { domainParser } = require('./ActionService'); + +jest.mock('keyv'); + +const globalCache = {}; +jest.mock('~/cache/getLogStores', () => { + return jest.fn().mockImplementation(() => { + const EventEmitter = require('events'); + const { CacheKeys } = require('librechat-data-provider'); + + class KeyvMongo extends EventEmitter { + constructor(url = 'mongodb://127.0.0.1:27017', options) { + super(); + this.ttlSupport = false; + url = url ?? {}; + if (typeof url === 'string') { + url = { url }; + } + if (url.uri) { + url = { url: url.uri, ...url }; + } + this.opts = { + url, + collection: 'keyv', + ...url, + ...options, + }; + } + + get = async (key) => { + return new Promise((resolve) => { + resolve(globalCache[key] || null); + }); + }; + + set = async (key, value) => { + return new Promise((resolve) => { + globalCache[key] = value; + resolve(true); + }); + }; + } + + return new KeyvMongo('', { + namespace: CacheKeys.ENCODED_DOMAINS, + ttl: 0, + }); + }); +}); + +describe('domainParser', () => { + const req = { + app: { + locals: { + [EModelEndpoint.azureOpenAI]: { + assistants: true, + }, + }, + }, + }; + + const reqNoAzure = { + app: { + locals: { + [EModelEndpoint.azureOpenAI]: { + assistants: false, + }, + }, + }, + }; + + const TLD = '.com'; + + // Non-azure request + it('returns domain as is if not azure', async () => { + const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`; + const result1 = await domainParser(reqNoAzure, domain, false); + const result2 = await domainParser(reqNoAzure, domain, true); + expect(result1).toEqual(domain); + expect(result2).toEqual(domain); + }); + + // Test for Empty or Null Inputs + it('returns undefined for null domain input', async () => { + const result = await domainParser(req, null, true); + expect(result).toBeUndefined(); + }); + + it('returns undefined for empty domain input', async () => { + const result = await domainParser(req, '', true); + expect(result).toBeUndefined(); + }); + + // Verify Correct Caching Behavior + it('caches encoded domain correctly', async () => { + const domain = 'longdomainname.com'; + const encodedDomain = Buffer.from(domain) + .toString('base64') + .substring(0, Constants.ENCODED_DOMAIN_LENGTH); + + await domainParser(req, domain, true); + + const cachedValue = await globalCache[encodedDomain]; + expect(cachedValue).toEqual(Buffer.from(domain).toString('base64')); + }); + + // Test for Edge Cases Around Length Threshold + it('encodes domain exactly at threshold without modification', async () => { + const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - TLD.length) + TLD; + const expected = domain.replace(/\./g, actionDomainSeparator); + const result = await domainParser(req, domain, true); + expect(result).toEqual(expected); + }); + + it('encodes domain just below threshold without modification', async () => { + const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - 1 - TLD.length) + TLD; + const expected = domain.replace(/\./g, actionDomainSeparator); + const result = await domainParser(req, domain, true); + expect(result).toEqual(expected); + }); + + // Test for Unicode Domain Names + it('handles unicode characters in domain names correctly when encoding', async () => { + const unicodeDomain = 'täst.example.com'; + const encodedDomain = Buffer.from(unicodeDomain) + .toString('base64') + .substring(0, Constants.ENCODED_DOMAIN_LENGTH); + const result = await domainParser(req, unicodeDomain, true); + expect(result).toEqual(encodedDomain); + }); + + it('decodes unicode domain names correctly', async () => { + const unicodeDomain = 'täst.example.com'; + const encodedDomain = Buffer.from(unicodeDomain).toString('base64'); + globalCache[encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH)] = encodedDomain; // Simulate caching + + const result = await domainParser( + req, + encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH), + false, + ); + expect(result).toEqual(unicodeDomain); + }); + + // Core Functionality Tests + it('returns domain with replaced separators if no cached domain exists', async () => { + const domain = 'example.com'; + const withSeparator = domain.replace(/\./g, actionDomainSeparator); + const result = await domainParser(req, withSeparator, false); + expect(result).toEqual(domain); + }); + + it('returns domain with replaced separators when inverse is false and under encoding length', async () => { + const domain = 'examp.com'; + const withSeparator = domain.replace(/\./g, actionDomainSeparator); + const result = await domainParser(req, withSeparator, false); + expect(result).toEqual(domain); + }); + + it('replaces periods with actionDomainSeparator when inverse is true and under encoding length', async () => { + const domain = 'examp.com'; + const expected = domain.replace(/\./g, actionDomainSeparator); + const result = await domainParser(req, domain, true); + expect(result).toEqual(expected); + }); + + it('encodes domain when length is above threshold and inverse is true', async () => { + const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH + 1).concat('.com'); + const result = await domainParser(req, domain, true); + expect(result).not.toEqual(domain); + expect(result.length).toBeLessThanOrEqual(Constants.ENCODED_DOMAIN_LENGTH); + }); + + it('returns encoded value if no encoded value is cached, and inverse is false', async () => { + const originalDomain = 'example.com'; + const encodedDomain = Buffer.from( + originalDomain.replace(/\./g, actionDomainSeparator), + ).toString('base64'); + const result = await domainParser(req, encodedDomain, false); + expect(result).toEqual(encodedDomain); + }); + + it('decodes encoded value if cached and encoded value is provided, and inverse is false', async () => { + const originalDomain = 'example.com'; + const encodedDomain = await domainParser(req, originalDomain, true); + const result = await domainParser(req, encodedDomain, false); + expect(result).toEqual(originalDomain); + }); + + it('handles invalid base64 encoded values gracefully', async () => { + const invalidBase64Domain = 'not_base64_encoded'; + const result = await domainParser(req, invalidBase64Domain, false); + expect(result).toEqual(invalidBase64Domain); + }); +}); diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 81c6ca4283..5b131acfd7 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -274,9 +274,16 @@ async function processRequiredActions(client, requiredActions) { })) ?? []; } - const actionSet = actionSets.find((action) => - currentAction.tool.includes(domainParser(client.req, action.metadata.domain, true)), - ); + let actionSet = null; + let currentDomain = ''; + for (let action of actionSets) { + const domain = await domainParser(client.req, action.metadata.domain, true); + if (currentAction.tool.includes(domain)) { + currentDomain = domain; + actionSet = action; + break; + } + } if (!actionSet) { // TODO: try `function` if no action set is found @@ -298,10 +305,8 @@ async function processRequiredActions(client, requiredActions) { builders = requestBuilders; } - const functionName = currentAction.tool.replace( - `${actionDelimiter}${domainParser(client.req, actionSet.metadata.domain, true)}`, - '', - ); + const functionName = currentAction.tool.replace(`${actionDelimiter}${currentDomain}`, ''); + const requestBuilder = builders[functionName]; if (!requestBuilder) { diff --git a/api/server/utils/index.js b/api/server/utils/index.js index e87a4680fc..d4e39a3ae5 100644 --- a/api/server/utils/index.js +++ b/api/server/utils/index.js @@ -5,6 +5,7 @@ const handleText = require('./handleText'); const cryptoUtils = require('./crypto'); const citations = require('./citations'); const sendEmail = require('./sendEmail'); +const mongoose = require('./mongoose'); const queue = require('./queue'); const files = require('./files'); const math = require('./math'); @@ -14,6 +15,7 @@ module.exports = { ...cryptoUtils, ...handleText, ...citations, + ...mongoose, countTokens, removePorts, sendEmail, diff --git a/api/server/utils/mongoose.js b/api/server/utils/mongoose.js new file mode 100644 index 0000000000..652a01c77a --- /dev/null +++ b/api/server/utils/mongoose.js @@ -0,0 +1,25 @@ +const mongoose = require('mongoose'); +/** + * Executes a database operation within a session. + * @param {() => Promise} method - The method to execute. This method must accept a session as its first argument. + * @param {...any} args - Additional arguments to pass to the method. + * @returns {Promise} - The result of the executed method. + */ +async function withSession(method, ...args) { + const session = await mongoose.startSession(); + session.startTransaction(); + try { + const result = await method(...args, session); + await session.commitTransaction(); + return result; + } catch (error) { + if (session.inTransaction()) { + await session.abortTransaction(); + } + throw error; + } finally { + await session.endSession(); + } +} + +module.exports = { withSession }; diff --git a/client/src/components/Chat/Messages/Content/ToolCall.tsx b/client/src/components/Chat/Messages/Content/ToolCall.tsx index 29e0984ca8..fc1da37fbe 100644 --- a/client/src/components/Chat/Messages/Content/ToolCall.tsx +++ b/client/src/components/Chat/Messages/Content/ToolCall.tsx @@ -1,5 +1,5 @@ // import { useState, useEffect } from 'react'; -import { actionDelimiter, actionDomainSeparator } from 'librechat-data-provider'; +import { actionDelimiter, actionDomainSeparator, Constants } from 'librechat-data-provider'; import * as Popover from '@radix-ui/react-popover'; import useLocalize from '~/hooks/useLocalize'; import ProgressCircle from './ProgressCircle'; @@ -63,7 +63,7 @@ export default function ToolCall({ onClick={() => ({})} inProgressText={localize('com_assistants_running_action')} finishedText={ - domain + domain && domain.length !== Constants.ENCODED_DOMAIN_LENGTH ? localize('com_assistants_completed_action', domain) : localize('com_assistants_completed_function', function_name) } diff --git a/client/src/components/Messages/Content/MessageContent.tsx b/client/src/components/Messages/Content/MessageContent.tsx index 07c92f009b..479ecac358 100644 --- a/client/src/components/Messages/Content/MessageContent.tsx +++ b/client/src/components/Messages/Content/MessageContent.tsx @@ -1,4 +1,5 @@ import { Fragment } from 'react'; +import { ViolationTypes } from 'librechat-data-provider'; import type { TResPlugin } from 'librechat-data-provider'; import type { TMessageContentProps, TText, TDisplayProps } from '~/common'; import { useAuthContext } from '~/hooks'; @@ -12,7 +13,7 @@ import Error from './Error'; const ErrorMessage = ({ text }: TText) => { const { logout } = useAuthContext(); - if (text.includes('ban')) { + if (text.includes(ViolationTypes.BAN)) { logout(); return null; } diff --git a/client/src/components/SidePanel/Builder/AssistantAction.tsx b/client/src/components/SidePanel/Builder/AssistantAction.tsx index acf5232b21..114dfdc21b 100644 --- a/client/src/components/SidePanel/Builder/AssistantAction.tsx +++ b/client/src/components/SidePanel/Builder/AssistantAction.tsx @@ -15,7 +15,7 @@ export default function AssistantAction({ className="border-token-border-medium flex w-full rounded-lg border text-sm hover:cursor-pointer" >
{action.metadata.domain} diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 6f58412375..8076504cda 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -482,6 +482,15 @@ export enum CacheKeys { * Key for the override config cache. */ OVERRIDE_CONFIG = 'overrideConfig', + /** + * Key for the bans cache. + */ + BANS = 'bans', + /** + * Key for the encoded domains cache. + * Used by Azure OpenAI Assistants. + */ + ENCODED_DOMAINS = 'encoded_domains', } /** @@ -500,6 +509,10 @@ export enum ViolationTypes { * Token Limit Violation. */ TOKEN_BALANCE = 'token_balance', + /** + * An issued ban. + */ + BAN = 'ban', } /** @@ -580,6 +593,10 @@ export enum Constants { * Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */ NO_PARENT = '00000000-0000-0000-0000-000000000000', + /** + * Fixed, encoded domain length for Azure OpenAI Assistants Function name parsing. + */ + ENCODED_DOMAIN_LENGTH = 10, } /**