Created
February 20, 2026 15:22
-
-
Save SubhiH/739750cc0837a37ba9835ae3c1fa1e30 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Optional, Dict, Any | |
| import hashlib | |
| import json | |
| import os | |
| import logging | |
| from datetime import datetime, timedelta, timezone | |
| from google.adk.plugins.base_plugin import BasePlugin | |
| from google.adk.agents.callback_context import CallbackContext | |
| from google.adk.models.llm_request import LlmRequest | |
| from google.adk.models.llm_response import LlmResponse | |
| from google.genai import types | |
| from google.genai import Client | |
| logger = logging.getLogger(__name__) | |
| from google.adk.agents.context_cache_config import ContextCacheConfig | |
| class GlobalCachingPlugin(BasePlugin): | |
| """ | |
| A plugin that implements GLOBAL Context Caching by enforcing a stable | |
| prefix (Warmup Turn) and manually managing cache creation/reuse across sessions. | |
| """ | |
| # We use minimal tokens for the warmup to save cost/latency | |
| WARMUP_USER_TEXT = "#" | |
| WARMUP_MODEL_TEXT = "#" | |
| # Limit the number of cache entries in memory/file to prevent unbounded growth | |
| MAX_CACHE_ENTRIES = 50 | |
| CACHE_FILE_PATH = os.path.abspath(os.path.expanduser("./.adk/global_cache.json")) | |
| # Global Cache Map: Fingerprint -> {'name': str, 'expire_time': datetime} | |
| # Shared across all instances of the plugin | |
| _GLOBAL_CACHE_MAP: Dict[str, Dict[str, Any]] = {} | |
| def __init__(self, api_key: Optional[str] = None, context_cache_config: Optional[ContextCacheConfig] = None): | |
| super().__init__(name="global_caching_plugin") | |
| self.api_key = api_key or os.environ.get("GOOGLE_API_KEY") | |
| if not self.api_key: | |
| logger.warning("GlobalCachingPlugin initialized without API Key. Cache creation might fail.") | |
| self.cache_config = context_cache_config or ContextCacheConfig(ttl_seconds=600) | |
| # Initialize Client | |
| self.client = Client(api_key=self.api_key) | |
| # Load cache map from disk | |
| logger.info(f"[Plugin] Cache File Path: {self.CACHE_FILE_PATH}") | |
| self._load_from_disk() | |
| def _load_from_disk(self): | |
| """Loads the cache map from JSON file.""" | |
| if not os.path.exists(self.CACHE_FILE_PATH): | |
| return | |
| try: | |
| with open(self.CACHE_FILE_PATH, 'r') as f: | |
| data = json.load(f) | |
| # Parse datetimes and populate map | |
| for fingerprint, entry in data.items(): | |
| if 'expire_time' in entry: | |
| # ISO format to datetime | |
| entry['expire_time'] = datetime.fromisoformat(entry['expire_time']) | |
| self._GLOBAL_CACHE_MAP[fingerprint] = entry | |
| logger.info(f"[Plugin] Loaded {len(self._GLOBAL_CACHE_MAP)} cache entries from {self.CACHE_FILE_PATH}") | |
| except Exception as e: | |
| logger.warning(f"[Plugin] Failed to load cache from disk: {e}") | |
| def _save_to_disk(self): | |
| """Saves the cache map to JSON file.""" | |
| try: | |
| # Ensure directory exists | |
| os.makedirs(os.path.dirname(self.CACHE_FILE_PATH), exist_ok=True) | |
| # Convert datetimes to string for JSON | |
| data_to_save = {} | |
| for fp, entry in self._GLOBAL_CACHE_MAP.items(): | |
| entry_copy = entry.copy() | |
| if isinstance(entry_copy.get('expire_time'), datetime): | |
| entry_copy['expire_time'] = entry_copy['expire_time'].isoformat() | |
| data_to_save[fp] = entry_copy | |
| with open(self.CACHE_FILE_PATH, 'w') as f: | |
| json.dump(data_to_save, f, indent=2) | |
| logger.info(f"[Plugin] Saved cache map to: {self.CACHE_FILE_PATH}") | |
| except Exception as e: | |
| logger.error(f"[Plugin] Failed to save cache to disk: {e}") | |
| def _prune_map(self): | |
| """ | |
| Removes expired entries and enforces MAX_CACHE_ENTRIES. | |
| Saves to disk afterwards. | |
| """ | |
| now = datetime.now(timezone.utc) | |
| changed = False | |
| # 1. Provide a list of keys to remove | |
| keys_to_remove = [] | |
| for fp, entry in self._GLOBAL_CACHE_MAP.items(): | |
| expire_time = entry.get('expire_time') | |
| # Buffer of 10s | |
| if expire_time and now > (expire_time - timedelta(seconds=10)): | |
| keys_to_remove.append(fp) | |
| if keys_to_remove: | |
| logger.info(f"[Plugin] Pruning {len(keys_to_remove)} expired cache entries.") | |
| for k in keys_to_remove: | |
| del self._GLOBAL_CACHE_MAP[k] | |
| changed = True | |
| # 2. If still too big, remove oldest (FIFO) | |
| while len(self._GLOBAL_CACHE_MAP) >= self.MAX_CACHE_ENTRIES: | |
| oldest_key = next(iter(self._GLOBAL_CACHE_MAP)) | |
| del self._GLOBAL_CACHE_MAP[oldest_key] | |
| logger.info(f"[Plugin] Pruning oldest cache entry to enforce limit {self.MAX_CACHE_ENTRIES}.") | |
| changed = True | |
| # 3. Always save if we pruned OR if we are calling this before adding (which implies we want to sync) | |
| # Actually simplest to just save indiscriminately or if changed? | |
| # The caller usually calls this before adding. | |
| # Let's save if changed here, AND the caller should trigger save after adding. | |
| if changed: | |
| self._save_to_disk() | |
| def _estimate_request_tokens(self, llm_request: LlmRequest) -> int: | |
| """ | |
| Rough estimate of tokens to check against min_tokens. | |
| Logic adapted from GeminiContextCacheManager. | |
| """ | |
| total_chars = 0 | |
| # System instruction | |
| if llm_request.config and llm_request.config.system_instruction: | |
| total_chars += len(llm_request.config.system_instruction) | |
| # Tools | |
| if llm_request.config and llm_request.config.tools: | |
| for tool in llm_request.config.tools: | |
| # Simple serialization for estimation | |
| tool_str = str(tool) | |
| total_chars += len(tool_str) | |
| # Contents (including the ones we are about to inject or just general contents) | |
| # Note: At this point, we haven't injected the prefix yet in the caller, | |
| # but we should count the stable prefix as part of the "cacheable" content. | |
| total_chars += len(self.WARMUP_USER_TEXT) + len(self.WARMUP_MODEL_TEXT) | |
| # We only care about the STABLE part for the cache creation check? | |
| # Or the Total request? | |
| # Usually caching is worth it if the *Cached* part is large. | |
| # Here the cached part is System + Tools + Prefix. | |
| # The user's dynamic query is NOT cached. | |
| # So we should only count System + Tools + Prefix. | |
| return total_chars // 4 | |
| async def before_model_callback( | |
| self, *, callback_context: CallbackContext, llm_request: LlmRequest | |
| ) -> Optional[LlmResponse]: | |
| """ | |
| Intercepts the LLM request to: | |
| 1. Inject Warmup Prefix | |
| 2. Calculate Global Fingerprint | |
| 3. Find or Create Cache | |
| 4. Apply Cache to Request (and strip contents) | |
| 5. Disable Built-in Cache Manager | |
| """ | |
| # 0. Check min_tokens (Cost Optimization) | |
| # We only cache the Stable Part (System + Tools + Prefix). | |
| # Use simple estimation. | |
| estimated_tokens = self._estimate_request_tokens(llm_request) | |
| if estimated_tokens < self.cache_config.min_tokens: | |
| logger.info(f"[Plugin] Skipping Global Cache. Estimated stable output {estimated_tokens} < min_tokens {self.cache_config.min_tokens}") | |
| return None | |
| # 1. Inject Warmup Prefix if not present | |
| # Q: Why do we need this Prefix? | |
| # A: The Gemini API explicitly forbids creating a cache with ONLY System Instructions. | |
| # It requires the 'contents' list to have at least one turn. | |
| # Since we want a global cache that works for ANY user query, we cannot include | |
| # the user's actual first query (which changes every time). | |
| # Therefore, we inject this minimal, stable "Warmup Turn" (# / #) to satisfy | |
| # the API requirement while keeping the cache fingerprint stable. | |
| has_prefix = False | |
| if len(llm_request.contents) >= 2: | |
| first = llm_request.contents[0] | |
| if (first.role == 'user' and first.parts and | |
| first.parts[0].text == self.WARMUP_USER_TEXT): | |
| has_prefix = True | |
| if not has_prefix: | |
| logger.info("[Plugin] Injecting Warmup Prefix") | |
| warmup_user = types.Content(role='user', parts=[types.Part(text=self.WARMUP_USER_TEXT)]) | |
| warmup_model = types.Content(role='model', parts=[types.Part(text=self.WARMUP_MODEL_TEXT)]) | |
| llm_request.contents.insert(0, warmup_model) | |
| llm_request.contents.insert(0, warmup_user) | |
| # 2. Calculate Fingerprint | |
| params_str = "" | |
| # System Instruction | |
| if llm_request.config and llm_request.config.system_instruction: | |
| params_str += f"SYS:{llm_request.config.system_instruction}|" | |
| # Tools | |
| if llm_request.config and llm_request.config.tools: | |
| for tool in llm_request.config.tools: | |
| params_str += f"TOOL:{str(tool)}|" | |
| # Prefix (Warmup) | |
| params_str += f"PREFIX:{self.WARMUP_USER_TEXT}:{self.WARMUP_MODEL_TEXT}" | |
| fingerprint = hashlib.sha256(params_str.encode("utf-8")).hexdigest() | |
| # 3. Find or Create Cache | |
| cache_entry = self._GLOBAL_CACHE_MAP.get(fingerprint) | |
| cache_name = None | |
| if cache_entry: | |
| # Check if expired (with 10s buffer) | |
| expire_time = cache_entry.get('expire_time') | |
| if expire_time and datetime.now(timezone.utc) < (expire_time - timedelta(seconds=10)): | |
| cache_name = cache_entry.get('name') | |
| ttl_seconds = self.cache_config.ttl_seconds or 600 | |
| new_expire_time = datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds) | |
| try: | |
| # 1. Update API (Extend TTL) | |
| self.client.caches.update( | |
| name=cache_name, | |
| config=types.UpdateCachedContentConfig( | |
| ttl=self.cache_config.ttl_string | |
| ) | |
| ) | |
| logger.info(f"[Plugin] Global Cache Hit! ID: {cache_name} - Extended TTL to {new_expire_time.isoformat()}") | |
| # 2. Update Local Map & Save | |
| cache_entry['expire_time'] = new_expire_time | |
| self._save_to_disk() | |
| except Exception as e: | |
| logger.warning(f"[Plugin] Cache Hit, but failed to extend TTL: {e}") | |
| # Still use the cache even if extension fails | |
| else: | |
| logger.info(f"[Plugin] Global Cache Expired. ID: {cache_entry.get('name')}") | |
| # Remove from map? It will be overwritten anyway. | |
| if not cache_name: | |
| if not cache_entry: | |
| logger.info(f"[Plugin] Global Cache Miss. Creating new cache for fingerprint: {fingerprint[:8]}...") | |
| try: | |
| # Prepare content for cache (Prefix Only + System + Tools) | |
| cache_contents = llm_request.contents[:2] | |
| req_cache_config = types.CreateCachedContentConfig( | |
| contents=cache_contents, | |
| ttl=self.cache_config.ttl_string, | |
| ) | |
| if llm_request.config.system_instruction: | |
| req_cache_config.system_instruction = llm_request.config.system_instruction | |
| if llm_request.config.tools: | |
| req_cache_config.tools = llm_request.config.tools | |
| # Create Cache | |
| model_name = llm_request.model or "gemini-2.5-flash" | |
| cached_content = self.client.caches.create( | |
| model=model_name, | |
| config=req_cache_config | |
| ) | |
| cache_name = cached_content.name | |
| # Calculate Expiration | |
| # Use the configured TTL (as fallback if not in response) | |
| ttl_seconds = self.cache_config.ttl_seconds or 600 | |
| expire_time = datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds) | |
| # Prune map before adding new entry | |
| self._prune_map() | |
| # Update Map | |
| self._GLOBAL_CACHE_MAP[fingerprint] = { | |
| 'name': cache_name, | |
| 'expire_time': expire_time | |
| } | |
| self._save_to_disk() | |
| logger.info(f"[Plugin] Cache Created: {cache_name} (Expires: {expire_time.isoformat()})") | |
| except Exception as e: | |
| logger.error(f"[Plugin] Failed to create cache: {e}") | |
| return None | |
| # 4. Apply Cache to Request | |
| if cache_name: | |
| # Set cached_content | |
| llm_request.config.cached_content = cache_name | |
| # STRIP the Prefix from contents (since it's in the cache) | |
| llm_request.contents = llm_request.contents[2:] | |
| # Remove System Instruction & Tools from request (since in cache) | |
| llm_request.config.system_instruction = None | |
| llm_request.config.tools = None | |
| # Since we manually handled it, we don't want ADK to mess with it | |
| llm_request.cache_config = None # Clear this to disable built-in manager | |
| return None | |
| ### How to use it: | |
| # test_agent = Agent( | |
| # model=MODEL_NAME, | |
| # name='root_agent', | |
| # description="Tells the current time in a specified city.", | |
| # instruction= DYNAMIC_INSTRUCTION, | |
| # static_instruction=STATIC_SYSTEM_INSTRUCTION, | |
| # tools=[get_current_time], | |
| # ) | |
| # cache_config = ContextCacheConfig(min_tokens=4096, ttl_seconds=600) | |
| # app = App( | |
| # name="adk_agent", | |
| # root_agent=test_agent, | |
| # plugins=[GlobalCachingPlugin(context_cache_config=cache_config)], | |
| # ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment