Flowise/packages/components/evaluation/EvaluationRunTracerLlama.ts

187 lines
9.3 KiB
TypeScript

import { ChatMessage, LLMEndEvent, LLMStartEvent, LLMStreamEvent, MessageContentTextDetail, RetrievalEndEvent, Settings } from 'llamaindex'
import { EvaluationRunner } from './EvaluationRunner'
import { additionalCallbacks, ICommonObject, INodeData } from '../src'
import { RetrievalStartEvent } from 'llamaindex/dist/type/llm/types'
import { AgentEndEvent, AgentStartEvent } from 'llamaindex/dist/type/agent/types'
import { encoding_for_model } from '@dqbd/tiktoken'
import { MessageContent } from '@langchain/core/messages'
export class EvaluationRunTracerLlama {
evaluationRunId: string
static cbInit = false
static startTimes = new Map<string, number>()
static models = new Map<string, string>()
static tokenCounts = new Map<string, number>()
constructor(id: string) {
this.evaluationRunId = id
EvaluationRunTracerLlama.constructCallBacks()
}
static constructCallBacks = () => {
if (!EvaluationRunTracerLlama.cbInit) {
Settings.callbackManager.on('llm-start', (event: LLMStartEvent) => {
const evalID = (event as any).reason.parent?.caller?.evaluationRunId || (event as any).reason.caller?.evaluationRunId
if (!evalID) return
const model = (event as any).reason?.caller?.model
if (model) {
EvaluationRunTracerLlama.models.set(evalID, model)
try {
const encoding = encoding_for_model(model)
if (encoding) {
const { messages } = event.detail.payload
let tokenCount = messages.reduce((count: number, message: ChatMessage) => {
return count + encoding.encode(extractText(message.content)).length
}, 0)
EvaluationRunTracerLlama.tokenCounts.set(evalID + '_promptTokens', tokenCount)
EvaluationRunTracerLlama.tokenCounts.set(evalID + '_outputTokens', 0)
}
} catch (e) {
// catch the error and continue to work.
}
}
EvaluationRunTracerLlama.startTimes.set(evalID + '_llm', event.timeStamp)
})
Settings.callbackManager.on('llm-end', (event: LLMEndEvent) => {
this.calculateAndSetMetrics(event, 'llm')
})
Settings.callbackManager.on('llm-stream', (event: LLMStreamEvent) => {
const evalID = (event as any).reason.parent?.caller?.evaluationRunId || (event as any).reason.caller?.evaluationRunId
if (!evalID) return
const { chunk } = event.detail.payload
const { delta } = chunk
const model = (event as any).reason?.caller?.model
try {
const encoding = encoding_for_model(model)
if (encoding) {
let tokenCount = EvaluationRunTracerLlama.tokenCounts.get(evalID + '_outputTokens') || 0
tokenCount += encoding.encode(extractText(delta)).length
EvaluationRunTracerLlama.tokenCounts.set(evalID + '_outputTokens', tokenCount)
}
} catch (e) {
// catch the error and continue to work.
}
})
Settings.callbackManager.on('retrieve-start', (event: RetrievalStartEvent) => {
const evalID = (event as any).reason.parent?.caller?.evaluationRunId || (event as any).reason.caller?.evaluationRunId
if (evalID) {
EvaluationRunTracerLlama.startTimes.set(evalID + '_retriever', event.timeStamp)
}
})
Settings.callbackManager.on('retrieve-end', (event: RetrievalEndEvent) => {
this.calculateAndSetMetrics(event, 'retriever')
})
Settings.callbackManager.on('agent-start', (event: AgentStartEvent) => {
const evalID = (event as any).reason.parent?.caller?.evaluationRunId || (event as any).reason.caller?.evaluationRunId
if (evalID) {
EvaluationRunTracerLlama.startTimes.set(evalID + '_agent', event.timeStamp)
}
})
Settings.callbackManager.on('agent-end', (event: AgentEndEvent) => {
this.calculateAndSetMetrics(event, 'agent')
})
EvaluationRunTracerLlama.cbInit = true
}
}
private static calculateAndSetMetrics(event: any, label: string) {
const evalID = event.reason.parent?.caller?.evaluationRunId || event.reason.caller?.evaluationRunId
if (!evalID) return
const startTime = EvaluationRunTracerLlama.startTimes.get(evalID + '_' + label) as number
let model =
(event as any).reason?.caller?.model || (event as any).reason?.caller?.llm?.model || EvaluationRunTracerLlama.models.get(evalID)
if (event.detail.payload?.response?.message && model) {
try {
const encoding = encoding_for_model(model)
if (encoding) {
let tokenCount = EvaluationRunTracerLlama.tokenCounts.get(evalID + '_outputTokens') || 0
tokenCount += encoding.encode(event.detail.payload.response?.message?.content || '').length
EvaluationRunTracerLlama.tokenCounts.set(evalID + '_outputTokens', tokenCount)
}
} catch (e) {
// catch the error and continue to work.
}
}
// Anthropic
if (event.detail?.payload?.response?.raw?.usage) {
const usage = event.detail.payload.response.raw.usage
if (usage.output_tokens) {
const metric = {
completionTokens: usage.output_tokens,
promptTokens: usage.input_tokens,
model: model,
totalTokens: usage.input_tokens + usage.output_tokens
}
EvaluationRunner.addMetrics(evalID, JSON.stringify(metric))
} else if (usage.completion_tokens) {
const metric = {
completionTokens: usage.completion_tokens,
promptTokens: usage.prompt_tokens,
model: model,
totalTokens: usage.total_tokens
}
EvaluationRunner.addMetrics(evalID, JSON.stringify(metric))
}
} else if (event.detail?.payload?.response?.raw['amazon-bedrock-invocationMetrics']) {
const usage = event.detail?.payload?.response?.raw['amazon-bedrock-invocationMetrics']
const metric = {
completionTokens: usage.outputTokenCount,
promptTokens: usage.inputTokenCount,
model: event.detail?.payload?.response?.raw.model,
totalTokens: usage.inputTokenCount + usage.outputTokenCount
}
EvaluationRunner.addMetrics(evalID, JSON.stringify(metric))
} else {
const metric = {
[label]: (event.timeStamp - startTime).toFixed(2),
completionTokens: EvaluationRunTracerLlama.tokenCounts.get(evalID + '_outputTokens'),
promptTokens: EvaluationRunTracerLlama.tokenCounts.get(evalID + '_promptTokens'),
model: model || EvaluationRunTracerLlama.models.get(evalID) || '',
totalTokens:
(EvaluationRunTracerLlama.tokenCounts.get(evalID + '_outputTokens') || 0) +
(EvaluationRunTracerLlama.tokenCounts.get(evalID + '_promptTokens') || 0)
}
EvaluationRunner.addMetrics(evalID, JSON.stringify(metric))
}
//cleanup
EvaluationRunTracerLlama.startTimes.delete(evalID + '_' + label)
EvaluationRunTracerLlama.startTimes.delete(evalID + '_outputTokens')
EvaluationRunTracerLlama.startTimes.delete(evalID + '_promptTokens')
EvaluationRunTracerLlama.models.delete(evalID)
}
static async injectEvaluationMetadata(nodeData: INodeData, options: ICommonObject, callerObj: any) {
if (options.evaluationRunId && callerObj) {
// these are needed for evaluation runs
options.llamaIndex = true
await additionalCallbacks(nodeData, options)
Object.defineProperty(callerObj, 'evaluationRunId', {
enumerable: true,
configurable: true,
writable: true,
value: options.evaluationRunId
})
}
}
}
// from https://github.com/run-llama/LlamaIndexTS/blob/main/packages/core/src/llm/utils.ts
export function extractText(message: MessageContent): string {
if (typeof message !== 'string' && !Array.isArray(message)) {
console.warn('extractText called with non-MessageContent message, this is likely a bug.')
return `${message}`
} else if (typeof message !== 'string' && Array.isArray(message)) {
// message is of type MessageContentDetail[] - retrieve just the text parts and concatenate them
// so we can pass them to the context generator
return message
.filter((c): c is MessageContentTextDetail => c.type === 'text')
.map((c) => c.text)
.join('\n\n')
} else {
return message
}
}