mirror of https://github.com/ghostfolio/ghostfolio
19 changed files with 2849 additions and 1 deletions
@ -0,0 +1,488 @@ |
|||
import { PortfolioService } from '@ghostfolio/api/app/portfolio/portfolio.service'; |
|||
import { RedisCacheService } from '@ghostfolio/api/app/redis-cache/redis-cache.service'; |
|||
import { DataProviderService } from '@ghostfolio/api/services/data-provider/data-provider.service'; |
|||
|
|||
import { DataSource } from '@prisma/client'; |
|||
import ms from 'ms'; |
|||
|
|||
import { |
|||
AiAgentToolCall, |
|||
AiAgentVerificationCheck |
|||
} from './ai-agent.interfaces'; |
|||
import { |
|||
AiAgentMemoryState, |
|||
MarketDataLookupResult, |
|||
PortfolioAnalysisResult, |
|||
RebalancePlanResult, |
|||
RiskAssessmentResult, |
|||
StressTestResult |
|||
} from './ai-agent.chat.interfaces'; |
|||
import { extractSymbolsFromQuery } from './ai-agent.utils'; |
|||
|
|||
const AI_AGENT_MEMORY_TTL = ms('24 hours'); |
|||
|
|||
export const AI_AGENT_MEMORY_MAX_TURNS = 10; |
|||
|
|||
export function addVerificationChecks({ |
|||
marketData, |
|||
portfolioAnalysis, |
|||
rebalancePlan, |
|||
stressTest, |
|||
toolCalls, |
|||
verification |
|||
}: { |
|||
marketData?: MarketDataLookupResult; |
|||
portfolioAnalysis?: PortfolioAnalysisResult; |
|||
rebalancePlan?: RebalancePlanResult; |
|||
stressTest?: StressTestResult; |
|||
toolCalls: AiAgentToolCall[]; |
|||
verification: AiAgentVerificationCheck[]; |
|||
}) { |
|||
if (portfolioAnalysis) { |
|||
const allocationDifference = Math.abs(portfolioAnalysis.allocationSum - 1); |
|||
|
|||
verification.push({ |
|||
check: 'numerical_consistency', |
|||
details: |
|||
allocationDifference <= 0.05 |
|||
? `Allocation sum difference is ${allocationDifference.toFixed(4)}` |
|||
: `Allocation sum difference is ${allocationDifference.toFixed(4)} (can happen with liabilities or leveraged exposure)`, |
|||
status: allocationDifference <= 0.05 ? 'passed' : 'warning' |
|||
}); |
|||
} else { |
|||
verification.push({ |
|||
check: 'numerical_consistency', |
|||
details: 'Portfolio tool did not run', |
|||
status: 'warning' |
|||
}); |
|||
} |
|||
|
|||
if (marketData) { |
|||
const unresolvedSymbols = marketData.symbolsRequested.length - |
|||
marketData.quotes.length; |
|||
|
|||
verification.push({ |
|||
check: 'market_data_coverage', |
|||
details: |
|||
unresolvedSymbols > 0 |
|||
? `${unresolvedSymbols} symbols did not resolve with quote data` |
|||
: 'All requested symbols resolved with quote data', |
|||
status: |
|||
unresolvedSymbols === 0 |
|||
? 'passed' |
|||
: marketData.quotes.length > 0 |
|||
? 'warning' |
|||
: 'failed' |
|||
}); |
|||
} |
|||
|
|||
if (rebalancePlan) { |
|||
verification.push({ |
|||
check: 'rebalance_coverage', |
|||
details: |
|||
rebalancePlan.overweightHoldings.length > 0 || |
|||
rebalancePlan.underweightHoldings.length > 0 |
|||
? `Rebalance plan found ${rebalancePlan.overweightHoldings.length} overweight and ${rebalancePlan.underweightHoldings.length} underweight holdings` |
|||
: 'No rebalance action identified from current holdings', |
|||
status: |
|||
rebalancePlan.overweightHoldings.length > 0 || |
|||
rebalancePlan.underweightHoldings.length > 0 |
|||
? 'passed' |
|||
: 'warning' |
|||
}); |
|||
} |
|||
|
|||
if (stressTest) { |
|||
verification.push({ |
|||
check: 'stress_test_coherence', |
|||
details: `Shock ${(stressTest.shockPercentage * 100).toFixed(1)}% implies drawdown ${stressTest.estimatedDrawdownInBaseCurrency.toFixed(2)}`, |
|||
status: |
|||
stressTest.estimatedDrawdownInBaseCurrency >= 0 && |
|||
stressTest.estimatedPortfolioValueAfterShock >= 0 |
|||
? 'passed' |
|||
: 'failed' |
|||
}); |
|||
} |
|||
|
|||
verification.push({ |
|||
check: 'tool_execution', |
|||
details: `${toolCalls.filter(({ status }) => {
|
|||
return status === 'success'; |
|||
}).length}/${toolCalls.length} tools executed successfully`,
|
|||
status: toolCalls.every(({ status }) => status === 'success') |
|||
? 'passed' |
|||
: 'warning' |
|||
}); |
|||
} |
|||
|
|||
export async function buildAnswer({ |
|||
generateText, |
|||
languageCode, |
|||
marketData, |
|||
memory, |
|||
portfolioAnalysis, |
|||
query, |
|||
rebalancePlan, |
|||
riskAssessment, |
|||
stressTest, |
|||
userCurrency |
|||
}: { |
|||
generateText: ({ prompt }: { prompt: string }) => Promise<{ text?: string }>; |
|||
languageCode: string; |
|||
marketData?: MarketDataLookupResult; |
|||
memory: AiAgentMemoryState; |
|||
portfolioAnalysis?: PortfolioAnalysisResult; |
|||
query: string; |
|||
rebalancePlan?: RebalancePlanResult; |
|||
riskAssessment?: RiskAssessmentResult; |
|||
stressTest?: StressTestResult; |
|||
userCurrency: string; |
|||
}) { |
|||
const fallbackSections: string[] = []; |
|||
const normalizedQuery = query.toLowerCase(); |
|||
const hasInvestmentIntent = [ |
|||
'add', |
|||
'allocat', |
|||
'buy', |
|||
'invest', |
|||
'next', |
|||
'rebalanc', |
|||
'sell', |
|||
'trim' |
|||
].some((keyword) => { |
|||
return normalizedQuery.includes(keyword); |
|||
}); |
|||
|
|||
if (memory.turns.length > 0) { |
|||
fallbackSections.push( |
|||
`Session memory applied from ${memory.turns.length} prior turn(s).` |
|||
); |
|||
} |
|||
|
|||
if (riskAssessment) { |
|||
fallbackSections.push( |
|||
`Risk concentration is ${riskAssessment.concentrationBand}. Top holding allocation is ${(riskAssessment.topHoldingAllocation * 100).toFixed(2)}% with HHI ${riskAssessment.hhi.toFixed(3)}.` |
|||
); |
|||
} |
|||
|
|||
if (rebalancePlan) { |
|||
if (rebalancePlan.overweightHoldings.length > 0) { |
|||
const topOverweight = rebalancePlan.overweightHoldings |
|||
.slice(0, 2) |
|||
.map(({ reductionNeeded, symbol }) => { |
|||
return `${symbol} trim ${(reductionNeeded * 100).toFixed(1)}pp`; |
|||
}) |
|||
.join(', '); |
|||
|
|||
fallbackSections.push(`Rebalance priority: ${topOverweight}.`); |
|||
} else { |
|||
fallbackSections.push( |
|||
'Rebalance check: no holding exceeds the current max-allocation target.' |
|||
); |
|||
} |
|||
} |
|||
|
|||
if (stressTest) { |
|||
fallbackSections.push( |
|||
`Stress test (${(stressTest.shockPercentage * 100).toFixed(0)}% downside): estimated drawdown ${stressTest.estimatedDrawdownInBaseCurrency.toFixed(2)} ${userCurrency}, projected value ${stressTest.estimatedPortfolioValueAfterShock.toFixed(2)} ${userCurrency}.` |
|||
); |
|||
} |
|||
|
|||
if (portfolioAnalysis?.holdings?.length > 0) { |
|||
const longHoldings = portfolioAnalysis.holdings |
|||
.filter(({ valueInBaseCurrency }) => { |
|||
return valueInBaseCurrency > 0; |
|||
}) |
|||
.sort((a, b) => { |
|||
return b.valueInBaseCurrency - a.valueInBaseCurrency; |
|||
}); |
|||
const totalLongValue = longHoldings.reduce((sum, { valueInBaseCurrency }) => { |
|||
return sum + valueInBaseCurrency; |
|||
}, 0); |
|||
|
|||
if (totalLongValue > 0) { |
|||
const topLongHoldingsSummary = longHoldings |
|||
.slice(0, 3) |
|||
.map(({ symbol, valueInBaseCurrency }) => { |
|||
return `${symbol} ${((valueInBaseCurrency / totalLongValue) * 100).toFixed(1)}%`; |
|||
}) |
|||
.join(', '); |
|||
|
|||
fallbackSections.push(`Largest long allocations: ${topLongHoldingsSummary}.`); |
|||
|
|||
if (hasInvestmentIntent) { |
|||
const topLongShare = longHoldings[0].valueInBaseCurrency / totalLongValue; |
|||
|
|||
if (topLongShare >= 0.35) { |
|||
fallbackSections.push( |
|||
'Next-step allocation: direct new capital to positions outside your top holding until concentration falls below 35%.' |
|||
); |
|||
} else { |
|||
fallbackSections.push( |
|||
'Next-step allocation: spread new capital across your smallest high-conviction positions to preserve diversification.' |
|||
); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
if (marketData?.quotes?.length > 0) { |
|||
const quoteSummary = marketData.quotes |
|||
.slice(0, 3) |
|||
.map(({ currency, marketPrice, symbol }) => { |
|||
return `${symbol}: ${marketPrice.toFixed(2)} ${currency}`; |
|||
}) |
|||
.join(', '); |
|||
|
|||
fallbackSections.push(`Market snapshot: ${quoteSummary}.`); |
|||
} else if (marketData?.symbolsRequested?.length > 0) { |
|||
fallbackSections.push( |
|||
`Market data request completed with limited quote coverage for: ${marketData.symbolsRequested.join(', ')}.` |
|||
); |
|||
} |
|||
|
|||
if (fallbackSections.length === 0) { |
|||
fallbackSections.push( |
|||
`Portfolio context is available. Ask about holdings, risk concentration, or symbol prices for deeper analysis.` |
|||
); |
|||
} |
|||
|
|||
const fallbackAnswer = fallbackSections.join('\n'); |
|||
const llmPrompt = [ |
|||
`You are a neutral financial assistant.`, |
|||
`User currency: ${userCurrency}`, |
|||
`Language code: ${languageCode}`, |
|||
`Query: ${query}`, |
|||
`Context summary:`, |
|||
fallbackAnswer, |
|||
`Write a concise response with actionable insight and avoid speculation.` |
|||
].join('\n'); |
|||
|
|||
try { |
|||
const generated = await generateText({ |
|||
prompt: llmPrompt |
|||
}); |
|||
|
|||
if (generated?.text?.trim()) { |
|||
return generated.text.trim(); |
|||
} |
|||
} catch {} |
|||
|
|||
return fallbackAnswer; |
|||
} |
|||
|
|||
export async function getMemory({ |
|||
redisCacheService, |
|||
sessionId, |
|||
userId |
|||
}: { |
|||
redisCacheService: RedisCacheService; |
|||
sessionId: string; |
|||
userId: string; |
|||
}): Promise<AiAgentMemoryState> { |
|||
const rawMemory = await redisCacheService.get( |
|||
getMemoryKey({ sessionId, userId }) |
|||
); |
|||
|
|||
if (!rawMemory) { |
|||
return { |
|||
turns: [] |
|||
}; |
|||
} |
|||
|
|||
try { |
|||
const parsed = JSON.parse(rawMemory) as AiAgentMemoryState; |
|||
|
|||
if (!Array.isArray(parsed?.turns)) { |
|||
return { |
|||
turns: [] |
|||
}; |
|||
} |
|||
|
|||
return parsed; |
|||
} catch { |
|||
return { |
|||
turns: [] |
|||
}; |
|||
} |
|||
} |
|||
|
|||
export function getMemoryKey({ |
|||
sessionId, |
|||
userId |
|||
}: { |
|||
sessionId: string; |
|||
userId: string; |
|||
}) { |
|||
return `ai-agent-memory-${userId}-${sessionId}`; |
|||
} |
|||
|
|||
export function resolveSymbols({ |
|||
portfolioAnalysis, |
|||
query, |
|||
symbols |
|||
}: { |
|||
portfolioAnalysis?: PortfolioAnalysisResult; |
|||
query: string; |
|||
symbols?: string[]; |
|||
}) { |
|||
const explicitSymbols = |
|||
symbols?.map((symbol) => symbol.trim().toUpperCase()).filter(Boolean) ?? []; |
|||
const extractedSymbols = extractSymbolsFromQuery(query); |
|||
|
|||
const derivedSymbols = |
|||
portfolioAnalysis?.holdings.slice(0, 3).map(({ symbol }) => symbol) ?? []; |
|||
|
|||
return Array.from( |
|||
new Set([...explicitSymbols, ...extractedSymbols, ...derivedSymbols]) |
|||
); |
|||
} |
|||
|
|||
export async function runMarketDataLookup({ |
|||
dataProviderService, |
|||
portfolioAnalysis, |
|||
symbols |
|||
}: { |
|||
dataProviderService: DataProviderService; |
|||
portfolioAnalysis?: PortfolioAnalysisResult; |
|||
symbols: string[]; |
|||
}): Promise<MarketDataLookupResult> { |
|||
const holdingsMap = new Map( |
|||
(portfolioAnalysis?.holdings ?? []).map((holding) => { |
|||
return [holding.symbol, holding]; |
|||
}) |
|||
); |
|||
|
|||
const quoteIdentifiers = symbols.map((symbol) => { |
|||
const knownHolding = holdingsMap.get(symbol); |
|||
|
|||
return { |
|||
dataSource: knownHolding?.dataSource ?? DataSource.YAHOO, |
|||
symbol |
|||
}; |
|||
}); |
|||
|
|||
const quotesBySymbol = |
|||
quoteIdentifiers.length > 0 |
|||
? await dataProviderService.getQuotes({ |
|||
items: quoteIdentifiers |
|||
}) |
|||
: {}; |
|||
|
|||
return { |
|||
quotes: symbols |
|||
.filter((symbol) => Boolean(quotesBySymbol[symbol])) |
|||
.map((symbol) => { |
|||
return { |
|||
currency: quotesBySymbol[symbol].currency, |
|||
marketPrice: quotesBySymbol[symbol].marketPrice, |
|||
marketState: quotesBySymbol[symbol].marketState, |
|||
symbol |
|||
}; |
|||
}), |
|||
symbolsRequested: symbols |
|||
}; |
|||
} |
|||
|
|||
export async function runPortfolioAnalysis({ |
|||
portfolioService, |
|||
userId |
|||
}: { |
|||
portfolioService: PortfolioService; |
|||
userId: string; |
|||
}): Promise<PortfolioAnalysisResult> { |
|||
const { holdings } = await portfolioService.getDetails({ |
|||
impersonationId: undefined, |
|||
userId |
|||
}); |
|||
const normalizedHoldings = Object.values(holdings) |
|||
.map((holding) => { |
|||
return { |
|||
allocationInPercentage: holding.allocationInPercentage ?? 0, |
|||
dataSource: holding.dataSource, |
|||
symbol: holding.symbol, |
|||
valueInBaseCurrency: holding.valueInBaseCurrency ?? 0 |
|||
}; |
|||
}) |
|||
.sort((a, b) => { |
|||
return b.valueInBaseCurrency - a.valueInBaseCurrency; |
|||
}); |
|||
|
|||
const totalValueInBaseCurrency = normalizedHoldings.reduce( |
|||
(totalValue, holding) => { |
|||
return totalValue + holding.valueInBaseCurrency; |
|||
}, |
|||
0 |
|||
); |
|||
const allocationSum = normalizedHoldings.reduce((sum, holding) => { |
|||
return sum + holding.allocationInPercentage; |
|||
}, 0); |
|||
|
|||
return { |
|||
allocationSum, |
|||
holdings: normalizedHoldings, |
|||
holdingsCount: normalizedHoldings.length, |
|||
totalValueInBaseCurrency |
|||
}; |
|||
} |
|||
|
|||
export function runRiskAssessment({ |
|||
portfolioAnalysis |
|||
}: { |
|||
portfolioAnalysis: PortfolioAnalysisResult; |
|||
}): RiskAssessmentResult { |
|||
const longExposureValues = portfolioAnalysis.holdings |
|||
.map(({ valueInBaseCurrency }) => { |
|||
return Math.max(valueInBaseCurrency, 0); |
|||
}) |
|||
.filter((value) => value > 0); |
|||
const totalLongExposure = longExposureValues.reduce((sum, value) => { |
|||
return sum + value; |
|||
}, 0); |
|||
const allocations = |
|||
totalLongExposure > 0 |
|||
? longExposureValues.map((value) => { |
|||
return value / totalLongExposure; |
|||
}) |
|||
: portfolioAnalysis.holdings |
|||
.map(({ allocationInPercentage }) => { |
|||
return Math.max(allocationInPercentage, 0); |
|||
}) |
|||
.filter((value) => value > 0); |
|||
const topHoldingAllocation = allocations.length > 0 ? Math.max(...allocations) : 0; |
|||
const hhi = allocations.reduce((sum, allocation) => { |
|||
return sum + allocation * allocation; |
|||
}, 0); |
|||
|
|||
let concentrationBand: RiskAssessmentResult['concentrationBand'] = 'low'; |
|||
|
|||
if (topHoldingAllocation >= 0.35 || hhi >= 0.25) { |
|||
concentrationBand = 'high'; |
|||
} else if (topHoldingAllocation >= 0.2 || hhi >= 0.15) { |
|||
concentrationBand = 'medium'; |
|||
} |
|||
|
|||
return { |
|||
concentrationBand, |
|||
hhi, |
|||
topHoldingAllocation |
|||
}; |
|||
} |
|||
|
|||
export async function setMemory({ |
|||
memory, |
|||
redisCacheService, |
|||
sessionId, |
|||
userId |
|||
}: { |
|||
memory: AiAgentMemoryState; |
|||
redisCacheService: RedisCacheService; |
|||
sessionId: string; |
|||
userId: string; |
|||
}) { |
|||
await redisCacheService.set( |
|||
getMemoryKey({ sessionId, userId }), |
|||
JSON.stringify(memory), |
|||
AI_AGENT_MEMORY_TTL |
|||
); |
|||
} |
|||
@ -0,0 +1,60 @@ |
|||
import { DataSource } from '@prisma/client'; |
|||
|
|||
import { AiAgentToolCall } from './ai-agent.interfaces'; |
|||
|
|||
export interface AiAgentMemoryState { |
|||
turns: { |
|||
answer: string; |
|||
query: string; |
|||
timestamp: string; |
|||
toolCalls: Pick<AiAgentToolCall, 'status' | 'tool'>[]; |
|||
}[]; |
|||
} |
|||
|
|||
export interface PortfolioAnalysisResult { |
|||
allocationSum: number; |
|||
holdings: { |
|||
allocationInPercentage: number; |
|||
dataSource: DataSource; |
|||
symbol: string; |
|||
valueInBaseCurrency: number; |
|||
}[]; |
|||
holdingsCount: number; |
|||
totalValueInBaseCurrency: number; |
|||
} |
|||
|
|||
export interface RiskAssessmentResult { |
|||
concentrationBand: 'high' | 'medium' | 'low'; |
|||
hhi: number; |
|||
topHoldingAllocation: number; |
|||
} |
|||
|
|||
export interface MarketDataLookupResult { |
|||
quotes: { |
|||
currency: string; |
|||
marketPrice: number; |
|||
marketState: string; |
|||
symbol: string; |
|||
}[]; |
|||
symbolsRequested: string[]; |
|||
} |
|||
|
|||
export interface RebalancePlanResult { |
|||
maxAllocationTarget: number; |
|||
overweightHoldings: { |
|||
currentAllocation: number; |
|||
reductionNeeded: number; |
|||
symbol: string; |
|||
}[]; |
|||
underweightHoldings: { |
|||
currentAllocation: number; |
|||
symbol: string; |
|||
}[]; |
|||
} |
|||
|
|||
export interface StressTestResult { |
|||
estimatedDrawdownInBaseCurrency: number; |
|||
estimatedPortfolioValueAfterShock: number; |
|||
longExposureInBaseCurrency: number; |
|||
shockPercentage: number; |
|||
} |
|||
@ -0,0 +1,46 @@ |
|||
export type AiAgentToolName = |
|||
| 'portfolio_analysis' |
|||
| 'risk_assessment' |
|||
| 'market_data_lookup' |
|||
| 'rebalance_plan' |
|||
| 'stress_test'; |
|||
|
|||
export type AiAgentConfidenceBand = 'high' | 'medium' | 'low'; |
|||
|
|||
export interface AiAgentCitation { |
|||
confidence: number; |
|||
snippet: string; |
|||
source: AiAgentToolName; |
|||
} |
|||
|
|||
export interface AiAgentConfidence { |
|||
band: AiAgentConfidenceBand; |
|||
score: number; |
|||
} |
|||
|
|||
export interface AiAgentVerificationCheck { |
|||
check: string; |
|||
details: string; |
|||
status: 'passed' | 'warning' | 'failed'; |
|||
} |
|||
|
|||
export interface AiAgentToolCall { |
|||
input: Record<string, unknown>; |
|||
outputSummary: string; |
|||
status: 'success' | 'failed'; |
|||
tool: AiAgentToolName; |
|||
} |
|||
|
|||
export interface AiAgentMemorySnapshot { |
|||
sessionId: string; |
|||
turns: number; |
|||
} |
|||
|
|||
export interface AiAgentChatResponse { |
|||
answer: string; |
|||
citations: AiAgentCitation[]; |
|||
confidence: AiAgentConfidence; |
|||
memory: AiAgentMemorySnapshot; |
|||
toolCalls: AiAgentToolCall[]; |
|||
verification: AiAgentVerificationCheck[]; |
|||
} |
|||
@ -0,0 +1,84 @@ |
|||
import { |
|||
PortfolioAnalysisResult, |
|||
RebalancePlanResult, |
|||
StressTestResult |
|||
} from './ai-agent.chat.interfaces'; |
|||
|
|||
export function runRebalancePlan({ |
|||
maxAllocationTarget = 0.35, |
|||
portfolioAnalysis |
|||
}: { |
|||
maxAllocationTarget?: number; |
|||
portfolioAnalysis: PortfolioAnalysisResult; |
|||
}): RebalancePlanResult { |
|||
const longExposure = portfolioAnalysis.holdings |
|||
.filter(({ valueInBaseCurrency }) => { |
|||
return valueInBaseCurrency > 0; |
|||
}) |
|||
.sort((a, b) => { |
|||
return b.valueInBaseCurrency - a.valueInBaseCurrency; |
|||
}); |
|||
const totalLongExposure = longExposure.reduce((sum, { valueInBaseCurrency }) => { |
|||
return sum + valueInBaseCurrency; |
|||
}, 0); |
|||
|
|||
if (totalLongExposure === 0) { |
|||
return { |
|||
maxAllocationTarget, |
|||
overweightHoldings: [], |
|||
underweightHoldings: [] |
|||
}; |
|||
} |
|||
|
|||
const withLongAllocation = longExposure.map(({ symbol, valueInBaseCurrency }) => { |
|||
return { |
|||
currentAllocation: valueInBaseCurrency / totalLongExposure, |
|||
symbol |
|||
}; |
|||
}); |
|||
|
|||
return { |
|||
maxAllocationTarget, |
|||
overweightHoldings: withLongAllocation |
|||
.filter(({ currentAllocation }) => { |
|||
return currentAllocation > maxAllocationTarget; |
|||
}) |
|||
.map(({ currentAllocation, symbol }) => { |
|||
return { |
|||
currentAllocation, |
|||
reductionNeeded: currentAllocation - maxAllocationTarget, |
|||
symbol |
|||
}; |
|||
}), |
|||
underweightHoldings: withLongAllocation |
|||
.filter(({ currentAllocation }) => { |
|||
return currentAllocation < maxAllocationTarget * 0.5; |
|||
}) |
|||
.slice(-3) |
|||
}; |
|||
} |
|||
|
|||
export function runStressTest({ |
|||
portfolioAnalysis, |
|||
shockPercentage = 0.1 |
|||
}: { |
|||
portfolioAnalysis: PortfolioAnalysisResult; |
|||
shockPercentage?: number; |
|||
}): StressTestResult { |
|||
const boundedShock = Math.min(Math.max(shockPercentage, 0), 0.8); |
|||
const longExposureInBaseCurrency = portfolioAnalysis.holdings.reduce( |
|||
(sum, { valueInBaseCurrency }) => { |
|||
return sum + Math.max(valueInBaseCurrency, 0); |
|||
}, |
|||
0 |
|||
); |
|||
const estimatedDrawdownInBaseCurrency = longExposureInBaseCurrency * boundedShock; |
|||
|
|||
return { |
|||
estimatedDrawdownInBaseCurrency, |
|||
estimatedPortfolioValueAfterShock: |
|||
portfolioAnalysis.totalValueInBaseCurrency - estimatedDrawdownInBaseCurrency, |
|||
longExposureInBaseCurrency, |
|||
shockPercentage: boundedShock |
|||
}; |
|||
} |
|||
@ -0,0 +1,201 @@ |
|||
import { |
|||
calculateConfidence, |
|||
determineToolPlan, |
|||
extractSymbolsFromQuery |
|||
} from './ai-agent.utils'; |
|||
|
|||
describe('AiAgentUtils', () => { |
|||
it('extracts and deduplicates symbols from query', () => { |
|||
expect(extractSymbolsFromQuery('Check AAPL and TSLA then AAPL')).toEqual([ |
|||
'AAPL', |
|||
'TSLA' |
|||
]); |
|||
}); |
|||
|
|||
it('ignores common uppercase stop words while keeping ticker symbols', () => { |
|||
expect( |
|||
extractSymbolsFromQuery('WHAT IS THE PRICE OF NVDA AND TSLA') |
|||
).toEqual(['NVDA', 'TSLA']); |
|||
}); |
|||
|
|||
it('supports dollar-prefixed lowercase or mixed-case symbol input', () => { |
|||
expect(extractSymbolsFromQuery('Check $nvda and $TsLa')).toEqual([ |
|||
'NVDA', |
|||
'TSLA' |
|||
]); |
|||
}); |
|||
|
|||
it('selects portfolio and risk tools for risk query', () => { |
|||
expect( |
|||
determineToolPlan({ |
|||
query: 'Analyze portfolio concentration risk' |
|||
}) |
|||
).toEqual(['portfolio_analysis', 'risk_assessment']); |
|||
}); |
|||
|
|||
it('selects market tool for quote query', () => { |
|||
expect( |
|||
determineToolPlan({ |
|||
query: 'What is the price for NVDA?', |
|||
symbols: ['NVDA'] |
|||
}) |
|||
).toEqual(['market_data_lookup']); |
|||
}); |
|||
|
|||
it('falls back to portfolio tool when no clear tool keyword exists', () => { |
|||
expect( |
|||
determineToolPlan({ |
|||
query: 'Help me with my account' |
|||
}) |
|||
).toEqual(['portfolio_analysis', 'risk_assessment']); |
|||
}); |
|||
|
|||
it('selects risk reasoning for investment intent queries', () => { |
|||
expect( |
|||
determineToolPlan({ |
|||
query: 'Where should I invest next?' |
|||
}) |
|||
).toEqual(['portfolio_analysis', 'risk_assessment', 'rebalance_plan']); |
|||
}); |
|||
|
|||
it('selects rebalance tool for rebalance-focused prompts', () => { |
|||
expect( |
|||
determineToolPlan({ |
|||
query: 'How should I rebalance overweight positions?' |
|||
}) |
|||
).toEqual(['portfolio_analysis', 'risk_assessment', 'rebalance_plan']); |
|||
}); |
|||
|
|||
it('selects stress test tool for crash scenario prompts', () => { |
|||
expect( |
|||
determineToolPlan({ |
|||
query: 'Run a drawdown stress test on my portfolio' |
|||
}) |
|||
).toEqual(['portfolio_analysis', 'risk_assessment', 'stress_test']); |
|||
}); |
|||
|
|||
it('calculates bounded confidence score and band', () => { |
|||
const confidence = calculateConfidence({ |
|||
toolCalls: [ |
|||
{ |
|||
input: {}, |
|||
outputSummary: 'ok', |
|||
status: 'success', |
|||
tool: 'portfolio_analysis' |
|||
}, |
|||
{ |
|||
input: {}, |
|||
outputSummary: 'ok', |
|||
status: 'success', |
|||
tool: 'risk_assessment' |
|||
}, |
|||
{ |
|||
input: {}, |
|||
outputSummary: 'failed', |
|||
status: 'failed', |
|||
tool: 'market_data_lookup' |
|||
} |
|||
], |
|||
verification: [ |
|||
{ |
|||
check: 'numerical_consistency', |
|||
details: 'ok', |
|||
status: 'passed' |
|||
}, |
|||
{ |
|||
check: 'tool_execution', |
|||
details: 'partial', |
|||
status: 'warning' |
|||
}, |
|||
{ |
|||
check: 'market_data_coverage', |
|||
details: 'missing', |
|||
status: 'failed' |
|||
} |
|||
] |
|||
}); |
|||
|
|||
expect(confidence.score).toBeGreaterThanOrEqual(0); |
|||
expect(confidence.score).toBeLessThanOrEqual(1); |
|||
expect(['high', 'medium', 'low']).toContain(confidence.band); |
|||
}); |
|||
|
|||
it('uses medium band at the 0.6 confidence threshold', () => { |
|||
const confidence = calculateConfidence({ |
|||
toolCalls: [], |
|||
verification: [ |
|||
{ |
|||
check: 'v1', |
|||
details: 'ok', |
|||
status: 'passed' |
|||
}, |
|||
{ |
|||
check: 'v2', |
|||
details: 'ok', |
|||
status: 'passed' |
|||
}, |
|||
{ |
|||
check: 'v3', |
|||
details: 'ok', |
|||
status: 'passed' |
|||
}, |
|||
{ |
|||
check: 'v4', |
|||
details: 'ok', |
|||
status: 'passed' |
|||
}, |
|||
{ |
|||
check: 'v5', |
|||
details: 'warn', |
|||
status: 'warning' |
|||
} |
|||
] |
|||
}); |
|||
|
|||
expect(confidence.score).toBe(0.6); |
|||
expect(confidence.band).toBe('medium'); |
|||
}); |
|||
|
|||
it('uses high band at the 0.8 confidence threshold', () => { |
|||
const confidence = calculateConfidence({ |
|||
toolCalls: [ |
|||
{ |
|||
input: {}, |
|||
outputSummary: 'ok', |
|||
status: 'success', |
|||
tool: 'portfolio_analysis' |
|||
} |
|||
], |
|||
verification: [ |
|||
{ |
|||
check: 'v1', |
|||
details: 'ok', |
|||
status: 'passed' |
|||
}, |
|||
{ |
|||
check: 'v2', |
|||
details: 'warn', |
|||
status: 'warning' |
|||
}, |
|||
{ |
|||
check: 'v3', |
|||
details: 'warn', |
|||
status: 'warning' |
|||
}, |
|||
{ |
|||
check: 'v4', |
|||
details: 'warn', |
|||
status: 'warning' |
|||
}, |
|||
{ |
|||
check: 'v5', |
|||
details: 'warn', |
|||
status: 'warning' |
|||
} |
|||
] |
|||
}); |
|||
|
|||
expect(confidence.score).toBe(0.8); |
|||
expect(confidence.band).toBe('high'); |
|||
}); |
|||
}); |
|||
@ -0,0 +1,205 @@ |
|||
import { |
|||
AiAgentConfidence, |
|||
AiAgentToolCall, |
|||
AiAgentToolName, |
|||
AiAgentVerificationCheck |
|||
} from './ai-agent.interfaces'; |
|||
|
|||
const CANDIDATE_TICKER_PATTERN = /\$?[A-Za-z0-9.]{1,10}/g; |
|||
const NORMALIZED_TICKER_PATTERN = /^(?=.*[A-Z])[A-Z0-9]{1,6}(?:\.[A-Z0-9]{1,4})?$/; |
|||
const SYMBOL_STOP_WORDS = new Set([ |
|||
'AND', |
|||
'FOR', |
|||
'GIVE', |
|||
'HELP', |
|||
'I', |
|||
'IS', |
|||
'MARKET', |
|||
'OF', |
|||
'PLEASE', |
|||
'PORTFOLIO', |
|||
'PRICE', |
|||
'QUOTE', |
|||
'RISK', |
|||
'SHOW', |
|||
'SYMBOL', |
|||
'THE', |
|||
'TICKER', |
|||
'WHAT', |
|||
'WITH' |
|||
]); |
|||
|
|||
const INVESTMENT_INTENT_KEYWORDS = [ |
|||
'add', |
|||
'allocat', |
|||
'buy', |
|||
'invest', |
|||
'next', |
|||
'rebalanc', |
|||
'sell', |
|||
'trim' |
|||
]; |
|||
|
|||
const REBALANCE_KEYWORDS = [ |
|||
'rebalanc', |
|||
'reduce', |
|||
'trim', |
|||
'underweight', |
|||
'overweight' |
|||
]; |
|||
|
|||
const STRESS_TEST_KEYWORDS = ['crash', 'drawdown', 'shock', 'stress']; |
|||
|
|||
function normalizeSymbolCandidate(rawCandidate: string) { |
|||
const hasDollarPrefix = rawCandidate.startsWith('$'); |
|||
const candidate = hasDollarPrefix |
|||
? rawCandidate.slice(1) |
|||
: rawCandidate; |
|||
|
|||
if (!candidate) { |
|||
return null; |
|||
} |
|||
|
|||
const normalized = candidate.toUpperCase(); |
|||
|
|||
if (SYMBOL_STOP_WORDS.has(normalized)) { |
|||
return null; |
|||
} |
|||
|
|||
if (!NORMALIZED_TICKER_PATTERN.test(normalized)) { |
|||
return null; |
|||
} |
|||
|
|||
// Conservative mode for non-prefixed symbols avoids false positives from
|
|||
// natural language words such as WHAT/THE/AND.
|
|||
if (!hasDollarPrefix && candidate !== candidate.toUpperCase()) { |
|||
return null; |
|||
} |
|||
|
|||
return normalized; |
|||
} |
|||
|
|||
export function extractSymbolsFromQuery(query: string) { |
|||
const matches = query.match(CANDIDATE_TICKER_PATTERN) ?? []; |
|||
|
|||
return Array.from( |
|||
new Set( |
|||
matches |
|||
.map((candidate) => normalizeSymbolCandidate(candidate)) |
|||
.filter(Boolean) |
|||
) |
|||
); |
|||
} |
|||
|
|||
export function determineToolPlan({ |
|||
query, |
|||
symbols |
|||
}: { |
|||
query: string; |
|||
symbols?: string[]; |
|||
}): AiAgentToolName[] { |
|||
const normalizedQuery = query.toLowerCase(); |
|||
const selectedTools = new Set<AiAgentToolName>(); |
|||
const extractedSymbols = symbols?.length |
|||
? symbols |
|||
: extractSymbolsFromQuery(query); |
|||
const hasInvestmentIntent = INVESTMENT_INTENT_KEYWORDS.some((keyword) => { |
|||
return normalizedQuery.includes(keyword); |
|||
}); |
|||
const hasRebalanceIntent = REBALANCE_KEYWORDS.some((keyword) => { |
|||
return normalizedQuery.includes(keyword); |
|||
}); |
|||
const hasStressTestIntent = STRESS_TEST_KEYWORDS.some((keyword) => { |
|||
return normalizedQuery.includes(keyword); |
|||
}); |
|||
|
|||
if ( |
|||
normalizedQuery.includes('portfolio') || |
|||
normalizedQuery.includes('holding') || |
|||
normalizedQuery.includes('allocation') || |
|||
normalizedQuery.includes('performance') || |
|||
normalizedQuery.includes('return') |
|||
) { |
|||
selectedTools.add('portfolio_analysis'); |
|||
} |
|||
|
|||
if ( |
|||
normalizedQuery.includes('risk') || |
|||
normalizedQuery.includes('concentration') || |
|||
normalizedQuery.includes('diversif') |
|||
) { |
|||
selectedTools.add('portfolio_analysis'); |
|||
selectedTools.add('risk_assessment'); |
|||
} |
|||
|
|||
if (hasInvestmentIntent || hasRebalanceIntent) { |
|||
selectedTools.add('portfolio_analysis'); |
|||
selectedTools.add('risk_assessment'); |
|||
selectedTools.add('rebalance_plan'); |
|||
} |
|||
|
|||
if (hasStressTestIntent) { |
|||
selectedTools.add('portfolio_analysis'); |
|||
selectedTools.add('risk_assessment'); |
|||
selectedTools.add('stress_test'); |
|||
} |
|||
|
|||
if ( |
|||
normalizedQuery.includes('quote') || |
|||
normalizedQuery.includes('price') || |
|||
normalizedQuery.includes('market') || |
|||
normalizedQuery.includes('ticker') || |
|||
extractedSymbols.length > 0 |
|||
) { |
|||
selectedTools.add('market_data_lookup'); |
|||
} |
|||
|
|||
if (selectedTools.size === 0) { |
|||
selectedTools.add('portfolio_analysis'); |
|||
selectedTools.add('risk_assessment'); |
|||
} |
|||
|
|||
return Array.from(selectedTools); |
|||
} |
|||
|
|||
export function calculateConfidence({ |
|||
toolCalls, |
|||
verification |
|||
}: { |
|||
toolCalls: AiAgentToolCall[]; |
|||
verification: AiAgentVerificationCheck[]; |
|||
}): AiAgentConfidence { |
|||
const successfulToolCalls = toolCalls.filter(({ status }) => { |
|||
return status === 'success'; |
|||
}).length; |
|||
|
|||
const passedVerification = verification.filter(({ status }) => { |
|||
return status === 'passed'; |
|||
}).length; |
|||
|
|||
const failedVerification = verification.filter(({ status }) => { |
|||
return status === 'failed'; |
|||
}).length; |
|||
|
|||
const toolSuccessRate = |
|||
toolCalls.length > 0 ? successfulToolCalls / toolCalls.length : 0; |
|||
const verificationPassRate = |
|||
verification.length > 0 ? passedVerification / verification.length : 0; |
|||
|
|||
let score = 0.4 + toolSuccessRate * 0.35 + verificationPassRate * 0.25; |
|||
score -= failedVerification * 0.1; |
|||
score = Math.max(0, Math.min(1, score)); |
|||
|
|||
let band: AiAgentConfidence['band'] = 'low'; |
|||
|
|||
if (score >= 0.8) { |
|||
band = 'high'; |
|||
} else if (score >= 0.6) { |
|||
band = 'medium'; |
|||
} |
|||
|
|||
return { |
|||
band, |
|||
score: Number(score.toFixed(2)) |
|||
}; |
|||
} |
|||
@ -0,0 +1,17 @@ |
|||
import { IsArray, IsNotEmpty, IsOptional, IsString } from 'class-validator'; |
|||
|
|||
export class AiChatDto { |
|||
@IsString() |
|||
@IsNotEmpty() |
|||
public query: string; |
|||
|
|||
@IsOptional() |
|||
@IsString() |
|||
public sessionId?: string; |
|||
|
|||
@IsOptional() |
|||
@IsArray() |
|||
@IsString({ each: true }) |
|||
public symbols?: string[]; |
|||
} |
|||
|
|||
@ -0,0 +1,123 @@ |
|||
const DEFAULT_GLM_MODEL = 'glm-5'; |
|||
const DEFAULT_MINIMAX_MODEL = 'MiniMax-M2.5'; |
|||
const DEFAULT_REQUEST_TIMEOUT_IN_MS = 15_000; |
|||
|
|||
function extractTextFromResponsePayload(payload: unknown) { |
|||
const firstChoice = (payload as { choices?: unknown[] })?.choices?.[0] as |
|||
| { message?: { content?: unknown } } |
|||
| undefined; |
|||
const content = firstChoice?.message?.content; |
|||
|
|||
if (typeof content === 'string') { |
|||
return content.trim(); |
|||
} |
|||
|
|||
if (Array.isArray(content)) { |
|||
const normalized = content |
|||
.map((item) => { |
|||
if (typeof item === 'string') { |
|||
return item; |
|||
} |
|||
|
|||
if ( |
|||
typeof item === 'object' && |
|||
item !== null && |
|||
'text' in item && |
|||
typeof item.text === 'string' |
|||
) { |
|||
return item.text; |
|||
} |
|||
|
|||
return ''; |
|||
}) |
|||
.join(' ') |
|||
.trim(); |
|||
|
|||
return normalized.length > 0 ? normalized : null; |
|||
} |
|||
|
|||
return null; |
|||
} |
|||
|
|||
async function callChatCompletions({ |
|||
apiKey, |
|||
model, |
|||
prompt, |
|||
url |
|||
}: { |
|||
apiKey: string; |
|||
model: string; |
|||
prompt: string; |
|||
url: string; |
|||
}) { |
|||
const response = await fetch(url, { |
|||
body: JSON.stringify({ |
|||
messages: [ |
|||
{ |
|||
content: 'You are a neutral financial assistant.', |
|||
role: 'system' |
|||
}, |
|||
{ |
|||
content: prompt, |
|||
role: 'user' |
|||
} |
|||
], |
|||
model |
|||
}), |
|||
headers: { |
|||
Authorization: `Bearer ${apiKey}`, |
|||
'Content-Type': 'application/json' |
|||
}, |
|||
method: 'POST', |
|||
signal: AbortSignal.timeout(DEFAULT_REQUEST_TIMEOUT_IN_MS) |
|||
}); |
|||
|
|||
if (!response.ok) { |
|||
throw new Error(`provider request failed with status ${response.status}`); |
|||
} |
|||
|
|||
const payload = (await response.json()) as unknown; |
|||
const text = extractTextFromResponsePayload(payload); |
|||
|
|||
if (!text) { |
|||
throw new Error('provider returned no assistant text'); |
|||
} |
|||
|
|||
return { |
|||
text |
|||
}; |
|||
} |
|||
|
|||
export async function generateTextWithZAiGlm({ |
|||
apiKey, |
|||
model, |
|||
prompt |
|||
}: { |
|||
apiKey: string; |
|||
model?: string; |
|||
prompt: string; |
|||
}) { |
|||
return callChatCompletions({ |
|||
apiKey, |
|||
model: model ?? DEFAULT_GLM_MODEL, |
|||
prompt, |
|||
url: 'https://api.z.ai/api/paas/v4/chat/completions' |
|||
}); |
|||
} |
|||
|
|||
export async function generateTextWithMinimax({ |
|||
apiKey, |
|||
model, |
|||
prompt |
|||
}: { |
|||
apiKey: string; |
|||
model?: string; |
|||
prompt: string; |
|||
}) { |
|||
return callChatCompletions({ |
|||
apiKey, |
|||
model: model ?? DEFAULT_MINIMAX_MODEL, |
|||
prompt, |
|||
url: 'https://api.minimax.io/v1/chat/completions' |
|||
}); |
|||
} |
|||
@ -0,0 +1,116 @@ |
|||
import { REQUEST } from '@nestjs/core'; |
|||
import { Test, TestingModule } from '@nestjs/testing'; |
|||
|
|||
import { ApiService } from '@ghostfolio/api/services/api/api.service'; |
|||
|
|||
import { AiController } from './ai.controller'; |
|||
import { AiChatDto } from './ai-chat.dto'; |
|||
import { AiService } from './ai.service'; |
|||
|
|||
describe('AiController', () => { |
|||
let controller: AiController; |
|||
let aiService: { chat: jest.Mock; getPrompt: jest.Mock }; |
|||
let apiService: { buildFiltersFromQueryParams: jest.Mock }; |
|||
|
|||
beforeEach(async () => { |
|||
aiService = { |
|||
chat: jest.fn(), |
|||
getPrompt: jest.fn() |
|||
}; |
|||
apiService = { |
|||
buildFiltersFromQueryParams: jest.fn() |
|||
}; |
|||
|
|||
const moduleRef: TestingModule = await Test.createTestingModule({ |
|||
controllers: [AiController], |
|||
providers: [ |
|||
{ |
|||
provide: AiService, |
|||
useValue: aiService |
|||
}, |
|||
{ |
|||
provide: ApiService, |
|||
useValue: apiService |
|||
}, |
|||
{ |
|||
provide: REQUEST, |
|||
useValue: { |
|||
user: { |
|||
id: 'user-controller', |
|||
settings: { |
|||
settings: { |
|||
baseCurrency: 'USD', |
|||
language: 'en' |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
] |
|||
}).compile(); |
|||
|
|||
controller = moduleRef.get(AiController); |
|||
}); |
|||
|
|||
it('passes validated chat payload and user context to ai service', async () => { |
|||
const dto: AiChatDto = { |
|||
query: 'Analyze my portfolio', |
|||
sessionId: 'chat-session-1', |
|||
symbols: ['AAPL'] |
|||
}; |
|||
|
|||
aiService.chat.mockResolvedValue({ |
|||
answer: 'ok', |
|||
citations: [], |
|||
confidence: { band: 'medium', score: 0.7 }, |
|||
memory: { sessionId: 'chat-session-1', turns: 1 }, |
|||
toolCalls: [], |
|||
verification: [] |
|||
}); |
|||
|
|||
await controller.chat(dto); |
|||
|
|||
expect(aiService.chat).toHaveBeenCalledWith({ |
|||
languageCode: 'en', |
|||
query: dto.query, |
|||
sessionId: dto.sessionId, |
|||
symbols: dto.symbols, |
|||
userCurrency: 'USD', |
|||
userId: 'user-controller' |
|||
}); |
|||
}); |
|||
|
|||
it('builds filters via api service before calling prompt generation', async () => { |
|||
const filters = [{ key: 'symbol', value: 'AAPL' }]; |
|||
apiService.buildFiltersFromQueryParams.mockReturnValue(filters); |
|||
aiService.getPrompt.mockResolvedValue('prompt-body'); |
|||
|
|||
const response = await controller.getPrompt( |
|||
'portfolio', |
|||
'account-1', |
|||
undefined, |
|||
undefined, |
|||
undefined, |
|||
'tag-1' |
|||
); |
|||
|
|||
expect(apiService.buildFiltersFromQueryParams).toHaveBeenCalledWith({ |
|||
filterByAccounts: 'account-1', |
|||
filterByAssetClasses: undefined, |
|||
filterByDataSource: undefined, |
|||
filterBySymbol: undefined, |
|||
filterByTags: 'tag-1' |
|||
}); |
|||
expect(aiService.getPrompt).toHaveBeenCalledWith({ |
|||
filters, |
|||
impersonationId: undefined, |
|||
languageCode: 'en', |
|||
mode: 'portfolio', |
|||
userCurrency: 'USD', |
|||
userId: 'user-controller' |
|||
}); |
|||
expect(response).toEqual({ |
|||
prompt: 'prompt-body' |
|||
}); |
|||
}); |
|||
}); |
|||
@ -0,0 +1,419 @@ |
|||
import { DataSource } from '@prisma/client'; |
|||
|
|||
import { AiService } from './ai.service'; |
|||
|
|||
describe('AiService', () => { |
|||
let dataProviderService: { getQuotes: jest.Mock }; |
|||
let portfolioService: { getDetails: jest.Mock }; |
|||
let propertyService: { getByKey: jest.Mock }; |
|||
let redisCacheService: { get: jest.Mock; set: jest.Mock }; |
|||
let subject: AiService; |
|||
const originalFetch = global.fetch; |
|||
const originalMinimaxApiKey = process.env.minimax_api_key; |
|||
const originalMinimaxModel = process.env.minimax_model; |
|||
const originalZAiGlmApiKey = process.env.z_ai_glm_api_key; |
|||
const originalZAiGlmModel = process.env.z_ai_glm_model; |
|||
|
|||
beforeEach(() => { |
|||
dataProviderService = { |
|||
getQuotes: jest.fn() |
|||
}; |
|||
portfolioService = { |
|||
getDetails: jest.fn() |
|||
}; |
|||
propertyService = { |
|||
getByKey: jest.fn() |
|||
}; |
|||
redisCacheService = { |
|||
get: jest.fn(), |
|||
set: jest.fn() |
|||
}; |
|||
|
|||
subject = new AiService( |
|||
dataProviderService as never, |
|||
portfolioService as never, |
|||
propertyService as never, |
|||
redisCacheService as never |
|||
); |
|||
|
|||
delete process.env.minimax_api_key; |
|||
delete process.env.minimax_model; |
|||
delete process.env.z_ai_glm_api_key; |
|||
delete process.env.z_ai_glm_model; |
|||
}); |
|||
|
|||
afterAll(() => { |
|||
global.fetch = originalFetch; |
|||
|
|||
if (originalMinimaxApiKey === undefined) { |
|||
delete process.env.minimax_api_key; |
|||
} else { |
|||
process.env.minimax_api_key = originalMinimaxApiKey; |
|||
} |
|||
|
|||
if (originalMinimaxModel === undefined) { |
|||
delete process.env.minimax_model; |
|||
} else { |
|||
process.env.minimax_model = originalMinimaxModel; |
|||
} |
|||
|
|||
if (originalZAiGlmApiKey === undefined) { |
|||
delete process.env.z_ai_glm_api_key; |
|||
} else { |
|||
process.env.z_ai_glm_api_key = originalZAiGlmApiKey; |
|||
} |
|||
|
|||
if (originalZAiGlmModel === undefined) { |
|||
delete process.env.z_ai_glm_model; |
|||
} else { |
|||
process.env.z_ai_glm_model = originalZAiGlmModel; |
|||
} |
|||
}); |
|||
|
|||
it('runs portfolio, risk, and market tools with structured response fields', async () => { |
|||
portfolioService.getDetails.mockResolvedValue({ |
|||
holdings: { |
|||
AAPL: { |
|||
allocationInPercentage: 0.6, |
|||
dataSource: DataSource.YAHOO, |
|||
symbol: 'AAPL', |
|||
valueInBaseCurrency: 6000 |
|||
}, |
|||
MSFT: { |
|||
allocationInPercentage: 0.4, |
|||
dataSource: DataSource.YAHOO, |
|||
symbol: 'MSFT', |
|||
valueInBaseCurrency: 4000 |
|||
} |
|||
} |
|||
}); |
|||
dataProviderService.getQuotes.mockResolvedValue({ |
|||
AAPL: { |
|||
currency: 'USD', |
|||
marketPrice: 210.12, |
|||
marketState: 'REGULAR' |
|||
}, |
|||
MSFT: { |
|||
currency: 'USD', |
|||
marketPrice: 455.9, |
|||
marketState: 'REGULAR' |
|||
} |
|||
}); |
|||
redisCacheService.get.mockResolvedValue(undefined); |
|||
jest.spyOn(subject, 'generateText').mockResolvedValue({ |
|||
text: 'Portfolio risk looks medium with strong concentration controls.' |
|||
} as never); |
|||
|
|||
const result = await subject.chat({ |
|||
languageCode: 'en', |
|||
query: 'Analyze my portfolio risk and price for AAPL', |
|||
sessionId: 'session-1', |
|||
userCurrency: 'USD', |
|||
userId: 'user-1' |
|||
}); |
|||
|
|||
expect(result.answer).toContain('Portfolio risk'); |
|||
expect(result.toolCalls).toEqual( |
|||
expect.arrayContaining([ |
|||
expect.objectContaining({ |
|||
status: 'success', |
|||
tool: 'portfolio_analysis' |
|||
}), |
|||
expect.objectContaining({ |
|||
status: 'success', |
|||
tool: 'risk_assessment' |
|||
}), |
|||
expect.objectContaining({ |
|||
status: 'success', |
|||
tool: 'market_data_lookup' |
|||
}) |
|||
]) |
|||
); |
|||
expect(result.citations.length).toBeGreaterThan(0); |
|||
expect(result.confidence.score).toBeGreaterThanOrEqual(0); |
|||
expect(result.confidence.score).toBeLessThanOrEqual(1); |
|||
expect(result.verification).toEqual( |
|||
expect.arrayContaining([ |
|||
expect.objectContaining({ check: 'numerical_consistency' }), |
|||
expect.objectContaining({ check: 'tool_execution' }), |
|||
expect.objectContaining({ check: 'output_completeness' }), |
|||
expect.objectContaining({ check: 'citation_coverage' }) |
|||
]) |
|||
); |
|||
expect(result.memory).toEqual({ |
|||
sessionId: 'session-1', |
|||
turns: 1 |
|||
}); |
|||
expect(redisCacheService.set).toHaveBeenCalledWith( |
|||
'ai-agent-memory-user-1-session-1', |
|||
expect.any(String), |
|||
expect.any(Number) |
|||
); |
|||
}); |
|||
|
|||
it('keeps memory history and caps turns at the configured limit', async () => { |
|||
const previousTurns = Array.from({ length: 10 }, (_, index) => { |
|||
return { |
|||
answer: `answer-${index}`, |
|||
query: `query-${index}`, |
|||
timestamp: `2026-02-20T00:0${index}:00.000Z`, |
|||
toolCalls: [{ status: 'success', tool: 'portfolio_analysis' }] |
|||
}; |
|||
}); |
|||
|
|||
portfolioService.getDetails.mockResolvedValue({ |
|||
holdings: {} |
|||
}); |
|||
redisCacheService.get.mockResolvedValue( |
|||
JSON.stringify({ |
|||
turns: previousTurns |
|||
}) |
|||
); |
|||
jest.spyOn(subject, 'generateText').mockRejectedValue(new Error('offline')); |
|||
|
|||
const result = await subject.chat({ |
|||
languageCode: 'en', |
|||
query: 'Show my portfolio overview', |
|||
sessionId: 'session-memory', |
|||
userCurrency: 'USD', |
|||
userId: 'user-memory' |
|||
}); |
|||
|
|||
expect(result.memory.turns).toBe(10); |
|||
const [, payload] = redisCacheService.set.mock.calls[0]; |
|||
const persistedMemory = JSON.parse(payload as string); |
|||
expect(persistedMemory.turns).toHaveLength(10); |
|||
expect( |
|||
persistedMemory.turns.find( |
|||
({ query }: { query: string }) => query === 'query-0' |
|||
) |
|||
).toBeUndefined(); |
|||
}); |
|||
|
|||
it('runs rebalance and stress test tools for portfolio scenario prompts', async () => { |
|||
portfolioService.getDetails.mockResolvedValue({ |
|||
holdings: { |
|||
AAPL: { |
|||
allocationInPercentage: 0.6, |
|||
dataSource: DataSource.YAHOO, |
|||
symbol: 'AAPL', |
|||
valueInBaseCurrency: 6000 |
|||
}, |
|||
MSFT: { |
|||
allocationInPercentage: 0.4, |
|||
dataSource: DataSource.YAHOO, |
|||
symbol: 'MSFT', |
|||
valueInBaseCurrency: 4000 |
|||
} |
|||
} |
|||
}); |
|||
redisCacheService.get.mockResolvedValue(undefined); |
|||
jest.spyOn(subject, 'generateText').mockResolvedValue({ |
|||
text: 'Trim AAPL toward target allocation and monitor stress drawdown.' |
|||
} as never); |
|||
|
|||
const result = await subject.chat({ |
|||
languageCode: 'en', |
|||
query: 'Rebalance my portfolio and run a stress test', |
|||
sessionId: 'session-core-tools', |
|||
userCurrency: 'USD', |
|||
userId: 'user-core-tools' |
|||
}); |
|||
|
|||
expect(result.toolCalls).toEqual( |
|||
expect.arrayContaining([ |
|||
expect.objectContaining({ tool: 'portfolio_analysis' }), |
|||
expect.objectContaining({ tool: 'risk_assessment' }), |
|||
expect.objectContaining({ tool: 'rebalance_plan' }), |
|||
expect.objectContaining({ tool: 'stress_test' }) |
|||
]) |
|||
); |
|||
expect(result.verification).toEqual( |
|||
expect.arrayContaining([ |
|||
expect.objectContaining({ |
|||
check: 'rebalance_coverage', |
|||
status: 'passed' |
|||
}), |
|||
expect.objectContaining({ |
|||
check: 'stress_test_coherence', |
|||
status: 'passed' |
|||
}) |
|||
]) |
|||
); |
|||
}); |
|||
|
|||
it('returns graceful failure metadata when a tool execution fails', async () => { |
|||
dataProviderService.getQuotes.mockRejectedValue( |
|||
new Error('market provider unavailable') |
|||
); |
|||
redisCacheService.get.mockResolvedValue(undefined); |
|||
jest.spyOn(subject, 'generateText').mockResolvedValue({ |
|||
text: 'Market data currently has limited availability.' |
|||
} as never); |
|||
|
|||
const result = await subject.chat({ |
|||
languageCode: 'en', |
|||
query: 'What is the current price of NVDA?', |
|||
sessionId: 'session-failure', |
|||
userCurrency: 'USD', |
|||
userId: 'user-failure' |
|||
}); |
|||
|
|||
expect(result.toolCalls).toEqual([ |
|||
expect.objectContaining({ |
|||
outputSummary: 'market provider unavailable', |
|||
status: 'failed', |
|||
tool: 'market_data_lookup' |
|||
}) |
|||
]); |
|||
expect(result.verification).toEqual( |
|||
expect.arrayContaining([ |
|||
expect.objectContaining({ |
|||
check: 'numerical_consistency', |
|||
status: 'warning' |
|||
}), |
|||
expect.objectContaining({ |
|||
check: 'tool_execution', |
|||
status: 'warning' |
|||
}) |
|||
]) |
|||
); |
|||
expect(result.answer).toContain('limited availability'); |
|||
}); |
|||
|
|||
it('flags numerical consistency warning when allocation sum exceeds tolerance', async () => { |
|||
portfolioService.getDetails.mockResolvedValue({ |
|||
holdings: { |
|||
AAPL: { |
|||
allocationInPercentage: 0.8, |
|||
dataSource: DataSource.YAHOO, |
|||
symbol: 'AAPL', |
|||
valueInBaseCurrency: 8000 |
|||
}, |
|||
MSFT: { |
|||
allocationInPercentage: 0.3, |
|||
dataSource: DataSource.YAHOO, |
|||
symbol: 'MSFT', |
|||
valueInBaseCurrency: 3000 |
|||
} |
|||
} |
|||
}); |
|||
redisCacheService.get.mockResolvedValue(undefined); |
|||
jest.spyOn(subject, 'generateText').mockRejectedValue(new Error('offline')); |
|||
|
|||
const result = await subject.chat({ |
|||
languageCode: 'en', |
|||
query: 'Show portfolio allocation', |
|||
sessionId: 'session-allocation-warning', |
|||
userCurrency: 'USD', |
|||
userId: 'user-allocation-warning' |
|||
}); |
|||
|
|||
expect(result.verification).toEqual( |
|||
expect.arrayContaining([ |
|||
expect.objectContaining({ |
|||
check: 'numerical_consistency', |
|||
status: 'warning' |
|||
}) |
|||
]) |
|||
); |
|||
}); |
|||
|
|||
it('flags market data coverage warning when only part of symbols resolve', async () => { |
|||
dataProviderService.getQuotes.mockResolvedValue({ |
|||
AAPL: { |
|||
currency: 'USD', |
|||
marketPrice: 210.12, |
|||
marketState: 'REGULAR' |
|||
} |
|||
}); |
|||
redisCacheService.get.mockResolvedValue(undefined); |
|||
jest.spyOn(subject, 'generateText').mockResolvedValue({ |
|||
text: 'Partial market data was returned.' |
|||
} as never); |
|||
|
|||
const result = await subject.chat({ |
|||
languageCode: 'en', |
|||
query: 'Get market prices for AAPL and TSLA', |
|||
sessionId: 'session-market-coverage-warning', |
|||
symbols: ['AAPL', 'TSLA'], |
|||
userCurrency: 'USD', |
|||
userId: 'user-market-coverage-warning' |
|||
}); |
|||
|
|||
expect(result.verification).toEqual( |
|||
expect.arrayContaining([ |
|||
expect.objectContaining({ |
|||
check: 'market_data_coverage', |
|||
status: 'warning' |
|||
}) |
|||
]) |
|||
); |
|||
}); |
|||
|
|||
it('uses z.ai glm provider when z_ai_glm_api_key is available', async () => { |
|||
process.env.z_ai_glm_api_key = 'zai-key'; |
|||
process.env.z_ai_glm_model = 'glm-5'; |
|||
|
|||
const fetchMock = jest.fn().mockResolvedValue({ |
|||
json: jest.fn().mockResolvedValue({ |
|||
choices: [{ message: { content: 'zai-response' } }] |
|||
}), |
|||
ok: true |
|||
}); |
|||
global.fetch = fetchMock as unknown as typeof fetch; |
|||
|
|||
const result = await subject.generateText({ |
|||
prompt: 'hello' |
|||
}); |
|||
|
|||
expect(fetchMock).toHaveBeenCalledWith( |
|||
'https://api.z.ai/api/paas/v4/chat/completions', |
|||
expect.objectContaining({ |
|||
method: 'POST' |
|||
}) |
|||
); |
|||
expect(result).toEqual({ |
|||
text: 'zai-response' |
|||
}); |
|||
expect(propertyService.getByKey).not.toHaveBeenCalled(); |
|||
}); |
|||
|
|||
it('falls back to minimax when z.ai request fails', async () => { |
|||
process.env.z_ai_glm_api_key = 'zai-key'; |
|||
process.env.minimax_api_key = 'minimax-key'; |
|||
process.env.minimax_model = 'MiniMax-M2.5'; |
|||
|
|||
const fetchMock = jest |
|||
.fn() |
|||
.mockResolvedValueOnce({ |
|||
ok: false, |
|||
status: 500 |
|||
}) |
|||
.mockResolvedValueOnce({ |
|||
json: jest.fn().mockResolvedValue({ |
|||
choices: [{ message: { content: 'minimax-response' } }] |
|||
}), |
|||
ok: true |
|||
}); |
|||
global.fetch = fetchMock as unknown as typeof fetch; |
|||
|
|||
const result = await subject.generateText({ |
|||
prompt: 'fallback test' |
|||
}); |
|||
|
|||
expect(fetchMock).toHaveBeenNthCalledWith( |
|||
1, |
|||
'https://api.z.ai/api/paas/v4/chat/completions', |
|||
expect.any(Object) |
|||
); |
|||
expect(fetchMock).toHaveBeenNthCalledWith( |
|||
2, |
|||
'https://api.minimax.io/v1/chat/completions', |
|||
expect.any(Object) |
|||
); |
|||
expect(result).toEqual({ |
|||
text: 'minimax-response' |
|||
}); |
|||
}); |
|||
}); |
|||
@ -0,0 +1,264 @@ |
|||
import { DataSource } from '@prisma/client'; |
|||
|
|||
import { AiAgentMvpEvalCase } from './mvp-eval.interfaces'; |
|||
|
|||
const DEFAULT_HOLDINGS = { |
|||
AAPL: { |
|||
allocationInPercentage: 0.5, |
|||
dataSource: DataSource.YAHOO, |
|||
symbol: 'AAPL', |
|||
valueInBaseCurrency: 5000 |
|||
}, |
|||
MSFT: { |
|||
allocationInPercentage: 0.3, |
|||
dataSource: DataSource.YAHOO, |
|||
symbol: 'MSFT', |
|||
valueInBaseCurrency: 3000 |
|||
}, |
|||
NVDA: { |
|||
allocationInPercentage: 0.2, |
|||
dataSource: DataSource.YAHOO, |
|||
symbol: 'NVDA', |
|||
valueInBaseCurrency: 2000 |
|||
} |
|||
}; |
|||
|
|||
const DEFAULT_QUOTES = { |
|||
AAPL: { |
|||
currency: 'USD', |
|||
marketPrice: 213.34, |
|||
marketState: 'REGULAR' |
|||
}, |
|||
MSFT: { |
|||
currency: 'USD', |
|||
marketPrice: 462.15, |
|||
marketState: 'REGULAR' |
|||
}, |
|||
NVDA: { |
|||
currency: 'USD', |
|||
marketPrice: 901.22, |
|||
marketState: 'REGULAR' |
|||
} |
|||
}; |
|||
|
|||
export const AI_AGENT_MVP_EVAL_DATASET: AiAgentMvpEvalCase[] = [ |
|||
{ |
|||
expected: { |
|||
minCitations: 1, |
|||
requiredTools: ['portfolio_analysis'], |
|||
verificationChecks: [{ check: 'tool_execution', status: 'passed' }] |
|||
}, |
|||
id: 'mvp-001-portfolio-overview', |
|||
input: { |
|||
query: 'Give me a quick portfolio allocation overview', |
|||
sessionId: 'mvp-eval-session-1', |
|||
userId: 'mvp-user' |
|||
}, |
|||
intent: 'portfolio-analysis', |
|||
setup: { |
|||
holdings: DEFAULT_HOLDINGS, |
|||
llmText: 'Your portfolio is diversified with large-cap concentration.', |
|||
quotesBySymbol: DEFAULT_QUOTES |
|||
} |
|||
}, |
|||
{ |
|||
expected: { |
|||
minCitations: 2, |
|||
requiredTools: ['portfolio_analysis', 'risk_assessment'], |
|||
verificationChecks: [{ check: 'numerical_consistency', status: 'passed' }] |
|||
}, |
|||
id: 'mvp-002-risk-assessment', |
|||
input: { |
|||
query: 'Analyze my portfolio concentration risk', |
|||
sessionId: 'mvp-eval-session-2', |
|||
userId: 'mvp-user' |
|||
}, |
|||
intent: 'risk-assessment', |
|||
setup: { |
|||
holdings: DEFAULT_HOLDINGS, |
|||
llmText: 'Concentration risk sits in the medium range.', |
|||
quotesBySymbol: DEFAULT_QUOTES |
|||
} |
|||
}, |
|||
{ |
|||
expected: { |
|||
minCitations: 1, |
|||
requiredToolCalls: [ |
|||
{ status: 'success', tool: 'market_data_lookup' } |
|||
], |
|||
requiredTools: ['market_data_lookup'] |
|||
}, |
|||
id: 'mvp-003-market-quote', |
|||
input: { |
|||
query: 'What is the latest price of NVDA?', |
|||
sessionId: 'mvp-eval-session-3', |
|||
userId: 'mvp-user' |
|||
}, |
|||
intent: 'market-data', |
|||
setup: { |
|||
holdings: DEFAULT_HOLDINGS, |
|||
llmText: 'NVDA is currently trading near recent highs.', |
|||
quotesBySymbol: DEFAULT_QUOTES |
|||
} |
|||
}, |
|||
{ |
|||
expected: { |
|||
minCitations: 3, |
|||
requiredTools: [ |
|||
'portfolio_analysis', |
|||
'risk_assessment', |
|||
'market_data_lookup' |
|||
], |
|||
verificationChecks: [ |
|||
{ check: 'numerical_consistency', status: 'passed' }, |
|||
{ check: 'citation_coverage', status: 'passed' } |
|||
] |
|||
}, |
|||
id: 'mvp-004-multi-tool-query', |
|||
input: { |
|||
query: 'Analyze portfolio risk and price action for AAPL', |
|||
sessionId: 'mvp-eval-session-4', |
|||
userId: 'mvp-user' |
|||
}, |
|||
intent: 'multi-tool', |
|||
setup: { |
|||
holdings: DEFAULT_HOLDINGS, |
|||
llmText: 'Risk is moderate and AAPL supports portfolio momentum.', |
|||
quotesBySymbol: DEFAULT_QUOTES |
|||
} |
|||
}, |
|||
{ |
|||
expected: { |
|||
requiredTools: ['portfolio_analysis'], |
|||
verificationChecks: [{ check: 'tool_execution', status: 'passed' }] |
|||
}, |
|||
id: 'mvp-005-default-fallback-tool', |
|||
input: { |
|||
query: 'Help me with my investments this week', |
|||
sessionId: 'mvp-eval-session-5', |
|||
userId: 'mvp-user' |
|||
}, |
|||
intent: 'fallback-tool-selection', |
|||
setup: { |
|||
holdings: DEFAULT_HOLDINGS, |
|||
llmText: 'Portfolio context provides the best starting point.', |
|||
quotesBySymbol: DEFAULT_QUOTES |
|||
} |
|||
}, |
|||
{ |
|||
expected: { |
|||
answerIncludes: ['Session memory applied from 2 prior turn(s).'], |
|||
memoryTurnsAtLeast: 3, |
|||
requiredTools: ['portfolio_analysis'] |
|||
}, |
|||
id: 'mvp-006-memory-continuity', |
|||
input: { |
|||
query: 'Show my portfolio status again', |
|||
sessionId: 'mvp-eval-session-6', |
|||
userId: 'mvp-user' |
|||
}, |
|||
intent: 'memory', |
|||
setup: { |
|||
holdings: DEFAULT_HOLDINGS, |
|||
llmThrows: true, |
|||
quotesBySymbol: DEFAULT_QUOTES, |
|||
storedMemoryTurns: [ |
|||
{ |
|||
answer: 'Prior answer 1', |
|||
query: 'Initial query', |
|||
timestamp: '2026-02-23T10:00:00.000Z', |
|||
toolCalls: [{ status: 'success', tool: 'portfolio_analysis' }] |
|||
}, |
|||
{ |
|||
answer: 'Prior answer 2', |
|||
query: 'Follow-up query', |
|||
timestamp: '2026-02-23T10:05:00.000Z', |
|||
toolCalls: [{ status: 'success', tool: 'risk_assessment' }] |
|||
} |
|||
] |
|||
} |
|||
}, |
|||
{ |
|||
expected: { |
|||
requiredToolCalls: [ |
|||
{ status: 'failed', tool: 'market_data_lookup' } |
|||
], |
|||
requiredTools: ['market_data_lookup'], |
|||
verificationChecks: [{ check: 'tool_execution', status: 'warning' }] |
|||
}, |
|||
id: 'mvp-007-market-tool-graceful-failure', |
|||
input: { |
|||
query: 'Fetch price for NVDA and TSLA', |
|||
sessionId: 'mvp-eval-session-7', |
|||
symbols: ['NVDA', 'TSLA'], |
|||
userId: 'mvp-user' |
|||
}, |
|||
intent: 'tool-failure', |
|||
setup: { |
|||
holdings: DEFAULT_HOLDINGS, |
|||
llmText: 'Market provider has limited availability right now.', |
|||
marketDataErrorMessage: 'market provider unavailable' |
|||
} |
|||
}, |
|||
{ |
|||
expected: { |
|||
requiredTools: ['market_data_lookup'], |
|||
verificationChecks: [{ check: 'market_data_coverage', status: 'warning' }] |
|||
}, |
|||
id: 'mvp-008-partial-market-coverage', |
|||
input: { |
|||
query: 'Get market prices for AAPL and UNKNOWN', |
|||
sessionId: 'mvp-eval-session-8', |
|||
symbols: ['AAPL', 'UNKNOWN'], |
|||
userId: 'mvp-user' |
|||
}, |
|||
intent: 'partial-coverage', |
|||
setup: { |
|||
holdings: DEFAULT_HOLDINGS, |
|||
llmText: 'Some symbols resolved while others remained unresolved.', |
|||
quotesBySymbol: { |
|||
AAPL: DEFAULT_QUOTES.AAPL |
|||
} |
|||
} |
|||
}, |
|||
{ |
|||
expected: { |
|||
requiredTools: [ |
|||
'portfolio_analysis', |
|||
'risk_assessment', |
|||
'rebalance_plan' |
|||
], |
|||
verificationChecks: [{ check: 'rebalance_coverage', status: 'passed' }] |
|||
}, |
|||
id: 'mvp-009-rebalance-plan', |
|||
input: { |
|||
query: 'Create a rebalance plan for my portfolio', |
|||
sessionId: 'mvp-eval-session-9', |
|||
userId: 'mvp-user' |
|||
}, |
|||
intent: 'rebalance', |
|||
setup: { |
|||
holdings: DEFAULT_HOLDINGS, |
|||
llmText: 'AAPL is overweight and should be trimmed toward your target.', |
|||
quotesBySymbol: DEFAULT_QUOTES |
|||
} |
|||
}, |
|||
{ |
|||
expected: { |
|||
requiredTools: ['portfolio_analysis', 'risk_assessment', 'stress_test'], |
|||
verificationChecks: [{ check: 'stress_test_coherence', status: 'passed' }] |
|||
}, |
|||
id: 'mvp-010-stress-test', |
|||
input: { |
|||
query: 'Run a drawdown stress scenario for my portfolio', |
|||
sessionId: 'mvp-eval-session-10', |
|||
userId: 'mvp-user' |
|||
}, |
|||
intent: 'stress-test', |
|||
setup: { |
|||
holdings: DEFAULT_HOLDINGS, |
|||
llmText: 'A ten percent downside shock indicates manageable drawdown.', |
|||
quotesBySymbol: DEFAULT_QUOTES |
|||
} |
|||
} |
|||
]; |
|||
@ -0,0 +1,84 @@ |
|||
import { DataSource } from '@prisma/client'; |
|||
|
|||
import { |
|||
AiAgentChatResponse, |
|||
AiAgentToolName |
|||
} from '../ai-agent.interfaces'; |
|||
|
|||
export interface AiAgentMvpEvalQuote { |
|||
currency: string; |
|||
marketPrice: number; |
|||
marketState: string; |
|||
} |
|||
|
|||
export interface AiAgentMvpEvalHolding { |
|||
allocationInPercentage: number; |
|||
dataSource: DataSource; |
|||
symbol: string; |
|||
valueInBaseCurrency: number; |
|||
} |
|||
|
|||
export interface AiAgentMvpEvalMemoryTurn { |
|||
answer: string; |
|||
query: string; |
|||
timestamp: string; |
|||
toolCalls: { |
|||
status: 'success' | 'failed'; |
|||
tool: AiAgentToolName; |
|||
}[]; |
|||
} |
|||
|
|||
export interface AiAgentMvpEvalCaseInput { |
|||
languageCode?: string; |
|||
query: string; |
|||
sessionId: string; |
|||
symbols?: string[]; |
|||
userCurrency?: string; |
|||
userId: string; |
|||
} |
|||
|
|||
export interface AiAgentMvpEvalCaseSetup { |
|||
holdings?: Record<string, AiAgentMvpEvalHolding>; |
|||
llmText?: string; |
|||
llmThrows?: boolean; |
|||
marketDataErrorMessage?: string; |
|||
quotesBySymbol?: Record<string, AiAgentMvpEvalQuote>; |
|||
storedMemoryTurns?: AiAgentMvpEvalMemoryTurn[]; |
|||
} |
|||
|
|||
export interface AiAgentMvpEvalToolExpectation { |
|||
status?: 'success' | 'failed'; |
|||
tool: AiAgentToolName; |
|||
} |
|||
|
|||
export interface AiAgentMvpEvalVerificationExpectation { |
|||
check: string; |
|||
status?: 'passed' | 'warning' | 'failed'; |
|||
} |
|||
|
|||
export interface AiAgentMvpEvalCaseExpected { |
|||
answerIncludes?: string[]; |
|||
confidenceScoreMin?: number; |
|||
forbiddenTools?: AiAgentToolName[]; |
|||
memoryTurnsAtLeast?: number; |
|||
minCitations?: number; |
|||
requiredTools?: AiAgentToolName[]; |
|||
requiredToolCalls?: AiAgentMvpEvalToolExpectation[]; |
|||
verificationChecks?: AiAgentMvpEvalVerificationExpectation[]; |
|||
} |
|||
|
|||
export interface AiAgentMvpEvalCase { |
|||
expected: AiAgentMvpEvalCaseExpected; |
|||
id: string; |
|||
input: AiAgentMvpEvalCaseInput; |
|||
intent: string; |
|||
setup: AiAgentMvpEvalCaseSetup; |
|||
} |
|||
|
|||
export interface AiAgentMvpEvalResult { |
|||
durationInMs: number; |
|||
failures: string[]; |
|||
id: string; |
|||
passed: boolean; |
|||
response?: AiAgentChatResponse; |
|||
} |
|||
@ -0,0 +1,109 @@ |
|||
import { DataSource } from '@prisma/client'; |
|||
|
|||
import { AiService } from '../ai.service'; |
|||
|
|||
import { AI_AGENT_MVP_EVAL_DATASET } from './mvp-eval.dataset'; |
|||
import { runMvpEvalSuite } from './mvp-eval.runner'; |
|||
import { AiAgentMvpEvalCase } from './mvp-eval.interfaces'; |
|||
|
|||
function createAiServiceForCase(evalCase: AiAgentMvpEvalCase) { |
|||
const dataProviderService = { |
|||
getQuotes: jest.fn() |
|||
}; |
|||
const portfolioService = { |
|||
getDetails: jest.fn() |
|||
}; |
|||
const propertyService = { |
|||
getByKey: jest.fn() |
|||
}; |
|||
const redisCacheService = { |
|||
get: jest.fn(), |
|||
set: jest.fn() |
|||
}; |
|||
|
|||
portfolioService.getDetails.mockResolvedValue({ |
|||
holdings: |
|||
evalCase.setup.holdings ?? |
|||
({ |
|||
CASH: { |
|||
allocationInPercentage: 1, |
|||
dataSource: DataSource.MANUAL, |
|||
symbol: 'CASH', |
|||
valueInBaseCurrency: 1000 |
|||
} |
|||
} as const) |
|||
}); |
|||
|
|||
dataProviderService.getQuotes.mockImplementation( |
|||
async ({ |
|||
items |
|||
}: { |
|||
items: { dataSource: DataSource; symbol: string }[]; |
|||
}) => { |
|||
if (evalCase.setup.marketDataErrorMessage) { |
|||
throw new Error(evalCase.setup.marketDataErrorMessage); |
|||
} |
|||
|
|||
const quotesBySymbol = evalCase.setup.quotesBySymbol ?? {}; |
|||
|
|||
return items.reduce<Record<string, (typeof quotesBySymbol)[string]>>( |
|||
(result, { symbol }) => { |
|||
if (quotesBySymbol[symbol]) { |
|||
result[symbol] = quotesBySymbol[symbol]; |
|||
} |
|||
|
|||
return result; |
|||
}, |
|||
{} |
|||
); |
|||
} |
|||
); |
|||
|
|||
redisCacheService.get.mockResolvedValue( |
|||
evalCase.setup.storedMemoryTurns |
|||
? JSON.stringify({ |
|||
turns: evalCase.setup.storedMemoryTurns |
|||
}) |
|||
: undefined |
|||
); |
|||
redisCacheService.set.mockResolvedValue(undefined); |
|||
|
|||
const aiService = new AiService( |
|||
dataProviderService as never, |
|||
portfolioService as never, |
|||
propertyService as never, |
|||
redisCacheService as never |
|||
); |
|||
|
|||
if (evalCase.setup.llmThrows) { |
|||
jest.spyOn(aiService, 'generateText').mockRejectedValue(new Error('offline')); |
|||
} else { |
|||
jest.spyOn(aiService, 'generateText').mockResolvedValue({ |
|||
text: evalCase.setup.llmText ?? `Eval response for ${evalCase.id}` |
|||
} as never); |
|||
} |
|||
|
|||
return aiService; |
|||
} |
|||
|
|||
describe('AiAgentMvpEvalSuite', () => { |
|||
it('contains at least five baseline MVP eval cases', () => { |
|||
expect(AI_AGENT_MVP_EVAL_DATASET.length).toBeGreaterThanOrEqual(5); |
|||
}); |
|||
|
|||
it('passes the MVP eval suite with at least 80% success rate', async () => { |
|||
const suiteResult = await runMvpEvalSuite({ |
|||
aiServiceFactory: (evalCase) => createAiServiceForCase(evalCase), |
|||
cases: AI_AGENT_MVP_EVAL_DATASET |
|||
}); |
|||
|
|||
expect(suiteResult.passRate).toBeGreaterThanOrEqual(0.8); |
|||
expect( |
|||
suiteResult.results |
|||
.filter(({ passed }) => !passed) |
|||
.map(({ failures, id }) => { |
|||
return `${id}: ${failures.join(' | ')}`; |
|||
}) |
|||
).toEqual([]); |
|||
}); |
|||
}); |
|||
@ -0,0 +1,183 @@ |
|||
import { AiService } from '../ai.service'; |
|||
|
|||
import { |
|||
AiAgentMvpEvalCase, |
|||
AiAgentMvpEvalResult, |
|||
AiAgentMvpEvalVerificationExpectation |
|||
} from './mvp-eval.interfaces'; |
|||
|
|||
function hasExpectedVerification({ |
|||
actualChecks, |
|||
expectedCheck |
|||
}: { |
|||
actualChecks: { check: string; status: 'passed' | 'warning' | 'failed' }[]; |
|||
expectedCheck: AiAgentMvpEvalVerificationExpectation; |
|||
}) { |
|||
return actualChecks.some(({ check, status }) => { |
|||
if (check !== expectedCheck.check) { |
|||
return false; |
|||
} |
|||
|
|||
if (!expectedCheck.status) { |
|||
return true; |
|||
} |
|||
|
|||
return status === expectedCheck.status; |
|||
}); |
|||
} |
|||
|
|||
function evaluateResponse({ |
|||
evalCase, |
|||
response |
|||
}: { |
|||
evalCase: AiAgentMvpEvalCase; |
|||
response: Awaited<ReturnType<AiService['chat']>>; |
|||
}) { |
|||
const failures: string[] = []; |
|||
const observedTools = response.toolCalls.map(({ tool }) => tool); |
|||
|
|||
for (const requiredTool of evalCase.expected.requiredTools ?? []) { |
|||
if (!observedTools.includes(requiredTool)) { |
|||
failures.push(`Missing required tool: ${requiredTool}`); |
|||
} |
|||
} |
|||
|
|||
for (const forbiddenTool of evalCase.expected.forbiddenTools ?? []) { |
|||
if (observedTools.includes(forbiddenTool)) { |
|||
failures.push(`Forbidden tool executed: ${forbiddenTool}`); |
|||
} |
|||
} |
|||
|
|||
for (const expectedCall of evalCase.expected.requiredToolCalls ?? []) { |
|||
const matched = response.toolCalls.some((toolCall) => { |
|||
return ( |
|||
toolCall.tool === expectedCall.tool && |
|||
(!expectedCall.status || toolCall.status === expectedCall.status) |
|||
); |
|||
}); |
|||
|
|||
if (!matched) { |
|||
failures.push( |
|||
`Missing required tool call: ${expectedCall.tool}${expectedCall.status ? `:${expectedCall.status}` : ''}` |
|||
); |
|||
} |
|||
} |
|||
|
|||
if ( |
|||
typeof evalCase.expected.minCitations === 'number' && |
|||
response.citations.length < evalCase.expected.minCitations |
|||
) { |
|||
failures.push( |
|||
`Expected at least ${evalCase.expected.minCitations} citation(s), got ${response.citations.length}` |
|||
); |
|||
} |
|||
|
|||
if ( |
|||
typeof evalCase.expected.memoryTurnsAtLeast === 'number' && |
|||
response.memory.turns < evalCase.expected.memoryTurnsAtLeast |
|||
) { |
|||
failures.push( |
|||
`Expected memory turns >= ${evalCase.expected.memoryTurnsAtLeast}, got ${response.memory.turns}` |
|||
); |
|||
} |
|||
|
|||
if ( |
|||
typeof evalCase.expected.confidenceScoreMin === 'number' && |
|||
response.confidence.score < evalCase.expected.confidenceScoreMin |
|||
) { |
|||
failures.push( |
|||
`Expected confidence score >= ${evalCase.expected.confidenceScoreMin}, got ${response.confidence.score}` |
|||
); |
|||
} |
|||
|
|||
for (const expectedText of evalCase.expected.answerIncludes ?? []) { |
|||
if (!response.answer.includes(expectedText)) { |
|||
failures.push(`Answer does not include expected text: "${expectedText}"`); |
|||
} |
|||
} |
|||
|
|||
for (const expectedVerification of evalCase.expected.verificationChecks ?? []) { |
|||
if ( |
|||
!hasExpectedVerification({ |
|||
actualChecks: response.verification, |
|||
expectedCheck: expectedVerification |
|||
}) |
|||
) { |
|||
failures.push( |
|||
`Missing verification check: ${expectedVerification.check}${expectedVerification.status ? `:${expectedVerification.status}` : ''}` |
|||
); |
|||
} |
|||
} |
|||
|
|||
return failures; |
|||
} |
|||
|
|||
export async function runMvpEvalCase({ |
|||
aiService, |
|||
evalCase |
|||
}: { |
|||
aiService: AiService; |
|||
evalCase: AiAgentMvpEvalCase; |
|||
}): Promise<AiAgentMvpEvalResult> { |
|||
const startedAt = Date.now(); |
|||
|
|||
try { |
|||
const response = await aiService.chat({ |
|||
languageCode: evalCase.input.languageCode ?? 'en', |
|||
query: evalCase.input.query, |
|||
sessionId: evalCase.input.sessionId, |
|||
symbols: evalCase.input.symbols, |
|||
userCurrency: evalCase.input.userCurrency ?? 'USD', |
|||
userId: evalCase.input.userId |
|||
}); |
|||
|
|||
const failures = evaluateResponse({ |
|||
evalCase, |
|||
response |
|||
}); |
|||
|
|||
return { |
|||
durationInMs: Date.now() - startedAt, |
|||
failures, |
|||
id: evalCase.id, |
|||
passed: failures.length === 0, |
|||
response |
|||
}; |
|||
} catch (error) { |
|||
return { |
|||
durationInMs: Date.now() - startedAt, |
|||
failures: [error instanceof Error ? error.message : 'unknown eval error'], |
|||
id: evalCase.id, |
|||
passed: false |
|||
}; |
|||
} |
|||
} |
|||
|
|||
export async function runMvpEvalSuite({ |
|||
aiServiceFactory, |
|||
cases |
|||
}: { |
|||
aiServiceFactory: (evalCase: AiAgentMvpEvalCase) => AiService; |
|||
cases: AiAgentMvpEvalCase[]; |
|||
}) { |
|||
const results: AiAgentMvpEvalResult[] = []; |
|||
|
|||
for (const evalCase of cases) { |
|||
results.push( |
|||
await runMvpEvalCase({ |
|||
aiService: aiServiceFactory(evalCase), |
|||
evalCase |
|||
}) |
|||
); |
|||
} |
|||
|
|||
const passed = results.filter(({ passed: isPassed }) => isPassed).length; |
|||
const passRate = cases.length > 0 ? passed / cases.length : 0; |
|||
|
|||
return { |
|||
passRate, |
|||
passed, |
|||
results, |
|||
total: cases.length |
|||
}; |
|||
} |
|||
@ -0,0 +1,46 @@ |
|||
export type AiAgentToolName = |
|||
| 'portfolio_analysis' |
|||
| 'risk_assessment' |
|||
| 'market_data_lookup' |
|||
| 'rebalance_plan' |
|||
| 'stress_test'; |
|||
|
|||
export type AiAgentConfidenceBand = 'high' | 'medium' | 'low'; |
|||
|
|||
export interface AiAgentCitation { |
|||
confidence: number; |
|||
snippet: string; |
|||
source: AiAgentToolName; |
|||
} |
|||
|
|||
export interface AiAgentConfidence { |
|||
band: AiAgentConfidenceBand; |
|||
score: number; |
|||
} |
|||
|
|||
export interface AiAgentVerificationCheck { |
|||
check: string; |
|||
details: string; |
|||
status: 'passed' | 'warning' | 'failed'; |
|||
} |
|||
|
|||
export interface AiAgentToolCall { |
|||
input: Record<string, unknown>; |
|||
outputSummary: string; |
|||
status: 'success' | 'failed'; |
|||
tool: AiAgentToolName; |
|||
} |
|||
|
|||
export interface AiAgentMemorySnapshot { |
|||
sessionId: string; |
|||
turns: number; |
|||
} |
|||
|
|||
export interface AiAgentChatResponse { |
|||
answer: string; |
|||
citations: AiAgentCitation[]; |
|||
confidence: AiAgentConfidence; |
|||
memory: AiAgentMemorySnapshot; |
|||
toolCalls: AiAgentToolCall[]; |
|||
verification: AiAgentVerificationCheck[]; |
|||
} |
|||
Loading…
Reference in new issue