def DiagonalLSTM(name, input_dim, inputs):
"""
inputs.shape: (batch size, height, width, input_dim)
outputs.shape: (batch size, height, width, DIM)
"""
inputs = Skew(inputs)
input_to_state = Conv2D(name+'.InputToState', input_dim, 4*DIM, 1, inputs, mask_type='b')
batch_size = inputs.shape[0]
c0_unbatched = lib.param(
name + '.c0',
numpy.zeros((HEIGHT, DIM), dtype=theano.config.floatX)
)
c0 = T.alloc(c0_unbatched, batch_size, HEIGHT, DIM)
h0_unbatched = lib.param(
name + '.h0',
numpy.zeros((HEIGHT, DIM), dtype=theano.config.floatX)
)
h0 = T.alloc(h0_unbatched, batch_size, HEIGHT, DIM)
def step_fn(current_input_to_state, prev_c, prev_h):
# all args have shape (batch size, height, DIM)
# TODO consider learning this padding
prev_h = T.concatenate([
T.zeros((batch_size, 1, DIM), theano.config.floatX),
prev_h
], axis=1)
state_to_state = Conv1D(name+'.StateToState', DIM, 4*DIM, 2, prev_h, apply_biases=False)
gates = current_input_to_state + state_to_state
o_f_i = T.nnet.sigmoid(gates[:,:,:3*DIM])
o = o_f_i[:,:,0*DIM:1*DIM]
f = o_f_i[:,:,1*DIM:2*DIM]
i = o_f_i[:,:,2*DIM:3*DIM]
g = T.tanh(gates[:,:,3*DIM:4*DIM])
new_c = (f * prev_c) + (i * g)
new_h = o * T.tanh(new_c)
return (new_c, new_h)
outputs, _ = theano.scan(
step_fn,
sequences=input_to_state.dimshuffle(2,0,1,3),
outputs_info=[c0, h0]
)
all_cs = outputs[0].dimshuffle(1,2,0,3)
all_hs = outputs[1].dimshuffle(1,2,0,3)
return Unskew(all_hs)
评论列表
文章目录