Provider trait and integrate new LLM providers.
Provider Trait Overview
All providers implement theProvider trait defined in crates/goose/src/providers/base.rs:456.
Core Methods
#[async_trait]
pub trait Provider: Send + Sync {
/// Provider identifier
fn get_name(&self) -> &str;
/// Primary streaming method
async fn stream(
&self,
model_config: &ModelConfig,
session_id: &str,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<MessageStream, ProviderError>;
/// Get model configuration
fn get_model_config(&self) -> ModelConfig;
/// Non-streaming completion (default impl uses stream)
async fn complete(
&self,
model_config: &ModelConfig,
session_id: &str,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError>;
// ... additional methods
}
Key Types
MessageStream - Stream of message chunks:pub type MessageStream = Pin<Box<dyn Stream<Item = Result<
(Option<Message>, Option<ProviderUsage>),
ProviderError
>> + Send>>;
pub struct ProviderUsage {
pub model: String,
pub usage: Usage,
}
pub struct Usage {
pub input_tokens: Option<i32>,
pub output_tokens: Option<i32>,
pub total_tokens: Option<i32>,
}
Implementing a Basic Provider
1. Define Provider Struct
use goose::providers::base::{Provider, ProviderError};
use goose::model::ModelConfig;
use async_trait::async_trait;
pub struct MyProvider {
api_key: String,
base_url: String,
model_config: ModelConfig,
client: reqwest::Client,
}
impl MyProvider {
pub fn new(
api_key: String,
base_url: String,
model_config: ModelConfig,
) -> Self {
Self {
api_key,
base_url,
model_config,
client: reqwest::Client::new(),
}
}
}
2. Implement Provider Trait
#[async_trait]
impl Provider for MyProvider {
fn get_name(&self) -> &str {
"myprovider"
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn stream(
&self,
model_config: &ModelConfig,
session_id: &str,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
// Build API request
let request = self.build_request(
model_config,
system,
messages,
tools,
)?;
// Make streaming API call
let response = self.client
.post(&format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await
.map_err(|e| ProviderError::RequestFailed(
format!("Request failed: {}", e)
))?;
// Convert response to MessageStream
let stream = self.process_stream(response).await?;
Ok(stream)
}
}
3. Implement ProviderDef
use goose::providers::base::{ProviderDef, ProviderMetadata, ConfigKey};
use futures::future::BoxFuture;
impl ProviderDef for MyProvider {
type Provider = Self;
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"myprovider",
"My Provider",
"Custom LLM provider",
"my-model-name",
vec!["my-model-name", "my-other-model"],
"https://myprovider.com/models",
vec![
ConfigKey::new(
"API_KEY",
true, // required
true, // secret
None, // no default
true, // primary
),
ConfigKey::new(
"BASE_URL",
false, // optional
false, // not secret
Some("https://api.myprovider.com"),
false, // not primary
),
],
)
}
fn from_env(
model_config: ModelConfig,
_extensions: Vec<ExtensionConfig>,
) -> BoxFuture<'static, anyhow::Result<Self::Provider>> {
Box::pin(async move {
let api_key = std::env::var("MYPROVIDER_API_KEY")
.map_err(|_| anyhow::anyhow!("MYPROVIDER_API_KEY not set"))?;
let base_url = std::env::var("MYPROVIDER_BASE_URL")
.unwrap_or_else(|_| "https://api.myprovider.com".to_string());
Ok(Self::new(api_key, base_url, model_config))
})
}
}
Message Format Conversion
Convert Goose’s internal message format to provider-specific format:fn build_request(
&self,
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<serde_json::Value, ProviderError> {
let mut api_messages = vec![];
// Add system message
if !system.is_empty() {
api_messages.push(serde_json::json!({
"role": "system",
"content": system
}));
}
// Convert messages
for message in messages {
api_messages.push(self.convert_message(message)?);
}
// Build request
let mut request = serde_json::json!({
"model": model_config.model_name,
"messages": api_messages,
"stream": true,
});
// Add tools if supported
if !tools.is_empty() {
request["tools"] = self.convert_tools(tools)?;
}
Ok(request)
}
fn convert_message(
&self,
message: &Message,
) -> Result<serde_json::Value, ProviderError> {
let role = match message.role {
rmcp::model::Role::User => "user",
rmcp::model::Role::Assistant => "assistant",
};
let content = message.content
.iter()
.filter_map(|c| c.as_text())
.collect::<Vec<_>>()
.join("");
Ok(serde_json::json!({
"role": role,
"content": content
}))
}
Streaming Implementation
Process Server-Sent Events (SSE) stream:use futures::stream::Stream;
use futures::StreamExt;
async fn process_stream(
&self,
response: reqwest::Response,
) -> Result<MessageStream, ProviderError> {
let mut stream = response.bytes_stream();
let mut buffer = String::new();
let message_stream = async_stream::stream! {
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| ProviderError::StreamError(
format!("Stream error: {}", e)
))?;
buffer.push_str(&String::from_utf8_lossy(&chunk));
// Process complete SSE messages
while let Some(pos) = buffer.find("\n\n") {
let message = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
if let Some(event) = self.parse_sse(&message)? {
yield Ok((Some(event.message), event.usage));
}
}
}
};
Ok(Box::pin(message_stream))
}
struct StreamEvent {
message: Message,
usage: Option<ProviderUsage>,
}
fn parse_sse(
&self,
data: &str,
) -> Result<Option<StreamEvent>, ProviderError> {
// Parse SSE data field
let json_data = data
.lines()
.find(|line| line.starts_with("data: "))
.and_then(|line| line.strip_prefix("data: "))
.unwrap_or("");
if json_data == "[DONE]" {
return Ok(None);
}
let parsed: serde_json::Value = serde_json::from_str(json_data)
.map_err(|e| ProviderError::ParseError(
format!("Parse error: {}", e)
))?;
// Extract delta content
let content = parsed["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or("");
if content.is_empty() {
return Ok(None);
}
let message = Message::assistant().with_text(content);
Ok(Some(StreamEvent {
message,
usage: None,
}))
}
Tool Support
Convert MCP tools to provider format:fn convert_tools(
&self,
tools: &[Tool],
) -> Result<serde_json::Value, ProviderError> {
let converted = tools
.iter()
.map(|tool| serde_json::json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema
}
}))
.collect::<Vec<_>>();
Ok(serde_json::json!(converted))
}
if let Some(tool_calls) = parsed["choices"][0]["message"]["tool_calls"].as_array() {
let mut message = Message::assistant();
for tool_call in tool_calls {
let name = tool_call["function"]["name"].as_str().unwrap();
let arguments = tool_call["function"]["arguments"].as_str().unwrap();
message = message.with_tool_request(
tool_call["id"].as_str().unwrap(),
name,
serde_json::from_str(arguments)?,
);
}
return Ok(Some(StreamEvent { message, usage: None }));
}
Advanced Features
OAuth Support
Implement OAuth authentication:async fn configure_oauth(&self) -> Result<(), ProviderError> {
use goose::providers::oauth::get_oauth_token_async;
let token = get_oauth_token_async(
"https://auth.myprovider.com",
"client-id",
"http://localhost:8080",
&["api.access".to_string()],
).await
.map_err(|e| ProviderError::ConfigError(
format!("OAuth failed: {}", e)
))?;
// Store token securely
// ...
Ok(())
}
Model Discovery
Fetch available models:async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let response = self.client
.get(&format!("{}/models", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
.map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
let models: Vec<ModelInfo> = response
.json()
.await
.map_err(|e| ProviderError::ParseError(e.to_string()))?;
Ok(models.iter().map(|m| m.name.clone()).collect())
}
Retry Logic
Configure retry behavior:use goose::providers::retry::RetryConfig;
fn retry_config(&self) -> RetryConfig {
RetryConfig {
max_retries: 3,
initial_delay_ms: 1000,
max_delay_ms: 10000,
retry_on_statuses: vec![429, 500, 502, 503, 504],
}
}
Session Naming
Customize session name generation:async fn generate_session_name(
&self,
session_id: &str,
messages: &Conversation,
) -> Result<String, ProviderError> {
let context = self.get_initial_user_messages(messages);
let system = "Generate a short 4-word title for this conversation.";
let message = Message::user().with_text(&context.join("\n"));
let (response, _) = self.complete_fast(
session_id,
system,
&[message],
&[],
).await?;
let title = response.as_concat_text();
Ok(title.trim().to_string())
}
Error Handling
UseProviderError for provider-specific errors:
pub enum ProviderError {
RequestFailed(String),
ParseError(String),
StreamError(String),
ConfigError(String),
ContextLengthExceeded(String),
RateLimitExceeded(String),
// ...
}
if response.status() == 429 {
return Err(ProviderError::RateLimitExceeded(
"Rate limit exceeded. Try again later.".to_string()
));
}
if response.status() == 413 {
return Err(ProviderError::ContextLengthExceeded(
"Request exceeds model context length".to_string()
));
}
Testing Providers
Unit Tests
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_provider_stream() {
let provider = MyProvider::new(
"test-key".to_string(),
"https://api.test.com".to_string(),
ModelConfig::new_or_fail("test-model"),
);
let messages = vec![Message::user().with_text("Hello")];
let stream = provider.stream(
&provider.model_config,
"test-session",
"You are helpful",
&messages,
&[],
).await;
assert!(stream.is_ok());
}
}
Integration Tests
#[tokio::test]
async fn test_provider_integration() {
let provider = MyProvider::from_env(
ModelConfig::new_or_fail("test-model"),
vec![],
).await.unwrap();
let (response, usage) = provider.complete(
&provider.get_model_config(),
"test-session",
"You are helpful",
&[Message::user().with_text("Hello")],
&[],
).await.unwrap();
assert!(!response.content.is_empty());
assert!(usage.usage.total_tokens.is_some());
}
Registration
Register your provider incrates/goose/src/providers/provider_registry.rs:
pub fn register_providers() -> HashMap<String, Box<dyn ProviderFactory>> {
let mut providers = HashMap::new();
// ... existing providers
providers.insert(
"myprovider".to_string(),
Box::new(MyProviderFactory) as Box<dyn ProviderFactory>,
);
providers
}
Next Steps
- Learn about Declarative Providers
- Implement OAuth Providers
- See existing providers in
crates/goose/src/providers/