use goose::providers::base::*;
use goose::providers::oauth::get_oauth_token_async;
use async_trait::async_trait;
pub struct DatabricksProvider {
host: String,
client_id: String,
scopes: Vec<String>,
model_config: ModelConfig,
client: reqwest::Client,
}
impl DatabricksProvider {
const DEFAULT_CLIENT_ID: &'static str = "databricks-cli";
const DEFAULT_SCOPES: &[&'static str] = &[
"all-apis",
"offline_access",
];
pub fn new(
host: String,
model_config: ModelConfig,
) -> Self {
Self {
host,
client_id: Self::DEFAULT_CLIENT_ID.to_string(),
scopes: Self::DEFAULT_SCOPES
.iter()
.map(|s| s.to_string())
.collect(),
model_config,
client: reqwest::Client::new(),
}
}
async fn get_token(&self) -> Result<String, ProviderError> {
get_oauth_token_async(
&self.host,
&self.client_id,
"http://localhost:8020",
&self.scopes,
)
.await
.map_err(|e| ProviderError::ConfigError(
format!("Failed to get OAuth token: {}", e)
))
}
}
#[async_trait]
impl Provider for DatabricksProvider {
fn get_name(&self) -> &str {
"databricks"
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn configure_oauth(&self) -> Result<(), ProviderError> {
println!("Starting Databricks OAuth authentication...");
println!("A browser window will open for authentication.");
let token = self.get_token().await?;
println!("\nAuthentication successful!");
println!("Token cached for future use.");
Ok(())
}
async fn stream(
&self,
model_config: &ModelConfig,
session_id: &str,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let token = self.get_token().await?;
let request = self.build_request(
model_config,
system,
messages,
tools,
)?;
let response = self.client
.post(&format!(
"{}/serving-endpoints/{}/invocations",
self.host,
model_config.model_name
))
.header("Authorization", format!("Bearer {}", token))
.json(&request)
.send()
.await
.map_err(|e| ProviderError::RequestFailed(
format!("Request failed: {}", e)
))?;
self.process_stream(response).await
}
}
impl ProviderDef for DatabricksProvider {
type Provider = Self;
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"databricks",
"Databricks",
"Databricks Model Serving with OAuth",
"databricks-meta-llama-3-1-70b-instruct",
vec![
"databricks-meta-llama-3-1-70b-instruct",
"databricks-meta-llama-3-1-405b-instruct",
],
"https://docs.databricks.com/en/machine-learning/foundation-models/",
vec![
ConfigKey::new(
"HOST",
true,
false,
None,
true,
),
ConfigKey::new_oauth(
"ACCESS_TOKEN",
true,
true,
None,
true,
),
],
)
}
fn from_env(
model_config: ModelConfig,
_extensions: Vec<ExtensionConfig>,
) -> BoxFuture<'static, anyhow::Result<Self::Provider>> {
Box::pin(async move {
let host = std::env::var("DATABRICKS_HOST")
.map_err(|_| anyhow::anyhow!(
"DATABRICKS_HOST environment variable not set"
))?;
Ok(Self::new(host, model_config))
})
}
}