GT AI OS Community Edition v2.0.33
Security hardening release addressing CodeQL and Dependabot alerts: - Fix stack trace exposure in error responses - Add SSRF protection with DNS resolution checking - Implement proper URL hostname validation (replaces substring matching) - Add centralized path sanitization to prevent path traversal - Fix ReDoS vulnerability in email validation regex - Improve HTML sanitization in validation utilities - Fix capability wildcard matching in auth utilities - Update glob dependency to address CVE - Add CodeQL suppression comments for verified false positives 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
131
apps/tenant-backend/app/core/api_standards.py
Normal file
131
apps/tenant-backend/app/core/api_standards.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend - CB-REST API Standards Integration
|
||||
|
||||
This module integrates the CB-REST standards into the Tenant backend
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the api-standards package to the path
|
||||
api_standards_path = Path(__file__).parent.parent.parent.parent.parent / "packages" / "api-standards" / "src"
|
||||
if api_standards_path.exists():
|
||||
sys.path.insert(0, str(api_standards_path))
|
||||
|
||||
# Import CB-REST standards
|
||||
try:
|
||||
from response import StandardResponse, format_response, format_error
|
||||
from capability import (
|
||||
init_capability_verifier,
|
||||
verify_capability,
|
||||
require_capability,
|
||||
Capability,
|
||||
CapabilityToken
|
||||
)
|
||||
from errors import ErrorCode, APIError, raise_api_error
|
||||
from middleware import (
|
||||
RequestCorrelationMiddleware,
|
||||
CapabilityMiddleware,
|
||||
TenantIsolationMiddleware,
|
||||
RateLimitMiddleware
|
||||
)
|
||||
except ImportError as e:
|
||||
# Fallback for development - create minimal implementations
|
||||
print(f"Warning: Could not import api-standards package: {e}")
|
||||
|
||||
# Create minimal implementations for development
|
||||
class StandardResponse:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def format_response(data, capability_used, request_id=None):
|
||||
return {
|
||||
"data": data,
|
||||
"error": None,
|
||||
"capability_used": capability_used,
|
||||
"request_id": request_id or "dev-mode"
|
||||
}
|
||||
|
||||
def format_error(code, message, capability_used="none", **kwargs):
|
||||
return {
|
||||
"data": None,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message,
|
||||
**kwargs
|
||||
},
|
||||
"capability_used": capability_used,
|
||||
"request_id": kwargs.get("request_id", "dev-mode")
|
||||
}
|
||||
|
||||
class ErrorCode:
|
||||
CAPABILITY_INSUFFICIENT = "CAPABILITY_INSUFFICIENT"
|
||||
RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"
|
||||
INVALID_REQUEST = "INVALID_REQUEST"
|
||||
SYSTEM_ERROR = "SYSTEM_ERROR"
|
||||
TENANT_ISOLATION_VIOLATION = "TENANT_ISOLATION_VIOLATION"
|
||||
|
||||
class APIError(Exception):
|
||||
def __init__(self, code, message, **kwargs):
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.kwargs = kwargs
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# Export all CB-REST components
|
||||
__all__ = [
|
||||
'StandardResponse',
|
||||
'format_response',
|
||||
'format_error',
|
||||
'init_capability_verifier',
|
||||
'verify_capability',
|
||||
'require_capability',
|
||||
'Capability',
|
||||
'CapabilityToken',
|
||||
'ErrorCode',
|
||||
'APIError',
|
||||
'raise_api_error',
|
||||
'RequestCorrelationMiddleware',
|
||||
'CapabilityMiddleware',
|
||||
'TenantIsolationMiddleware',
|
||||
'RateLimitMiddleware'
|
||||
]
|
||||
|
||||
|
||||
def setup_api_standards(app, secret_key: str, tenant_id: str):
|
||||
"""
|
||||
Setup CB-REST API standards for the tenant application
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
secret_key: Secret key for JWT signing
|
||||
tenant_id: Tenant identifier for isolation
|
||||
"""
|
||||
# Initialize capability verifier
|
||||
if 'init_capability_verifier' in globals():
|
||||
init_capability_verifier(secret_key)
|
||||
|
||||
# Add middleware in correct order
|
||||
if 'RequestCorrelationMiddleware' in globals():
|
||||
app.add_middleware(RequestCorrelationMiddleware)
|
||||
|
||||
if 'RateLimitMiddleware' in globals():
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
requests_per_minute=100 # Per-tenant rate limiting
|
||||
)
|
||||
|
||||
if 'TenantIsolationMiddleware' in globals():
|
||||
app.add_middleware(
|
||||
TenantIsolationMiddleware,
|
||||
tenant_id=tenant_id,
|
||||
enforce_isolation=True
|
||||
)
|
||||
|
||||
if 'CapabilityMiddleware' in globals():
|
||||
app.add_middleware(
|
||||
CapabilityMiddleware,
|
||||
exclude_paths=["/health", "/ready", "/metrics", "/api/v1/auth/login"]
|
||||
)
|
||||
162
apps/tenant-backend/app/core/asgi_router.py
Normal file
162
apps/tenant-backend/app/core/asgi_router.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Composite ASGI Router for GT 2.0 Tenant Backend
|
||||
|
||||
Handles routing between FastAPI and Socket.IO applications to prevent
|
||||
ASGI protocol conflicts while maintaining both WebSocket systems.
|
||||
|
||||
Architecture:
|
||||
- `/socket.io/*` → Socket.IO ASGIApp (agentic real-time features)
|
||||
- All other paths → FastAPI app (REST API, native WebSocket)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Callable, Awaitable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompositeASGIRouter:
|
||||
"""
|
||||
ASGI router that handles both FastAPI and Socket.IO applications
|
||||
without protocol conflicts.
|
||||
"""
|
||||
|
||||
def __init__(self, fastapi_app, socketio_app):
|
||||
"""
|
||||
Initialize composite router with both applications.
|
||||
|
||||
Args:
|
||||
fastapi_app: FastAPI application instance
|
||||
socketio_app: Socket.IO ASGIApp instance
|
||||
"""
|
||||
self.fastapi_app = fastapi_app
|
||||
self.socketio_app = socketio_app
|
||||
logger.info("Composite ASGI router initialized for FastAPI + Socket.IO")
|
||||
|
||||
async def __call__(self, scope: Dict[str, Any], receive: Callable, send: Callable) -> None:
|
||||
"""
|
||||
ASGI application entry point that routes requests based on path.
|
||||
|
||||
Args:
|
||||
scope: ASGI scope containing request information
|
||||
receive: ASGI receive callable
|
||||
send: ASGI send callable
|
||||
"""
|
||||
try:
|
||||
# Extract path from scope
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Route based on path pattern
|
||||
if self._is_socketio_path(path):
|
||||
# Only log Socket.IO routing at DEBUG level for non-operational paths
|
||||
if self._should_log_route(path):
|
||||
logger.debug(f"Routing to Socket.IO: {path}")
|
||||
await self.socketio_app(scope, receive, send)
|
||||
else:
|
||||
# Only log FastAPI routing at DEBUG level for non-operational paths
|
||||
if self._should_log_route(path):
|
||||
logger.debug(f"Routing to FastAPI: {path}")
|
||||
await self.fastapi_app(scope, receive, send)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ASGI routing: {e}")
|
||||
# Fallback to FastAPI for error handling
|
||||
try:
|
||||
await self.fastapi_app(scope, receive, send)
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback routing also failed: {fallback_error}")
|
||||
# Last resort: send basic error response
|
||||
await self._send_error_response(scope, send)
|
||||
|
||||
def _is_socketio_path(self, path: str) -> bool:
|
||||
"""
|
||||
Determine if path should be routed to Socket.IO.
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
True if path should go to Socket.IO, False for FastAPI
|
||||
"""
|
||||
socketio_patterns = [
|
||||
"/socket.io/",
|
||||
"/socket.io"
|
||||
]
|
||||
|
||||
# Check if path starts with any Socket.IO pattern
|
||||
for pattern in socketio_patterns:
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _should_log_route(self, path: str) -> bool:
|
||||
"""
|
||||
Determine if this path should be logged during routing.
|
||||
|
||||
Operational endpoints like health checks and metrics are excluded
|
||||
to reduce log noise during normal operation.
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
True if path should be logged, False for operational endpoints
|
||||
"""
|
||||
operational_endpoints = [
|
||||
"/health",
|
||||
"/ready",
|
||||
"/metrics",
|
||||
"/api/v1/health"
|
||||
]
|
||||
|
||||
# Don't log operational endpoints
|
||||
if any(path.startswith(endpoint) for endpoint in operational_endpoints):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _send_error_response(self, scope: Dict[str, Any], send: Callable) -> None:
|
||||
"""
|
||||
Send basic error response when both applications fail.
|
||||
|
||||
Args:
|
||||
scope: ASGI scope
|
||||
send: ASGI send callable
|
||||
"""
|
||||
try:
|
||||
if scope["type"] == "http":
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [
|
||||
[b"content-type", b"application/json"],
|
||||
[b"content-length", b"27"]
|
||||
]
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": b'{"error": "ASGI routing failed"}'
|
||||
})
|
||||
elif scope["type"] == "websocket":
|
||||
await send({
|
||||
"type": "websocket.close",
|
||||
"code": 1011,
|
||||
"reason": "ASGI routing failed"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send error response: {e}")
|
||||
|
||||
|
||||
def create_composite_asgi_app(fastapi_app, socketio_app):
|
||||
"""
|
||||
Factory function to create composite ASGI application.
|
||||
|
||||
Args:
|
||||
fastapi_app: FastAPI application instance
|
||||
socketio_app: Socket.IO ASGIApp instance
|
||||
|
||||
Returns:
|
||||
CompositeASGIRouter instance
|
||||
"""
|
||||
return CompositeASGIRouter(fastapi_app, socketio_app)
|
||||
202
apps/tenant-backend/app/core/cache.py
Normal file
202
apps/tenant-backend/app/core/cache.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Simple in-memory cache with TTL support for Gen Two performance optimization.
|
||||
|
||||
This module provides a thread-safe caching layer for expensive database queries
|
||||
and API calls. Each Uvicorn worker maintains its own cache instance.
|
||||
|
||||
Key features:
|
||||
- TTL-based expiration (configurable per-key)
|
||||
- LRU eviction when cache reaches max size
|
||||
- Thread-safe for concurrent request handling
|
||||
- Pattern-based deletion for cache invalidation
|
||||
|
||||
Usage:
|
||||
from app.core.cache import get_cache
|
||||
|
||||
cache = get_cache()
|
||||
|
||||
# Get cached value with 60-second TTL
|
||||
cached_data = cache.get("agents_minimal_user123", ttl=60)
|
||||
if not cached_data:
|
||||
data = await fetch_from_db()
|
||||
cache.set("agents_minimal_user123", data)
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Dict, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleCache:
|
||||
"""
|
||||
Thread-safe TTL cache for API responses and database query results.
|
||||
|
||||
This cache is per-worker (each Uvicorn worker maintains separate cache).
|
||||
Cache keys should include tenant_domain or user_id for proper isolation.
|
||||
|
||||
Attributes:
|
||||
max_entries: Maximum number of cache entries before LRU eviction
|
||||
_cache: Internal cache storage (key -> (timestamp, data))
|
||||
_lock: Thread lock for safe concurrent access
|
||||
"""
|
||||
|
||||
def __init__(self, max_entries: int = 1000):
|
||||
"""
|
||||
Initialize cache with maximum entry limit.
|
||||
|
||||
Args:
|
||||
max_entries: Maximum cache entries (default 1000)
|
||||
Typical: 200KB per agent list × 1000 = 200MB per worker
|
||||
"""
|
||||
self._cache: Dict[str, Tuple[datetime, Any]] = {}
|
||||
self._lock = Lock()
|
||||
self._max_entries = max_entries
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
logger.info(f"SimpleCache initialized with max_entries={max_entries}")
|
||||
|
||||
def get(self, key: str, ttl: int = 60) -> Optional[Any]:
|
||||
"""
|
||||
Get cached value if not expired.
|
||||
|
||||
Args:
|
||||
key: Cache key (should include tenant/user for isolation)
|
||||
ttl: Time-to-live in seconds (default 60)
|
||||
|
||||
Returns:
|
||||
Cached data if found and not expired, None otherwise
|
||||
|
||||
Example:
|
||||
data = cache.get("agents_minimal_user123", ttl=60)
|
||||
if data is None:
|
||||
# Cache miss - fetch from database
|
||||
data = await fetch_from_db()
|
||||
cache.set("agents_minimal_user123", data)
|
||||
"""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
self._misses += 1
|
||||
logger.debug(f"Cache miss: {key}")
|
||||
return None
|
||||
|
||||
timestamp, data = self._cache[key]
|
||||
age = (datetime.utcnow() - timestamp).total_seconds()
|
||||
|
||||
if age > ttl:
|
||||
# Expired - remove and return None
|
||||
del self._cache[key]
|
||||
self._misses += 1
|
||||
logger.debug(f"Cache expired: {key} (age={age:.1f}s, ttl={ttl}s)")
|
||||
return None
|
||||
|
||||
self._hits += 1
|
||||
logger.debug(f"Cache hit: {key} (age={age:.1f}s, ttl={ttl}s)")
|
||||
return data
|
||||
|
||||
def set(self, key: str, data: Any) -> None:
|
||||
"""
|
||||
Set cache value with current timestamp.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
data: Data to cache (should be JSON-serializable)
|
||||
|
||||
Note:
|
||||
If cache is full, oldest entry is evicted (LRU)
|
||||
"""
|
||||
with self._lock:
|
||||
# LRU eviction if cache full
|
||||
if len(self._cache) >= self._max_entries:
|
||||
oldest_key = min(self._cache.items(), key=lambda x: x[1][0])[0]
|
||||
del self._cache[oldest_key]
|
||||
logger.warning(
|
||||
f"Cache full ({self._max_entries} entries), "
|
||||
f"evicted oldest key: {oldest_key}"
|
||||
)
|
||||
|
||||
self._cache[key] = (datetime.utcnow(), data)
|
||||
logger.debug(f"Cache set: {key} (total entries: {len(self._cache)})")
|
||||
|
||||
def delete(self, pattern: str) -> int:
|
||||
"""
|
||||
Delete all keys matching pattern (prefix match).
|
||||
|
||||
Args:
|
||||
pattern: Key prefix to match (e.g., "agents_minimal_")
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
|
||||
Example:
|
||||
# Delete all agent cache entries for a user
|
||||
count = cache.delete(f"agents_minimal_{user_id}")
|
||||
count += cache.delete(f"agents_summary_{user_id}")
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_delete = [k for k in self._cache.keys() if k.startswith(pattern)]
|
||||
for k in keys_to_delete:
|
||||
del self._cache[k]
|
||||
|
||||
if keys_to_delete:
|
||||
logger.info(f"Cache invalidated {len(keys_to_delete)} entries matching '{pattern}'")
|
||||
|
||||
return len(keys_to_delete)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear entire cache (use with caution)."""
|
||||
with self._lock:
|
||||
entry_count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
logger.warning(f"Cache cleared (removed {entry_count} entries)")
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get number of cached entries."""
|
||||
return len(self._cache)
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dict with size, hits, misses, hit_rate
|
||||
"""
|
||||
total_requests = self._hits + self._misses
|
||||
hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
"size": len(self._cache),
|
||||
"max_entries": self._max_entries,
|
||||
"hits": self._hits,
|
||||
"misses": self._misses,
|
||||
"hit_rate_percent": round(hit_rate, 2),
|
||||
}
|
||||
|
||||
|
||||
# Singleton cache instance per worker
|
||||
_cache: Optional[SimpleCache] = None
|
||||
|
||||
|
||||
def get_cache() -> SimpleCache:
|
||||
"""
|
||||
Get or create singleton cache instance.
|
||||
|
||||
Each Uvicorn worker creates its own cache instance (isolated per-process).
|
||||
|
||||
Returns:
|
||||
SimpleCache instance
|
||||
"""
|
||||
global _cache
|
||||
if _cache is None:
|
||||
_cache = SimpleCache(max_entries=1000)
|
||||
return _cache
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
"""Clear global cache (for testing or emergency use)."""
|
||||
cache = get_cache()
|
||||
cache.clear()
|
||||
380
apps/tenant-backend/app/core/capability_client.py
Normal file
380
apps/tenant-backend/app/core/capability_client.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend - Capability Client
|
||||
Generate JWT capability tokens for Resource Cluster API calls
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from jose import jwt
|
||||
from app.core.config import get_settings
|
||||
import logging
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class CapabilityClient:
|
||||
"""Generates capability-based JWT tokens for Resource Cluster access"""
|
||||
|
||||
def __init__(self):
|
||||
# Use tenant-specific secret key for token signing
|
||||
self.secret_key = settings.secret_key
|
||||
self.algorithm = "HS256"
|
||||
self.issuer = f"gt2-tenant-{settings.tenant_id}"
|
||||
self.http_client = httpx.AsyncClient(timeout=10.0)
|
||||
self.control_panel_url = settings.control_panel_url
|
||||
|
||||
async def generate_capability_token(
|
||||
self,
|
||||
user_email: str,
|
||||
tenant_id: str,
|
||||
resources: List[str],
|
||||
expires_hours: int = 24,
|
||||
additional_claims: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a JWT capability token for Resource Cluster API access.
|
||||
|
||||
Args:
|
||||
user_email: Email of the user making the request
|
||||
tenant_id: Tenant identifier
|
||||
resources: List of resource capabilities (e.g., ['external_services', 'rag_processing'])
|
||||
expires_hours: Token expiration time in hours
|
||||
additional_claims: Additional JWT claims to include
|
||||
|
||||
Returns:
|
||||
Signed JWT token string
|
||||
"""
|
||||
|
||||
now = datetime.utcnow()
|
||||
expiry = now + timedelta(hours=expires_hours)
|
||||
|
||||
# Build capability token payload
|
||||
payload = {
|
||||
# Standard JWT claims
|
||||
"iss": self.issuer, # Issuer
|
||||
"sub": user_email, # Subject (user)
|
||||
"aud": "gt2-resource-cluster", # Audience
|
||||
"iat": int(now.timestamp()), # Issued at
|
||||
"exp": int(expiry.timestamp()), # Expiration
|
||||
"nbf": int(now.timestamp()), # Not before
|
||||
"jti": f"{tenant_id}-{user_email}-{int(now.timestamp())}", # JWT ID
|
||||
|
||||
# GT 2.0 specific claims
|
||||
"tenant_id": tenant_id,
|
||||
"user_email": user_email,
|
||||
"user_type": "tenant_user",
|
||||
|
||||
# Capability grants
|
||||
"capabilities": await self._build_capabilities(resources, tenant_id, expiry),
|
||||
|
||||
# Security metadata
|
||||
"capability_hash": self._generate_capability_hash(resources, tenant_id),
|
||||
"token_version": "2.0",
|
||||
"security_level": "standard"
|
||||
}
|
||||
|
||||
# Add any additional claims
|
||||
if additional_claims:
|
||||
payload.update(additional_claims)
|
||||
|
||||
# Sign the token
|
||||
try:
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self.secret_key,
|
||||
algorithm=self.algorithm
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated capability token for {user_email} with resources: {resources}"
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate capability token: {e}")
|
||||
raise RuntimeError(f"Token generation failed: {e}")
|
||||
|
||||
async def _build_capabilities(
|
||||
self,
|
||||
resources: List[str],
|
||||
tenant_id: str,
|
||||
expiry: datetime
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Build capability grants for resources with constraints from Control Panel.
|
||||
|
||||
For LLM resources, fetches real rate limits from Control Panel API.
|
||||
For other resources, uses default constraints.
|
||||
"""
|
||||
capabilities = []
|
||||
|
||||
for resource in resources:
|
||||
capability = {
|
||||
"resource": resource,
|
||||
"actions": self._get_default_actions(resource),
|
||||
"constraints": await self._get_constraints_for_resource(resource, tenant_id),
|
||||
"valid_until": expiry.isoformat()
|
||||
}
|
||||
capabilities.append(capability)
|
||||
|
||||
return capabilities
|
||||
|
||||
async def _get_constraints_for_resource(
|
||||
self,
|
||||
resource: str,
|
||||
tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get constraints for a resource, fetching from Control Panel for LLM resources.
|
||||
|
||||
GT 2.0 Principle: Single source of truth in database.
|
||||
Fails fast if Control Panel is unreachable for LLM resources.
|
||||
"""
|
||||
# For LLM resources, fetch real config from Control Panel
|
||||
if resource in ["llm", "llm_inference"]:
|
||||
# Note: We don't have model_id at this point in the flow
|
||||
# This is called during general capability token generation
|
||||
# For now, return default constraints that will be overridden
|
||||
# when model-specific tokens are generated
|
||||
return self._get_default_constraints(resource)
|
||||
|
||||
# For non-LLM resources, use defaults
|
||||
return self._get_default_constraints(resource)
|
||||
|
||||
async def _fetch_tenant_model_config(
|
||||
self,
|
||||
tenant_id: str,
|
||||
model_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch tenant model configuration from Control Panel API.
|
||||
|
||||
Returns rate limits from database (single source of truth).
|
||||
Fails fast if Control Panel is unreachable (no fallbacks).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Model config with rate_limits, or None if not found
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Control Panel API is unreachable (fail fast)
|
||||
"""
|
||||
try:
|
||||
url = f"{self.control_panel_url}/api/v1/tenant-models/tenants/{tenant_id}/models/{model_id}"
|
||||
|
||||
logger.debug(f"Fetching model config from Control Panel: {url}")
|
||||
|
||||
response = await self.http_client.get(url)
|
||||
|
||||
if response.status_code == 404:
|
||||
logger.warning(f"Model {model_id} not configured for tenant {tenant_id}")
|
||||
return None
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
config = response.json()
|
||||
logger.info(f"Fetched model config for {model_id}: rate_limits={config.get('rate_limits')}")
|
||||
|
||||
return config
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Control Panel API error: {e.response.status_code}")
|
||||
raise RuntimeError(
|
||||
f"Failed to fetch model config from Control Panel: HTTP {e.response.status_code}"
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Control Panel API unreachable: {e}")
|
||||
raise RuntimeError(
|
||||
f"Control Panel API unreachable - cannot generate capability token. "
|
||||
f"Ensure Control Panel is running at {self.control_panel_url}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error fetching model config: {e}")
|
||||
raise RuntimeError(f"Failed to fetch model config: {e}")
|
||||
|
||||
def _get_default_actions(self, resource: str) -> List[str]:
|
||||
"""Get default actions for a resource type"""
|
||||
|
||||
action_mappings = {
|
||||
"external_services": ["create", "read", "update", "delete", "health_check", "sso_token"],
|
||||
"rag_processing": ["process_document", "generate_embeddings", "vector_search"],
|
||||
"llm_inference": ["chat_completion", "streaming", "function_calling"],
|
||||
"llm": ["execute"], # Use valid ActionType from resource cluster
|
||||
"agent_orchestration": ["execute", "status", "interrupt"],
|
||||
"ai_literacy": ["play_games", "solve_puzzles", "dialogue", "analytics"],
|
||||
"app_integrations": ["read", "write", "webhook"],
|
||||
"admin": ["all"],
|
||||
# MCP Server Resources
|
||||
"mcp:rag": ["search_datasets", "query_documents", "list_user_datasets", "get_dataset_info", "get_relevant_chunks"]
|
||||
}
|
||||
|
||||
return action_mappings.get(resource, ["read"])
|
||||
|
||||
def _get_default_constraints(self, resource: str) -> Dict[str, Any]:
|
||||
"""Get default constraints for a resource type"""
|
||||
|
||||
constraint_mappings = {
|
||||
"external_services": {
|
||||
"max_instances_per_user": 10,
|
||||
"max_cpu_per_instance": "2000m",
|
||||
"max_memory_per_instance": "4Gi",
|
||||
"max_storage_per_instance": "50Gi",
|
||||
"allowed_service_types": ["ctfd", "canvas", "guacamole"]
|
||||
},
|
||||
"rag_processing": {
|
||||
"max_document_size_mb": 100,
|
||||
"max_batch_size": 50,
|
||||
"max_requests_per_hour": 1000
|
||||
},
|
||||
"llm_inference": {
|
||||
"max_tokens_per_request": 4000,
|
||||
"max_requests_per_hour": 100,
|
||||
"allowed_models": [] # Models dynamically determined by admin backend
|
||||
},
|
||||
"llm": {
|
||||
"max_tokens_per_request": 4000,
|
||||
"max_requests_per_hour": 100,
|
||||
"allowed_models": [] # Models dynamically determined by admin backend
|
||||
},
|
||||
"agent_orchestration": {
|
||||
"max_concurrent_agents": 5,
|
||||
"max_execution_time_minutes": 30
|
||||
},
|
||||
"ai_literacy": {
|
||||
"max_sessions_per_day": 20,
|
||||
"max_session_duration_hours": 4
|
||||
},
|
||||
"app_integrations": {
|
||||
"max_api_calls_per_hour": 500,
|
||||
"allowed_domains": ["api.example.com"]
|
||||
},
|
||||
# MCP Server Resources
|
||||
"mcp:rag": {
|
||||
"max_requests_per_hour": 500,
|
||||
"max_results_per_query": 50
|
||||
}
|
||||
}
|
||||
|
||||
return constraint_mappings.get(resource, {})
|
||||
|
||||
def _generate_capability_hash(self, resources: List[str], tenant_id: str) -> str:
|
||||
"""Generate a hash of the capabilities for verification"""
|
||||
import hashlib
|
||||
|
||||
# Create a deterministic string from capabilities
|
||||
capability_string = f"{tenant_id}:{':'.join(sorted(resources))}"
|
||||
|
||||
# Hash with SHA-256
|
||||
hash_object = hashlib.sha256(capability_string.encode())
|
||||
return hash_object.hexdigest()[:16] # First 16 characters
|
||||
|
||||
async def verify_capability_token(self, token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify and decode a capability token.
|
||||
|
||||
Args:
|
||||
token: JWT token to verify
|
||||
|
||||
Returns:
|
||||
Decoded token payload
|
||||
|
||||
Raises:
|
||||
ValueError: If token is invalid or expired
|
||||
"""
|
||||
|
||||
try:
|
||||
# Decode and verify the token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.secret_key,
|
||||
algorithms=[self.algorithm],
|
||||
audience="gt2-resource-cluster"
|
||||
)
|
||||
|
||||
# Additional validation
|
||||
if payload.get("iss") != self.issuer:
|
||||
raise ValueError("Invalid token issuer")
|
||||
|
||||
# Check if token is still valid
|
||||
now = datetime.utcnow()
|
||||
if payload.get("exp", 0) < now.timestamp():
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
if payload.get("nbf", 0) > now.timestamp():
|
||||
raise ValueError("Token not yet valid")
|
||||
|
||||
logger.debug(f"Verified capability token for user {payload.get('user_email')}")
|
||||
|
||||
return payload
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.JWTClaimsError as e:
|
||||
raise ValueError(f"Token claims validation failed: {e}")
|
||||
except jwt.JWTError as e:
|
||||
raise ValueError(f"Token validation failed: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Capability token verification failed: {e}")
|
||||
raise ValueError(f"Invalid token: {e}")
|
||||
|
||||
async def refresh_capability_token(
|
||||
self,
|
||||
current_token: str,
|
||||
extend_hours: int = 24
|
||||
) -> str:
|
||||
"""
|
||||
Refresh an existing capability token with extended expiration.
|
||||
|
||||
Args:
|
||||
current_token: Current JWT token
|
||||
extend_hours: Hours to extend from now
|
||||
|
||||
Returns:
|
||||
New JWT token with extended expiration
|
||||
"""
|
||||
|
||||
# Verify current token
|
||||
payload = await self.verify_capability_token(current_token)
|
||||
|
||||
# Extract current capabilities
|
||||
resources = [cap.get("resource") for cap in payload.get("capabilities", [])]
|
||||
|
||||
# Generate new token with extended expiration
|
||||
return await self.generate_capability_token(
|
||||
user_email=payload.get("user_email"),
|
||||
tenant_id=payload.get("tenant_id"),
|
||||
resources=resources,
|
||||
expires_hours=extend_hours
|
||||
)
|
||||
|
||||
def get_token_info(self, token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about a token without full verification.
|
||||
Useful for debugging and logging.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Decode without verification to get claims
|
||||
payload = jwt.get_unverified_claims(token)
|
||||
|
||||
return {
|
||||
"user_email": payload.get("user_email"),
|
||||
"tenant_id": payload.get("tenant_id"),
|
||||
"resources": [cap.get("resource") for cap in payload.get("capabilities", [])],
|
||||
"expires_at": datetime.fromtimestamp(payload.get("exp", 0)).isoformat(),
|
||||
"issued_at": datetime.fromtimestamp(payload.get("iat", 0)).isoformat(),
|
||||
"token_version": payload.get("token_version"),
|
||||
"security_level": payload.get("security_level")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get token info: {e}")
|
||||
return {"error": str(e)}
|
||||
289
apps/tenant-backend/app/core/config.py
Normal file
289
apps/tenant-backend/app/core/config.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend Configuration
|
||||
|
||||
Environment-based configuration for tenant applications with perfect isolation.
|
||||
Each tenant gets its own isolated backend instance with separate database files.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field, validator
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings with environment variable support"""
|
||||
|
||||
# Environment
|
||||
environment: str = Field(default="development", description="Runtime environment")
|
||||
debug: bool = Field(default=False, description="Debug mode")
|
||||
|
||||
# Tenant Identification (Critical for isolation)
|
||||
tenant_id: str = Field(..., description="Unique tenant identifier")
|
||||
tenant_domain: str = Field(..., description="Tenant domain (e.g., customer1)")
|
||||
|
||||
# Database Configuration (PostgreSQL + PGVector direct connection)
|
||||
database_url: str = Field(
|
||||
default="postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres-primary:5432/gt2_tenants",
|
||||
description="PostgreSQL connection URL (direct to primary)"
|
||||
)
|
||||
|
||||
|
||||
# PostgreSQL Configuration
|
||||
postgres_schema: str = Field(
|
||||
default="tenant_test",
|
||||
description="PostgreSQL schema for tenant data (tenant_{tenant_domain})"
|
||||
)
|
||||
postgres_pool_size: int = Field(
|
||||
default=10,
|
||||
description="Connection pool size for PostgreSQL"
|
||||
)
|
||||
postgres_max_overflow: int = Field(
|
||||
default=20,
|
||||
description="Max overflow connections for PostgreSQL pool"
|
||||
)
|
||||
|
||||
|
||||
# Authentication & Security
|
||||
secret_key: str = Field(..., description="JWT signing key")
|
||||
algorithm: str = Field(default="HS256", description="JWT algorithm")
|
||||
|
||||
# OAuth2 Configuration
|
||||
require_oauth2_auth: bool = Field(
|
||||
default=True,
|
||||
description="Require OAuth2 authentication for API endpoints"
|
||||
)
|
||||
oauth2_proxy_url: str = Field(
|
||||
default="http://oauth2-proxy:4180",
|
||||
description="Internal URL of OAuth2 Proxy service"
|
||||
)
|
||||
oauth2_issuer_url: str = Field(
|
||||
default="https://auth.gt2.com",
|
||||
description="OAuth2 provider issuer URL"
|
||||
)
|
||||
oauth2_audience: str = Field(
|
||||
default="gt2-tenant-client",
|
||||
description="OAuth2 token audience"
|
||||
)
|
||||
|
||||
# Resource Cluster Integration
|
||||
resource_cluster_url: str = Field(
|
||||
default="http://localhost:8004",
|
||||
description="URL of the Resource Cluster API"
|
||||
)
|
||||
resource_cluster_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Resource Cluster authentication"
|
||||
)
|
||||
|
||||
# MCP Service Configuration
|
||||
mcp_service_url: str = Field(
|
||||
default="http://resource-cluster:8000",
|
||||
description="URL of the MCP service for tool execution"
|
||||
)
|
||||
|
||||
# Control Panel Integration
|
||||
control_panel_url: str = Field(
|
||||
default="http://localhost:8001",
|
||||
description="URL of the Control Panel API"
|
||||
)
|
||||
service_auth_token: str = Field(
|
||||
default="internal-service-token",
|
||||
description="Service-to-service authentication token"
|
||||
)
|
||||
|
||||
# WebSocket Configuration
|
||||
websocket_ping_interval: int = Field(default=25, description="WebSocket ping interval")
|
||||
websocket_ping_timeout: int = Field(default=20, description="WebSocket ping timeout")
|
||||
|
||||
# File Upload Configuration
|
||||
max_file_size_mb: int = Field(default=10, description="Maximum file upload size in MB")
|
||||
allowed_file_types: List[str] = Field(
|
||||
default=[".pdf", ".docx", ".txt", ".md", ".csv", ".xlsx"],
|
||||
description="Allowed file extensions for upload"
|
||||
)
|
||||
upload_directory: str = Field(
|
||||
default_factory=lambda: f"/tmp/gt2-data/{os.getenv('TENANT_DOMAIN', 'default')}/uploads" if os.getenv('ENVIRONMENT') == 'test' else f"/data/{os.getenv('TENANT_DOMAIN', 'default')}/uploads",
|
||||
description="Directory for uploaded files"
|
||||
)
|
||||
temp_directory: str = Field(
|
||||
default_factory=lambda: f"/tmp/gt2-data/{os.getenv('TENANT_DOMAIN', 'default')}/temp" if os.getenv('ENVIRONMENT') == 'test' else f"/data/{os.getenv('TENANT_DOMAIN', 'default')}/temp",
|
||||
description="Temporary directory for file processing"
|
||||
)
|
||||
file_storage_path: str = Field(
|
||||
default_factory=lambda: f"/tmp/gt2-data/{os.getenv('TENANT_DOMAIN', 'default')}" if os.getenv('ENVIRONMENT') == 'test' else f"/data/{os.getenv('TENANT_DOMAIN', 'default')}",
|
||||
description="Root directory for file storage (conversation files, etc.)"
|
||||
)
|
||||
|
||||
# File Context Settings (for chat attachments)
|
||||
max_chunks_per_file: int = Field(
|
||||
default=50,
|
||||
description="Maximum chunks per file (enforces diversity across files)"
|
||||
)
|
||||
max_total_file_chunks: int = Field(
|
||||
default=100,
|
||||
description="Maximum total chunks across all attached files"
|
||||
)
|
||||
file_context_token_safety_margin: float = Field(
|
||||
default=0.05,
|
||||
description="Safety margin for token budget calculations (0.05 = 5%)"
|
||||
)
|
||||
|
||||
# Rate Limiting
|
||||
rate_limit_requests: int = Field(default=1000, description="Requests per minute per IP")
|
||||
rate_limit_window_seconds: int = Field(default=60, description="Rate limit window")
|
||||
|
||||
# CORS Configuration
|
||||
cors_origins: List[str] = Field(
|
||||
default=["http://localhost:3001", "http://localhost:3002", "https://*.gt2.com"],
|
||||
description="Allowed CORS origins"
|
||||
)
|
||||
|
||||
# Security
|
||||
allowed_hosts: List[str] = Field(
|
||||
default=["localhost", "*.gt2.com", "testserver", "gentwo-tenant-backend", "tenant-backend"],
|
||||
description="Allowed host headers"
|
||||
)
|
||||
|
||||
# Vector Storage Configuration (PGVector integrated with PostgreSQL)
|
||||
vector_dimensions: int = Field(
|
||||
default=384,
|
||||
description="Vector dimensions for embeddings (all-MiniLM-L6-v2 model)"
|
||||
)
|
||||
embedding_model: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Embedding model for document processing"
|
||||
)
|
||||
vector_similarity_threshold: float = Field(
|
||||
default=0.3,
|
||||
description="Minimum similarity threshold for vector search"
|
||||
)
|
||||
|
||||
# Legacy ChromaDB Configuration (DEPRECATED - replaced by PGVector)
|
||||
chromadb_mode: str = Field(
|
||||
default="disabled",
|
||||
description="ChromaDB mode - DEPRECATED, using PGVector instead"
|
||||
)
|
||||
chromadb_host: str = Field(
|
||||
default_factory=lambda: f"tenant-{os.getenv('TENANT_DOMAIN', 'test')}-chromadb",
|
||||
description="ChromaDB host - DEPRECATED"
|
||||
)
|
||||
chromadb_port: int = Field(
|
||||
default=8000,
|
||||
description="ChromaDB HTTP port - DEPRECATED"
|
||||
)
|
||||
chromadb_path: str = Field(
|
||||
default_factory=lambda: f"/data/{os.getenv('TENANT_DOMAIN', 'default')}/chromadb",
|
||||
description="ChromaDB file storage path - DEPRECATED"
|
||||
)
|
||||
|
||||
# Redis removed - PostgreSQL handles all caching and session storage needs
|
||||
|
||||
# Logging Configuration
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
log_format: str = Field(default="json", description="Log format: json or text")
|
||||
|
||||
# Performance
|
||||
worker_processes: int = Field(default=1, description="Number of worker processes")
|
||||
max_connections: int = Field(default=100, description="Maximum concurrent connections")
|
||||
|
||||
# Monitoring
|
||||
prometheus_enabled: bool = Field(default=True, description="Enable Prometheus metrics")
|
||||
prometheus_port: int = Field(default=9090, description="Prometheus metrics port")
|
||||
|
||||
# Feature Flags
|
||||
enable_file_upload: bool = Field(default=True, description="Enable file upload feature")
|
||||
enable_voice_input: bool = Field(default=False, description="Enable voice input (future)")
|
||||
enable_document_analysis: bool = Field(default=True, description="Enable document analysis")
|
||||
|
||||
@validator("tenant_id")
|
||||
def validate_tenant_id(cls, v):
|
||||
if not v or len(v) < 3:
|
||||
raise ValueError("Tenant ID must be at least 3 characters long")
|
||||
return v
|
||||
|
||||
@validator("tenant_domain")
|
||||
def validate_tenant_domain(cls, v):
|
||||
if not v or not v.replace("-", "").replace("_", "").isalnum():
|
||||
raise ValueError("Tenant domain must be alphanumeric with optional hyphens/underscores")
|
||||
return v
|
||||
|
||||
|
||||
@validator("upload_directory")
|
||||
def validate_upload_directory(cls, v):
|
||||
# Ensure the upload directory exists with secure permissions
|
||||
os.makedirs(v, exist_ok=True, mode=0o700)
|
||||
return v
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
def get_settings(tenant_id: Optional[str] = None) -> Settings:
|
||||
"""Get tenant-scoped application settings"""
|
||||
# For development and testing, use simple settings without caching
|
||||
if os.getenv("ENVIRONMENT") in ["development", "test"]:
|
||||
return Settings()
|
||||
|
||||
# In production, settings should be tenant-scoped
|
||||
# This prevents global state from affecting tenant isolation
|
||||
if tenant_id:
|
||||
# Create tenant-specific settings with proper isolation
|
||||
settings = Settings()
|
||||
# In production, this could load tenant-specific overrides
|
||||
return settings
|
||||
else:
|
||||
# Default settings for non-tenant operations
|
||||
return Settings()
|
||||
|
||||
|
||||
# Security and isolation utilities
|
||||
def get_tenant_data_path(tenant_domain: str) -> str:
|
||||
"""Get the secure data path for a tenant"""
|
||||
if os.getenv('ENVIRONMENT') == 'test':
|
||||
return f"/tmp/gt2-data/{tenant_domain}"
|
||||
return f"/data/{tenant_domain}"
|
||||
|
||||
|
||||
def get_tenant_database_url(tenant_domain: str) -> str:
|
||||
"""Get the database URL for a specific tenant (PostgreSQL)"""
|
||||
return f"postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres:5432/gt2_tenants"
|
||||
|
||||
|
||||
def get_tenant_schema_name(tenant_domain: str) -> str:
|
||||
"""Get the PostgreSQL schema name for a specific tenant"""
|
||||
# Clean domain name for schema usage
|
||||
clean_domain = tenant_domain.replace('-', '_').replace('.', '_').lower()
|
||||
return f"tenant_{clean_domain}"
|
||||
|
||||
|
||||
def ensure_tenant_isolation(tenant_id: str) -> None:
|
||||
"""Ensure proper tenant isolation is configured"""
|
||||
settings = get_settings()
|
||||
|
||||
if settings.tenant_id != tenant_id:
|
||||
raise ValueError(f"Tenant ID mismatch: expected {settings.tenant_id}, got {tenant_id}")
|
||||
|
||||
# Verify database path contains tenant identifier
|
||||
if settings.tenant_domain not in settings.database_path:
|
||||
raise ValueError("Database path does not contain tenant identifier - isolation breach risk")
|
||||
|
||||
# Verify upload directory contains tenant identifier
|
||||
if settings.tenant_domain not in settings.upload_directory:
|
||||
raise ValueError("Upload directory does not contain tenant identifier - isolation breach risk")
|
||||
|
||||
|
||||
# Development helpers
|
||||
def is_development() -> bool:
|
||||
"""Check if running in development mode"""
|
||||
return get_settings().environment == "development"
|
||||
|
||||
|
||||
def is_production() -> bool:
|
||||
"""Check if running in production mode"""
|
||||
return get_settings().environment == "production"
|
||||
131
apps/tenant-backend/app/core/database.py
Normal file
131
apps/tenant-backend/app/core/database.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend Database Configuration - PostgreSQL + PGVector Client
|
||||
|
||||
Migrated from DuckDB service to PostgreSQL + PGVector for enterprise readiness:
|
||||
- PostgreSQL + PGVector unified storage (replaces DuckDB + ChromaDB)
|
||||
- BionicGPT Row Level Security patterns for enterprise isolation
|
||||
- MVCC concurrency solving DuckDB file locking issues
|
||||
- Hybrid vector + full-text search in single queries
|
||||
- Connection pooling for 10,000+ concurrent connections
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Generator, Optional, Any, Dict, List
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.postgresql_client import (
|
||||
get_postgresql_client, init_postgresql, close_postgresql,
|
||||
get_db_session, execute_query, execute_command,
|
||||
fetch_one, fetch_scalar, health_check, get_database_info
|
||||
)
|
||||
|
||||
# Legacy DuckDB imports removed - PostgreSQL + PGVector only
|
||||
|
||||
# SQLAlchemy Base for ORM models
|
||||
Base = declarative_base()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
# PostgreSQL client is managed by postgresql_client module
|
||||
|
||||
|
||||
async def init_database() -> None:
|
||||
"""Initialize PostgreSQL + PGVector connection"""
|
||||
logger.info("Initializing PostgreSQL + PGVector database connection...")
|
||||
|
||||
try:
|
||||
await init_postgresql()
|
||||
logger.info("PostgreSQL + PGVector connection initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL database: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def close_database() -> None:
|
||||
"""Close PostgreSQL connections"""
|
||||
try:
|
||||
await close_postgresql()
|
||||
logger.info("PostgreSQL connections closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing PostgreSQL connections: {e}")
|
||||
|
||||
|
||||
async def get_db_client_instance():
|
||||
"""Get the PostgreSQL client instance"""
|
||||
return await get_postgresql_client()
|
||||
|
||||
|
||||
# get_db_session is imported from postgresql_client
|
||||
|
||||
|
||||
# execute_query is imported from postgresql_client
|
||||
|
||||
|
||||
# execute_command is imported from postgresql_client
|
||||
|
||||
|
||||
async def execute_transaction(commands: List[Dict[str, Any]]) -> List[int]:
|
||||
"""Execute multiple commands in a transaction (PostgreSQL format)"""
|
||||
client = await get_postgresql_client()
|
||||
pg_commands = [(cmd.get('query', cmd.get('command', '')), tuple(cmd.get('params', {}).values())) for cmd in commands]
|
||||
return await client.execute_transaction(pg_commands)
|
||||
|
||||
|
||||
# fetch_one is imported from postgresql_client
|
||||
|
||||
|
||||
async def fetch_all(query: str, *args) -> List[Dict[str, Any]]:
|
||||
"""Execute query and return all rows"""
|
||||
return await execute_query(query, *args)
|
||||
|
||||
|
||||
# fetch_scalar is imported from postgresql_client
|
||||
|
||||
|
||||
# get_database_info is imported from postgresql_client
|
||||
|
||||
|
||||
# health_check is imported from postgresql_client
|
||||
|
||||
|
||||
# Legacy compatibility functions (for gradual migration)
|
||||
def get_db() -> Generator[None, None, None]:
|
||||
"""Legacy sync database dependency - deprecated"""
|
||||
logger.warning("get_db() is deprecated. Use async get_db_session() instead")
|
||||
# Return a dummy generator for compatibility
|
||||
yield None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session_sync():
|
||||
"""Legacy sync session - deprecated"""
|
||||
logger.warning("get_db_session_sync() is deprecated. Use async get_db_session() instead")
|
||||
yield None
|
||||
|
||||
|
||||
def execute_raw_query(query: str, params: Optional[Dict] = None) -> List[Dict]:
|
||||
"""Legacy sync query execution - deprecated"""
|
||||
logger.error("execute_raw_query() is deprecated and not supported with PostgreSQL async client")
|
||||
raise NotImplementedError("Use async execute_query() instead")
|
||||
|
||||
|
||||
def verify_tenant_isolation() -> bool:
|
||||
"""Verify tenant isolation - PostgreSQL schema-based isolation with RLS is always enabled"""
|
||||
return True
|
||||
|
||||
|
||||
# Initialize database on module import (for FastAPI startup)
|
||||
async def startup_database():
|
||||
"""Initialize database during FastAPI startup"""
|
||||
await init_database()
|
||||
|
||||
|
||||
async def shutdown_database():
|
||||
"""Cleanup database during FastAPI shutdown"""
|
||||
await close_database()
|
||||
348
apps/tenant-backend/app/core/database_interface.py
Normal file
348
apps/tenant-backend/app/core/database_interface.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
GT 2.0 Database Interface - DuckDB Implementation
|
||||
|
||||
Provides a unified interface for DuckDB database operations
|
||||
following GT 2.0 principles of Zero Downtime, Perfect Tenant Isolation, and Elegant Simplicity.
|
||||
Post-migration: SQLite has been completely replaced with DuckDB for enhanced MVCC performance.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, AsyncGenerator, Union
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class DatabaseEngine(Enum):
|
||||
"""Supported database engines - DEPRECATED: Use PostgreSQL directly"""
|
||||
POSTGRESQL = "postgresql"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database configuration"""
|
||||
engine: DatabaseEngine
|
||||
database_path: str
|
||||
tenant_id: str
|
||||
shard_id: Optional[str] = None
|
||||
encryption_key: Optional[str] = None
|
||||
connection_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
"""Standardized query result"""
|
||||
rows: List[Dict[str, Any]]
|
||||
row_count: int
|
||||
columns: List[str]
|
||||
execution_time_ms: float
|
||||
|
||||
|
||||
class DatabaseInterface(ABC):
|
||||
"""
|
||||
Abstract database interface for GT 2.0 tenant isolation.
|
||||
|
||||
DuckDB implementation with MVCC concurrency for true zero-downtime operations,
|
||||
perfect tenant isolation, and 10x analytical performance improvements.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DatabaseConfig):
|
||||
self.config = config
|
||||
self.tenant_id = config.tenant_id
|
||||
self.database_path = config.database_path
|
||||
self.engine = config.engine
|
||||
|
||||
# Connection Management
|
||||
@abstractmethod
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize database connection and create tables"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close database connections"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def is_initialized(self) -> bool:
|
||||
"""Check if database is properly initialized"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@asynccontextmanager
|
||||
async def get_session(self) -> AsyncGenerator[Any, None]:
|
||||
"""Get database session context manager"""
|
||||
pass
|
||||
|
||||
# Schema Management
|
||||
@abstractmethod
|
||||
async def create_tables(self) -> None:
|
||||
"""Create all required tables"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_schema_version(self) -> Optional[str]:
|
||||
"""Get current database schema version"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def migrate_schema(self, target_version: str) -> bool:
|
||||
"""Migrate database schema to target version"""
|
||||
pass
|
||||
|
||||
# Query Operations
|
||||
@abstractmethod
|
||||
async def execute_query(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> QueryResult:
|
||||
"""Execute SELECT query and return results"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_command(
|
||||
self,
|
||||
command: str,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> int:
|
||||
"""Execute INSERT/UPDATE/DELETE command and return affected rows"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_batch(
|
||||
self,
|
||||
commands: List[str],
|
||||
params: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[int]:
|
||||
"""Execute batch commands in transaction"""
|
||||
pass
|
||||
|
||||
# Transaction Management
|
||||
@abstractmethod
|
||||
@asynccontextmanager
|
||||
async def transaction(self) -> AsyncGenerator[Any, None]:
|
||||
"""Transaction context manager"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def begin_transaction(self) -> Any:
|
||||
"""Begin transaction and return transaction handle"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def commit_transaction(self, tx: Any) -> None:
|
||||
"""Commit transaction"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def rollback_transaction(self, tx: Any) -> None:
|
||||
"""Rollback transaction"""
|
||||
pass
|
||||
|
||||
# Health and Monitoring
|
||||
@abstractmethod
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Check database health and return status"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get database statistics"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def optimize(self) -> bool:
|
||||
"""Optimize database performance"""
|
||||
pass
|
||||
|
||||
# Backup and Recovery
|
||||
@abstractmethod
|
||||
async def backup(self, backup_path: str) -> bool:
|
||||
"""Create database backup"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def restore(self, backup_path: str) -> bool:
|
||||
"""Restore from database backup"""
|
||||
pass
|
||||
|
||||
# Sharding Support (DuckDB specific)
|
||||
@abstractmethod
|
||||
async def create_shard(self, shard_id: str) -> bool:
|
||||
"""Create new database shard"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_shard_info(self) -> Dict[str, Any]:
|
||||
"""Get information about current shard"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def migrate_to_shard(self, source_db: 'DatabaseInterface') -> bool:
|
||||
"""Migrate data from another database instance"""
|
||||
pass
|
||||
|
||||
# Vector Operations (ChromaDB integration)
|
||||
@abstractmethod
|
||||
async def store_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadata: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Store embeddings with documents and metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def query_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
query_embedding: List[float],
|
||||
limit: int = 10,
|
||||
filter_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query embeddings by similarity"""
|
||||
pass
|
||||
|
||||
# Data Import/Export
|
||||
@abstractmethod
|
||||
async def export_data(
|
||||
self,
|
||||
format: str = "json",
|
||||
tables: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Export database data"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def import_data(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
format: str = "json",
|
||||
merge_strategy: str = "replace"
|
||||
) -> bool:
|
||||
"""Import database data"""
|
||||
pass
|
||||
|
||||
# Security and Encryption
|
||||
@abstractmethod
|
||||
async def encrypt_database(self, encryption_key: str) -> bool:
|
||||
"""Enable database encryption"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def verify_encryption(self) -> bool:
|
||||
"""Verify database encryption status"""
|
||||
pass
|
||||
|
||||
# Performance and Indexing
|
||||
@abstractmethod
|
||||
async def create_index(
|
||||
self,
|
||||
table: str,
|
||||
columns: List[str],
|
||||
index_name: Optional[str] = None,
|
||||
unique: bool = False
|
||||
) -> bool:
|
||||
"""Create database index"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def drop_index(self, index_name: str) -> bool:
|
||||
"""Drop database index"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_queries(self) -> Dict[str, Any]:
|
||||
"""Analyze query performance"""
|
||||
pass
|
||||
|
||||
# Utility Methods
|
||||
async def get_engine_info(self) -> Dict[str, Any]:
|
||||
"""Get database engine information"""
|
||||
return {
|
||||
"engine": self.engine.value,
|
||||
"tenant_id": self.tenant_id,
|
||||
"database_path": self.database_path,
|
||||
"shard_id": self.config.shard_id,
|
||||
"supports_mvcc": self.engine == DatabaseEngine.POSTGRESQL,
|
||||
"supports_sharding": self.engine == DatabaseEngine.POSTGRESQL,
|
||||
"file_based": True
|
||||
}
|
||||
|
||||
async def validate_tenant_isolation(self) -> bool:
|
||||
"""Validate that tenant isolation is maintained"""
|
||||
try:
|
||||
stats = await self.get_statistics()
|
||||
return (
|
||||
self.tenant_id in self.database_path and
|
||||
stats.get("isolated", False)
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class DatabaseFactory:
|
||||
"""Factory for creating database instances"""
|
||||
|
||||
@staticmethod
|
||||
async def create_database(config: DatabaseConfig) -> DatabaseInterface:
|
||||
"""Create database instance - PostgreSQL only"""
|
||||
raise NotImplementedError("Database interface deprecated. Use PostgreSQL directly via postgresql_client.py")
|
||||
|
||||
@staticmethod
|
||||
async def migrate_database(
|
||||
source_config: DatabaseConfig,
|
||||
target_config: DatabaseConfig,
|
||||
migration_options: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""Migrate data from source to target database"""
|
||||
source_db = await DatabaseFactory.create_database(source_config)
|
||||
target_db = await DatabaseFactory.create_database(target_config)
|
||||
|
||||
try:
|
||||
await source_db.initialize()
|
||||
await target_db.initialize()
|
||||
|
||||
# Export data from source
|
||||
data = await source_db.export_data()
|
||||
|
||||
# Import data to target
|
||||
success = await target_db.import_data(data)
|
||||
|
||||
if success and migration_options and migration_options.get("verify", True):
|
||||
# Verify migration
|
||||
source_stats = await source_db.get_statistics()
|
||||
target_stats = await target_db.get_statistics()
|
||||
|
||||
return source_stats.get("row_count", 0) == target_stats.get("row_count", 0)
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
await source_db.close()
|
||||
await target_db.close()
|
||||
|
||||
|
||||
# Error Classes
|
||||
class DatabaseError(Exception):
|
||||
"""Base database error"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseConnectionError(DatabaseError):
|
||||
"""Database connection error"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseMigrationError(DatabaseError):
|
||||
"""Database migration error"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseShardingError(DatabaseError):
|
||||
"""Database sharding error"""
|
||||
pass
|
||||
265
apps/tenant-backend/app/core/dependencies/resource_access.py
Normal file
265
apps/tenant-backend/app/core/dependencies/resource_access.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Resource Access Control Dependencies for FastAPI
|
||||
|
||||
Provides declarative access control for agents and datasets using team-based permissions.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
from uuid import UUID
|
||||
from fastapi import Depends, HTTPException
|
||||
from app.api.dependencies import get_current_user
|
||||
from app.services.team_service import TeamService
|
||||
from app.core.permissions import get_user_role
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def require_resource_access(
|
||||
resource_type: str,
|
||||
required_permission: str = "read"
|
||||
) -> Callable:
|
||||
"""
|
||||
FastAPI dependency factory for resource access control.
|
||||
|
||||
Creates a dependency that verifies user has required permission on a resource
|
||||
via ownership, organization visibility, or team membership.
|
||||
|
||||
Args:
|
||||
resource_type: 'agent' or 'dataset'
|
||||
required_permission: 'read' or 'edit' (default: 'read')
|
||||
|
||||
Returns:
|
||||
FastAPI dependency function
|
||||
|
||||
Usage:
|
||||
@router.get("/agents/{agent_id}")
|
||||
async def get_agent(
|
||||
agent_id: str,
|
||||
_: None = Depends(require_resource_access("agent", "read"))
|
||||
):
|
||||
# User has read access if we reach here
|
||||
...
|
||||
|
||||
@router.put("/agents/{agent_id}")
|
||||
async def update_agent(
|
||||
agent_id: str,
|
||||
_: None = Depends(require_resource_access("agent", "edit"))
|
||||
):
|
||||
# User has edit access if we reach here
|
||||
...
|
||||
"""
|
||||
|
||||
async def check_access(
|
||||
resource_id: str,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> None:
|
||||
"""
|
||||
Verify user has required permission on resource.
|
||||
|
||||
Raises HTTPException(403) if access denied.
|
||||
"""
|
||||
user_id = current_user["user_id"]
|
||||
tenant_domain = current_user["tenant_domain"]
|
||||
user_email = current_user.get("email", user_id)
|
||||
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Check if admin/developer (bypass all checks)
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
if user_role in ["admin", "developer"]:
|
||||
logger.debug(f"Admin/developer {user_id} has full access to {resource_type} {resource_id}")
|
||||
return
|
||||
|
||||
# Check if user owns the resource
|
||||
ownership_query = f"""
|
||||
SELECT created_by FROM {resource_type}s
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
owner_id = await pg_client.fetch_scalar(ownership_query, resource_id, tenant_domain)
|
||||
|
||||
if owner_id and str(owner_id) == str(user_id):
|
||||
logger.debug(f"User {user_id} owns {resource_type} {resource_id}")
|
||||
return
|
||||
|
||||
# Check if resource is organization-wide
|
||||
visibility_query = f"""
|
||||
SELECT visibility FROM {resource_type}s
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
visibility = await pg_client.fetch_scalar(visibility_query, resource_id, tenant_domain)
|
||||
|
||||
if visibility == "organization":
|
||||
logger.debug(f"{resource_type.capitalize()} {resource_id} is organization-wide")
|
||||
return
|
||||
|
||||
# Check team-based access using TeamService
|
||||
team_service = TeamService(tenant_domain, user_id, user_email)
|
||||
has_permission = await team_service.check_user_resource_permission(
|
||||
user_id=user_id,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
required_permission=required_permission
|
||||
)
|
||||
|
||||
if has_permission:
|
||||
logger.debug(f"User {user_id} has {required_permission} permission on {resource_type} {resource_id} via team")
|
||||
return
|
||||
|
||||
# Access denied
|
||||
logger.warning(f"Access denied: User {user_id} cannot access {resource_type} {resource_id} (required: {required_permission})")
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"You do not have {required_permission} permission for this {resource_type}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking resource access: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error verifying {resource_type} access"
|
||||
)
|
||||
|
||||
return check_access
|
||||
|
||||
|
||||
def require_agent_access(required_permission: str = "read") -> Callable:
|
||||
"""
|
||||
Convenience wrapper for agent access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/agents/{agent_id}")
|
||||
async def get_agent(
|
||||
agent_id: str,
|
||||
_: None = Depends(require_agent_access("read"))
|
||||
):
|
||||
...
|
||||
"""
|
||||
return require_resource_access("agent", required_permission)
|
||||
|
||||
|
||||
def require_dataset_access(required_permission: str = "read") -> Callable:
|
||||
"""
|
||||
Convenience wrapper for dataset access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/datasets/{dataset_id}")
|
||||
async def get_dataset(
|
||||
dataset_id: str,
|
||||
_: None = Depends(require_dataset_access("read"))
|
||||
):
|
||||
...
|
||||
"""
|
||||
return require_resource_access("dataset", required_permission)
|
||||
|
||||
|
||||
async def check_agent_edit_permission(
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
tenant_domain: str,
|
||||
user_email: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
Helper function to check if user can edit an agent.
|
||||
|
||||
Can be used in service layer without FastAPI dependency injection.
|
||||
|
||||
Args:
|
||||
agent_id: UUID of the agent
|
||||
user_id: UUID of the user
|
||||
tenant_domain: Tenant domain
|
||||
user_email: User email (optional)
|
||||
|
||||
Returns:
|
||||
True if user can edit agent
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Check if admin/developer
|
||||
user_role = await get_user_role(pg_client, user_email or user_id, tenant_domain)
|
||||
if user_role in ["admin", "developer"]:
|
||||
return True
|
||||
|
||||
# Check ownership
|
||||
query = """
|
||||
SELECT created_by FROM agents
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
owner_id = await pg_client.fetch_scalar(query, agent_id, tenant_domain)
|
||||
|
||||
if owner_id and str(owner_id) == str(user_id):
|
||||
return True
|
||||
|
||||
# Check team edit permission
|
||||
team_service = TeamService(tenant_domain, user_id, user_email or user_id)
|
||||
return await team_service.check_user_resource_permission(
|
||||
user_id=user_id,
|
||||
resource_type="agent",
|
||||
resource_id=agent_id,
|
||||
required_permission="edit"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking agent edit permission: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def check_dataset_edit_permission(
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
tenant_domain: str,
|
||||
user_email: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
Helper function to check if user can edit a dataset.
|
||||
|
||||
Can be used in service layer without FastAPI dependency injection.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset
|
||||
user_id: UUID of the user
|
||||
tenant_domain: Tenant domain
|
||||
user_email: User email (optional)
|
||||
|
||||
Returns:
|
||||
True if user can edit dataset
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Check if admin/developer
|
||||
user_role = await get_user_role(pg_client, user_email or user_id, tenant_domain)
|
||||
if user_role in ["admin", "developer"]:
|
||||
return True
|
||||
|
||||
# Check ownership
|
||||
query = """
|
||||
SELECT user_id FROM datasets
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
owner_id = await pg_client.fetch_scalar(query, dataset_id, tenant_domain)
|
||||
|
||||
if owner_id and str(owner_id) == str(user_id):
|
||||
return True
|
||||
|
||||
# Check team edit permission
|
||||
team_service = TeamService(tenant_domain, user_id, user_email or user_id)
|
||||
return await team_service.check_user_resource_permission(
|
||||
user_id=user_id,
|
||||
resource_type="dataset",
|
||||
resource_id=dataset_id,
|
||||
required_permission="edit"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking dataset edit permission: {e}")
|
||||
return False
|
||||
169
apps/tenant-backend/app/core/logging_config.py
Normal file
169
apps/tenant-backend/app/core/logging_config.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend Logging Configuration
|
||||
|
||||
Structured logging with tenant isolation and security considerations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Setup logging configuration for the tenant backend"""
|
||||
settings = get_settings()
|
||||
|
||||
# Determine log directory based on environment
|
||||
if settings.environment == "test":
|
||||
log_dir = f"/tmp/gt2-data/{settings.tenant_domain}/logs"
|
||||
else:
|
||||
log_dir = f"/data/{settings.tenant_domain}/logs"
|
||||
|
||||
# Create logging configuration
|
||||
log_config: Dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
},
|
||||
"json": {
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(pathname)s:%(lineno)d",
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
},
|
||||
"detailed": {
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(funcName)s() - %(message)s",
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": settings.log_level,
|
||||
"formatter": "json" if settings.log_format == "json" else "default",
|
||||
"stream": sys.stdout,
|
||||
},
|
||||
"file": {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"level": "INFO",
|
||||
"formatter": "json" if settings.log_format == "json" else "detailed",
|
||||
"filename": f"{log_dir}/tenant-backend.log",
|
||||
"maxBytes": 10485760, # 10MB
|
||||
"backupCount": 5,
|
||||
"encoding": "utf-8",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"": { # Root logger
|
||||
"level": settings.log_level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"app": {
|
||||
"level": settings.log_level,
|
||||
"handlers": ["console", "file"] if settings.environment == "production" else ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"sqlalchemy.engine": {
|
||||
"level": "INFO" if settings.debug else "WARNING",
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.access": {
|
||||
"level": "WARNING", # Suppress INFO level access logs (operational endpoints)
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.error": {
|
||||
"level": "INFO",
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Create log directory if it doesn't exist
|
||||
import os
|
||||
os.makedirs(log_dir, exist_ok=True, mode=0o700)
|
||||
|
||||
# Apply logging configuration
|
||||
logging.config.dictConfig(log_config)
|
||||
|
||||
# Add tenant context to all logs
|
||||
class TenantContextFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
record.tenant_id = settings.tenant_id
|
||||
record.tenant_domain = settings.tenant_domain
|
||||
return True
|
||||
|
||||
tenant_filter = TenantContextFilter()
|
||||
|
||||
# Add tenant filter to all handlers
|
||||
for handler in logging.getLogger().handlers:
|
||||
handler.addFilter(tenant_filter)
|
||||
|
||||
# Log startup information
|
||||
logger = logging.getLogger("app.startup")
|
||||
logger.info(
|
||||
"Tenant backend logging initialized",
|
||||
extra={
|
||||
"tenant_id": settings.tenant_id,
|
||||
"tenant_domain": settings.tenant_domain,
|
||||
"environment": settings.environment,
|
||||
"log_level": settings.log_level,
|
||||
"log_format": settings.log_format,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get logger with consistent naming and formatting"""
|
||||
return logging.getLogger(f"app.{name}")
|
||||
|
||||
|
||||
|
||||
class SecurityRedactionFilter(logging.Filter):
|
||||
"""Filter to redact sensitive information from logs"""
|
||||
|
||||
SENSITIVE_FIELDS = [
|
||||
"password", "token", "secret", "key", "authorization",
|
||||
"cookie", "session", "csrf", "api_key", "jwt"
|
||||
]
|
||||
|
||||
def filter(self, record):
|
||||
if hasattr(record, 'args') and record.args:
|
||||
# Redact sensitive information from log messages
|
||||
record.args = self._redact_sensitive_data(record.args)
|
||||
|
||||
if hasattr(record, 'msg') and isinstance(record.msg, str):
|
||||
for field in self.SENSITIVE_FIELDS:
|
||||
if field.lower() in record.msg.lower():
|
||||
record.msg = record.msg.replace(field, "[REDACTED]")
|
||||
|
||||
return True
|
||||
|
||||
def _redact_sensitive_data(self, data):
|
||||
"""Recursively redact sensitive data from log arguments"""
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
key: "[REDACTED]" if any(sensitive in key.lower() for sensitive in self.SENSITIVE_FIELDS)
|
||||
else self._redact_sensitive_data(value)
|
||||
for key, value in data.items()
|
||||
}
|
||||
elif isinstance(data, (list, tuple)):
|
||||
return type(data)(self._redact_sensitive_data(item) for item in data)
|
||||
return data
|
||||
|
||||
|
||||
def setup_security_logging():
|
||||
"""Setup security-focused logging with redaction"""
|
||||
security_filter = SecurityRedactionFilter()
|
||||
|
||||
# Add security filter to all loggers
|
||||
for name in ["app", "uvicorn", "sqlalchemy"]:
|
||||
logger = logging.getLogger(name)
|
||||
logger.addFilter(security_filter)
|
||||
175
apps/tenant-backend/app/core/path_security.py
Normal file
175
apps/tenant-backend/app/core/path_security.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Path Security Utilities for GT AI OS
|
||||
|
||||
Provides path sanitization and validation to prevent path traversal attacks.
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def sanitize_path_component(component: str) -> str:
|
||||
"""
|
||||
Sanitize a single path component to prevent path traversal.
|
||||
|
||||
Removes or replaces dangerous characters including:
|
||||
- Path separators (/ and \\)
|
||||
- Parent directory references (..)
|
||||
- Null bytes
|
||||
- Other special characters
|
||||
|
||||
Args:
|
||||
component: The path component to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized component safe for use in file paths
|
||||
"""
|
||||
if not component:
|
||||
return ""
|
||||
|
||||
# Remove null bytes
|
||||
sanitized = component.replace('\x00', '')
|
||||
|
||||
# Remove path separators
|
||||
sanitized = re.sub(r'[/\\]', '', sanitized)
|
||||
|
||||
# Remove parent directory references
|
||||
sanitized = sanitized.replace('..', '')
|
||||
|
||||
# For tenant domains and similar identifiers, allow alphanumeric, hyphen, underscore
|
||||
# For filenames, allow alphanumeric, hyphen, underscore, and single dots
|
||||
sanitized = re.sub(r'[^a-zA-Z0-9_\-.]', '_', sanitized)
|
||||
|
||||
# Prevent leading dots (hidden files) and multiple consecutive dots
|
||||
sanitized = re.sub(r'^\.+', '', sanitized)
|
||||
sanitized = re.sub(r'\.{2,}', '.', sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_tenant_domain(domain: str) -> str:
|
||||
"""
|
||||
Sanitize a tenant domain for safe use in file paths.
|
||||
|
||||
More restrictive than general path component sanitization.
|
||||
Only allows lowercase alphanumeric characters, hyphens, and underscores.
|
||||
|
||||
Args:
|
||||
domain: The tenant domain to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized domain safe for use in file paths
|
||||
"""
|
||||
if not domain:
|
||||
raise ValueError("Tenant domain cannot be empty")
|
||||
|
||||
# Convert to lowercase and sanitize
|
||||
sanitized = domain.lower()
|
||||
sanitized = re.sub(r'[^a-z0-9_\-]', '_', sanitized)
|
||||
sanitized = sanitized.strip('_-')
|
||||
|
||||
if not sanitized:
|
||||
raise ValueError("Tenant domain resulted in empty string after sanitization")
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""
|
||||
Sanitize a filename for safe storage.
|
||||
|
||||
Preserves the file extension but sanitizes the rest.
|
||||
|
||||
Args:
|
||||
filename: The filename to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized filename
|
||||
"""
|
||||
if not filename:
|
||||
return ""
|
||||
|
||||
# Get the extension
|
||||
path = Path(filename)
|
||||
stem = path.stem
|
||||
suffix = path.suffix
|
||||
|
||||
# Sanitize the stem (filename without extension)
|
||||
safe_stem = sanitize_path_component(stem)
|
||||
|
||||
# Sanitize the extension (should just be alphanumeric)
|
||||
safe_suffix = ""
|
||||
if suffix:
|
||||
safe_suffix = '.' + re.sub(r'[^a-zA-Z0-9]', '', suffix[1:])
|
||||
|
||||
result = safe_stem + safe_suffix
|
||||
|
||||
if not result:
|
||||
result = "unnamed"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def safe_join_path(base: Path, *components: str, require_within_base: bool = True) -> Path:
|
||||
"""
|
||||
Safely join path components, preventing traversal attacks.
|
||||
|
||||
Args:
|
||||
base: The base directory that all paths must stay within
|
||||
components: Path components to join to the base
|
||||
require_within_base: If True, verify the result is within base
|
||||
|
||||
Returns:
|
||||
The joined path
|
||||
|
||||
Raises:
|
||||
ValueError: If the resulting path would be outside the base directory
|
||||
"""
|
||||
if not base:
|
||||
raise ValueError("Base path cannot be empty")
|
||||
|
||||
# Sanitize all components
|
||||
sanitized = [sanitize_path_component(c) for c in components if c]
|
||||
|
||||
# Filter out empty components
|
||||
sanitized = [c for c in sanitized if c]
|
||||
|
||||
if not sanitized:
|
||||
return base
|
||||
|
||||
# Join the path
|
||||
result = base.joinpath(*sanitized)
|
||||
|
||||
# Verify the result is within the base directory
|
||||
if require_within_base:
|
||||
try:
|
||||
resolved_base = base.resolve()
|
||||
resolved_result = result.resolve()
|
||||
|
||||
# Check if result is within base
|
||||
resolved_result.relative_to(resolved_base)
|
||||
except (ValueError, RuntimeError):
|
||||
raise ValueError(f"Path traversal detected: result would be outside base directory")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_file_extension(filename: str, allowed_extensions: Optional[list] = None) -> bool:
|
||||
"""
|
||||
Validate that a file has an allowed extension.
|
||||
|
||||
Args:
|
||||
filename: The filename to check
|
||||
allowed_extensions: List of allowed extensions (e.g., ['.txt', '.pdf']).
|
||||
If None, all extensions are allowed.
|
||||
|
||||
Returns:
|
||||
True if the extension is allowed, False otherwise
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
return True
|
||||
|
||||
path = Path(filename)
|
||||
extension = path.suffix.lower()
|
||||
|
||||
return extension in [ext.lower() for ext in allowed_extensions]
|
||||
138
apps/tenant-backend/app/core/permissions.py
Normal file
138
apps/tenant-backend/app/core/permissions.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
GT 2.0 Role-Based Permissions
|
||||
Enforces organization-level resource sharing based on user roles.
|
||||
|
||||
Visibility Levels:
|
||||
- individual: Only the creator can see and edit
|
||||
- organization: All users can read, only admins/developers can create and edit
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Role hierarchy: admin/developer > analyst > student
|
||||
ADMIN_ROLES = ["admin", "developer"]
|
||||
|
||||
# Visibility levels
|
||||
VISIBILITY_INDIVIDUAL = "individual"
|
||||
VISIBILITY_ORGANIZATION = "organization"
|
||||
|
||||
|
||||
async def get_user_role(pg_client, user_email: str, tenant_domain: str) -> str:
|
||||
"""
|
||||
Get the role for a user in the tenant database.
|
||||
Returns: 'admin', 'developer', 'analyst', or 'student'
|
||||
"""
|
||||
query = """
|
||||
SELECT role FROM users
|
||||
WHERE email = $1
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
LIMIT 1
|
||||
"""
|
||||
role = await pg_client.fetch_scalar(query, user_email, tenant_domain)
|
||||
return role or "student"
|
||||
|
||||
|
||||
def can_share_to_organization(user_role: str) -> bool:
|
||||
"""
|
||||
Check if a user can share resources at the organization level.
|
||||
Only admin and developer roles can share to organization.
|
||||
"""
|
||||
return user_role in ADMIN_ROLES
|
||||
|
||||
|
||||
def validate_visibility_permission(visibility: str, user_role: str) -> None:
|
||||
"""
|
||||
Validate that the user has permission to set the given visibility level.
|
||||
Raises HTTPException if not authorized.
|
||||
|
||||
Rules:
|
||||
- admin/developer: Can set individual or organization visibility
|
||||
- analyst/student: Can only set individual visibility
|
||||
"""
|
||||
if visibility == VISIBILITY_ORGANIZATION and not can_share_to_organization(user_role):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Only admin and developer users can share resources to organization. Your role: {user_role}"
|
||||
)
|
||||
|
||||
|
||||
def can_edit_resource(resource_creator_id: str, current_user_id: str, user_role: str, resource_visibility: str) -> bool:
|
||||
"""
|
||||
Check if user can edit a resource.
|
||||
|
||||
Rules:
|
||||
- Owner can always edit their own resources
|
||||
- Admin/developer can edit any resource
|
||||
- Organization-shared resources: read-only for non-admins who didn't create it
|
||||
"""
|
||||
# Admin and developer can edit anything
|
||||
if user_role in ADMIN_ROLES:
|
||||
return True
|
||||
|
||||
# Owner can always edit
|
||||
if resource_creator_id == current_user_id:
|
||||
return True
|
||||
|
||||
# Organization resources are read-only for non-admins
|
||||
return False
|
||||
|
||||
|
||||
def can_delete_resource(resource_creator_id: str, current_user_id: str, user_role: str) -> bool:
|
||||
"""
|
||||
Check if user can delete a resource.
|
||||
|
||||
Rules:
|
||||
- Owner can delete their own resources
|
||||
- Admin/developer can delete any resource
|
||||
- Others cannot delete
|
||||
"""
|
||||
# Admin and developer can delete anything
|
||||
if user_role in ADMIN_ROLES:
|
||||
return True
|
||||
|
||||
# Owner can delete
|
||||
if resource_creator_id == current_user_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_effective_owner(resource_creator_id: str, current_user_id: str, user_role: str) -> bool:
|
||||
"""
|
||||
Check if user is effective owner of a resource.
|
||||
|
||||
Effective owners have identical access to actual owners:
|
||||
- Actual resource creator
|
||||
- Admin/developer users (tenant admins)
|
||||
|
||||
This determines whether user gets owner-level field visibility in ResponseFilter
|
||||
and whether they can perform owner-only actions like sharing.
|
||||
|
||||
Note: Tenant isolation is enforced at query level via tenant_id checks.
|
||||
This function only determines ownership semantics within the tenant.
|
||||
|
||||
Args:
|
||||
resource_creator_id: UUID of resource creator
|
||||
current_user_id: UUID of current user
|
||||
user_role: User's role in tenant (admin, developer, analyst, student)
|
||||
|
||||
Returns:
|
||||
True if user should have owner-level access
|
||||
|
||||
Examples:
|
||||
>>> is_effective_owner("user123", "admin456", "admin")
|
||||
True # Admin has owner-level access to all resources
|
||||
>>> is_effective_owner("user123", "user123", "student")
|
||||
True # Actual owner
|
||||
>>> is_effective_owner("user123", "user456", "analyst")
|
||||
False # Different user, not admin
|
||||
"""
|
||||
# Admins and developers have identical access to owners
|
||||
if user_role in ADMIN_ROLES:
|
||||
return True
|
||||
|
||||
# Actual owner
|
||||
return resource_creator_id == current_user_id
|
||||
498
apps/tenant-backend/app/core/postgresql_client.py
Normal file
498
apps/tenant-backend/app/core/postgresql_client.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
GT 2.0 PostgreSQL + PGVector Client for Tenant Backend
|
||||
|
||||
Replaces DuckDB service with direct PostgreSQL connections, providing:
|
||||
- PostgreSQL + PGVector unified storage (replaces DuckDB + ChromaDB)
|
||||
- BionicGPT Row Level Security patterns for enterprise isolation
|
||||
- MVCC concurrency solving DuckDB file locking issues
|
||||
- Hybrid vector + full-text search in single queries
|
||||
- Connection pooling for 10,000+ concurrent connections
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator, Tuple, Union
|
||||
from contextlib import asynccontextmanager
|
||||
import json
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
from asyncpg import Pool, Connection
|
||||
from asyncpg.exceptions import PostgresError
|
||||
|
||||
from app.core.config import get_settings, get_tenant_schema_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PostgreSQLClient:
|
||||
"""PostgreSQL + PGVector client for tenant backend operations"""
|
||||
|
||||
def __init__(self, database_url: str, tenant_domain: str):
|
||||
self.database_url = database_url
|
||||
self.tenant_domain = tenant_domain
|
||||
self.schema_name = get_tenant_schema_name(tenant_domain)
|
||||
self._pool: Optional[Pool] = None
|
||||
self._initialized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize connection pool and verify schema"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info(f"Initializing PostgreSQL connection pool for tenant: {self.tenant_domain}")
|
||||
logger.info(f"Schema: {self.schema_name}, URL: {self.database_url}")
|
||||
|
||||
try:
|
||||
# Create connection pool with resilient settings
|
||||
# Sized for 100+ concurrent users with RAG/vector search workloads
|
||||
self._pool = await asyncpg.create_pool(
|
||||
self.database_url,
|
||||
min_size=10,
|
||||
max_size=50, # Increased from 20 to handle 100+ concurrent users
|
||||
command_timeout=120, # Increased from 60s for queries under load
|
||||
timeout=10, # Connection acquire timeout increased for high load
|
||||
max_inactive_connection_lifetime=3600, # Recycle connections after 1 hour
|
||||
server_settings={
|
||||
'application_name': f'gt2_tenant_{self.tenant_domain}'
|
||||
},
|
||||
# Enable prepared statements for direct postgres connection (performance gain)
|
||||
statement_cache_size=100
|
||||
)
|
||||
|
||||
# Verify schema exists and has required tables
|
||||
await self._verify_schema()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"PostgreSQL client initialized successfully for tenant: {self.tenant_domain}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL client: {e}")
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close connection pool"""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
self._initialized = False
|
||||
logger.info(f"PostgreSQL connection pool closed for tenant: {self.tenant_domain}")
|
||||
|
||||
async def _verify_schema(self) -> None:
|
||||
"""Verify tenant schema exists and has required tables"""
|
||||
async with self._pool.acquire() as conn:
|
||||
# Check if schema exists
|
||||
schema_exists = await conn.fetchval("""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.schemata
|
||||
WHERE schema_name = $1
|
||||
)
|
||||
""", self.schema_name)
|
||||
|
||||
if not schema_exists:
|
||||
raise RuntimeError(f"Tenant schema '{self.schema_name}' does not exist. Run schema initialization first.")
|
||||
|
||||
# Check for required tables
|
||||
required_tables = ['tenants', 'users', 'agents', 'datasets', 'conversations', 'messages', 'documents', 'document_chunks']
|
||||
|
||||
for table in required_tables:
|
||||
table_exists = await conn.fetchval(f"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = $1 AND table_name = $2
|
||||
)
|
||||
""", self.schema_name, table)
|
||||
|
||||
if not table_exists:
|
||||
logger.warning(f"Table '{table}' not found in schema '{self.schema_name}'")
|
||||
|
||||
logger.info(f"Schema verification complete for tenant: {self.tenant_domain}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection(self) -> AsyncGenerator[Connection, None]:
|
||||
"""Get a connection from the pool"""
|
||||
if not self._pool:
|
||||
raise RuntimeError("PostgreSQL client not initialized. Call initialize() first.")
|
||||
|
||||
async with self._pool.acquire() as conn:
|
||||
try:
|
||||
# Set schema search path for this connection
|
||||
await conn.execute(f"SET search_path TO {self.schema_name}, public")
|
||||
|
||||
# Session variable logging removed - no longer using RLS
|
||||
|
||||
yield conn
|
||||
except Exception as e:
|
||||
logger.error(f"Database connection error: {e}")
|
||||
raise
|
||||
|
||||
async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
|
||||
"""Execute a SELECT query and return results"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
rows = await conn.fetch(query, *args)
|
||||
return [dict(row) for row in rows]
|
||||
except PostgresError as e:
|
||||
logger.error(f"Query execution failed: {e}, Query: {query}")
|
||||
raise
|
||||
|
||||
async def execute_command(self, command: str, *args) -> int:
|
||||
"""Execute an INSERT/UPDATE/DELETE command and return affected rows"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
result = await conn.execute(command, *args)
|
||||
# Parse result like "INSERT 0 5" to get affected rows
|
||||
return int(result.split()[-1]) if result else 0
|
||||
except PostgresError as e:
|
||||
logger.error(f"Command execution failed: {e}, Command: {command}")
|
||||
raise
|
||||
|
||||
async def fetch_one(self, query: str, *args) -> Optional[Dict[str, Any]]:
|
||||
"""Execute query and return first row"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
row = await conn.fetchrow(query, *args)
|
||||
return dict(row) if row else None
|
||||
except PostgresError as e:
|
||||
logger.error(f"Fetch one failed: {e}, Query: {query}")
|
||||
raise
|
||||
|
||||
async def fetch_scalar(self, query: str, *args) -> Any:
|
||||
"""Execute query and return single value"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
return await conn.fetchval(query, *args)
|
||||
except PostgresError as e:
|
||||
logger.error(f"Fetch scalar failed: {e}, Query: {query}")
|
||||
raise
|
||||
|
||||
async def execute_transaction(self, commands: List[Tuple[str, tuple]]) -> List[int]:
|
||||
"""Execute multiple commands in a transaction"""
|
||||
async with self.get_connection() as conn:
|
||||
async with conn.transaction():
|
||||
results = []
|
||||
for command, args in commands:
|
||||
try:
|
||||
result = await conn.execute(command, *args)
|
||||
results.append(int(result.split()[-1]) if result else 0)
|
||||
except PostgresError as e:
|
||||
logger.error(f"Transaction command failed: {e}, Command: {command}")
|
||||
raise
|
||||
return results
|
||||
|
||||
# Vector Search Operations (PGVector)
|
||||
|
||||
async def vector_similarity_search(
|
||||
self,
|
||||
query_vector: List[float],
|
||||
table: str = "document_chunks",
|
||||
limit: int = 10,
|
||||
similarity_threshold: float = 0.3,
|
||||
user_id: Optional[str] = None,
|
||||
dataset_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Perform vector similarity search using PGVector"""
|
||||
|
||||
# Convert Python list to PostgreSQL array format
|
||||
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
id,
|
||||
content,
|
||||
1 - (embedding <=> $1::vector) as similarity_score,
|
||||
metadata
|
||||
FROM {table}
|
||||
WHERE embedding IS NOT NULL
|
||||
AND 1 - (embedding <=> $1::vector) > $2
|
||||
"""
|
||||
|
||||
params = [vector_str, similarity_threshold]
|
||||
param_idx = 3
|
||||
|
||||
# Add user isolation if specified
|
||||
if user_id:
|
||||
query += f" AND user_id = ${param_idx}"
|
||||
params.append(user_id)
|
||||
param_idx += 1
|
||||
|
||||
# Add dataset filtering if specified
|
||||
if dataset_id:
|
||||
query += f" AND dataset_id = ${param_idx}"
|
||||
params.append(dataset_id)
|
||||
param_idx += 1
|
||||
|
||||
query += f" ORDER BY embedding <=> $1::vector LIMIT ${param_idx}"
|
||||
params.append(limit)
|
||||
|
||||
return await self.execute_query(query, *params)
|
||||
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query_text: str,
|
||||
query_vector: List[float],
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
similarity_threshold: float = 0.3,
|
||||
text_weight: float = 0.3,
|
||||
vector_weight: float = 0.7,
|
||||
dataset_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Perform hybrid search combining vector similarity and full-text search"""
|
||||
|
||||
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
|
||||
|
||||
# Use the enhanced_hybrid_search_chunks function from BionicGPT integration
|
||||
query = """
|
||||
SELECT
|
||||
id,
|
||||
document_id,
|
||||
content,
|
||||
similarity_score,
|
||||
text_rank,
|
||||
combined_score,
|
||||
metadata,
|
||||
access_verified
|
||||
FROM enhanced_hybrid_search_chunks($1, $2::vector, $3::uuid, $4, $5, $6, $7, $8)
|
||||
"""
|
||||
|
||||
return await self.execute_query(
|
||||
query,
|
||||
query_text,
|
||||
vector_str,
|
||||
user_id,
|
||||
dataset_id,
|
||||
limit,
|
||||
similarity_threshold,
|
||||
text_weight,
|
||||
vector_weight
|
||||
)
|
||||
|
||||
async def insert_document_chunk(
|
||||
self,
|
||||
document_id: str,
|
||||
tenant_id: int,
|
||||
user_id: str,
|
||||
chunk_index: int,
|
||||
content: str,
|
||||
content_hash: str,
|
||||
embedding: List[float],
|
||||
dataset_id: Optional[str] = None,
|
||||
token_count: int = 0,
|
||||
metadata: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""Insert a document chunk with vector embedding"""
|
||||
|
||||
vector_str = '[' + ','.join(map(str, embedding)) + ']'
|
||||
metadata_json = json.dumps(metadata or {})
|
||||
|
||||
query = """
|
||||
INSERT INTO document_chunks (
|
||||
document_id, tenant_id, user_id, dataset_id, chunk_index,
|
||||
content, content_hash, token_count, embedding, metadata
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector, $10::jsonb)
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
return await self.fetch_scalar(
|
||||
query,
|
||||
document_id, tenant_id, user_id, dataset_id, chunk_index,
|
||||
content, content_hash, token_count, vector_str, metadata_json
|
||||
)
|
||||
|
||||
# Health Check and Statistics
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check on PostgreSQL connection"""
|
||||
try:
|
||||
if not self._pool:
|
||||
return {"status": "unhealthy", "reason": "Connection pool not initialized"}
|
||||
|
||||
# Test basic connectivity
|
||||
test_result = await self.fetch_scalar("SELECT 1")
|
||||
|
||||
# Get pool statistics
|
||||
pool_stats = {
|
||||
"size": self._pool.get_size(),
|
||||
"min_size": self._pool.get_min_size(),
|
||||
"max_size": self._pool.get_max_size(),
|
||||
"idle_size": self._pool.get_idle_size()
|
||||
}
|
||||
|
||||
# Test schema access
|
||||
schema_test = await self.fetch_scalar("""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.schemata
|
||||
WHERE schema_name = $1
|
||||
)
|
||||
""", self.schema_name)
|
||||
|
||||
return {
|
||||
"status": "healthy" if test_result == 1 and schema_test else "degraded",
|
||||
"connectivity": "ok" if test_result == 1 else "failed",
|
||||
"schema_access": "ok" if schema_test else "failed",
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"schema_name": self.schema_name,
|
||||
"pool_stats": pool_stats,
|
||||
"database_type": "postgresql_pgvector"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL health check failed: {e}")
|
||||
return {"status": "unhealthy", "reason": str(e)}
|
||||
|
||||
async def get_database_stats(self) -> Dict[str, Any]:
|
||||
"""Get database statistics for monitoring"""
|
||||
try:
|
||||
# Get table counts and sizes
|
||||
stats_query = """
|
||||
SELECT
|
||||
schemaname,
|
||||
tablename,
|
||||
n_tup_ins as inserts,
|
||||
n_tup_upd as updates,
|
||||
n_tup_del as deletes,
|
||||
n_live_tup as live_tuples,
|
||||
n_dead_tup as dead_tuples
|
||||
FROM pg_stat_user_tables
|
||||
WHERE schemaname = $1
|
||||
"""
|
||||
|
||||
table_stats = await self.execute_query(stats_query, self.schema_name)
|
||||
|
||||
# Get total schema size
|
||||
size_query = """
|
||||
SELECT pg_size_pretty(
|
||||
SUM(pg_total_relation_size(quote_ident(schemaname)||'.'||quote_ident(tablename)))
|
||||
) as schema_size
|
||||
FROM pg_tables
|
||||
WHERE schemaname = $1
|
||||
"""
|
||||
|
||||
schema_size = await self.fetch_scalar(size_query, self.schema_name)
|
||||
|
||||
# Get vector index statistics if available
|
||||
vector_stats_query = """
|
||||
SELECT
|
||||
COUNT(*) as vector_count,
|
||||
AVG(vector_dims(embedding)) as avg_dimensions
|
||||
FROM document_chunks
|
||||
WHERE embedding IS NOT NULL
|
||||
"""
|
||||
|
||||
try:
|
||||
vector_stats = await self.fetch_one(vector_stats_query)
|
||||
except:
|
||||
vector_stats = {"vector_count": 0, "avg_dimensions": 0}
|
||||
|
||||
return {
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"schema_name": self.schema_name,
|
||||
"schema_size": schema_size,
|
||||
"table_stats": table_stats,
|
||||
"vector_stats": vector_stats,
|
||||
"engine_type": "PostgreSQL + PGVector",
|
||||
"mvcc_enabled": True,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database statistics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Global client instance (singleton pattern for tenant backend)
|
||||
_pg_client: Optional[PostgreSQLClient] = None
|
||||
|
||||
|
||||
async def get_postgresql_client() -> PostgreSQLClient:
|
||||
"""Get or create PostgreSQL client instance"""
|
||||
global _pg_client
|
||||
|
||||
if not _pg_client:
|
||||
settings = get_settings()
|
||||
_pg_client = PostgreSQLClient(
|
||||
database_url=settings.database_url,
|
||||
tenant_domain=settings.tenant_domain
|
||||
)
|
||||
await _pg_client.initialize()
|
||||
|
||||
return _pg_client
|
||||
|
||||
|
||||
async def init_postgresql() -> None:
|
||||
"""Initialize PostgreSQL client during startup"""
|
||||
logger.info("Initializing PostgreSQL client...")
|
||||
await get_postgresql_client()
|
||||
logger.info("PostgreSQL client initialized successfully")
|
||||
|
||||
|
||||
async def close_postgresql() -> None:
|
||||
"""Close PostgreSQL client during shutdown"""
|
||||
global _pg_client
|
||||
|
||||
if _pg_client:
|
||||
await _pg_client.close()
|
||||
_pg_client = None
|
||||
logger.info("PostgreSQL client closed")
|
||||
|
||||
|
||||
# Context manager for database operations
|
||||
@asynccontextmanager
|
||||
async def get_db_session():
|
||||
"""Async context manager for database operations"""
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
# Convenience functions for common operations
|
||||
async def execute_query(query: str, *args) -> List[Dict[str, Any]]:
|
||||
"""Execute a SELECT query"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.execute_query(query, *args)
|
||||
|
||||
|
||||
async def execute_command(command: str, *args) -> int:
|
||||
"""Execute an INSERT/UPDATE/DELETE command"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.execute_command(command, *args)
|
||||
|
||||
|
||||
async def fetch_one(query: str, *args) -> Optional[Dict[str, Any]]:
|
||||
"""Execute query and return first row"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.fetch_one(query, *args)
|
||||
|
||||
|
||||
async def fetch_scalar(query: str, *args) -> Any:
|
||||
"""Execute query and return single value"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.fetch_scalar(query, *args)
|
||||
|
||||
|
||||
async def health_check() -> Dict[str, Any]:
|
||||
"""Perform database health check"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
return await client.health_check()
|
||||
except Exception as e:
|
||||
return {"status": "unhealthy", "reason": str(e)}
|
||||
|
||||
|
||||
async def get_database_info() -> Dict[str, Any]:
|
||||
"""Get database information and statistics"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
return await client.get_database_stats()
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
531
apps/tenant-backend/app/core/resource_client.py
Normal file
531
apps/tenant-backend/app/core/resource_client.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""
|
||||
Resource Cluster Client for GT 2.0 Tenant Backend
|
||||
|
||||
Provides stateless access to Resource Cluster services including:
|
||||
- Document processing
|
||||
- Embedding generation
|
||||
- Vector storage (ChromaDB)
|
||||
- Model inference
|
||||
|
||||
Perfect tenant isolation with capability-based authentication.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import gc
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.capability_client import CapabilityClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResourceClusterClient:
|
||||
"""
|
||||
Client for accessing Resource Cluster services with capability-based auth.
|
||||
|
||||
GT 2.0 Security Principles:
|
||||
- Capability tokens for fine-grained access control
|
||||
- Stateless operations (no data persistence in Resource Cluster)
|
||||
- Perfect tenant isolation
|
||||
- Immediate memory cleanup
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.capability_client = CapabilityClient()
|
||||
|
||||
# Resource Cluster endpoints
|
||||
# IMPORTANT: Use Docker service name for stability across container restarts
|
||||
# Fixed 2025-09-12: Changed from hardcoded IP to service name for reliability
|
||||
self.base_url = getattr(
|
||||
self.settings,
|
||||
'resource_cluster_url', # Matches Pydantic field name (case insensitive)
|
||||
'http://gentwo-resource-backend:8000' # Fallback uses service name, not IP
|
||||
)
|
||||
|
||||
self.endpoints = {
|
||||
'document_processor': f"{self.base_url}/api/v1/process/document",
|
||||
'embedding_generator': f"{self.base_url}/api/v1/embeddings/generate",
|
||||
'chromadb_backend': f"{self.base_url}/api/v1/vectors",
|
||||
'inference': f"{self.base_url}/api/v1/ai/chat/completions" # Updated to match actual endpoint
|
||||
}
|
||||
|
||||
# Request timeouts
|
||||
self.request_timeout = 300 # seconds - 5 minutes for complex agent operations
|
||||
self.upload_timeout = 300 # seconds for large documents
|
||||
|
||||
logger.info("Resource Cluster client initialized")
|
||||
|
||||
async def _get_capability_token(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
resources: List[str]
|
||||
) -> str:
|
||||
"""Generate capability token for Resource Cluster access"""
|
||||
try:
|
||||
token = await self.capability_client.generate_capability_token(
|
||||
user_email=user_id, # Using user_id as email for now
|
||||
tenant_id=tenant_id,
|
||||
resources=resources,
|
||||
expires_hours=1
|
||||
)
|
||||
return token
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate capability token: {e}")
|
||||
raise
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
data: Dict[str, Any],
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
resources: List[str],
|
||||
timeout: int = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Make authenticated request to Resource Cluster"""
|
||||
try:
|
||||
# Get capability token
|
||||
token = await self._get_capability_token(tenant_id, user_id, resources)
|
||||
|
||||
# Prepare headers
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {token}',
|
||||
'X-Tenant-ID': tenant_id,
|
||||
'X-User-ID': user_id,
|
||||
'X-Request-ID': f"{tenant_id}_{user_id}_{datetime.utcnow().timestamp()}"
|
||||
}
|
||||
|
||||
# Make request
|
||||
timeout_config = aiohttp.ClientTimeout(total=timeout or self.request_timeout)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout_config) as session:
|
||||
async with session.request(
|
||||
method=method.upper(),
|
||||
url=endpoint,
|
||||
json=data,
|
||||
headers=headers
|
||||
) as response:
|
||||
|
||||
if response.status not in [200, 201]:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(
|
||||
f"Resource Cluster error: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Resource Cluster request failed: {e}")
|
||||
raise
|
||||
|
||||
# Document Processing
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
content: bytes,
|
||||
document_type: str,
|
||||
strategy_type: str = "hybrid",
|
||||
tenant_id: str = None,
|
||||
user_id: str = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Process document into chunks via Resource Cluster"""
|
||||
try:
|
||||
# Convert bytes to base64 for JSON transport
|
||||
import base64
|
||||
content_b64 = base64.b64encode(content).decode('utf-8')
|
||||
|
||||
request_data = {
|
||||
"content": content_b64,
|
||||
"document_type": document_type,
|
||||
"strategy": {
|
||||
"strategy_type": strategy_type,
|
||||
"chunk_size": 512,
|
||||
"chunk_overlap": 128
|
||||
}
|
||||
}
|
||||
|
||||
# Clear original content from memory
|
||||
del content
|
||||
gc.collect()
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=self.endpoints['document_processor'],
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['document_processing'],
|
||||
timeout=self.upload_timeout
|
||||
)
|
||||
|
||||
chunks = result.get('chunks', [])
|
||||
logger.info(f"Processed document into {len(chunks)} chunks")
|
||||
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Document processing failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
# Embedding Generation
|
||||
|
||||
async def generate_document_embeddings(
|
||||
self,
|
||||
documents: List[str],
|
||||
tenant_id: str,
|
||||
user_id: str
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings for documents"""
|
||||
try:
|
||||
request_data = {
|
||||
"texts": documents,
|
||||
"model": "BAAI/bge-m3",
|
||||
"instruction": None # Document embeddings don't need instruction
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=self.endpoints['embedding_generator'],
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['embedding_generation']
|
||||
)
|
||||
|
||||
embeddings = result.get('embeddings', [])
|
||||
|
||||
# Clear documents from memory
|
||||
del documents
|
||||
gc.collect()
|
||||
|
||||
logger.info(f"Generated {len(embeddings)} document embeddings")
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Document embedding generation failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def generate_query_embeddings(
|
||||
self,
|
||||
queries: List[str],
|
||||
tenant_id: str,
|
||||
user_id: str
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings for queries"""
|
||||
try:
|
||||
request_data = {
|
||||
"texts": queries,
|
||||
"model": "BAAI/bge-m3",
|
||||
"instruction": "Represent this sentence for searching relevant passages: "
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=self.endpoints['embedding_generator'],
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['embedding_generation']
|
||||
)
|
||||
|
||||
embeddings = result.get('embeddings', [])
|
||||
|
||||
# Clear queries from memory
|
||||
del queries
|
||||
gc.collect()
|
||||
|
||||
logger.info(f"Generated {len(embeddings)} query embeddings")
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Query embedding generation failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
# Vector Storage (ChromaDB)
|
||||
|
||||
async def create_vector_collection(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
dataset_name: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""Create vector collection in ChromaDB"""
|
||||
try:
|
||||
request_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"dataset_name": dataset_name,
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=f"{self.endpoints['chromadb_backend']}/collections",
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['vector_storage']
|
||||
)
|
||||
|
||||
success = result.get('success', False)
|
||||
logger.info(f"Created vector collection for {dataset_name}: {success}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector collection creation failed: {e}")
|
||||
raise
|
||||
|
||||
async def store_vectors(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
dataset_name: str,
|
||||
documents: List[str],
|
||||
embeddings: List[List[float]],
|
||||
metadata: List[Dict[str, Any]] = None,
|
||||
ids: List[str] = None
|
||||
) -> bool:
|
||||
"""Store vectors in ChromaDB"""
|
||||
try:
|
||||
request_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"dataset_name": dataset_name,
|
||||
"documents": documents,
|
||||
"embeddings": embeddings,
|
||||
"metadata": metadata or [],
|
||||
"ids": ids
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=f"{self.endpoints['chromadb_backend']}/store",
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['vector_storage']
|
||||
)
|
||||
|
||||
# Clear vectors from memory immediately
|
||||
del documents, embeddings
|
||||
gc.collect()
|
||||
|
||||
success = result.get('success', False)
|
||||
logger.info(f"Stored vectors in {dataset_name}: {success}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector storage failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def search_vectors(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
dataset_name: str,
|
||||
query_embedding: List[float],
|
||||
top_k: int = 5,
|
||||
filter_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search vectors in ChromaDB"""
|
||||
try:
|
||||
request_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"dataset_name": dataset_name,
|
||||
"query_embedding": query_embedding,
|
||||
"top_k": top_k,
|
||||
"filter_metadata": filter_metadata or {}
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=f"{self.endpoints['chromadb_backend']}/search",
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['vector_storage']
|
||||
)
|
||||
|
||||
# Clear query embedding from memory
|
||||
del query_embedding
|
||||
gc.collect()
|
||||
|
||||
results = result.get('results', [])
|
||||
logger.info(f"Found {len(results)} vector search results")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector search failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def delete_vector_collection(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
dataset_name: str
|
||||
) -> bool:
|
||||
"""Delete vector collection from ChromaDB"""
|
||||
try:
|
||||
request_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"dataset_name": dataset_name
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='DELETE',
|
||||
endpoint=f"{self.endpoints['chromadb_backend']}/collections",
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['vector_storage']
|
||||
)
|
||||
|
||||
success = result.get('success', False)
|
||||
logger.info(f"Deleted vector collection {dataset_name}: {success}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector collection deletion failed: {e}")
|
||||
raise
|
||||
|
||||
# Model Inference
|
||||
|
||||
async def inference_with_context(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
context: str,
|
||||
model: str = "llama-3.1-70b-versatile",
|
||||
tenant_id: str = None,
|
||||
user_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform inference with RAG context"""
|
||||
try:
|
||||
# Inject context into system message
|
||||
enhanced_messages = []
|
||||
system_context = f"Use the following context to answer the user's question:\n\n{context}\n\n"
|
||||
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
enhanced_msg = msg.copy()
|
||||
enhanced_msg["content"] = system_context + enhanced_msg["content"]
|
||||
enhanced_messages.append(enhanced_msg)
|
||||
else:
|
||||
enhanced_messages.append(msg)
|
||||
|
||||
# Add system message if none exists
|
||||
if not any(msg.get("role") == "system" for msg in enhanced_messages):
|
||||
enhanced_messages.insert(0, {
|
||||
"role": "system",
|
||||
"content": system_context + "You are a helpful AI agent."
|
||||
})
|
||||
|
||||
request_data = {
|
||||
"messages": enhanced_messages,
|
||||
"model": model,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4000,
|
||||
"user_id": user_id,
|
||||
"tenant_id": tenant_id
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=self.endpoints['inference'],
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['llm_inference']
|
||||
)
|
||||
|
||||
# Clear context from memory
|
||||
del context, enhanced_messages
|
||||
gc.collect()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Inference with context failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def check_health(self) -> Dict[str, Any]:
|
||||
"""Check Resource Cluster health"""
|
||||
try:
|
||||
# Test basic connectivity
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.base_url}/health") as response:
|
||||
if response.status == 200:
|
||||
health_data = await response.json()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"resource_cluster": health_data,
|
||||
"endpoints": list(self.endpoints.keys()),
|
||||
"base_url": self.base_url
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": f"Health check failed: {response.status}",
|
||||
"base_url": self.base_url
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"base_url": self.base_url
|
||||
}
|
||||
|
||||
async def call_inference_endpoint(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
endpoint: str = "chat/completions",
|
||||
data: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Call AI inference endpoint on Resource Cluster"""
|
||||
try:
|
||||
# Use the direct inference endpoint
|
||||
inference_url = self.endpoints['inference']
|
||||
|
||||
# Add tenant/user context to request
|
||||
request_data = data.copy() if data else {}
|
||||
|
||||
# Make request with capability token
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=inference_url,
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['llm'] # Use valid ResourceType from resource cluster
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Inference endpoint call failed: {e}")
|
||||
raise
|
||||
|
||||
# Streaming removed for reliability - using non-streaming only
|
||||
320
apps/tenant-backend/app/core/response_filter.py
Normal file
320
apps/tenant-backend/app/core/response_filter.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Response Filtering Utilities for GT 2.0
|
||||
|
||||
Provides field-level authorization and data filtering for API responses.
|
||||
Implements principle of least privilege - users only see data they're authorized to access.
|
||||
|
||||
Security principles:
|
||||
1. Owner-only fields: resource_preferences, advanced RAG configs (max_chunks_per_query, history_context)
|
||||
2. Viewer fields: Public + usage stats + prompt_template + personality_config + dataset connections
|
||||
(Team members with read access need these fields to effectively use shared agents)
|
||||
3. Public fields: id, name, description, category, basic metadata
|
||||
4. No internal UUIDs, implementation details, or system configuration exposure
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Set
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseFilter:
|
||||
"""Filter API responses based on user permissions and access level"""
|
||||
|
||||
# Define field access levels for agents
|
||||
# REQUIRED fields that must always be present for AgentResponse schema
|
||||
AGENT_REQUIRED_FIELDS = {
|
||||
'id', 'name', 'description', 'created_at', 'updated_at'
|
||||
}
|
||||
|
||||
AGENT_PUBLIC_FIELDS = AGENT_REQUIRED_FIELDS | {
|
||||
'category', 'conversation_count', 'usage_count', 'is_favorite', 'tags',
|
||||
'created_by_name', 'can_edit', 'can_delete', 'is_owner',
|
||||
# Include these for display purposes
|
||||
'model', 'visibility', 'disclaimer', 'easy_prompts',
|
||||
# Dataset connections for showing dataset count on agent tiles
|
||||
'dataset_connection', 'selected_dataset_ids'
|
||||
}
|
||||
|
||||
AGENT_VIEWER_FIELDS = AGENT_PUBLIC_FIELDS | {
|
||||
'temperature', 'max_tokens', 'total_cost_cents', 'template_id',
|
||||
# Essential fields for using shared agents (team collaboration)
|
||||
'prompt_template', 'personality_config',
|
||||
'dataset_connection', 'selected_dataset_ids'
|
||||
}
|
||||
|
||||
AGENT_OWNER_FIELDS = AGENT_VIEWER_FIELDS | {
|
||||
# Advanced configuration fields (owner-only)
|
||||
'resource_preferences', 'max_chunks_per_query', 'history_context',
|
||||
# Team sharing configuration (owner-only for editing)
|
||||
'team_shares'
|
||||
}
|
||||
|
||||
# Define field access levels for datasets
|
||||
# Fields for all users (public/shared datasets) - stats are informational, not sensitive
|
||||
DATASET_PUBLIC_FIELDS = {
|
||||
'id', 'name', 'description', 'created_by_name', 'owner_name',
|
||||
'document_count', 'chunk_count', 'vector_count', 'storage_size_mb',
|
||||
'tags', 'created_at', 'updated_at', 'access_group',
|
||||
# Permission flags for UI controls
|
||||
'is_owner', 'can_edit', 'can_delete', 'can_share',
|
||||
# Team sharing flag for proper visibility indicators
|
||||
'shared_via_team'
|
||||
}
|
||||
|
||||
DATASET_VIEWER_FIELDS = DATASET_PUBLIC_FIELDS | {
|
||||
'summary' # Viewers can see dataset summary
|
||||
}
|
||||
|
||||
DATASET_OWNER_FIELDS = DATASET_VIEWER_FIELDS | {
|
||||
# Only owners see internal configuration
|
||||
'owner_id', 'team_members', 'chunking_strategy', 'chunk_size',
|
||||
'chunk_overlap', 'embedding_model', 'summary_generated_at',
|
||||
# Team sharing configuration (owner-only for editing)
|
||||
'team_shares'
|
||||
}
|
||||
|
||||
# Define field access levels for files
|
||||
# Public fields include processing info since it's informational metadata, not sensitive
|
||||
FILE_PUBLIC_FIELDS = {
|
||||
'id', 'original_filename', 'content_type', 'file_type', 'file_size', 'file_size_bytes',
|
||||
'created_at', 'updated_at', 'category',
|
||||
# Processing fields - informational, not sensitive
|
||||
'processing_status', 'chunk_count', 'processing_progress', 'processing_stage',
|
||||
# Permission flags for UI controls
|
||||
'can_delete'
|
||||
}
|
||||
|
||||
FILE_OWNER_FIELDS = FILE_PUBLIC_FIELDS | {
|
||||
'user_id', 'dataset_id', 'storage_path', 'metadata'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def filter_agent_response(
|
||||
agent_data: Dict[str, Any],
|
||||
is_owner: bool = False,
|
||||
can_view: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Filter agent response fields based on user permissions
|
||||
|
||||
Args:
|
||||
agent_data: Full agent data dictionary
|
||||
is_owner: Whether user owns this agent
|
||||
can_view: Whether user can view detailed information
|
||||
|
||||
Returns:
|
||||
Filtered dictionary with only authorized fields
|
||||
"""
|
||||
if is_owner:
|
||||
allowed_fields = ResponseFilter.AGENT_OWNER_FIELDS
|
||||
logger.info(f"🔓 Agent '{agent_data.get('name', 'Unknown')}': Using OWNER fields (is_owner=True, can_view={can_view})")
|
||||
elif can_view:
|
||||
allowed_fields = ResponseFilter.AGENT_VIEWER_FIELDS
|
||||
logger.info(f"👁️ Agent '{agent_data.get('name', 'Unknown')}': Using VIEWER fields (is_owner=False, can_view=True)")
|
||||
else:
|
||||
allowed_fields = ResponseFilter.AGENT_PUBLIC_FIELDS
|
||||
logger.info(f"🌍 Agent '{agent_data.get('name', 'Unknown')}': Using PUBLIC fields (is_owner=False, can_view=False)")
|
||||
|
||||
filtered = {
|
||||
key: value for key, value in agent_data.items()
|
||||
if key in allowed_fields
|
||||
}
|
||||
|
||||
# Ensure defaults for optional fields that were filtered out
|
||||
# This prevents AgentResponse schema validation errors
|
||||
default_values = {
|
||||
'personality_config': {},
|
||||
'resource_preferences': {},
|
||||
'tags': [],
|
||||
'easy_prompts': [],
|
||||
'conversation_count': 0,
|
||||
'usage_count': 0,
|
||||
'total_cost_cents': 0,
|
||||
'is_favorite': False,
|
||||
'can_edit': False,
|
||||
'can_delete': False,
|
||||
'is_owner': is_owner
|
||||
}
|
||||
|
||||
for key, default_value in default_values.items():
|
||||
if key not in filtered:
|
||||
filtered[key] = default_value
|
||||
|
||||
# Log field filtering for security audit
|
||||
removed_fields = set(agent_data.keys()) - set(filtered.keys())
|
||||
if removed_fields:
|
||||
logger.info(
|
||||
f"🔒 Filtered agent '{agent_data.get('name', 'Unknown')}' - removed fields: {removed_fields} "
|
||||
f"(is_owner={is_owner}, can_view={can_view})"
|
||||
)
|
||||
|
||||
# Special logging for prompt_template field
|
||||
if 'prompt_template' in agent_data:
|
||||
if 'prompt_template' in filtered:
|
||||
logger.info(f"✅ Agent '{agent_data.get('name', 'Unknown')}': prompt_template INCLUDED in response")
|
||||
else:
|
||||
logger.warning(f"❌ Agent '{agent_data.get('name', 'Unknown')}': prompt_template FILTERED OUT (is_owner={is_owner}, can_view={can_view})")
|
||||
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
def filter_dataset_response(
|
||||
dataset_data: Dict[str, Any],
|
||||
is_owner: bool = False,
|
||||
can_view: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Filter dataset response fields based on user permissions
|
||||
|
||||
Args:
|
||||
dataset_data: Full dataset data dictionary
|
||||
is_owner: Whether user owns this dataset
|
||||
can_view: Whether user can view the dataset
|
||||
|
||||
Returns:
|
||||
Filtered dictionary with only authorized fields
|
||||
"""
|
||||
if is_owner:
|
||||
allowed_fields = ResponseFilter.DATASET_OWNER_FIELDS
|
||||
elif can_view:
|
||||
allowed_fields = ResponseFilter.DATASET_VIEWER_FIELDS
|
||||
else:
|
||||
allowed_fields = ResponseFilter.DATASET_PUBLIC_FIELDS
|
||||
|
||||
filtered = {
|
||||
key: value for key, value in dataset_data.items()
|
||||
if key in allowed_fields
|
||||
}
|
||||
|
||||
# Security: Never expose owner_id UUID to non-owners
|
||||
if not is_owner and 'owner_id' in filtered:
|
||||
del filtered['owner_id']
|
||||
|
||||
# Ensure defaults for optional fields to prevent schema validation errors
|
||||
default_values = {
|
||||
'tags': [],
|
||||
'is_owner': is_owner,
|
||||
'can_edit': False,
|
||||
'can_delete': False,
|
||||
'can_share': False,
|
||||
# Always set these to None for non-owners (security)
|
||||
'team_members': None if not is_owner else filtered.get('team_members', []),
|
||||
'owner_id': None if not is_owner else filtered.get('owner_id'),
|
||||
# Internal fields - null for all except detail view
|
||||
'agent_has_access': None,
|
||||
'user_owns': None,
|
||||
# Stats fields - use actual values or safe defaults for frontend compatibility
|
||||
# These are informational only, not sensitive
|
||||
'chunk_count': filtered.get('chunk_count', 0),
|
||||
'vector_count': filtered.get('vector_count', 0),
|
||||
'storage_size_mb': filtered.get('storage_size_mb', 0.0),
|
||||
'updated_at': filtered.get('updated_at'),
|
||||
'summary': None
|
||||
}
|
||||
|
||||
for key, default_value in default_values.items():
|
||||
if key not in filtered:
|
||||
filtered[key] = default_value
|
||||
|
||||
# Log field filtering for security audit
|
||||
removed_fields = set(dataset_data.keys()) - set(filtered.keys())
|
||||
if removed_fields:
|
||||
logger.debug(
|
||||
f"Filtered dataset response - removed fields: {removed_fields} "
|
||||
f"(is_owner={is_owner}, can_view={can_view})"
|
||||
)
|
||||
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
def filter_file_response(
|
||||
file_data: Dict[str, Any],
|
||||
is_owner: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Filter file response fields based on user permissions
|
||||
|
||||
Args:
|
||||
file_data: Full file data dictionary
|
||||
is_owner: Whether user owns this file
|
||||
|
||||
Returns:
|
||||
Filtered dictionary with only authorized fields
|
||||
"""
|
||||
allowed_fields = (
|
||||
ResponseFilter.FILE_OWNER_FIELDS if is_owner
|
||||
else ResponseFilter.FILE_PUBLIC_FIELDS
|
||||
)
|
||||
|
||||
filtered = {
|
||||
key: value for key, value in file_data.items()
|
||||
if key in allowed_fields
|
||||
}
|
||||
|
||||
# Log field filtering for security audit
|
||||
removed_fields = set(file_data.keys()) - set(filtered.keys())
|
||||
if removed_fields:
|
||||
logger.debug(
|
||||
f"Filtered file response - removed fields: {removed_fields} "
|
||||
f"(is_owner={is_owner})"
|
||||
)
|
||||
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
def filter_batch_responses(
|
||||
items: List[Dict[str, Any]],
|
||||
filter_func: callable,
|
||||
ownership_map: Optional[Dict[str, bool]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Filter a batch of items using the provided filter function
|
||||
|
||||
Args:
|
||||
items: List of items to filter
|
||||
filter_func: Function to apply to each item (e.g., filter_agent_response)
|
||||
ownership_map: Optional map of item_id -> is_owner boolean
|
||||
|
||||
Returns:
|
||||
List of filtered items
|
||||
"""
|
||||
filtered_items = []
|
||||
|
||||
for item in items:
|
||||
item_id = item.get('id')
|
||||
is_owner = ownership_map.get(item_id, False) if ownership_map else False
|
||||
|
||||
filtered_item = filter_func(item, is_owner=is_owner)
|
||||
filtered_items.append(filtered_item)
|
||||
|
||||
return filtered_items
|
||||
|
||||
@staticmethod
|
||||
def sanitize_dataset_summary(
|
||||
summary_data: Dict[str, Any],
|
||||
user_can_access: bool = True
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Sanitize dataset summary for inclusion in chat context
|
||||
|
||||
Args:
|
||||
summary_data: Dataset summary with metadata
|
||||
user_can_access: Whether user should have access to this dataset
|
||||
|
||||
Returns:
|
||||
Sanitized summary or None if user shouldn't access
|
||||
"""
|
||||
if not user_can_access:
|
||||
return None
|
||||
|
||||
# Only include safe fields in summary
|
||||
safe_fields = {
|
||||
'id', 'name', 'description', 'summary',
|
||||
'document_count', 'chunk_count'
|
||||
}
|
||||
|
||||
return {
|
||||
key: value for key, value in summary_data.items()
|
||||
if key in safe_fields
|
||||
}
|
||||
314
apps/tenant-backend/app/core/security.py
Normal file
314
apps/tenant-backend/app/core/security.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
Security module for GT 2.0 Tenant Backend
|
||||
|
||||
Provides JWT capability token verification and user authentication.
|
||||
"""
|
||||
|
||||
import os
|
||||
import jwt
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import Header
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_jwt_secret() -> str:
|
||||
"""Get JWT secret from environment variable.
|
||||
|
||||
The JWT_SECRET is auto-generated by installers using:
|
||||
openssl rand -hex 32
|
||||
|
||||
This provides a 256-bit secret suitable for HS256 signing.
|
||||
"""
|
||||
secret = os.environ.get('JWT_SECRET')
|
||||
if not secret:
|
||||
raise ValueError("JWT_SECRET environment variable is required. Run the installer to generate one.")
|
||||
return secret
|
||||
|
||||
|
||||
def verify_capability_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify JWT capability token using HS256 symmetric key
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
Token payload if valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
secret = get_jwt_secret()
|
||||
|
||||
# Verify token with HS256 symmetric key
|
||||
payload = jwt.decode(token, secret, algorithms=["HS256"])
|
||||
|
||||
# Check expiration
|
||||
if "exp" in payload:
|
||||
if datetime.utcnow().timestamp() > payload["exp"]:
|
||||
logger.warning("Token expired")
|
||||
return None
|
||||
|
||||
return payload
|
||||
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.warning(f"Invalid token: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_capability_token(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
capabilities: list,
|
||||
expires_hours: int = 4
|
||||
) -> str:
|
||||
"""
|
||||
Create JWT capability token using HS256 symmetric key
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tenant_id: Tenant domain
|
||||
capabilities: List of capability objects
|
||||
expires_hours: Token expiration in hours
|
||||
|
||||
Returns:
|
||||
JWT token string
|
||||
"""
|
||||
try:
|
||||
secret = get_jwt_secret()
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": user_id,
|
||||
"user_type": "tenant_user",
|
||||
|
||||
# Current tenant context (primary structure)
|
||||
"current_tenant": {
|
||||
"id": tenant_id,
|
||||
"domain": tenant_id,
|
||||
"name": f"Tenant {tenant_id}",
|
||||
"role": "tenant_user",
|
||||
"display_name": user_id,
|
||||
"email": user_id,
|
||||
"is_primary": True,
|
||||
"capabilities": capabilities
|
||||
},
|
||||
|
||||
# Available tenants for tenant switching
|
||||
"available_tenants": [{
|
||||
"id": tenant_id,
|
||||
"domain": tenant_id,
|
||||
"name": f"Tenant {tenant_id}",
|
||||
"role": "tenant_user"
|
||||
}],
|
||||
|
||||
# Standard JWT fields
|
||||
"iat": datetime.utcnow().timestamp(),
|
||||
"exp": (datetime.utcnow() + timedelta(hours=expires_hours)).timestamp()
|
||||
}
|
||||
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create capability token: {e}")
|
||||
raise ValueError("Failed to create capability token")
|
||||
|
||||
|
||||
async def get_current_user(authorization: str = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current user from authorization header - REQUIRED for all endpoints
|
||||
Raises 401 if authentication fails - following GT 2.0 security principles
|
||||
"""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
if not authorization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
# Extract token
|
||||
token = authorization.replace("Bearer ", "")
|
||||
payload = verify_capability_token(token)
|
||||
|
||||
if not payload:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
# Extract tenant context from new JWT structure
|
||||
current_tenant = payload.get('current_tenant', {})
|
||||
available_tenants = payload.get('available_tenants', [])
|
||||
user_type = payload.get('user_type', 'tenant_user')
|
||||
|
||||
# For admin users, allow access to any tenant backend
|
||||
if user_type == 'super_admin' and current_tenant.get('domain') == 'admin':
|
||||
# Admin users accessing tenant backends - create tenant context for the current backend
|
||||
from app.core.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
# Override the admin context with the current tenant backend's context
|
||||
current_tenant = {
|
||||
'id': settings.tenant_id,
|
||||
'domain': settings.tenant_domain,
|
||||
'name': f'Tenant {settings.tenant_domain}',
|
||||
'role': 'super_admin',
|
||||
'display_name': payload.get('email', 'Admin User'),
|
||||
'email': payload.get('email'),
|
||||
'is_primary': True,
|
||||
'capabilities': [
|
||||
{'resource': '*', 'actions': ['*'], 'constraints': {}},
|
||||
]
|
||||
}
|
||||
logger.info(f"Admin user {payload.get('email')} accessing tenant backend {settings.tenant_domain}")
|
||||
|
||||
# Validate tenant context exists
|
||||
if not current_tenant or not current_tenant.get('id'):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No valid tenant context in token",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
# Return user dict with clean tenant context structure
|
||||
return {
|
||||
'sub': payload.get('sub'),
|
||||
'email': payload.get('email'),
|
||||
'user_id': payload.get('sub'),
|
||||
'user_type': payload.get('user_type', 'tenant_user'),
|
||||
|
||||
# Current tenant context (primary structure)
|
||||
'tenant_id': str(current_tenant.get('id')),
|
||||
'tenant_domain': current_tenant.get('domain'),
|
||||
'tenant_name': current_tenant.get('name'),
|
||||
'tenant_role': current_tenant.get('role'),
|
||||
'tenant_display_name': current_tenant.get('display_name'),
|
||||
'tenant_email': current_tenant.get('email'),
|
||||
'is_primary_tenant': current_tenant.get('is_primary', False),
|
||||
|
||||
# Tenant-specific capabilities
|
||||
'capabilities': current_tenant.get('capabilities', []),
|
||||
|
||||
# Available tenants for tenant switching
|
||||
'available_tenants': available_tenants
|
||||
}
|
||||
|
||||
|
||||
def get_current_user_email(authorization: str) -> str:
|
||||
"""
|
||||
Extract user email from authorization header
|
||||
"""
|
||||
if authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
payload = verify_capability_token(token)
|
||||
if payload:
|
||||
current_tenant = payload.get('current_tenant', {})
|
||||
# Prefer tenant-specific email, fallback to user email, then sub
|
||||
return (current_tenant.get('email') or
|
||||
payload.get('email') or
|
||||
payload.get('sub', 'test@example.com'))
|
||||
|
||||
return 'anonymous@example.com'
|
||||
|
||||
|
||||
def get_tenant_info(authorization: str) -> Dict[str, str]:
|
||||
"""
|
||||
Extract tenant information from authorization header
|
||||
"""
|
||||
if authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
payload = verify_capability_token(token)
|
||||
if payload:
|
||||
current_tenant = payload.get('current_tenant', {})
|
||||
if current_tenant:
|
||||
return {
|
||||
'tenant_id': str(current_tenant.get('id')),
|
||||
'tenant_domain': current_tenant.get('domain'),
|
||||
'tenant_name': current_tenant.get('name'),
|
||||
'tenant_role': current_tenant.get('role')
|
||||
}
|
||||
|
||||
return {
|
||||
'tenant_id': 'default',
|
||||
'tenant_domain': 'default',
|
||||
'tenant_name': 'Default Tenant',
|
||||
'tenant_role': 'tenant_user'
|
||||
}
|
||||
|
||||
|
||||
def verify_jwt_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify JWT token - alias for verify_capability_token
|
||||
"""
|
||||
return verify_capability_token(token)
|
||||
|
||||
|
||||
async def get_user_context_unified(
|
||||
authorization: Optional[str] = Header(None),
|
||||
x_tenant_domain: Optional[str] = Header(None),
|
||||
x_user_id: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Unified authentication for both JWT (user requests) and header-based (service requests).
|
||||
|
||||
Supports two auth modes:
|
||||
1. JWT Authentication: Authorization header with Bearer token (for direct user requests)
|
||||
2. Header Authentication: X-Tenant-Domain + X-User-ID headers (for internal service requests)
|
||||
|
||||
Returns user context with tenant information for both modes.
|
||||
"""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
# Mode 1: Header-based authentication (for internal services like MCP)
|
||||
if x_tenant_domain and x_user_id:
|
||||
logger.info(f"Using header auth: tenant={x_tenant_domain}, user={x_user_id}")
|
||||
return {
|
||||
"tenant_domain": x_tenant_domain,
|
||||
"tenant_id": x_tenant_domain,
|
||||
"id": x_user_id,
|
||||
"sub": x_user_id,
|
||||
"email": x_user_id,
|
||||
"user_id": x_user_id,
|
||||
"user_type": "internal_service",
|
||||
"tenant_role": "tenant_user"
|
||||
}
|
||||
|
||||
# Mode 2: JWT authentication (for direct user requests)
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
payload = verify_capability_token(token)
|
||||
|
||||
if payload:
|
||||
logger.info(f"Using JWT auth: user={payload.get('sub')}")
|
||||
# Extract tenant context from JWT structure
|
||||
current_tenant = payload.get('current_tenant', {})
|
||||
return {
|
||||
'sub': payload.get('sub'),
|
||||
'email': payload.get('email'),
|
||||
'user_id': payload.get('sub'),
|
||||
'id': payload.get('sub'),
|
||||
'user_type': payload.get('user_type', 'tenant_user'),
|
||||
'tenant_id': str(current_tenant.get('id', 'default')),
|
||||
'tenant_domain': current_tenant.get('domain', 'default'),
|
||||
'tenant_name': current_tenant.get('name', 'Default Tenant'),
|
||||
'tenant_role': current_tenant.get('role', 'tenant_user'),
|
||||
'capabilities': current_tenant.get('capabilities', [])
|
||||
}
|
||||
|
||||
# No valid authentication provided
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication: provide either Authorization header or X-Tenant-Domain + X-User-ID headers"
|
||||
)
|
||||
165
apps/tenant-backend/app/core/user_resolver.py
Normal file
165
apps/tenant-backend/app/core/user_resolver.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
User UUID Resolution Utilities for GT 2.0
|
||||
|
||||
Handles email-to-UUID resolution across all services to ensure
|
||||
consistent user identification in database operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def resolve_user_uuid(current_user: Dict[str, Any]) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Resolve user email to UUID for internal services.
|
||||
|
||||
Args:
|
||||
current_user: User data from JWT token
|
||||
|
||||
Returns:
|
||||
Tuple of (tenant_domain, user_email, user_uuid)
|
||||
|
||||
Raises:
|
||||
HTTPException: If UUID resolution fails
|
||||
"""
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
user_email = current_user["email"]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from app.api.auth import get_tenant_user_uuid_by_email
|
||||
|
||||
user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
|
||||
if not user_uuid:
|
||||
logger.error(f"Failed to resolve UUID for user {user_email} in tenant {tenant_domain}")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"User {user_email} not found in tenant system"
|
||||
)
|
||||
|
||||
logger.info(f"✅ Resolved user {user_email} to UUID: {user_uuid}")
|
||||
return tenant_domain, user_email, user_uuid
|
||||
|
||||
|
||||
async def ensure_user_uuid(email_or_uuid: str, tenant_domain: Optional[str] = None) -> str:
|
||||
"""
|
||||
Ensure we have a UUID, converting email if needed.
|
||||
|
||||
Args:
|
||||
email_or_uuid: Either an email address or UUID string
|
||||
tenant_domain: Tenant domain for lookup context
|
||||
|
||||
Returns:
|
||||
UUID string
|
||||
|
||||
Raises:
|
||||
ValueError: If email cannot be resolved to UUID or input is invalid
|
||||
"""
|
||||
import uuid
|
||||
import re
|
||||
|
||||
# Validate input is not empty or None
|
||||
if not email_or_uuid or not isinstance(email_or_uuid, str):
|
||||
raise ValueError(f"Invalid user identifier: {email_or_uuid}")
|
||||
|
||||
email_or_uuid = email_or_uuid.strip()
|
||||
|
||||
# Check if it's an email
|
||||
if "@" in email_or_uuid:
|
||||
# It's an email, resolve to UUID
|
||||
from app.api.auth import get_tenant_user_uuid_by_email
|
||||
|
||||
user_uuid = await get_tenant_user_uuid_by_email(email_or_uuid)
|
||||
|
||||
if not user_uuid:
|
||||
error_msg = f"Cannot resolve email {email_or_uuid} to UUID"
|
||||
if tenant_domain:
|
||||
error_msg += f" in tenant {tenant_domain}"
|
||||
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.debug(f"Resolved email {email_or_uuid} to UUID: {user_uuid}")
|
||||
return user_uuid
|
||||
|
||||
# Check if it's a valid UUID format
|
||||
try:
|
||||
uuid_obj = uuid.UUID(email_or_uuid)
|
||||
return str(uuid_obj) # Return normalized UUID string
|
||||
except (ValueError, TypeError):
|
||||
# Not a valid UUID, could be a numeric ID or other format
|
||||
pass
|
||||
|
||||
# Handle numeric user IDs or other legacy formats
|
||||
if email_or_uuid.isdigit():
|
||||
logger.warning(f"Received numeric user ID '{email_or_uuid}', attempting database lookup")
|
||||
# Try to resolve numeric ID to proper UUID via database
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
tenant_schema = f"tenant_{tenant_domain.replace('.', '_').replace('-', '_')}" if tenant_domain else "tenant_test"
|
||||
|
||||
# Try to find user by numeric ID (assuming it might be a legacy ID)
|
||||
user_row = await conn.fetchrow(
|
||||
f"SELECT id FROM {tenant_schema}.users WHERE id::text = $1 OR email = $1 LIMIT 1",
|
||||
email_or_uuid
|
||||
)
|
||||
|
||||
if user_row:
|
||||
return str(user_row['id'])
|
||||
|
||||
# If not found, try finding the first user (fallback for development)
|
||||
logger.warning(f"User '{email_or_uuid}' not found, using first available user as fallback")
|
||||
first_user = await conn.fetchrow(f"SELECT id FROM {tenant_schema}.users LIMIT 1")
|
||||
|
||||
if first_user:
|
||||
logger.info(f"Using fallback user UUID: {first_user['id']}")
|
||||
return str(first_user['id'])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database lookup failed for user '{email_or_uuid}': {e}")
|
||||
|
||||
# If all else fails, raise an error
|
||||
error_msg = f"Cannot resolve user identifier '{email_or_uuid}' to UUID. Expected email or valid UUID format."
|
||||
if tenant_domain:
|
||||
error_msg += f" Tenant: {tenant_domain}"
|
||||
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def get_user_sql_clause(param_num: int, user_identifier: str) -> str:
|
||||
"""
|
||||
Get the appropriate SQL clause for user identification.
|
||||
|
||||
Args:
|
||||
param_num: Parameter number in SQL query (e.g., 3 for $3)
|
||||
user_identifier: Either email or UUID
|
||||
|
||||
Returns:
|
||||
SQL clause string for user lookup
|
||||
"""
|
||||
if "@" in user_identifier:
|
||||
# Email - do lookup
|
||||
return f"(SELECT id FROM users WHERE email = ${param_num} LIMIT 1)"
|
||||
else:
|
||||
# UUID - use directly
|
||||
return f"${param_num}::uuid"
|
||||
|
||||
|
||||
def is_uuid_format(identifier: str) -> bool:
|
||||
"""
|
||||
Check if a string looks like a UUID.
|
||||
|
||||
Args:
|
||||
identifier: String to check
|
||||
|
||||
Returns:
|
||||
True if looks like UUID, False if looks like email
|
||||
"""
|
||||
return "@" not in identifier and len(identifier) == 36 and identifier.count("-") == 4
|
||||
Reference in New Issue
Block a user