Skip to main content

LangChain Integration

Integrate Koreshield security directly into your LangChain pipelines for comprehensive LLM protection.

Installation

pip install Koreshield-sdk langchain langchain-openai

Basic Integration

Callback Handler

Create a Koreshield callback handler:
from langchain.callbacks.base import BaseCallbackHandler
from Koreshield_sdk import KoreshieldClient

class KoreshieldCallback(BaseCallbackHandler):
    def __init__(self, api_key: str, sensitivity: str = "medium"):
        self.client = KoreshieldClient(api_key=api_key)
        self.sensitivity = sensitivity
    
    def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs):
        """Scan prompts before sending to LLM"""
        for prompt in prompts:
            result = self.client.scan(
                input=prompt,
                sensitivity=self.sensitivity
            )
            
            if result.is_threat:
                raise ValueError(
                    f"Security threat detected: {result.attack_types[0]} "
                    f"(confidence: {result.confidence:.2f})"
                )
    
    def on_chain_start(self, serialized: dict, inputs: dict, **kwargs):
        """Scan chain inputs"""
        for key, value in inputs.items():
            if isinstance(value, str):
                result = self.client.scan(value)
                if result.is_threat:
                    raise ValueError(f"Threat in {key}: {result.attack_types}")

# Usage
from langchain_openai import ChatOpenAI

Koreshield_callback = KoreshieldCallback(api_key="ks_prod_xxx")

llm = ChatOpenAI(
    model="gpt-4",
    callbacks=[Koreshield_callback]
)

# Prompts are automatically scanned
response = llm.invoke("What is the capital of France?")

Chain Protection

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from Koreshield_sdk import KoreshieldClient

Koreshield = KoreshieldClient(api_key="ks_prod_xxx")

def secure_chain(user_input: str):
    # Scan input first
    scan_result = Koreshield.scan(user_input)
    
    if scan_result.is_threat:
        return {
            "error": "Security violation detected",
            "attack_type": scan_result.attack_types[0],
            "confidence": scan_result.confidence
        }
    
    # Safe to process
    prompt = PromptTemplate(
        input_variables=["question"],
        template="Answer this question: {question}"
    )
    
    chain = LLMChain(llm=llm, prompt=prompt)
    return chain.invoke({"question": user_input})

# Use secure chain
result = secure_chain("How does photosynthesis work?")

Agent Protection

Secure Agent

from langchain.agents import create_openai_functions_agent, AgentExecutor
from langchain.tools import Tool
from langchain import hub

# Create tools
def search_tool(query: str) -> str:
    # Scan tool input
    result = Koreshield.scan(query)
    if result.is_threat:
        return f"Security violation: {result.attack_types[0]}"
    return f"Search results for: {query}"

tools = [
    Tool(
        name="Search",
        func=search_tool,
        description="Search for information"
    )
]

# Create secure agent
prompt = hub.pull("hwchase17/openai-functions-agent")

agent = create_openai_functions_agent(
    llm=llm,
    tools=tools,
    prompt=prompt
)

agent_executor = AgentExecutor(
    agent=agent,
    tools=tools,
    callbacks=[Koreshield_callback],
    verbose=True
)

# Execute with automatic security
result = agent_executor.invoke({
    "input": "Search for Python tutorials"
})

Custom Security Layer

Wrapper Class

from typing import Any, List, Optional
from langchain.llms.base import LLM
from Koreshield_sdk import KoreshieldClient

class SecureLLM(LLM):
    """LLM wrapper with Koreshield protection"""
    
    llm: LLM
    Koreshield: KoreshieldClient
    sensitivity: str = "medium"
    block_on_threat: bool = True
    
    def __init__(self, llm: LLM, api_key: str, **kwargs):
        super().__init__(**kwargs)
        self.llm = llm
        self.Koreshield = KoreshieldClient(api_key=api_key)
    
    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        **kwargs: Any
    ) -> str:
        # Scan input
        scan_result = self.Koreshield.scan(
            input=prompt,
            sensitivity=self.sensitivity
        )
        
        if scan_result.is_threat:
            if self.block_on_threat:
                raise ValueError(
                    f"Blocked: {scan_result.attack_types[0]} "
                    f"(confidence: {scan_result.confidence:.2%})"
                )
            else:
                # Log but allow
                print(f"Warning: Potential threat detected - {scan_result.attack_types}")
        
        # Call underlying LLM
        return self.llm._call(prompt, stop=stop, **kwargs)
    
    @property
    def _llm_type(self) -> str:
        return f"secure_{self.llm._llm_type}"

# Usage
from langchain_openai import OpenAI

base_llm = OpenAI(temperature=0.7)
secure_llm = SecureLLM(
    llm=base_llm,
    api_key="ks_prod_xxx",
    sensitivity="high"
)

response = secure_llm.invoke("Tell me about AI safety")

Multi-Tenancy Support

class TenantAwareLLM(LLM):
    """LLM with per-tenant security policies"""
    
    def __init__(self, llm: LLM, api_key: str):
        super().__init__()
        self.llm = llm
        self.Koreshield = KoreshieldClient(api_key=api_key)
        self.tenant_policies = {
            "free": {"sensitivity": "high", "max_requests": 100},
            "pro": {"sensitivity": "medium", "max_requests": 10000},
            "enterprise": {"sensitivity": "low", "max_requests": -1}
        }
    
    def _call(self, prompt: str, tenant_id: str = "free", **kwargs) -> str:
        policy = self.tenant_policies.get(tenant_id, self.tenant_policies["free"])
        
        # Apply tenant-specific security
        scan_result = self.Koreshield.scan(
            input=prompt,
            sensitivity=policy["sensitivity"],
            metadata={"tenant_id": tenant_id}
        )
        
        if scan_result.is_threat:
            raise ValueError(f"Threat detected for tenant {tenant_id}")
        
        return self.llm._call(prompt, **kwargs)

LangChain Expression Language (LCEL)

from langchain_core.runnables import RunnableLambda
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

Koreshield = KoreshieldClient(api_key="ks_prod_xxx")

# Security check runnable
def security_check(input_dict):
    message = input_dict.get("question", "")
    result = Koreshield.scan(message)
    
    if result.is_threat:
        raise ValueError(f"Security threat: {result.attack_types[0]}")
    
    return input_dict

# Build secure chain with LCEL
prompt = ChatPromptTemplate.from_template("Answer: {question}")
model = ChatOpenAI(model="gpt-4")

secure_chain = (
    RunnableLambda(security_check) |
    prompt |
    model
)

# Use chain
response = secure_chain.invoke({"question": "What is machine learning?"})

Async Support

import asyncio
from Koreshield_sdk import AsyncKoreshieldClient

class AsyncSecureCallback(BaseCallbackHandler):
    def __init__(self, api_key: str):
        self.client = AsyncKoreshieldClient(api_key=api_key)
    
    async def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs):
        for prompt in prompts:
            result = await self.client.scan(prompt)
            if result.is_threat:
                raise ValueError(f"Threat: {result.attack_types}")

# Async usage
async def async_query(question: str):
    callback = AsyncSecureCallback(api_key="ks_prod_xxx")
    
    llm = ChatOpenAI(callbacks=[callback])
    response = await llm.ainvoke(question)
    return response

# Run
result = asyncio.run(async_query("How does encryption work?"))

Streaming with Protection

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

class SecureStreamingCallback(StreamingStdOutCallbackHandler):
    def __init__(self, api_key: str):
        super().__init__()
        self.Koreshield = KoreshieldClient(api_key=api_key)
        self.buffer = ""
    
    def on_llm_new_token(self, token: str, **kwargs):
        # Accumulate tokens
        self.buffer += token
        
        # Check for threats in output (optional)
        if len(self.buffer) > 100:
            result = self.Koreshield.scan(self.buffer[-100:])
            if result.is_threat:
                raise ValueError("Threat detected in output")
        
        super().on_llm_new_token(token, **kwargs)

# Use streaming with security
llm = ChatOpenAI(
    streaming=True,
    callbacks=[SecureStreamingCallback(api_key="ks_prod_xxx")]
)

response = llm.invoke("Explain quantum computing")

Memory Protection

from langchain.memory import ConversationBufferMemory

class SecureMemory(ConversationBufferMemory):
    def __init__(self, api_key: str, **kwargs):
        super().__init__(**kwargs)
        self.Koreshield = KoreshieldClient(api_key=api_key)
    
    def save_context(self, inputs: dict, outputs: dict):
        # Scan before saving to memory
        for value in inputs.values():
            if isinstance(value, str):
                result = self.Koreshield.scan(value)
                if result.is_threat:
                    # Don't save malicious content to memory
                    return
        
        super().save_context(inputs, outputs)

# Use secure memory
memory = SecureMemory(api_key="ks_prod_xxx")

chain = LLMChain(
    llm=llm,
    prompt=prompt,
    memory=memory
)

Testing

import pytest
from unittest.mock import Mock, patch

@pytest.fixture
def mock_Koreshield():
    with patch('Koreshield_sdk.KoreshieldClient') as mock:
        client = mock.return_value
        client.scan.return_value = Mock(
            is_threat=False,
            confidence=0.1,
            attack_types=[]
        )
        yield client

def test_secure_chain(mock_Koreshield):
    callback = KoreshieldCallback(api_key="test_key")
    llm = ChatOpenAI(callbacks=[callback])
    
    response = llm.invoke("Safe question")
    assert response is not None

def test_threat_detection(mock_Koreshield):
    mock_Koreshield.scan.return_value = Mock(
        is_threat=True,
        confidence=0.95,
        attack_types=["prompt_injection"]
    )
    
    callback = KoreshieldCallback(api_key="test_key")
    llm = ChatOpenAI(callbacks=[callback])
    
    with pytest.raises(ValueError, match="Security threat"):
        llm.invoke("Malicious input")

Best Practices

# 1. Input validation
user_input_scan = Koreshield.scan(user_input)

# 2. Chain input validation  
chain_input_scan = Koreshield.scan(formatted_prompt)

# 3. Tool input validation
tool_input_scan = Koreshield.scan(tool_args)

# 4. Optional: Output validation
output_scan = Koreshield.scan(llm_response)
def safe_invoke(chain, input_data):
    try:
        return chain.invoke(input_data)
    except ValueError as e:
        if "Security threat" in str(e):
            return {
                "error": "Your request was blocked for security reasons",
                "support": "[email protected]"
            }
        raise

Python SDK

Complete Python SDK documentation

RAG Security

Secure your RAG pipelines

LangChain Docs

Official LangChain documentation

Support

Build docs developers (and LLMs) love