python类unsqueeze()的实例源码

MPNN.py 文件源码 项目:nmp_qc 作者: priba 项目源码 文件源码 阅读 16 收藏 0 点赞 0 评论 0
def forward(self, g, h_in, e):

        h = []

        # Padding to some larger dimension d
        h_t = torch.cat([h_in, Variable(
            torch.zeros(h_in.size(0), h_in.size(1), self.args['out'] - h_in.size(2)).type_as(h_in.data))], 2)

        h.append(h_t.clone())

        # Layer
        for t in range(0, self.n_layers):
            e_aux = e.view(-1, e.size(3))

            h_aux = h[t].view(-1, h[t].size(2))

            m = self.m[0].forward(h[t], h_aux, e_aux)
            m = m.view(h[0].size(0), h[0].size(1), -1, m.size(1))

            # Nodes without edge set message to 0
            m = torch.unsqueeze(g, 3).expand_as(m) * m

            m = torch.squeeze(torch.sum(m, 1))

            h_t = self.u[0].forward(h[t], m)

            # Delete virtual nodes
            h_t = (torch.sum(h_in, 2).expand_as(h_t) > 0).type_as(h_t) * h_t
            h.append(h_t)

        # Readout
        res = self.r.forward(h)

        if self.type == 'classification':
            res = nn.LogSoftmax()(res)
        return res
MessageFunction.py 文件源码 项目:nmp_qc 作者: priba 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def m_mpnn(self, h_v, h_w, e_vw, opt={}):
        # Matrices for each edge
        edge_output = self.learn_modules[0](e_vw)
        edge_output = edge_output.view(-1, self.args['out'], self.args['in'])

        h_w_rows = h_w[..., None].expand(h_w.size(0), h_v.size(1), h_w.size(1)).contiguous()

        h_w_rows = h_w_rows.view(-1, self.args['in'])

        h_multiply = torch.bmm(edge_output, torch.unsqueeze(h_w_rows,2))

        m_new = torch.squeeze(h_multiply)

        return m_new
torch_backend.py 文件源码 项目:ktorch 作者: farizrahman4u 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def batch_dot(x, y, axes=None):
    if type(axes) is int:
        axes = (axes, axes)
    def _dot(X):
        x, y = X
        x_shape = x.size()
        y_shape = y.size()
        x_ndim = len(x_shape)
        y_ndim = len(y_shape)
        if x_ndim <= 3 and y_ndim <= 3:
            if x_ndim < 3:
                x_diff = 3 - x_ndim
                for i in range(diff):
                    x = torch.unsqueeze(x, x_ndim + i)
            else:
                x_diff = 0
            if y_ndim < 3:
                y_diff = 3 - y_ndim
                for i in range(diff):
                    y = torch.unsqueeze(y, y_ndim + i)
            else:
                y_diff = 0
            if axes[0] == 1:
                x = torch.transpose(x, 1, 2)
            elif axes[0] == 2:
                pass
            else:
                raise Exception('Invalid axis : ' + str(axes[0]))
            if axes[1] == 2:
                x = torch.transpose(x, 1, 2)
            # -------TODO--------------#
torch_backend.py 文件源码 项目:ktorch 作者: farizrahman4u 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def expand_dims(x, axis=-1):
    def _expand_dims(x, axis=axis):
        return torch.unsqueeze(x, axis)

    def _compute_output_shape(x, axis=axis):
        shape = list(_get_shape(x))
        shape.insert(axis, 1)
        return shape

    return get_op(_expand_dims, output_shape=_compute_output_shape, arguments=[axis])(x)
matrix.py 文件源码 项目:paysage 作者: drckf 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def scatter_(mat: T.Tensor, inds: T.LongTensor, val: T.Scalar) -> T.Tensor:
    """
    Assign a value a specific points in a matrix.
    Iterates along the rows of mat,
    successively assigning val to column indices given by inds.

    Note:
        Modifies mat in place.

    Args:
        mat: A tensor.
        inds: The indices
        val: The value to insert
    """
    return mat.scatter_(1, inds.unsqueeze(1), val)
matrix.py 文件源码 项目:paysage 作者: drckf 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def unsqueeze(tensor: T.Tensor, axis: int) -> T.Tensor:
    """
    Return tensor with a new axis inserted.

    Args:
        tensor: A tensor.
        axis: The desired axis.

    Returns:
        tensor: A tensor with the new axis inserted.

    """
    return torch.unsqueeze(tensor, axis)
matrix.py 文件源码 项目:paysage 作者: drckf 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def broadcast(vec: T.FloatTensor, matrix: T.FloatTensor) -> T.FloatTensor:
    """
    Broadcasts vec into the shape of matrix following numpy rules:

    vec ~ (N, 1) broadcasts to matrix ~ (N, M)
    vec ~ (1, N) and (N,) broadcast to matrix ~ (M, N)

    Args:
        vec: A vector (either flat, row, or column).
        matrix: A matrix (i.e., a 2D tensor).

    Returns:
        tensor: A tensor of the same size as matrix containing the elements
                of the vector.

    Raises:
        BroadcastError

    """
    try:
        if ndim(vec) == 1:
            if ndim(matrix) == 1:
                return vec
            return vec.unsqueeze(0).expand(matrix.size(0), matrix.size(1))
        else:
            return vec.expand(matrix.size(0), matrix.size(1))
    except ValueError:
        raise BroadcastError('cannot broadcast vector of dimension {} \
              onto matrix of dimension {}'.format(shape(vec), shape(matrix)))
matrix.py 文件源码 项目:paysage 作者: drckf 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def repeat(tensor: T.FloatTensor, n: int) -> T.FloatTensor:
    """
    Repeat tensor n times along specified axis.

    Args:
        tensor: A vector (i.e., 1D tensor).
        n: The number of repeats.

    Returns:
        tensor: A vector created from many repeats of the input tensor.

    """
    # current implementation only works for vectors
    assert ndim(tensor) == 1
    return flatten(tensor.unsqueeze(1).repeat(1, n))
image_pool.py 文件源码 项目:GAN_Liveness_Detection 作者: yunfan0621 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def query(self, images):
        # images: torch.Variable of size [batch_size, channel * 2, w, h]

        if self.pool_size == 0:
            return images

        return_images = []
        for image in images.data: # traverse data in batch dimension
            image = torch.unsqueeze(image, 0)

            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                # randomly substitute
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)

        return_images = Variable(torch.cat(return_images, 0))

        return return_images
memory.py 文件源码 项目:LSH_Memory 作者: RUSH-LAB 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def index(batch_size, x):
    idx = torch.arange(0, batch_size).long() 
    idx = torch.unsqueeze(idx, -1)
    return torch.cat((idx, x), dim=1)
memory.py 文件源码 项目:LSH_Memory 作者: RUSH-LAB 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def update(self, query, y, y_hat, y_hat_indices):
        batch_size, dims = query.size()

        # 1) Untouched: Increment memory by 1
        self.age += 1

        # Divide batch by correctness
        result = torch.squeeze(torch.eq(y_hat, torch.unsqueeze(y.data, dim=1))).float()
        incorrect_examples = torch.squeeze(torch.nonzero(1-result))
        correct_examples = torch.squeeze(torch.nonzero(result))

        incorrect = len(incorrect_examples.size()) > 0
        correct = len(correct_examples.size()) > 0

        # 2) Correct: if V[n1] = v
        # Update Key k[n1] <- normalize(q + K[n1]), Reset Age A[n1] <- 0
        if correct:
            correct_indices = y_hat_indices[correct_examples]
            correct_keys = self.keys[correct_indices]
            correct_query = query.data[correct_examples]

            new_correct_keys = F.normalize(correct_keys + correct_query, dim=1)
            self.keys[correct_indices] = new_correct_keys
            self.age[correct_indices] = 0

        # 3) Incorrect: if V[n1] != v
        # Select item with oldest age, Add random offset - n' = argmax_i(A[i]) + r_i 
        # K[n'] <- q, V[n'] <- v, A[n'] <- 0
        if incorrect:
            incorrect_size = incorrect_examples.size()[0]
            incorrect_query = query.data[incorrect_examples]
            incorrect_values = y.data[incorrect_examples]

            age_with_noise = self.age + random_uniform((self.memory_size, 1), -self.age_noise, self.age_noise, cuda=True)
            topk_values, topk_indices = torch.topk(age_with_noise, incorrect_size, dim=0)
            oldest_indices = torch.squeeze(topk_indices)

            self.keys[oldest_indices] = incorrect_query
            self.values[oldest_indices] = incorrect_values
            self.age[oldest_indices] = 0
Conv2Conv.py 文件源码 项目:OpenNMT-py 作者: OpenNMT 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def shape_transform(x):
    """ Tranform the size of the tensors to fit for conv input. """
    return torch.unsqueeze(torch.transpose(x, 1, 2), 3)
model.py 文件源码 项目:relational-networks 作者: kimhc6028 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def forward(self, img, qst):
        x = self.conv(img) ## x = (64 x 24 x 5 x 5)

        """g"""
        mb = x.size()[0]
        n_channels = x.size()[1]
        d = x.size()[2]
        # x_flat = (64 x 25 x 24)
        x_flat = x.view(mb,n_channels,d*d).permute(0,2,1)

        # add coordinates
        x_flat = torch.cat([x_flat, self.coord_tensor],2)

        # add question everywhere
        qst = torch.unsqueeze(qst, 1)
        qst = qst.repeat(1,25,1)
        qst = torch.unsqueeze(qst, 2)

        # cast all pairs against each other
        x_i = torch.unsqueeze(x_flat,1) # (64x1x25x26+11)
        x_i = x_i.repeat(1,25,1,1) # (64x25x25x26+11)
        x_j = torch.unsqueeze(x_flat,2) # (64x25x1x26+11)
        x_j = torch.cat([x_j,qst],3)
        x_j = x_j.repeat(1,1,25,1) # (64x25x25x26+11)

        # concatenate all together
        x_full = torch.cat([x_i,x_j],3) # (64x25x25x2*26+11)

        # reshape for passing through network
        x_ = x_full.view(mb*d*d*d*d,63)
        x_ = self.g_fc1(x_)
        x_ = F.relu(x_)
        x_ = self.g_fc2(x_)
        x_ = F.relu(x_)
        x_ = self.g_fc3(x_)
        x_ = F.relu(x_)
        x_ = self.g_fc4(x_)
        x_ = F.relu(x_)

        # reshape again and sum
        x_g = x_.view(mb,d*d*d*d,256)
        x_g = x_g.sum(1).squeeze()

        """f"""
        x_f = self.f_fc1(x_g)
        x_f = F.relu(x_f)

        return self.fcout(x_f)
OneShotMiniImageNetBuilder.py 文件源码 项目:MatchingNetworks 作者: gitabcworld 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def run_validation_epoch(self):
        """
        Runs one validation epoch
        :param total_val_batches: Number of batches to train on
        :return: mean_validation_categorical_crossentropy_loss and mean_validation_accuracy
        """
        total_val_c_loss = 0.
        total_val_accuracy = 0.
        total_val_batches = len(self.val_loader)
        pbar = tqdm(enumerate(self.val_loader))
        for batch_idx, (x_support_set, y_support_set, x_target, target_y) in pbar:

                x_support_set = Variable(x_support_set).float()
                y_support_set = Variable(y_support_set,requires_grad=False).long()
                x_target = Variable(x_target.squeeze()).float()
                y_target = Variable(target_y.squeeze(),requires_grad=False).long()

                # y_support_set: Add extra dimension for the one_hot
                y_support_set = torch.unsqueeze(y_support_set, 2)
                sequence_length = y_support_set.size()[1]
                batch_size = y_support_set.size()[0]
                y_support_set_one_hot = torch.FloatTensor(batch_size, sequence_length,
                                                          self.classes_per_set).zero_()
                y_support_set_one_hot.scatter_(2, y_support_set.data, 1)
                y_support_set_one_hot = Variable(y_support_set_one_hot)

                if self.isCudaAvailable:
                    acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
                                                         x_target.cuda(), y_target.cuda())
                else:
                    acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
                                                         x_target, y_target)

                iter_out = "val_loss: {}, val_accuracy: {}".format(c_loss_value.data[0], acc.data[0])
                pbar.set_description(iter_out)
                pbar.update(1)

                total_val_c_loss += c_loss_value.data[0]
                total_val_accuracy += acc.data[0]

        total_val_c_loss = total_val_c_loss / total_val_batches
        total_val_accuracy = total_val_accuracy / total_val_batches

        return total_val_c_loss, total_val_accuracy
OneShotMiniImageNetBuilder.py 文件源码 项目:MatchingNetworks 作者: gitabcworld 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def run_testing_epoch(self):
        """
        Runs one testing epoch
        :param total_test_batches: Number of batches to train on
        :param sess: Session object
        :return: mean_testing_categorical_crossentropy_loss and mean_testing_accuracy
        """
        total_test_c_loss = 0.
        total_test_accuracy = 0.
        total_test_batches = len(self.test_loader)
        pbar = tqdm(enumerate(self.test_loader))
        for batch_idx, (x_support_set, y_support_set, x_target, target_y) in pbar:

                x_support_set = Variable(x_support_set).float()
                y_support_set = Variable(y_support_set,requires_grad=False).long()
                x_target = Variable(x_target.squeeze()).float()
                y_target = Variable(target_y.squeeze(),requires_grad=False).long()

                # y_support_set: Add extra dimension for the one_hot
                y_support_set = torch.unsqueeze(y_support_set, 2)
                sequence_length = y_support_set.size()[1]
                batch_size = y_support_set.size()[0]
                y_support_set_one_hot = torch.FloatTensor(batch_size, sequence_length,
                                                          self.classes_per_set).zero_()
                y_support_set_one_hot.scatter_(2, y_support_set.data, 1)
                y_support_set_one_hot = Variable(y_support_set_one_hot)

                if self.isCudaAvailable:
                    acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
                                                         x_target.cuda(), y_target.cuda())
                else:
                    acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
                                                         x_target, y_target)

                iter_out = "test_loss: {}, test_accuracy: {}".format(c_loss_value.data[0], acc.data[0])
                pbar.set_description(iter_out)
                pbar.update(1)

                total_test_c_loss += c_loss_value.data[0]
                total_test_accuracy += acc.data[0]

        total_test_c_loss = total_test_c_loss / total_test_batches
        total_test_accuracy = total_test_accuracy / total_test_batches
        return total_test_c_loss, total_test_accuracy
OneShotBuilder.py 文件源码 项目:MatchingNetworks 作者: gitabcworld 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def run_validation_epoch(self, total_val_batches):
        """
        Runs one validation epoch
        :param total_val_batches: Number of batches to train on
        :return: mean_validation_categorical_crossentropy_loss and mean_validation_accuracy
        """
        total_val_c_loss = 0.
        total_val_accuracy = 0.

        with tqdm.tqdm(total=total_val_batches) as pbar:
            for i in range(total_val_batches):  # validation epoch
                x_support_set, y_support_set, x_target, y_target = \
                    self.data.get_batch(str_type='val', rotate_flag=False)

                x_support_set = Variable(torch.from_numpy(x_support_set), volatile=True).float()
                y_support_set = Variable(torch.from_numpy(y_support_set), volatile=True).long()
                x_target = Variable(torch.from_numpy(x_target), volatile=True).float()
                y_target = Variable(torch.from_numpy(y_target), volatile=True).long()

                # y_support_set: Add extra dimension for the one_hot
                y_support_set = torch.unsqueeze(y_support_set, 2)
                sequence_length = y_support_set.size()[1]
                batch_size = y_support_set.size()[0]
                y_support_set_one_hot = torch.FloatTensor(batch_size, sequence_length,
                                                          self.classes_per_set).zero_()
                y_support_set_one_hot.scatter_(2, y_support_set.data, 1)
                y_support_set_one_hot = Variable(y_support_set_one_hot)

                # Reshape channels
                size = x_support_set.size()
                x_support_set = x_support_set.view(size[0], size[1], size[4], size[2], size[3])
                size = x_target.size()
                x_target = x_target.view(size[0],size[1],size[4],size[2],size[3])
                if self.isCudaAvailable:
                    acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
                                                         x_target.cuda(), y_target.cuda())
                else:
                    acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
                                                         x_target, y_target)

                iter_out = "val_loss: {}, val_accuracy: {}".format(c_loss_value.data[0], acc.data[0])
                pbar.set_description(iter_out)
                pbar.update(1)

                total_val_c_loss += c_loss_value.data[0]
                total_val_accuracy += acc.data[0]

        total_val_c_loss = total_val_c_loss / total_val_batches
        total_val_accuracy = total_val_accuracy / total_val_batches

        return total_val_c_loss, total_val_accuracy
OneShotBuilder.py 文件源码 项目:MatchingNetworks 作者: gitabcworld 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def run_testing_epoch(self, total_test_batches):
        """
        Runs one testing epoch
        :param total_test_batches: Number of batches to train on
        :param sess: Session object
        :return: mean_testing_categorical_crossentropy_loss and mean_testing_accuracy
        """
        total_test_c_loss = 0.
        total_test_accuracy = 0.
        with tqdm.tqdm(total=total_test_batches) as pbar:
            for i in range(total_test_batches):
                x_support_set, y_support_set, x_target, y_target = \
                    self.data.get_batch(str_type='test', rotate_flag=False)

                x_support_set = Variable(torch.from_numpy(x_support_set), volatile=True).float()
                y_support_set = Variable(torch.from_numpy(y_support_set), volatile=True).long()
                x_target = Variable(torch.from_numpy(x_target), volatile=True).float()
                y_target = Variable(torch.from_numpy(y_target), volatile=True).long()

                # y_support_set: Add extra dimension for the one_hot
                y_support_set = torch.unsqueeze(y_support_set, 2)
                sequence_length = y_support_set.size()[1]
                batch_size = y_support_set.size()[0]
                y_support_set_one_hot = torch.FloatTensor(batch_size, sequence_length,
                                                          self.classes_per_set).zero_()
                y_support_set_one_hot.scatter_(2, y_support_set.data, 1)
                y_support_set_one_hot = Variable(y_support_set_one_hot)

                # Reshape channels
                size = x_support_set.size()
                x_support_set = x_support_set.view(size[0], size[1], size[4], size[2], size[3])
                size = x_target.size()
                x_target = x_target.view(size[0],size[1],size[4],size[2],size[3])
                if self.isCudaAvailable:
                    acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
                                                         x_target.cuda(), y_target.cuda())
                else:
                    acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
                                                         x_target, y_target)

                iter_out = "test_loss: {}, test_accuracy: {}".format(c_loss_value.data[0], acc.data[0])
                pbar.set_description(iter_out)
                pbar.update(1)

                total_test_c_loss += c_loss_value.data[0]
                total_test_accuracy += acc.data[0]
            total_test_c_loss = total_test_c_loss / total_test_batches
            total_test_accuracy = total_test_accuracy / total_test_batches
        return total_test_c_loss, total_test_accuracy
eew_rnn_cuda.py 文件源码 项目:EarlyWarning 作者: wjlei1990 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def main():
    outputdir = "output.disp.abs"
    if not os.path.exists(outputdir):
        os.makedirs(outputdir)

    waveforms, magnitudes = load_data()
    data_split = split_data(waveforms, magnitudes, train_percentage=0.9)
    print("dimension of train x and y: ", data_split["train_x"].shape,
          data_split["train_y"].shape)
    print("dimension of test x and y: ", data_split["test_x"].shape,
          data_split["test_y"].shape)
    train_loader = make_dataloader(data_split["train_x"],
                                   data_split["train_y"])

    rnn = RNN(input_size, hidden_size, num_layers)
    rnn.cuda()
    print(rnn)

    optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
    loss_func = nn.MSELoss()

    # train
    ntest = data_split["train_x"].shape[0]
    all_loss = {}
    for epoch in range(3):
        loss_epoch = []
        for step, (batch_x, batch_y) in enumerate(train_loader):
            x = torch.unsqueeze(batch_x[0, :, :].t(), dim=1)
            if step % int((ntest/100) + 1) == 1:
                print('Epoch: ', epoch, '| Step: %d/%d' % (step, ntest),
                      "| Loss: %f" % np.mean(loss_epoch))
            if CUDA_FLAG:
                x = Variable(x).cuda()
                y = Variable(torch.Tensor([batch_y.numpy(), ])).cuda()
            else:
                x = Variable(x)
                y = Variable(torch.Tensor([batch_y.numpy(), ]))
            prediction = rnn(x)
            loss = loss_func(prediction, y)
            optimizer.zero_grad()  # clear gradients for this training step
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()
            loss_epoch.append(loss.data[0])
        all_loss["epoch_%d" % epoch] = loss_epoch

        outputfn = os.path.join(outputdir, "loss.epoch_%d.json" % epoch)
        print("epoch loss file: %s" % outputfn)
        dump_json(loss_epoch, outputfn)

    # test
    pred_y = predict_on_test(rnn, data_split["test_x"])
    test_y = data_split["test_y"]
    _mse = mean_squared_error(test_y, pred_y)
    _std = np.std(test_y - pred_y)
    print("MSE and error std: %f, %f" % (_mse, _std))

    outputfn = os.path.join(outputdir, "prediction.json")
    print("output file: %s" % outputfn)
    data = {"test_y": list(test_y), "test_y_pred": list(pred_y),
            "epoch_loss": all_loss, "mse": _mse, "err_std": _std}
    dump_json(data, outputfn)
urnn.py 文件源码 项目:URNN-PyTorch 作者: jingli9111 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _EUNN(self, hx, thetaA, thetaB):

        L = self.capacity
        N = self.hidden_size

        sinA = torch.sin(self.thetaA)
        cosA = torch.cos(self.thetaA)
        sinB = torch.sin(self.thetaB)
        cosB = torch.cos(self.thetaB)

        I = Variable(torch.ones((L/2, 1)))
        O = Variable(torch.zeros((L/2, 1)))

        diagA = torch.stack((cosA, cosA), 2)
        offA = torch.stack((-sinA, sinA), 2)
        diagB = torch.stack((cosB, cosB), 2)
        offB = torch.stack((-sinB, sinB), 2)

        diagA = diagA.view(L/2, N)
        offA = offA.view(L/2, N)
        diagB = diagB.view(L/2, N-2)
        offB = offB.view(L/2, N-2)

        diagB = torch.cat((I, diagB, I), 1)
        offB = torch.cat((O, offB, O), 1)

        batch_size = hx.size()[0]
        x = hx
        for i in range(L/2):
#           # A
            y = x.view(batch_size, N/2, 2)
            y = torch.stack((y[:,:,1], y[:,:,0]), 2)
            y = y.view(batch_size, N)

            x = torch.mul(x, diagA[i].expand_as(x))
            y = torch.mul(y, offA[i].expand_as(x))

            x = x + y

            # B
            x_top = x[:,0]
            x_mid = x[:,1:-1].contiguous()
            x_bot = x[:,-1]
            y = x_mid.view(batch_size, N/2-1, 2)
            y = torch.stack((y[:, :, 1], y[:, :, 0]), 1)
            y = y.view(batch_size, N-2)
            x_top = torch.unsqueeze(x_top, 1)
            x_bot = torch.unsqueeze(x_bot, 1)
            # print x_top.size(), y.size(), x_bot.size()
            y = torch.cat((x_top, y, x_bot), 1)

            x = x * diagB[i].expand(batch_size, N)
            y = y * offB[i].expand(batch_size, N)

            x = x + y
        return x
goru.py 文件源码 项目:URNN-PyTorch 作者: jingli9111 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _EUNN(self, hx, thetaA, thetaB):

        L = self.capacity
        N = self.hidden_size

        sinA = torch.sin(self.thetaA)
        cosA = torch.cos(self.thetaA)
        sinB = torch.sin(self.thetaB)
        cosB = torch.cos(self.thetaB)

        I = Variable(torch.ones((L//2, 1)))
        O = Variable(torch.zeros((L//2, 1)))

        diagA = torch.stack((cosA, cosA), 2)
        offA = torch.stack((-sinA, sinA), 2)
        diagB = torch.stack((cosB, cosB), 2)
        offB = torch.stack((-sinB, sinB), 2)

        diagA = diagA.view(L//2, N)
        offA = offA.view(L//2, N)
        diagB = diagB.view(L//2, N-2)
        offB = offB.view(L//2, N-2)

        diagB = torch.cat((I, diagB, I), 1)
        offB = torch.cat((O, offB, O), 1)

        batch_size = hx.size()[0]
        x = hx
        for i in range(L//2):
#           # A
            y = x.view(batch_size, N//2, 2)
            y = torch.stack((y[:,:,1], y[:,:,0]), 2)
            y = y.view(batch_size, N)

            x = torch.mul(x, diagA[i].expand_as(x))
            y = torch.mul(y, offA[i].expand_as(x))

            x = x + y

            # B
            x_top = x[:,0]
            x_mid = x[:,1:-1].contiguous()
            x_bot = x[:,-1]
            y = x_mid.view(batch_size, N//2-1, 2)
            y = torch.stack((y[:, :, 1], y[:, :, 0]), 1)
            y = y.view(batch_size, N-2)
            x_top = torch.unsqueeze(x_top, 1)
            x_bot = torch.unsqueeze(x_bot, 1)
            # print x_top.size(), y.size(), x_bot.size()
            y = torch.cat((x_top, y, x_bot), 1)

            x = x * diagB[i].expand(batch_size, N)
            y = y * offB[i].expand(batch_size, N)

            x = x + y
        return x


问题


面经


文章

微信
公众号

扫码关注公众号