Stable baselines example#

Welcome to a brief introduction to using gym-DSSAT with stable-baselines3.In this tutorial, we will assume familiarity with reinforcement learning and stable-baselines3. For a background or more details about using stable-baselines3 for reinforcement learning, please take a look at the docs.


Running this example using Binder allow you to interact/change the behaviour by editing the python script. Due to ressource limitations on Binder, the number of timestep has been reduced to 10_000.

Alternatively, if you wish to run this tutorial locally. You can copy the code block below into a python script or download the notebook available here.

If you want to run the evalutation on Binder with a pre-trained PPO Agent (600000 step) you can click below:

From source#


For the PPO agent, we need long training (here, 400000), and to compare the agents, we need many evaluations (here 400). With this learning and evaluation, running this code can take about an hour to complete.

import gym
import gym_dssat_pdi
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor

# helpers for action normalization
def normalize_action(action_space_limits, action):
    """Normalize the action from [low, high] to [-1, 1]"""
    low, high = action_space_limits
    return 2.0 * ((action - low) / (high - low)) - 1.0

def denormalize_action(action_space_limits, action):
    """Denormalize the action from [-1, 1] to [low, high]"""
    low, high = action_space_limits
    return low + (0.5 * (action + 1.0) * (high - low))

# Wrapper for easy and uniform interfacing with SB3
class GymDssatWrapper(gym.Wrapper):
    def __init__(self, env):
        super(GymDssatWrapper, self).__init__(env)

        self.action_low, self.action_high = self._get_action_space_bounds()

        # using a normalized action space
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype="float32")

        # using a vector representation of observations to allow
        # easily using SB3 MlpPolicy
        self.observation_space = gym.spaces.Box(low=0.0,

        # to avoid annoying problem with Monitor when episodes end and things are None
        self.last_info = {}
        self.last_obs = None

    def _get_action_space_bounds(self):
        box = self.env.action_space['anfer']
        return box.low, box.high

    def _format_action(self, action):
        return { 'anfer': action[0] }

    def _format_observation(self, observation):
        return self.env.observation_dict_to_array(observation)

    def reset(self):
        return self._format_observation(self.env.reset())

    def step(self, action):
        # Rescale action from [-1, 1] to original action space interval
        denormalized_action = denormalize_action((self.action_low, self.action_high), action)
        formatted_action = self._format_action(denormalized_action)
        obs, reward, done, info = self.env.step(formatted_action)

        # handle `None`s in obs, reward, and info on done step
        if done:
            obs, reward, info = self.last_obs, 0, self.last_info
            self.last_obs = obs
            self.last_info = info

        formatted_observation = self._format_observation(obs)
        return formatted_observation, reward, done, info

    def close(self):
        return self.env.close()

    def seed(self, seed):

    def __del__(self):

# Create environment
env_args = {
    'mode': 'fertilization',
    'seed': 123,
    'random_weather': True,

env = GymDssatWrapper(gym.make('GymDssatPdi-v0', **env_args))

# Training arguments for PPO agent
ppo_args = {
    'gamma': 1,
    'learning_rate': 0.0003,
    'seed': 123,

# Create the agent
ppo_agent = PPO('MlpPolicy', env, **ppo_args)

# Train for 400k timesteps
print('Training PPO agent...')
print('Training done')

# Baseline agents for comparison
class NullAgent:
    Agent always choosing to do no fertilization
    def __init__(self, env):
        self.env = env

    def predict(self, obs, state=None, episode_start=None, deterministic=None):
        action = normalize_action((self.env.action_low, self.env.action_high), [0])
        return np.array([action], dtype=np.float32), obs

class ExpertAgent:
    Simple agent using policy of choosing fertilization amount based on days after planting
    fertilization_dic = {
        40: 27,
        45: 35,
        80: 54,

    def __init__(self, env, normalize_action=False, fertilization_dic=None):
        self.env = env
        self.normalize_action = normalize_action

    def _policy(self, obs):
        dap = int(obs[0][1])
        return [self.fertilization_dic[dap] if dap in self.fertilization_dic else 0]

    def predict(self, obs, state=None, episode_start=None, deterministic=None):
        action = self._policy(obs)
        action = normalize_action((self.env.action_low, self.env.action_high), action)

        return np.array([action], dtype=np.float32), obs

# evaluation and plotting functions
def evaluate(agent, n_episodes=10):
    # Create eval env
    eval_args = {
        'mode': 'fertilization',
        'seed': 456,
        'random_weather': True,
    env = Monitor(GymDssatWrapper(gym.make('GymDssatPdi-v0', **eval_args)))

    returns, _ = evaluate_policy(
        agent, env, n_eval_episodes=n_episodes, return_episode_rewards=True)


    return returns

def plot_results(labels, returns):
    data_dict = {}
    for label, data in zip(labels, returns):
        data_dict[label] = data
    df = pd.DataFrame(data_dict)

    ax = sns.boxplot(data=df)
    ax.set_ylabel("evaluation output")
    print("\nThe result is saved in the current working directory as 'results_sb3.pdf'\n")

# evaluate agents
null_agent = NullAgent(env)
print('Evaluating Null agent...')
null_returns = evaluate(null_agent,n_episodes=400)

print('Evaluating PPO agent...')
ppo_returns = evaluate(ppo_agent,n_episodes=400)

expert_agent = ExpertAgent(env)
print('Evaluating Expert agent...')
expert_returns = evaluate(expert_agent,n_episodes=400)

# display results
labels = ['null', 'ppo', 'expert']
returns = [null_returns, ppo_returns, expert_returns]
plot_results(labels, returns)

with open("eval_output.txt",'w') as f:
    f.write("Null Agent : "+str(null_returns)+"\r\n")
    f.write("PPO Agent : "+str(ppo_returns)+"\r\n")
    f.write("Expert Agent : "+str(expert_returns)+"\r\n")

# Cleanup