Skip to main content

Overview

Custom nodes allow you to extend ScrapeGraphAI with specialized functionality tailored to your scraping needs. All nodes inherit from the BaseNode abstract base class.

BaseNode Architecture

The BaseNode class provides the foundation for all nodes in ScrapeGraphAI:
~/workspace/source/scrapegraphai/nodes/base_node.py
from abc import ABC, abstractmethod
from typing import List, Optional

class BaseNode(ABC):
    def __init__(
        self,
        node_name: str,
        node_type: str,
        input: str,
        output: List[str],
        min_input_len: int = 1,
        node_config: Optional[dict] = None,
    ):
        self.node_name = node_name
        self.input = input
        self.output = output
        self.min_input_len = min_input_len
        self.node_config = node_config
        self.logger = get_logger()

        if node_type not in ["node", "conditional_node"]:
            raise ValueError(
                f"node_type must be 'node' or 'conditional_node', got '{node_type}'"
            )
        self.node_type = node_type

    @abstractmethod
    def execute(self, state: dict) -> dict:
        """Execute the node's logic and return updated state."""
        pass

Key Attributes

  • node_name: Unique identifier for the node
  • node_type: Either "node" or "conditional_node"
  • input: Boolean expression defining required state keys
  • output: List of state keys this node will update
  • min_input_len: Minimum number of input keys required
  • node_config: Dictionary of additional configuration

Creating a Simple Node

Here’s how to create a basic custom node:
1

Import Required Classes

from typing import List, Optional
from scrapegraphai.nodes import BaseNode
2

Define Your Node Class

class CustomTextCleanerNode(BaseNode):
    """
    A node that cleans and normalizes text content.
    """

    def __init__(
        self,
        input: str,
        output: List[str],
        node_config: Optional[dict] = None,
        node_name: str = "CustomTextCleaner",
    ):
        super().__init__(node_name, "node", input, output, 1, node_config)

        # Custom configuration
        self.remove_html = (
            True if node_config is None
            else node_config.get("remove_html", True)
        )
        self.lowercase = (
            False if node_config is None
            else node_config.get("lowercase", False)
        )
        self.verbose = (
            False if node_config is None
            else node_config.get("verbose", False)
        )
3

Implement the execute() Method

    def execute(self, state: dict) -> dict:
        """
        Executes the text cleaning logic.

        Args:
            state: Current graph state

        Returns:
            Updated state with cleaned text
        """
        self.logger.info(f"--- Executing {self.node_name} Node ---")

        # Get input keys from state
        input_keys = self.get_input_keys(state)
        input_data = [state[key] for key in input_keys]

        text = input_data[0]

        # Clean the text
        if self.remove_html:
            import re
            text = re.sub(r'<[^>]+>', '', text)

        if self.lowercase:
            text = text.lower()

        # Remove extra whitespace
        text = ' '.join(text.split())

        if self.verbose:
            self.logger.info(f"Cleaned text length: {len(text)} characters")

        # Update state with output
        state.update({self.output[0]: text})
        return state

Using Your Custom Node

from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchNode

# Create custom node instance
cleaner_node = CustomTextCleanerNode(
    input="doc",
    output=["cleaned_text"],
    node_config={
        "remove_html": True,
        "lowercase": True,
        "verbose": True,
    },
)

# Use in graph
graph = BaseGraph(
    nodes=[fetch_node, cleaner_node],
    edges=[(fetch_node, cleaner_node)],
    entry_point=fetch_node,
)

Advanced Example: Conditional Node

Conditional nodes determine the next node based on logic:
class ContentTypeRouter(BaseNode):
    """
    Routes execution based on content type detection.
    """

    def __init__(
        self,
        input: str,
        output: List[str],
        node_config: Optional[dict] = None,
        node_name: str = "ContentRouter",
    ):
        super().__init__(node_name, "conditional_node", input, output, 1, node_config)

        # These will be set by BaseGraph
        self.true_node_name = None
        self.false_node_name = None

    def execute(self, state: dict) -> str:
        """
        Returns the name of the next node to execute.

        Returns:
            Node name to execute next
        """
        self.logger.info(f"--- Executing {self.node_name} Node ---")

        input_keys = self.get_input_keys(state)
        content = state[input_keys[0]]

        # Detect content type
        if "<table" in content.lower():
            self.logger.info("Detected table content")
            return self.true_node_name  # Route to table parser
        else:
            self.logger.info("Detected regular content")
            return self.false_node_name  # Route to regular parser
Conditional nodes must return a node name (string) instead of updating the state. They require exactly two outgoing edges in the graph.

Real-World Example: API Integration Node

Here’s a practical example that enriches scraped data with external API calls:
import requests
from typing import List, Optional
from scrapegraphai.nodes import BaseNode

class APIEnrichmentNode(BaseNode):
    """
    Enriches scraped data with information from an external API.
    """

    def __init__(
        self,
        input: str,
        output: List[str],
        node_config: Optional[dict] = None,
        node_name: str = "APIEnrichment",
    ):
        super().__init__(node_name, "node", input, output, 1, node_config)

        # API configuration
        self.api_url = node_config.get("api_url")
        self.api_key = node_config.get("api_key")
        self.timeout = node_config.get("timeout", 10)
        self.verbose = node_config.get("verbose", False)

        if not self.api_url:
            raise ValueError("api_url is required in node_config")

    def execute(self, state: dict) -> dict:
        self.logger.info(f"--- Executing {self.node_name} Node ---")

        input_keys = self.get_input_keys(state)
        data = state[input_keys[0]]

        enriched_data = []

        # Process each item
        for item in data:
            try:
                # Call external API
                headers = {"Authorization": f"Bearer {self.api_key}"}
                response = requests.get(
                    self.api_url,
                    params={"query": item.get("name")},
                    headers=headers,
                    timeout=self.timeout
                )
                response.raise_for_status()

                # Merge API data with scraped data
                api_data = response.json()
                enriched_item = {**item, **api_data}
                enriched_data.append(enriched_item)

                if self.verbose:
                    self.logger.info(f"Enriched item: {item.get('name')}")

            except requests.RequestException as e:
                self.logger.warning(f"API call failed: {e}")
                enriched_data.append(item)  # Keep original data

        # Update state
        state.update({self.output[0]: enriched_data})
        return state

Input Expression Parsing

The get_input_keys() method parses input expressions:
# Simple input
input = "url"  # Requires 'url' in state

# OR logic
input = "url | local_dir"  # Accepts either

# AND logic
input = "user_prompt & parsed_doc"  # Requires both

# Complex expressions
input = "user_prompt & (relevant_chunks | parsed_doc | doc)"
# Requires user_prompt AND at least one of the others

# In your execute method:
input_keys = self.get_input_keys(state)
# Returns: ['user_prompt', 'relevant_chunks'] (for example)

Node Configuration Patterns

Pattern 1: Model Integration

class LLMProcessingNode(BaseNode):
    def __init__(self, input, output, node_config=None, node_name="LLMProcessor"):
        super().__init__(node_name, "node", input, output, 1, node_config)

        self.llm_model = node_config.get("llm_model")
        self.temperature = node_config.get("temperature", 0.7)
        self.max_tokens = node_config.get("max_tokens", 1000)

    def execute(self, state: dict) -> dict:
        # Use self.llm_model for processing
        pass

Pattern 2: Multiple Outputs

class DataSplitterNode(BaseNode):
    def __init__(self, input, output, node_config=None, node_name="DataSplitter"):
        super().__init__(node_name, "node", input, output, 1, node_config)

    def execute(self, state: dict) -> dict:
        input_keys = self.get_input_keys(state)
        data = state[input_keys[0]]

        # Split into multiple outputs
        tables = self._extract_tables(data)
        text = self._extract_text(data)
        images = self._extract_images(data)

        # Update multiple output keys
        state.update({
            self.output[0]: tables,
            self.output[1]: text,
            self.output[2]: images,
        })
        return state

Pattern 3: Stateful Processing

class CacheNode(BaseNode):
    def __init__(self, input, output, node_config=None, node_name="Cache"):
        super().__init__(node_name, "node", input, output, 1, node_config)
        self._cache = {}  # Instance variable

    def execute(self, state: dict) -> dict:
        input_keys = self.get_input_keys(state)
        key = state[input_keys[0]]

        # Check cache
        if key in self._cache:
            result = self._cache[key]
            self.logger.info("Cache hit")
        else:
            result = self._expensive_operation(key)
            self._cache[key] = result
            self.logger.info("Cache miss")

        state.update({self.output[0]: result})
        return state

Testing Custom Nodes

Always test your custom nodes in isolation:
import unittest

class TestCustomNode(unittest.TestCase):
    def test_text_cleaner(self):
        node = CustomTextCleanerNode(
            input="doc",
            output=["cleaned_text"],
            node_config={"lowercase": True}
        )

        state = {"doc": "<p>Hello WORLD!</p>"}
        result = node.execute(state)

        self.assertEqual(result["cleaned_text"], "hello world!")

    def test_missing_input(self):
        node = CustomTextCleanerNode(
            input="doc",
            output=["cleaned_text"]
        )

        state = {}  # Missing 'doc'
        with self.assertRaises(ValueError):
            node.execute(state)

Built-in Node Reference

Learn from existing nodes in the codebase:
  • FetchNode: Downloads web pages or loads files (~/workspace/source/scrapegraphai/nodes/fetch_node.py:1)
  • ParseNode: Parses HTML and splits into chunks (~/workspace/source/scrapegraphai/nodes/parse_node.py:1)
  • RAGNode: Stores documents in vector database (~/workspace/source/scrapegraphai/nodes/rag_node.py:1)
  • GenerateAnswerNode: Generates LLM responses
  • ConditionalNode: Routes based on conditions
  • DescriptionNode: Generates content descriptions
  • RobotsNode: Checks robots.txt compliance

Best Practices

  1. Always call super().init(): Ensures proper BaseNode initialization
  2. Use self.logger: Leverage the built-in logger for debugging
  3. Validate node_config: Check for required configuration keys
  4. Handle errors gracefully: Use try-except blocks and meaningful error messages
  5. Document input/output: Clearly specify expected state keys
  6. Set min_input_len: Define minimum required inputs
  7. Test thoroughly: Write unit tests for edge cases

Next Steps

Build docs developers (and LLMs) love