From d44ab269e6320e3271a0e6918545f2a999883dba Mon Sep 17 00:00:00 2001 From: Henry Date: Wed, 17 May 2023 14:43:15 +0100 Subject: [PATCH] add baselanguagemodel --- .../nodes/agents/MRKLAgentLLM/MRKLAgentLLM.ts | 8 ++++---- .../ConversationalRetrievalQAChain.ts | 12 ++++++------ .../RetrievalQAChain/RetrievalQAChain.ts | 12 ++++++------ .../SqlDatabaseChain/SqlDatabaseChain.ts | 18 +++++++++--------- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/packages/components/nodes/agents/MRKLAgentLLM/MRKLAgentLLM.ts b/packages/components/nodes/agents/MRKLAgentLLM/MRKLAgentLLM.ts index c092fc17e..1aa233646 100644 --- a/packages/components/nodes/agents/MRKLAgentLLM/MRKLAgentLLM.ts +++ b/packages/components/nodes/agents/MRKLAgentLLM/MRKLAgentLLM.ts @@ -1,8 +1,8 @@ import { INode, INodeData, INodeParams } from '../../../src/Interface' import { initializeAgentExecutorWithOptions, AgentExecutor } from 'langchain/agents' import { Tool } from 'langchain/tools' -import { BaseLLM } from 'langchain/llms/base' import { getBaseClasses } from '../../../src/utils' +import { BaseLanguageModel } from 'langchain/base_language' class MRKLAgentLLM_Agents implements INode { label: string @@ -30,15 +30,15 @@ class MRKLAgentLLM_Agents implements INode { list: true }, { - label: 'LLM Model', + label: 'Language Model', name: 'model', - type: 'BaseLLM' + type: 'BaseLanguageModel' } ] } async init(nodeData: INodeData): Promise { - const model = nodeData.inputs?.model as BaseLLM + const model = nodeData.inputs?.model as BaseLanguageModel let tools = nodeData.inputs?.tools as Tool[] tools = tools.flat() diff --git a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts index 616e2a9b9..1c0afaec6 100644 --- a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts +++ b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts @@ -1,7 +1,7 @@ +import { BaseLanguageModel } from 'langchain/base_language' import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' import { ConversationalRetrievalQAChain } from 'langchain/chains' -import { BaseLLM } from 'langchain/llms/base' import { BaseRetriever } from 'langchain/schema' class ConversationalRetrievalQAChain_Chains implements INode { @@ -24,9 +24,9 @@ class ConversationalRetrievalQAChain_Chains implements INode { this.baseClasses = [this.type, ...getBaseClasses(ConversationalRetrievalQAChain)] this.inputs = [ { - label: 'LLM', - name: 'llm', - type: 'BaseLLM' + label: 'Language Model', + name: 'model', + type: 'BaseLanguageModel' }, { label: 'Vector Store Retriever', @@ -37,10 +37,10 @@ class ConversationalRetrievalQAChain_Chains implements INode { } async init(nodeData: INodeData): Promise { - const llm = nodeData.inputs?.llm as BaseLLM + const model = nodeData.inputs?.model as BaseLanguageModel const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever - const chain = ConversationalRetrievalQAChain.fromLLM(llm, vectorStoreRetriever) + const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever) return chain } diff --git a/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts b/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts index 2887643af..71a885ec7 100644 --- a/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts +++ b/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts @@ -1,8 +1,8 @@ import { INode, INodeData, INodeParams } from '../../../src/Interface' import { RetrievalQAChain } from 'langchain/chains' -import { BaseLLM } from 'langchain/llms/base' import { BaseRetriever } from 'langchain/schema' import { getBaseClasses } from '../../../src/utils' +import { BaseLanguageModel } from 'langchain/base_language' class RetrievalQAChain_Chains implements INode { label: string @@ -24,9 +24,9 @@ class RetrievalQAChain_Chains implements INode { this.baseClasses = [this.type, ...getBaseClasses(RetrievalQAChain)] this.inputs = [ { - label: 'LLM', - name: 'llm', - type: 'BaseLLM' + label: 'Language Model', + name: 'model', + type: 'BaseLanguageModel' }, { label: 'Vector Store Retriever', @@ -37,10 +37,10 @@ class RetrievalQAChain_Chains implements INode { } async init(nodeData: INodeData): Promise { - const llm = nodeData.inputs?.llm as BaseLLM + const model = nodeData.inputs?.model as BaseLanguageModel const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever - const chain = RetrievalQAChain.fromLLM(llm, vectorStoreRetriever) + const chain = RetrievalQAChain.fromLLM(model, vectorStoreRetriever) return chain } diff --git a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts index 99c5a2f63..5cf825578 100644 --- a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts +++ b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts @@ -3,7 +3,7 @@ import { SqlDatabaseChain, SqlDatabaseChainInput } from 'langchain/chains' import { getBaseClasses } from '../../../src/utils' import { DataSource } from 'typeorm' import { SqlDatabase } from 'langchain/sql_db' -import { BaseLLM } from 'langchain/llms/base' +import { BaseLanguageModel } from 'langchain/base_language' class SqlDatabaseChain_Chains implements INode { label: string @@ -25,9 +25,9 @@ class SqlDatabaseChain_Chains implements INode { this.baseClasses = [this.type, ...getBaseClasses(SqlDatabaseChain)] this.inputs = [ { - label: 'LLM', - name: 'llm', - type: 'BaseLLM' + label: 'Language Model', + name: 'model', + type: 'BaseLanguageModel' }, { label: 'Database', @@ -52,25 +52,25 @@ class SqlDatabaseChain_Chains implements INode { async init(nodeData: INodeData): Promise { const databaseType = nodeData.inputs?.database as 'sqlite' - const llm = nodeData.inputs?.llm as BaseLLM + const model = nodeData.inputs?.model as BaseLanguageModel const dbFilePath = nodeData.inputs?.dbFilePath - const chain = await getSQLDBChain(databaseType, dbFilePath, llm) + const chain = await getSQLDBChain(databaseType, dbFilePath, model) return chain } async run(nodeData: INodeData, input: string): Promise { const databaseType = nodeData.inputs?.database as 'sqlite' - const llm = nodeData.inputs?.llm as BaseLLM + const model = nodeData.inputs?.model as BaseLanguageModel const dbFilePath = nodeData.inputs?.dbFilePath - const chain = await getSQLDBChain(databaseType, dbFilePath, llm) + const chain = await getSQLDBChain(databaseType, dbFilePath, model) const res = await chain.run(input) return res } } -const getSQLDBChain = async (databaseType: 'sqlite', dbFilePath: string, llm: BaseLLM) => { +const getSQLDBChain = async (databaseType: 'sqlite', dbFilePath: string, llm: BaseLanguageModel) => { const datasource = new DataSource({ type: databaseType, database: dbFilePath