mirror of https://github.com/ghostfolio/ghostfolio
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
8.7 KiB
255 lines
8.7 KiB
// Schema and format checks
|
|
import { Injectable } from '@nestjs/common';
|
|
|
|
import type {
|
|
OutputValidationCorrection,
|
|
OutputValidationIssue,
|
|
OutputValidationResult,
|
|
VerificationChecker,
|
|
VerificationContext
|
|
} from './verification.interfaces';
|
|
|
|
// Currency amount: $1,234.56 or 1234.56 USD
|
|
const CURRENCY_RE =
|
|
/\$\s?[\d,]+(?:\.\d+)?|\b\d[\d,]*\.\d+\s?(?:USD|EUR|GBP|CHF|JPY|CAD|AUD)\b/g;
|
|
const CURRENCY_STRIP_RE = /[$,\s]|USD|EUR|GBP|CHF|JPY|CAD|AUD/g;
|
|
|
|
// Percentage: 12.5%
|
|
const PERCENT_RE = /\b\d+(?:\.\d+)?\s?%/g;
|
|
|
|
// Potential ticker: 1-5 uppercase letters at word boundary
|
|
const TICKER_RE = /\b[A-Z]{1,5}\b/g;
|
|
|
|
// Common uppercase words/acronyms to ignore when detecting tickers
|
|
// prettier-ignore
|
|
const NON_TICKERS = new Set([
|
|
'A','I','AM','AN','AND','ARE','AS','AT','BE','BY','DO','FOR','GO','HAS','HE',
|
|
'IF','IN','IS','IT','ME','MY','NO','NOT','OF','ON','OR','OUR','SO','THE','TO',
|
|
'UP','US','WE','ALL','ANY','BUT','CAN','DID','FEW','GOT','HAD','HAS','HER',
|
|
'HIM','HIS','HOW','ITS','LET','MAY','NEW','NOR','NOW','OFF','OLD','ONE','OUR',
|
|
'OUT','OWN','PUT','SAY','SET','SHE','TOO','USE','WAS','WHO','WHY','YET','YOU',
|
|
'ALSO','BACK','BEEN','DOES','EACH','EVEN','FROM','GIVE','INTO','JUST','KEEP',
|
|
'LIKE','LONG','LOOK','MADE','MAKE','MANY','MORE','MOST','MUCH','MUST','NEXT',
|
|
'ONLY','OVER','PART','SAID','SAME','SOME','SUCH','TAKE','THAN','THAT','THEM',
|
|
'THEN','THEY','THIS','VERY','WANT','WELL','WERE','WHAT','WHEN','WILL','WITH',
|
|
'YOUR','TOTAL','PER','NET','YTD','ETF','API','CEO','CFO','IPO','GDP','ROI',
|
|
'EPS','NAV','YOY','QOQ','MOM','USD','EUR','GBP','CHF','JPY','CAD','AUD',
|
|
]);
|
|
|
|
const MIN_LENGTH = 50;
|
|
const MAX_LENGTH = 5000;
|
|
|
|
@Injectable()
|
|
export class OutputValidator implements VerificationChecker {
|
|
public readonly stageName = 'outputValidator';
|
|
public validate(context: VerificationContext): OutputValidationResult {
|
|
const issues: OutputValidationIssue[] = [];
|
|
const corrections: OutputValidationCorrection[] = [];
|
|
const { agentResponseText: text, toolCalls } = context;
|
|
|
|
this.checkCurrencyFormatting(text, issues, corrections);
|
|
this.checkPercentageFormatting(text, issues, corrections);
|
|
this.checkSymbolReferences(text, toolCalls, issues);
|
|
this.checkResponseLength(text, issues);
|
|
this.checkDateFormatting(text, issues);
|
|
this.checkResponseCompleteness(text, toolCalls, issues);
|
|
|
|
return {
|
|
passed: issues.length === 0,
|
|
issues,
|
|
...(corrections.length > 0 ? { corrections } : {})
|
|
};
|
|
}
|
|
|
|
private checkCurrencyFormatting(
|
|
text: string,
|
|
issues: OutputValidationIssue[],
|
|
corrections: OutputValidationCorrection[]
|
|
): void {
|
|
const re = new RegExp(CURRENCY_RE.source, CURRENCY_RE.flags);
|
|
let match: RegExpExecArray | null;
|
|
while ((match = re.exec(text)) !== null) {
|
|
const raw = match[0];
|
|
const numeric = raw.replace(CURRENCY_STRIP_RE, '');
|
|
const dot = numeric.indexOf('.');
|
|
if (dot === -1) continue; // whole amount, acceptable
|
|
const decimals = numeric.length - dot - 1;
|
|
if (decimals !== 2) {
|
|
issues.push({
|
|
checkId: 'currency_format',
|
|
description: `Currency amount "${raw}" should use exactly 2 decimal places`,
|
|
severity: 'warning'
|
|
});
|
|
// Auto-correct: round to 2 decimal places
|
|
const correctedNumeric = parseFloat(numeric).toFixed(2);
|
|
const corrected = raw.replace(numeric, correctedNumeric);
|
|
corrections.push({
|
|
original: raw,
|
|
corrected,
|
|
checkId: 'currency_format'
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
private checkPercentageFormatting(
|
|
text: string,
|
|
issues: OutputValidationIssue[],
|
|
corrections: OutputValidationCorrection[]
|
|
): void {
|
|
const re = new RegExp(PERCENT_RE.source, PERCENT_RE.flags);
|
|
let match: RegExpExecArray | null;
|
|
while ((match = re.exec(text)) !== null) {
|
|
const raw = match[0];
|
|
const numeric = raw.replace(/[%\s]/g, '');
|
|
const dot = numeric.indexOf('.');
|
|
if (dot === -1) continue; // whole percentage, acceptable
|
|
const decimals = numeric.length - dot - 1;
|
|
if (decimals < 1 || decimals > 2) {
|
|
issues.push({
|
|
checkId: 'percentage_format',
|
|
description: `Percentage "${raw}" should use 1-2 decimal places`,
|
|
severity: 'warning'
|
|
});
|
|
// Auto-correct: round to 2 decimal places
|
|
const correctedNumeric = parseFloat(numeric).toFixed(2);
|
|
const corrected = raw.replace(numeric, correctedNumeric);
|
|
corrections.push({
|
|
original: raw,
|
|
corrected,
|
|
checkId: 'percentage_format'
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
private checkSymbolReferences(
|
|
text: string,
|
|
toolCalls: VerificationContext['toolCalls'],
|
|
issues: OutputValidationIssue[]
|
|
): void {
|
|
const known = new Set<string>();
|
|
for (const call of toolCalls) {
|
|
if (call.success && call.outputData != null) {
|
|
this.extractSymbols(call.outputData, known);
|
|
}
|
|
}
|
|
if (known.size === 0) return;
|
|
|
|
const unknown = new Set<string>();
|
|
const tickerRe = new RegExp(TICKER_RE.source, TICKER_RE.flags);
|
|
let tickerMatch: RegExpExecArray | null;
|
|
while ((tickerMatch = tickerRe.exec(text)) !== null) {
|
|
const t = tickerMatch[0];
|
|
if (t.length >= 2 && !NON_TICKERS.has(t) && !known.has(t)) {
|
|
unknown.add(t);
|
|
}
|
|
}
|
|
unknown.forEach((sym) => {
|
|
issues.push({
|
|
checkId: 'symbol_reference',
|
|
description: `Symbol "${sym}" referenced in response was not found in tool result data`,
|
|
severity: 'warning'
|
|
});
|
|
});
|
|
}
|
|
|
|
private extractSymbols(data: unknown, out: Set<string>): void {
|
|
if (data == null) return;
|
|
if (typeof data === 'string') {
|
|
const re = new RegExp(TICKER_RE.source, TICKER_RE.flags);
|
|
let m: RegExpExecArray | null;
|
|
while ((m = re.exec(data)) !== null) {
|
|
if (!NON_TICKERS.has(m[0])) out.add(m[0]);
|
|
}
|
|
return;
|
|
}
|
|
if (Array.isArray(data)) {
|
|
for (const item of data) this.extractSymbols(item, out);
|
|
return;
|
|
}
|
|
if (typeof data === 'object') {
|
|
const obj = data as Record<string, unknown>;
|
|
for (const key of ['symbol', 'ticker', 'code', 'name']) {
|
|
if (typeof obj[key] === 'string') out.add(obj[key] as string);
|
|
}
|
|
for (const val of Object.values(obj)) this.extractSymbols(val, out);
|
|
}
|
|
}
|
|
|
|
private checkResponseLength(
|
|
text: string,
|
|
issues: OutputValidationIssue[]
|
|
): void {
|
|
const len = text.length;
|
|
if (len < MIN_LENGTH) {
|
|
issues.push({
|
|
checkId: 'response_length',
|
|
description: `Response length (${len} chars) is below minimum of ${MIN_LENGTH} characters`,
|
|
severity: 'warning'
|
|
});
|
|
} else if (len > MAX_LENGTH) {
|
|
issues.push({
|
|
checkId: 'response_length',
|
|
description: `Response length (${len} chars) exceeds maximum of ${MAX_LENGTH} characters`,
|
|
severity: 'warning'
|
|
});
|
|
}
|
|
}
|
|
|
|
private checkDateFormatting(
|
|
text: string,
|
|
issues: OutputValidationIssue[]
|
|
): void {
|
|
const isoDatePattern = /\b\d{4}-\d{2}-\d{2}\b/g;
|
|
const slashDatePattern = /\b\d{1,2}\/\d{1,2}\/\d{2,4}\b/g;
|
|
const hasIso = isoDatePattern.test(text);
|
|
const hasSlash = slashDatePattern.test(text);
|
|
if (hasIso && hasSlash) {
|
|
issues.push({
|
|
checkId: 'date_format',
|
|
description:
|
|
'Mixed date formats detected (ISO and slash-separated). Use consistent formatting.',
|
|
severity: 'warning'
|
|
});
|
|
}
|
|
}
|
|
|
|
private checkResponseCompleteness(
|
|
text: string,
|
|
toolCalls: VerificationContext['toolCalls'],
|
|
issues: OutputValidationIssue[]
|
|
): void {
|
|
for (const call of toolCalls) {
|
|
if (!call.success || call.outputData == null) continue;
|
|
if (Array.isArray(call.outputData)) {
|
|
const holdingCount = call.outputData.length;
|
|
if (holdingCount > 0 && holdingCount <= 20) {
|
|
// Count how many items from the array are referenced
|
|
const symbols = new Set<string>();
|
|
for (const item of call.outputData) {
|
|
if (item && typeof item === 'object') {
|
|
const sym =
|
|
(item as Record<string, unknown>).symbol ??
|
|
(item as Record<string, unknown>).name;
|
|
if (typeof sym === 'string') symbols.add(sym);
|
|
}
|
|
}
|
|
if (symbols.size > 0) {
|
|
let referenced = 0;
|
|
symbols.forEach((sym) => {
|
|
if (text.includes(sym)) referenced++;
|
|
});
|
|
if (referenced > 0 && referenced < symbols.size * 0.5) {
|
|
issues.push({
|
|
checkId: 'response_completeness',
|
|
description: `Response references ${referenced} of ${symbols.size} items from tool results. Some items may be missing.`,
|
|
severity: 'warning'
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|