train.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:product-category-classifier 作者: two-tap 项目源码 文件源码
def build_text_model(word_index):
  text_input = Input(shape=(MAX_SEQUENCE_LENGTH,))

  embedding_matrix = np.zeros((len(word_index) + 1, EMBEDDING_DIM))

  for word, i in word_index.items():
    embedding_vector = embeddings_index.get(word)

    if embedding_vector is not None:
      # words not found in embedding index will be all-zeros.
      embedding_matrix[i] = embedding_vector[:EMBEDDING_DIM]

  embedding_layer = Embedding(embedding_matrix.shape[0],
                              embedding_matrix.shape[1],
                              weights=[embedding_matrix],
                              input_length=MAX_SEQUENCE_LENGTH)



  x = embedding_layer(text_input)
  x.trainable = False
  x = Conv1D(128, 5, activation='relu')(x)
  x = MaxPooling1D(5)(x)
  x = Conv1D(128, 5, activation='relu')(x)
  x = MaxPooling1D(5)(x)
  x = Flatten()(x)
  x = Dense(1024, activation='relu')(x)

  return x, text_input

##
## Image model
##
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号