From c78ca494f57d1c6c8193d306b7b9b92942b66bfd Mon Sep 17 00:00:00 2001 From: vjsai Date: Thu, 6 Jul 2023 18:28:33 +0530 Subject: [PATCH] Added OpenAPI chain --- .../nodes/chains/ApiChain/OpenAPIChain.ts | 84 +++++++++++++++++++ .../SqlDatabaseChain/SqlDatabaseChain.ts | 2 +- packages/components/package.json | 2 +- 3 files changed, 86 insertions(+), 2 deletions(-) create mode 100644 packages/components/nodes/chains/ApiChain/OpenAPIChain.ts diff --git a/packages/components/nodes/chains/ApiChain/OpenAPIChain.ts b/packages/components/nodes/chains/ApiChain/OpenAPIChain.ts new file mode 100644 index 000000000..f0656f6dd --- /dev/null +++ b/packages/components/nodes/chains/ApiChain/OpenAPIChain.ts @@ -0,0 +1,84 @@ +import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { APIChain, createOpenAPIChain } from 'langchain/chains' +import { CustomChainHandler, getBaseClasses } from '../../../src/utils' +import { ChatOpenAI } from 'langchain/chat_models/openai' + +class OpenApiChain_Chains implements INode { + label: string + name: string + type: string + icon: string + category: string + baseClasses: string[] + description: string + inputs: INodeParams[] + + constructor() { + this.label = 'OpenAPI Chain' + this.name = 'openApiChain' + this.type = 'openApiChain' + this.icon = 'apichain.svg' + this.category = 'Chains' + this.description = 'Chain to run queries against OpenAPI' + this.baseClasses = [this.type, ...getBaseClasses(APIChain)] + this.inputs = [ + { + label: 'ChatOpenAI Model', + name: 'model', + type: 'ChatOpenAI' + }, + { + label: 'YAML File', + name: 'yamlFile', + type: 'file', + fileType: '.yaml' + }, + { + label: 'Headers', + name: 'headers', + type: 'json', + additionalParams: true, + optional: true + } + ] + } + + async init(nodeData: INodeData): Promise { + const model = nodeData.inputs?.model as ChatOpenAI + const headers = nodeData.inputs?.headers as Record + const yamlFileBase64 = nodeData.inputs?.yamlFile as string + const splitDataURI = yamlFileBase64.split(',') + splitDataURI.pop() + const bf = Buffer.from(splitDataURI.pop() || '', 'base64') + const utf8String = bf.toString('utf-8') + const chain = await createOpenAPIChain(utf8String, { + llm: model, + headers + }) + return chain + } + + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { + const model = nodeData.inputs?.model as ChatOpenAI + const headers = nodeData.inputs?.headers as Record + const yamlFileBase64 = nodeData.inputs?.yamlFile as string + const splitDataURI = yamlFileBase64.split(',') + splitDataURI.pop() + const bf = Buffer.from(splitDataURI.pop() || '', 'base64') + const utf8String = bf.toString('utf-8') + const chain = await createOpenAPIChain(utf8String, { + llm: model, + headers + }) + if (options.socketIO && options.socketIOClientId) { + const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, 2) + const res = await chain.run(input, [handler]) + return res + } else { + const res = await chain.run(input) + return res + } + } +} + +module.exports = { nodeClass: OpenApiChain_Chains } diff --git a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts index 27a245e1a..441ec4cd2 100644 --- a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts +++ b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts @@ -1,5 +1,5 @@ import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' -import { SqlDatabaseChain, SqlDatabaseChainInput } from 'langchain/chains' +import { SqlDatabaseChain, SqlDatabaseChainInput } from 'langchain/chains/sql_db' import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { DataSource } from 'typeorm' import { SqlDatabase } from 'langchain/sql_db' diff --git a/packages/components/package.json b/packages/components/package.json index a5f03e10e..3e5f540b9 100644 --- a/packages/components/package.json +++ b/packages/components/package.json @@ -35,7 +35,7 @@ "form-data": "^4.0.0", "graphql": "^16.6.0", "html-to-text": "^9.0.5", - "langchain": "^0.0.96", + "langchain": "^0.0.103", "linkifyjs": "^4.1.1", "mammoth": "^1.5.1", "moment": "^2.29.3",