def predict(self, states):
""" Returns values for each state
:param states states as feature -> value dict
"""
previous_workspace = workspace.CurrentWorkspace()
workspace.SwitchWorkspace(self._workspace_id)
for input_blob in states:
workspace.FeedBlob(
input_blob,
np.atleast_1d(states[input_blob]).astype(np.float32)
)
workspace.RunNet(self._net)
result = {
output: workspace.FetchBlob(output)
for output in self._output_blobs
}
workspace.SwitchWorkspace(previous_workspace)
return result
评论列表
文章目录