1-4,时间序列数据建模流程范
国内的新冠肺炎疫情从发现至今已经持续N个多月了,这场起源于吃野味的灾难给大家的生活造成了诸多方面的影响。
有的同学是收入上的,有的同学是感情上的,有的同学是心理上的,还有的同学是体重上的。
那么国内的新冠肺炎疫情何时结束呢?什么时候我们才可以重获自由呢?
本篇文章将利用TensorFlow2.0建立时间序列RNN模型,对国内的新冠肺炎疫情结束时间进行预测。

一,准备数据
本文的数据集取自tushare,获取该数据集的方法参考了以下文章(前三个月数据)。
《https://zhuanlan.zhihu.com/p/109556102》
import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport tensorflow as tffrom tensorflow.keras import models,layers,losses,metrics,callbacks
%matplotlib inline%config InlineBackend.figure_format = 'svg'df = pd.read_csv("./data/covid-19.csv",sep = "\t")df.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))plt.xticks(rotation=60)

dfdata = df.set_index("date")dfdiff = dfdata.diff(periods=1).dropna()dfdiff = dfdiff.reset_index("date")dfdiff.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))plt.xticks(rotation=60)dfdiff = dfdiff.drop("date",axis = 1).astype("float32")

#用某日前8天窗口数据作为输入预测该日数据WINDOW_SIZE = 8def batch_dataset(dataset):dataset_batched = dataset.batch(WINDOW_SIZE,drop_remainder=True)return dataset_batchedds_data = tf.data.Dataset.from_tensor_slices(tf.constant(dfdiff.values,dtype = tf.float32)) \.window(WINDOW_SIZE,shift=1).flat_map(batch_dataset)ds_label = tf.data.Dataset.from_tensor_slices(tf.constant(dfdiff.values[WINDOW_SIZE:],dtype = tf.float32))#数据较小,可以将全部训练数据放入到一个batch中,提升性能ds_train = tf.data.Dataset.zip((ds_data,ds_label)).batch(38).cache()
二,定义模型
使用Keras接口有以下3种方式构建模型:使用Sequential按层顺序构建模型,使用函数式API构建任意结构模型,继承Model基类构建自定义模型。
此处选择使用函数式API构建任意结构模型。
#考虑到新增确诊,新增治愈,新增死亡人数数据不可能小于0,设计如下结构class Block(layers.Layer):def __init__(self, **kwargs):super(Block, self).__init__(**kwargs)def call(self, x_input,x):x_out = tf.maximum((1+x)*x_input[:,-1,:],0.0)return x_outdef get_config(self):config = super(Block, self).get_config()return config
tf.keras.backend.clear_session()x_input = layers.Input(shape = (None,3),dtype = tf.float32)x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x_input)x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x)x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x)x = layers.LSTM(3,input_shape=(None,3))(x)x = layers.Dense(3)(x)#考虑到新增确诊,新增治愈,新增死亡人数数据不可能小于0,设计如下结构#x = tf.maximum((1+x)*x_input[:,-1,:],0.0)x = Block()(x_input,x)model = models.Model(inputs = [x_input],outputs = [x])model.summary()
Model: "model"_________________________________________________________________Layer (type) Output Shape Param #=================================================================input_1 (InputLayer) [(None, None, 3)] 0_________________________________________________________________lstm (LSTM) (None, None, 3) 84_________________________________________________________________lstm_1 (LSTM) (None, None, 3) 84_________________________________________________________________lstm_2 (LSTM) (None, None, 3) 84_________________________________________________________________lstm_3 (LSTM) (None, 3) 84_________________________________________________________________dense (Dense) (None, 3) 12_________________________________________________________________block (Block) (None, 3) 0=================================================================Total params: 348Trainable params: 348Non-trainable params: 0_________________________________________________________________
三,训练模型
训练模型通常有3种方法,内置fit方法,内置train_on_batch方法,以及自定义训练循环。此处我们选择最常用也最简单的内置fit方法。
注:循环神经网络调试较为困难,需要设置多个不同的学习率多次尝试,以取得较好的效果。
#自定义损失函数,考虑平方差和预测目标的比值class MSPE(losses.Loss):def call(self,y_true,y_pred):err_percent = (y_true - y_pred)**2/(tf.maximum(y_true**2,1e-7))mean_err_percent = tf.reduce_mean(err_percent)return mean_err_percentdef get_config(self):config = super(MSPE, self).get_config()return config
import osimport datetimeoptimizer = tf.keras.optimizers.Adam(learning_rate=0.01)model.compile(optimizer=optimizer,loss=MSPE(name = "MSPE"))stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")logdir = os.path.join('data', 'autograph', stamp)## 在 Python3 下建议使用 pathlib 修正各操作系统的路径# from pathlib import Path# stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")# logdir = str(Path('./data/autograph/' + stamp))tb_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)#如果loss在100个epoch后没有提升,学习率减半。lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor="loss",factor = 0.5, patience = 100)#当loss在200个epoch后没有提升,则提前终止训练。stop_callback = tf.keras.callbacks.EarlyStopping(monitor = "loss", patience= 200)callbacks_list = [tb_callback,lr_callback,stop_callback]history = model.fit(ds_train,epochs=500,callbacks = callbacks_list)
Epoch 371/5001/1 [==============================] - 0s 61ms/step - loss: 0.1184Epoch 372/5001/1 [==============================] - 0s 64ms/step - loss: 0.1177Epoch 373/5001/1 [==============================] - 0s 56ms/step - loss: 0.1169Epoch 374/5001/1 [==============================] - 0s 50ms/step - loss: 0.1161Epoch 375/5001/1 [==============================] - 0s 55ms/step - loss: 0.1154Epoch 376/5001/1 [==============================] - 0s 55ms/step - loss: 0.1147Epoch 377/5001/1 [==============================] - 0s 62ms/step - loss: 0.1140Epoch 378/5001/1 [==============================] - 0s 93ms/step - loss: 0.1133Epoch 379/5001/1 [==============================] - 0s 85ms/step - loss: 0.1126Epoch 380/5001/1 [==============================] - 0s 68ms/step - loss: 0.1119Epoch 381/5001/1 [==============================] - 0s 52ms/step - loss: 0.1113Epoch 382/5001/1 [==============================] - 0s 54ms/step - loss: 0.1107Epoch 383/5001/1 [==============================] - 0s 55ms/step - loss: 0.1100Epoch 384/5001/1 [==============================] - 0s 56ms/step - loss: 0.1094Epoch 385/5001/1 [==============================] - 0s 54ms/step - loss: 0.1088Epoch 386/5001/1 [==============================] - 0s 74ms/step - loss: 0.1082Epoch 387/5001/1 [==============================] - 0s 60ms/step - loss: 0.1077Epoch 388/5001/1 [==============================] - 0s 52ms/step - loss: 0.1071Epoch 389/5001/1 [==============================] - 0s 52ms/step - loss: 0.1066Epoch 390/5001/1 [==============================] - 0s 56ms/step - loss: 0.1060Epoch 391/5001/1 [==============================] - 0s 61ms/step - loss: 0.1055Epoch 392/5001/1 [==============================] - 0s 60ms/step - loss: 0.1050Epoch 393/5001/1 [==============================] - 0s 59ms/step - loss: 0.1045Epoch 394/5001/1 [==============================] - 0s 65ms/step - loss: 0.1040Epoch 395/5001/1 [==============================] - 0s 58ms/step - loss: 0.1035Epoch 396/5001/1 [==============================] - 0s 52ms/step - loss: 0.1031Epoch 397/5001/1 [==============================] - 0s 58ms/step - loss: 0.1026Epoch 398/5001/1 [==============================] - 0s 60ms/step - loss: 0.1022Epoch 399/5001/1 [==============================] - 0s 57ms/step - loss: 0.1017Epoch 400/5001/1 [==============================] - 0s 63ms/step - loss: 0.1013Epoch 401/5001/1 [==============================] - 0s 59ms/step - loss: 0.1009Epoch 402/5001/1 [==============================] - 0s 53ms/step - loss: 0.1005Epoch 403/5001/1 [==============================] - 0s 56ms/step - loss: 0.1001Epoch 404/5001/1 [==============================] - 0s 55ms/step - loss: 0.0997Epoch 405/5001/1 [==============================] - 0s 58ms/step - loss: 0.0993Epoch 406/5001/1 [==============================] - 0s 53ms/step - loss: 0.0990Epoch 407/5001/1 [==============================] - 0s 59ms/step - loss: 0.0986Epoch 408/5001/1 [==============================] - 0s 63ms/step - loss: 0.0982Epoch 409/5001/1 [==============================] - 0s 67ms/step - loss: 0.0979Epoch 410/5001/1 [==============================] - 0s 55ms/step - loss: 0.0976Epoch 411/5001/1 [==============================] - 0s 54ms/step - loss: 0.0972Epoch 412/5001/1 [==============================] - 0s 55ms/step - loss: 0.0969Epoch 413/5001/1 [==============================] - 0s 55ms/step - loss: 0.0966Epoch 414/5001/1 [==============================] - 0s 59ms/step - loss: 0.0963Epoch 415/5001/1 [==============================] - 0s 60ms/step - loss: 0.0960Epoch 416/5001/1 [==============================] - 0s 62ms/step - loss: 0.0957Epoch 417/5001/1 [==============================] - 0s 69ms/step - loss: 0.0954Epoch 418/5001/1 [==============================] - 0s 60ms/step - loss: 0.0951Epoch 419/5001/1 [==============================] - 0s 50ms/step - loss: 0.0948Epoch 420/5001/1 [==============================] - 0s 56ms/step - loss: 0.0946Epoch 421/5001/1 [==============================] - 0s 57ms/step - loss: 0.0943Epoch 422/5001/1 [==============================] - 0s 55ms/step - loss: 0.0941Epoch 423/5001/1 [==============================] - 0s 62ms/step - loss: 0.0938Epoch 424/5001/1 [==============================] - 0s 60ms/step - loss: 0.0936Epoch 425/5001/1 [==============================] - 0s 100ms/step - loss: 0.0933Epoch 426/5001/1 [==============================] - 0s 68ms/step - loss: 0.0931Epoch 427/5001/1 [==============================] - 0s 60ms/step - loss: 0.0929Epoch 428/5001/1 [==============================] - 0s 50ms/step - loss: 0.0926Epoch 429/5001/1 [==============================] - 0s 55ms/step - loss: 0.0924Epoch 430/5001/1 [==============================] - 0s 57ms/step - loss: 0.0922Epoch 431/5001/1 [==============================] - 0s 75ms/step - loss: 0.0920Epoch 432/5001/1 [==============================] - 0s 57ms/step - loss: 0.0918Epoch 433/5001/1 [==============================] - 0s 77ms/step - loss: 0.0916Epoch 434/5001/1 [==============================] - 0s 50ms/step - loss: 0.0914Epoch 435/5001/1 [==============================] - 0s 56ms/step - loss: 0.0912Epoch 436/5001/1 [==============================] - 0s 60ms/step - loss: 0.0911Epoch 437/5001/1 [==============================] - 0s 55ms/step - loss: 0.0909Epoch 438/5001/1 [==============================] - 0s 57ms/step - loss: 0.0907Epoch 439/5001/1 [==============================] - 0s 59ms/step - loss: 0.0905Epoch 440/5001/1 [==============================] - 0s 60ms/step - loss: 0.0904Epoch 441/5001/1 [==============================] - 0s 68ms/step - loss: 0.0902Epoch 442/5001/1 [==============================] - 0s 73ms/step - loss: 0.0901Epoch 443/5001/1 [==============================] - 0s 50ms/step - loss: 0.0899Epoch 444/5001/1 [==============================] - 0s 58ms/step - loss: 0.0898Epoch 445/5001/1 [==============================] - 0s 56ms/step - loss: 0.0896Epoch 446/5001/1 [==============================] - 0s 52ms/step - loss: 0.0895Epoch 447/5001/1 [==============================] - 0s 60ms/step - loss: 0.0893Epoch 448/5001/1 [==============================] - 0s 64ms/step - loss: 0.0892Epoch 449/5001/1 [==============================] - 0s 70ms/step - loss: 0.0891Epoch 450/5001/1 [==============================] - 0s 57ms/step - loss: 0.0889Epoch 451/5001/1 [==============================] - 0s 53ms/step - loss: 0.0888Epoch 452/5001/1 [==============================] - 0s 51ms/step - loss: 0.0887Epoch 453/5001/1 [==============================] - 0s 55ms/step - loss: 0.0886Epoch 454/5001/1 [==============================] - 0s 58ms/step - loss: 0.0885Epoch 455/5001/1 [==============================] - 0s 55ms/step - loss: 0.0883Epoch 456/5001/1 [==============================] - 0s 71ms/step - loss: 0.0882Epoch 457/5001/1 [==============================] - 0s 50ms/step - loss: 0.0881Epoch 458/5001/1 [==============================] - 0s 56ms/step - loss: 0.0880Epoch 459/5001/1 [==============================] - 0s 55ms/step - loss: 0.0879Epoch 460/5001/1 [==============================] - 0s 57ms/step - loss: 0.0878Epoch 461/5001/1 [==============================] - 0s 56ms/step - loss: 0.0878Epoch 462/5001/1 [==============================] - 0s 55ms/step - loss: 0.0879Epoch 463/5001/1 [==============================] - 0s 60ms/step - loss: 0.0879Epoch 464/5001/1 [==============================] - 0s 68ms/step - loss: 0.0888Epoch 465/5001/1 [==============================] - 0s 62ms/step - loss: 0.0875Epoch 466/5001/1 [==============================] - 0s 55ms/step - loss: 0.0873Epoch 467/5001/1 [==============================] - 0s 49ms/step - loss: 0.0872Epoch 468/5001/1 [==============================] - 0s 56ms/step - loss: 0.0872Epoch 469/5001/1 [==============================] - 0s 55ms/step - loss: 0.0871Epoch 470/5001/1 [==============================] - 0s 55ms/step - loss: 0.0871Epoch 471/5001/1 [==============================] - 0s 59ms/step - loss: 0.0870Epoch 472/5001/1 [==============================] - 0s 68ms/step - loss: 0.0871Epoch 473/5001/1 [==============================] - 0s 57ms/step - loss: 0.0869Epoch 474/5001/1 [==============================] - 0s 61ms/step - loss: 0.0870Epoch 475/5001/1 [==============================] - 0s 47ms/step - loss: 0.0868Epoch 476/5001/1 [==============================] - 0s 55ms/step - loss: 0.0868Epoch 477/5001/1 [==============================] - 0s 62ms/step - loss: 0.0866Epoch 478/5001/1 [==============================] - 0s 58ms/step - loss: 0.0867Epoch 479/5001/1 [==============================] - 0s 60ms/step - loss: 0.0865Epoch 480/5001/1 [==============================] - 0s 65ms/step - loss: 0.0866Epoch 481/5001/1 [==============================] - 0s 58ms/step - loss: 0.0864Epoch 482/5001/1 [==============================] - 0s 57ms/step - loss: 0.0865Epoch 483/5001/1 [==============================] - 0s 53ms/step - loss: 0.0863Epoch 484/5001/1 [==============================] - 0s 56ms/step - loss: 0.0864Epoch 485/5001/1 [==============================] - 0s 56ms/step - loss: 0.0862Epoch 486/5001/1 [==============================] - 0s 55ms/step - loss: 0.0863Epoch 487/5001/1 [==============================] - 0s 52ms/step - loss: 0.0861Epoch 488/5001/1 [==============================] - 0s 68ms/step - loss: 0.0862Epoch 489/5001/1 [==============================] - 0s 62ms/step - loss: 0.0860Epoch 490/5001/1 [==============================] - 0s 57ms/step - loss: 0.0861Epoch 491/5001/1 [==============================] - 0s 51ms/step - loss: 0.0859Epoch 492/5001/1 [==============================] - 0s 54ms/step - loss: 0.0860Epoch 493/5001/1 [==============================] - 0s 51ms/step - loss: 0.0859Epoch 494/5001/1 [==============================] - 0s 54ms/step - loss: 0.0860Epoch 495/5001/1 [==============================] - 0s 50ms/step - loss: 0.0858Epoch 496/5001/1 [==============================] - 0s 69ms/step - loss: 0.0859Epoch 497/5001/1 [==============================] - 0s 63ms/step - loss: 0.0857Epoch 498/5001/1 [==============================] - 0s 56ms/step - loss: 0.0858Epoch 499/5001/1 [==============================] - 0s 54ms/step - loss: 0.0857Epoch 500/5001/1 [==============================] - 0s 57ms/step - loss: 0.0858
四,评估模型
评估模型一般要设置验证集或者测试集,由于此例数据较少,我们仅仅可视化损失函数在训练集上的迭代情况。
%matplotlib inline%config InlineBackend.figure_format = 'svg'import matplotlib.pyplot as pltdef plot_metric(history, metric):train_metrics = history.history[metric]epochs = range(1, len(train_metrics) + 1)plt.plot(epochs, train_metrics, 'bo--')plt.title('Training '+ metric)plt.xlabel("Epochs")plt.ylabel(metric)plt.legend(["train_"+metric])plt.show()
plot_metric(history,"loss")

五,使用模型
此处我们使用模型预测疫情结束时间,即 新增确诊病例为0 的时间。
#使用dfresult记录现有数据以及此后预测的疫情数据dfresult = dfdiff[["confirmed_num","cured_num","dead_num"]].copy()dfresult.tail()

#预测此后100天的新增走势,将其结果添加到dfresult中for i in range(100):arr_predict = model.predict(tf.constant(tf.expand_dims(dfresult.values[-38:,:],axis = 0)))dfpredict = pd.DataFrame(tf.cast(tf.floor(arr_predict),tf.float32).numpy(),columns = dfresult.columns)dfresult = dfresult.append(dfpredict,ignore_index=True)
dfresult.query("confirmed_num==0").head()# 第55天开始新增确诊降为0,第45天对应3月10日,也就是10天后,即预计3月20日新增确诊降为0# 注:该预测偏乐观

dfresult.query("cured_num==0").head()# 第164天开始新增治愈降为0,第45天对应3月10日,也就是大概4个月后,即7月10日左右全部治愈。# 注: 该预测偏悲观,并且存在问题,如果将每天新增治愈人数加起来,将超过累计确诊人数。

dfresult.query("dead_num==0").head()# 第60天开始,新增死亡降为0,第45天对应3月10日,也就是大概15天后,即20200325# 该预测较为合理
六,保存模型
推荐使用TensorFlow原生方式保存模型。
model.save('./data/tf_model_savedmodel', save_format="tf")print('export saved model.')
model_loaded = tf.keras.models.load_model('./data/tf_model_savedmodel',compile=False)optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)model_loaded.compile(optimizer=optimizer,loss=MSPE(name = "MSPE"))model_loaded.predict(ds_train)
