Lecture 09 - Policy gradient methods for the Breakout game
MachineLearningCourse.Lecture09 — Module
Lecture09Policy gradient methods for the Breakout game.
Available Functions
demo_reinforce(): Train agent using REINFORCE algorithmreinforce_action(): Get action from trained REINFORCE policy
Usage
using MachineLearningCourse
policy, logger = Lecture09.demo_reinforce(max_episodes=500)MachineLearningCourse.Lecture09.ActorCritic — Method
ActorCritic(env; kwargs...) -> Flux.ChainTrain an agent using the Actor-Critic algorithm.
Actor-Critic combines policy gradients (actor) with value function learning (critic). Updates are performed at each step using TD error as the advantage estimate.
Arguments
env: Environment implementing CommonRLInterfacehidden_layers=[64, 32]: Architecture of hidden layers for both actor and criticη=1e-4: Learning rate for actorη_critic=1e-3: Learning rate for criticγ=0.99: Discount factor for TD errorT=20_000: Maximum steps per episodemax_episodes=1000: Maximum number of episodes to trainbatch_size=32: Number of steps to collect before updating networkscallback=EpisodeLogger(): Function called after each episode
Returns
- Trained policy network (Flux.Chain)
Example
env = BreakoutEnv()
policy = ActorCritic(env, max_episodes=1000)MachineLearningCourse.Lecture09.REINFORCE — Method
REINFORCE(env; kwargs...) -> Flux.ChainTrain an agent using the REINFORCE policy gradient algorithm.
Arguments
env: Environment implementing CommonRLInterfacehidden_layers=[64, 32]: Architecture of hidden layersη=1e-3: Learning rateγ=0.99: Discount factor for returnsT=20_000: Maximum steps per episodemax_episodes=1000: Maximum number of episodes to traincallback=EpisodeLogger(): Function called after each episode
Returns
- Trained policy network (Flux.Chain)
Example
env = BreakoutEnv()
policy = REINFORCE(env, max_episodes=1000)MachineLearningCourse.Lecture09.demo — Function
demo(algorithm=:REINFORCE; max_episodes=100_000, plot=true)Run policy gradient training demo on Breakout environment.
Arguments
algorithm=:REINFORCE: Algorithm function to use (:REINFORCE or :ActorCritic)max_episodes=100_000: Maximum number of training episodesplot=true: Whether to display live plots during training
Returns
(policy, logger): Trained network and logger
Usage
# REINFORCE (default)
policy, logger = Lecture09.demo(max_episodes=500)
# Actor-critic
policy, logger = Lecture09.demo(algorithm=ActorCritic, max_episodes=500)MachineLearningCourse.Lecture09.policy_agent — Method
policy_agent(game_state::GameState, policy) -> AnyGet action from discrete policy network for Breakout.
Arguments
game_state: Current Breakout game statepolicy: Trained policy network (Flux.Chain with softmax output)
Returns
- Discrete action from Breakout action space
MachineLearningCourse.Lecture09.sample_action — Method
sample_action(probs::Vector)Sample action from probability distribution using cumulative distribution.
Arguments
probs: Action probabilities (should sum to 1.0)
Returns
- Action index (Int) sampled according to probabilities
Example
probs = [0.1, 0.7, 0.2] # Action probabilities
action_idx = sample_action(probs) # Returns 1, 2, or 3