Flowise/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts

262 lines
10 KiB
TypeScript

import { DataSourceOptions } from 'typeorm/data-source'
import { DataSource } from 'typeorm'
import { BaseLanguageModel } from '@langchain/core/language_models/base'
import { PromptTemplate, PromptTemplateInput } from '@langchain/core/prompts'
import { SqlDatabaseChain, SqlDatabaseChainInput, DEFAULT_SQL_DATABASE_PROMPT } from 'langchain/chains/sql_db'
import { SqlDatabase } from 'langchain/sql_db'
import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer } from '../../../src/Interface'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { getBaseClasses, getInputVariables, transformBracesWithColon } from '../../../src/utils'
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation'
import { formatResponse } from '../../outputparsers/OutputParserHelpers'
type DatabaseType = 'sqlite' | 'postgres' | 'mssql' | 'mysql'
class SqlDatabaseChain_Chains implements INode {
label: string
name: string
version: number
type: string
icon: string
category: string
baseClasses: string[]
description: string
inputs: INodeParams[]
constructor() {
this.label = 'Sql Database Chain'
this.name = 'sqlDatabaseChain'
this.version = 5.0
this.type = 'SqlDatabaseChain'
this.icon = 'sqlchain.svg'
this.category = 'Chains'
this.description = 'Answer questions over a SQL database'
this.baseClasses = [this.type, ...getBaseClasses(SqlDatabaseChain)]
this.inputs = [
{
label: 'Language Model',
name: 'model',
type: 'BaseLanguageModel'
},
{
label: 'Database',
name: 'database',
type: 'options',
options: [
{
label: 'SQLite',
name: 'sqlite'
},
{
label: 'PostgreSQL',
name: 'postgres'
},
{
label: 'MSSQL',
name: 'mssql'
},
{
label: 'MySQL',
name: 'mysql'
}
],
default: 'sqlite'
},
{
label: 'Connection string or file path (sqlite only)',
name: 'url',
type: 'string',
placeholder: '127.0.0.1:5432/chinook'
},
{
label: 'Include Tables',
name: 'includesTables',
type: 'string',
description: 'Tables to include for queries, separated by comma. Can only use Include Tables or Ignore Tables',
placeholder: 'table1, table2',
additionalParams: true,
optional: true
},
{
label: 'Ignore Tables',
name: 'ignoreTables',
type: 'string',
description: 'Tables to ignore for queries, separated by comma. Can only use Ignore Tables or Include Tables',
placeholder: 'table1, table2',
additionalParams: true,
optional: true
},
{
label: "Sample table's rows info",
name: 'sampleRowsInTableInfo',
type: 'number',
description: 'Number of sample row for tables to load for info.',
placeholder: '3',
additionalParams: true,
optional: true
},
{
label: 'Top Keys',
name: 'topK',
type: 'number',
description:
'If you are querying for several rows of a table you can select the maximum number of results you want to get by using the "top_k" parameter (default is 10). This is useful for avoiding query results that exceed the prompt max length or consume tokens unnecessarily.',
placeholder: '10',
additionalParams: true,
optional: true
},
{
label: 'Custom Prompt',
name: 'customPrompt',
type: 'string',
description:
'You can provide custom prompt to the chain. This will override the existing default prompt used. See <a target="_blank" href="https://python.langchain.com/docs/integrations/tools/sqlite#customize-prompt">guide</a>',
warning:
'Prompt must include 3 input variables: {input}, {dialect}, {table_info}. You can refer to official guide from description above',
rows: 4,
placeholder: DEFAULT_SQL_DATABASE_PROMPT.template + DEFAULT_SQL_DATABASE_PROMPT.templateFormat,
additionalParams: true,
optional: true
},
{
label: 'Input Moderation',
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
name: 'inputModeration',
type: 'Moderation',
optional: true,
list: true
}
]
}
async init(nodeData: INodeData): Promise<any> {
const databaseType = nodeData.inputs?.database as DatabaseType
const model = nodeData.inputs?.model as BaseLanguageModel
const url = nodeData.inputs?.url as string
const includesTables = nodeData.inputs?.includesTables
const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',')
const ignoreTables = nodeData.inputs?.ignoreTables
const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',')
const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number
const topK = nodeData.inputs?.topK as number
const customPrompt = nodeData.inputs?.customPrompt as string
const chain = await getSQLDBChain(
databaseType,
url,
model,
splittedIncludesTables,
splittedIgnoreTables,
sampleRowsInTableInfo,
topK,
customPrompt
)
return chain
}
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
const databaseType = nodeData.inputs?.database as DatabaseType
const model = nodeData.inputs?.model as BaseLanguageModel
const url = nodeData.inputs?.url as string
const includesTables = nodeData.inputs?.includesTables
const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',')
const ignoreTables = nodeData.inputs?.ignoreTables
const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',')
const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number
const topK = nodeData.inputs?.topK as number
const customPrompt = nodeData.inputs?.customPrompt as string
const moderations = nodeData.inputs?.inputModeration as Moderation[]
const shouldStreamResponse = options.shouldStreamResponse
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
const chatId = options.chatId
if (moderations && moderations.length > 0) {
try {
// Use the output of the moderation chain as input for the Sql Database Chain
input = await checkInputs(moderations, input)
} catch (e) {
await new Promise((resolve) => setTimeout(resolve, 500))
if (shouldStreamResponse) {
streamResponse(sseStreamer, chatId, e.message)
}
return formatResponse(e.message)
}
}
const chain = await getSQLDBChain(
databaseType,
url,
model,
splittedIncludesTables,
splittedIgnoreTables,
sampleRowsInTableInfo,
topK,
customPrompt
)
const loggerHandler = new ConsoleCallbackHandler(options.logger, options?.orgId)
const callbacks = await additionalCallbacks(nodeData, options)
if (shouldStreamResponse) {
const handler = new CustomChainHandler(sseStreamer, chatId, 2)
const res = await chain.run(input, [loggerHandler, handler, ...callbacks])
return res
} else {
const res = await chain.run(input, [loggerHandler, ...callbacks])
return res
}
}
}
const getSQLDBChain = async (
databaseType: DatabaseType,
url: string,
llm: BaseLanguageModel,
includesTables?: string[],
ignoreTables?: string[],
sampleRowsInTableInfo?: number,
topK?: number,
customPrompt?: string
) => {
const datasource = new DataSource(
databaseType === 'sqlite'
? {
type: databaseType,
database: url
}
: ({
type: databaseType,
url: url
} as DataSourceOptions)
)
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
includesTables: includesTables,
ignoreTables: ignoreTables,
sampleRowsInTableInfo: sampleRowsInTableInfo
})
const obj: SqlDatabaseChainInput = {
llm,
database: db,
verbose: process.env.DEBUG === 'true' ? true : false,
topK: topK
}
if (customPrompt) {
customPrompt = transformBracesWithColon(customPrompt)
const options: PromptTemplateInput = {
template: customPrompt,
inputVariables: getInputVariables(customPrompt)
}
obj.prompt = new PromptTemplate(options)
}
const chain = new SqlDatabaseChain(obj)
return chain
}
module.exports = { nodeClass: SqlDatabaseChain_Chains }