Created
December 14, 2024 01:09
-
-
Save schroneko/0088dc9e00628c305975fc114fd24536 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import contextlib | |
| import wave | |
| import logging | |
| import asyncio | |
| import os | |
| import sounddevice as sd | |
| import soundfile as sf | |
| from google import genai | |
| # ロガーの設定 | |
| logger = logging.getLogger('Live') | |
| logger.setLevel('INFO') | |
| # wave ファイル保存用のコンテキストマネージャー | |
| @contextlib.contextmanager | |
| def wave_file(filename, channels=1, rate=24000, sample_width=2): | |
| with wave.open(filename, "wb") as wf: | |
| wf.setnchannels(channels) | |
| wf.setsampwidth(sample_width) | |
| wf.setframerate(rate) | |
| yield wf | |
| def play_audio(filename): | |
| """音声ファイルを再生""" | |
| data, samplerate = sf.read(filename) | |
| sd.play(data, samplerate) | |
| sd.wait() # 再生完了まで待機 | |
| class AudioLoop: | |
| def __init__(self, config=None): | |
| self.session = None | |
| self.index = 0 | |
| if config is None: | |
| config = { | |
| "generation_config": {"response_modalities": ["AUDIO"]} | |
| } | |
| self.config = config | |
| async def run(self): | |
| print("Type 'q' to quit") | |
| logger.debug('connect') | |
| client = genai.Client(http_options={'api_version': 'v1alpha'}) | |
| async with client.aio.live.connect(model="gemini-2.0-flash-exp", config=self.config) as session: | |
| self.session = session | |
| while True: | |
| if not await self.send(): | |
| break | |
| await self.recv() | |
| async def send(self): | |
| logger.debug('send') | |
| text = await asyncio.to_thread(input, "message > ") | |
| if text.lower() == 'q': | |
| return False | |
| await self.session.send(text, end_of_turn=True) | |
| logger.debug('sent') | |
| return True | |
| async def recv(self): | |
| file_name = f"audio_{self.index}.wav" | |
| with wave_file(file_name) as wav: | |
| self.index += 1 | |
| logger.debug('receive') | |
| async for response in self.session.receive(): | |
| logger.debug(f'got chunk: {str(response)}') | |
| server_content = response.server_content | |
| if server_content is None: | |
| logger.error(f'Unhandled server message! - {response}') | |
| break | |
| model_turn = response.server_content.model_turn | |
| if model_turn is not None: | |
| for part in model_turn.parts: | |
| data = part.inline_data.data | |
| print('.', end='') | |
| logger.debug(f'Got pcm_data, mimetype: {part.inline_data.mime_type}') | |
| wav.writeframes(data) | |
| if response.server_content.turn_complete: | |
| print('\n') | |
| break | |
| # 音声ファイルを再生 | |
| play_audio(file_name) | |
| await asyncio.sleep(2) | |
| async def text_to_text_demo(): | |
| """Text-to-Text デモの実行""" | |
| client = genai.Client(http_options={'api_version': 'v1alpha'}) | |
| config = { | |
| "generation_config": {"response_modalities": ["TEXT"]} | |
| } | |
| async with client.aio.live.connect(model="gemini-2.0-flash-exp", config=config) as session: | |
| message = "Hello? Gemini are you there?" | |
| print("> ", message, "\n") | |
| await session.send(message, end_of_turn=True) | |
| turn = session.receive() | |
| async for chunk in turn: | |
| if chunk.text is not None: | |
| print(f'- {chunk.text}') | |
| async def text_to_audio_demo(): | |
| """Text-to-Audio デモの実行""" | |
| client = genai.Client(http_options={'api_version': 'v1alpha'}) | |
| config = { | |
| "generation_config": {"response_modalities": ["AUDIO"]} | |
| } | |
| async with client.aio.live.connect(model="gemini-2.0-flash-exp", config=config) as session: | |
| file_name = 'audio.wav' | |
| with wave_file(file_name) as wav: | |
| message = "Hello? Gemini are you there?" | |
| print("> ", message, "\n") | |
| await session.send(message, end_of_turn=True) | |
| first = True | |
| async for response in session.receive(): | |
| if response.data is not None: | |
| model_turn = response.server_content.model_turn | |
| if first: | |
| print(model_turn.parts[0].inline_data.mime_type) | |
| first = False | |
| print('.', end='') | |
| wav.writeframes(response.data) | |
| # 音声ファイルを再生 | |
| play_audio(file_name) | |
| async def main(): | |
| """メイン関数""" | |
| # APIキーの設定 | |
| if not os.environ.get('GOOGLE_API_KEY'): | |
| raise ValueError("GOOGLE_API_KEY environment variable is not set") | |
| print("Starting Text-to-Text demo...") | |
| await text_to_text_demo() | |
| print("\nStarting Text-to-Audio demo...") | |
| await text_to_audio_demo() | |
| print("\nStarting Interactive Audio Loop...") | |
| await AudioLoop().run() | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment