diff --git a/agent/graph.py b/agent/graph.py index 9e7c1b744..943ddb93f 100644 --- a/agent/graph.py +++ b/agent/graph.py @@ -827,6 +827,20 @@ async def classify_node(state: AgentState) -> AgentState: if has_overview: return {**state, "query_type": "market_overview"} + # --- "my TICKER stock" = stock price, not portfolio holding --- + # Check BEFORE portfolio_ticker_kws ("my share of" = portfolio) + _TICKER_CORRECTIONS = { + "APPL": "AAPL", "APPL.": "AAPL", "APPLE": "AAPL", + "GOOG": "GOOGL", "GOOGLE": "GOOGL", "ALPHABET": "GOOGL", + "AMAZON": "AMZN", "MICROSOFT": "MSFT", "NVIDIA": "NVDA", + "TESLA": "TSLA", "META": "META", "FACEBOOK": "META", + } + my_stock_match = re.search(r"my\s+([A-Za-z]{1,5})\s+stock", query, re.IGNORECASE) + if my_stock_match: + candidate = my_stock_match.group(1).upper() + corrected = _TICKER_CORRECTIONS.get(candidate, candidate) + return {**state, "query_type": "market"} + # --- Possessive portfolio queries — check BEFORE stock price keywords --- # "my share of AAPL" = portfolio holding, not stock price portfolio_ticker_kws = [ @@ -847,6 +861,9 @@ async def classify_node(state: AgentState) -> AgentState: ] if any(kw in query for kw in portfolio_ticker_kws): return {**state, "query_type": "performance"} + # "my AAPL position" = portfolio holding (regex: my + optional ticker + position) + if re.search(r"my\s+([A-Za-z]{1,5}\s+)?position", query, re.IGNORECASE): + return {**state, "query_type": "performance"} # --- Stock price / market quote queries — MUST route to market_data not portfolio --- # Check BEFORE performance/portfolio fallback. User asking about market price of a ticker.