Created
December 27, 2023 03:48
-
-
Save gaspardpetit/e2af3728d922239e0a6ec80e53fb5f58 to your computer and use it in GitHub Desktop.
diarization using pyannote
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| import torchaudio | |
| import logging | |
| from pyannote.audio import Pipeline | |
| from pyannote.core import Annotation | |
| from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization | |
| from pyannote.audio.pipelines.utils.hook import ProgressHook | |
| # Environment variable to use for setting the HuggingFace Token | |
| ENV_HUGGINGFACE_TOKEN : str = "HUGGINGFACE_TOKEN" | |
| LOG: logging.Logger = logging.getLogger(__name__) | |
| class DiarizationPipeline(object): | |
| """ | |
| DiarizationPipeline is a singleton class that represents the diarization pipeline. | |
| Attributes: | |
| - device (torch.device): The device (CPU or CUDA) to use for diarization. | |
| - pipeline (Pipeline): The diarization pipeline. | |
| Methods: | |
| - _get_huggingface_token(): Retrieves the Hugging Face token from the environment variables. | |
| - _get_device(): Retrieves the device to use for diarization. | |
| - _load_pipeline(device): Loads the diarization pipeline. | |
| - _init_once(): Initializes the DiarizationPipeline singleton instance. | |
| """ | |
| __instance = None | |
| def __new__(cls): | |
| if cls.__instance is None: | |
| cls.__instance = super(DiarizationPipeline, cls).__new__(cls) | |
| cls.__instance._init_once() | |
| return cls.__instance | |
| @staticmethod | |
| def _get_huggingface_token() -> str: | |
| """Retrieves the Hugging Face token from the environment variables.""" | |
| token = os.getenv(ENV_HUGGINGFACE_TOKEN) | |
| if not token: | |
| raise EnvironmentError(f"{ENV_HUGGINGFACE_TOKEN} environment variable is not set") | |
| return token | |
| @staticmethod | |
| def _get_device() -> torch.device: | |
| """Retrieves the device (CPU or CUDA) to use for diarization.""" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| LOG.info(f"Using device: {device.type}") | |
| return device | |
| @staticmethod | |
| def _load_pipeline(device: torch.device) -> Pipeline: | |
| """Loads the diarization pipeline.""" | |
| pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", | |
| use_auth_token=DiarizationPipeline._get_huggingface_token() | |
| ).to(device) | |
| return pipeline | |
| def _init_once(self): | |
| """Initializes the DiarizationPipeline singleton instance.""" | |
| self.device : torch.device = DiarizationPipeline._get_device() | |
| self.pipeline : Pipeline = DiarizationPipeline._load_pipeline(self.device) | |
| class DiarizationService: | |
| """ | |
| Diarization is a class that performs speaker diarization on audio files. | |
| Attributes: | |
| - diarization (DiarizationPipeline): The diarization pipeline instance. | |
| Methods: | |
| - _load_audio(audio_path): Loads the audio waveform and sample rate from an audio file. | |
| - diarize(audio_path): Performs diarization on the specified audio file. | |
| """ | |
| def __init__(self): | |
| self.diarization: DiarizationPipeline = DiarizationPipeline() | |
| @staticmethod | |
| def _load_audio(audio_path: str) -> tuple: | |
| """Loads the audio waveform and sample rate from an audio file.""" | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| return waveform, sample_rate | |
| def diarize(self, audio_path: str) -> Annotation: | |
| """ | |
| Performs diarization on the specified audio file. | |
| Returns: | |
| - diarization: The diarization results. | |
| """ | |
| waveform, sample_rate = DiarizationService._load_audio(audio_path) | |
| with ProgressHook() as hook: | |
| diarization : Annotation = self.diarization.pipeline({ | |
| "waveform": waveform, | |
| "sample_rate": sample_rate | |
| }, | |
| hook=hook) | |
| # Set the diarization file id | |
| from urllib.parse import quote | |
| audio_name_without_extension = os.path.splitext(os.path.basename(audio_path))[0] | |
| diarization.uri = quote(audio_name_without_extension) | |
| return diarization | |
| def main(): | |
| """ | |
| The main function that performs diarization on an audio file. | |
| """ | |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s][%(filename)s:%(lineno)d][%(funcName)s] %(message)s', datefmt='%Y-%m-%dT%H:%M:%SZ') | |
| audio_file = "test.wav" | |
| if (ENV_HUGGINGFACE_TOKEN not in os.environ): | |
| os.environ[ENV_HUGGINGFACE_TOKEN] = '<huggingface_token>' | |
| diarization = DiarizationService().diarize(audio_file) | |
| turn_str = "" | |
| for turn, _, speaker in diarization.itertracks(yield_label=True): | |
| turn_str += f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}\n" | |
| LOG.info(turn_str) | |
| rttm_name = diarization.uri + ".rttm" | |
| LOG.info(f"saving to {rttm_name}") | |
| with open(rttm_name, "w") as rttm: | |
| diarization.write_rttm(rttm) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment