Skip to content

Instantly share code, notes, and snippets.

@SubhiH
Created February 20, 2026 15:22
Show Gist options
  • Select an option

  • Save SubhiH/739750cc0837a37ba9835ae3c1fa1e30 to your computer and use it in GitHub Desktop.

Select an option

Save SubhiH/739750cc0837a37ba9835ae3c1fa1e30 to your computer and use it in GitHub Desktop.
from typing import Optional, Dict, Any
import hashlib
import json
import os
import logging
from datetime import datetime, timedelta, timezone
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.agents.callback_context import CallbackContext
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.genai import types
from google.genai import Client
logger = logging.getLogger(__name__)
from google.adk.agents.context_cache_config import ContextCacheConfig
class GlobalCachingPlugin(BasePlugin):
"""
A plugin that implements GLOBAL Context Caching by enforcing a stable
prefix (Warmup Turn) and manually managing cache creation/reuse across sessions.
"""
# We use minimal tokens for the warmup to save cost/latency
WARMUP_USER_TEXT = "#"
WARMUP_MODEL_TEXT = "#"
# Limit the number of cache entries in memory/file to prevent unbounded growth
MAX_CACHE_ENTRIES = 50
CACHE_FILE_PATH = os.path.abspath(os.path.expanduser("./.adk/global_cache.json"))
# Global Cache Map: Fingerprint -> {'name': str, 'expire_time': datetime}
# Shared across all instances of the plugin
_GLOBAL_CACHE_MAP: Dict[str, Dict[str, Any]] = {}
def __init__(self, api_key: Optional[str] = None, context_cache_config: Optional[ContextCacheConfig] = None):
super().__init__(name="global_caching_plugin")
self.api_key = api_key or os.environ.get("GOOGLE_API_KEY")
if not self.api_key:
logger.warning("GlobalCachingPlugin initialized without API Key. Cache creation might fail.")
self.cache_config = context_cache_config or ContextCacheConfig(ttl_seconds=600)
# Initialize Client
self.client = Client(api_key=self.api_key)
# Load cache map from disk
logger.info(f"[Plugin] Cache File Path: {self.CACHE_FILE_PATH}")
self._load_from_disk()
def _load_from_disk(self):
"""Loads the cache map from JSON file."""
if not os.path.exists(self.CACHE_FILE_PATH):
return
try:
with open(self.CACHE_FILE_PATH, 'r') as f:
data = json.load(f)
# Parse datetimes and populate map
for fingerprint, entry in data.items():
if 'expire_time' in entry:
# ISO format to datetime
entry['expire_time'] = datetime.fromisoformat(entry['expire_time'])
self._GLOBAL_CACHE_MAP[fingerprint] = entry
logger.info(f"[Plugin] Loaded {len(self._GLOBAL_CACHE_MAP)} cache entries from {self.CACHE_FILE_PATH}")
except Exception as e:
logger.warning(f"[Plugin] Failed to load cache from disk: {e}")
def _save_to_disk(self):
"""Saves the cache map to JSON file."""
try:
# Ensure directory exists
os.makedirs(os.path.dirname(self.CACHE_FILE_PATH), exist_ok=True)
# Convert datetimes to string for JSON
data_to_save = {}
for fp, entry in self._GLOBAL_CACHE_MAP.items():
entry_copy = entry.copy()
if isinstance(entry_copy.get('expire_time'), datetime):
entry_copy['expire_time'] = entry_copy['expire_time'].isoformat()
data_to_save[fp] = entry_copy
with open(self.CACHE_FILE_PATH, 'w') as f:
json.dump(data_to_save, f, indent=2)
logger.info(f"[Plugin] Saved cache map to: {self.CACHE_FILE_PATH}")
except Exception as e:
logger.error(f"[Plugin] Failed to save cache to disk: {e}")
def _prune_map(self):
"""
Removes expired entries and enforces MAX_CACHE_ENTRIES.
Saves to disk afterwards.
"""
now = datetime.now(timezone.utc)
changed = False
# 1. Provide a list of keys to remove
keys_to_remove = []
for fp, entry in self._GLOBAL_CACHE_MAP.items():
expire_time = entry.get('expire_time')
# Buffer of 10s
if expire_time and now > (expire_time - timedelta(seconds=10)):
keys_to_remove.append(fp)
if keys_to_remove:
logger.info(f"[Plugin] Pruning {len(keys_to_remove)} expired cache entries.")
for k in keys_to_remove:
del self._GLOBAL_CACHE_MAP[k]
changed = True
# 2. If still too big, remove oldest (FIFO)
while len(self._GLOBAL_CACHE_MAP) >= self.MAX_CACHE_ENTRIES:
oldest_key = next(iter(self._GLOBAL_CACHE_MAP))
del self._GLOBAL_CACHE_MAP[oldest_key]
logger.info(f"[Plugin] Pruning oldest cache entry to enforce limit {self.MAX_CACHE_ENTRIES}.")
changed = True
# 3. Always save if we pruned OR if we are calling this before adding (which implies we want to sync)
# Actually simplest to just save indiscriminately or if changed?
# The caller usually calls this before adding.
# Let's save if changed here, AND the caller should trigger save after adding.
if changed:
self._save_to_disk()
def _estimate_request_tokens(self, llm_request: LlmRequest) -> int:
"""
Rough estimate of tokens to check against min_tokens.
Logic adapted from GeminiContextCacheManager.
"""
total_chars = 0
# System instruction
if llm_request.config and llm_request.config.system_instruction:
total_chars += len(llm_request.config.system_instruction)
# Tools
if llm_request.config and llm_request.config.tools:
for tool in llm_request.config.tools:
# Simple serialization for estimation
tool_str = str(tool)
total_chars += len(tool_str)
# Contents (including the ones we are about to inject or just general contents)
# Note: At this point, we haven't injected the prefix yet in the caller,
# but we should count the stable prefix as part of the "cacheable" content.
total_chars += len(self.WARMUP_USER_TEXT) + len(self.WARMUP_MODEL_TEXT)
# We only care about the STABLE part for the cache creation check?
# Or the Total request?
# Usually caching is worth it if the *Cached* part is large.
# Here the cached part is System + Tools + Prefix.
# The user's dynamic query is NOT cached.
# So we should only count System + Tools + Prefix.
return total_chars // 4
async def before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
"""
Intercepts the LLM request to:
1. Inject Warmup Prefix
2. Calculate Global Fingerprint
3. Find or Create Cache
4. Apply Cache to Request (and strip contents)
5. Disable Built-in Cache Manager
"""
# 0. Check min_tokens (Cost Optimization)
# We only cache the Stable Part (System + Tools + Prefix).
# Use simple estimation.
estimated_tokens = self._estimate_request_tokens(llm_request)
if estimated_tokens < self.cache_config.min_tokens:
logger.info(f"[Plugin] Skipping Global Cache. Estimated stable output {estimated_tokens} < min_tokens {self.cache_config.min_tokens}")
return None
# 1. Inject Warmup Prefix if not present
# Q: Why do we need this Prefix?
# A: The Gemini API explicitly forbids creating a cache with ONLY System Instructions.
# It requires the 'contents' list to have at least one turn.
# Since we want a global cache that works for ANY user query, we cannot include
# the user's actual first query (which changes every time).
# Therefore, we inject this minimal, stable "Warmup Turn" (# / #) to satisfy
# the API requirement while keeping the cache fingerprint stable.
has_prefix = False
if len(llm_request.contents) >= 2:
first = llm_request.contents[0]
if (first.role == 'user' and first.parts and
first.parts[0].text == self.WARMUP_USER_TEXT):
has_prefix = True
if not has_prefix:
logger.info("[Plugin] Injecting Warmup Prefix")
warmup_user = types.Content(role='user', parts=[types.Part(text=self.WARMUP_USER_TEXT)])
warmup_model = types.Content(role='model', parts=[types.Part(text=self.WARMUP_MODEL_TEXT)])
llm_request.contents.insert(0, warmup_model)
llm_request.contents.insert(0, warmup_user)
# 2. Calculate Fingerprint
params_str = ""
# System Instruction
if llm_request.config and llm_request.config.system_instruction:
params_str += f"SYS:{llm_request.config.system_instruction}|"
# Tools
if llm_request.config and llm_request.config.tools:
for tool in llm_request.config.tools:
params_str += f"TOOL:{str(tool)}|"
# Prefix (Warmup)
params_str += f"PREFIX:{self.WARMUP_USER_TEXT}:{self.WARMUP_MODEL_TEXT}"
fingerprint = hashlib.sha256(params_str.encode("utf-8")).hexdigest()
# 3. Find or Create Cache
cache_entry = self._GLOBAL_CACHE_MAP.get(fingerprint)
cache_name = None
if cache_entry:
# Check if expired (with 10s buffer)
expire_time = cache_entry.get('expire_time')
if expire_time and datetime.now(timezone.utc) < (expire_time - timedelta(seconds=10)):
cache_name = cache_entry.get('name')
ttl_seconds = self.cache_config.ttl_seconds or 600
new_expire_time = datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)
try:
# 1. Update API (Extend TTL)
self.client.caches.update(
name=cache_name,
config=types.UpdateCachedContentConfig(
ttl=self.cache_config.ttl_string
)
)
logger.info(f"[Plugin] Global Cache Hit! ID: {cache_name} - Extended TTL to {new_expire_time.isoformat()}")
# 2. Update Local Map & Save
cache_entry['expire_time'] = new_expire_time
self._save_to_disk()
except Exception as e:
logger.warning(f"[Plugin] Cache Hit, but failed to extend TTL: {e}")
# Still use the cache even if extension fails
else:
logger.info(f"[Plugin] Global Cache Expired. ID: {cache_entry.get('name')}")
# Remove from map? It will be overwritten anyway.
if not cache_name:
if not cache_entry:
logger.info(f"[Plugin] Global Cache Miss. Creating new cache for fingerprint: {fingerprint[:8]}...")
try:
# Prepare content for cache (Prefix Only + System + Tools)
cache_contents = llm_request.contents[:2]
req_cache_config = types.CreateCachedContentConfig(
contents=cache_contents,
ttl=self.cache_config.ttl_string,
)
if llm_request.config.system_instruction:
req_cache_config.system_instruction = llm_request.config.system_instruction
if llm_request.config.tools:
req_cache_config.tools = llm_request.config.tools
# Create Cache
model_name = llm_request.model or "gemini-2.5-flash"
cached_content = self.client.caches.create(
model=model_name,
config=req_cache_config
)
cache_name = cached_content.name
# Calculate Expiration
# Use the configured TTL (as fallback if not in response)
ttl_seconds = self.cache_config.ttl_seconds or 600
expire_time = datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)
# Prune map before adding new entry
self._prune_map()
# Update Map
self._GLOBAL_CACHE_MAP[fingerprint] = {
'name': cache_name,
'expire_time': expire_time
}
self._save_to_disk()
logger.info(f"[Plugin] Cache Created: {cache_name} (Expires: {expire_time.isoformat()})")
except Exception as e:
logger.error(f"[Plugin] Failed to create cache: {e}")
return None
# 4. Apply Cache to Request
if cache_name:
# Set cached_content
llm_request.config.cached_content = cache_name
# STRIP the Prefix from contents (since it's in the cache)
llm_request.contents = llm_request.contents[2:]
# Remove System Instruction & Tools from request (since in cache)
llm_request.config.system_instruction = None
llm_request.config.tools = None
# Since we manually handled it, we don't want ADK to mess with it
llm_request.cache_config = None # Clear this to disable built-in manager
return None
### How to use it:
# test_agent = Agent(
# model=MODEL_NAME,
# name='root_agent',
# description="Tells the current time in a specified city.",
# instruction= DYNAMIC_INSTRUCTION,
# static_instruction=STATIC_SYSTEM_INSTRUCTION,
# tools=[get_current_time],
# )
# cache_config = ContextCacheConfig(min_tokens=4096, ttl_seconds=600)
# app = App(
# name="adk_agent",
# root_agent=test_agent,
# plugins=[GlobalCachingPlugin(context_cache_config=cache_config)],
# )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment