Defensive Programming: Preventing Errors
Welcome to Defensive Programming! Instead of just handling errors when they occur, let’s learn to prevent them in the first place. Think of this as building a fortress around your code to keep errors out.
What is Defensive Programming?
Defensive programming is writing code that anticipates and prevents potential errors before they happen. It’s like being a chess grandmaster who thinks 10 moves ahead.
Offensive vs Defensive Programming
# Offensive Programming - assumes everything works
def calculate_average(numbers):
return sum(numbers) / len(numbers)
# What if numbers is empty? CRASH!
calculate_average([]) # ZeroDivisionError
# Defensive Programming - anticipates problems
def calculate_average_safe(numbers):
if not numbers:
return 0 # Or raise ValueError, or return None
return sum(numbers) / len(numbers)
calculate_average_safe([]) # Returns 0 safely
Input Validation
Type Checking
def add_numbers(a, b):
"""Add two numbers with defensive type checking."""
# Check types
if not isinstance(a, (int, float)):
raise TypeError(f"First argument must be a number, got {type(a)}")
if not isinstance(b, (int, float)):
raise TypeError(f"Second argument must be a number, got {type(b)}")
return a + b
# Test
try:
result = add_numbers(5, "3") # TypeError
except TypeError as e:
print(f"Error: {e}")
Value Range Validation
def set_age(age):
"""Set age with range validation."""
if not isinstance(age, int):
raise TypeError("Age must be an integer")
if age < 0:
raise ValueError("Age cannot be negative")
if age > 150:
raise ValueError("Age seems unreasonably high")
return age
def set_percentage(value):
"""Set percentage with bounds checking."""
if not isinstance(value, (int, float)):
raise TypeError("Percentage must be a number")
if not (0 <= value <= 100):
raise ValueError("Percentage must be between 0 and 100")
return value
Complex Validation with Contracts
from typing import List, Union
import inspect
def validate_contract(**validators):
"""Decorator for function input validation."""
def decorator(func):
def wrapper(*args, **kwargs):
# Get function signature
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
# Validate each parameter
for param_name, validator in validators.items():
if param_name in bound_args.arguments:
value = bound_args.arguments[param_name]
if not validator(value):
raise ValueError(f"Invalid value for {param_name}: {value}")
return func(*args, **kwargs)
return wrapper
return decorator
# Validation functions
def is_positive_number(x):
return isinstance(x, (int, float)) and x > 0
def is_non_empty_list(x):
return isinstance(x, list) and len(x) > 0
def is_valid_email(x):
return isinstance(x, str) and "@" in x and "." in x.split("@")[1]
@validate_contract(
price=is_positive_number,
items=is_non_empty_list,
email=is_valid_email
)
def process_order(price, items, email):
"""Process an order with validated inputs."""
print(f"Processing order for {email}")
print(f"Items: {items}")
print(f"Total: ${price}")
return {"status": "processed", "total": price}
# Test
try:
process_order(100, ["item1", "item2"], "user@example.com")
process_order(-50, [], "invalid-email") # Multiple validation errors
except ValueError as e:
print(f"Validation error: {e}")
Resource Management
Safe File Operations
import os
from pathlib import Path
def safe_read_file(filepath, encoding="utf-8", max_size_mb=10):
"""Read a file with comprehensive safety checks."""
# Convert to Path object for better path handling
path = Path(filepath)
# Check if path exists
if not path.exists():
raise FileNotFoundError(f"File not found: {filepath}")
# Check if it's actually a file
if not path.is_file():
raise ValueError(f"Path is not a file: {filepath}")
# Check file size
file_size = path.stat().st_size
max_size_bytes = max_size_mb * 1024 * 1024
if file_size > max_size_bytes:
raise ValueError(f"File too large: {file_size} bytes (max: {max_size_bytes})")
# Check permissions
if not os.access(path, os.R_OK):
raise PermissionError(f"Cannot read file: {filepath}")
# Read the file safely
try:
with open(path, "r", encoding=encoding) as file:
content = file.read()
return content
except UnicodeDecodeError as e:
raise ValueError(f"File encoding error: {e}")
except Exception as e:
raise RuntimeError(f"Error reading file: {e}")
# Usage
try:
content = safe_read_file("large_file.txt", max_size_mb=1)
except (FileNotFoundError, ValueError, PermissionError) as e:
print(f"File read failed: {e}")
Database Connection Safety
import sqlite3
from contextlib import contextmanager
@contextmanager
def safe_database_connection(db_path, timeout=5.0):
"""Context manager for safe database connections."""
# Validate database path
if not db_path or not isinstance(db_path, str):
raise ValueError("Invalid database path")
# Check if directory exists for new databases
db_file = Path(db_path)
if not db_file.exists():
db_file.parent.mkdir(parents=True, exist_ok=True)
connection = None
try:
# Establish connection with timeout
connection = sqlite3.connect(db_path, timeout=timeout)
# Enable foreign key constraints
connection.execute("PRAGMA foreign_keys = ON")
# Set busy timeout
connection.execute(f"PRAGMA busy_timeout = {int(timeout * 1000)}")
yield connection
except sqlite3.Error as e:
raise RuntimeError(f"Database error: {e}")
finally:
if connection:
connection.close()
def safe_execute_query(db_path, query, params=None):
"""Execute a database query safely."""
# Validate query
if not query or not isinstance(query, str):
raise ValueError("Invalid query")
# Basic SQL injection prevention (though parameterized queries are better)
dangerous_keywords = ["DROP", "DELETE", "UPDATE", "INSERT"]
query_upper = query.upper()
if any(keyword in query_upper for keyword in dangerous_keywords):
raise ValueError("Potentially dangerous query detected")
# Use parameterized queries
if params is not None and not isinstance(params, (list, tuple, dict)):
raise ValueError("Parameters must be a list, tuple, or dict")
try:
with safe_database_connection(db_path) as conn:
cursor = conn.cursor()
if params:
cursor.execute(query, params)
else:
cursor.execute(query)
# For SELECT queries, return results
if query_upper.strip().startswith("SELECT"):
return cursor.fetchall()
else:
conn.commit()
return cursor.rowcount
except sqlite3.IntegrityError as e:
raise ValueError(f"Data integrity error: {e}")
except sqlite3.OperationalError as e:
raise RuntimeError(f"Database operation error: {e}")
# Usage
try:
# Safe query execution
results = safe_execute_query("example.db",
"SELECT * FROM users WHERE age > ?",
(18,))
print(f"Found {len(results)} users")
except (ValueError, RuntimeError) as e:
print(f"Database operation failed: {e}")
Error Recovery and Fallbacks
Graceful Degradation
class DataProcessor:
def __init__(self):
self.primary_source = "database"
self.secondary_source = "cache"
self.fallback_source = "defaults"
def get_user_data(self, user_id):
"""Get user data with multiple fallback strategies."""
# Try primary source first
try:
return self._get_from_database(user_id)
except Exception as e:
print(f"Primary source failed: {e}")
# Fallback to cache
try:
return self._get_from_cache(user_id)
except Exception as e:
print(f"Secondary source failed: {e}")
# Final fallback to defaults
try:
return self._get_defaults(user_id)
except Exception as e:
print(f"All sources failed: {e}")
raise RuntimeError("Unable to retrieve user data from any source")
def _get_from_database(self, user_id):
"""Simulate database access."""
if user_id == "error_user":
raise ConnectionError("Database connection failed")
return {"name": "John", "source": "database"}
def _get_from_cache(self, user_id):
"""Simulate cache access."""
if user_id == "cache_error":
raise KeyError("Cache miss")
return {"name": "John", "source": "cache"}
def _get_defaults(self, user_id):
"""Provide default data."""
return {"name": "Unknown", "source": "defaults"}
# Usage
processor = DataProcessor()
# Test different scenarios
print(processor.get_user_data("normal_user")) # Uses database
print(processor.get_user_data("error_user")) # Falls back to cache
print(processor.get_user_data("cache_error")) # Falls back to defaults
Circuit Breaker Pattern
import time
from enum import Enum
class CircuitState(Enum):
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, requests blocked
HALF_OPEN = "half_open" # Testing if service recovered
class CircuitBreaker:
def __init__(self, failure_threshold=5, recovery_timeout=60, expected_exception=Exception):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.failure_count = 0
self.last_failure_time = None
self.state = CircuitState.CLOSED
def call(self, func, *args, **kwargs):
"""Execute function with circuit breaker protection."""
if self.state == CircuitState.OPEN:
if self._should_attempt_reset():
self.state = CircuitState.HALF_OPEN
else:
raise CircuitBreakerError("Circuit breaker is OPEN")
try:
result = func(*args, **kwargs)
self._on_success()
return result
except self.expected_exception as e:
self._on_failure()
raise e
def _should_attempt_reset(self):
"""Check if enough time has passed to try resetting."""
if self.last_failure_time is None:
return True
return time.time() - self.last_failure_time >= self.recovery_timeout
def _on_success(self):
"""Handle successful call."""
if self.state == CircuitState.HALF_OPEN:
self.state = CircuitState.CLOSED
self.failure_count = 0
print("Circuit breaker reset to CLOSED")
def _on_failure(self):
"""Handle failed call."""
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
self.state = CircuitState.OPEN
print(f"Circuit breaker opened after {self.failure_count} failures")
class CircuitBreakerError(Exception):
"""Exception raised when circuit breaker is open."""
pass
# Usage
def unreliable_service():
"""Simulate an unreliable service."""
import random
if random.random() < 0.7: # 70% failure rate
raise ConnectionError("Service temporarily unavailable")
return "Service response"
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=10)
for i in range(10):
try:
result = breaker.call(unreliable_service)
print(f"Call {i+1}: Success - {result}")
except CircuitBreakerError:
print(f"Call {i+1}: Circuit breaker OPEN")
except ConnectionError:
print(f"Call {i+1}: Service error")
time.sleep(1)
Sanitization and Security
Input Sanitization
import re
import html
def sanitize_string(input_str, max_length=1000):
"""Sanitize string input."""
if not isinstance(input_str, str):
raise TypeError("Input must be a string")
# Remove null bytes and other dangerous characters
sanitized = input_str.replace('\x00', '')
# Limit length
if len(sanitized) > max_length:
sanitized = sanitized[:max_length]
# Remove potentially dangerous patterns
# Remove script tags
sanitized = re.sub(r'<script[^>]*>.*?</script>', '', sanitized, flags=re.IGNORECASE | re.DOTALL)
# Escape HTML entities
sanitized = html.escape(sanitized)
return sanitized
def sanitize_filename(filename):
"""Sanitize filename to prevent directory traversal."""
if not isinstance(filename, str):
raise TypeError("Filename must be a string")
# Remove path separators
sanitized = re.sub(r'[\/\\]', '', filename)
# Remove dangerous characters
sanitized = re.sub(r'[<>:"|?*]', '', sanitized)
# Limit length
if len(sanitized) > 255:
sanitized = sanitized[:255]
# Ensure it's not empty and doesn't start with dot
if not sanitized or sanitized.startswith('.'):
raise ValueError("Invalid filename")
return sanitized
# Test sanitization
print(sanitize_string('<script>alert("xss")</script>Hello World'))
print(sanitize_filename('../../../etc/passwd'))
SQL Injection Prevention
def safe_sql_query(db_path, table, columns=None, where_clause=None, params=None):
"""Build safe SQL queries with proper escaping."""
# Validate table name (basic check)
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', table):
raise ValueError("Invalid table name")
# Validate columns
if columns is None:
columns = ["*"]
elif isinstance(columns, str):
columns = [columns]
for col in columns:
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', col) and col != "*":
raise ValueError(f"Invalid column name: {col}")
# Build query safely
cols_str = ", ".join(columns)
query = f"SELECT {cols_str} FROM {table}"
if where_clause:
# Only allow simple WHERE clauses with placeholders
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*\s*[=<>!]+\s*\?$', where_clause.strip()):
raise ValueError("Invalid WHERE clause format")
query += f" WHERE {where_clause}"
# Execute with parameters
try:
with safe_database_connection(db_path) as conn:
cursor = conn.cursor()
if params:
cursor.execute(query, params)
else:
cursor.execute(query)
return cursor.fetchall()
except Exception as e:
raise RuntimeError(f"Query execution failed: {e}")
# Usage
try:
# Safe query
results = safe_sql_query("users.db", "users", ["name", "email"], "age > ?", (18,))
print(f"Found {len(results)} users")
# This would fail validation
# results = safe_sql_query("users.db", "users; DROP TABLE users;", ["*"])
except ValueError as e:
print(f"Validation error: {e}")
Testing for Defensive Code
Unit Tests for Error Conditions
import pytest
def test_safe_divide():
"""Test safe division function."""
from your_module import safe_divide # Assume this function exists
# Normal cases
assert safe_divide(10, 2) == 5
assert safe_divide(10, 0) == 0 # Should handle division by zero
# Type errors
with pytest.raises(TypeError):
safe_divide("10", 2)
with pytest.raises(TypeError):
safe_divide(10, "2")
def test_input_validation():
"""Test input validation functions."""
# Test age validation
assert set_age(25) == 25
with pytest.raises(TypeError):
set_age("25")
with pytest.raises(ValueError):
set_age(-5)
with pytest.raises(ValueError):
set_age(200)
def test_file_operations():
"""Test safe file operations."""
# Test with non-existent file
with pytest.raises(FileNotFoundError):
safe_read_file("nonexistent.txt")
# Test with directory
with pytest.raises(ValueError):
safe_read_file(".")
# Test with valid file (create one first)
test_file = "test_file.txt"
with open(test_file, "w") as f:
f.write("test content")
try:
content = safe_read_file(test_file)
assert content == "test content"
finally:
os.remove(test_file)
# Run tests
if __name__ == "__main__":
pytest.main([__file__])
Fuzz Testing
import random
import string
def fuzz_test_function(func, iterations=1000):
"""Fuzz test a function with random inputs."""
errors = []
for i in range(iterations):
try:
# Generate random inputs
if func.__name__ == "add_numbers":
a = random.choice([random.randint(-100, 100), random.random(), "string", None, []])
b = random.choice([random.randint(-100, 100), random.random(), "string", None, []])
func(a, b)
elif func.__name__ == "set_age":
age = random.choice([random.randint(-1000, 1000), random.random(), "string", None, []])
func(age)
# Add more function-specific fuzzing here
except Exception as e:
errors.append((i, type(e).__name__, str(e)))
return errors
# Test with fuzzing
def add_numbers_fuzz(a, b):
"""Fuzz-safe version of add_numbers."""
try:
return add_numbers(a, b)
except (TypeError, ValueError):
return 0 # Safe fallback
def set_age_fuzz(age):
"""Fuzz-safe version of set_age."""
try:
return set_age(age)
except (TypeError, ValueError):
return 0 # Safe fallback
# Run fuzz tests
print("Fuzz testing add_numbers...")
errors = fuzz_test_function(add_numbers_fuzz)
print(f"Found {len(errors)} errors in {1000} iterations")
print("Fuzz testing set_age...")
errors = fuzz_test_function(set_age_fuzz)
print(f"Found {len(errors)} errors in {1000} iterations")
Real-World Examples
Example 1: Safe Web API Client
import requests
from urllib.parse import urlparse
import time
class SafeAPIClient:
def __init__(self, base_url, timeout=30, max_retries=3):
self.base_url = self._validate_base_url(base_url)
self.timeout = self._validate_timeout(timeout)
self.max_retries = max_retries
self.session = requests.Session()
def _validate_base_url(self, url):
"""Validate base URL."""
if not isinstance(url, str):
raise TypeError("Base URL must be a string")
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
raise ValueError("Invalid URL format")
if parsed.scheme not in ["http", "https"]:
raise ValueError("Only HTTP and HTTPS URLs are supported")
return url.rstrip("/")
def _validate_timeout(self, timeout):
"""Validate timeout value."""
if not isinstance(timeout, (int, float)) or timeout <= 0:
raise ValueError("Timeout must be a positive number")
return timeout
def _safe_request(self, method, endpoint, **kwargs):
"""Make a safe HTTP request."""
# Validate endpoint
if not isinstance(endpoint, str) or endpoint.startswith("/"):
raise ValueError("Invalid endpoint format")
url = f"{self.base_url}/{endpoint}"
# Set default timeout
kwargs.setdefault("timeout", self.timeout)
# Retry logic
for attempt in range(self.max_retries + 1):
try:
response = self.session.request(method, url, **kwargs)
# Validate response
response.raise_for_status()
return response
except requests.exceptions.Timeout:
if attempt == self.max_retries:
raise RuntimeError(f"Request timeout after {self.max_retries + 1} attempts")
time.sleep(2 ** attempt) # Exponential backoff
except requests.exceptions.ConnectionError:
if attempt == self.max_retries:
raise RuntimeError(f"Connection failed after {self.max_retries + 1} attempts")
time.sleep(2 ** attempt)
except requests.exceptions.HTTPError as e:
# Don't retry client errors (4xx)
if 400 <= e.response.status_code < 500:
raise ValueError(f"Client error: {e.response.status_code}")
# Retry server errors (5xx)
if attempt == self.max_retries:
raise RuntimeError(f"Server error: {e.response.status_code}")
time.sleep(2 ** attempt)
def get(self, endpoint, **kwargs):
"""Safe GET request."""
return self._safe_request("GET", endpoint, **kwargs)
def post(self, endpoint, data=None, json=None, **kwargs):
"""Safe POST request."""
return self._safe_request("POST", endpoint, data=data, json=json, **kwargs)
# Usage
try:
client = SafeAPIClient("https://jsonplaceholder.typicode.com")
# Safe GET request
response = client.get("posts/1")
print(f"Response: {response.json()}")
# This would fail validation
# client = SafeAPIClient("not-a-url")
except (TypeError, ValueError, RuntimeError) as e:
print(f"API client error: {e}")
Example 2: Configuration Manager
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional
class SafeConfigManager:
def __init__(self, config_file: str, schema: Optional[Dict] = None):
self.config_file = Path(config_file)
self.schema = schema or {}
self._config = self._load_config()
def _load_config(self) -> Dict[str, Any]:
"""Load configuration with validation."""
# Check if file exists
if not self.config_file.exists():
print(f"Config file {self.config_file} not found, using defaults")
return self._get_defaults()
# Check file size (prevent huge files)
if self.config_file.stat().st_size > 1024 * 1024: # 1MB limit
raise ValueError("Config file too large")
try:
with open(self.config_file, "r", encoding="utf-8") as f:
config = json.load(f)
# Validate against schema
self._validate_config(config)
return config
except json.JSONDecodeError as e:
print(f"Invalid JSON in config file: {e}")
return self._get_defaults()
except Exception as e:
print(f"Error loading config: {e}")
return self._get_defaults()
def _validate_config(self, config: Dict[str, Any]) -> None:
"""Validate configuration against schema."""
for key, expected_type in self.schema.items():
if key in config:
value = config[key]
if not isinstance(value, expected_type):
raise TypeError(f"Config key '{key}' must be {expected_type.__name__}, got {type(value).__name__}")
# Additional validation
if expected_type == int and value < 0:
raise ValueError(f"Config key '{key}' must be non-negative")
elif expected_type == str and not value.strip():
raise ValueError(f"Config key '{key}' cannot be empty")
def _get_defaults(self) -> Dict[str, Any]:
"""Get default configuration."""
return {
"debug": False,
"max_connections": 10,
"timeout": 30,
"log_level": "INFO"
}
def get(self, key: str, default: Any = None) -> Any:
"""Get configuration value safely."""
return self._config.get(key, default)
def set(self, key: str, value: Any) -> None:
"""Set configuration value with validation."""
# Validate against schema if key is defined
if key in self.schema:
expected_type = self.schema[key]
if not isinstance(value, expected_type):
raise TypeError(f"Value for '{key}' must be {expected_type.__name__}")
# Additional validation
if expected_type == int and value < 0:
raise ValueError(f"Value for '{key}' must be non-negative")
self._config[key] = value
self._save_config()
def _save_config(self) -> None:
"""Save configuration to file."""
try:
# Create directory if it doesn't exist
self.config_file.parent.mkdir(parents=True, exist_ok=True)
with open(self.config_file, "w", encoding="utf-8") as f:
json.dump(self._config, f, indent=2)
except Exception as e:
print(f"Error saving config: {e}")
# Usage
schema = {
"debug": bool,
"max_connections": int,
"timeout": int,
"log_level": str
}
config = SafeConfigManager("app_config.json", schema)
print(f"Debug mode: {config.get('debug')}")
print(f"Max connections: {config.get('max_connections')}")
# Safe setting
config.set("max_connections", 20)
config.set("debug", True)
# This would fail validation
try:
config.set("max_connections", "not_a_number")
except TypeError as e:
print(f"Validation error: {e}")
Best Practices Summary
1. Validate All Inputs
def process_data(data):
# Type checking
if not isinstance(data, dict):
raise TypeError("Data must be a dictionary")
# Required fields
required = ["name", "value"]
for field in required:
if field not in data:
raise ValueError(f"Missing required field: {field}")
# Value validation
if not isinstance(data["value"], (int, float)):
raise TypeError("Value must be a number")
if data["value"] <= 0:
raise ValueError("Value must be positive")
2. Use Resource Management
# Context managers for automatic cleanup
with open("file.txt", "r") as f:
data = f.read()
# Custom context managers
class DatabaseConnection:
def __enter__(self): return self
def __exit__(self, exc_type, exc_val, exc_tb): self.close()
3. Implement Fallbacks
def get_data(source):
try:
return get_from_primary(source)
except Exception:
try:
return get_from_backup(source)
except Exception:
return get_defaults()
4. Sanitize Inputs
def sanitize_input(text):
# Remove dangerous characters
text = re.sub(r'[<>]', '', text)
# Escape special characters
text = html.escape(text)
# Limit length
return text[:1000]
5. Test Error Conditions
def test_edge_cases():
# Test with invalid inputs
with pytest.raises(ValueError):
my_function(None)
with pytest.raises(TypeError):
my_function("invalid")
# Test boundary conditions
assert my_function(0) == expected
assert my_function(1000) == expected
Practice Exercises
Exercise 1: Safe Calculator
Create a calculator that:
- Validates all inputs
- Handles division by zero
- Prevents overflow
- Sanitizes expressions
Exercise 2: File Manager
Build a file manager that:
- Validates file paths
- Checks permissions
- Handles large files
- Provides safe operations
Exercise 3: User Registration System
Create a registration system that:
- Validates emails
- Checks password strength
- Prevents SQL injection
- Sanitizes all inputs
Exercise 4: Configuration Validator
Build a configuration validator that:
- Validates config files
- Provides defaults
- Handles corruption
- Supports schemas
Exercise 5: Network Request Handler
Create a network client that:
- Validates URLs
- Handles timeouts
- Implements retries
- Sanitizes responses
Summary
Defensive programming prevents errors before they occur:
Input Validation:
- Type checking with
isinstance() - Range validation
- Format validation
- Sanitization
Resource Management:
- Context managers (
withstatements) - Proper cleanup in
finallyblocks - Resource limits
Error Recovery:
- Fallback strategies
- Circuit breakers
- Graceful degradation
Security:
- Input sanitization
- SQL injection prevention
- XSS protection
Testing:
- Unit tests for error conditions
- Fuzz testing
- Boundary testing
Key Principles:
- Fail fast with clear error messages
- Never trust user input
- Use safe defaults
- Test thoroughly
- Log errors for debugging
Next: Object-Oriented Programming - organizing your code! 🏗️