Skip to content

Instantly share code, notes, and snippets.

@etahamad
Last active December 12, 2025 23:11
Show Gist options
  • Select an option

  • Save etahamad/d037ff28e54ad5ceec144f4cf5eb1d1d to your computer and use it in GitHub Desktop.

Select an option

Save etahamad/d037ff28e54ad5ceec144f4cf5eb1d1d to your computer and use it in GitHub Desktop.
"""
GeminiImageDirectNode - Direct Gemini API Integration for ComfyUI
Clean implementation using google-genai SDK 1.55.0+ with no fallbacks or retries.
"""
import os
import base64
import time
import traceback
from io import BytesIO
from typing import Optional, Tuple, Literal
from concurrent.futures import ThreadPoolExecutor
try:
import google.genai as genai
from google.genai.types import Content, Part, GenerateContentConfig, ImageConfig, Blob, HttpOptions
except ImportError:
raise ImportError(
"google-genai package not found. "
"Please install it with: pip install google-genai>=1.55.0"
)
from PIL import Image
import numpy as np
import torch
GEMINI_IMAGE_SYS_PROMPT = (
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
"Interpret all user input—regardless of "
"format, intent, or abstraction—as literal visual directives for image composition.\n"
"If a prompt is conversational or lacks specific visual details, "
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
)
def get_number_of_images(images):
"""Get number of images from tensor or list."""
if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1
return len(images) if images else 0
def tensor_to_image_bytes(image_tensor: torch.Tensor, mime_type: str = "image/png") -> bytes:
"""Convert [B, H, W, C] or [H, W, C] tensor to raw image bytes."""
if len(image_tensor.shape) > 3:
image_tensor = image_tensor[0]
# Use GPU if available for tensor operations
device = "cuda" if torch.cuda.is_available() else "cpu"
if image_tensor.device.type != device:
image_tensor = image_tensor.to(device)
# Convert to numpy array (ensure contiguous and correct dtype)
image_np = (image_tensor.cpu().numpy() * 255).astype(np.uint8)
if not image_np.flags['C_CONTIGUOUS']:
image_np = np.ascontiguousarray(image_np)
# Convert to PIL Image
pil_image = Image.fromarray(image_np)
# Save to bytes
img_byte_arr = BytesIO()
pil_format = mime_type.split("/")[-1].upper()
if pil_format == "JPG":
pil_format = "JPEG"
pil_image.save(img_byte_arr, format=pil_format)
return img_byte_arr.getvalue()
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
"""Convert image bytes to torch tensor."""
image = Image.open(image_bytesio)
image = image.convert(mode)
image_array = np.array(image).astype(np.float32) / 255.0
# Ensure contiguous array
if not image_array.flags['C_CONTIGUOUS']:
image_array = np.ascontiguousarray(image_array)
return torch.from_numpy(image_array).unsqueeze(0)
def get_parts_by_type(response, part_type: Literal["text"] | str):
"""Filter response parts by their type."""
if not hasattr(response, 'candidates') or not response.candidates:
if hasattr(response, 'prompt_feedback') and response.prompt_feedback:
feedback = response.prompt_feedback
if hasattr(feedback, 'block_reason') and feedback.block_reason:
raise ValueError(
f"Gemini API blocked the request. Reason: {feedback.block_reason}"
)
raise ValueError(
"Gemini API returned no response candidates. If you are using the `IMAGE` modality, "
"try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed."
)
parts = []
candidate = response.candidates[0]
if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
for part in candidate.content.parts:
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif hasattr(part, "inline_data") and part.inline_data:
mime_type = getattr(part.inline_data, 'mime_type', None)
if mime_type == part_type:
parts.append(part)
return parts
def get_text_from_response(response) -> str:
"""Extract and concatenate all text parts from the response."""
parts = get_parts_by_type(response, "text")
return "\n".join([part.text for part in parts])
def get_image_from_response(response) -> torch.Tensor:
"""Extract images from response and return as tensor."""
image_tensors = []
# Try to get image/png parts first
parts = get_parts_by_type(response, "image/png")
# If no PNG images, try other image MIME types
if len(parts) == 0:
for mime_type in ["image/jpeg", "image/jpg", "image/webp"]:
parts = get_parts_by_type(response, mime_type)
if len(parts) > 0:
break
for part in parts:
if hasattr(part, 'inline_data') and part.inline_data:
image_data = part.inline_data.data
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
image_bytes = image_data
try:
returned_image = bytesio_to_image_tensor(BytesIO(image_bytes))
image_tensors.append(returned_image)
except Exception as e:
print(f"GeminiImageDirectNode - Warning: Failed to decode image part: {e}")
continue
if len(image_tensors) == 0:
print("GeminiImageDirectNode - Warning: No images found in response, returning blank image")
return torch.zeros((1, 1024, 1024, 4))
return torch.cat(image_tensors, dim=0)
def validate_string(prompt: str, strip_whitespace: bool = True, min_length: int = 1):
"""Validate prompt string."""
if prompt is None:
raise ValueError("Prompt cannot be empty.")
if strip_whitespace:
prompt = prompt.strip()
if len(prompt) < min_length:
raise ValueError(f"Prompt cannot be shorter than {min_length} characters.")
return prompt
class GeminiImageDirectNode:
"""
Direct Gemini API node for ComfyUI.
Uses GEMINI_API_KEY environment variable for authentication.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING", {"multiline": True, "default": ""}),
"model": ("STRING", {"default": "gemini-2.5-flash-image"}),
"seed": ("INT", {"default": 42, "min": -2147483648, "max": 2147483647}),
"aspect_ratio": (["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], {"default": "auto"}),
"resolution": (["1K", "2K", "4K"], {"default": "2K"}),
"response_modalities": (["IMAGE+TEXT", "IMAGE"], {"default": "IMAGE+TEXT"}),
},
"optional": {
"images": ("IMAGE", {"forceInput": True}),
"system_prompt": ("STRING", {"multiline": True, "default": GEMINI_IMAGE_SYS_PROMPT}),
}
}
RETURN_TYPES = ("IMAGE", "STRING")
RETURN_NAMES = ("image", "text")
FUNCTION = "generate"
CATEGORY = "image/generation"
def __init__(self):
self.api_key = None
self.client = None
self._initialized = False
def _initialize_gemini(self):
"""Initialize Gemini client with API key from environment."""
if self._initialized and self.api_key == os.environ.get("GEMINI_API_KEY"):
return
self.api_key = os.environ.get("GEMINI_API_KEY")
if not self.api_key:
raise ValueError(
"GEMINI_API_KEY environment variable not set. "
"Please set it before using GeminiImageDirectNode."
)
# Initialize client with extended timeout for large image generation
# Default: 15 minutes (900000ms), can be overridden via GEMINI_API_TIMEOUT_MS env var
# For 4K images with pro models, may need up to 20-30 minutes
# Set GEMINI_API_TIMEOUT_MS=1800000 for 30 minutes if needed
default_timeout_ms = 900000 # 15 minutes (increased for large 4K images)
timeout_ms = int(os.environ.get("GEMINI_API_TIMEOUT_MS", default_timeout_ms))
http_options = HttpOptions(timeout=timeout_ms)
self.client = genai.Client(api_key=self.api_key, http_options=http_options)
self._initialized = True
sdk_version = getattr(genai, '__version__', 'unknown')
timeout_seconds = timeout_ms / 1000
print(f"GeminiImageDirectNode - Initialized (SDK: {sdk_version}, timeout: {timeout_seconds}s)")
def _create_image_parts(self, images: torch.Tensor) -> list:
"""Create image parts for Gemini API using Blob with raw bytes."""
image_parts = []
total_images = get_number_of_images(images)
if total_images <= 0:
return image_parts
# Use ThreadPoolExecutor for parallel encoding (up to 12 workers for 12 vCPUs)
max_workers = min(12, total_images)
def encode_image(idx):
image_tensor = images[idx] if images.ndim >= 4 else images
image_bytes = tensor_to_image_bytes(image_tensor, mime_type="image/png")
return Part(inline_data=Blob(mime_type="image/png", data=image_bytes))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
image_parts = list(executor.map(encode_image, range(total_images)))
return image_parts
def generate(
self,
prompt: str,
model: str = "gemini-2.5-flash-image",
seed: int = 42,
aspect_ratio: str = "auto",
resolution: str = "2K",
response_modalities: str = "IMAGE+TEXT",
images: Optional[torch.Tensor] = None,
system_prompt: str = GEMINI_IMAGE_SYS_PROMPT,
) -> Tuple[torch.Tensor, str]:
"""Generate image using Gemini API directly."""
validate_string(prompt, strip_whitespace=True, min_length=1)
self._initialize_gemini()
# Build contents - matching SDK test patterns exactly
# SDK tests show contents can be:
# 1. A string: contents='text'
# 2. A list: contents=['text', {'inline_data': {'data': bytes}}]
# 3. Content objects: contents=[Content(role='user', parts=[...])]
# We'll use Content objects for clarity, matching test_image_base64 pattern
parts = [Part(text=prompt)]
# Add images - matching SDK test pattern: Part(inline_data=Blob(data=bytes, mime_type='image/png'))
if images is not None:
if get_number_of_images(images) > 14:
raise ValueError("The current maximum number of supported images is 14.")
parts.extend(self._create_image_parts(images))
# Build contents list - SDK accepts Content objects or strings
# Matching SDK test pattern: contents=[Content(role='user', parts=[...])]
user_content = Content(role="user", parts=parts)
contents = [user_content]
# Build ImageConfig
# Note: gemini-3-pro-image-preview supports image_size (1K, 2K, 4K) per SDK docs
# gemini-2.5-flash-image may not support image_size, so we only add it for pro models
image_config_obj = None
image_config_kwargs = {}
if aspect_ratio != "auto":
image_config_kwargs["aspect_ratio"] = aspect_ratio
# Add image_size for pro models (gemini-3-pro-image-preview supports it per SDK docs)
if resolution and "pro" in model.lower():
image_config_kwargs["image_size"] = resolution
if image_config_kwargs:
image_config_obj = ImageConfig(**image_config_kwargs)
# Build GenerateContentConfig - matching SDK test pattern exactly
# SDK test pattern: GenerateContentConfig(response_modalities=['IMAGE'], image_config=ImageConfig(aspect_ratio='16:9'))
config_kwargs = {}
# response_modalities: ["IMAGE"] or ["TEXT", "IMAGE"] - required for image generation
if response_modalities:
config_kwargs["response_modalities"] = (
["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]
)
# image_config - only aspect_ratio per SDK tests
if image_config_obj:
config_kwargs["image_config"] = image_config_obj
# system_instruction - SDK accepts string or Content, but tests use Content via t.t_content()
# For simplicity, we'll use string (SDK handles conversion)
if system_prompt and system_prompt.strip():
config_kwargs["system_instruction"] = system_prompt.strip()
# seed - SDK test shows: config={'seed': 42} as int
# Clamp to INT32 range
if seed != 0:
INT32_MAX = 2147483647
INT32_MIN = -2147483648
clamped_seed = seed
if seed > INT32_MAX:
print(f"GeminiImageDirectNode - Warning: Seed {seed} exceeds INT32 max, clamping to {INT32_MAX}")
clamped_seed = INT32_MAX
elif seed < INT32_MIN:
print(f"GeminiImageDirectNode - Warning: Seed {seed} below INT32 min, clamping to {INT32_MIN}")
clamped_seed = INT32_MIN
config_kwargs["seed"] = int(clamped_seed)
generation_config = GenerateContentConfig(**config_kwargs) if config_kwargs else None
# Call API
print(f"GeminiImageDirectNode - Calling {model} with {len(parts)} part(s)")
if generation_config:
print(f"GeminiImageDirectNode - Config: {generation_config.model_dump(exclude_none=True)}")
start_time = time.time()
try:
if generation_config:
response = self.client.models.generate_content(
model=model,
contents=contents,
config=generation_config,
)
else:
response = self.client.models.generate_content(
model=model,
contents=contents,
)
except Exception as e:
# Print detailed error information
error_type = type(e).__name__
error_msg = str(e)
print(f"GeminiImageDirectNode - Error: {error_type}: {error_msg}")
# Try to get more details from the exception
if hasattr(e, 'response') and hasattr(e.response, 'json'):
try:
error_json = e.response.json()
print(f"GeminiImageDirectNode - Error details: {error_json}")
except:
pass
# Try to get error details from the exception object itself
if hasattr(e, 'error'):
print(f"GeminiImageDirectNode - Exception error attribute: {e.error}")
# Print config for debugging
if generation_config:
try:
config_dict = generation_config.model_dump(exclude_none=True)
print(f"GeminiImageDirectNode - Config that failed: {config_dict}")
except Exception as dump_e:
print(f"GeminiImageDirectNode - Could not dump config: {dump_e}")
# Print contents structure for debugging
try:
print(f"GeminiImageDirectNode - Contents: {len(contents)} content(s)")
for i, content in enumerate(contents):
print(f"GeminiImageDirectNode - Content {i}: role={content.role}, parts={len(content.parts) if content.parts else 0}")
if content.parts:
for j, part in enumerate(content.parts):
part_type = "text" if hasattr(part, 'text') and part.text else "inline_data" if hasattr(part, 'inline_data') and part.inline_data else "unknown"
print(f"GeminiImageDirectNode - Part {j}: type={part_type}")
except Exception as contents_e:
print(f"GeminiImageDirectNode - Could not print contents: {contents_e}")
raise
api_time = time.time() - start_time
# Parse response
output_image = get_image_from_response(response)
output_text = get_text_from_response(response)
# Log response details for debugging
num_images = output_image.shape[0] if isinstance(output_image, torch.Tensor) else 0
image_shape = output_image.shape if isinstance(output_image, torch.Tensor) else "N/A"
print(f"GeminiImageDirectNode - Extracted {num_images} image(s), shape: {image_shape}")
# Add timing info to text output
timing_info = f"\n\n[API call: {api_time:.2f}s]"
output_text = output_text + timing_info if output_text else timing_info.strip()
print(f"GeminiImageDirectNode - Generation complete ({api_time:.2f}s)")
return (output_image, output_text)
NODE_CLASS_MAPPINGS = {
"GeminiImageDirectNode": GeminiImageDirectNode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"GeminiImageDirectNode": "Gemini Image Direct"
}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment