Overview
LiteLLM allows you to add custom LLM providers by implementing theCustomLLM base class. This enables you to integrate proprietary models, internal APIs, or specialized LLM services that aren’t natively supported.
When to Use Custom Providers
- Integrate internal/proprietary LLM APIs
- Add support for new LLM providers before native support
- Wrap existing APIs with custom logic (rate limiting, caching, etc.)
- Implement specialized model endpoints
Quick Start
Create Custom Provider Class
Inherit from
CustomLLM and implement required methods:from litellm.llms.custom_llm import CustomLLM, CustomLLMError
from litellm.utils import ModelResponse
from typing import Optional, Union
import httpx
class MyCustomLLM(CustomLLM):
def __init__(self):
super().__init__()
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[any] = None,
):
# Implement your completion logic
response_text = self._call_your_api(
model=model,
messages=messages,
api_key=api_key,
**optional_params
)
# Format as ModelResponse
model_response.choices[0].message.content = response_text
return model_response
def _call_your_api(self, model, messages, api_key, **kwargs):
# Your API call logic here
return "Response from custom API"
Register Your Provider
Register the custom provider with LiteLLM:
import litellm
# Register your custom provider
litellm.custom_provider_map = [
{
"provider": "my-custom-llm",
"custom_handler": MyCustomLLM()
}
]
Implementation Guide
Required Methods
completion() - Synchronous Completion
completion() - Synchronous Completion
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[any] = None,
) -> ModelResponse:
"""
Handle synchronous completion requests.
Returns:
ModelResponse with populated choices
"""
# Transform messages to your API format
api_messages = self._transform_messages(messages)
# Call your API
response = self._make_request(
endpoint=f"{api_base}/completions",
payload={
"model": model,
"messages": api_messages,
**optional_params
},
headers={"Authorization": f"Bearer {api_key}"},
timeout=timeout
)
# Populate model_response
model_response.choices[0].message.content = response["text"]
model_response.model = model
# Set usage if available
if "usage" in response:
model_response.usage.prompt_tokens = response["usage"]["prompt_tokens"]
model_response.usage.completion_tokens = response["usage"]["completion_tokens"]
model_response.usage.total_tokens = response["usage"]["total_tokens"]
return model_response
acompletion() - Async Completion
acompletion() - Async Completion
async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[any] = None,
) -> ModelResponse:
"""
Handle asynchronous completion requests.
"""
# Use async HTTP client
response = await self._make_async_request(
endpoint=f"{api_base}/completions",
payload={
"model": model,
"messages": messages,
**optional_params
},
headers={"Authorization": f"Bearer {api_key}"},
timeout=timeout
)
model_response.choices[0].message.content = response["text"]
return model_response
streaming() - Sync Streaming
streaming() - Sync Streaming
from litellm.types.utils import GenericStreamingChunk
from typing import Iterator
def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[any] = None,
) -> Iterator[GenericStreamingChunk]:
"""
Handle synchronous streaming requests.
Yields:
GenericStreamingChunk objects
"""
stream = self._make_streaming_request(
endpoint=f"{api_base}/completions",
payload={
"model": model,
"messages": messages,
"stream": True,
**optional_params
},
headers={"Authorization": f"Bearer {api_key}"},
timeout=timeout
)
for chunk in stream:
if chunk.get("text"):
yield GenericStreamingChunk(
text=chunk["text"],
is_finished=chunk.get("is_finished", False),
finish_reason=chunk.get("finish_reason")
)
astreaming() - Async Streaming
astreaming() - Async Streaming
from typing import AsyncIterator
async def astreaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[any] = None,
) -> AsyncIterator[GenericStreamingChunk]:
"""
Handle asynchronous streaming requests.
"""
stream = self._make_async_streaming_request(
endpoint=f"{api_base}/completions",
payload={
"model": model,
"messages": messages,
"stream": True,
**optional_params
},
headers={"Authorization": f"Bearer {api_key}"},
timeout=timeout
)
async for chunk in stream:
if chunk.get("text"):
yield GenericStreamingChunk(
text=chunk["text"],
is_finished=chunk.get("is_finished", False),
finish_reason=chunk.get("finish_reason")
)
Optional Methods
def embedding(
self,
model: str,
input: list,
model_response: EmbeddingResponse,
print_verbose: callable,
logging_obj: any,
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
litellm_params=None,
) -> EmbeddingResponse:
"""
Handle embedding requests.
"""
response = self._make_request(
endpoint=f"{api_base}/embeddings",
payload={
"model": model,
"input": input,
**optional_params
},
api_key=api_key
)
model_response.data = response["embeddings"]
return model_response
Complete Example
Full Custom Provider Implementation
Full Custom Provider Implementation
from litellm.llms.custom_llm import CustomLLM, CustomLLMError
from litellm.utils import ModelResponse
from litellm.types.utils import GenericStreamingChunk
from typing import Optional, Union, Iterator, AsyncIterator
import httpx
import json
class MyInternalLLM(CustomLLM):
def __init__(self):
super().__init__()
self.base_url = "https://internal-llm.company.com/api/v1"
def _make_request(self, endpoint: str, payload: dict, api_key: str, timeout=None):
"""Helper for HTTP requests"""
try:
response = httpx.post(
endpoint,
json=payload,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
timeout=timeout or 60.0
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise CustomLLMError(
status_code=e.response.status_code,
message=f"API request failed: {e.response.text}"
)
except Exception as e:
raise CustomLLMError(
status_code=500,
message=f"Request error: {str(e)}"
)
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose,
encoding,
api_key,
logging_obj,
optional_params: dict,
**kwargs
) -> ModelResponse:
"""Synchronous completion"""
print_verbose(f"MyInternalLLM: completion call for {model}")
# Make API request
response = self._make_request(
endpoint=f"{api_base or self.base_url}/chat/completions",
payload={
"model": model,
"messages": messages,
"temperature": optional_params.get("temperature", 0.7),
"max_tokens": optional_params.get("max_tokens", 1000),
},
api_key=api_key,
timeout=kwargs.get("timeout")
)
# Populate response
model_response.choices[0].message.content = response["choices"][0]["message"]["content"]
model_response.model = model
# Set usage
if "usage" in response:
model_response.usage.prompt_tokens = response["usage"].get("prompt_tokens", 0)
model_response.usage.completion_tokens = response["usage"].get("completion_tokens", 0)
model_response.usage.total_tokens = response["usage"].get("total_tokens", 0)
return model_response
async def acompletion(self, **kwargs) -> ModelResponse:
"""Async completion"""
# For simplicity, wrap sync version
# In production, use async HTTP client
return self.completion(**kwargs)
def streaming(
self,
model: str,
messages: list,
api_base: str,
api_key: str,
**kwargs
) -> Iterator[GenericStreamingChunk]:
"""Streaming completion"""
with httpx.stream(
"POST",
f"{api_base or self.base_url}/chat/completions",
json={
"model": model,
"messages": messages,
"stream": True,
},
headers={"Authorization": f"Bearer {api_key}"},
) as response:
for line in response.iter_lines():
if line.startswith("data: "):
data = json.loads(line[6:])
if data.get("choices"):
delta = data["choices"][0].get("delta", {})
if "content" in delta:
yield GenericStreamingChunk(
text=delta["content"],
is_finished=data["choices"][0].get("finish_reason") is not None
)
async def astreaming(self, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
"""Async streaming"""
# Wrap sync streaming for simplicity
for chunk in self.streaming(**kwargs):
yield chunk
# Register provider
import litellm
litellm.custom_provider_map = [
{
"provider": "my-internal-llm",
"custom_handler": MyInternalLLM()
}
]
# Use it
response = litellm.completion(
model="my-internal-llm/gpt-custom",
messages=[{"role": "user", "content": "Hello!"}],
api_key="internal-key-123"
)
print(response.choices[0].message.content)
Error Handling
Always useCustomLLMError for exceptions:
from litellm.llms.custom_llm import CustomLLMError
class MyCustomLLM(CustomLLM):
def completion(self, **kwargs):
try:
response = self._make_api_call(**kwargs)
return self._format_response(response)
except httpx.HTTPStatusError as e:
# Map HTTP errors
raise CustomLLMError(
status_code=e.response.status_code,
message=f"API error: {e.response.text}"
)
except httpx.TimeoutException:
# Handle timeouts
raise CustomLLMError(
status_code=408,
message="Request timeout"
)
except Exception as e:
# Catch-all for unexpected errors
raise CustomLLMError(
status_code=500,
message=f"Internal error: {str(e)}"
)
Using with LiteLLM Proxy
Deploy custom providers via the proxy:config.yaml
model_list:
- model_name: my-custom-model
litellm_params:
model: my-custom-llm/custom-model-v1
api_key: os.environ/CUSTOM_API_KEY
api_base: https://internal-api.company.com/v1
# In your proxy startup script
from my_custom_provider import MyCustomLLM
import litellm
litellm.custom_provider_map = [
{
"provider": "my-custom-llm",
"custom_handler": MyCustomLLM()
}
]
Best Practices
Use HTTP Clients from LiteLLM
Use HTTP Clients from LiteLLM
Leverage LiteLLM’s HTTP handlers for connection pooling:
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
class MyCustomLLM(CustomLLM):
def __init__(self):
super().__init__()
self.http_handler = HTTPHandler()
self.async_handler = AsyncHTTPHandler()
def completion(self, client=None, **kwargs):
# Use provided client or create one
http_client = client or self.http_handler
response = http_client.post(url, json=payload)
return response.json()
Implement Proper Token Counting
Implement Proper Token Counting
Provide accurate usage statistics:
def completion(self, model, messages, **kwargs) -> ModelResponse:
# Use encoding for token counting
encoding = kwargs.get("encoding")
# Count prompt tokens
prompt_text = " ".join([m["content"] for m in messages])
prompt_tokens = len(encoding.encode(prompt_text))
# Make request
response = self._make_request(**kwargs)
# Count completion tokens
completion_text = response["text"]
completion_tokens = len(encoding.encode(completion_text))
# Set usage
model_response.usage.prompt_tokens = prompt_tokens
model_response.usage.completion_tokens = completion_tokens
model_response.usage.total_tokens = prompt_tokens + completion_tokens
return model_response
Handle Rate Limiting
Handle Rate Limiting
Implement exponential backoff:
import time
def _make_request_with_retry(self, endpoint, payload, api_key, max_retries=3):
for attempt in range(max_retries):
try:
return self._make_request(endpoint, payload, api_key)
except CustomLLMError as e:
if e.status_code == 429 and attempt < max_retries - 1:
# Exponential backoff
wait_time = 2 ** attempt
time.sleep(wait_time)
continue
raise
Support Function Calling
Support Function Calling
Transform tool calls appropriately:
def completion(self, messages, optional_params, **kwargs):
# Extract tools from optional_params
tools = optional_params.get("tools", [])
# Transform to your API format
api_tools = self._transform_tools(tools)
# Include in request
response = self._make_request(
payload={
"messages": messages,
"tools": api_tools
}
)
# Transform tool calls back to OpenAI format
if response.get("tool_calls"):
model_response.choices[0].message.tool_calls = \
self._transform_tool_calls(response["tool_calls"])
return model_response
Testing Your Provider
import pytest
from litellm import completion
def test_custom_provider_completion():
response = completion(
model="my-custom-llm/test-model",
messages=[{"role": "user", "content": "Hello!"}],
api_key="test-key"
)
assert response.choices[0].message.content
assert response.model == "my-custom-llm/test-model"
assert response.usage.total_tokens > 0
@pytest.mark.asyncio
async def test_custom_provider_async():
response = await completion(
model="my-custom-llm/test-model",
messages=[{"role": "user", "content": "Hello!"}],
api_key="test-key"
)
assert response.choices[0].message.content
def test_custom_provider_streaming():
stream = completion(
model="my-custom-llm/test-model",
messages=[{"role": "user", "content": "Hello!"}],
stream=True,
api_key="test-key"
)
chunks = list(stream)
assert len(chunks) > 0
assert any(chunk.choices[0].delta.content for chunk in chunks)
Reference
Source Files
- Base class:
litellm/llms/custom_llm.py:47 - HTTP handlers:
litellm/llms/custom_httpx/http_handler.py - Error handling:
litellm/llms/custom_llm.py:34