Flowise/packages/components/nodes/chains/GraphCypherQAChain/GraphCypherQAChain.ts

259 lines
10 KiB
TypeScript

import { ICommonObject, INode, INodeData, INodeParams, INodeOutputsValue, IServerSideEventStreamer } from '../../../src/Interface'
import { FromLLMInput, GraphCypherQAChain } from '@langchain/community/chains/graph_qa/cypher'
import { getBaseClasses } from '../../../src/utils'
import { BasePromptTemplate, PromptTemplate, FewShotPromptTemplate } from '@langchain/core/prompts'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { ConsoleCallbackHandler as LCConsoleCallbackHandler } from '@langchain/core/tracers/console'
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation'
import { formatResponse } from '../../outputparsers/OutputParserHelpers'
class GraphCypherQA_Chain implements INode {
label: string
name: string
version: number
type: string
icon: string
category: string
description: string
baseClasses: string[]
inputs: INodeParams[]
sessionId?: string
outputs: INodeOutputsValue[]
constructor(fields?: { sessionId?: string }) {
this.label = 'Graph Cypher QA Chain'
this.name = 'graphCypherQAChain'
this.version = 1.1
this.type = 'GraphCypherQAChain'
this.icon = 'graphqa.svg'
this.category = 'Chains'
this.description = 'Advanced chain for question-answering against a Neo4j graph by generating Cypher statements'
this.baseClasses = [this.type, ...getBaseClasses(GraphCypherQAChain)]
this.sessionId = fields?.sessionId
this.inputs = [
{
label: 'Language Model',
name: 'model',
type: 'BaseLanguageModel',
description: 'Model for generating Cypher queries and answers.'
},
{
label: 'Neo4j Graph',
name: 'graph',
type: 'Neo4j'
},
{
label: 'Cypher Generation Prompt',
name: 'cypherPrompt',
optional: true,
type: 'BasePromptTemplate',
description:
'Prompt template for generating Cypher queries. Must include {schema} and {question} variables. If not provided, default prompt will be used.'
},
{
label: 'Cypher Generation Model',
name: 'cypherModel',
optional: true,
type: 'BaseLanguageModel',
description: 'Model for generating Cypher queries. If not provided, the main model will be used.'
},
{
label: 'QA Prompt',
name: 'qaPrompt',
optional: true,
type: 'BasePromptTemplate',
description:
'Prompt template for generating answers. Must include {context} and {question} variables. If not provided, default prompt will be used.'
},
{
label: 'QA Model',
name: 'qaModel',
optional: true,
type: 'BaseLanguageModel',
description: 'Model for generating answers. If not provided, the main model will be used.'
},
{
label: 'Input Moderation',
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
name: 'inputModeration',
type: 'Moderation',
optional: true,
list: true
},
{
label: 'Return Direct',
name: 'returnDirect',
type: 'boolean',
default: false,
optional: true,
description: 'If true, return the raw query results instead of using the QA chain'
}
]
this.outputs = [
{
label: 'Graph Cypher QA Chain',
name: 'graphCypherQAChain',
baseClasses: [this.type, ...getBaseClasses(GraphCypherQAChain)]
},
{
label: 'Output Prediction',
name: 'outputPrediction',
baseClasses: ['string', 'json']
}
]
}
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
const model = nodeData.inputs?.model
const cypherModel = nodeData.inputs?.cypherModel
const qaModel = nodeData.inputs?.qaModel
const graph = nodeData.inputs?.graph
const cypherPrompt = nodeData.inputs?.cypherPrompt as BasePromptTemplate | FewShotPromptTemplate | undefined
const qaPrompt = nodeData.inputs?.qaPrompt as BasePromptTemplate | undefined
const returnDirect = nodeData.inputs?.returnDirect as boolean
const output = nodeData.outputs?.output as string
if (!model) {
throw new Error('Language Model is required')
}
// Handle prompt values if they exist
let cypherPromptTemplate: PromptTemplate | FewShotPromptTemplate | undefined
let qaPromptTemplate: PromptTemplate | undefined
if (cypherPrompt) {
if (cypherPrompt instanceof PromptTemplate) {
cypherPromptTemplate = new PromptTemplate({
template: cypherPrompt.template as string,
inputVariables: cypherPrompt.inputVariables
})
if (!qaPrompt) {
throw new Error('QA Prompt is required when Cypher Prompt is a Prompt Template')
}
} else if (cypherPrompt instanceof FewShotPromptTemplate) {
const examplePrompt = cypherPrompt.examplePrompt as PromptTemplate
cypherPromptTemplate = new FewShotPromptTemplate({
examples: cypherPrompt.examples,
examplePrompt: examplePrompt,
inputVariables: cypherPrompt.inputVariables,
prefix: cypherPrompt.prefix,
suffix: cypherPrompt.suffix,
exampleSeparator: cypherPrompt.exampleSeparator,
templateFormat: cypherPrompt.templateFormat
})
} else {
cypherPromptTemplate = cypherPrompt as PromptTemplate
}
}
if (qaPrompt instanceof PromptTemplate) {
qaPromptTemplate = new PromptTemplate({
template: qaPrompt.template as string,
inputVariables: qaPrompt.inputVariables
})
}
// Validate required variables in prompts
if (
cypherPromptTemplate &&
(!cypherPromptTemplate?.inputVariables.includes('schema') || !cypherPromptTemplate?.inputVariables.includes('question'))
) {
throw new Error('Cypher Generation Prompt must include {schema} and {question} variables')
}
const fromLLMInput: FromLLMInput = {
llm: model,
graph,
returnDirect
}
if (cypherPromptTemplate) {
fromLLMInput['cypherLLM'] = cypherModel ?? model
fromLLMInput['cypherPrompt'] = cypherPromptTemplate
}
if (qaPromptTemplate) {
fromLLMInput['qaLLM'] = qaModel ?? model
fromLLMInput['qaPrompt'] = qaPromptTemplate
}
const chain = GraphCypherQAChain.fromLLM(fromLLMInput)
if (output === this.name) {
return chain
} else if (output === 'outputPrediction') {
nodeData.instance = chain
return await this.run(nodeData, input, options)
}
return chain
}
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
const chain = nodeData.instance as GraphCypherQAChain
const moderations = nodeData.inputs?.inputModeration as Moderation[]
const returnDirect = nodeData.inputs?.returnDirect as boolean
const shouldStreamResponse = options.shouldStreamResponse
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
const chatId = options.chatId
// Handle input moderation if configured
if (moderations && moderations.length > 0) {
try {
input = await checkInputs(moderations, input)
} catch (e) {
await new Promise((resolve) => setTimeout(resolve, 500))
if (shouldStreamResponse) {
streamResponse(sseStreamer, chatId, e.message)
}
return formatResponse(e.message)
}
}
const obj = {
query: input
}
const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbackHandlers = await additionalCallbacks(nodeData, options)
let callbacks = [loggerHandler, ...callbackHandlers]
if (process.env.DEBUG === 'true') {
callbacks.push(new LCConsoleCallbackHandler())
}
try {
let response
if (shouldStreamResponse) {
if (returnDirect) {
response = await chain.invoke(obj, { callbacks })
let result = response?.result
if (typeof result === 'object') {
result = '```json\n' + JSON.stringify(result, null, 2)
}
if (result && typeof result === 'string') {
streamResponse(sseStreamer, chatId, result)
}
} else {
const handler = new CustomChainHandler(sseStreamer, chatId, 2)
callbacks.push(handler)
response = await chain.invoke(obj, { callbacks })
}
} else {
response = await chain.invoke(obj, { callbacks })
}
return formatResponse(response?.result)
} catch (error) {
console.error('Error in GraphCypherQAChain:', error)
if (shouldStreamResponse) {
streamResponse(sseStreamer, chatId, error.message)
}
return formatResponse(`Error: ${error.message}`)
}
}
}
module.exports = { nodeClass: GraphCypherQA_Chain }