def generate_batch_pvdm(batch_size, window_size):
'''
Batch generator for PV-DM (Distributed Memory Model of Paragraph Vectors).
batch should be a shape of (batch_size, window_size+1)
Parameters
----------
batch_size: number of words in each mini-batch
window_size: number of leading words on before the target word direction
'''
global data_index
assert batch_size % window_size == 0
batch = np.ndarray(shape=(batch_size, window_size + 1), dtype=np.int32)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
span = window_size + 1
buffer = collections.deque(maxlen=span) # used for collecting word_ids[data_index] in the sliding window
buffer_doc = collections.deque(maxlen=span) # collecting id of documents in the sliding window
# collect the first window of words
for _ in range(span):
buffer.append(word_ids[data_index])
buffer_doc.append(doc_ids[data_index])
data_index = (data_index + 1) % len(word_ids)
mask = [1] * span
mask[-1] = 0
i = 0
while i < batch_size:
if len(set(buffer_doc)) == 1:
doc_id = buffer_doc[-1]
# all leading words and the doc_id
batch[i, :] = list(compress(buffer, mask)) + [doc_id]
labels[i, 0] = buffer[-1] # the last word at end of the sliding window
i += 1
# print buffer
# print list(compress(buffer, mask))
# move the sliding window
buffer.append(word_ids[data_index])
buffer_doc.append(doc_ids[data_index])
data_index = (data_index + 1) % len(word_ids)
return batch, labels
## examinng the batch generator function
评论列表
文章目录