From 5c99d93744031c92dc366a2b7e4d1d0403891c99 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 8 Aug 2024 12:14:00 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=82=20feat:=20Added=20Security=20for?= =?UTF-8?q?=20Conversation=20Access=20(#3588)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🛂 feat: Added Security for Conversation Access * refactor: Update concurrentLimiter and convoAccess middleware to use isEnabled function for Redis check * refactor: handle access check even if cache is not available (edge case) --- api/cache/getLogStores.js | 5 +- api/server/middleware/concurrentLimiter.js | 11 +-- api/server/middleware/index.js | 2 + api/server/middleware/validate/convoAccess.js | 73 +++++++++++++++++++ api/server/middleware/validate/index.js | 4 + api/server/routes/ask/index.js | 5 +- api/server/routes/assistants/chatV1.js | 11 ++- api/server/routes/assistants/chatV2.js | 11 ++- api/server/routes/assistants/index.js | 9 +-- api/server/routes/edit/index.js | 3 + packages/data-provider/src/config.ts | 5 ++ 11 files changed, 121 insertions(+), 18 deletions(-) create mode 100644 api/server/middleware/validate/convoAccess.js create mode 100644 api/server/middleware/validate/index.js diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 2b33751a04..1fdaee9006 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -35,11 +35,11 @@ const messages = isEnabled(USE_REDIS) ? new Keyv({ store: keyvRedis, ttl: Time.FIVE_MINUTES }) : new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.FIVE_MINUTES }); -const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes +const tokenConfig = isEnabled(USE_REDIS) ? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES }) : new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES }); -const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes +const genTitle = isEnabled(USE_REDIS) ? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES }) : new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES }); @@ -69,6 +69,7 @@ const namespaces = { registrations: createViolationInstance('registrations'), [ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT), [ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT), + [ViolationTypes.CONVO_ACCESS]: createViolationInstance(ViolationTypes.CONVO_ACCESS), [ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT), [ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT), [ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance( diff --git a/api/server/middleware/concurrentLimiter.js b/api/server/middleware/concurrentLimiter.js index 402152eb02..58ff689a0b 100644 --- a/api/server/middleware/concurrentLimiter.js +++ b/api/server/middleware/concurrentLimiter.js @@ -1,5 +1,7 @@ -const clearPendingReq = require('../../cache/clearPendingReq'); -const { logViolation, getLogStores } = require('../../cache'); +const { Time } = require('librechat-data-provider'); +const clearPendingReq = require('~/cache/clearPendingReq'); +const { logViolation, getLogStores } = require('~/cache'); +const { isEnabled } = require('~/server/utils'); const denyRequest = require('./denyRequest'); const { @@ -7,7 +9,6 @@ const { CONCURRENT_MESSAGE_MAX = 1, CONCURRENT_VIOLATION_SCORE: score, } = process.env ?? {}; -const ttl = 1000 * 60 * 1; /** * Middleware to limit concurrent requests for a user. @@ -38,7 +39,7 @@ const concurrentLimiter = async (req, res, next) => { const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1); const type = 'concurrent'; - const key = `${USE_REDIS ? namespace : ''}:${userId}`; + const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`; const pendingRequests = +((await cache.get(key)) ?? 0); if (pendingRequests >= limit) { @@ -51,7 +52,7 @@ const concurrentLimiter = async (req, res, next) => { await logViolation(req, res, type, errorMessage, score); return await denyRequest(req, res, errorMessage); } else { - await cache.set(key, pendingRequests + 1, ttl); + await cache.set(key, pendingRequests + 1, Time.ONE_MINUTE); } // Ensure the requests are removed from the store once the request is done diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index 75aab961b5..8d3fff58ff 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -14,6 +14,7 @@ const requireJwtAuth = require('./requireJwtAuth'); const validateModel = require('./validateModel'); const moderateText = require('./moderateText'); const setHeaders = require('./setHeaders'); +const validate = require('./validate'); const limiters = require('./limiters'); const uaParser = require('./uaParser'); const checkBan = require('./checkBan'); @@ -22,6 +23,7 @@ const roles = require('./roles'); module.exports = { ...abortMiddleware, + ...validate, ...limiters, ...roles, noIndex, diff --git a/api/server/middleware/validate/convoAccess.js b/api/server/middleware/validate/convoAccess.js new file mode 100644 index 0000000000..fb48d4475c --- /dev/null +++ b/api/server/middleware/validate/convoAccess.js @@ -0,0 +1,73 @@ +const { Constants, ViolationTypes, Time } = require('librechat-data-provider'); +const denyRequest = require('~/server/middleware/denyRequest'); +const { logViolation, getLogStores } = require('~/cache'); +const { isEnabled } = require('~/server/utils'); +const { getConvo } = require('~/models'); + +const { USE_REDIS, CONVO_ACCESS_VIOLATION_SCORE: score = 0 } = process.env ?? {}; + +/** + * Middleware to validate user's authorization for a conversation. + * + * This middleware checks if a user has the right to access a specific conversation. + * If the user doesn't have access, an error is returned. If the conversation doesn't exist, + * a not found error is returned. If the access is valid, the middleware allows the request to proceed. + * If the `cache` store is not available, the middleware will skip its logic. + * + * @function + * @param {Express.Request} req - Express request object containing user information. + * @param {Express.Response} res - Express response object. + * @param {function} next - Express next middleware function. + * @throws {Error} Throws an error if the user doesn't have access to the conversation. + */ +const validateConvoAccess = async (req, res, next) => { + const namespace = ViolationTypes.CONVO_ACCESS; + const cache = getLogStores(namespace); + + const conversationId = req.body.conversationId; + + if (!conversationId || conversationId === Constants.NEW_CONVO) { + return next(); + } + + const userId = req.user?.id ?? req.user?._id ?? ''; + const type = ViolationTypes.CONVO_ACCESS; + const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}:${conversationId}`; + + try { + if (cache) { + const cachedAccess = await cache.get(key); + if (cachedAccess === 'authorized') { + return next(); + } + } + + const conversation = await getConvo(userId, conversationId); + + if (!conversation) { + return next(); + } + + if (conversation.user !== userId) { + const errorMessage = { + type, + error: 'User not authorized for this conversation', + }; + + if (cache) { + await logViolation(req, res, type, errorMessage, score); + } + return await denyRequest(req, res, errorMessage); + } + + if (cache) { + await cache.set(key, 'authorized', Time.TEN_MINUTES); + } + next(); + } catch (error) { + console.error('Error validating conversation access:', error); + res.status(500).json({ error: 'Internal server error' }); + } +}; + +module.exports = validateConvoAccess; diff --git a/api/server/middleware/validate/index.js b/api/server/middleware/validate/index.js new file mode 100644 index 0000000000..ce476e747f --- /dev/null +++ b/api/server/middleware/validate/index.js @@ -0,0 +1,4 @@ +const validateConvoAccess = require('./convoAccess'); +module.exports = { + validateConvoAccess, +}; diff --git a/api/server/routes/ask/index.js b/api/server/routes/ask/index.js index b5156ed8d1..fb737d3a74 100644 --- a/api/server/routes/ask/index.js +++ b/api/server/routes/ask/index.js @@ -12,9 +12,10 @@ const { uaParser, checkBan, requireJwtAuth, - concurrentLimiter, messageIpLimiter, + concurrentLimiter, messageUserLimiter, + validateConvoAccess, } = require('~/server/middleware'); const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; @@ -37,6 +38,8 @@ if (isEnabled(LIMIT_MESSAGE_USER)) { router.use(messageUserLimiter); } +router.use(validateConvoAccess); + router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI); router.use(`/${EModelEndpoint.chatGPTBrowser}`, askChatGPTBrowser); router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins); diff --git a/api/server/routes/assistants/chatV1.js b/api/server/routes/assistants/chatV1.js index 13386c6c85..36ed6d49e0 100644 --- a/api/server/routes/assistants/chatV1.js +++ b/api/server/routes/assistants/chatV1.js @@ -8,6 +8,7 @@ const { // validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); +const validateConvoAccess = require('~/server/middleware/validate/convoAccess'); const validateAssistant = require('~/server/middleware/assistants/validate'); const chatController = require('~/server/controllers/assistants/chatV1'); @@ -21,6 +22,14 @@ router.post('/abort', handleAbort()); * @param {express.Response} res - The response object, used to send back a response. * @returns {void} */ -router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController); +router.post( + '/', + validateModel, + buildEndpointOption, + validateAssistant, + validateConvoAccess, + setHeaders, + chatController, +); module.exports = router; diff --git a/api/server/routes/assistants/chatV2.js b/api/server/routes/assistants/chatV2.js index 36c29f4bc0..e50994e9bc 100644 --- a/api/server/routes/assistants/chatV2.js +++ b/api/server/routes/assistants/chatV2.js @@ -8,6 +8,7 @@ const { // validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); +const validateConvoAccess = require('~/server/middleware/validate/convoAccess'); const validateAssistant = require('~/server/middleware/assistants/validate'); const chatController = require('~/server/controllers/assistants/chatV2'); @@ -21,6 +22,14 @@ router.post('/abort', handleAbort()); * @param {express.Response} res - The response object, used to send back a response. * @returns {void} */ -router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController); +router.post( + '/', + validateModel, + buildEndpointOption, + validateAssistant, + validateConvoAccess, + setHeaders, + chatController, +); module.exports = router; diff --git a/api/server/routes/assistants/index.js b/api/server/routes/assistants/index.js index 6613177e7b..9640b37b39 100644 --- a/api/server/routes/assistants/index.js +++ b/api/server/routes/assistants/index.js @@ -1,13 +1,6 @@ const express = require('express'); const router = express.Router(); -const { - uaParser, - checkBan, - requireJwtAuth, - // concurrentLimiter, - // messageIpLimiter, - // messageUserLimiter, -} = require('~/server/middleware'); +const { uaParser, checkBan, requireJwtAuth } = require('~/server/middleware'); const v1 = require('./v1'); const chatV1 = require('./chatV1'); diff --git a/api/server/routes/edit/index.js b/api/server/routes/edit/index.js index fa19f9effd..f1d47af3f9 100644 --- a/api/server/routes/edit/index.js +++ b/api/server/routes/edit/index.js @@ -13,6 +13,7 @@ const { messageIpLimiter, concurrentLimiter, messageUserLimiter, + validateConvoAccess, } = require('~/server/middleware'); const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; @@ -35,6 +36,8 @@ if (isEnabled(LIMIT_MESSAGE_USER)) { router.use(messageUserLimiter); } +router.use(validateConvoAccess); + router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI); router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins); router.use(`/${EModelEndpoint.anthropic}`, anthropic); diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 2e83117977..f4c2db609f 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -679,6 +679,7 @@ export enum InfiniteCollections { * Enum for time intervals */ export enum Time { + ONE_HOUR = 3600000, THIRTY_MINUTES = 1800000, TEN_MINUTES = 600000, FIVE_MINUTES = 300000, @@ -799,6 +800,10 @@ export enum ViolationTypes { * Verify Email Limit Violation. */ VERIFY_EMAIL_LIMIT = 'verify_email_limit', + /** + * Verify Conversation Access violation. + */ + CONVO_ACCESS = 'convo_access', } /**