253 lines
9.8 KiB
TypeScript
253 lines
9.8 KiB
TypeScript
import { DataSourceOptions } from 'typeorm/data-source'
|
|
import { DataSource } from 'typeorm'
|
|
import { BaseLanguageModel } from '@langchain/core/language_models/base'
|
|
import { PromptTemplate, PromptTemplateInput } from '@langchain/core/prompts'
|
|
import { SqlDatabaseChain, SqlDatabaseChainInput, DEFAULT_SQL_DATABASE_PROMPT } from 'langchain/chains/sql_db'
|
|
import { SqlDatabase } from 'langchain/sql_db'
|
|
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
|
|
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
|
|
import { getBaseClasses, getInputVariables } from '../../../src/utils'
|
|
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation'
|
|
import { formatResponse } from '../../outputparsers/OutputParserHelpers'
|
|
|
|
type DatabaseType = 'sqlite' | 'postgres' | 'mssql' | 'mysql'
|
|
|
|
class SqlDatabaseChain_Chains implements INode {
|
|
label: string
|
|
name: string
|
|
version: number
|
|
type: string
|
|
icon: string
|
|
category: string
|
|
baseClasses: string[]
|
|
description: string
|
|
inputs: INodeParams[]
|
|
|
|
constructor() {
|
|
this.label = 'Sql Database Chain'
|
|
this.name = 'sqlDatabaseChain'
|
|
this.version = 5.0
|
|
this.type = 'SqlDatabaseChain'
|
|
this.icon = 'sqlchain.svg'
|
|
this.category = 'Chains'
|
|
this.description = 'Answer questions over a SQL database'
|
|
this.baseClasses = [this.type, ...getBaseClasses(SqlDatabaseChain)]
|
|
this.inputs = [
|
|
{
|
|
label: 'Language Model',
|
|
name: 'model',
|
|
type: 'BaseLanguageModel'
|
|
},
|
|
{
|
|
label: 'Database',
|
|
name: 'database',
|
|
type: 'options',
|
|
options: [
|
|
{
|
|
label: 'SQLite',
|
|
name: 'sqlite'
|
|
},
|
|
{
|
|
label: 'PostgreSQL',
|
|
name: 'postgres'
|
|
},
|
|
{
|
|
label: 'MSSQL',
|
|
name: 'mssql'
|
|
},
|
|
{
|
|
label: 'MySQL',
|
|
name: 'mysql'
|
|
}
|
|
],
|
|
default: 'sqlite'
|
|
},
|
|
{
|
|
label: 'Connection string or file path (sqlite only)',
|
|
name: 'url',
|
|
type: 'string',
|
|
placeholder: '1270.0.0.1:5432/chinook'
|
|
},
|
|
{
|
|
label: 'Include Tables',
|
|
name: 'includesTables',
|
|
type: 'string',
|
|
description: 'Tables to include for queries, separated by comma. Can only use Include Tables or Ignore Tables',
|
|
placeholder: 'table1, table2',
|
|
additionalParams: true,
|
|
optional: true
|
|
},
|
|
{
|
|
label: 'Ignore Tables',
|
|
name: 'ignoreTables',
|
|
type: 'string',
|
|
description: 'Tables to ignore for queries, separated by comma. Can only use Ignore Tables or Include Tables',
|
|
placeholder: 'table1, table2',
|
|
additionalParams: true,
|
|
optional: true
|
|
},
|
|
{
|
|
label: "Sample table's rows info",
|
|
name: 'sampleRowsInTableInfo',
|
|
type: 'number',
|
|
description: 'Number of sample row for tables to load for info.',
|
|
placeholder: '3',
|
|
additionalParams: true,
|
|
optional: true
|
|
},
|
|
{
|
|
label: 'Top Keys',
|
|
name: 'topK',
|
|
type: 'number',
|
|
description:
|
|
'If you are querying for several rows of a table you can select the maximum number of results you want to get by using the "top_k" parameter (default is 10). This is useful for avoiding query results that exceed the prompt max length or consume tokens unnecessarily.',
|
|
placeholder: '10',
|
|
additionalParams: true,
|
|
optional: true
|
|
},
|
|
{
|
|
label: 'Custom Prompt',
|
|
name: 'customPrompt',
|
|
type: 'string',
|
|
description:
|
|
'You can provide custom prompt to the chain. This will override the existing default prompt used. See <a target="_blank" href="https://python.langchain.com/docs/integrations/tools/sqlite#customize-prompt">guide</a>',
|
|
warning:
|
|
'Prompt must include 3 input variables: {input}, {dialect}, {table_info}. You can refer to official guide from description above',
|
|
rows: 4,
|
|
placeholder: DEFAULT_SQL_DATABASE_PROMPT.template + DEFAULT_SQL_DATABASE_PROMPT.templateFormat,
|
|
additionalParams: true,
|
|
optional: true
|
|
},
|
|
{
|
|
label: 'Input Moderation',
|
|
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
|
|
name: 'inputModeration',
|
|
type: 'Moderation',
|
|
optional: true,
|
|
list: true
|
|
}
|
|
]
|
|
}
|
|
|
|
async init(nodeData: INodeData): Promise<any> {
|
|
const databaseType = nodeData.inputs?.database as DatabaseType
|
|
const model = nodeData.inputs?.model as BaseLanguageModel
|
|
const url = nodeData.inputs?.url as string
|
|
const includesTables = nodeData.inputs?.includesTables
|
|
const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',')
|
|
const ignoreTables = nodeData.inputs?.ignoreTables
|
|
const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',')
|
|
const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number
|
|
const topK = nodeData.inputs?.topK as number
|
|
const customPrompt = nodeData.inputs?.customPrompt as string
|
|
|
|
const chain = await getSQLDBChain(
|
|
databaseType,
|
|
url,
|
|
model,
|
|
splittedIncludesTables,
|
|
splittedIgnoreTables,
|
|
sampleRowsInTableInfo,
|
|
topK,
|
|
customPrompt
|
|
)
|
|
return chain
|
|
}
|
|
|
|
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
|
|
const databaseType = nodeData.inputs?.database as DatabaseType
|
|
const model = nodeData.inputs?.model as BaseLanguageModel
|
|
const url = nodeData.inputs?.url as string
|
|
const includesTables = nodeData.inputs?.includesTables
|
|
const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',')
|
|
const ignoreTables = nodeData.inputs?.ignoreTables
|
|
const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',')
|
|
const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number
|
|
const topK = nodeData.inputs?.topK as number
|
|
const customPrompt = nodeData.inputs?.customPrompt as string
|
|
const moderations = nodeData.inputs?.inputModeration as Moderation[]
|
|
if (moderations && moderations.length > 0) {
|
|
try {
|
|
// Use the output of the moderation chain as input for the Sql Database Chain
|
|
input = await checkInputs(moderations, input)
|
|
} catch (e) {
|
|
await new Promise((resolve) => setTimeout(resolve, 500))
|
|
streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId)
|
|
return formatResponse(e.message)
|
|
}
|
|
}
|
|
|
|
const chain = await getSQLDBChain(
|
|
databaseType,
|
|
url,
|
|
model,
|
|
splittedIncludesTables,
|
|
splittedIgnoreTables,
|
|
sampleRowsInTableInfo,
|
|
topK,
|
|
customPrompt
|
|
)
|
|
const loggerHandler = new ConsoleCallbackHandler(options.logger)
|
|
const callbacks = await additionalCallbacks(nodeData, options)
|
|
|
|
if (options.socketIO && options.socketIOClientId) {
|
|
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, 2)
|
|
const res = await chain.run(input, [loggerHandler, handler, ...callbacks])
|
|
return res
|
|
} else {
|
|
const res = await chain.run(input, [loggerHandler, ...callbacks])
|
|
return res
|
|
}
|
|
}
|
|
}
|
|
|
|
const getSQLDBChain = async (
|
|
databaseType: DatabaseType,
|
|
url: string,
|
|
llm: BaseLanguageModel,
|
|
includesTables?: string[],
|
|
ignoreTables?: string[],
|
|
sampleRowsInTableInfo?: number,
|
|
topK?: number,
|
|
customPrompt?: string
|
|
) => {
|
|
const datasource = new DataSource(
|
|
databaseType === 'sqlite'
|
|
? {
|
|
type: databaseType,
|
|
database: url
|
|
}
|
|
: ({
|
|
type: databaseType,
|
|
url: url
|
|
} as DataSourceOptions)
|
|
)
|
|
|
|
const db = await SqlDatabase.fromDataSourceParams({
|
|
appDataSource: datasource,
|
|
includesTables: includesTables,
|
|
ignoreTables: ignoreTables,
|
|
sampleRowsInTableInfo: sampleRowsInTableInfo
|
|
})
|
|
|
|
const obj: SqlDatabaseChainInput = {
|
|
llm,
|
|
database: db,
|
|
verbose: process.env.DEBUG === 'true' ? true : false,
|
|
topK: topK
|
|
}
|
|
|
|
if (customPrompt) {
|
|
const options: PromptTemplateInput = {
|
|
template: customPrompt,
|
|
inputVariables: getInputVariables(customPrompt)
|
|
}
|
|
obj.prompt = new PromptTemplate(options)
|
|
}
|
|
|
|
const chain = new SqlDatabaseChain(obj)
|
|
return chain
|
|
}
|
|
|
|
module.exports = { nodeClass: SqlDatabaseChain_Chains }
|