diff --git a/agent/graph.py b/agent/graph.py index bc0da6e65..072cb42c5 100644 --- a/agent/graph.py +++ b/agent/graph.py @@ -76,6 +76,33 @@ try: except ImportError: _RE_STRATEGY_AVAILABLE = False +# Model selection constants +FAST_MODEL = "claude-haiku-4-5-20251001" +SMART_MODEL = "claude-sonnet-4-20250514" + +# Query types that need Sonnet for quality +COMPLEX_QUERY_TYPES = { + "life_decision", + "family_planner", + "wealth_gap", + "wealth_down_payment", + "wealth_job_offer", + "wealth_global_city", + "wealth_portfolio_summary", + "equity_unlock", + "real_estate_detail", + "real_estate_snapshot", + "real_estate_search", + "real_estate_compare", +} + +def get_model_for_query(query_type: str) -> str: + """Returns appropriate model based on query complexity.""" + if query_type in COMPLEX_QUERY_TYPES: + return SMART_MODEL + return FAST_MODEL + + SYSTEM_PROMPT = """You are a portfolio analysis assistant integrated with Ghostfolio wealth management software. REASONING PROTOCOL — silently reason through these four steps BEFORE writing your response. @@ -2393,8 +2420,10 @@ async def format_node(state: AgentState) -> AgentState: ), }) try: + _qt = state.get("query_type", "portfolio") + _model = get_model_for_query(_qt) response_obj = client.messages.create( - model="claude-sonnet-4-20250514", + model=_model, max_tokens=800, system=SYSTEM_PROMPT, messages=api_messages_ctx, @@ -2531,8 +2560,10 @@ async def format_node(state: AgentState) -> AgentState: actual_input_tokens: int | None = None actual_output_tokens: int | None = None try: + _qt = state.get("query_type", "portfolio") + _model = get_model_for_query(_qt) response_obj = client.messages.create( - model="claude-sonnet-4-20250514", + model=_model, max_tokens=800, system=SYSTEM_PROMPT, messages=api_messages,