Skip to content

Instantly share code, notes, and snippets.

@gaspardpetit
Created December 27, 2023 03:48
Show Gist options
  • Select an option

  • Save gaspardpetit/e2af3728d922239e0a6ec80e53fb5f58 to your computer and use it in GitHub Desktop.

Select an option

Save gaspardpetit/e2af3728d922239e0a6ec80e53fb5f58 to your computer and use it in GitHub Desktop.
diarization using pyannote
import os
import torch
import torchaudio
import logging
from pyannote.audio import Pipeline
from pyannote.core import Annotation
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from pyannote.audio.pipelines.utils.hook import ProgressHook
# Environment variable to use for setting the HuggingFace Token
ENV_HUGGINGFACE_TOKEN : str = "HUGGINGFACE_TOKEN"
LOG: logging.Logger = logging.getLogger(__name__)
class DiarizationPipeline(object):
"""
DiarizationPipeline is a singleton class that represents the diarization pipeline.
Attributes:
- device (torch.device): The device (CPU or CUDA) to use for diarization.
- pipeline (Pipeline): The diarization pipeline.
Methods:
- _get_huggingface_token(): Retrieves the Hugging Face token from the environment variables.
- _get_device(): Retrieves the device to use for diarization.
- _load_pipeline(device): Loads the diarization pipeline.
- _init_once(): Initializes the DiarizationPipeline singleton instance.
"""
__instance = None
def __new__(cls):
if cls.__instance is None:
cls.__instance = super(DiarizationPipeline, cls).__new__(cls)
cls.__instance._init_once()
return cls.__instance
@staticmethod
def _get_huggingface_token() -> str:
"""Retrieves the Hugging Face token from the environment variables."""
token = os.getenv(ENV_HUGGINGFACE_TOKEN)
if not token:
raise EnvironmentError(f"{ENV_HUGGINGFACE_TOKEN} environment variable is not set")
return token
@staticmethod
def _get_device() -> torch.device:
"""Retrieves the device (CPU or CUDA) to use for diarization."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOG.info(f"Using device: {device.type}")
return device
@staticmethod
def _load_pipeline(device: torch.device) -> Pipeline:
"""Loads the diarization pipeline."""
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=DiarizationPipeline._get_huggingface_token()
).to(device)
return pipeline
def _init_once(self):
"""Initializes the DiarizationPipeline singleton instance."""
self.device : torch.device = DiarizationPipeline._get_device()
self.pipeline : Pipeline = DiarizationPipeline._load_pipeline(self.device)
class DiarizationService:
"""
Diarization is a class that performs speaker diarization on audio files.
Attributes:
- diarization (DiarizationPipeline): The diarization pipeline instance.
Methods:
- _load_audio(audio_path): Loads the audio waveform and sample rate from an audio file.
- diarize(audio_path): Performs diarization on the specified audio file.
"""
def __init__(self):
self.diarization: DiarizationPipeline = DiarizationPipeline()
@staticmethod
def _load_audio(audio_path: str) -> tuple:
"""Loads the audio waveform and sample rate from an audio file."""
waveform, sample_rate = torchaudio.load(audio_path)
return waveform, sample_rate
def diarize(self, audio_path: str) -> Annotation:
"""
Performs diarization on the specified audio file.
Returns:
- diarization: The diarization results.
"""
waveform, sample_rate = DiarizationService._load_audio(audio_path)
with ProgressHook() as hook:
diarization : Annotation = self.diarization.pipeline({
"waveform": waveform,
"sample_rate": sample_rate
},
hook=hook)
# Set the diarization file id
from urllib.parse import quote
audio_name_without_extension = os.path.splitext(os.path.basename(audio_path))[0]
diarization.uri = quote(audio_name_without_extension)
return diarization
def main():
"""
The main function that performs diarization on an audio file.
"""
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s][%(filename)s:%(lineno)d][%(funcName)s] %(message)s', datefmt='%Y-%m-%dT%H:%M:%SZ')
audio_file = "test.wav"
if (ENV_HUGGINGFACE_TOKEN not in os.environ):
os.environ[ENV_HUGGINGFACE_TOKEN] = '<huggingface_token>'
diarization = DiarizationService().diarize(audio_file)
turn_str = ""
for turn, _, speaker in diarization.itertracks(yield_label=True):
turn_str += f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}\n"
LOG.info(turn_str)
rttm_name = diarization.uri + ".rttm"
LOG.info(f"saving to {rttm_name}")
with open(rttm_name, "w") as rttm:
diarization.write_rttm(rttm)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment