Xiaohei's Blog
headpicBlur image

前言#

我在做强化学习实验时,最痛苦的不是推公式,而是:同一个算法换个环境就要重写一堆训练脚手架。后来我索性把常用套路总结成一个“可复用的算法骨架”,再把每个算法的差异点(几乎都集中在 sampleupdate)填进去。

更具体一点说,我踩过的坑基本都集中在“工程细节被忽略”上:同样的算法,有时候不是你公式写错了,而是你训练循环里 terminated/truncated 没处理好、epsilon 衰减太快导致过早收敛、或者你没有把每回合最大步数限制住导致训练数据统计完全不一致。于是到最后你只能对着一条乱七八糟的 reward 曲线发呆,完全不知道是环境问题还是代码问题。

所以这篇我会很刻意地把“程序骨架”写得死一点:有哪些函数、训练/测试循环怎么摆、每一步该打印什么,尽量让你在写第二个、第三个算法时不用再从头搭脚手架。

这篇是系列第一篇,目标很明确:

  • 给出一个可复用的 RL 项目代码结构(训练 / 测试 / 环境 / 参数)
  • 用它实现经典的 Q-Learning(off-policy)
  • 以及和它只有一处关键差别的 Sarsa(on-policy)

后续文章我会按“值函数系(DQN 家族)”和“策略梯度 / Actor-Critic 系”继续拆分:

  • 第 2 篇:DQN / Double DQN / Dueling DQN / Noisy DQN / PER-DQN
  • 第 3 篇:Policy Gradient(REINFORCE)/ PPO / A2C
  • 第 4 篇:DDPG / TD3 / SAC(连续控制三件套)

一个通用的 RL 代码骨架#

不管是表格型方法(Q-table),还是深度强化学习(DQN 及其变体),我习惯把智能体抽象成 4 个核心动作:

  • sample(state):训练时采样动作(带探索,Exploration)
  • predict(state):测试时输出动作(不探索,Exploitation)
  • update(transition):用交互数据更新策略
  • save()/load():保存与加载(可选,但强烈推荐)

其中最关键的提醒是:对不同算法而言,sampleupdate 的实现差异很大;其它(训练循环、日志、保存/加载)通常大同小异。

训练循环(定义训练)#

一个“回合制(episode-based)”的训练循环,基本就是下面这个顺序:

  1. 回合开始:state = env.reset()
  2. 设定每回合最大步数 max_steps(帮助更快收敛,也避免一直跑不结束)
  3. 循环交互直到终止:
    • action = agent.sample(state)(探索策略)
    • next_state, reward, terminated, truncated, info = env.step(action)
    • 构造 transition(必要时写入 memory)
    • agent.update(transition)
    • state = next_state
    • terminated or truncated 则结束回合

用几行伪代码把它写死(推荐你直接复制作为项目模板):

def train(env, agent, num_episodes: int, max_steps: int):
	for ep in range(num_episodes):
		state, info = env.reset()
		ep_return = 0.0

		for t in range(max_steps):
			action = agent.sample(state)
			next_state, reward, terminated, truncated, info = env.step(action)

			agent.update(state, action, reward, next_state, terminated)

			ep_return += reward
			state = next_state
			if terminated or truncated:
				break

		print(f"episode={ep} return={ep_return:.1f} eps={agent.epsilon:.3f}")
python

测试循环(定义测试)#

测试和训练长得很像,但有两点必须改:

  1. 不要更新:测试只是评估性能
  2. 不要 sample:用 predict 走纯利用(Exploitation)
def evaluate(env, agent, num_episodes: int, max_steps: int):
	returns = []
	for ep in range(num_episodes):
		state, info = env.reset()
		ep_return = 0.0
		for t in range(max_steps):
			action = agent.predict(state)
			next_state, reward, terminated, truncated, info = env.step(action)
			ep_return += reward
			state = next_state
			if terminated or truncated:
				break
		returns.append(ep_return)
		print(f"[eval] episode={ep} return={ep_return:.1f}")
	return sum(returns) / len(returns)
python

环境:用 Gym 就够了(需要自定义时只看 reset/step)#

大多数情况下我不会自己造环境:Gym/Gymnasium 已经足够。

如果你必须自定义环境,最关键就是对齐这两个接口:

  • reset():回合开始,返回初始状态
  • step(action):执行动作,返回 (next_state, reward, terminated, truncated, info)

Q-Learning:从“骨架”到可跑的算法#

Q-Learning 的三个核心函数#

Q-Learning(表格法)里我们维护一个 Q(s,a)Q(s,a) 表。把它塞进“骨架”里,你会发现:

  • predict(state):选 argmaxaQ(s,a)\arg\max_a Q(s,a)
  • sample(state):在 predict 基础上加探索(epsilon-greedy / UCB etc.)
  • update(...):就是贝尔曼更新(TD)

epsilon-greedy(最常用的探索策略)#

训练过程一般是:前期探索多、后期逐步收敛,也就是让 ϵ\epsilon 从大到小。

通常会设三元组:

  • epsilon_start:初始探索率(常见 0.95)
  • epsilon_end:最小探索率(常见 0.01,留一点探索避免错过更优策略)
  • epsilon_decay:衰减速度(太快容易“过早收敛/过拟合”,太慢收敛会拖)

一个简单衰减:

ϵmax(ϵend,ϵϵdecay)\epsilon \leftarrow \max(\epsilon_{end}, \epsilon \cdot \epsilon_{decay})

一份“能直接套用”的 Q-Learning 类#

下面代码是我常用的最小实现(离散状态/动作)。写成类是为了和 DQN 等深度算法保持一致的接口。

import random
from collections import defaultdict
from dataclasses import dataclass


@dataclass
class QLearningConfig:
	gamma: float = 0.99
	lr: float = 0.1
	epsilon_start: float = 0.95
	epsilon_end: float = 0.01
	epsilon_decay: float = 0.995


class QLearningAgent:
	def __init__(self, n_actions: int, cfg: QLearningConfig):
		self.n_actions = n_actions
		self.cfg = cfg

		self.Q = defaultdict(lambda: [0.0 for _ in range(n_actions)])
		self.epsilon = cfg.epsilon_start

	def sample(self, state):
		# Exploration + exploitation
		if random.random() < self.epsilon:
			return random.randrange(self.n_actions)
		return self.predict(state)

	def predict(self, state):
		q = self.Q[state]
		return int(max(range(self.n_actions), key=lambda a: q[a]))

	def update(self, state, action, reward, next_state, terminated: bool):
		q_sa = self.Q[state][action]
		next_q_max = 0.0 if terminated else max(self.Q[next_state])
		target = reward + self.cfg.gamma * next_q_max
		self.Q[state][action] = q_sa + self.cfg.lr * (target - q_sa)

		# decay epsilon once per step (or per episode, both ok; step 更细)
		self.epsilon = max(self.cfg.epsilon_end, self.epsilon * self.cfg.epsilon_decay)

	def save(self, path: str):
		import json
		with open(path, 'w', encoding='utf-8') as f:
			json.dump({str(k): v for k, v in self.Q.items()}, f, ensure_ascii=False)

	def load(self, path: str):
		import json
		with open(path, 'r', encoding='utf-8') as f:
			obj = json.load(f)
		self.Q = defaultdict(lambda: [0.0 for _ in range(self.n_actions)])
		for k, v in obj.items():
			self.Q[k] = v
python

Sarsa:和 Q-Learning 只差一个 update#

Sarsa 和 Q-Learning 的核心差别,一句话总结:

  • Sarsa:用“实际执行的下一步动作”更新(on-policy)
  • Q-Learning:用“假设下一步最优动作”更新(off-policy)

写成更新目标(只看差异就好):

  • Q-Learning:r+γmaxaQ(s,a)r + \gamma \max_{a'} Q(s', a')
  • Sarsa:r+γQ(s,anext)r + \gamma Q(s', a_{next}),其中 anexta_{next} 是用当前策略在 ss' 上采样的动作

Sarsa 的最小实现(只展示 update 差异)#

Sarsa 的代码组织和 Q-Learning 几乎一致,主要差别在 update 需要 next_action

class SarsaAgent(QLearningAgent):
	def update(self, state, action, reward, next_state, next_action, terminated: bool):
		q_sa = self.Q[state][action]
		next_q = 0.0 if terminated else self.Q[next_state][next_action]
		target = reward + self.cfg.gamma * next_q
		self.Q[state][action] = q_sa + self.cfg.lr * (target - q_sa)
		self.epsilon = max(self.cfg.epsilon_end, self.epsilon * self.cfg.epsilon_decay)
python

对应训练循环也要改一行:先拿到 next_action 再更新。

action = agent.sample(state)
next_state, reward, terminated, truncated, info = env.step(action)
next_action = agent.sample(next_state)
agent.update(state, action, reward, next_state, next_action, terminated)
state = next_state
action = next_action
python

小结#

你如果只记住一件事,那就是:

  • 训练/测试循环可以固定成模板,主要改 sample/predict/update

下一篇我会把这个模板升级到“深度强化学习”版本:引入 Replay Buffer、目标网络,逐步讲清楚 DQN 以及它的几种常见改进(Double / Dueling / Noisy / PER)。

我常用的调参/调试顺序(很实用)#

最后留一段我自己最常用的“排障流程”。强化学习的坏处是:它不会像监督学习那样一眼看出你是不是 overfit,它更像一锅粥,哪一步有问题都可能表现成“reward 不涨”。我一般按下面顺序来定位:

  1. 先把环境跑通:随机策略能不能结束回合?terminated/truncated 是否合理?reward 的量级大概是多少?
  2. 把日志打印得足够具体:每回合 return、每步/每回合 epsilon、以及“是否提前结束”。如果是 Q 表法,我还会打印 max(Q[state]) 的量级看有没有爆炸。
  3. epsilon 先别衰减太快:我最常犯的错误是 epsilon_decay 设太狠,导致前几十回合就开始“自信地瞎走”。如果你看到 reward 一开始有波动、很快就僵住,优先怀疑探索不足。
  4. gamma 和 max_steps 要配套gamma 大的时候,回报有效视野更长;如果你 max_steps 又很短,很多环境会变成“怎么都学不出来”。反过来也是:max_steps 太长会让训练变慢且方差更大。
  5. 先在最小环境验证:像 FrozenLake / Taxi 这种可以快速验证“训练循环是否正确”。骨架确认没问题,再搬到复杂环境会省掉很多时间。
强化学习算法程序实践(1):通用训练框架 + Q-Learning / Sarsa
https://xiaohei-blog.vercel.app/blog/rl-algorithm-1
Author 红鼻子小黑
Published at May 1, 2025
Comment seems to stuck. Try to refresh?✨