Created
December 28, 2023 02:40
-
-
Save gaspardpetit/b51ab30c8de0e638e3917823964be779 to your computer and use it in GitHub Desktop.
Diarization sample using NVidia NeMo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import wget | |
| import json | |
| import logging | |
| from omegaconf import OmegaConf | |
| from pyannote.core import Annotation | |
| import torch | |
| from nemo.collections.asr.models import ClusteringDiarizer, NeuralDiarizer | |
| from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object | |
| LOG = logging.getLogger(__name__) | |
| class DiarizationService: | |
| """ | |
| A class for speaker diarization using NeMo's ASR models. | |
| Attributes: | |
| data_dir (str): Directory for storing data files. | |
| output_dir (str): Directory for storing output files. | |
| input_manifest_file (str): Path to the input manifest file. | |
| inference_config_file (str): Path to the inference configuration file. | |
| config (OmegaConf): Configuration object for diarization. | |
| predicted_rttm_file (str): Path to the predicted RTTM file. | |
| """ | |
| def __init__(self): | |
| ROOT = os.getcwd() | |
| self.data_dir = os.path.join(ROOT, 'data') | |
| os.makedirs(self.data_dir, exist_ok=True) | |
| self.output_dir = os.path.join(ROOT, 'outputs') | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| # configure | |
| self.input_manifest_file = os.path.join(self.data_dir, 'input_manifest.json') | |
| self.inference_config_file = os.path.join(self.data_dir, 'diar_infer_meeting.yaml') | |
| self.config = self._configure(self.input_manifest_file, self.inference_config_file, self.output_dir) | |
| self.predicted_rttm_file = None | |
| def _create_manifest(self, audio_path: str) -> dict: | |
| """ | |
| Create a manifest for diarization. | |
| Args: | |
| audio_path (str): Path to the input audio file. | |
| Returns: | |
| dict: Manifest information. | |
| """ | |
| audio_file_name_no_ext = os.path.splitext(os.path.basename(audio_path))[0] | |
| self.predicted_rttm_file = f'{self.output_dir}/pred_rttms/{audio_file_name_no_ext}.rttm' | |
| LOG.debug("##### create manifest") | |
| meta = { | |
| 'audio_filepath': audio_path, | |
| 'offset': 0, | |
| 'duration': None, | |
| 'label': 'infer', | |
| 'text': '-', | |
| 'num_speakers': 5, | |
| 'rttm_filepath': self.predicted_rttm_file, | |
| 'uem_filepath': None | |
| } | |
| with open(self.input_manifest_file, 'w') as fp: | |
| json.dump(meta, fp) | |
| fp.write('\n') | |
| return meta | |
| @staticmethod | |
| def _get_device() -> torch.device: | |
| """ | |
| Retrieves the device (CPU or CUDA) to use for diarization. | |
| Returns: | |
| torch.device: Device for diarization. | |
| """ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| LOG.info(f"Using device: {device.type}") | |
| return device | |
| @staticmethod | |
| def _load_model_config(model_name, inference_config_file: str) -> OmegaConf: | |
| """ | |
| Load model configuration from a file or download if not available. | |
| Args: | |
| model_name (str): Name of the model. | |
| inference_config_file (str): Path to the inference configuration file. | |
| Returns: | |
| OmegaConf: Model configuration. | |
| """ | |
| LOG.debug("##### LOAD MODEL") | |
| if not os.path.exists(inference_config_file): | |
| inference_config_dir = os.path.dirname(inference_config_file) | |
| config_url = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{model_name}" | |
| inference_config_file = wget.download(config_url, inference_config_dir) | |
| config = OmegaConf.load(inference_config_file) | |
| LOG.debug(OmegaConf.to_yaml(config)) | |
| return config | |
| @staticmethod | |
| def _configure(input_manifest_file, inference_config_file, output_dir) -> OmegaConf: | |
| """ | |
| Configure the diarization process. | |
| Args: | |
| input_manifest_file (str): Path to the input manifest file. | |
| inference_config_file (str): Path to the inference configuration file. | |
| output_dir (str): Directory for storing intermediate files and prediction outputs. | |
| Returns: | |
| OmegaConf: Configuration object for diarization. | |
| """ | |
| LOG.debug("##### CONFIGURE") | |
| config = DiarizationService._load_model_config('diar_infer_meeting.yaml', inference_config_file) | |
| config.device = DiarizationService._get_device().type | |
| config.diarizer.clustering.parameters.oracle_num_speakers = False | |
| config.diarizer.manifest_filepath = input_manifest_file | |
| config.diarizer.msdd_model.model_path = 'diar_msdd_telephonic' | |
| config.diarizer.msdd_model.parameters.sigmoid_threshold = [0.7, 1.0] | |
| config.diarizer.oracle_vad = False | |
| config.diarizer.out_dir = output_dir | |
| config.diarizer.speaker_embeddings.model_path = 'titanet_large' | |
| config.diarizer.speaker_embeddings.parameters.multiscale_weights = [1, 1, 1, 1, 1] | |
| config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = [0.75, 0.625, 0.5, 0.375, 0.1] | |
| config.diarizer.speaker_embeddings.parameters.window_length_in_sec = [1.5, 1.25, 1.0, 0.75, 0.5] | |
| config.diarizer.vad.model_path = 'vad_multilingual_marblenet' | |
| config.diarizer.vad.parameters.offset = 0.6 | |
| config.diarizer.vad.parameters.onset = 0.8 | |
| config.diarizer.vad.parameters.pad_offset = -0.05 | |
| config.num_workers = 1 | |
| return config | |
| def _diarize_cluster(self) -> Annotation: | |
| """ | |
| Perform diarization using the clustering-based diarizer. | |
| Returns: | |
| Annotation: Diarization result. | |
| """ | |
| sd_model = ClusteringDiarizer(cfg=self.config).to(self.config.device) | |
| LOG.debug("Clustering Diarizer...") | |
| sd_model.diarize() | |
| LOG.debug("Clustering Diarizer... Done") | |
| pred_labels_neural = rttm_to_labels(self.predicted_rttm_file) | |
| hypothesis_neural: Annotation = labels_to_pyannote_object(pred_labels_neural) | |
| LOG.debug(f"Clustering Diarizer Result (RTTM format)\n{hypothesis_neural}") | |
| return hypothesis_neural | |
| def _diarize_neural(self) -> Annotation: | |
| """ | |
| Perform diarization using the neural-based diarizer. | |
| Returns: | |
| Annotation: Diarization result. | |
| """ | |
| system_vad_msdd_model = NeuralDiarizer(cfg=self.config).to(self.config.device) | |
| LOG.debug("Neural Diarizer...") | |
| system_vad_msdd_model.diarize() | |
| LOG.debug("Neural Diarizer... Done") | |
| pred_labels_neural = rttm_to_labels(self.predicted_rttm_file) | |
| hypothesis_neural: Annotation = labels_to_pyannote_object(pred_labels_neural) | |
| LOG.info(f"Neural Diarizer Result (RTTM format)\n{hypothesis_neural}") | |
| return hypothesis_neural | |
| def diarize(self, audio_path) -> Annotation: | |
| """ | |
| Perform diarization on the given audio file. | |
| Args: | |
| audio_path (str): Path to the input audio file. | |
| Returns: | |
| Annotation: Diarization result. | |
| """ | |
| manifest = self._create_manifest(audio_path) | |
| cluster_diarization = self._diarize_cluster() | |
| neural_diarization = self._diarize_neural() | |
| audio_file_name_no_ext = os.path.splitext(os.path.basename(audio_path))[0] | |
| neural_diarization.uri = audio_file_name_no_ext | |
| return neural_diarization | |
| def main(): | |
| """ | |
| The main function that performs diarization on an audio file. | |
| """ | |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s][%(filename)s:%(lineno)d][%(funcName)s] %(message)s', datefmt='%Y-%m-%dT%H:%M:%SZ') | |
| audio_file = "test.wav" | |
| diarization: Annotation = DiarizationService().diarize(audio_file) | |
| turn_str = "" | |
| for turn, _, speaker in diarization.itertracks(yield_label=True): | |
| turn_str += f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}\n" | |
| LOG.info(turn_str) | |
| rttm_name = diarization.uri + ".rttm" | |
| LOG.info(f"saving to {rttm_name}") | |
| with open(rttm_name, "w") as rttm: | |
| diarization.write_rttm(rttm) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment