Lecture 08 - Deep Q-Learning for the Breakout game
MachineLearningCourse.Lecture08 — Module
Lecture08Deep Q-Learning for the Breakout game
Available Functions
demo(): Train agent for the Breakout game using DQNbreakout(): Play Train Breakout game using trained agent
Usage
using MachineLearningCourse
q_network = Lecture08.demo()
Lecture08.save(q_network,"model.jld2")
Lecture08.breakout("model.jld2")using MachineLearningCourse
Lecture08.breakout()MachineLearningCourse.Lecture08.QUIT — Constant
Global reference to track interrupt status.
MachineLearningCourse.Lecture08.EpisodeLogger — Type
EpisodeLogger(window_size=1000; plot=true)Callback logger for tracking and visualizing DQN training progress.
Maintains a moving average of episode rewards and optionally displays plots during training. Always stores episode data for later plot generation using create_plot!.
Arguments
window_size=1000: Size of circular buffer for moving average calculationplot=true: Whether to display plots during training (every 10 episodes)
Fields
reward_buffer: Circular buffer for episode rewards (for moving average)cumulative_reward_sum: Running sum for efficient moving average calculationepisode_rewards: All episode rewards for plottingavg_rewards: Moving average rewards for plottingepisodes: Episode numbers for plottingtotal_steps: Cumulative step count across all episodesplot: Current plot object (updated during training or viacreate_plot!)plot_enabled: Whether to display plots during training
Usage
# With live plotting
logger = EpisodeLogger(500, plot=true)
q_network = DQN(env, callback=logger)
# Without live plotting (generate plot after training)
logger = EpisodeLogger(500, plot=false)
q_network = DQN(env, callback=logger)
create_plot!(logger) # Generate final plot
display(logger.plot) # Show the plotMachineLearningCourse.Lecture08.StateTransition — Type
StateTransitionRepresents a single experience tuple (s, a, r, s', done) for the replay buffer.
This is the fundamental unit of experience that the DQN agent learns from. Each transition contains:
state: The current state observation (flattened game state)action: The action taken by the agentreward: The immediate reward receivednext_state: The resulting state after taking the actionterminal: Whether the episode ended (true if game over)
MachineLearningCourse.Lecture08.DQN — Method
DQN(env; kwargs...) -> Flux.ChainTrain an agent using the (Double) DQN algorithm.
Arguments
env: Environment implementing CommonRLInterfacehidden_layers=[128, 64]: Architecture of hidden layersη=1e-4: Learning rate for Adam optimizerγ=0.99: Discount factor for Bellman equationT=20_000: Maximum steps per episodeε=(0.5, 0.01): Epsilon tuple (initial, final) for explorationΔε=1e-4: Epsilon decay per episodereplay_memory_size=1_000_000: Size of experience replay bufferreplay_start_size=100_000: Start training after this many experiencesbatch_size=32: Batch size for trainingupdate_frequency=4: How often to perform training stepstarget_evaluation=ddqn_target_evaluation: Function for target Q-value evaluationtarget_update_frequency=25_000: How often to update target networkmax_episodes=100_000: Maximum number of episodes to traincallback=EpisodeLogger(): Function called after each episode
Returns
- Trained Q-network (Flux.Chain)
Examples
Basic usage (Double DQN):
env = BreakoutEnv()
q_network = DQN(env)Simple DQN:
env = BreakoutEnv()
q_network = DQN(env, target_evaluation=dqn_target_evaluation)MachineLearningCourse.Lecture08.Network — Method
Network(layers::Vector{Int})Create a neural network with specified architecture.
Arguments
layers: Network architecture (e.g., [inputdim, 64, 32, outputdim])
Returns
- Neural network (Flux.Chain) with ReLU hidden layers and linear output
```
MachineLearningCourse.Lecture08.agent — Method
agent(game_state::GameState, model) -> IntAgent providing paddle control for Breakout.
Arguments
game_state: Breakout game statemodel: Trained DQN network (Flux.Chain)
Returns
-1: Move paddle left1: Move paddle right0: No movement
MachineLearningCourse.Lecture08.breakout — Function
breakout(model_path::String; speed=nothing)Run Breakout game with a trained DQN model.
Arguments
model_path: Path to saved DQN model (.jld2 file)speed=nothing: Game speed
Example
# Run game with trained model
breakout("agent.jld2")MachineLearningCourse.Lecture08.create_plot! — Method
create_plot!(logger::EpisodeLogger)Create and store a plot from the logger's stored data.
Updates logger.plot with a complete visualization showing episode rewards as scatter points and moving average as a line. The plot includes proper legends and formatting.
Usage
logger = EpisodeLogger(plot=false) # No live plotting
q_network = DQN(env, callback=logger)
create_plot!(logger) # Generate final plot
display(logger.plot) # Show the plotMachineLearningCourse.Lecture08.ddqn_target_evaluation — Method
ddqn_target_evaluation(q_network, target_network, next_states, target_q_values, batch_size)Double DQN target evaluation: use main network to select actions, target network to evaluate them. This reduces overestimation bias in Q-learning.
MachineLearningCourse.Lecture08.demo — Function
demo(algorithm=:DDQN; max_episodes=100_000, plot=true)Run a DQN training demo on Breakout environment.
Arguments
algorithm=:DDQN: Selected algorithm (use:DQNfor standard DQN)max_episodes=100_000: Maximum number of training episodesplot=true: Whether to display live plots during training
Returns
(q_network, logger): Trained network and logger
Usage
# With live plotting
q_network, logger = Lecture08.demo(max_episodes=1000, plot=true)
# DQN without live plotting (faster training)
q_network, logger = Lecture08.demo(:DQN, max_episodes=1000, plot=false)
Lecture08.create_plot!(logger) # Generate final plot
using Plots
plot!(logger.plot,size=(800,600)) # Resize plot
savefig(logger.plot,"training.png")MachineLearningCourse.Lecture08.dqn_target_evaluation — Method
dqn_target_evaluation(q_network, target_network, next_states, target_q_values, batch_size)Standard DQN target evaluation: use target network for both action selection and evaluation.
MachineLearningCourse.Lecture08.enable_interrupt — Method
enable_interrupt()Start monitoring for ENTER key press to interrupt training. Press ENTER to interrupt training gracefully.
MachineLearningCourse.Lecture08.load — Method
load(filepath::String) -> Flux.ChainLoad a trained DQN network from file.
Arguments
filepath: Path to saved network file (.jld2)
Returns
- Loaded network (Flux.Chain)
Example
network = load("model.jld2")MachineLearningCourse.Lecture08.sample — Method
sample(replay_buffer::ReplayBuffer, batch_size::Int) -> Vector{StateTransition}Randomly sample a batch of experiences from the replay buffer.
This implements the experience replay mechanism from the DQN paper. By randomly sampling past experiences, we break the temporal correlations that would otherwise make learning unstable.
Arguments
replay_buffer: Circular buffer containing past experiencesbatch_size: Number of experiences to sample
Returns
- Vector of
StateTransitionobjects for training
MachineLearningCourse.Lecture08.save — Method
save(network, filepath::String)Save a trained DQN network to file using JLD2.
Arguments
network: Trained network (Flux.Chain)filepath: Path to save file (should end with .jld2)
Example
save(network, "model.jld2")MachineLearningCourse.Lecture08.train_step! — Method
train_step!(q_network, target_network, optimizer, replay_buffer, batch_size, discount_factor, env_actions) -> Float32Perform one training step of the DQN algorithm.
Arguments
q_network: Main Q-network being trainedtarget_network: Target network for stable Q-learning targetsoptimizer: Flux optimizer (typically Adam)replay_buffer: Buffer of past experiencesbatch_size: Number of experiences to learn fromdiscount_factor: γ (gamma) - how much to value future rewardsenv_actions: Available actions in the environment
Returns
- Loss value for monitoring training progress
MachineLearningCourse.Lecture08.update_target_network! — Method
update_target_network!(target_network, source_network)Copy weights from the main Q-network to the target network.
Arguments
target_network: The target network to updatesource_network: The main Q-network to copy weights from