fix: change data source lifecycle on agent memory mysql saver (#3578)

* fix: change data source lifecycle on agent memory mysql saver

* Update mysqlSaver.ts

* Update pgSaver.ts

* linting fix

---------

Co-authored-by: Henry Heng <henryheng@flowiseai.com>
This commit is contained in:
João Paulo 2024-12-06 16:49:49 -03:00 committed by GitHub
parent 371da23986
commit 09d20fa5ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 159 additions and 129 deletions

View File

@ -1,50 +1,46 @@
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph' import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
import { RunnableConfig } from '@langchain/core/runnables' import { RunnableConfig } from '@langchain/core/runnables'
import { BaseMessage } from '@langchain/core/messages' import { BaseMessage } from '@langchain/core/messages'
import { DataSource, QueryRunner } from 'typeorm' import { DataSource } from 'typeorm'
import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface' import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface'
import { IMessage, MemoryMethods } from '../../../src/Interface' import { IMessage, MemoryMethods } from '../../../src/Interface'
import { mapChatMessageToBaseMessage } from '../../../src/utils' import { mapChatMessageToBaseMessage } from '../../../src/utils'
export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods { export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods {
protected isSetup: boolean protected isSetup: boolean
datasource: DataSource
queryRunner: QueryRunner
config: SaverOptions config: SaverOptions
threadId: string threadId: string
tableName = 'checkpoints' tableName = 'checkpoints'
constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) { constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
super(serde) super(serde)
this.config = config this.config = config
const { datasourceOptions, threadId } = config const { threadId } = config
this.threadId = threadId this.threadId = threadId
this.datasource = new DataSource(datasourceOptions)
} }
private async setup(): Promise<void> { private async getDataSource(): Promise<DataSource> {
if (this.isSetup) { const { datasourceOptions } = this.config
return const dataSource = new DataSource(datasourceOptions)
} await dataSource.initialize()
return dataSource
}
private async setup(dataSource: DataSource): Promise<void> {
if (this.isSetup) return
try { try {
const appDataSource = await this.datasource.initialize() const queryRunner = dataSource.createQueryRunner()
await queryRunner.manager.query(`
this.queryRunner = appDataSource.createQueryRunner() CREATE TABLE IF NOT EXISTS ${this.tableName} (
await this.queryRunner.manager.query(` thread_id VARCHAR(255) NOT NULL,
CREATE TABLE IF NOT EXISTS ${this.tableName} ( checkpoint_id VARCHAR(255) NOT NULL,
thread_id VARCHAR(255) NOT NULL, parent_id VARCHAR(255),
checkpoint_id VARCHAR(255) NOT NULL, checkpoint BLOB,
parent_id VARCHAR(255), metadata BLOB,
checkpoint BLOB, PRIMARY KEY (thread_id, checkpoint_id)
metadata BLOB, );`)
PRIMARY KEY (thread_id, checkpoint_id) await queryRunner.release()
);`)
} catch (error) { } catch (error) {
console.error(`Error creating ${this.tableName} table`, error) console.error(`Error creating ${this.tableName} table`, error)
throw new Error(`Error creating ${this.tableName} table`) throw new Error(`Error creating ${this.tableName} table`)
@ -54,79 +50,67 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} }
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> { async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
await this.setup() const dataSource = await this.getDataSource()
await this.setup(dataSource)
const thread_id = config.configurable?.thread_id || this.threadId const thread_id = config.configurable?.thread_id || this.threadId
const checkpoint_id = config.configurable?.checkpoint_id const checkpoint_id = config.configurable?.checkpoint_id
if (checkpoint_id) { try {
try { const queryRunner = dataSource.createQueryRunner()
const keys = [thread_id, checkpoint_id] const sql = checkpoint_id
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?` ? `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`
: `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
const rows = await this.queryRunner.manager.query(sql, keys) const rows = await queryRunner.manager.query(sql, checkpoint_id ? [thread_id, checkpoint_id] : [thread_id])
await queryRunner.release()
if (rows && rows.length > 0) {
return {
config,
checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint,
metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata,
parentConfig: rows[0].parent_id
? {
configurable: {
thread_id,
checkpoint_id: rows[0].parent_id
}
}
: undefined
}
}
} catch (error) {
console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`)
}
} else {
const keys = [thread_id]
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
const rows = await this.queryRunner.manager.query(sql, keys)
if (rows && rows.length > 0) { if (rows && rows.length > 0) {
const row = rows[0]
return { return {
config: { config: {
configurable: { configurable: {
thread_id: rows[0].thread_id, thread_id: row.thread_id || thread_id,
checkpoint_id: rows[0].checkpoint_id checkpoint_id: row.checkpoint_id || checkpoint_id
} }
}, },
checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint, checkpoint: (await this.serde.parse(row.checkpoint.toString())) as Checkpoint,
metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata, metadata: (await this.serde.parse(row.metadata.toString())) as CheckpointMetadata,
parentConfig: rows[0].parent_id parentConfig: row.parent_id
? { ? {
configurable: { configurable: {
thread_id: rows[0].thread_id, thread_id,
checkpoint_id: rows[0].parent_id checkpoint_id: row.parent_id
} }
} }
: undefined : undefined
} }
} }
} catch (error) {
console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`)
} finally {
await dataSource.destroy()
} }
return undefined return undefined
} }
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> { async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple, void, unknown> {
await this.setup() const dataSource = await this.getDataSource()
const thread_id = config.configurable?.thread_id || this.threadId await this.setup(dataSource)
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${ const queryRunner = dataSource.createQueryRunner()
before ? 'AND checkpoint_id < ?' : ''
} ORDER BY checkpoint_id DESC`
if (limit) {
sql += ` LIMIT ${limit}`
}
const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean)
try { try {
const rows = await this.queryRunner.manager.query(sql, args) const threadId = config.configurable?.thread_id || this.threadId
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
before ? 'AND checkpoint_id < ?' : ''
} ORDER BY checkpoint_id DESC`
if (limit) {
sql += ` LIMIT ${limit}`
}
const args = [threadId, before?.configurable?.checkpoint_id].filter(Boolean)
const rows = await queryRunner.manager.query(sql, args)
await queryRunner.release()
if (rows && rows.length > 0) { if (rows && rows.length > 0) {
for (const row of rows) { for (const row of rows) {
@ -151,15 +135,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} }
} }
} catch (error) { } catch (error) {
console.error(`Error listing ${this.tableName}`, error) console.error(`Error listing checkpoints`, error)
throw new Error(`Error listing ${this.tableName}`) throw new Error(`Error listing checkpoints`)
} finally {
await dataSource.destroy()
} }
} }
async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> { async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
await this.setup() const dataSource = await this.getDataSource()
await this.setup(dataSource)
if (!config.configurable?.checkpoint_id) return {} if (!config.configurable?.checkpoint_id) return {}
try { try {
const queryRunner = dataSource.createQueryRunner()
const row = [ const row = [
config.configurable?.thread_id || this.threadId, config.configurable?.thread_id || this.threadId,
checkpoint.id, checkpoint.id,
@ -172,10 +161,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
ON DUPLICATE KEY UPDATE checkpoint = VALUES(checkpoint), metadata = VALUES(metadata)` ON DUPLICATE KEY UPDATE checkpoint = VALUES(checkpoint), metadata = VALUES(metadata)`
await this.queryRunner.manager.query(query, row) await queryRunner.manager.query(query, row)
await queryRunner.release()
} catch (error) { } catch (error) {
console.error('Error saving checkpoint', error) console.error('Error saving checkpoint', error)
throw new Error('Error saving checkpoint') throw new Error('Error saving checkpoint')
} finally {
await dataSource.destroy()
} }
return { return {
@ -187,16 +179,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} }
async delete(threadId: string): Promise<void> { async delete(threadId: string): Promise<void> {
if (!threadId) { if (!threadId) return
return
} const dataSource = await this.getDataSource()
await this.setup() await this.setup(dataSource)
const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;`
try { try {
await this.queryRunner.manager.query(query, [threadId]) const queryRunner = dataSource.createQueryRunner()
const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;`
await queryRunner.manager.query(query, [threadId])
await queryRunner.release()
} catch (error) { } catch (error) {
console.error(`Error deleting thread_id ${threadId}`, error) console.error(`Error deleting thread_id ${threadId}`, error)
} finally {
await dataSource.destroy()
} }
} }
@ -232,6 +228,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
type: m.role type: m.role
}) })
} }
return returnIMessages return returnIMessages
} }
@ -240,6 +237,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} }
async clearChatMessages(overrideSessionId = ''): Promise<void> { async clearChatMessages(overrideSessionId = ''): Promise<void> {
if (!overrideSessionId) return
await this.delete(overrideSessionId) await this.delete(overrideSessionId)
} }
} }

View File

@ -1,42 +1,39 @@
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph' import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
import { RunnableConfig } from '@langchain/core/runnables' import { RunnableConfig } from '@langchain/core/runnables'
import { BaseMessage } from '@langchain/core/messages' import { BaseMessage } from '@langchain/core/messages'
import { DataSource, QueryRunner } from 'typeorm' import { DataSource } from 'typeorm'
import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface' import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface'
import { IMessage, MemoryMethods } from '../../../src/Interface' import { IMessage, MemoryMethods } from '../../../src/Interface'
import { mapChatMessageToBaseMessage } from '../../../src/utils' import { mapChatMessageToBaseMessage } from '../../../src/utils'
export class PostgresSaver extends BaseCheckpointSaver implements MemoryMethods { export class PostgresSaver extends BaseCheckpointSaver implements MemoryMethods {
protected isSetup: boolean protected isSetup: boolean
datasource: DataSource
queryRunner: QueryRunner
config: SaverOptions config: SaverOptions
threadId: string threadId: string
tableName = 'checkpoints' tableName = 'checkpoints'
constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) { constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
super(serde) super(serde)
this.config = config this.config = config
const { datasourceOptions, threadId } = config const { threadId } = config
this.threadId = threadId this.threadId = threadId
this.datasource = new DataSource(datasourceOptions)
} }
private async setup(): Promise<void> { private async getDataSource(): Promise<DataSource> {
const { datasourceOptions } = this.config
const dataSource = new DataSource(datasourceOptions)
await dataSource.initialize()
return dataSource
}
private async setup(dataSource: DataSource): Promise<void> {
if (this.isSetup) { if (this.isSetup) {
return return
} }
try { try {
const appDataSource = await this.datasource.initialize() const queryRunner = dataSource.createQueryRunner()
await queryRunner.manager.query(`
this.queryRunner = appDataSource.createQueryRunner()
await this.queryRunner.manager.query(`
CREATE TABLE IF NOT EXISTS ${this.tableName} ( CREATE TABLE IF NOT EXISTS ${this.tableName} (
thread_id TEXT NOT NULL, thread_id TEXT NOT NULL,
checkpoint_id TEXT NOT NULL, checkpoint_id TEXT NOT NULL,
@ -44,6 +41,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
checkpoint BYTEA, checkpoint BYTEA,
metadata BYTEA, metadata BYTEA,
PRIMARY KEY (thread_id, checkpoint_id));`) PRIMARY KEY (thread_id, checkpoint_id));`)
await queryRunner.release()
} catch (error) { } catch (error) {
console.error(`Error creating ${this.tableName} table`, error) console.error(`Error creating ${this.tableName} table`, error)
throw new Error(`Error creating ${this.tableName} table`) throw new Error(`Error creating ${this.tableName} table`)
@ -53,16 +51,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} }
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> { async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
await this.setup() const dataSource = await this.getDataSource()
await this.setup(dataSource)
const thread_id = config.configurable?.thread_id || this.threadId const thread_id = config.configurable?.thread_id || this.threadId
const checkpoint_id = config.configurable?.checkpoint_id const checkpoint_id = config.configurable?.checkpoint_id
if (checkpoint_id) { if (checkpoint_id) {
try { try {
const queryRunner = dataSource.createQueryRunner()
const keys = [thread_id, checkpoint_id] const keys = [thread_id, checkpoint_id]
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = $1 AND checkpoint_id = $2` const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = $1 AND checkpoint_id = $2`
const rows = await this.queryRunner.manager.query(sql, keys) const rows = await queryRunner.manager.query(sql, keys)
await queryRunner.release()
if (rows && rows.length > 0) { if (rows && rows.length > 0) {
return { return {
@ -82,39 +84,53 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} catch (error) { } catch (error) {
console.error(`Error retrieving ${this.tableName}`, error) console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`) throw new Error(`Error retrieving ${this.tableName}`)
} finally {
await dataSource.destroy()
} }
} else { } else {
const keys = [thread_id] try {
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1 ORDER BY checkpoint_id DESC LIMIT 1` const queryRunner = dataSource.createQueryRunner()
const keys = [thread_id]
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1 ORDER BY checkpoint_id DESC LIMIT 1`
const rows = await this.queryRunner.manager.query(sql, keys) const rows = await queryRunner.manager.query(sql, keys)
await queryRunner.release()
if (rows && rows.length > 0) { if (rows && rows.length > 0) {
return { return {
config: { config: {
configurable: { configurable: {
thread_id: rows[0].thread_id, thread_id: rows[0].thread_id,
checkpoint_id: rows[0].checkpoint_id checkpoint_id: rows[0].checkpoint_id
} }
}, },
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint, checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata, metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
parentConfig: rows[0].parent_id parentConfig: rows[0].parent_id
? { ? {
configurable: { configurable: {
thread_id: rows[0].thread_id, thread_id: rows[0].thread_id,
checkpoint_id: rows[0].parent_id checkpoint_id: rows[0].parent_id
}
} }
} : undefined
: undefined }
} }
} catch (error) {
console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`)
} finally {
await dataSource.destroy()
} }
} }
return undefined return undefined
} }
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> { async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
await this.setup() const dataSource = await this.getDataSource()
await this.setup(dataSource)
const queryRunner = dataSource.createQueryRunner()
const thread_id = config.configurable?.thread_id || this.threadId const thread_id = config.configurable?.thread_id || this.threadId
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1` let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1`
const args = [thread_id] const args = [thread_id]
@ -130,7 +146,8 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} }
try { try {
const rows = await this.queryRunner.manager.query(sql, args) const rows = await queryRunner.manager.query(sql, args)
await queryRunner.release()
if (rows && rows.length > 0) { if (rows && rows.length > 0) {
for (const row of rows) { for (const row of rows) {
@ -157,13 +174,18 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} catch (error) { } catch (error) {
console.error(`Error listing ${this.tableName}`, error) console.error(`Error listing ${this.tableName}`, error)
throw new Error(`Error listing ${this.tableName}`) throw new Error(`Error listing ${this.tableName}`)
} finally {
await dataSource.destroy()
} }
} }
async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> { async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
await this.setup() const dataSource = await this.getDataSource()
await this.setup(dataSource)
if (!config.configurable?.checkpoint_id) return {} if (!config.configurable?.checkpoint_id) return {}
try { try {
const queryRunner = dataSource.createQueryRunner()
const row = [ const row = [
config.configurable?.thread_id || this.threadId, config.configurable?.thread_id || this.threadId,
checkpoint.id, checkpoint.id,
@ -177,10 +199,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
ON CONFLICT (thread_id, checkpoint_id) ON CONFLICT (thread_id, checkpoint_id)
DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata` DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata`
await this.queryRunner.manager.query(query, row) await queryRunner.manager.query(query, row)
await queryRunner.release()
} catch (error) { } catch (error) {
console.error('Error saving checkpoint', error) console.error('Error saving checkpoint', error)
throw new Error('Error saving checkpoint') throw new Error('Error saving checkpoint')
} finally {
await dataSource.destroy()
} }
return { return {
@ -195,13 +220,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
if (!threadId) { if (!threadId) {
return return
} }
await this.setup()
const dataSource = await this.getDataSource()
await this.setup(dataSource)
const query = `DELETE FROM "${this.tableName}" WHERE thread_id = $1;` const query = `DELETE FROM "${this.tableName}" WHERE thread_id = $1;`
try { try {
await this.queryRunner.manager.query(query, [threadId]) const queryRunner = dataSource.createQueryRunner()
await queryRunner.manager.query(query, [threadId])
await queryRunner.release()
} catch (error) { } catch (error) {
console.error(`Error deleting thread_id ${threadId}`, error) console.error(`Error deleting thread_id ${threadId}`, error)
} finally {
await dataSource.destroy()
} }
} }