def setUp(self):
self.rnn_cell = rnn_cell.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
self.mock_target_column = MockTargetColumn(
num_label_columns=self.NUM_LABEL_COLUMNS)
location = tf.contrib.layers.sparse_column_with_keys(
'location', keys=['west_side', 'east_side', 'nyc'])
location_onehot = tf.contrib.layers.one_hot_column(location)
self.context_feature_columns = [location_onehot]
wire_cast = tf.contrib.layers.sparse_column_with_keys(
'wire_cast', ['marlo', 'omar', 'stringer'])
wire_cast_embedded = tf.contrib.layers.embedding_column(
wire_cast, dimension=8)
measurements = tf.contrib.layers.real_valued_column(
'measurements', dimension=2)
self.sequence_feature_columns = [measurements, wire_cast_embedded]
self.columns_to_tensors = {
'location': tf.SparseTensor(
indices=[[0, 0], [1, 0], [2, 0]],
values=['west_side', 'west_side', 'nyc'],
shape=[3, 1]),
'wire_cast': tf.SparseTensor(
indices=[[0, 0, 0], [0, 1, 0],
[1, 0, 0], [1, 1, 0], [1, 1, 1],
[2, 0, 0]],
values=[b'marlo', b'stringer',
b'omar', b'stringer', b'marlo',
b'marlo'],
shape=[3, 2, 2]),
'measurements': tf.random_uniform([3, 2, 2])}
评论列表
文章目录