modules.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tensorsoup 作者: ai-guild 项目源码 文件源码
def match(qstates, pstates, d, dropout=None):

    # infer batch_size, passage length and question length
    qlen, batch_size, _ = tf.unstack(tf.shape(qstates))
    plen = tf.shape(pstates)[0]

    # ouput projection params
    # Wo = tf.get_variable('Wo', shape=[2*d, d], dtype=tf.float32)

    # define rnn cell
    #  TODO : replace with LSTM
    cell = rcell('lstm', num_units=2*d, dropout=dropout)

    states = tf.TensorArray(dtype=tf.float32, size=plen+1, name='states',
                clear_after_read=False)
    outputs = tf.TensorArray(dtype=tf.float32, size=plen, name='outputs',
                clear_after_read=False)

    # set init state
    #init_state = tf.zeros(dtype=tf.float32, shape=[batch_size, 2*d])
    init_state = cell.zero_state(batch_size, tf.float32)
    states = states.write(0, init_state)

    def mlstm_step(i, states, outputs):
        # get previous state
        prev_state = states.read(i)

        prev_state = tf.unstack(prev_state)
        prev_state_tuple = tf.contrib.rnn.LSTMStateTuple(prev_state[0], prev_state[1])
        prev_state_c = prev_state_tuple.c

        # get attention weighted representation
        ci = attention(qstates, pstates[i], prev_state_c, d)

        # combine ci and input(i) 
        input_ = tf.concat([pstates[i], ci], axis=-1)
        output, state = cell(input_, prev_state_tuple)

        # save output, state 
        states = states.write(i+1, state)
        outputs = outputs.write(i, output)

        return (i+1, states, outputs)

    # execute loop
    #i = tf.constant(0)
    c = lambda x, y, z : tf.less(x, plen)
    b = lambda x, y, z : mlstm_step(x, y, z)
    _, fstates, foutputs = tf.while_loop(c,b, [0, states, outputs])

    return foutputs.stack(), project_lstm_states(fstates.stack()[1:], 4*d, d)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号