Skip to main content

Overview

The Speech-to-Text (STT) component uses Kyutai’s streaming ASR (Automatic Speech Recognition) models to transcribe user speech in real-time with low latency and integrated Voice Activity Detection (VAD). Key Features:
  • Real-time streaming transcription
  • ~2.5 second algorithmic delay (configurable)
  • Integrated pause prediction (VAD)
  • WebSocket-based binary protocol (MessagePack)
  • Word-level timestamps

Architecture

STT Service

Technology: Rust (moshi-server) Location: services/moshi-server/ Model: Kyutai Streaming ASR Deployment: Docker container with GPU access

Service Configuration

Docker Compose (docker-compose.yml:78):
stt:
  image: moshi-server:latest
  command: ["worker", "--config", "configs/stt.toml"]
  environment:
    - HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN
  deploy:
    resources:
      reservations:
        devices:
          - driver: nvidia
            count: all
            capabilities: [gpu]
Resource Usage:
  • VRAM: ~2.5 GB
  • Concurrent streams: Limited by capacity management

Python Client

File: unmute/stt/speech_to_text.py

SpeechToText Class

class SpeechToText(ServiceWithStartup):
    def __init__(
        self,
        stt_instance: str = STT_SERVER,
        delay_sec: float = STT_DELAY_SEC,
    ):
        self.websocket: websockets.ClientConnection | None
        self.sent_samples: int = 0
        self.current_time: float = -STT_DELAY_SEC
        self.pause_prediction: ExponentialMovingAverage

Connection Flow

Startup Sequence

File: speech_to_text.py:130
async def start_up(self):
    # 1. Connect WebSocket
    self.websocket = await websockets.connect(
        self.stt_instance + SPEECH_TO_TEXT_PATH,
        additional_headers=HEADERS
    )
    
    # 2. Wait for Ready message
    message_bytes = await self.websocket.recv()
    message = STTMessageAdapter.validate_python(
        msgpack.unpackb(message_bytes)
    )
    
    if isinstance(message, STTReadyMessage):
        mt.STT_ACTIVE_SESSIONS.inc()
        return
    elif isinstance(message, STTErrorMessage):
        # Server at capacity
        raise MissingServiceAtCapacity("stt")

Sending Audio

File: speech_to_text.py:105
async def send_audio(self, audio: np.ndarray) -> None:
    # Validate: 1D float32 array
    if audio.ndim != 1:
        raise ValueError(f"Expected 1D array, got {audio.shape=}")
    
    if audio.dtype != np.float32:
        audio = audio_to_float32(audio)
    
    # Track sent samples (for timing)
    self.sent_samples += len(audio)
    
    # Send as MessagePack
    await self._send({
        "type": "Audio",
        "pcm": audio.tolist()
    })
Frame Size: 480 samples (20ms @ 24kHz) Encoding: MessagePack with use_single_float=True for efficiency

Receiving Messages

File: speech_to_text.py:175 The STT client is an async iterator:
async def __aiter__(self) -> AsyncIterator[STTWordMessage | STTMarkerMessage]:
    async for message_bytes in self.websocket:
        data = msgpack.unpackb(message_bytes)
        message: STTMessage = STTMessageAdapter.validate_python(data)
        
        match message:
            case STTWordMessage():
                # New transcribed word
                mt.STT_RECV_WORDS.inc()
                self.received_words += 1
                yield message
                
            case STTStepMessage():
                # VAD update (every frame)
                self.current_time += FRAME_TIME_SEC
                self.pause_prediction.update(
                    dt=FRAME_TIME_SEC,
                    new_value=message.prs[2]  # Pause score
                )
                
            case STTMarkerMessage():
                # Echo marker (for timing tests)
                yield message

Message Types

Client → Server

Audio Message

class STTClientAudioMessage:
    type: Literal["Audio"] = "Audio"
    pcm: list[float]  # 480 samples, float32
Example:
{
    "type": "Audio",
    "pcm": [0.0001, -0.0002, ...] # 480 values
}

Marker Message

class STTClientMarkerMessage:
    type: Literal["Marker"] = "Marker"
    id: int
Used for latency measurements - server echoes back.

Server → Client

Ready Message

class STTReadyMessage(BaseModel):
    type: Literal["Ready"]
Sent immediately after connection. Signals server is ready.

Word Message

class STTWordMessage(BaseModel):
    type: Literal["Word"]
    text: str
    start_time: float  # Seconds since first audio
Example:
{
    "type": "Word",
    "text": "hello",
    "start_time": 1.52
}

Step Message

class STTStepMessage(BaseModel):
    type: Literal["Step"]
    step_idx: int
    prs: list[float]  # [3 probabilities]
Sent every audio frame (20ms). Contains VAD scores. PRS Array:
  • prs[0]: Probability of speech continuing
  • prs[1]: Probability of speech ending soon
  • prs[2]: Pause prediction score (0-1, higher = more likely pause)
Example:
{
    "type": "Step",
    "step_idx": 42,
    "prs": [0.8, 0.15, 0.05]  # Likely continuing
}

{
    "type": "Step",
    "step_idx": 87,
    "prs": [0.1, 0.2, 0.7]  # Likely pausing
}

Error Message

class STTErrorMessage(BaseModel):
    type: Literal["Error"]
    message: str
Sent when server cannot accept connection (at capacity).

Voice Activity Detection

Pause Prediction

The backend uses prs[2] from Step messages to detect pauses. File: unmute/unmute_handler.py:372
def determine_pause(self) -> bool:
    if self.chatbot.conversation_state() != "user_speaking":
        return False
    
    # Check pause prediction score
    if stt.pause_prediction.value > 0.6:
        return True  # Pause detected
    else:
        return False
Threshold: 0.6 (configurable) Smoothing: Exponential moving average

Exponential Moving Average

File: unmute/stt/exponential_moving_average.py
class ExponentialMovingAverage:
    def __init__(
        self,
        attack_time: float = 0.01,   # Fast rise
        release_time: float = 0.01,  # Fast fall
        initial_value: float = 1.0,
    ):
        self.value: float = initial_value
    
    def update(self, dt: float, new_value: float):
        if new_value > self.value:
            # Attack (rising)
            alpha = 1 - exp(-dt / self.attack_time)
        else:
            # Release (falling)
            alpha = 1 - exp(-dt / self.release_time)
        
        self.value = alpha * new_value + (1 - alpha) * self.value
Purpose: Smooth noisy VAD scores to prevent false pause detections.

Interruption Detection

File: unmute/unmute_handler.py:352 Two methods for detecting user interruption:
  1. STT Word: Any word from STT during bot speaking
    if self.chatbot.conversation_state() == "bot_speaking":
        if stt_word_received:
            await self.interrupt_bot()
    
  2. VAD-based: Pause prediction drops below threshold
    if (
        conversation_state == "bot_speaking"
        and stt.pause_prediction.value < 0.4  # Low = user speaking
        and audio_received_sec > UNINTERRUPTIBLE_BY_VAD_TIME_SEC
    ):
        await self.interrupt_bot()
    
Cooldown Period: First 3 seconds (VAD-based only)
  • Prevents echo cancellation issues
  • STT word-based interruption always works

Flushing

File: unmute/unmute_handler.py:340 When a pause is detected, the STT needs to be “flushed” to process remaining audio:
if self.determine_pause():
    # Calculate flush time
    self.stt_end_of_flush_time = stt.current_time + stt.delay_sec
    
    # Send zeros to flush the delay buffer
    num_frames = int(ceil(stt.delay_sec / FRAME_TIME_SEC)) + 1
    zero = np.zeros(SAMPLES_PER_FRAME, dtype=np.float32)
    for _ in range(num_frames):
        await stt.send_audio(zero)
Why Zeros?: The STT has an internal delay buffer (~2.5s). Sending zeros pushes remaining audio through the model. Flush Timing:
if stt.current_time > self.stt_end_of_flush_time:
    # Flushing complete
    elapsed = self.stt_flush_timer.time()
    rtf = stt.delay_sec / elapsed  # Real-time factor
    await self._generate_response()

Timing & Latency

Time to First Token (TTFT)

File: speech_to_text.py:203
if self.waiting_first_step and self.time_since_first_audio_sent.started:
    self.waiting_first_step = False
    mt.STT_TTFT.observe(self.time_since_first_audio_sent.time())
Typical: 50-100ms after first audio frame

Algorithmic Delay

Constant: STT_DELAY_SEC = 2.5 (configurable)
  • Purpose: Look-ahead for better accuracy
  • Trade-off: Higher delay = better accuracy, higher latency
  • Current Time: Tracked via step messages
self.current_time: float = -STT_DELAY_SEC  # Start negative

# Each step:
self.current_time += FRAME_TIME_SEC  # +20ms

# After flush:
assert self.current_time >= 0  # Caught up to real-time

Metrics

File: unmute/metrics.py

STT-Specific Metrics

# Session metrics
STT_SESSIONS = Counter('unmute_stt_sessions_total')
STT_ACTIVE_SESSIONS = Gauge('unmute_stt_active_sessions')
STT_SESSION_DURATION = Histogram('unmute_stt_session_duration_seconds')

# Audio metrics
STT_SENT_FRAMES = Counter('unmute_stt_sent_frames_total')
STT_RECV_FRAMES = Counter('unmute_stt_recv_frames_total')
STT_AUDIO_DURATION = Histogram('unmute_stt_audio_duration_seconds')

# Word metrics
STT_RECV_WORDS = Counter('unmute_stt_recv_words_total')
STT_NUM_WORDS = Histogram('unmute_stt_num_words')

# Latency metrics
STT_TTFT = Histogram(
    'unmute_stt_ttft_seconds',
    buckets=[0.01, 0.05, 0.1, 0.5, 1.0, 5.0]
)

Real-Time Factor

Calculated during flush:
rtf = stt.delay_sec / elapsed_time
# > 1.0 = Faster than real-time (good)
# < 1.0 = Slower than real-time (bad, will lag)

Integration with UnmuteHandler

Startup

File: unmute/unmute_handler.py:422
async def start_up_stt(self):
    async def _init() -> SpeechToText:
        return await find_instance("stt", SpeechToText)
    
    async def _run(stt: SpeechToText):
        await self._stt_loop(stt)
    
    async def _close(stt: SpeechToText):
        await stt.shutdown()
    
    quest = await self.quest_manager.add(
        Quest("stt", _init, _run, _close)
    )
    
    # Wait for STT to be ready
    await quest.get()

Message Loop

File: unmute/unmute_handler.py:436
async def _stt_loop(self, stt: SpeechToText):
    try:
        async for data in stt:
            if isinstance(data, STTMarkerMessage):
                continue  # Ignore markers
            
            # Send transcription to frontend
            await self.output_queue.put(
                ora.ConversationItemInputAudioTranscriptionDelta(
                    delta=data.text,
                    start_time=data.start_time,
                )
            )
            
            # Skip empty first message
            if data.text == "":
                continue
            
            # Interrupt bot if speaking
            if self.chatbot.conversation_state() == "bot_speaking":
                await self.interrupt_bot()
            
            # Add to chat history
            await self.add_chat_message_delta(data.text, "user")
            
    except websockets.ConnectionClosed:
        logger.info("STT connection closed")

Error Handling

Connection Failures

try:
    stt = await find_instance("stt", SpeechToText)
except MissingServiceAtCapacity:
    # All STT instances at capacity
    await emit_error("Too many users, try again later")
except MissingServiceTimeout:
    # No STT instances responding
    await emit_error("Service timeout, try again later")

WebSocket Disconnect

try:
    async for message in stt:
        # Process message
except websockets.ConnectionClosedOK:
    # Normal shutdown
    pass
except websockets.ConnectionClosedError:
    # Unexpected disconnect
    logger.error("STT connection lost")
    raise

Graceful Shutdown

File: speech_to_text.py:157
async def shutdown(self):
    if self.shutdown_complete.is_set():
        return  # Already shut down
    
    # Record metrics
    mt.STT_ACTIVE_SESSIONS.dec()
    if self.time_since_first_audio_sent.started:
        mt.STT_SESSION_DURATION.observe(
            self.time_since_first_audio_sent.time()
        )
        mt.STT_AUDIO_DURATION.observe(
            self.sent_samples / SAMPLE_RATE
        )
    
    # Close WebSocket
    await self.websocket.close()
    await self.shutdown_complete.wait()

Testing & Debugging

Dummy STT

File: unmute/stt/dummy_speech_to_text.py For testing without GPU:
class DummySpeechToText:
    async def __aiter__(self):
        # Simulate word stream
        await asyncio.sleep(0.5)
        yield STTWordMessage(
            type="Word",
            text="hello",
            start_time=0.5
        )

Example Script

File: unmute/scripts/stt_from_file_example.py Transcribe audio file:
stt = SpeechToText()
await stt.start_up()

# Send audio from file
for chunk in audio_chunks:
    await stt.send_audio(chunk)

# Receive transcription
async for word in stt:
    print(word.text, word.start_time)

Performance Tuning

Delay Configuration

Adjust STT_DELAY_SEC in environment:
export KYUTAI_STT_DELAY_SEC=2.0  # Faster, less accurate
export KYUTAI_STT_DELAY_SEC=3.0  # Slower, more accurate

VAD Threshold

Adjust pause detection sensitivity:
# More sensitive (faster pause detection)
if stt.pause_prediction.value > 0.5:
    pause_detected = True

# Less sensitive (wait longer for pause)
if stt.pause_prediction.value > 0.7:
    pause_detected = True

EMA Parameters

Adjust smoothing for VAD scores:
self.pause_prediction = ExponentialMovingAverage(
    attack_time=0.005,  # Faster rise
    release_time=0.02,  # Slower fall
)

Next Steps

Build docs developers (and LLMs) love