diff --git a/agent/graph.py b/agent/graph.py index 63d09c15b..bc0da6e65 100644 --- a/agent/graph.py +++ b/agent/graph.py @@ -169,18 +169,44 @@ def _get_client() -> anthropic.Anthropic: def _extract_ticker(query: str, fallback: str = None) -> str | None: """ Extracts the most likely stock ticker from a query string. - Looks for 1-5 uppercase letters. + Handles typos (APPL→AAPL), company names (APPLE→AAPL), and "share of TICKER" phrasing. Returns fallback (default None) if no ticker found. Pass fallback='SPY' for market queries that require a symbol. """ - words = query.upper().split() + # Common misspellings and aliases + TICKER_CORRECTIONS = { + "APPL": "AAPL", + "APPL.": "AAPL", + "APPLE": "AAPL", + "GOOG": "GOOGL", + "GOOGLE": "GOOGL", + "ALPHABET": "GOOGL", + "AMAZON": "AMZN", + "MICROSOFT": "MSFT", + "NVIDIA": "NVDA", + "TESLA": "TSLA", + "META": "META", + "FACEBOOK": "META", + } + + message = query.strip() + msg_upper = message.upper() + + # Pattern: "share of TICKER" or "shares of TICKER" — check first + share_of_match = re.search(r"share[s]?\s+of\s+([A-Z]{1,5})", msg_upper) + if share_of_match: + candidate = share_of_match.group(1) + return TICKER_CORRECTIONS.get(candidate, candidate) + + words = msg_upper.split() known_tickers = {"AAPL", "MSFT", "NVDA", "TSLA", "GOOGL", "GOOG", "AMZN", - "META", "NFLX", "SPY", "QQQ", "BRK", "BRKB"} + "META", "NFLX", "SPY", "QQQ", "BRK", "BRKB", "VTI"} for word in words: clean = re.sub(r"[^A-Z]", "", word) - if clean in known_tickers: - return clean + corrected = TICKER_CORRECTIONS.get(clean, clean) + if corrected in known_tickers: + return corrected for word in words: clean = re.sub(r"[^A-Z]", "", word) @@ -202,7 +228,7 @@ def _extract_ticker(query: str, fallback: str = None) -> str | None: "COULD", "SHOULD", "MIGHT", "SHALL", "ONLY", "ALSO", "SINCE", "WHILE", "STILL", "AGAIN", "THOSE", "OTHER", }: - return clean + return TICKER_CORRECTIONS.get(clean, clean) return fallback @@ -777,8 +803,10 @@ async def classify_node(state: AgentState) -> AgentState: # Check BEFORE performance/portfolio fallback. User asking about market price of a ticker. stock_price_kws = [ "stock price", "share price", "price of", "current price", + "share of", "shares of", "price for", "stock for", "trading for", + "worth today", "per share", "what is aapl", "what is msft", "what is nvda", "what is tsla", - "what is googl", "what is amzn", "what is meta", + "what is googl", "what is amzn", "what is meta", "what is vti", "trading at", "price today", "how much is", "ticker", "quote", "what's the stock price", "whats the stock price", ]