Flowise/packages/components/nodes/llms/GPT4All/core.ts

162 lines
4.9 KiB
TypeScript

import { spawn } from 'child_process'
import { BaseLLM, BaseLLMParams } from 'langchain/llms/base'
import { Generation, LLMResult } from 'langchain/schema'
export class GPT4All extends BaseLLM {
executablePath: string
modelPath: string
promptTimeout: number
client: ReturnType<typeof spawn> | null = null
constructor(
config: {
executablePath?: string
modelPath?: string
promptTimeout?: number
} & BaseLLMParams
) {
super(config ?? {})
const executablePath = config.executablePath
if (!executablePath) {
throw new Error(`Executable path must be provided`)
}
this.executablePath = executablePath
const modelPath = config.modelPath
if (!modelPath) {
throw new Error(`Model path must be provided`)
}
this.modelPath = modelPath
this.promptTimeout = Math.min(1000, config.promptTimeout || 1000)
}
close(): void {
if (this.client !== null) {
this.client.kill()
this.client = null
}
}
async open(): Promise<void> {
if (this.client !== null) {
this.close()
}
let spawnArgs = [this.executablePath, '--model', this.modelPath]
this.client = spawn(spawnArgs[0], spawnArgs.slice(1), { stdio: ['pipe', 'pipe', 'ignore'] })
// wait for the bot to be ready
await new Promise((resolve) => {
this.client?.stdout?.on('data', (data) => {
if (data.toString().includes('>')) {
resolve(true)
}
})
})
}
_llmType(): string {
return 'gpt4all'
}
/**
* Call out to GPT4All's generate method.
*
* @param prompt - The prompt to pass into the model.
* @param stop - Optional list of stop words to use when generating.
*
* @returns the full LLM response.
*
* @example
* ```ts
* import { GPT4All } from "./Gpt4All.ts";
* const gpt4All = new GPT4All();
* const response = await gpt4All.call("Tell me a joke.")
* ```
*/
async _call(prompt: string, _stop?: string[]): Promise<string> {
await this.open()
const response = await sendMessageAndWaitForResult(this.client, prompt, this.promptTimeout)
this.close()
// eslint-disable-next-line
const ansiEscapeSequences = /[\u001b\u009b][[()#;?]*(?:[0-9]{1,4}(?:;[0-9]{0,4})*)?[0-9A-ORZcf-nqry=><]/g
const finalResponse = response.replace(ansiEscapeSequences, '')
return finalResponse
}
/**
* Call out to GPT4All's generate method.
*
* @param prompts - The prompts to pass into the model.
* @param stop - Optional list of stop words to use when generating.
*
* @returns the full LLM output.
*
* @example
* ```ts
* import { GPT4All } from "./Gpt4All.ts";
* const gpt4All = new GPT4All();
* const response = await gpt4All.generate(["Tell me a joke."])
* ```
*/
async _generate(prompts: Array<string>, stop?: string[]): Promise<LLMResult> {
const generations: Array<Array<Generation>> = []
for await (const prompt of prompts) {
const result = await this._call(prompt, stop)
generations.push([{ text: result }])
}
return { generations }
}
}
function sendMessageAndWaitForResult(client: any, prompt: string, promptTimeout: number): Promise<string> {
if (client === null) {
throw new Error('Client is not initialized.')
}
client.stdin.write(prompt + '\n')
return new Promise((resolve, reject) => {
let response: string = ''
let timeoutId: NodeJS.Timeout
const onStdoutData = (data: Buffer) => {
const text = data.toString()
if (timeoutId) {
clearTimeout(timeoutId)
}
if (text.includes('>')) {
terminateAndResolve(response) // Remove the trailing "\f" delimiter
} else {
timeoutId = setTimeout(() => {
terminateAndResolve(response)
}, promptTimeout) // Set a timeout of
}
response += text
}
const onStdoutError = (err: Error) => {
client.stdout.removeListener('data', onStdoutData)
client.stdout.removeListener('error', onStdoutError)
reject(err)
}
const terminateAndResolve = (finalResponse: string) => {
client.stdout.removeListener('data', onStdoutData)
client.stdout.removeListener('error', onStdoutError)
// check for > at the end and remove it
if (finalResponse.endsWith('>')) {
finalResponse = finalResponse.slice(0, -1)
}
resolve(finalResponse)
}
client.stdout.on('data', onStdoutData)
client.stdout.on('error', onStdoutError)
})
}