Last active
February 10, 2026 13:53
-
-
Save blepping/d0f6a26b1f59ed705999945821a3ee8a to your computer and use it in GitHub Desktop.
Some ComfyUI nodes for ACE-Steps (1.0 and 1.5)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # By https://github.com/blepping | |
| # License: Apache2 | |
| # | |
| # Place this file in your custom_nodes directory and it should load automatically. | |
| import itertools | |
| import math | |
| import sys | |
| import nodes | |
| import torch | |
| import yaml | |
| from comfy import model_management, model_patcher, patcher_extension, samplers | |
| try: | |
| from comfy.ldm.ace.ace_step15 import get_silence_latent as get_ace15_silence_latent | |
| HAVE_ACE15_SILENCE = True | |
| except Exception: | |
| HAVE_ACE15_SILENCE = False | |
| # try: | |
| # import mutagen | |
| # HAVE_MUTAGEN = True | |
| # except Exception: | |
| # HAVE_MUTAGEN = False | |
| LATENT_TIME_MULTIPLIER = 44100 / 512 / 8 | |
| LATENT_TIME_MULTIPLIER_15 = 25.0 # 48000 / 1920 | |
| SILENCE = torch.tensor( | |
| ( | |
| ( | |
| -0.6462, | |
| -1.2132, | |
| -1.3026, | |
| -1.2432, | |
| -1.2455, | |
| -1.2162, | |
| -1.2184, | |
| -1.2114, | |
| -1.2153, | |
| -1.2144, | |
| -1.2130, | |
| -1.2115, | |
| -1.2063, | |
| -1.1918, | |
| -1.1154, | |
| -0.7924, | |
| ), | |
| ( | |
| 0.0473, | |
| -0.3690, | |
| -0.6507, | |
| -0.5677, | |
| -0.6139, | |
| -0.5863, | |
| -0.5783, | |
| -0.5746, | |
| -0.5748, | |
| -0.5763, | |
| -0.5774, | |
| -0.5760, | |
| -0.5714, | |
| -0.5560, | |
| -0.5393, | |
| -0.3263, | |
| ), | |
| ( | |
| -1.3019, | |
| -1.9225, | |
| -2.0812, | |
| -2.1188, | |
| -2.1298, | |
| -2.1227, | |
| -2.1080, | |
| -2.1133, | |
| -2.1096, | |
| -2.1077, | |
| -2.1118, | |
| -2.1141, | |
| -2.1168, | |
| -2.1134, | |
| -2.0720, | |
| -1.7442, | |
| ), | |
| ( | |
| -4.4184, | |
| -5.5253, | |
| -5.7387, | |
| -5.7961, | |
| -5.7819, | |
| -5.7850, | |
| -5.7980, | |
| -5.8083, | |
| -5.8197, | |
| -5.8202, | |
| -5.8231, | |
| -5.8305, | |
| -5.8313, | |
| -5.8153, | |
| -5.6875, | |
| -4.7317, | |
| ), | |
| ( | |
| 1.5986, | |
| 2.0669, | |
| 2.0660, | |
| 2.0476, | |
| 2.0330, | |
| 2.0271, | |
| 2.0252, | |
| 2.0268, | |
| 2.0289, | |
| 2.0260, | |
| 2.0261, | |
| 2.0252, | |
| 2.0240, | |
| 2.0220, | |
| 1.9828, | |
| 1.6429, | |
| ), | |
| ( | |
| -0.4177, | |
| -0.9632, | |
| -1.0095, | |
| -1.0597, | |
| -1.0462, | |
| -1.0640, | |
| -1.0607, | |
| -1.0604, | |
| -1.0641, | |
| -1.0636, | |
| -1.0631, | |
| -1.0594, | |
| -1.0555, | |
| -1.0466, | |
| -1.0139, | |
| -0.8284, | |
| ), | |
| ( | |
| -0.7686, | |
| -1.0507, | |
| -1.3932, | |
| -1.4880, | |
| -1.5199, | |
| -1.5377, | |
| -1.5333, | |
| -1.5320, | |
| -1.5307, | |
| -1.5319, | |
| -1.5360, | |
| -1.5383, | |
| -1.5398, | |
| -1.5381, | |
| -1.4961, | |
| -1.1732, | |
| ), | |
| ( | |
| 0.0199, | |
| -0.0880, | |
| -0.4010, | |
| -0.3936, | |
| -0.4219, | |
| -0.4026, | |
| -0.3907, | |
| -0.3940, | |
| -0.3961, | |
| -0.3947, | |
| -0.3941, | |
| -0.3929, | |
| -0.3889, | |
| -0.3741, | |
| -0.3432, | |
| -0.169, | |
| ), | |
| ), | |
| dtype=torch.float32, | |
| device="cpu", | |
| )[None, ..., None] | |
| BLEND_MODES = None | |
| def _ensure_blend_modes(): | |
| global BLEND_MODES | |
| if BLEND_MODES is None: | |
| bi = sys.modules.get("_blepping_integrations", {}) or getattr( | |
| nodes, | |
| "_blepping_integrations", | |
| {}, | |
| ) | |
| bleh = bi.get("bleh") | |
| if bleh is not None: | |
| BLEND_MODES = bleh.py.latent_utils.BLENDING_MODES | |
| else: | |
| BLEND_MODES = { | |
| "lerp": torch.lerp, | |
| "a_only": lambda a, _b, _t: a, | |
| "b_only": lambda _a, b, _t: b, | |
| "subtract_b": lambda a, b, t: a - b * t, | |
| } | |
| def normalize_to_scale(latent, target_min, target_max, *, dim=(-3, -2, -1)): | |
| min_val, max_val = ( | |
| latent.amin(dim=dim, keepdim=True), | |
| latent.amax(dim=dim, keepdim=True), | |
| ) | |
| normalized = (latent - min_val).div_(max_val - min_val) | |
| return ( | |
| normalized.mul_(target_max - target_min) | |
| .add_(target_min) | |
| .clamp_(target_min, target_max) | |
| ) | |
| def fixup_waveform( | |
| waveform: torch.Tensor, | |
| *, | |
| copy: bool = True, | |
| move_to_cpu: bool = True, | |
| ensure_stereo: bool = False, | |
| ) -> torch.Tensor: | |
| if move_to_cpu: | |
| waveform = waveform.to(device="cpu", copy=copy) | |
| if waveform.ndim == 2: | |
| waveform = waveform[None] | |
| elif waveform.ndim == 1: | |
| waveform = waveform[None, None] | |
| if ensure_stereo and waveform.shape[1] == 1: | |
| waveform = waveform.repeat(1, 2, 1) | |
| return waveform | |
| class SilentLatentNode: | |
| DESCRIPTION = "Creates a latent full of (roughly) silence. This node can work for ACE-Steps 1.5 if you connect a reference latent." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("LATENT",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "seconds": ( | |
| "FLOAT", | |
| { | |
| "default": 120.0, | |
| "min": 1.0, | |
| "max": 1000.0, | |
| "step": 0.1, | |
| "tooltip": "Number of seconds to generate. Ignored if optional latent input is connected.", | |
| }, | |
| ), | |
| "batch_size": ( | |
| "INT", | |
| { | |
| "default": 1, | |
| "min": 1, | |
| "max": 4096, | |
| "tooltip": "Batch size to generate. Ignored if optional latent input is connected.", | |
| }, | |
| ), | |
| }, | |
| "optional": { | |
| "ref_latent_opt": ( | |
| "LATENT", | |
| { | |
| "tooltip": "When connected the other parameters are ignored and the latent output will match the length/batch size of the reference. This needs to be connected to get a ACE-Steps 1.5 silent latent." | |
| }, | |
| ), | |
| }, | |
| } | |
| @classmethod | |
| def go_ace15(cls, ref_shape: torch.Size) -> tuple[dict]: | |
| if not HAVE_ACE15_SILENCE: | |
| raise RuntimeError("ACE 1.5 silence unavailable. ComfyUI version too old?") | |
| ndim = len(ref_shape) | |
| if ndim == 4 and ref_shape[-2] != 1: | |
| raise ValueError( | |
| "Can't handle 4D ACE 1.5 latent with non-empty dimension -2" | |
| ) | |
| latent = torch.zeros( | |
| ref_shape[0], 64, ref_shape[-1], device="cpu", dtype=torch.float32 | |
| ) | |
| latent += get_ace15_silence_latent(ref_shape[-1], device="cpu").to(latent) | |
| if ndim == 4: | |
| latent = latent.unsqueeze(-2) | |
| return ({"samples": latent, "type": "audio"},) | |
| @classmethod | |
| def go(cls, *, seconds: float, batch_size: int, ref_latent_opt=None) -> tuple[dict]: | |
| if ref_latent_opt is not None: | |
| ref_shape = ref_latent_opt["samples"].shape | |
| if len(ref_shape) in {3, 4} and ref_shape[1] == 64: | |
| return cls.go_ace15(ref_shape=ref_shape) | |
| latent = torch.zeros(ref_shape, device="cpu", dtype=torch.float32) | |
| else: | |
| length = int(seconds * LATENT_TIME_MULTIPLIER) | |
| latent = torch.zeros( | |
| batch_size, 8, 16, length, device="cpu", dtype=torch.float32 | |
| ) | |
| latent += SILENCE | |
| return ({"samples": latent, "type": "audio"},) | |
| class VisualizeLatentNode: | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("IMAGE",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "latent": ("LATENT",), | |
| "scale_secs": ( | |
| "INT", | |
| { | |
| "default": 0, | |
| "min": 0, | |
| "max": 1000, | |
| "tooltip": "Horizontal scale. Number of pixels that corresponds to one second of audio. You can use 0 for no scaling which is roughly 11 pixels per second.", | |
| }, | |
| ), | |
| "scale_vertical": ( | |
| "INT", | |
| { | |
| "default": 1, | |
| "min": 1, | |
| "max": 1024, | |
| "tooltip": "Pixel expansion factor for channels (or frequency bands if you have swap_channels_freqs mode enabled).", | |
| }, | |
| ), | |
| "swap_channels_freqs": ( | |
| "BOOLEAN", | |
| { | |
| "default": False, | |
| "tooltip": "Swaps the order of channels and frequency in the vertical dimension. When enabled, scale_vertical applies to frequency bands.", | |
| }, | |
| ), | |
| "normalize_dims": ( | |
| "STRING", | |
| { | |
| "default": "-1", | |
| "tooltip": "Dimensions the latent scale is normalized using. Must be a comma-separated list. The default setting normalizes the channels and frequency bands independently per batch, you can try -3, -2, -1 if you want to see the relative differences better.", | |
| }, | |
| ), | |
| "mode": ( | |
| ( | |
| "split", | |
| "combined", | |
| "brg", | |
| "rgb", | |
| "bgr", | |
| "split_flip", | |
| "combined_flip", | |
| "brg_flip", | |
| "rgb_flip", | |
| "bgr_flip", | |
| ), | |
| { | |
| "default": "split", | |
| "tooltip": "Split shows a monochrome view of of each channel/freq, combined shows the average. Flip means invert the energy in the channel (i.e. white -> black). The other modes put the latent channels into the RGB channels of the preview image.", | |
| }, | |
| ), | |
| }, | |
| } | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| latent, | |
| scale_secs, | |
| scale_vertical, | |
| swap_channels_freqs, | |
| normalize_dims, | |
| mode, | |
| ) -> tuple: | |
| normalize_dims = normalize_dims.strip() | |
| normalize_dims = ( | |
| () | |
| if not normalize_dims | |
| else tuple(int(dim) for dim in normalize_dims.split(",")) | |
| ) | |
| samples = latent["samples"].to(dtype=torch.float32, device="cpu") | |
| if samples.ndim == 3 and samples.shape[1] == 64: | |
| samples = samples.unsqueeze(-2) | |
| temporal_scale_factor = LATENT_TIME_MULTIPLIER_15 | |
| elif samples.ndim == 4: | |
| temporal_scale_factor = LATENT_TIME_MULTIPLIER | |
| else: | |
| raise ValueError( | |
| "Expected an ACE-Steps 1.0 latent with 4 dimensions or an Ace-Step 1.5 latent with 3 dimensions and 64 channels." | |
| ) | |
| color_mode = mode not in {"split", "combined", "split_flip", "combined_flip"} | |
| batch, channels, freqs, temporal = samples.shape | |
| samples = normalize_to_scale(samples, 0.0, 1.0, dim=normalize_dims) | |
| if mode.endswith("_flip"): | |
| samples = 1.0 - samples | |
| if swap_channels_freqs: | |
| samples = samples.movedim(2, 1) | |
| if mode.startswith("combined"): | |
| samples = samples.mean(dim=1, keepdim=True) | |
| if scale_vertical != 1: | |
| samples = samples.repeat_interleave(scale_vertical, dim=2) | |
| if not color_mode: | |
| samples = samples.reshape(batch, -1, temporal) | |
| if scale_secs > 0: | |
| new_temporal = round((temporal / temporal_scale_factor) * scale_secs) | |
| samples = torch.nn.functional.interpolate( | |
| samples.unsqueeze(1) if not color_mode else samples, | |
| size=(samples.shape[-2], new_temporal), | |
| mode="nearest-exact", | |
| ) | |
| if not color_mode: | |
| samples = samples.squeeze(1) | |
| if not color_mode: | |
| return (samples[..., None].expand(*samples.shape, 3),) | |
| rgb_count = math.ceil(samples.shape[1] / 3) | |
| channels_pad = rgb_count * 3 - samples.shape[1] | |
| samples = torch.cat( | |
| ( | |
| samples, | |
| samples.new_zeros(samples.shape[0], channels_pad, *samples.shape[-2:]), | |
| ), | |
| dim=1, | |
| ) | |
| samples = torch.cat(samples.chunk(rgb_count, dim=1), dim=2).movedim(1, -1) | |
| if mode.startswith("bgr"): | |
| samples = samples.flip(-1) | |
| elif mode.startswith("brg"): | |
| samples = samples.roll(-1, -1) | |
| return (samples,) | |
| SPLIT_KEYS = ("conditioning_lyrics", "lyrics_strength", "audio_codes") | |
| class SplitOutLyricsNode: | |
| DESCRIPTION = "Allows splitting out lyrics and lyrics strength from ACE-Steps CONDITIONING objects. Note that you will only be able to join it back again if it is the same shape." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("CONDITIONING", "CONDITIONING_ACE_LYRICS") | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "conditioning": ("CONDITIONING",), | |
| "add_fake_pooled": ("BOOLEAN", {"default": True}), | |
| }, | |
| } | |
| @classmethod | |
| def go(cls, *, conditioning, add_fake_pooled) -> tuple: | |
| tags_result, lyrics_result = [], [] | |
| for cond_t, cond_d in conditioning: | |
| cond_d = cond_d.copy() | |
| split_d = {k: cond_d.pop(k) for k in SPLIT_KEYS if k in cond_d} | |
| if add_fake_pooled and cond_d.get("pooled_output") is None: | |
| cond_d["pooled_output"] = cond_t.new_zeros(1, 1) | |
| split_d["pooled_ouput"] = None | |
| tags_result.append([cond_t.clone(), cond_d]) | |
| lyrics_result.append(split_d) | |
| return (tags_result, lyrics_result) | |
| class JoinLyricsNode: | |
| DESCRIPTION = "Allows joining CONDITIONING_ACE_LYRICS back into CONDITIONING. Will overwrite any lyrics that exist. Must be the same shape as the conditioning the lyrics were split from." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("CONDITIONING",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "conditioning_tags": ("CONDITIONING",), | |
| "conditioning_lyrics": ("CONDITIONING_ACE_LYRICS",), | |
| "mode": (("matching", "add_missing"), {"default": "matching"}), | |
| "start_time": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0}), | |
| "end_time": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), | |
| }, | |
| } | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| conditioning_tags: list, | |
| conditioning_lyrics: list, | |
| mode: str, | |
| start_time: float, | |
| end_time: float, | |
| ) -> tuple[list]: | |
| ct_len, cl_len = len(conditioning_tags), len(conditioning_lyrics) | |
| if mode == "add_missing": | |
| if not cl_len: | |
| raise ValueError("conditioning_lyrics must have at least one item") | |
| conditioning_lyrics = ( | |
| conditioning_lyrics + [conditioning_lyrics[-1]] * (ct_len - cl_len) | |
| )[:ct_len] | |
| elif ct_len != cl_len: | |
| raise ValueError( | |
| f"Different lengths for tags {ct_len} vs conditioning lyrics {cl_len}" | |
| ) | |
| result = [ | |
| [ | |
| cond_t.clone(), | |
| cond_d | |
| | { | |
| k: v.clone() | |
| if isinstance(v, torch.Tensor) | |
| else (v.copy() if hasattr(v, "copy") else v) | |
| for k, v in cond_l.items() | |
| if mode != "add_missing" or k not in cond_d | |
| } | |
| if cond_d.get("start_percent", 0.0) >= start_time | |
| and cond_d.get("end_percent", 1.0) <= end_time | |
| else cond_d.copy(), | |
| ] | |
| for (cond_t, cond_d), cond_l in zip(conditioning_tags, conditioning_lyrics) | |
| ] | |
| return (result,) | |
| class EncodeLyricsNode: | |
| DESCRIPTION = "Encode lyrics for ACE-Steps 1.0" | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("CONDITIONING_ACE_LYRICS",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "clip": ("CLIP",), | |
| "lyrics_strength": ( | |
| "FLOAT", | |
| {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}, | |
| ), | |
| "lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}), | |
| }, | |
| } | |
| @classmethod | |
| def go(cls, *, clip, lyrics_strength: float, lyrics: str) -> tuple: | |
| conditioning = clip.encode_from_tokens_scheduled( | |
| clip.tokenize("", lyrics=lyrics) | |
| ) | |
| lyrics_result = [ | |
| { | |
| "conditioning_lyrics": cond[1]["conditioning_lyrics"], | |
| "lyrics_strength": lyrics_strength, | |
| } | |
| for cond in conditioning | |
| ] | |
| return (lyrics_result,) | |
| class SetAudioDtypeNode: | |
| DESCRIPTION = "Advanced node that allows the datatype of the audio waveform. The 16 and 8 bit types are not recommended." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("AUDIO",) | |
| _ALLOWED_DTYPES = ( | |
| "float64", | |
| "float32", | |
| "float16", | |
| "bfloat16", | |
| "float8_e4m3fn", | |
| "float8_e5m2", | |
| ) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "audio": ("AUDIO",), | |
| "dtype": ( | |
| cls._ALLOWED_DTYPES, | |
| {"default": "float64", "tooltip": "TBD"}, | |
| ), | |
| }, | |
| } | |
| @classmethod | |
| def go(cls, *, audio: dict, dtype: str) -> tuple[dict]: | |
| if dtype not in cls._ALLOWED_DTYPES: | |
| raise ValueError("Bad dtype") | |
| waveform = audio["waveform"] | |
| dt = getattr(torch, dtype) | |
| if waveform.dtype == dt: | |
| return (audio,) | |
| return (audio | {"waveform": waveform.to(dtype=dt)},) | |
| class AudioLevelsNode: | |
| DESCRIPTION = "The values in the waveform range for -1 to 1. This node allows you to scale audio to a percentage of that range." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("AUDIO",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "audio": ("AUDIO",), | |
| "scale": ( | |
| "FLOAT", | |
| { | |
| "default": 0.95, | |
| "min": 0.0, | |
| "max": 1.0, | |
| "tooltip": "Percentage where 1.0 indicates 100% of the maximum allowed value in an audio tensor. You can use 1.0 to make it as loud as possible without actually clipping.", | |
| }, | |
| ), | |
| "per_channel": ( | |
| "BOOLEAN", | |
| { | |
| "default": False, | |
| "tooltip": "When enabled, the levels for each channel will be scaled independently. For multi-channel audio (like stereo) enabling this will not preserve the relative levels between the channels so probably should be left disabled most of the time.", | |
| }, | |
| ), | |
| }, | |
| } | |
| @classmethod | |
| def go(cls, *, audio: dict, scale: float, per_channel: bool) -> tuple[dict]: | |
| waveform = audio["waveform"].to(device="cpu", copy=True) | |
| if waveform.ndim == 1: | |
| waveform = waveform[None, None, ...] | |
| elif waveform.ndim == 2: | |
| waveform = waveform[None, ...] | |
| elif waveform.ndim != 3: | |
| raise ValueError("Unexpected number of dimensions in waveform!") | |
| max_val = ( | |
| waveform.abs().flatten(start_dim=2 if per_channel else 1).max(dim=-1).values | |
| ) | |
| max_val = max_val[..., None] if per_channel else max_val[..., None, None] | |
| # Max could be 0, multiplying by 0 is fine in that case. | |
| waveform *= (scale / max_val).nan_to_num() | |
| return (audio | {"waveform": waveform.clamp(-1.0, 1.0)},) | |
| class AudioAsLatentNode: | |
| DESCRIPTION = "This node allows you to rearrange AUDIO to look like a LATENT. Can be useful if you want to apply some latent operations to AUDIO. Can be reversed with the ACETricks LatentAsAudio node." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("LATENT",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "audio": ("AUDIO",), | |
| "use_width": ( | |
| "BOOLEAN", | |
| { | |
| "default": True, | |
| "tooltip": "When enabled, you'll get a 4 channel with height 1 and the audio audio data in the width dimension, otherwise the opposite.", | |
| }, | |
| ), | |
| }, | |
| } | |
| @classmethod | |
| def go(cls, *, audio: dict, use_width: bool) -> tuple: | |
| waveform = audio["waveform"].to(device="cpu", copy=True) | |
| if waveform.ndim == 1: | |
| waveform = waveform[None, None, ...] | |
| elif waveform.ndim == 2: | |
| waveform = waveform[None, ...] | |
| elif waveform.ndim != 3: | |
| raise ValueError("Unexpected number of dimensions in waveform!") | |
| waveform = waveform.unsqueeze(2) if use_width else waveform[..., None] | |
| return ({"samples": waveform},) | |
| class LatentAsAudioNode: | |
| DESCRIPTION = "This node lets you rearrange a LATENT to look like AUDIO. Mainly useful for getting back after using the ACETricks AudioAsLatent node and performing some operations. If you connect the optional audio input it will use whatever non-waveform parameters exist in it (can be stuff like the sample rate), otherwise it will just add sample_rate: 41000 and the waveform." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("AUDIO",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "latent": ("LATENT",), | |
| "values_mode": ( | |
| ("rescale", "clamp"), | |
| {"default": "rescale"}, | |
| ), | |
| "use_width": ( | |
| "BOOLEAN", | |
| { | |
| "default": True, | |
| "tooltip": "When enabled, takes the audio data from the first item in the width dimension, otherwise height.", | |
| }, | |
| ), | |
| }, | |
| "optional": { | |
| "audio_opt": ( | |
| "AUDIO", | |
| { | |
| "tooltip": "Optional audio to use as a reference for sample rate and possibly other values." | |
| }, | |
| ), | |
| }, | |
| } | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| latent: dict, | |
| values_mode: str, | |
| use_width: bool, | |
| audio_opt: dict | None = None, | |
| ) -> tuple: | |
| samples = latent["samples"] | |
| if samples.ndim != 4: | |
| raise ValueError("Expected a 4D latent but didn't get one") | |
| samples = (samples[..., 0, :] if use_width else samples[..., 0]).to( | |
| device="cpu", copy=True | |
| ) | |
| if audio_opt is None: | |
| audio_opt = {"sample_rate": 44100} | |
| result = audio_opt | {"waveform": samples} | |
| if values_mode == "clamp": | |
| result["waveform"] = samples.clamp(-1.0, 1.0) | |
| elif torch.any(samples.abs() > 1.0): | |
| return AudioLevelsNode.go(audio=result, per_channel=False, scale=1.0) | |
| return (result,) | |
| class MonoToStereoNode: | |
| DESCRIPTION = "Can convert mono AUDIO to stereo. It will leave AUDIO that's already stereo alone. Note: Always adds a batch dimension if it doesn't exist and moves to the CPU device." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("AUDIO",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return {"required": {"audio": ("AUDIO",)}} | |
| @classmethod | |
| def go(cls, *, audio: dict) -> tuple: | |
| waveform = audio["waveform"].to(device="cpu") | |
| if waveform.ndim == 2: | |
| waveform = waveform[None] | |
| elif waveform.ndim == 1: | |
| waveform = waveform[None, None] | |
| channels = waveform.shape[1] | |
| audio = audio.copy() | |
| if channels == 1: | |
| waveform = waveform.repeat(1, 2, 1) | |
| audio["waveform"] = waveform | |
| return (audio,) | |
| class AudioBlendNode: | |
| DESCRIPTION = "Blends two AUDIO inputs together. If you have ComfyUI-bleh installed you will have access to many additional blend modes." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("AUDIO",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| _ensure_blend_modes() | |
| assert BLEND_MODES is not None # Make static analysis happy. | |
| return { | |
| "required": { | |
| "audio_a": ("AUDIO",), | |
| "audio_b": ("AUDIO",), | |
| "audio_b_strength": ( | |
| "FLOAT", | |
| { | |
| "default": 0.5, | |
| "min": -1000.0, | |
| "max": 1000.0, | |
| }, | |
| ), | |
| "blend_mode": ( | |
| tuple(BLEND_MODES.keys()), | |
| { | |
| "default": "lerp", | |
| }, | |
| ), | |
| "length_mismatch_mode": ( | |
| ("shrink", "blend"), | |
| { | |
| "default": "shrink", | |
| "tooltip": "Shrink mode will return audio matching whatever the shortest input was. Blend will blend up to the shortest input's size and use unblended longer input to fill the rest. Note that this adjustment occurs before blending.", | |
| }, | |
| ), | |
| "normalization_mode": ( | |
| ("clamp", "levels", "levels_per_channel", "none"), | |
| { | |
| "default": "levels", | |
| "tooltip": "Clamp will just clip the result to ensure it is within the permitted range. Levels will rebalance it so the maximum value is the maximum value for the permitted range. Levels per channel is the same, except the maximum value is determined separately per channel. Setting this to none is not recommended unless you are planning to do your own normalization as it may leave invalid values in the audio latent.", | |
| }, | |
| ), | |
| "result_template": ( | |
| ("a", "b"), | |
| { | |
| "default": "a", | |
| "tooltip": "AUDIOs contain metadata like sampling rate. The result will be based on the metadata from the audio input you select here, with the blended result as the waveform in it.", | |
| }, | |
| ), | |
| } | |
| } | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| audio_a: dict, | |
| audio_b: dict, | |
| audio_b_strength: float, | |
| blend_mode: str, | |
| length_mismatch_mode: str, | |
| normalization_mode: str, | |
| result_template: str, | |
| ) -> tuple: | |
| wa = fixup_waveform(audio_a["waveform"]) | |
| wb = fixup_waveform(audio_b["waveform"]) | |
| if wa.dtype != wb.dtype: | |
| wa = wa.to(dtype=torch.float32) | |
| wb = wb.to(dtype=torch.float32) | |
| if wa.shape[:-1] != wb.shape[:-1]: | |
| errstr = f"Unexpected batch or channels shape mismatch in audio. audio_a has shape {wa.shape}, audio_b has shape {wb.shape}" | |
| raise ValueError(errstr) | |
| assert BLEND_MODES is not None # Make static analysis happy. | |
| blend_function = BLEND_MODES[blend_mode] | |
| walen, wblen = wa.shape[-1], wb.shape[-1] | |
| if walen != wblen: | |
| if length_mismatch_mode == "shrink": | |
| minlen = min(walen, wblen) | |
| wa = wa[..., :minlen] | |
| wb = wb[..., :minlen] | |
| elif walen > wblen: | |
| wb_temp = wa.clone() | |
| wb_temp[..., :wblen] = wb | |
| wb = wb_temp | |
| else: | |
| wa_temp = wb.clone() | |
| wa_temp[..., :walen] = wa | |
| wa = wa_temp | |
| walen = wblen = wa.shape[-1] | |
| result = blend_function(wa, wb, audio_b_strength) | |
| result_audio = audio_a.copy() if result_template == "a" else audio_b.copy() | |
| if normalization_mode == "clamp": | |
| result = result.clamp_(min=-1.0, max=1.0) | |
| elif normalization_mode in {"levels", "levels_per_channel"}: | |
| result = AudioLevelsNode.go( | |
| audio={"waveform": result}, | |
| scale=1.0, | |
| per_channel=normalization_mode == "levels_per_channel", | |
| )[0]["waveform"] | |
| result_audio["waveform"] = result | |
| return (result_audio,) | |
| class AudioFromBatchNode: | |
| DESCRIPTION = "Can be used to extract batch items from AUDIO." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("AUDIO",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "audio": ("AUDIO",), | |
| "start": ( | |
| "INT", | |
| { | |
| "default": 0, | |
| "tooltip": "Start index (zero-based). Negative indexes count from the end.", | |
| }, | |
| ), | |
| "length": ("INT", {"default": 1, "min": 0}), | |
| } | |
| } | |
| @classmethod | |
| def go(cls, *, audio: dict, start: int, length: int) -> tuple: | |
| waveform = audio["waveform"] | |
| if not waveform.ndim == 3: | |
| raise ValueError("Expected 3D waveform") | |
| batch = waveform.shape[0] | |
| if start < 0: | |
| start = batch + start | |
| if start < 0: | |
| raise ValueError("Start index is out of range") | |
| new_waveform = waveform[start : start + length].clone() | |
| return (audio | {"waveform": new_waveform},) | |
| class TimeOffsetNode: | |
| DESCRIPTION = "Can be used to calculate an offset into an ACE-Steps 1.0 latent given a time in seconds." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("INT", "FLOAT") | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "seconds": ("FLOAT", {"default": 0.0, "min": 0.0}), | |
| } | |
| } | |
| @classmethod | |
| def go(cls, *, seconds: float) -> tuple[int, float]: | |
| offset = seconds * LATENT_TIME_MULTIPLIER | |
| return (int(offset), offset) | |
| class MaskNode: | |
| DESCRIPTION = "Can be used to create a mask based on time and frequency bands" | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("MASK",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "seconds": ("FLOAT", {"default": 120.0, "min": 1.0}), | |
| "start_time": ( | |
| "FLOAT", | |
| { | |
| "default": 0.0, | |
| "min": -99999.0, | |
| "tooltip": "Negative values count from the end.", | |
| }, | |
| ), | |
| "end_time": ( | |
| "FLOAT", | |
| { | |
| "default": -1.0, | |
| "min": -99999.0, | |
| "tooltip": "Negative values count from the end.", | |
| }, | |
| ), | |
| "start_freq": ( | |
| "INT", | |
| { | |
| "default": 0, | |
| "min": 0, | |
| "max": 15, | |
| "tooltip": "Frequency bands, 0 is the lowest frequency. Inclusive.", | |
| }, | |
| ), | |
| "end_freq": ( | |
| "INT", | |
| { | |
| "default": 15, | |
| "min": 0, | |
| "max": 15, | |
| "tooltip": "Frequency bands, 0 is the lowest frequency. Inclusive.", | |
| }, | |
| ), | |
| "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), | |
| "base_value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), | |
| } | |
| } | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| seconds: float, | |
| start_time: float, | |
| end_time: float, | |
| start_freq: int, | |
| end_freq: int, | |
| strength: float, | |
| base_value: float, | |
| ) -> tuple[torch.Tensor]: | |
| time_len = int(seconds * LATENT_TIME_MULTIPLIER) | |
| offs_start = int(start_time * LATENT_TIME_MULTIPLIER) | |
| offs_end = int(end_time * LATENT_TIME_MULTIPLIER) | |
| if offs_start < 0: | |
| offs_start = max(0, time_len + offs_start) | |
| if offs_end < 0: | |
| offs_end = max(0, time_len + offs_end) | |
| offs_start = min(time_len - 1, offs_start) | |
| offs_end = min(time_len - 1, offs_end) | |
| mask = torch.full( | |
| (1, 16, time_len), value=base_value, dtype=torch.float32, device="cpu" | |
| ) | |
| mask[:, start_freq : end_freq + 1, offs_start : offs_end + 1] = strength | |
| return (mask,) | |
| class WaveformNode: | |
| DESCRIPTION = "Creates a waveform image from audio." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("IMAGE",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict[str, dict]: | |
| return { | |
| "required": { | |
| "audio": ("AUDIO",), | |
| "width": ("INT", {"default": 800, "min": 1}), | |
| "height": ("INT", {"default": 200, "min": 1}), | |
| "background_rgb": ("STRING", {"default": "000020"}), | |
| "left_rgb": ( | |
| "STRING", | |
| { | |
| "default": "e0a080", | |
| "tooltip": "Used for both channels in the case of mono audio.", | |
| }, | |
| ), | |
| "right_rgb": ("STRING", {"default": "80e0a0"}), | |
| "mode": ( | |
| ("normal", "rescaled", "log", "log_rescaled"), | |
| {"default": "rescaled"}, | |
| ), | |
| "log_factor": ("FLOAT", {"default": 10.0, "min": 0.0, "step": 0.01}), | |
| "oversampling": ("INT", {"default": 4, "min": 1}), | |
| }, | |
| } | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| audio: dict, | |
| width: int, | |
| height: int, | |
| background_rgb: str, | |
| left_rgb: str, | |
| right_rgb: str, | |
| mode: str, | |
| log_factor: float, | |
| oversampling: int, | |
| ) -> tuple: | |
| height = max(1, height // 2) | |
| n = int(background_rgb, 16) | |
| brgb = tuple(((n >> (i * 8)) & 255) / 255 for i in range(2, -1, -1)) | |
| n = int(left_rgb, 16) | |
| lrgb = tuple(((n >> (i * 8)) & 255) / 255 for i in range(2, -1, -1)) | |
| n = int(right_rgb, 16) | |
| rrgb = tuple(((n >> (i * 8)) & 255) / 255 for i in range(2, -1, -1)) | |
| waveform = audio["waveform"] | |
| if waveform.ndim == 1: | |
| waveform = waveform[None, None] | |
| elif waveform.ndim == 2: | |
| waveform = waveform[None] | |
| elif waveform.ndim != 3: | |
| raise ValueError( | |
| f"Unexpected number of dimensions in waveform, expected 1-3, got {waveform.ndim}" | |
| ) | |
| waveform = waveform[:, :2, ...].to(dtype=torch.float64) | |
| waveform = torch.nn.functional.interpolate( | |
| waveform.unsqueeze(2), | |
| size=(height, min(waveform.shape[-1], width * oversampling)), | |
| mode="nearest-exact", | |
| ).movedim(1, -1) | |
| waveform = waveform.abs().clamp(0, 1) | |
| if mode in {"log", "log_rescaled"}: | |
| waveform = ( | |
| (waveform * log_factor).log1p() / math.log1p(log_factor) | |
| ).clamp_(0, 1) | |
| if mode in {"rescaled", "log_rescaled"}: | |
| waveform = (waveform / waveform.max()).nan_to_num_().clamp_(0, 1) | |
| channels = waveform.shape[-1] | |
| hmask = torch.linspace( | |
| 1.0, 0.0, height + 1, dtype=waveform.dtype, device=waveform.device | |
| )[1:].view(1, height, 1, 1) | |
| left_channel = waveform[..., 0:1] | |
| limg = torch.cat( | |
| tuple( | |
| torch.where(left_channel > hmask, fpixval, bpixval) | |
| for fpixval, bpixval in zip(lrgb, brgb) | |
| ), | |
| dim=-1, | |
| ) | |
| if channels >= 2: | |
| right_channel = waveform[..., 1:] | |
| rimg = torch.cat( | |
| tuple( | |
| torch.where(right_channel > hmask, fpixval, bpixval) | |
| for fpixval, bpixval in zip(rrgb, brgb) | |
| ), | |
| dim=-1, | |
| ) | |
| else: | |
| rimg = limg | |
| result = torch.cat((limg, rimg.flip(dims=(1,))), dim=1) | |
| if result.shape[2] != width: | |
| result = torch.nn.functional.interpolate( | |
| result.movedim(-1, 1), size=(height * 2, width), mode="bicubic" | |
| ).movedim(1, -1) | |
| return (result.to(device="cpu", dtype=torch.float32),) | |
| class SqueezeUnsqueezeLatentDimensionNode: | |
| DESCRIPTION = "This node can be used to add or remove an empty dimension from latents. Useful with ACE 1.5 which uses 3D latents while many ComfyUI latent processing nodes expect 4D+ latents. You will also likely need to use the ModelPatchAce15Use4dLatent node." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("LATENT",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "latent": ("LATENT",), | |
| "dimension": ( | |
| "INT", | |
| { | |
| "default": 2, | |
| "min": -9999, | |
| "max": 9999, | |
| "tooltip": "Negative dimensions count from the end.", | |
| }, | |
| ), | |
| "unsqueeze_mode": ( | |
| "BOOLEAN", | |
| { | |
| "default": True, | |
| "tooltip": "When enabled, unsqueezes (adds) an empty dimension at the specified position. When disabled will remove an empty dimension instead.", | |
| }, | |
| ), | |
| }, | |
| } | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| latent: dict, | |
| dimension: int, | |
| unsqueeze_mode: bool, | |
| ) -> tuple[dict]: | |
| samples = latent["samples"] | |
| pos_dim = dimension if dimension >= 0 else samples.ndim + dimension | |
| max_dim = samples.ndim if unsqueeze_mode else samples.ndim - 1 | |
| if pos_dim < 0 or pos_dim > max_dim: | |
| errstr = f"Specified dimension {dimension} out of range for latent with shape {samples.shape} ({samples.ndim} dimension(s))" | |
| raise ValueError(errstr) | |
| if unsqueeze_mode: | |
| samples = samples.unsqueeze(dimension) | |
| else: | |
| if samples.shape[dimension] != 1: | |
| errstr = f"Dimension {dimension} in latent with shape {samples.shape} is not empty, has size {samples.shape[dimension]}. This node can only squeeze empty dimensions." | |
| raise ValueError(errstr) | |
| samples = samples.squeeze(dimension) | |
| return (latent | {"samples": samples.clone()},) | |
| class ModelPatchAce15Use4dLatentNode: | |
| DESCRIPTION = "Patches an ACE 1.5 model (or theoretically, any model that uses 3D latents) to handle 4D latent inputs/outputs. The point of doing this is because many existing nodes are not able to handle 3D latents. You will also likely need to use the SqueezeUnsqueezeLatentDimension node on both sides of sampling (unsqueeze before, squeeze after). NOTE: This node is a hacky abomination. There are probably many cases where it where it won't function correctly, it's also quite likely to break if ComfyUI changes stuff. There really just isn't a reliable way to do what this is trying to do." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("MODEL",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "model": ("MODEL",), | |
| }, | |
| "optional": { | |
| "preserve_patch": ( | |
| "BOOLEAN", | |
| { | |
| "default": True, | |
| "tooltip": "Preserves/delegates to an existing patch of the same type (if it exists) instead of replacing it.", | |
| }, | |
| ), | |
| "dimension": ( | |
| "INT", | |
| { | |
| "default": 2, | |
| "min": -9999, | |
| "max": 9999, | |
| "tooltip": "Negative dimensions count from the end.", | |
| }, | |
| ), | |
| "expected_dimensions": ( | |
| "INT", | |
| { | |
| "default": 4, | |
| "min": 2, | |
| "max": 9999, | |
| "tooltip": "Number of dimensions to expect, generally this would be original model dimensions + 1.", | |
| }, | |
| ), | |
| }, | |
| } | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| model: model_patcher.ModelPatcher, | |
| preserve_patch: bool = True, | |
| dimension: int = 2, | |
| expected_dimensions: int = 4, | |
| ) -> tuple: | |
| model = model.clone() | |
| old_wrapper = ( | |
| model.model_options.get("sampler_calc_cond_batch_function") | |
| if preserve_patch | |
| else None | |
| ) | |
| def calc_cond_batch_patch(args: dict) -> list[torch.Tensor | None]: | |
| model = args["model"] | |
| x: torch.Tensor = args["input"] | |
| skip = ( | |
| not isinstance(x, torch.Tensor) | |
| or x.ndim != expected_dimensions | |
| or x.shape[dimension] != 1 | |
| ) | |
| if not skip: | |
| x = x.squeeze(dimension) | |
| if old_wrapper is None: | |
| result = samplers.calc_cond_batch( | |
| model, | |
| args["conds"], | |
| x, | |
| args["sigma"], | |
| args["model_options"], | |
| ) | |
| else: | |
| result = old_wrapper(args | {"input": x}) | |
| if skip: | |
| return result | |
| return [ | |
| t if not isinstance(t, torch.Tensor) else t.unsqueeze(dimension) | |
| for t in result | |
| ] | |
| def outer_sample_wrapper(executor, *args, **kwargs): | |
| noise = args[0] | |
| co = getattr(executor, "class_obj", None) | |
| inner_sample = getattr(co, "inner_sample", None) | |
| if ( | |
| inner_sample is None | |
| or noise.ndim != expected_dimensions | |
| or noise.shape[dimension] != 1 | |
| ): | |
| return executor(*args, **kwargs) | |
| def the_horror_pt1(*args, **kwargs): | |
| # AKA inner sample wrapper, which can't currently be hooked. | |
| inner_noise = args[0] | |
| inner_model = getattr(co, "inner_model", None) | |
| extra_conds = getattr(inner_model, "extra_conds", None) | |
| if ( | |
| extra_conds is None | |
| or inner_noise.ndim != expected_dimensions | |
| or inner_noise.shape[dimension] != 1 | |
| ): | |
| return inner_sample(*args, **kwargs) | |
| def the_horror_pt2(*args, **kwargs): | |
| ec_noise = kwargs.get("noise", None) | |
| if ( | |
| isinstance(ec_noise, torch.Tensor) | |
| and ec_noise.ndim == expected_dimensions | |
| and ec_noise.shape[dimension] == 1 | |
| ): | |
| kwargs["noise"] = ec_noise.squeeze(dimension) | |
| return extra_conds(*args, **kwargs) | |
| inner_model.extra_conds = the_horror_pt2 | |
| try: | |
| return inner_sample(*args, **kwargs) | |
| finally: | |
| if getattr(inner_model, "extra_conds", None) is the_horror_pt2: | |
| inner_model.extra_conds = extra_conds | |
| else: | |
| raise RuntimeError( | |
| "ACETricks: Use4dLatent hack: Can't replace original extra_conds wrapper. You may need to restart ComfyUI.", | |
| ) | |
| # Nice clean implementation, definitely isn't ever going to cause issues. | |
| try: | |
| co.inner_sample = the_horror_pt1 | |
| return executor(*args, **kwargs) | |
| finally: | |
| if co.inner_sample is the_horror_pt1: | |
| co.inner_sample = inner_sample | |
| else: | |
| raise RuntimeError( | |
| "ACETricks: Use4dLatent hack: Can't replace original inner_sample wrapper. You may need to restart ComfyUI.", | |
| ) | |
| model.set_model_sampler_calc_cond_batch_function(calc_cond_batch_patch) | |
| model.add_wrapper_with_key( | |
| patcher_extension.WrappersMP.OUTER_SAMPLE, | |
| "acetricks_use4dlatent_hack", | |
| outer_sample_wrapper, | |
| ) | |
| return (model,) | |
| class EmptyAce15LatentFromConditioningNode: | |
| DESCRIPTION = "Creates an empty latent from ACE-Steps 1.5 conditioning. Calculates the size from the number of LM audio codes in the conditioning, so it's only useful if you are using the LM feature." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("LATENT",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "conditioning": ("CONDITIONING",), | |
| }, | |
| "optional": { | |
| "batch_size": ("INT", {"default": 1, "min": 1, "max": 9999}), | |
| "minimum_duration": ( | |
| "FLOAT", | |
| { | |
| "default": 1.0, | |
| "min": 0.0, | |
| "max": 9999.0, | |
| "tooltip": "Will trigger an error if there aren't enough audio codes to reach the specified duration.", | |
| }, | |
| ), | |
| }, | |
| } | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| conditioning: list, | |
| batch_size: int = 1, | |
| minimum_duration: float = 0.0, | |
| ) -> tuple: | |
| n_codes = 0 | |
| for _, d in conditioning: | |
| audio_codes = d.get("audio_codes") | |
| if ( | |
| not isinstance(audio_codes, list) | |
| or not audio_codes | |
| or not len(audio_codes[0]) | |
| ): | |
| continue | |
| n_codes = max(n_codes, *(len(codes) for codes in audio_codes)) | |
| duration = 0.2 * n_codes # 5hz | |
| if duration < minimum_duration: | |
| errstr = f"{n_codes} 5hz code(s) ({duration:.3f} second(s)) doesn't match the minimum of {minimum_duration:.3f}." | |
| raise ValueError(errstr) | |
| device = model_management.intermediate_device() | |
| temporal_length = int(duration * LATENT_TIME_MULTIPLIER_15) | |
| latent = torch.zeros(batch_size, 64, temporal_length, device=device) | |
| return ({"samples": latent, "type": "audio"},) | |
| class Ace15CompressDuplicateAudioCodesNode: | |
| DESCRIPTION = "Creates an empty latent from ACE-Steps 1.5 conditioning. Calculates the size from the number of LM audio codes in the conditioning, so it's only useful if you are using the LM feature." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("CONDITIONING",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "conditioning": ("CONDITIONING",), | |
| }, | |
| "optional": { | |
| "mode": ( | |
| ("start", "end", "start_end", "anywhere"), | |
| { | |
| "default": "end", | |
| "tooltip": "Controls whether the node prunes duplicate audio codes from the beginning, end of the list, both head and tail or has no location restriction.", | |
| }, | |
| ), | |
| "repeat_limit": ( | |
| "INT", | |
| { | |
| "default": 3, | |
| "min": 1, | |
| "max": 9999, | |
| }, | |
| ), | |
| "repeat_replace": ( | |
| "INT", | |
| { | |
| "default": -1, | |
| "min": -1, | |
| "max": 9999, | |
| "tooltip": "When set to -1, will use the repeat_limit value. Otherwise, you can can prune repeated codes by setting this to 0 or possibly even expand them further (sounds like a bad idea).", | |
| }, | |
| ), | |
| }, | |
| } | |
| @staticmethod | |
| def simple_rle(items: list | tuple) -> tuple: | |
| n_items = len(items) | |
| if not n_items: | |
| return () | |
| result = [] | |
| idx = 0 | |
| while idx < n_items: | |
| item = items[idx] | |
| run_counter = 0 | |
| for subitem in items[idx + 1 :]: | |
| if subitem != item: | |
| break | |
| run_counter += 1 | |
| run_counter += 1 | |
| result.append((item, run_counter)) | |
| idx += run_counter | |
| return tuple(result) | |
| @staticmethod | |
| def simple_unrle(rle_items: list | tuple) -> tuple: | |
| return tuple(itertools.chain(*((i,) * n for i, n in rle_items))) | |
| @classmethod | |
| def apply_limits( | |
| cls, | |
| audio_codes: list | tuple, | |
| *, | |
| apply_any: bool, | |
| apply_head: bool, | |
| apply_tail: bool, | |
| limit: int = 1, | |
| replace: int = -1, | |
| ) -> list | tuple: | |
| if not (apply_any or apply_head or apply_tail): | |
| return audio_codes[:] | |
| ac_rle = cls.simple_rle(audio_codes) | |
| if not ac_rle: | |
| return audio_codes[:] | |
| # print(f"\nAC RLE: {ac_rle}") | |
| def limit_item(item: int, n_reps: int) -> tuple[int, int]: | |
| n_reps = ( | |
| min(limit, n_reps) | |
| if replace < 0 | |
| else (n_reps if n_reps <= limit else replace) | |
| ) | |
| return (item, n_reps) | |
| if apply_any: | |
| ac_rle = tuple(itertools.starmap(limit_item, ac_rle)) | |
| if apply_head: | |
| ac_rle = ( | |
| limit_item(*ac_rle[0]), | |
| *ac_rle[1:], | |
| ) | |
| if apply_tail: | |
| ac_rle = ( | |
| *ac_rle[:-1], | |
| limit_item(*ac_rle[-1]), | |
| ) | |
| return audio_codes.__class__(cls.simple_unrle(ac_rle)) | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| conditioning: list, | |
| mode: str, | |
| repeat_limit: int = 3, | |
| repeat_replace: int = -1, | |
| ) -> tuple: | |
| mode = mode.strip().lower() | |
| repeat_limit = max(1, repeat_limit) | |
| apply_any = mode == "anywhere" | |
| apply_head = mode in {"start", "start_end"} | |
| apply_tail = mode in {"end", "start_end"} | |
| result = [] | |
| for t, d in conditioning: | |
| d = d.copy() | |
| audio_codes = d.get("audio_codes") | |
| if ( | |
| isinstance(audio_codes, (list, tuple)) | |
| and audio_codes | |
| and len(audio_codes[0]) | |
| ): | |
| d["audio_codes"] = audio_codes.__class__( | |
| cls.apply_limits( | |
| codes_item, | |
| apply_any=apply_any, | |
| apply_head=apply_head, | |
| apply_tail=apply_tail, | |
| ) | |
| for codes_item in audio_codes | |
| ) | |
| result.append([t.clone(), d]) | |
| return (result,) | |
| class TextEncodeAce15Node: | |
| DESCRIPTION = "Text encoder for ACE-Steps 1.5 with extended features." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("CONDITIONING",) | |
| _DEFAULT_YAML = """# YAML metadata here. Must be a YAML object. | |
| # Most keys support a _negative suffix, i.e. bpm_negative. | |
| # There is no error checking so if you do "bpm: urmom", all I can say is good luck. | |
| bpm: 120 | |
| timesignature: 4/4 | |
| language: en | |
| keyscale: D major | |
| """ | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "clip": ("CLIP",), | |
| "duration": ("FLOAT", {"default": 120.0, "min": 0.2, "step": 0.2}), | |
| "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}), | |
| "tags": ( | |
| "STRING", | |
| {"defaultInput": True, "dynamicPrompts": False, "multiline": True}, | |
| ), | |
| "lyrics": ( | |
| "STRING", | |
| {"defaultInput": True, "dynamicPrompts": False, "multiline": True}, | |
| ), | |
| "generate_audio_codes": ("BOOLEAN", {"default": True}), | |
| "cfg_scale": ("FLOAT", {"default": 2.0}), | |
| "temperature": ("FLOAT", {"default": 0.85}), | |
| "top_p": ("FLOAT", {"default": 0.9}), | |
| "top_k": ("INT", {"default": 0, "min": 0}), | |
| "yaml_metadata": ( | |
| "STRING", | |
| { | |
| "dynamicPrompts": False, | |
| "multiline": True, | |
| "default": cls._DEFAULT_YAML, | |
| }, | |
| ), | |
| }, | |
| "optional": { | |
| "tags_negative": ( | |
| "STRING", | |
| {"defaultInput": True, "dynamicPrompts": False, "multiline": True}, | |
| ), | |
| "lyrics_negative": ( | |
| "STRING", | |
| {"defaultInput": True, "dynamicPrompts": False, "multiline": True}, | |
| ), | |
| }, | |
| } | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| clip: object, | |
| duration: float, | |
| seed: int, | |
| tags: str, | |
| lyrics: str, | |
| generate_audio_codes: bool, | |
| cfg_scale: float, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| yaml_metadata: str, | |
| tags_negative: str | None = None, | |
| lyrics_negative: str | None = None, | |
| ) -> tuple: | |
| parsed_metadata = yaml.safe_load(yaml_metadata) | |
| if parsed_metadata is None: | |
| parsed_metadata = {} | |
| elif not isinstance(parsed_metadata, dict): | |
| raise TypeError("yaml_metadata must be a YAML object") | |
| tkwargs = { | |
| "lyrics": lyrics, | |
| "duration": duration, | |
| "seed": seed, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "temperature": temperature, | |
| "generate_audio_codes": generate_audio_codes, | |
| } | |
| if tags_negative is not None: | |
| tkwargs["caption_negative"] = tags_negative | |
| if lyrics_negative is not None: | |
| tkwargs["lyrics_negative"] = lyrics_negative | |
| tkwargs |= parsed_metadata | |
| tokens = clip.tokenize(tags, **tkwargs) | |
| return (clip.encode_from_tokens_scheduled(tokens),) | |
| class RawTextEncodeAce15Node: | |
| DESCRIPTION = "Advanced text encoder for ACE-Steps 1.5." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("CONDITIONING",) | |
| _DEFAULT_YAML = """# YAML metadata here. Must be a YAML object. | |
| # There is absolutely no error checking or type checking for these fields. | |
| lm_metadata: | |
| # 5hz tokens, use desired seconds * 5. | |
| # The end of text token is banned until min_tokens is reached. | |
| min_tokens: 600 | |
| # The model is allowed to generate at most max_tokens. | |
| max_tokens: 1536 | |
| seed: 0 | |
| generate_audio_codes: true | |
| cfg_scale: 2.5 | |
| temperature: 1.25 | |
| top_p: 0.9 | |
| top_k: 0 | |
| # List of strings such as [lyrics_prompt, dit_prompt]. | |
| # Will allow encoding tokens with the normal weight parsing, I.E. (blah:1.2). | |
| # This probably won't work well if at all. | |
| allow_token_weights: [] | |
| # All prompt keys use the same format. | |
| # This is the format for multiline YAML strings. Line breaks will | |
| # be preserved but leading ident (to the initial level) will be stripped. | |
| # Use key: |- instead of just | to avoid a trailing line break. | |
| # NOTE: If you attach the optional prompt string inputs, the | |
| # respective keys in the YAML here will be overwritten. | |
| # Used to encode the lyrics (for DIT sampling) | |
| lyrics_prompt: | | |
| # Languages | |
| en | |
| # Lyric | |
| [Intro] | |
| [Verse] | |
| # Used for sampling. Note that duration is specified with a seconds unit. | |
| # Time signatures ending with /4 should just use the initial value. I.E. 4/4 -> 4 | |
| # Results may vary for other time signatures. | |
| dit_prompt: | | |
| # Instruction | |
| Generate audio semantic tokens based on the given conditions: | |
| # Caption | |
| User caption/tags here. | |
| # Metas | |
| - bpm: 60 | |
| - timesignature: 4 | |
| - keyscale: F minor | |
| - duration: 195 seconds | |
| <|endoftext|> | |
| # Positive prompt used for generating audio codes. | |
| # Note that duration is a raw integer value. | |
| lm_prompt: | | |
| <|im_start|>system | |
| # Instruction | |
| Generate audio semantic tokens based on the given conditions: | |
| <|im_end|> | |
| <|im_start|>user | |
| # Caption | |
| User caption/tags here. | |
| # Lyric | |
| [Intro] | |
| [Verse] | |
| <|im_end|> | |
| <|im_start|>assistant | |
| <think> | |
| bpm: 60 | |
| duration: 195 | |
| keyscale: F minor | |
| language: en | |
| timesignature: 4 | |
| </think> | |
| <|im_end|> | |
| # Note the blank CoT (<think></think> section). | |
| # Caption is not a recommendation, just an example. | |
| lm_prompt_negative: | | |
| <|im_start|>system | |
| # Instruction | |
| Generate audio semantic tokens based on the given conditions: | |
| <|im_end|> | |
| <|im_start|>user | |
| # Caption | |
| AI slop. Low quality, boring, repetitive. | |
| # Lyric | |
| [Intro] | |
| [Verse] | |
| <|im_end|> | |
| <|im_start|>assistant | |
| <think> | |
| </think> | |
| <|im_end|> | |
| # When set, disables generate_audio codes. Can be specified in two formats: | |
| # A list of integers, I.E. [1, 2, 3] | |
| # A string with audio code tokens, I.E: <|audio_code_36126|><|audio_code_36123|> | |
| audio_codes: null | |
| """ | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "clip": ("CLIP",), | |
| "yaml_text": ( | |
| "STRING", | |
| { | |
| "dynamicPrompts": False, | |
| "multiline": True, | |
| "default": cls._DEFAULT_YAML, | |
| }, | |
| ), | |
| }, | |
| "optional": { | |
| "lyrics_prompt": ( | |
| "STRING", | |
| { | |
| "defaultInput": True, | |
| "dynamicPrompts": False, | |
| "multiline": True, | |
| "tooltip": "Optional input. See example YAML for the correct format. Will overwrite the YAML key when populated.", | |
| }, | |
| ), | |
| "dit_prompt": ( | |
| "STRING", | |
| { | |
| "defaultInput": True, | |
| "dynamicPrompts": False, | |
| "multiline": True, | |
| "tooltip": "Optional input. See example YAML for the correct format. Will overwrite the YAML key when populated.", | |
| }, | |
| ), | |
| "lm_prompt": ( | |
| "STRING", | |
| { | |
| "defaultInput": True, | |
| "dynamicPrompts": False, | |
| "multiline": True, | |
| "tooltip": "Optional input. See example YAML for the correct format. Will overwrite the YAML key when populated.", | |
| }, | |
| ), | |
| "lm_prompt_negative": ( | |
| "STRING", | |
| { | |
| "defaultInput": True, | |
| "dynamicPrompts": False, | |
| "multiline": True, | |
| "tooltip": "Optional input. See example YAML for the correct format. Will overwrite the YAML key when populated.", | |
| }, | |
| ), | |
| "audio_codes": ( | |
| "STRING", | |
| { | |
| "defaultInput": True, | |
| "dynamicPrompts": False, | |
| "multiline": True, | |
| "tooltip": "Optional input. See example YAML for the correct format. Will overwrite the YAML key when populated. When attached, must either be a string of comma-separated integer audio code values or tokens as show in the example.", | |
| }, | |
| ), | |
| }, | |
| } | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| clip: object, | |
| yaml_text: str, | |
| lyrics_prompt: str | None = None, | |
| dit_prompt: str | None = None, | |
| lm_prompt: str | None = None, | |
| lm_prompt_negative: str | None = None, | |
| audio_codes: str | None = None, | |
| ) -> tuple: | |
| parsed_metadata = yaml.safe_load(yaml_text) | |
| if not isinstance(parsed_metadata, dict): | |
| raise TypeError("yaml_metadata must be a YAML object") | |
| lm_metadata = parsed_metadata.get("lm_metadata", {}) | |
| if not isinstance(lm_metadata, dict): | |
| raise TypeError("lm_metadata key must be a YAML object") | |
| verbose = parsed_metadata.pop("verbose", False) | |
| lyrics_prompt, dit_prompt, lm_prompt, lm_prompt_negative = ( | |
| v if v is not None else parsed_metadata.pop(k, None) | |
| for k, v in ( | |
| ("lyrics_prompt", lyrics_prompt), | |
| ("dit_prompt", dit_prompt), | |
| ("lm_prompt", lm_prompt), | |
| ("lm_prompt_negative", lm_prompt_negative), | |
| ) | |
| ) | |
| if not all(isinstance(p, str) for p in (dit_prompt, lyrics_prompt)): | |
| raise ValueError( | |
| "At the least, dit_prompt and lyrics_prompt keys must be specified and be type string." | |
| ) | |
| audio_codes = ( | |
| audio_codes | |
| if audio_codes is not None | |
| else lm_metadata.pop("audio_codes", None) | |
| ) | |
| if isinstance(audio_codes, str): | |
| audio_codes = audio_codes.strip() | |
| if isinstance(audio_codes, (tuple, list)): | |
| if not all(isinstance(ac, int) for ac in audio_codes): | |
| raise TypeError( | |
| "When present, audio codes must be a list of integer values." | |
| ) | |
| elif isinstance(audio_codes, str): | |
| if audio_codes.startswith("<|audio_code_"): | |
| audio_codes = [ | |
| int(ac.rsplit("_", 1)) | |
| for ac in audio_codes.split("|") | |
| if ac.startswith("audio_code_") | |
| ] | |
| else: | |
| audio_codes = [int(ac) for ac in audio_codes.split(",")] | |
| elif audio_codes is not None: | |
| raise TypeError("Bad format for specified audio codes") | |
| generate_audio_codes = ( | |
| lm_metadata.get("generate_audio_codes", False) | |
| and audio_codes is None | |
| and lm_prompt is not None | |
| and (lm_prompt_negative is not None or lm_metadata.get("cfg", None) == 1.0) | |
| ) | |
| lm_metadata["generate_audio_codes"] = generate_audio_codes | |
| allow_token_weights = frozenset(parsed_metadata.pop("allow_token_weights", ())) | |
| if not generate_audio_codes: | |
| lm_prompt = lm_prompt_negative = None | |
| tokenizer_kwargs = {} | |
| tokenizer_options = getattr(clip, "tokenizer_options", {}) | |
| if tokenizer_options: | |
| tokenizer_kwargs["tokenizer_options"] = tokenizer_options.copy() | |
| key_fixup = {"dit_prompt": "qwen3_06b", "lyrics_prompt": "lyrics"} | |
| prompts = ( | |
| ("lyrics_prompt", lyrics_prompt), | |
| ("dit_prompt", dit_prompt), | |
| ("lm_prompt", lm_prompt), | |
| ("lm_prompt_negative", lm_prompt_negative), | |
| ) | |
| encode_input = { | |
| key_fixup.get(k, k): clip.tokenizer.qwen3_06b.tokenize_with_weights( | |
| v, | |
| False, | |
| disable_weights=k not in allow_token_weights, | |
| **tokenizer_kwargs, | |
| ) | |
| for k, v in prompts | |
| if v is not None | |
| } | |
| if generate_audio_codes and "lm_prompt_negative" not in encode_input: | |
| encode_input["lm_prompt_negative"] = encode_input["lm_prompt"] | |
| encode_input["lm_metadata"] = lm_metadata | |
| if verbose: | |
| print( | |
| f"\n*** ACETricks RawTextEncodeAce15 debug:\n* PROMPTS: {dict(prompts)}\n* ENCODE INPUT: {encode_input}\n* AUDIO_CODES: {audio_codes}\n" | |
| ) | |
| conditioning = clip.encode_from_tokens_scheduled(encode_input) | |
| if audio_codes is not None: | |
| for _, d in conditioning: | |
| d["audio_codes"] = [list(audio_codes)] | |
| return (conditioning,) | |
| class Ace15LatentToAudioCodesNode: | |
| DESCRIPTION = "Extracts audio codes from an ACE-Steps 1.5 latent. This is probably not working correctly at the moment since the extracted codes don't work very well for sampling." | |
| FUNCTION = "go" | |
| CATEGORY = "audio/acetricks" | |
| RETURN_TYPES = ("CONDITIONING_ACE_LYRICS",) | |
| @classmethod | |
| def INPUT_TYPES(cls) -> dict: | |
| return { | |
| "required": { | |
| "model": ("MODEL",), | |
| "latent": ("LATENT",), | |
| }, | |
| "optional": { | |
| "include_upsampled_codes": ( | |
| "BOOLEAN", | |
| { | |
| "default": False, | |
| "tooltip": "There isn't anything that can use this yet.", | |
| }, | |
| ), | |
| }, | |
| } | |
| @staticmethod | |
| @torch.no_grad() | |
| def get_audio_codes( | |
| model: object, | |
| latent: torch.Tensor, | |
| *, | |
| upsample_codes: bool, | |
| ) -> tuple[list[int], torch.Tensor | None]: | |
| audio_codes, indices = model.tokenizer.tokenize(latent.mT) | |
| mask = (indices < 64000) & (indices >= 0) | |
| indices = indices[mask].reshape(latent.shape[0], -1).detach().cpu().tolist() | |
| upsampled_codes = ( | |
| model.detokenizer(audio_codes).mT.detach().float().cpu() | |
| if upsample_codes | |
| else None | |
| ) | |
| # print( | |
| # f"\nGOT:\nCODES={audio_codes}\nUPSAMPLED:{upsampled_codes}\nINDICES: {indices}" | |
| # ) | |
| return indices, upsampled_codes | |
| @classmethod | |
| def go( | |
| cls, | |
| *, | |
| model: object, | |
| latent: dict, | |
| include_upsampled_codes: bool = False, | |
| ) -> tuple[list[dict]]: | |
| samples = latent["samples"] | |
| if samples.ndim == 4 and samples.shape[-2] == 1: | |
| samples = samples.squeeze(-2) | |
| if samples.ndim != 3 or samples.shape[1] != 64: | |
| raise ValueError( | |
| "Incorrect latent format, doesn't appear to be an ACE-Steps 1.5 latent." | |
| ) | |
| model_management.load_model_gpu(model) | |
| mmodel = model.model | |
| device = mmodel.device | |
| dtype = mmodel.get_dtype() | |
| samples = samples.to(device=device, dtype=dtype) | |
| ac, uac = cls.get_audio_codes( | |
| mmodel.diffusion_model, | |
| samples, | |
| upsample_codes=include_upsampled_codes, | |
| ) | |
| result = {"audio_codes": ac} | |
| if uac is not None: | |
| result["upsampled_audio_codes"] = uac | |
| return ([result],) | |
| NODE_CLASS_MAPPINGS = { | |
| "ACETricks SilentLatent": SilentLatentNode, | |
| "ACETricks VisualizeLatent": VisualizeLatentNode, | |
| "ACETricks CondSplitOutLyrics": SplitOutLyricsNode, | |
| "ACETricks CondJoinLyrics": JoinLyricsNode, | |
| "ACETricks EncodeLyrics": EncodeLyricsNode, | |
| "ACETricks SetAudioDtype": SetAudioDtypeNode, | |
| "ACETricks AudioLevels": AudioLevelsNode, | |
| "ACETricks AudioAsLatent": AudioAsLatentNode, | |
| "ACETricks LatentAsAudio": LatentAsAudioNode, | |
| "ACETricks MonoToStereo": MonoToStereoNode, | |
| "ACETricks AudioBlend": AudioBlendNode, | |
| "ACETricks AudioFromBatch": AudioFromBatchNode, | |
| "ACETricks Mask": MaskNode, | |
| "ACETricks Time Offset": TimeOffsetNode, | |
| "ACETricks Waveform Image": WaveformNode, | |
| "ACETricks SqueezeUnsqueezeLatentDimension": SqueezeUnsqueezeLatentDimensionNode, | |
| "ACETricks ModelPatchAce15Use4dLatent": ModelPatchAce15Use4dLatentNode, | |
| "ACETricks EmptyAce15LatentFromConditioning": EmptyAce15LatentFromConditioningNode, | |
| "ACETricks Ace15CompressDuplicateAudioCodes": Ace15CompressDuplicateAudioCodesNode, | |
| "ACETricks TextEncodeAce15": TextEncodeAce15Node, | |
| "ACETricks RawTextEncodeAce15": RawTextEncodeAce15Node, | |
| "ACETricks Ace15LatentToAudioCodes": Ace15LatentToAudioCodesNode, | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment