Skip to main content
Goose’s provider system allows integration with any LLM service. This guide explains how to implement the Provider trait and integrate new LLM providers.

Provider Trait Overview

All providers implement the Provider trait defined in crates/goose/src/providers/base.rs:456.

Core Methods

#[async_trait]
pub trait Provider: Send + Sync {
    /// Provider identifier
    fn get_name(&self) -> &str;
    
    /// Primary streaming method
    async fn stream(
        &self,
        model_config: &ModelConfig,
        session_id: &str,
        system: &str,
        messages: &[Message],
        tools: &[Tool],
    ) -> Result<MessageStream, ProviderError>;
    
    /// Get model configuration
    fn get_model_config(&self) -> ModelConfig;
    
    /// Non-streaming completion (default impl uses stream)
    async fn complete(
        &self,
        model_config: &ModelConfig,
        session_id: &str,
        system: &str,
        messages: &[Message],
        tools: &[Tool],
    ) -> Result<(Message, ProviderUsage), ProviderError>;
    
    // ... additional methods
}

Key Types

MessageStream - Stream of message chunks:
pub type MessageStream = Pin<Box<dyn Stream<Item = Result<
    (Option<Message>, Option<ProviderUsage>),
    ProviderError
>> + Send>>;
ProviderUsage - Token usage information:
pub struct ProviderUsage {
    pub model: String,
    pub usage: Usage,
}

pub struct Usage {
    pub input_tokens: Option<i32>,
    pub output_tokens: Option<i32>,
    pub total_tokens: Option<i32>,
}

Implementing a Basic Provider

1. Define Provider Struct

use goose::providers::base::{Provider, ProviderError};
use goose::model::ModelConfig;
use async_trait::async_trait;

pub struct MyProvider {
    api_key: String,
    base_url: String,
    model_config: ModelConfig,
    client: reqwest::Client,
}

impl MyProvider {
    pub fn new(
        api_key: String,
        base_url: String,
        model_config: ModelConfig,
    ) -> Self {
        Self {
            api_key,
            base_url,
            model_config,
            client: reqwest::Client::new(),
        }
    }
}

2. Implement Provider Trait

#[async_trait]
impl Provider for MyProvider {
    fn get_name(&self) -> &str {
        "myprovider"
    }
    
    fn get_model_config(&self) -> ModelConfig {
        self.model_config.clone()
    }
    
    async fn stream(
        &self,
        model_config: &ModelConfig,
        session_id: &str,
        system: &str,
        messages: &[Message],
        tools: &[Tool],
    ) -> Result<MessageStream, ProviderError> {
        // Build API request
        let request = self.build_request(
            model_config,
            system,
            messages,
            tools,
        )?;
        
        // Make streaming API call
        let response = self.client
            .post(&format!("{}/chat/completions", self.base_url))
            .header("Authorization", format!("Bearer {}", self.api_key))
            .json(&request)
            .send()
            .await
            .map_err(|e| ProviderError::RequestFailed(
                format!("Request failed: {}", e)
            ))?;
        
        // Convert response to MessageStream
        let stream = self.process_stream(response).await?;
        Ok(stream)
    }
}

3. Implement ProviderDef

use goose::providers::base::{ProviderDef, ProviderMetadata, ConfigKey};
use futures::future::BoxFuture;

impl ProviderDef for MyProvider {
    type Provider = Self;
    
    fn metadata() -> ProviderMetadata {
        ProviderMetadata::new(
            "myprovider",
            "My Provider",
            "Custom LLM provider",
            "my-model-name",
            vec!["my-model-name", "my-other-model"],
            "https://myprovider.com/models",
            vec![
                ConfigKey::new(
                    "API_KEY",
                    true,   // required
                    true,   // secret
                    None,   // no default
                    true,   // primary
                ),
                ConfigKey::new(
                    "BASE_URL",
                    false,  // optional
                    false,  // not secret
                    Some("https://api.myprovider.com"),
                    false,  // not primary
                ),
            ],
        )
    }
    
    fn from_env(
        model_config: ModelConfig,
        _extensions: Vec<ExtensionConfig>,
    ) -> BoxFuture<'static, anyhow::Result<Self::Provider>> {
        Box::pin(async move {
            let api_key = std::env::var("MYPROVIDER_API_KEY")
                .map_err(|_| anyhow::anyhow!("MYPROVIDER_API_KEY not set"))?;
            
            let base_url = std::env::var("MYPROVIDER_BASE_URL")
                .unwrap_or_else(|_| "https://api.myprovider.com".to_string());
            
            Ok(Self::new(api_key, base_url, model_config))
        })
    }
}

Message Format Conversion

Convert Goose’s internal message format to provider-specific format:
fn build_request(
    &self,
    model_config: &ModelConfig,
    system: &str,
    messages: &[Message],
    tools: &[Tool],
) -> Result<serde_json::Value, ProviderError> {
    let mut api_messages = vec![];
    
    // Add system message
    if !system.is_empty() {
        api_messages.push(serde_json::json!({
            "role": "system",
            "content": system
        }));
    }
    
    // Convert messages
    for message in messages {
        api_messages.push(self.convert_message(message)?);
    }
    
    // Build request
    let mut request = serde_json::json!({
        "model": model_config.model_name,
        "messages": api_messages,
        "stream": true,
    });
    
    // Add tools if supported
    if !tools.is_empty() {
        request["tools"] = self.convert_tools(tools)?;
    }
    
    Ok(request)
}

fn convert_message(
    &self,
    message: &Message,
) -> Result<serde_json::Value, ProviderError> {
    let role = match message.role {
        rmcp::model::Role::User => "user",
        rmcp::model::Role::Assistant => "assistant",
    };
    
    let content = message.content
        .iter()
        .filter_map(|c| c.as_text())
        .collect::<Vec<_>>()
        .join("");
    
    Ok(serde_json::json!({
        "role": role,
        "content": content
    }))
}

Streaming Implementation

Process Server-Sent Events (SSE) stream:
use futures::stream::Stream;
use futures::StreamExt;

async fn process_stream(
    &self,
    response: reqwest::Response,
) -> Result<MessageStream, ProviderError> {
    let mut stream = response.bytes_stream();
    let mut buffer = String::new();
    
    let message_stream = async_stream::stream! {
        while let Some(chunk) = stream.next().await {
            let chunk = chunk.map_err(|e| ProviderError::StreamError(
                format!("Stream error: {}", e)
            ))?;
            
            buffer.push_str(&String::from_utf8_lossy(&chunk));
            
            // Process complete SSE messages
            while let Some(pos) = buffer.find("\n\n") {
                let message = buffer[..pos].to_string();
                buffer = buffer[pos + 2..].to_string();
                
                if let Some(event) = self.parse_sse(&message)? {
                    yield Ok((Some(event.message), event.usage));
                }
            }
        }
    };
    
    Ok(Box::pin(message_stream))
}

struct StreamEvent {
    message: Message,
    usage: Option<ProviderUsage>,
}

fn parse_sse(
    &self,
    data: &str,
) -> Result<Option<StreamEvent>, ProviderError> {
    // Parse SSE data field
    let json_data = data
        .lines()
        .find(|line| line.starts_with("data: "))
        .and_then(|line| line.strip_prefix("data: "))
        .unwrap_or("");
    
    if json_data == "[DONE]" {
        return Ok(None);
    }
    
    let parsed: serde_json::Value = serde_json::from_str(json_data)
        .map_err(|e| ProviderError::ParseError(
            format!("Parse error: {}", e)
        ))?;
    
    // Extract delta content
    let content = parsed["choices"][0]["delta"]["content"]
        .as_str()
        .unwrap_or("");
    
    if content.is_empty() {
        return Ok(None);
    }
    
    let message = Message::assistant().with_text(content);
    
    Ok(Some(StreamEvent {
        message,
        usage: None,
    }))
}

Tool Support

Convert MCP tools to provider format:
fn convert_tools(
    &self,
    tools: &[Tool],
) -> Result<serde_json::Value, ProviderError> {
    let converted = tools
        .iter()
        .map(|tool| serde_json::json!({
            "type": "function",
            "function": {
                "name": tool.name,
                "description": tool.description,
                "parameters": tool.input_schema
            }
        }))
        .collect::<Vec<_>>();
    
    Ok(serde_json::json!(converted))
}
Handle tool calls in responses:
if let Some(tool_calls) = parsed["choices"][0]["message"]["tool_calls"].as_array() {
    let mut message = Message::assistant();
    
    for tool_call in tool_calls {
        let name = tool_call["function"]["name"].as_str().unwrap();
        let arguments = tool_call["function"]["arguments"].as_str().unwrap();
        
        message = message.with_tool_request(
            tool_call["id"].as_str().unwrap(),
            name,
            serde_json::from_str(arguments)?,
        );
    }
    
    return Ok(Some(StreamEvent { message, usage: None }));
}

Advanced Features

OAuth Support

Implement OAuth authentication:
async fn configure_oauth(&self) -> Result<(), ProviderError> {
    use goose::providers::oauth::get_oauth_token_async;
    
    let token = get_oauth_token_async(
        "https://auth.myprovider.com",
        "client-id",
        "http://localhost:8080",
        &["api.access".to_string()],
    ).await
    .map_err(|e| ProviderError::ConfigError(
        format!("OAuth failed: {}", e)
    ))?;
    
    // Store token securely
    // ...
    
    Ok(())
}

Model Discovery

Fetch available models:
async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
    let response = self.client
        .get(&format!("{}/models", self.base_url))
        .header("Authorization", format!("Bearer {}", self.api_key))
        .send()
        .await
        .map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
    
    let models: Vec<ModelInfo> = response
        .json()
        .await
        .map_err(|e| ProviderError::ParseError(e.to_string()))?;
    
    Ok(models.iter().map(|m| m.name.clone()).collect())
}

Retry Logic

Configure retry behavior:
use goose::providers::retry::RetryConfig;

fn retry_config(&self) -> RetryConfig {
    RetryConfig {
        max_retries: 3,
        initial_delay_ms: 1000,
        max_delay_ms: 10000,
        retry_on_statuses: vec![429, 500, 502, 503, 504],
    }
}

Session Naming

Customize session name generation:
async fn generate_session_name(
    &self,
    session_id: &str,
    messages: &Conversation,
) -> Result<String, ProviderError> {
    let context = self.get_initial_user_messages(messages);
    let system = "Generate a short 4-word title for this conversation.";
    
    let message = Message::user().with_text(&context.join("\n"));
    let (response, _) = self.complete_fast(
        session_id,
        system,
        &[message],
        &[],
    ).await?;
    
    let title = response.as_concat_text();
    Ok(title.trim().to_string())
}

Error Handling

Use ProviderError for provider-specific errors:
pub enum ProviderError {
    RequestFailed(String),
    ParseError(String),
    StreamError(String),
    ConfigError(String),
    ContextLengthExceeded(String),
    RateLimitExceeded(String),
    // ...
}
Example:
if response.status() == 429 {
    return Err(ProviderError::RateLimitExceeded(
        "Rate limit exceeded. Try again later.".to_string()
    ));
}

if response.status() == 413 {
    return Err(ProviderError::ContextLengthExceeded(
        "Request exceeds model context length".to_string()
    ));
}

Testing Providers

Unit Tests

#[cfg(test)]
mod tests {
    use super::*;
    
    #[tokio::test]
    async fn test_provider_stream() {
        let provider = MyProvider::new(
            "test-key".to_string(),
            "https://api.test.com".to_string(),
            ModelConfig::new_or_fail("test-model"),
        );
        
        let messages = vec![Message::user().with_text("Hello")];
        let stream = provider.stream(
            &provider.model_config,
            "test-session",
            "You are helpful",
            &messages,
            &[],
        ).await;
        
        assert!(stream.is_ok());
    }
}

Integration Tests

#[tokio::test]
async fn test_provider_integration() {
    let provider = MyProvider::from_env(
        ModelConfig::new_or_fail("test-model"),
        vec![],
    ).await.unwrap();
    
    let (response, usage) = provider.complete(
        &provider.get_model_config(),
        "test-session",
        "You are helpful",
        &[Message::user().with_text("Hello")],
        &[],
    ).await.unwrap();
    
    assert!(!response.content.is_empty());
    assert!(usage.usage.total_tokens.is_some());
}

Registration

Register your provider in crates/goose/src/providers/provider_registry.rs:
pub fn register_providers() -> HashMap<String, Box<dyn ProviderFactory>> {
    let mut providers = HashMap::new();
    
    // ... existing providers
    
    providers.insert(
        "myprovider".to_string(),
        Box::new(MyProviderFactory) as Box<dyn ProviderFactory>,
    );
    
    providers
}

Next Steps

Build docs developers (and LLMs) love