annealing.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:crayimage 作者: yandexdataschool 项目源码 文件源码
def sa(inputs, loss, params, outputs = (), srng=None, seed=1122334455, iters=32,
       initial_temperature = 1.0e-1, learning_rate=1.0e-2):
  if srng is None:
    # from theano.sandbox.cuda.rng_curand import CURAND_RandomStreams as RandomStreams
    from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
    srng = srng or RandomStreams(seed=seed)

  inputs_cached = [ to_shared(i) for i in inputs ]
  input_setter = OrderedDict()
  for inpc, inp in zip(inputs_cached, inputs):
    input_setter[inpc] = inp

  memorize_inputs = theano.function(inputs, [], updates=input_setter, no_default_updates=True)

  inputs_givens = [
    (inp, inpc)
    for inp, inpc in zip(inputs, inputs_cached)
  ]

  deltas = [
    make_copy(param)
    for param in params
  ]

  alpha = T.fscalar('learning rate')

  delta_setter = OrderedDict([
    (delta, make_uniform(delta, -alpha, alpha, srng))
    for delta in deltas
  ])

  generate_deltas = theano.function([alpha], [], updates=delta_setter, no_default_updates=False)

  probe_givens = [
    (param, param + delta)
    for param, delta in zip(params, deltas)
  ]

  probe = theano.function(
    [], [loss] + list(outputs),
    givens=probe_givens + inputs_givens,
    no_default_updates=True
  )

  params_setter = OrderedDict(probe_givens)

  set_params = theano.function(
    [], [],
    updates=params_setter,
    no_default_updates=True
  )

  return simulated_annealing(
    probe, memorize_inputs, set_params, generate_deltas,
    iters, initial_temperature, learning_rate
  )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号