recurrence.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:tensorsoup 作者: ai-guild 项目源码 文件源码
def uni_net_dynamic(cell, inputs, proj_dim=None, init_state=None, scope='uni_net_d0'):
    # transpose to time major
    inputs_tm = tf.transpose(inputs, [1,0,2], name='inputs_tm')

    # infer timesteps and batch_size
    timesteps, batch_size, _ = tf.unstack(tf.shape(inputs_tm))

    # check if init_state is provided
    #  TODO : fix and add this
    # init_state = init_state if init_state else cell.zero_state(batch_size,tf.float32)
    if init_state is None:
        init_state = cell.zero_state(batch_size, tf.float32)

    states = tf.TensorArray(dtype=tf.float32, size=timesteps+1, name='states',
                    clear_after_read=False)
    outputs = tf.TensorArray(dtype=tf.float32, size=timesteps, name='outputs',
                    clear_after_read=False)

    def step(i, states, outputs):
        # run one step
        #  read from TensorArray (states)
        state_prev = states.read(i)

        if is_lstm(cell):
            # previous state <tensor> -> <LSTMStateTuple>
            c, h = tf.unstack(state_prev)
            state_prev = rnn.LSTMStateTuple(c,h)

        output, state = cell(inputs_tm[i], state_prev)
        # add state, output to list
        states = states.write(i+1, state)
        outputs = outputs.write(i, output)
        i = tf.add(i,1)
        return i, states, outputs

    with tf.variable_scope(scope):
        # initial state
        states = states.write(0, init_state)
        i = tf.constant(0)
        # stopping condition
        c = lambda x, y, z : tf.less(x, timesteps)
        # body
        b = lambda x, y, z : step(x, y, z)
        # execution 
        _, fstates, foutputs = tf.while_loop(c,b, [i, states, outputs])

        # if LSTM, project states
        if is_lstm(cell):
            d1 = 2*cell.state_size.c
            d2 = proj_dim if proj_dim else d1//2
            return foutputs.stack(), project_lstm_states(fstates.stack()[1:], d1, d2)

    return foutputs.stack(), fstates.stack()[1:]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号