• Stable Baselines/用户向导/处理NaNs和infs

    Stable Baselines官方文档中文版 Github CSDN 尝试翻译官方文档,水平有限,如有错误万望指正

    在指定环境下训练模型的过程中,当遇到输入或者从RL模型中返回的NaNinf时,RL模型有完全崩溃的可能。

  • 原因和方式

    问题出现后,NaNsinfs不会崩溃,而是简单的通过训练传递,直到所有的浮点数收敛到NaNinf。这符合IEEE浮点运算标准(IEEE754),标准指出:

    可能出现的物种异常:

    • 无效的操作符($\sqrt{-1}$, inf$*$1, NaN mod 1, …)返回NaN
    • 除以0:
      • 如果运算对象非零(1/0, -2/0, …)返回 $\pm inf$
      • 如果运算对象是零(0/0)返回NaN
    • 上溢(指数太高而无法表示)返回$\pm inf$
    • 下溢(指数太低而无法表示)返回$0$
    • 不精确(以2为底时不能准确表示,例如1/5)返回四舍五入值(例如:assert (1/5) * 3 == 0.6000000000000001

    只有除以0会报错,其他方式只会静静传递。

    Python中,除以0会报如下错:ZeroDivisionError: float division by zero,其他会忽略。

    Numpy中默认警告:RuntimeWarning: invalid value encountered但不会停止代码。

    最差的情况,Tensorflow不会提示任何信息

    1. import tensorflow as tf
    2. import numpy as np
    3. print("tensorflow test:")
    4. a = tf.constant(1.0)
    5. b = tf.constant(0.0)
    6. c = a / b
    7. sess = tf.Session()
    8. val = sess.run(c) # this will be quiet
    9. print(val)
    10. sess.close()
    11. print("\r\nnumpy test:")
    12. a = np.float64(1.0)
    13. b = np.float64(0.0)
    14. val = a / b # this will warn
    15. print(val)
    16. print("\r\npure python test:")
    17. a = 1.0
    18. b = 0.0
    19. val = a / b # this will raise an exception and halt.
    20. print(val)

    不幸的是,大多数浮点运算都是用TensorflowNumpy处理的,这意味着当无效值出现时,你很可能得不到任何警告。

  • Numpy参数

    Numpy有方便处理无效值的方法:numpy.seterr,它是为Python进程定义的,决定它如何处理浮点型错误。

    1. import numpy as np
    2. np.seterr(all='raise') # define before your code.
    3. print("numpy test:")
    4. a = np.float64(1.0)
    5. b = np.float64(0.0)
    6. val = a / b # this will now raise an exception instead of a warning.
    7. print(val)

    不过这也会避免浮点数的溢出问题:

    1. import numpy as np
    2. np.seterr(all='raise') # define before your code.
    3. print("numpy overflow test:")
    4. a = np.float64(10)
    5. b = np.float64(1000)
    6. val = a ** b # this will now raise an exception
    7. print(val)

    不过无法避免传递问题:

    1. import numpy as np
    2. np.seterr(all='raise') # define before your code.
    3. print("numpy propagation test:")
    4. a = np.float64('NaN')
    5. b = np.float64(1.0)
    6. val = a + b # this will neither warn nor raise anything
    7. print(val)
  • Tensorflow参数

    Tensorflow会增加检查以侦测和处理无效值:tf.add_check_numerics_opstf.check_numerics,然而,他们会增加Tensorflow图表处理,增加运算时间。

    1. import tensorflow as tf
    2. print("tensorflow test:")
    3. a = tf.constant(1.0)
    4. b = tf.constant(0.0)
    5. c = a / b
    6. check_nan = tf.add_check_numerics_ops() # add after your graph definition.
    7. sess = tf.Session()
    8. val, _ = sess.run([c, check_nan]) # this will now raise an exception
    9. print(val)
    10. sess.close()

    这也会避免浮点数溢出问题:

    1. import tensorflow as tf
    2. print("tensorflow overflow test:")
    3. check_nan = [] # the list of check_numerics operations
    4. a = tf.constant(10)
    5. b = tf.constant(1000)
    6. c = a ** b
    7. check_nan.append(tf.check_numerics(c, "")) # check the 'c' operations
    8. sess = tf.Session()
    9. val, _ = sess.run([c] + check_nan) # this will now raise an exception
    10. print(val)
    11. sess.close()

    捕捉传播问题:

    1. import tensorflow as tf
    2. print("tensorflow propagation test:")
    3. check_nan = [] # the list of check_numerics operations
    4. a = tf.constant('NaN')
    5. b = tf.constant(1.0)
    6. c = a + b
    7. check_nan.append(tf.check_numerics(c, "")) # check the 'c' operations
    8. sess = tf.Session()
    9. val, _ = sess.run([c] + check_nan) # this will now raise an exception
    10. print(val)
    11. sess.close()
  • VecChecNan包装器

    为查明无效值源自何时何处,stable-baselines提出VecChecknan包装器。

    它会监控行动、观测、奖励,指明那种行动和观测导致了无效值以及从何处出现。

    1. import gym
    2. from gym import spaces
    3. import numpy as np
    4. from stable_baselines import PPO2
    5. from stable_baselines.common.vec_env import DummyVecEnv, VecCheckNan
    6. class NanAndInfEnv(gym.Env):
    7. """Custom Environment that raised NaNs and Infs"""
    8. metadata = {'render.modes': ['human']}
    9. def __init__(self):
    10. super(NanAndInfEnv, self).__init__()
    11. self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
    12. self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
    13. def step(self, _action):
    14. randf = np.random.rand()
    15. if randf > 0.99:
    16. obs = float('NaN')
    17. elif randf > 0.98:
    18. obs = float('inf')
    19. else:
    20. obs = randf
    21. return [obs], 0.0, False, {}
    22. def reset(self):
    23. return [0.0]
    24. def render(self, mode='human', close=False):
    25. pass
    26. # Create environment
    27. env = DummyVecEnv([lambda: NanAndInfEnv()])
    28. env = VecCheckNan(env, raise_exception=True)
    29. # Instantiate the agent
    30. model = PPO2('MlpPolicy', env)
    31. # Train the agent
    32. model.learn(total_timesteps=int(2e5)) # this will crash explaining that the invalid value originated from the environment.
  • RL模型超参数

    依据你的超参数,NaN可能更经常出现。一个极好的例子:https://github.com/hill-a/stable-baselines/issues/340

    要明白,虽然在大多数案例中默认超参数看起来可以跑通,不过在你的环境下很难最优。如果是这样,搞清楚每个超参数如何影响模型,以便你可以调参以得到稳定模型。或者,你可以尝试自动调参(参见RL Zoo)。

  • 数据集中的缺失值

    如果你的环境产生自外部数据集,确保数据集中不含NaNs。因为有时候数据集中会用NaNs代替缺失值。

    这里有一些关于如何查找NaNs的阅读材料:点击链接

    以及用其他方式查找缺失值:点击链接