From 0c326797dd9472180ef09cdbcb94edaa39221890 Mon Sep 17 00:00:00 2001 From: Danny Avila <110412045+danny-avila@users.noreply.github.com> Date: Sat, 16 Dec 2023 20:45:27 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=B8=20feat:=20Gemini=20vision,=20Impro?= =?UTF-8?q?ved=20Logs=20and=20Multi-modal=20Handling=20(#1368)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add GOOGLE_MODELS env var * feat: add gemini vision support * refactor(GoogleClient): adjust clientOptions handling depending on model * fix(logger): fix redact logic and redact errors only * fix(GoogleClient): do not allow non-multiModal messages when gemini-pro-vision is selected * refactor(OpenAIClient): use `isVisionModel` client property to avoid calling validateVisionModel multiple times * refactor: better debug logging by correctly traversing, redacting sensitive info, and logging condensed versions of long values * refactor(GoogleClient): allow response errors to be thrown/caught above client handling so user receives meaningful error message debug orderedMessages, parentMessageId, and buildMessages result * refactor(AskController): use model from client.modelOptions.model when saving intermediate messages, which requires for the progress callback to be initialized after the client is initialized * feat(useSSE): revert to previous model if the model was auto-switched by backend due to message attachments * docs: update with google updates, notes about Gemini Pro Vision * fix: redis should not be initialized without USE_REDIS and increase max listeners to 20 --- README.md | 2 +- api/app/clients/BaseClient.js | 3 +- api/app/clients/GoogleClient.js | 102 ++++++--- api/app/clients/OpenAIClient.js | 15 +- api/app/clients/PluginsClient.js | 2 +- api/cache/keyvRedis.js | 10 +- api/config/parsers.js | 198 ++++++++++-------- api/config/winston.js | 21 +- api/models/File.js | 6 +- api/models/Transaction.js | 2 +- api/server/controllers/AskController.js | 81 +++---- api/server/middleware/abortMiddleware.js | 3 +- .../services/Config/loadDefaultModels.js | 14 +- api/server/services/Files/images/encode.js | 31 ++- api/server/services/ModelService.js | 20 +- client/src/hooks/useSSE.ts | 7 +- docs/index.md | 2 +- docs/install/apis_and_tokens.md | 32 ++- docs/install/dotenv.md | 9 + package-lock.json | 2 +- packages/data-provider/src/schemas.ts | 4 +- 21 files changed, 356 insertions(+), 210 deletions(-) diff --git a/README.md b/README.md index 1b223f2fd7..d8b2d9ee23 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ # Features - 🖥️ UI matching ChatGPT, including Dark mode, Streaming, and 11-2023 updates - 💬 Multimodal Chat: - - Upload and analyze images with GPT-4-Vision 📸 + - Upload and analyze images with GPT-4 and Gemini Vision 📸 - More filetypes and Assistants API integration in Active Development 🚧 - 🌎 Multilingual UI: - English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro, diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 1ed41b746c..b76883265a 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -357,11 +357,11 @@ class BaseClient { const promptTokens = this.maxContextTokens - remainingContextTokens; - logger.debug('[BaseClient] Payload size:', payload.length); logger.debug('[BaseClient] tokenCountMap:', tokenCountMap); logger.debug('[BaseClient]', { promptTokens, remainingContextTokens, + payloadSize: payload.length, maxContextTokens: this.maxContextTokens, }); @@ -414,7 +414,6 @@ class BaseClient { logger.debug('[BaseClient] tokenCountMap', tokenCountMap); if (tokenCountMap[userMessage.messageId]) { userMessage.tokenCount = tokenCountMap[userMessage.messageId]; - logger.debug('[BaseClient] userMessage.tokenCount', userMessage.tokenCount); logger.debug('[BaseClient] userMessage', userMessage); } diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index d7f887e1ff..950cc8d111 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -4,6 +4,7 @@ const { GoogleVertexAI } = require('langchain/llms/googlevertexai'); const { ChatGoogleGenerativeAI } = require('@langchain/google-genai'); const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai'); const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema'); +const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { getResponseSender, @@ -122,9 +123,18 @@ class GoogleClient extends BaseClient { // stop: modelOptions.stop // no stop method for now }; + if (this.options.attachments) { + this.modelOptions.model = 'gemini-pro-vision'; + } + // TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google this.isGenerativeModel = this.modelOptions.model.includes('gemini'); + this.isVisionModel = validateVisionModel(this.modelOptions.model); const { isGenerativeModel } = this; + if (this.isVisionModel && !this.options.attachments) { + this.modelOptions.model = 'gemini-pro'; + this.isVisionModel = false; + } this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat'); const { isChatModel } = this; this.isTextModel = @@ -216,7 +226,34 @@ class GoogleClient extends BaseClient { })).bind(this); } - buildMessages(messages = [], parentMessageId) { + async buildVisionMessages(messages = [], parentMessageId) { + const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId); + const attachments = await this.options.attachments; + const { files, image_urls } = await encodeAndFormat( + this.options.req, + attachments.filter((file) => file.type.includes('image')), + EModelEndpoint.google, + ); + + const latestMessage = { ...messages[messages.length - 1] }; + + latestMessage.image_urls = image_urls; + this.options.attachments = files; + + latestMessage.text = prompt; + + const payload = { + instances: [ + { + messages: [new HumanMessage(formatMessage({ message: latestMessage }))], + }, + ], + parameters: this.modelOptions, + }; + return { prompt: payload }; + } + + async buildMessages(messages = [], parentMessageId) { if (!this.isGenerativeModel && !this.project_id) { throw new Error( '[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)', @@ -227,17 +264,24 @@ class GoogleClient extends BaseClient { ); } + if (this.options.attachments) { + return this.buildVisionMessages(messages, parentMessageId); + } + if (this.isTextModel) { return this.buildMessagesPrompt(messages, parentMessageId); } - const formattedMessages = messages.map(this.formatMessages()); + let payload = { instances: [ { - messages: formattedMessages, + messages: messages + .map(this.formatMessages()) + .map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' })) + .map((message) => formatMessage({ message, langChain: true })), }, ], - parameters: this.options.modelOptions, + parameters: this.modelOptions, }; if (this.options.promptPrefix) { @@ -248,9 +292,7 @@ class GoogleClient extends BaseClient { payload.instances[0].examples = this.options.examples; } - if (this.options.debug) { - logger.debug('GoogleClient buildMessages', payload); - } + logger.debug('[GoogleClient] buildMessages', payload); return { prompt: payload }; } @@ -260,12 +302,11 @@ class GoogleClient extends BaseClient { messages, parentMessageId, }); - if (this.options.debug) { - logger.debug('GoogleClient: orderedMessages, parentMessageId', { - orderedMessages, - parentMessageId, - }); - } + + logger.debug('[GoogleClient]', { + orderedMessages, + parentMessageId, + }); const formattedMessages = orderedMessages.map((message) => ({ author: message.isCreatedByUser ? this.userLabel : this.modelLabel, @@ -394,7 +435,7 @@ class GoogleClient extends BaseClient { context.shift(); } - let prompt = `${promptBody}${promptSuffix}`; + let prompt = `${promptBody}${promptSuffix}`.trim(); // Add 2 tokens for metadata after all messages have been counted. currentTokenCount += 2; @@ -453,20 +494,26 @@ class GoogleClient extends BaseClient { let examples; - let clientOptions = { - authOptions: { + let clientOptions = { ...parameters, maxRetries: 2 }; + + if (!this.isGenerativeModel) { + clientOptions['authOptions'] = { credentials: { ...this.serviceKey, }, projectId: this.project_id, - }, - ...parameters, - }; + }; + } if (!parameters) { clientOptions = { ...clientOptions, ...this.modelOptions }; } + if (this.isGenerativeModel) { + clientOptions.modelName = clientOptions.model; + delete clientOptions.model; + } + if (_examples && _examples.length) { examples = _examples .map((ex) => { @@ -487,13 +534,9 @@ class GoogleClient extends BaseClient { const model = this.createLLM(clientOptions); let reply = ''; - const messages = this.isTextModel - ? _payload.trim() - : _messages - .map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' })) - .map((message) => formatMessage({ message, langChain: true })); + const messages = this.isTextModel ? _payload.trim() : _messages; - if (context && messages?.length > 0) { + if (!this.isVisionModel && context && messages?.length > 0) { messages.unshift(new SystemMessage(context)); } @@ -526,14 +569,7 @@ class GoogleClient extends BaseClient { async sendCompletion(payload, opts = {}) { let reply = ''; - try { - reply = await this.getCompletion(payload, opts); - if (this.options.debug) { - logger.debug('GoogleClient sendCompletion', { reply }); - } - } catch (err) { - logger.error('failed to send completion to Google', err); - } + reply = await this.getCompletion(payload, opts); return reply.trim(); } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 5894329fc1..13ff16bcb8 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1,7 +1,7 @@ const OpenAI = require('openai'); const { HttpsProxyAgent } = require('https-proxy-agent'); -const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { getResponseSender, EModelEndpoint } = require('librechat-data-provider'); +const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images'); const { getModelMaxTokens, genAzureChatCompletion, extractBaseURL } = require('~/utils'); const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts'); @@ -76,11 +76,14 @@ class OpenAIClient extends BaseClient { }; } - if (this.options.attachments && !validateVisionModel(this.modelOptions.model)) { + this.isVisionModel = validateVisionModel(this.modelOptions.model); + + if (this.options.attachments && !this.isVisionModel) { this.modelOptions.model = 'gpt-4-vision-preview'; + this.isVisionModel = true; } - if (validateVisionModel(this.modelOptions.model)) { + if (this.isVisionModel) { delete this.modelOptions.stop; } @@ -152,7 +155,7 @@ class OpenAIClient extends BaseClient { this.setupTokens(); - if (!this.modelOptions.stop && !validateVisionModel(this.modelOptions.model)) { + if (!this.modelOptions.stop && !this.isVisionModel) { const stopTokens = [this.startToken]; if (this.endToken && this.endToken !== this.startToken) { stopTokens.push(this.endToken); @@ -689,7 +692,7 @@ ${convo} } async recordTokenUsage({ promptTokens, completionTokens }) { - logger.debug('[OpenAIClient]', { promptTokens, completionTokens }); + logger.debug('[OpenAIClient] recordTokenUsage:', { promptTokens, completionTokens }); await spendTokens( { user: this.user, @@ -757,7 +760,7 @@ ${convo} opts.httpAgent = new HttpsProxyAgent(this.options.proxy); } - if (validateVisionModel(modelOptions.model)) { + if (this.isVisionModel) { modelOptions.max_tokens = 4000; } diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index 509b98ca6f..f26df8a2d1 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -180,7 +180,7 @@ class PluginsClient extends OpenAIClient { logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`); if (errorMessage.length > 0) { - logger.debug('[PluginsClient] Caught error, input:', input); + logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input)); } try { diff --git a/api/cache/keyvRedis.js b/api/cache/keyvRedis.js index a5cbb45f11..f723429ee2 100644 --- a/api/cache/keyvRedis.js +++ b/api/cache/keyvRedis.js @@ -1,15 +1,19 @@ const KeyvRedis = require('@keyv/redis'); const { logger } = require('~/config'); +const { isEnabled } = require('~/server/utils'); -const { REDIS_URI } = process.env; +const { REDIS_URI, USE_REDIS } = process.env; let keyvRedis; -if (REDIS_URI) { +if (REDIS_URI && isEnabled(USE_REDIS)) { keyvRedis = new KeyvRedis(REDIS_URI, { useRedisSets: false }); keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err)); + keyvRedis.setMaxListeners(20); } else { - logger.info('REDIS_URI not provided. Redis module will not be initialized.'); + logger.info( + '`REDIS_URI` not provided, or `USE_REDIS` not set. Redis module will not be initialized.', + ); } module.exports = keyvRedis; diff --git a/api/config/parsers.js b/api/config/parsers.js index f94c38da75..4f94d6e4d7 100644 --- a/api/config/parsers.js +++ b/api/config/parsers.js @@ -1,128 +1,160 @@ -const util = require('util'); const winston = require('winston'); const traverse = require('traverse'); const { klona } = require('klona/full'); -const sensitiveKeys = [/^sk-\w+$/, /Bearer \w+/, /api-key: \w+/]; +const SPLAT_SYMBOL = Symbol.for('splat'); +const MESSAGE_SYMBOL = Symbol.for('message'); + +const sensitiveKeys = [/^(sk-)[^\s]+/, /(Bearer )[^\s]+/, /(api-key:? )[^\s]+/, /(key=)[^\s]+/]; /** - * Determines if a given key string is sensitive. + * Determines if a given value string is sensitive and returns matching regex patterns. * - * @param {string} keyStr - The key string to check. - * @returns {boolean} True if the key string matches known sensitive key patterns. + * @param {string} valueStr - The value string to check. + * @returns {Array} An array of regex patterns that match the value string. */ -function isSensitiveKey(keyStr) { - if (keyStr) { - return sensitiveKeys.some((regex) => regex.test(keyStr)); +function getMatchingSensitivePatterns(valueStr) { + if (valueStr) { + // Filter and return all regex patterns that match the value string + return sensitiveKeys.filter((regex) => regex.test(valueStr)); } - return false; + return []; } /** - * Recursively redacts sensitive information from an object. + * Redacts sensitive information from a console message. * - * @param {object} obj - The object to traverse and redact. + * @param {string} str - The console message to be redacted. + * @returns {string} - The redacted console message. */ -function redactObject(obj) { - traverse(obj).forEach(function redactor() { - if (isSensitiveKey(this.key)) { - this.update('[REDACTED]'); - } +function redactMessage(str) { + const patterns = getMatchingSensitivePatterns(str); + + if (patterns.length === 0) { + return str; + } + + patterns.forEach((pattern) => { + str = str.replace(pattern, '$1[REDACTED]'); }); + + return str; } /** - * Deep copies and redacts sensitive information from an object. - * - * @param {object} obj - The object to copy and redact. - * @returns {object} The redacted copy of the original object. + * Redacts sensitive information from log messages if the log level is 'error'. + * Note: Intentionally mutates the object. + * @param {Object} info - The log information object. + * @returns {Object} - The modified log information object. */ -function redact(obj) { - const copy = klona(obj); // Making a deep copy to prevent side effects - redactObject(copy); - - const splat = copy[Symbol.for('splat')]; - redactObject(splat); // Specifically redact splat Symbol - - return copy; -} +const redactFormat = winston.format((info) => { + if (info.level === 'error') { + info.message = redactMessage(info.message); + if (info[MESSAGE_SYMBOL]) { + info[MESSAGE_SYMBOL] = redactMessage(info[MESSAGE_SYMBOL]); + } + } + return info; +}); /** * Truncates long strings, especially base64 image data, within log messages. * * @param {any} value - The value to be inspected and potentially truncated. + * @param {number} [length] - The length at which to truncate the value. Default: 100. * @returns {any} - The truncated or original value. */ -const truncateLongStrings = (value) => { +const truncateLongStrings = (value, length = 100) => { if (typeof value === 'string') { - return value.length > 100 ? value.substring(0, 100) + '... [truncated]' : value; + return value.length > length ? value.substring(0, length) + '... [truncated]' : value; } return value; }; -// /** -// * Processes each message in the messages array, specifically looking for and truncating -// * base64 image URLs in the content. If a base64 image URL is found, it replaces the URL -// * with a truncated message. -// * -// * @param {PayloadMessage} message - The payload message object to format. -// * @returns {PayloadMessage} - The processed message object with base64 image URLs truncated. -// */ -// const truncateBase64ImageURLs = (message) => { -// // Create a deep copy of the message -// const messageCopy = JSON.parse(JSON.stringify(message)); - -// if (messageCopy.content && Array.isArray(messageCopy.content)) { -// messageCopy.content = messageCopy.content.map(contentItem => { -// if (contentItem.type === 'image_url' && contentItem.image_url && isBase64String(contentItem.image_url.url)) { -// return { ...contentItem, image_url: { ...contentItem.image_url, url: 'Base64 Image Data... [truncated]' } }; -// } -// return contentItem; -// }); -// } -// return messageCopy; -// }; - -// /** -// * Checks if a string is a base64 image data string. -// * -// * @param {string} str - The string to be checked. -// * @returns {boolean} - True if the string is base64 image data, otherwise false. -// */ -// const isBase64String = (str) => /^data:image\/[a-zA-Z]+;base64,/.test(str); +/** + * An array mapping function that truncates long strings (objects converted to JSON strings). + * @param {any} item - The item to be condensed. + * @returns {any} - The condensed item. + */ +const condenseArray = (item) => { + if (typeof item === 'string') { + return truncateLongStrings(JSON.stringify(item)); + } else if (typeof item === 'object') { + return truncateLongStrings(JSON.stringify(item)); + } + return item; +}; /** - * Custom log format for Winston that handles deep object inspection. - * It specifically truncates long strings and handles nested structures within metadata. + * Formats log messages for debugging purposes. + * - Truncates long strings within log messages. + * - Condenses arrays by truncating long strings and objects as strings within array items. + * - Redacts sensitive information from log messages if the log level is 'error'. + * - Converts log information object to a formatted string. * - * @param {Object} info - Information about the log entry. + * @param {Object} options - The options for formatting log messages. + * @param {string} options.level - The log level. + * @param {string} options.message - The log message. + * @param {string} options.timestamp - The timestamp of the log message. + * @param {Object} options.metadata - Additional metadata associated with the log message. * @returns {string} - The formatted log message. */ -const deepObjectFormat = winston.format.printf(({ level, message, timestamp, ...metadata }) => { - let msg = `${timestamp} ${level}: ${message}`; +const debugTraverse = winston.format.printf(({ level, message, timestamp, ...metadata }) => { + let msg = `${timestamp} ${level}: ${truncateLongStrings(message?.trim(), 150)}`; - if (Object.keys(metadata).length) { - Object.entries(metadata).forEach(([key, value]) => { - let val = value; - if (key === 'modelOptions' && value && Array.isArray(value.messages)) { - // Create a shallow copy of the messages array - // val = { ...value, messages: value.messages.map(truncateBase64ImageURLs) }; - val = { ...value, messages: `${value.messages.length} message(s) in payload` }; - } - // Inspects each metadata value; applies special handling for 'messages' - const inspectedValue = - typeof val === 'string' - ? truncateLongStrings(val) - : util.inspect(val, { depth: null, colors: false }); // Use 'val' here - msg += ` ${key}: ${inspectedValue}`; - }); + if (level !== 'debug') { + return msg; } + if (!metadata) { + return msg; + } + + const debugValue = metadata[SPLAT_SYMBOL]?.[0]; + + if (!debugValue) { + return msg; + } + + if (debugValue && Array.isArray(debugValue)) { + msg += `\n${JSON.stringify(debugValue.map(condenseArray))}`; + return msg; + } + + if (typeof debugValue !== 'object') { + return (msg += ` ${debugValue}`); + } + + msg += '\n{'; + + const copy = klona(metadata); + traverse(copy).forEach(function (value) { + const parent = this.parent; + const parentKey = `${parent && parent.notRoot ? parent.key + '.' : ''}`; + const tabs = `${parent && parent.notRoot ? '\t\t' : '\t'}`; + if (this.isLeaf && typeof value === 'string') { + const truncatedText = truncateLongStrings(value); + msg += `\n${tabs}${parentKey}${this.key}: ${JSON.stringify(truncatedText)},`; + } else if (this.notLeaf && Array.isArray(value) && value.length > 0) { + const currentMessage = `\n${tabs}// ${value.length} ${this.key.replace(/s$/, '')}(s)`; + this.update(currentMessage, true); + msg += currentMessage; + const stringifiedArray = value.map(condenseArray); + msg += `\n${tabs}${parentKey}${this.key}: [${stringifiedArray}],`; + } else if (this.isLeaf && typeof value === 'function') { + msg += `\n${tabs}${parentKey}${this.key}: function,`; + } else if (this.isLeaf) { + msg += `\n${tabs}${parentKey}${this.key}: ${value},`; + } + }); + + msg += '\n}'; return msg; }); module.exports = { - redact, - deepObjectFormat, + redactFormat, + redactMessage, + debugTraverse, }; diff --git a/api/config/winston.js b/api/config/winston.js index 8038b19106..6cba153f16 100644 --- a/api/config/winston.js +++ b/api/config/winston.js @@ -1,7 +1,7 @@ const path = require('path'); const winston = require('winston'); require('winston-daily-rotate-file'); -const { redact, deepObjectFormat } = require('./parsers'); +const { redactFormat, redactMessage, debugTraverse } = require('./parsers'); const logDir = path.join(__dirname, '..', 'logs'); @@ -32,10 +32,11 @@ const level = () => { }; const fileFormat = winston.format.combine( + redactFormat(), winston.format.timestamp({ format: 'YYYY-MM-DD HH:mm:ss' }), winston.format.errors({ stack: true }), winston.format.splat(), - winston.format((info) => redact(info))(), + // redactErrors(), ); const transports = [ @@ -78,16 +79,24 @@ if ( zippedArchive: true, maxSize: '20m', maxFiles: '14d', - format: winston.format.combine(fileFormat, deepObjectFormat), + format: winston.format.combine(fileFormat, debugTraverse), }), ); } const consoleFormat = winston.format.combine( + redactFormat(), winston.format.colorize({ all: true }), winston.format.timestamp({ format: 'YYYY-MM-DD HH:mm:ss' }), - winston.format((info) => redact(info))(), - winston.format.printf((info) => `${info.timestamp} ${info.level}: ${info.message}`), + // redactErrors(), + winston.format.printf((info) => { + const message = `${info.timestamp} ${info.level}: ${info.message}`; + if (info.level.includes('error')) { + return redactMessage(message); + } + + return message; + }), ); if ( @@ -97,7 +106,7 @@ if ( transports.push( new winston.transports.Console({ level: 'debug', - format: winston.format.combine(consoleFormat, deepObjectFormat), + format: winston.format.combine(consoleFormat, debugTraverse), }), ); } else { diff --git a/api/models/File.js b/api/models/File.js index 84822a71d7..4c353fd70b 100644 --- a/api/models/File.js +++ b/api/models/File.js @@ -24,7 +24,7 @@ const getFiles = async (filter) => { /** * Creates a new file with a TTL of 1 hour. - * @param {Object} data - The file data to be created, must contain file_id. + * @param {MongoFile} data - The file data to be created, must contain file_id. * @returns {Promise} A promise that resolves to the created file document. */ const createFile = async (data) => { @@ -40,7 +40,7 @@ const createFile = async (data) => { /** * Updates a file identified by file_id with new data and removes the TTL. - * @param {Object} data - The data to update, must contain file_id. + * @param {MongoFile} data - The data to update, must contain file_id. * @returns {Promise} A promise that resolves to the updated file document. */ const updateFile = async (data) => { @@ -54,7 +54,7 @@ const updateFile = async (data) => { /** * Increments the usage of a file identified by file_id. - * @param {Object} data - The data to update, must contain file_id and the increment value for usage. + * @param {MongoFile} data - The data to update, must contain file_id and the increment value for usage. * @returns {Promise} A promise that resolves to the updated file document. */ const updateFileUsage = async (data) => { diff --git a/api/models/Transaction.js b/api/models/Transaction.js index 635db45b63..0bc26fc37e 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -39,7 +39,7 @@ transactionSchema.statics.create = async function (transactionData) { { user: transaction.user }, { $inc: { tokenCredits: transaction.tokenValue } }, { upsert: true, new: true }, - ); + ).lean(); }; module.exports = mongoose.model('Transaction', transactionSchema); diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index 11be3afd30..d1d9f8f7ad 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -43,46 +43,51 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { } }; - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender, - conversationId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - cancelled: false, - error: false, - user, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - - const getAbortData = () => ({ - sender, - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - promptTokens, - }); - - const { abortController, onStart } = createAbortController(req, res, getAbortData); + let getText; try { const { client } = await initializeClient({ req, res, endpointOption }); + + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender, + conversationId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: partialText, + model: client.modelOptions.model, + unfinished: true, + cancelled: false, + error: false, + user, + }); + } + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); + + getText = getPartialText; + + const getAbortData = () => ({ + sender, + conversationId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + promptTokens, + }); + + const { abortController, onStart } = createAbortController(req, res, getAbortData); + const messageOptions = { user, parentMessageId, @@ -134,7 +139,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { }); } } catch (error) { - const partialText = getPartialText(); + const partialText = getText && getText(); handleAbortError(res, req, error, { partialText, conversationId, diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index 9bf3b54e31..4a109acf8f 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -2,6 +2,7 @@ const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/uti const { saveMessage, getConvo, getConvoTitle } = require('~/models'); const clearPendingReq = require('~/cache/clearPendingReq'); const abortControllers = require('./abortControllers'); +const { redactMessage } = require('~/config/parsers'); const spendTokens = require('~/models/spendTokens'); const { logger } = require('~/config'); @@ -92,7 +93,7 @@ const handleAbortError = async (res, req, error, data) => { messageId, conversationId, parentMessageId, - text: error.message, + text: redactMessage(error.message), shouldSaveMessage: true, user: req.user.id, }; diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index 31780c5f61..665aa71479 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -1,9 +1,10 @@ -const { EModelEndpoint, defaultModels } = require('librechat-data-provider'); +const { EModelEndpoint } = require('librechat-data-provider'); const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config; const { getOpenAIModels, - getChatGPTBrowserModels, + getGoogleModels, getAnthropicModels, + getChatGPTBrowserModels, } = require('~/server/services/ModelService'); const fitlerAssistantModels = (str) => { @@ -11,6 +12,7 @@ const fitlerAssistantModels = (str) => { }; async function loadDefaultModels() { + const google = getGoogleModels(); const openAI = await getOpenAIModels(); const anthropic = getAnthropicModels(); const chatGPTBrowser = getChatGPTBrowserModels(); @@ -19,13 +21,13 @@ async function loadDefaultModels() { return { [EModelEndpoint.openAI]: openAI, + [EModelEndpoint.google]: google, + [EModelEndpoint.anthropic]: anthropic, + [EModelEndpoint.gptPlugins]: gptPlugins, [EModelEndpoint.azureOpenAI]: azureOpenAI, - [EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels), - [EModelEndpoint.google]: defaultModels[EModelEndpoint.google], [EModelEndpoint.bingAI]: ['BingAI', 'Sydney'], [EModelEndpoint.chatGPTBrowser]: chatGPTBrowser, - [EModelEndpoint.gptPlugins]: gptPlugins, - [EModelEndpoint.anthropic]: anthropic, + [EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels), }; } diff --git a/api/server/services/Files/images/encode.js b/api/server/services/Files/images/encode.js index 90c14c051c..30428ffabc 100644 --- a/api/server/services/Files/images/encode.js +++ b/api/server/services/Files/images/encode.js @@ -1,7 +1,13 @@ const fs = require('fs'); const path = require('path'); +const { EModelEndpoint } = require('librechat-data-provider'); const { updateFile } = require('~/models'); +/** + * Encodes an image file to base64. + * @param {string} imagePath - The path to the image file. + * @returns {Promise} A promise that resolves with the base64 encoded image data. + */ function encodeImage(imagePath) { return new Promise((resolve, reject) => { fs.readFile(imagePath, (err, data) => { @@ -14,6 +20,12 @@ function encodeImage(imagePath) { }); } +/** + * Updates the file and encodes the image. + * @param {Object} req - The request object. + * @param {Object} file - The file object. + * @returns {Promise<[MongoFile, string]>} - A promise that resolves to an array of results from updateFile and encodeImage. + */ async function updateAndEncode(req, file) { const { publicPath, imageOutput } = req.app.locals.config; const userPath = path.join(imageOutput, req.user.id); @@ -29,7 +41,14 @@ async function updateAndEncode(req, file) { return await Promise.all(promises); } -async function encodeAndFormat(req, files) { +/** + * Encodes and formats the given files. + * @param {Express.Request} req - The request object. + * @param {Array} files - The array of files to encode and format. + * @param {EModelEndpoint} [endpoint] - Optional: The endpoint for the image. + * @returns {Promise} - A promise that resolves to the result object containing the encoded images and file details. + */ +async function encodeAndFormat(req, files, endpoint) { const promises = []; for (let file of files) { promises.push(updateAndEncode(req, file)); @@ -46,13 +65,19 @@ async function encodeAndFormat(req, files) { }; for (const [file, base64] of encodedImages) { - result.image_urls.push({ + const imagePart = { type: 'image_url', image_url: { url: `data:image/webp;base64,${base64}`, detail, }, - }); + }; + + if (endpoint && endpoint === EModelEndpoint.google) { + imagePart.image_url = imagePart.image_url.url; + } + + result.image_urls.push(imagePart); result.files.push({ file_id: file.file_id, diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index 46e9368173..08c9ae71d2 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -15,8 +15,14 @@ const modelsCache = isEnabled(process.env.USE_REDIS) ? new Keyv({ store: keyvRedis }) : new Keyv({ namespace: 'models' }); -const { OPENROUTER_API_KEY, OPENAI_REVERSE_PROXY, CHATGPT_MODELS, ANTHROPIC_MODELS, PROXY } = - process.env ?? {}; +const { + OPENROUTER_API_KEY, + OPENAI_REVERSE_PROXY, + CHATGPT_MODELS, + ANTHROPIC_MODELS, + GOOGLE_MODELS, + PROXY, +} = process.env ?? {}; const fetchOpenAIModels = async (opts = { azure: false, plugins: false }, _models = []) => { let models = _models.slice() ?? []; @@ -126,8 +132,18 @@ const getAnthropicModels = () => { return models; }; +const getGoogleModels = () => { + let models = defaultModels[EModelEndpoint.google]; + if (GOOGLE_MODELS) { + models = String(GOOGLE_MODELS).split(','); + } + + return models; +}; + module.exports = { getOpenAIModels, getChatGPTBrowserModels, getAnthropicModels, + getGoogleModels, }; diff --git a/client/src/hooks/useSSE.ts b/client/src/hooks/useSSE.ts index d5e0742e9d..e41fa3cb10 100644 --- a/client/src/hooks/useSSE.ts +++ b/client/src/hooks/useSSE.ts @@ -172,7 +172,7 @@ export default function useSSE(submission: TSubmission | null, index = 0) { const finalHandler = (data: TResData, submission: TSubmission) => { const { requestMessage, responseMessage, conversation } = data; - const { messages, isRegenerate = false } = submission; + const { messages, conversation: submissionConvo, isRegenerate = false } = submission; // update the messages if (isRegenerate) { @@ -199,6 +199,11 @@ export default function useSSE(submission: TSubmission | null, index = 0) { ...conversation, }; + // Revert to previous model if the model was auto-switched by backend due to message attachments + if (conversation.model?.includes('vision') && !submissionConvo.model?.includes('vision')) { + update.model = submissionConvo?.model; + } + setStorage(update); return update; }); diff --git a/docs/index.md b/docs/index.md index 7dbbd55c80..c4882ffccb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,7 +31,7 @@ # Features - 🖥️ UI matching ChatGPT, including Dark mode, Streaming, and 11-2023 updates - 💬 Multimodal Chat: - - Upload and analyze images with GPT-4-Vision 📸 + - Upload and analyze images with GPT-4 and Gemini Vision 📸 - More filetypes and Assistants API integration in Active Development 🚧 - 🌎 Multilingual UI: - English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro, Русский diff --git a/docs/install/apis_and_tokens.md b/docs/install/apis_and_tokens.md index b738e9b224..1e5a3acd32 100644 --- a/docs/install/apis_and_tokens.md +++ b/docs/install/apis_and_tokens.md @@ -70,10 +70,6 @@ For Vertex AI, you need a Service Account JSON key file, with appropriate access Instructions for both are given below. -Setting `GOOGLE_KEY=user_provided` in your .env file will configure both values to be provided from the client (or frontend) like so: - -![image](https://github.com/danny-avila/LibreChat/assets/110412045/728cbc04-4180-45a8-848c-ae5de2b02996) - ### Generative Language API (Gemini) **60 Gemini requests/minute are currently free until early next year when it enters general availability.** @@ -85,21 +81,22 @@ To use Gemini models, you'll need an API key. If you don't already have one, cre

Get an API key here

-Once you have your key, you can either provide it from the frontend by setting the following: - -```bash -GOOGLE_KEY=user_provided -``` - -Or, provide the key in your .env file, which allows all users of your instance to use it. +Once you have your key, provide the key in your .env file, which allows all users of your instance to use it. ```bash GOOGLE_KEY=mY_SeCreT_w9347w8_kEY ``` -> Notes: -> - As of 12/15/23, Gemini Pro Vision is not yet supported but is planned. -> - PaLM2 and Codey models cannot be accessed through the Generative Language API. +Or, you can make users provide it from the frontend by setting the following: +```bash +GOOGLE_KEY=user_provided +``` + +Note: PaLM2 and Codey models cannot be accessed through the Generative Language API, only through Vertex AI. + +Setting `GOOGLE_KEY=user_provided` in your .env file will configure both the Vertex AI Service Account JSON key file and the Generative Language API key to be provided from the frontend like so: + +![image](https://github.com/danny-avila/LibreChat/assets/110412045/728cbc04-4180-45a8-848c-ae5de2b02996) ### Vertex AI (PaLM 2 & Codey) @@ -132,14 +129,15 @@ You can usually get **$300 starting credit**, which makes this option free for 9 **Saving your JSON key file in the project directory which allows all users of your LibreChat instance to use it.** -Alternatively, Once you have your JSON key file, you can also provide it from the frontend on a user-basis by setting the following: +Alternatively, you can make users provide it from the frontend by setting the following: ```bash +# Note: this configures both the Vertex AI Service Account JSON key file +# and the Generative Language API key to be provided from the frontend. GOOGLE_KEY=user_provided ``` -> Notes: -> - As of 12/15/23, Gemini and Gemini Pro Vision are not yet supported through Vertex AI but are planned. +Note: Using Gemini models through Vertex AI is possible but not yet supported. ## Azure OpenAI diff --git a/docs/install/dotenv.md b/docs/install/dotenv.md index 98ad009b4e..67a30b165a 100644 --- a/docs/install/dotenv.md +++ b/docs/install/dotenv.md @@ -199,6 +199,15 @@ GOOGLE_KEY=user_provided GOOGLE_REVERSE_PROXY= ``` +- Customize the available models, separated by commas, **without spaces**. + - The first will be default. + - Leave it blank or commented out to use internal settings (default: all listed below). + +```bash +# all available models as of 12/16/23 +GOOGLE_MODELS=gemini-pro,gemini-pro-vision,chat-bison,chat-bison-32k,codechat-bison,codechat-bison-32k,text-bison,text-bison-32k,text-unicorn,code-gecko,code-bison,code-bison-32k +``` + ### OpenAI - To get your OpenAI API key, you need to: diff --git a/package-lock.json b/package-lock.json index bdc835628c..b0d68e1714 100644 --- a/package-lock.json +++ b/package-lock.json @@ -25558,7 +25558,7 @@ }, "packages/data-provider": { "name": "librechat-data-provider", - "version": "0.3.1", + "version": "0.3.2", "license": "ISC", "dependencies": { "axios": "^1.3.4", diff --git a/packages/data-provider/src/schemas.ts b/packages/data-provider/src/schemas.ts index feac357594..425f3da0aa 100644 --- a/packages/data-provider/src/schemas.ts +++ b/packages/data-provider/src/schemas.ts @@ -25,6 +25,7 @@ export const defaultEndpoints: EModelEndpoint[] = [ export const defaultModels = { [EModelEndpoint.google]: [ 'gemini-pro', + 'gemini-pro-vision', 'chat-bison', 'chat-bison-32k', 'codechat-bison', @@ -135,6 +136,7 @@ export const modularEndpoints = new Set([ export const supportsFiles = { [EModelEndpoint.openAI]: true, + [EModelEndpoint.google]: true, [EModelEndpoint.assistant]: true, }; @@ -144,7 +146,7 @@ export const supportsBalanceCheck = { [EModelEndpoint.gptPlugins]: true, }; -export const visionModels = ['gpt-4-vision', 'llava-13b']; +export const visionModels = ['gpt-4-vision', 'llava-13b', 'gemini-pro-vision']; export const eModelEndpointSchema = z.nativeEnum(EModelEndpoint);