Browse Source

fix: accept both query and message field names in /chat endpoint — backwards compatible

Made-with: Cursor
pull/6453/head
Priyanka Punukollu 1 month ago
parent
commit
6ff1dae0e9
  1. 38
      main.py

38
main.py

@ -16,6 +16,7 @@ from fastapi import FastAPI, Response, Depends, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, StreamingResponse, HTMLResponse, JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional
from pydantic import BaseModel
from dotenv import load_dotenv
import httpx
@ -118,7 +119,9 @@ def log_error(error: Exception, context: dict) -> dict:
class ChatRequest(BaseModel):
query: str
query: Optional[str] = None
message: Optional[str] = None
session_id: Optional[str] = None
history: list[dict] = []
# Clients must echo back pending_write from the previous response when
# the user is confirming (or cancelling) a write operation.
@ -128,6 +131,10 @@ class ChatRequest(BaseModel):
# on the caller's own portfolio data instead of the shared env-var token.
bearer_token: str | None = None
def get_text(self) -> str:
"""Accept either field name for backwards compatibility."""
return (self.query or self.message or "").strip()
class FeedbackRequest(BaseModel):
query: str
@ -141,6 +148,23 @@ async def chat(req: ChatRequest, gf_token: str = Depends(require_auth)):
start_time = time.time()
trace_id = str(uuid.uuid4())
user_text = req.get_text()
if not user_text:
return {
"error": "No message provided",
"response": "Please send a message.",
"confidence": 0.0,
"verified": False,
"latency_ms": int((time.time() - start_time) * 1000),
"trace_id": trace_id,
"tokens": {
"estimated_input": 0,
"estimated_output": 0,
"estimated_total": 0,
"estimated_cost_usd": 0.0,
},
}
# Build conversation history preserving both user AND assistant turns so
# Claude has full context for follow-up questions.
history_messages = []
@ -179,7 +203,7 @@ async def chat(req: ChatRequest, gf_token: str = Depends(require_auth)):
log_error(
e,
{
"message": req.query[:200],
"message": user_text[:200],
"session_id": None,
"query_type": "unknown",
},
@ -205,7 +229,7 @@ async def chat(req: ChatRequest, gf_token: str = Depends(require_auth)):
cost_log.append({
"timestamp": datetime.utcnow().isoformat(),
"query": req.query[:80],
"query": user_text[:80],
"estimated_cost_usd": round(COST_PER_REQUEST_USD, 5),
"latency_seconds": elapsed,
})
@ -328,6 +352,7 @@ async def chat_stream(req: ChatRequest, gf_token: str = Depends(require_auth)):
Runs the full graph, then streams the final response word by word so
the user sees output immediately rather than waiting for the full response.
"""
user_text = req.get_text()
history_messages = []
for m in req.history:
@ -339,7 +364,7 @@ async def chat_stream(req: ChatRequest, gf_token: str = Depends(require_auth)):
history_messages.append(AIMessage(content=content))
initial_state: AgentState = {
"user_query": req.query,
"user_query": user_text,
"messages": history_messages,
"query_type": "",
"portfolio_snapshot": {},
@ -607,6 +632,7 @@ async def chat_steps(req: ChatRequest, gf_token: str = Depends(require_auth)):
then a meta event with final metadata, then token events for the response.
"""
start = time.time()
user_text = req.get_text()
history_messages = []
for m in req.history:
@ -618,7 +644,7 @@ async def chat_steps(req: ChatRequest, gf_token: str = Depends(require_auth)):
history_messages.append(AIMessage(content=content))
initial_state: AgentState = {
"user_query": req.query,
"user_query": user_text,
"messages": history_messages,
"query_type": "",
"portfolio_snapshot": {},
@ -680,7 +706,7 @@ async def chat_steps(req: ChatRequest, gf_token: str = Depends(require_auth)):
cost_log.append({
"timestamp": datetime.utcnow().isoformat(),
"query": req.query[:80],
"query": user_text[:80],
"estimated_cost_usd": round(COST_PER_REQUEST_USD, 5),
"latency_seconds": elapsed,
})

Loading…
Cancel
Save