鱼C论坛

 找回密码
 立即注册
查看: 2187|回复: 1

[技术交流] Python实现LSTNet【tensorflow2.0】【时间序列】

[复制链接]
发表于 2021-8-17 22:10:40 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
论文链接:https://arxiv.org/pdf/1703.07015.pdf
参考链接:https://zhuanlan.zhihu.com/p/61795416



  1. class LSTNet(keras.Model):
  2.     def __init__(self, ):
  3.         super(LSTNet, self).__init__()
  4.         self.CnnChannel = 32#CNN输出的channel数
  5.         self.CnnKernelSize = 5#CNN中Kernel的大小
  6.         self.GruChannel = 16#GRU输出的channel数
  7.         self.GruSkipChannel = 16#GRU_Skip的输出channel数
  8.         self.skip = 7#时间跳跃的跨度
  9.         self.hw = 7#AR线性窗口
  10.    
  11.     def build(self, input_shape):
  12.         self.CountryDims = input_shape[2]
  13.         self.TimeStamp = input_shape[1]
  14.         self.IntervalCount = int(self.TimeStamp / self.skip)
  15.         #################非线性层###################
  16.         self.CNN = layers.Conv1D(filters = self.CnnChannel,
  17.                                  kernel_size = self.CnnKernelSize,
  18.                                  activation = "relu", #dropout = 0.5,
  19.                                  padding = "same")
  20.         #非跳跃的RNN层
  21.         self.RNN = layers.GRU(units = self.GruChannel,
  22.                               dropout = 0.5, unroll = True)
  23.         #跳跃的RNN层
  24.         self.RNNSkip = layers.GRU(units = self.GruSkipChannel,
  25.                                   dropout = 0.5, unroll = True)
  26.         self.Dense = layers.Dense(units = self.CountryDims)
  27.         #################线性层######################
  28.         super(LSTNet, self).build(input_shape)
  29.         
  30.     def call(self, x, training = None):
  31.         #[batchsize, timestramp, countrycount] -> [batchsize, timestramp, CnnChannnel]
  32.         cnn_out = self.CNN(x)
  33.         #[batchsize, timestramp, CnnChannnel] -> [batchsize, GruChannel]
  34.         rnn_out = self.RNN(cnn_out)
  35.         ######################跳跃GRU########################
  36.         #[batchsize, timestramp, CnnChannnel] -> [batchsize, int(-p*IntervalCount), CnnChannnel]
  37.         #给定周期跨度下,可能存在原时间戳长度不能整除的情况,所以这里用int(-p*IntervalCount)
  38.         cnn_out_cut = cnn_out[:, int(-self.skip * self.IntervalCount):, :]
  39.         #[batchsize, int(-skip*IntervalCount), CnnChannnel] -> [batchsize, IntervalCount, skip, CnnChannnel]
  40.         cnn_out_stack = tf.reshape(cnn_out_cut, [-1, self.IntervalCount, self.skip, self.CnnChannel])
  41.         #[batchsize, IntervalCount, skip, CnnChannnel] -> [batchsize, skip, IntervalCount, CnnChannnel]
  42.         cnn_out_exchange = tf.transpose(cnn_out_stack, [0, 2, 1, 3])
  43.         #[batchsize, skip, IntervalCount, CnnChannnel] -> [batchsize*skip, IntervalCount, CnnChannnel]
  44.         cnn_out_input = tf.reshape(cnn_out_exchange, [-1, self.IntervalCount, self.CnnChannel])
  45.         #[batchsize*skip, IntervalCount, CnnChannnel]-> [batchsize*skip, GruSkipChannel]
  46.         cnn_skip = self.RNNSkip(cnn_out_input)
  47.         #[batchsize*skip, GruSkipChannel] -> [batchsize, skip*GruSkipChannel]
  48.         cnn_skip_out = tf.reshape(cnn_skip, [-1, self.skip * self.GruSkipChannel])
  49.         #合并RNN和Skip-RNN
  50.         #[batchsize, GruChannel] concate [batchsize, skip*GruSkipChannel] -> [batchsize, GruChannel + skip*GruSkipChannel]
  51.         con_cnn = layers.concatenate([rnn_out, cnn_skip_out], axis = 1)
  52.         #[batchsize, GruChannel + skip*GruSkipChannel] -> [batchsize, CountryDims]
  53.         out = self.Dense(con_cnn)
  54.         #######################线性部分######################
  55.         #线性AR模型
  56.         # highway,模型线性AR
  57.         #[batchsize, timestramp, countrycount] -> [batchsize, hw, countrycount]
  58.         linear = x[:, -self.hw:, :]
  59.         linear = tf.convert_to_tensor(linear)
  60.         #[batchsize, hw, countrycount] -> [batchsize,countrycount, hw]
  61.         linear = tf.transpose(linear, [0, 2, 1])
  62.         #print(linear.shape)
  63.         #[batchsize,countrycount, hw] -> [batchsize*countrycount, hw]
  64.         linear = tf.reshape(linear, [-1, self.hw])
  65.         #[batchsize*countrycount, hw] -> [batchsize*countrycount, 1]
  66.         linear = layers.Dense(1)(linear)
  67.         #[batchsize*countrycount, 1] -> [batchsize, countrycount]
  68.         linear = tf.reshape(linear, (-1, self.CountryDims))
  69.         res = layers.add([out, linear])
  70.         return res
  71.    
  72. model = LSTNet()
  73. model.build((None, 14, 8))
  74. model.summary()      
复制代码

本帖被以下淘专辑推荐:

小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2021-8-18 08:19:06 | 显示全部楼层
厉害
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-26 12:34

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表