def train():
c_i = pickle.loads( open("c_i.pkl", "rb").read() )
i_c = {i:c for c,i in c_i.items() }
xss = []
yss = []
for gi, pkl in enumerate(glob.glob("data/*.pkl")):
if gi > 500:
break
o = pickle.loads( open(pkl, "rb").read() )
img = o["image"]
kana = o["kana"]
print( kana )
xss.append( np.array(img) )
ys = [[0. for i in range(128) ] for j in range(50)]
for i,k in enumerate(list(kana[:50])):
try:
ys[i][c_i[k]] = 1.
except KeyError as e:
print(e)
yss.append( ys )
Xs = np.array( xss )
Ys = np.array( yss )
print(Xs.shape)
#optims = [Adam(lr=0.001), SGD(lr=0.01)]
optims = [Adam(), SGD(), RMSprop()]
if '--resume' in sys.argv:
"""
optims = [ Adam(lr=0.001), \
Adam(lr=0.0005), \
Adam(lr=0.0001), \
Adam(lr=0.00005), \
SGD(lr=0.01), \
SGD(lr=0.005), \
SGD(lr=0.001), \
SGD(lr=0.0005), \
]
"""
model = sorted( glob.glob("models/*.h5") ).pop(0)
print("loaded model is ", model)
t2i.load_weights(model)
for i in range(2000):
print_callback = LambdaCallback(on_epoch_end=callbacks)
batch_size = random.choice( [8] )
random_optim = random.choice( optims )
print( random_optim )
t2i.optimizer = random_optim
t2i.fit( Xs, Ys, shuffle=True, batch_size=batch_size, epochs=20, callbacks=[print_callback] )
if i%50 == 0:
t2i.save("models/%9f_%09d.h5"%(buff['loss'], i))
lossrate = buff["loss"]
os.system("echo \"{} {}\" `date` >> loss.log".format(i, lossrate))
print("saved ..")
print("logs...", buff )
评论列表
文章目录