Last active
July 27, 2025 10:29
-
-
Save BHznJNs/1f774f33b2cb12d0e5207fc97e8935c3 to your computer and use it in GitHub Desktop.
A model that can captures the environment noise sample audio, and use it to reduce the noise for frames, then do VAD on frames, finally writes the audio data that contains a sentence into a .wav file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import threading | |
| import pyaudio | |
| from abc import ABC, abstractmethod | |
| from contextlib import contextmanager | |
| from loguru import logger | |
| class BaseRecorder(ABC): | |
| def __init__(self, frame_size: int, channels: int): | |
| self.frame_size = frame_size | |
| self.format = pyaudio.paInt16 | |
| self.channels = channels | |
| self.sample_rate = 16000 | |
| self.sample_size = pyaudio.PyAudio().get_sample_size(self.format) | |
| self._is_recording = False | |
| self._audio_thread = None | |
| self._lock = threading.RLock() | |
| @property | |
| def is_recording(self): | |
| with self._lock: | |
| return self._is_recording | |
| def start(self): | |
| with self._lock: | |
| if self._is_recording: | |
| logger.info("Already recording.") | |
| return | |
| self._is_recording = True | |
| self._audio_thread = threading.Thread(target=self._record) | |
| self._audio_thread.start() | |
| logger.info("Recording started.") | |
| def stop(self): | |
| with self._lock: | |
| if not self._is_recording: | |
| logger.info("Not recording.") | |
| return | |
| self._is_recording = False | |
| if self._audio_thread: | |
| self._audio_thread.join() | |
| logger.info("Recording stopped.") | |
| def shutdown(self): | |
| self.stop() | |
| @contextmanager | |
| def _audio_resources(self): | |
| p = pyaudio.PyAudio() | |
| stream = p.open(format=self.format, | |
| channels=self.channels, | |
| rate=self.sample_rate, | |
| input=True, | |
| frames_per_buffer=self.frame_size) | |
| try: | |
| yield p, stream | |
| finally: | |
| stream.stop_stream() | |
| stream.close() | |
| p.terminate() | |
| @abstractmethod | |
| def _record(self): | |
| pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import collections | |
| import tempfile | |
| import threading | |
| import wave | |
| import webrtcvad | |
| import itertools | |
| from pathlib import Path | |
| from typing import Callable, Iterable | |
| from queue import Queue, Empty as QueueEmptyException | |
| from loguru import logger | |
| from .BaseRecorder import BaseRecorder | |
| class RealtimeRecorder(BaseRecorder): | |
| def __init__(self, on_record_result: Callable[[str], None],): | |
| self.sample_rate = 16000 | |
| self.frame_duration_ms = 30 | |
| self.frame_size = int(self.sample_rate * self.frame_duration_ms / 1000) | |
| super().__init__(frame_size=self.frame_size, channels=1) | |
| pre_sentence_silence_duration_ms = 300 | |
| post_sentence_silence_duration_ms = 1200 | |
| self._pre_padding_frames = int(pre_sentence_silence_duration_ms / self.frame_duration_ms) | |
| self._post_padding_frames = int(post_sentence_silence_duration_ms / self.frame_duration_ms) | |
| self._vad = webrtcvad.Vad(1) | |
| self._VAD_TRIGGER_THRESHOLD = 0.9 | |
| self._frame_queue: Queue[bytes] = Queue() | |
| self._segment_id = 0 | |
| self._segment_max_id = 128 | |
| self._temp_dir = Path(tempfile.mkdtemp(prefix="gibberish_temp_")) | |
| self._on_record_result = on_record_result | |
| def start(self): | |
| super().start() | |
| self._segment_id = 0 | |
| self._vad_collector_thread = threading.Thread(target=self._vad_collector) | |
| self._vad_collector_thread.start() | |
| def stop(self, cancel: bool = False): | |
| super().stop(cancel) | |
| self._frame_queue.put(b"") | |
| self._vad_collector_thread.join() | |
| def shutdown(self): | |
| super().shutdown() | |
| import shutil | |
| shutil.rmtree(self._temp_dir, ignore_errors=True) | |
| def _recording_filepath_generator(self, id: int) -> str: | |
| return str(self._temp_dir / f"{id}.wav") | |
| def _save_speech_segment(self, frames: list[bytes]) -> str: | |
| output_path = self._recording_filepath_generator(self._segment_id) | |
| self._segment_id = (self._segment_id + 1) % self._segment_max_id | |
| with wave.open(output_path, "wb") as wf: | |
| wf.setsampwidth(self.sample_size) | |
| wf.setnchannels(self.channels) | |
| wf.setframerate(self.sample_rate) | |
| for frame in frames: | |
| wf.writeframes(frame) | |
| logger.debug(f"Saved speech segment to {output_path}") | |
| return output_path | |
| def _vad_collector(self): | |
| def calculate_voice_rate(frame_iter: Iterable[tuple[bytes, bool]], size: int) -> float: | |
| voiced_frame_count = sum(1 for frame in frame_iter if frame[1]) | |
| return voiced_frame_count / size | |
| def get_trigger_view() -> itertools.islice | None: | |
| nonlocal window_buffer | |
| window_buffer_len = len(window_buffer) | |
| if window_buffer_len < self._pre_padding_frames: | |
| return None | |
| start_index = window_buffer_len - self._pre_padding_frames | |
| trigger_view = itertools.islice(window_buffer, start_index, None) | |
| return trigger_view | |
| def trigger_controller(frame: bytes): | |
| nonlocal is_triggered, window_buffer, segment_buffer | |
| assert window_buffer.maxlen is not None | |
| if is_triggered: | |
| voiced_rate = calculate_voice_rate(window_buffer, window_buffer.maxlen) | |
| unvoiced_rate = 1 - voiced_rate | |
| segment_buffer.append(frame) | |
| if unvoiced_rate > self._VAD_TRIGGER_THRESHOLD: | |
| is_triggered = False | |
| record_path = self._save_speech_segment(segment_buffer) | |
| segment_buffer.clear() | |
| self._on_record_result(record_path) | |
| else: | |
| trigger_view = get_trigger_view() | |
| if trigger_view is None: return | |
| voiced_rate = calculate_voice_rate(trigger_view, self._pre_padding_frames) | |
| if voiced_rate > self._VAD_TRIGGER_THRESHOLD: | |
| is_triggered = True | |
| segment_buffer.extend(window_frame[0] for window_frame in window_buffer) | |
| is_triggered = False | |
| segment_buffer: list[bytes] = [] | |
| window_buffer: collections.deque[tuple[bytes, bool]] =\ | |
| collections.deque(maxlen=self._post_padding_frames) | |
| while True: | |
| try: | |
| frame = self._frame_queue.get(timeout=0.1) | |
| except QueueEmptyException: continue | |
| if frame == b"": break | |
| is_speech = self._vad.is_speech(frame, self.sample_rate) | |
| window_buffer.append((frame, is_speech)) | |
| trigger_controller(frame) | |
| def _record(self): | |
| with self._audio_resources() as (_, stream): | |
| while True: | |
| with self._lock: | |
| if not self._is_recording: break | |
| data = stream.read(self.frame_size, exception_on_overflow=False) | |
| self._frame_queue.put(data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment