Created
January 28, 2026 18:19
-
-
Save praeclarum/1c295701c530d2009ff0c3d7554fa138 to your computer and use it in GitHub Desktop.
Frank's little neural network library
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| from collections import namedtuple | |
| from typing import Any, Callable, Dict, Optional | |
| import json | |
| import torch | |
| from torch import Tensor | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms.v2 | |
| # | |
| # CONFIG | |
| # | |
| Config = Dict[str, Any] | |
| registered_types: Dict[str, Callable] = {} | |
| def register_type(name: str, ctor: Callable): | |
| registered_types[name] = ctor | |
| def object_from_config(config: Config, **kwargs) -> Any: | |
| if not isinstance(config, dict): | |
| raise ValueError("Config must be a dictionary.") | |
| t = config["type"] if "type" in config else kwargs.get("type") | |
| if t in registered_types: | |
| new_config = dict(config) | |
| new_config.update(kwargs) | |
| new_kwargs = dict(new_config) | |
| del new_kwargs["type"] | |
| new_object = registered_types[t](**new_kwargs) | |
| new_object.__config = new_config | |
| return new_object | |
| else: | |
| raise ValueError(f"Unknown config object type: {t}") | |
| def save_config(config: Config, filename: str): | |
| with open(filename, "w") as f: | |
| json.dump(config, f, indent=4) | |
| def load_config(filename: str) -> Config: | |
| with open(filename, "r") as f: | |
| return json.load(f) | |
| def load_config_object(filename: str) -> Any: | |
| return object_from_config(load_config(filename)) | |
| def save_module(obj: nn.Module, filename: str): | |
| out_dict = { | |
| "config": obj.__config, | |
| "state_dict": obj.state_dict(), | |
| } | |
| torch.save(out_dict, filename) | |
| def load_module(filename: str, **kwargs) -> nn.Module: | |
| in_dict = torch.load(filename) | |
| obj = object_from_config(in_dict["config"], **kwargs) | |
| obj.load_state_dict(in_dict["state_dict"]) | |
| return obj | |
| # | |
| # NN | |
| # | |
| def get_normalization(normalization: str, num_channels: int) -> nn.Module: | |
| if normalization == "None": | |
| return nn.Identity() | |
| if normalization == "BatchNorm1d": | |
| return nn.BatchNorm1d(num_channels) | |
| if normalization == "BatchNorm2d": | |
| return nn.BatchNorm2d(num_channels) | |
| if normalization == "InstanceNorm1d": | |
| return nn.InstanceNorm1d() | |
| if normalization == "InstanceNorm1d": | |
| return nn.InstanceNorm2d() | |
| if normalization == "GroupNorm32": | |
| return nn.GroupNorm(num_groups=32, num_channels=num_channels) | |
| if normalization == "GroupNorm24": | |
| return nn.GroupNorm(num_groups=24, num_channels=num_channels) | |
| if normalization == "GroupNorm16": | |
| return nn.GroupNorm(num_groups=16, num_channels=num_channels) | |
| if normalization == "GroupNorm8": | |
| return nn.GroupNorm(num_groups=8, num_channels=num_channels) | |
| raise NotImplementedError(f"Unknown normalization: {normalization}") | |
| def get_activation(activation: str) -> nn.Module: | |
| if activation is None or activation == "None": | |
| return nn.Identity() | |
| if activation == "LeakyReLU": | |
| return nn.LeakyReLU() | |
| if activation == "ReLU": | |
| return nn.ReLU() | |
| if activation == "SiLU": | |
| return nn.SiLU() | |
| if activation == "Tanh" or activation == "tanh": | |
| return nn.Tanh() | |
| if activation == "Sigmoid": | |
| return nn.Sigmoid() | |
| if activation == "Softmax": | |
| return nn.Softmax() | |
| if activation == "GELU" or activation == "gelu": | |
| return nn.GELU() | |
| raise NotImplementedError(f"Unknown activation: {activation}") | |
| def zero_module(module: nn.Module, should_zero: bool = True): | |
| """ | |
| Zero out the parameters of a module and return it. | |
| """ | |
| if should_zero: | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| normalization: str, | |
| activation: str, | |
| input_kernel_size=3, | |
| kernel_size=3, | |
| zero_out=True, | |
| ): | |
| super().__init__() | |
| self.in_layers = nn.Sequential( | |
| get_normalization(normalization, in_channels), | |
| get_activation(activation), | |
| nn.Conv2d(in_channels, out_channels, kernel_size=input_kernel_size, padding=input_kernel_size//2), | |
| ) | |
| self.out_layers = nn.Sequential( | |
| get_normalization(normalization, out_channels), | |
| get_activation(activation), | |
| zero_module( | |
| nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2), | |
| should_zero=zero_out, | |
| ), | |
| ) | |
| if out_channels == in_channels: | |
| self.skip_connection = nn.Identity() | |
| else: | |
| self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1) | |
| def forward(self, x): | |
| h = self.in_layers(x) | |
| h = self.out_layers(h) | |
| skip = self.skip_connection(x) | |
| out = skip + h | |
| return out | |
| register_type("ResBlock", ResBlock) | |
| class UNet(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| model_channels: int, | |
| ch_mult: list[int], | |
| normalization: str, | |
| activation: str, | |
| num_res_blocks: int, | |
| output_activation: str, | |
| in_kernel_size: int, | |
| out_kernel_size: int, | |
| inject_decoder_noise_channels: int, | |
| ): | |
| super().__init__() | |
| self.dtype = torch.float32 | |
| self.model_channels = model_channels | |
| self.inject_decoder_noise_channels = inject_decoder_noise_channels | |
| self.input_downsample_count = len(ch_mult) | |
| self.input_block = nn.Sequential( | |
| nn.Conv2d(in_channels, model_channels, kernel_size=in_kernel_size, padding=in_kernel_size//2), | |
| ) | |
| enc_blocks = [] | |
| dec_blocks = [] | |
| noise_blocks = [] | |
| down_blocks = [] | |
| up_blocks = [] | |
| ch = model_channels | |
| def res_block(in_ch, out_ch): | |
| blocks = [ResBlock(in_ch, out_ch, normalization=normalization, activation=activation)] | |
| for _ in range(num_res_blocks - 1): | |
| blocks.append(ResBlock(out_ch, out_ch, normalization=normalization, activation=activation)) | |
| return nn.Sequential(*blocks) | |
| for i, ch_m in enumerate(ch_mult): | |
| out_ch = model_channels * ch_m | |
| enc_blocks.append(res_block(ch, out_ch)) | |
| skip_ch = out_ch | |
| prev_ch = 0 if i == len(ch_mult) - 1 else model_channels * ch_mult[i + 1] | |
| if i < len(ch_mult) - 1: | |
| down_blocks.append(nn.AvgPool2d(2)) | |
| dec_blocks.append(res_block(skip_ch + prev_ch + inject_decoder_noise_channels, out_ch)) | |
| if inject_decoder_noise_channels > 0: | |
| noise_blocks.append(ScaledImageNoise(inject_decoder_noise_channels)) | |
| should_upsample = i > 0 | |
| up_blocks.append(nn.UpsamplingBilinear2d(scale_factor=2) if should_upsample else nn.Identity()) | |
| ch = out_ch | |
| self.enc_blocks = nn.ModuleList(enc_blocks) | |
| self.down_blocks = nn.ModuleList(down_blocks) | |
| self.dec_blocks = nn.ModuleList(dec_blocks) | |
| self.up_blocks = nn.ModuleList(up_blocks) | |
| self.noise_blocks = nn.ModuleList(noise_blocks) | |
| self.output_block = nn.Sequential( | |
| res_block(model_channels, model_channels), | |
| get_normalization(normalization, model_channels), | |
| get_activation(activation), | |
| zero_module(nn.Conv2d(model_channels, out_channels, kernel_size=out_kernel_size, padding=out_kernel_size//2)), | |
| get_activation(output_activation), | |
| ) | |
| def forward(self, x: Tensor): | |
| h = self.input_block(x) | |
| res_hs = [] | |
| for i, enc_block in enumerate(self.enc_blocks): | |
| h = enc_block(h) | |
| res_hs.append(h) | |
| if i < len(self.down_blocks): | |
| h = self.down_blocks[i](h) | |
| i = len(self.dec_blocks) - 1 | |
| while i >= 0: | |
| if i < len(self.dec_blocks) - 1: | |
| items = [h, res_hs[i]] | |
| if self.inject_decoder_noise_channels > 0: | |
| items.append(self.noise_blocks[i](h)) | |
| h = torch.cat(items, dim=1) | |
| h = self.dec_blocks[i](h) | |
| else: | |
| if self.inject_decoder_noise_channels > 0: | |
| h = torch.cat([h, self.noise_blocks[i](h)], dim=1) | |
| h = self.dec_blocks[i](h) | |
| h = self.up_blocks[i](h) | |
| i -= 1 | |
| y = self.output_block(h) | |
| return y | |
| register_type("UNet", UNet) | |
| class ScaledImageNoise(nn.Module): | |
| def __init__(self, out_channels: int): | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.linear = zero_module(nn.Conv2d(out_channels, out_channels, kernel_size=1)) | |
| def forward(self, x: Tensor) -> Tensor: | |
| b, c, h, w = x.shape | |
| noise = torch.randn(b, self.out_channels, h, w, device=x.device, dtype=x.dtype) | |
| noise = self.linear(noise) | |
| return noise | |
| register_type("ScaledImageNoise", ScaledImageNoise) | |
| # | |
| # LOSS | |
| # | |
| def simple_loss(loss_type, pred, target): | |
| if loss_type == "mse": | |
| return torch.nn.functional.mse_loss(pred, target) | |
| elif loss_type == "mae" or loss_type == "l1": | |
| return torch.nn.functional.l1_loss(pred, target) | |
| else: | |
| raise NotImplementedError(f"Unknown loss_type: {loss_type}") | |
| class UncertaintyLoss(nn.Module): | |
| """ | |
| u_loss = 1/(2*sigma**2)*loss + log(sigma) | |
| u_loss = 1 / (2 * torch.exp(2 * log_sigma)) * loss + log_sigma | |
| u_loss = 0.5 * torch.exp(-2 * log_sigma) * loss + log_sigma | |
| lambda = 0.5 * torch.exp(-2 * log_sigma) | |
| log(2 * lambda) = -2 * log_sigma | |
| log_sigma = -0.5 * log(2 * lambda) | |
| """ | |
| def __init__(self, init_lambda: float = 1.0, enabled: bool = True): | |
| super().__init__() | |
| self.init_lambda = max(init_lambda, 1.0e-9) | |
| self.enabled = enabled and init_lambda > 0 | |
| init_log_sigma = -0.5 * math.log(2.0 * self.init_lambda) | |
| self.log_sigma = nn.Parameter(torch.tensor(init_log_sigma)) | |
| @property | |
| def lambda_value(self) -> Tensor: | |
| return 0.5 * torch.exp(-2.0 * self.log_sigma) | |
| def forward(self, loss: Tensor) -> Tensor: | |
| if self.enabled: | |
| lambda_ = self.lambda_value | |
| return lambda_ * loss + self.log_sigma | |
| else: | |
| return self.init_lambda * loss | |
| class SobelEdgeDetection(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # Sobel kernels for x and y directions | |
| self.register_buffer('kernel_x', torch.tensor([ | |
| [-1.0, 0.0, 1.0], | |
| [-2.0, 0.0, 2.0], | |
| [-1.0, 0.0, 1.0] | |
| ]).view(1, 1, 3, 3)) | |
| self.register_buffer('kernel_y', torch.tensor([ | |
| [-1.0, -2.0, -1.0], | |
| [0.0, 0.0, 0.0], | |
| [1.0, 2.0, 1.0] | |
| ]).view(1, 1, 3, 3)) | |
| def forward(self, x): | |
| """ | |
| Apply Sobel edge detection to input images, preserving color channels | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (B, C, H, W) with values in range [-1, 1] | |
| Returns: | |
| torch.Tensor: Edge magnitude map of shape (B, C, H, W) | |
| """ | |
| # Ensure input is in correct range | |
| batch_size, channels, h, w = x.shape | |
| # Apply padding to maintain input dimensions | |
| padded = F.pad(x, (1, 1, 1, 1), mode='reflect') | |
| # Process each color channel independently | |
| grad_x = F.conv2d( | |
| padded.view(batch_size * channels, 1, h+2, w+2), | |
| self.kernel_x, | |
| groups=1 | |
| ) | |
| grad_y = F.conv2d( | |
| padded.view(batch_size * channels, 1, h+2, w+2), | |
| self.kernel_y, | |
| groups=1 | |
| ) | |
| # Compute gradient magnitude per channel | |
| magnitude = torch.sqrt(grad_x.pow(2) + grad_y.pow(2) + 1e-10) # Small epsilon for numerical stability | |
| # Reshape back to batch, channel, height, width | |
| magnitude = magnitude.view(batch_size, channels, h, w) | |
| return magnitude | |
| class LPIPS(nn.Module): | |
| """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" | |
| # Learned perceptual metric | |
| def __init__(self, use_dropout=True): | |
| super().__init__() | |
| self.scaling_layer = ScalingLayer() | |
| self.chns = [64, 128, 256, 512, 512] # vg16 features | |
| self.net = vgg16(pretrained=True, requires_grad=False) | |
| self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) | |
| self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) | |
| self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) | |
| self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) | |
| self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) | |
| self.load_from_pretrained() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def load_from_pretrained(self): | |
| ckpt = "vgg.pth" | |
| self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) | |
| print("loaded pretrained LPIPS loss from {}".format(ckpt)) | |
| @classmethod | |
| def from_pretrained(cls, name="vgg_lpips"): | |
| if name != "vgg_lpips": | |
| raise NotImplementedError | |
| model = cls() | |
| ckpt = "vgg.pth" | |
| model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) | |
| return model | |
| def forward(self, input, target): | |
| in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) | |
| outs0, outs1 = self.net(in0_input), self.net(in1_input) | |
| feats0, feats1, diffs = {}, {}, {} | |
| lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] | |
| for kk in range(len(self.chns)): | |
| feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) | |
| diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 | |
| res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] | |
| val = res[0] | |
| for l in range(1, len(self.chns)): | |
| val += res[l] | |
| return val | |
| class ScalingLayer(nn.Module): | |
| def __init__(self): | |
| super(ScalingLayer, self).__init__() | |
| self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) | |
| self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) | |
| def forward(self, inp): | |
| return (inp - self.shift) / self.scale | |
| class NetLinLayer(nn.Module): | |
| """ A single linear layer which does a 1x1 conv """ | |
| def __init__(self, chn_in, chn_out=1, use_dropout=False): | |
| super(NetLinLayer, self).__init__() | |
| layers = [nn.Dropout(), ] if (use_dropout) else [] | |
| layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] | |
| self.model = nn.Sequential(*layers) | |
| class vgg16(torch.nn.Module): | |
| def __init__(self, requires_grad=False, pretrained=True): | |
| super(vgg16, self).__init__() | |
| vgg_pretrained_features = torchvision.models.vgg16(pretrained=pretrained).features | |
| self.slice1 = torch.nn.Sequential() | |
| self.slice2 = torch.nn.Sequential() | |
| self.slice3 = torch.nn.Sequential() | |
| self.slice4 = torch.nn.Sequential() | |
| self.slice5 = torch.nn.Sequential() | |
| self.N_slices = 5 | |
| for x in range(4): | |
| self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(4, 9): | |
| self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(9, 16): | |
| self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(16, 23): | |
| self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(23, 30): | |
| self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
| if not requires_grad: | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward_latents(self, X): | |
| h = self.slice1(X) | |
| h_relu1_2 = h | |
| h = self.slice2(h) | |
| h_relu2_2 = h | |
| h = self.slice3(h) | |
| h_relu3_3 = h | |
| h = self.slice4(h) | |
| h_relu4_3 = h | |
| h = self.slice5(h) | |
| h_relu5_3 = h | |
| vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) | |
| out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) | |
| return out | |
| def forward(self, X): | |
| return self.forward_latents(X) | |
| class vgg16_short(nn.Module): | |
| def __init__(self, requires_grad=True, pretrained=True): | |
| super(vgg16_short, self).__init__() | |
| vgg = vgg16(requires_grad=requires_grad, pretrained=pretrained) | |
| self.features = nn.Sequential( | |
| vgg.slice1, | |
| vgg.slice2, | |
| vgg.slice3, | |
| vgg.slice4, | |
| # vgg.slice5, | |
| ) | |
| # print(self.features) | |
| def forward(self, x): | |
| return self.features(x) | |
| def normalize_tensor(x,eps=1e-10): | |
| norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) | |
| return x/(norm_factor+eps) | |
| def spatial_average(x, keepdim=True): | |
| return x.mean([2,3],keepdim=keepdim) | |
| class PatchDiscriminator(nn.Module): | |
| """ | |
| Discriminator network based on a pretrained ResNet, outputting a binary real/fake signal patches | |
| """ | |
| def __init__( | |
| self, | |
| backbone: str, | |
| freeze_backbone: bool, | |
| ): | |
| super().__init__() | |
| self.needs_imagenet_scale = True | |
| if backbone == "vgg16": | |
| model = torchvision.models.vgg16(pretrained=True) | |
| num_features = 512 | |
| self.backbone = model.features | |
| elif backbone == "vgg16_short": | |
| num_features = 512 | |
| self.backbone = vgg16_short() | |
| elif backbone == "mobilenet_v3_small": | |
| model = torchvision.models.mobilenet_v3_small(pretrained=True) | |
| num_features = 576 | |
| self.backbone = model.features | |
| elif backbone == "mobilenet_v3_large": | |
| model = torchvision.models.mobilenet_v3_large(pretrained=True) | |
| num_features = 960 | |
| self.backbone = model.features | |
| else: | |
| self.needs_imagenet_scale = False | |
| self.backbone = object_from_config(backbone) | |
| num_features = self.backbone.out_channels | |
| if freeze_backbone: | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| self.output_layer = nn.Sequential( | |
| zero_module(nn.Conv2d(num_features, 1, kernel_size=1, padding=0)), | |
| ) | |
| self.register_buffer("imagenet_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).contiguous()) | |
| self.register_buffer("imagenet_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).contiguous()) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: Tensor of shape (B, C, H, W) with values in [-1, 1] | |
| Returns: | |
| Tensor of shape (B, 1) with values in [0, 1], | |
| where 1 indicates real and 0 indicates fake | |
| """ | |
| if self.needs_imagenet_scale: | |
| x = (x + 1) / 2 | |
| x = (x - self.imagenet_mean) / self.imagenet_std | |
| x = self.backbone(x) | |
| x = self.output_layer(x) | |
| return x | |
| register_type("PatchDiscriminator", PatchDiscriminator) | |
| class MultiScalePatchDiscriminator(nn.Module): | |
| """ | |
| Multi-resolution PatchGAN discriminator that operates at multiple scales | |
| """ | |
| def __init__( | |
| self, | |
| num_scales: int, | |
| downsample_mode: str, | |
| backbone: str, | |
| freeze_backbone: bool, | |
| num_skip_res: int, | |
| ): | |
| super().__init__() | |
| self.num_scales = num_scales | |
| self.num_skip_res = num_skip_res | |
| self.downsample_mode = downsample_mode | |
| self.discriminators = nn.ModuleList([ | |
| PatchDiscriminator( | |
| backbone=backbone, | |
| freeze_backbone=freeze_backbone | |
| ) for _ in range(num_scales) | |
| ]) | |
| def downsample(self, x): | |
| """Downsample input for next scale""" | |
| if self.downsample_mode == "avgpool": | |
| return F.avg_pool2d(x, kernel_size=2, stride=2) | |
| elif self.downsample_mode == "bilinear": | |
| return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) | |
| else: | |
| raise ValueError(f"Unknown downsample mode: {self.downsample_mode}") | |
| def forward(self, x): | |
| """ | |
| Returns predictions from all discriminators at different scales | |
| Args: | |
| x: Input image tensor of shape (B, C, H, W) | |
| Returns: | |
| List of prediction maps from each discriminator | |
| """ | |
| results = [] | |
| input_downsampled = x | |
| for i, d in enumerate(self.discriminators): | |
| for _ in range(self.num_skip_res): | |
| input_downsampled = self.downsample(input_downsampled) | |
| results.append(d(input_downsampled)) | |
| if i < self.num_scales - 1: | |
| input_downsampled = self.downsample(input_downsampled) | |
| return results | |
| register_type("MultiScalePatchDiscriminator", MultiScalePatchDiscriminator) | |
| def discriminator_hinge_loss(real_preds, fake_preds): | |
| """ | |
| Hinge loss for discriminator | |
| Args: | |
| real_preds: List of prediction tensors for real images | |
| fake_preds: List of prediction tensors for fake images | |
| """ | |
| loss = 0.0 | |
| for real_pred, fake_pred in zip(real_preds, fake_preds): | |
| real_loss = torch.mean(F.relu(1.0 - real_pred)) | |
| fake_loss = torch.mean(F.relu(1.0 + fake_pred)) | |
| loss += (real_loss + fake_loss) * 0.5 | |
| return loss / len(real_preds) | |
| def generator_hinge_loss(fake_preds): | |
| """ | |
| Hinge loss for generator | |
| Args: | |
| fake_preds: List of prediction tensors for fake images | |
| """ | |
| loss = 0.0 | |
| for fake_pred in fake_preds: | |
| loss += -torch.mean(fake_pred) | |
| return loss / len(fake_preds) | |
| # | |
| # MODEL | |
| # | |
| class DenoiseModel(nn.Module): | |
| def __init__( | |
| self, | |
| input_width: int, | |
| input_height: int, | |
| model_channels: int, | |
| ch_mult: list[int], | |
| normalization: str, | |
| activation: str, | |
| num_res_blocks: int, | |
| output_upsample_count: int, | |
| sample_batch_size: int, | |
| train_batch_size: int, | |
| train_minibatch_size: int, | |
| learning_rate: float, | |
| lpips_lambda: float, | |
| edge_lambda: float, | |
| rec_loss_type: str, | |
| train_dtype: str, | |
| discr_learning_rate: float = 0.0001, | |
| train_images: int = 1_000_000, | |
| gan_lambda: float = 0.1, | |
| gan_start_images: int = 0, | |
| gan_loss_start_images: int = 500, | |
| in_channels: int = 3, | |
| out_channels: int = 3, | |
| output_activation: str = "tanh", | |
| latent_model: Optional[str] = None, | |
| ): | |
| super().__init__() | |
| self.dtype = torch.float32 | |
| self.input_width = input_width | |
| self.input_height = input_height | |
| self.model_channels = model_channels | |
| self.output_upsample_count = output_upsample_count | |
| self.output_width = input_width * (2 ** output_upsample_count) | |
| self.output_height = input_height * (2 ** output_upsample_count) | |
| self.input_downsample_count = len(ch_mult) | |
| self.train_images = train_images | |
| self.train_batch_size = train_batch_size | |
| self.train_minibatch_size = train_minibatch_size | |
| self.sample_batch_size = sample_batch_size | |
| self.learning_rate = learning_rate | |
| self.discr_learning_rate = discr_learning_rate | |
| self.lpips_lambda = lpips_lambda | |
| self.edge_lambda = edge_lambda | |
| self.gan_lambda = gan_lambda | |
| self.gan_start_images = gan_start_images | |
| self.gan_loss_start_images = gan_loss_start_images | |
| self.rec_loss_type = rec_loss_type | |
| self.input_block = nn.Sequential( | |
| nn.Conv2d(in_channels, model_channels, kernel_size=5, padding=2), | |
| ) | |
| enc_blocks = [] | |
| dec_blocks = [] | |
| down_blocks = [] | |
| up_blocks = [] | |
| ch = model_channels | |
| def res_block(in_ch, out_ch): | |
| blocks = [ResBlock(in_ch, out_ch, normalization=normalization, activation=activation)] | |
| for _ in range(num_res_blocks - 1): | |
| blocks.append(ResBlock(out_ch, out_ch, normalization=normalization, activation=activation)) | |
| return nn.Sequential(*blocks) | |
| for i, ch_m in enumerate(ch_mult): | |
| out_ch = model_channels * ch_m | |
| enc_blocks.append(res_block(ch, out_ch)) | |
| skip_ch = out_ch | |
| prev_ch = 0 if i == len(ch_mult) - 1 else model_channels * ch_mult[i + 1] | |
| if i < len(ch_mult) - 1: | |
| down_blocks.append(nn.AvgPool2d(2)) | |
| dec_blocks.append(res_block(skip_ch + prev_ch, out_ch)) | |
| should_upsample = i > 0 or output_upsample_count > 0 | |
| up_blocks.append(nn.UpsamplingBilinear2d(scale_factor=2) if should_upsample else nn.Identity()) | |
| ch = out_ch | |
| self.enc_blocks = nn.ModuleList(enc_blocks) | |
| self.down_blocks = nn.ModuleList(down_blocks) | |
| self.dec_blocks = nn.ModuleList(dec_blocks) | |
| self.up_blocks = nn.ModuleList(up_blocks) | |
| if output_upsample_count not in [0, 1]: | |
| raise NotImplementedError("Only output_upsample_count=0,1 is supported.") | |
| self.output_block = nn.Sequential( | |
| res_block(model_channels, model_channels), | |
| get_normalization(normalization, model_channels), | |
| get_activation(activation), | |
| nn.Conv2d(model_channels, out_channels, kernel_size=3, padding=1), | |
| nn.Tanh() if output_activation == "tanh" else nn.Identity(), | |
| ) | |
| def forward(self, x): | |
| h = self.input_block(x) | |
| res_hs = [] | |
| for i, enc_block in enumerate(self.enc_blocks): | |
| h = enc_block(h) | |
| res_hs.append(h) | |
| if i < len(self.down_blocks): | |
| h = self.down_blocks[i](h) | |
| i = len(self.dec_blocks) - 1 | |
| while i >= 0: | |
| if i < len(self.dec_blocks) - 1: | |
| h = self.dec_blocks[i](torch.cat([h, res_hs[i]], dim=1)) | |
| else: | |
| h = self.dec_blocks[i](h) | |
| h = self.up_blocks[i](h) | |
| i -= 1 | |
| y = self.output_block(h) | |
| return y | |
| register_type("DenoiseModel", DenoiseModel) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment