def omniglot():
sess = tf.InteractiveSession()
""" def wrapper(v):
return tf.Print(v, [v], message="Printing v")
v = tf.Variable(initial_value=np.arange(0, 36).reshape((6, 6)), dtype=tf.float32, name='Matrix')
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
temp = tf.Variable(initial_value=np.arange(0, 36).reshape((6, 6)), dtype=tf.float32, name='temp')
temp = wrapper(v)
#with tf.control_dependencies([temp]):
temp.eval()
print 'Hello'"""
def update_tensor(V, dim2, val): # Update tensor V, with index(:,dim2[:]) by val[:]
val = tf.cast(val, V.dtype)
def body(_, (v, d2, chg)):
d2_int = tf.cast(d2, tf.int32)
return tf.slice(tf.concat_v2([v[:d2_int],[chg] ,v[d2_int+1:]], axis=0), [0], [v.get_shape().as_list()[0]])
Z = tf.scan(body, elems=(V, dim2, val), initializer=tf.constant(1, shape=V.get_shape().as_list()[1:], dtype=tf.float32), name="Scan_Update")
return Z
python类scan()的实例源码
def __call__(self, inputs, steps):
def fn(zv, x):
"""
Transition for training, without Metropolis-Hastings.
`z` is the input state.
`v` is created as a dummy variable to allow output of v_, for training p(v).
:param x: variable only for specifying the number of steps
:return: next state `z_`, and the corresponding auxiliary variable `v_`.
"""
z, v = zv
v = tf.random_normal(shape=tf.stack([tf.shape(z)[0], self.network.v_dim]))
z_, v_ = self.network.forward([z, v])
return z_, v_
elems = tf.zeros([steps])
return tf.scan(fn, elems, inputs, back_prop=True)
def _cumprod(tensor, axis=0):
"""A custom version of cumprod to prevent NaN gradients when there are zeros in `tensor`
as reported here: https://github.com/tensorflow/tensorflow/issues/3862
:param tensor: tf.Tensor
:return: tf.Tensor
"""
transpose_permutation = None
n_dim = len(tensor.get_shape())
if n_dim > 1 and axis != 0:
if axis < 0:
axis += n_dim
transpose_permutation = np.arange(n_dim)
transpose_permutation[-1], transpose_permutation[0] = 0, axis
tensor = tf.transpose(tensor, transpose_permutation)
def prod(acc, x):
return acc * x
prob = tf.scan(prod, tensor)
tensor = tf.transpose(prob, transpose_permutation)
return tensor
def omniglot():
sess = tf.InteractiveSession()
""" def wrapper(v):
return tf.Print(v, [v], message="Printing v")
v = tf.Variable(initial_value=np.arange(0, 36).reshape((6, 6)), dtype=tf.float32, name='Matrix')
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
temp = tf.Variable(initial_value=np.arange(0, 36).reshape((6, 6)), dtype=tf.float32, name='temp')
temp = wrapper(v)
#with tf.control_dependencies([temp]):
temp.eval()
print 'Hello'"""
def update_tensor(V, dim2, val): # Update tensor V, with index(:,dim2[:]) by val[:]
val = tf.cast(val, V.dtype)
def body(_, (v, d2, chg)):
d2_int = tf.cast(d2, tf.int32)
return tf.slice(tf.concat_v2([v[:d2_int],[chg] ,v[d2_int+1:]], axis=0), [0], [v.get_shape().as_list()[0]])
Z = tf.scan(body, elems=(V, dim2, val), initializer=tf.constant(1, shape=V.get_shape().as_list()[1:], dtype=tf.float32), name="Scan_Update")
return Z
def get_output_for(self, input, **kwargs):
input_shape = tf.shape(input)
n_batches = input_shape[0]
n_steps = input_shape[1]
input = tf.reshape(input, tf.pack([n_batches, n_steps, -1]))
if 'recurrent_state' in kwargs and self in kwargs['recurrent_state']:
h0s = kwargs['recurrent_state'][self]
else:
h0s = tf.tile(
tf.reshape(self.h0, (1, self.num_units)),
(n_batches, 1)
)
# flatten extra dimensions
shuffled_input = tf.transpose(input, (1, 0, 2))
hs = tf.scan(
self.step,
elems=shuffled_input,
initializer=h0s
)
shuffled_hs = tf.transpose(hs, (1, 0, 2))
if 'recurrent_state_output' in kwargs:
kwargs['recurrent_state_output'][self] = shuffled_hs
return shuffled_hs
def get_output_for(self, input, **kwargs):
input_shape = tf.shape(input)
n_batches = input_shape[0]
n_steps = input_shape[1]
input = tf.reshape(input, tf.pack([n_batches, n_steps, -1]))
c0s = tf.tile(
tf.reshape(self.c0, (1, self.num_units)),
(n_batches, 1)
)
h0s = self.nonlinearity(c0s)
# flatten extra dimensions
shuffled_input = tf.transpose(input, (1, 0, 2))
hcs = tf.scan(
self.step,
elems=shuffled_input,
initializer=tf.concat(1, [h0s, c0s])
)
shuffled_hcs = tf.transpose(hcs, (1, 0, 2))
shuffled_hs = shuffled_hcs[:, :, :self.num_units]
shuffled_cs = shuffled_hcs[:, :, self.num_units:]
return shuffled_hs
def get_output_for(self, input, **kwargs):
input_shape = tf.shape(input)
n_batches = input_shape[0]
n_steps = input_shape[1]
input = tf.reshape(input, tf.stack([n_batches, n_steps, -1]))
if 'recurrent_state' in kwargs and self in kwargs['recurrent_state']:
h0s = kwargs['recurrent_state'][self]
else:
h0s = tf.tile(
tf.reshape(self.h0, (1, self.num_units)),
(n_batches, 1)
)
# flatten extra dimensions
shuffled_input = tf.transpose(input, (1, 0, 2))
hs = tf.scan(
self.step,
elems=shuffled_input,
initializer=h0s
)
shuffled_hs = tf.transpose(hs, (1, 0, 2))
if 'recurrent_state_output' in kwargs:
kwargs['recurrent_state_output'][self] = shuffled_hs
return shuffled_hs
def get_output_for(self, input, **kwargs):
input_shape = tf.shape(input)
n_batches = input_shape[0]
n_steps = input_shape[1]
input = tf.reshape(input, tf.stack([n_batches, n_steps, -1]))
c0s = tf.tile(
tf.reshape(self.c0, (1, self.num_units)),
(n_batches, 1)
)
h0s = self.nonlinearity(c0s)
# flatten extra dimensions
shuffled_input = tf.transpose(input, (1, 0, 2))
hcs = tf.scan(
self.step,
elems=shuffled_input,
initializer=tf.concat(axis=1, values=[h0s, c0s])
)
shuffled_hcs = tf.transpose(hcs, (1, 0, 2))
shuffled_hs = shuffled_hcs[:, :, :self.num_units]
shuffled_cs = shuffled_hcs[:, :, self.num_units:]
return shuffled_hs
def seg_prediction(self):
outputs, size, batch_size = self.outputs
num_class = self.config.num_class
output_w = weight_variable([size, num_class])
output_b = bias_variable([num_class])
# outputs = tf.transpose(outputs,[1,0,2])
tag_trans = weight_variable([num_class, num_class])
def transition(p, x):
res = tf.matmul(x, output_w) + output_b
# deviation = tf.tile(tf.expand_dims(tf.reduce_min(previous_pred, reduction_indices=1), 1),
# [1, num_class])
# previous_pred -= deviation
focus = 1.
res += tf.matmul(p, tag_trans) * focus
prediction = tf.nn.softmax(res)
return prediction
# Recurrent network.
pred = tf.scan(transition, outputs, initializer=tf.zeros([batch_size, num_class]), parallel_iterations=100)
pred = tf.transpose(pred, [1, 0, 2])
return pred
def pos_prediction(self):
outputs, size, batch_size = self.outputs
num_class = len(POS_tagging['P'])
output_w = weight_variable([size, num_class])
output_b = bias_variable([num_class])
# outputs = tf.transpose(outputs,[1,0,2])
tag_trans = weight_variable([num_class, num_class])
outputs = tf.reverse(outputs, [True, False, False])
def transition(previous_pred, x):
res = tf.matmul(x, output_w) + output_b
deviation = tf.tile(tf.expand_dims(tf.reduce_min(previous_pred, reduction_indices=1), 1),
[1, num_class])
previous_pred -= deviation
focus = 0.5
res += tf.matmul(previous_pred, tag_trans) * focus
prediction = tf.nn.softmax(res)
return prediction
# Recurrent network.
pred = tf.scan(transition, outputs, initializer=tf.zeros([batch_size, num_class]), parallel_iterations=100)
pred = tf.reverse(pred, [True, False, False])
pred = tf.transpose(pred, [1, 0, 2])
return pred
def loss_crf(self):
"""
CRF based loss.
:return: loss
"""
# Reshaping seq_len tensor [seq_len, 1]
seq_length_reshaped = tf.reshape(self.x_tokens_len, [tf.shape(self.x_tokens_len)[0], -1])
# Computing loss by scanning mini-batch tensor
out = tf.scan(self.loss_crf_scan, [self.prediction,
seq_length_reshaped,
self.y], back_prop=True, infer_shape=True, initializer=0.0)
# Division by batch_size
loss_crf = tf.divide(tf.reduce_sum(out), tf.cast(tf.shape(self.x_tokens)[0], dtype=tf.float32))
return loss_crf
def refine_boxes(boxes, num_iters, step, sigma):
assert num_iters > 1
def iteration(prev, i):
state_prev, _ = prev
features = state_prev / sigma
dists = tf.nn.relu(nnutil.pairwise_distance(features))
weights = tf.exp(-dists)
confidence = tf.reduce_sum(weights, [1], True)
weights = weights / confidence
state_up = tf.matmul(weights, state_prev)
return (1.0 - step) * state_prev + step * state_up, confidence
states = tf.scan(iteration,
tf.range(0, num_iters),
initializer=(boxes, boxes[:,0:1]))
return states[0][-1], states[1][-1]
def compute_predictions_scan(self):
state = self.init_state
rnn_states = \
tf.scan(
self.rnn_step_scan,
tf.transpose(self.x, [1, 0, 2]),
initializer=state,
parallel_iterations=1)
rnn_outputs = \
tf.scan(
self.output_step_scan,
rnn_states,
initializer=tf.zeros([self.N_batch, self.N_out]),
parallel_iterations= 1)
return tf.transpose(rnn_outputs, [1, 0, 2]), tf.unstack(rnn_states)
# fix spectral radius of recurrent matrix
def get_output_for(self, input, **kwargs):
input_shape = tf.shape(input)
n_batches = input_shape[0]
n_steps = input_shape[1]
input = tf.reshape(input, tf.pack([n_batches, n_steps, -1]))
if 'recurrent_state' in kwargs and self in kwargs['recurrent_state']:
h0s = kwargs['recurrent_state'][self]
else:
h0s = tf.tile(
tf.reshape(self.h0, (1, self.num_units)),
(n_batches, 1)
)
# flatten extra dimensions
shuffled_input = tf.transpose(input, (1, 0, 2))
hs = tf.scan(
self.step,
elems=shuffled_input,
initializer=h0s
)
shuffled_hs = tf.transpose(hs, (1, 0, 2))
if 'recurrent_state_output' in kwargs:
kwargs['recurrent_state_output'][self] = shuffled_hs
return shuffled_hs
def get_output_for(self, input, **kwargs):
input_shape = tf.shape(input)
n_batches = input_shape[0]
n_steps = input_shape[1]
input = tf.reshape(input, tf.pack([n_batches, n_steps, -1]))
c0s = tf.tile(
tf.reshape(self.c0, (1, self.num_units)),
(n_batches, 1)
)
h0s = self.nonlinearity(c0s)
# flatten extra dimensions
shuffled_input = tf.transpose(input, (1, 0, 2))
hcs = tf.scan(
self.step,
elems=shuffled_input,
initializer=tf.concat(1, [h0s, c0s])
)
shuffled_hcs = tf.transpose(hcs, (1, 0, 2))
shuffled_hs = shuffled_hcs[:, :, :self.num_units]
shuffled_cs = shuffled_hcs[:, :, self.num_units:]
return shuffled_hs
def get_decoder_states(self):
batch_size = tf.shape(self.input)[0]
seq_length = tf.shape(self.input)[1]
scan_input_ = tf.transpose(self.input, perm=[2, 0, 1])
scan_input_ = tf.transpose(scan_input_) # scan input is [seq_length x batch_size x input_dim]
z = tf.zeros([1, batch_size, self.input_dim], dtype=tf.float32)
scan_input = tf.concat([scan_input_,z],0)
scan_input = tf.slice(scan_input, [1,0,0],[seq_length ,batch_size, self.input_dim])
scan_input = tf.reverse(scan_input, [0])#tf.reverse(scan_input, [True, False, False])
scan_time_ = tf.transpose(self.time) # scan_time [seq_length x batch_size]
z2 = tf.zeros([1, batch_size], dtype=tf.float32)
scan_time = tf.concat([scan_time_, z2],0)
scan_time = tf.slice(scan_time, [1,0],[seq_length ,batch_size])
scan_time = tf.reverse(scan_time, [0])#tf.reverse(scan_time, [True, False])
initial_hidden, initial_cell = self.get_representation()
ini_state_cell = tf.stack([initial_hidden, initial_cell])
# make scan_time [seq_length x batch_size x 1]
scan_time = tf.reshape(scan_time, [tf.shape(scan_time)[0], tf.shape(scan_time)[1], 1])
concat_input = tf.concat([scan_time, scan_input],2) # [seq_length x batch_size x input_dim+1]
packed_hidden_states = tf.scan(self.T_LSTM_Decoder_Unit, concat_input, initializer=ini_state_cell, name='decoder_states')
all_decoder_states = packed_hidden_states[:, 0, :, :]
return all_decoder_states
def __call__(self, inputs, init_state=None):
if init_state is None:
init_state = self.zero_state
init_states = tf.unstack(init_state)
next_inputs = inputs
for i, cell in enumerate(self.cells):
with tf.variable_scope('bilstm_%d' % i):
with tf.variable_scope('forward'):
f_outputs = cell.scan(next_inputs, init_states[i])
with tf.variable_scope('backward'):
r_inputs = tf.reverse(next_inputs, axis=(0,))
rb_outputs = cell.scan(r_inputs, init_states[i])
b_outputs = tf.reverse(rb_outputs, axis=(0,))
outputs = tf.concat([f_outputs, b_outputs], axis=2)
next_inputs = tf.nn.dropout(outputs, keep_prob=self.dropout)
return next_inputs
def parse_args():
parser = argparse.ArgumentParser(
description='Gated Recurrent Unit RNN for Text Hallucination, built with tf.scan')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-g', '--generate', action='store_true',
help='generate text')
group.add_argument('-t', '--train', action='store_true',
help='train model')
parser.add_argument('-n', '--num_words', required=False, type=int,
help='number of words to generate')
args = vars(parser.parse_args())
return args
###
# main function
def parse_args():
parser = argparse.ArgumentParser(
description='Long Short Term Memory RNN for Text Hallucination, built with tf.scan')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-g', '--generate', action='store_true',
help='generate text')
group.add_argument('-t', '--train', action='store_true',
help='train model')
parser.add_argument('-n', '--num_words', required=False, type=int,
help='number of words to generate')
args = vars(parser.parse_args())
return args
###
# main function
def parse_args():
parser = argparse.ArgumentParser(
description='Stacked Long Short Term Memory RNN for Text Hallucination, built with tf.scan')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-g', '--generate', action='store_true',
help='generate text')
group.add_argument('-t', '--train', action='store_true',
help='train model')
parser.add_argument('-n', '--num_words', required=False, type=int,
help='number of words to generate')
args = vars(parser.parse_args())
return args
###
# main function
def parse_args():
parser = argparse.ArgumentParser(
description='Stacked Gated Recurrent Unit RNN for Text Hallucination, built with tf.scan')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-g', '--generate', action='store_true',
help='generate text')
group.add_argument('-t', '--train', action='store_true',
help='train model')
parser.add_argument('-n', '--num_words', required=False, type=int,
help='number of words to generate')
args = vars(parser.parse_args())
return args
###
# main function
def get_output_for(self, input, **kwargs):
input_shape = tf.shape(input)
n_batches = input_shape[0]
n_steps = input_shape[1]
input = tf.reshape(input, tf.stack([n_batches, n_steps, -1]))
if 'recurrent_state' in kwargs and self in kwargs['recurrent_state']:
h0s = kwargs['recurrent_state'][self]
else:
h0s = tf.tile(
tf.reshape(self.h0, (1, self.num_units)),
(n_batches, 1)
)
# flatten extra dimensions
shuffled_input = tf.transpose(input, (1, 0, 2))
hs = tf.scan(
self.step,
elems=shuffled_input,
initializer=h0s
)
shuffled_hs = tf.transpose(hs, (1, 0, 2))
if 'recurrent_state_output' in kwargs:
kwargs['recurrent_state_output'][self] = shuffled_hs
return shuffled_hs
def get_output_for(self, input, **kwargs):
input_shape = tf.shape(input)
n_batches = input_shape[0]
n_steps = input_shape[1]
input = tf.reshape(input, tf.stack([n_batches, n_steps, -1]))
c0s = tf.tile(
tf.reshape(self.c0, (1, self.num_units)),
(n_batches, 1)
)
h0s = self.nonlinearity(c0s)
# flatten extra dimensions
shuffled_input = tf.transpose(input, (1, 0, 2))
hcs = tf.scan(
self.step,
elems=shuffled_input,
initializer=tf.concat(axis=1, values=[h0s, c0s])
)
shuffled_hcs = tf.transpose(hcs, (1, 0, 2))
shuffled_hs = shuffled_hcs[:, :, :self.num_units]
shuffled_cs = shuffled_hcs[:, :, self.num_units:]
return shuffled_hs
def backward_step_fn(self, params, inputs):
"""
Backwards step over a batch, to be used in tf.scan
:param params:
:param inputs: (batch_size, variable dimensions)
:return:
"""
mu_back, Sigma_back = params
mu_pred_tp1, Sigma_pred_tp1, mu_filt_t, Sigma_filt_t, A = inputs
# J_t = tf.matmul(tf.reshape(tf.transpose(tf.matrix_inverse(Sigma_pred_tp1), [0, 2, 1]), [-1, self.dim_z]),
# self.A)
# J_t = tf.transpose(tf.reshape(J_t, [-1, self.dim_z, self.dim_z]), [0, 2, 1])
J_t = tf.matmul(tf.transpose(A, [0, 2, 1]), tf.matrix_inverse(Sigma_pred_tp1))
J_t = tf.matmul(Sigma_filt_t, J_t)
mu_back = mu_filt_t + tf.matmul(J_t, mu_back - mu_pred_tp1)
Sigma_back = Sigma_filt_t + tf.matmul(J_t, tf.matmul(Sigma_back - Sigma_pred_tp1, J_t, adjoint_b=True))
return mu_back, Sigma_back
def compute_forwards(self, reuse=None):
"""Compute the forward step in the Kalman filter.
The forward pass is intialized with p(z_1)=N(self.mu, self.Sigma).
We then return the mean and covariances of the predictive distribution p(z_t|z_tm1,u_t), t=2,..T+1
and the filtering distribution p(z_t|x_1:t,u_1:t), t=1,..T
We follow the notation of Murphy's book, section 18.3.1
"""
# To make sure we are not accidentally using the real outputs in the steps with missing values, set them to 0.
y_masked = tf.multiply(tf.expand_dims(self.mask, 2), self.y)
inputs = tf.concat([y_masked, self.u, tf.expand_dims(self.mask, 2)], axis=2)
y_prev = tf.expand_dims(self.y_0, 0) # (1, dim_y)
y_prev = tf.tile(y_prev, (tf.shape(self.mu)[0], 1))
alpha, state, u, buffer = self.alpha(y_prev, self.state, self.u[:, 0], init_buffer=True, reuse= reuse)
# dummy matrix to initialize B and C in scan
dummy_init_A = tf.ones([self.Sigma.get_shape()[0], self.dim_z, self.dim_z])
dummy_init_B = tf.ones([self.Sigma.get_shape()[0], self.dim_z, self.dim_u])
dummy_init_C = tf.ones([self.Sigma.get_shape()[0], self.dim_y, self.dim_z])
forward_states = tf.scan(self.forward_step_fn, tf.transpose(inputs, [1, 0, 2]),
initializer=(self.mu, self.Sigma, self.mu, self.Sigma, alpha, u, state, buffer,
dummy_init_A, dummy_init_B, dummy_init_C),
parallel_iterations=1, name='forward')
return forward_states
def decode(self, input):
# returns a decoder
hidden = tf.matmul(input, self.weights["decoder1_weights"]) + self.weights["decoder1_biases"]
hidden_relu = tf.nn.relu(hidden)
# output is encoding_size x 1 x small_encoding_size
# multiheaded_hidden = tf.matmul(input, self.weights["multiheaded1_weights"]) + self.weights["multiheaded1_biases"]
# multiheaded_hidden = tf.reshape(multiheaded_hidden, [-1, self.arch_params['output_dim'], 1, self.arch_params['small_encoding_dim']])
# multiheaded_hidden = tf.nn.relu(multiheaded_hidden)
#
# h = tf.scan(lambda a,x: tf.batch_matmul(x, self.weights["multiheaded2_weights"]), multiheaded_hidden,
# initializer=tf.Variable(tf.constant(0.0, shape=[self.arch_params['output_dim'],1,1])))
# multiheaded_output = h + self.weights["multiheaded2_biases"]
# output1 = tf.reshape(multiheaded_output, [-1, self.arch_params['output_dim']])
output1 = tf.matmul(hidden_relu, self.weights["decoder2_weights"]) + self.weights["decoder2_biases"]
output = output1
return output
def decode(self, input):
# returns a decoder
hidden = tf.matmul(input, self.weights["decoder1_weights"]) + self.weights["decoder1_biases"]
hidden_relu = tf.nn.relu(hidden)
# output is encoding_size x 1 x small_encoding_size
# multiheaded_hidden = tf.matmul(input, self.weights["multiheaded1_weights"]) + self.weights["multiheaded1_biases"]
# multiheaded_hidden = tf.reshape(multiheaded_hidden, [-1, self.arch_params['output_dim'], 1, self.arch_params['small_encoding_dim']])
# multiheaded_hidden = tf.nn.relu(multiheaded_hidden)
#
# h = tf.scan(lambda a,x: tf.batch_matmul(x, self.weights["multiheaded2_weights"]), multiheaded_hidden,
# initializer=tf.Variable(tf.constant(0.0, shape=[self.arch_params['output_dim'],1,1])))
# multiheaded_output = h + self.weights["multiheaded2_biases"]
# output1 = tf.reshape(multiheaded_output, [-1, self.arch_params['output_dim']])
output1 = tf.matmul(hidden_relu, self.weights["decoder2_weights"]) + self.weights["decoder2_biases"]
output = output1
return output
def decode(self, input):
# returns a decoder
hidden = tf.matmul(input, self.weights["decoder1_weights"]) + self.weights["decoder1_biases"]
hidden_relu = tf.nn.relu(hidden)
# output is encoding_size x 1 x small_encoding_size
# multiheaded_hidden = tf.matmul(input, self.weights["multiheaded1_weights"]) + self.weights["multiheaded1_biases"]
# multiheaded_hidden = tf.reshape(multiheaded_hidden, [-1, self.arch_params['output_dim'], 1, self.arch_params['small_encoding_dim']])
# multiheaded_hidden = tf.nn.relu(multiheaded_hidden)
#
# h = tf.scan(lambda a,x: tf.batch_matmul(x, self.weights["multiheaded2_weights"]), multiheaded_hidden,
# initializer=tf.Variable(tf.constant(0.0, shape=[self.arch_params['output_dim'],1,1])))
# multiheaded_output = h + self.weights["multiheaded2_biases"]
# output1 = tf.reshape(multiheaded_output, [-1, self.arch_params['output_dim']])
output1 = tf.matmul(hidden_relu, self.weights["decoder2_weights"]) + self.weights["decoder2_biases"]
output = output1
return output
def decode(self, input):
# returns a decoder
hidden = tf.matmul(input, self.weights["decoder1_weights"]) + self.weights["decoder1_biases"]
hidden_relu = tf.nn.relu(hidden)
# output is encoding_size x 1 x small_encoding_size
# multiheaded_hidden = tf.matmul(input, self.weights["multiheaded1_weights"]) + self.weights["multiheaded1_biases"]
# multiheaded_hidden = tf.reshape(multiheaded_hidden, [-1, self.arch_params['output_dim'], 1, self.arch_params['small_encoding_dim']])
# multiheaded_hidden = tf.nn.relu(multiheaded_hidden)
#
# h = tf.scan(lambda a,x: tf.batch_matmul(x, self.weights["multiheaded2_weights"]), multiheaded_hidden,
# initializer=tf.Variable(tf.constant(0.0, shape=[self.arch_params['output_dim'],1,1])))
# multiheaded_output = h + self.weights["multiheaded2_biases"]
# output1 = tf.reshape(multiheaded_output, [-1, self.arch_params['output_dim']])
output1 = tf.matmul(hidden_relu, self.weights["decoder2_weights"]) + self.weights["decoder2_biases"]
output = output1
return output
def decode(self, input):
# returns a decoder
hidden = tf.matmul(input, self.weights["decoder1_weights"]) + self.weights["decoder1_biases"]
hidden_relu = tf.nn.relu(hidden)
# output is encoding_size x 1 x small_encoding_size
# multiheaded_hidden = tf.matmul(input, self.weights["multiheaded1_weights"]) + self.weights["multiheaded1_biases"]
# multiheaded_hidden = tf.reshape(multiheaded_hidden, [-1, self.arch_params['output_dim'], 1, self.arch_params['small_encoding_dim']])
# multiheaded_hidden = tf.nn.relu(multiheaded_hidden)
#
# h = tf.scan(lambda a,x: tf.batch_matmul(x, self.weights["multiheaded2_weights"]), multiheaded_hidden,
# initializer=tf.Variable(tf.constant(0.0, shape=[self.arch_params['output_dim'],1,1])))
# multiheaded_output = h + self.weights["multiheaded2_biases"]
# output1 = tf.reshape(multiheaded_output, [-1, self.arch_params['output_dim']])
output1 = tf.matmul(hidden_relu, self.weights["decoder2_weights"]) + self.weights["decoder2_biases"]
output = output1
return output