Appearance
LangChain教程 - 10 LangChain会话记忆
前面在使用大模型进行对话的时候,为了让大模型能够有上下文的记忆,需要手动填充对话记录。
现在我们来使用 LangChain 中的会话记忆功能,让它能自动填充之前的会话内容,实现有上下文的聊天。
实现会话记忆有两种方式:
- 临时会话:将记忆放在内存中,只对当前运行有效,程序停止重新启动,则数据丢失;
- 长期会话:将记忆保存到文件中,持久保存,重新运行程序,依然有效。
10.1 临时会话
1 临时会话的使用
临时会话,这里用到两个类:
- RunnableWithMessageHistory:这个是对原有链的封装,创建带有历史记录功能的新的链;
- InMemoryChatMessageHistory:为历史记录提供内存存储功能。
下面看一下如何使用:
- 首先还是创建一个执行链,然后使用
RunnableWithMessageHistory对执行链进行包装,形成具有记忆功能的新链; - 根据
session_id获取一个InMemoryChatMessageHistory对象,传递给RunnableWithMessageHistory。
python
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import InMemoryChatMessageHistory
# 1. 创建模型
model = ChatOllama(model="qwen3:1.7b")
# 2. 创建聊天提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "你是一个带有历史记忆的聊天助手,回答要简单"),
("human", "{input}")
])
# 3. 构建链条
chain = prompt | model
# 4. 创建内存存储容器,是一个字典,根据session_id存储会话历史记录
chat_history_store = {}
# 从chat_history_store中,根据session_id获取会话历史记录
def get_session_history(session_id: str):
if session_id not in chat_history_store:
chat_history_store[session_id] = InMemoryChatMessageHistory()
return chat_history_store[session_id]
# 5. 包装成带记忆的链
with_memory_chain = RunnableWithMessageHistory(
chain,
get_session_history,
input_messages_key="input"
)
# 下面的代码就是 chain 的调用了,只是可以通过键盘重复聊天而已,主要看上面
if __name__ == "__main__":
# 6. 持续聊天
session_id = "user1"
print("开始聊天,输入 exit 退出。")
while True:
user_input = input("你:")
if user_input.lower() == "exit":
print("聊天结束。")
break
response = with_memory_chain.stream(
{"input": user_input},
config={"configurable": {"session_id": session_id}}
)
has_print_title = False # 用来标识打印
for chunk in response:
if not has_print_title:
has_print_title = True
print("AI:", end="", flush=True) # print后面默认end="\n",改成"",那么后面不会换行
print(chunk.content, end="", flush=True)
print()- 通过调用
get_session_history函数,通过session_id获取到InMemoryChatMessageHistory对象传递给RunnableWithMessageHistory;每个session_id对应一个InMemoryChatMessageHistory对象
执行结果,可以不停的通过键盘输入进行聊天,聊天具有上下文功能:
python
开始聊天,输入 exit 退出。
你:小明有一只狗
AI:好的,小明有一只狗。
你:小明有两只猫
AI:好的,小明有一只狗和两只猫。
你:小明一共有多少只宠物
AI:小明一共有 3 只宠物:1 只狗和 2 只猫。2 打印提示词
我们可以添加一个链接组件,打印一下会话提示词:
python
# 打印提示词
def print_prompt(full_prompt: ChatPromptTemplate):
print("提示词:", full_prompt.to_string())
return full_prompt
# 3. 构建链条
chain = prompt | print_prompt | model- 方法的参数就是提示词对象,返回的也是提示词对象,可以直接加入到链条。
这样就可以看到提示词信息了,可以看到 LangChain 自动将提示词发送给模型:
开始聊天,输入 exit 退出。
你:小明有一只狗
提示词: System: 你是一个带有历史记忆的聊天助手,回答要简单
Human: [HumanMessage(content='小明有一只狗', additional_kwargs={}, response_metadata={})]
AI:好的,小明有一只狗。需要我记住这个信息吗?
你:小明有两只猫
提示词: System: 你是一个带有历史记忆的聊天助手,回答要简单
Human: [HumanMessage(content='小明有一只狗', additional_kwargs={}, response_metadata={}), AIMessageChunk(content='好的,小明有一只狗。需要我记住这个信息吗?', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'request_id': '4966546f-28b4-4f18-bc76-518a7bc0a5b2', 'token_usage': {'input_tokens': 44, 'output_tokens': 15, 'total_tokens': 59, 'prompt_tokens_details': {'cached_tokens': 0}}}, id='run--019c8388-ea99-7a32-998f-ee99c73b058b'), HumanMessage(content='小明有两只猫', additional_kwargs={}, response_metadata={})]
AI:好的,小明有一只狗和两只猫。需要我记住这些信息吗?
你:小明一共有多少只宠物
提示词: System: 你是一个带有历史记忆的聊天助手,回答要简单
Human: [HumanMessage(content='小明有一只狗', additional_kwargs={}, response_metadata={}), AIMessageChunk(content='好的,小明有一只狗。需要我记住这个信息吗?', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'request_id': '4966546f-28b4-4f18-bc76-518a7bc0a5b2', 'token_usage': {'input_tokens': 44, 'output_tokens': 15, 'total_tokens': 59, 'prompt_tokens_details': {'cached_tokens': 0}}}, id='run--019c8388-ea99-7a32-998f-ee99c73b058b'), HumanMessage(content='小明有两只猫', additional_kwargs={}, response_metadata={}), AIMessageChunk(content='好的,小明有一只狗和两只猫。需要我记住这些信息吗?', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'request_id': '880bdc3d-0dbb-4268-b19e-facd14ca993a', 'token_usage': {'input_tokens': 214, 'output_tokens': 18, 'total_tokens': 232, 'prompt_tokens_details': {'cached_tokens': 0}}}, id='run--019c8388-fc97-73e0-956f-344b1d7a08e0'), HumanMessage(content='小明一共有多少只宠物', additional_kwargs={}, response_metadata={})]
AI:小明一共有 3 只宠物:1 只狗和 2 只猫。3 MessagesPlaceholder
在上面使用会话记录的时候,RunnableWithMessageHistory 会自动把历史消息插入到最前面。
我们也可以手动通过 MessagesPlaceholder 来写:
python
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
# 创建聊天提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "你是一个带有历史记忆的聊天助手,回答要简单"),
MessagesPlaceholder("chat_history"), # 历史记录会添加到这个位置
("human", "{input}")
])
# 包装成带记忆的链
with_memory_chain = RunnableWithMessageHistory(
chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history" # 把历史消息,传给 prompt 里哪个变量名
)- 上面是手动控制历史记录插入的位置和占位符的名称,灵活性更高一些。
临时会话记忆是放在内存中的,程序退出记忆就没了,所以一般生产环境也不会使用。
10.2 长期会话
想要实现长期会话,我们可以将会话记录保存到文件中,这样即使程序重启,会话历史也不会丢失。
1 使用文件存储会话记录
LangChain 提供了多种持久化存储会话历史的方式,其中最常见的是使用文件存储。
我们可以使用 FileChatMessageHistory 类来实现,和上面的内存临时会话是类似的,只是将每个 session_id 的会话内容保存到单独的文件中,其他的是一样的。
python
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import FileChatMessageHistory
import os
# 1. 创建模型
model = ChatOllama(model="qwen3:1.7b")
# 2. 创建聊天提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "你是一个带有历史记忆的聊天助手,回答要简单"),
MessagesPlaceholder("chat_history"), # 历史记录会添加到这个位置
("human", "{input}")
])
# 3. 构建链条
chain = prompt | model
# 4. 创建会话历史存储目录
if not os.path.exists("chat_history"):
os.makedirs("chat_history")
# 从文件中获取会话历史记录
def get_session_history(session_id: str):
# 每个session_id对应一个文件,保存在当前目录chat_history下
file_path = f"./chat_history/{session_id}.json"
return FileChatMessageHistory(file_path=file_path)
# 5. 包装成带记忆的链
with_memory_chain = RunnableWithMessageHistory(
chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history"
)
# 下面的代码就是 chain 的调用了,只是可以通过键盘重复聊天而已,主要看上面
if __name__ == "__main__":
# 6. 持续聊天
session_id = "user1"
print("开始聊天,输入 exit 退出。")
while True:
user_input = input("你:")
if user_input.lower() == "exit":
print("聊天结束。")
break
response = with_memory_chain.stream(
{"input": user_input},
config={"configurable": {"session_id": session_id}}
)
has_print_title = False
for chunk in response:
if not has_print_title:
has_print_title = True
print("AI:", end="", flush=True)
print(chunk.content, end="", flush=True)
print()- 运行这段代码后,会话历史会被保存到
chat_history目录下的user1.json文件中 - 当程序重启后,再次运行时会加载之前的会话历史。
所以我们可以先运行:
开始聊天,输入 exit 退出。
你:小明有一只狗
AI:好的,小明有一只狗。
你:小明有两只猫
AI:好的,小明有一只狗和两只猫。然后关掉程序,重新运行:
开始聊天,输入 exit 退出。
你:小明一共有多少只宠物
AI:小明一共有 3 只宠物:1 只狗和 2 只猫。- AI 可以继续之前的回答啦!
2 使用数据库存储会话记录
除了文件存储,我们还可以使用数据库来存储会话历史,这样更加适合生产环境。LangChain 支持多种数据库存储方式,比如 SQLite、PostgreSQL 等。
下面介绍一下使用 SQLite 存储会话记录,使用 SQLChatMessageHistory 类:
python
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import SQLChatMessageHistory
from sqlalchemy import create_engine # 导入sqlalchemy核心函数
import os
# 1. 创建模型
model = ChatOllama(model="qwen3:1.7b")
# 2. 创建聊天提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "你是一个带有历史记忆的聊天助手,回答要简单"),
MessagesPlaceholder("chat_history"),
("human", "{input}")
])
# 3. 构建链条
chain = prompt | model
# 4. 创建会话历史存储目录,否则后面连接数据库创建文件会报错
if not os.path.exists("chat_history"):
os.makedirs("chat_history")
# 全局创建数据库引擎(关键:只创建一次,所有会话复用)
DB_ENGINE = create_engine(
"sqlite:///chat_history/chat_history.db",
pool_size=5, # 连接池常驻连接数
max_overflow=10, # 超出pool_size的临时连接数
pool_recycle=3600, # 连接回收时间(避免长连接失效)
connect_args={
"timeout": 30, # 数据库连接超时
"check_same_thread": False # SQLite多线程必配
},
echo=False # 调试时设为True,可查看执行的SQL
)
# 从SQLite数据库中获取会话历史记录(带高级配置)
def get_session_history(session_id: str):
# 使用新参数connection传入引擎
return SQLChatMessageHistory(
session_id=session_id,
connection=DB_ENGINE
)
# 5. 包装成带记忆的链
with_memory_chain = RunnableWithMessageHistory(
chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history"
)
# 下面的代码就是 chain 的调用了,只是可以通过键盘重复聊天而已,主要看上面
if __name__ == "__main__":
# 6. 持续聊天(代码与之前相同)
session_id = "user1"
print("开始聊天,输入 exit 退出。")
while True:
user_input = input("你:")
if user_input.lower() == "exit":
print("聊天结束。")
break
response = with_memory_chain.stream(
{"input": user_input},
config={"configurable": {"session_id": session_id}}
)
has_print_title = False
for chunk in response:
if not has_print_title:
has_print_title = True
print("AI:", end="", flush=True)
print(chunk.content, end="", flush=True)
print()- 在连接数据库的时候,使用了
sqlalchemy创建连接池,如果没有安装,可以使用pip install sqlalchemy安装一下; - 然后在
get_session_history的时候,使用SQLChatMessageHistory作为聊天记录的存储方式; - 运行这段代码后,会话历史会被保存到
chat_history/chat_history.db的 SQLite 数据库文件中,当程序重启后,再次运行时会加载之前的会话历史。
10.3 会话记忆的长度
每个大模型都有一个 最大上下文长度(context window)。
比如:8K、16K、32K 等,这个数字代表模型一次最多能“看到”的 token 数量。
如果不停的聊天,上下文的长度会越来越长,而如果将所有的聊天记录传递给大模型,会超过了模型的最大上下文长度,那么会报错、被自动截断、或者丢失前面的内容。
为了避免会话历史过长导致模型输入超过限制,我们可以控制会话历史的长度,例如在历史消息数量超过指定数量的时候,将前面的消息裁剪掉。
举个栗子:
我们可以自定义一个 LimitedFileChatMessageHistory 类继承 BaseChatMessageHistory ,可以点击 FileChatMessageHistory 类,参考 FileChatMessageHistory 的实现,在添加消息的时候,判断当前消息的数量,如果消息数量超过设置的阈值,就将前面的消息裁剪掉,然后再保存到文件。
重写三个方法就可以了:
def messages(self) -> List[BaseMessage]:def add_message(self, message: BaseMessage) -> None:def clear(self) -> None:
python
from pathlib import Path
from typing import List
from langchain_ollama import ChatOllama
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.messages import BaseMessage, messages_to_dict, messages_from_dict
import json, os
# 1. 创建模型
model = ChatOllama(model="qwen3:1.7b")
# 2. 创建聊天提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "你是一个带有历史记忆的聊天助手,回答要简单"),
MessagesPlaceholder("chat_history"),
("human", "{input}")
])
# 3. 构建链条
chain = prompt | model
# 4. 完全自定义的会话历史类(只继承BaseChatMessageHistory)
class LimitedFileChatMessageHistory(BaseChatMessageHistory):
"""自定义文件存储的会话历史,支持消息数量限制"""
def __init__(
self,
file_path: str,
max_messages: int = 10
) -> None:
self.file_path = Path(file_path)
self.max_messages = max_messages
# 初始化文件(不存在则创建空文件)
self._init_file()
# 初始化时裁剪历史(确保符合数量限制)
self._trim_history()
def _init_file(self) -> None:
"""确保文件存在,不存在则创建并写入空列表"""
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
if not self.file_path.exists():
self.file_path.touch()
self._write_messages([])
# 用来将列表数据写入到文件
def _write_messages(self, messages: List[BaseMessage]) -> None:
# 将[BaseMessage,BaseMessage,...]转换为[dict,dict,...]保存到文件
serialized = messages_to_dict(messages)
self.file_path.write_text(
json.dumps(serialized, ensure_ascii=False, indent=2),
encoding="utf-8"
)
@property # 使用@property确保只读访问,防止外部修改
def messages(self) -> List[BaseMessage]:
"""读取文件中的消息列表"""
try:
items = json.loads(self.file_path.read_text(encoding="utf-8"))
# 将[dict,dict,...]转换为[BaseMessage,BaseMessage,...]
return messages_from_dict(items)
except (json.JSONDecodeError, FileNotFoundError):
# 兼容文件损坏/丢失的情况
return []
def add_message(self, message: BaseMessage) -> None:
"""添加消息并自动裁剪"""
# 1. 获取当前所有消息
current_messages = self.messages
# 2. 添加新消息
current_messages.append(message)
# 3. 如果达到最大数量,则进行裁剪
if len(current_messages) > self.max_messages:
print(f"消息数量超过 {self.max_messages},进行裁剪")
current_messages = current_messages[-self.max_messages:]
print(f"当前消息数量:{len(current_messages)}")
# 4. 一次性写入文件
self._write_messages(current_messages)
def _trim_history(self) -> None:
"""裁剪历史消息到指定数量"""
current_messages = self.messages
if len(current_messages) > self.max_messages:
# 获取列表 current_messages 中倒数第十个元素到最后一个元素,保存到文件
self._write_messages(current_messages[-self.max_messages:])
def clear(self) -> None:
"""清空所有历史消息"""
self._write_messages([])
# 5. 从文件中获取会话历史记录
def get_session_history(session_id: str):
file_path = f"chat_history/{session_id}.json"
return LimitedFileChatMessageHistory(file_path=file_path, max_messages=10)
# 6. 包装成带记忆的链
with_memory = RunnableWithMessageHistory(
chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history"
)
# 下面的代码就是 chain 的调用了,只是可以通过键盘重复聊天而已,主要看上面
if __name__ == "__main__":
# 7. 持续聊天(代码与之前相同)
session_id = "user1"
print("开始聊天,输入 exit 退出。")
while True:
user_input = input("你:")
if user_input.lower() == "exit":
print("聊天结束。")
break
response = with_memory.stream(
{"input": user_input},
config={"configurable": {"session_id": session_id}}
)
has_print_title = False
for chunk in response:
if not has_print_title:
has_print_title = True
print("AI:", end="", flush=True)
print(chunk.content, end="", flush=True)
print()- 在保存数据的时候,需要将
[BaseMessage,BaseMessage,...]转换为[dict,dict,...]保存到文件; - 在加载数据的时候,需要将
[dict,dict,...]转换为[BaseMessage,BaseMessage,...]; - 当历史消息超过
max_messages时,会自动裁剪,只保留最近的消息。
运行一下:
开始聊天,输入 exit 退出。
你:我说你记,小明有1只狗,记住只回复好的,我问你问题你再回答
AI:好的当前消息数量:1
当前消息数量:2
你:小明有2只猫
AI:好的当前消息数量:3
当前消息数量:4
你:小明有3只猪
AI:好的当前消息数量:5
当前消息数量:6
你:小明有4只鸡
AI:好的当前消息数量:7
当前消息数量:8
你:小明有5只鸭
AI:好的当前消息数量:9
当前消息数量:10
你:小明有6只鹅
AI:好的消息数量超过 10,进行裁剪
当前消息数量:10
消息数量超过 10,进行裁剪
当前消息数量:10
你:小明有几只猪
AI:小明有3只猪。消息数量超过 10,进行裁剪
当前消息数量:10
消息数量超过 10,进行裁剪
当前消息数量:10
你:小明有几只狗
AI:目前你还没有告诉过我小明有几只狗,所以我不知道。消息数量超过 10,进行裁剪
当前消息数量:10
消息数量超过 10,进行裁剪
当前消息数量:10- 这里需要注意,AI 的回答也会加入聊天记录,所以如果 AI 的聊天记录中增加了问题的信息,也会被后面的回答用来参考,所以这里我叫他回答好的,就没有参考数量了。
上面只是最简单的会话记忆裁剪,还有根据token个数裁剪、或者将早期的历史会话使用大模型进行汇总压缩等方案。
10.4 会话记忆的清理
在某些情况下,我们可能需要清理会话记忆,比如用户注销或者会话过期。
在上面的代码中,直接调用就可以了:
python
history = get_session_history(session_id)
history.clear() # 👉 这里会触发你定义的clear()