259 lines
10 KiB
TypeScript
259 lines
10 KiB
TypeScript
import { ICommonObject, INode, INodeData, INodeParams, INodeOutputsValue, IServerSideEventStreamer } from '../../../src/Interface'
|
|
import { FromLLMInput, GraphCypherQAChain } from '@langchain/community/chains/graph_qa/cypher'
|
|
import { getBaseClasses } from '../../../src/utils'
|
|
import { BasePromptTemplate, PromptTemplate, FewShotPromptTemplate } from '@langchain/core/prompts'
|
|
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
|
|
import { ConsoleCallbackHandler as LCConsoleCallbackHandler } from '@langchain/core/tracers/console'
|
|
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation'
|
|
import { formatResponse } from '../../outputparsers/OutputParserHelpers'
|
|
|
|
class GraphCypherQA_Chain implements INode {
|
|
label: string
|
|
name: string
|
|
version: number
|
|
type: string
|
|
icon: string
|
|
category: string
|
|
description: string
|
|
baseClasses: string[]
|
|
inputs: INodeParams[]
|
|
sessionId?: string
|
|
outputs: INodeOutputsValue[]
|
|
|
|
constructor(fields?: { sessionId?: string }) {
|
|
this.label = 'Graph Cypher QA Chain'
|
|
this.name = 'graphCypherQAChain'
|
|
this.version = 1.1
|
|
this.type = 'GraphCypherQAChain'
|
|
this.icon = 'graphqa.svg'
|
|
this.category = 'Chains'
|
|
this.description = 'Advanced chain for question-answering against a Neo4j graph by generating Cypher statements'
|
|
this.baseClasses = [this.type, ...getBaseClasses(GraphCypherQAChain)]
|
|
this.sessionId = fields?.sessionId
|
|
this.inputs = [
|
|
{
|
|
label: 'Language Model',
|
|
name: 'model',
|
|
type: 'BaseLanguageModel',
|
|
description: 'Model for generating Cypher queries and answers.'
|
|
},
|
|
{
|
|
label: 'Neo4j Graph',
|
|
name: 'graph',
|
|
type: 'Neo4j'
|
|
},
|
|
{
|
|
label: 'Cypher Generation Prompt',
|
|
name: 'cypherPrompt',
|
|
optional: true,
|
|
type: 'BasePromptTemplate',
|
|
description:
|
|
'Prompt template for generating Cypher queries. Must include {schema} and {question} variables. If not provided, default prompt will be used.'
|
|
},
|
|
{
|
|
label: 'Cypher Generation Model',
|
|
name: 'cypherModel',
|
|
optional: true,
|
|
type: 'BaseLanguageModel',
|
|
description: 'Model for generating Cypher queries. If not provided, the main model will be used.'
|
|
},
|
|
{
|
|
label: 'QA Prompt',
|
|
name: 'qaPrompt',
|
|
optional: true,
|
|
type: 'BasePromptTemplate',
|
|
description:
|
|
'Prompt template for generating answers. Must include {context} and {question} variables. If not provided, default prompt will be used.'
|
|
},
|
|
{
|
|
label: 'QA Model',
|
|
name: 'qaModel',
|
|
optional: true,
|
|
type: 'BaseLanguageModel',
|
|
description: 'Model for generating answers. If not provided, the main model will be used.'
|
|
},
|
|
{
|
|
label: 'Input Moderation',
|
|
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
|
|
name: 'inputModeration',
|
|
type: 'Moderation',
|
|
optional: true,
|
|
list: true
|
|
},
|
|
{
|
|
label: 'Return Direct',
|
|
name: 'returnDirect',
|
|
type: 'boolean',
|
|
default: false,
|
|
optional: true,
|
|
description: 'If true, return the raw query results instead of using the QA chain'
|
|
}
|
|
]
|
|
this.outputs = [
|
|
{
|
|
label: 'Graph Cypher QA Chain',
|
|
name: 'graphCypherQAChain',
|
|
baseClasses: [this.type, ...getBaseClasses(GraphCypherQAChain)]
|
|
},
|
|
{
|
|
label: 'Output Prediction',
|
|
name: 'outputPrediction',
|
|
baseClasses: ['string', 'json']
|
|
}
|
|
]
|
|
}
|
|
|
|
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
|
|
const model = nodeData.inputs?.model
|
|
const cypherModel = nodeData.inputs?.cypherModel
|
|
const qaModel = nodeData.inputs?.qaModel
|
|
const graph = nodeData.inputs?.graph
|
|
const cypherPrompt = nodeData.inputs?.cypherPrompt as BasePromptTemplate | FewShotPromptTemplate | undefined
|
|
const qaPrompt = nodeData.inputs?.qaPrompt as BasePromptTemplate | undefined
|
|
const returnDirect = nodeData.inputs?.returnDirect as boolean
|
|
const output = nodeData.outputs?.output as string
|
|
|
|
if (!model) {
|
|
throw new Error('Language Model is required')
|
|
}
|
|
|
|
// Handle prompt values if they exist
|
|
let cypherPromptTemplate: PromptTemplate | FewShotPromptTemplate | undefined
|
|
let qaPromptTemplate: PromptTemplate | undefined
|
|
|
|
if (cypherPrompt) {
|
|
if (cypherPrompt instanceof PromptTemplate) {
|
|
cypherPromptTemplate = new PromptTemplate({
|
|
template: cypherPrompt.template as string,
|
|
inputVariables: cypherPrompt.inputVariables
|
|
})
|
|
if (!qaPrompt) {
|
|
throw new Error('QA Prompt is required when Cypher Prompt is a Prompt Template')
|
|
}
|
|
} else if (cypherPrompt instanceof FewShotPromptTemplate) {
|
|
const examplePrompt = cypherPrompt.examplePrompt as PromptTemplate
|
|
cypherPromptTemplate = new FewShotPromptTemplate({
|
|
examples: cypherPrompt.examples,
|
|
examplePrompt: examplePrompt,
|
|
inputVariables: cypherPrompt.inputVariables,
|
|
prefix: cypherPrompt.prefix,
|
|
suffix: cypherPrompt.suffix,
|
|
exampleSeparator: cypherPrompt.exampleSeparator,
|
|
templateFormat: cypherPrompt.templateFormat
|
|
})
|
|
} else {
|
|
cypherPromptTemplate = cypherPrompt as PromptTemplate
|
|
}
|
|
}
|
|
|
|
if (qaPrompt instanceof PromptTemplate) {
|
|
qaPromptTemplate = new PromptTemplate({
|
|
template: qaPrompt.template as string,
|
|
inputVariables: qaPrompt.inputVariables
|
|
})
|
|
}
|
|
|
|
// Validate required variables in prompts
|
|
if (
|
|
cypherPromptTemplate &&
|
|
(!cypherPromptTemplate?.inputVariables.includes('schema') || !cypherPromptTemplate?.inputVariables.includes('question'))
|
|
) {
|
|
throw new Error('Cypher Generation Prompt must include {schema} and {question} variables')
|
|
}
|
|
|
|
const fromLLMInput: FromLLMInput = {
|
|
llm: model,
|
|
graph,
|
|
returnDirect
|
|
}
|
|
|
|
if (cypherPromptTemplate) {
|
|
fromLLMInput['cypherLLM'] = cypherModel ?? model
|
|
fromLLMInput['cypherPrompt'] = cypherPromptTemplate
|
|
}
|
|
|
|
if (qaPromptTemplate) {
|
|
fromLLMInput['qaLLM'] = qaModel ?? model
|
|
fromLLMInput['qaPrompt'] = qaPromptTemplate
|
|
}
|
|
|
|
const chain = GraphCypherQAChain.fromLLM(fromLLMInput)
|
|
|
|
if (output === this.name) {
|
|
return chain
|
|
} else if (output === 'outputPrediction') {
|
|
nodeData.instance = chain
|
|
return await this.run(nodeData, input, options)
|
|
}
|
|
|
|
return chain
|
|
}
|
|
|
|
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
|
|
const chain = nodeData.instance as GraphCypherQAChain
|
|
const moderations = nodeData.inputs?.inputModeration as Moderation[]
|
|
const returnDirect = nodeData.inputs?.returnDirect as boolean
|
|
|
|
const shouldStreamResponse = options.shouldStreamResponse
|
|
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
|
|
const chatId = options.chatId
|
|
|
|
// Handle input moderation if configured
|
|
if (moderations && moderations.length > 0) {
|
|
try {
|
|
input = await checkInputs(moderations, input)
|
|
} catch (e) {
|
|
await new Promise((resolve) => setTimeout(resolve, 500))
|
|
if (shouldStreamResponse) {
|
|
streamResponse(sseStreamer, chatId, e.message)
|
|
}
|
|
return formatResponse(e.message)
|
|
}
|
|
}
|
|
|
|
const obj = {
|
|
query: input
|
|
}
|
|
|
|
const loggerHandler = new ConsoleCallbackHandler(options.logger)
|
|
const callbackHandlers = await additionalCallbacks(nodeData, options)
|
|
let callbacks = [loggerHandler, ...callbackHandlers]
|
|
|
|
if (process.env.DEBUG === 'true') {
|
|
callbacks.push(new LCConsoleCallbackHandler())
|
|
}
|
|
|
|
try {
|
|
let response
|
|
if (shouldStreamResponse) {
|
|
if (returnDirect) {
|
|
response = await chain.invoke(obj, { callbacks })
|
|
let result = response?.result
|
|
if (typeof result === 'object') {
|
|
result = '```json\n' + JSON.stringify(result, null, 2)
|
|
}
|
|
if (result && typeof result === 'string') {
|
|
streamResponse(sseStreamer, chatId, result)
|
|
}
|
|
} else {
|
|
const handler = new CustomChainHandler(sseStreamer, chatId, 2)
|
|
callbacks.push(handler)
|
|
response = await chain.invoke(obj, { callbacks })
|
|
}
|
|
} else {
|
|
response = await chain.invoke(obj, { callbacks })
|
|
}
|
|
|
|
return formatResponse(response?.result)
|
|
} catch (error) {
|
|
console.error('Error in GraphCypherQAChain:', error)
|
|
if (shouldStreamResponse) {
|
|
streamResponse(sseStreamer, chatId, error.message)
|
|
}
|
|
return formatResponse(`Error: ${error.message}`)
|
|
}
|
|
}
|
|
}
|
|
|
|
module.exports = { nodeClass: GraphCypherQA_Chain }
|