| 
 | 
 
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册  
 
x
 
论文链接:https://arxiv.org/pdf/1703.07015.pdf 
参考链接:https://zhuanlan.zhihu.com/p/61795416 
 
 
 
- class LSTNet(keras.Model):
 
 -     def __init__(self, ):
 
 -         super(LSTNet, self).__init__()
 
 -         self.CnnChannel = 32#CNN输出的channel数
 
 -         self.CnnKernelSize = 5#CNN中Kernel的大小
 
 -         self.GruChannel = 16#GRU输出的channel数
 
 -         self.GruSkipChannel = 16#GRU_Skip的输出channel数
 
 -         self.skip = 7#时间跳跃的跨度
 
 -         self.hw = 7#AR线性窗口
 
 -     
 
 -     def build(self, input_shape):
 
 -         self.CountryDims = input_shape[2]
 
 -         self.TimeStamp = input_shape[1]
 
 -         self.IntervalCount = int(self.TimeStamp / self.skip)
 
 -         #################非线性层###################
 
 -         self.CNN = layers.Conv1D(filters = self.CnnChannel, 
 
 -                                  kernel_size = self.CnnKernelSize, 
 
 -                                  activation = "relu", #dropout = 0.5,
 
 -                                  padding = "same")
 
 -         #非跳跃的RNN层
 
 -         self.RNN = layers.GRU(units = self.GruChannel, 
 
 -                               dropout = 0.5, unroll = True)
 
 -         #跳跃的RNN层
 
 -         self.RNNSkip = layers.GRU(units = self.GruSkipChannel,
 
 -                                   dropout = 0.5, unroll = True)
 
 -         self.Dense = layers.Dense(units = self.CountryDims)
 
 -         #################线性层######################
 
 -         super(LSTNet, self).build(input_shape)
 
 -         
 
 -     def call(self, x, training = None):
 
 -         #[batchsize, timestramp, countrycount] -> [batchsize, timestramp, CnnChannnel]
 
 -         cnn_out = self.CNN(x)
 
 -         #[batchsize, timestramp, CnnChannnel] -> [batchsize, GruChannel]
 
 -         rnn_out = self.RNN(cnn_out)
 
 -         ######################跳跃GRU########################
 
 -         #[batchsize, timestramp, CnnChannnel] -> [batchsize, int(-p*IntervalCount), CnnChannnel]
 
 -         #给定周期跨度下,可能存在原时间戳长度不能整除的情况,所以这里用int(-p*IntervalCount)
 
 -         cnn_out_cut = cnn_out[:, int(-self.skip * self.IntervalCount):, :]
 
 -         #[batchsize, int(-skip*IntervalCount), CnnChannnel] -> [batchsize, IntervalCount, skip, CnnChannnel]
 
 -         cnn_out_stack = tf.reshape(cnn_out_cut, [-1, self.IntervalCount, self.skip, self.CnnChannel])
 
 -         #[batchsize, IntervalCount, skip, CnnChannnel] -> [batchsize, skip, IntervalCount, CnnChannnel]
 
 -         cnn_out_exchange = tf.transpose(cnn_out_stack, [0, 2, 1, 3])
 
 -         #[batchsize, skip, IntervalCount, CnnChannnel] -> [batchsize*skip, IntervalCount, CnnChannnel]
 
 -         cnn_out_input = tf.reshape(cnn_out_exchange, [-1, self.IntervalCount, self.CnnChannel])
 
 -         #[batchsize*skip, IntervalCount, CnnChannnel]-> [batchsize*skip, GruSkipChannel]
 
 -         cnn_skip = self.RNNSkip(cnn_out_input)
 
 -         #[batchsize*skip, GruSkipChannel] -> [batchsize, skip*GruSkipChannel]
 
 -         cnn_skip_out = tf.reshape(cnn_skip, [-1, self.skip * self.GruSkipChannel])
 
 -         #合并RNN和Skip-RNN
 
 -         #[batchsize, GruChannel] concate [batchsize, skip*GruSkipChannel] -> [batchsize, GruChannel + skip*GruSkipChannel]
 
 -         con_cnn = layers.concatenate([rnn_out, cnn_skip_out], axis = 1)
 
 -         #[batchsize, GruChannel + skip*GruSkipChannel] -> [batchsize, CountryDims]
 
 -         out = self.Dense(con_cnn)
 
 -         #######################线性部分######################
 
 -         #线性AR模型
 
 -         # highway,模型线性AR
 
 -         #[batchsize, timestramp, countrycount] -> [batchsize, hw, countrycount]
 
 -         linear = x[:, -self.hw:, :]
 
 -         linear = tf.convert_to_tensor(linear)
 
 -         #[batchsize, hw, countrycount] -> [batchsize,countrycount, hw]
 
 -         linear = tf.transpose(linear, [0, 2, 1])
 
 -         #print(linear.shape)
 
 -         #[batchsize,countrycount, hw] -> [batchsize*countrycount, hw]
 
 -         linear = tf.reshape(linear, [-1, self.hw])
 
 -         #[batchsize*countrycount, hw] -> [batchsize*countrycount, 1]
 
 -         linear = layers.Dense(1)(linear)
 
 -         #[batchsize*countrycount, 1] -> [batchsize, countrycount]
 
 -         linear = tf.reshape(linear, (-1, self.CountryDims))
 
 -         res = layers.add([out, linear])
 
 -         return res
 
 -     
 
 - model = LSTNet()
 
 - model.build((None, 14, 8))
 
 - model.summary()       
 
  复制代码 |   
 
 
 
 |