seq2seq.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensorsoup 作者: ai-guild 项目源码 文件源码
def naive_decoder(cell, enc_states, targets, start_token, end_token, 
        feed_previous=True, training=True, scope='naive_decoder.0'):

    init_state = enc_states[-1]
    timesteps = tf.shape(enc_states)[0]

    # targets time major
    targets_tm = tf.transpose(targets, [1,0,2])

    states = tf.TensorArray(dtype=tf.float32, size=timesteps+1, name='states',
                    clear_after_read=False)
    outputs = tf.TensorArray(dtype=tf.float32, size=timesteps+1, name='outputs',
                    clear_after_read=False)

    def step(i, states, outputs):
        # run one step
        #  read from TensorArray (states)
        state_prev = states.read(i)

        if is_lstm(cell):
            # previous state <tensor> -> <LSTMStateTuple>
            c, h = tf.unstack(state_prev)
            state_prev = rnn.LSTMStateTuple(c,h)

        if feed_previous:
            input_ = outputs.read(i)
        else:
            input_ = targets_tm[i]

        output, state = cell(input_, state_prev)
        # add state, output to list
        states = states.write(i+1, state)
        outputs = outputs.write(i+1, output)
        i = tf.add(i,1)
        return i, states, outputs


    with tf.variable_scope(scope):
        # initial state
        states = states.write(0, init_state)
        # initial input
        outputs = outputs.write(0, start_token)

        i = tf.constant(0)

        # Stop loop condition
        if training:
            c = lambda x, y, z : tf.less(x, timesteps)
        else:
            c = lambda x, y, z : tf.reduce_all(tf.not_equal(tf.argmax(z.read(x), axis=-1), 
                    end_token))
        # body
        b = lambda x, y, z : step(x, y, z)
        # execution 
        _, fstates, foutputs = tf.while_loop(c,b, [i, states, outputs])

    return foutputs.stack()[1:] # add states; but why?
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号