official_tensorflow_phased_lstm.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:tensorflow-phased-lstm 作者: philipperemy 项目源码 文件源码
def run_lstm_mnist(lstm_cell=BasicLSTMCell, hidden_size=32, batch_size=256, steps=1000, log_file='log.tsv'):
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    learning_rate = 0.001
    file_logger = FileLogger(log_file, ['step', 'training_loss', 'training_accuracy'])
    x_ = tf.placeholder(tf.float32, (batch_size, mnist_img_size, 1))
    t_ = tf.placeholder(tf.float32, (batch_size, mnist_img_size, 1))
    y_ = tf.placeholder(tf.float32, (batch_size, num_classes))

    if lstm_cell == PhasedLSTMCell:
        inputs = (t_, x_)
    else:
        inputs = x_
    outputs, _ = dynamic_rnn(cell=lstm_cell(hidden_size), inputs=inputs, dtype=tf.float32)
    rnn_out = tf.squeeze(outputs[:, -1, :])

    y = slim.fully_connected(inputs=rnn_out,
                             num_outputs=num_classes,
                             activation_fn=None)

    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_))
    grad_update = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
    correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    sess.run(tf.global_variables_initializer())

    def feed_dict_phased_lstm(batch):
        img = np.expand_dims(batch[0], axis=2)
        t = np.reshape(np.tile(np.array(range(mnist_img_size)), (batch_size, 1)), (batch_size, mnist_img_size, 1))
        return {x_: img, y_: batch[1], t_: t}

    def feed_dict_basic_lstm(batch):
        img = np.expand_dims(batch[0], axis=2)
        return {x_: img, y_: batch[1]}

    for i in range(steps):
        b = mnist.train.next_batch(batch_size)
        st = time()

        if lstm_cell == PhasedLSTMCell:
            feed_dict = feed_dict_phased_lstm(b)
        else:
            feed_dict = feed_dict_basic_lstm(b)

        tr_loss, tr_acc, _ = sess.run([cross_entropy, accuracy, grad_update], feed_dict=feed_dict)
        print('steps = {0} | time {1:.2f} | tr_loss = {2:.3f} | tr_acc = {3:.3f}'.format(str(i).zfill(6),
                                                                                         time() - st,
                                                                                         tr_loss,
                                                                                         tr_acc))
        file_logger.write([i, tr_loss, tr_acc])

    file_logger.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号