seq2seq_autoencoder_model.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:tensorflow-seq2seq-autoencoder 作者: qixiang109 项目源码 文件源码
def get_batch(self,data_set,batch_size,random=True):
        '''get a batch of data from a data_set and do all needed preprocess
        to make them usable for the model defined above'''
        if random:
            seqs = np.random.choice(data_set,size= batch_size)
        else:
            seqs = data_set[0:batch_size]
        encoder_inputs = np.zeros((batch_size,self.max_seq_length),dtype = int)
        decoder_inputs = np.zeros((batch_size,self.max_seq_length+2),dtype = int)
        encoder_lengths = np.zeros(batch_size)
        decoder_weights = np.zeros((batch_size,self.max_seq_length+2),dtype=float)
        for i,seq in enumerate(seqs):
            encoder_inputs[i] = np.array(list(reversed(seq))+[data_utils.PAD_ID]*(self.max_seq_length-len(seq)))
            decoder_inputs[i] = np.array([data_utils.GO_ID]+seq+[data_utils.EOS_ID]+[data_utils.PAD_ID]*(self.max_seq_length-len(seq)))
            encoder_lengths[i]= len(seq)
            decoder_weights[i,0:(len(seq)+1)]=1.0
        return np.transpose(encoder_inputs), np.transpose(decoder_inputs), encoder_lengths, np.transpose(decoder_weights)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号