def classification_metrics(y, y_pred, threshold):
metrics = {}
metrics['threshold'] = threshold_from_predictions(y, y_pred, 0)
metrics['np.std(y_pred)'] = np.std(y_pred)
metrics['positive_frac_batch'] = float(np.count_nonzero(y == True)) / len(y)
denom = np.count_nonzero(y == False)
num = np.count_nonzero(np.logical_and(y == False, y_pred >= threshold))
if denom > 0:
metrics['fpr'] = float(num) / float(denom)
if any(y) and not all(y):
metrics['auc'] = roc_auc_score(y, y_pred)
y_pred_bool = y_pred >= threshold
if (any(y_pred_bool) and not all(y_pred_bool)):
metrics['precision'] = precision_score(np.array(y, dtype=np.float32), y_pred_bool)
metrics['recall'] = recall_score(y, y_pred_bool)
return metrics
python类float32()的实例源码
def preprocess(image):
"""Takes an image and apply preprocess"""
# ????????????
image = cv2.resize(image, (data_shape, data_shape))
# ?? BGR ? RGB
image = image[:, :, (2, 1, 0)]
# ?mean?????float
image = image.astype(np.float32)
# ? mean
image -= np.array([123, 117, 104])
# ??? [batch-channel-height-width]
image = np.transpose(image, (2, 0, 1))
image = image[np.newaxis, :]
# ?? ndarray
image = nd.array(image)
return image
def remove_artifacts(self, image):
"""
Remove the connected components that are not within the parameters
Operates in place
:param image: sudoku's thresholded image w/o grid
:return: None
"""
labeled, features = label(image, structure=CROSS)
lbls = np.arange(1, features + 1)
areas = extract_feature(image, labeled, lbls, np.sum,
np.uint32, 0)
sides = extract_feature(image, labeled, lbls, min_side,
np.float32, 0, True)
diags = extract_feature(image, labeled, lbls, diagonal,
np.float32, 0, True)
for index in lbls:
area = areas[index - 1] / 255
side = sides[index - 1]
diag = diags[index - 1]
if side < 5 or side > 20 \
or diag < 15 or diag > 25 \
or area < 40:
image[labeled == index] = 0
return None
def word_list_to_embedding(words, embeddings, embedding_dimension=50):
'''
:param words: an n x (2*window_size + 1) matrix from data_to_mat
:param embeddings: an embedding dictionary where keys are strings and values
are embeddings; the output from embeddings_to_dict
:param embedding_dimension: the dimension of the values in embeddings; in this
assignment, embedding_dimension=50
:return: an n x ((2*window_size + 1)*embedding_dimension) matrix where each entry of the
words matrix is replaced with its embedding
'''
m, n = words.shape
words = words.reshape((-1))
return np.array([embeddings[w] for w in words], dtype=np.float32).reshape(m, n*embedding_dimension)
#
# End Twitter Helper Functions
#
def put_images_on_grid(images, shape=(16,8)):
nrof_images = images.shape[0]
img_size = images.shape[1]
bw = 3
img = np.zeros((shape[1]*(img_size+bw)+bw, shape[0]*(img_size+bw)+bw, 3), np.float32)
for i in range(shape[1]):
x_start = i*(img_size+bw)+bw
for j in range(shape[0]):
img_index = i*shape[0]+j
if img_index>=nrof_images:
break
y_start = j*(img_size+bw)+bw
img[x_start:x_start+img_size, y_start:y_start+img_size, :] = images[img_index, :, :, :]
if img_index>=nrof_images:
break
return img
def layout_tree(correlation):
"""Layout tree for visualization with e.g. matplotlib.
Args:
correlation: A [V, V]-shaped numpy array of latent correlations.
Returns:
A [V, 3]-shaped numpy array of spectral positions of vertices.
"""
assert len(correlation.shape) == 2
assert correlation.shape[0] == correlation.shape[1]
assert correlation.dtype == np.float32
laplacian = -correlation
np.fill_diagonal(laplacian, 0)
np.fill_diagonal(laplacian, -laplacian.sum(axis=0))
evals, evects = scipy.linalg.eigh(laplacian, eigvals=[1, 2, 3])
assert np.all(evals > 0)
assert evects.shape[1] == 3
return evects
def test_quantize_from_probs2(size, resolution):
set_random_seed(make_seed(size, resolution))
probs = np.exp(np.random.random(size)).astype(np.float32)
probs2 = probs.reshape((1, size))
quantized = quantize_from_probs2(probs2, resolution)
assert quantized.shape == probs2.shape
assert quantized.dtype == np.int8
assert np.all(quantized.sum(axis=1) == resolution)
# Check that quantized result is closer to target than any other value.
quantized = quantized.reshape((size, ))
target = resolution * probs / probs.sum()
distance = np.abs(quantized - target).sum()
for combo in itertools.combinations(range(size), resolution):
other = np.zeros(size, np.int8)
for i in combo:
other[i] += 1
assert other.sum() == resolution
other_distance = np.abs(other - target).sum()
assert distance <= other_distance
def sample_tree(self):
"""Samples a random tree.
Returns:
A pair (edges, edge_logits), where:
edges: A list of (vertex, vertex) pairs.
edge_logits: A [K]-shaped numpy array of edge logits.
"""
logger.info('TreeCatTrainer.sample_tree given %d rows',
len(self._added_rows))
SERIES.sample_tree_num_rows.append(len(self._added_rows))
complete_grid = self._tree.complete_grid
edge_logits = self.compute_edge_logits()
assert edge_logits.shape[0] == complete_grid.shape[1]
assert edge_logits.dtype == np.float32
edges = self.get_edges()
edges = sample_tree(complete_grid, edge_logits, edges)
return edges, edge_logits
def treecat_add_cell(
feature_type,
ragged_index,
data_row,
message,
feat_probs,
meas_probs,
v, ):
if feature_type == TY_MULTINOMIAL:
beg, end = ragged_index[v:v + 2]
feat_block = feat_probs[beg:end, :]
meas_block = meas_probs[v, :]
for c, count in enumerate(data_row[beg:end]):
for _ in range(count):
message *= feat_block[c, :]
message /= meas_block
feat_block[c, :] += np.float32(1)
meas_block += np.float32(1)
else:
raise NotImplementedError
def __init__(self, data, tree_prior, config):
"""Initialize a model with an empty subsample.
Args:
data: An [N, V]-shaped numpy array of real-valued data.
tree_prior: A [K]-shaped numpy array of prior edge log odds, where
K is the number of edges in the complete graph on V vertices.
config: A global config dict.
"""
assert isinstance(data, np.ndarray)
data = np.asarray(data, np.float32)
assert len(data.shape) == 2
N, V = data.shape
D = config['model_latent_dim']
E = V - 1 # Number of edges in the tree.
TreeTrainer.__init__(self, N, V, tree_prior, config)
self._data = data
self._latent = np.zeros([N, V, D], np.float32)
# This is symmetric positive definite.
self._vert_ss = np.zeros([V, D, D], np.float32)
# This is arbitrary (not necessarily symmetric).
self._edge_ss = np.zeros([E, D, D], np.float32)
# This represents (count, mean, covariance).
self._feat_ss = np.zeros([V, D, 1 + 1 + D], np.float32)
def observed_perplexity(self, counts):
"""Compute perplexity = exp(entropy) of observed variables.
Perplexity is an information theoretic measure of the number of
clusters or latent classes. Perplexity is a real number in the range
[1, M], where M is model_num_clusters.
Args:
counts: A [V]-shaped array of multinomial counts.
Returns:
A [V]-shaped numpy array of perplexity.
"""
V, E, M, R = self._VEMR
if counts is not None:
counts = np.ones(V, dtype=np.int8)
assert counts.shape == (V, )
assert counts.dtype == np.int8
assert np.all(counts > 0)
observed_entropy = np.empty(V, dtype=np.float32)
for v in range(V):
beg, end = self._ragged_index[v:v + 2]
probs = np.dot(self._feat_cond[beg:end, :], self._vert_probs[v, :])
observed_entropy[v] = multinomial_entropy(probs, counts[v])
return np.exp(observed_entropy)
def generate_model_file(num_rows, num_cols, num_cats=4, rate=1.0):
"""Generate a random model.
Returns:
The path to a gzipped pickled model.
"""
path = os.path.join(DATA, '{}-{}-{}-{:0.1f}.model.pkz'.format(
num_rows, num_cols, num_cats, rate))
V = num_cols
K = V * (V - 1) // 2
if os.path.exists(path):
return path
print('Generating {}'.format(path))
if not os.path.exists(DATA):
os.makedirs(DATA)
dataset_path = generate_dataset_file(num_rows, num_cols, num_cats, rate)
dataset = pickle_load(dataset_path)
table = dataset['table']
tree_prior = np.zeros(K, dtype=np.float32)
config = make_config(learning_init_epochs=5)
model = train_model(table, tree_prior, config)
pickle_dump(model, path)
return path
def calculate_loss(self, predictions, labels, weights=None, **unused_params):
with tf.name_scope("loss_xent"):
epsilon = 10e-6
if FLAGS.label_smoothing:
float_labels = smoothing(labels)
else:
float_labels = tf.cast(labels, tf.float32)
cross_entropy_loss = float_labels * tf.log(predictions + epsilon) + (
1 - float_labels) * tf.log(1 - predictions + epsilon)
cross_entropy_loss = tf.negative(cross_entropy_loss)
if weights is not None:
print cross_entropy_loss, weights
weighted_loss = tf.einsum("ij,i->ij", cross_entropy_loss, weights)
print "create weighted_loss", weighted_loss
return tf.reduce_mean(tf.reduce_sum(weighted_loss, 1))
else:
return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1))
def calculate_loss(self, predictions, support_predictions, labels, **unused_params):
"""
support_predictions batch_size x num_models x num_classes
predictions = tf.reduce_mean(support_predictions, axis=1)
"""
model_count = tf.shape(support_predictions)[1]
vocab_size = tf.shape(support_predictions)[2]
mean_predictions = tf.reduce_mean(support_predictions, axis=1, keep_dims=True)
support_labels = tf.tile(tf.expand_dims(tf.cast(labels, dtype=tf.float32), axis=1), multiples=[1,model_count,1])
support_means = tf.stop_gradient(tf.tile(mean_predictions, multiples=[1,model_count,1]))
support_predictions = tf.reshape(support_predictions, shape=[-1,model_count*vocab_size])
support_labels = tf.reshape(support_labels, shape=[-1,model_count*vocab_size])
support_means = tf.reshape(support_means, shape=[-1,model_count*vocab_size])
ce_loss_fn = CrossEntropyLoss()
# The cross entropy between predictions and ground truth
cross_entropy_loss = ce_loss_fn.calculate_loss(support_predictions, support_labels, **unused_params)
# The cross entropy between predictions and mean predictions
divergence = ce_loss_fn.calculate_loss(support_predictions, support_means, **unused_params)
loss = cross_entropy_loss * (1.0 - FLAGS.support_loss_percent) - divergence * FLAGS.support_loss_percent
return loss
def calculate_loss(self, predictions, labels, weights=None, **unused_params):
with tf.name_scope("loss_xent"):
epsilon = 10e-6
if FLAGS.label_smoothing:
float_labels = smoothing(labels)
else:
float_labels = tf.cast(labels, tf.float32)
cross_entropy_loss = float_labels * tf.log(predictions + epsilon) + (
1 - float_labels) * tf.log(1 - predictions + epsilon)
cross_entropy_loss = tf.negative(cross_entropy_loss)
if weights is not None:
print cross_entropy_loss, weights
weighted_loss = tf.einsum("ij,i->ij", cross_entropy_loss, weights)
print "create weighted_loss", weighted_loss
return tf.reduce_mean(tf.reduce_sum(weighted_loss, 1))
else:
return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1))
def get_batch_data():
# Load data
X, Y = load_data()
# calc total batch count
num_batch = len(X) // hp.batch_size
# Convert to tensor
X = tf.convert_to_tensor(X, tf.int32)
Y = tf.convert_to_tensor(Y, tf.float32)
# Create Queues
input_queues = tf.train.slice_input_producer([X, Y])
# create batch queues
x, y = tf.train.batch(input_queues,
num_threads=8,
batch_size=hp.batch_size,
capacity=hp.batch_size * 64,
allow_smaller_final_batch=False)
return x, y, num_batch # (N, T), (N, T), ()
def metrics(self, X, y):
metrics = {}
y_pred_pair, loss = self.predict_proba_with_loss(X, y)
y_pred = y_pred_pair[:,1] ## From softmax pair to prob of catastrophe
metrics['loss'] = loss
threshold = self.threshold_from_data(X, y)
metrics['threshold'] = threshold
metrics['np.std(y_pred)'] = np.std(y_pred)
denom = np.count_nonzero(y == False)
num = np.count_nonzero(np.logical_and(y == False, y_pred >= threshold))
metrics['fpr'] = float(num) / float(denom)
if any(y) and not all(y):
metrics['auc'] = roc_auc_score(y, y_pred)
y_pred_bool = y_pred >= threshold
if (any(y_pred_bool) and not all(y_pred_bool)):
metrics['precision'] = precision_score(np.array(y, dtype=np.float32), y_pred_bool)
metrics['recall'] = recall_score(y, y_pred_bool)
return metrics
def classification_metrics(y, y_pred, threshold):
metrics = {}
metrics['threshold'] = threshold_from_predictions(y, y_pred, 0)
metrics['np.std(y_pred)'] = np.std(y_pred)
metrics['positive_frac_batch'] = float(np.count_nonzero(y == True)) / len(y)
denom = np.count_nonzero(y == False)
num = np.count_nonzero(np.logical_and(y == False, y_pred >= threshold))
if denom > 0:
metrics['fpr'] = float(num) / float(denom)
if any(y) and not all(y):
metrics['auc'] = roc_auc_score(y, y_pred)
y_pred_bool = y_pred >= threshold
if (any(y_pred_bool) and not all(y_pred_bool)):
metrics['precision'] = precision_score(np.array(y, dtype=np.float32), y_pred_bool)
metrics['recall'] = recall_score(y, y_pred_bool)
return metrics
def metrics(self, X, y):
metrics = {}
y_pred_pair, loss = self.predict_proba_with_loss(X, y)
y_pred = y_pred_pair[:,1] ## From softmax pair to prob of catastrophe
metrics['loss'] = loss
threshold = self.threshold_from_data(X, y)
metrics['threshold'] = threshold
metrics['np.std(y_pred)'] = np.std(y_pred)
denom = np.count_nonzero(y == False)
num = np.count_nonzero(np.logical_and(y == False, y_pred >= threshold))
metrics['fpr'] = float(num) / float(denom)
if any(y) and not all(y):
metrics['auc'] = roc_auc_score(y, y_pred)
y_pred_bool = y_pred >= threshold
if (any(y_pred_bool) and not all(y_pred_bool)):
metrics['precision'] = precision_score(np.array(y, dtype=np.float32), y_pred_bool)
metrics['recall'] = recall_score(y, y_pred_bool)
return metrics
def metrics(self, X, y):
metrics = {}
y_pred_pair, loss = self.predict_proba_with_loss(X, y)
y_pred = y_pred_pair[:,1] ## From softmax pair to prob of catastrophe
metrics['loss'] = loss
threshold = self.threshold_from_data(X, y)
metrics['threshold'] = threshold
metrics['np.std(y_pred)'] = np.std(y_pred)
denom = np.count_nonzero(y == False)
num = np.count_nonzero(np.logical_and(y == False, y_pred >= threshold))
metrics['fpr'] = float(num) / float(denom)
if any(y) and not all(y):
metrics['auc'] = roc_auc_score(y, y_pred)
y_pred_bool = y_pred >= threshold
if (any(y_pred_bool) and not all(y_pred_bool)):
metrics['precision'] = precision_score(np.array(y, dtype=np.float32), y_pred_bool)
metrics['recall'] = recall_score(y, y_pred_bool)
return metrics
def classification_metrics(y, y_pred, threshold):
metrics = {}
metrics['threshold'] = threshold_from_predictions(y, y_pred, 0)
metrics['np.std(y_pred)'] = np.std(y_pred)
metrics['positive_frac_batch'] = float(np.count_nonzero(y == True)) / len(y)
denom = np.count_nonzero(y == False)
num = np.count_nonzero(np.logical_and(y == False, y_pred >= threshold))
if denom > 0:
metrics['fpr'] = float(num) / float(denom)
if any(y) and not all(y):
metrics['auc'] = roc_auc_score(y, y_pred)
y_pred_bool = y_pred >= threshold
if (any(y_pred_bool) and not all(y_pred_bool)):
metrics['precision'] = precision_score(np.array(y, dtype=np.float32), y_pred_bool)
metrics['recall'] = recall_score(y, y_pred_bool)
return metrics
def metrics(self, X, y):
metrics = {}
y_pred_pair, loss = self.predict_proba_with_loss(X, y)
y_pred = y_pred_pair[:,1] ## From softmax pair to prob of catastrophe
metrics['loss'] = loss
threshold = self.threshold_from_data(X, y)
metrics['threshold'] = threshold
metrics['np.std(y_pred)'] = np.std(y_pred)
denom = np.count_nonzero(y == False)
num = np.count_nonzero(np.logical_and(y == False, y_pred >= threshold))
metrics['fpr'] = float(num) / float(denom)
if any(y) and not all(y):
metrics['auc'] = roc_auc_score(y, y_pred)
y_pred_bool = y_pred >= threshold
if (any(y_pred_bool) and not all(y_pred_bool)):
metrics['precision'] = precision_score(np.array(y, dtype=np.float32), y_pred_bool)
metrics['recall'] = recall_score(y, y_pred_bool)
return metrics
def classification_metrics(y, y_pred, threshold):
metrics = {}
metrics['threshold'] = threshold_from_predictions(y, y_pred, 0)
metrics['np.std(y_pred)'] = np.std(y_pred)
metrics['positive_frac_batch'] = float(np.count_nonzero(y == True)) / len(y)
denom = np.count_nonzero(y == False)
num = np.count_nonzero(np.logical_and(y == False, y_pred >= threshold))
if denom > 0:
metrics['fpr'] = float(num) / float(denom)
if any(y) and not all(y):
metrics['auc'] = roc_auc_score(y, y_pred)
y_pred_bool = y_pred >= threshold
if (any(y_pred_bool) and not all(y_pred_bool)):
metrics['precision'] = precision_score(np.array(y, dtype=np.float32), y_pred_bool)
metrics['recall'] = recall_score(y, y_pred_bool)
return metrics
def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None):
with tf.variable_scope(name):
stride_shape = [1, stride[0], stride[1], 1]
filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters]
# there are "num input feature maps * filter height * filter width"
# inputs to each hidden unit
fan_in = np.prod(filter_shape[:3])
# each unit in the lower layer receives a gradient from:
# "num output feature maps * filter height * filter width" /
# pooling size
fan_out = np.prod(filter_shape[:2]) * num_filters
# initialize weights with random weights
w_bound = np.sqrt(6. / (fan_in + fan_out))
w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound),
collections=collections)
b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.constant_initializer(0.0),
collections=collections)
return tf.nn.conv2d(x, w, stride_shape, pad) + b
def __init__(self, ob_space, ac_space, layers=[256], **kwargs):
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
rank = len(ob_space)
if rank == 3: # pixel input
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "c{}".format(i + 1), [3, 3], [2, 2]))
elif rank == 1: # plain features
#x = tf.nn.elu(linear(x, 256, "l1", normalized_columns_initializer(0.01)))
pass
else:
raise TypeError("observation space must have rank 1 or 3, got %d" % rank)
x = flatten(x)
for i, layer in enumerate(layers):
x = tf.nn.elu(linear(x, layer, "l{}".format(i + 1), tf.contrib.layers.xavier_initializer()))
self.logits = linear(x, ac_space, "action", tf.contrib.layers.xavier_initializer())
self.vf = tf.reshape(linear(x, 1, "value", tf.contrib.layers.xavier_initializer()), [-1])
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
self.state_in = []
def __init__(self, ob_space, ac_space, size=256, **kwargs):
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
# introduce a "fake" batch dimension of 1 after flatten so that we can do GRU over time dim
x = tf.expand_dims(flatten(x), 1)
gru = rnn.GRUCell(size)
h_init = np.zeros((1, size), np.float32)
self.state_init = [h_init]
h_in = tf.placeholder(tf.float32, [1, size])
self.state_in = [h_in]
gru_outputs, gru_state = tf.nn.dynamic_rnn(
gru, x, initial_state=h_in, sequence_length=[size], time_major=True)
x = tf.reshape(gru_outputs, [-1, size])
self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
self.state_out = [gru_state[:1]]
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
def metrics(self, X, y):
metrics = {}
y_pred_pair, loss = self.predict_proba_with_loss(X, y)
y_pred = y_pred_pair[:,1] ## From softmax pair to prob of catastrophe
metrics['loss'] = loss
threshold = self.threshold_from_data(X, y)
metrics['threshold'] = threshold
metrics['np.std(y_pred)'] = np.std(y_pred)
denom = np.count_nonzero(y == False)
num = np.count_nonzero(np.logical_and(y == False, y_pred >= threshold))
metrics['fpr'] = float(num) / float(denom)
if any(y) and not all(y):
metrics['auc'] = roc_auc_score(y, y_pred)
y_pred_bool = y_pred >= threshold
if (any(y_pred_bool) and not all(y_pred_bool)):
metrics['precision'] = precision_score(np.array(y, dtype=np.float32), y_pred_bool)
metrics['recall'] = recall_score(y, y_pred_bool)
return metrics
def classification_metrics(y, y_pred, threshold):
metrics = {}
metrics['threshold'] = threshold_from_predictions(y, y_pred, 0)
metrics['np.std(y_pred)'] = np.std(y_pred)
metrics['positive_frac_batch'] = float(np.count_nonzero(y == True)) / len(y)
denom = np.count_nonzero(y == False)
num = np.count_nonzero(np.logical_and(y == False, y_pred >= threshold))
if denom > 0:
metrics['fpr'] = float(num) / float(denom)
if any(y) and not all(y):
metrics['auc'] = roc_auc_score(y, y_pred)
y_pred_bool = y_pred >= threshold
if (any(y_pred_bool) and not all(y_pred_bool)):
metrics['precision'] = precision_score(np.array(y, dtype=np.float32), y_pred_bool)
metrics['recall'] = recall_score(y, y_pred_bool)
return metrics
def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None):
with tf.variable_scope(name):
stride_shape = [1, stride[0], stride[1], 1]
filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters]
# there are "num input feature maps * filter height * filter width"
# inputs to each hidden unit
fan_in = np.prod(filter_shape[:3])
# each unit in the lower layer receives a gradient from:
# "num output feature maps * filter height * filter width" /
# pooling size
fan_out = np.prod(filter_shape[:2]) * num_filters
# initialize weights with random weights
w_bound = np.sqrt(6. / (fan_in + fan_out))
w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound),
collections=collections)
b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.constant_initializer(0.0),
collections=collections)
return tf.nn.conv2d(x, w, stride_shape, pad) + b
def __init__(self, ob_space, ac_space, layers=[256], **kwargs):
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
rank = len(ob_space)
if rank == 3: # pixel input
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "c{}".format(i + 1), [3, 3], [2, 2]))
elif rank == 1: # plain features
#x = tf.nn.elu(linear(x, 256, "l1", normalized_columns_initializer(0.01)))
pass
else:
raise TypeError("observation space must have rank 1 or 3, got %d" % rank)
x = flatten(x)
for i, layer in enumerate(layers):
x = tf.nn.elu(linear(x, layer, "l{}".format(i + 1), tf.contrib.layers.xavier_initializer()))
self.logits = linear(x, ac_space, "action", tf.contrib.layers.xavier_initializer())
self.vf = tf.reshape(linear(x, 1, "value", tf.contrib.layers.xavier_initializer()), [-1])
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
self.state_in = []