

强化学习算法程序实践(1):通用训练框架 + Q-Learning / Sarsa
从一个可落地 Q-Learning 与 Sarsa、epsilon-greedy、回合训练循环、以及保存与加载的最小实践。
前言#
我在做强化学习实验时,最痛苦的不是推公式,而是:同一个算法换个环境就要重写一堆训练脚手架。后来我索性把常用套路总结成一个“可复用的算法骨架”,再把每个算法的差异点(几乎都集中在 sample 和 update)填进去。
更具体一点说,我踩过的坑基本都集中在“工程细节被忽略”上:同样的算法,有时候不是你公式写错了,而是你训练循环里 terminated/truncated 没处理好、epsilon 衰减太快导致过早收敛、或者你没有把每回合最大步数限制住导致训练数据统计完全不一致。于是到最后你只能对着一条乱七八糟的 reward 曲线发呆,完全不知道是环境问题还是代码问题。
所以这篇我会很刻意地把“程序骨架”写得死一点:有哪些函数、训练/测试循环怎么摆、每一步该打印什么,尽量让你在写第二个、第三个算法时不用再从头搭脚手架。
这篇是系列第一篇,目标很明确:
- 给出一个可复用的 RL 项目代码结构(训练 / 测试 / 环境 / 参数)
- 用它实现经典的 Q-Learning(off-policy)
- 以及和它只有一处关键差别的 Sarsa(on-policy)
后续文章我会按“值函数系(DQN 家族)”和“策略梯度 / Actor-Critic 系”继续拆分:
- 第 2 篇:DQN / Double DQN / Dueling DQN / Noisy DQN / PER-DQN
- 第 3 篇:Policy Gradient(REINFORCE)/ PPO / A2C
- 第 4 篇:DDPG / TD3 / SAC(连续控制三件套)
一个通用的 RL 代码骨架#
不管是表格型方法(Q-table),还是深度强化学习(DQN 及其变体),我习惯把智能体抽象成 4 个核心动作:
sample(state):训练时采样动作(带探索,Exploration)predict(state):测试时输出动作(不探索,Exploitation)update(transition):用交互数据更新策略save()/load():保存与加载(可选,但强烈推荐)
其中最关键的提醒是:对不同算法而言,sample 和 update 的实现差异很大;其它(训练循环、日志、保存/加载)通常大同小异。
训练循环(定义训练)#
一个“回合制(episode-based)”的训练循环,基本就是下面这个顺序:
- 回合开始:
state = env.reset() - 设定每回合最大步数
max_steps(帮助更快收敛,也避免一直跑不结束) - 循环交互直到终止:
action = agent.sample(state)(探索策略)next_state, reward, terminated, truncated, info = env.step(action)- 构造
transition(必要时写入 memory) agent.update(transition)state = next_state- 若
terminated or truncated则结束回合
用几行伪代码把它写死(推荐你直接复制作为项目模板):
def train(env, agent, num_episodes: int, max_steps: int):
for ep in range(num_episodes):
state, info = env.reset()
ep_return = 0.0
for t in range(max_steps):
action = agent.sample(state)
next_state, reward, terminated, truncated, info = env.step(action)
agent.update(state, action, reward, next_state, terminated)
ep_return += reward
state = next_state
if terminated or truncated:
break
print(f"episode={ep} return={ep_return:.1f} eps={agent.epsilon:.3f}")python测试循环(定义测试)#
测试和训练长得很像,但有两点必须改:
- 不要更新:测试只是评估性能
- 不要 sample:用
predict走纯利用(Exploitation)
def evaluate(env, agent, num_episodes: int, max_steps: int):
returns = []
for ep in range(num_episodes):
state, info = env.reset()
ep_return = 0.0
for t in range(max_steps):
action = agent.predict(state)
next_state, reward, terminated, truncated, info = env.step(action)
ep_return += reward
state = next_state
if terminated or truncated:
break
returns.append(ep_return)
print(f"[eval] episode={ep} return={ep_return:.1f}")
return sum(returns) / len(returns)python环境:用 Gym 就够了(需要自定义时只看 reset/step)#
大多数情况下我不会自己造环境:Gym/Gymnasium 已经足够。
如果你必须自定义环境,最关键就是对齐这两个接口:
reset():回合开始,返回初始状态step(action):执行动作,返回(next_state, reward, terminated, truncated, info)
Q-Learning:从“骨架”到可跑的算法#
Q-Learning 的三个核心函数#
Q-Learning(表格法)里我们维护一个 表。把它塞进“骨架”里,你会发现:
predict(state):选sample(state):在predict基础上加探索(epsilon-greedy / UCB etc.)update(...):就是贝尔曼更新(TD)
epsilon-greedy(最常用的探索策略)#
训练过程一般是:前期探索多、后期逐步收敛,也就是让 从大到小。
通常会设三元组:
epsilon_start:初始探索率(常见 0.95)epsilon_end:最小探索率(常见 0.01,留一点探索避免错过更优策略)epsilon_decay:衰减速度(太快容易“过早收敛/过拟合”,太慢收敛会拖)
一个简单衰减:
一份“能直接套用”的 Q-Learning 类#
下面代码是我常用的最小实现(离散状态/动作)。写成类是为了和 DQN 等深度算法保持一致的接口。
import random
from collections import defaultdict
from dataclasses import dataclass
@dataclass
class QLearningConfig:
gamma: float = 0.99
lr: float = 0.1
epsilon_start: float = 0.95
epsilon_end: float = 0.01
epsilon_decay: float = 0.995
class QLearningAgent:
def __init__(self, n_actions: int, cfg: QLearningConfig):
self.n_actions = n_actions
self.cfg = cfg
self.Q = defaultdict(lambda: [0.0 for _ in range(n_actions)])
self.epsilon = cfg.epsilon_start
def sample(self, state):
# Exploration + exploitation
if random.random() < self.epsilon:
return random.randrange(self.n_actions)
return self.predict(state)
def predict(self, state):
q = self.Q[state]
return int(max(range(self.n_actions), key=lambda a: q[a]))
def update(self, state, action, reward, next_state, terminated: bool):
q_sa = self.Q[state][action]
next_q_max = 0.0 if terminated else max(self.Q[next_state])
target = reward + self.cfg.gamma * next_q_max
self.Q[state][action] = q_sa + self.cfg.lr * (target - q_sa)
# decay epsilon once per step (or per episode, both ok; step 更细)
self.epsilon = max(self.cfg.epsilon_end, self.epsilon * self.cfg.epsilon_decay)
def save(self, path: str):
import json
with open(path, 'w', encoding='utf-8') as f:
json.dump({str(k): v for k, v in self.Q.items()}, f, ensure_ascii=False)
def load(self, path: str):
import json
with open(path, 'r', encoding='utf-8') as f:
obj = json.load(f)
self.Q = defaultdict(lambda: [0.0 for _ in range(self.n_actions)])
for k, v in obj.items():
self.Q[k] = vpythonSarsa:和 Q-Learning 只差一个 update#
Sarsa 和 Q-Learning 的核心差别,一句话总结:
- Sarsa:用“实际执行的下一步动作”更新(on-policy)
- Q-Learning:用“假设下一步最优动作”更新(off-policy)
写成更新目标(只看差异就好):
- Q-Learning:
- Sarsa:,其中 是用当前策略在 上采样的动作
Sarsa 的最小实现(只展示 update 差异)#
Sarsa 的代码组织和 Q-Learning 几乎一致,主要差别在 update 需要 next_action:
class SarsaAgent(QLearningAgent):
def update(self, state, action, reward, next_state, next_action, terminated: bool):
q_sa = self.Q[state][action]
next_q = 0.0 if terminated else self.Q[next_state][next_action]
target = reward + self.cfg.gamma * next_q
self.Q[state][action] = q_sa + self.cfg.lr * (target - q_sa)
self.epsilon = max(self.cfg.epsilon_end, self.epsilon * self.cfg.epsilon_decay)python对应训练循环也要改一行:先拿到 next_action 再更新。
action = agent.sample(state)
next_state, reward, terminated, truncated, info = env.step(action)
next_action = agent.sample(next_state)
agent.update(state, action, reward, next_state, next_action, terminated)
state = next_state
action = next_actionpython小结#
你如果只记住一件事,那就是:
- 训练/测试循环可以固定成模板,主要改
sample/predict/update
下一篇我会把这个模板升级到“深度强化学习”版本:引入 Replay Buffer、目标网络,逐步讲清楚 DQN 以及它的几种常见改进(Double / Dueling / Noisy / PER)。
我常用的调参/调试顺序(很实用)#
最后留一段我自己最常用的“排障流程”。强化学习的坏处是:它不会像监督学习那样一眼看出你是不是 overfit,它更像一锅粥,哪一步有问题都可能表现成“reward 不涨”。我一般按下面顺序来定位:
- 先把环境跑通:随机策略能不能结束回合?
terminated/truncated是否合理?reward 的量级大概是多少? - 把日志打印得足够具体:每回合 return、每步/每回合 epsilon、以及“是否提前结束”。如果是 Q 表法,我还会打印
max(Q[state])的量级看有没有爆炸。 - epsilon 先别衰减太快:我最常犯的错误是
epsilon_decay设太狠,导致前几十回合就开始“自信地瞎走”。如果你看到 reward 一开始有波动、很快就僵住,优先怀疑探索不足。 - gamma 和 max_steps 要配套:
gamma大的时候,回报有效视野更长;如果你max_steps又很短,很多环境会变成“怎么都学不出来”。反过来也是:max_steps太长会让训练变慢且方差更大。 - 先在最小环境验证:像 FrozenLake / Taxi 这种可以快速验证“训练循环是否正确”。骨架确认没问题,再搬到复杂环境会省掉很多时间。