Sunday, October 23, 2022

I'm relatively new to Rust and recently was looking for a new project to challenge the meager skills I have. I decided to tackle the Cart Pole Problem not because of my extensive experience with Bevy or dfdx, but my lack thereof. This post does not show best practices, or code that should probably be copied, but a quick, easy, and fun way to solve the cart pole problem using bevy and dfdx.

I have never built a game, but found Bevy, one of Rust's more popular ECS (entity component system) game engines to be amazing to work with. If you haven't heard of Bevy I highly encourage checking it out, they have an awesome community: https://bevyengine.org/

Dfdx is a simple deep learning library in rust. It has a bunch of awesome features that make deep learning a breeze. Once again, if you haven't seen it before, I would check it out: https://github.com/coreylowman/dfdx

What We Are Building

====================================================================================================================================================================================================================================================================================================================================

The Cart Pole Problem

====================================================================================================================================================================================================================================================================================================================================

The premise is simple. You control a cart in a track that only allows for horizontal movement. On top of the cart is a pole with a hinge connected to the cart that limits rotation to only the z axis (it can rotate left and right). The goal is to keep the pole upright as long as possible. The modern version of the Cart Pole ends after 500 steps (approximately 10 seconds according to OpenAI).

For those interested in learning more, please check out OpenAI's wiki. Please note, we are playing v1.

{{section-header.html Deep Q-Learning}} There are a multitude of ways to solve the Cart Pole Problem. From my understanding, most of these use some variation of Reinforcement Learning, where the AI learns from its past attempts. In our case, we will be using something called Deep Q-Learning, or techincally, Double Deep Q-Learning, as we will have two models, a target model, and one actually playing the game.

While building a deep reinforcement learning model from scratch would make enough content for a post of its own, dfdx provides us with a number of luxuries that greatly speed up this process. Using dfdx we are able to abstract away the nitty gritty of moving forward and backwards through the network, and the training process. The actual capabilities of the library extend far beyond what it was used for here.

I won't go into the details of Q-Learning, as that is far outside the scope of this post, but for anyone interested I highly recommend Hugging Face's deep reinforcement learning class: https://huggingface.co/blog/deep-rl-dqn

Some Code

====================================================================================================================================================================================================================================================================================================================================

```
fn main() {
App::new()
.insert_resource(WindowDescriptor {
title: "Cart Pole".to_string(),
present_mode: PresentMode::AutoVsync,
..default()
})
.add_plugins(DefaultPlugins)
.add_startup_system(add_camera)
.add_system(size_scaling)
.add_startup_system(add_cart_pole)
.add_startup_system(add_model.exclusive_system())
.add_system(step)
.run();
}
```

Important things to note:

- Bevy's
`DefaultPlugins`

include a number of various sytems. Most importantly for us, they include the actual game loop, and the window so we can see our game. We don't need all of the systems included in`DefaultPlugins`

, but for simplicity's sake, we will leave them. - We have 5 custom sytems we add: the
`add_camera`

,`size_scaling`

(handles sprite scaling),`add_cart_pole`

,`add_model`

, and`step`

The `add_camera_system`

adds a camera so we can see the game. The `size_scaling`

is code taken from one of Bevy's tutorials that helps scale sprites as the window size changes. None of these systems are worth noting or viewing here.

The `add_cart_pole`

system adds the `cart`

and `pole`

entities to our world.

```
let cart_handle = asset_server.load("cart.png");
let pole_handle = asset_server.load("pole.png");
commands
.spawn_bundle(SpriteBundle {
sprite: Sprite {
custom_size: Some(Vec2::new(1., 1.)),
..default()
},
texture: cart_handle,
transform: Transform {
translation: Vec3::new(0., 0., 0.),
scale: Vec3::new(1., 1., 1.),
..default()
},
..default()
})
.insert(Cart)
.insert(Velocity::default())
.insert(Size {
width: 0.6,
height: 0.3,
});
commands
.spawn_bundle(SpriteBundle {
sprite: Sprite {
anchor: sprite::Anchor::BottomCenter,
custom_size: Some(Vec2::new(1., 1.)),
..default()
},
texture: pole_handle,
transform: Transform {
translation: Vec3::new(0., 0., 1.),
scale: Vec3::new(1., 1., 1.),
..default()
},
..default()
})
.insert(Pole)
.insert(Velocity::default())
.insert(Size {
width: 0.1,
height: 1.,
});
```

This is a pretty standard way to add sprites to Bevy, the only important things to note here are the `Velocity`

components we add to both entities. The `Velocity`

components are used in the `step`

system when calculating the movement of the cart and pole. For those curious, the `Size`

components are used in the `sprite_scaling`

system, but nowhere else.

The last entity we add to our world is the `Model`

. This is the agent that is going to solve the Cart Pole Problem. Note that above when we pass the `add_model`

function into the app system, we pass it as an exclusive system. Dfdx makes use of Rusts' `Rc`

type. This type cannot be sent safely between threads, so we set this system as exclusive to let Bevy know it must run any system using it in the main thread.

The `Model`

is defined by the following code.

```
type Mlp = (
Linear<4, 64>,
(Linear<64, 64>, ReLU),
(Linear<64, 32>, ReLU),
Linear<32, 2>,
);
type Transition = ([f32; 4], i32, i32, Option<[f32; 4]>);
#[derive(Debug, Default)]
struct Model {
model: Mlp,
target: Mlp,
optimizer: Adam<Mlp>,
steps_since_last_merge: i32,
survived_steps: i32,
episode: i32,
epsilon: f32,
experience: Vec<Transition>,
}
```

There are probably more effective places to store items like `survived_steps`

, and the `epsilon`

for choosing random actions, but for this simple example, I thought it good enough.

The Game Logic

The final system to discuss is the `step`

. The `step`

is the actual game logic, a "step" in the game world.

```
let (mut cart_transform, mut cart_velocity) = q_cart
.get_single_mut()
.expect("Could not get the cart information");
let (mut pole_transform, mut pole_velocity) = q_pole
.get_single_mut()
.expect("Could not get the pole information");
let mut text = q_text
.get_single_mut()
.expect("Could not get the text with the episode info");
let observation = [
cart_transform.translation.x,
cart_velocity.0,
pole_transform.rotation.z,
pole_velocity.0,
];
let action = match model.epsilon > rand::random::<f32>() {
true => match rand::random::<bool>() {
true => 0,
false => 1,
},
false => {
let tensor_observation: Tensor1D<4> = TensorCreator::new(observation);
let prediction = model.model.forward(tensor_observation);
match prediction.data()[0] > prediction.data()[1] {
true => 0,
false => 1,
}
}
};
model.epsilon = (model.epsilon - EPSILON_DECAY).max(0.05);
// These calculations are directly from openai https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py
let force = match action {
1 => FORCE_MAG * -1.,
_ => FORCE_MAG,
};
let costheta = pole_transform.rotation.z.cos();
let sintheta = pole_transform.rotation.z.sin();
let temp =
(force + POLEMASS_LENGTH * pole_transform.rotation.z.powi(2) * sintheta) / TOTAL_MASS;
let thetaacc = (GRAVITY * sintheta - costheta * temp)
/ (LENGTH * (4.0 / 3.0 - MASS_POLE * (costheta * costheta) / TOTAL_MASS));
let xacc = temp - POLEMASS_LENGTH * thetaacc * costheta / TOTAL_MASS;
// Apply above calculations
cart_transform.translation.x += TAU * cart_velocity.0 * cart_transform.scale.x;
cart_velocity.0 += TAU * xacc;
pole_transform.rotation.z += TAU * pole_velocity.0;
pole_velocity.0 += TAU * thetaacc;
// Match the pole x to the cart x
pole_transform.translation.x = cart_transform.translation.x;
// Check if the episode is over
if pole_transform.rotation.z > THETA_THRESHOLD_RADIANS
|| pole_transform.rotation.z < -1. * THETA_THRESHOLD_RADIANS
|| (cart_transform.translation.x / cart_transform.scale.x) > X_THRESHOLD
|| (cart_transform.translation.x / cart_transform.scale.x) < -1. * X_THRESHOLD
|| model.survived_steps > 499
{
println!(
"RESETTING Episode: {} SURVIVED: {}",
model.episode, model.survived_steps,
);
// Reset cart and pole variables just like openai does
let mut rng = rand::thread_rng();
cart_velocity.0 = rng.gen_range(-0.05..0.05);
pole_velocity.0 = rng.gen_range(-0.05..0.05);
cart_transform.translation.x = rng.gen_range(-0.05..0.05);
pole_transform.translation.x = cart_transform.translation.x;
pole_transform.rotation.z = rng.gen_range(-0.05..0.05);
// Update the latest episode survided text
text.sections[0].value = format!(
"Episode: {} - Survided: {}",
model.episode, model.survived_steps
);
// Reset the survived_steps, increment episode count, and push_experience
model.survived_steps = 0;
model.episode += 1;
model.push_experience((observation, action, 0, None));
} else {
model.survived_steps += 1;
let next_observation = [
cart_transform.translation.x,
cart_velocity.0,
pole_transform.rotation.z,
pole_velocity.0,
];
model.push_experience((observation, action, 1, Some(next_observation)));
}
// Train if we have the necessary experience
if model.experience.len() > BATCH_SIZE {
model.train();
}
// Merge the target model after a certain number of steps
if model.steps_since_last_merge > 10 {
model.target = model.model.clone();
model.steps_since_last_merge = 0;
} else {
model.steps_since_last_merge += 1;
}
```

While it seems like a lot (and could be broken down into smaller steps using event emitters), the actual logic is fairly simple, and anyone who has worked with reinforcement learning should recognize the pattern.

The first thing we do is grab the Cart and Pole `position`

and `velocity`

. Using these variables, we create our current step's `observation`

. We perform an epsilon-greedy action, and given that `action`

, calculate the `force`

applied to the cart and pole.

The calculations for the pole's `rotation`

and cart's `position`

have been taken from OpenaAI's cart pole. This project would not have been possible without this already clear code. For those curious, there is a paper outlining how to correctly calculate the force on the pole, but it is far above my head: Correct equations for the dynamics of the cart-pole system

If the `episode`

is over, we reset the cart and pole `position`

, `velocity`

, and `rotation`

, reset the `survived_steps`

to 0, increment the `episode`

counter, and push this `observation`

with a reward of 0. If the episode is not over, we increase the `survived_steps`

and push the `observation`

with a reward of 1.

Training The Model

After each step we train the `Model`

.

```
pub fn train(&mut self) {
// Select the experience batch
let mut rng = rand::thread_rng();
let distribution = rand::distributions::Uniform::from(0..self.experience.len());
let experience: Vec<Transition> = (0..BATCH_SIZE)
.map(|_index| self.experience[distribution.sample(&mut rng)])
.collect();
// Get the models expected rewards
let observations: Vec<_> = experience.iter().map(|x| x.0.to_owned()).collect();
let observations: [[f32; 4]; BATCH_SIZE] = observations.try_into().unwrap();
let observations: Tensor2D<BATCH_SIZE, 4> = TensorCreator::new(observations);
let predictions = self.model.forward(observations.trace());
let actions_indices: Vec<_> = experience.iter().map(|x| x.1 as usize).collect();
let actions_indices: [usize; BATCH_SIZE] = actions_indices.try_into().unwrap();
let predictions: Tensor1D<BATCH_SIZE, dfdx::prelude::OwnedTape> =
predictions.select(&actions_indices);
// Get the targets expected rewards for the next_observation
// This could be optimized but I can't think of a easy way to do it without making this
// code much more gross, and since we are already far faster than we need to be, this is
// fine BUT when not rendering the window, this is the bottleneck in the program
let mut target_predictions: [f32; BATCH_SIZE] = [0.; BATCH_SIZE];
for (i, x) in experience.iter().enumerate() {
let target_prediction = match x.3 {
Some(next_observation) => {
let next_observation: Tensor1D<4> = TensorCreator::new(next_observation);
let target_prediction = self.target.forward(next_observation);
let target_prediction =
target_prediction.data()[0].max(target_prediction.data()[1]);
target_prediction * NEXT_STATE_DISCOUNT + experience[i].2 as f32
}
None => experience[i].2 as f32,
};
target_predictions[i] = target_prediction;
}
let target_predictions: Tensor1D<BATCH_SIZE> = TensorCreator::new(target_predictions);
// Get the loss and train the model
let loss = mse_loss(predictions, &target_predictions);
self.optimizer
.update(&mut self.model, loss.backward())
.expect("Oops, we messed up");
}
```

The `train`

function is relatively simple, we select a random batch from the `experience`

buffer, get our current model's value `predictions`

for those states, compare those to the target's predictions on the `next_observation`

plus the current observation's reward, and train on the loss.

As mentioned in the comments above, this is the biggest bottleneck for performance. Specifically, because the `next_observation`

can be `None`

, we cannot run the entire batch of `next_observations`

through the target model as easily as we did the acting model. I am sure there are many ways to increase the performance, and reduce the vector to some index aware observations without the Nones, but for the purposes of this post, the performance is already more than adequate.

Every 10 updates we copy the current `model`

to the `target_model`

.

The Result

The last thing to do is create the sprites for the entities.

I've been having fun using NixOs as my daily system, and so instead of creating the sprites in Pixelmator or Adobe, I used Gimp, an open source image editor.

The sprite themselves were very simple. They pretty closely match the traditional cart pole except for the Ferris decal I added to the cart.

All said and done we end up with this:

To keep the video short, I only show it training till it survives for 500 steps the first time.

All code for this is publicly available on my github.

Thanks for reading!

----------------------------------------

Github | Twitter | LinkedIn | Newsletter

© 2024 Silas Marvin. No tracking, no cookies, just plain HTML and CSS.