Last active
December 21, 2017 22:53
-
-
Save atremblay/00671abe8b741700aa279c5a189531c3 to your computer and use it in GitHub Desktop.
Keras Batch Profile Callback. Measure time to run one batch. Use mostly for monitoring.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| import datetime | |
| import tensorflow as tf | |
| class BatchProfileCallback(tf.keras.callbacks.Callback): | |
| """ | |
| BatchProfileCallback | |
| Log how long a batch takes to run. Monitor potential slow downs. | |
| """ | |
| def __init__(self, log_file, name='batch_monitor'): | |
| super(BatchProfileCallback, self).__init__() | |
| self.log_file = log_file | |
| self.logger = logging.getLogger(name) | |
| self.logger.setLevel(logging.INFO) | |
| ch = logging.FileHandler(filename=log_file, mode='w') | |
| formatter = logging.Formatter( | |
| '%(asctime)s; %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| ch.setLevel(logging.INFO) | |
| ch.setFormatter(formatter) | |
| self.logger.addHandler(ch) | |
| self.epoch = 0 | |
| def on_batch_begin(self, batch, logs=None): | |
| self.batch_start_time = datetime.datetime.now() | |
| def on_batch_end(self, batch, logs=None): | |
| time_elapsed = datetime.datetime.now() - self.batch_start_time | |
| time_elapsed = time_elapsed.total_seconds() | |
| msg = "epoch {}; batch {}; time to run: {:.2f} seconds" | |
| self.logger.info(msg.format(self.epoch, batch, time_elapsed)) | |
| def on_epoch_begin(self, epoch, logs=None): | |
| self.epoch = epoch |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment