I'm relatively new to Rust and recently was looking for a new project to challenge the meager skills I have. I decided to tackle the Cart Pole Problem not because of my extensive experience with Bevy or dfdx, but my lack thereof. This post does not show best practices, or code that should probably be copied, but a quick, easy, and fun way to solve the cart pole problem using bevy and dfdx.
I have never built a game, but found Bevy, one of Rust's more popular ECS (entity component system) game engines to be amazing to work with. If you haven't heard of Bevy I highly encourage checking it out, they have an awesome community: https://bevyengine.org/
Dfdx is a simple deep learning library in rust. It has a bunch of awesome features that make deep learning a breeze. Once again, if you haven't seen it before, I would check it out: https://github.com/coreylowman/dfdx
The premise is simple. You control a cart in a track that only allows for horizontal movement. On top of the cart is a pole with a hinge connected to the cart that limits rotation to only the z axis (it can rotate left and right). The goal is to keep the pole upright as long as possible. The modern version of the Cart Pole ends after 500 steps (approximately 10 seconds according to OpenAI).
For those interested in learning more, please check out OpenAI's wiki. Please note, we are playing v1.
{{section-header.html Deep Q-Learning}} There are a multitude of ways to solve the Cart Pole Problem. From my understanding, most of these use some variation of Reinforcement Learning, where the AI learns from its past attempts. In our case, we will be using something called Deep Q-Learning, or techincally, Double Deep Q-Learning, as we will have two models, a target model, and one actually playing the game.
While building a deep reinforcement learning model from scratch would make enough content for a post of its own, dfdx provides us with a number of luxuries that greatly speed up this process. Using dfdx we are able to abstract away the nitty gritty of moving forward and backwards through the network, and the training process. The actual capabilities of the library extend far beyond what it was used for here.
I won't go into the details of Q-Learning, as that is far outside the scope of this post, but for anyone interested I highly recommend Hugging Face's deep reinforcement learning class: https://huggingface.co/blog/deep-rl-dqn
fn main() {
App::new()
.insert_resource(WindowDescriptor {
title: "Cart Pole".to_string(),
present_mode: PresentMode::AutoVsync,
..default()
})
.add_plugins(DefaultPlugins)
.add_startup_system(add_camera)
.add_system(size_scaling)
.add_startup_system(add_cart_pole)
.add_startup_system(add_model.exclusive_system())
.add_system(step)
.run();
}
Important things to note:
DefaultPlugins
include a number of various sytems. Most importantly for us, they include the actual game loop, and the window so we can see our game. We don't need all of the systems included in DefaultPlugins
, but for simplicity's sake, we will leave them.add_camera
, size_scaling
(handles sprite scaling), add_cart_pole
, add_model
, and step
The add_camera_system
adds a camera so we can see the game. The size_scaling
is code taken from one of Bevy's tutorials that helps scale sprites as the window size changes. None of these systems are worth noting or viewing here.
The add_cart_pole
system adds the cart
and pole
entities to our world.
let cart_handle = asset_server.load("cart.png");
let pole_handle = asset_server.load("pole.png");
commands
.spawn_bundle(SpriteBundle {
sprite: Sprite {
custom_size: Some(Vec2::new(1., 1.)),
..default()
},
texture: cart_handle,
transform: Transform {
translation: Vec3::new(0., 0., 0.),
scale: Vec3::new(1., 1., 1.),
..default()
},
..default()
})
.insert(Cart)
.insert(Velocity::default())
.insert(Size {
width: 0.6,
height: 0.3,
});
commands
.spawn_bundle(SpriteBundle {
sprite: Sprite {
anchor: sprite::Anchor::BottomCenter,
custom_size: Some(Vec2::new(1., 1.)),
..default()
},
texture: pole_handle,
transform: Transform {
translation: Vec3::new(0., 0., 1.),
scale: Vec3::new(1., 1., 1.),
..default()
},
..default()
})
.insert(Pole)
.insert(Velocity::default())
.insert(Size {
width: 0.1,
height: 1.,
});
This is a pretty standard way to add sprites to Bevy, the only important things to note here are the Velocity
components we add to both entities. The Velocity
components are used in the step
system when calculating the movement of the cart and pole. For those curious, the Size
components are used in the sprite_scaling
system, but nowhere else.
The last entity we add to our world is the Model
. This is the agent that is going to solve the Cart Pole Problem. Note that above when we pass the add_model
function into the app system, we pass it as an exclusive system. Dfdx makes use of Rusts' Rc
type. This type cannot be sent safely between threads, so we set this system as exclusive to let Bevy know it must run any system using it in the main thread.
The Model
is defined by the following code.
type Mlp = (
Linear<4, 64>,
(Linear<64, 64>, ReLU),
(Linear<64, 32>, ReLU),
Linear<32, 2>,
);
type Transition = ([f32; 4], i32, i32, Option<[f32; 4]>);
#[derive(Debug, Default)]
struct Model {
model: Mlp,
target: Mlp,
optimizer: Adam<Mlp>,
steps_since_last_merge: i32,
survived_steps: i32,
episode: i32,
epsilon: f32,
experience: Vec<Transition>,
}
There are probably more effective places to store items like survived_steps
, and the epsilon
for choosing random actions, but for this simple example, I thought it good enough.
The final system to discuss is the step
. The step
is the actual game logic, a "step" in the game world.
let (mut cart_transform, mut cart_velocity) = q_cart
.get_single_mut()
.expect("Could not get the cart information");
let (mut pole_transform, mut pole_velocity) = q_pole
.get_single_mut()
.expect("Could not get the pole information");
let mut text = q_text
.get_single_mut()
.expect("Could not get the text with the episode info");
let observation = [
cart_transform.translation.x,
cart_velocity.0,
pole_transform.rotation.z,
pole_velocity.0,
];
let action = match model.epsilon > rand::random::<f32>() {
true => match rand::random::<bool>() {
true => 0,
false => 1,
},
false => {
let tensor_observation: Tensor1D<4> = TensorCreator::new(observation);
let prediction = model.model.forward(tensor_observation);
match prediction.data()[0] > prediction.data()[1] {
true => 0,
false => 1,
}
}
};
model.epsilon = (model.epsilon - EPSILON_DECAY).max(0.05);
// These calculations are directly from openai https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py
let force = match action {
1 => FORCE_MAG * -1.,
_ => FORCE_MAG,
};
let costheta = pole_transform.rotation.z.cos();
let sintheta = pole_transform.rotation.z.sin();
let temp =
(force + POLEMASS_LENGTH * pole_transform.rotation.z.powi(2) * sintheta) / TOTAL_MASS;
let thetaacc = (GRAVITY * sintheta - costheta * temp)
/ (LENGTH * (4.0 / 3.0 - MASS_POLE * (costheta * costheta) / TOTAL_MASS));
let xacc = temp - POLEMASS_LENGTH * thetaacc * costheta / TOTAL_MASS;
// Apply above calculations
cart_transform.translation.x += TAU * cart_velocity.0 * cart_transform.scale.x;
cart_velocity.0 += TAU * xacc;
pole_transform.rotation.z += TAU * pole_velocity.0;
pole_velocity.0 += TAU * thetaacc;
// Match the pole x to the cart x
pole_transform.translation.x = cart_transform.translation.x;
// Check if the episode is over
if pole_transform.rotation.z > THETA_THRESHOLD_RADIANS
|| pole_transform.rotation.z < -1. * THETA_THRESHOLD_RADIANS
|| (cart_transform.translation.x / cart_transform.scale.x) > X_THRESHOLD
|| (cart_transform.translation.x / cart_transform.scale.x) < -1. * X_THRESHOLD
|| model.survived_steps > 499
{
println!(
"RESETTING Episode: {} SURVIVED: {}",
model.episode, model.survived_steps,
);
// Reset cart and pole variables just like openai does
let mut rng = rand::thread_rng();
cart_velocity.0 = rng.gen_range(-0.05..0.05);
pole_velocity.0 = rng.gen_range(-0.05..0.05);
cart_transform.translation.x = rng.gen_range(-0.05..0.05);
pole_transform.translation.x = cart_transform.translation.x;
pole_transform.rotation.z = rng.gen_range(-0.05..0.05);
// Update the latest episode survided text
text.sections[0].value = format!(
"Episode: {} - Survided: {}",
model.episode, model.survived_steps
);
// Reset the survived_steps, increment episode count, and push_experience
model.survived_steps = 0;
model.episode += 1;
model.push_experience((observation, action, 0, None));
} else {
model.survived_steps += 1;
let next_observation = [
cart_transform.translation.x,
cart_velocity.0,
pole_transform.rotation.z,
pole_velocity.0,
];
model.push_experience((observation, action, 1, Some(next_observation)));
}
// Train if we have the necessary experience
if model.experience.len() > BATCH_SIZE {
model.train();
}
// Merge the target model after a certain number of steps
if model.steps_since_last_merge > 10 {
model.target = model.model.clone();
model.steps_since_last_merge = 0;
} else {
model.steps_since_last_merge += 1;
}
While it seems like a lot (and could be broken down into smaller steps using event emitters), the actual logic is fairly simple, and anyone who has worked with reinforcement learning should recognize the pattern.
The first thing we do is grab the Cart and Pole position
and velocity
. Using these variables, we create our current step's observation
. We perform an epsilon-greedy action, and given that action
, calculate the force
applied to the cart and pole.
The calculations for the pole's rotation
and cart's position
have been taken from OpenaAI's cart pole. This project would not have been possible without this already clear code. For those curious, there is a paper outlining how to correctly calculate the force on the pole, but it is far above my head: Correct equations for the dynamics of the cart-pole system
If the episode
is over, we reset the cart and pole position
, velocity
, and rotation
, reset the survived_steps
to 0, increment the episode
counter, and push this observation
with a reward of 0. If the episode is not over, we increase the survived_steps
and push the observation
with a reward of 1.
After each step we train the Model
.
pub fn train(&mut self) {
// Select the experience batch
let mut rng = rand::thread_rng();
let distribution = rand::distributions::Uniform::from(0..self.experience.len());
let experience: Vec<Transition> = (0..BATCH_SIZE)
.map(|_index| self.experience[distribution.sample(&mut rng)])
.collect();
// Get the models expected rewards
let observations: Vec<_> = experience.iter().map(|x| x.0.to_owned()).collect();
let observations: [[f32; 4]; BATCH_SIZE] = observations.try_into().unwrap();
let observations: Tensor2D<BATCH_SIZE, 4> = TensorCreator::new(observations);
let predictions = self.model.forward(observations.trace());
let actions_indices: Vec<_> = experience.iter().map(|x| x.1 as usize).collect();
let actions_indices: [usize; BATCH_SIZE] = actions_indices.try_into().unwrap();
let predictions: Tensor1D<BATCH_SIZE, dfdx::prelude::OwnedTape> =
predictions.select(&actions_indices);
// Get the targets expected rewards for the next_observation
// This could be optimized but I can't think of a easy way to do it without making this
// code much more gross, and since we are already far faster than we need to be, this is
// fine BUT when not rendering the window, this is the bottleneck in the program
let mut target_predictions: [f32; BATCH_SIZE] = [0.; BATCH_SIZE];
for (i, x) in experience.iter().enumerate() {
let target_prediction = match x.3 {
Some(next_observation) => {
let next_observation: Tensor1D<4> = TensorCreator::new(next_observation);
let target_prediction = self.target.forward(next_observation);
let target_prediction =
target_prediction.data()[0].max(target_prediction.data()[1]);
target_prediction * NEXT_STATE_DISCOUNT + experience[i].2 as f32
}
None => experience[i].2 as f32,
};
target_predictions[i] = target_prediction;
}
let target_predictions: Tensor1D<BATCH_SIZE> = TensorCreator::new(target_predictions);
// Get the loss and train the model
let loss = mse_loss(predictions, &target_predictions);
self.optimizer
.update(&mut self.model, loss.backward())
.expect("Oops, we messed up");
}
The train
function is relatively simple, we select a random batch from the experience
buffer, get our current model's value predictions
for those states, compare those to the target's predictions on the next_observation
plus the current observation's reward, and train on the loss.
As mentioned in the comments above, this is the biggest bottleneck for performance. Specifically, because the next_observation
can be None
, we cannot run the entire batch of next_observations
through the target model as easily as we did the acting model. I am sure there are many ways to increase the performance, and reduce the vector to some index aware observations without the Nones, but for the purposes of this post, the performance is already more than adequate.
Every 10 updates we copy the current model
to the target_model
.
The last thing to do is create the sprites for the entities.
I've been having fun using NixOs as my daily system, and so instead of creating the sprites in Pixelmator or Adobe, I used Gimp, an open source image editor.
The sprite themselves were very simple. They pretty closely match the traditional cart pole except for the Ferris decal I added to the cart.
All said and done we end up with this:
To keep the video short, I only show it training till it survives for 500 steps the first time.
All code for this is publicly available on my github.
Thanks for reading!
Github | Twitter | LinkedIn | Newsletter
© 2024 Silas Marvin. No tracking, no cookies, just plain HTML and CSS.