Giao diện
Input Validation Security
Validate và sanitize mọi input - tuyến phòng thủ đầu tiên chống lại attacks
Learning Outcomes
Sau khi hoàn thành trang này, bạn sẽ:
- 🎯 Master Pydantic validation cho type-safe input handling
- 🎯 Hiểu input sanitization techniques và khi nào cần dùng
- 🎯 Phòng chống SQL injection với parameterized queries
- 🎯 Ngăn chặn Command injection attacks
- 🎯 Tránh các Production Pitfalls về validation
Tại sao Input Validation quan trọng?
python
# Rule #1: NEVER trust user input
# Mọi data từ bên ngoài đều có thể là malicious
# User input sources:
# - Form data, query params, headers
# - File uploads
# - API requests (JSON, XML)
# - Database reads (có thể bị tampered)
# - Environment variables (trong shared environments)Defense in Depth
┌─────────────────────────────────────────────────────┐
│ User Input │
└─────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ Layer 1: Type Validation (Pydantic) │
│ - Correct types, required fields │
└─────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ Layer 2: Business Validation │
│ - Range checks, format validation │
└─────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ Layer 3: Sanitization │
│ - Escape special chars, normalize │
└─────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ Layer 4: Context-specific Encoding │
│ - SQL params, HTML escape, shell escape │
└─────────────────────────────────────────────────────┘Pydantic Validation
Basic Validation
python
# SECURITY: Never hardcode credentials in production code
# Use environment variables or secure credential management systems
from pydantic import BaseModel, Field, EmailStr, field_validator
from typing import Optional
from datetime import date
class UserCreate(BaseModel):
"""User registration với comprehensive validation."""
username: str = Field(
...,
min_length=3,
max_length=50,
pattern=r'^[a-zA-Z0-9_]+$', # Alphanumeric + underscore only
description="Username (3-50 chars, alphanumeric)"
)
email: EmailStr # Built-in email validation
password: str = Field(
...,
min_length=8,
max_length=128,
description="Password (8-128 chars)"
)
age: Optional[int] = Field(default=None, ge=13, le=150)
birth_date: Optional[date] = None
@field_validator('password')
@classmethod
def password_strength(cls, v: str) -> str:
"""Validate password complexity."""
if not any(c.isupper() for c in v):
raise ValueError('Password must contain uppercase letter')
if not any(c.islower() for c in v):
raise ValueError('Password must contain lowercase letter')
if not any(c.isdigit() for c in v):
raise ValueError('Password must contain digit')
if not any(c in '!@#$%^&*()_+-=' for c in v):
raise ValueError('Password must contain special character')
return v
@field_validator('birth_date')
@classmethod
def validate_birth_date(cls, v: Optional[date]) -> Optional[date]:
"""Ensure birth date is in the past."""
if v and v >= date.today():
raise ValueError('Birth date must be in the past')
return v
# Usage
try:
user = UserCreate(
username="john_doe",
email="john@example.com",
password=os.getenv("DB_PASSWORD"),
age=25
)
except ValidationError as e:
print(e.json()) # Structured error response
)Custom Validators
python
from pydantic import BaseModel, field_validator, model_validator
from typing import Self
import re
class PaymentRequest(BaseModel):
"""Payment với cross-field validation."""
card_number: str
expiry_month: int = Field(ge=1, le=12)
expiry_year: int = Field(ge=2024, le=2050)
cvv: str
amount: float = Field(gt=0, le=10000)
currency: str = Field(pattern=r'^[A-Z]{3}$') # ISO 4217
@field_validator('card_number')
@classmethod
def validate_card_number(cls, v: str) -> str:
"""Validate card number format (Luhn algorithm)."""
# Remove spaces and dashes
cleaned = re.sub(r'[\s-]', '', v)
if not cleaned.isdigit():
raise ValueError('Card number must contain only digits')
if len(cleaned) not in (13, 14, 15, 16):
raise ValueError('Invalid card number length')
# Luhn algorithm check
if not cls._luhn_check(cleaned):
raise ValueError('Invalid card number (Luhn check failed)')
return cleaned
@staticmethod
def _luhn_check(card_number: str) -> bool:
"""Luhn algorithm for card validation."""
digits = [int(d) for d in card_number]
odd_digits = digits[-1::-2]
even_digits = digits[-2::-2]
total = sum(odd_digits)
for d in even_digits:
total += sum(divmod(d * 2, 10))
return total % 10 == 0
@field_validator('cvv')
@classmethod
def validate_cvv(cls, v: str) -> str:
"""Validate CVV format."""
if not v.isdigit() or len(v) not in (3, 4):
raise ValueError('CVV must be 3 or 4 digits')
return v
@model_validator(mode='after')
def validate_expiry(self) -> Self:
"""Validate card not expired."""
from datetime import date
today = date.today()
if (self.expiry_year < today.year or
(self.expiry_year == today.year and
self.expiry_month < today.month)):
raise ValueError('Card has expired')
return selfStrict Mode và Coercion
python
from pydantic import BaseModel, ConfigDict
# Default: Pydantic coerces types
class LooseModel(BaseModel):
count: int
loose = LooseModel(count="42") # OK - "42" -> 42
print(loose.count) # 42
# Strict mode: No coercion
class StrictModel(BaseModel):
model_config = ConfigDict(strict=True)
count: int
# StrictModel(count="42") # ValidationError!
strict = StrictModel(count=42) # OK
# Per-field strict
from pydantic import Field
class MixedModel(BaseModel):
loose_count: int # Allows coercion
strict_count: int = Field(strict=True) # No coercionInput Sanitization
String Sanitization
python
import re
import html
import unicodedata
def sanitize_string(
value: str,
max_length: int = 1000,
allow_newlines: bool = False,
strip_html: bool = True
) -> str:
"""
Sanitize string input.
Args:
value: Input string
max_length: Maximum allowed length
allow_newlines: Whether to preserve newlines
strip_html: Whether to escape HTML entities
Returns:
Sanitized string
"""
if not isinstance(value, str):
raise ValueError("Input must be a string")
# 1. Normalize Unicode (prevent homograph attacks)
value = unicodedata.normalize('NFKC', value)
# 2. Remove null bytes (can bypass security checks)
value = value.replace('\x00', '')
# 3. Strip or preserve whitespace
if allow_newlines:
# Normalize line endings
value = value.replace('\r\n', '\n').replace('\r', '\n')
# Remove excessive newlines
value = re.sub(r'\n{3,}', '\n\n', value)
else:
# Replace all whitespace with single space
value = ' '.join(value.split())
# 4. Escape HTML if needed
if strip_html:
value = html.escape(value)
# 5. Truncate to max length
if len(value) > max_length:
value = value[:max_length]
return value.strip()
# Usage
user_input = " <script>alert('xss')</script> Hello\x00World "
safe = sanitize_string(user_input)
print(safe) # "<script>alert('xss')</script> Hello World"Filename Sanitization
python
import os
import re
from pathlib import Path
def sanitize_filename(
filename: str,
max_length: int = 255,
allowed_extensions: set[str] | None = None
) -> str:
"""
Sanitize filename to prevent path traversal and other attacks.
Args:
filename: Original filename
max_length: Maximum filename length
allowed_extensions: Set of allowed extensions (e.g., {'.jpg', '.png'})
Returns:
Safe filename
Raises:
ValueError: If filename is invalid or extension not allowed
"""
if not filename:
raise ValueError("Filename cannot be empty")
# 1. Get just the filename (prevent path traversal)
filename = os.path.basename(filename)
# 2. Remove null bytes
filename = filename.replace('\x00', '')
# 3. Remove/replace dangerous characters
# Keep only alphanumeric, dash, underscore, dot
filename = re.sub(r'[^\w\-.]', '_', filename)
# 4. Remove leading/trailing dots and spaces
filename = filename.strip('. ')
# 5. Prevent hidden files (Unix)
if filename.startswith('.'):
filename = '_' + filename[1:]
# 6. Check extension
if allowed_extensions:
ext = Path(filename).suffix.lower()
if ext not in allowed_extensions:
raise ValueError(f"Extension {ext} not allowed")
# 7. Truncate if too long (preserve extension)
if len(filename) > max_length:
name = Path(filename).stem
ext = Path(filename).suffix
max_name_length = max_length - len(ext)
filename = name[:max_name_length] + ext
# 8. Ensure not empty after sanitization
if not filename or filename == '.':
raise ValueError("Invalid filename after sanitization")
return filename
# Usage
dangerous = "../../../etc/passwd"
safe = sanitize_filename(dangerous)
print(safe) # "_.._.._.._etc_passwd"
upload = "my photo<script>.jpg"
safe = sanitize_filename(upload, allowed_extensions={'.jpg', '.png', '.gif'})
print(safe) # "my_photo_script_.jpg"from urllib.parse import urlparse, urljoin from typing import Optional
ALLOWED_SCHEMES = {'http', 'https'} ALLOWED_HOSTS =
def sanitize_url( url: str, base_url: Optional[str] = None, allow_external: bool = False ) -> str: """ Sanitize and validate URL.
Args:
url: URL to sanitize
base_url: Base URL for relative URLs
allow_external: Whether to allow external URLs
Returns:
Sanitized URL
Raises:
ValueError: If URL is invalid or not allowed
"""
if not url:
raise ValueError("URL cannot be empty")
# Handle relative URLs
if base_url and not url.startswith(('http://', 'https://')):
url = urljoin(base_url, url)
parsed = urlparse(url)
# 1. Validate scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise ValueError(f"Scheme {parsed.scheme} not allowed")
# 2. Validate host (if not allowing external)
if not allow_external and parsed.netloc not in ALLOWED_HOSTS:
raise ValueError(f"Host {parsed.netloc} not allowed")
# 3. Check for suspicious patterns
suspicious_patterns = [
r'javascript:',
r'data:',
r'vbscript:',
r'@', # Credential injection
]
url_lower = url.lower()
for pattern in suspicious_patterns:
if pattern in url_lower:
raise ValueError(f"Suspicious pattern in URL: {pattern}")
return url
Usage
try: safe_url = sanitize_url("https://example.com/page?q=search") print(safe_url) except ValueError as e: print(f"Invalid URL: {e}") ))) print(safe_url) except ValueError as e: print(f"Invalid URL: {e}")
---
## SQL Injection Prevention
### The Problem
```python
# ❌ VULNERABLE: String concatenation
def get_user_vulnerable(username: str):
query = f"SELECT * FROM users WHERE username = '{username}'"
cursor.execute(query)
return cursor.fetchone()
# Attack: username = "' OR '1'='1' --"
# Query becomes: SELECT * FROM users WHERE username = '' OR '1'='1' --'
# Returns ALL users!
# Attack: username = "'; DROP TABLE users; --"
# Query becomes: SELECT * FROM users WHERE username = ''; DROP TABLE users; --'
# Deletes entire table!Parameterized Queries (The Solution)
python
import sqlite3
from typing import Optional, Any
# ✅ SAFE: Parameterized queries
def get_user_safe(username: str) -> Optional[dict]:
"""Get user by username using parameterized query."""
conn = sqlite3.connect('app.db')
cursor = conn.cursor()
# Parameters are escaped automatically
cursor.execute(
"SELECT id, username, email FROM users WHERE username = ?",
(username,) # Tuple of parameters
)
row = cursor.fetchone()
conn.close()
if row:
return {"id": row[0], "username": row[1], "email": row[2]}
return None
# ✅ SAFE: Named parameters
def search_users_safe(
name: Optional[str] = None,
email: Optional[str] = None,
min_age: Optional[int] = None
) -> list[dict]:
"""Search users with multiple optional filters."""
conn = sqlite3.connect('app.db')
cursor = conn.cursor()
query = "SELECT * FROM users WHERE 1=1"
params = {}
if name:
query += " AND name LIKE :name"
params['name'] = f"%{name}%"
if email:
query += " AND email = :email"
params['email'] = email
if min_age is not None:
query += " AND age >= :min_age"
params['min_age'] = min_age
cursor.execute(query, params)
rows = cursor.fetchall()
conn.close()
return [dict(row) for row in rows]from sqlalchemy import create_engine, select, text from sqlalchemy.orm import Session from models import User
engine = create_engine("sqlite:///app.db")
✅ SAFE: ORM queries (automatically parameterized)
def get_user_orm(username: str) -> Optional[User]: with Session(engine) as session: stmt = select(User).where(User.username == username) return session.scalar(stmt)
✅ SAFE: Filter with multiple conditions
def search_users_orm( name: Optional[str] = None, is_active: bool = True ) -> list[User]: with Session(engine) as session: stmt = select(User).where(User.is_active == is_active)
if name:
stmt = stmt.where(User.name.ilike(f"%{name}%"))
return list(session.scalars(stmt))
⚠️ CAREFUL: Raw SQL with text()
def raw_query_safe(table_name: str, user_id: int): """Even raw queries should use parameters.""" with Session(engine) as session: # ❌ VULNERABLE: Don't interpolate table names! # query = text(f"SELECT * FROM {table_name} WHERE id = :id")
# ✅ SAFE: Whitelist table names
allowed_tables = {'users', 'products', 'orders'}
if table_name not in allowed_tables:
raise ValueError(f"Invalid table: {table_name}")
# Table name is validated, parameter is bound
query = text(f"SELECT * FROM {table_name} WHERE id = :id")
result = session.execute(query, {"id": user_id})
return result.fetchone()
) # Table name is validated, parameter is bound query = text(f"SELECT * FROM {table_name} WHERE id = :id") result = session.execute(query, {"id": user_id}) return result.fetchone()
### Dynamic Query Building (Safe Pattern)
```python
from typing import Any
from enum import Enum
class SortOrder(Enum):
ASC = "ASC"
DESC = "DESC"
class UserQueryBuilder:
"""Safe dynamic query builder."""
# Whitelist of allowed columns
ALLOWED_COLUMNS = {'id', 'username', 'email', 'created_at', 'age'}
ALLOWED_SORT_COLUMNS = {'id', 'username', 'created_at'}
def __init__(self):
self.conditions: list[str] = []
self.params: dict[str, Any] = {}
self.order_by: Optional[str] = None
def filter_by_username(self, username: str) -> 'UserQueryBuilder':
self.conditions.append("username = :username")
self.params['username'] = username
return self
def filter_by_email_domain(self, domain: str) -> 'UserQueryBuilder':
self.conditions.append("email LIKE :email_pattern")
self.params['email_pattern'] = f"%@{domain}"
return self
def filter_by_age_range(
self,
min_age: Optional[int] = None,
max_age: Optional[int] = None
) -> 'UserQueryBuilder':
if min_age is not None:
self.conditions.append("age >= :min_age")
self.params['min_age'] = min_age
if max_age is not None:
self.conditions.append("age <= :max_age")
self.params['max_age'] = max_age
return self
def sort_by(
self,
column: str,
order: SortOrder = SortOrder.ASC
) -> 'UserQueryBuilder':
# Whitelist validation for column names
if column not in self.ALLOWED_SORT_COLUMNS:
raise ValueError(f"Cannot sort by column: {column}")
self.order_by = f"{column} {order.value}"
return self
def build(self) -> tuple[str, dict[str, Any]]:
query = "SELECT id, username, email, age FROM users"
if self.conditions:
query += " WHERE " + " AND ".join(self.conditions)
if self.order_by:
query += f" ORDER BY {self.order_by}"
return query, self.params
# Usage
builder = UserQueryBuilder()
query, params = (
builder
.filter_by_email_domain("example.com")
.filter_by_age_range(min_age=18)
.sort_by("created_at", SortOrder.DESC)
.build()
)
# query: "SELECT ... WHERE email LIKE :email_pattern AND age >= :min_age ORDER BY created_at DESC"
# params: {'email_pattern': '%@example.com', 'min_age': 18}Command Injection Prevention
The Problem
python
import os
import subprocess
# ❌ VULNERABLE: Shell injection
def ping_host_vulnerable(host: str) -> str:
"""Ping a host - VULNERABLE to command injection."""
result = os.system(f"ping -c 1 {host}") # NEVER DO THIS!
return str(result)
# Attack: host = "google.com; rm -rf /"
# Executes: ping -c 1 google.com; rm -rf /
# ❌ VULNERABLE: subprocess with shell=True
def get_file_info_vulnerable(filename: str) -> str:
result = subprocess.run(
f"ls -la {filename}",
shell=True, # DANGEROUS!
capture_output=True,
text=True
)
return result.stdout
# Attack: filename = "file.txt; cat /etc/passwd"Safe Command Execution
python
import subprocess
import shlex
from typing import Optional
# ✅ SAFE: Use list of arguments (no shell)
def ping_host_safe(host: str, count: int = 1) -> tuple[bool, str]:
"""
Ping a host safely.
Args:
host: Hostname or IP to ping
count: Number of ping attempts
Returns:
Tuple of (success, output)
"""
# Validate input
if not host or not isinstance(host, str):
raise ValueError("Invalid host")
# Whitelist allowed characters for hostname
import re
if not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9\-\.]*$', host):
raise ValueError("Invalid hostname format")
# Validate count
if not isinstance(count, int) or count < 1 or count > 10:
raise ValueError("Count must be between 1 and 10")
try:
result = subprocess.run(
["ping", "-c", str(count), host], # List of args, no shell
capture_output=True,
text=True,
timeout=30 # Prevent hanging
)
return result.returncode == 0, result.stdout
except subprocess.TimeoutExpired:
return False, "Ping timed out"
except Exception as e:
return False, str(e)
# ✅ SAFE: shlex.quote for unavoidable shell usage
def run_with_shell_safe(command: str, user_input: str) -> str:
"""
Run command with user input - properly escaped.
Note: Prefer subprocess with list args when possible.
"""
# Escape user input for shell
safe_input = shlex.quote(user_input)
result = subprocess.run(
f"{command} {safe_input}",
shell=True,
capture_output=True,
text=True
)
return result.stdout
# Example
user_file = "file.txt; rm -rf /"
safe = shlex.quote(user_file)
print(safe) # 'file.txt; rm -rf /' (quoted, safe)Whitelist Approach
python
from enum import Enum
from typing import Optional
import subprocess
class AllowedCommand(Enum):
"""Whitelist of allowed commands."""
PING = "ping"
DIG = "dig"
NSLOOKUP = "nslookup"
TRACEROUTE = "traceroute"
def run_network_diagnostic(
command: AllowedCommand,
target: str,
options: Optional[list[str]] = None
) -> tuple[int, str, str]:
"""
Run network diagnostic command safely.
Args:
command: Whitelisted command to run
target: Target host/IP
options: Additional options (validated)
Returns:
Tuple of (return_code, stdout, stderr)
"""
import re
# Validate target (hostname or IP)
hostname_pattern = r'^[a-zA-Z0-9][a-zA-Z0-9\-\.]*$'
ip_pattern = r'^(\d{1,3}\.){3}\d{1,3}$'
if not (re.match(hostname_pattern, target) or re.match(ip_pattern, target)):
raise ValueError("Invalid target format")
# Build command
cmd = [command.value, target]
# Validate and add options
ALLOWED_OPTIONS = {
AllowedCommand.PING: {'-c', '-W', '-i'},
AllowedCommand.DIG: {'+short', '+trace', '-t'},
AllowedCommand.NSLOOKUP: {'-type'},
AllowedCommand.TRACEROUTE: {'-m', '-w'},
}
if options:
allowed = ALLOWED_OPTIONS.get(command, set())
for opt in options:
# Only allow whitelisted option flags
opt_flag = opt.split('=')[0] if '=' in opt else opt.split()[0]
if opt_flag not in allowed:
raise ValueError(f"Option {opt_flag} not allowed for {command.value}")
cmd.append(opt)
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=60
)
return result.returncode, result.stdout, result.stderr
# Usage
code, out, err = run_network_diagnostic(
AllowedCommand.PING,
"google.com",
["-c", "3"]
)Avoid Shell When Possible
python
import subprocess
from pathlib import Path
# ❌ AVOID: Using shell for file operations
def list_files_bad(directory: str) -> str:
return subprocess.run(
f"ls -la {directory}",
shell=True,
capture_output=True,
text=True
).stdout
# ✅ BETTER: Use Python's built-in libraries
def list_files_good(directory: str) -> list[dict]:
"""List files using pathlib - no shell needed."""
path = Path(directory)
# Validate path
if not path.exists():
raise ValueError("Directory does not exist")
if not path.is_dir():
raise ValueError("Path is not a directory")
# Prevent path traversal
try:
path = path.resolve()
# Optionally check if within allowed base directory
# base = Path("/allowed/base").resolve()
# if not str(path).startswith(str(base)):
# raise ValueError("Access denied")
except Exception:
raise ValueError("Invalid path")
files = []
for item in path.iterdir():
stat = item.stat()
files.append({
"name": item.name,
"size": stat.st_size,
"is_dir": item.is_dir(),
"modified": stat.st_mtime
})
return filesProduction Pitfalls
Pitfall 1: Validation Bypass via Type Coercion
python
from pydantic import BaseModel
# ❌ BUG: Pydantic coerces by default
class UserInput(BaseModel):
is_admin: bool
# Attack: is_admin = "false" (string)
user = UserInput(is_admin="false")
print(user.is_admin) # True! Non-empty string is truthy
# ✅ FIX: Use strict mode
from pydantic import ConfigDict
class UserInputStrict(BaseModel):
model_config = ConfigDict(strict=True)
is_admin: bool
# UserInputStrict(is_admin="false") # ValidationError!Pitfall 2: Incomplete Sanitization
python
import html
# ❌ BUG: Only escaping HTML, not preventing all XSS
def render_comment_bad(comment: str) -> str:
safe = html.escape(comment)
return f'<div onclick="handleClick(\'{safe}\')">{safe}</div>'
# Attack: comment = "'); alert('xss'); ('"
# Result: <div onclick="handleClick(''); alert('xss'); ('')">...
# Still vulnerable in onclick context!
# ✅ FIX: Context-aware escaping
import json
def render_comment_good(comment: str) -> str:
# HTML escape for content
html_safe = html.escape(comment)
# JSON encode for JavaScript context
js_safe = json.dumps(comment) # Properly escapes quotes
return f'<div onclick="handleClick({js_safe})">{html_safe}</div>'Pitfall 3: TOCTOU (Time-of-Check to Time-of-Use)
python
import os
from pathlib import Path
# ❌ BUG: Race condition between check and use
def read_file_bad(filepath: str) -> str:
path = Path(filepath)
# Check
if not path.exists():
raise ValueError("File not found")
if not path.is_file():
raise ValueError("Not a file")
# Time passes... attacker could swap file with symlink here
# Use
return path.read_text() # Could read different file!
# ✅ FIX: Use atomic operations, handle errors
def read_file_good(filepath: str, base_dir: str = "/safe/dir") -> str:
"""Read file with TOCTOU protection."""
base = Path(base_dir).resolve()
path = Path(filepath)
# Resolve to absolute path
try:
resolved = (base / path).resolve()
except Exception:
raise ValueError("Invalid path")
# Ensure within base directory
if not str(resolved).startswith(str(base)):
raise ValueError("Path traversal detected")
# Open atomically - let OS handle race conditions
try:
with open(resolved, 'r') as f:
return f.read()
except FileNotFoundError:
raise ValueError("File not found")
except IsADirectoryError:
raise ValueError("Path is a directory")
except PermissionError:
raise ValueError("Permission denied")Pitfall 4: Regex DoS (ReDoS)
python
import re
# ❌ BUG: Catastrophic backtracking
evil_pattern = re.compile(r'^(a+)+$')
# This takes exponential time!
# evil_pattern.match('a' * 30 + 'b') # Hangs!
# ✅ FIX: Use non-backtracking patterns or timeout
import signal
def regex_with_timeout(pattern: str, text: str, timeout: int = 1) -> bool:
"""Match regex with timeout protection."""
def handler(signum, frame):
raise TimeoutError("Regex timeout")
# Set timeout (Unix only)
old_handler = signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
return bool(re.match(pattern, text))
except TimeoutError:
return False
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
# Better: Use atomic groups or possessive quantifiers
# Or use google-re2 library for guaranteed linear timePitfall 5: Mass Assignment
python
from pydantic import BaseModel
from typing import Optional
class User(BaseModel):
id: int
username: str
email: str
is_admin: bool = False
password_hash: str
# ❌ BUG: Accepting all fields from request
@app.put("/users/{user_id}")
async def update_user_bad(user_id: int, data: dict):
user = get_user(user_id)
for key, value in data.items():
setattr(user, key, value) # Can set is_admin!
save_user(user)
# ✅ FIX: Explicit update model
class UserUpdate(BaseModel):
"""Only fields that can be updated by user."""
username: Optional[str] = None
email: Optional[str] = None
# is_admin NOT included - cannot be set by user
@app.put("/users/{user_id}")
async def update_user_good(user_id: int, data: UserUpdate):
user = get_user(user_id)
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(user, key, value)
save_user(user)Quick Reference
python
# === PYDANTIC VALIDATION ===
from pydantic import BaseModel, Field, field_validator, ConfigDict
class SafeInput(BaseModel):
model_config = ConfigDict(strict=True) # No type coercion
name: str = Field(..., min_length=1, max_length=100)
email: EmailStr
@field_validator('name')
@classmethod
def validate_name(cls, v):
# Custom validation
return v
# === SQL INJECTION PREVENTION ===
# Always use parameterized queries
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
# SQLAlchemy ORM is safe by default
stmt = select(User).where(User.id == user_id)
# === COMMAND INJECTION PREVENTION ===
# Use list of args, not shell
subprocess.run(["ping", "-c", "1", host])
# If shell needed, use shlex.quote
import shlex
safe_arg = shlex.quote(user_input)
# === SANITIZATION ===
import html
safe_html = html.escape(user_input)
# Filename sanitization
safe_name = os.path.basename(filename)Cross-links
- Prerequisites: dataclasses & Pydantic
- Related: Common Vulnerabilities - OWASP Top 10
- See Also: FastAPI - Request validation
- See Also: SQLAlchemy - Safe database queries