add bugfix for SQLDBChain

This commit is contained in:
Henry 2023-05-01 13:54:14 +01:00
parent a3864fbccb
commit 04d0170c6b
1 changed files with 27 additions and 16 deletions

View File

@ -1,5 +1,5 @@
import { INode, INodeData, INodeParams } from '../../../src/Interface' import { INode, INodeData, INodeParams } from '../../../src/Interface'
import { SqlDatabaseChain } from 'langchain/chains' import { SqlDatabaseChain, SqlDatabaseChainInput } from 'langchain/chains'
import { getBaseClasses } from '../../../src/utils' import { getBaseClasses } from '../../../src/utils'
import { DataSource } from 'typeorm' import { DataSource } from 'typeorm'
import { SqlDatabase } from 'langchain/sql_db' import { SqlDatabase } from 'langchain/sql_db'
@ -51,31 +51,42 @@ class SqlDatabaseChain_Chains implements INode {
} }
async init(nodeData: INodeData): Promise<any> { async init(nodeData: INodeData): Promise<any> {
const databaseType = nodeData.inputs?.database const databaseType = nodeData.inputs?.database as 'sqlite'
const llm = nodeData.inputs?.llm as BaseLLM const llm = nodeData.inputs?.llm as BaseLLM
const dbFilePath = nodeData.inputs?.dbFilePath const dbFilePath = nodeData.inputs?.dbFilePath
const datasource = new DataSource({ const chain = await getSQLDBChain(databaseType, dbFilePath, llm)
type: databaseType,
database: dbFilePath
})
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource
})
const chain = new SqlDatabaseChain({
llm,
database: db
})
return chain return chain
} }
async run(nodeData: INodeData, input: string): Promise<string> { async run(nodeData: INodeData, input: string): Promise<string> {
const chain = nodeData.instance as SqlDatabaseChain const databaseType = nodeData.inputs?.database as 'sqlite'
const llm = nodeData.inputs?.llm as BaseLLM
const dbFilePath = nodeData.inputs?.dbFilePath
const chain = await getSQLDBChain(databaseType, dbFilePath, llm)
const res = await chain.run(input) const res = await chain.run(input)
return res return res
} }
} }
const getSQLDBChain = async (databaseType: 'sqlite', dbFilePath: string, llm: BaseLLM) => {
const datasource = new DataSource({
type: databaseType,
database: dbFilePath
})
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource
})
const obj: SqlDatabaseChainInput = {
llm,
database: db
}
const chain = new SqlDatabaseChain(obj)
return chain
}
module.exports = { nodeClass: SqlDatabaseChain_Chains } module.exports = { nodeClass: SqlDatabaseChain_Chains }