from langgraph.checkpoint.memory import MemorySaver from langgraph.graph.message import add_messages from scripts.load_model import get_model from typing import TypedDict, Literal, Optional, Annotated from langgraph.graph import StateGraph, END from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from scripts.prompts import ( GENERAL_CHAT_SYSTEM_PROMPT, CODE_AGENT_SYSTEM_PROMPT, REPAIR_SYSTEM_PROMPT ) # ========================================================= # LLM # ========================================================= llm = get_model() # ========================================================= # STATE # ========================================================= class ChatState(TypedDict): user_input: Annotated[list[BaseMessage], add_messages] intent: Optional[Literal["coding", "general_chat"]] task_type: Optional[Literal["code_generation", "debugging", "explanation"]] language: Optional[Literal["python", "javascript", "cpp", "sql", "java", "unknown"]] generated_output: Optional[str] is_code: Optional[bool] retry_count: int # ========================================================= # HELPERS # ========================================================= def get_last_user_message(state: ChatState) -> str: messages = state["user_input"] if not messages: return "" return messages[-1].content # ========================================================= # INTENT CLASSIFIER # ========================================================= def classify_intent(state: ChatState): query = get_last_user_message(state).lower() coding_keywords = [ "code", "python", "javascript", "java", "c++", "sql", "bug", "debug", "fix", "function", "algorithm", "class", "implement", "build", "api", "query", "optimize", "refactor", "script", "program" ] intent = "coding" if any(word in query for word in coding_keywords) else "general_chat" return {"intent": intent} def route_intent(state: ChatState): if state["intent"] == "general_chat": return "general_chat" return "coding" # ========================================================= # GENERAL CHAT # ========================================================= def handle_general_chat(state: ChatState): query = get_last_user_message(state) messages = [ SystemMessage(content=GENERAL_CHAT_SYSTEM_PROMPT), HumanMessage(content=query) ] response = llm.invoke(messages) return { "generated_output": response.content } # ========================================================= # TASK CLASSIFIER # ========================================================= def classify_task(state: ChatState): query = get_last_user_message(state).lower() debugging_keywords = [ "fix", "debug", "error", "broken", "issue", "not working", "exception", "traceback" ] explanation_keywords = [ "explain", "what is", "how does", "why", "difference between" ] if any(word in query for word in debugging_keywords): task = "debugging" elif any(word in query for word in explanation_keywords): task = "explanation" else: task = "code_generation" return {"task_type": task} # ========================================================= # LANGUAGE DETECTOR # ========================================================= def detect_language(state: ChatState): query = get_last_user_message(state).lower() if any(k in query for k in ["python", "def ", "import ", "print("]): lang = "python" elif any(k in query for k in ["javascript", "js", "function ", "console.log", "const ", "let "]): lang = "javascript" elif any(k in query for k in ["c++", "#include", "std::", "cout"]): lang = "cpp" elif any(k in query for k in ["java", "public class", "system.out"]): lang = "java" elif any(k in query for k in ["sql", "select ", "insert ", "update ", "delete ", "join "]): lang = "sql" else: lang = "unknown" return {"language": lang} # ========================================================= # GENERATOR # ========================================================= def generate_code(state: ChatState): query = get_last_user_message(state) task = state["task_type"] language = state["language"] messages = [ SystemMessage(content=CODE_AGENT_SYSTEM_PROMPT), HumanMessage(content=f""" Task Type: {task} Requested Language: {language} User Request: {query} """) ] response = llm.invoke(messages) return { "generated_output": response.content } # ========================================================= # OUTPUT CLASSIFIER # ========================================================= def classify_output(state: ChatState): output = (state["generated_output"] or "").lower() code_markers = [ "def ", "class ", "function ", "const ", "let ", "#include", "std::", "public class", "system.out", "select ", "insert ", "update ", "delete ", "console.log", "print(", "fn " ] is_code = any(marker in output for marker in code_markers) return { "is_code": is_code } # ========================================================= # ROUTER # ========================================================= def route_output(state: ChatState): if state["task_type"] == "explanation": return "final" if state["is_code"]: return "final" if state["retry_count"] >= 2: return "final" return "repair" # ========================================================= # REPAIR # ========================================================= def repair_code(state: ChatState): bad_output = state["generated_output"] or "" language = state["language"] messages = [ SystemMessage(content=REPAIR_SYSTEM_PROMPT), HumanMessage(content=f""" The previous response was expected to be executable code but was not. Requested language: {language} Bad output: {bad_output} Return ONLY executable code. No explanation. No markdown. """) ] response = llm.invoke(messages) return { "generated_output": response.content, "retry_count": state["retry_count"] + 1 } # ========================================================= # GRAPH # ========================================================= checkpointer = MemorySaver() builder = StateGraph(ChatState) builder.add_node("classify_intent", classify_intent) builder.add_node("general_chat", handle_general_chat) builder.add_node("classify_task", classify_task) builder.add_node("detect_language", detect_language) builder.add_node("generate_code", generate_code) builder.add_node("classify_output", classify_output) builder.add_node("repair", repair_code) builder.set_entry_point("classify_intent") builder.add_conditional_edges( "classify_intent", route_intent, { "general_chat": "general_chat", "coding": "classify_task" } ) builder.add_edge("general_chat", END) builder.add_edge("classify_task", "detect_language") builder.add_edge("detect_language", "generate_code") builder.add_edge("generate_code", "classify_output") builder.add_conditional_edges( "classify_output", route_output, { "final": END, "repair": "repair" } ) builder.add_edge("repair", "classify_output") graph = builder.compile(checkpointer=checkpointer) # ========================================================= # STREAM # ========================================================= def stream_chat_response(user_message: str, thread_id: str): config = {"configurable": {"thread_id": thread_id}} initial_state = { "user_input": [HumanMessage(content=user_message)], "intent": None, "task_type": None, "language": None, "generated_output": None, "is_code": None, "retry_count": 0 } for chunk in graph.stream( initial_state, config=config, stream_mode="messages" ): yield chunk def get_chat_metadata(thread_id: str): config = {"configurable": {"thread_id": thread_id}} state = graph.get_state(config) values = state.values return { "intent": values.get("intent"), "task_type": values.get("task_type"), "language": values.get("language"), "is_code": values.get("is_code"), "retry_count": values.get("retry_count") }