Skip to content

Instantly share code, notes, and snippets.

@shhommychon
Last active September 11, 2025 15:19
Show Gist options
  • Select an option

  • Save shhommychon/a577a97c199d7b756d3bc2d1ac44c1ca to your computer and use it in GitHub Desktop.

Select an option

Save shhommychon/a577a97c199d7b756d3bc2d1ac44c1ca to your computer and use it in GitHub Desktop.
내가 쓰려고 남겨놓는 음성 업스케일링 유틸
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
.cursorignore
LICENSE
*.ipynb
.aiexclude
LICENSE
*.ipynb

my-speech-upscale-webui

이 프로젝트는 복잡한 설치 과정 없이 누구나 쉽게 사용할 수 있는 올인원(All-in-One) 음성 업스케일링 구글 Colab 노트북을 제공합니다. 코딩이나 AI에 대한 전문 지식이 없어도 단 한 번의 클릭으로 최신 음성 업스케일링 기술을 경험해보세요.

본 프로젝트는 AUTOMATIC1111/stable-diffusion-webui, oobabooga/text-generation-webui, liltom-eth/llama2-webui, jhj0517/Whisper-WebUI, rsxdalv/TTS-WebUI, 그리고 gitmylo/audio-webui와 같이 사용자 편의성을 극대화한 다른 성공적인 WebUI 프로젝트들에서 영감을 받아 제작되었습니다.

✨ 주요 특징

  • 간편한 사용: 별도의 프로그램 설치 없이, 구글 Colab에서 바로 실행할 수 있습니다.
  • 최신 모델 통합: 검증된 고성능 음성 업스케일링 모델들을 하나로 모았습니다.
  • 자동화된 환경: 버튼만 누르면 필요한 모든 것이 자동으로 준비됩니다.

🚀 사용 방법

  1. Open In Colab 배지를 클릭하여 구글 Colab 노트북을 엽니다.
  2. 노트북의 셀을 위에서부터 순서대로 실행하기만 하면 됩니다.
  3. 잠시 기다리면 WebUI가 나타나고, 여러분의 음성 파일을 업스케일링할 수 있습니다.

🧠 탑재된 모델

  • NU-Wave 2 (Seoul Natl. Univ.; 2022.06) (ArXiv, GitHub): STFC와 BSFT를 활용하여 다양한 입력 샘플링 레이트를 단일 모델로 처리하는 Diffusion 모델입니다.
  • AudioSR (Univ. of Surrey, Univ. of CA S.D., ByteDance; 2023.09) (ArXiv, GitHub): 음향 효과, 음악, 음성 등 다용도 오디오에 범용적으로 적용 가능한 Diffusion 기반 초해상도 모델입니다.
  • FlowHigh (Korea Univ.; 2025.01) (ArXiv, GitHub): 단일 스텝 샘플링으로 효율적인 오디오 생성을 목표로 하는 Flow Matching 기반 모델입니다.
import gc
import os
import random
import numpy as np
from scipy.signal.windows import hann
import soundfile as sf
import torch
from cog import BasePredictor, Input, Path
import tempfile
import argparse
import librosa
from audiosr import build_model, super_resolution
from scipy import signal
import pyloudnorm as pyln
import warnings
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "true"
torch.set_float32_matmul_precision("high")
def match_array_shapes(array_1:np.ndarray, array_2:np.ndarray):
if (len(array_1.shape) == 1) & (len(array_2.shape) == 1):
if array_1.shape[0] > array_2.shape[0]:
array_1 = array_1[:array_2.shape[0]]
elif array_1.shape[0] < array_2.shape[0]:
array_1 = np.pad(array_1, ((array_2.shape[0] - array_1.shape[0], 0)), 'constant', constant_values=0)
else:
if array_1.shape[1] > array_2.shape[1]:
array_1 = array_1[:,:array_2.shape[1]]
elif array_1.shape[1] < array_2.shape[1]:
padding = array_2.shape[1] - array_1.shape[1]
array_1 = np.pad(array_1, ((0,0), (0,padding)), 'constant', constant_values=0)
return array_1
def lr_filter(audio, cutoff, filter_type, order=12, sr=48000):
audio = audio.T
nyquist = 0.5 * sr
normal_cutoff = cutoff / nyquist
b, a = signal.butter(order//2, normal_cutoff, btype=filter_type, analog=False)
sos = signal.tf2sos(b, a)
filtered_audio = signal.sosfiltfilt(sos, audio)
return filtered_audio.T
class Predictor(BasePredictor):
def setup(self, model_name="basic", device="auto"):
self.model_name = model_name
self.device = device
self.sr = 48000
print("Loading Model...")
self.audiosr = build_model(model_name=self.model_name, device=self.device)
# print(self.audiosr)
# exit()
print("Model loaded!")
def process_audio(self, input_file, chunk_size=5.12, overlap=0.1, seed=None, guidance_scale=3.5, ddim_steps=50):
audio, sr = librosa.load(input_file, sr=input_cutoff * 2, mono=False)
audio = audio.T
sr = input_cutoff * 2
print(f"audio.shape = {audio.shape}")
print(f"input cutoff = {input_cutoff}")
is_stereo = len(audio.shape) == 2
audio_channels = [audio] if not is_stereo else [audio[:, 0], audio[:, 1]]
print("audio is stereo" if is_stereo else "Audio is mono")
chunk_samples = int(chunk_size * sr)
overlap_samples = int(overlap * chunk_samples)
output_chunk_samples = int(chunk_size * self.sr)
output_overlap_samples = int(overlap * output_chunk_samples)
enable_overlap = overlap > 0
print(f"enable_overlap = {enable_overlap}")
def process_chunks(audio):
chunks = []
original_lengths = []
start = 0
while start < len(audio):
end = min(start + chunk_samples, len(audio))
chunk = audio[start:end]
if len(chunk) < chunk_samples:
original_lengths.append(len(chunk))
chunk = np.concatenate([chunk, np.zeros(chunk_samples - len(chunk))])
else:
original_lengths.append(chunk_samples)
chunks.append(chunk)
start += chunk_samples - overlap_samples if enable_overlap else chunk_samples
return chunks, original_lengths
# Process both channels (mono or stereo)
chunks_per_channel = [process_chunks(channel) for channel in audio_channels]
sample_rate_ratio = self.sr / sr
total_length = len(chunks_per_channel[0][0]) * output_chunk_samples - (len(chunks_per_channel[0][0]) - 1) * (output_overlap_samples if enable_overlap else 0)
reconstructed_channels = [np.zeros((1, total_length)) for _ in audio_channels]
meter_before = pyln.Meter(sr)
meter_after = pyln.Meter(self.sr)
# Process chunks for each channel
for ch_idx, (chunks, original_lengths) in enumerate(chunks_per_channel):
for i, chunk in enumerate(chunks):
loudness_before = meter_before.integrated_loudness(chunk)
print(f"Processing chunk {i+1} of {len(chunks)} for {'Left/Mono' if ch_idx == 0 else 'Right'} channel")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
sf.write(temp_wav.name, chunk, sr)
out_chunk = super_resolution(
self.audiosr,
temp_wav.name,
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
latent_t_per_second=12.8
)
out_chunk = out_chunk[0]
num_samples_to_keep = int(original_lengths[i] * sample_rate_ratio)
out_chunk = out_chunk[:, :num_samples_to_keep].squeeze()
loudness_after = meter_after.integrated_loudness(out_chunk)
out_chunk = pyln.normalize.loudness(out_chunk, loudness_after, loudness_before)
if enable_overlap:
actual_overlap_samples = min(output_overlap_samples, num_samples_to_keep)
fade_out = np.linspace(1., 0., actual_overlap_samples)
fade_in = np.linspace(0., 1., actual_overlap_samples)
if i == 0:
out_chunk[-actual_overlap_samples:] *= fade_out
elif i < len(chunks) - 1:
out_chunk[:actual_overlap_samples] *= fade_in
out_chunk[-actual_overlap_samples:] *= fade_out
else:
out_chunk[:actual_overlap_samples] *= fade_in
start = i * (output_chunk_samples - output_overlap_samples if enable_overlap else output_chunk_samples)
end = start + out_chunk.shape[0]
reconstructed_channels[ch_idx][0, start:end] += out_chunk.flatten()
reconstructed_audio = np.stack(reconstructed_channels, axis=-1) if is_stereo else reconstructed_channels[0]
if multiband_ensemble:
low, _ = librosa.load(input_file, sr=48000, mono=False)
output = match_array_shapes(reconstructed_audio[0].T, low)
low = lr_filter(low.T, crossover_freq, 'lowpass', order=10)
high = lr_filter(output.T, crossover_freq, 'highpass', order=10)
high = lr_filter(high, 23000, 'lowpass', order=2)
output = low + high
else:
output = reconstructed_audio[0]
# print(output, type(output))
return output
def predict(self,
input_file: Path = Input(description="Audio to upsample"),
ddim_steps: int = Input(description="Number of inference steps", default=50, ge=10, le=500),
guidance_scale: float = Input(description="Scale for classifier free guidance", default=3.5, ge=1.0, le=20.0),
overlap: float = Input(description="overlap size", default=0.04),
chunk_size: float = Input(description="chunksize", default=10.24),
seed: int = Input(description="Random seed. Leave blank to randomize the seed", default=None)
) -> Path:
if seed == 0:
seed = random.randint(0, 2**32 - 1)
print(f"Setting seed to: {seed}")
print(f"overlap = {overlap}")
print(f"guidance_scale = {guidance_scale}")
print(f"ddim_steps = {ddim_steps}")
print(f"chunk_size = {chunk_size}")
print(f"multiband_ensemble = {multiband_ensemble}")
print(f"input file = {os.path.basename(input_file)}")
os.makedirs(output_folder, exist_ok=True)
waveform = self.process_audio(
input_file,
chunk_size=chunk_size,
overlap=overlap,
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps
)
filename = os.path.splitext(os.path.basename(input_file))[0]
sf.write(f"{output_folder}/SR_{filename}.wav", data=waveform, samplerate=48000, subtype="PCM_16")
print(f"file created: {output_folder}/SR_{filename}.wav")
del self.audiosr, waveform
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Find volume difference of two audio files.")
parser.add_argument("--input", help="Path to input audio file")
parser.add_argument("--output", help="Output folder")
# 모델을 선택할 수 있도록 --model_name 인자를 추가합니다. (기본값: 'speech')
parser.add_argument("--model_name", help="Model name: speech or basic", type=str, required=False, default="speech", choices=["speech", "basic"]) # ★ 추가
parser.add_argument("--ddim_steps", help="Number of ddim steps", type=int, required=False, default=50)
parser.add_argument("--chunk_size", help="chunk size", type=float, required=False, default=10.24)
parser.add_argument("--guidance_scale", help="Guidance scale value", type=float, required=False, default=3.5)
parser.add_argument("--seed", help="Seed value, 0 = random seed", type=int, required=False, default=0)
parser.add_argument("--overlap", help="overlap value", type=float, required=False, default=0.04)
parser.add_argument("--multiband_ensemble", type=bool, help="Use multiband ensemble with input")
parser.add_argument("--input_cutoff", help="Define the crossover of audio input in the multiband ensemble", type=int, required=False, default=12000)
args = parser.parse_args()
input_file_path = args.input
output_folder = args.output
ddim_steps = args.ddim_steps
chunk_size = args.chunk_size
guidance_scale = args.guidance_scale
seed = args.seed
overlap = args.overlap
input_cutoff = args.input_cutoff
multiband_ensemble = args.multiband_ensemble
crossover_freq = input_cutoff - 1000
p = Predictor()
# 커맨드 라인에서 받은 model_name을 setup에 전달하며,
# 기본값은 음성 모델 ('haoheliu/audiosr_speech') 입니다. (※ audiosr/utils.py#L393 참고.)
p.setup(model_name=args.model_name, device='auto') # ★ 추가
out = p.predict(
input_file_path,
ddim_steps=ddim_steps,
guidance_scale=guidance_scale,
seed=seed,
chunk_size=chunk_size,
overlap=overlap
)
del p
gc.collect()
torch.cuda.empty_cache()

█████████ ██████████ ██████ ██████ █████ ██████ █████ █████ ███░░░░░███░░███░░░░░█░░██████ ██████ ░░███ ░░██████ ░░███ ░░███ ███ ░░░ ░███ █ ░ ░███░█████░███ ░███ ░███░███ ░███ ░███ ░███ ░██████ ░███░░███ ░███ ░███ ░███░░███░███ ░███ ░███ █████ ░███░░█ ░███ ░░░ ░███ ░███ ░███ ░░██████ ░███ ░░███ ░░███ ░███ ░ █ ░███ ░███ ░███ ░███ ░░█████ ░███ ░░█████████ ██████████ █████ █████ █████ █████ ░░█████ █████ ░░░░░░░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░ ░░░░░ ░░░░░ ░░░░░

Gemini-CLI 개발 가이드: 음성 업스케일링 Colab 노트북

1. 프로젝트 목표

이 프로젝트의 핵심 목표는 코딩이나 AI에 대한 전문 지식이 없는 사용자도 손쉽게 사용할 수 있는 올인원(All-in-One) 음성 업스케일링 구글 Colab 노트북을 개발하는 것입니다.

사용자는 복잡한 환경 설정 과정 없이, 단일 Colab 파일을 실행하는 것만으로 최신 음성 업스케일링 모델을 활용할 수 있어야 합니다. 이는 사용자 편의성을 극대화한 다른 성공적인 WebUI 프로젝트들의 철학을 따릅니다.

본 프로젝트는 이미 검증된 고성능 오픈소스 모델들을 선별하여 통합하고 사용자 친화적인 인터페이스를 제공하는 데 집중합니다.

2. 기술 요약

  • 실행 환경: Google Colab
    • 별도의 로컬 환경 구축을 지양하고 접근성을 높이기 위해 채택합니다.
  • 핵심 모델: 다음의 3가지 최신 음성 업스케일링 모델을 통합합니다.
    • NU-Wave 2: STFC와 BSFT를 활용하여 다양한 입력 샘플링 레이트를 단일 모델로 처리하는 Diffusion 모델
    • AudioSR: 음향 효과, 음악, 음성 등 다용도 오디오에 범용적으로 적용 가능한 Diffusion 기반 초해상도 모델
    • FlowHigh: 단일 스텝 샘플링으로 효율적인 오디오 생성을 목표로 하는 Flow Matching 기반 모델

3. 구현 목표 상세

  • 단일 노트북 제공: 모든 환경 설정, 모델 다운로드 및 실행 로직이 포함된 단일 .ipynb 파일을 최종 결과물로 합니다. 방문자는 이 파일 하나만으로 모든 기능을 사용할 수 있어야 합니다.
  • 자동화된 환경 구성: 사용자가 노트북의 셀을 순서대로 실행하기만 하면, 필요한 라이브러리와 종속성 패키지가 자동으로 설치되도록 구현합니다.
  • 모델 통합 및 코드 현대화: NU-Wave 2, AudioSR, FlowHigh 모델이 현재의 구글 Colab Python 실행 환경에서 오류 없이 원활하게 작동하도록 코드를 검토하고 최신화합니다.
  • 사용자 정의 기능 구현: 개발 과정에서 유저가 제안하는 세부 기능들을 코드에 반영합니다. (예: 음성 파일을 읽어 최대 표현 가능 주파수를 추측하는 나이브한 로직)

4. Gemini의 역할

Gemini는 본 프로젝트의 AI 코딩 어시스턴트로서 다음 역할을 수행합니다.

  • 코드 작성 및 관리: 파일 처리, 상태 관리 등 프로젝트 운영에 필요한 핵심 Python 코드를 작성합니다. Colab 환경에 최적화된 스크립트를 구현하는 데 집중합니다.
  • 모델 코드 현대화 지원: 3가지 핵심 모델(NU-Wave 2, AudioSR, FlowHigh)의 소스 코드가 최신 라이브러리 및 Colab 환경과 호환되도록 수정하고 업데이트하는 작업을 주도적으로 지원합니다.
  • 요구 기능 구현: 사용자가 명시적으로 요구하는 기능(부가 기능, UI 로직 등)을 정확하게 코드로 구현합니다.
  • 디버깅 및 문제 해결: 코드 실행 중 발생하는 오류, 플랫폼 호환성 문제, 로직 버그 등을 식별하고 해결책을 제시합니다.
  • 작업 효율성: 토큰 절약을 위해, UI를 제외한 핵심 로직은 .ipynb의 스크립트 버전인 my-speech-upscale-webui.py 파일을 기준으로 작업을 진행합니다. 어시스턴트는 -my-speech-upscale-webui.ipynb에 대한 조회 권한이 주어지지 않습니다.
MOST OF IT IS NOT MY CODE PUBLIC LICENSE
:~!777!!!!!~~^::. .::^~~!!!7!!77!~:
!J7777!!!!!!!7777!!^ ^!!!777!!!!!!!7777J7.
!J7!7777!!777??!77???: :???777??777!!777777J7.
~~^^^^^^^^^^^~!7^77!~^!. .!^~!77^7!~^^^^^^^^^^^~!.
^!^^^^^^^^~!~!~7^~^( O)7 .(O )^~^7^!!!~^^^^^^^^!~
~7^^^^^^^^?!!??7.^!7J5777 7775J?!~.!??!!?^^^^^^^^!!
.~~^^^^^^^7!!???~::::^:^^~ ~^^:^^:.:~7??!~7^^^^^^^~!.
:~!~~~~~!7J7!777^~~~!~~!. .!~~!~~~~7?7!7J7!~~~~~!~:
!::::::!77!!77JJJ?JJJ: :JJJ?JJJ77!!777::::::7.
.!::::::!?!7!777??JJJ: :YJJ??777!77?7::::::!.
::7~^:..::^?7!!77!J777?? ??777J777!!7?^::..:^~7^:
.~JYB##BBGY7~:..^~!?7??!????. .????!7?7?!~^:.:^7YGBB###5J!.
~JP#&&&########G5?!~~!J!~77!7~ ~7!77~!J!~^!?YGB######&&&&#GJ!.
.:^~77?YPB###&&&&##BB#G^... ...^P#BB##&&&&###BG5J?7~^:..
.^~7Y5PGBBBBB#Y Y#BBBBBGPPY?!^.
..:::::: .::::::.
Man, I stole your code It's not my code
Version 0.1.0, May 2023
Copyright (c) 2023 Chon, Sung Hyu
Everyone is permitted to copy and distribute, modify this work as long as the
original licenses of this work is maintained.
Most of the contents of this work are under MIT, GPL, or some other major public
licenses. Use this work at your own risk.
If copyright infringement occurs, the author will pretend that the author does
not know anything, does not intend to defend you, and will claim that it's not
the author's fault.
MOST OF IT IS NOT MY CODE PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
1. You do what the original authors tell you that you can do.
2. If the original license or author is not mentioned in some parts of code,
those codes are probably written by the author. You can use those parts of
code however you would like to, but keep in mind that the author has very
little experience in programming. Do not blame the author when you use those
codes for some serious applications and something goes wrong.
from lightning_model import NuWave2
from omegaconf import OmegaConf as OC
import os
import argparse
import datetime
from glob import glob
import torch
import librosa as rosa
from scipy.io.wavfile import write as swrite
import matplotlib.pyplot as plt
from utils.stft import STFTMag
import numpy as np
from scipy.signal import sosfiltfilt
from scipy.signal import butter, cheby1, cheby2, ellip, bessel
from scipy.signal import resample_poly
import random
def save_stft_mag(wav, fname):
fig = plt.figure(figsize=(9, 3))
plt.imshow(rosa.amplitude_to_db(stft(wav[0].detach().cpu()).numpy(),
ref=np.max, top_db = 80.),
aspect='auto',
origin='lower',
interpolation='none')
plt.colorbar()
plt.xlabel('Frames')
plt.ylabel('Channels')
plt.tight_layout()
fig.savefig(fname, format='png')
plt.close()
return
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c',
'--checkpoint',
type=str,
required=True,
help="Checkpoint path")
parser.add_argument('-i',
'--wav',
type=str,
default=None,
help="audio")
parser.add_argument('--sr',
type=int,
required=True,
help="Sampling rate of input audio")
parser.add_argument('--steps',
type=int,
required=False,
help="Steps for sampling")
parser.add_argument('--gt', action="store_true",
required=False, help="Whether the input audio is 48 kHz ground truth audio.")
parser.add_argument('--device',
type=str,
default='cuda',
required=False,
help="Device, 'cuda' or 'cpu'")
args = parser.parse_args()
#torch.backends.cudnn.benchmark = False
hparams = OC.load('nuwave2/hparameter.yaml')
os.makedirs(hparams.log.test_result_dir, exist_ok=True)
if args.steps is None or args.steps == 8:
args.steps = 8
noise_schedule = eval(hparams.dpm.infer_schedule)
else:
noise_schedule = None
model = NuWave2(hparams).to(args.device)
model.eval()
stft = STFTMag()
# PyTorch 2.0+ 버전부터는 보안상의 이유로 weights_only=True가 기본값이 되었습니다.
# 공식 체크포인트는 신뢰할 수 있으므로, 이전 버전과 동일하게 동작하도록 weights_only=False로 설정합니다.
ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False) # ★ weights_only=False로 명시
model.load_state_dict(ckpt['state_dict'] if not('EMA' in args.checkpoint) else ckpt)
highcut = args.sr // 2
nyq = 0.5 * hparams.audio.sampling_rate
hi = highcut / nyq
if args.gt:
wav, _ = rosa.load(args.wav, sr=hparams.audio.sampling_rate, mono=True)
wav /= np.max(np.abs(wav))
wav = wav[:len(wav) - len(wav) % hparams.audio.hop_length]
order = 8
sos = cheby1(order, 0.05, hi, btype='lowpass', output='sos')
wav_l = sosfiltfilt(sos, wav)
# downsample to the low sampling rate
wav_l = resample_poly(wav_l, highcut * 2, hparams.audio.sampling_rate)
# upsample to the original sampling rate
wav_l = resample_poly(wav_l, hparams.audio.sampling_rate, highcut * 2)
if len(wav_l) < len(wav):
wav_l = np.pad(wav, (0, len(wav) - len(wav_l)), 'constant', constant_values=0)
elif len(wav_l) > len(wav):
wav_l = wav_l[:len(wav)]
else:
wav, _ = rosa.load(args.wav, sr=args.sr, mono=True)
wav /= np.max(np.abs(wav))
# upsample to the original sampling rate
wav_l = resample_poly(wav, hparams.audio.sampling_rate, args.sr)
wav_l = wav_l[:len(wav_l) - len(wav_l) % hparams.audio.hop_length]
fft_size = hparams.audio.filter_length // 2 + 1
band = torch.zeros(fft_size, dtype=torch.int64)
band[:int(hi * fft_size)] = 1
wav = torch.from_numpy(wav).unsqueeze(0).to(args.device)
wav_l = torch.from_numpy(wav_l.copy()).float().unsqueeze(0).to(args.device)
band = band.unsqueeze(0).to(args.device)
wav_recon, wav_list = model.inference(wav_l, band, args.steps, noise_schedule)
wav = torch.clamp(wav, min=-1, max=1 - torch.finfo(torch.float16).eps)
save_stft_mag(wav, os.path.join(hparams.log.test_result_dir, f'wav.png'))
if args.gt:
swrite(os.path.join(hparams.log.test_result_dir, f'wav.wav'),
hparams.audio.sampling_rate, wav[0].detach().cpu().numpy())
else:
swrite(os.path.join(hparams.log.test_result_dir, f'wav.wav'),
args.sr, wav[0].detach().cpu().numpy())
wav_l = torch.clamp(wav_l, min=-1, max=1 - torch.finfo(torch.float16).eps)
save_stft_mag(wav_l, os.path.join(hparams.log.test_result_dir, f'wav_l.png'))
swrite(os.path.join(hparams.log.test_result_dir, f'wav_l.wav'),
hparams.audio.sampling_rate, wav_l[0].detach().cpu().numpy())
wav_recon = torch.clamp(wav_recon, min=-1, max=1 - torch.finfo(torch.float16).eps)
save_stft_mag(wav_recon, os.path.join(hparams.log.test_result_dir, f'result.png'))
swrite(os.path.join(hparams.log.test_result_dir, f'result.wav'),
hparams.audio.sampling_rate, wav_recon[0].detach().cpu().numpy())
# for i in range(len(wav_list)):
# wav_recon_i = torch.clamp(wav_list[i], min=-1, max=1-torch.finfo(torch.float16).eps)
# save_stft_mag(wav_recon_i, os.path.join(hparams.log.test_result_dir, f'result_{i}.png'))
# swrite(os.path.join(hparams.log.test_result_dir, f'result_{i}.wav'),
# hparams.audio.sampling_rate, wav_recon_i[0].detach().cpu().numpy())
#Some codes are adopted from
#https://github.com/ivanvovk/WaveGrad
#https://github.com/lmnt-com/diffwave
#https://github.com/NVlabs/SPADE
#https://github.com/pkumivision/FFC
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from math import sqrt, log
Linear = nn.Linear
silu = F.silu
relu = F.relu
def Conv1d(*args, **kwargs):
layer = nn.Conv1d(*args, **kwargs)
nn.init.kaiming_normal_(layer.weight)
return layer
def Conv2d(*args, **kwargs):
layer = nn.Conv2d(*args, **kwargs)
nn.init.kaiming_normal_(layer.weight)
return layer
class DiffusionEmbedding(nn.Module):
def __init__(self, hparams):
super().__init__()
self.n_channels = hparams.dpm.pos_emb_channels
self.linear_scale = hparams.dpm.pos_emb_scale
self.out_channels = hparams.arch.pos_emb_dim
self.projection1 = Linear(self.n_channels, self.out_channels)
self.projection2 = Linear(self.out_channels, self.out_channels)
def forward(self, noise_level):
if len(noise_level.shape) > 1:
noise_level = noise_level.squeeze(-1)
half_dim = self.n_channels // 2
emb = log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32).to(noise_level) * -emb)
emb = self.linear_scale * noise_level.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
emb = self.projection1(emb)
emb = silu(emb)
emb = self.projection2(emb)
emb = silu(emb)
return emb
class BSFT(nn.Module):
def __init__(self, nhidden, out_channels):
super().__init__()
self.mlp_shared = nn.Conv1d(2, nhidden, kernel_size=3, padding=1)
self.mlp_gamma = Conv1d(nhidden, out_channels, kernel_size=3, padding=1)
self.mlp_beta = Conv1d(nhidden, out_channels, kernel_size=3, padding=1)
def forward(self, x, band):
# band: (B, 2, n_fft // 2 + 1)
actv = silu(self.mlp_shared(band))
gamma = self.mlp_gamma(actv).unsqueeze(-1)
beta = self.mlp_beta(actv).unsqueeze(-1)
# apply scale and bias
out = x * (1 + gamma) + beta
return out
class FourierUnit(nn.Module):
def __init__(self, in_channels, out_channels, bsft_channels, filter_length=1024, hop_length=256, win_length=1024,
sampling_rate=48000):
# bn_layer not used
super(FourierUnit, self).__init__()
self.sampling_rate = sampling_rate
self.n_fft = filter_length
self.hop_size = hop_length
self.win_size = win_length
hann_window = torch.hann_window(win_length)
self.register_buffer('hann_window', hann_window)
self.conv_layer = Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2,
kernel_size=1, padding=0, bias=False)
self.bsft = BSFT(bsft_channels, out_channels * 2)
def forward(self, x, band):
batch = x.shape[0]
x = x.view(-1, x.size()[-1])
# PyTorch 1.9+ 버전 호환성을 위한 stft/istft 처리 방식 변경
# 1. stft는 복소수 텐서를 반환 (return_complex=True)
ffted = torch.stft(
x, self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=True,
normalized=True,
onesided=True,
return_complex=True, # ★ False에서 True로 변경
)
# 2. 컨볼루션 연산을 위해 복소수 텐서를 실수형으로 변환 (채널 차원 추가)
ffted = torch.view_as_real(ffted) # ★ 추가
ffted = ffted.permute(0, 3, 1, 2).contiguous() # (BC, 2, n_fft/2+1, T)
ffted = ffted.view((batch, -1,) + ffted.size()[2:]) # (B, 2C, n_fft/2+1, T)
ffted = relu(self.bsft(ffted, band)) # (B, 2C, n_fft/2+1, T)
ffted = self.conv_layer(ffted)
# 3. istft를 위해 다시 복소수 텐서로 변환
ffted = ffted.view((-1, 2,) + ffted.size()[2:]).permute(0, 2, 3, 1).contiguous() # (BC, n_fft/2+1, T, 2)
ffted = torch.view_as_complex(ffted) # ★ 추가
output = torch.istft(ffted, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window,
center=True, normalized=True, onesided=True)
output = output.view(batch, -1, x.size()[-1])
return output
class SpectralTransform(nn.Module):
def __init__(self, in_channels, out_channels, bsft_channels, **audio_kwargs):
# bn_layer not used
super(SpectralTransform, self).__init__()
self.conv1 = Conv1d(
in_channels, out_channels // 2, kernel_size=1, bias=False)
self.fu = FourierUnit(out_channels // 2, out_channels // 2, bsft_channels, **audio_kwargs)
self.conv2 = Conv1d(
out_channels // 2, out_channels, kernel_size=1, bias=False)
def forward(self, x, band):
x = silu(self.conv1(x))
output = self.fu(x, band)
output = self.conv2(x + output)
return output
class FFC(nn.Module): # STFC
def __init__(self, in_channels, out_channels, bsft_channels, kernel_size=3,
ratio_gin=0.5, ratio_gout=0.5, padding=1,
**audio_kwargs):
super(FFC, self).__init__()
in_cg = int(in_channels * ratio_gin)
in_cl = in_channels - in_cg
out_cg = int(out_channels * ratio_gout)
out_cl = out_channels - out_cg
self.ratio_gin = ratio_gin
self.ratio_gout = ratio_gout
self.global_in_num = in_cg
self.convl2l = Conv1d(in_cl, out_cl, kernel_size, padding=padding, bias=False)
self.convl2g = Conv1d(in_cl, out_cg, kernel_size, padding=padding, bias=False)
self.convg2l = Conv1d(in_cg, out_cl, kernel_size, padding=padding, bias=False)
self.convg2g = SpectralTransform(in_cg, out_cg, bsft_channels, **audio_kwargs)
def forward(self, x_l, x_g, band):
out_xl = self.convl2l(x_l) + self.convg2l(x_g)
out_xg = self.convl2g(x_l) + self.convg2g(x_g, band)
return out_xl, out_xg
class ResidualBlock(nn.Module):
def __init__(self, residual_channels, pos_emb_dim, bsft_channels, **audio_kwargs):
super().__init__()
self.ffc1 = FFC(residual_channels, 2*residual_channels, bsft_channels,
kernel_size=3, ratio_gin=0.5, ratio_gout=0.5, padding=1, **audio_kwargs) # STFC
self.diffusion_projection = Linear(pos_emb_dim, residual_channels)
self.output_projection = Conv1d(residual_channels,
2 * residual_channels, 1)
def forward(self, x, band, noise_level):
noise_level = self.diffusion_projection(noise_level).unsqueeze(-1)
y = x + noise_level
y_l, y_g = torch.split(y, [y.shape[1] - self.ffc1.global_in_num, self.ffc1.global_in_num], dim=1)
y_l, y_g = self.ffc1(y_l, y_g, band) # STFC
gate_l, filter_l = torch.chunk(y_l, 2, dim=1)
gate_g, filter_g = torch.chunk(y_g, 2, dim=1)
gate, filter = torch.cat((gate_l, gate_g), dim=1), torch.cat((filter_l, filter_g), dim=1)
y = torch.sigmoid(gate) * torch.tanh(filter)
y = self.output_projection(y)
residual, skip = torch.chunk(y, 2, dim=1)
return (x + residual) / sqrt(2.0), skip
class NuWave2(nn.Module):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.input_projection = Conv1d(2, hparams.arch.residual_channels, 1)
self.diffusion_embedding = DiffusionEmbedding(
hparams)
audio_kwargs = dict(filter_length = hparams.audio.filter_length, hop_length = hparams.audio.hop_length,
win_length = hparams.audio.win_length, sampling_rate = hparams.audio.sampling_rate)
self.residual_layers = nn.ModuleList([
ResidualBlock(hparams.arch.residual_channels,
hparams.arch.pos_emb_dim,
hparams.arch.bsft_channels,
**audio_kwargs)
for i in range(hparams.arch.residual_layers)
])
self.len_res = len(self.residual_layers)
self.skip_projection = Conv1d(hparams.arch.residual_channels,
hparams.arch.residual_channels, 1)
self.output_projection = Conv1d(hparams.arch.residual_channels, 1, 1)
def forward(self, audio, audio_low, band, noise_level):
x = torch.stack((audio, audio_low), dim=1)
x = self.input_projection(x)
x = silu(x)
noise_level = self.diffusion_embedding(noise_level)
band = F.one_hot(band).transpose(1, -1).float()
#This way is more faster!
#skip = []
skip =0.
for layer in self.residual_layers:
x, skip_connection = layer(x, band, noise_level)
#skip.append(skip_connection)
skip += skip_connection
#x = torch.sum(torch.stack(skip), dim=0) / sqrt(self.len_res)
x = skip / sqrt(self.len_res)
x = self.skip_projection(x)
x = silu(x)
x = self.output_projection(x).squeeze(1)
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
class STFTMag(nn.Module):
def __init__(self,
nfft=1024,
hop=256):
super().__init__()
self.nfft = nfft
self.hop = hop
self.register_buffer('window', torch.hann_window(nfft), False)
#x: [B,T] or [T]
@torch.no_grad()
def forward(self, x):
T = x.shape[-1]
# PyTorch 1.9+ 부터 torch.stft는 복소수 텐서를 반환하는 것이 기본값이 되었습니다.
# return_complex=True로 명시하고, torch.abs()를 사용하여 크기를 계산합니다.
# 이전 방식(return_complex=False)은 torch.norm(stft, p=2, dim =-1)을 사용했습니다.
stft = torch.stft(
x,
self.nfft,
self.hop,
window=self.window,
return_complex=True, # ★ return_complex=True로 명시
)
mag = torch.abs(stft) # ★ torch.abs()로 크기 계산
return mag
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment