fix: change data source lifecycle on agent memory mysql saver (#3578)
* fix: change data source lifecycle on agent memory mysql saver * Update mysqlSaver.ts * Update pgSaver.ts * linting fix --------- Co-authored-by: Henry Heng <henryheng@flowiseai.com>
This commit is contained in:
parent
371da23986
commit
09d20fa5ad
|
|
@ -1,50 +1,46 @@
|
|||
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
|
||||
import { RunnableConfig } from '@langchain/core/runnables'
|
||||
import { BaseMessage } from '@langchain/core/messages'
|
||||
import { DataSource, QueryRunner } from 'typeorm'
|
||||
import { DataSource } from 'typeorm'
|
||||
import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface'
|
||||
import { IMessage, MemoryMethods } from '../../../src/Interface'
|
||||
import { mapChatMessageToBaseMessage } from '../../../src/utils'
|
||||
|
||||
export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods {
|
||||
protected isSetup: boolean
|
||||
|
||||
datasource: DataSource
|
||||
|
||||
queryRunner: QueryRunner
|
||||
|
||||
config: SaverOptions
|
||||
|
||||
threadId: string
|
||||
|
||||
tableName = 'checkpoints'
|
||||
|
||||
constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
|
||||
super(serde)
|
||||
this.config = config
|
||||
const { datasourceOptions, threadId } = config
|
||||
const { threadId } = config
|
||||
this.threadId = threadId
|
||||
this.datasource = new DataSource(datasourceOptions)
|
||||
}
|
||||
|
||||
private async setup(): Promise<void> {
|
||||
if (this.isSetup) {
|
||||
return
|
||||
}
|
||||
private async getDataSource(): Promise<DataSource> {
|
||||
const { datasourceOptions } = this.config
|
||||
const dataSource = new DataSource(datasourceOptions)
|
||||
await dataSource.initialize()
|
||||
return dataSource
|
||||
}
|
||||
|
||||
private async setup(dataSource: DataSource): Promise<void> {
|
||||
if (this.isSetup) return
|
||||
|
||||
try {
|
||||
const appDataSource = await this.datasource.initialize()
|
||||
|
||||
this.queryRunner = appDataSource.createQueryRunner()
|
||||
await this.queryRunner.manager.query(`
|
||||
CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
||||
thread_id VARCHAR(255) NOT NULL,
|
||||
checkpoint_id VARCHAR(255) NOT NULL,
|
||||
parent_id VARCHAR(255),
|
||||
checkpoint BLOB,
|
||||
metadata BLOB,
|
||||
PRIMARY KEY (thread_id, checkpoint_id)
|
||||
);`)
|
||||
const queryRunner = dataSource.createQueryRunner()
|
||||
await queryRunner.manager.query(`
|
||||
CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
||||
thread_id VARCHAR(255) NOT NULL,
|
||||
checkpoint_id VARCHAR(255) NOT NULL,
|
||||
parent_id VARCHAR(255),
|
||||
checkpoint BLOB,
|
||||
metadata BLOB,
|
||||
PRIMARY KEY (thread_id, checkpoint_id)
|
||||
);`)
|
||||
await queryRunner.release()
|
||||
} catch (error) {
|
||||
console.error(`Error creating ${this.tableName} table`, error)
|
||||
throw new Error(`Error creating ${this.tableName} table`)
|
||||
|
|
@ -54,79 +50,67 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
}
|
||||
|
||||
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
|
||||
await this.setup()
|
||||
const dataSource = await this.getDataSource()
|
||||
await this.setup(dataSource)
|
||||
|
||||
const thread_id = config.configurable?.thread_id || this.threadId
|
||||
const checkpoint_id = config.configurable?.checkpoint_id
|
||||
|
||||
if (checkpoint_id) {
|
||||
try {
|
||||
const keys = [thread_id, checkpoint_id]
|
||||
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`
|
||||
try {
|
||||
const queryRunner = dataSource.createQueryRunner()
|
||||
const sql = checkpoint_id
|
||||
? `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`
|
||||
: `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
|
||||
|
||||
const rows = await this.queryRunner.manager.query(sql, keys)
|
||||
|
||||
if (rows && rows.length > 0) {
|
||||
return {
|
||||
config,
|
||||
checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint,
|
||||
metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata,
|
||||
parentConfig: rows[0].parent_id
|
||||
? {
|
||||
configurable: {
|
||||
thread_id,
|
||||
checkpoint_id: rows[0].parent_id
|
||||
}
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error retrieving ${this.tableName}`, error)
|
||||
throw new Error(`Error retrieving ${this.tableName}`)
|
||||
}
|
||||
} else {
|
||||
const keys = [thread_id]
|
||||
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
|
||||
|
||||
const rows = await this.queryRunner.manager.query(sql, keys)
|
||||
const rows = await queryRunner.manager.query(sql, checkpoint_id ? [thread_id, checkpoint_id] : [thread_id])
|
||||
await queryRunner.release()
|
||||
|
||||
if (rows && rows.length > 0) {
|
||||
const row = rows[0]
|
||||
return {
|
||||
config: {
|
||||
configurable: {
|
||||
thread_id: rows[0].thread_id,
|
||||
checkpoint_id: rows[0].checkpoint_id
|
||||
thread_id: row.thread_id || thread_id,
|
||||
checkpoint_id: row.checkpoint_id || checkpoint_id
|
||||
}
|
||||
},
|
||||
checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint,
|
||||
metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata,
|
||||
parentConfig: rows[0].parent_id
|
||||
checkpoint: (await this.serde.parse(row.checkpoint.toString())) as Checkpoint,
|
||||
metadata: (await this.serde.parse(row.metadata.toString())) as CheckpointMetadata,
|
||||
parentConfig: row.parent_id
|
||||
? {
|
||||
configurable: {
|
||||
thread_id: rows[0].thread_id,
|
||||
checkpoint_id: rows[0].parent_id
|
||||
thread_id,
|
||||
checkpoint_id: row.parent_id
|
||||
}
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error retrieving ${this.tableName}`, error)
|
||||
throw new Error(`Error retrieving ${this.tableName}`)
|
||||
} finally {
|
||||
await dataSource.destroy()
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
|
||||
await this.setup()
|
||||
const thread_id = config.configurable?.thread_id || this.threadId
|
||||
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
|
||||
before ? 'AND checkpoint_id < ?' : ''
|
||||
} ORDER BY checkpoint_id DESC`
|
||||
if (limit) {
|
||||
sql += ` LIMIT ${limit}`
|
||||
}
|
||||
const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean)
|
||||
|
||||
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple, void, unknown> {
|
||||
const dataSource = await this.getDataSource()
|
||||
await this.setup(dataSource)
|
||||
const queryRunner = dataSource.createQueryRunner()
|
||||
try {
|
||||
const rows = await this.queryRunner.manager.query(sql, args)
|
||||
const threadId = config.configurable?.thread_id || this.threadId
|
||||
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
|
||||
before ? 'AND checkpoint_id < ?' : ''
|
||||
} ORDER BY checkpoint_id DESC`
|
||||
if (limit) {
|
||||
sql += ` LIMIT ${limit}`
|
||||
}
|
||||
const args = [threadId, before?.configurable?.checkpoint_id].filter(Boolean)
|
||||
|
||||
const rows = await queryRunner.manager.query(sql, args)
|
||||
await queryRunner.release()
|
||||
|
||||
if (rows && rows.length > 0) {
|
||||
for (const row of rows) {
|
||||
|
|
@ -151,15 +135,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error listing ${this.tableName}`, error)
|
||||
throw new Error(`Error listing ${this.tableName}`)
|
||||
console.error(`Error listing checkpoints`, error)
|
||||
throw new Error(`Error listing checkpoints`)
|
||||
} finally {
|
||||
await dataSource.destroy()
|
||||
}
|
||||
}
|
||||
|
||||
async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
|
||||
await this.setup()
|
||||
const dataSource = await this.getDataSource()
|
||||
await this.setup(dataSource)
|
||||
|
||||
if (!config.configurable?.checkpoint_id) return {}
|
||||
try {
|
||||
const queryRunner = dataSource.createQueryRunner()
|
||||
const row = [
|
||||
config.configurable?.thread_id || this.threadId,
|
||||
checkpoint.id,
|
||||
|
|
@ -172,10 +161,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
VALUES (?, ?, ?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE checkpoint = VALUES(checkpoint), metadata = VALUES(metadata)`
|
||||
|
||||
await this.queryRunner.manager.query(query, row)
|
||||
await queryRunner.manager.query(query, row)
|
||||
await queryRunner.release()
|
||||
} catch (error) {
|
||||
console.error('Error saving checkpoint', error)
|
||||
throw new Error('Error saving checkpoint')
|
||||
} finally {
|
||||
await dataSource.destroy()
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
@ -187,16 +179,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
}
|
||||
|
||||
async delete(threadId: string): Promise<void> {
|
||||
if (!threadId) {
|
||||
return
|
||||
}
|
||||
await this.setup()
|
||||
const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;`
|
||||
if (!threadId) return
|
||||
|
||||
const dataSource = await this.getDataSource()
|
||||
await this.setup(dataSource)
|
||||
|
||||
try {
|
||||
await this.queryRunner.manager.query(query, [threadId])
|
||||
const queryRunner = dataSource.createQueryRunner()
|
||||
const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;`
|
||||
await queryRunner.manager.query(query, [threadId])
|
||||
await queryRunner.release()
|
||||
} catch (error) {
|
||||
console.error(`Error deleting thread_id ${threadId}`, error)
|
||||
} finally {
|
||||
await dataSource.destroy()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -232,6 +228,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
type: m.role
|
||||
})
|
||||
}
|
||||
|
||||
return returnIMessages
|
||||
}
|
||||
|
||||
|
|
@ -240,6 +237,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
}
|
||||
|
||||
async clearChatMessages(overrideSessionId = ''): Promise<void> {
|
||||
if (!overrideSessionId) return
|
||||
await this.delete(overrideSessionId)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,42 +1,39 @@
|
|||
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
|
||||
import { RunnableConfig } from '@langchain/core/runnables'
|
||||
import { BaseMessage } from '@langchain/core/messages'
|
||||
import { DataSource, QueryRunner } from 'typeorm'
|
||||
import { DataSource } from 'typeorm'
|
||||
import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface'
|
||||
import { IMessage, MemoryMethods } from '../../../src/Interface'
|
||||
import { mapChatMessageToBaseMessage } from '../../../src/utils'
|
||||
|
||||
export class PostgresSaver extends BaseCheckpointSaver implements MemoryMethods {
|
||||
protected isSetup: boolean
|
||||
|
||||
datasource: DataSource
|
||||
|
||||
queryRunner: QueryRunner
|
||||
|
||||
config: SaverOptions
|
||||
|
||||
threadId: string
|
||||
|
||||
tableName = 'checkpoints'
|
||||
|
||||
constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
|
||||
super(serde)
|
||||
this.config = config
|
||||
const { datasourceOptions, threadId } = config
|
||||
const { threadId } = config
|
||||
this.threadId = threadId
|
||||
this.datasource = new DataSource(datasourceOptions)
|
||||
}
|
||||
|
||||
private async setup(): Promise<void> {
|
||||
private async getDataSource(): Promise<DataSource> {
|
||||
const { datasourceOptions } = this.config
|
||||
const dataSource = new DataSource(datasourceOptions)
|
||||
await dataSource.initialize()
|
||||
return dataSource
|
||||
}
|
||||
|
||||
private async setup(dataSource: DataSource): Promise<void> {
|
||||
if (this.isSetup) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const appDataSource = await this.datasource.initialize()
|
||||
|
||||
this.queryRunner = appDataSource.createQueryRunner()
|
||||
await this.queryRunner.manager.query(`
|
||||
const queryRunner = dataSource.createQueryRunner()
|
||||
await queryRunner.manager.query(`
|
||||
CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
||||
thread_id TEXT NOT NULL,
|
||||
checkpoint_id TEXT NOT NULL,
|
||||
|
|
@ -44,6 +41,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
checkpoint BYTEA,
|
||||
metadata BYTEA,
|
||||
PRIMARY KEY (thread_id, checkpoint_id));`)
|
||||
await queryRunner.release()
|
||||
} catch (error) {
|
||||
console.error(`Error creating ${this.tableName} table`, error)
|
||||
throw new Error(`Error creating ${this.tableName} table`)
|
||||
|
|
@ -53,16 +51,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
}
|
||||
|
||||
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
|
||||
await this.setup()
|
||||
const dataSource = await this.getDataSource()
|
||||
await this.setup(dataSource)
|
||||
|
||||
const thread_id = config.configurable?.thread_id || this.threadId
|
||||
const checkpoint_id = config.configurable?.checkpoint_id
|
||||
|
||||
if (checkpoint_id) {
|
||||
try {
|
||||
const queryRunner = dataSource.createQueryRunner()
|
||||
const keys = [thread_id, checkpoint_id]
|
||||
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = $1 AND checkpoint_id = $2`
|
||||
|
||||
const rows = await this.queryRunner.manager.query(sql, keys)
|
||||
const rows = await queryRunner.manager.query(sql, keys)
|
||||
await queryRunner.release()
|
||||
|
||||
if (rows && rows.length > 0) {
|
||||
return {
|
||||
|
|
@ -82,39 +84,53 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
} catch (error) {
|
||||
console.error(`Error retrieving ${this.tableName}`, error)
|
||||
throw new Error(`Error retrieving ${this.tableName}`)
|
||||
} finally {
|
||||
await dataSource.destroy()
|
||||
}
|
||||
} else {
|
||||
const keys = [thread_id]
|
||||
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1 ORDER BY checkpoint_id DESC LIMIT 1`
|
||||
try {
|
||||
const queryRunner = dataSource.createQueryRunner()
|
||||
const keys = [thread_id]
|
||||
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1 ORDER BY checkpoint_id DESC LIMIT 1`
|
||||
|
||||
const rows = await this.queryRunner.manager.query(sql, keys)
|
||||
const rows = await queryRunner.manager.query(sql, keys)
|
||||
await queryRunner.release()
|
||||
|
||||
if (rows && rows.length > 0) {
|
||||
return {
|
||||
config: {
|
||||
configurable: {
|
||||
thread_id: rows[0].thread_id,
|
||||
checkpoint_id: rows[0].checkpoint_id
|
||||
}
|
||||
},
|
||||
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
|
||||
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
|
||||
parentConfig: rows[0].parent_id
|
||||
? {
|
||||
configurable: {
|
||||
thread_id: rows[0].thread_id,
|
||||
checkpoint_id: rows[0].parent_id
|
||||
if (rows && rows.length > 0) {
|
||||
return {
|
||||
config: {
|
||||
configurable: {
|
||||
thread_id: rows[0].thread_id,
|
||||
checkpoint_id: rows[0].checkpoint_id
|
||||
}
|
||||
},
|
||||
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
|
||||
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
|
||||
parentConfig: rows[0].parent_id
|
||||
? {
|
||||
configurable: {
|
||||
thread_id: rows[0].thread_id,
|
||||
checkpoint_id: rows[0].parent_id
|
||||
}
|
||||
}
|
||||
}
|
||||
: undefined
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error retrieving ${this.tableName}`, error)
|
||||
throw new Error(`Error retrieving ${this.tableName}`)
|
||||
} finally {
|
||||
await dataSource.destroy()
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
|
||||
await this.setup()
|
||||
const dataSource = await this.getDataSource()
|
||||
await this.setup(dataSource)
|
||||
|
||||
const queryRunner = dataSource.createQueryRunner()
|
||||
const thread_id = config.configurable?.thread_id || this.threadId
|
||||
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1`
|
||||
const args = [thread_id]
|
||||
|
|
@ -130,7 +146,8 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
}
|
||||
|
||||
try {
|
||||
const rows = await this.queryRunner.manager.query(sql, args)
|
||||
const rows = await queryRunner.manager.query(sql, args)
|
||||
await queryRunner.release()
|
||||
|
||||
if (rows && rows.length > 0) {
|
||||
for (const row of rows) {
|
||||
|
|
@ -157,13 +174,18 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
} catch (error) {
|
||||
console.error(`Error listing ${this.tableName}`, error)
|
||||
throw new Error(`Error listing ${this.tableName}`)
|
||||
} finally {
|
||||
await dataSource.destroy()
|
||||
}
|
||||
}
|
||||
|
||||
async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
|
||||
await this.setup()
|
||||
const dataSource = await this.getDataSource()
|
||||
await this.setup(dataSource)
|
||||
|
||||
if (!config.configurable?.checkpoint_id) return {}
|
||||
try {
|
||||
const queryRunner = dataSource.createQueryRunner()
|
||||
const row = [
|
||||
config.configurable?.thread_id || this.threadId,
|
||||
checkpoint.id,
|
||||
|
|
@ -177,10 +199,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
ON CONFLICT (thread_id, checkpoint_id)
|
||||
DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata`
|
||||
|
||||
await this.queryRunner.manager.query(query, row)
|
||||
await queryRunner.manager.query(query, row)
|
||||
await queryRunner.release()
|
||||
} catch (error) {
|
||||
console.error('Error saving checkpoint', error)
|
||||
throw new Error('Error saving checkpoint')
|
||||
} finally {
|
||||
await dataSource.destroy()
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
@ -195,13 +220,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||
if (!threadId) {
|
||||
return
|
||||
}
|
||||
await this.setup()
|
||||
|
||||
const dataSource = await this.getDataSource()
|
||||
await this.setup(dataSource)
|
||||
|
||||
const query = `DELETE FROM "${this.tableName}" WHERE thread_id = $1;`
|
||||
|
||||
try {
|
||||
await this.queryRunner.manager.query(query, [threadId])
|
||||
const queryRunner = dataSource.createQueryRunner()
|
||||
await queryRunner.manager.query(query, [threadId])
|
||||
await queryRunner.release()
|
||||
} catch (error) {
|
||||
console.error(`Error deleting thread_id ${threadId}`, error)
|
||||
} finally {
|
||||
await dataSource.destroy()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue