def load_data(data_source):
assert data_source in ["keras_data_set", "local_dir"], "Unknown data source"
if data_source == "keras_data_set":
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_words, start_char=None,
oov_char=None, index_from=None)
x_train = sequence.pad_sequences(x_train, maxlen=sequence_length, padding="post", truncating="post")
x_test = sequence.pad_sequences(x_test, maxlen=sequence_length, padding="post", truncating="post")
vocabulary = imdb.get_word_index()
vocabulary_inv = dict((v, k) for k, v in vocabulary.items())
vocabulary_inv[0] = "<PAD/>"
else:
x, y, vocabulary, vocabulary_inv_list = data_helpers.load_data()
vocabulary_inv = {key: value for key, value in enumerate(vocabulary_inv_list)}
y = y.argmax(axis=1)
# Shuffle data
shuffle_indices = np.random.permutation(np.arange(len(y)))
x = x[shuffle_indices]
y = y[shuffle_indices]
train_len = int(len(x) * 0.9)
x_train = x[:train_len]
y_train = y[:train_len]
x_test = x[train_len:]
y_test = y[train_len:]
return x_train, y_train, x_test, y_test, vocabulary_inv
# Data Preparation
sentiment_cnn.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录