paragraph_vector.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:tensorflow-playground 作者: wangz10 项目源码 文件源码
def generate_batch_pvdm(batch_size, window_size):
    '''
    Batch generator for PV-DM (Distributed Memory Model of Paragraph Vectors).
    batch should be a shape of (batch_size, window_size+1)

    Parameters
    ----------
    batch_size: number of words in each mini-batch
    window_size: number of leading words on before the target word direction 
    '''
    global data_index
    assert batch_size % window_size == 0
    batch = np.ndarray(shape=(batch_size, window_size + 1), dtype=np.int32)
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
    span = window_size + 1
    buffer = collections.deque(maxlen=span) # used for collecting word_ids[data_index] in the sliding window
    buffer_doc = collections.deque(maxlen=span) # collecting id of documents in the sliding window
    # collect the first window of words
    for _ in range(span):
        buffer.append(word_ids[data_index])
        buffer_doc.append(doc_ids[data_index])
        data_index = (data_index + 1) % len(word_ids)

    mask = [1] * span
    mask[-1] = 0 
    i = 0
    while i < batch_size:
        if len(set(buffer_doc)) == 1:
            doc_id = buffer_doc[-1]
            # all leading words and the doc_id
            batch[i, :] = list(compress(buffer, mask)) + [doc_id]
            labels[i, 0] = buffer[-1] # the last word at end of the sliding window
            i += 1
            # print buffer
            # print list(compress(buffer, mask))
        # move the sliding window  
        buffer.append(word_ids[data_index])
        buffer_doc.append(doc_ids[data_index])
        data_index = (data_index + 1) % len(word_ids)

    return batch, labels

## examinng the batch generator function
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号