Last active
December 12, 2025 23:11
-
-
Save etahamad/d037ff28e54ad5ceec144f4cf5eb1d1d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| GeminiImageDirectNode - Direct Gemini API Integration for ComfyUI | |
| Clean implementation using google-genai SDK 1.55.0+ with no fallbacks or retries. | |
| """ | |
| import os | |
| import base64 | |
| import time | |
| import traceback | |
| from io import BytesIO | |
| from typing import Optional, Tuple, Literal | |
| from concurrent.futures import ThreadPoolExecutor | |
| try: | |
| import google.genai as genai | |
| from google.genai.types import Content, Part, GenerateContentConfig, ImageConfig, Blob, HttpOptions | |
| except ImportError: | |
| raise ImportError( | |
| "google-genai package not found. " | |
| "Please install it with: pip install google-genai>=1.55.0" | |
| ) | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| GEMINI_IMAGE_SYS_PROMPT = ( | |
| "You are an expert image-generation engine. You must ALWAYS produce an image.\n" | |
| "Interpret all user input—regardless of " | |
| "format, intent, or abstraction—as literal visual directives for image composition.\n" | |
| "If a prompt is conversational or lacks specific visual details, " | |
| "you must creatively invent a concrete visual scenario that depicts the concept.\n" | |
| "Prioritize generating the visual representation above any text, formatting, or conversational requests." | |
| ) | |
| def get_number_of_images(images): | |
| """Get number of images from tensor or list.""" | |
| if isinstance(images, torch.Tensor): | |
| return images.shape[0] if images.ndim >= 4 else 1 | |
| return len(images) if images else 0 | |
| def tensor_to_image_bytes(image_tensor: torch.Tensor, mime_type: str = "image/png") -> bytes: | |
| """Convert [B, H, W, C] or [H, W, C] tensor to raw image bytes.""" | |
| if len(image_tensor.shape) > 3: | |
| image_tensor = image_tensor[0] | |
| # Use GPU if available for tensor operations | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if image_tensor.device.type != device: | |
| image_tensor = image_tensor.to(device) | |
| # Convert to numpy array (ensure contiguous and correct dtype) | |
| image_np = (image_tensor.cpu().numpy() * 255).astype(np.uint8) | |
| if not image_np.flags['C_CONTIGUOUS']: | |
| image_np = np.ascontiguousarray(image_np) | |
| # Convert to PIL Image | |
| pil_image = Image.fromarray(image_np) | |
| # Save to bytes | |
| img_byte_arr = BytesIO() | |
| pil_format = mime_type.split("/")[-1].upper() | |
| if pil_format == "JPG": | |
| pil_format = "JPEG" | |
| pil_image.save(img_byte_arr, format=pil_format) | |
| return img_byte_arr.getvalue() | |
| def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: | |
| """Convert image bytes to torch tensor.""" | |
| image = Image.open(image_bytesio) | |
| image = image.convert(mode) | |
| image_array = np.array(image).astype(np.float32) / 255.0 | |
| # Ensure contiguous array | |
| if not image_array.flags['C_CONTIGUOUS']: | |
| image_array = np.ascontiguousarray(image_array) | |
| return torch.from_numpy(image_array).unsqueeze(0) | |
| def get_parts_by_type(response, part_type: Literal["text"] | str): | |
| """Filter response parts by their type.""" | |
| if not hasattr(response, 'candidates') or not response.candidates: | |
| if hasattr(response, 'prompt_feedback') and response.prompt_feedback: | |
| feedback = response.prompt_feedback | |
| if hasattr(feedback, 'block_reason') and feedback.block_reason: | |
| raise ValueError( | |
| f"Gemini API blocked the request. Reason: {feedback.block_reason}" | |
| ) | |
| raise ValueError( | |
| "Gemini API returned no response candidates. If you are using the `IMAGE` modality, " | |
| "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed." | |
| ) | |
| parts = [] | |
| candidate = response.candidates[0] | |
| if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'): | |
| for part in candidate.content.parts: | |
| if part_type == "text" and hasattr(part, "text") and part.text: | |
| parts.append(part) | |
| elif hasattr(part, "inline_data") and part.inline_data: | |
| mime_type = getattr(part.inline_data, 'mime_type', None) | |
| if mime_type == part_type: | |
| parts.append(part) | |
| return parts | |
| def get_text_from_response(response) -> str: | |
| """Extract and concatenate all text parts from the response.""" | |
| parts = get_parts_by_type(response, "text") | |
| return "\n".join([part.text for part in parts]) | |
| def get_image_from_response(response) -> torch.Tensor: | |
| """Extract images from response and return as tensor.""" | |
| image_tensors = [] | |
| # Try to get image/png parts first | |
| parts = get_parts_by_type(response, "image/png") | |
| # If no PNG images, try other image MIME types | |
| if len(parts) == 0: | |
| for mime_type in ["image/jpeg", "image/jpg", "image/webp"]: | |
| parts = get_parts_by_type(response, mime_type) | |
| if len(parts) > 0: | |
| break | |
| for part in parts: | |
| if hasattr(part, 'inline_data') and part.inline_data: | |
| image_data = part.inline_data.data | |
| if isinstance(image_data, str): | |
| image_bytes = base64.b64decode(image_data) | |
| else: | |
| image_bytes = image_data | |
| try: | |
| returned_image = bytesio_to_image_tensor(BytesIO(image_bytes)) | |
| image_tensors.append(returned_image) | |
| except Exception as e: | |
| print(f"GeminiImageDirectNode - Warning: Failed to decode image part: {e}") | |
| continue | |
| if len(image_tensors) == 0: | |
| print("GeminiImageDirectNode - Warning: No images found in response, returning blank image") | |
| return torch.zeros((1, 1024, 1024, 4)) | |
| return torch.cat(image_tensors, dim=0) | |
| def validate_string(prompt: str, strip_whitespace: bool = True, min_length: int = 1): | |
| """Validate prompt string.""" | |
| if prompt is None: | |
| raise ValueError("Prompt cannot be empty.") | |
| if strip_whitespace: | |
| prompt = prompt.strip() | |
| if len(prompt) < min_length: | |
| raise ValueError(f"Prompt cannot be shorter than {min_length} characters.") | |
| return prompt | |
| class GeminiImageDirectNode: | |
| """ | |
| Direct Gemini API node for ComfyUI. | |
| Uses GEMINI_API_KEY environment variable for authentication. | |
| """ | |
| @classmethod | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "prompt": ("STRING", {"multiline": True, "default": ""}), | |
| "model": ("STRING", {"default": "gemini-2.5-flash-image"}), | |
| "seed": ("INT", {"default": 42, "min": -2147483648, "max": 2147483647}), | |
| "aspect_ratio": (["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], {"default": "auto"}), | |
| "resolution": (["1K", "2K", "4K"], {"default": "2K"}), | |
| "response_modalities": (["IMAGE+TEXT", "IMAGE"], {"default": "IMAGE+TEXT"}), | |
| }, | |
| "optional": { | |
| "images": ("IMAGE", {"forceInput": True}), | |
| "system_prompt": ("STRING", {"multiline": True, "default": GEMINI_IMAGE_SYS_PROMPT}), | |
| } | |
| } | |
| RETURN_TYPES = ("IMAGE", "STRING") | |
| RETURN_NAMES = ("image", "text") | |
| FUNCTION = "generate" | |
| CATEGORY = "image/generation" | |
| def __init__(self): | |
| self.api_key = None | |
| self.client = None | |
| self._initialized = False | |
| def _initialize_gemini(self): | |
| """Initialize Gemini client with API key from environment.""" | |
| if self._initialized and self.api_key == os.environ.get("GEMINI_API_KEY"): | |
| return | |
| self.api_key = os.environ.get("GEMINI_API_KEY") | |
| if not self.api_key: | |
| raise ValueError( | |
| "GEMINI_API_KEY environment variable not set. " | |
| "Please set it before using GeminiImageDirectNode." | |
| ) | |
| # Initialize client with extended timeout for large image generation | |
| # Default: 15 minutes (900000ms), can be overridden via GEMINI_API_TIMEOUT_MS env var | |
| # For 4K images with pro models, may need up to 20-30 minutes | |
| # Set GEMINI_API_TIMEOUT_MS=1800000 for 30 minutes if needed | |
| default_timeout_ms = 900000 # 15 minutes (increased for large 4K images) | |
| timeout_ms = int(os.environ.get("GEMINI_API_TIMEOUT_MS", default_timeout_ms)) | |
| http_options = HttpOptions(timeout=timeout_ms) | |
| self.client = genai.Client(api_key=self.api_key, http_options=http_options) | |
| self._initialized = True | |
| sdk_version = getattr(genai, '__version__', 'unknown') | |
| timeout_seconds = timeout_ms / 1000 | |
| print(f"GeminiImageDirectNode - Initialized (SDK: {sdk_version}, timeout: {timeout_seconds}s)") | |
| def _create_image_parts(self, images: torch.Tensor) -> list: | |
| """Create image parts for Gemini API using Blob with raw bytes.""" | |
| image_parts = [] | |
| total_images = get_number_of_images(images) | |
| if total_images <= 0: | |
| return image_parts | |
| # Use ThreadPoolExecutor for parallel encoding (up to 12 workers for 12 vCPUs) | |
| max_workers = min(12, total_images) | |
| def encode_image(idx): | |
| image_tensor = images[idx] if images.ndim >= 4 else images | |
| image_bytes = tensor_to_image_bytes(image_tensor, mime_type="image/png") | |
| return Part(inline_data=Blob(mime_type="image/png", data=image_bytes)) | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| image_parts = list(executor.map(encode_image, range(total_images))) | |
| return image_parts | |
| def generate( | |
| self, | |
| prompt: str, | |
| model: str = "gemini-2.5-flash-image", | |
| seed: int = 42, | |
| aspect_ratio: str = "auto", | |
| resolution: str = "2K", | |
| response_modalities: str = "IMAGE+TEXT", | |
| images: Optional[torch.Tensor] = None, | |
| system_prompt: str = GEMINI_IMAGE_SYS_PROMPT, | |
| ) -> Tuple[torch.Tensor, str]: | |
| """Generate image using Gemini API directly.""" | |
| validate_string(prompt, strip_whitespace=True, min_length=1) | |
| self._initialize_gemini() | |
| # Build contents - matching SDK test patterns exactly | |
| # SDK tests show contents can be: | |
| # 1. A string: contents='text' | |
| # 2. A list: contents=['text', {'inline_data': {'data': bytes}}] | |
| # 3. Content objects: contents=[Content(role='user', parts=[...])] | |
| # We'll use Content objects for clarity, matching test_image_base64 pattern | |
| parts = [Part(text=prompt)] | |
| # Add images - matching SDK test pattern: Part(inline_data=Blob(data=bytes, mime_type='image/png')) | |
| if images is not None: | |
| if get_number_of_images(images) > 14: | |
| raise ValueError("The current maximum number of supported images is 14.") | |
| parts.extend(self._create_image_parts(images)) | |
| # Build contents list - SDK accepts Content objects or strings | |
| # Matching SDK test pattern: contents=[Content(role='user', parts=[...])] | |
| user_content = Content(role="user", parts=parts) | |
| contents = [user_content] | |
| # Build ImageConfig | |
| # Note: gemini-3-pro-image-preview supports image_size (1K, 2K, 4K) per SDK docs | |
| # gemini-2.5-flash-image may not support image_size, so we only add it for pro models | |
| image_config_obj = None | |
| image_config_kwargs = {} | |
| if aspect_ratio != "auto": | |
| image_config_kwargs["aspect_ratio"] = aspect_ratio | |
| # Add image_size for pro models (gemini-3-pro-image-preview supports it per SDK docs) | |
| if resolution and "pro" in model.lower(): | |
| image_config_kwargs["image_size"] = resolution | |
| if image_config_kwargs: | |
| image_config_obj = ImageConfig(**image_config_kwargs) | |
| # Build GenerateContentConfig - matching SDK test pattern exactly | |
| # SDK test pattern: GenerateContentConfig(response_modalities=['IMAGE'], image_config=ImageConfig(aspect_ratio='16:9')) | |
| config_kwargs = {} | |
| # response_modalities: ["IMAGE"] or ["TEXT", "IMAGE"] - required for image generation | |
| if response_modalities: | |
| config_kwargs["response_modalities"] = ( | |
| ["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"] | |
| ) | |
| # image_config - only aspect_ratio per SDK tests | |
| if image_config_obj: | |
| config_kwargs["image_config"] = image_config_obj | |
| # system_instruction - SDK accepts string or Content, but tests use Content via t.t_content() | |
| # For simplicity, we'll use string (SDK handles conversion) | |
| if system_prompt and system_prompt.strip(): | |
| config_kwargs["system_instruction"] = system_prompt.strip() | |
| # seed - SDK test shows: config={'seed': 42} as int | |
| # Clamp to INT32 range | |
| if seed != 0: | |
| INT32_MAX = 2147483647 | |
| INT32_MIN = -2147483648 | |
| clamped_seed = seed | |
| if seed > INT32_MAX: | |
| print(f"GeminiImageDirectNode - Warning: Seed {seed} exceeds INT32 max, clamping to {INT32_MAX}") | |
| clamped_seed = INT32_MAX | |
| elif seed < INT32_MIN: | |
| print(f"GeminiImageDirectNode - Warning: Seed {seed} below INT32 min, clamping to {INT32_MIN}") | |
| clamped_seed = INT32_MIN | |
| config_kwargs["seed"] = int(clamped_seed) | |
| generation_config = GenerateContentConfig(**config_kwargs) if config_kwargs else None | |
| # Call API | |
| print(f"GeminiImageDirectNode - Calling {model} with {len(parts)} part(s)") | |
| if generation_config: | |
| print(f"GeminiImageDirectNode - Config: {generation_config.model_dump(exclude_none=True)}") | |
| start_time = time.time() | |
| try: | |
| if generation_config: | |
| response = self.client.models.generate_content( | |
| model=model, | |
| contents=contents, | |
| config=generation_config, | |
| ) | |
| else: | |
| response = self.client.models.generate_content( | |
| model=model, | |
| contents=contents, | |
| ) | |
| except Exception as e: | |
| # Print detailed error information | |
| error_type = type(e).__name__ | |
| error_msg = str(e) | |
| print(f"GeminiImageDirectNode - Error: {error_type}: {error_msg}") | |
| # Try to get more details from the exception | |
| if hasattr(e, 'response') and hasattr(e.response, 'json'): | |
| try: | |
| error_json = e.response.json() | |
| print(f"GeminiImageDirectNode - Error details: {error_json}") | |
| except: | |
| pass | |
| # Try to get error details from the exception object itself | |
| if hasattr(e, 'error'): | |
| print(f"GeminiImageDirectNode - Exception error attribute: {e.error}") | |
| # Print config for debugging | |
| if generation_config: | |
| try: | |
| config_dict = generation_config.model_dump(exclude_none=True) | |
| print(f"GeminiImageDirectNode - Config that failed: {config_dict}") | |
| except Exception as dump_e: | |
| print(f"GeminiImageDirectNode - Could not dump config: {dump_e}") | |
| # Print contents structure for debugging | |
| try: | |
| print(f"GeminiImageDirectNode - Contents: {len(contents)} content(s)") | |
| for i, content in enumerate(contents): | |
| print(f"GeminiImageDirectNode - Content {i}: role={content.role}, parts={len(content.parts) if content.parts else 0}") | |
| if content.parts: | |
| for j, part in enumerate(content.parts): | |
| part_type = "text" if hasattr(part, 'text') and part.text else "inline_data" if hasattr(part, 'inline_data') and part.inline_data else "unknown" | |
| print(f"GeminiImageDirectNode - Part {j}: type={part_type}") | |
| except Exception as contents_e: | |
| print(f"GeminiImageDirectNode - Could not print contents: {contents_e}") | |
| raise | |
| api_time = time.time() - start_time | |
| # Parse response | |
| output_image = get_image_from_response(response) | |
| output_text = get_text_from_response(response) | |
| # Log response details for debugging | |
| num_images = output_image.shape[0] if isinstance(output_image, torch.Tensor) else 0 | |
| image_shape = output_image.shape if isinstance(output_image, torch.Tensor) else "N/A" | |
| print(f"GeminiImageDirectNode - Extracted {num_images} image(s), shape: {image_shape}") | |
| # Add timing info to text output | |
| timing_info = f"\n\n[API call: {api_time:.2f}s]" | |
| output_text = output_text + timing_info if output_text else timing_info.strip() | |
| print(f"GeminiImageDirectNode - Generation complete ({api_time:.2f}s)") | |
| return (output_image, output_text) | |
| NODE_CLASS_MAPPINGS = { | |
| "GeminiImageDirectNode": GeminiImageDirectNode | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "GeminiImageDirectNode": "Gemini Image Direct" | |
| } | |
| __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment