mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2026-06-15 23:20:32 +03:00
wip support reranker for all vector dbs
This commit is contained in:
@@ -1,8 +1,5 @@
|
||||
import { useState } from "react";
|
||||
|
||||
// We dont support all vectorDBs yet for reranking due to complexities of how each provider
|
||||
// returns information. We need to normalize the response data so Reranker can be used for each provider.
|
||||
const supportedVectorDBs = ["lancedb"];
|
||||
const hint = {
|
||||
default: {
|
||||
title: "Default",
|
||||
@@ -20,8 +17,7 @@ export default function VectorSearchMode({ workspace, setHasChanges }) {
|
||||
const [selection, setSelection] = useState(
|
||||
workspace?.vectorSearchMode ?? "default"
|
||||
);
|
||||
if (!workspace?.vectorDB || !supportedVectorDBs.includes(workspace?.vectorDB))
|
||||
return null;
|
||||
if (!workspace?.vectorDB) return null;
|
||||
|
||||
return (
|
||||
<div>
|
||||
|
||||
@@ -5,6 +5,7 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files");
|
||||
const { v4: uuidv4 } = require("uuid");
|
||||
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
|
||||
const { sourceIdentifier } = require("../../chats");
|
||||
const { rerankDocuments, getSearchLimit } = require("../rerank");
|
||||
|
||||
const sanitizeNamespace = (namespace) => {
|
||||
// If namespace already starts with ns_, don't add it again
|
||||
@@ -301,6 +302,7 @@ const AstraDB = {
|
||||
similarityThreshold = 0.25,
|
||||
topN = 4,
|
||||
filterIdentifiers = [],
|
||||
rerank = false,
|
||||
}) {
|
||||
if (!namespace || !input || !LLMConnector)
|
||||
throw new Error("Invalid request to performSimilaritySearch.");
|
||||
@@ -319,14 +321,24 @@ const AstraDB = {
|
||||
}
|
||||
|
||||
const queryVector = await LLMConnector.embedTextInput(input);
|
||||
const { contextTexts, sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace: sanitizedNamespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
const { contextTexts, sourceDocuments } = rerank
|
||||
? await this.rerankedSimilarityResponse({
|
||||
client,
|
||||
namespace: sanitizedNamespace,
|
||||
query: input,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
})
|
||||
: await this.similarityResponse({
|
||||
client,
|
||||
namespace: sanitizedNamespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
|
||||
const sources = sourceDocuments.map((metadata, i) => {
|
||||
return { ...metadata, text: contextTexts[i] };
|
||||
@@ -378,6 +390,35 @@ const AstraDB = {
|
||||
});
|
||||
return result;
|
||||
},
|
||||
rerankedSimilarityResponse: async function ({
|
||||
client,
|
||||
namespace,
|
||||
query,
|
||||
queryVector,
|
||||
topN = 4,
|
||||
similarityThreshold = 0.25,
|
||||
filterIdentifiers = [],
|
||||
}) {
|
||||
const totalEmbeddings = await this.namespaceCount(namespace);
|
||||
const searchLimit = getSearchLimit(totalEmbeddings, topN);
|
||||
const { sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN: searchLimit,
|
||||
filterIdentifiers,
|
||||
});
|
||||
return await rerankDocuments(
|
||||
query,
|
||||
sourceDocuments.map((doc) => ({ ...doc.metadata, score: null })),
|
||||
{
|
||||
topN,
|
||||
similarityThreshold,
|
||||
filterIdentifiers,
|
||||
}
|
||||
);
|
||||
},
|
||||
allNamespaces: async function (client) {
|
||||
try {
|
||||
let header = new Headers();
|
||||
|
||||
@@ -6,6 +6,7 @@ const { v4: uuidv4 } = require("uuid");
|
||||
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
|
||||
const { parseAuthHeader } = require("../../http");
|
||||
const { sourceIdentifier } = require("../../chats");
|
||||
const { rerankDocuments, getSearchLimit } = require("../rerank");
|
||||
const COLLECTION_REGEX = new RegExp(
|
||||
/^(?!\d+\.\d+\.\d+\.\d+$)(?!.*\.\.)(?=^[a-zA-Z0-9][a-zA-Z0-9_-]{1,61}[a-zA-Z0-9]$).{3,63}$/
|
||||
);
|
||||
@@ -150,6 +151,52 @@ const Chroma = {
|
||||
|
||||
return result;
|
||||
},
|
||||
rerankedSimilarityResponse: async function ({
|
||||
client,
|
||||
namespace,
|
||||
query,
|
||||
queryVector,
|
||||
topN = 4,
|
||||
similarityThreshold = 0.25,
|
||||
filterIdentifiers = [],
|
||||
}) {
|
||||
const totalEmbeddings = await this.namespaceCount(namespace);
|
||||
const searchLimit = getSearchLimit(totalEmbeddings, topN);
|
||||
const { sourceDocuments, contextTexts } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN: searchLimit,
|
||||
filterIdentifiers,
|
||||
});
|
||||
const documentsForReranking = sourceDocuments.map((metadata, i) => ({
|
||||
...metadata,
|
||||
text: contextTexts[i],
|
||||
}));
|
||||
|
||||
const rerankedDocs = await rerankDocuments(query, documentsForReranking, {
|
||||
topN,
|
||||
similarityThreshold,
|
||||
filterIdentifiers,
|
||||
});
|
||||
|
||||
// Post-process to fix scores and contextTexts from the generic reranker.
|
||||
const result = {
|
||||
contextTexts: [],
|
||||
sourceDocuments: [],
|
||||
scores: [],
|
||||
};
|
||||
|
||||
rerankedDocs.sourceDocuments.forEach((item) => {
|
||||
if (item.rerank_score < similarityThreshold) return;
|
||||
const { rerank_score, ...rest } = item;
|
||||
result.sourceDocuments.push({ ...rest, score: rerank_score });
|
||||
result.contextTexts.push(item.text);
|
||||
result.scores.push(rerank_score);
|
||||
});
|
||||
return result;
|
||||
},
|
||||
namespace: async function (client, namespace = null) {
|
||||
if (!namespace) throw new Error("No namespace value provided.");
|
||||
const collection = await client
|
||||
@@ -348,12 +395,14 @@ const Chroma = {
|
||||
similarityThreshold = 0.25,
|
||||
topN = 4,
|
||||
filterIdentifiers = [],
|
||||
rerank = false,
|
||||
}) {
|
||||
if (!namespace || !input || !LLMConnector)
|
||||
throw new Error("Invalid request to performSimilaritySearch.");
|
||||
|
||||
const { client } = await this.connect();
|
||||
if (!(await this.namespaceExists(client, this.normalize(namespace)))) {
|
||||
const collectionName = this.normalize(namespace);
|
||||
if (!(await this.namespaceExists(client, collectionName))) {
|
||||
return {
|
||||
contextTexts: [],
|
||||
sources: [],
|
||||
@@ -362,16 +411,26 @@ const Chroma = {
|
||||
}
|
||||
|
||||
const queryVector = await LLMConnector.embedTextInput(input);
|
||||
const { contextTexts, sourceDocuments, scores } =
|
||||
await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
const result = rerank
|
||||
? await this.rerankedSimilarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
query: input,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
})
|
||||
: await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
|
||||
const { contextTexts, sourceDocuments, scores } = result;
|
||||
const sources = sourceDocuments.map((metadata, i) => ({
|
||||
metadata: {
|
||||
...metadata,
|
||||
|
||||
@@ -5,7 +5,7 @@ const { SystemSettings } = require("../../../models/systemSettings");
|
||||
const { storeVectorResult, cachedVectorInformation } = require("../../files");
|
||||
const { v4: uuidv4 } = require("uuid");
|
||||
const { sourceIdentifier } = require("../../chats");
|
||||
const { NativeEmbeddingReranker } = require("../../EmbeddingRerankers/native");
|
||||
const { rerankDocuments, getSearchLimit } = require("../rerank");
|
||||
|
||||
/**
|
||||
* LancedDB Client connection object
|
||||
@@ -79,68 +79,24 @@ const LanceDb = {
|
||||
similarityThreshold = 0.25,
|
||||
filterIdentifiers = [],
|
||||
}) {
|
||||
const reranker = new NativeEmbeddingReranker();
|
||||
const collection = await client.openTable(namespace);
|
||||
const totalEmbeddings = await this.namespaceCount(namespace);
|
||||
const result = {
|
||||
contextTexts: [],
|
||||
sourceDocuments: [],
|
||||
scores: [],
|
||||
};
|
||||
const searchLimit = getSearchLimit(totalEmbeddings, topN);
|
||||
const vectorSearchResults = await client
|
||||
.openTable(namespace)
|
||||
.then((tbl) =>
|
||||
tbl
|
||||
.vectorSearch(queryVector)
|
||||
.distanceType("cosine")
|
||||
.limit(searchLimit)
|
||||
.toArray()
|
||||
);
|
||||
|
||||
/**
|
||||
* For reranking, we want to work with a larger number of results than the topN.
|
||||
* This is because the reranker can only rerank the results it it given and we dont auto-expand the results.
|
||||
* We want to give the reranker a larger number of results to work with.
|
||||
*
|
||||
* However, we cannot make this boundless as reranking is expensive and time consuming.
|
||||
* So we limit the number of results to a maximum of 50 and a minimum of 10.
|
||||
* This is a good balance between the number of results to rerank and the cost of reranking
|
||||
* and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware.
|
||||
*
|
||||
* Benchmarks:
|
||||
* On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec
|
||||
*/
|
||||
const searchLimit = Math.max(
|
||||
10,
|
||||
Math.min(50, Math.ceil(totalEmbeddings * 0.1))
|
||||
);
|
||||
const vectorSearchResults = await collection
|
||||
.vectorSearch(queryVector)
|
||||
.distanceType("cosine")
|
||||
.limit(searchLimit)
|
||||
.toArray();
|
||||
|
||||
await reranker
|
||||
.rerank(query, vectorSearchResults, { topK: topN })
|
||||
.then((rerankResults) => {
|
||||
rerankResults.forEach((item) => {
|
||||
if (this.distanceToSimilarity(item._distance) < similarityThreshold)
|
||||
return;
|
||||
const { vector: _, ...rest } = item;
|
||||
if (filterIdentifiers.includes(sourceIdentifier(rest))) {
|
||||
console.log(
|
||||
"LanceDB: A source was filtered from context as it's parent document is pinned."
|
||||
);
|
||||
return;
|
||||
}
|
||||
const score =
|
||||
item?.rerank_score || this.distanceToSimilarity(item._distance);
|
||||
|
||||
result.contextTexts.push(rest.text);
|
||||
result.sourceDocuments.push({
|
||||
...rest,
|
||||
score,
|
||||
});
|
||||
result.scores.push(score);
|
||||
});
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error(e);
|
||||
console.error("LanceDB::rerankedSimilarityResponse", e.message);
|
||||
});
|
||||
|
||||
return result;
|
||||
const reranked = await rerankDocuments(query, vectorSearchResults, {
|
||||
topN,
|
||||
similarityThreshold,
|
||||
filterIdentifiers,
|
||||
});
|
||||
return reranked;
|
||||
},
|
||||
|
||||
/**
|
||||
@@ -421,6 +377,8 @@ const LanceDb = {
|
||||
filterIdentifiers,
|
||||
});
|
||||
|
||||
console.log("result", result);
|
||||
|
||||
const { contextTexts, sourceDocuments } = result;
|
||||
const sources = sourceDocuments.map((metadata, i) => {
|
||||
return { metadata: { ...metadata, text: contextTexts[i] } };
|
||||
|
||||
@@ -10,6 +10,7 @@ const { v4: uuidv4 } = require("uuid");
|
||||
const { storeVectorResult, cachedVectorInformation } = require("../../files");
|
||||
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
|
||||
const { sourceIdentifier } = require("../../chats");
|
||||
const { rerankDocuments, getSearchLimit } = require("../rerank");
|
||||
|
||||
const Milvus = {
|
||||
name: "Milvus",
|
||||
@@ -299,6 +300,7 @@ const Milvus = {
|
||||
similarityThreshold = 0.25,
|
||||
topN = 4,
|
||||
filterIdentifiers = [],
|
||||
rerank = false,
|
||||
}) {
|
||||
if (!namespace || !input || !LLMConnector)
|
||||
throw new Error("Invalid request to performSimilaritySearch.");
|
||||
@@ -313,14 +315,24 @@ const Milvus = {
|
||||
}
|
||||
|
||||
const queryVector = await LLMConnector.embedTextInput(input);
|
||||
const { contextTexts, sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
const { contextTexts, sourceDocuments } = rerank
|
||||
? await this.rerankedSimilarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
query: input,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
})
|
||||
: await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
|
||||
const sources = sourceDocuments.map((doc, i) => {
|
||||
return { metadata: doc, text: contextTexts[i] };
|
||||
@@ -368,6 +380,31 @@ const Milvus = {
|
||||
});
|
||||
return result;
|
||||
},
|
||||
rerankedSimilarityResponse: async function ({
|
||||
client,
|
||||
namespace,
|
||||
query,
|
||||
queryVector,
|
||||
topN = 4,
|
||||
similarityThreshold = 0.25,
|
||||
filterIdentifiers = [],
|
||||
}) {
|
||||
const totalEmbeddings = await this.namespaceCount(namespace);
|
||||
const searchLimit = getSearchLimit(totalEmbeddings, topN);
|
||||
const { sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN: searchLimit,
|
||||
filterIdentifiers,
|
||||
});
|
||||
return await rerankDocuments(query, sourceDocuments, {
|
||||
topN,
|
||||
similarityThreshold,
|
||||
filterIdentifiers,
|
||||
});
|
||||
},
|
||||
"namespace-stats": async function (reqBody = {}) {
|
||||
const { namespace = null } = reqBody;
|
||||
if (!namespace) throw new Error("namespace required");
|
||||
|
||||
@@ -5,6 +5,7 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files");
|
||||
const { v4: uuidv4 } = require("uuid");
|
||||
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
|
||||
const { sourceIdentifier } = require("../../chats");
|
||||
const { rerankDocuments, getSearchLimit } = require("../rerank");
|
||||
|
||||
const PineconeDB = {
|
||||
name: "Pinecone",
|
||||
@@ -76,6 +77,31 @@ const PineconeDB = {
|
||||
|
||||
return result;
|
||||
},
|
||||
rerankedSimilarityResponse: async function ({
|
||||
client,
|
||||
namespace,
|
||||
query,
|
||||
queryVector,
|
||||
topN = 4,
|
||||
similarityThreshold = 0.25,
|
||||
filterIdentifiers = [],
|
||||
}) {
|
||||
const totalEmbeddings = await this.namespaceCount(namespace);
|
||||
const searchLimit = getSearchLimit(totalEmbeddings, topN);
|
||||
const { sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN: searchLimit,
|
||||
filterIdentifiers,
|
||||
});
|
||||
return await rerankDocuments(query, sourceDocuments, {
|
||||
topN,
|
||||
similarityThreshold,
|
||||
filterIdentifiers,
|
||||
});
|
||||
},
|
||||
namespace: async function (index, namespace = null) {
|
||||
if (!namespace) throw new Error("No namespace value provided.");
|
||||
const { namespaces } = await index.describeIndexStats();
|
||||
@@ -247,6 +273,7 @@ const PineconeDB = {
|
||||
similarityThreshold = 0.25,
|
||||
topN = 4,
|
||||
filterIdentifiers = [],
|
||||
rerank = false,
|
||||
}) {
|
||||
if (!namespace || !input || !LLMConnector)
|
||||
throw new Error("Invalid request to performSimilaritySearch.");
|
||||
@@ -258,14 +285,24 @@ const PineconeDB = {
|
||||
);
|
||||
|
||||
const queryVector = await LLMConnector.embedTextInput(input);
|
||||
const { contextTexts, sourceDocuments } = await this.similarityResponse({
|
||||
client: pineconeIndex,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
const { contextTexts, sourceDocuments } = rerank
|
||||
? await this.rerankedSimilarityResponse({
|
||||
client: pineconeIndex,
|
||||
namespace,
|
||||
query: input,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
})
|
||||
: await this.similarityResponse({
|
||||
client: pineconeIndex,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
|
||||
const sources = sourceDocuments.map((doc, i) => {
|
||||
return { metadata: doc, text: contextTexts[i] };
|
||||
|
||||
@@ -5,6 +5,7 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files");
|
||||
const { v4: uuidv4 } = require("uuid");
|
||||
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
|
||||
const { sourceIdentifier } = require("../../chats");
|
||||
const { rerankDocuments, getSearchLimit } = require("../rerank");
|
||||
|
||||
const QDrant = {
|
||||
name: "QDrant",
|
||||
@@ -86,6 +87,35 @@ const QDrant = {
|
||||
|
||||
return result;
|
||||
},
|
||||
rerankedSimilarityResponse: async function ({
|
||||
client,
|
||||
namespace,
|
||||
query,
|
||||
queryVector,
|
||||
topN = 4,
|
||||
similarityThreshold = 0.25,
|
||||
filterIdentifiers = [],
|
||||
}) {
|
||||
const totalEmbeddings = await this.namespaceCount(namespace);
|
||||
const searchLimit = getSearchLimit(totalEmbeddings, topN);
|
||||
const { sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN: searchLimit,
|
||||
filterIdentifiers,
|
||||
});
|
||||
return await rerankDocuments(
|
||||
query,
|
||||
sourceDocuments.map((doc) => ({ ...doc, score: null })),
|
||||
{
|
||||
topN,
|
||||
similarityThreshold,
|
||||
filterIdentifiers,
|
||||
}
|
||||
);
|
||||
},
|
||||
namespace: async function (client, namespace = null) {
|
||||
if (!namespace) throw new Error("No namespace value provided.");
|
||||
const collection = await client.getCollection(namespace).catch(() => null);
|
||||
@@ -324,6 +354,7 @@ const QDrant = {
|
||||
similarityThreshold = 0.25,
|
||||
topN = 4,
|
||||
filterIdentifiers = [],
|
||||
rerank = false,
|
||||
}) {
|
||||
if (!namespace || !input || !LLMConnector)
|
||||
throw new Error("Invalid request to performSimilaritySearch.");
|
||||
@@ -338,14 +369,24 @@ const QDrant = {
|
||||
}
|
||||
|
||||
const queryVector = await LLMConnector.embedTextInput(input);
|
||||
const { contextTexts, sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
const { contextTexts, sourceDocuments } = rerank
|
||||
? await this.rerankedSimilarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
query: input,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
})
|
||||
: await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
|
||||
const sources = sourceDocuments.map((metadata, i) => {
|
||||
return { ...metadata, text: contextTexts[i] };
|
||||
|
||||
65
server/utils/vectorDbProviders/rerank.js
Normal file
65
server/utils/vectorDbProviders/rerank.js
Normal file
@@ -0,0 +1,65 @@
|
||||
const { NativeEmbeddingReranker } = require("../EmbeddingRerankers/native");
|
||||
const { sourceIdentifier } = require("../chats");
|
||||
|
||||
async function rerankDocuments(
|
||||
query,
|
||||
documents,
|
||||
options = { topN: 4, similarityThreshold: 0.25, filterIdentifiers: [] }
|
||||
) {
|
||||
const { topN, similarityThreshold, filterIdentifiers } = options;
|
||||
const reranker = new NativeEmbeddingReranker();
|
||||
const result = {
|
||||
contextTexts: [],
|
||||
sourceDocuments: [],
|
||||
scores: [],
|
||||
};
|
||||
|
||||
await reranker
|
||||
.rerank(query, documents, { topK: topN })
|
||||
.then((rerankResults) => {
|
||||
rerankResults.forEach((item) => {
|
||||
if (item.score < similarityThreshold) return;
|
||||
|
||||
const { vector: _, ...rest } = item;
|
||||
if (filterIdentifiers.includes(sourceIdentifier(rest))) {
|
||||
console.log(
|
||||
"A source was filtered from context as it's parent document is pinned."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
result.contextTexts.push(rest.text);
|
||||
result.sourceDocuments.push({
|
||||
...rest,
|
||||
});
|
||||
result.scores.push(item.score);
|
||||
});
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error(e);
|
||||
console.error("rerankDocuments", e.message);
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
/**
|
||||
* For reranking, we want to work with a larger number of results than the topN.
|
||||
* This is because the reranker can only rerank the results it it given and we dont auto-expand the results.
|
||||
* We want to give the reranker a larger number of results to work with.
|
||||
*
|
||||
* However, we cannot make this boundless as reranking is expensive and time consuming.
|
||||
* So we limit the number of results to a maximum of 50 and a minimum of 10.
|
||||
* This is a good balance between the number of results to rerank and the cost of reranking
|
||||
* and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware.
|
||||
*
|
||||
* Benchmarks:
|
||||
* On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec
|
||||
*/
|
||||
function getSearchLimit(totalEmbeddings = 0, topN = 4) {
|
||||
return Math.max(10, Math.min(50, Math.ceil(totalEmbeddings * 0.1) || topN));
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
rerankDocuments,
|
||||
getSearchLimit,
|
||||
};
|
||||
@@ -6,6 +6,7 @@ const { v4: uuidv4 } = require("uuid");
|
||||
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
|
||||
const { camelCase } = require("../../helpers/camelcase");
|
||||
const { sourceIdentifier } = require("../../chats");
|
||||
const { rerankDocuments, getSearchLimit } = require("../rerank");
|
||||
|
||||
const Weaviate = {
|
||||
name: "Weaviate",
|
||||
@@ -121,6 +122,31 @@ const Weaviate = {
|
||||
|
||||
return result;
|
||||
},
|
||||
rerankedSimilarityResponse: async function ({
|
||||
client,
|
||||
namespace,
|
||||
query,
|
||||
queryVector,
|
||||
topN = 4,
|
||||
similarityThreshold = 0.25,
|
||||
filterIdentifiers = [],
|
||||
}) {
|
||||
const totalEmbeddings = await this.namespaceCount(namespace);
|
||||
const searchLimit = getSearchLimit(totalEmbeddings, topN);
|
||||
const { sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN: searchLimit,
|
||||
filterIdentifiers,
|
||||
});
|
||||
return await rerankDocuments(query, sourceDocuments, {
|
||||
topN,
|
||||
similarityThreshold,
|
||||
filterIdentifiers,
|
||||
});
|
||||
},
|
||||
allNamespaces: async function (client) {
|
||||
try {
|
||||
const { classes = [] } = await client.schema.getter().do();
|
||||
@@ -368,6 +394,7 @@ const Weaviate = {
|
||||
similarityThreshold = 0.25,
|
||||
topN = 4,
|
||||
filterIdentifiers = [],
|
||||
rerank = false,
|
||||
}) {
|
||||
if (!namespace || !input || !LLMConnector)
|
||||
throw new Error("Invalid request to performSimilaritySearch.");
|
||||
@@ -382,14 +409,24 @@ const Weaviate = {
|
||||
}
|
||||
|
||||
const queryVector = await LLMConnector.embedTextInput(input);
|
||||
const { contextTexts, sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
const { contextTexts, sourceDocuments } = rerank
|
||||
? await this.rerankedSimilarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
query: input,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
})
|
||||
: await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
|
||||
const sources = sourceDocuments.map((metadata, i) => {
|
||||
return { ...metadata, text: contextTexts[i] };
|
||||
|
||||
@@ -10,6 +10,7 @@ const { v4: uuidv4 } = require("uuid");
|
||||
const { storeVectorResult, cachedVectorInformation } = require("../../files");
|
||||
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
|
||||
const { sourceIdentifier } = require("../../chats");
|
||||
const { rerankDocuments, getSearchLimit } = require("../rerank");
|
||||
|
||||
// Zilliz is basically a copy of Milvus DB class with a different constructor
|
||||
// to connect to the cloud
|
||||
@@ -292,6 +293,7 @@ const Zilliz = {
|
||||
similarityThreshold = 0.25,
|
||||
topN = 4,
|
||||
filterIdentifiers = [],
|
||||
rerank = false,
|
||||
}) {
|
||||
if (!namespace || !input || !LLMConnector)
|
||||
throw new Error("Invalid request to performSimilaritySearch.");
|
||||
@@ -306,14 +308,24 @@ const Zilliz = {
|
||||
}
|
||||
|
||||
const queryVector = await LLMConnector.embedTextInput(input);
|
||||
const { contextTexts, sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
const { contextTexts, sourceDocuments } = rerank
|
||||
? await this.rerankedSimilarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
query: input,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
})
|
||||
: await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN,
|
||||
filterIdentifiers,
|
||||
});
|
||||
|
||||
const sources = sourceDocuments.map((doc, i) => {
|
||||
return { metadata: doc, text: contextTexts[i] };
|
||||
@@ -359,6 +371,31 @@ const Zilliz = {
|
||||
});
|
||||
return result;
|
||||
},
|
||||
rerankedSimilarityResponse: async function ({
|
||||
client,
|
||||
namespace,
|
||||
query,
|
||||
queryVector,
|
||||
topN = 4,
|
||||
similarityThreshold = 0.25,
|
||||
filterIdentifiers = [],
|
||||
}) {
|
||||
const totalEmbeddings = await this.namespaceCount(namespace);
|
||||
const searchLimit = getSearchLimit(totalEmbeddings, topN);
|
||||
const { sourceDocuments } = await this.similarityResponse({
|
||||
client,
|
||||
namespace,
|
||||
queryVector,
|
||||
similarityThreshold,
|
||||
topN: searchLimit,
|
||||
filterIdentifiers,
|
||||
});
|
||||
return await rerankDocuments(query, sourceDocuments, {
|
||||
topN,
|
||||
similarityThreshold,
|
||||
filterIdentifiers,
|
||||
});
|
||||
},
|
||||
"namespace-stats": async function (reqBody = {}) {
|
||||
const { namespace = null } = reqBody;
|
||||
if (!namespace) throw new Error("namespace required");
|
||||
|
||||
Reference in New Issue
Block a user