GT AI OS Community Edition v2.0.33

Security hardening release addressing CodeQL and Dependabot alerts:

- Fix stack trace exposure in error responses
- Add SSRF protection with DNS resolution checking
- Implement proper URL hostname validation (replaces substring matching)
- Add centralized path sanitization to prevent path traversal
- Fix ReDoS vulnerability in email validation regex
- Improve HTML sanitization in validation utilities
- Fix capability wildcard matching in auth utilities
- Update glob dependency to address CVE
- Add CodeQL suppression comments for verified false positives

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
HackWeasel
2025-12-12 17:04:45 -05:00
commit b9dfb86260
746 changed files with 232071 additions and 0 deletions

View File

@@ -0,0 +1,131 @@
"""
GT 2.0 Tenant Backend - CB-REST API Standards Integration
This module integrates the CB-REST standards into the Tenant backend
"""
import os
import sys
from pathlib import Path
# Add the api-standards package to the path
api_standards_path = Path(__file__).parent.parent.parent.parent.parent / "packages" / "api-standards" / "src"
if api_standards_path.exists():
sys.path.insert(0, str(api_standards_path))
# Import CB-REST standards
try:
from response import StandardResponse, format_response, format_error
from capability import (
init_capability_verifier,
verify_capability,
require_capability,
Capability,
CapabilityToken
)
from errors import ErrorCode, APIError, raise_api_error
from middleware import (
RequestCorrelationMiddleware,
CapabilityMiddleware,
TenantIsolationMiddleware,
RateLimitMiddleware
)
except ImportError as e:
# Fallback for development - create minimal implementations
print(f"Warning: Could not import api-standards package: {e}")
# Create minimal implementations for development
class StandardResponse:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def format_response(data, capability_used, request_id=None):
return {
"data": data,
"error": None,
"capability_used": capability_used,
"request_id": request_id or "dev-mode"
}
def format_error(code, message, capability_used="none", **kwargs):
return {
"data": None,
"error": {
"code": code,
"message": message,
**kwargs
},
"capability_used": capability_used,
"request_id": kwargs.get("request_id", "dev-mode")
}
class ErrorCode:
CAPABILITY_INSUFFICIENT = "CAPABILITY_INSUFFICIENT"
RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"
INVALID_REQUEST = "INVALID_REQUEST"
SYSTEM_ERROR = "SYSTEM_ERROR"
TENANT_ISOLATION_VIOLATION = "TENANT_ISOLATION_VIOLATION"
class APIError(Exception):
def __init__(self, code, message, **kwargs):
self.code = code
self.message = message
self.kwargs = kwargs
super().__init__(message)
# Export all CB-REST components
__all__ = [
'StandardResponse',
'format_response',
'format_error',
'init_capability_verifier',
'verify_capability',
'require_capability',
'Capability',
'CapabilityToken',
'ErrorCode',
'APIError',
'raise_api_error',
'RequestCorrelationMiddleware',
'CapabilityMiddleware',
'TenantIsolationMiddleware',
'RateLimitMiddleware'
]
def setup_api_standards(app, secret_key: str, tenant_id: str):
"""
Setup CB-REST API standards for the tenant application
Args:
app: FastAPI application instance
secret_key: Secret key for JWT signing
tenant_id: Tenant identifier for isolation
"""
# Initialize capability verifier
if 'init_capability_verifier' in globals():
init_capability_verifier(secret_key)
# Add middleware in correct order
if 'RequestCorrelationMiddleware' in globals():
app.add_middleware(RequestCorrelationMiddleware)
if 'RateLimitMiddleware' in globals():
app.add_middleware(
RateLimitMiddleware,
requests_per_minute=100 # Per-tenant rate limiting
)
if 'TenantIsolationMiddleware' in globals():
app.add_middleware(
TenantIsolationMiddleware,
tenant_id=tenant_id,
enforce_isolation=True
)
if 'CapabilityMiddleware' in globals():
app.add_middleware(
CapabilityMiddleware,
exclude_paths=["/health", "/ready", "/metrics", "/api/v1/auth/login"]
)

View File

@@ -0,0 +1,162 @@
"""
Composite ASGI Router for GT 2.0 Tenant Backend
Handles routing between FastAPI and Socket.IO applications to prevent
ASGI protocol conflicts while maintaining both WebSocket systems.
Architecture:
- `/socket.io/*` → Socket.IO ASGIApp (agentic real-time features)
- All other paths → FastAPI app (REST API, native WebSocket)
"""
import logging
from typing import Dict, Any, Callable, Awaitable
logger = logging.getLogger(__name__)
class CompositeASGIRouter:
"""
ASGI router that handles both FastAPI and Socket.IO applications
without protocol conflicts.
"""
def __init__(self, fastapi_app, socketio_app):
"""
Initialize composite router with both applications.
Args:
fastapi_app: FastAPI application instance
socketio_app: Socket.IO ASGIApp instance
"""
self.fastapi_app = fastapi_app
self.socketio_app = socketio_app
logger.info("Composite ASGI router initialized for FastAPI + Socket.IO")
async def __call__(self, scope: Dict[str, Any], receive: Callable, send: Callable) -> None:
"""
ASGI application entry point that routes requests based on path.
Args:
scope: ASGI scope containing request information
receive: ASGI receive callable
send: ASGI send callable
"""
try:
# Extract path from scope
path = scope.get("path", "")
# Route based on path pattern
if self._is_socketio_path(path):
# Only log Socket.IO routing at DEBUG level for non-operational paths
if self._should_log_route(path):
logger.debug(f"Routing to Socket.IO: {path}")
await self.socketio_app(scope, receive, send)
else:
# Only log FastAPI routing at DEBUG level for non-operational paths
if self._should_log_route(path):
logger.debug(f"Routing to FastAPI: {path}")
await self.fastapi_app(scope, receive, send)
except Exception as e:
logger.error(f"Error in ASGI routing: {e}")
# Fallback to FastAPI for error handling
try:
await self.fastapi_app(scope, receive, send)
except Exception as fallback_error:
logger.error(f"Fallback routing also failed: {fallback_error}")
# Last resort: send basic error response
await self._send_error_response(scope, send)
def _is_socketio_path(self, path: str) -> bool:
"""
Determine if path should be routed to Socket.IO.
Args:
path: Request path
Returns:
True if path should go to Socket.IO, False for FastAPI
"""
socketio_patterns = [
"/socket.io/",
"/socket.io"
]
# Check if path starts with any Socket.IO pattern
for pattern in socketio_patterns:
if path.startswith(pattern):
return True
return False
def _should_log_route(self, path: str) -> bool:
"""
Determine if this path should be logged during routing.
Operational endpoints like health checks and metrics are excluded
to reduce log noise during normal operation.
Args:
path: Request path
Returns:
True if path should be logged, False for operational endpoints
"""
operational_endpoints = [
"/health",
"/ready",
"/metrics",
"/api/v1/health"
]
# Don't log operational endpoints
if any(path.startswith(endpoint) for endpoint in operational_endpoints):
return False
return True
async def _send_error_response(self, scope: Dict[str, Any], send: Callable) -> None:
"""
Send basic error response when both applications fail.
Args:
scope: ASGI scope
send: ASGI send callable
"""
try:
if scope["type"] == "http":
await send({
"type": "http.response.start",
"status": 500,
"headers": [
[b"content-type", b"application/json"],
[b"content-length", b"27"]
]
})
await send({
"type": "http.response.body",
"body": b'{"error": "ASGI routing failed"}'
})
elif scope["type"] == "websocket":
await send({
"type": "websocket.close",
"code": 1011,
"reason": "ASGI routing failed"
})
except Exception as e:
logger.error(f"Failed to send error response: {e}")
def create_composite_asgi_app(fastapi_app, socketio_app):
"""
Factory function to create composite ASGI application.
Args:
fastapi_app: FastAPI application instance
socketio_app: Socket.IO ASGIApp instance
Returns:
CompositeASGIRouter instance
"""
return CompositeASGIRouter(fastapi_app, socketio_app)

View File

@@ -0,0 +1,202 @@
"""
Simple in-memory cache with TTL support for Gen Two performance optimization.
This module provides a thread-safe caching layer for expensive database queries
and API calls. Each Uvicorn worker maintains its own cache instance.
Key features:
- TTL-based expiration (configurable per-key)
- LRU eviction when cache reaches max size
- Thread-safe for concurrent request handling
- Pattern-based deletion for cache invalidation
Usage:
from app.core.cache import get_cache
cache = get_cache()
# Get cached value with 60-second TTL
cached_data = cache.get("agents_minimal_user123", ttl=60)
if not cached_data:
data = await fetch_from_db()
cache.set("agents_minimal_user123", data)
"""
from typing import Any, Optional, Dict, Tuple
from datetime import datetime, timedelta
from threading import Lock
import logging
logger = logging.getLogger(__name__)
class SimpleCache:
"""
Thread-safe TTL cache for API responses and database query results.
This cache is per-worker (each Uvicorn worker maintains separate cache).
Cache keys should include tenant_domain or user_id for proper isolation.
Attributes:
max_entries: Maximum number of cache entries before LRU eviction
_cache: Internal cache storage (key -> (timestamp, data))
_lock: Thread lock for safe concurrent access
"""
def __init__(self, max_entries: int = 1000):
"""
Initialize cache with maximum entry limit.
Args:
max_entries: Maximum cache entries (default 1000)
Typical: 200KB per agent list × 1000 = 200MB per worker
"""
self._cache: Dict[str, Tuple[datetime, Any]] = {}
self._lock = Lock()
self._max_entries = max_entries
self._hits = 0
self._misses = 0
logger.info(f"SimpleCache initialized with max_entries={max_entries}")
def get(self, key: str, ttl: int = 60) -> Optional[Any]:
"""
Get cached value if not expired.
Args:
key: Cache key (should include tenant/user for isolation)
ttl: Time-to-live in seconds (default 60)
Returns:
Cached data if found and not expired, None otherwise
Example:
data = cache.get("agents_minimal_user123", ttl=60)
if data is None:
# Cache miss - fetch from database
data = await fetch_from_db()
cache.set("agents_minimal_user123", data)
"""
with self._lock:
if key not in self._cache:
self._misses += 1
logger.debug(f"Cache miss: {key}")
return None
timestamp, data = self._cache[key]
age = (datetime.utcnow() - timestamp).total_seconds()
if age > ttl:
# Expired - remove and return None
del self._cache[key]
self._misses += 1
logger.debug(f"Cache expired: {key} (age={age:.1f}s, ttl={ttl}s)")
return None
self._hits += 1
logger.debug(f"Cache hit: {key} (age={age:.1f}s, ttl={ttl}s)")
return data
def set(self, key: str, data: Any) -> None:
"""
Set cache value with current timestamp.
Args:
key: Cache key
data: Data to cache (should be JSON-serializable)
Note:
If cache is full, oldest entry is evicted (LRU)
"""
with self._lock:
# LRU eviction if cache full
if len(self._cache) >= self._max_entries:
oldest_key = min(self._cache.items(), key=lambda x: x[1][0])[0]
del self._cache[oldest_key]
logger.warning(
f"Cache full ({self._max_entries} entries), "
f"evicted oldest key: {oldest_key}"
)
self._cache[key] = (datetime.utcnow(), data)
logger.debug(f"Cache set: {key} (total entries: {len(self._cache)})")
def delete(self, pattern: str) -> int:
"""
Delete all keys matching pattern (prefix match).
Args:
pattern: Key prefix to match (e.g., "agents_minimal_")
Returns:
Number of keys deleted
Example:
# Delete all agent cache entries for a user
count = cache.delete(f"agents_minimal_{user_id}")
count += cache.delete(f"agents_summary_{user_id}")
"""
with self._lock:
keys_to_delete = [k for k in self._cache.keys() if k.startswith(pattern)]
for k in keys_to_delete:
del self._cache[k]
if keys_to_delete:
logger.info(f"Cache invalidated {len(keys_to_delete)} entries matching '{pattern}'")
return len(keys_to_delete)
def clear(self) -> None:
"""Clear entire cache (use with caution)."""
with self._lock:
entry_count = len(self._cache)
self._cache.clear()
self._hits = 0
self._misses = 0
logger.warning(f"Cache cleared (removed {entry_count} entries)")
def size(self) -> int:
"""Get number of cached entries."""
return len(self._cache)
def stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
Dict with size, hits, misses, hit_rate
"""
total_requests = self._hits + self._misses
hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0
return {
"size": len(self._cache),
"max_entries": self._max_entries,
"hits": self._hits,
"misses": self._misses,
"hit_rate_percent": round(hit_rate, 2),
}
# Singleton cache instance per worker
_cache: Optional[SimpleCache] = None
def get_cache() -> SimpleCache:
"""
Get or create singleton cache instance.
Each Uvicorn worker creates its own cache instance (isolated per-process).
Returns:
SimpleCache instance
"""
global _cache
if _cache is None:
_cache = SimpleCache(max_entries=1000)
return _cache
def clear_cache() -> None:
"""Clear global cache (for testing or emergency use)."""
cache = get_cache()
cache.clear()

View File

@@ -0,0 +1,380 @@
"""
GT 2.0 Tenant Backend - Capability Client
Generate JWT capability tokens for Resource Cluster API calls
"""
import json
import time
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from jose import jwt
from app.core.config import get_settings
import logging
import httpx
logger = logging.getLogger(__name__)
settings = get_settings()
class CapabilityClient:
"""Generates capability-based JWT tokens for Resource Cluster access"""
def __init__(self):
# Use tenant-specific secret key for token signing
self.secret_key = settings.secret_key
self.algorithm = "HS256"
self.issuer = f"gt2-tenant-{settings.tenant_id}"
self.http_client = httpx.AsyncClient(timeout=10.0)
self.control_panel_url = settings.control_panel_url
async def generate_capability_token(
self,
user_email: str,
tenant_id: str,
resources: List[str],
expires_hours: int = 24,
additional_claims: Optional[Dict[str, Any]] = None
) -> str:
"""
Generate a JWT capability token for Resource Cluster API access.
Args:
user_email: Email of the user making the request
tenant_id: Tenant identifier
resources: List of resource capabilities (e.g., ['external_services', 'rag_processing'])
expires_hours: Token expiration time in hours
additional_claims: Additional JWT claims to include
Returns:
Signed JWT token string
"""
now = datetime.utcnow()
expiry = now + timedelta(hours=expires_hours)
# Build capability token payload
payload = {
# Standard JWT claims
"iss": self.issuer, # Issuer
"sub": user_email, # Subject (user)
"aud": "gt2-resource-cluster", # Audience
"iat": int(now.timestamp()), # Issued at
"exp": int(expiry.timestamp()), # Expiration
"nbf": int(now.timestamp()), # Not before
"jti": f"{tenant_id}-{user_email}-{int(now.timestamp())}", # JWT ID
# GT 2.0 specific claims
"tenant_id": tenant_id,
"user_email": user_email,
"user_type": "tenant_user",
# Capability grants
"capabilities": await self._build_capabilities(resources, tenant_id, expiry),
# Security metadata
"capability_hash": self._generate_capability_hash(resources, tenant_id),
"token_version": "2.0",
"security_level": "standard"
}
# Add any additional claims
if additional_claims:
payload.update(additional_claims)
# Sign the token
try:
token = jwt.encode(
payload,
self.secret_key,
algorithm=self.algorithm
)
logger.info(
f"Generated capability token for {user_email} with resources: {resources}"
)
return token
except Exception as e:
logger.error(f"Failed to generate capability token: {e}")
raise RuntimeError(f"Token generation failed: {e}")
async def _build_capabilities(
self,
resources: List[str],
tenant_id: str,
expiry: datetime
) -> List[Dict[str, Any]]:
"""
Build capability grants for resources with constraints from Control Panel.
For LLM resources, fetches real rate limits from Control Panel API.
For other resources, uses default constraints.
"""
capabilities = []
for resource in resources:
capability = {
"resource": resource,
"actions": self._get_default_actions(resource),
"constraints": await self._get_constraints_for_resource(resource, tenant_id),
"valid_until": expiry.isoformat()
}
capabilities.append(capability)
return capabilities
async def _get_constraints_for_resource(
self,
resource: str,
tenant_id: str
) -> Dict[str, Any]:
"""
Get constraints for a resource, fetching from Control Panel for LLM resources.
GT 2.0 Principle: Single source of truth in database.
Fails fast if Control Panel is unreachable for LLM resources.
"""
# For LLM resources, fetch real config from Control Panel
if resource in ["llm", "llm_inference"]:
# Note: We don't have model_id at this point in the flow
# This is called during general capability token generation
# For now, return default constraints that will be overridden
# when model-specific tokens are generated
return self._get_default_constraints(resource)
# For non-LLM resources, use defaults
return self._get_default_constraints(resource)
async def _fetch_tenant_model_config(
self,
tenant_id: str,
model_id: str
) -> Optional[Dict[str, Any]]:
"""
Fetch tenant model configuration from Control Panel API.
Returns rate limits from database (single source of truth).
Fails fast if Control Panel is unreachable (no fallbacks).
Args:
tenant_id: Tenant identifier
model_id: Model identifier
Returns:
Model config with rate_limits, or None if not found
Raises:
RuntimeError: If Control Panel API is unreachable (fail fast)
"""
try:
url = f"{self.control_panel_url}/api/v1/tenant-models/tenants/{tenant_id}/models/{model_id}"
logger.debug(f"Fetching model config from Control Panel: {url}")
response = await self.http_client.get(url)
if response.status_code == 404:
logger.warning(f"Model {model_id} not configured for tenant {tenant_id}")
return None
response.raise_for_status()
config = response.json()
logger.info(f"Fetched model config for {model_id}: rate_limits={config.get('rate_limits')}")
return config
except httpx.HTTPStatusError as e:
logger.error(f"Control Panel API error: {e.response.status_code}")
raise RuntimeError(
f"Failed to fetch model config from Control Panel: HTTP {e.response.status_code}"
)
except httpx.RequestError as e:
logger.error(f"Control Panel API unreachable: {e}")
raise RuntimeError(
f"Control Panel API unreachable - cannot generate capability token. "
f"Ensure Control Panel is running at {self.control_panel_url}"
)
except Exception as e:
logger.error(f"Unexpected error fetching model config: {e}")
raise RuntimeError(f"Failed to fetch model config: {e}")
def _get_default_actions(self, resource: str) -> List[str]:
"""Get default actions for a resource type"""
action_mappings = {
"external_services": ["create", "read", "update", "delete", "health_check", "sso_token"],
"rag_processing": ["process_document", "generate_embeddings", "vector_search"],
"llm_inference": ["chat_completion", "streaming", "function_calling"],
"llm": ["execute"], # Use valid ActionType from resource cluster
"agent_orchestration": ["execute", "status", "interrupt"],
"ai_literacy": ["play_games", "solve_puzzles", "dialogue", "analytics"],
"app_integrations": ["read", "write", "webhook"],
"admin": ["all"],
# MCP Server Resources
"mcp:rag": ["search_datasets", "query_documents", "list_user_datasets", "get_dataset_info", "get_relevant_chunks"]
}
return action_mappings.get(resource, ["read"])
def _get_default_constraints(self, resource: str) -> Dict[str, Any]:
"""Get default constraints for a resource type"""
constraint_mappings = {
"external_services": {
"max_instances_per_user": 10,
"max_cpu_per_instance": "2000m",
"max_memory_per_instance": "4Gi",
"max_storage_per_instance": "50Gi",
"allowed_service_types": ["ctfd", "canvas", "guacamole"]
},
"rag_processing": {
"max_document_size_mb": 100,
"max_batch_size": 50,
"max_requests_per_hour": 1000
},
"llm_inference": {
"max_tokens_per_request": 4000,
"max_requests_per_hour": 100,
"allowed_models": [] # Models dynamically determined by admin backend
},
"llm": {
"max_tokens_per_request": 4000,
"max_requests_per_hour": 100,
"allowed_models": [] # Models dynamically determined by admin backend
},
"agent_orchestration": {
"max_concurrent_agents": 5,
"max_execution_time_minutes": 30
},
"ai_literacy": {
"max_sessions_per_day": 20,
"max_session_duration_hours": 4
},
"app_integrations": {
"max_api_calls_per_hour": 500,
"allowed_domains": ["api.example.com"]
},
# MCP Server Resources
"mcp:rag": {
"max_requests_per_hour": 500,
"max_results_per_query": 50
}
}
return constraint_mappings.get(resource, {})
def _generate_capability_hash(self, resources: List[str], tenant_id: str) -> str:
"""Generate a hash of the capabilities for verification"""
import hashlib
# Create a deterministic string from capabilities
capability_string = f"{tenant_id}:{':'.join(sorted(resources))}"
# Hash with SHA-256
hash_object = hashlib.sha256(capability_string.encode())
return hash_object.hexdigest()[:16] # First 16 characters
async def verify_capability_token(self, token: str) -> Dict[str, Any]:
"""
Verify and decode a capability token.
Args:
token: JWT token to verify
Returns:
Decoded token payload
Raises:
ValueError: If token is invalid or expired
"""
try:
# Decode and verify the token
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
audience="gt2-resource-cluster"
)
# Additional validation
if payload.get("iss") != self.issuer:
raise ValueError("Invalid token issuer")
# Check if token is still valid
now = datetime.utcnow()
if payload.get("exp", 0) < now.timestamp():
raise ValueError("Token has expired")
if payload.get("nbf", 0) > now.timestamp():
raise ValueError("Token not yet valid")
logger.debug(f"Verified capability token for user {payload.get('user_email')}")
return payload
except jwt.ExpiredSignatureError:
raise ValueError("Token has expired")
except jwt.JWTClaimsError as e:
raise ValueError(f"Token claims validation failed: {e}")
except jwt.JWTError as e:
raise ValueError(f"Token validation failed: {e}")
except Exception as e:
logger.error(f"Capability token verification failed: {e}")
raise ValueError(f"Invalid token: {e}")
async def refresh_capability_token(
self,
current_token: str,
extend_hours: int = 24
) -> str:
"""
Refresh an existing capability token with extended expiration.
Args:
current_token: Current JWT token
extend_hours: Hours to extend from now
Returns:
New JWT token with extended expiration
"""
# Verify current token
payload = await self.verify_capability_token(current_token)
# Extract current capabilities
resources = [cap.get("resource") for cap in payload.get("capabilities", [])]
# Generate new token with extended expiration
return await self.generate_capability_token(
user_email=payload.get("user_email"),
tenant_id=payload.get("tenant_id"),
resources=resources,
expires_hours=extend_hours
)
def get_token_info(self, token: str) -> Dict[str, Any]:
"""
Get information about a token without full verification.
Useful for debugging and logging.
"""
try:
# Decode without verification to get claims
payload = jwt.get_unverified_claims(token)
return {
"user_email": payload.get("user_email"),
"tenant_id": payload.get("tenant_id"),
"resources": [cap.get("resource") for cap in payload.get("capabilities", [])],
"expires_at": datetime.fromtimestamp(payload.get("exp", 0)).isoformat(),
"issued_at": datetime.fromtimestamp(payload.get("iat", 0)).isoformat(),
"token_version": payload.get("token_version"),
"security_level": payload.get("security_level")
}
except Exception as e:
logger.error(f"Failed to get token info: {e}")
return {"error": str(e)}

View File

@@ -0,0 +1,289 @@
"""
GT 2.0 Tenant Backend Configuration
Environment-based configuration for tenant applications with perfect isolation.
Each tenant gets its own isolated backend instance with separate database files.
"""
import os
from typing import List, Optional
from pydantic_settings import BaseSettings
from pydantic import Field, validator
class Settings(BaseSettings):
"""Application settings with environment variable support"""
# Environment
environment: str = Field(default="development", description="Runtime environment")
debug: bool = Field(default=False, description="Debug mode")
# Tenant Identification (Critical for isolation)
tenant_id: str = Field(..., description="Unique tenant identifier")
tenant_domain: str = Field(..., description="Tenant domain (e.g., customer1)")
# Database Configuration (PostgreSQL + PGVector direct connection)
database_url: str = Field(
default="postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres-primary:5432/gt2_tenants",
description="PostgreSQL connection URL (direct to primary)"
)
# PostgreSQL Configuration
postgres_schema: str = Field(
default="tenant_test",
description="PostgreSQL schema for tenant data (tenant_{tenant_domain})"
)
postgres_pool_size: int = Field(
default=10,
description="Connection pool size for PostgreSQL"
)
postgres_max_overflow: int = Field(
default=20,
description="Max overflow connections for PostgreSQL pool"
)
# Authentication & Security
secret_key: str = Field(..., description="JWT signing key")
algorithm: str = Field(default="HS256", description="JWT algorithm")
# OAuth2 Configuration
require_oauth2_auth: bool = Field(
default=True,
description="Require OAuth2 authentication for API endpoints"
)
oauth2_proxy_url: str = Field(
default="http://oauth2-proxy:4180",
description="Internal URL of OAuth2 Proxy service"
)
oauth2_issuer_url: str = Field(
default="https://auth.gt2.com",
description="OAuth2 provider issuer URL"
)
oauth2_audience: str = Field(
default="gt2-tenant-client",
description="OAuth2 token audience"
)
# Resource Cluster Integration
resource_cluster_url: str = Field(
default="http://localhost:8004",
description="URL of the Resource Cluster API"
)
resource_cluster_api_key: Optional[str] = Field(
default=None,
description="API key for Resource Cluster authentication"
)
# MCP Service Configuration
mcp_service_url: str = Field(
default="http://resource-cluster:8000",
description="URL of the MCP service for tool execution"
)
# Control Panel Integration
control_panel_url: str = Field(
default="http://localhost:8001",
description="URL of the Control Panel API"
)
service_auth_token: str = Field(
default="internal-service-token",
description="Service-to-service authentication token"
)
# WebSocket Configuration
websocket_ping_interval: int = Field(default=25, description="WebSocket ping interval")
websocket_ping_timeout: int = Field(default=20, description="WebSocket ping timeout")
# File Upload Configuration
max_file_size_mb: int = Field(default=10, description="Maximum file upload size in MB")
allowed_file_types: List[str] = Field(
default=[".pdf", ".docx", ".txt", ".md", ".csv", ".xlsx"],
description="Allowed file extensions for upload"
)
upload_directory: str = Field(
default_factory=lambda: f"/tmp/gt2-data/{os.getenv('TENANT_DOMAIN', 'default')}/uploads" if os.getenv('ENVIRONMENT') == 'test' else f"/data/{os.getenv('TENANT_DOMAIN', 'default')}/uploads",
description="Directory for uploaded files"
)
temp_directory: str = Field(
default_factory=lambda: f"/tmp/gt2-data/{os.getenv('TENANT_DOMAIN', 'default')}/temp" if os.getenv('ENVIRONMENT') == 'test' else f"/data/{os.getenv('TENANT_DOMAIN', 'default')}/temp",
description="Temporary directory for file processing"
)
file_storage_path: str = Field(
default_factory=lambda: f"/tmp/gt2-data/{os.getenv('TENANT_DOMAIN', 'default')}" if os.getenv('ENVIRONMENT') == 'test' else f"/data/{os.getenv('TENANT_DOMAIN', 'default')}",
description="Root directory for file storage (conversation files, etc.)"
)
# File Context Settings (for chat attachments)
max_chunks_per_file: int = Field(
default=50,
description="Maximum chunks per file (enforces diversity across files)"
)
max_total_file_chunks: int = Field(
default=100,
description="Maximum total chunks across all attached files"
)
file_context_token_safety_margin: float = Field(
default=0.05,
description="Safety margin for token budget calculations (0.05 = 5%)"
)
# Rate Limiting
rate_limit_requests: int = Field(default=1000, description="Requests per minute per IP")
rate_limit_window_seconds: int = Field(default=60, description="Rate limit window")
# CORS Configuration
cors_origins: List[str] = Field(
default=["http://localhost:3001", "http://localhost:3002", "https://*.gt2.com"],
description="Allowed CORS origins"
)
# Security
allowed_hosts: List[str] = Field(
default=["localhost", "*.gt2.com", "testserver", "gentwo-tenant-backend", "tenant-backend"],
description="Allowed host headers"
)
# Vector Storage Configuration (PGVector integrated with PostgreSQL)
vector_dimensions: int = Field(
default=384,
description="Vector dimensions for embeddings (all-MiniLM-L6-v2 model)"
)
embedding_model: str = Field(
default="all-MiniLM-L6-v2",
description="Embedding model for document processing"
)
vector_similarity_threshold: float = Field(
default=0.3,
description="Minimum similarity threshold for vector search"
)
# Legacy ChromaDB Configuration (DEPRECATED - replaced by PGVector)
chromadb_mode: str = Field(
default="disabled",
description="ChromaDB mode - DEPRECATED, using PGVector instead"
)
chromadb_host: str = Field(
default_factory=lambda: f"tenant-{os.getenv('TENANT_DOMAIN', 'test')}-chromadb",
description="ChromaDB host - DEPRECATED"
)
chromadb_port: int = Field(
default=8000,
description="ChromaDB HTTP port - DEPRECATED"
)
chromadb_path: str = Field(
default_factory=lambda: f"/data/{os.getenv('TENANT_DOMAIN', 'default')}/chromadb",
description="ChromaDB file storage path - DEPRECATED"
)
# Redis removed - PostgreSQL handles all caching and session storage needs
# Logging Configuration
log_level: str = Field(default="INFO", description="Logging level")
log_format: str = Field(default="json", description="Log format: json or text")
# Performance
worker_processes: int = Field(default=1, description="Number of worker processes")
max_connections: int = Field(default=100, description="Maximum concurrent connections")
# Monitoring
prometheus_enabled: bool = Field(default=True, description="Enable Prometheus metrics")
prometheus_port: int = Field(default=9090, description="Prometheus metrics port")
# Feature Flags
enable_file_upload: bool = Field(default=True, description="Enable file upload feature")
enable_voice_input: bool = Field(default=False, description="Enable voice input (future)")
enable_document_analysis: bool = Field(default=True, description="Enable document analysis")
@validator("tenant_id")
def validate_tenant_id(cls, v):
if not v or len(v) < 3:
raise ValueError("Tenant ID must be at least 3 characters long")
return v
@validator("tenant_domain")
def validate_tenant_domain(cls, v):
if not v or not v.replace("-", "").replace("_", "").isalnum():
raise ValueError("Tenant domain must be alphanumeric with optional hyphens/underscores")
return v
@validator("upload_directory")
def validate_upload_directory(cls, v):
# Ensure the upload directory exists with secure permissions
os.makedirs(v, exist_ok=True, mode=0o700)
return v
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore",
}
def get_settings(tenant_id: Optional[str] = None) -> Settings:
"""Get tenant-scoped application settings"""
# For development and testing, use simple settings without caching
if os.getenv("ENVIRONMENT") in ["development", "test"]:
return Settings()
# In production, settings should be tenant-scoped
# This prevents global state from affecting tenant isolation
if tenant_id:
# Create tenant-specific settings with proper isolation
settings = Settings()
# In production, this could load tenant-specific overrides
return settings
else:
# Default settings for non-tenant operations
return Settings()
# Security and isolation utilities
def get_tenant_data_path(tenant_domain: str) -> str:
"""Get the secure data path for a tenant"""
if os.getenv('ENVIRONMENT') == 'test':
return f"/tmp/gt2-data/{tenant_domain}"
return f"/data/{tenant_domain}"
def get_tenant_database_url(tenant_domain: str) -> str:
"""Get the database URL for a specific tenant (PostgreSQL)"""
return f"postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres:5432/gt2_tenants"
def get_tenant_schema_name(tenant_domain: str) -> str:
"""Get the PostgreSQL schema name for a specific tenant"""
# Clean domain name for schema usage
clean_domain = tenant_domain.replace('-', '_').replace('.', '_').lower()
return f"tenant_{clean_domain}"
def ensure_tenant_isolation(tenant_id: str) -> None:
"""Ensure proper tenant isolation is configured"""
settings = get_settings()
if settings.tenant_id != tenant_id:
raise ValueError(f"Tenant ID mismatch: expected {settings.tenant_id}, got {tenant_id}")
# Verify database path contains tenant identifier
if settings.tenant_domain not in settings.database_path:
raise ValueError("Database path does not contain tenant identifier - isolation breach risk")
# Verify upload directory contains tenant identifier
if settings.tenant_domain not in settings.upload_directory:
raise ValueError("Upload directory does not contain tenant identifier - isolation breach risk")
# Development helpers
def is_development() -> bool:
"""Check if running in development mode"""
return get_settings().environment == "development"
def is_production() -> bool:
"""Check if running in production mode"""
return get_settings().environment == "production"

View File

@@ -0,0 +1,131 @@
"""
GT 2.0 Tenant Backend Database Configuration - PostgreSQL + PGVector Client
Migrated from DuckDB service to PostgreSQL + PGVector for enterprise readiness:
- PostgreSQL + PGVector unified storage (replaces DuckDB + ChromaDB)
- BionicGPT Row Level Security patterns for enterprise isolation
- MVCC concurrency solving DuckDB file locking issues
- Hybrid vector + full-text search in single queries
- Connection pooling for 10,000+ concurrent connections
"""
import os
import logging
from typing import Generator, Optional, Any, Dict, List
from contextlib import contextmanager, asynccontextmanager
from sqlalchemy.ext.declarative import declarative_base
from app.core.config import get_settings
from app.core.postgresql_client import (
get_postgresql_client, init_postgresql, close_postgresql,
get_db_session, execute_query, execute_command,
fetch_one, fetch_scalar, health_check, get_database_info
)
# Legacy DuckDB imports removed - PostgreSQL + PGVector only
# SQLAlchemy Base for ORM models
Base = declarative_base()
logger = logging.getLogger(__name__)
settings = get_settings()
# PostgreSQL client is managed by postgresql_client module
async def init_database() -> None:
"""Initialize PostgreSQL + PGVector connection"""
logger.info("Initializing PostgreSQL + PGVector database connection...")
try:
await init_postgresql()
logger.info("PostgreSQL + PGVector connection initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL database: {e}")
raise
async def close_database() -> None:
"""Close PostgreSQL connections"""
try:
await close_postgresql()
logger.info("PostgreSQL connections closed")
except Exception as e:
logger.error(f"Error closing PostgreSQL connections: {e}")
async def get_db_client_instance():
"""Get the PostgreSQL client instance"""
return await get_postgresql_client()
# get_db_session is imported from postgresql_client
# execute_query is imported from postgresql_client
# execute_command is imported from postgresql_client
async def execute_transaction(commands: List[Dict[str, Any]]) -> List[int]:
"""Execute multiple commands in a transaction (PostgreSQL format)"""
client = await get_postgresql_client()
pg_commands = [(cmd.get('query', cmd.get('command', '')), tuple(cmd.get('params', {}).values())) for cmd in commands]
return await client.execute_transaction(pg_commands)
# fetch_one is imported from postgresql_client
async def fetch_all(query: str, *args) -> List[Dict[str, Any]]:
"""Execute query and return all rows"""
return await execute_query(query, *args)
# fetch_scalar is imported from postgresql_client
# get_database_info is imported from postgresql_client
# health_check is imported from postgresql_client
# Legacy compatibility functions (for gradual migration)
def get_db() -> Generator[None, None, None]:
"""Legacy sync database dependency - deprecated"""
logger.warning("get_db() is deprecated. Use async get_db_session() instead")
# Return a dummy generator for compatibility
yield None
@contextmanager
def get_db_session_sync():
"""Legacy sync session - deprecated"""
logger.warning("get_db_session_sync() is deprecated. Use async get_db_session() instead")
yield None
def execute_raw_query(query: str, params: Optional[Dict] = None) -> List[Dict]:
"""Legacy sync query execution - deprecated"""
logger.error("execute_raw_query() is deprecated and not supported with PostgreSQL async client")
raise NotImplementedError("Use async execute_query() instead")
def verify_tenant_isolation() -> bool:
"""Verify tenant isolation - PostgreSQL schema-based isolation with RLS is always enabled"""
return True
# Initialize database on module import (for FastAPI startup)
async def startup_database():
"""Initialize database during FastAPI startup"""
await init_database()
async def shutdown_database():
"""Cleanup database during FastAPI shutdown"""
await close_database()

View File

@@ -0,0 +1,348 @@
"""
GT 2.0 Database Interface - DuckDB Implementation
Provides a unified interface for DuckDB database operations
following GT 2.0 principles of Zero Downtime, Perfect Tenant Isolation, and Elegant Simplicity.
Post-migration: SQLite has been completely replaced with DuckDB for enhanced MVCC performance.
"""
import asyncio
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, AsyncGenerator, Union
from contextlib import asynccontextmanager
from dataclasses import dataclass
class DatabaseEngine(Enum):
"""Supported database engines - DEPRECATED: Use PostgreSQL directly"""
POSTGRESQL = "postgresql"
@dataclass
class DatabaseConfig:
"""Database configuration"""
engine: DatabaseEngine
database_path: str
tenant_id: str
shard_id: Optional[str] = None
encryption_key: Optional[str] = None
connection_params: Optional[Dict[str, Any]] = None
@dataclass
class QueryResult:
"""Standardized query result"""
rows: List[Dict[str, Any]]
row_count: int
columns: List[str]
execution_time_ms: float
class DatabaseInterface(ABC):
"""
Abstract database interface for GT 2.0 tenant isolation.
DuckDB implementation with MVCC concurrency for true zero-downtime operations,
perfect tenant isolation, and 10x analytical performance improvements.
"""
def __init__(self, config: DatabaseConfig):
self.config = config
self.tenant_id = config.tenant_id
self.database_path = config.database_path
self.engine = config.engine
# Connection Management
@abstractmethod
async def initialize(self) -> None:
"""Initialize database connection and create tables"""
pass
@abstractmethod
async def close(self) -> None:
"""Close database connections"""
pass
@abstractmethod
async def is_initialized(self) -> bool:
"""Check if database is properly initialized"""
pass
@abstractmethod
@asynccontextmanager
async def get_session(self) -> AsyncGenerator[Any, None]:
"""Get database session context manager"""
pass
# Schema Management
@abstractmethod
async def create_tables(self) -> None:
"""Create all required tables"""
pass
@abstractmethod
async def get_schema_version(self) -> Optional[str]:
"""Get current database schema version"""
pass
@abstractmethod
async def migrate_schema(self, target_version: str) -> bool:
"""Migrate database schema to target version"""
pass
# Query Operations
@abstractmethod
async def execute_query(
self,
query: str,
params: Optional[Dict[str, Any]] = None
) -> QueryResult:
"""Execute SELECT query and return results"""
pass
@abstractmethod
async def execute_command(
self,
command: str,
params: Optional[Dict[str, Any]] = None
) -> int:
"""Execute INSERT/UPDATE/DELETE command and return affected rows"""
pass
@abstractmethod
async def execute_batch(
self,
commands: List[str],
params: Optional[List[Dict[str, Any]]] = None
) -> List[int]:
"""Execute batch commands in transaction"""
pass
# Transaction Management
@abstractmethod
@asynccontextmanager
async def transaction(self) -> AsyncGenerator[Any, None]:
"""Transaction context manager"""
pass
@abstractmethod
async def begin_transaction(self) -> Any:
"""Begin transaction and return transaction handle"""
pass
@abstractmethod
async def commit_transaction(self, tx: Any) -> None:
"""Commit transaction"""
pass
@abstractmethod
async def rollback_transaction(self, tx: Any) -> None:
"""Rollback transaction"""
pass
# Health and Monitoring
@abstractmethod
async def health_check(self) -> Dict[str, Any]:
"""Check database health and return status"""
pass
@abstractmethod
async def get_statistics(self) -> Dict[str, Any]:
"""Get database statistics"""
pass
@abstractmethod
async def optimize(self) -> bool:
"""Optimize database performance"""
pass
# Backup and Recovery
@abstractmethod
async def backup(self, backup_path: str) -> bool:
"""Create database backup"""
pass
@abstractmethod
async def restore(self, backup_path: str) -> bool:
"""Restore from database backup"""
pass
# Sharding Support (DuckDB specific)
@abstractmethod
async def create_shard(self, shard_id: str) -> bool:
"""Create new database shard"""
pass
@abstractmethod
async def get_shard_info(self) -> Dict[str, Any]:
"""Get information about current shard"""
pass
@abstractmethod
async def migrate_to_shard(self, source_db: 'DatabaseInterface') -> bool:
"""Migrate data from another database instance"""
pass
# Vector Operations (ChromaDB integration)
@abstractmethod
async def store_embeddings(
self,
collection: str,
embeddings: List[List[float]],
documents: List[str],
metadata: List[Dict[str, Any]]
) -> bool:
"""Store embeddings with documents and metadata"""
pass
@abstractmethod
async def query_embeddings(
self,
collection: str,
query_embedding: List[float],
limit: int = 10,
filter_metadata: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""Query embeddings by similarity"""
pass
# Data Import/Export
@abstractmethod
async def export_data(
self,
format: str = "json",
tables: Optional[List[str]] = None
) -> Dict[str, Any]:
"""Export database data"""
pass
@abstractmethod
async def import_data(
self,
data: Dict[str, Any],
format: str = "json",
merge_strategy: str = "replace"
) -> bool:
"""Import database data"""
pass
# Security and Encryption
@abstractmethod
async def encrypt_database(self, encryption_key: str) -> bool:
"""Enable database encryption"""
pass
@abstractmethod
async def verify_encryption(self) -> bool:
"""Verify database encryption status"""
pass
# Performance and Indexing
@abstractmethod
async def create_index(
self,
table: str,
columns: List[str],
index_name: Optional[str] = None,
unique: bool = False
) -> bool:
"""Create database index"""
pass
@abstractmethod
async def drop_index(self, index_name: str) -> bool:
"""Drop database index"""
pass
@abstractmethod
async def analyze_queries(self) -> Dict[str, Any]:
"""Analyze query performance"""
pass
# Utility Methods
async def get_engine_info(self) -> Dict[str, Any]:
"""Get database engine information"""
return {
"engine": self.engine.value,
"tenant_id": self.tenant_id,
"database_path": self.database_path,
"shard_id": self.config.shard_id,
"supports_mvcc": self.engine == DatabaseEngine.POSTGRESQL,
"supports_sharding": self.engine == DatabaseEngine.POSTGRESQL,
"file_based": True
}
async def validate_tenant_isolation(self) -> bool:
"""Validate that tenant isolation is maintained"""
try:
stats = await self.get_statistics()
return (
self.tenant_id in self.database_path and
stats.get("isolated", False)
)
except Exception:
return False
class DatabaseFactory:
"""Factory for creating database instances"""
@staticmethod
async def create_database(config: DatabaseConfig) -> DatabaseInterface:
"""Create database instance - PostgreSQL only"""
raise NotImplementedError("Database interface deprecated. Use PostgreSQL directly via postgresql_client.py")
@staticmethod
async def migrate_database(
source_config: DatabaseConfig,
target_config: DatabaseConfig,
migration_options: Optional[Dict[str, Any]] = None
) -> bool:
"""Migrate data from source to target database"""
source_db = await DatabaseFactory.create_database(source_config)
target_db = await DatabaseFactory.create_database(target_config)
try:
await source_db.initialize()
await target_db.initialize()
# Export data from source
data = await source_db.export_data()
# Import data to target
success = await target_db.import_data(data)
if success and migration_options and migration_options.get("verify", True):
# Verify migration
source_stats = await source_db.get_statistics()
target_stats = await target_db.get_statistics()
return source_stats.get("row_count", 0) == target_stats.get("row_count", 0)
return success
finally:
await source_db.close()
await target_db.close()
# Error Classes
class DatabaseError(Exception):
"""Base database error"""
pass
class DatabaseConnectionError(DatabaseError):
"""Database connection error"""
pass
class DatabaseMigrationError(DatabaseError):
"""Database migration error"""
pass
class DatabaseShardingError(DatabaseError):
"""Database sharding error"""
pass

View File

@@ -0,0 +1,265 @@
"""
Resource Access Control Dependencies for FastAPI
Provides declarative access control for agents and datasets using team-based permissions.
"""
from typing import Callable
from uuid import UUID
from fastapi import Depends, HTTPException
from app.api.dependencies import get_current_user
from app.services.team_service import TeamService
from app.core.permissions import get_user_role
from app.core.postgresql_client import get_postgresql_client
import logging
logger = logging.getLogger(__name__)
def require_resource_access(
resource_type: str,
required_permission: str = "read"
) -> Callable:
"""
FastAPI dependency factory for resource access control.
Creates a dependency that verifies user has required permission on a resource
via ownership, organization visibility, or team membership.
Args:
resource_type: 'agent' or 'dataset'
required_permission: 'read' or 'edit' (default: 'read')
Returns:
FastAPI dependency function
Usage:
@router.get("/agents/{agent_id}")
async def get_agent(
agent_id: str,
_: None = Depends(require_resource_access("agent", "read"))
):
# User has read access if we reach here
...
@router.put("/agents/{agent_id}")
async def update_agent(
agent_id: str,
_: None = Depends(require_resource_access("agent", "edit"))
):
# User has edit access if we reach here
...
"""
async def check_access(
resource_id: str,
current_user: dict = Depends(get_current_user)
) -> None:
"""
Verify user has required permission on resource.
Raises HTTPException(403) if access denied.
"""
user_id = current_user["user_id"]
tenant_domain = current_user["tenant_domain"]
user_email = current_user.get("email", user_id)
try:
pg_client = await get_postgresql_client()
# Check if admin/developer (bypass all checks)
user_role = await get_user_role(pg_client, user_email, tenant_domain)
if user_role in ["admin", "developer"]:
logger.debug(f"Admin/developer {user_id} has full access to {resource_type} {resource_id}")
return
# Check if user owns the resource
ownership_query = f"""
SELECT created_by FROM {resource_type}s
WHERE id = $1::uuid
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
"""
owner_id = await pg_client.fetch_scalar(ownership_query, resource_id, tenant_domain)
if owner_id and str(owner_id) == str(user_id):
logger.debug(f"User {user_id} owns {resource_type} {resource_id}")
return
# Check if resource is organization-wide
visibility_query = f"""
SELECT visibility FROM {resource_type}s
WHERE id = $1::uuid
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
"""
visibility = await pg_client.fetch_scalar(visibility_query, resource_id, tenant_domain)
if visibility == "organization":
logger.debug(f"{resource_type.capitalize()} {resource_id} is organization-wide")
return
# Check team-based access using TeamService
team_service = TeamService(tenant_domain, user_id, user_email)
has_permission = await team_service.check_user_resource_permission(
user_id=user_id,
resource_type=resource_type,
resource_id=resource_id,
required_permission=required_permission
)
if has_permission:
logger.debug(f"User {user_id} has {required_permission} permission on {resource_type} {resource_id} via team")
return
# Access denied
logger.warning(f"Access denied: User {user_id} cannot access {resource_type} {resource_id} (required: {required_permission})")
raise HTTPException(
status_code=403,
detail=f"You do not have {required_permission} permission for this {resource_type}"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error checking resource access: {e}")
raise HTTPException(
status_code=500,
detail=f"Error verifying {resource_type} access"
)
return check_access
def require_agent_access(required_permission: str = "read") -> Callable:
"""
Convenience wrapper for agent access control.
Usage:
@router.get("/agents/{agent_id}")
async def get_agent(
agent_id: str,
_: None = Depends(require_agent_access("read"))
):
...
"""
return require_resource_access("agent", required_permission)
def require_dataset_access(required_permission: str = "read") -> Callable:
"""
Convenience wrapper for dataset access control.
Usage:
@router.get("/datasets/{dataset_id}")
async def get_dataset(
dataset_id: str,
_: None = Depends(require_dataset_access("read"))
):
...
"""
return require_resource_access("dataset", required_permission)
async def check_agent_edit_permission(
agent_id: str,
user_id: str,
tenant_domain: str,
user_email: str = None
) -> bool:
"""
Helper function to check if user can edit an agent.
Can be used in service layer without FastAPI dependency injection.
Args:
agent_id: UUID of the agent
user_id: UUID of the user
tenant_domain: Tenant domain
user_email: User email (optional)
Returns:
True if user can edit agent
"""
try:
pg_client = await get_postgresql_client()
# Check if admin/developer
user_role = await get_user_role(pg_client, user_email or user_id, tenant_domain)
if user_role in ["admin", "developer"]:
return True
# Check ownership
query = """
SELECT created_by FROM agents
WHERE id = $1::uuid
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
"""
owner_id = await pg_client.fetch_scalar(query, agent_id, tenant_domain)
if owner_id and str(owner_id) == str(user_id):
return True
# Check team edit permission
team_service = TeamService(tenant_domain, user_id, user_email or user_id)
return await team_service.check_user_resource_permission(
user_id=user_id,
resource_type="agent",
resource_id=agent_id,
required_permission="edit"
)
except Exception as e:
logger.error(f"Error checking agent edit permission: {e}")
return False
async def check_dataset_edit_permission(
dataset_id: str,
user_id: str,
tenant_domain: str,
user_email: str = None
) -> bool:
"""
Helper function to check if user can edit a dataset.
Can be used in service layer without FastAPI dependency injection.
Args:
dataset_id: UUID of the dataset
user_id: UUID of the user
tenant_domain: Tenant domain
user_email: User email (optional)
Returns:
True if user can edit dataset
"""
try:
pg_client = await get_postgresql_client()
# Check if admin/developer
user_role = await get_user_role(pg_client, user_email or user_id, tenant_domain)
if user_role in ["admin", "developer"]:
return True
# Check ownership
query = """
SELECT user_id FROM datasets
WHERE id = $1::uuid
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
"""
owner_id = await pg_client.fetch_scalar(query, dataset_id, tenant_domain)
if owner_id and str(owner_id) == str(user_id):
return True
# Check team edit permission
team_service = TeamService(tenant_domain, user_id, user_email or user_id)
return await team_service.check_user_resource_permission(
user_id=user_id,
resource_type="dataset",
resource_id=dataset_id,
required_permission="edit"
)
except Exception as e:
logger.error(f"Error checking dataset edit permission: {e}")
return False

View File

@@ -0,0 +1,169 @@
"""
GT 2.0 Tenant Backend Logging Configuration
Structured logging with tenant isolation and security considerations.
"""
import logging
import logging.config
import sys
from typing import Dict, Any
from app.core.config import get_settings
def setup_logging() -> None:
"""Setup logging configuration for the tenant backend"""
settings = get_settings()
# Determine log directory based on environment
if settings.environment == "test":
log_dir = f"/tmp/gt2-data/{settings.tenant_domain}/logs"
else:
log_dir = f"/data/{settings.tenant_domain}/logs"
# Create logging configuration
log_config: Dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
"json": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(pathname)s:%(lineno)d",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(funcName)s() - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": settings.log_level,
"formatter": "json" if settings.log_format == "json" else "default",
"stream": sys.stdout,
},
"file": {
"class": "logging.handlers.RotatingFileHandler",
"level": "INFO",
"formatter": "json" if settings.log_format == "json" else "detailed",
"filename": f"{log_dir}/tenant-backend.log",
"maxBytes": 10485760, # 10MB
"backupCount": 5,
"encoding": "utf-8",
},
},
"loggers": {
"": { # Root logger
"level": settings.log_level,
"handlers": ["console"],
"propagate": False,
},
"app": {
"level": settings.log_level,
"handlers": ["console", "file"] if settings.environment == "production" else ["console"],
"propagate": False,
},
"sqlalchemy.engine": {
"level": "INFO" if settings.debug else "WARNING",
"handlers": ["console"],
"propagate": False,
},
"uvicorn.access": {
"level": "WARNING", # Suppress INFO level access logs (operational endpoints)
"handlers": ["console"],
"propagate": False,
},
"uvicorn.error": {
"level": "INFO",
"handlers": ["console"],
"propagate": False,
},
},
}
# Create log directory if it doesn't exist
import os
os.makedirs(log_dir, exist_ok=True, mode=0o700)
# Apply logging configuration
logging.config.dictConfig(log_config)
# Add tenant context to all logs
class TenantContextFilter(logging.Filter):
def filter(self, record):
record.tenant_id = settings.tenant_id
record.tenant_domain = settings.tenant_domain
return True
tenant_filter = TenantContextFilter()
# Add tenant filter to all handlers
for handler in logging.getLogger().handlers:
handler.addFilter(tenant_filter)
# Log startup information
logger = logging.getLogger("app.startup")
logger.info(
"Tenant backend logging initialized",
extra={
"tenant_id": settings.tenant_id,
"tenant_domain": settings.tenant_domain,
"environment": settings.environment,
"log_level": settings.log_level,
"log_format": settings.log_format,
}
)
def get_logger(name: str) -> logging.Logger:
"""Get logger with consistent naming and formatting"""
return logging.getLogger(f"app.{name}")
class SecurityRedactionFilter(logging.Filter):
"""Filter to redact sensitive information from logs"""
SENSITIVE_FIELDS = [
"password", "token", "secret", "key", "authorization",
"cookie", "session", "csrf", "api_key", "jwt"
]
def filter(self, record):
if hasattr(record, 'args') and record.args:
# Redact sensitive information from log messages
record.args = self._redact_sensitive_data(record.args)
if hasattr(record, 'msg') and isinstance(record.msg, str):
for field in self.SENSITIVE_FIELDS:
if field.lower() in record.msg.lower():
record.msg = record.msg.replace(field, "[REDACTED]")
return True
def _redact_sensitive_data(self, data):
"""Recursively redact sensitive data from log arguments"""
if isinstance(data, dict):
return {
key: "[REDACTED]" if any(sensitive in key.lower() for sensitive in self.SENSITIVE_FIELDS)
else self._redact_sensitive_data(value)
for key, value in data.items()
}
elif isinstance(data, (list, tuple)):
return type(data)(self._redact_sensitive_data(item) for item in data)
return data
def setup_security_logging():
"""Setup security-focused logging with redaction"""
security_filter = SecurityRedactionFilter()
# Add security filter to all loggers
for name in ["app", "uvicorn", "sqlalchemy"]:
logger = logging.getLogger(name)
logger.addFilter(security_filter)

View File

@@ -0,0 +1,175 @@
"""
Path Security Utilities for GT AI OS
Provides path sanitization and validation to prevent path traversal attacks.
"""
import re
from pathlib import Path
from typing import Optional
def sanitize_path_component(component: str) -> str:
"""
Sanitize a single path component to prevent path traversal.
Removes or replaces dangerous characters including:
- Path separators (/ and \\)
- Parent directory references (..)
- Null bytes
- Other special characters
Args:
component: The path component to sanitize
Returns:
Sanitized component safe for use in file paths
"""
if not component:
return ""
# Remove null bytes
sanitized = component.replace('\x00', '')
# Remove path separators
sanitized = re.sub(r'[/\\]', '', sanitized)
# Remove parent directory references
sanitized = sanitized.replace('..', '')
# For tenant domains and similar identifiers, allow alphanumeric, hyphen, underscore
# For filenames, allow alphanumeric, hyphen, underscore, and single dots
sanitized = re.sub(r'[^a-zA-Z0-9_\-.]', '_', sanitized)
# Prevent leading dots (hidden files) and multiple consecutive dots
sanitized = re.sub(r'^\.+', '', sanitized)
sanitized = re.sub(r'\.{2,}', '.', sanitized)
return sanitized
def sanitize_tenant_domain(domain: str) -> str:
"""
Sanitize a tenant domain for safe use in file paths.
More restrictive than general path component sanitization.
Only allows lowercase alphanumeric characters, hyphens, and underscores.
Args:
domain: The tenant domain to sanitize
Returns:
Sanitized domain safe for use in file paths
"""
if not domain:
raise ValueError("Tenant domain cannot be empty")
# Convert to lowercase and sanitize
sanitized = domain.lower()
sanitized = re.sub(r'[^a-z0-9_\-]', '_', sanitized)
sanitized = sanitized.strip('_-')
if not sanitized:
raise ValueError("Tenant domain resulted in empty string after sanitization")
return sanitized
def sanitize_filename(filename: str) -> str:
"""
Sanitize a filename for safe storage.
Preserves the file extension but sanitizes the rest.
Args:
filename: The filename to sanitize
Returns:
Sanitized filename
"""
if not filename:
return ""
# Get the extension
path = Path(filename)
stem = path.stem
suffix = path.suffix
# Sanitize the stem (filename without extension)
safe_stem = sanitize_path_component(stem)
# Sanitize the extension (should just be alphanumeric)
safe_suffix = ""
if suffix:
safe_suffix = '.' + re.sub(r'[^a-zA-Z0-9]', '', suffix[1:])
result = safe_stem + safe_suffix
if not result:
result = "unnamed"
return result
def safe_join_path(base: Path, *components: str, require_within_base: bool = True) -> Path:
"""
Safely join path components, preventing traversal attacks.
Args:
base: The base directory that all paths must stay within
components: Path components to join to the base
require_within_base: If True, verify the result is within base
Returns:
The joined path
Raises:
ValueError: If the resulting path would be outside the base directory
"""
if not base:
raise ValueError("Base path cannot be empty")
# Sanitize all components
sanitized = [sanitize_path_component(c) for c in components if c]
# Filter out empty components
sanitized = [c for c in sanitized if c]
if not sanitized:
return base
# Join the path
result = base.joinpath(*sanitized)
# Verify the result is within the base directory
if require_within_base:
try:
resolved_base = base.resolve()
resolved_result = result.resolve()
# Check if result is within base
resolved_result.relative_to(resolved_base)
except (ValueError, RuntimeError):
raise ValueError(f"Path traversal detected: result would be outside base directory")
return result
def validate_file_extension(filename: str, allowed_extensions: Optional[list] = None) -> bool:
"""
Validate that a file has an allowed extension.
Args:
filename: The filename to check
allowed_extensions: List of allowed extensions (e.g., ['.txt', '.pdf']).
If None, all extensions are allowed.
Returns:
True if the extension is allowed, False otherwise
"""
if allowed_extensions is None:
return True
path = Path(filename)
extension = path.suffix.lower()
return extension in [ext.lower() for ext in allowed_extensions]

View File

@@ -0,0 +1,138 @@
"""
GT 2.0 Role-Based Permissions
Enforces organization-level resource sharing based on user roles.
Visibility Levels:
- individual: Only the creator can see and edit
- organization: All users can read, only admins/developers can create and edit
"""
from fastapi import HTTPException, status
import logging
logger = logging.getLogger(__name__)
# Role hierarchy: admin/developer > analyst > student
ADMIN_ROLES = ["admin", "developer"]
# Visibility levels
VISIBILITY_INDIVIDUAL = "individual"
VISIBILITY_ORGANIZATION = "organization"
async def get_user_role(pg_client, user_email: str, tenant_domain: str) -> str:
"""
Get the role for a user in the tenant database.
Returns: 'admin', 'developer', 'analyst', or 'student'
"""
query = """
SELECT role FROM users
WHERE email = $1
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
LIMIT 1
"""
role = await pg_client.fetch_scalar(query, user_email, tenant_domain)
return role or "student"
def can_share_to_organization(user_role: str) -> bool:
"""
Check if a user can share resources at the organization level.
Only admin and developer roles can share to organization.
"""
return user_role in ADMIN_ROLES
def validate_visibility_permission(visibility: str, user_role: str) -> None:
"""
Validate that the user has permission to set the given visibility level.
Raises HTTPException if not authorized.
Rules:
- admin/developer: Can set individual or organization visibility
- analyst/student: Can only set individual visibility
"""
if visibility == VISIBILITY_ORGANIZATION and not can_share_to_organization(user_role):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Only admin and developer users can share resources to organization. Your role: {user_role}"
)
def can_edit_resource(resource_creator_id: str, current_user_id: str, user_role: str, resource_visibility: str) -> bool:
"""
Check if user can edit a resource.
Rules:
- Owner can always edit their own resources
- Admin/developer can edit any resource
- Organization-shared resources: read-only for non-admins who didn't create it
"""
# Admin and developer can edit anything
if user_role in ADMIN_ROLES:
return True
# Owner can always edit
if resource_creator_id == current_user_id:
return True
# Organization resources are read-only for non-admins
return False
def can_delete_resource(resource_creator_id: str, current_user_id: str, user_role: str) -> bool:
"""
Check if user can delete a resource.
Rules:
- Owner can delete their own resources
- Admin/developer can delete any resource
- Others cannot delete
"""
# Admin and developer can delete anything
if user_role in ADMIN_ROLES:
return True
# Owner can delete
if resource_creator_id == current_user_id:
return True
return False
def is_effective_owner(resource_creator_id: str, current_user_id: str, user_role: str) -> bool:
"""
Check if user is effective owner of a resource.
Effective owners have identical access to actual owners:
- Actual resource creator
- Admin/developer users (tenant admins)
This determines whether user gets owner-level field visibility in ResponseFilter
and whether they can perform owner-only actions like sharing.
Note: Tenant isolation is enforced at query level via tenant_id checks.
This function only determines ownership semantics within the tenant.
Args:
resource_creator_id: UUID of resource creator
current_user_id: UUID of current user
user_role: User's role in tenant (admin, developer, analyst, student)
Returns:
True if user should have owner-level access
Examples:
>>> is_effective_owner("user123", "admin456", "admin")
True # Admin has owner-level access to all resources
>>> is_effective_owner("user123", "user123", "student")
True # Actual owner
>>> is_effective_owner("user123", "user456", "analyst")
False # Different user, not admin
"""
# Admins and developers have identical access to owners
if user_role in ADMIN_ROLES:
return True
# Actual owner
return resource_creator_id == current_user_id

View File

@@ -0,0 +1,498 @@
"""
GT 2.0 PostgreSQL + PGVector Client for Tenant Backend
Replaces DuckDB service with direct PostgreSQL connections, providing:
- PostgreSQL + PGVector unified storage (replaces DuckDB + ChromaDB)
- BionicGPT Row Level Security patterns for enterprise isolation
- MVCC concurrency solving DuckDB file locking issues
- Hybrid vector + full-text search in single queries
- Connection pooling for 10,000+ concurrent connections
"""
import asyncio
import logging
from typing import Dict, List, Optional, Any, AsyncGenerator, Tuple, Union
from contextlib import asynccontextmanager
import json
from datetime import datetime
from uuid import UUID
import asyncpg
from asyncpg import Pool, Connection
from asyncpg.exceptions import PostgresError
from app.core.config import get_settings, get_tenant_schema_name
logger = logging.getLogger(__name__)
class PostgreSQLClient:
"""PostgreSQL + PGVector client for tenant backend operations"""
def __init__(self, database_url: str, tenant_domain: str):
self.database_url = database_url
self.tenant_domain = tenant_domain
self.schema_name = get_tenant_schema_name(tenant_domain)
self._pool: Optional[Pool] = None
self._initialized = False
async def __aenter__(self):
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async def initialize(self) -> None:
"""Initialize connection pool and verify schema"""
if self._initialized:
return
logger.info(f"Initializing PostgreSQL connection pool for tenant: {self.tenant_domain}")
logger.info(f"Schema: {self.schema_name}, URL: {self.database_url}")
try:
# Create connection pool with resilient settings
# Sized for 100+ concurrent users with RAG/vector search workloads
self._pool = await asyncpg.create_pool(
self.database_url,
min_size=10,
max_size=50, # Increased from 20 to handle 100+ concurrent users
command_timeout=120, # Increased from 60s for queries under load
timeout=10, # Connection acquire timeout increased for high load
max_inactive_connection_lifetime=3600, # Recycle connections after 1 hour
server_settings={
'application_name': f'gt2_tenant_{self.tenant_domain}'
},
# Enable prepared statements for direct postgres connection (performance gain)
statement_cache_size=100
)
# Verify schema exists and has required tables
await self._verify_schema()
self._initialized = True
logger.info(f"PostgreSQL client initialized successfully for tenant: {self.tenant_domain}")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL client: {e}")
if self._pool:
await self._pool.close()
self._pool = None
raise
async def close(self) -> None:
"""Close connection pool"""
if self._pool:
await self._pool.close()
self._pool = None
self._initialized = False
logger.info(f"PostgreSQL connection pool closed for tenant: {self.tenant_domain}")
async def _verify_schema(self) -> None:
"""Verify tenant schema exists and has required tables"""
async with self._pool.acquire() as conn:
# Check if schema exists
schema_exists = await conn.fetchval("""
SELECT EXISTS (
SELECT 1 FROM information_schema.schemata
WHERE schema_name = $1
)
""", self.schema_name)
if not schema_exists:
raise RuntimeError(f"Tenant schema '{self.schema_name}' does not exist. Run schema initialization first.")
# Check for required tables
required_tables = ['tenants', 'users', 'agents', 'datasets', 'conversations', 'messages', 'documents', 'document_chunks']
for table in required_tables:
table_exists = await conn.fetchval(f"""
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = $1 AND table_name = $2
)
""", self.schema_name, table)
if not table_exists:
logger.warning(f"Table '{table}' not found in schema '{self.schema_name}'")
logger.info(f"Schema verification complete for tenant: {self.tenant_domain}")
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator[Connection, None]:
"""Get a connection from the pool"""
if not self._pool:
raise RuntimeError("PostgreSQL client not initialized. Call initialize() first.")
async with self._pool.acquire() as conn:
try:
# Set schema search path for this connection
await conn.execute(f"SET search_path TO {self.schema_name}, public")
# Session variable logging removed - no longer using RLS
yield conn
except Exception as e:
logger.error(f"Database connection error: {e}")
raise
async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
"""Execute a SELECT query and return results"""
async with self.get_connection() as conn:
try:
rows = await conn.fetch(query, *args)
return [dict(row) for row in rows]
except PostgresError as e:
logger.error(f"Query execution failed: {e}, Query: {query}")
raise
async def execute_command(self, command: str, *args) -> int:
"""Execute an INSERT/UPDATE/DELETE command and return affected rows"""
async with self.get_connection() as conn:
try:
result = await conn.execute(command, *args)
# Parse result like "INSERT 0 5" to get affected rows
return int(result.split()[-1]) if result else 0
except PostgresError as e:
logger.error(f"Command execution failed: {e}, Command: {command}")
raise
async def fetch_one(self, query: str, *args) -> Optional[Dict[str, Any]]:
"""Execute query and return first row"""
async with self.get_connection() as conn:
try:
row = await conn.fetchrow(query, *args)
return dict(row) if row else None
except PostgresError as e:
logger.error(f"Fetch one failed: {e}, Query: {query}")
raise
async def fetch_scalar(self, query: str, *args) -> Any:
"""Execute query and return single value"""
async with self.get_connection() as conn:
try:
return await conn.fetchval(query, *args)
except PostgresError as e:
logger.error(f"Fetch scalar failed: {e}, Query: {query}")
raise
async def execute_transaction(self, commands: List[Tuple[str, tuple]]) -> List[int]:
"""Execute multiple commands in a transaction"""
async with self.get_connection() as conn:
async with conn.transaction():
results = []
for command, args in commands:
try:
result = await conn.execute(command, *args)
results.append(int(result.split()[-1]) if result else 0)
except PostgresError as e:
logger.error(f"Transaction command failed: {e}, Command: {command}")
raise
return results
# Vector Search Operations (PGVector)
async def vector_similarity_search(
self,
query_vector: List[float],
table: str = "document_chunks",
limit: int = 10,
similarity_threshold: float = 0.3,
user_id: Optional[str] = None,
dataset_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Perform vector similarity search using PGVector"""
# Convert Python list to PostgreSQL array format
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
query = f"""
SELECT
id,
content,
1 - (embedding <=> $1::vector) as similarity_score,
metadata
FROM {table}
WHERE embedding IS NOT NULL
AND 1 - (embedding <=> $1::vector) > $2
"""
params = [vector_str, similarity_threshold]
param_idx = 3
# Add user isolation if specified
if user_id:
query += f" AND user_id = ${param_idx}"
params.append(user_id)
param_idx += 1
# Add dataset filtering if specified
if dataset_id:
query += f" AND dataset_id = ${param_idx}"
params.append(dataset_id)
param_idx += 1
query += f" ORDER BY embedding <=> $1::vector LIMIT ${param_idx}"
params.append(limit)
return await self.execute_query(query, *params)
async def hybrid_search(
self,
query_text: str,
query_vector: List[float],
user_id: str,
limit: int = 10,
similarity_threshold: float = 0.3,
text_weight: float = 0.3,
vector_weight: float = 0.7,
dataset_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Perform hybrid search combining vector similarity and full-text search"""
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
# Use the enhanced_hybrid_search_chunks function from BionicGPT integration
query = """
SELECT
id,
document_id,
content,
similarity_score,
text_rank,
combined_score,
metadata,
access_verified
FROM enhanced_hybrid_search_chunks($1, $2::vector, $3::uuid, $4, $5, $6, $7, $8)
"""
return await self.execute_query(
query,
query_text,
vector_str,
user_id,
dataset_id,
limit,
similarity_threshold,
text_weight,
vector_weight
)
async def insert_document_chunk(
self,
document_id: str,
tenant_id: int,
user_id: str,
chunk_index: int,
content: str,
content_hash: str,
embedding: List[float],
dataset_id: Optional[str] = None,
token_count: int = 0,
metadata: Optional[Dict] = None
) -> str:
"""Insert a document chunk with vector embedding"""
vector_str = '[' + ','.join(map(str, embedding)) + ']'
metadata_json = json.dumps(metadata or {})
query = """
INSERT INTO document_chunks (
document_id, tenant_id, user_id, dataset_id, chunk_index,
content, content_hash, token_count, embedding, metadata
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector, $10::jsonb)
RETURNING id
"""
return await self.fetch_scalar(
query,
document_id, tenant_id, user_id, dataset_id, chunk_index,
content, content_hash, token_count, vector_str, metadata_json
)
# Health Check and Statistics
async def health_check(self) -> Dict[str, Any]:
"""Perform health check on PostgreSQL connection"""
try:
if not self._pool:
return {"status": "unhealthy", "reason": "Connection pool not initialized"}
# Test basic connectivity
test_result = await self.fetch_scalar("SELECT 1")
# Get pool statistics
pool_stats = {
"size": self._pool.get_size(),
"min_size": self._pool.get_min_size(),
"max_size": self._pool.get_max_size(),
"idle_size": self._pool.get_idle_size()
}
# Test schema access
schema_test = await self.fetch_scalar("""
SELECT EXISTS (
SELECT 1 FROM information_schema.schemata
WHERE schema_name = $1
)
""", self.schema_name)
return {
"status": "healthy" if test_result == 1 and schema_test else "degraded",
"connectivity": "ok" if test_result == 1 else "failed",
"schema_access": "ok" if schema_test else "failed",
"tenant_domain": self.tenant_domain,
"schema_name": self.schema_name,
"pool_stats": pool_stats,
"database_type": "postgresql_pgvector"
}
except Exception as e:
logger.error(f"PostgreSQL health check failed: {e}")
return {"status": "unhealthy", "reason": str(e)}
async def get_database_stats(self) -> Dict[str, Any]:
"""Get database statistics for monitoring"""
try:
# Get table counts and sizes
stats_query = """
SELECT
schemaname,
tablename,
n_tup_ins as inserts,
n_tup_upd as updates,
n_tup_del as deletes,
n_live_tup as live_tuples,
n_dead_tup as dead_tuples
FROM pg_stat_user_tables
WHERE schemaname = $1
"""
table_stats = await self.execute_query(stats_query, self.schema_name)
# Get total schema size
size_query = """
SELECT pg_size_pretty(
SUM(pg_total_relation_size(quote_ident(schemaname)||'.'||quote_ident(tablename)))
) as schema_size
FROM pg_tables
WHERE schemaname = $1
"""
schema_size = await self.fetch_scalar(size_query, self.schema_name)
# Get vector index statistics if available
vector_stats_query = """
SELECT
COUNT(*) as vector_count,
AVG(vector_dims(embedding)) as avg_dimensions
FROM document_chunks
WHERE embedding IS NOT NULL
"""
try:
vector_stats = await self.fetch_one(vector_stats_query)
except:
vector_stats = {"vector_count": 0, "avg_dimensions": 0}
return {
"tenant_domain": self.tenant_domain,
"schema_name": self.schema_name,
"schema_size": schema_size,
"table_stats": table_stats,
"vector_stats": vector_stats,
"engine_type": "PostgreSQL + PGVector",
"mvcc_enabled": True,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Failed to get database statistics: {e}")
return {"error": str(e)}
# Global client instance (singleton pattern for tenant backend)
_pg_client: Optional[PostgreSQLClient] = None
async def get_postgresql_client() -> PostgreSQLClient:
"""Get or create PostgreSQL client instance"""
global _pg_client
if not _pg_client:
settings = get_settings()
_pg_client = PostgreSQLClient(
database_url=settings.database_url,
tenant_domain=settings.tenant_domain
)
await _pg_client.initialize()
return _pg_client
async def init_postgresql() -> None:
"""Initialize PostgreSQL client during startup"""
logger.info("Initializing PostgreSQL client...")
await get_postgresql_client()
logger.info("PostgreSQL client initialized successfully")
async def close_postgresql() -> None:
"""Close PostgreSQL client during shutdown"""
global _pg_client
if _pg_client:
await _pg_client.close()
_pg_client = None
logger.info("PostgreSQL client closed")
# Context manager for database operations
@asynccontextmanager
async def get_db_session():
"""Async context manager for database operations"""
client = await get_postgresql_client()
async with client.get_connection() as conn:
yield conn
# Convenience functions for common operations
async def execute_query(query: str, *args) -> List[Dict[str, Any]]:
"""Execute a SELECT query"""
client = await get_postgresql_client()
return await client.execute_query(query, *args)
async def execute_command(command: str, *args) -> int:
"""Execute an INSERT/UPDATE/DELETE command"""
client = await get_postgresql_client()
return await client.execute_command(command, *args)
async def fetch_one(query: str, *args) -> Optional[Dict[str, Any]]:
"""Execute query and return first row"""
client = await get_postgresql_client()
return await client.fetch_one(query, *args)
async def fetch_scalar(query: str, *args) -> Any:
"""Execute query and return single value"""
client = await get_postgresql_client()
return await client.fetch_scalar(query, *args)
async def health_check() -> Dict[str, Any]:
"""Perform database health check"""
try:
client = await get_postgresql_client()
return await client.health_check()
except Exception as e:
return {"status": "unhealthy", "reason": str(e)}
async def get_database_info() -> Dict[str, Any]:
"""Get database information and statistics"""
try:
client = await get_postgresql_client()
return await client.get_database_stats()
except Exception as e:
return {"error": str(e)}

View File

@@ -0,0 +1,531 @@
"""
Resource Cluster Client for GT 2.0 Tenant Backend
Provides stateless access to Resource Cluster services including:
- Document processing
- Embedding generation
- Vector storage (ChromaDB)
- Model inference
Perfect tenant isolation with capability-based authentication.
"""
import logging
import asyncio
import aiohttp
import json
import gc
from typing import Dict, Any, List, Optional, AsyncGenerator
from datetime import datetime
from app.core.config import get_settings
from app.core.capability_client import CapabilityClient
logger = logging.getLogger(__name__)
class ResourceClusterClient:
"""
Client for accessing Resource Cluster services with capability-based auth.
GT 2.0 Security Principles:
- Capability tokens for fine-grained access control
- Stateless operations (no data persistence in Resource Cluster)
- Perfect tenant isolation
- Immediate memory cleanup
"""
def __init__(self):
self.settings = get_settings()
self.capability_client = CapabilityClient()
# Resource Cluster endpoints
# IMPORTANT: Use Docker service name for stability across container restarts
# Fixed 2025-09-12: Changed from hardcoded IP to service name for reliability
self.base_url = getattr(
self.settings,
'resource_cluster_url', # Matches Pydantic field name (case insensitive)
'http://gentwo-resource-backend:8000' # Fallback uses service name, not IP
)
self.endpoints = {
'document_processor': f"{self.base_url}/api/v1/process/document",
'embedding_generator': f"{self.base_url}/api/v1/embeddings/generate",
'chromadb_backend': f"{self.base_url}/api/v1/vectors",
'inference': f"{self.base_url}/api/v1/ai/chat/completions" # Updated to match actual endpoint
}
# Request timeouts
self.request_timeout = 300 # seconds - 5 minutes for complex agent operations
self.upload_timeout = 300 # seconds for large documents
logger.info("Resource Cluster client initialized")
async def _get_capability_token(
self,
tenant_id: str,
user_id: str,
resources: List[str]
) -> str:
"""Generate capability token for Resource Cluster access"""
try:
token = await self.capability_client.generate_capability_token(
user_email=user_id, # Using user_id as email for now
tenant_id=tenant_id,
resources=resources,
expires_hours=1
)
return token
except Exception as e:
logger.error(f"Failed to generate capability token: {e}")
raise
async def _make_request(
self,
method: str,
endpoint: str,
data: Dict[str, Any],
tenant_id: str,
user_id: str,
resources: List[str],
timeout: int = None
) -> Dict[str, Any]:
"""Make authenticated request to Resource Cluster"""
try:
# Get capability token
token = await self._get_capability_token(tenant_id, user_id, resources)
# Prepare headers
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {token}',
'X-Tenant-ID': tenant_id,
'X-User-ID': user_id,
'X-Request-ID': f"{tenant_id}_{user_id}_{datetime.utcnow().timestamp()}"
}
# Make request
timeout_config = aiohttp.ClientTimeout(total=timeout or self.request_timeout)
async with aiohttp.ClientSession(timeout=timeout_config) as session:
async with session.request(
method=method.upper(),
url=endpoint,
json=data,
headers=headers
) as response:
if response.status not in [200, 201]:
error_text = await response.text()
raise RuntimeError(
f"Resource Cluster error: {response.status} - {error_text}"
)
result = await response.json()
return result
except Exception as e:
logger.error(f"Resource Cluster request failed: {e}")
raise
# Document Processing
async def process_document(
self,
content: bytes,
document_type: str,
strategy_type: str = "hybrid",
tenant_id: str = None,
user_id: str = None
) -> List[Dict[str, Any]]:
"""Process document into chunks via Resource Cluster"""
try:
# Convert bytes to base64 for JSON transport
import base64
content_b64 = base64.b64encode(content).decode('utf-8')
request_data = {
"content": content_b64,
"document_type": document_type,
"strategy": {
"strategy_type": strategy_type,
"chunk_size": 512,
"chunk_overlap": 128
}
}
# Clear original content from memory
del content
gc.collect()
result = await self._make_request(
method='POST',
endpoint=self.endpoints['document_processor'],
data=request_data,
tenant_id=tenant_id,
user_id=user_id,
resources=['document_processing'],
timeout=self.upload_timeout
)
chunks = result.get('chunks', [])
logger.info(f"Processed document into {len(chunks)} chunks")
return chunks
except Exception as e:
logger.error(f"Document processing failed: {e}")
gc.collect()
raise
# Embedding Generation
async def generate_document_embeddings(
self,
documents: List[str],
tenant_id: str,
user_id: str
) -> List[List[float]]:
"""Generate embeddings for documents"""
try:
request_data = {
"texts": documents,
"model": "BAAI/bge-m3",
"instruction": None # Document embeddings don't need instruction
}
result = await self._make_request(
method='POST',
endpoint=self.endpoints['embedding_generator'],
data=request_data,
tenant_id=tenant_id,
user_id=user_id,
resources=['embedding_generation']
)
embeddings = result.get('embeddings', [])
# Clear documents from memory
del documents
gc.collect()
logger.info(f"Generated {len(embeddings)} document embeddings")
return embeddings
except Exception as e:
logger.error(f"Document embedding generation failed: {e}")
gc.collect()
raise
async def generate_query_embeddings(
self,
queries: List[str],
tenant_id: str,
user_id: str
) -> List[List[float]]:
"""Generate embeddings for queries"""
try:
request_data = {
"texts": queries,
"model": "BAAI/bge-m3",
"instruction": "Represent this sentence for searching relevant passages: "
}
result = await self._make_request(
method='POST',
endpoint=self.endpoints['embedding_generator'],
data=request_data,
tenant_id=tenant_id,
user_id=user_id,
resources=['embedding_generation']
)
embeddings = result.get('embeddings', [])
# Clear queries from memory
del queries
gc.collect()
logger.info(f"Generated {len(embeddings)} query embeddings")
return embeddings
except Exception as e:
logger.error(f"Query embedding generation failed: {e}")
gc.collect()
raise
# Vector Storage (ChromaDB)
async def create_vector_collection(
self,
tenant_id: str,
user_id: str,
dataset_name: str,
metadata: Optional[Dict[str, Any]] = None
) -> bool:
"""Create vector collection in ChromaDB"""
try:
request_data = {
"tenant_id": tenant_id,
"user_id": user_id,
"dataset_name": dataset_name,
"metadata": metadata or {}
}
result = await self._make_request(
method='POST',
endpoint=f"{self.endpoints['chromadb_backend']}/collections",
data=request_data,
tenant_id=tenant_id,
user_id=user_id,
resources=['vector_storage']
)
success = result.get('success', False)
logger.info(f"Created vector collection for {dataset_name}: {success}")
return success
except Exception as e:
logger.error(f"Vector collection creation failed: {e}")
raise
async def store_vectors(
self,
tenant_id: str,
user_id: str,
dataset_name: str,
documents: List[str],
embeddings: List[List[float]],
metadata: List[Dict[str, Any]] = None,
ids: List[str] = None
) -> bool:
"""Store vectors in ChromaDB"""
try:
request_data = {
"tenant_id": tenant_id,
"user_id": user_id,
"dataset_name": dataset_name,
"documents": documents,
"embeddings": embeddings,
"metadata": metadata or [],
"ids": ids
}
result = await self._make_request(
method='POST',
endpoint=f"{self.endpoints['chromadb_backend']}/store",
data=request_data,
tenant_id=tenant_id,
user_id=user_id,
resources=['vector_storage']
)
# Clear vectors from memory immediately
del documents, embeddings
gc.collect()
success = result.get('success', False)
logger.info(f"Stored vectors in {dataset_name}: {success}")
return success
except Exception as e:
logger.error(f"Vector storage failed: {e}")
gc.collect()
raise
async def search_vectors(
self,
tenant_id: str,
user_id: str,
dataset_name: str,
query_embedding: List[float],
top_k: int = 5,
filter_metadata: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""Search vectors in ChromaDB"""
try:
request_data = {
"tenant_id": tenant_id,
"user_id": user_id,
"dataset_name": dataset_name,
"query_embedding": query_embedding,
"top_k": top_k,
"filter_metadata": filter_metadata or {}
}
result = await self._make_request(
method='POST',
endpoint=f"{self.endpoints['chromadb_backend']}/search",
data=request_data,
tenant_id=tenant_id,
user_id=user_id,
resources=['vector_storage']
)
# Clear query embedding from memory
del query_embedding
gc.collect()
results = result.get('results', [])
logger.info(f"Found {len(results)} vector search results")
return results
except Exception as e:
logger.error(f"Vector search failed: {e}")
gc.collect()
raise
async def delete_vector_collection(
self,
tenant_id: str,
user_id: str,
dataset_name: str
) -> bool:
"""Delete vector collection from ChromaDB"""
try:
request_data = {
"tenant_id": tenant_id,
"user_id": user_id,
"dataset_name": dataset_name
}
result = await self._make_request(
method='DELETE',
endpoint=f"{self.endpoints['chromadb_backend']}/collections",
data=request_data,
tenant_id=tenant_id,
user_id=user_id,
resources=['vector_storage']
)
success = result.get('success', False)
logger.info(f"Deleted vector collection {dataset_name}: {success}")
return success
except Exception as e:
logger.error(f"Vector collection deletion failed: {e}")
raise
# Model Inference
async def inference_with_context(
self,
messages: List[Dict[str, str]],
context: str,
model: str = "llama-3.1-70b-versatile",
tenant_id: str = None,
user_id: str = None
) -> Dict[str, Any]:
"""Perform inference with RAG context"""
try:
# Inject context into system message
enhanced_messages = []
system_context = f"Use the following context to answer the user's question:\n\n{context}\n\n"
for msg in messages:
if msg.get("role") == "system":
enhanced_msg = msg.copy()
enhanced_msg["content"] = system_context + enhanced_msg["content"]
enhanced_messages.append(enhanced_msg)
else:
enhanced_messages.append(msg)
# Add system message if none exists
if not any(msg.get("role") == "system" for msg in enhanced_messages):
enhanced_messages.insert(0, {
"role": "system",
"content": system_context + "You are a helpful AI agent."
})
request_data = {
"messages": enhanced_messages,
"model": model,
"temperature": 0.7,
"max_tokens": 4000,
"user_id": user_id,
"tenant_id": tenant_id
}
result = await self._make_request(
method='POST',
endpoint=self.endpoints['inference'],
data=request_data,
tenant_id=tenant_id,
user_id=user_id,
resources=['llm_inference']
)
# Clear context from memory
del context, enhanced_messages
gc.collect()
return result
except Exception as e:
logger.error(f"Inference with context failed: {e}")
gc.collect()
raise
async def check_health(self) -> Dict[str, Any]:
"""Check Resource Cluster health"""
try:
# Test basic connectivity
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.base_url}/health") as response:
if response.status == 200:
health_data = await response.json()
return {
"status": "healthy",
"resource_cluster": health_data,
"endpoints": list(self.endpoints.keys()),
"base_url": self.base_url
}
else:
return {
"status": "unhealthy",
"error": f"Health check failed: {response.status}",
"base_url": self.base_url
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
"base_url": self.base_url
}
async def call_inference_endpoint(
self,
tenant_id: str,
user_id: str,
endpoint: str = "chat/completions",
data: Dict[str, Any] = None
) -> Dict[str, Any]:
"""Call AI inference endpoint on Resource Cluster"""
try:
# Use the direct inference endpoint
inference_url = self.endpoints['inference']
# Add tenant/user context to request
request_data = data.copy() if data else {}
# Make request with capability token
result = await self._make_request(
method='POST',
endpoint=inference_url,
data=request_data,
tenant_id=tenant_id,
user_id=user_id,
resources=['llm'] # Use valid ResourceType from resource cluster
)
return result
except Exception as e:
logger.error(f"Inference endpoint call failed: {e}")
raise
# Streaming removed for reliability - using non-streaming only

View File

@@ -0,0 +1,320 @@
"""
Response Filtering Utilities for GT 2.0
Provides field-level authorization and data filtering for API responses.
Implements principle of least privilege - users only see data they're authorized to access.
Security principles:
1. Owner-only fields: resource_preferences, advanced RAG configs (max_chunks_per_query, history_context)
2. Viewer fields: Public + usage stats + prompt_template + personality_config + dataset connections
(Team members with read access need these fields to effectively use shared agents)
3. Public fields: id, name, description, category, basic metadata
4. No internal UUIDs, implementation details, or system configuration exposure
"""
from typing import Dict, Any, List, Optional, Set
import logging
logger = logging.getLogger(__name__)
class ResponseFilter:
"""Filter API responses based on user permissions and access level"""
# Define field access levels for agents
# REQUIRED fields that must always be present for AgentResponse schema
AGENT_REQUIRED_FIELDS = {
'id', 'name', 'description', 'created_at', 'updated_at'
}
AGENT_PUBLIC_FIELDS = AGENT_REQUIRED_FIELDS | {
'category', 'conversation_count', 'usage_count', 'is_favorite', 'tags',
'created_by_name', 'can_edit', 'can_delete', 'is_owner',
# Include these for display purposes
'model', 'visibility', 'disclaimer', 'easy_prompts',
# Dataset connections for showing dataset count on agent tiles
'dataset_connection', 'selected_dataset_ids'
}
AGENT_VIEWER_FIELDS = AGENT_PUBLIC_FIELDS | {
'temperature', 'max_tokens', 'total_cost_cents', 'template_id',
# Essential fields for using shared agents (team collaboration)
'prompt_template', 'personality_config',
'dataset_connection', 'selected_dataset_ids'
}
AGENT_OWNER_FIELDS = AGENT_VIEWER_FIELDS | {
# Advanced configuration fields (owner-only)
'resource_preferences', 'max_chunks_per_query', 'history_context',
# Team sharing configuration (owner-only for editing)
'team_shares'
}
# Define field access levels for datasets
# Fields for all users (public/shared datasets) - stats are informational, not sensitive
DATASET_PUBLIC_FIELDS = {
'id', 'name', 'description', 'created_by_name', 'owner_name',
'document_count', 'chunk_count', 'vector_count', 'storage_size_mb',
'tags', 'created_at', 'updated_at', 'access_group',
# Permission flags for UI controls
'is_owner', 'can_edit', 'can_delete', 'can_share',
# Team sharing flag for proper visibility indicators
'shared_via_team'
}
DATASET_VIEWER_FIELDS = DATASET_PUBLIC_FIELDS | {
'summary' # Viewers can see dataset summary
}
DATASET_OWNER_FIELDS = DATASET_VIEWER_FIELDS | {
# Only owners see internal configuration
'owner_id', 'team_members', 'chunking_strategy', 'chunk_size',
'chunk_overlap', 'embedding_model', 'summary_generated_at',
# Team sharing configuration (owner-only for editing)
'team_shares'
}
# Define field access levels for files
# Public fields include processing info since it's informational metadata, not sensitive
FILE_PUBLIC_FIELDS = {
'id', 'original_filename', 'content_type', 'file_type', 'file_size', 'file_size_bytes',
'created_at', 'updated_at', 'category',
# Processing fields - informational, not sensitive
'processing_status', 'chunk_count', 'processing_progress', 'processing_stage',
# Permission flags for UI controls
'can_delete'
}
FILE_OWNER_FIELDS = FILE_PUBLIC_FIELDS | {
'user_id', 'dataset_id', 'storage_path', 'metadata'
}
@staticmethod
def filter_agent_response(
agent_data: Dict[str, Any],
is_owner: bool = False,
can_view: bool = True
) -> Dict[str, Any]:
"""
Filter agent response fields based on user permissions
Args:
agent_data: Full agent data dictionary
is_owner: Whether user owns this agent
can_view: Whether user can view detailed information
Returns:
Filtered dictionary with only authorized fields
"""
if is_owner:
allowed_fields = ResponseFilter.AGENT_OWNER_FIELDS
logger.info(f"🔓 Agent '{agent_data.get('name', 'Unknown')}': Using OWNER fields (is_owner=True, can_view={can_view})")
elif can_view:
allowed_fields = ResponseFilter.AGENT_VIEWER_FIELDS
logger.info(f"👁️ Agent '{agent_data.get('name', 'Unknown')}': Using VIEWER fields (is_owner=False, can_view=True)")
else:
allowed_fields = ResponseFilter.AGENT_PUBLIC_FIELDS
logger.info(f"🌍 Agent '{agent_data.get('name', 'Unknown')}': Using PUBLIC fields (is_owner=False, can_view=False)")
filtered = {
key: value for key, value in agent_data.items()
if key in allowed_fields
}
# Ensure defaults for optional fields that were filtered out
# This prevents AgentResponse schema validation errors
default_values = {
'personality_config': {},
'resource_preferences': {},
'tags': [],
'easy_prompts': [],
'conversation_count': 0,
'usage_count': 0,
'total_cost_cents': 0,
'is_favorite': False,
'can_edit': False,
'can_delete': False,
'is_owner': is_owner
}
for key, default_value in default_values.items():
if key not in filtered:
filtered[key] = default_value
# Log field filtering for security audit
removed_fields = set(agent_data.keys()) - set(filtered.keys())
if removed_fields:
logger.info(
f"🔒 Filtered agent '{agent_data.get('name', 'Unknown')}' - removed fields: {removed_fields} "
f"(is_owner={is_owner}, can_view={can_view})"
)
# Special logging for prompt_template field
if 'prompt_template' in agent_data:
if 'prompt_template' in filtered:
logger.info(f"✅ Agent '{agent_data.get('name', 'Unknown')}': prompt_template INCLUDED in response")
else:
logger.warning(f"❌ Agent '{agent_data.get('name', 'Unknown')}': prompt_template FILTERED OUT (is_owner={is_owner}, can_view={can_view})")
return filtered
@staticmethod
def filter_dataset_response(
dataset_data: Dict[str, Any],
is_owner: bool = False,
can_view: bool = True
) -> Dict[str, Any]:
"""
Filter dataset response fields based on user permissions
Args:
dataset_data: Full dataset data dictionary
is_owner: Whether user owns this dataset
can_view: Whether user can view the dataset
Returns:
Filtered dictionary with only authorized fields
"""
if is_owner:
allowed_fields = ResponseFilter.DATASET_OWNER_FIELDS
elif can_view:
allowed_fields = ResponseFilter.DATASET_VIEWER_FIELDS
else:
allowed_fields = ResponseFilter.DATASET_PUBLIC_FIELDS
filtered = {
key: value for key, value in dataset_data.items()
if key in allowed_fields
}
# Security: Never expose owner_id UUID to non-owners
if not is_owner and 'owner_id' in filtered:
del filtered['owner_id']
# Ensure defaults for optional fields to prevent schema validation errors
default_values = {
'tags': [],
'is_owner': is_owner,
'can_edit': False,
'can_delete': False,
'can_share': False,
# Always set these to None for non-owners (security)
'team_members': None if not is_owner else filtered.get('team_members', []),
'owner_id': None if not is_owner else filtered.get('owner_id'),
# Internal fields - null for all except detail view
'agent_has_access': None,
'user_owns': None,
# Stats fields - use actual values or safe defaults for frontend compatibility
# These are informational only, not sensitive
'chunk_count': filtered.get('chunk_count', 0),
'vector_count': filtered.get('vector_count', 0),
'storage_size_mb': filtered.get('storage_size_mb', 0.0),
'updated_at': filtered.get('updated_at'),
'summary': None
}
for key, default_value in default_values.items():
if key not in filtered:
filtered[key] = default_value
# Log field filtering for security audit
removed_fields = set(dataset_data.keys()) - set(filtered.keys())
if removed_fields:
logger.debug(
f"Filtered dataset response - removed fields: {removed_fields} "
f"(is_owner={is_owner}, can_view={can_view})"
)
return filtered
@staticmethod
def filter_file_response(
file_data: Dict[str, Any],
is_owner: bool = False
) -> Dict[str, Any]:
"""
Filter file response fields based on user permissions
Args:
file_data: Full file data dictionary
is_owner: Whether user owns this file
Returns:
Filtered dictionary with only authorized fields
"""
allowed_fields = (
ResponseFilter.FILE_OWNER_FIELDS if is_owner
else ResponseFilter.FILE_PUBLIC_FIELDS
)
filtered = {
key: value for key, value in file_data.items()
if key in allowed_fields
}
# Log field filtering for security audit
removed_fields = set(file_data.keys()) - set(filtered.keys())
if removed_fields:
logger.debug(
f"Filtered file response - removed fields: {removed_fields} "
f"(is_owner={is_owner})"
)
return filtered
@staticmethod
def filter_batch_responses(
items: List[Dict[str, Any]],
filter_func: callable,
ownership_map: Optional[Dict[str, bool]] = None
) -> List[Dict[str, Any]]:
"""
Filter a batch of items using the provided filter function
Args:
items: List of items to filter
filter_func: Function to apply to each item (e.g., filter_agent_response)
ownership_map: Optional map of item_id -> is_owner boolean
Returns:
List of filtered items
"""
filtered_items = []
for item in items:
item_id = item.get('id')
is_owner = ownership_map.get(item_id, False) if ownership_map else False
filtered_item = filter_func(item, is_owner=is_owner)
filtered_items.append(filtered_item)
return filtered_items
@staticmethod
def sanitize_dataset_summary(
summary_data: Dict[str, Any],
user_can_access: bool = True
) -> Optional[Dict[str, Any]]:
"""
Sanitize dataset summary for inclusion in chat context
Args:
summary_data: Dataset summary with metadata
user_can_access: Whether user should have access to this dataset
Returns:
Sanitized summary or None if user shouldn't access
"""
if not user_can_access:
return None
# Only include safe fields in summary
safe_fields = {
'id', 'name', 'description', 'summary',
'document_count', 'chunk_count'
}
return {
key: value for key, value in summary_data.items()
if key in safe_fields
}

View File

@@ -0,0 +1,314 @@
"""
Security module for GT 2.0 Tenant Backend
Provides JWT capability token verification and user authentication.
"""
import os
import jwt
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
from fastapi import Header
import logging
logger = logging.getLogger(__name__)
def get_jwt_secret() -> str:
"""Get JWT secret from environment variable.
The JWT_SECRET is auto-generated by installers using:
openssl rand -hex 32
This provides a 256-bit secret suitable for HS256 signing.
"""
secret = os.environ.get('JWT_SECRET')
if not secret:
raise ValueError("JWT_SECRET environment variable is required. Run the installer to generate one.")
return secret
def verify_capability_token(token: str) -> Optional[Dict[str, Any]]:
"""
Verify JWT capability token using HS256 symmetric key
Args:
token: JWT token string
Returns:
Token payload if valid, None otherwise
"""
try:
secret = get_jwt_secret()
# Verify token with HS256 symmetric key
payload = jwt.decode(token, secret, algorithms=["HS256"])
# Check expiration
if "exp" in payload:
if datetime.utcnow().timestamp() > payload["exp"]:
logger.warning("Token expired")
return None
return payload
except jwt.InvalidTokenError as e:
logger.warning(f"Invalid token: {e}")
return None
except Exception as e:
logger.error(f"Token verification error: {e}")
return None
def create_capability_token(
user_id: str,
tenant_id: str,
capabilities: list,
expires_hours: int = 4
) -> str:
"""
Create JWT capability token using HS256 symmetric key
Args:
user_id: User identifier
tenant_id: Tenant domain
capabilities: List of capability objects
expires_hours: Token expiration in hours
Returns:
JWT token string
"""
try:
secret = get_jwt_secret()
payload = {
"sub": user_id,
"email": user_id,
"user_type": "tenant_user",
# Current tenant context (primary structure)
"current_tenant": {
"id": tenant_id,
"domain": tenant_id,
"name": f"Tenant {tenant_id}",
"role": "tenant_user",
"display_name": user_id,
"email": user_id,
"is_primary": True,
"capabilities": capabilities
},
# Available tenants for tenant switching
"available_tenants": [{
"id": tenant_id,
"domain": tenant_id,
"name": f"Tenant {tenant_id}",
"role": "tenant_user"
}],
# Standard JWT fields
"iat": datetime.utcnow().timestamp(),
"exp": (datetime.utcnow() + timedelta(hours=expires_hours)).timestamp()
}
return jwt.encode(payload, secret, algorithm="HS256")
except Exception as e:
logger.error(f"Failed to create capability token: {e}")
raise ValueError("Failed to create capability token")
async def get_current_user(authorization: str = Header(None)) -> Dict[str, Any]:
"""
Get current user from authorization header - REQUIRED for all endpoints
Raises 401 if authentication fails - following GT 2.0 security principles
"""
from fastapi import HTTPException, status
if not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
headers={"WWW-Authenticate": "Bearer"}
)
if not authorization.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"}
)
# Extract token
token = authorization.replace("Bearer ", "")
payload = verify_capability_token(token)
if not payload:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"}
)
# Extract tenant context from new JWT structure
current_tenant = payload.get('current_tenant', {})
available_tenants = payload.get('available_tenants', [])
user_type = payload.get('user_type', 'tenant_user')
# For admin users, allow access to any tenant backend
if user_type == 'super_admin' and current_tenant.get('domain') == 'admin':
# Admin users accessing tenant backends - create tenant context for the current backend
from app.core.config import get_settings
settings = get_settings()
# Override the admin context with the current tenant backend's context
current_tenant = {
'id': settings.tenant_id,
'domain': settings.tenant_domain,
'name': f'Tenant {settings.tenant_domain}',
'role': 'super_admin',
'display_name': payload.get('email', 'Admin User'),
'email': payload.get('email'),
'is_primary': True,
'capabilities': [
{'resource': '*', 'actions': ['*'], 'constraints': {}},
]
}
logger.info(f"Admin user {payload.get('email')} accessing tenant backend {settings.tenant_domain}")
# Validate tenant context exists
if not current_tenant or not current_tenant.get('id'):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No valid tenant context in token",
headers={"WWW-Authenticate": "Bearer"}
)
# Return user dict with clean tenant context structure
return {
'sub': payload.get('sub'),
'email': payload.get('email'),
'user_id': payload.get('sub'),
'user_type': payload.get('user_type', 'tenant_user'),
# Current tenant context (primary structure)
'tenant_id': str(current_tenant.get('id')),
'tenant_domain': current_tenant.get('domain'),
'tenant_name': current_tenant.get('name'),
'tenant_role': current_tenant.get('role'),
'tenant_display_name': current_tenant.get('display_name'),
'tenant_email': current_tenant.get('email'),
'is_primary_tenant': current_tenant.get('is_primary', False),
# Tenant-specific capabilities
'capabilities': current_tenant.get('capabilities', []),
# Available tenants for tenant switching
'available_tenants': available_tenants
}
def get_current_user_email(authorization: str) -> str:
"""
Extract user email from authorization header
"""
if authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
payload = verify_capability_token(token)
if payload:
current_tenant = payload.get('current_tenant', {})
# Prefer tenant-specific email, fallback to user email, then sub
return (current_tenant.get('email') or
payload.get('email') or
payload.get('sub', 'test@example.com'))
return 'anonymous@example.com'
def get_tenant_info(authorization: str) -> Dict[str, str]:
"""
Extract tenant information from authorization header
"""
if authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
payload = verify_capability_token(token)
if payload:
current_tenant = payload.get('current_tenant', {})
if current_tenant:
return {
'tenant_id': str(current_tenant.get('id')),
'tenant_domain': current_tenant.get('domain'),
'tenant_name': current_tenant.get('name'),
'tenant_role': current_tenant.get('role')
}
return {
'tenant_id': 'default',
'tenant_domain': 'default',
'tenant_name': 'Default Tenant',
'tenant_role': 'tenant_user'
}
def verify_jwt_token(token: str) -> Optional[Dict[str, Any]]:
"""
Verify JWT token - alias for verify_capability_token
"""
return verify_capability_token(token)
async def get_user_context_unified(
authorization: Optional[str] = Header(None),
x_tenant_domain: Optional[str] = Header(None),
x_user_id: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
Unified authentication for both JWT (user requests) and header-based (service requests).
Supports two auth modes:
1. JWT Authentication: Authorization header with Bearer token (for direct user requests)
2. Header Authentication: X-Tenant-Domain + X-User-ID headers (for internal service requests)
Returns user context with tenant information for both modes.
"""
from fastapi import HTTPException, status
# Mode 1: Header-based authentication (for internal services like MCP)
if x_tenant_domain and x_user_id:
logger.info(f"Using header auth: tenant={x_tenant_domain}, user={x_user_id}")
return {
"tenant_domain": x_tenant_domain,
"tenant_id": x_tenant_domain,
"id": x_user_id,
"sub": x_user_id,
"email": x_user_id,
"user_id": x_user_id,
"user_type": "internal_service",
"tenant_role": "tenant_user"
}
# Mode 2: JWT authentication (for direct user requests)
if authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
payload = verify_capability_token(token)
if payload:
logger.info(f"Using JWT auth: user={payload.get('sub')}")
# Extract tenant context from JWT structure
current_tenant = payload.get('current_tenant', {})
return {
'sub': payload.get('sub'),
'email': payload.get('email'),
'user_id': payload.get('sub'),
'id': payload.get('sub'),
'user_type': payload.get('user_type', 'tenant_user'),
'tenant_id': str(current_tenant.get('id', 'default')),
'tenant_domain': current_tenant.get('domain', 'default'),
'tenant_name': current_tenant.get('name', 'Default Tenant'),
'tenant_role': current_tenant.get('role', 'tenant_user'),
'capabilities': current_tenant.get('capabilities', [])
}
# No valid authentication provided
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication: provide either Authorization header or X-Tenant-Domain + X-User-ID headers"
)

View File

@@ -0,0 +1,165 @@
"""
User UUID Resolution Utilities for GT 2.0
Handles email-to-UUID resolution across all services to ensure
consistent user identification in database operations.
"""
import logging
from typing import Dict, Any, Optional, Tuple
from fastapi import HTTPException
logger = logging.getLogger(__name__)
async def resolve_user_uuid(current_user: Dict[str, Any]) -> Tuple[str, str, str]:
"""
Resolve user email to UUID for internal services.
Args:
current_user: User data from JWT token
Returns:
Tuple of (tenant_domain, user_email, user_uuid)
Raises:
HTTPException: If UUID resolution fails
"""
tenant_domain = current_user.get("tenant_domain", "test")
user_email = current_user["email"]
# Import here to avoid circular imports
from app.api.auth import get_tenant_user_uuid_by_email
user_uuid = await get_tenant_user_uuid_by_email(user_email)
if not user_uuid:
logger.error(f"Failed to resolve UUID for user {user_email} in tenant {tenant_domain}")
raise HTTPException(
status_code=404,
detail=f"User {user_email} not found in tenant system"
)
logger.info(f"✅ Resolved user {user_email} to UUID: {user_uuid}")
return tenant_domain, user_email, user_uuid
async def ensure_user_uuid(email_or_uuid: str, tenant_domain: Optional[str] = None) -> str:
"""
Ensure we have a UUID, converting email if needed.
Args:
email_or_uuid: Either an email address or UUID string
tenant_domain: Tenant domain for lookup context
Returns:
UUID string
Raises:
ValueError: If email cannot be resolved to UUID or input is invalid
"""
import uuid
import re
# Validate input is not empty or None
if not email_or_uuid or not isinstance(email_or_uuid, str):
raise ValueError(f"Invalid user identifier: {email_or_uuid}")
email_or_uuid = email_or_uuid.strip()
# Check if it's an email
if "@" in email_or_uuid:
# It's an email, resolve to UUID
from app.api.auth import get_tenant_user_uuid_by_email
user_uuid = await get_tenant_user_uuid_by_email(email_or_uuid)
if not user_uuid:
error_msg = f"Cannot resolve email {email_or_uuid} to UUID"
if tenant_domain:
error_msg += f" in tenant {tenant_domain}"
logger.error(error_msg)
raise ValueError(error_msg)
logger.debug(f"Resolved email {email_or_uuid} to UUID: {user_uuid}")
return user_uuid
# Check if it's a valid UUID format
try:
uuid_obj = uuid.UUID(email_or_uuid)
return str(uuid_obj) # Return normalized UUID string
except (ValueError, TypeError):
# Not a valid UUID, could be a numeric ID or other format
pass
# Handle numeric user IDs or other legacy formats
if email_or_uuid.isdigit():
logger.warning(f"Received numeric user ID '{email_or_uuid}', attempting database lookup")
# Try to resolve numeric ID to proper UUID via database
from app.core.postgresql_client import get_postgresql_client
try:
client = await get_postgresql_client()
async with client.get_connection() as conn:
tenant_schema = f"tenant_{tenant_domain.replace('.', '_').replace('-', '_')}" if tenant_domain else "tenant_test"
# Try to find user by numeric ID (assuming it might be a legacy ID)
user_row = await conn.fetchrow(
f"SELECT id FROM {tenant_schema}.users WHERE id::text = $1 OR email = $1 LIMIT 1",
email_or_uuid
)
if user_row:
return str(user_row['id'])
# If not found, try finding the first user (fallback for development)
logger.warning(f"User '{email_or_uuid}' not found, using first available user as fallback")
first_user = await conn.fetchrow(f"SELECT id FROM {tenant_schema}.users LIMIT 1")
if first_user:
logger.info(f"Using fallback user UUID: {first_user['id']}")
return str(first_user['id'])
except Exception as e:
logger.error(f"Database lookup failed for user '{email_or_uuid}': {e}")
# If all else fails, raise an error
error_msg = f"Cannot resolve user identifier '{email_or_uuid}' to UUID. Expected email or valid UUID format."
if tenant_domain:
error_msg += f" Tenant: {tenant_domain}"
logger.error(error_msg)
raise ValueError(error_msg)
def get_user_sql_clause(param_num: int, user_identifier: str) -> str:
"""
Get the appropriate SQL clause for user identification.
Args:
param_num: Parameter number in SQL query (e.g., 3 for $3)
user_identifier: Either email or UUID
Returns:
SQL clause string for user lookup
"""
if "@" in user_identifier:
# Email - do lookup
return f"(SELECT id FROM users WHERE email = ${param_num} LIMIT 1)"
else:
# UUID - use directly
return f"${param_num}::uuid"
def is_uuid_format(identifier: str) -> bool:
"""
Check if a string looks like a UUID.
Args:
identifier: String to check
Returns:
True if looks like UUID, False if looks like email
"""
return "@" not in identifier and len(identifier) == 36 and identifier.count("-") == 4