248 lines
9.0 KiB
TypeScript
248 lines
9.0 KiB
TypeScript
import { Redis } from 'ioredis'
|
|
import { RedisByteStore } from '@langchain/community/storage/ioredis'
|
|
import { Embeddings, EmbeddingsInterface } from '@langchain/core/embeddings'
|
|
import { CacheBackedEmbeddingsFields } from 'langchain/embeddings/cache_backed'
|
|
import { getBaseClasses, getCredentialData, getCredentialParam, ICommonObject, INode, INodeData, INodeParams } from '../../../src'
|
|
import { BaseStore } from '@langchain/core/stores'
|
|
import { insecureHash } from '@langchain/core/utils/hash'
|
|
import { Document } from '@langchain/core/documents'
|
|
|
|
class RedisEmbeddingsCache implements INode {
|
|
label: string
|
|
name: string
|
|
version: number
|
|
description: string
|
|
type: string
|
|
icon: string
|
|
category: string
|
|
baseClasses: string[]
|
|
inputs: INodeParams[]
|
|
credential: INodeParams
|
|
|
|
constructor() {
|
|
this.label = 'Redis Embeddings Cache'
|
|
this.name = 'redisEmbeddingsCache'
|
|
this.version = 1.0
|
|
this.type = 'RedisEmbeddingsCache'
|
|
this.description = 'Cache generated Embeddings in Redis to avoid needing to recompute them.'
|
|
this.icon = 'redis.svg'
|
|
this.category = 'Cache'
|
|
this.baseClasses = [this.type, ...getBaseClasses(CacheBackedEmbeddings)]
|
|
this.credential = {
|
|
label: 'Connect Credential',
|
|
name: 'credential',
|
|
type: 'credential',
|
|
optional: true,
|
|
credentialNames: ['redisCacheApi', 'redisCacheUrlApi']
|
|
}
|
|
this.inputs = [
|
|
{
|
|
label: 'Embeddings',
|
|
name: 'embeddings',
|
|
type: 'Embeddings'
|
|
},
|
|
{
|
|
label: 'Time to Live (ms)',
|
|
name: 'ttl',
|
|
type: 'number',
|
|
step: 10,
|
|
default: 60 * 60,
|
|
optional: true,
|
|
additionalParams: true
|
|
},
|
|
{
|
|
label: 'Namespace',
|
|
name: 'namespace',
|
|
type: 'string',
|
|
optional: true,
|
|
additionalParams: true
|
|
}
|
|
]
|
|
}
|
|
|
|
async init(nodeData: INodeData, _: string, options: ICommonObject): Promise<any> {
|
|
let ttl = nodeData.inputs?.ttl as string
|
|
const namespace = nodeData.inputs?.namespace as string
|
|
const underlyingEmbeddings = nodeData.inputs?.embeddings as Embeddings
|
|
|
|
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
|
|
const redisUrl = getCredentialParam('redisUrl', credentialData, nodeData)
|
|
|
|
let client: Redis
|
|
if (!redisUrl || redisUrl === '') {
|
|
const username = getCredentialParam('redisCacheUser', credentialData, nodeData)
|
|
const password = getCredentialParam('redisCachePwd', credentialData, nodeData)
|
|
const portStr = getCredentialParam('redisCachePort', credentialData, nodeData)
|
|
const host = getCredentialParam('redisCacheHost', credentialData, nodeData)
|
|
const sslEnabled = getCredentialParam('redisCacheSslEnabled', credentialData, nodeData)
|
|
|
|
const tlsOptions = sslEnabled === true ? { tls: { rejectUnauthorized: false } } : {}
|
|
|
|
client = new Redis({
|
|
port: portStr ? parseInt(portStr) : 6379,
|
|
host,
|
|
username,
|
|
password,
|
|
keepAlive:
|
|
process.env.REDIS_KEEP_ALIVE && !isNaN(parseInt(process.env.REDIS_KEEP_ALIVE, 10))
|
|
? parseInt(process.env.REDIS_KEEP_ALIVE, 10)
|
|
: undefined,
|
|
...tlsOptions
|
|
})
|
|
} else {
|
|
client = new Redis(redisUrl, {
|
|
keepAlive:
|
|
process.env.REDIS_KEEP_ALIVE && !isNaN(parseInt(process.env.REDIS_KEEP_ALIVE, 10))
|
|
? parseInt(process.env.REDIS_KEEP_ALIVE, 10)
|
|
: undefined
|
|
})
|
|
}
|
|
|
|
ttl ??= '3600'
|
|
let ttlNumber = parseInt(ttl, 10)
|
|
const redisStore = new RedisByteStore({
|
|
client: client,
|
|
ttl: ttlNumber
|
|
})
|
|
|
|
const store = CacheBackedEmbeddings.fromBytesStore(underlyingEmbeddings, redisStore, {
|
|
namespace: namespace,
|
|
redisClient: client
|
|
})
|
|
|
|
return store
|
|
}
|
|
}
|
|
|
|
class CacheBackedEmbeddings extends Embeddings {
|
|
protected underlyingEmbeddings: EmbeddingsInterface
|
|
|
|
protected documentEmbeddingStore: BaseStore<string, number[]>
|
|
|
|
protected redisClient?: Redis
|
|
|
|
constructor(fields: CacheBackedEmbeddingsFields & { redisClient?: Redis }) {
|
|
super(fields)
|
|
this.underlyingEmbeddings = fields.underlyingEmbeddings
|
|
this.documentEmbeddingStore = fields.documentEmbeddingStore
|
|
this.redisClient = fields.redisClient
|
|
}
|
|
|
|
async embedQuery(document: string): Promise<number[]> {
|
|
const res = this.underlyingEmbeddings.embedQuery(document)
|
|
this.redisClient?.quit()
|
|
return res
|
|
}
|
|
|
|
async embedDocuments(documents: string[]): Promise<number[][]> {
|
|
const vectors = await this.documentEmbeddingStore.mget(documents)
|
|
const missingIndicies = []
|
|
const missingDocuments = []
|
|
for (let i = 0; i < vectors.length; i += 1) {
|
|
if (vectors[i] === undefined) {
|
|
missingIndicies.push(i)
|
|
missingDocuments.push(documents[i])
|
|
}
|
|
}
|
|
if (missingDocuments.length) {
|
|
const missingVectors = await this.underlyingEmbeddings.embedDocuments(missingDocuments)
|
|
const keyValuePairs: [string, number[]][] = missingDocuments.map((document, i) => [document, missingVectors[i]])
|
|
await this.documentEmbeddingStore.mset(keyValuePairs)
|
|
for (let i = 0; i < missingIndicies.length; i += 1) {
|
|
vectors[missingIndicies[i]] = missingVectors[i]
|
|
}
|
|
}
|
|
this.redisClient?.quit()
|
|
return vectors as number[][]
|
|
}
|
|
|
|
static fromBytesStore(
|
|
underlyingEmbeddings: EmbeddingsInterface,
|
|
documentEmbeddingStore: BaseStore<string, Uint8Array>,
|
|
options?: {
|
|
namespace?: string
|
|
redisClient?: Redis
|
|
}
|
|
) {
|
|
const encoder = new TextEncoder()
|
|
const decoder = new TextDecoder()
|
|
const encoderBackedStore = new EncoderBackedStore<string, number[], Uint8Array>({
|
|
store: documentEmbeddingStore,
|
|
keyEncoder: (key) => (options?.namespace ?? '') + insecureHash(key),
|
|
valueSerializer: (value) => encoder.encode(JSON.stringify(value)),
|
|
valueDeserializer: (serializedValue) => JSON.parse(decoder.decode(serializedValue))
|
|
})
|
|
return new this({
|
|
underlyingEmbeddings,
|
|
documentEmbeddingStore: encoderBackedStore,
|
|
redisClient: options?.redisClient
|
|
})
|
|
}
|
|
}
|
|
|
|
class EncoderBackedStore<K, V, SerializedType = any> extends BaseStore<K, V> {
|
|
lc_namespace = ['langchain', 'storage']
|
|
|
|
store: BaseStore<string, SerializedType>
|
|
|
|
keyEncoder: (key: K) => string
|
|
|
|
valueSerializer: (value: V) => SerializedType
|
|
|
|
valueDeserializer: (value: SerializedType) => V
|
|
|
|
constructor(fields: {
|
|
store: BaseStore<string, SerializedType>
|
|
keyEncoder: (key: K) => string
|
|
valueSerializer: (value: V) => SerializedType
|
|
valueDeserializer: (value: SerializedType) => V
|
|
}) {
|
|
super(fields)
|
|
this.store = fields.store
|
|
this.keyEncoder = fields.keyEncoder
|
|
this.valueSerializer = fields.valueSerializer
|
|
this.valueDeserializer = fields.valueDeserializer
|
|
}
|
|
|
|
async mget(keys: K[]): Promise<(V | undefined)[]> {
|
|
const encodedKeys = keys.map(this.keyEncoder)
|
|
const values = await this.store.mget(encodedKeys)
|
|
return values.map((value) => {
|
|
if (value === undefined) {
|
|
return undefined
|
|
}
|
|
return this.valueDeserializer(value)
|
|
})
|
|
}
|
|
|
|
async mset(keyValuePairs: [K, V][]): Promise<void> {
|
|
const encodedPairs: [string, SerializedType][] = keyValuePairs.map(([key, value]) => [
|
|
this.keyEncoder(key),
|
|
this.valueSerializer(value)
|
|
])
|
|
return this.store.mset(encodedPairs)
|
|
}
|
|
|
|
async mdelete(keys: K[]): Promise<void> {
|
|
const encodedKeys = keys.map(this.keyEncoder)
|
|
return this.store.mdelete(encodedKeys)
|
|
}
|
|
|
|
async *yieldKeys(prefix?: string | undefined): AsyncGenerator<string | K> {
|
|
yield* this.store.yieldKeys(prefix)
|
|
}
|
|
}
|
|
|
|
export function createDocumentStoreFromByteStore(store: BaseStore<string, Uint8Array>) {
|
|
const encoder = new TextEncoder()
|
|
const decoder = new TextDecoder()
|
|
return new EncoderBackedStore({
|
|
store,
|
|
keyEncoder: (key: string) => key,
|
|
valueSerializer: (doc: Document) => encoder.encode(JSON.stringify({ pageContent: doc.pageContent, metadata: doc.metadata })),
|
|
valueDeserializer: (bytes: Uint8Array) => new Document(JSON.parse(decoder.decode(bytes)))
|
|
})
|
|
}
|
|
|
|
module.exports = { nodeClass: RedisEmbeddingsCache }
|