Files
gt-ai-os-community/apps/control-panel-backend/app/services/message_dmz.py
HackWeasel b9dfb86260 GT AI OS Community Edition v2.0.33
Security hardening release addressing CodeQL and Dependabot alerts:

- Fix stack trace exposure in error responses
- Add SSRF protection with DNS resolution checking
- Implement proper URL hostname validation (replaces substring matching)
- Add centralized path sanitization to prevent path traversal
- Fix ReDoS vulnerability in email validation regex
- Improve HTML sanitization in validation utilities
- Fix capability wildcard matching in auth utilities
- Update glob dependency to address CVE
- Add CodeQL suppression comments for verified false positives

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-12 17:04:45 -05:00

360 lines
12 KiB
Python

"""
Message DMZ Service for secure air-gap communication
Implements security controls for cross-cluster messaging including:
- Message validation and sanitization
- Command signature verification
- Audit logging
- Rate limiting
- Security policy enforcement
"""
import json
import logging
import hashlib
import hmac
import re
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List, Set
from collections import defaultdict
import asyncio
from app.core.config import settings
from app.schemas.messages import CommandType, AlertSeverity
logger = logging.getLogger(__name__)
class SecurityViolation(Exception):
"""Raised when a security policy is violated"""
pass
class MessageDMZ:
"""
Security DMZ for message bus communication
Provides defense-in-depth security controls for cross-cluster messaging
"""
def __init__(self):
# Rate limiting
self.rate_limits: Dict[str, List[datetime]] = defaultdict(list)
self.rate_limit_window = timedelta(minutes=1)
self.max_messages_per_minute = 100
# Command whitelist
self.allowed_commands = set(CommandType)
# Blocked patterns (for detecting potential injection attacks)
self.blocked_patterns = [
r'<script[^>]*>.*?</script>', # XSS
r'javascript:', # JavaScript URI
r'on\w+\s*=', # Event handlers
r'DROP\s+TABLE', # SQL injection
r'DELETE\s+FROM', # SQL injection
r'INSERT\s+INTO', # SQL injection
r'UPDATE\s+SET', # SQL injection
r'--', # SQL comment
r'/\*.*\*/', # SQL block comment
r'\.\./+', # Path traversal
r'\\x[0-9a-fA-F]{2}', # Hex encoding
r'%[0-9a-fA-F]{2}', # URL encoding suspicious patterns
]
# Audit log
self.audit_log: List[Dict[str, Any]] = []
self.max_audit_entries = 10000
# Security metrics
self.metrics = {
'messages_validated': 0,
'messages_rejected': 0,
'signature_failures': 0,
'rate_limit_violations': 0,
'injection_attempts': 0,
}
async def validate_incoming_message(
self,
message: Dict[str, Any],
source: str
) -> Dict[str, Any]:
"""
Validate incoming message from another cluster
Args:
message: Raw message data
source: Source cluster identifier
Returns:
Validated and sanitized message
Raises:
SecurityViolation: If message fails validation
"""
try:
# Check rate limits
if not self._check_rate_limit(source):
self.metrics['rate_limit_violations'] += 1
raise SecurityViolation(f"Rate limit exceeded for source: {source}")
# Verify required fields
required_fields = ['command_id', 'command_type', 'timestamp', 'signature']
for field in required_fields:
if field not in message:
raise SecurityViolation(f"Missing required field: {field}")
# Verify timestamp (prevent replay attacks)
if not self._verify_timestamp(message['timestamp']):
raise SecurityViolation("Message timestamp is too old or invalid")
# Verify command type is allowed
if message['command_type'] not in self.allowed_commands:
raise SecurityViolation(f"Unknown command type: {message['command_type']}")
# Verify signature
if not self._verify_signature(message):
self.metrics['signature_failures'] += 1
raise SecurityViolation("Invalid message signature")
# Sanitize payload
if 'payload' in message:
message['payload'] = self._sanitize_payload(message['payload'])
# Log successful validation
self._audit_log('message_validated', source, message['command_id'])
self.metrics['messages_validated'] += 1
return message
except SecurityViolation:
self.metrics['messages_rejected'] += 1
self._audit_log('message_rejected', source, message.get('command_id', 'unknown'))
raise
except Exception as e:
logger.error(f"Unexpected error validating message: {e}")
self.metrics['messages_rejected'] += 1
raise SecurityViolation(f"Message validation failed: {str(e)}")
async def prepare_outgoing_message(
self,
command_type: str,
payload: Dict[str, Any],
target: str
) -> Dict[str, Any]:
"""
Prepare message for sending to another cluster
Args:
command_type: Type of command
payload: Command payload
target: Target cluster identifier
Returns:
Prepared and signed message
"""
# Sanitize payload
sanitized_payload = self._sanitize_payload(payload)
# Create message structure
message = {
'command_type': command_type,
'payload': sanitized_payload,
'target_cluster': target,
'timestamp': datetime.utcnow().isoformat(),
'source': 'admin_cluster'
}
# Sign message
signature = self._create_signature(message)
message['signature'] = signature
# Audit log
self._audit_log('message_prepared', target, command_type)
return message
def _check_rate_limit(self, source: str) -> bool:
"""Check if source has exceeded rate limits"""
now = datetime.utcnow()
# Clean old entries
cutoff = now - self.rate_limit_window
self.rate_limits[source] = [
ts for ts in self.rate_limits[source]
if ts > cutoff
]
# Check limit
if len(self.rate_limits[source]) >= self.max_messages_per_minute:
return False
# Add current timestamp
self.rate_limits[source].append(now)
return True
def _verify_timestamp(self, timestamp_str: str, max_age_seconds: int = 300) -> bool:
"""Verify message timestamp is recent (prevent replay attacks)"""
try:
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
age = (datetime.utcnow() - timestamp.replace(tzinfo=None)).total_seconds()
# Message too old
if age > max_age_seconds:
return False
# Message from future (clock skew tolerance of 30 seconds)
if age < -30:
return False
return True
except (ValueError, AttributeError):
return False
def _verify_signature(self, message: Dict[str, Any]) -> bool:
"""Verify message signature"""
signature = message.get('signature', '')
# Create message to verify (exclude signature field)
message_copy = {k: v for k, v in message.items() if k != 'signature'}
message_json = json.dumps(message_copy, sort_keys=True)
# Verify signature
expected_signature = hmac.new(
settings.MESSAGE_BUS_SECRET_KEY.encode(),
message_json.encode(),
hashlib.sha256
).hexdigest()
return hmac.compare_digest(signature, expected_signature)
def _create_signature(self, message: Dict[str, Any]) -> str:
"""Create message signature"""
message_json = json.dumps(message, sort_keys=True)
return hmac.new(
settings.MESSAGE_BUS_SECRET_KEY.encode(),
message_json.encode(),
hashlib.sha256
).hexdigest()
def _sanitize_payload(self, payload: Any) -> Any:
"""
Sanitize payload to prevent injection attacks
Recursively sanitizes strings in dictionaries and lists
"""
if isinstance(payload, str):
# Check for blocked patterns
for pattern in self.blocked_patterns:
if re.search(pattern, payload, re.IGNORECASE):
self.metrics['injection_attempts'] += 1
raise SecurityViolation(f"Potential injection attempt detected")
# Basic sanitization
# Remove control characters except standard whitespace
sanitized = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]', '', payload)
# Limit string length
max_length = 10000
if len(sanitized) > max_length:
sanitized = sanitized[:max_length]
return sanitized
elif isinstance(payload, dict):
return {
self._sanitize_payload(k): self._sanitize_payload(v)
for k, v in payload.items()
}
elif isinstance(payload, list):
return [self._sanitize_payload(item) for item in payload]
else:
# Numbers, booleans, None are safe
return payload
def _audit_log(
self,
event_type: str,
target: str,
details: Any
) -> None:
"""Add entry to audit log"""
entry = {
'timestamp': datetime.utcnow().isoformat(),
'event_type': event_type,
'target': target,
'details': details
}
self.audit_log.append(entry)
# Rotate log if too large
if len(self.audit_log) > self.max_audit_entries:
self.audit_log = self.audit_log[-self.max_audit_entries:]
# Log to application logger
logger.info(f"DMZ Audit: {event_type} - Target: {target} - Details: {details}")
def get_security_metrics(self) -> Dict[str, Any]:
"""Get security metrics"""
return {
**self.metrics,
'audit_log_size': len(self.audit_log),
'rate_limited_sources': len(self.rate_limits),
'timestamp': datetime.utcnow().isoformat()
}
def get_audit_log(
self,
limit: int = 100,
event_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Get audit log entries"""
logs = self.audit_log[-limit:]
if event_type:
logs = [log for log in logs if log['event_type'] == event_type]
return logs
async def validate_command_permissions(
self,
command_type: str,
user_id: int,
user_type: str,
tenant_id: Optional[int] = None
) -> bool:
"""
Validate user has permission to execute command
Args:
command_type: Type of command
user_id: User ID
user_type: User type (super_admin, tenant_admin, tenant_user)
tenant_id: Tenant ID (for tenant-scoped commands)
Returns:
True if user has permission, False otherwise
"""
# Super admins can execute all commands
if user_type == 'super_admin':
return True
# Tenant admins can execute tenant-scoped commands for their tenant
if user_type == 'tenant_admin' and tenant_id:
tenant_commands = [
CommandType.USER_CREATE,
CommandType.USER_UPDATE,
CommandType.USER_SUSPEND,
CommandType.RESOURCE_ASSIGN,
CommandType.RESOURCE_UNASSIGN
]
return command_type in tenant_commands
# Regular users cannot execute admin commands
return False
# Global DMZ instance
message_dmz = MessageDMZ()