def generate_batch_cbow(data, batch_size, num_skips, skip_window):
'''
Batch generator for CBOW (Continuous Bag of Words).
batch should be a shape of (batch_size, num_skips)
Parameters
----------
data: list of index of words
batch_size: number of words in each mini-batch
num_skips: number of surrounding words on both direction (2: one word ahead and one word following)
skip_window: number of words at both ends of a sentence to skip (1: skip the first and last word of a sentence)
'''
global data_index
assert batch_size % num_skips == 0
assert num_skips <= 2 * skip_window
batch = np.ndarray(shape=(batch_size, num_skips), dtype=np.int32)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
span = 2 * skip_window + 1 # [ skip_window target skip_window ]
buffer = collections.deque(maxlen=span) # used for collecting data[data_index] in the sliding window
# collect the first window of words
for _ in range(span):
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
# move the sliding window
for i in range(batch_size):
mask = [1] * span
mask[skip_window] = 0
batch[i, :] = list(compress(buffer, mask)) # all surrounding words
labels[i, 0] = buffer[skip_window] # the word at the center
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
return batch, labels
评论列表
文章目录