def dot(inp, matrix, bias=None):
"""
Decide the right type of dot product depending on the input
arguments
"""
if 'int' in inp.dtype and inp.ndim == 2:
return matrix[inp.flatten()]
elif 'int' in inp.dtype:
return matrix[inp]
elif 'float' in inp.dtype and inp.ndim == 3:
shape0 = inp.shape[0]
shape1 = inp.shape[1]
shape2 = inp.shape[2]
if bias:
return (T.dot(inp.reshape((shape0 * shape1, shape2)), matrix) + bias).reshape((shape0, shape1, matrix.shape[1]))
else:
return T.dot(inp.reshape((shape0 * shape1, shape2)), matrix).reshape((shape0, shape1, matrix.shape[1]))
else:
if bias:
return T.dot(inp, matrix) + bias
else:
return T.dot(inp, matrix)
# Numerically stable log(sum(exp(A))). Can also be used in softmax function.
评论列表
文章目录