import { DataSourceOptions } from 'typeorm/data-source' import { DataSource } from 'typeorm' import { BaseLanguageModel } from '@langchain/core/language_models/base' import { PromptTemplate, PromptTemplateInput } from '@langchain/core/prompts' import { SqlDatabaseChain, SqlDatabaseChainInput, DEFAULT_SQL_DATABASE_PROMPT } from 'langchain/chains/sql_db' import { SqlDatabase } from 'langchain/sql_db' import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer } from '../../../src/Interface' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { getBaseClasses, getInputVariables, transformBracesWithColon } from '../../../src/utils' import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation' import { formatResponse } from '../../outputparsers/OutputParserHelpers' type DatabaseType = 'sqlite' | 'postgres' | 'mssql' | 'mysql' class SqlDatabaseChain_Chains implements INode { label: string name: string version: number type: string icon: string category: string baseClasses: string[] description: string inputs: INodeParams[] constructor() { this.label = 'Sql Database Chain' this.name = 'sqlDatabaseChain' this.version = 5.0 this.type = 'SqlDatabaseChain' this.icon = 'sqlchain.svg' this.category = 'Chains' this.description = 'Answer questions over a SQL database' this.baseClasses = [this.type, ...getBaseClasses(SqlDatabaseChain)] this.inputs = [ { label: 'Language Model', name: 'model', type: 'BaseLanguageModel' }, { label: 'Database', name: 'database', type: 'options', options: [ { label: 'SQLite', name: 'sqlite' }, { label: 'PostgreSQL', name: 'postgres' }, { label: 'MSSQL', name: 'mssql' }, { label: 'MySQL', name: 'mysql' } ], default: 'sqlite' }, { label: 'Connection string or file path (sqlite only)', name: 'url', type: 'string', placeholder: '127.0.0.1:5432/chinook' }, { label: 'Include Tables', name: 'includesTables', type: 'string', description: 'Tables to include for queries, separated by comma. Can only use Include Tables or Ignore Tables', placeholder: 'table1, table2', additionalParams: true, optional: true }, { label: 'Ignore Tables', name: 'ignoreTables', type: 'string', description: 'Tables to ignore for queries, separated by comma. Can only use Ignore Tables or Include Tables', placeholder: 'table1, table2', additionalParams: true, optional: true }, { label: "Sample table's rows info", name: 'sampleRowsInTableInfo', type: 'number', description: 'Number of sample row for tables to load for info.', placeholder: '3', additionalParams: true, optional: true }, { label: 'Top Keys', name: 'topK', type: 'number', description: 'If you are querying for several rows of a table you can select the maximum number of results you want to get by using the "top_k" parameter (default is 10). This is useful for avoiding query results that exceed the prompt max length or consume tokens unnecessarily.', placeholder: '10', additionalParams: true, optional: true }, { label: 'Custom Prompt', name: 'customPrompt', type: 'string', description: 'You can provide custom prompt to the chain. This will override the existing default prompt used. See guide', warning: 'Prompt must include 3 input variables: {input}, {dialect}, {table_info}. You can refer to official guide from description above', rows: 4, placeholder: DEFAULT_SQL_DATABASE_PROMPT.template + DEFAULT_SQL_DATABASE_PROMPT.templateFormat, additionalParams: true, optional: true }, { label: 'Input Moderation', description: 'Detect text that could generate harmful output and prevent it from being sent to the language model', name: 'inputModeration', type: 'Moderation', optional: true, list: true } ] } async init(nodeData: INodeData): Promise { const databaseType = nodeData.inputs?.database as DatabaseType const model = nodeData.inputs?.model as BaseLanguageModel const url = nodeData.inputs?.url as string const includesTables = nodeData.inputs?.includesTables const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',') const ignoreTables = nodeData.inputs?.ignoreTables const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',') const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number const topK = nodeData.inputs?.topK as number const customPrompt = nodeData.inputs?.customPrompt as string const chain = await getSQLDBChain( databaseType, url, model, splittedIncludesTables, splittedIgnoreTables, sampleRowsInTableInfo, topK, customPrompt ) return chain } async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const databaseType = nodeData.inputs?.database as DatabaseType const model = nodeData.inputs?.model as BaseLanguageModel const url = nodeData.inputs?.url as string const includesTables = nodeData.inputs?.includesTables const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',') const ignoreTables = nodeData.inputs?.ignoreTables const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',') const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number const topK = nodeData.inputs?.topK as number const customPrompt = nodeData.inputs?.customPrompt as string const moderations = nodeData.inputs?.inputModeration as Moderation[] const shouldStreamResponse = options.shouldStreamResponse const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer const chatId = options.chatId if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the Sql Database Chain input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) if (shouldStreamResponse) { streamResponse(sseStreamer, chatId, e.message) } return formatResponse(e.message) } } const chain = await getSQLDBChain( databaseType, url, model, splittedIncludesTables, splittedIgnoreTables, sampleRowsInTableInfo, topK, customPrompt ) const loggerHandler = new ConsoleCallbackHandler(options.logger, options?.orgId) const callbacks = await additionalCallbacks(nodeData, options) if (shouldStreamResponse) { const handler = new CustomChainHandler(sseStreamer, chatId, 2) const res = await chain.run(input, [loggerHandler, handler, ...callbacks]) return res } else { const res = await chain.run(input, [loggerHandler, ...callbacks]) return res } } } const getSQLDBChain = async ( databaseType: DatabaseType, url: string, llm: BaseLanguageModel, includesTables?: string[], ignoreTables?: string[], sampleRowsInTableInfo?: number, topK?: number, customPrompt?: string ) => { const datasource = new DataSource( databaseType === 'sqlite' ? { type: databaseType, database: url } : ({ type: databaseType, url: url } as DataSourceOptions) ) const db = await SqlDatabase.fromDataSourceParams({ appDataSource: datasource, includesTables: includesTables, ignoreTables: ignoreTables, sampleRowsInTableInfo: sampleRowsInTableInfo }) const obj: SqlDatabaseChainInput = { llm, database: db, verbose: process.env.DEBUG === 'true' ? true : false, topK: topK } if (customPrompt) { customPrompt = transformBracesWithColon(customPrompt) const options: PromptTemplateInput = { template: customPrompt, inputVariables: getInputVariables(customPrompt) } obj.prompt = new PromptTemplate(options) } const chain = new SqlDatabaseChain(obj) return chain } module.exports = { nodeClass: SqlDatabaseChain_Chains }