dynamic_rnn_estimator_test.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def setUp(self):
    self.rnn_cell = rnn_cell.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
    self.mock_target_column = MockTargetColumn(
        num_label_columns=self.NUM_LABEL_COLUMNS)

    location = tf.contrib.layers.sparse_column_with_keys(
        'location', keys=['west_side', 'east_side', 'nyc'])
    location_onehot = tf.contrib.layers.one_hot_column(location)
    self.context_feature_columns = [location_onehot]

    wire_cast = tf.contrib.layers.sparse_column_with_keys(
        'wire_cast', ['marlo', 'omar', 'stringer'])
    wire_cast_embedded = tf.contrib.layers.embedding_column(
        wire_cast, dimension=8)
    measurements = tf.contrib.layers.real_valued_column(
        'measurements', dimension=2)
    self.sequence_feature_columns = [measurements, wire_cast_embedded]

    self.columns_to_tensors = {
        'location': tf.SparseTensor(
            indices=[[0, 0], [1, 0], [2, 0]],
            values=['west_side', 'west_side', 'nyc'],
            shape=[3, 1]),
        'wire_cast': tf.SparseTensor(
            indices=[[0, 0, 0], [0, 1, 0],
                     [1, 0, 0], [1, 1, 0], [1, 1, 1],
                     [2, 0, 0]],
            values=[b'marlo', b'stringer',
                    b'omar', b'stringer', b'marlo',
                    b'marlo'],
            shape=[3, 2, 2]),
        'measurements': tf.random_uniform([3, 2, 2])}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号