add bugfix for SQLDBChain
This commit is contained in:
parent
a3864fbccb
commit
04d0170c6b
|
|
@ -1,5 +1,5 @@
|
|||
import { INode, INodeData, INodeParams } from '../../../src/Interface'
|
||||
import { SqlDatabaseChain } from 'langchain/chains'
|
||||
import { SqlDatabaseChain, SqlDatabaseChainInput } from 'langchain/chains'
|
||||
import { getBaseClasses } from '../../../src/utils'
|
||||
import { DataSource } from 'typeorm'
|
||||
import { SqlDatabase } from 'langchain/sql_db'
|
||||
|
|
@ -51,10 +51,26 @@ class SqlDatabaseChain_Chains implements INode {
|
|||
}
|
||||
|
||||
async init(nodeData: INodeData): Promise<any> {
|
||||
const databaseType = nodeData.inputs?.database
|
||||
const databaseType = nodeData.inputs?.database as 'sqlite'
|
||||
const llm = nodeData.inputs?.llm as BaseLLM
|
||||
const dbFilePath = nodeData.inputs?.dbFilePath
|
||||
|
||||
const chain = await getSQLDBChain(databaseType, dbFilePath, llm)
|
||||
return chain
|
||||
}
|
||||
|
||||
async run(nodeData: INodeData, input: string): Promise<string> {
|
||||
const databaseType = nodeData.inputs?.database as 'sqlite'
|
||||
const llm = nodeData.inputs?.llm as BaseLLM
|
||||
const dbFilePath = nodeData.inputs?.dbFilePath
|
||||
|
||||
const chain = await getSQLDBChain(databaseType, dbFilePath, llm)
|
||||
const res = await chain.run(input)
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
const getSQLDBChain = async (databaseType: 'sqlite', dbFilePath: string, llm: BaseLLM) => {
|
||||
const datasource = new DataSource({
|
||||
type: databaseType,
|
||||
database: dbFilePath
|
||||
|
|
@ -64,18 +80,13 @@ class SqlDatabaseChain_Chains implements INode {
|
|||
appDataSource: datasource
|
||||
})
|
||||
|
||||
const chain = new SqlDatabaseChain({
|
||||
const obj: SqlDatabaseChainInput = {
|
||||
llm,
|
||||
database: db
|
||||
})
|
||||
}
|
||||
|
||||
const chain = new SqlDatabaseChain(obj)
|
||||
return chain
|
||||
}
|
||||
|
||||
async run(nodeData: INodeData, input: string): Promise<string> {
|
||||
const chain = nodeData.instance as SqlDatabaseChain
|
||||
const res = await chain.run(input)
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { nodeClass: SqlDatabaseChain_Chains }
|
||||
|
|
|
|||
Loading…
Reference in New Issue