1-4,时间序列数据建模流程范

国内的新冠肺炎疫情从发现至今已经持续N个多月了,这场起源于吃野味的灾难给大家的生活造成了诸多方面的影响。
有的同学是收入上的,有的同学是感情上的,有的同学是心理上的,还有的同学是体重上的。
那么国内的新冠肺炎疫情何时结束呢?什么时候我们才可以重获自由呢?
本篇文章将利用TensorFlow2.0建立时间序列RNN模型,对国内的新冠肺炎疫情结束时间进行预测。

image.png

一,准备数据

本文的数据集取自tushare,获取该数据集的方法参考了以下文章(前三个月数据)。
https://zhuanlan.zhihu.com/p/109556102》
image.png

  1. import numpy as np
  2. import pandas as pd
  3. import matplotlib.pyplot as plt
  4. import tensorflow as tf
  5. from tensorflow.keras import models,layers,losses,metrics,callbacks
  1. %matplotlib inline
  2. %config InlineBackend.figure_format = 'svg'
  3. df = pd.read_csv("./data/covid-19.csv",sep = "\t")
  4. df.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
  5. plt.xticks(rotation=60)

image.png

  1. dfdata = df.set_index("date")
  2. dfdiff = dfdata.diff(periods=1).dropna()
  3. dfdiff = dfdiff.reset_index("date")
  4. dfdiff.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
  5. plt.xticks(rotation=60)
  6. dfdiff = dfdiff.drop("date",axis = 1).astype("float32")

image.png

  1. #用某日前8天窗口数据作为输入预测该日数据
  2. WINDOW_SIZE = 8
  3. def batch_dataset(dataset):
  4. dataset_batched = dataset.batch(WINDOW_SIZE,drop_remainder=True)
  5. return dataset_batched
  6. ds_data = tf.data.Dataset.from_tensor_slices(tf.constant(dfdiff.values,dtype = tf.float32)) \
  7. .window(WINDOW_SIZE,shift=1).flat_map(batch_dataset)
  8. ds_label = tf.data.Dataset.from_tensor_slices(
  9. tf.constant(dfdiff.values[WINDOW_SIZE:],dtype = tf.float32))
  10. #数据较小,可以将全部训练数据放入到一个batch中,提升性能
  11. ds_train = tf.data.Dataset.zip((ds_data,ds_label)).batch(38).cache()

二,定义模型

使用Keras接口有以下3种方式构建模型:使用Sequential按层顺序构建模型,使用函数式API构建任意结构模型,继承Model基类构建自定义模型。
此处选择使用函数式API构建任意结构模型。

  1. #考虑到新增确诊,新增治愈,新增死亡人数数据不可能小于0,设计如下结构
  2. class Block(layers.Layer):
  3. def __init__(self, **kwargs):
  4. super(Block, self).__init__(**kwargs)
  5. def call(self, x_input,x):
  6. x_out = tf.maximum((1+x)*x_input[:,-1,:],0.0)
  7. return x_out
  8. def get_config(self):
  9. config = super(Block, self).get_config()
  10. return config
  1. tf.keras.backend.clear_session()
  2. x_input = layers.Input(shape = (None,3),dtype = tf.float32)
  3. x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x_input)
  4. x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x)
  5. x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x)
  6. x = layers.LSTM(3,input_shape=(None,3))(x)
  7. x = layers.Dense(3)(x)
  8. #考虑到新增确诊,新增治愈,新增死亡人数数据不可能小于0,设计如下结构
  9. #x = tf.maximum((1+x)*x_input[:,-1,:],0.0)
  10. x = Block()(x_input,x)
  11. model = models.Model(inputs = [x_input],outputs = [x])
  12. model.summary()
  1. Model: "model"
  2. _________________________________________________________________
  3. Layer (type) Output Shape Param #
  4. =================================================================
  5. input_1 (InputLayer) [(None, None, 3)] 0
  6. _________________________________________________________________
  7. lstm (LSTM) (None, None, 3) 84
  8. _________________________________________________________________
  9. lstm_1 (LSTM) (None, None, 3) 84
  10. _________________________________________________________________
  11. lstm_2 (LSTM) (None, None, 3) 84
  12. _________________________________________________________________
  13. lstm_3 (LSTM) (None, 3) 84
  14. _________________________________________________________________
  15. dense (Dense) (None, 3) 12
  16. _________________________________________________________________
  17. block (Block) (None, 3) 0
  18. =================================================================
  19. Total params: 348
  20. Trainable params: 348
  21. Non-trainable params: 0
  22. _________________________________________________________________

三,训练模型

训练模型通常有3种方法,内置fit方法,内置train_on_batch方法,以及自定义训练循环。此处我们选择最常用也最简单的内置fit方法。
注:循环神经网络调试较为困难,需要设置多个不同的学习率多次尝试,以取得较好的效果。

  1. #自定义损失函数,考虑平方差和预测目标的比值
  2. class MSPE(losses.Loss):
  3. def call(self,y_true,y_pred):
  4. err_percent = (y_true - y_pred)**2/(tf.maximum(y_true**2,1e-7))
  5. mean_err_percent = tf.reduce_mean(err_percent)
  6. return mean_err_percent
  7. def get_config(self):
  8. config = super(MSPE, self).get_config()
  9. return config
  1. import os
  2. import datetime
  3. optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
  4. model.compile(optimizer=optimizer,loss=MSPE(name = "MSPE"))
  5. stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
  6. logdir = os.path.join('data', 'autograph', stamp)
  7. ## 在 Python3 下建议使用 pathlib 修正各操作系统的路径
  8. # from pathlib import Path
  9. # stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
  10. # logdir = str(Path('./data/autograph/' + stamp))
  11. tb_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
  12. #如果loss在100个epoch后没有提升,学习率减半。
  13. lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor="loss",factor = 0.5, patience = 100)
  14. #当loss在200个epoch后没有提升,则提前终止训练。
  15. stop_callback = tf.keras.callbacks.EarlyStopping(monitor = "loss", patience= 200)
  16. callbacks_list = [tb_callback,lr_callback,stop_callback]
  17. history = model.fit(ds_train,epochs=500,callbacks = callbacks_list)
  1. Epoch 371/500
  2. 1/1 [==============================] - 0s 61ms/step - loss: 0.1184
  3. Epoch 372/500
  4. 1/1 [==============================] - 0s 64ms/step - loss: 0.1177
  5. Epoch 373/500
  6. 1/1 [==============================] - 0s 56ms/step - loss: 0.1169
  7. Epoch 374/500
  8. 1/1 [==============================] - 0s 50ms/step - loss: 0.1161
  9. Epoch 375/500
  10. 1/1 [==============================] - 0s 55ms/step - loss: 0.1154
  11. Epoch 376/500
  12. 1/1 [==============================] - 0s 55ms/step - loss: 0.1147
  13. Epoch 377/500
  14. 1/1 [==============================] - 0s 62ms/step - loss: 0.1140
  15. Epoch 378/500
  16. 1/1 [==============================] - 0s 93ms/step - loss: 0.1133
  17. Epoch 379/500
  18. 1/1 [==============================] - 0s 85ms/step - loss: 0.1126
  19. Epoch 380/500
  20. 1/1 [==============================] - 0s 68ms/step - loss: 0.1119
  21. Epoch 381/500
  22. 1/1 [==============================] - 0s 52ms/step - loss: 0.1113
  23. Epoch 382/500
  24. 1/1 [==============================] - 0s 54ms/step - loss: 0.1107
  25. Epoch 383/500
  26. 1/1 [==============================] - 0s 55ms/step - loss: 0.1100
  27. Epoch 384/500
  28. 1/1 [==============================] - 0s 56ms/step - loss: 0.1094
  29. Epoch 385/500
  30. 1/1 [==============================] - 0s 54ms/step - loss: 0.1088
  31. Epoch 386/500
  32. 1/1 [==============================] - 0s 74ms/step - loss: 0.1082
  33. Epoch 387/500
  34. 1/1 [==============================] - 0s 60ms/step - loss: 0.1077
  35. Epoch 388/500
  36. 1/1 [==============================] - 0s 52ms/step - loss: 0.1071
  37. Epoch 389/500
  38. 1/1 [==============================] - 0s 52ms/step - loss: 0.1066
  39. Epoch 390/500
  40. 1/1 [==============================] - 0s 56ms/step - loss: 0.1060
  41. Epoch 391/500
  42. 1/1 [==============================] - 0s 61ms/step - loss: 0.1055
  43. Epoch 392/500
  44. 1/1 [==============================] - 0s 60ms/step - loss: 0.1050
  45. Epoch 393/500
  46. 1/1 [==============================] - 0s 59ms/step - loss: 0.1045
  47. Epoch 394/500
  48. 1/1 [==============================] - 0s 65ms/step - loss: 0.1040
  49. Epoch 395/500
  50. 1/1 [==============================] - 0s 58ms/step - loss: 0.1035
  51. Epoch 396/500
  52. 1/1 [==============================] - 0s 52ms/step - loss: 0.1031
  53. Epoch 397/500
  54. 1/1 [==============================] - 0s 58ms/step - loss: 0.1026
  55. Epoch 398/500
  56. 1/1 [==============================] - 0s 60ms/step - loss: 0.1022
  57. Epoch 399/500
  58. 1/1 [==============================] - 0s 57ms/step - loss: 0.1017
  59. Epoch 400/500
  60. 1/1 [==============================] - 0s 63ms/step - loss: 0.1013
  61. Epoch 401/500
  62. 1/1 [==============================] - 0s 59ms/step - loss: 0.1009
  63. Epoch 402/500
  64. 1/1 [==============================] - 0s 53ms/step - loss: 0.1005
  65. Epoch 403/500
  66. 1/1 [==============================] - 0s 56ms/step - loss: 0.1001
  67. Epoch 404/500
  68. 1/1 [==============================] - 0s 55ms/step - loss: 0.0997
  69. Epoch 405/500
  70. 1/1 [==============================] - 0s 58ms/step - loss: 0.0993
  71. Epoch 406/500
  72. 1/1 [==============================] - 0s 53ms/step - loss: 0.0990
  73. Epoch 407/500
  74. 1/1 [==============================] - 0s 59ms/step - loss: 0.0986
  75. Epoch 408/500
  76. 1/1 [==============================] - 0s 63ms/step - loss: 0.0982
  77. Epoch 409/500
  78. 1/1 [==============================] - 0s 67ms/step - loss: 0.0979
  79. Epoch 410/500
  80. 1/1 [==============================] - 0s 55ms/step - loss: 0.0976
  81. Epoch 411/500
  82. 1/1 [==============================] - 0s 54ms/step - loss: 0.0972
  83. Epoch 412/500
  84. 1/1 [==============================] - 0s 55ms/step - loss: 0.0969
  85. Epoch 413/500
  86. 1/1 [==============================] - 0s 55ms/step - loss: 0.0966
  87. Epoch 414/500
  88. 1/1 [==============================] - 0s 59ms/step - loss: 0.0963
  89. Epoch 415/500
  90. 1/1 [==============================] - 0s 60ms/step - loss: 0.0960
  91. Epoch 416/500
  92. 1/1 [==============================] - 0s 62ms/step - loss: 0.0957
  93. Epoch 417/500
  94. 1/1 [==============================] - 0s 69ms/step - loss: 0.0954
  95. Epoch 418/500
  96. 1/1 [==============================] - 0s 60ms/step - loss: 0.0951
  97. Epoch 419/500
  98. 1/1 [==============================] - 0s 50ms/step - loss: 0.0948
  99. Epoch 420/500
  100. 1/1 [==============================] - 0s 56ms/step - loss: 0.0946
  101. Epoch 421/500
  102. 1/1 [==============================] - 0s 57ms/step - loss: 0.0943
  103. Epoch 422/500
  104. 1/1 [==============================] - 0s 55ms/step - loss: 0.0941
  105. Epoch 423/500
  106. 1/1 [==============================] - 0s 62ms/step - loss: 0.0938
  107. Epoch 424/500
  108. 1/1 [==============================] - 0s 60ms/step - loss: 0.0936
  109. Epoch 425/500
  110. 1/1 [==============================] - 0s 100ms/step - loss: 0.0933
  111. Epoch 426/500
  112. 1/1 [==============================] - 0s 68ms/step - loss: 0.0931
  113. Epoch 427/500
  114. 1/1 [==============================] - 0s 60ms/step - loss: 0.0929
  115. Epoch 428/500
  116. 1/1 [==============================] - 0s 50ms/step - loss: 0.0926
  117. Epoch 429/500
  118. 1/1 [==============================] - 0s 55ms/step - loss: 0.0924
  119. Epoch 430/500
  120. 1/1 [==============================] - 0s 57ms/step - loss: 0.0922
  121. Epoch 431/500
  122. 1/1 [==============================] - 0s 75ms/step - loss: 0.0920
  123. Epoch 432/500
  124. 1/1 [==============================] - 0s 57ms/step - loss: 0.0918
  125. Epoch 433/500
  126. 1/1 [==============================] - 0s 77ms/step - loss: 0.0916
  127. Epoch 434/500
  128. 1/1 [==============================] - 0s 50ms/step - loss: 0.0914
  129. Epoch 435/500
  130. 1/1 [==============================] - 0s 56ms/step - loss: 0.0912
  131. Epoch 436/500
  132. 1/1 [==============================] - 0s 60ms/step - loss: 0.0911
  133. Epoch 437/500
  134. 1/1 [==============================] - 0s 55ms/step - loss: 0.0909
  135. Epoch 438/500
  136. 1/1 [==============================] - 0s 57ms/step - loss: 0.0907
  137. Epoch 439/500
  138. 1/1 [==============================] - 0s 59ms/step - loss: 0.0905
  139. Epoch 440/500
  140. 1/1 [==============================] - 0s 60ms/step - loss: 0.0904
  141. Epoch 441/500
  142. 1/1 [==============================] - 0s 68ms/step - loss: 0.0902
  143. Epoch 442/500
  144. 1/1 [==============================] - 0s 73ms/step - loss: 0.0901
  145. Epoch 443/500
  146. 1/1 [==============================] - 0s 50ms/step - loss: 0.0899
  147. Epoch 444/500
  148. 1/1 [==============================] - 0s 58ms/step - loss: 0.0898
  149. Epoch 445/500
  150. 1/1 [==============================] - 0s 56ms/step - loss: 0.0896
  151. Epoch 446/500
  152. 1/1 [==============================] - 0s 52ms/step - loss: 0.0895
  153. Epoch 447/500
  154. 1/1 [==============================] - 0s 60ms/step - loss: 0.0893
  155. Epoch 448/500
  156. 1/1 [==============================] - 0s 64ms/step - loss: 0.0892
  157. Epoch 449/500
  158. 1/1 [==============================] - 0s 70ms/step - loss: 0.0891
  159. Epoch 450/500
  160. 1/1 [==============================] - 0s 57ms/step - loss: 0.0889
  161. Epoch 451/500
  162. 1/1 [==============================] - 0s 53ms/step - loss: 0.0888
  163. Epoch 452/500
  164. 1/1 [==============================] - 0s 51ms/step - loss: 0.0887
  165. Epoch 453/500
  166. 1/1 [==============================] - 0s 55ms/step - loss: 0.0886
  167. Epoch 454/500
  168. 1/1 [==============================] - 0s 58ms/step - loss: 0.0885
  169. Epoch 455/500
  170. 1/1 [==============================] - 0s 55ms/step - loss: 0.0883
  171. Epoch 456/500
  172. 1/1 [==============================] - 0s 71ms/step - loss: 0.0882
  173. Epoch 457/500
  174. 1/1 [==============================] - 0s 50ms/step - loss: 0.0881
  175. Epoch 458/500
  176. 1/1 [==============================] - 0s 56ms/step - loss: 0.0880
  177. Epoch 459/500
  178. 1/1 [==============================] - 0s 55ms/step - loss: 0.0879
  179. Epoch 460/500
  180. 1/1 [==============================] - 0s 57ms/step - loss: 0.0878
  181. Epoch 461/500
  182. 1/1 [==============================] - 0s 56ms/step - loss: 0.0878
  183. Epoch 462/500
  184. 1/1 [==============================] - 0s 55ms/step - loss: 0.0879
  185. Epoch 463/500
  186. 1/1 [==============================] - 0s 60ms/step - loss: 0.0879
  187. Epoch 464/500
  188. 1/1 [==============================] - 0s 68ms/step - loss: 0.0888
  189. Epoch 465/500
  190. 1/1 [==============================] - 0s 62ms/step - loss: 0.0875
  191. Epoch 466/500
  192. 1/1 [==============================] - 0s 55ms/step - loss: 0.0873
  193. Epoch 467/500
  194. 1/1 [==============================] - 0s 49ms/step - loss: 0.0872
  195. Epoch 468/500
  196. 1/1 [==============================] - 0s 56ms/step - loss: 0.0872
  197. Epoch 469/500
  198. 1/1 [==============================] - 0s 55ms/step - loss: 0.0871
  199. Epoch 470/500
  200. 1/1 [==============================] - 0s 55ms/step - loss: 0.0871
  201. Epoch 471/500
  202. 1/1 [==============================] - 0s 59ms/step - loss: 0.0870
  203. Epoch 472/500
  204. 1/1 [==============================] - 0s 68ms/step - loss: 0.0871
  205. Epoch 473/500
  206. 1/1 [==============================] - 0s 57ms/step - loss: 0.0869
  207. Epoch 474/500
  208. 1/1 [==============================] - 0s 61ms/step - loss: 0.0870
  209. Epoch 475/500
  210. 1/1 [==============================] - 0s 47ms/step - loss: 0.0868
  211. Epoch 476/500
  212. 1/1 [==============================] - 0s 55ms/step - loss: 0.0868
  213. Epoch 477/500
  214. 1/1 [==============================] - 0s 62ms/step - loss: 0.0866
  215. Epoch 478/500
  216. 1/1 [==============================] - 0s 58ms/step - loss: 0.0867
  217. Epoch 479/500
  218. 1/1 [==============================] - 0s 60ms/step - loss: 0.0865
  219. Epoch 480/500
  220. 1/1 [==============================] - 0s 65ms/step - loss: 0.0866
  221. Epoch 481/500
  222. 1/1 [==============================] - 0s 58ms/step - loss: 0.0864
  223. Epoch 482/500
  224. 1/1 [==============================] - 0s 57ms/step - loss: 0.0865
  225. Epoch 483/500
  226. 1/1 [==============================] - 0s 53ms/step - loss: 0.0863
  227. Epoch 484/500
  228. 1/1 [==============================] - 0s 56ms/step - loss: 0.0864
  229. Epoch 485/500
  230. 1/1 [==============================] - 0s 56ms/step - loss: 0.0862
  231. Epoch 486/500
  232. 1/1 [==============================] - 0s 55ms/step - loss: 0.0863
  233. Epoch 487/500
  234. 1/1 [==============================] - 0s 52ms/step - loss: 0.0861
  235. Epoch 488/500
  236. 1/1 [==============================] - 0s 68ms/step - loss: 0.0862
  237. Epoch 489/500
  238. 1/1 [==============================] - 0s 62ms/step - loss: 0.0860
  239. Epoch 490/500
  240. 1/1 [==============================] - 0s 57ms/step - loss: 0.0861
  241. Epoch 491/500
  242. 1/1 [==============================] - 0s 51ms/step - loss: 0.0859
  243. Epoch 492/500
  244. 1/1 [==============================] - 0s 54ms/step - loss: 0.0860
  245. Epoch 493/500
  246. 1/1 [==============================] - 0s 51ms/step - loss: 0.0859
  247. Epoch 494/500
  248. 1/1 [==============================] - 0s 54ms/step - loss: 0.0860
  249. Epoch 495/500
  250. 1/1 [==============================] - 0s 50ms/step - loss: 0.0858
  251. Epoch 496/500
  252. 1/1 [==============================] - 0s 69ms/step - loss: 0.0859
  253. Epoch 497/500
  254. 1/1 [==============================] - 0s 63ms/step - loss: 0.0857
  255. Epoch 498/500
  256. 1/1 [==============================] - 0s 56ms/step - loss: 0.0858
  257. Epoch 499/500
  258. 1/1 [==============================] - 0s 54ms/step - loss: 0.0857
  259. Epoch 500/500
  260. 1/1 [==============================] - 0s 57ms/step - loss: 0.0858

四,评估模型

评估模型一般要设置验证集或者测试集,由于此例数据较少,我们仅仅可视化损失函数在训练集上的迭代情况。

  1. %matplotlib inline
  2. %config InlineBackend.figure_format = 'svg'
  3. import matplotlib.pyplot as plt
  4. def plot_metric(history, metric):
  5. train_metrics = history.history[metric]
  6. epochs = range(1, len(train_metrics) + 1)
  7. plt.plot(epochs, train_metrics, 'bo--')
  8. plt.title('Training '+ metric)
  9. plt.xlabel("Epochs")
  10. plt.ylabel(metric)
  11. plt.legend(["train_"+metric])
  12. plt.show()
  1. plot_metric(history,"loss")

image.png

五,使用模型

此处我们使用模型预测疫情结束时间,即 新增确诊病例为0 的时间。

  1. #使用dfresult记录现有数据以及此后预测的疫情数据
  2. dfresult = dfdiff[["confirmed_num","cured_num","dead_num"]].copy()
  3. dfresult.tail()

image.png

  1. #预测此后100天的新增走势,将其结果添加到dfresult中
  2. for i in range(100):
  3. arr_predict = model.predict(tf.constant(tf.expand_dims(dfresult.values[-38:,:],axis = 0)))
  4. dfpredict = pd.DataFrame(tf.cast(tf.floor(arr_predict),tf.float32).numpy(),
  5. columns = dfresult.columns)
  6. dfresult = dfresult.append(dfpredict,ignore_index=True)
  1. dfresult.query("confirmed_num==0").head()
  2. # 第55天开始新增确诊降为0,第45天对应3月10日,也就是10天后,即预计3月20日新增确诊降为0
  3. # 注:该预测偏乐观

image.png

  1. dfresult.query("cured_num==0").head()
  2. # 第164天开始新增治愈降为0,第45天对应3月10日,也就是大概4个月后,即7月10日左右全部治愈。
  3. # 注: 该预测偏悲观,并且存在问题,如果将每天新增治愈人数加起来,将超过累计确诊人数。

image.png

  1. dfresult.query("dead_num==0").head()
  2. # 第60天开始,新增死亡降为0,第45天对应3月10日,也就是大概15天后,即20200325
  3. # 该预测较为合理

image.png

六,保存模型

推荐使用TensorFlow原生方式保存模型。

  1. model.save('./data/tf_model_savedmodel', save_format="tf")
  2. print('export saved model.')
  1. model_loaded = tf.keras.models.load_model('./data/tf_model_savedmodel',compile=False)
  2. optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
  3. model_loaded.compile(optimizer=optimizer,loss=MSPE(name = "MSPE"))
  4. model_loaded.predict(ds_train)