Skip to content

Instantly share code, notes, and snippets.

@atremblay
Last active December 21, 2017 22:53
Show Gist options
  • Select an option

  • Save atremblay/00671abe8b741700aa279c5a189531c3 to your computer and use it in GitHub Desktop.

Select an option

Save atremblay/00671abe8b741700aa279c5a189531c3 to your computer and use it in GitHub Desktop.
Keras Batch Profile Callback. Measure time to run one batch. Use mostly for monitoring.
import logging
import datetime
import tensorflow as tf
class BatchProfileCallback(tf.keras.callbacks.Callback):
"""
BatchProfileCallback
Log how long a batch takes to run. Monitor potential slow downs.
"""
def __init__(self, log_file, name='batch_monitor'):
super(BatchProfileCallback, self).__init__()
self.log_file = log_file
self.logger = logging.getLogger(name)
self.logger.setLevel(logging.INFO)
ch = logging.FileHandler(filename=log_file, mode='w')
formatter = logging.Formatter(
'%(asctime)s; %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
self.logger.addHandler(ch)
self.epoch = 0
def on_batch_begin(self, batch, logs=None):
self.batch_start_time = datetime.datetime.now()
def on_batch_end(self, batch, logs=None):
time_elapsed = datetime.datetime.now() - self.batch_start_time
time_elapsed = time_elapsed.total_seconds()
msg = "epoch {}; batch {}; time to run: {:.2f} seconds"
self.logger.info(msg.format(self.epoch, batch, time_elapsed))
def on_epoch_begin(self, epoch, logs=None):
self.epoch = epoch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment