Skip to content

Instantly share code, notes, and snippets.

@AndreFCruz
Last active October 25, 2023 11:14
Show Gist options
  • Select an option

  • Save AndreFCruz/d4708ff4c528b76eb929b0a5bc5aadf5 to your computer and use it in GitHub Desktop.

Select an option

Save AndreFCruz/d4708ff4c528b76eb929b0a5bc5aadf5 to your computer and use it in GitHub Desktop.
A minimalist implementation of local caching to prevent re-running time-consuming function calls.
"""Utils for caching results.
"""
import json
import pickle
import logging
from pathlib import Path
from hashlib import sha256
import cloudpickle
DEFAULT_CACHE_DIR = Path("/var/tmp")
DEFAULT_CACHE_NAME = "local-python-cache"
def _hash_jsonable_object(obj: object, hash_func=sha256) -> str:
"""Hash any jsonable object to a reproducible unique identifier.
"""
return hash_func(json.dumps(obj, sort_keys=True).encode("utf-8")).hexdigest()
class LocalCache():
"""Class used to cache returned objects to avoid unnecessarily re-computing.
Usage example:
>>> cache = LocalCache(<unique_name>)
>>> with cache:
>>> cache.get_or_set("foo", foo, *args, **kwargs)
"""
UNINITIATED_USAGE_MSG = """
Trying to use a `LocalCache` object whose context has not been entered.
The `LocalCache` class serves as a context manager that should be used as follows:
>>> with LocalCache(<cache_name>) as cache:
>>> ...
"""
def __init__(
self,
name: str = None,
save_dir: str | Path = None,
overwrite: bool = False,
unique_context_details: dict = None,
):
"""Constructor for the cache object.
Parameters
----------
name : str, optional
A *unique* name to identify this cache locally.
save_dir : str | Path, optional
The path to a directory where the cache should be saved.
overwrite : bool, optional
Whether to re-compute and overwrite previously cached artifacts, by
default False.
unique_context_details : dict, optional
A dictionary containing details that uniquely identify this cache
object; can be used as an alternative to a unique name. A unique
identifier will be generated from this dictionary.
"""
self.save_dir = save_dir or DEFAULT_CACHE_DIR
self.overwrite = overwrite
assert name is None or unique_context_details is None, (
"Cannot provide both `name` and `context_unique_details`, "
"will lead to conflicting naming scheme.")
if unique_context_details is not None:
assert name is None
self.name = _hash_jsonable_object(unique_context_details)
else:
self.name = name or DEFAULT_CACHE_NAME
# Path to the cache file in the local file system
self.cache_path = self.save_dir / f"{self.name}.cache.pkl"
logging.info(f"Cache path: '{self.cache_path}'.")
# Cache variable
self._cache = dict()
self._lifecycle_initiated = False
def _check_lifecycle_has_been_initiated(self, raise_: bool = True):
if not self._lifecycle_initiated:
if raise_:
raise RuntimeError(self.UNINITIATED_USAGE_MSG)
else:
logging.error(self.UNINITIATED_USAGE_MSG)
# Begin life-cycle
def __enter__(self):
assert not self._lifecycle_initiated
self._lifecycle_initiated = True
# Read current cache status (if one exists)
if self.cache_path.exists():
logging.info(f"Loading cache state from disk at '{self.cache_path}'.")
with open(self.cache_path, "rb") as f_in:
self._cache = pickle.load(f_in)
return self
# End life-cycle
def __exit__(self, _exc_type, _exc_val, _exc_tb) -> bool:
assert self._lifecycle_initiated
self._lifecycle_initiated = False
logging.info(f"Saving cache state to disk at '{self.cache_path}'.")
# Save current cache status to disk
with open(self.cache_path, "wb") as f_out:
cloudpickle.dump(self._cache, f_out)
# Don't suppress any occurred exceptions
return False
# Cache getters and setters
def __getitem__(self, key) -> object:
self._check_lifecycle_has_been_initiated()
return self._cache[key]
def __setitem__(self, key, value):
self._check_lifecycle_has_been_initiated()
logging.info(f"Caching value with key='{key}'.")
if key in self._cache and not self.overwrite:
raise KeyError(
f"Key '{key}' maps to an existing object, and "
f"self.overwrite=={self.overwrite}."
)
# Set cache value
self._cache[key] = value
def get(self, key: str, default=None) -> object:
self._check_lifecycle_has_been_initiated()
return self._cache.get(key, default)
def get_or_set(self, key: str, func: callable, *args, **kwargs) -> object:
self._check_lifecycle_has_been_initiated()
if key not in self._cache or self.overwrite:
if key in self._cache and self.overwrite:
logging.warning(f"Overwriting previously cached result for '{key}'.")
# Run `func` and cache results
self[key] = func(*args, **kwargs)
return self[key]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment