Training DQN models for trading using PyTorch and Stable-Baselines3 (drl4t-05)

Xiaoguang Li
4 min readMar 28, 2023

After we created a custom Gym Env for trading in Create custom OpenAI Gym environment for Deep Reinforcement Learning (drl4t-04), it is time to start training our first Deep Reinforcement Learning trading model.

Deep Q-Network (DQN)

Deep Q-Network (DQN) is a reinforcement learning algorithm that uses a deep neural network to approximate the Q-value function. The Q-value function is a function that receives a status-action pair and outputs a value representing the expected cumulative reward for taking that action from that status, and continues playing following the optimal policy.

The basic idea behind DQN is to use a deep neural network to estimate the Q-value function, rather than using a table, which is not feasible for large and continuous status space. DQN also uses techniques such as experience replay and separate target network to optimize and stabilize the learning process.

PyTorch

PyTorch is a deep learning framework that provides an easy-to-use interface for building and training neural networks. It allows users to define custom architectures and loss functions, and provides tools for debugging and visualization. PyTorch’s ease of use and flexibility make it a popular choice for both academic research and industrial level machine learning projects.

Stable-Baselines3

Stable-Baselines3 is a reinforcement learning library built on top of PyTorch. It provides a set of pre-implemented RL algorithms and simplifies the process of training and evaluating RL models. It also includes a variety of useful tools for monitoring training progress and visualizing model behavior.

We will use PyTorch and Stable-Baselines3 to train a DQN model. First, we need to install the Stable-Baselines3 library.

!pip install stable-baselines3

Train Model

Stable-Baselines3 wraps the agent training process and fully supports OpenAI Gym environment. Only one line of script is enough to train a model based on a Gym environment.

First, we need to create an object of DRL4TEnv, as described in Create custom OpenAI Gym environment for Deep Reinforcement Learning (drl4t-04). For convenience, I packaged it in a file called “drl4t_env.py”.

from drl4t_data import download
from drl4t_env import DRL4TEnv

train_data, test_data = download('nyse.csv')
env = DRL4TEnv(train_data)

Now, we can use stable_baselines3 to train the trading model as simply as this:

from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import DQN

model = DQN('MlpPolicy', DummyVecEnv([lambda: env]), learning_rate=0.001, verbose=1)
model.learn(total_timesteps=1000, log_interval=10)
model.save('nyse_dqn_model.pt')

To demonstrate clearly, two lines of script are used here. The first line creates a new DQN model with several parameters:

  • MlpPolicy: the policy used to train the model
  • DummyVecEnv(): uses multiple independent environments to optimize the training
  • learning_rate: the step size of each iteration
  • verbose: verbosity level (1 for info messages)

The second line trains the model with two parameters:

  • total_timesteps: the number of timesteps used to complete the training
  • log_interval: the number of timesteps between logging

Retrain Model

In the end of the scripts above for training the model, we save the trained model in a file named “nyse_dqn_model.zip”. We can reload the model from this file and continue training the model with new data.

model = DQN.load('nyse_dqn_model.pt')

model.set_env(DummyVecEnv([lambda: env]))
model.learn(total_timesteps=1000, log_interval=10)
model.save('nyse_dqn_model.pt')

Validate Model

To validate the trained model, we will have the model make predictions based on each stock’s test-data separately, simulate trades based on the model’s predictions and count the results.

In the scripts below, each time only the test-data of one stock is sent to the env, and then loop one single episode. The info returned from each action has been appended to logs.

import pandas as pd

model = DQN.load('nyse_dqn_model.pt')

logs = []
for symbol, data in test_data.items():
env = DRL4TEnv({ symbol: data })
model.set_env(DummyVecEnv([lambda: env]))

obs = env.reset()
done = False

log = pd.DataFrame()
while(not done):
action, _ = model.predict(obs)
obs, _, done, info = env.step(action)
log = pd.concat([log, pd.DataFrame(info, index=[info['Date']])])
logs.append(log)

In the end, we summarize the validation results of all stocks:

val = pd.DataFrame()
for log in logs:
log['Benchmark'] = env.starting_balance / log['Close'][0] * log['Close']
log['Policy'] = log['Total']
val = val.add(log[['Benchmark', 'Policy']], fill_value=0)
val.to_csv('nyse_dqn_val.csv')

Here, the concept of benchmark is introduced: buying the stock with all available cash on the first day and holding it until the last day.

Let’s check the results of model validation:

val

By comparing the returns of the simulated trades with the benchmark, we can determine whether the trading model is able to consistently make accurate trading decisions.

Data Visualization

Unfortunately, even though these data are very detailed, it is difficult to get an intuitive conclusion from them. At that point a chart might help to solve this problem.

Data visualization is the graphical representation of data and information. It is an essential tool to quickly understand the performance of machine learning trading models, and quickly identify areas that need attention and enhancement.

Data visualization can take many forms. Here, to compare the returns of the simulated trades with the benchmark, we can use a line chart.

The following code will read the saved model validation results and display them as a line chart. It also includes steps to normalize the data.

import matplotlib.pyplot as plt

val = pd.read_csv('nyse_dqn_val.csv', parse_dates=True, index_col=0)

val['Policy'] /= val['Benchmark'][0]
val['Benchmark'] /= val['Benchmark'][0]

ax = val[['Policy', 'Benchmark']].plot(title='Normalized Policy vs. Benchmark')
ax.set_xlabel('Date')
ax.set_ylabel('Normalized Balance')
plt.show()

The line chart below shows the summarized results of simulated trades on 21 stocks using a trading model trained over 100,000 time steps versus the benchmark.

The chart shows that in 100 days, the overall stock price of these 21 stocks is decreasing, and the trading model got a positive return. Good job!

--

--

Xiaoguang Li

Master of Science in Computational Data Analytics from Georgia Tech, Senior IT Consultant at Morgan Stanley