146 lines
4.7 KiB
TypeScript
146 lines
4.7 KiB
TypeScript
import { LLM, type BaseLLMParams } from '@langchain/core/language_models/llms'
|
|
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
|
|
import { GenerationChunk } from '@langchain/core/outputs'
|
|
|
|
import type ReplicateInstance from 'replicate'
|
|
|
|
export interface ReplicateInput {
|
|
model: `${string}/${string}` | `${string}/${string}:${string}`
|
|
input?: {
|
|
// different models accept different inputs
|
|
[key: string]: string | number | boolean
|
|
}
|
|
apiKey?: string
|
|
promptKey?: string
|
|
}
|
|
|
|
export class Replicate extends LLM implements ReplicateInput {
|
|
lc_serializable = true
|
|
|
|
model: ReplicateInput['model']
|
|
|
|
input: ReplicateInput['input']
|
|
|
|
apiKey: string
|
|
|
|
promptKey?: string
|
|
|
|
constructor(fields: ReplicateInput & BaseLLMParams) {
|
|
super(fields)
|
|
|
|
const apiKey = fields?.apiKey
|
|
|
|
if (!apiKey) {
|
|
throw new Error('Please set the REPLICATE_API_TOKEN')
|
|
}
|
|
|
|
this.apiKey = apiKey
|
|
this.model = fields.model
|
|
this.input = fields.input ?? {}
|
|
this.promptKey = fields.promptKey
|
|
}
|
|
|
|
_llmType() {
|
|
return 'replicate'
|
|
}
|
|
|
|
/** @ignore */
|
|
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
|
|
const replicate = await this._prepareReplicate()
|
|
const input = await this._getReplicateInput(replicate, prompt)
|
|
|
|
const output = await this.caller.callWithOptions({ signal: options.signal }, () =>
|
|
replicate.run(this.model, {
|
|
input
|
|
})
|
|
)
|
|
|
|
if (typeof output === 'string') {
|
|
return output
|
|
} else if (Array.isArray(output)) {
|
|
return output.join('')
|
|
} else {
|
|
// Note this is a little odd, but the output format is not consistent
|
|
// across models, so it makes some amount of sense.
|
|
return String(output)
|
|
}
|
|
}
|
|
|
|
async *_streamResponseChunks(
|
|
prompt: string,
|
|
options: this['ParsedCallOptions'],
|
|
runManager?: CallbackManagerForLLMRun
|
|
): AsyncGenerator<GenerationChunk> {
|
|
const replicate = await this._prepareReplicate()
|
|
const input = await this._getReplicateInput(replicate, prompt)
|
|
|
|
const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () =>
|
|
replicate.stream(this.model, {
|
|
input
|
|
})
|
|
)
|
|
for await (const chunk of stream) {
|
|
if (chunk.event === 'output') {
|
|
yield new GenerationChunk({ text: chunk.data, generationInfo: chunk })
|
|
await runManager?.handleLLMNewToken(chunk.data ?? '')
|
|
}
|
|
|
|
// stream is done
|
|
if (chunk.event === 'done')
|
|
yield new GenerationChunk({
|
|
text: '',
|
|
generationInfo: { finished: true }
|
|
})
|
|
}
|
|
}
|
|
|
|
/** @ignore */
|
|
static async imports(): Promise<{
|
|
Replicate: typeof ReplicateInstance
|
|
}> {
|
|
try {
|
|
const { default: Replicate } = await import('replicate')
|
|
return { Replicate }
|
|
} catch (e) {
|
|
throw new Error('Please install replicate as a dependency with, e.g. `yarn add replicate`')
|
|
}
|
|
}
|
|
|
|
private async _prepareReplicate(): Promise<ReplicateInstance> {
|
|
const imports = await Replicate.imports()
|
|
|
|
return new imports.Replicate({
|
|
userAgent: 'flowise',
|
|
auth: this.apiKey
|
|
})
|
|
}
|
|
|
|
private async _getReplicateInput(replicate: ReplicateInstance, prompt: string) {
|
|
if (this.promptKey === undefined) {
|
|
const [modelString, versionString] = this.model.split(':')
|
|
if (versionString) {
|
|
const version = await replicate.models.versions.get(modelString.split('/')[0], modelString.split('/')[1], versionString)
|
|
const openapiSchema = version.openapi_schema
|
|
const inputProperties: { 'x-order': number | undefined }[] = (openapiSchema as any)?.components?.schemas?.Input?.properties
|
|
if (inputProperties === undefined) {
|
|
this.promptKey = 'prompt'
|
|
} else {
|
|
const sortedInputProperties = Object.entries(inputProperties).sort(([_keyA, valueA], [_keyB, valueB]) => {
|
|
const orderA = valueA['x-order'] || 0
|
|
const orderB = valueB['x-order'] || 0
|
|
return orderA - orderB
|
|
})
|
|
this.promptKey = sortedInputProperties[0][0] ?? 'prompt'
|
|
}
|
|
} else {
|
|
this.promptKey = 'prompt'
|
|
}
|
|
}
|
|
|
|
return {
|
|
[this.promptKey!]: prompt,
|
|
...this.input
|
|
}
|
|
}
|
|
}
|