Skip to content

Instantly share code, notes, and snippets.

@gaspardpetit
Created December 28, 2023 02:40
Show Gist options
  • Select an option

  • Save gaspardpetit/b51ab30c8de0e638e3917823964be779 to your computer and use it in GitHub Desktop.

Select an option

Save gaspardpetit/b51ab30c8de0e638e3917823964be779 to your computer and use it in GitHub Desktop.
Diarization sample using NVidia NeMo
import os
import wget
import json
import logging
from omegaconf import OmegaConf
from pyannote.core import Annotation
import torch
from nemo.collections.asr.models import ClusteringDiarizer, NeuralDiarizer
from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object
LOG = logging.getLogger(__name__)
class DiarizationService:
"""
A class for speaker diarization using NeMo's ASR models.
Attributes:
data_dir (str): Directory for storing data files.
output_dir (str): Directory for storing output files.
input_manifest_file (str): Path to the input manifest file.
inference_config_file (str): Path to the inference configuration file.
config (OmegaConf): Configuration object for diarization.
predicted_rttm_file (str): Path to the predicted RTTM file.
"""
def __init__(self):
ROOT = os.getcwd()
self.data_dir = os.path.join(ROOT, 'data')
os.makedirs(self.data_dir, exist_ok=True)
self.output_dir = os.path.join(ROOT, 'outputs')
os.makedirs(self.output_dir, exist_ok=True)
# configure
self.input_manifest_file = os.path.join(self.data_dir, 'input_manifest.json')
self.inference_config_file = os.path.join(self.data_dir, 'diar_infer_meeting.yaml')
self.config = self._configure(self.input_manifest_file, self.inference_config_file, self.output_dir)
self.predicted_rttm_file = None
def _create_manifest(self, audio_path: str) -> dict:
"""
Create a manifest for diarization.
Args:
audio_path (str): Path to the input audio file.
Returns:
dict: Manifest information.
"""
audio_file_name_no_ext = os.path.splitext(os.path.basename(audio_path))[0]
self.predicted_rttm_file = f'{self.output_dir}/pred_rttms/{audio_file_name_no_ext}.rttm'
LOG.debug("##### create manifest")
meta = {
'audio_filepath': audio_path,
'offset': 0,
'duration': None,
'label': 'infer',
'text': '-',
'num_speakers': 5,
'rttm_filepath': self.predicted_rttm_file,
'uem_filepath': None
}
with open(self.input_manifest_file, 'w') as fp:
json.dump(meta, fp)
fp.write('\n')
return meta
@staticmethod
def _get_device() -> torch.device:
"""
Retrieves the device (CPU or CUDA) to use for diarization.
Returns:
torch.device: Device for diarization.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOG.info(f"Using device: {device.type}")
return device
@staticmethod
def _load_model_config(model_name, inference_config_file: str) -> OmegaConf:
"""
Load model configuration from a file or download if not available.
Args:
model_name (str): Name of the model.
inference_config_file (str): Path to the inference configuration file.
Returns:
OmegaConf: Model configuration.
"""
LOG.debug("##### LOAD MODEL")
if not os.path.exists(inference_config_file):
inference_config_dir = os.path.dirname(inference_config_file)
config_url = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{model_name}"
inference_config_file = wget.download(config_url, inference_config_dir)
config = OmegaConf.load(inference_config_file)
LOG.debug(OmegaConf.to_yaml(config))
return config
@staticmethod
def _configure(input_manifest_file, inference_config_file, output_dir) -> OmegaConf:
"""
Configure the diarization process.
Args:
input_manifest_file (str): Path to the input manifest file.
inference_config_file (str): Path to the inference configuration file.
output_dir (str): Directory for storing intermediate files and prediction outputs.
Returns:
OmegaConf: Configuration object for diarization.
"""
LOG.debug("##### CONFIGURE")
config = DiarizationService._load_model_config('diar_infer_meeting.yaml', inference_config_file)
config.device = DiarizationService._get_device().type
config.diarizer.clustering.parameters.oracle_num_speakers = False
config.diarizer.manifest_filepath = input_manifest_file
config.diarizer.msdd_model.model_path = 'diar_msdd_telephonic'
config.diarizer.msdd_model.parameters.sigmoid_threshold = [0.7, 1.0]
config.diarizer.oracle_vad = False
config.diarizer.out_dir = output_dir
config.diarizer.speaker_embeddings.model_path = 'titanet_large'
config.diarizer.speaker_embeddings.parameters.multiscale_weights = [1, 1, 1, 1, 1]
config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = [0.75, 0.625, 0.5, 0.375, 0.1]
config.diarizer.speaker_embeddings.parameters.window_length_in_sec = [1.5, 1.25, 1.0, 0.75, 0.5]
config.diarizer.vad.model_path = 'vad_multilingual_marblenet'
config.diarizer.vad.parameters.offset = 0.6
config.diarizer.vad.parameters.onset = 0.8
config.diarizer.vad.parameters.pad_offset = -0.05
config.num_workers = 1
return config
def _diarize_cluster(self) -> Annotation:
"""
Perform diarization using the clustering-based diarizer.
Returns:
Annotation: Diarization result.
"""
sd_model = ClusteringDiarizer(cfg=self.config).to(self.config.device)
LOG.debug("Clustering Diarizer...")
sd_model.diarize()
LOG.debug("Clustering Diarizer... Done")
pred_labels_neural = rttm_to_labels(self.predicted_rttm_file)
hypothesis_neural: Annotation = labels_to_pyannote_object(pred_labels_neural)
LOG.debug(f"Clustering Diarizer Result (RTTM format)\n{hypothesis_neural}")
return hypothesis_neural
def _diarize_neural(self) -> Annotation:
"""
Perform diarization using the neural-based diarizer.
Returns:
Annotation: Diarization result.
"""
system_vad_msdd_model = NeuralDiarizer(cfg=self.config).to(self.config.device)
LOG.debug("Neural Diarizer...")
system_vad_msdd_model.diarize()
LOG.debug("Neural Diarizer... Done")
pred_labels_neural = rttm_to_labels(self.predicted_rttm_file)
hypothesis_neural: Annotation = labels_to_pyannote_object(pred_labels_neural)
LOG.info(f"Neural Diarizer Result (RTTM format)\n{hypothesis_neural}")
return hypothesis_neural
def diarize(self, audio_path) -> Annotation:
"""
Perform diarization on the given audio file.
Args:
audio_path (str): Path to the input audio file.
Returns:
Annotation: Diarization result.
"""
manifest = self._create_manifest(audio_path)
cluster_diarization = self._diarize_cluster()
neural_diarization = self._diarize_neural()
audio_file_name_no_ext = os.path.splitext(os.path.basename(audio_path))[0]
neural_diarization.uri = audio_file_name_no_ext
return neural_diarization
def main():
"""
The main function that performs diarization on an audio file.
"""
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s][%(filename)s:%(lineno)d][%(funcName)s] %(message)s', datefmt='%Y-%m-%dT%H:%M:%SZ')
audio_file = "test.wav"
diarization: Annotation = DiarizationService().diarize(audio_file)
turn_str = ""
for turn, _, speaker in diarization.itertracks(yield_label=True):
turn_str += f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}\n"
LOG.info(turn_str)
rttm_name = diarization.uri + ".rttm"
LOG.info(f"saving to {rttm_name}")
with open(rttm_name, "w") as rttm:
diarization.write_rttm(rttm)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment