Lecture 08 - Deep Q-Learning for the Breakout game

MachineLearningCourse.Lecture08Module
Lecture08

Deep Q-Learning for the Breakout game

Available Functions

  • demo(): Train agent for the Breakout game using DQN
  • breakout(): Play Train Breakout game using trained agent

Usage

using MachineLearningCourse
q_network = Lecture08.demo()
Lecture08.save(q_network,"model.jld2")
Lecture08.breakout("model.jld2")
using MachineLearningCourse
Lecture08.breakout()
source
MachineLearningCourse.Lecture08.EpisodeLoggerType
EpisodeLogger(window_size=1000; plot=true)

Callback logger for tracking and visualizing DQN training progress.

Maintains a moving average of episode rewards and optionally displays plots during training. Always stores episode data for later plot generation using create_plot!.

Arguments

  • window_size=1000: Size of circular buffer for moving average calculation
  • plot=true: Whether to display plots during training (every 10 episodes)

Fields

  • reward_buffer: Circular buffer for episode rewards (for moving average)
  • cumulative_reward_sum: Running sum for efficient moving average calculation
  • episode_rewards: All episode rewards for plotting
  • avg_rewards: Moving average rewards for plotting
  • episodes: Episode numbers for plotting
  • total_steps: Cumulative step count across all episodes
  • plot: Current plot object (updated during training or via create_plot!)
  • plot_enabled: Whether to display plots during training

Usage

# With live plotting
logger = EpisodeLogger(500, plot=true)
q_network = DQN(env, callback=logger)

# Without live plotting (generate plot after training)
logger = EpisodeLogger(500, plot=false)
q_network = DQN(env, callback=logger)
create_plot!(logger)  # Generate final plot
display(logger.plot)  # Show the plot
source
MachineLearningCourse.Lecture08.StateTransitionType
StateTransition

Represents a single experience tuple (s, a, r, s', done) for the replay buffer.

This is the fundamental unit of experience that the DQN agent learns from. Each transition contains:

  • state: The current state observation (flattened game state)
  • action: The action taken by the agent
  • reward: The immediate reward received
  • next_state: The resulting state after taking the action
  • terminal: Whether the episode ended (true if game over)
source
MachineLearningCourse.Lecture08.DQNMethod
DQN(env; kwargs...) -> Flux.Chain

Train an agent using the (Double) DQN algorithm.

Arguments

  • env: Environment implementing CommonRLInterface
  • hidden_layers=[128, 64]: Architecture of hidden layers
  • η=1e-4: Learning rate for Adam optimizer
  • γ=0.99: Discount factor for Bellman equation
  • T=20_000: Maximum steps per episode
  • ε=(0.5, 0.01): Epsilon tuple (initial, final) for exploration
  • Δε=1e-4: Epsilon decay per episode
  • replay_memory_size=1_000_000: Size of experience replay buffer
  • replay_start_size=100_000: Start training after this many experiences
  • batch_size=32: Batch size for training
  • update_frequency=4: How often to perform training steps
  • target_evaluation=ddqn_target_evaluation: Function for target Q-value evaluation
  • target_update_frequency=25_000: How often to update target network
  • max_episodes=100_000: Maximum number of episodes to train
  • callback=EpisodeLogger(): Function called after each episode

Returns

  • Trained Q-network (Flux.Chain)

Examples

Basic usage (Double DQN):

env = BreakoutEnv()
q_network = DQN(env)

Simple DQN:

env = BreakoutEnv()
q_network = DQN(env, target_evaluation=dqn_target_evaluation)
source
MachineLearningCourse.Lecture08.NetworkMethod
Network(layers::Vector{Int})

Create a neural network with specified architecture.

Arguments

  • layers: Network architecture (e.g., [inputdim, 64, 32, outputdim])

Returns

  • Neural network (Flux.Chain) with ReLU hidden layers and linear output

```

source
MachineLearningCourse.Lecture08.agentMethod
agent(game_state::GameState, model) -> Int

Agent providing paddle control for Breakout.

Arguments

  • game_state: Breakout game state
  • model: Trained DQN network (Flux.Chain)

Returns

  • -1: Move paddle left
  • 1: Move paddle right
  • 0: No movement
source
MachineLearningCourse.Lecture08.breakoutFunction
breakout(model_path::String; speed=nothing)

Run Breakout game with a trained DQN model.

Arguments

  • model_path: Path to saved DQN model (.jld2 file)
  • speed=nothing: Game speed

Example

# Run game with trained model
breakout("agent.jld2")
source
MachineLearningCourse.Lecture08.create_plot!Method
create_plot!(logger::EpisodeLogger)

Create and store a plot from the logger's stored data.

Updates logger.plot with a complete visualization showing episode rewards as scatter points and moving average as a line. The plot includes proper legends and formatting.

Usage

logger = EpisodeLogger(plot=false)  # No live plotting
q_network = DQN(env, callback=logger)
create_plot!(logger)  # Generate final plot
display(logger.plot)  # Show the plot
source
MachineLearningCourse.Lecture08.ddqn_target_evaluationMethod
ddqn_target_evaluation(q_network, target_network, next_states, target_q_values, batch_size)

Double DQN target evaluation: use main network to select actions, target network to evaluate them. This reduces overestimation bias in Q-learning.

source
MachineLearningCourse.Lecture08.demoFunction
demo(algorithm=:DDQN; max_episodes=100_000, plot=true)

Run a DQN training demo on Breakout environment.

Arguments

  • algorithm=:DDQN: Selected algorithm (use :DQN for standard DQN)
  • max_episodes=100_000: Maximum number of training episodes
  • plot=true: Whether to display live plots during training

Returns

  • (q_network, logger): Trained network and logger

Usage

# With live plotting
q_network, logger = Lecture08.demo(max_episodes=1000, plot=true)

# DQN without live plotting (faster training)
q_network, logger = Lecture08.demo(:DQN, max_episodes=1000, plot=false)
Lecture08.create_plot!(logger)  # Generate final plot
using Plots
plot!(logger.plot,size=(800,600)) # Resize plot
savefig(logger.plot,"training.png")
source
MachineLearningCourse.Lecture08.loadMethod
load(filepath::String) -> Flux.Chain

Load a trained DQN network from file.

Arguments

  • filepath: Path to saved network file (.jld2)

Returns

  • Loaded network (Flux.Chain)

Example

network = load("model.jld2")
source
MachineLearningCourse.Lecture08.sampleMethod
sample(replay_buffer::ReplayBuffer, batch_size::Int) -> Vector{StateTransition}

Randomly sample a batch of experiences from the replay buffer.

This implements the experience replay mechanism from the DQN paper. By randomly sampling past experiences, we break the temporal correlations that would otherwise make learning unstable.

Arguments

  • replay_buffer: Circular buffer containing past experiences
  • batch_size: Number of experiences to sample

Returns

  • Vector of StateTransition objects for training
source
MachineLearningCourse.Lecture08.saveMethod
save(network, filepath::String)

Save a trained DQN network to file using JLD2.

Arguments

  • network: Trained network (Flux.Chain)
  • filepath: Path to save file (should end with .jld2)

Example

save(network, "model.jld2")
source
MachineLearningCourse.Lecture08.train_step!Method
train_step!(q_network, target_network, optimizer, replay_buffer, batch_size, discount_factor, env_actions) -> Float32

Perform one training step of the DQN algorithm.

Arguments

  • q_network: Main Q-network being trained
  • target_network: Target network for stable Q-learning targets
  • optimizer: Flux optimizer (typically Adam)
  • replay_buffer: Buffer of past experiences
  • batch_size: Number of experiences to learn from
  • discount_factor: γ (gamma) - how much to value future rewards
  • env_actions: Available actions in the environment

Returns

  • Loss value for monitoring training progress
source