svhn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:mxnet_workshop 作者: NervanaSystems 项目源码 文件源码
def setup_dashboard(self):
        output_notebook()

        x = [0,    1,     1,      2,     2,     3,    3,    4]
        y = 8*[10 ** -7]

        source = ColumnDataSource(data=dict(x=x, y=y))

        plot = figure(plot_width=300, plot_height=150, y_axis_type="log", y_range=[0.0000000001, 1], x_range=[0, 4],
                      x_axis_label='Epoch', y_axis_label='Learning Rate')
        plot.line('x', 'y', source=source, line_width=3, line_alpha=0.6)

        learn_inputs = 4 * [3]

        base_code = """
                var data = source.data;
                var f = cb_obj.value
                x = data['x']
                y = data['y']
                y[{}] = Math.pow(10.0, -1.0 * (10-f))
                y[{}] = Math.pow(10.0, -1.0 * (10-f))
                source.trigger('change');

                var command = 'dashboard.learn_inputs[{}] = ' + f;
                var kernel = IPython.notebook.kernel;
                kernel.execute(command)
            """

        # set up figure
        fig = figure(name="cost", y_axis_label="Cost", x_range=(0, 4),
                     x_axis_label="Epoch", plot_width=300, plot_height=300)
        self.fig = fig
        train_source = ColumnDataSource(data=dict(x=[], y=[]))
        train_cost = fig.line('x', 'y', source=train_source)
        self.train_source = train_source

        val_source = ColumnDataSource(data=dict(x=[], y=[]))
        val_cost = fig.line('x', 'y', source=val_source, color='red')
        self.val_source = val_source

        # set up sliders and callback
        callbacks = [CustomJS(args=dict(source=source), code=base_code.format(k, k+1, k/2)) for k in [0, 2, 4, 6]]
        slider = [Slider(start=0.1, end=10, value=3, step=.1, title=None, callback=C, orientation='vertical', width=80, height=50) for C in callbacks]

        radio_group = RadioGroup(labels=[""], active=0, width=65)

        def train_model_button(run=True):
            train_model(slider, fig=fig, handle=fh, train_source=train_source, val_source=val_source)

        sliders = row(radio_group, slider[0], slider[1], slider[2], slider[3])
        settings = column(plot, sliders)


        layout = gridplot([[settings, fig]], sizing_mode='fixed', merge_tools=True, toolbar_location=None)

        self.fh = show(layout, notebook_handle=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号