Files
website/venv/lib/python3.11/site-packages/openai/helpers/local_audio_player.py

166 lines
5.7 KiB
Python

# mypy: ignore-errors
from __future__ import annotations
import queue
import asyncio
from typing import Any, Union, Callable, AsyncGenerator, cast
from typing_extensions import TYPE_CHECKING
from .. import _legacy_response
from .._extras import numpy as np, sounddevice as sd
from .._response import StreamedBinaryAPIResponse, AsyncStreamedBinaryAPIResponse
if TYPE_CHECKING:
import numpy.typing as npt
SAMPLE_RATE = 24000
class LocalAudioPlayer:
def __init__(
self,
should_stop: Union[Callable[[], bool], None] = None,
):
self.channels = 1
self.dtype = np.float32
self.should_stop = should_stop
async def _tts_response_to_buffer(
self,
response: Union[
_legacy_response.HttpxBinaryResponseContent,
AsyncStreamedBinaryAPIResponse,
StreamedBinaryAPIResponse,
],
) -> npt.NDArray[np.float32]:
chunks: list[bytes] = []
if isinstance(response, _legacy_response.HttpxBinaryResponseContent) or isinstance(
response, StreamedBinaryAPIResponse
):
for chunk in response.iter_bytes(chunk_size=1024):
if chunk:
chunks.append(chunk)
else:
async for chunk in response.iter_bytes(chunk_size=1024):
if chunk:
chunks.append(chunk)
audio_bytes = b"".join(chunks)
audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0
audio_np = audio_np.reshape(-1, 1)
return audio_np
async def play(
self,
input: Union[
npt.NDArray[np.int16],
npt.NDArray[np.float32],
_legacy_response.HttpxBinaryResponseContent,
AsyncStreamedBinaryAPIResponse,
StreamedBinaryAPIResponse,
],
) -> None:
audio_content: npt.NDArray[np.float32]
if isinstance(input, np.ndarray):
if input.dtype == np.int16 and self.dtype == np.float32:
audio_content = (input.astype(np.float32) / 32767.0).reshape(-1, self.channels)
elif input.dtype == np.float32:
audio_content = cast('npt.NDArray[np.float32]', input)
else:
raise ValueError(f"Unsupported dtype: {input.dtype}")
else:
audio_content = await self._tts_response_to_buffer(input)
loop = asyncio.get_event_loop()
event = asyncio.Event()
idx = 0
def callback(
outdata: npt.NDArray[np.float32],
frame_count: int,
_time_info: Any,
_status: Any,
):
nonlocal idx
remainder = len(audio_content) - idx
if remainder == 0 or (callable(self.should_stop) and self.should_stop()):
loop.call_soon_threadsafe(event.set)
raise sd.CallbackStop
valid_frames = frame_count if remainder >= frame_count else remainder
outdata[:valid_frames] = audio_content[idx : idx + valid_frames]
outdata[valid_frames:] = 0
idx += valid_frames
stream = sd.OutputStream(
samplerate=SAMPLE_RATE,
callback=callback,
dtype=audio_content.dtype,
channels=audio_content.shape[1],
)
with stream:
await event.wait()
async def play_stream(
self,
buffer_stream: AsyncGenerator[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None], None],
) -> None:
loop = asyncio.get_event_loop()
event = asyncio.Event()
buffer_queue: queue.Queue[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None]] = queue.Queue(maxsize=50)
async def buffer_producer():
async for buffer in buffer_stream:
if buffer is None:
break
await loop.run_in_executor(None, buffer_queue.put, buffer)
await loop.run_in_executor(None, buffer_queue.put, None) # Signal completion
def callback(
outdata: npt.NDArray[np.float32],
frame_count: int,
_time_info: Any,
_status: Any,
):
nonlocal current_buffer, buffer_pos
frames_written = 0
while frames_written < frame_count:
if current_buffer is None or buffer_pos >= len(current_buffer):
try:
current_buffer = buffer_queue.get(timeout=0.1)
if current_buffer is None:
loop.call_soon_threadsafe(event.set)
raise sd.CallbackStop
buffer_pos = 0
if current_buffer.dtype == np.int16 and self.dtype == np.float32:
current_buffer = (current_buffer.astype(np.float32) / 32767.0).reshape(-1, self.channels)
except queue.Empty:
outdata[frames_written:] = 0
return
remaining_frames = len(current_buffer) - buffer_pos
frames_to_write = min(frame_count - frames_written, remaining_frames)
outdata[frames_written : frames_written + frames_to_write] = current_buffer[
buffer_pos : buffer_pos + frames_to_write
]
buffer_pos += frames_to_write
frames_written += frames_to_write
current_buffer = None
buffer_pos = 0
producer_task = asyncio.create_task(buffer_producer())
with sd.OutputStream(
samplerate=SAMPLE_RATE,
channels=self.channels,
dtype=self.dtype,
callback=callback,
):
await event.wait()
await producer_task