187 lines
9.3 KiB
TypeScript
187 lines
9.3 KiB
TypeScript
import { ChatMessage, LLMEndEvent, LLMStartEvent, LLMStreamEvent, MessageContentTextDetail, RetrievalEndEvent, Settings } from 'llamaindex'
|
|
import { EvaluationRunner } from './EvaluationRunner'
|
|
import { additionalCallbacks, ICommonObject, INodeData } from '../src'
|
|
import { RetrievalStartEvent } from 'llamaindex/dist/type/llm/types'
|
|
import { AgentEndEvent, AgentStartEvent } from 'llamaindex/dist/type/agent/types'
|
|
import { encoding_for_model } from '@dqbd/tiktoken'
|
|
import { MessageContent } from '@langchain/core/messages'
|
|
|
|
export class EvaluationRunTracerLlama {
|
|
evaluationRunId: string
|
|
static cbInit = false
|
|
static startTimes = new Map<string, number>()
|
|
static models = new Map<string, string>()
|
|
static tokenCounts = new Map<string, number>()
|
|
|
|
constructor(id: string) {
|
|
this.evaluationRunId = id
|
|
EvaluationRunTracerLlama.constructCallBacks()
|
|
}
|
|
|
|
static constructCallBacks = () => {
|
|
if (!EvaluationRunTracerLlama.cbInit) {
|
|
Settings.callbackManager.on('llm-start', (event: LLMStartEvent) => {
|
|
const evalID = (event as any).reason.parent?.caller?.evaluationRunId || (event as any).reason.caller?.evaluationRunId
|
|
if (!evalID) return
|
|
const model = (event as any).reason?.caller?.model
|
|
if (model) {
|
|
EvaluationRunTracerLlama.models.set(evalID, model)
|
|
try {
|
|
const encoding = encoding_for_model(model)
|
|
if (encoding) {
|
|
const { messages } = event.detail.payload
|
|
let tokenCount = messages.reduce((count: number, message: ChatMessage) => {
|
|
return count + encoding.encode(extractText(message.content)).length
|
|
}, 0)
|
|
EvaluationRunTracerLlama.tokenCounts.set(evalID + '_promptTokens', tokenCount)
|
|
EvaluationRunTracerLlama.tokenCounts.set(evalID + '_outputTokens', 0)
|
|
}
|
|
} catch (e) {
|
|
// catch the error and continue to work.
|
|
}
|
|
}
|
|
EvaluationRunTracerLlama.startTimes.set(evalID + '_llm', event.timeStamp)
|
|
})
|
|
Settings.callbackManager.on('llm-end', (event: LLMEndEvent) => {
|
|
this.calculateAndSetMetrics(event, 'llm')
|
|
})
|
|
Settings.callbackManager.on('llm-stream', (event: LLMStreamEvent) => {
|
|
const evalID = (event as any).reason.parent?.caller?.evaluationRunId || (event as any).reason.caller?.evaluationRunId
|
|
if (!evalID) return
|
|
const { chunk } = event.detail.payload
|
|
const { delta } = chunk
|
|
const model = (event as any).reason?.caller?.model
|
|
try {
|
|
const encoding = encoding_for_model(model)
|
|
if (encoding) {
|
|
let tokenCount = EvaluationRunTracerLlama.tokenCounts.get(evalID + '_outputTokens') || 0
|
|
tokenCount += encoding.encode(extractText(delta)).length
|
|
EvaluationRunTracerLlama.tokenCounts.set(evalID + '_outputTokens', tokenCount)
|
|
}
|
|
} catch (e) {
|
|
// catch the error and continue to work.
|
|
}
|
|
})
|
|
Settings.callbackManager.on('retrieve-start', (event: RetrievalStartEvent) => {
|
|
const evalID = (event as any).reason.parent?.caller?.evaluationRunId || (event as any).reason.caller?.evaluationRunId
|
|
if (evalID) {
|
|
EvaluationRunTracerLlama.startTimes.set(evalID + '_retriever', event.timeStamp)
|
|
}
|
|
})
|
|
Settings.callbackManager.on('retrieve-end', (event: RetrievalEndEvent) => {
|
|
this.calculateAndSetMetrics(event, 'retriever')
|
|
})
|
|
Settings.callbackManager.on('agent-start', (event: AgentStartEvent) => {
|
|
const evalID = (event as any).reason.parent?.caller?.evaluationRunId || (event as any).reason.caller?.evaluationRunId
|
|
if (evalID) {
|
|
EvaluationRunTracerLlama.startTimes.set(evalID + '_agent', event.timeStamp)
|
|
}
|
|
})
|
|
Settings.callbackManager.on('agent-end', (event: AgentEndEvent) => {
|
|
this.calculateAndSetMetrics(event, 'agent')
|
|
})
|
|
EvaluationRunTracerLlama.cbInit = true
|
|
}
|
|
}
|
|
|
|
private static calculateAndSetMetrics(event: any, label: string) {
|
|
const evalID = event.reason.parent?.caller?.evaluationRunId || event.reason.caller?.evaluationRunId
|
|
if (!evalID) return
|
|
const startTime = EvaluationRunTracerLlama.startTimes.get(evalID + '_' + label) as number
|
|
let model =
|
|
(event as any).reason?.caller?.model || (event as any).reason?.caller?.llm?.model || EvaluationRunTracerLlama.models.get(evalID)
|
|
|
|
if (event.detail.payload?.response?.message && model) {
|
|
try {
|
|
const encoding = encoding_for_model(model)
|
|
if (encoding) {
|
|
let tokenCount = EvaluationRunTracerLlama.tokenCounts.get(evalID + '_outputTokens') || 0
|
|
tokenCount += encoding.encode(event.detail.payload.response?.message?.content || '').length
|
|
EvaluationRunTracerLlama.tokenCounts.set(evalID + '_outputTokens', tokenCount)
|
|
}
|
|
} catch (e) {
|
|
// catch the error and continue to work.
|
|
}
|
|
}
|
|
|
|
// Anthropic
|
|
if (event.detail?.payload?.response?.raw?.usage) {
|
|
const usage = event.detail.payload.response.raw.usage
|
|
if (usage.output_tokens) {
|
|
const metric = {
|
|
completionTokens: usage.output_tokens,
|
|
promptTokens: usage.input_tokens,
|
|
model: model,
|
|
totalTokens: usage.input_tokens + usage.output_tokens
|
|
}
|
|
EvaluationRunner.addMetrics(evalID, JSON.stringify(metric))
|
|
} else if (usage.completion_tokens) {
|
|
const metric = {
|
|
completionTokens: usage.completion_tokens,
|
|
promptTokens: usage.prompt_tokens,
|
|
model: model,
|
|
totalTokens: usage.total_tokens
|
|
}
|
|
EvaluationRunner.addMetrics(evalID, JSON.stringify(metric))
|
|
}
|
|
} else if (event.detail?.payload?.response?.raw['amazon-bedrock-invocationMetrics']) {
|
|
const usage = event.detail?.payload?.response?.raw['amazon-bedrock-invocationMetrics']
|
|
const metric = {
|
|
completionTokens: usage.outputTokenCount,
|
|
promptTokens: usage.inputTokenCount,
|
|
model: event.detail?.payload?.response?.raw.model,
|
|
totalTokens: usage.inputTokenCount + usage.outputTokenCount
|
|
}
|
|
EvaluationRunner.addMetrics(evalID, JSON.stringify(metric))
|
|
} else {
|
|
const metric = {
|
|
[label]: (event.timeStamp - startTime).toFixed(2),
|
|
completionTokens: EvaluationRunTracerLlama.tokenCounts.get(evalID + '_outputTokens'),
|
|
promptTokens: EvaluationRunTracerLlama.tokenCounts.get(evalID + '_promptTokens'),
|
|
model: model || EvaluationRunTracerLlama.models.get(evalID) || '',
|
|
totalTokens:
|
|
(EvaluationRunTracerLlama.tokenCounts.get(evalID + '_outputTokens') || 0) +
|
|
(EvaluationRunTracerLlama.tokenCounts.get(evalID + '_promptTokens') || 0)
|
|
}
|
|
EvaluationRunner.addMetrics(evalID, JSON.stringify(metric))
|
|
}
|
|
|
|
//cleanup
|
|
EvaluationRunTracerLlama.startTimes.delete(evalID + '_' + label)
|
|
EvaluationRunTracerLlama.startTimes.delete(evalID + '_outputTokens')
|
|
EvaluationRunTracerLlama.startTimes.delete(evalID + '_promptTokens')
|
|
EvaluationRunTracerLlama.models.delete(evalID)
|
|
}
|
|
|
|
static async injectEvaluationMetadata(nodeData: INodeData, options: ICommonObject, callerObj: any) {
|
|
if (options.evaluationRunId && callerObj) {
|
|
// these are needed for evaluation runs
|
|
options.llamaIndex = true
|
|
await additionalCallbacks(nodeData, options)
|
|
Object.defineProperty(callerObj, 'evaluationRunId', {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: options.evaluationRunId
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// from https://github.com/run-llama/LlamaIndexTS/blob/main/packages/core/src/llm/utils.ts
|
|
export function extractText(message: MessageContent): string {
|
|
if (typeof message !== 'string' && !Array.isArray(message)) {
|
|
console.warn('extractText called with non-MessageContent message, this is likely a bug.')
|
|
return `${message}`
|
|
} else if (typeof message !== 'string' && Array.isArray(message)) {
|
|
// message is of type MessageContentDetail[] - retrieve just the text parts and concatenate them
|
|
// so we can pass them to the context generator
|
|
return message
|
|
.filter((c): c is MessageContentTextDetail => c.type === 'text')
|
|
.map((c) => c.text)
|
|
.join('\n\n')
|
|
} else {
|
|
return message
|
|
}
|
|
}
|