Feature/Add Neo4j GraphRag support (#3686)

* added: Neo4j database connectivity, Neo4j credentials, supports the usage of the GraphCypherQaChain node and modifies the FewShotPromptTemplate node to handle variables from the prefix field.

* Merge branch 'main' of github.com:FlowiseAI/Flowise into feature/graphragsupport

* revert pnpm-lock.yaml

* add: neo4j package

* Refactor GraphCypherQAChain: Update version to 1.0, remove memory input, and enhance prompt handling

- Changed version from 2.0 to 1.0.
- Removed the 'Memory' input parameter from the GraphCypherQAChain.
- Made 'cypherPrompt' optional and improved error handling for prompt validation.
- Updated the 'init' and 'run' methods to streamline input processing and response handling.
- Enhanced streaming response logic based on the 'returnDirect' flag.

* Refactor GraphCypherQAChain: Simplify imports and update init method signature

- Consolidated import statements for better readability.
- Removed the 'input' and 'options' parameters from the 'init' method, streamlining its signature to only accept 'nodeData'.

* add output, format final response, fix optional inputs

---------

Co-authored-by: Henry <hzj94@hotmail.com>
This commit is contained in:
Anthony Bryan Gavilan Vinces 2024-12-22 20:35:53 -05:00 committed by GitHub
parent 93f3a5d98a
commit a7c1ab881c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 34468 additions and 34040 deletions

View File

@ -0,0 +1,39 @@
import { INodeParams, INodeCredential } from '../src/Interface'
class Neo4jApi implements INodeCredential {
label: string
name: string
version: number
description: string
inputs: INodeParams[]
constructor() {
this.label = 'Neo4j API'
this.name = 'neo4jApi'
this.version = 1.0
this.description =
'Refer to <a target="_blank" href="https://neo4j.com/docs/operations-manual/current/authentication-authorization/">official guide</a> on Neo4j authentication'
this.inputs = [
{
label: 'Neo4j URL',
name: 'url',
type: 'string',
description: 'Your Neo4j instance URL (e.g., neo4j://localhost:7687)'
},
{
label: 'Username',
name: 'username',
type: 'string',
description: 'Neo4j database username'
},
{
label: 'Password',
name: 'password',
type: 'password',
description: 'Neo4j database password'
}
]
}
}
module.exports = { credClass: Neo4jApi }

View File

@ -0,0 +1,256 @@
import { ICommonObject, INode, INodeData, INodeParams, INodeOutputsValue, IServerSideEventStreamer } from '../../../src/Interface'
import { FromLLMInput, GraphCypherQAChain } from '@langchain/community/chains/graph_qa/cypher'
import { getBaseClasses } from '../../../src/utils'
import { BasePromptTemplate, PromptTemplate, FewShotPromptTemplate } from '@langchain/core/prompts'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { ConsoleCallbackHandler as LCConsoleCallbackHandler } from '@langchain/core/tracers/console'
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation'
import { formatResponse } from '../../outputparsers/OutputParserHelpers'
class GraphCypherQA_Chain implements INode {
label: string
name: string
version: number
type: string
icon: string
category: string
description: string
baseClasses: string[]
inputs: INodeParams[]
sessionId?: string
outputs: INodeOutputsValue[]
constructor(fields?: { sessionId?: string }) {
this.label = 'Graph Cypher QA Chain'
this.name = 'graphCypherQAChain'
this.version = 1.0
this.type = 'GraphCypherQAChain'
this.icon = 'graphqa.svg'
this.category = 'Chains'
this.description = 'Advanced chain for question-answering against a Neo4j graph by generating Cypher statements'
this.baseClasses = [this.type, ...getBaseClasses(GraphCypherQAChain)]
this.sessionId = fields?.sessionId
this.inputs = [
{
label: 'Language Model',
name: 'model',
type: 'BaseLanguageModel',
description: 'Model for generating Cypher queries and answers.'
},
{
label: 'Neo4j Graph',
name: 'graph',
type: 'Neo4j'
},
{
label: 'Cypher Generation Prompt',
name: 'cypherPrompt',
optional: true,
type: 'BasePromptTemplate',
description: 'Prompt template for generating Cypher queries. Must include {schema} and {question} variables'
},
{
label: 'Cypher Generation Model',
name: 'cypherModel',
optional: true,
type: 'BaseLanguageModel',
description: 'Model for generating Cypher queries. If not provided, the main model will be used.'
},
{
label: 'QA Prompt',
name: 'qaPrompt',
optional: true,
type: 'BasePromptTemplate',
description: 'Prompt template for generating answers. Must include {context} and {question} variables'
},
{
label: 'QA Model',
name: 'qaModel',
optional: true,
type: 'BaseLanguageModel',
description: 'Model for generating answers. If not provided, the main model will be used.'
},
{
label: 'Input Moderation',
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
name: 'inputModeration',
type: 'Moderation',
optional: true,
list: true
},
{
label: 'Return Direct',
name: 'returnDirect',
type: 'boolean',
default: false,
optional: true,
description: 'If true, return the raw query results instead of using the QA chain'
}
]
this.outputs = [
{
label: 'Graph Cypher QA Chain',
name: 'graphCypherQAChain',
baseClasses: [this.type, ...getBaseClasses(GraphCypherQAChain)]
},
{
label: 'Output Prediction',
name: 'outputPrediction',
baseClasses: ['string', 'json']
}
]
}
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
const model = nodeData.inputs?.model
const cypherModel = nodeData.inputs?.cypherModel
const qaModel = nodeData.inputs?.qaModel
const graph = nodeData.inputs?.graph
const cypherPrompt = nodeData.inputs?.cypherPrompt as BasePromptTemplate | FewShotPromptTemplate | undefined
const qaPrompt = nodeData.inputs?.qaPrompt as BasePromptTemplate | undefined
const returnDirect = nodeData.inputs?.returnDirect as boolean
const output = nodeData.outputs?.output as string
// Handle prompt values if they exist
let cypherPromptTemplate: PromptTemplate | FewShotPromptTemplate | undefined
let qaPromptTemplate: PromptTemplate | undefined
if (cypherPrompt) {
if (cypherPrompt instanceof PromptTemplate) {
cypherPromptTemplate = new PromptTemplate({
template: cypherPrompt.template as string,
inputVariables: cypherPrompt.inputVariables
})
if (!qaPrompt) {
throw new Error('QA Prompt is required when Cypher Prompt is a Prompt Template')
}
} else if (cypherPrompt instanceof FewShotPromptTemplate) {
const examplePrompt = cypherPrompt.examplePrompt as PromptTemplate
cypherPromptTemplate = new FewShotPromptTemplate({
examples: cypherPrompt.examples,
examplePrompt: examplePrompt,
inputVariables: cypherPrompt.inputVariables,
prefix: cypherPrompt.prefix,
suffix: cypherPrompt.suffix,
exampleSeparator: cypherPrompt.exampleSeparator,
templateFormat: cypherPrompt.templateFormat
})
} else {
cypherPromptTemplate = cypherPrompt as PromptTemplate
}
}
if (qaPrompt instanceof PromptTemplate) {
qaPromptTemplate = new PromptTemplate({
template: qaPrompt.template as string,
inputVariables: qaPrompt.inputVariables
})
}
if ((!cypherModel || !qaModel) && !model) {
throw new Error('Language Model is required when Cypher Model or QA Model are not provided')
}
// Validate required variables in prompts
if (
cypherPromptTemplate &&
(!cypherPromptTemplate?.inputVariables.includes('schema') || !cypherPromptTemplate?.inputVariables.includes('question'))
) {
throw new Error('Cypher Generation Prompt must include {schema} and {question} variables')
}
const fromLLMInput: FromLLMInput = {
llm: model,
graph,
returnDirect
}
if (cypherModel && cypherPromptTemplate) {
fromLLMInput['cypherLLM'] = cypherModel
fromLLMInput['cypherPrompt'] = cypherPromptTemplate
}
if (qaModel && qaPromptTemplate) {
fromLLMInput['qaLLM'] = qaModel
fromLLMInput['qaPrompt'] = qaPromptTemplate
}
const chain = GraphCypherQAChain.fromLLM(fromLLMInput)
if (output === this.name) {
return chain
} else if (output === 'outputPrediction') {
nodeData.instance = chain
return await this.run(nodeData, input, options)
}
return chain
}
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
const chain = nodeData.instance as GraphCypherQAChain
const moderations = nodeData.inputs?.inputModeration as Moderation[]
const returnDirect = nodeData.inputs?.returnDirect as boolean
const shouldStreamResponse = options.shouldStreamResponse
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
const chatId = options.chatId
// Handle input moderation if configured
if (moderations && moderations.length > 0) {
try {
input = await checkInputs(moderations, input)
} catch (e) {
await new Promise((resolve) => setTimeout(resolve, 500))
if (shouldStreamResponse) {
streamResponse(sseStreamer, chatId, e.message)
}
return formatResponse(e.message)
}
}
const obj = {
query: input
}
const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbackHandlers = await additionalCallbacks(nodeData, options)
let callbacks = [loggerHandler, ...callbackHandlers]
if (process.env.DEBUG === 'true') {
callbacks.push(new LCConsoleCallbackHandler())
}
try {
let response
if (shouldStreamResponse) {
if (returnDirect) {
response = await chain.invoke(obj, { callbacks })
let result = response?.result
if (typeof result === 'object') {
result = '```json\n' + JSON.stringify(result, null, 2)
}
if (result && typeof result === 'string') {
streamResponse(sseStreamer, chatId, result)
}
} else {
const handler = new CustomChainHandler(sseStreamer, chatId, 2)
callbacks.push(handler)
response = await chain.invoke(obj, { callbacks })
}
} else {
response = await chain.invoke(obj, { callbacks })
}
return formatResponse(response?.result)
} catch (error) {
console.error('Error in GraphCypherQAChain:', error)
if (shouldStreamResponse) {
streamResponse(sseStreamer, chatId, error.message)
}
return formatResponse(`Error: ${error.message}`)
}
}
}
module.exports = { nodeClass: GraphCypherQA_Chain }

View File

@ -0,0 +1,22 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="24px" height="24px" viewBox="0 0 24 24" version="1.1" xmlns="http://www.w3.org/2000/svg">
<g stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<!-- Graph Nodes and Edges -->
<circle fill="#4CAF50" cx="6" cy="6" r="2.5"/>
<circle fill="#4CAF50" cx="18" cy="6" r="2.5"/>
<circle fill="#4CAF50" cx="6" cy="18" r="2.5"/>
<circle fill="#4CAF50" cx="18" cy="18" r="2.5"/>
<!-- Graph Connections -->
<line x1="6" y1="6" x2="18" y2="6" stroke="#4CAF50" stroke-width="1.5"/>
<line x1="6" y1="6" x2="6" y2="18" stroke="#4CAF50" stroke-width="1.5"/>
<line x1="18" y1="6" x2="18" y2="18" stroke="#4CAF50" stroke-width="1.5"/>
<line x1="6" y1="18" x2="18" y2="18" stroke="#4CAF50" stroke-width="1.5"/>
<!-- Question Mark -->
<path d="M12,8 C13.1045695,8 14,8.8954305 14,10 C14,10.7403567 13.5978014,11.3866184 13,11.7324555 L13,13 C13,13.5522847 12.5522847,14 12,14 C11.4477153,14 11,13.5522847 11,13 L11,11 C11,10.4477153 11.4477153,10 12,10 C12.5522847,10 13,10.4477153 13,11 C13,11.5522847 12.5522847,12 12,12"
fill="#2196F3"
fill-rule="nonzero"/>
<circle fill="#2196F3" cx="12" cy="16" r="1"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@ -0,0 +1,80 @@
import { getBaseClasses, getCredentialData } from '../../../src/utils'
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
import { Neo4jGraph } from '@langchain/community/graphs/neo4j_graph'
class Neo4j_Graphs implements INode {
label: string
name: string
version: number
type: string
icon: string
category: string
description: string
baseClasses: string[]
credential: INodeParams
inputs: INodeParams[]
constructor() {
this.label = 'Neo4j'
this.name = 'Neo4j'
this.version = 1.0
this.type = 'Neo4j'
this.icon = 'neo4j.svg'
this.category = 'Graph'
this.description = 'Connect with Neo4j graph database'
this.baseClasses = [this.type, ...getBaseClasses(Neo4jGraph)]
this.credential = {
label: 'Connect Credential',
name: 'credential',
type: 'credential',
credentialNames: ['neo4jApi']
}
this.inputs = [
{
label: 'Database',
name: 'database',
type: 'string',
placeholder: 'neo4j',
optional: true
},
{
label: 'Timeout (ms)',
name: 'timeoutMs',
type: 'number',
default: 5000,
optional: true
},
{
label: 'Enhanced Schema',
name: 'enhancedSchema',
type: 'boolean',
default: false,
optional: true
}
]
}
async init(nodeData: INodeData, _: string, options: ICommonObject): Promise<any> {
const database = nodeData.inputs?.database as string
const timeoutMs = nodeData.inputs?.timeoutMs as number
const enhancedSchema = nodeData.inputs?.enhancedSchema as boolean
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
const neo4jConfig = {
url: credentialData?.url,
username: credentialData?.username,
password: credentialData?.password
}
const neo4jGraph = await Neo4jGraph.initialize({
...neo4jConfig,
...(database && { database }),
...(timeoutMs && { timeoutMs }),
...(enhancedSchema && { enhancedSchema })
})
return neo4jGraph
}
}
module.exports = { nodeClass: Neo4j_Graphs }

View File

@ -0,0 +1 @@
<?xml version="1.0" encoding="UTF-8"?> <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.1" id="Layer_1" x="0px" y="0px" viewBox="0 0 677.5 242.4" style="enable-background:new 0 0 677.5 242.4;" xml:space="preserve"> <style type="text/css"> .st0{fill:#231F20;} .st1{fill:#014063;} </style> <g> <path class="st0" d="M137.8,61.9c-35.3,0-58.9,20.5-58.9,60.4v28.4c3.5-1.7,7.3-2.6,11.4-2.6s8,1,11.5,2.7v-28.5 c0-25.8,14.2-39.1,36-39.1s36,13.3,36,39.1v62.1h22.9v-62.1C196.7,82.2,173,61.9,137.8,61.9L137.8,61.9z"></path> <path class="st0" d="M209.2,124.7c0-36.2,26.6-62.8,64.2-62.8s63.8,26.6,63.8,62.8v8.5H233.3c3.4,21.3,19.3,33.1,40.1,33.1 c15.5,0,26.3-4.8,33.3-15.2h25.4c-9.2,22.2-30.9,36.5-58.7,36.5C235.7,187.5,209.2,161,209.2,124.7L209.2,124.7z M313,112.7 c-4.6-19.1-20.3-29.5-39.6-29.5s-34.8,10.6-39.4,29.5H313z"></path> <path class="st0" d="M349.5,124.7c0-36.2,26.6-62.8,64.2-62.8s64.2,26.6,64.2,62.8c0,36.2-26.6,62.8-64.2,62.8 S349.5,161,349.5,124.7z M454.7,124.7c0-24.2-16.4-41.5-41.1-41.5s-41.1,17.4-41.1,41.5c0,24.1,16.4,41.5,41.1,41.5 S454.7,148.9,454.7,124.7z"></path> <path class="st0" d="M609.1,208.1h2.7c14.7,0,20.3-6.5,20.3-23.4V65.9H655v117.3c0,29.5-11.6,44.7-41.1,44.7h-4.8L609.1,208.1 L609.1,208.1z"></path> <path class="st0" d="M597.6,195.9h-22.9v-28.3h-58.2c-11.6,0-21.7-5.7-26.3-14.8c-4.3-8.6-3.1-18.7,3.1-27.2L545.6,57 c7.5-10.1,20.2-14.2,32.2-10.2c12,4,19.8,14.7,19.8,27.3v73.1h17.2v20.4h-17.2V195.9L597.6,195.9z M512.6,138.1 c-0.7,0.9-1.1,2.1-1.1,3.4c0,3.2,2.6,5.8,5.8,5.8h57.5V73.6c0-3.8-2.8-5.2-4-5.6c-0.5-0.1-1.2-0.3-2.1-0.3c-1.4,0-3.1,0.5-4.6,2.4 L512.6,138.1L512.6,138.1L512.6,138.1z"></path> <path class="st1" d="M24.6,125.8c-3,1.5-5.8,4-7.8,7.3c-2,3.3-2.8,6.8-2.5,10.3c0.3,6.3,3.8,12.1,9.5,15.3c5.3,3,11.3,2.3,17,1 c7-1.8,13-2.5,19.3,1.3c0,0,0,0,0.3,0c10.8,6.3,10.8,22.1,0,28.4c0,0,0,0-0.3,0c-6.3,3.8-12.3,3-19.3,1.3c-5.5-1.5-11.5-2.3-17,1 c-5.8,3.3-9,9.3-9.5,15.3c-0.3,3.5,0.5,7,2.5,10.3c2,3.3,4.5,5.8,7.8,7.3c5.5,2.8,12.3,2.8,18-0.5c5.3-3,7.8-8.8,9.3-14.3 c2-7,4.3-12.6,10.8-16.1c6.3-3.8,12.3-3,19.3-1.3c5.5,1.5,11.5,2.3,17-1c5.8-3.3,9-9.3,9.5-15.3c0-0.5,0-0.8,0-1.3 c0-0.5,0-0.8,0-1.3c-0.3-6.3-3.8-12.1-9.5-15.3c-5.3-3-11.3-2.3-17-1c-7,1.8-13,2.5-19.3-1.3c-6.3-3.8-8.8-9.1-10.8-16.1 c-1.5-5.5-4-11.1-9.3-14.3C36.8,123.1,30.1,123.1,24.6,125.8z"></path> <path class="st1" d="M643.6,17.2c-10.8,0-19.6,8.8-19.6,19.6s8.8,19.6,19.6,19.6c10.8,0,19.6-8.8,19.6-19.6S654.4,17.2,643.6,17.2z "></path> </g> </svg>

After

Width:  |  Height:  |  Size: 2.4 KiB

View File

@ -86,7 +86,7 @@ class FewShotPromptTemplate_Prompts implements INode {
const templateFormat = nodeData.inputs?.templateFormat as TemplateFormat const templateFormat = nodeData.inputs?.templateFormat as TemplateFormat
const examplePrompt = nodeData.inputs?.examplePrompt as PromptTemplate const examplePrompt = nodeData.inputs?.examplePrompt as PromptTemplate
const inputVariables = getInputVariables(suffix) const inputVariables = [...new Set([...getInputVariables(suffix), ...getInputVariables(prefix)])]
let examples: Example[] = [] let examples: Example[] = []
if (examplesStr) { if (examplesStr) {

View File

@ -107,6 +107,7 @@
"moment": "^2.29.3", "moment": "^2.29.3",
"mongodb": "6.3.0", "mongodb": "6.3.0",
"mysql2": "^3.11.3", "mysql2": "^3.11.3",
"neo4j-driver": "^5.26.0",
"node-fetch": "^2.6.11", "node-fetch": "^2.6.11",
"node-html-markdown": "^1.3.0", "node-html-markdown": "^1.3.0",
"notion-to-md": "^3.1.1", "notion-to-md": "^3.1.1",

File diff suppressed because it is too large Load Diff