糖逗 发表于 2021-8-17 22:10:40

Python实现LSTNet【tensorflow2.0】【时间序列】

论文链接:https://arxiv.org/pdf/1703.07015.pdf
参考链接:https://zhuanlan.zhihu.com/p/61795416



class LSTNet(keras.Model):
    def __init__(self, ):
      super(LSTNet, self).__init__()
      self.CnnChannel = 32#CNN输出的channel数
      self.CnnKernelSize = 5#CNN中Kernel的大小
      self.GruChannel = 16#GRU输出的channel数
      self.GruSkipChannel = 16#GRU_Skip的输出channel数
      self.skip = 7#时间跳跃的跨度
      self.hw = 7#AR线性窗口
   
    def build(self, input_shape):
      self.CountryDims = input_shape
      self.TimeStamp = input_shape
      self.IntervalCount = int(self.TimeStamp / self.skip)
      #################非线性层###################
      self.CNN = layers.Conv1D(filters = self.CnnChannel,
                                 kernel_size = self.CnnKernelSize,
                                 activation = "relu", #dropout = 0.5,
                                 padding = "same")
      #非跳跃的RNN层
      self.RNN = layers.GRU(units = self.GruChannel,
                              dropout = 0.5, unroll = True)
      #跳跃的RNN层
      self.RNNSkip = layers.GRU(units = self.GruSkipChannel,
                                  dropout = 0.5, unroll = True)
      self.Dense = layers.Dense(units = self.CountryDims)
      #################线性层######################
      super(LSTNet, self).build(input_shape)
      
    def call(self, x, training = None):
      # ->
      cnn_out = self.CNN(x)
      # ->
      rnn_out = self.RNN(cnn_out)
      ######################跳跃GRU########################
      # ->
      #给定周期跨度下,可能存在原时间戳长度不能整除的情况,所以这里用int(-p*IntervalCount)
      cnn_out_cut = cnn_out[:, int(-self.skip * self.IntervalCount):, :]
      # ->
      cnn_out_stack = tf.reshape(cnn_out_cut, [-1, self.IntervalCount, self.skip, self.CnnChannel])
      # ->
      cnn_out_exchange = tf.transpose(cnn_out_stack, )
      # ->
      cnn_out_input = tf.reshape(cnn_out_exchange, [-1, self.IntervalCount, self.CnnChannel])
      #->
      cnn_skip = self.RNNSkip(cnn_out_input)
      # ->
      cnn_skip_out = tf.reshape(cnn_skip, [-1, self.skip * self.GruSkipChannel])
      #合并RNN和Skip-RNN
      # concate ->
      con_cnn = layers.concatenate(, axis = 1)
      # ->
      out = self.Dense(con_cnn)
      #######################线性部分######################
      #线性AR模型
      # highway,模型线性AR
      # ->
      linear = x[:, -self.hw:, :]
      linear = tf.convert_to_tensor(linear)
      # ->
      linear = tf.transpose(linear, )
      #print(linear.shape)
      # ->
      linear = tf.reshape(linear, [-1, self.hw])
      # ->
      linear = layers.Dense(1)(linear)
      # ->
      linear = tf.reshape(linear, (-1, self.CountryDims))
      res = layers.add()
      return res
   
model = LSTNet()
model.build((None, 14, 8))
model.summary()      

antonybear 发表于 2021-8-18 08:19:06

厉害
页: [1]
查看完整版本: Python实现LSTNet【tensorflow2.0】【时间序列】