def generator(encoder_inputs,decoder_inputs,target_weights,bucket_id,seq_len):
def seq2seq_f(encoder,decoder):
cell = tf.contrib.rnn.BasicLSTMCell(embedding_size)
if num_layers > 1:
cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)
w = tf.get_variable("proj_w", [embedding_size, num_symbols])
b = tf.get_variable("proj_b", [num_symbols])
output_projection = (w, b)
outputs, state = tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(encoder,
decoder,cell,num_symbols,num_symbols,embedding_size,output_projection=output_projection,
feed_previous = True)
trans_output = []
for output in outputs:
trans_output.append(tf.matmul(output,w) + b)
return trans_output, state
targets = decoder_inputs
outputs, losses = tf.contrib.legacy_seq2seq.model_with_buckets(
encoder_inputs, decoder_inputs, targets,
target_weights, buckets, seq2seq_f,
softmax_loss_function=None,
per_example_loss=False, name='model_with_buckets')
patch = tf.convert_to_tensor([[0.0]*num_symbols] * batch_size)
def f0():
for _ in range(0,max_len-buckets[0][1]):
outputs[0].append(patch)
return tf.convert_to_tensor(outputs[0],dtype = tf.float32)
def f1():
for _ in range(0,max_len-buckets[1][1]):
outputs[1].append(patch)
return tf.convert_to_tensor(outputs[1],dtype = tf.float32)
def f2():
for _ in range(0,max_len-buckets[2][1]):
outputs[2].append(patch)
return tf.convert_to_tensor(outputs[2],dtype = tf.float32)
r = tf.case({tf.equal(bucket_id, 0): f0,
tf.equal(bucket_id, 1): f1},
default=f2, exclusive=True)
return tf.nn.softmax(tf.reshape(r,[max_len,batch_size,num_symbols]))
评论列表
文章目录