Last active
October 25, 2023 11:14
-
-
Save AndreFCruz/d4708ff4c528b76eb929b0a5bc5aadf5 to your computer and use it in GitHub Desktop.
A minimalist implementation of local caching to prevent re-running time-consuming function calls.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Utils for caching results. | |
| """ | |
| import json | |
| import pickle | |
| import logging | |
| from pathlib import Path | |
| from hashlib import sha256 | |
| import cloudpickle | |
| DEFAULT_CACHE_DIR = Path("/var/tmp") | |
| DEFAULT_CACHE_NAME = "local-python-cache" | |
| def _hash_jsonable_object(obj: object, hash_func=sha256) -> str: | |
| """Hash any jsonable object to a reproducible unique identifier. | |
| """ | |
| return hash_func(json.dumps(obj, sort_keys=True).encode("utf-8")).hexdigest() | |
| class LocalCache(): | |
| """Class used to cache returned objects to avoid unnecessarily re-computing. | |
| Usage example: | |
| >>> cache = LocalCache(<unique_name>) | |
| >>> with cache: | |
| >>> cache.get_or_set("foo", foo, *args, **kwargs) | |
| """ | |
| UNINITIATED_USAGE_MSG = """ | |
| Trying to use a `LocalCache` object whose context has not been entered. | |
| The `LocalCache` class serves as a context manager that should be used as follows: | |
| >>> with LocalCache(<cache_name>) as cache: | |
| >>> ... | |
| """ | |
| def __init__( | |
| self, | |
| name: str = None, | |
| save_dir: str | Path = None, | |
| overwrite: bool = False, | |
| unique_context_details: dict = None, | |
| ): | |
| """Constructor for the cache object. | |
| Parameters | |
| ---------- | |
| name : str, optional | |
| A *unique* name to identify this cache locally. | |
| save_dir : str | Path, optional | |
| The path to a directory where the cache should be saved. | |
| overwrite : bool, optional | |
| Whether to re-compute and overwrite previously cached artifacts, by | |
| default False. | |
| unique_context_details : dict, optional | |
| A dictionary containing details that uniquely identify this cache | |
| object; can be used as an alternative to a unique name. A unique | |
| identifier will be generated from this dictionary. | |
| """ | |
| self.save_dir = save_dir or DEFAULT_CACHE_DIR | |
| self.overwrite = overwrite | |
| assert name is None or unique_context_details is None, ( | |
| "Cannot provide both `name` and `context_unique_details`, " | |
| "will lead to conflicting naming scheme.") | |
| if unique_context_details is not None: | |
| assert name is None | |
| self.name = _hash_jsonable_object(unique_context_details) | |
| else: | |
| self.name = name or DEFAULT_CACHE_NAME | |
| # Path to the cache file in the local file system | |
| self.cache_path = self.save_dir / f"{self.name}.cache.pkl" | |
| logging.info(f"Cache path: '{self.cache_path}'.") | |
| # Cache variable | |
| self._cache = dict() | |
| self._lifecycle_initiated = False | |
| def _check_lifecycle_has_been_initiated(self, raise_: bool = True): | |
| if not self._lifecycle_initiated: | |
| if raise_: | |
| raise RuntimeError(self.UNINITIATED_USAGE_MSG) | |
| else: | |
| logging.error(self.UNINITIATED_USAGE_MSG) | |
| # Begin life-cycle | |
| def __enter__(self): | |
| assert not self._lifecycle_initiated | |
| self._lifecycle_initiated = True | |
| # Read current cache status (if one exists) | |
| if self.cache_path.exists(): | |
| logging.info(f"Loading cache state from disk at '{self.cache_path}'.") | |
| with open(self.cache_path, "rb") as f_in: | |
| self._cache = pickle.load(f_in) | |
| return self | |
| # End life-cycle | |
| def __exit__(self, _exc_type, _exc_val, _exc_tb) -> bool: | |
| assert self._lifecycle_initiated | |
| self._lifecycle_initiated = False | |
| logging.info(f"Saving cache state to disk at '{self.cache_path}'.") | |
| # Save current cache status to disk | |
| with open(self.cache_path, "wb") as f_out: | |
| cloudpickle.dump(self._cache, f_out) | |
| # Don't suppress any occurred exceptions | |
| return False | |
| # Cache getters and setters | |
| def __getitem__(self, key) -> object: | |
| self._check_lifecycle_has_been_initiated() | |
| return self._cache[key] | |
| def __setitem__(self, key, value): | |
| self._check_lifecycle_has_been_initiated() | |
| logging.info(f"Caching value with key='{key}'.") | |
| if key in self._cache and not self.overwrite: | |
| raise KeyError( | |
| f"Key '{key}' maps to an existing object, and " | |
| f"self.overwrite=={self.overwrite}." | |
| ) | |
| # Set cache value | |
| self._cache[key] = value | |
| def get(self, key: str, default=None) -> object: | |
| self._check_lifecycle_has_been_initiated() | |
| return self._cache.get(key, default) | |
| def get_or_set(self, key: str, func: callable, *args, **kwargs) -> object: | |
| self._check_lifecycle_has_been_initiated() | |
| if key not in self._cache or self.overwrite: | |
| if key in self._cache and self.overwrite: | |
| logging.warning(f"Overwriting previously cached result for '{key}'.") | |
| # Run `func` and cache results | |
| self[key] = func(*args, **kwargs) | |
| return self[key] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment