diff --git a/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts b/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts index 4d6ac3c72..c60b9431f 100644 --- a/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts +++ b/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts @@ -1,7 +1,5 @@ import { Redis, RedisOptions } from 'ioredis' -import { isEqual } from 'lodash' import { BufferMemory, BufferMemoryInput } from 'langchain/memory' -import { RedisChatMessageHistory, RedisChatMessageHistoryInput } from '@langchain/community/stores/message/ioredis' import { mapStoredMessageToChatMessage, BaseMessage, AIMessage, HumanMessage } from '@langchain/core/messages' import { INode, INodeData, INodeParams, ICommonObject, MessageType, IMessage, MemoryMethods, FlowiseMemory } from '../../../src/Interface' import { @@ -12,42 +10,6 @@ import { mapChatMessageToBaseMessage } from '../../../src/utils' -let redisClientSingleton: Redis -let redisClientOption: RedisOptions -let redisClientUrl: string - -const getRedisClientbyOption = (option: RedisOptions) => { - if (!redisClientSingleton) { - // if client doesn't exists - redisClientSingleton = new Redis(option) - redisClientOption = option - return redisClientSingleton - } else if (redisClientSingleton && !isEqual(option, redisClientOption)) { - // if client exists but option changed - redisClientSingleton.quit() - redisClientSingleton = new Redis(option) - redisClientOption = option - return redisClientSingleton - } - return redisClientSingleton -} - -const getRedisClientbyUrl = (url: string) => { - if (!redisClientSingleton) { - // if client doesn't exists - redisClientSingleton = new Redis(url) - redisClientUrl = url - return redisClientSingleton - } else if (redisClientSingleton && url !== redisClientUrl) { - // if client exists but option changed - redisClientSingleton.quit() - redisClientSingleton = new Redis(url) - redisClientUrl = url - return redisClientSingleton - } - return redisClientSingleton -} - class RedisBackedChatMemory_Memory implements INode { label: string name: string @@ -114,11 +76,11 @@ class RedisBackedChatMemory_Memory implements INode { } async init(nodeData: INodeData, _: string, options: ICommonObject): Promise { - return await initalizeRedis(nodeData, options) + return await initializeRedis(nodeData, options) } } -const initalizeRedis = async (nodeData: INodeData, options: ICommonObject): Promise => { +const initializeRedis = async (nodeData: INodeData, options: ICommonObject): Promise => { const sessionTTL = nodeData.inputs?.sessionTTL as number const memoryKey = nodeData.inputs?.memoryKey as string const sessionId = nodeData.inputs?.sessionId as string @@ -127,73 +89,55 @@ const initalizeRedis = async (nodeData: INodeData, options: ICommonObject): Prom const credentialData = await getCredentialData(nodeData.credential ?? '', options) const redisUrl = getCredentialParam('redisUrl', credentialData, nodeData) - let client: Redis - - if (!redisUrl || redisUrl === '') { - const username = getCredentialParam('redisCacheUser', credentialData, nodeData) - const password = getCredentialParam('redisCachePwd', credentialData, nodeData) - const portStr = getCredentialParam('redisCachePort', credentialData, nodeData) - const host = getCredentialParam('redisCacheHost', credentialData, nodeData) - const sslEnabled = getCredentialParam('redisCacheSslEnabled', credentialData, nodeData) - - const tlsOptions = sslEnabled === true ? { tls: { rejectUnauthorized: false } } : {} - - client = getRedisClientbyOption({ - port: portStr ? parseInt(portStr) : 6379, - host, - username, - password, - ...tlsOptions - }) - } else { - client = getRedisClientbyUrl(redisUrl) - } - - let obj: RedisChatMessageHistoryInput = { - sessionId, - client - } - - if (sessionTTL) { - obj = { - ...obj, - sessionTTL - } - } - - const redisChatMessageHistory = new RedisChatMessageHistory(obj) + const redisOptions = redisUrl + ? redisUrl + : ({ + port: parseInt(getCredentialParam('redisCachePort', credentialData, nodeData) || '6379'), + host: getCredentialParam('redisCacheHost', credentialData, nodeData), + username: getCredentialParam('redisCacheUser', credentialData, nodeData), + password: getCredentialParam('redisCachePwd', credentialData, nodeData), + tls: getCredentialParam('redisCacheSslEnabled', credentialData, nodeData) ? { rejectUnauthorized: false } : undefined + } as RedisOptions) const memory = new BufferMemoryExtended({ memoryKey: memoryKey ?? 'chat_history', - chatHistory: redisChatMessageHistory, sessionId, windowSize, - redisClient: client, - sessionTTL + sessionTTL, + redisOptions }) return memory } interface BufferMemoryExtendedInput { - redisClient: Redis sessionId: string windowSize?: number sessionTTL?: number + redisOptions: RedisOptions | string } class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods { sessionId = '' - redisClient: Redis windowSize?: number sessionTTL?: number + redisOptions: RedisOptions | string constructor(fields: BufferMemoryInput & BufferMemoryExtendedInput) { super(fields) this.sessionId = fields.sessionId - this.redisClient = fields.redisClient this.windowSize = fields.windowSize this.sessionTTL = fields.sessionTTL + this.redisOptions = fields.redisOptions + } + + private async withRedisClient(fn: (client: Redis) => Promise): Promise { + const client = typeof this.redisOptions === 'string' ? new Redis(this.redisOptions) : new Redis(this.redisOptions) + try { + return await fn(client) + } finally { + await client.quit() + } } async getChatMessages( @@ -201,46 +145,46 @@ class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods { returnBaseMessages = false, prependMessages?: IMessage[] ): Promise { - if (!this.redisClient) return [] - - const id = overrideSessionId ? overrideSessionId : this.sessionId - const rawStoredMessages = await this.redisClient.lrange(id, this.windowSize ? this.windowSize * -1 : 0, -1) - const orderedMessages = rawStoredMessages.reverse().map((message) => JSON.parse(message)) - const baseMessages = orderedMessages.map(mapStoredMessageToChatMessage) - if (prependMessages?.length) { - baseMessages.unshift(...(await mapChatMessageToBaseMessage(prependMessages))) - } - return returnBaseMessages ? baseMessages : convertBaseMessagetoIMessage(baseMessages) + return this.withRedisClient(async (client) => { + const id = overrideSessionId ? overrideSessionId : this.sessionId + const rawStoredMessages = await client.lrange(id, this.windowSize ? this.windowSize * -1 : 0, -1) + const orderedMessages = rawStoredMessages.reverse().map((message) => JSON.parse(message)) + const baseMessages = orderedMessages.map(mapStoredMessageToChatMessage) + if (prependMessages?.length) { + baseMessages.unshift(...(await mapChatMessageToBaseMessage(prependMessages))) + } + return returnBaseMessages ? baseMessages : convertBaseMessagetoIMessage(baseMessages) + }) } async addChatMessages(msgArray: { text: string; type: MessageType }[], overrideSessionId = ''): Promise { - if (!this.redisClient) return + await this.withRedisClient(async (client) => { + const id = overrideSessionId ? overrideSessionId : this.sessionId + const input = msgArray.find((msg) => msg.type === 'userMessage') + const output = msgArray.find((msg) => msg.type === 'apiMessage') - const id = overrideSessionId ? overrideSessionId : this.sessionId - const input = msgArray.find((msg) => msg.type === 'userMessage') - const output = msgArray.find((msg) => msg.type === 'apiMessage') + if (input) { + const newInputMessage = new HumanMessage(input.text) + const messageToAdd = [newInputMessage].map((msg) => msg.toDict()) + await client.lpush(id, JSON.stringify(messageToAdd[0])) + if (this.sessionTTL) await client.expire(id, this.sessionTTL) + } - if (input) { - const newInputMessage = new HumanMessage(input.text) - const messageToAdd = [newInputMessage].map((msg) => msg.toDict()) - await this.redisClient.lpush(id, JSON.stringify(messageToAdd[0])) - if (this.sessionTTL) await this.redisClient.expire(id, this.sessionTTL) - } - - if (output) { - const newOutputMessage = new AIMessage(output.text) - const messageToAdd = [newOutputMessage].map((msg) => msg.toDict()) - await this.redisClient.lpush(id, JSON.stringify(messageToAdd[0])) - if (this.sessionTTL) await this.redisClient.expire(id, this.sessionTTL) - } + if (output) { + const newOutputMessage = new AIMessage(output.text) + const messageToAdd = [newOutputMessage].map((msg) => msg.toDict()) + await client.lpush(id, JSON.stringify(messageToAdd[0])) + if (this.sessionTTL) await client.expire(id, this.sessionTTL) + } + }) } async clearChatMessages(overrideSessionId = ''): Promise { - if (!this.redisClient) return - - const id = overrideSessionId ? overrideSessionId : this.sessionId - await this.redisClient.del(id) - await this.clear() + await this.withRedisClient(async (client) => { + const id = overrideSessionId ? overrideSessionId : this.sessionId + await client.del(id) + await this.clear() + }) } }