def position_wise_feed_forward_fn(self):
"""
x: [batch,sequence_length,d_model]
:return: [batch,sequence_length,d_model]
"""
output=None
#1.conv1
input=tf.expand_dims(self.x,axis=3) #[batch,sequence_length,d_model,1]
# conv2d.input: [None,sentence_length,embed_size,1]. filter=[filter_size,self.embed_size,1,self.num_filters]
# output with padding:[None,sentence_length,1,1]
filter1 = tf.get_variable("filter1"+str(self.layer_index) , shape=[1, self.d_model, 1, 1],initializer=self.initializer)
ouput_conv1=tf.nn.conv2d(input,filter1,strides=[1,1,1,1],padding="VALID",name="conv1") #[batch,sequence_length,1,1]
print("output_conv1:",ouput_conv1)
#2.conv2
filter2 = tf.get_variable("filter2"+str(self.layer_index), [1, 1, 1, self.d_model], initializer=self.initializer)
output_conv2=tf.nn.conv2d(ouput_conv1,filter2,strides=[1,1,1,1],padding="VALID",name="conv2") #[batch,sequence_length,1,d_model]
output=tf.squeeze(output_conv2) #[batch,sequence_length,d_model]
return output #[batch,sequence_length,d_model]
#test function of position_wise_feed_forward_fn
#time spent:OLD VERSION: length=8000,time spent:35.6s; NEW VERSION:0.03s
a2_poistion_wise_feed_forward.py 文件源码
python
阅读 34
收藏 0
点赞 0
评论 0
评论列表
文章目录