Skip to content

11 自定义中间件

上一篇文章我们学习了LangChain的内置中间件,它们能覆盖大部分常见场景。但在实际项目中,你往往需要更灵活的控制逻辑,比如:

  • 根据用户身份动态修改系统提示词
  • 记录每次工具调用的耗时
  • 在特定条件下提前终止Agent
  • 动态选择使用哪个模型

这时候就需要自定义中间件了。

一、两种Hook类型

LangChain的中间件提供了两种Hook(钩子)类型:

1.1 Node-style Hook(节点式)

在特定的执行点顺序运行,适合做日志记录、数据校验、状态更新等。

Hook何时运行
before_agentAgent开始前运行一次
before_model每次模型调用前运行
after_model每次模型调用后运行
after_agentAgent结束后运行一次

1.2 Wrap-style Hook(包裹式)

包裹在每次调用外面,可以控制调用次数(0次=短路,1次=正常,多次=重试),适合做重试、缓存、格式转换等。

Hook何时运行
wrap_model_call包裹每次模型调用
wrap_tool_call包裹每次工具调用

二、装饰器方式

对于简单的单Hook中间件,用装饰器方式最简洁。

2.1 before_model - 模型调用前

在每次模型调用前执行,比如记录日志:

python
from langchain.agents.middleware import before_model, AgentState
from langchain.agents import create_agent
from langgraph.runtime import Runtime
from typing import Any


@before_model
def log_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    print(f"即将调用模型,当前消息数: {len(state['messages'])}")
    return None  # 返回None表示不修改状态


agent = create_agent(
    model="deepseek-v4-flash",
    middleware=[log_before_model],
    tools=[...],
)

2.2 after_model - 模型调用后

在每次模型调用后执行,比如记录模型输出:

python
from langchain.agents.middleware import after_model, AgentState
from langgraph.runtime import Runtime
from typing import Any


@after_model
def log_after_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    print(f"模型返回: {state['messages'][-1].content}")
    return None


agent = create_agent(
    model="deepseek-v4-flash",
    middleware=[log_after_model],
    tools=[...],
)

2.3 wrap_model_call - 包裹模型调用

包裹模型调用,可以控制是否调用、调用几次,适合做重试逻辑:

python
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from langchain.agents import create_agent
from typing import Callable


@wrap_model_call
def retry_model(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    """模型调用失败时自动重试3次"""
    for attempt in range(3):
        try:
            return handler(request)
        except Exception as e:
            if attempt == 2:
                raise
            print(f"模型调用失败(第{attempt + 1}次),重试中...")
    # 不会走到这里,但Python需要这个返回
    return handler(request)


agent = create_agent(
    model="deepseek-v4-flash",
    middleware=[retry_model],
    tools=[...],
)

2.4 组合多个装饰器

一个中间件列表里可以放多个装饰器:

python
from langchain.agents import create_agent
from langchain.agents.middleware import before_model, after_model, wrap_model_call, AgentState, ModelRequest, ModelResponse
from langgraph.runtime import Runtime
from typing import Any, Callable


@before_model
def log_before(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    print(f"[开始] 消息数: {len(state['messages'])}")
    return None


@after_model
def log_after(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    print(f"[结束] 模型输出: {state['messages'][-1].content[:50]}...")
    return None


@wrap_model_call
def retry_model(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    for attempt in range(3):
        try:
            return handler(request)
        except Exception as e:
            if attempt == 2:
                raise
            print(f"重试 {attempt + 1}/3: {e}")


agent = create_agent(
    model="deepseek-v4-flash",
    middleware=[log_before, retry_model, log_after],
    tools=[...],
)

三、类方式

当你的中间件比较复杂,需要多个Hook配合、有配置参数、或者需要同时支持同步和异步时,用类方式更合适。

3.1 基本类中间件

python
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain.agents import create_agent
from langgraph.runtime import Runtime
from typing import Any


class LoggingMiddleware(AgentMiddleware):
    """日志中间件:记录模型调用的输入输出"""

    def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        print(f"[日志] 即将调用模型,消息数: {len(state['messages'])}")
        return None

    def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        print(f"[日志] 模型返回: {state['messages'][-1].content[:50]}...")
        return None

    async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        # 异步版本,用于astream等异步场景
        return None

    async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        return None


agent = create_agent(
    model="deepseek-v4-flash",
    middleware=[LoggingMiddleware()],
    tools=[...],
)

3.2 带配置的中间件

通过__init__传入配置参数:

python
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any


class MessageLimitMiddleware(AgentMiddleware):
    """消息数量限制中间件:超过限制时自动结束"""

    def __init__(self, max_messages: int = 50):
        super().__init__()
        self.max_messages = max_messages

    def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        if len(state["messages"]) >= self.max_messages:
            return {
                "messages": [AIMessage("对话轮次已达上限。")],
                "jump_to": "end",  # 跳转到Agent结束
            }
        return None


agent = create_agent(
    model="deepseek-v4-flash",
    middleware=[MessageLimitMiddleware(max_messages=20)],
    tools=[...],
)

3.3 重试中间件(类方式)

python
from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse
from typing import Callable


class RetryMiddleware(AgentMiddleware):
    """模型调用重试中间件"""

    def __init__(self, max_retries: int = 3):
        super().__init__()
        self.max_retries = max_retries

    def wrap_model_call(
        self,
        request: ModelRequest,
        handler: Callable[[ModelRequest], ModelResponse],
    ) -> ModelResponse:
        for attempt in range(self.max_retries):
            try:
                return handler(request)
            except Exception as e:
                if attempt == self.max_retries - 1:
                    raise
                print(f"重试 {attempt + 1}/{self.max_retries}: {e}")


agent = create_agent(
    model="deepseek-v4-flash",
    middleware=[RetryMiddleware(max_retries=5)],
    tools=[...],
)

四、更新状态

中间件可以修改Agent的状态,比如记录调用次数、追踪token使用量等。

4.1 Node-style Hook更新状态

直接返回一个字典,字典的key-value会合并到Agent状态中:

python
from langchain.agents.middleware import after_model, AgentState
from langgraph.runtime import Runtime
from typing import Any
from typing_extensions import NotRequired


class TrackingState(AgentState):
    model_call_count: NotRequired[int]


@after_model(state_schema=TrackingState)
def increment_counter(state: TrackingState, runtime: Runtime) -> dict[str, Any] | None:
    return {"model_call_count": state.get("model_call_count", 0) + 1}

4.2 Wrap-style Hook更新状态

返回ExtendedModelResponse配合Command来更新状态:

python
from typing import Callable
from langchain.agents.middleware import (
    wrap_model_call, ModelRequest, ModelResponse,
    AgentState, ExtendedModelResponse
)
from langgraph.types import Command
from typing_extensions import NotRequired


class UsageTrackingState(AgentState):
    last_model_call_tokens: NotRequired[int]


@wrap_model_call(state_schema=UsageTrackingState)
def track_usage(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
    response = handler(request)
    # 假设我们从response中获取了token使用量
    return ExtendedModelResponse(
        model_response=response,
        command=Command(update={"last_model_call_tokens": 150}),
    )

五、跳转控制

中间件可以在特定条件下让Agent提前结束、跳到工具节点、或跳到模型节点。

可用的跳转目标:

目标说明
"end"跳到Agent执行结束
"tools"跳到工具节点
"model"跳到模型节点

5.1 提前结束

当检测到敏感内容时,提前结束Agent:

python
from langchain.agents.middleware import after_model, AgentState
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any


@after_model(can_jump_to=["end"])
def check_blocked(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    last_message = state["messages"][-1]
    if "BLOCKED" in last_message.content:
        return {
            "messages": [AIMessage("我无法回答这个问题。")],
            "jump_to": "end",
        }
    return None


agent = create_agent(
    model="deepseek-v4-flash",
    middleware=[check_blocked],
    tools=[...],
)

5.2 类方式的跳转

python
from langchain.agents.middleware import AgentMiddleware, hook_config, AgentState
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any


class BlockedContentMiddleware(AgentMiddleware):
    @hook_config(can_jump_to=["end"])
    def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        last_message = state["messages"][-1]
        if "BLOCKED" in last_message.content:
            return {
                "messages": [AIMessage("我无法回答这个问题。")],
                "jump_to": "end",
            }
        return None

六、实用示例

6.1 动态修改系统提示词

根据用户身份动态注入上下文信息:

python
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from langchain.agents import create_agent
from typing import Callable


@wrap_model_call
def inject_user_context(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    """根据用户身份动态修改系统提示词"""
    # 从状态或上下文中获取用户信息
    user_role = request.runtime.context.get("user_role", "guest")

    # 动态构建系统提示词
    role_prompt = {
        "admin": "你是一个管理员助手,可以执行所有操作。",
        "user": "你是一个普通用户助手,只能查询信息。",
        "guest": "你是一个访客助手,只能查看公开信息。",
    }

    # 修改系统提示词
    new_prompt = role_prompt.get(user_role, role_prompt["guest"])
    modified_request = request.override(system_prompt=new_prompt)

    return handler(modified_request)


agent = create_agent(
    model="deepseek-v4-flash",
    middleware=[inject_user_context],
    tools=[...],
)

6.2 动态选择工具

根据用户问题动态过滤工具,减少token消耗:

python
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from langchain.agents import create_agent
from typing import Callable


@wrap_model_call
def select_relevant_tools(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    """根据用户问题动态选择相关工具"""
    last_message = request.state["messages"][-1].content

    # 简单的关键词匹配,实际项目可以用更复杂的逻辑
    if "天气" in last_message or "气温" in last_message:
        tools = [get_weather]  # 只保留天气相关工具
    elif "文件" in last_message or "读取" in last_message:
        tools = [read_file, write_file]  # 只保留文件相关工具
    else:
        tools = request.tools  # 使用所有工具

    return handler(request.override(tools=tools))


agent = create_agent(
    model="deepseek-v4-flash",
    middleware=[select_relevant_tools],
    tools=[get_weather, read_file, write_file, send_email, ...],
)

6.3 工具调用监控

记录每个工具调用的耗时:

python
import time
from langchain.agents.middleware import wrap_tool_call, ToolRequest, ToolResponse
from langchain.agents import create_agent
from typing import Callable


@wrap_tool_call
def monitor_tool_call(
    request: ToolRequest,
    handler: Callable[[ToolRequest], ToolResponse],
) -> ToolResponse:
    """记录工具调用耗时"""
    start_time = time.time()
    print(f"[监控] 开始调用工具: {request.name}")

    try:
        result = handler(request)
        elapsed = time.time() - start_time
        print(f"[监控] 工具 {request.name} 调用成功,耗时: {elapsed:.2f}秒")
        return result
    except Exception as e:
        elapsed = time.time() - start_time
        print(f"[监控] 工具 {request.name} 调用失败,耗时: {elapsed:.2f}秒,错误: {e}")
        raise


agent = create_agent(
    model="deepseek-v4-flash",
    middleware=[monitor_tool_call],
    tools=[...],
)

七、最佳实践

  1. 保持单一职责:每个中间件只做一件事,不要把日志、重试、脱敏全塞到一个中间件里
  2. 选择合适的Hook类型
    • 顺序逻辑(日志、校验)用Node-style
    • 控制流逻辑(重试、缓存)用Wrap-style
  3. 注意执行顺序:把关键的中间件放在列表前面
  4. 处理好错误:不要让中间件的错误导致整个Agent崩溃
  5. 同步和异步都要实现:如果中间件需要在invokeainvoke中都能用,记得实现同步和异步两个版本

八、总结

自定义中间件让你可以在Agent执行的各个关键节点插入自己的逻辑:

  • 装饰器方式:简单直接,适合单Hook场景
  • 类方式:功能强大,适合多Hook、有配置、需要异步的场景
  • 六种Hookbefore_agentbefore_modelafter_modelafter_agentwrap_model_callwrap_tool_call
  • 状态更新:Node-style返回字典,Wrap-style返回ExtendedModelResponse + Command
  • 跳转控制:通过jump_to可以提前结束、跳到工具节点或模型节点

掌握中间件,你就掌握了LangChain中最灵活的扩展机制。