11 自定义中间件
上一篇文章我们学习了LangChain的内置中间件,它们能覆盖大部分常见场景。但在实际项目中,你往往需要更灵活的控制逻辑,比如:
- 根据用户身份动态修改系统提示词
- 记录每次工具调用的耗时
- 在特定条件下提前终止Agent
- 动态选择使用哪个模型
这时候就需要自定义中间件了。
一、两种Hook类型
LangChain的中间件提供了两种Hook(钩子)类型:
1.1 Node-style Hook(节点式)
在特定的执行点顺序运行,适合做日志记录、数据校验、状态更新等。
| Hook | 何时运行 |
|---|---|
before_agent | Agent开始前运行一次 |
before_model | 每次模型调用前运行 |
after_model | 每次模型调用后运行 |
after_agent | Agent结束后运行一次 |
1.2 Wrap-style Hook(包裹式)
包裹在每次调用外面,可以控制调用次数(0次=短路,1次=正常,多次=重试),适合做重试、缓存、格式转换等。
| Hook | 何时运行 |
|---|---|
wrap_model_call | 包裹每次模型调用 |
wrap_tool_call | 包裹每次工具调用 |
二、装饰器方式
对于简单的单Hook中间件,用装饰器方式最简洁。
2.1 before_model - 模型调用前
在每次模型调用前执行,比如记录日志:
python
from langchain.agents.middleware import before_model, AgentState
from langchain.agents import create_agent
from langgraph.runtime import Runtime
from typing import Any
@before_model
def log_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
print(f"即将调用模型,当前消息数: {len(state['messages'])}")
return None # 返回None表示不修改状态
agent = create_agent(
model="deepseek-v4-flash",
middleware=[log_before_model],
tools=[...],
)2.2 after_model - 模型调用后
在每次模型调用后执行,比如记录模型输出:
python
from langchain.agents.middleware import after_model, AgentState
from langgraph.runtime import Runtime
from typing import Any
@after_model
def log_after_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
print(f"模型返回: {state['messages'][-1].content}")
return None
agent = create_agent(
model="deepseek-v4-flash",
middleware=[log_after_model],
tools=[...],
)2.3 wrap_model_call - 包裹模型调用
包裹模型调用,可以控制是否调用、调用几次,适合做重试逻辑:
python
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from langchain.agents import create_agent
from typing import Callable
@wrap_model_call
def retry_model(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""模型调用失败时自动重试3次"""
for attempt in range(3):
try:
return handler(request)
except Exception as e:
if attempt == 2:
raise
print(f"模型调用失败(第{attempt + 1}次),重试中...")
# 不会走到这里,但Python需要这个返回
return handler(request)
agent = create_agent(
model="deepseek-v4-flash",
middleware=[retry_model],
tools=[...],
)2.4 组合多个装饰器
一个中间件列表里可以放多个装饰器:
python
from langchain.agents import create_agent
from langchain.agents.middleware import before_model, after_model, wrap_model_call, AgentState, ModelRequest, ModelResponse
from langgraph.runtime import Runtime
from typing import Any, Callable
@before_model
def log_before(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
print(f"[开始] 消息数: {len(state['messages'])}")
return None
@after_model
def log_after(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
print(f"[结束] 模型输出: {state['messages'][-1].content[:50]}...")
return None
@wrap_model_call
def retry_model(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
for attempt in range(3):
try:
return handler(request)
except Exception as e:
if attempt == 2:
raise
print(f"重试 {attempt + 1}/3: {e}")
agent = create_agent(
model="deepseek-v4-flash",
middleware=[log_before, retry_model, log_after],
tools=[...],
)三、类方式
当你的中间件比较复杂,需要多个Hook配合、有配置参数、或者需要同时支持同步和异步时,用类方式更合适。
3.1 基本类中间件
python
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain.agents import create_agent
from langgraph.runtime import Runtime
from typing import Any
class LoggingMiddleware(AgentMiddleware):
"""日志中间件:记录模型调用的输入输出"""
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
print(f"[日志] 即将调用模型,消息数: {len(state['messages'])}")
return None
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
print(f"[日志] 模型返回: {state['messages'][-1].content[:50]}...")
return None
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
# 异步版本,用于astream等异步场景
return None
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
return None
agent = create_agent(
model="deepseek-v4-flash",
middleware=[LoggingMiddleware()],
tools=[...],
)3.2 带配置的中间件
通过__init__传入配置参数:
python
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any
class MessageLimitMiddleware(AgentMiddleware):
"""消息数量限制中间件:超过限制时自动结束"""
def __init__(self, max_messages: int = 50):
super().__init__()
self.max_messages = max_messages
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if len(state["messages"]) >= self.max_messages:
return {
"messages": [AIMessage("对话轮次已达上限。")],
"jump_to": "end", # 跳转到Agent结束
}
return None
agent = create_agent(
model="deepseek-v4-flash",
middleware=[MessageLimitMiddleware(max_messages=20)],
tools=[...],
)3.3 重试中间件(类方式)
python
from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse
from typing import Callable
class RetryMiddleware(AgentMiddleware):
"""模型调用重试中间件"""
def __init__(self, max_retries: int = 3):
super().__init__()
self.max_retries = max_retries
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
for attempt in range(self.max_retries):
try:
return handler(request)
except Exception as e:
if attempt == self.max_retries - 1:
raise
print(f"重试 {attempt + 1}/{self.max_retries}: {e}")
agent = create_agent(
model="deepseek-v4-flash",
middleware=[RetryMiddleware(max_retries=5)],
tools=[...],
)四、更新状态
中间件可以修改Agent的状态,比如记录调用次数、追踪token使用量等。
4.1 Node-style Hook更新状态
直接返回一个字典,字典的key-value会合并到Agent状态中:
python
from langchain.agents.middleware import after_model, AgentState
from langgraph.runtime import Runtime
from typing import Any
from typing_extensions import NotRequired
class TrackingState(AgentState):
model_call_count: NotRequired[int]
@after_model(state_schema=TrackingState)
def increment_counter(state: TrackingState, runtime: Runtime) -> dict[str, Any] | None:
return {"model_call_count": state.get("model_call_count", 0) + 1}4.2 Wrap-style Hook更新状态
返回ExtendedModelResponse配合Command来更新状态:
python
from typing import Callable
from langchain.agents.middleware import (
wrap_model_call, ModelRequest, ModelResponse,
AgentState, ExtendedModelResponse
)
from langgraph.types import Command
from typing_extensions import NotRequired
class UsageTrackingState(AgentState):
last_model_call_tokens: NotRequired[int]
@wrap_model_call(state_schema=UsageTrackingState)
def track_usage(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
response = handler(request)
# 假设我们从response中获取了token使用量
return ExtendedModelResponse(
model_response=response,
command=Command(update={"last_model_call_tokens": 150}),
)五、跳转控制
中间件可以在特定条件下让Agent提前结束、跳到工具节点、或跳到模型节点。
可用的跳转目标:
| 目标 | 说明 |
|---|---|
"end" | 跳到Agent执行结束 |
"tools" | 跳到工具节点 |
"model" | 跳到模型节点 |
5.1 提前结束
当检测到敏感内容时,提前结束Agent:
python
from langchain.agents.middleware import after_model, AgentState
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any
@after_model(can_jump_to=["end"])
def check_blocked(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
last_message = state["messages"][-1]
if "BLOCKED" in last_message.content:
return {
"messages": [AIMessage("我无法回答这个问题。")],
"jump_to": "end",
}
return None
agent = create_agent(
model="deepseek-v4-flash",
middleware=[check_blocked],
tools=[...],
)5.2 类方式的跳转
python
from langchain.agents.middleware import AgentMiddleware, hook_config, AgentState
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any
class BlockedContentMiddleware(AgentMiddleware):
@hook_config(can_jump_to=["end"])
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
last_message = state["messages"][-1]
if "BLOCKED" in last_message.content:
return {
"messages": [AIMessage("我无法回答这个问题。")],
"jump_to": "end",
}
return None六、实用示例
6.1 动态修改系统提示词
根据用户身份动态注入上下文信息:
python
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from langchain.agents import create_agent
from typing import Callable
@wrap_model_call
def inject_user_context(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""根据用户身份动态修改系统提示词"""
# 从状态或上下文中获取用户信息
user_role = request.runtime.context.get("user_role", "guest")
# 动态构建系统提示词
role_prompt = {
"admin": "你是一个管理员助手,可以执行所有操作。",
"user": "你是一个普通用户助手,只能查询信息。",
"guest": "你是一个访客助手,只能查看公开信息。",
}
# 修改系统提示词
new_prompt = role_prompt.get(user_role, role_prompt["guest"])
modified_request = request.override(system_prompt=new_prompt)
return handler(modified_request)
agent = create_agent(
model="deepseek-v4-flash",
middleware=[inject_user_context],
tools=[...],
)6.2 动态选择工具
根据用户问题动态过滤工具,减少token消耗:
python
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from langchain.agents import create_agent
from typing import Callable
@wrap_model_call
def select_relevant_tools(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""根据用户问题动态选择相关工具"""
last_message = request.state["messages"][-1].content
# 简单的关键词匹配,实际项目可以用更复杂的逻辑
if "天气" in last_message or "气温" in last_message:
tools = [get_weather] # 只保留天气相关工具
elif "文件" in last_message or "读取" in last_message:
tools = [read_file, write_file] # 只保留文件相关工具
else:
tools = request.tools # 使用所有工具
return handler(request.override(tools=tools))
agent = create_agent(
model="deepseek-v4-flash",
middleware=[select_relevant_tools],
tools=[get_weather, read_file, write_file, send_email, ...],
)6.3 工具调用监控
记录每个工具调用的耗时:
python
import time
from langchain.agents.middleware import wrap_tool_call, ToolRequest, ToolResponse
from langchain.agents import create_agent
from typing import Callable
@wrap_tool_call
def monitor_tool_call(
request: ToolRequest,
handler: Callable[[ToolRequest], ToolResponse],
) -> ToolResponse:
"""记录工具调用耗时"""
start_time = time.time()
print(f"[监控] 开始调用工具: {request.name}")
try:
result = handler(request)
elapsed = time.time() - start_time
print(f"[监控] 工具 {request.name} 调用成功,耗时: {elapsed:.2f}秒")
return result
except Exception as e:
elapsed = time.time() - start_time
print(f"[监控] 工具 {request.name} 调用失败,耗时: {elapsed:.2f}秒,错误: {e}")
raise
agent = create_agent(
model="deepseek-v4-flash",
middleware=[monitor_tool_call],
tools=[...],
)七、最佳实践
- 保持单一职责:每个中间件只做一件事,不要把日志、重试、脱敏全塞到一个中间件里
- 选择合适的Hook类型:
- 顺序逻辑(日志、校验)用Node-style
- 控制流逻辑(重试、缓存)用Wrap-style
- 注意执行顺序:把关键的中间件放在列表前面
- 处理好错误:不要让中间件的错误导致整个Agent崩溃
- 同步和异步都要实现:如果中间件需要在
invoke和ainvoke中都能用,记得实现同步和异步两个版本
八、总结
自定义中间件让你可以在Agent执行的各个关键节点插入自己的逻辑:
- 装饰器方式:简单直接,适合单Hook场景
- 类方式:功能强大,适合多Hook、有配置、需要异步的场景
- 六种Hook:
before_agent、before_model、after_model、after_agent、wrap_model_call、wrap_tool_call - 状态更新:Node-style返回字典,Wrap-style返回
ExtendedModelResponse+Command - 跳转控制:通过
jump_to可以提前结束、跳到工具节点或模型节点
掌握中间件,你就掌握了LangChain中最灵活的扩展机制。