diff --git a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts index 1f6d9fddb..99c5a2f63 100644 --- a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts +++ b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts @@ -1,5 +1,5 @@ import { INode, INodeData, INodeParams } from '../../../src/Interface' -import { SqlDatabaseChain } from 'langchain/chains' +import { SqlDatabaseChain, SqlDatabaseChainInput } from 'langchain/chains' import { getBaseClasses } from '../../../src/utils' import { DataSource } from 'typeorm' import { SqlDatabase } from 'langchain/sql_db' @@ -51,31 +51,42 @@ class SqlDatabaseChain_Chains implements INode { } async init(nodeData: INodeData): Promise { - const databaseType = nodeData.inputs?.database + const databaseType = nodeData.inputs?.database as 'sqlite' const llm = nodeData.inputs?.llm as BaseLLM const dbFilePath = nodeData.inputs?.dbFilePath - const datasource = new DataSource({ - type: databaseType, - database: dbFilePath - }) - - const db = await SqlDatabase.fromDataSourceParams({ - appDataSource: datasource - }) - - const chain = new SqlDatabaseChain({ - llm, - database: db - }) + const chain = await getSQLDBChain(databaseType, dbFilePath, llm) return chain } async run(nodeData: INodeData, input: string): Promise { - const chain = nodeData.instance as SqlDatabaseChain + const databaseType = nodeData.inputs?.database as 'sqlite' + const llm = nodeData.inputs?.llm as BaseLLM + const dbFilePath = nodeData.inputs?.dbFilePath + + const chain = await getSQLDBChain(databaseType, dbFilePath, llm) const res = await chain.run(input) return res } } +const getSQLDBChain = async (databaseType: 'sqlite', dbFilePath: string, llm: BaseLLM) => { + const datasource = new DataSource({ + type: databaseType, + database: dbFilePath + }) + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: datasource + }) + + const obj: SqlDatabaseChainInput = { + llm, + database: db + } + + const chain = new SqlDatabaseChain(obj) + return chain +} + module.exports = { nodeClass: SqlDatabaseChain_Chains }