Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,59 @@
# SPDX-License-Identifier: MIT

import argparse
import functools
import sys

import grpc

# Exit codes shared by the CLI scripts. Pipelines that compose these scripts
# rely on a non-zero status to detect failure; see also `cli_main` below.
EXIT_OK = 0
EXIT_GENERIC_ERROR = 1
EXIT_BAD_INPUT = 2 # malformed args, missing file, empty/whitespace text, ...
EXIT_UNAVAILABLE = 3 # gRPC UNAVAILABLE (server down, wrong port, ...)
EXIT_INVALID_ARGUMENT = 4 # gRPC INVALID_ARGUMENT or NOT_FOUND (bad model/lang/voice)
EXIT_INTERRUPTED = 130 # SIGINT


def _grpc_exit_code(error: grpc.RpcError) -> int:
code = error.code() if callable(getattr(error, "code", None)) else None
if code == grpc.StatusCode.UNAVAILABLE:
return EXIT_UNAVAILABLE
if code in (grpc.StatusCode.INVALID_ARGUMENT, grpc.StatusCode.NOT_FOUND):
return EXIT_INVALID_ARGUMENT
return EXIT_GENERIC_ERROR


def cli_main(func):
"""Translate exceptions raised by a CLI ``main`` into consistent exit codes.

Wrapped function may return an int exit code or ``None`` (treated as
``EXIT_OK``). Unhandled exceptions are caught and mapped: gRPC ``RpcError``
via status code, ``FileNotFoundError`` / ``ValueError`` → ``EXIT_BAD_INPUT``,
anything else → ``EXIT_GENERIC_ERROR``. The error is also printed to stderr
so CI logs surface the cause.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
return EXIT_OK if result is None else int(result)
except KeyboardInterrupt:
return EXIT_INTERRUPTED
except grpc.RpcError as e:
details = e.details() if callable(getattr(e, "details", None)) else str(e)
print(f"Error: {details}", file=sys.stderr)
return _grpc_exit_code(e)
except (FileNotFoundError, IsADirectoryError, ValueError) as e:
print(f"Error: {e}", file=sys.stderr)
return EXIT_BAD_INPUT
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return EXIT_GENERIC_ERROR

return wrapper


def validate_grpc_message_size(value):
"""Validate that the GRPC message size is within acceptable limits."""
Expand Down
21 changes: 19 additions & 2 deletions riva/client/audio_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
import queue
from typing import Dict, Union, Optional

import pyaudio

def _require_pyaudio():
try:
import pyaudio
return pyaudio
except ImportError as e:
raise ImportError(
"pyaudio is required for audio device I/O. Install the system PortAudio "
"headers first (e.g. `apt-get install -y portaudio19-dev` on Debian/Ubuntu, "
"`brew install portaudio` on macOS), then `pip install pyaudio`."
) from e


class MicrophoneStream:
Expand All @@ -20,6 +30,8 @@ def __init__(self, rate: int, chunk: int, device: int = None) -> None:
self.closed = True

def __enter__(self):
pyaudio = _require_pyaudio()
self._pa_module = pyaudio
self._audio_interface = pyaudio.PyAudio()
self._audio_stream = self._audio_interface.open(
format=pyaudio.paInt16,
Expand Down Expand Up @@ -50,7 +62,7 @@ def __exit__(self, type, value, traceback):
def _fill_buffer(self, in_data, frame_count, time_info, status_flags):
"""Continuously collect data from the audio stream into the buffer."""
self._buff.put(in_data)
return None, pyaudio.paContinue
return None, self._pa_module.paContinue

def __next__(self) -> bytes:
if self.closed:
Expand All @@ -76,13 +88,15 @@ def __iter__(self):


def get_audio_device_info(device_id: int) -> Dict[str, Union[int, float, str]]:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
info = p.get_device_info_by_index(device_id)
p.terminate()
return info


def get_default_input_device_info() -> Optional[Dict[str, Union[int, float, str]]]:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
try:
info = p.get_default_input_device_info()
Expand All @@ -93,6 +107,7 @@ def get_default_input_device_info() -> Optional[Dict[str, Union[int, float, str]


def list_output_devices() -> None:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
print("Output audio devices:")
for i in range(p.get_device_count()):
Expand All @@ -104,6 +119,7 @@ def list_output_devices() -> None:


def list_input_devices() -> None:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
print("Input audio devices:")
for i in range(p.get_device_count()):
Expand All @@ -118,6 +134,7 @@ class SoundCallBack:
def __init__(
self, output_device_index: Optional[int], sampwidth: int, nchannels: int, framerate: int,
) -> None:
pyaudio = _require_pyaudio()
self.pa = pyaudio.PyAudio()
self.stream = self.pa.open(
output_device_index=output_device_index,
Expand Down
Loading