utils.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:seqmod 作者: emanjavacas 项目源码 文件源码
def make_initializer(
        linear={'type': 'uniform', 'args': {'a': -0.05, 'b': 0.05}},
        linear_bias={'type': 'constant', 'args': {'val': 0.}},
        rnn={'type': 'xavier_uniform', 'args': {'gain': 1.}},
        rnn_bias={'type': 'constant', 'args': {'val': 0.}},
        cnn_bias={'type': 'constant', 'args': {'val': 0.}},
        emb={'type': 'normal', 'args': {'mean': 0, 'std': 1}},
        default={'type': 'uniform', 'args': {'a': -0.05, 'b': 0.05}}):

    rnns = (torch.nn.LSTM, torch.nn.GRU,
            torch.nn.LSTMCell, torch.nn.GRUCell,
            StackedGRU, StackedLSTM, NormalizedGRU,
            NormalizedGRUCell, StackedNormalizedGRU)

    convs = (torch.nn.Conv1d, torch.nn.Conv2d)

    def initializer(m):

        if isinstance(m, (rnns)):  # RNNs
            for p_name, p in m.named_parameters():
                if hasattr(p, 'custom'):
                    continue
                if is_bias(p_name):
                    getattr(init, rnn_bias['type'])(p, **rnn_bias['args'])
                else:
                    getattr(init, rnn['type'])(p, **rnn['args'])

        elif isinstance(m, torch.nn.Linear):  # linear
            for p_name, p in m.named_parameters():
                if hasattr(p, 'custom'):
                    continue
                if is_bias(p_name):
                    getattr(init, linear_bias['type'])(p, **linear_bias['args'])
                else:
                    getattr(init, linear['type'])(p, **linear['args'])

        elif isinstance(m, torch.nn.Embedding):  # embedding
            for p in m.parameters():
                if hasattr(p, 'custom'):
                    continue
                getattr(init, emb['type'])(p, **emb['args'])

        elif isinstance(m, convs):
            for p_name, p in m.named_parameters():
                if hasattr(p, 'custom'):
                    continue
                if is_bias(p_name):
                    getattr(init, cnn_bias['type'])(p, **cnn_bias['args'])
                else:
                    # Karpathy: http://cs231n.github.io/neural-networks-2/#init
                    # -> scale weight vector by square root of its fan-in...
                    # fan_in, _ = init._calculate_fan_in_and_fan_out(p)
                    # init.normal(p, mean=0, std=math.sqrt(fan_in))
                    init.xavier_normal(p)

    return initializer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号