def train_breaker(datafilename, sentence_num=1000, puncs=u',?.?!???', \
RNN=recurrent.GRU, HIDDEN_SIZE=128, EPOCH_SIZE=10, validate=True):
wordtable = WordTable()
wordtable.parse(datafilename, sentence_num)
X, Y = [], []
for line in open(datafilename).readlines()[:sentence_num]:
line = line.strip().decode('utf-8')
line = re.sub(ur'(^[{0}]+)|([{0}]+$)'.format(puncs),'',line)
words = wordtable.encode(re.sub(ur'[{0}]'.format(puncs),'',line))
breaks = re.sub(ur'0[{0}]+'.format(puncs),'1',re.sub(ur'[^{0}]'.format(puncs),'0',line))
if len(words) >= 30 and len(words) <= 50 and breaks.count('1') >= 4:
x = np.zeros((len(words), wordtable.capacity), dtype=np.bool)
y = np.zeros((len(breaks), 2), dtype=np.bool)
for idx in xrange(len(words)):
x[idx][words[idx]] = True
y[idx][int(breaks[idx])] = True
X.append(x)
Y.append(y)
print 'total sentence: ', len(X)
if validate:
# Set apart 10% for validation
split_at = len(X) - len(X)/10
X_train, X_val = X[:split_at], X[split_at:]
y_train, y_val = Y[:split_at], Y[split_at:]
else:
X_train, y_train = X, Y
model = Graph()
model.add_input(name='input', input_shape=(None, wordtable.capacity))
model.add_node(RNN(HIDDEN_SIZE, return_sequences=True), name='forward', input='input')
model.add_node(TimeDistributedDense(2, activation='softmax'), name='softmax', input='forward')
model.add_output(name='output', input='softmax')
model.compile('adam', {'output': 'categorical_crossentropy'})
for epoch in xrange(EPOCH_SIZE):
print "epoch: ", epoch
for idx, (seq, label) in enumerate(zip(X_train, y_train)):
loss, accuracy = model.train_on_batch({'input':np.array([seq]), 'output':np.array([label])}, accuracy=True)
if idx % 20 == 0:
print "\tidx={0}, loss={1}, accuracy={2}".format(idx, loss, accuracy)
if validate:
_Y, _P = [], []
for (seq, label) in zip(X_val, y_val):
y = label.argmax(axis=-1)
p = model.predict({'input':np.array([seq])})['output'][0].argmax(axis=-1)
_Y.extend(list(y))
_P.extend(list(p))
_Y, _P = np.array(_Y), np.array(_P)
print "should break right: ", ((_P == 1)*(_Y == 1)).sum()
print "should break wrong: ", ((_P == 0)*(_Y == 1)).sum()
print "should not break right: ", ((_P == 0)*(_Y == 0)).sum()
print "should not break wrong: ", ((_P == 1)*(_Y == 0)).sum()
with open('wordtable_json.txt','w') as wordtable_file:
wordtable_file.write(wordtable.to_json())
with open('model_json.txt','w') as model_file:
model_file.write(model.to_json())
model.save_weights('model_weights.h5', overwrite=True)
simple_sentence_breaker.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录