GT AI OS Community v2.0.33 - Add NVIDIA NIM and Nemotron agents

- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2
- Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2
  - Made more general-purpose (flexible targets, expanded tools)
- Added nemotron-mini-agent.csv for fast local inference via Ollama
- Added nemotron-agent.csv for advanced reasoning via Ollama
- Added wiki page: Projects for NVIDIA NIMs and Nemotron
This commit is contained in:
HackWeasel
2025-12-12 17:47:14 -05:00
commit 310491a557
750 changed files with 232701 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
"""
GT 2.0 Control Panel Services
"""

View File

@@ -0,0 +1,461 @@
"""
API Key Management Service for tenant-specific external API keys
"""
import os
import json
from typing import Dict, Any, Optional, List
from datetime import datetime
from cryptography.fernet import Fernet
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from sqlalchemy.orm.attributes import flag_modified
from app.models.tenant import Tenant
from app.models.audit import AuditLog
from app.core.config import settings
class APIKeyService:
"""Service for managing tenant-specific API keys"""
# Supported API key providers - NVIDIA, Groq, and Backblaze
SUPPORTED_PROVIDERS = {
'nvidia': {
'name': 'NVIDIA NIM',
'description': 'GPU-accelerated inference on DGX Cloud via build.nvidia.com',
'required_format': 'nvapi-*',
'test_endpoint': 'https://integrate.api.nvidia.com/v1/models'
},
'groq': {
'name': 'Groq Cloud LLM',
'description': 'High-performance LLM inference',
'required_format': 'gsk_*',
'test_endpoint': 'https://api.groq.com/openai/v1/models'
},
'backblaze': {
'name': 'Backblaze B2',
'description': 'S3-compatible backup storage',
'required_format': None, # Key ID and Application Key
'test_endpoint': None
}
}
def __init__(self, db: AsyncSession):
self.db = db
# Use environment variable or generate a key for encryption
encryption_key = os.getenv('API_KEY_ENCRYPTION_KEY')
if not encryption_key:
# In production, this should be stored securely
encryption_key = Fernet.generate_key().decode()
os.environ['API_KEY_ENCRYPTION_KEY'] = encryption_key
self.cipher = Fernet(encryption_key.encode() if isinstance(encryption_key, str) else encryption_key)
async def set_api_key(
self,
tenant_id: int,
provider: str,
api_key: str,
api_secret: Optional[str] = None,
enabled: bool = True,
metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Set or update an API key for a tenant"""
if provider not in self.SUPPORTED_PROVIDERS:
raise ValueError(f"Unsupported provider: {provider}")
# Validate key format if required
provider_info = self.SUPPORTED_PROVIDERS[provider]
if provider_info['required_format'] and not api_key.startswith(provider_info['required_format'].replace('*', '')):
raise ValueError(f"Invalid API key format for {provider}")
# Get tenant
result = await self.db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
# Encrypt API key
encrypted_key = self.cipher.encrypt(api_key.encode()).decode()
encrypted_secret = None
if api_secret:
encrypted_secret = self.cipher.encrypt(api_secret.encode()).decode()
# Update tenant's API keys
api_keys = tenant.api_keys or {}
api_keys[provider] = {
'key': encrypted_key,
'secret': encrypted_secret,
'enabled': enabled,
'metadata': metadata or {},
'updated_at': datetime.utcnow().isoformat(),
'updated_by': 'admin' # Should come from auth context
}
tenant.api_keys = api_keys
flag_modified(tenant, "api_keys")
await self.db.commit()
# Log the action
audit_log = AuditLog(
tenant_id=tenant_id,
action='api_key_updated',
resource_type='api_key',
resource_id=provider,
details={'provider': provider, 'enabled': enabled}
)
self.db.add(audit_log)
await self.db.commit()
# Invalidate Resource Cluster cache so it picks up the new key
await self._invalidate_resource_cluster_cache(tenant.domain, provider)
return {
'tenant_id': tenant_id,
'provider': provider,
'enabled': enabled,
'updated_at': api_keys[provider]['updated_at']
}
async def get_api_keys(self, tenant_id: int) -> Dict[str, Any]:
"""Get all API keys for a tenant (without decryption)"""
result = await self.db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
api_keys = tenant.api_keys or {}
# Return key status without actual keys
return {
provider: {
'configured': True,
'enabled': info.get('enabled', False),
'updated_at': info.get('updated_at'),
'metadata': info.get('metadata', {})
}
for provider, info in api_keys.items()
}
async def get_decrypted_key(
self,
tenant_id: int,
provider: str,
require_enabled: bool = True
) -> Dict[str, Any]:
"""Get decrypted API key for a specific provider"""
result = await self.db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
api_keys = tenant.api_keys or {}
if provider not in api_keys:
raise ValueError(f"API key for {provider} not configured for tenant {tenant_id}")
key_info = api_keys[provider]
if require_enabled and not key_info.get('enabled', False):
raise ValueError(f"API key for {provider} is disabled for tenant {tenant_id}")
# Decrypt the key
decrypted_key = self.cipher.decrypt(key_info['key'].encode()).decode()
decrypted_secret = None
if key_info.get('secret'):
decrypted_secret = self.cipher.decrypt(key_info['secret'].encode()).decode()
return {
'provider': provider,
'api_key': decrypted_key,
'api_secret': decrypted_secret,
'metadata': key_info.get('metadata', {}),
'enabled': key_info.get('enabled', False)
}
async def disable_api_key(self, tenant_id: int, provider: str) -> bool:
"""Disable an API key without removing it"""
result = await self.db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
api_keys = tenant.api_keys or {}
if provider not in api_keys:
raise ValueError(f"API key for {provider} not configured")
api_keys[provider]['enabled'] = False
api_keys[provider]['updated_at'] = datetime.utcnow().isoformat()
tenant.api_keys = api_keys
flag_modified(tenant, "api_keys")
await self.db.commit()
# Log the action
audit_log = AuditLog(
tenant_id=tenant_id,
action='api_key_disabled',
resource_type='api_key',
resource_id=provider,
details={'provider': provider}
)
self.db.add(audit_log)
await self.db.commit()
# Invalidate Resource Cluster cache
await self._invalidate_resource_cluster_cache(tenant.domain, provider)
return True
async def remove_api_key(self, tenant_id: int, provider: str) -> bool:
"""Completely remove an API key"""
result = await self.db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
api_keys = tenant.api_keys or {}
if provider in api_keys:
del api_keys[provider]
tenant.api_keys = api_keys
flag_modified(tenant, "api_keys")
await self.db.commit()
# Log the action
audit_log = AuditLog(
tenant_id=tenant_id,
action='api_key_removed',
resource_type='api_key',
resource_id=provider,
details={'provider': provider}
)
self.db.add(audit_log)
await self.db.commit()
# Invalidate Resource Cluster cache
await self._invalidate_resource_cluster_cache(tenant.domain, provider)
return True
return False
async def test_api_key(self, tenant_id: int, provider: str) -> Dict[str, Any]:
"""Test if an API key is valid by making a test request with detailed error mapping"""
import httpx
# Get decrypted key
key_info = await self.get_decrypted_key(tenant_id, provider)
provider_info = self.SUPPORTED_PROVIDERS[provider]
if not provider_info.get('test_endpoint'):
return {
'provider': provider,
'testable': False,
'valid': False,
'message': 'No test endpoint available for this provider',
'error_type': 'not_testable'
}
# Validate key format before making request
api_key = key_info['api_key']
if provider == 'nvidia' and not api_key.startswith('nvapi-'):
return {
'provider': provider,
'valid': False,
'message': 'Invalid key format (should start with nvapi-)',
'error_type': 'invalid_format'
}
if provider == 'groq' and not api_key.startswith('gsk_'):
return {
'provider': provider,
'valid': False,
'message': 'Invalid key format (should start with gsk_)',
'error_type': 'invalid_format'
}
# Build authorization headers based on provider
headers = self._get_auth_headers(provider, api_key)
try:
async with httpx.AsyncClient() as client:
response = await client.get(
provider_info['test_endpoint'],
headers=headers,
timeout=10.0
)
# Extract rate limit headers
rate_limit_remaining = None
rate_limit_reset = None
if 'x-ratelimit-remaining' in response.headers:
try:
rate_limit_remaining = int(response.headers['x-ratelimit-remaining'])
except (ValueError, TypeError):
pass
if 'x-ratelimit-reset' in response.headers:
rate_limit_reset = response.headers['x-ratelimit-reset']
# Count available models if response is successful
models_available = None
if response.status_code == 200:
try:
data = response.json()
if 'data' in data and isinstance(data['data'], list):
models_available = len(data['data'])
except Exception:
pass
# Detailed error mapping
if response.status_code == 200:
return {
'provider': provider,
'valid': True,
'message': 'API key is valid',
'status_code': response.status_code,
'rate_limit_remaining': rate_limit_remaining,
'rate_limit_reset': rate_limit_reset,
'models_available': models_available
}
elif response.status_code == 401:
return {
'provider': provider,
'valid': False,
'message': 'Invalid or expired API key',
'status_code': response.status_code,
'error_type': 'auth_failed',
'rate_limit_remaining': rate_limit_remaining,
'rate_limit_reset': rate_limit_reset
}
elif response.status_code == 403:
return {
'provider': provider,
'valid': False,
'message': 'Insufficient permissions for this API key',
'status_code': response.status_code,
'error_type': 'insufficient_permissions',
'rate_limit_remaining': rate_limit_remaining,
'rate_limit_reset': rate_limit_reset
}
elif response.status_code == 429:
return {
'provider': provider,
'valid': True, # Key is valid, just rate limited
'message': 'Rate limit exceeded - key is valid but currently limited',
'status_code': response.status_code,
'error_type': 'rate_limited',
'rate_limit_remaining': rate_limit_remaining,
'rate_limit_reset': rate_limit_reset
}
else:
return {
'provider': provider,
'valid': False,
'message': f'Test failed with HTTP {response.status_code}',
'status_code': response.status_code,
'error_type': 'server_error' if response.status_code >= 500 else 'unknown',
'rate_limit_remaining': rate_limit_remaining,
'rate_limit_reset': rate_limit_reset
}
except httpx.ConnectError:
return {
'provider': provider,
'valid': False,
'message': f"Connection failed: Unable to reach {provider_info['test_endpoint']}",
'error_type': 'connection_error'
}
except httpx.TimeoutException:
return {
'provider': provider,
'valid': False,
'message': 'Connection timed out after 10 seconds',
'error_type': 'timeout'
}
except Exception as e:
return {
'provider': provider,
'valid': False,
'error': str(e),
'message': f"Test failed: {str(e)}",
'error_type': 'unknown'
}
def _get_auth_headers(self, provider: str, api_key: str) -> Dict[str, str]:
"""Build authorization headers based on provider"""
if provider in ('nvidia', 'groq', 'openai', 'cohere', 'huggingface'):
return {'Authorization': f"Bearer {api_key}"}
elif provider == 'anthropic':
return {'x-api-key': api_key}
else:
return {'Authorization': f"Bearer {api_key}"}
async def get_api_key_usage(self, tenant_id: int, provider: str) -> Dict[str, Any]:
"""Get usage statistics for an API key"""
# This would query usage records for the specific provider
# For now, return mock data
return {
'provider': provider,
'tenant_id': tenant_id,
'usage': {
'requests_today': 1234,
'tokens_today': 456789,
'cost_today_cents': 234,
'requests_month': 45678,
'tokens_month': 12345678,
'cost_month_cents': 8901
}
}
async def _invalidate_resource_cluster_cache(
self,
tenant_domain: str,
provider: str
) -> None:
"""
Notify Resource Cluster to invalidate its API key cache.
This is called after API keys are modified, disabled, or removed
to ensure the Resource Cluster doesn't use stale cached keys.
Non-critical: If this fails, the cache will expire naturally after TTL.
"""
try:
from app.clients.resource_cluster_client import get_resource_cluster_client
client = get_resource_cluster_client()
await client.invalidate_api_key_cache(
tenant_domain=tenant_domain,
provider=provider
)
except Exception as e:
# Log but don't fail the main operation
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Failed to invalidate Resource Cluster cache (non-critical): {e}")
@classmethod
def get_supported_providers(cls) -> List[Dict[str, Any]]:
"""Get list of supported API key providers"""
return [
{
'id': provider_id,
'name': info['name'],
'description': info['description'],
'requires_secret': provider_id == 'backblaze'
}
for provider_id, info in cls.SUPPORTED_PROVIDERS.items()
]

View File

@@ -0,0 +1,344 @@
"""
Backup Service - Manages system backups and restoration
"""
import os
import asyncio
import hashlib
from typing import Dict, Any, Optional, List
from datetime import datetime
from pathlib import Path
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc, and_
from fastapi import HTTPException, status
import structlog
from app.models.system import BackupRecord, BackupType
logger = structlog.get_logger()
class BackupService:
"""Service for creating and managing system backups"""
BACKUP_SCRIPT = "/app/scripts/backup/backup-compose.sh"
RESTORE_SCRIPT = "/app/scripts/backup/restore-compose.sh"
BACKUP_DIR = os.getenv("GT2_BACKUP_DIR", "/app/backups")
def __init__(self, db: AsyncSession):
self.db = db
async def create_backup(
self,
backup_type: str = "manual",
description: str = None,
created_by: str = None
) -> Dict[str, Any]:
"""Create a new system backup"""
try:
# Validate backup type
if backup_type not in ["manual", "pre_update", "scheduled"]:
raise ValueError(f"Invalid backup type: {backup_type}")
# Ensure backup directory exists
os.makedirs(self.BACKUP_DIR, exist_ok=True)
# Generate backup filename
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
backup_filename = f"gt2_backup_{timestamp}.tar.gz"
backup_path = os.path.join(self.BACKUP_DIR, backup_filename)
# Get current version
current_version = await self._get_current_version()
# Create backup record
backup_record = BackupRecord(
backup_type=BackupType[backup_type],
location=backup_path,
version=current_version,
description=description or f"{backup_type.replace('_', ' ').title()} backup",
created_by=created_by,
components=self._get_backup_components()
)
self.db.add(backup_record)
await self.db.commit()
await self.db.refresh(backup_record)
# Run backup script in background
asyncio.create_task(
self._run_backup_process(backup_record.uuid, backup_path)
)
logger.info(f"Backup job {backup_record.uuid} created")
return backup_record.to_dict()
except Exception as e:
logger.error(f"Failed to create backup: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create backup: {str(e)}"
)
async def list_backups(
self,
limit: int = 50,
offset: int = 0,
backup_type: str = None
) -> Dict[str, Any]:
"""List available backups"""
try:
# Build query
query = select(BackupRecord)
if backup_type:
query = query.where(BackupRecord.backup_type == BackupType[backup_type])
query = query.order_by(desc(BackupRecord.created_at)).limit(limit).offset(offset)
result = await self.db.execute(query)
backups = result.scalars().all()
# Get total count
count_query = select(BackupRecord)
if backup_type:
count_query = count_query.where(BackupRecord.backup_type == BackupType[backup_type])
count_result = await self.db.execute(count_query)
total = len(count_result.scalars().all())
# Calculate total storage used by backups
backup_list = [b.to_dict() for b in backups]
storage_used = sum(b.get("size", 0) or 0 for b in backup_list)
return {
"backups": backup_list,
"total": total,
"limit": limit,
"offset": offset,
"storage_used": storage_used
}
except Exception as e:
logger.error(f"Failed to list backups: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to list backups: {str(e)}"
)
async def get_backup(self, backup_id: str) -> Dict[str, Any]:
"""Get details of a specific backup"""
stmt = select(BackupRecord).where(BackupRecord.uuid == backup_id)
result = await self.db.execute(stmt)
backup = result.scalar_one_or_none()
if not backup:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Backup {backup_id} not found"
)
# Check if file actually exists
file_exists = os.path.exists(backup.location)
backup_dict = backup.to_dict()
backup_dict["file_exists"] = file_exists
return backup_dict
async def restore_backup(
self,
backup_id: str,
components: List[str] = None
) -> Dict[str, Any]:
"""Restore from a backup"""
# Get backup record
stmt = select(BackupRecord).where(BackupRecord.uuid == backup_id)
result = await self.db.execute(stmt)
backup = result.scalar_one_or_none()
if not backup:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Backup {backup_id} not found"
)
if not backup.is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Backup is marked as invalid and cannot be restored"
)
# Check if backup file exists
if not os.path.exists(backup.location):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Backup file not found on disk"
)
# Verify checksum if available
if backup.checksum:
calculated_checksum = await self._calculate_checksum(backup.location)
if calculated_checksum != backup.checksum:
backup.is_valid = False
await self.db.commit()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Backup checksum mismatch - file may be corrupted"
)
# Run restore in background
asyncio.create_task(self._run_restore_process(backup.location, components))
return {
"message": "Restore initiated",
"backup_id": backup_id,
"version": backup.version,
"components": components or list(backup.components.keys())
}
async def delete_backup(self, backup_id: str) -> Dict[str, Any]:
"""Delete a backup"""
stmt = select(BackupRecord).where(BackupRecord.uuid == backup_id)
result = await self.db.execute(stmt)
backup = result.scalar_one_or_none()
if not backup:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Backup {backup_id} not found"
)
# Delete file from disk
try:
if os.path.exists(backup.location):
os.remove(backup.location)
logger.info(f"Deleted backup file: {backup.location}")
except Exception as e:
logger.warning(f"Failed to delete backup file: {str(e)}")
# Delete database record
await self.db.delete(backup)
await self.db.commit()
return {
"message": "Backup deleted",
"backup_id": backup_id
}
async def _run_backup_process(self, backup_uuid: str, backup_path: str):
"""Background task to create backup"""
try:
# Reload backup record
stmt = select(BackupRecord).where(BackupRecord.uuid == backup_uuid)
result = await self.db.execute(stmt)
backup = result.scalar_one_or_none()
if not backup:
logger.error(f"Backup {backup_uuid} not found")
return
logger.info(f"Starting backup process: {backup_uuid}")
# Run backup script
process = await asyncio.create_subprocess_exec(
self.BACKUP_SCRIPT,
backup_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
# Success - calculate file size and checksum
if os.path.exists(backup_path):
backup.size_bytes = os.path.getsize(backup_path)
backup.checksum = await self._calculate_checksum(backup_path)
logger.info(f"Backup completed: {backup_uuid} ({backup.size_bytes} bytes)")
else:
backup.is_valid = False
logger.error(f"Backup file not created: {backup_path}")
else:
# Failure
backup.is_valid = False
error_msg = stderr.decode() if stderr else "Unknown error"
logger.error(f"Backup failed: {error_msg}")
await self.db.commit()
except Exception as e:
logger.error(f"Backup process error: {str(e)}")
# Mark backup as invalid
stmt = select(BackupRecord).where(BackupRecord.uuid == backup_uuid)
result = await self.db.execute(stmt)
backup = result.scalar_one_or_none()
if backup:
backup.is_valid = False
await self.db.commit()
async def _run_restore_process(self, backup_path: str, components: List[str] = None):
"""Background task to restore from backup"""
try:
logger.info(f"Starting restore process from: {backup_path}")
# Build restore command
cmd = [self.RESTORE_SCRIPT, backup_path]
if components:
cmd.extend(components)
# Run restore script
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
logger.info("Restore completed successfully")
else:
error_msg = stderr.decode() if stderr else "Unknown error"
logger.error(f"Restore failed: {error_msg}")
except Exception as e:
logger.error(f"Restore process error: {str(e)}")
async def _get_current_version(self) -> str:
"""Get current system version"""
try:
from app.models.system import SystemVersion
stmt = select(SystemVersion.version).where(
SystemVersion.is_current == True
).order_by(desc(SystemVersion.installed_at)).limit(1)
result = await self.db.execute(stmt)
version = result.scalar_one_or_none()
return version or "unknown"
except Exception:
return "unknown"
def _get_backup_components(self) -> Dict[str, bool]:
"""Get list of components to backup"""
return {
"databases": True,
"docker_volumes": True,
"configs": True,
"logs": False # Logs typically excluded to save space
}
async def _calculate_checksum(self, filepath: str) -> str:
"""Calculate SHA256 checksum of a file"""
try:
sha256_hash = hashlib.sha256()
with open(filepath, "rb") as f:
# Read file in chunks to handle large files
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
except Exception as e:
logger.error(f"Failed to calculate checksum: {str(e)}")
return ""

View File

@@ -0,0 +1,452 @@
"""
Default Model Configurations for GT 2.0
This module contains the default configuration for all 19 Groq models
plus the BGE-M3 embedding model on GT Edge network.
"""
from typing import List, Dict, Any
def get_default_models() -> List[Dict[str, Any]]:
"""Get list of all default model configurations"""
# Groq LLM Models (11 models)
groq_llm_models = [
{
"model_id": "llama-3.3-70b-versatile",
"name": "Llama 3.3 70B Versatile",
"version": "3.3",
"provider": "groq",
"model_type": "llm",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"specifications": {
"context_window": 128000,
"max_tokens": 32768,
},
"capabilities": {
"reasoning": True,
"function_calling": True,
"streaming": True,
"multilingual": True
},
"cost": {
"per_1k_input": 0.59,
"per_1k_output": 0.79
},
"description": "Latest Llama 3.3 70B model optimized for versatile tasks with large context window",
"is_active": True
},
{
"model_id": "llama-3.3-70b-specdec",
"name": "Llama 3.3 70B Speculative Decoding",
"version": "3.3",
"provider": "groq",
"model_type": "llm",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"specifications": {
"context_window": 8192,
"max_tokens": 8192,
},
"capabilities": {
"reasoning": True,
"function_calling": True,
"streaming": True
},
"cost": {
"per_1k_input": 0.59,
"per_1k_output": 0.79
},
"description": "Llama 3.3 70B with speculative decoding for faster inference",
"is_active": True
},
{
"model_id": "llama-3.2-90b-text-preview",
"name": "Llama 3.2 90B Text Preview",
"version": "3.2",
"provider": "groq",
"model_type": "llm",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"specifications": {
"context_window": 128000,
"max_tokens": 8000,
},
"capabilities": {
"reasoning": True,
"function_calling": True,
"streaming": True
},
"cost": {
"per_1k_input": 0.2,
"per_1k_output": 0.2
},
"description": "Large Llama 3.2 model with enhanced text processing capabilities",
"is_active": True
},
{
"model_id": "llama-3.1-405b-reasoning",
"name": "Llama 3.1 405B Reasoning",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"specifications": {
"context_window": 131072,
"max_tokens": 32768,
},
"capabilities": {
"reasoning": True,
"function_calling": True,
"streaming": True,
"multilingual": True
},
"cost": {
"per_1k_input": 2.5,
"per_1k_output": 2.5
},
"description": "Largest Llama model optimized for complex reasoning tasks",
"is_active": True
},
{
"model_id": "llama-3.1-70b-versatile",
"name": "Llama 3.1 70B Versatile",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"specifications": {
"context_window": 131072,
"max_tokens": 32768,
},
"capabilities": {
"reasoning": True,
"function_calling": True,
"streaming": True,
"multilingual": True
},
"cost": {
"per_1k_input": 0.59,
"per_1k_output": 0.79
},
"description": "Balanced Llama model for general-purpose tasks with large context",
"is_active": True
},
{
"model_id": "llama-3.1-8b-instant",
"name": "Llama 3.1 8B Instant",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"specifications": {
"context_window": 131072,
"max_tokens": 8192,
},
"capabilities": {
"streaming": True,
"multilingual": True
},
"cost": {
"per_1k_input": 0.05,
"per_1k_output": 0.08
},
"description": "Fast and efficient Llama model for quick responses",
"is_active": True
},
{
"model_id": "llama3-groq-70b-8192-tool-use-preview",
"name": "Llama 3 Groq 70B Tool Use Preview",
"version": "3.0",
"provider": "groq",
"model_type": "llm",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"specifications": {
"context_window": 8192,
"max_tokens": 8192,
},
"capabilities": {
"function_calling": True,
"streaming": True
},
"cost": {
"per_1k_input": 0.89,
"per_1k_output": 0.89
},
"description": "Llama 3 70B optimized for tool use and function calling",
"is_active": True
},
{
"model_id": "llama3-groq-8b-8192-tool-use-preview",
"name": "Llama 3 Groq 8B Tool Use Preview",
"version": "3.0",
"provider": "groq",
"model_type": "llm",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"specifications": {
"context_window": 8192,
"max_tokens": 8192,
},
"capabilities": {
"function_calling": True,
"streaming": True
},
"cost": {
"per_1k_input": 0.19,
"per_1k_output": 0.19
},
"description": "Compact Llama 3 model optimized for tool use and function calling",
"is_active": True
},
{
"model_id": "mixtral-8x7b-32768",
"name": "Mixtral 8x7B",
"version": "1.0",
"provider": "groq",
"model_type": "llm",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"specifications": {
"context_window": 32768,
"max_tokens": 32768,
},
"capabilities": {
"reasoning": True,
"streaming": True,
"multilingual": True
},
"cost": {
"per_1k_input": 0.24,
"per_1k_output": 0.24
},
"description": "Mixture of experts model with strong multilingual capabilities",
"is_active": True
},
{
"model_id": "gemma2-9b-it",
"name": "Gemma 2 9B Instruction Tuned",
"version": "2.0",
"provider": "groq",
"model_type": "llm",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"specifications": {
"context_window": 8192,
"max_tokens": 8192,
},
"capabilities": {
"streaming": True,
"multilingual": False
},
"cost": {
"per_1k_input": 0.2,
"per_1k_output": 0.2
},
"description": "Google's Gemma 2 model optimized for instruction following",
"is_active": True
},
{
"model_id": "llama-guard-3-8b",
"name": "Llama Guard 3 8B",
"version": "3.0",
"provider": "groq",
"model_type": "llm",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"specifications": {
"context_window": 8192,
"max_tokens": 8192,
},
"capabilities": {
"streaming": False,
"safety_classification": True
},
"cost": {
"per_1k_input": 0.2,
"per_1k_output": 0.2
},
"description": "Safety classification model for content moderation",
"is_active": True
}
]
# Groq Audio Models (3 models)
groq_audio_models = [
{
"model_id": "whisper-large-v3",
"name": "Whisper Large v3",
"version": "3.0",
"provider": "groq",
"model_type": "audio",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"capabilities": {
"transcription": True,
"multilingual": True
},
"cost": {
"per_1k_input": 0.111,
"per_1k_output": 0.111
},
"description": "High-quality speech transcription with multilingual support",
"is_active": True
},
{
"model_id": "whisper-large-v3-turbo",
"name": "Whisper Large v3 Turbo",
"version": "3.0",
"provider": "groq",
"model_type": "audio",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"capabilities": {
"transcription": True,
"multilingual": True
},
"cost": {
"per_1k_input": 0.04,
"per_1k_output": 0.04
},
"description": "Fast speech transcription optimized for speed",
"is_active": True
},
{
"model_id": "distil-whisper-large-v3-en",
"name": "Distil-Whisper Large v3 English",
"version": "3.0",
"provider": "groq",
"model_type": "audio",
"endpoint": "https://api.groq.com/openai/v1",
"api_key_name": "GROQ_API_KEY",
"capabilities": {
"transcription": True,
"multilingual": False
},
"cost": {
"per_1k_input": 0.02,
"per_1k_output": 0.02
},
"description": "Compact English-only transcription model",
"is_active": True
}
]
# BGE-M3 Embedding Model (External on GT Edge)
external_models = [
{
"model_id": "bge-m3",
"name": "BAAI BGE-M3 Multilingual Embeddings",
"version": "1.0",
"provider": "external",
"model_type": "embedding",
"endpoint": "http://10.0.1.50:8080", # GT Edge local network
"specifications": {
"dimensions": 1024,
"max_tokens": 8192,
},
"capabilities": {
"multilingual": True,
"dense_retrieval": True,
"sparse_retrieval": True,
"colbert": True
},
"cost": {
"per_1k_input": 0.0,
"per_1k_output": 0.0
},
"description": "State-of-the-art multilingual embedding model running on GT Edge local network",
"config": {
"batch_size": 32,
"normalize": True,
"pooling_method": "mean"
},
"is_active": True
}
]
# Local Ollama Models (for on-premise deployments)
ollama_models = [
{
"model_id": "ollama-local-dgx-x86",
"name": "Local Ollama (DGX/X86)",
"version": "1.0",
"provider": "ollama",
"model_type": "llm",
"endpoint": "http://ollama-host:11434/v1/chat/completions",
"api_key_name": None, # No API key needed for local Ollama
"specifications": {
"context_window": 131072,
"max_tokens": 4096,
},
"capabilities": {
"streaming": True,
"function_calling": False
},
"cost": {
"per_1k_input": 0.0,
"per_1k_output": 0.0
},
"description": "Local Ollama instance for DGX and x86 Linux deployments. Uses ollama-host DNS resolution.",
"is_active": True
},
{
"model_id": "ollama-local-macos",
"name": "Local Ollama (MacOS)",
"version": "1.0",
"provider": "ollama",
"model_type": "llm",
"endpoint": "http://host.docker.internal:11434/v1/chat/completions",
"api_key_name": None, # No API key needed for local Ollama
"specifications": {
"context_window": 131072,
"max_tokens": 4096,
},
"capabilities": {
"streaming": True,
"function_calling": False
},
"cost": {
"per_1k_input": 0.0,
"per_1k_output": 0.0
},
"description": "Local Ollama instance for macOS deployments. Uses host.docker.internal for Docker-to-host networking.",
"is_active": True
}
]
# TTS Models (placeholder - will be added when available)
tts_models = [
# Future TTS models from Groq/PlayAI
]
# Combine all models
all_models = groq_llm_models + groq_audio_models + external_models + ollama_models + tts_models
return all_models
def get_groq_models() -> List[Dict[str, Any]]:
"""Get only Groq models"""
return [model for model in get_default_models() if model["provider"] == "groq"]
def get_external_models() -> List[Dict[str, Any]]:
"""Get only external models (BGE-M3, etc.)"""
return [model for model in get_default_models() if model["provider"] == "external"]
def get_ollama_models() -> List[Dict[str, Any]]:
"""Get only Ollama models (local inference)"""
return [model for model in get_default_models() if model["provider"] == "ollama"]
def get_models_by_type(model_type: str) -> List[Dict[str, Any]]:
"""Get models by type (llm, embedding, audio, tts)"""
return [model for model in get_default_models() if model["model_type"] == model_type]

View File

@@ -0,0 +1,484 @@
"""
Dremio SQL Federation Service for cross-cluster analytics
"""
import asyncio
import json
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
import httpx
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text
from app.models.tenant import Tenant
from app.models.user import User
from app.models.ai_resource import AIResource
from app.models.usage import UsageRecord
from app.core.config import settings
class DremioService:
"""Service for Dremio SQL federation and cross-cluster queries"""
def __init__(self, db: AsyncSession):
self.db = db
self.dremio_url = settings.DREMIO_URL or "http://dremio:9047"
self.dremio_username = settings.DREMIO_USERNAME or "admin"
self.dremio_password = settings.DREMIO_PASSWORD or "admin123"
self.auth_token = None
self.token_expires = None
async def _authenticate(self) -> str:
"""Authenticate with Dremio and get token"""
# Check if we have a valid token
if self.auth_token and self.token_expires and self.token_expires > datetime.utcnow():
return self.auth_token
# Get new token
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.dremio_url}/apiv2/login",
json={
"userName": self.dremio_username,
"password": self.dremio_password
}
)
if response.status_code == 200:
data = response.json()
self.auth_token = data['token']
# Token typically expires in 24 hours
self.token_expires = datetime.utcnow() + timedelta(hours=23)
return self.auth_token
else:
raise Exception(f"Dremio authentication failed: {response.status_code}")
async def execute_query(self, sql: str) -> List[Dict[str, Any]]:
"""Execute a SQL query via Dremio"""
token = await self._authenticate()
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.dremio_url}/api/v3/sql",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
},
json={"sql": sql},
timeout=30.0
)
if response.status_code == 200:
job_id = response.json()['id']
# Wait for job completion
while True:
job_response = await client.get(
f"{self.dremio_url}/api/v3/job/{job_id}",
headers={"Authorization": f"Bearer {token}"}
)
job_data = job_response.json()
if job_data['jobState'] == 'COMPLETED':
break
elif job_data['jobState'] in ['FAILED', 'CANCELLED']:
raise Exception(f"Query failed: {job_data.get('errorMessage', 'Unknown error')}")
await asyncio.sleep(0.5)
# Get results
results_response = await client.get(
f"{self.dremio_url}/api/v3/job/{job_id}/results",
headers={"Authorization": f"Bearer {token}"}
)
if results_response.status_code == 200:
return results_response.json()['rows']
else:
raise Exception(f"Failed to get results: {results_response.status_code}")
else:
raise Exception(f"Query execution failed: {response.status_code}")
async def get_tenant_dashboard_data(self, tenant_id: int) -> Dict[str, Any]:
"""Get comprehensive dashboard data for a tenant"""
# Get tenant info
result = await self.db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
# Federated queries across clusters
dashboard_data = {
'tenant': tenant.to_dict(),
'metrics': {},
'analytics': {},
'alerts': []
}
# 1. User metrics from Admin Cluster
user_metrics = await self._get_user_metrics(tenant_id)
dashboard_data['metrics']['users'] = user_metrics
# 2. Resource usage from Resource Cluster (via Dremio)
resource_usage = await self._get_resource_usage_federated(tenant_id)
dashboard_data['metrics']['resources'] = resource_usage
# 3. Application metrics from Tenant Cluster (via Dremio)
app_metrics = await self._get_application_metrics_federated(tenant.domain)
dashboard_data['metrics']['applications'] = app_metrics
# 4. Performance metrics
performance_data = await self._get_performance_metrics(tenant_id)
dashboard_data['analytics']['performance'] = performance_data
# 6. Security alerts
security_alerts = await self._get_security_alerts(tenant_id)
dashboard_data['alerts'] = security_alerts
return dashboard_data
async def _get_user_metrics(self, tenant_id: int) -> Dict[str, Any]:
"""Get user metrics from Admin Cluster database"""
# Total users
user_count_result = await self.db.execute(
select(User).where(User.tenant_id == tenant_id)
)
users = user_count_result.scalars().all()
# Active users (logged in within 7 days)
seven_days_ago = datetime.utcnow() - timedelta(days=7)
active_users = [u for u in users if u.last_login and u.last_login > seven_days_ago]
return {
'total_users': len(users),
'active_users': len(active_users),
'inactive_users': len(users) - len(active_users),
'user_growth_7d': 0, # Would calculate from historical data
'by_role': {
'admin': len([u for u in users if u.user_type == 'tenant_admin']),
'developer': len([u for u in users if u.user_type == 'developer']),
'analyst': len([u for u in users if u.user_type == 'analyst']),
'student': len([u for u in users if u.user_type == 'student'])
}
}
async def _get_resource_usage_federated(self, tenant_id: int) -> Dict[str, Any]:
"""Get resource usage via Dremio federation to Resource Cluster"""
try:
# Query Resource Cluster data via Dremio
sql = f"""
SELECT
resource_type,
COUNT(*) as request_count,
SUM(tokens_used) as total_tokens,
SUM(cost_cents) as total_cost_cents,
AVG(processing_time_ms) as avg_latency_ms
FROM resource_cluster.usage_records
WHERE tenant_id = {tenant_id}
AND started_at >= CURRENT_DATE - INTERVAL '7' DAY
GROUP BY resource_type
"""
results = await self.execute_query(sql)
# Process results
usage_by_type = {}
total_requests = 0
total_tokens = 0
total_cost = 0
for row in results:
usage_by_type[row['resource_type']] = {
'requests': row['request_count'],
'tokens': row['total_tokens'],
'cost_cents': row['total_cost_cents'],
'avg_latency_ms': row['avg_latency_ms']
}
total_requests += row['request_count']
total_tokens += row['total_tokens'] or 0
total_cost += row['total_cost_cents'] or 0
return {
'total_requests_7d': total_requests,
'total_tokens_7d': total_tokens,
'total_cost_cents_7d': total_cost,
'by_resource_type': usage_by_type
}
except Exception as e:
# Fallback to local database query if Dremio fails
print(f"Dremio query failed, using local data: {e}")
return await self._get_resource_usage_local(tenant_id)
async def _get_resource_usage_local(self, tenant_id: int) -> Dict[str, Any]:
"""Fallback: Get resource usage from local database"""
seven_days_ago = datetime.utcnow() - timedelta(days=7)
result = await self.db.execute(
select(UsageRecord).where(
UsageRecord.tenant_id == tenant_id,
UsageRecord.started_at >= seven_days_ago
)
)
usage_records = result.scalars().all()
usage_by_type = {}
total_requests = len(usage_records)
total_tokens = sum(r.tokens_used or 0 for r in usage_records)
total_cost = sum(r.cost_cents or 0 for r in usage_records)
for record in usage_records:
if record.operation_type not in usage_by_type:
usage_by_type[record.operation_type] = {
'requests': 0,
'tokens': 0,
'cost_cents': 0
}
usage_by_type[record.operation_type]['requests'] += 1
usage_by_type[record.operation_type]['tokens'] += record.tokens_used or 0
usage_by_type[record.operation_type]['cost_cents'] += record.cost_cents or 0
return {
'total_requests_7d': total_requests,
'total_tokens_7d': total_tokens,
'total_cost_cents_7d': total_cost,
'by_resource_type': usage_by_type
}
async def _get_application_metrics_federated(self, tenant_domain: str) -> Dict[str, Any]:
"""Get application metrics via Dremio federation to Tenant Cluster"""
try:
# Query Tenant Cluster data via Dremio
sql = f"""
SELECT
COUNT(DISTINCT c.id) as total_conversations,
COUNT(m.id) as total_messages,
COUNT(DISTINCT a.id) as total_assistants,
COUNT(DISTINCT d.id) as total_documents,
SUM(d.chunk_count) as total_chunks,
AVG(m.processing_time_ms) as avg_response_time_ms
FROM tenant_{tenant_domain}.conversations c
LEFT JOIN tenant_{tenant_domain}.messages m ON c.id = m.conversation_id
LEFT JOIN tenant_{tenant_domain}.agents a ON c.agent_id = a.id
LEFT JOIN tenant_{tenant_domain}.documents d ON d.created_at >= CURRENT_DATE - INTERVAL '7' DAY
WHERE c.created_at >= CURRENT_DATE - INTERVAL '7' DAY
"""
results = await self.execute_query(sql)
if results:
row = results[0]
return {
'conversations': row['total_conversations'] or 0,
'messages': row['total_messages'] or 0,
'agents': row['total_assistants'] or 0,
'documents': row['total_documents'] or 0,
'document_chunks': row['total_chunks'] or 0,
'avg_response_time_ms': row['avg_response_time_ms'] or 0
}
except Exception as e:
print(f"Dremio tenant query failed: {e}")
# Return default metrics if query fails
return {
'conversations': 0,
'messages': 0,
'agents': 0,
'documents': 0,
'document_chunks': 0,
'avg_response_time_ms': 0
}
async def _get_performance_metrics(self, tenant_id: int) -> Dict[str, Any]:
"""Get performance metrics for the tenant"""
# This would aggregate performance data from various sources
return {
'api_latency_p50_ms': 45,
'api_latency_p95_ms': 120,
'api_latency_p99_ms': 250,
'uptime_percentage': 99.95,
'error_rate_percentage': 0.12,
'concurrent_users': 23,
'requests_per_second': 45.6
}
async def _get_security_alerts(self, tenant_id: int) -> List[Dict[str, Any]]:
"""Get security alerts for the tenant"""
# This would query security monitoring systems
alerts = []
# Check for common security issues
# 1. Check for expired API keys
result = await self.db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none()
if tenant and tenant.api_keys:
for provider, info in tenant.api_keys.items():
updated_at = datetime.fromisoformat(info.get('updated_at', '2020-01-01T00:00:00'))
if (datetime.utcnow() - updated_at).days > 90:
alerts.append({
'severity': 'warning',
'type': 'api_key_rotation',
'message': f'API key for {provider} has not been rotated in over 90 days',
'timestamp': datetime.utcnow().isoformat()
})
# 2. Check for high error rates (would come from monitoring)
# 3. Check for unusual access patterns (would come from logs)
return alerts
async def create_virtual_datasets(self, tenant_id: int) -> Dict[str, Any]:
"""Create Dremio virtual datasets for tenant analytics"""
token = await self._authenticate()
# Create virtual datasets that join data across clusters
datasets = [
{
'name': f'tenant_{tenant_id}_unified_usage',
'sql': f"""
SELECT
ac.user_email,
ac.user_type,
rc.resource_type,
rc.operation_type,
rc.tokens_used,
rc.cost_cents,
rc.started_at,
tc.conversation_id,
tc.assistant_name
FROM admin_cluster.users ac
JOIN resource_cluster.usage_records rc ON ac.email = rc.user_id
LEFT JOIN tenant_cluster.conversations tc ON rc.conversation_id = tc.id
WHERE ac.tenant_id = {tenant_id}
"""
},
{
'name': f'tenant_{tenant_id}_cost_analysis',
'sql': f"""
SELECT
DATE_TRUNC('day', started_at) as date,
resource_type,
SUM(tokens_used) as daily_tokens,
SUM(cost_cents) as daily_cost_cents,
COUNT(*) as daily_requests
FROM resource_cluster.usage_records
WHERE tenant_id = {tenant_id}
GROUP BY DATE_TRUNC('day', started_at), resource_type
"""
}
]
created_datasets = []
for dataset in datasets:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.dremio_url}/api/v3/catalog",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
},
json={
"entityType": "dataset",
"path": ["Analytics", dataset['name']],
"dataset": {
"type": "VIRTUAL",
"sql": dataset['sql'],
"sqlContext": ["@admin"]
}
}
)
if response.status_code in [200, 201]:
created_datasets.append(dataset['name'])
return {
'tenant_id': tenant_id,
'datasets_created': created_datasets,
'status': 'success'
}
async def get_custom_analytics(
self,
tenant_id: int,
query_type: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Run custom analytics queries for a tenant"""
if not start_date:
start_date = datetime.utcnow() - timedelta(days=30)
if not end_date:
end_date = datetime.utcnow()
queries = {
'user_activity': f"""
SELECT
u.email,
u.user_type,
COUNT(DISTINCT ur.conversation_id) as conversations,
SUM(ur.tokens_used) as total_tokens,
SUM(ur.cost_cents) as total_cost_cents
FROM admin_cluster.users u
LEFT JOIN resource_cluster.usage_records ur ON u.email = ur.user_id
WHERE u.tenant_id = {tenant_id}
AND ur.started_at BETWEEN '{start_date.isoformat()}' AND '{end_date.isoformat()}'
GROUP BY u.email, u.user_type
ORDER BY total_cost_cents DESC
""",
'resource_trends': f"""
SELECT
DATE_TRUNC('day', started_at) as date,
resource_type,
COUNT(*) as requests,
SUM(tokens_used) as tokens,
SUM(cost_cents) as cost_cents
FROM resource_cluster.usage_records
WHERE tenant_id = {tenant_id}
AND started_at BETWEEN '{start_date.isoformat()}' AND '{end_date.isoformat()}'
GROUP BY DATE_TRUNC('day', started_at), resource_type
ORDER BY date DESC
""",
'cost_optimization': f"""
SELECT
resource_type,
operation_type,
AVG(tokens_used) as avg_tokens,
AVG(cost_cents) as avg_cost_cents,
COUNT(*) as request_count,
SUM(cost_cents) as total_cost_cents
FROM resource_cluster.usage_records
WHERE tenant_id = {tenant_id}
AND started_at BETWEEN '{start_date.isoformat()}' AND '{end_date.isoformat()}'
GROUP BY resource_type, operation_type
HAVING COUNT(*) > 10
ORDER BY total_cost_cents DESC
LIMIT 20
"""
}
if query_type not in queries:
raise ValueError(f"Unknown query type: {query_type}")
try:
results = await self.execute_query(queries[query_type])
return results
except Exception as e:
print(f"Analytics query failed: {e}")
return []

View File

@@ -0,0 +1,307 @@
"""
Groq LLM integration service with high availability and failover support
"""
import asyncio
import time
from typing import Dict, Any, List, Optional, AsyncGenerator
from datetime import datetime, timedelta
import httpx
import json
import logging
from contextlib import asynccontextmanager
from app.models.ai_resource import AIResource
from app.models.usage import UsageRecord
logger = logging.getLogger(__name__)
class GroqAPIError(Exception):
"""Custom exception for Groq API errors"""
def __init__(self, message: str, status_code: Optional[int] = None, response_body: Optional[str] = None):
self.message = message
self.status_code = status_code
self.response_body = response_body
super().__init__(self.message)
class GroqClient:
"""High-availability Groq API client with automatic failover"""
def __init__(self, resource: AIResource, api_key: str):
self.resource = resource
self.api_key = api_key
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
)
self._current_endpoint_index = 0
self._endpoint_failures = {}
self._rate_limit_reset = None
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.client.aclose()
def _get_next_endpoint(self) -> Optional[str]:
"""Get next available endpoint with circuit breaker logic"""
endpoints = self.resource.get_available_endpoints()
if not endpoints:
return None
# Try current endpoint first if not in failure state
current_endpoint = endpoints[self._current_endpoint_index % len(endpoints)]
failure_info = self._endpoint_failures.get(current_endpoint)
if not failure_info or failure_info["reset_time"] < datetime.utcnow():
return current_endpoint
# Find next healthy endpoint
for i in range(len(endpoints)):
endpoint = endpoints[(self._current_endpoint_index + i + 1) % len(endpoints)]
failure_info = self._endpoint_failures.get(endpoint)
if not failure_info or failure_info["reset_time"] < datetime.utcnow():
self._current_endpoint_index = (self._current_endpoint_index + i + 1) % len(endpoints)
return endpoint
return None
def _mark_endpoint_failed(self, endpoint: str, backoff_minutes: int = 5):
"""Mark endpoint as failed with exponential backoff"""
current_failures = self._endpoint_failures.get(endpoint, {"count": 0})
current_failures["count"] += 1
# Exponential backoff: 5min, 10min, 20min, 40min, max 60min
backoff_time = min(backoff_minutes * (2 ** (current_failures["count"] - 1)), 60)
current_failures["reset_time"] = datetime.utcnow() + timedelta(minutes=backoff_time)
self._endpoint_failures[endpoint] = current_failures
logger.warning(f"Marked endpoint {endpoint} as failed for {backoff_time} minutes (failure #{current_failures['count']})")
def _reset_endpoint_failures(self, endpoint: str):
"""Reset failure count for successful endpoint"""
if endpoint in self._endpoint_failures:
del self._endpoint_failures[endpoint]
async def _make_request(self, method: str, path: str, **kwargs) -> Dict[str, Any]:
"""Make HTTP request with automatic failover"""
last_error = None
for attempt in range(len(self.resource.get_available_endpoints()) + 1):
endpoint = self._get_next_endpoint()
if not endpoint:
raise GroqAPIError("No healthy endpoints available")
url = f"{endpoint.rstrip('/')}/{path.lstrip('/')}"
try:
logger.debug(f"Making {method} request to {url}")
response = await self.client.request(method, url, **kwargs)
# Handle rate limiting
if response.status_code == 429:
retry_after = int(response.headers.get("retry-after", "60"))
self._rate_limit_reset = datetime.utcnow() + timedelta(seconds=retry_after)
raise GroqAPIError(f"Rate limited, retry after {retry_after} seconds", 429)
# Handle server errors with failover
if response.status_code >= 500:
self._mark_endpoint_failed(endpoint)
last_error = GroqAPIError(f"Server error: {response.status_code}", response.status_code, response.text)
continue
# Handle client errors (don't retry)
if response.status_code >= 400:
raise GroqAPIError(f"Client error: {response.status_code}", response.status_code, response.text)
# Success - reset failures for this endpoint
self._reset_endpoint_failures(endpoint)
return response.json()
except httpx.RequestError as e:
logger.warning(f"Request failed for endpoint {endpoint}: {e}")
self._mark_endpoint_failed(endpoint)
last_error = GroqAPIError(f"Request failed: {str(e)}")
continue
# All endpoints failed
raise last_error or GroqAPIError("All endpoints failed")
async def health_check(self) -> bool:
"""Check if the Groq API is healthy"""
try:
await self._make_request("GET", "models")
return True
except Exception as e:
logger.error(f"Health check failed: {e}")
return False
async def list_models(self) -> List[Dict[str, Any]]:
"""List available models"""
response = await self._make_request("GET", "models")
return response.get("data", [])
async def chat_completion(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
stream: bool = False,
**kwargs
) -> Dict[str, Any]:
"""Create chat completion"""
config = self.resource.merge_config(kwargs)
payload = {
"model": model or self.resource.model_name,
"messages": messages,
"stream": stream,
**config
}
# Remove None values
payload = {k: v for k, v in payload.items() if v is not None}
start_time = time.time()
response = await self._make_request("POST", "chat/completions", json=payload)
latency_ms = int((time.time() - start_time) * 1000)
# Log performance metrics
if latency_ms > self.resource.latency_sla_ms:
logger.warning(f"Request exceeded SLA: {latency_ms}ms > {self.resource.latency_sla_ms}ms")
return {
**response,
"_metadata": {
"latency_ms": latency_ms,
"model_used": payload["model"],
"endpoint_used": self._get_next_endpoint()
}
}
async def chat_completion_stream(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
**kwargs
) -> AsyncGenerator[Dict[str, Any], None]:
"""Create streaming chat completion"""
config = self.resource.merge_config(kwargs)
payload = {
"model": model or self.resource.model_name,
"messages": messages,
"stream": True,
**config
}
# Remove None values
payload = {k: v for k, v in payload.items() if v is not None}
endpoint = self._get_next_endpoint()
if not endpoint:
raise GroqAPIError("No healthy endpoints available")
url = f"{endpoint.rstrip('/')}/chat/completions"
async with self.client.stream("POST", url, json=payload) as response:
if response.status_code >= 400:
error_text = await response.aread()
raise GroqAPIError(f"Stream error: {response.status_code}", response.status_code, error_text.decode())
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:] # Remove "data: " prefix
if data.strip() == "[DONE]":
break
try:
yield json.loads(data)
except json.JSONDecodeError:
continue
class GroqService:
"""Service for managing Groq resources and API interactions"""
def __init__(self):
self._clients: Dict[int, GroqClient] = {}
@asynccontextmanager
async def get_client(self, resource: AIResource, api_key: str):
"""Get or create a Groq client for the resource"""
if resource.id not in self._clients:
self._clients[resource.id] = GroqClient(resource, api_key)
try:
yield self._clients[resource.id]
finally:
# Keep clients alive for reuse, cleanup handled separately
pass
async def health_check_resource(self, resource: AIResource, api_key: str) -> bool:
"""Perform health check on a Groq resource"""
try:
async with self.get_client(resource, api_key) as client:
is_healthy = await client.health_check()
resource.update_health_status("healthy" if is_healthy else "unhealthy")
return is_healthy
except Exception as e:
logger.error(f"Health check failed for resource {resource.id}: {e}")
resource.update_health_status("unhealthy")
return False
async def chat_completion(
self,
resource: AIResource,
api_key: str,
messages: List[Dict[str, str]],
user_email: str,
tenant_id: int,
**kwargs
) -> Dict[str, Any]:
"""Create chat completion with usage tracking"""
async with self.get_client(resource, api_key) as client:
response = await client.chat_completion(messages, **kwargs)
# Extract usage information
usage = response.get("usage", {})
total_tokens = usage.get("total_tokens", 0)
# Calculate cost
cost_cents = resource.calculate_cost(total_tokens)
# Create usage record (would be saved to database)
usage_record = {
"tenant_id": tenant_id,
"resource_id": resource.id,
"user_email": user_email,
"request_type": "chat_completion",
"tokens_used": total_tokens,
"cost_cents": cost_cents,
"model_used": response.get("_metadata", {}).get("model_used", resource.model_name),
"latency_ms": response.get("_metadata", {}).get("latency_ms", 0)
}
logger.info(f"Chat completion: {total_tokens} tokens, ${cost_cents/100:.4f} cost")
return {
**response,
"_usage_record": usage_record
}
async def cleanup_clients(self):
"""Cleanup inactive clients"""
for resource_id, client in list(self._clients.items()):
try:
await client.client.aclose()
except Exception:
pass
self._clients.clear()
# Global service instance
groq_service = GroqService()

View File

@@ -0,0 +1,435 @@
"""
RabbitMQ Message Bus Service for cross-cluster communication
Implements secure message passing between Admin, Tenant, and Resource clusters
with cryptographic signing and air-gap communication protocol.
"""
import asyncio
import json
import logging
import hashlib
import hmac
import uuid
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List, Callable
from dataclasses import dataclass, asdict
import aio_pika
from aio_pika import Message, ExchangeType, DeliveryMode
from aio_pika.abc import AbstractRobustConnection, AbstractRobustChannel
from app.core.config import settings
logger = logging.getLogger(__name__)
@dataclass
class AdminCommand:
"""Base class for admin commands sent via message bus"""
command_id: str
command_type: str
target_cluster: str # 'tenant' or 'resource'
target_namespace: Optional[str] # For tenant-specific commands
payload: Dict[str, Any]
timestamp: str
signature: str = ""
def to_dict(self) -> Dict[str, Any]:
"""Convert command to dictionary for JSON serialization"""
return asdict(self)
def sign(self, secret_key: str) -> None:
"""Sign the command with HMAC-SHA256"""
# Create message to sign (exclude signature field)
message = json.dumps({
'command_id': self.command_id,
'command_type': self.command_type,
'target_cluster': self.target_cluster,
'target_namespace': self.target_namespace,
'payload': self.payload,
'timestamp': self.timestamp
}, sort_keys=True)
# Generate signature
self.signature = hmac.new(
secret_key.encode(),
message.encode(),
hashlib.sha256
).hexdigest()
@classmethod
def verify_signature(cls, data: Dict[str, Any], secret_key: str) -> bool:
"""Verify command signature"""
signature = data.get('signature', '')
# Create message to verify (exclude signature field)
message = json.dumps({
'command_id': data.get('command_id'),
'command_type': data.get('command_type'),
'target_cluster': data.get('target_cluster'),
'target_namespace': data.get('target_namespace'),
'payload': data.get('payload'),
'timestamp': data.get('timestamp')
}, sort_keys=True)
# Verify signature
expected_signature = hmac.new(
secret_key.encode(),
message.encode(),
hashlib.sha256
).hexdigest()
return hmac.compare_digest(signature, expected_signature)
class MessageBusService:
"""RabbitMQ message bus service for cross-cluster communication"""
def __init__(self):
self.connection: Optional[AbstractRobustConnection] = None
self.channel: Optional[AbstractRobustChannel] = None
self.command_callbacks: Dict[str, List[Callable]] = {}
self.response_futures: Dict[str, asyncio.Future] = {}
self.secret_key = settings.MESSAGE_BUS_SECRET_KEY or "PRODUCTION_MESSAGE_BUS_SECRET_REQUIRED"
async def connect(self) -> None:
"""Establish connection to RabbitMQ"""
try:
# Get connection URL from settings
rabbitmq_url = settings.RABBITMQ_URL or "amqp://admin:dev_rabbitmq_password@localhost:5672/gt2"
# Create robust connection (auto-reconnect on failure)
self.connection = await aio_pika.connect_robust(
rabbitmq_url,
client_properties={
'connection_name': 'gt2-control-panel'
}
)
# Create channel
self.channel = await self.connection.channel()
await self.channel.set_qos(prefetch_count=10)
# Declare exchanges
await self._declare_exchanges()
# Set up queues for receiving responses
await self._setup_response_queue()
logger.info("Connected to RabbitMQ message bus")
except Exception as e:
logger.error(f"Failed to connect to RabbitMQ: {e}")
raise
async def disconnect(self) -> None:
"""Close RabbitMQ connection"""
if self.channel:
await self.channel.close()
if self.connection:
await self.connection.close()
logger.info("Disconnected from RabbitMQ message bus")
async def _declare_exchanges(self) -> None:
"""Declare message exchanges for cross-cluster communication"""
# Admin commands exchange (fanout to all clusters)
await self.channel.declare_exchange(
name='gt2.admin.commands',
type=ExchangeType.TOPIC,
durable=True
)
# Tenant cluster exchange
await self.channel.declare_exchange(
name='gt2.tenant.commands',
type=ExchangeType.DIRECT,
durable=True
)
# Resource cluster exchange
await self.channel.declare_exchange(
name='gt2.resource.commands',
type=ExchangeType.DIRECT,
durable=True
)
# Response exchange (for command responses)
await self.channel.declare_exchange(
name='gt2.responses',
type=ExchangeType.DIRECT,
durable=True
)
# System alerts exchange
await self.channel.declare_exchange(
name='gt2.alerts',
type=ExchangeType.FANOUT,
durable=True
)
async def _setup_response_queue(self) -> None:
"""Set up queue for receiving command responses"""
# Declare response queue for this control panel instance
queue_name = f"gt2.admin.responses.{uuid.uuid4().hex[:8]}"
queue = await self.channel.declare_queue(
name=queue_name,
exclusive=True, # Exclusive to this connection
auto_delete=True # Delete when connection closes
)
# Bind to response exchange
await queue.bind(
exchange='gt2.responses',
routing_key=queue_name
)
# Start consuming responses
await queue.consume(self._handle_response)
self.response_queue_name = queue_name
async def send_tenant_command(
self,
command_type: str,
tenant_namespace: str,
payload: Dict[str, Any],
wait_for_response: bool = False,
timeout: int = 30
) -> Optional[Dict[str, Any]]:
"""
Send command to tenant cluster
Args:
command_type: Type of command (e.g., 'provision', 'deploy', 'suspend')
tenant_namespace: Target tenant namespace
payload: Command payload
wait_for_response: Whether to wait for response
timeout: Response timeout in seconds
Returns:
Response data if wait_for_response is True, else None
"""
command = AdminCommand(
command_id=str(uuid.uuid4()),
command_type=command_type,
target_cluster='tenant',
target_namespace=tenant_namespace,
payload=payload,
timestamp=datetime.utcnow().isoformat()
)
# Sign the command
command.sign(self.secret_key)
# Create response future if needed
if wait_for_response:
future = asyncio.Future()
self.response_futures[command.command_id] = future
# Send command
await self._publish_command(command)
# Wait for response if requested
if wait_for_response:
try:
response = await asyncio.wait_for(future, timeout=timeout)
return response
except asyncio.TimeoutError:
logger.error(f"Command {command.command_id} timed out after {timeout}s")
del self.response_futures[command.command_id]
return None
finally:
# Clean up future
if command.command_id in self.response_futures:
del self.response_futures[command.command_id]
return None
async def send_resource_command(
self,
command_type: str,
payload: Dict[str, Any],
wait_for_response: bool = False,
timeout: int = 30
) -> Optional[Dict[str, Any]]:
"""
Send command to resource cluster
Args:
command_type: Type of command (e.g., 'health_check', 'update_config')
payload: Command payload
wait_for_response: Whether to wait for response
timeout: Response timeout in seconds
Returns:
Response data if wait_for_response is True, else None
"""
command = AdminCommand(
command_id=str(uuid.uuid4()),
command_type=command_type,
target_cluster='resource',
target_namespace=None,
payload=payload,
timestamp=datetime.utcnow().isoformat()
)
# Sign the command
command.sign(self.secret_key)
# Create response future if needed
if wait_for_response:
future = asyncio.Future()
self.response_futures[command.command_id] = future
# Send command
await self._publish_command(command)
# Wait for response if requested
if wait_for_response:
try:
response = await asyncio.wait_for(future, timeout=timeout)
return response
except asyncio.TimeoutError:
logger.error(f"Command {command.command_id} timed out after {timeout}s")
del self.response_futures[command.command_id]
return None
finally:
# Clean up future
if command.command_id in self.response_futures:
del self.response_futures[command.command_id]
return None
async def _publish_command(self, command: AdminCommand) -> None:
"""Publish command to appropriate exchange"""
# Determine exchange and routing key
if command.target_cluster == 'tenant':
exchange_name = 'gt2.tenant.commands'
routing_key = command.target_namespace or 'all'
elif command.target_cluster == 'resource':
exchange_name = 'gt2.resource.commands'
routing_key = 'all'
else:
exchange_name = 'gt2.admin.commands'
routing_key = f"{command.target_cluster}.{command.command_type}"
# Create message
message = Message(
body=json.dumps(command.to_dict()).encode(),
delivery_mode=DeliveryMode.PERSISTENT,
headers={
'command_id': command.command_id,
'command_type': command.command_type,
'timestamp': command.timestamp,
'reply_to': self.response_queue_name if hasattr(self, 'response_queue_name') else None
}
)
# Get exchange
exchange = await self.channel.get_exchange(exchange_name)
# Publish message
await exchange.publish(
message=message,
routing_key=routing_key
)
logger.info(f"Published command {command.command_id} to {exchange_name}/{routing_key}")
async def _handle_response(self, message: aio_pika.IncomingMessage) -> None:
"""Handle response messages"""
async with message.process():
try:
# Parse response
data = json.loads(message.body.decode())
# Verify signature
if not AdminCommand.verify_signature(data, self.secret_key):
logger.error(f"Invalid signature for response: {data.get('command_id')}")
return
command_id = data.get('command_id')
# Check if we're waiting for this response
if command_id in self.response_futures:
future = self.response_futures[command_id]
if not future.done():
future.set_result(data.get('payload'))
# Log response
logger.info(f"Received response for command {command_id}")
except Exception as e:
logger.error(f"Error handling response: {e}")
async def publish_alert(
self,
alert_type: str,
severity: str,
message: str,
details: Optional[Dict[str, Any]] = None
) -> None:
"""
Publish system alert to all clusters
Args:
alert_type: Type of alert (e.g., 'security', 'health', 'deployment')
severity: Alert severity ('info', 'warning', 'error', 'critical')
message: Alert message
details: Additional alert details
"""
alert_data = {
'alert_id': str(uuid.uuid4()),
'alert_type': alert_type,
'severity': severity,
'message': message,
'details': details or {},
'timestamp': datetime.utcnow().isoformat(),
'source': 'admin_cluster'
}
# Sign the alert
alert_json = json.dumps(alert_data, sort_keys=True)
signature = hmac.new(
self.secret_key.encode(),
alert_json.encode(),
hashlib.sha256
).hexdigest()
alert_data['signature'] = signature
# Create message
message = Message(
body=json.dumps(alert_data).encode(),
delivery_mode=DeliveryMode.PERSISTENT,
headers={
'alert_type': alert_type,
'severity': severity,
'timestamp': alert_data['timestamp']
}
)
# Get alerts exchange
exchange = await self.channel.get_exchange('gt2.alerts')
# Publish alert
await exchange.publish(
message=message,
routing_key='' # Fanout exchange, routing key ignored
)
logger.info(f"Published {severity} alert: {message}")
# Global message bus instance
message_bus = MessageBusService()
async def initialize_message_bus():
"""Initialize the message bus connection"""
await message_bus.connect()
async def shutdown_message_bus():
"""Shutdown the message bus connection"""
await message_bus.disconnect()

View File

@@ -0,0 +1,360 @@
"""
Message DMZ Service for secure air-gap communication
Implements security controls for cross-cluster messaging including:
- Message validation and sanitization
- Command signature verification
- Audit logging
- Rate limiting
- Security policy enforcement
"""
import json
import logging
import hashlib
import hmac
import re
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List, Set
from collections import defaultdict
import asyncio
from app.core.config import settings
from app.schemas.messages import CommandType, AlertSeverity
logger = logging.getLogger(__name__)
class SecurityViolation(Exception):
"""Raised when a security policy is violated"""
pass
class MessageDMZ:
"""
Security DMZ for message bus communication
Provides defense-in-depth security controls for cross-cluster messaging
"""
def __init__(self):
# Rate limiting
self.rate_limits: Dict[str, List[datetime]] = defaultdict(list)
self.rate_limit_window = timedelta(minutes=1)
self.max_messages_per_minute = 100
# Command whitelist
self.allowed_commands = set(CommandType)
# Blocked patterns (for detecting potential injection attacks)
self.blocked_patterns = [
r'<script[^>]*>.*?</script>', # XSS
r'javascript:', # JavaScript URI
r'on\w+\s*=', # Event handlers
r'DROP\s+TABLE', # SQL injection
r'DELETE\s+FROM', # SQL injection
r'INSERT\s+INTO', # SQL injection
r'UPDATE\s+SET', # SQL injection
r'--', # SQL comment
r'/\*.*\*/', # SQL block comment
r'\.\./+', # Path traversal
r'\\x[0-9a-fA-F]{2}', # Hex encoding
r'%[0-9a-fA-F]{2}', # URL encoding suspicious patterns
]
# Audit log
self.audit_log: List[Dict[str, Any]] = []
self.max_audit_entries = 10000
# Security metrics
self.metrics = {
'messages_validated': 0,
'messages_rejected': 0,
'signature_failures': 0,
'rate_limit_violations': 0,
'injection_attempts': 0,
}
async def validate_incoming_message(
self,
message: Dict[str, Any],
source: str
) -> Dict[str, Any]:
"""
Validate incoming message from another cluster
Args:
message: Raw message data
source: Source cluster identifier
Returns:
Validated and sanitized message
Raises:
SecurityViolation: If message fails validation
"""
try:
# Check rate limits
if not self._check_rate_limit(source):
self.metrics['rate_limit_violations'] += 1
raise SecurityViolation(f"Rate limit exceeded for source: {source}")
# Verify required fields
required_fields = ['command_id', 'command_type', 'timestamp', 'signature']
for field in required_fields:
if field not in message:
raise SecurityViolation(f"Missing required field: {field}")
# Verify timestamp (prevent replay attacks)
if not self._verify_timestamp(message['timestamp']):
raise SecurityViolation("Message timestamp is too old or invalid")
# Verify command type is allowed
if message['command_type'] not in self.allowed_commands:
raise SecurityViolation(f"Unknown command type: {message['command_type']}")
# Verify signature
if not self._verify_signature(message):
self.metrics['signature_failures'] += 1
raise SecurityViolation("Invalid message signature")
# Sanitize payload
if 'payload' in message:
message['payload'] = self._sanitize_payload(message['payload'])
# Log successful validation
self._audit_log('message_validated', source, message['command_id'])
self.metrics['messages_validated'] += 1
return message
except SecurityViolation:
self.metrics['messages_rejected'] += 1
self._audit_log('message_rejected', source, message.get('command_id', 'unknown'))
raise
except Exception as e:
logger.error(f"Unexpected error validating message: {e}")
self.metrics['messages_rejected'] += 1
raise SecurityViolation(f"Message validation failed: {str(e)}")
async def prepare_outgoing_message(
self,
command_type: str,
payload: Dict[str, Any],
target: str
) -> Dict[str, Any]:
"""
Prepare message for sending to another cluster
Args:
command_type: Type of command
payload: Command payload
target: Target cluster identifier
Returns:
Prepared and signed message
"""
# Sanitize payload
sanitized_payload = self._sanitize_payload(payload)
# Create message structure
message = {
'command_type': command_type,
'payload': sanitized_payload,
'target_cluster': target,
'timestamp': datetime.utcnow().isoformat(),
'source': 'admin_cluster'
}
# Sign message
signature = self._create_signature(message)
message['signature'] = signature
# Audit log
self._audit_log('message_prepared', target, command_type)
return message
def _check_rate_limit(self, source: str) -> bool:
"""Check if source has exceeded rate limits"""
now = datetime.utcnow()
# Clean old entries
cutoff = now - self.rate_limit_window
self.rate_limits[source] = [
ts for ts in self.rate_limits[source]
if ts > cutoff
]
# Check limit
if len(self.rate_limits[source]) >= self.max_messages_per_minute:
return False
# Add current timestamp
self.rate_limits[source].append(now)
return True
def _verify_timestamp(self, timestamp_str: str, max_age_seconds: int = 300) -> bool:
"""Verify message timestamp is recent (prevent replay attacks)"""
try:
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
age = (datetime.utcnow() - timestamp.replace(tzinfo=None)).total_seconds()
# Message too old
if age > max_age_seconds:
return False
# Message from future (clock skew tolerance of 30 seconds)
if age < -30:
return False
return True
except (ValueError, AttributeError):
return False
def _verify_signature(self, message: Dict[str, Any]) -> bool:
"""Verify message signature"""
signature = message.get('signature', '')
# Create message to verify (exclude signature field)
message_copy = {k: v for k, v in message.items() if k != 'signature'}
message_json = json.dumps(message_copy, sort_keys=True)
# Verify signature
expected_signature = hmac.new(
settings.MESSAGE_BUS_SECRET_KEY.encode(),
message_json.encode(),
hashlib.sha256
).hexdigest()
return hmac.compare_digest(signature, expected_signature)
def _create_signature(self, message: Dict[str, Any]) -> str:
"""Create message signature"""
message_json = json.dumps(message, sort_keys=True)
return hmac.new(
settings.MESSAGE_BUS_SECRET_KEY.encode(),
message_json.encode(),
hashlib.sha256
).hexdigest()
def _sanitize_payload(self, payload: Any) -> Any:
"""
Sanitize payload to prevent injection attacks
Recursively sanitizes strings in dictionaries and lists
"""
if isinstance(payload, str):
# Check for blocked patterns
for pattern in self.blocked_patterns:
if re.search(pattern, payload, re.IGNORECASE):
self.metrics['injection_attempts'] += 1
raise SecurityViolation(f"Potential injection attempt detected")
# Basic sanitization
# Remove control characters except standard whitespace
sanitized = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]', '', payload)
# Limit string length
max_length = 10000
if len(sanitized) > max_length:
sanitized = sanitized[:max_length]
return sanitized
elif isinstance(payload, dict):
return {
self._sanitize_payload(k): self._sanitize_payload(v)
for k, v in payload.items()
}
elif isinstance(payload, list):
return [self._sanitize_payload(item) for item in payload]
else:
# Numbers, booleans, None are safe
return payload
def _audit_log(
self,
event_type: str,
target: str,
details: Any
) -> None:
"""Add entry to audit log"""
entry = {
'timestamp': datetime.utcnow().isoformat(),
'event_type': event_type,
'target': target,
'details': details
}
self.audit_log.append(entry)
# Rotate log if too large
if len(self.audit_log) > self.max_audit_entries:
self.audit_log = self.audit_log[-self.max_audit_entries:]
# Log to application logger
logger.info(f"DMZ Audit: {event_type} - Target: {target} - Details: {details}")
def get_security_metrics(self) -> Dict[str, Any]:
"""Get security metrics"""
return {
**self.metrics,
'audit_log_size': len(self.audit_log),
'rate_limited_sources': len(self.rate_limits),
'timestamp': datetime.utcnow().isoformat()
}
def get_audit_log(
self,
limit: int = 100,
event_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Get audit log entries"""
logs = self.audit_log[-limit:]
if event_type:
logs = [log for log in logs if log['event_type'] == event_type]
return logs
async def validate_command_permissions(
self,
command_type: str,
user_id: int,
user_type: str,
tenant_id: Optional[int] = None
) -> bool:
"""
Validate user has permission to execute command
Args:
command_type: Type of command
user_id: User ID
user_type: User type (super_admin, tenant_admin, tenant_user)
tenant_id: Tenant ID (for tenant-scoped commands)
Returns:
True if user has permission, False otherwise
"""
# Super admins can execute all commands
if user_type == 'super_admin':
return True
# Tenant admins can execute tenant-scoped commands for their tenant
if user_type == 'tenant_admin' and tenant_id:
tenant_commands = [
CommandType.USER_CREATE,
CommandType.USER_UPDATE,
CommandType.USER_SUSPEND,
CommandType.RESOURCE_ASSIGN,
CommandType.RESOURCE_UNASSIGN
]
return command_type in tenant_commands
# Regular users cannot execute admin commands
return False
# Global DMZ instance
message_dmz = MessageDMZ()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,525 @@
"""
GT 2.0 Resource Allocation Management Service
Manages CPU, memory, storage, and API quotas for tenants following GT 2.0 principles:
- Granular resource control per tenant
- Real-time usage monitoring
- Automatic scaling within limits
- Cost tracking and optimization
"""
import asyncio
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Tuple
from enum import Enum
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, func, and_
from app.models.tenant import Tenant
from app.models.resource_usage import ResourceUsage, ResourceQuota, ResourceAlert
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class ResourceType(Enum):
"""Types of resources that can be allocated"""
CPU = "cpu"
MEMORY = "memory"
STORAGE = "storage"
API_CALLS = "api_calls"
GPU_TIME = "gpu_time"
VECTOR_OPERATIONS = "vector_operations"
MODEL_INFERENCE = "model_inference"
class AlertLevel(Enum):
"""Resource usage alert levels"""
INFO = "info"
WARNING = "warning"
CRITICAL = "critical"
@dataclass
class ResourceLimit:
"""Resource limit configuration"""
resource_type: ResourceType
max_value: float
warning_threshold: float = 0.8 # 80% of max
critical_threshold: float = 0.95 # 95% of max
unit: str = "units"
cost_per_unit: float = 0.0
@dataclass
class ResourceUsageData:
"""Current resource usage data"""
resource_type: ResourceType
current_usage: float
max_allowed: float
percentage_used: float
cost_accrued: float
last_updated: datetime
class ResourceAllocationService:
"""
Service for managing resource allocation and monitoring usage across tenants.
Features:
- Dynamic quota allocation
- Real-time usage tracking
- Automatic scaling policies
- Cost optimization
- Alert generation
"""
def __init__(self, db: AsyncSession):
self.db = db
# Default resource templates
self.resource_templates = {
"startup": {
ResourceType.CPU: ResourceLimit(ResourceType.CPU, 2.0, unit="cores", cost_per_unit=0.10),
ResourceType.MEMORY: ResourceLimit(ResourceType.MEMORY, 4096, unit="MB", cost_per_unit=0.05),
ResourceType.STORAGE: ResourceLimit(ResourceType.STORAGE, 10240, unit="MB", cost_per_unit=0.01),
ResourceType.API_CALLS: ResourceLimit(ResourceType.API_CALLS, 10000, unit="calls/hour", cost_per_unit=0.001),
ResourceType.MODEL_INFERENCE: ResourceLimit(ResourceType.MODEL_INFERENCE, 1000, unit="tokens", cost_per_unit=0.002),
},
"standard": {
ResourceType.CPU: ResourceLimit(ResourceType.CPU, 4.0, unit="cores", cost_per_unit=0.10),
ResourceType.MEMORY: ResourceLimit(ResourceType.MEMORY, 8192, unit="MB", cost_per_unit=0.05),
ResourceType.STORAGE: ResourceLimit(ResourceType.STORAGE, 51200, unit="MB", cost_per_unit=0.01),
ResourceType.API_CALLS: ResourceLimit(ResourceType.API_CALLS, 50000, unit="calls/hour", cost_per_unit=0.001),
ResourceType.MODEL_INFERENCE: ResourceLimit(ResourceType.MODEL_INFERENCE, 10000, unit="tokens", cost_per_unit=0.002),
},
"enterprise": {
ResourceType.CPU: ResourceLimit(ResourceType.CPU, 16.0, unit="cores", cost_per_unit=0.10),
ResourceType.MEMORY: ResourceLimit(ResourceType.MEMORY, 32768, unit="MB", cost_per_unit=0.05),
ResourceType.STORAGE: ResourceLimit(ResourceType.STORAGE, 102400, unit="MB", cost_per_unit=0.01),
ResourceType.API_CALLS: ResourceLimit(ResourceType.API_CALLS, 200000, unit="calls/hour", cost_per_unit=0.001),
ResourceType.MODEL_INFERENCE: ResourceLimit(ResourceType.MODEL_INFERENCE, 100000, unit="tokens", cost_per_unit=0.002),
ResourceType.GPU_TIME: ResourceLimit(ResourceType.GPU_TIME, 1000, unit="minutes", cost_per_unit=0.50),
}
}
async def allocate_resources(self, tenant_id: int, template: str = "standard") -> bool:
"""
Allocate initial resources to a tenant based on template.
Args:
tenant_id: Tenant database ID
template: Resource template name
Returns:
True if allocation successful
"""
try:
# Get tenant
result = await self.db.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if not tenant:
logger.error(f"Tenant {tenant_id} not found")
return False
# Get resource template
if template not in self.resource_templates:
logger.error(f"Unknown resource template: {template}")
return False
resources = self.resource_templates[template]
# Create resource quotas
for resource_type, limit in resources.items():
quota = ResourceQuota(
tenant_id=tenant_id,
resource_type=resource_type.value,
max_value=limit.max_value,
warning_threshold=limit.warning_threshold,
critical_threshold=limit.critical_threshold,
unit=limit.unit,
cost_per_unit=limit.cost_per_unit,
current_usage=0.0,
is_active=True
)
self.db.add(quota)
await self.db.commit()
logger.info(f"Allocated {template} resources to tenant {tenant.domain}")
return True
except Exception as e:
logger.error(f"Failed to allocate resources to tenant {tenant_id}: {e}")
await self.db.rollback()
return False
async def get_tenant_resource_usage(self, tenant_id: int) -> Dict[str, ResourceUsageData]:
"""
Get current resource usage for a tenant.
Args:
tenant_id: Tenant database ID
Returns:
Dictionary of resource usage data
"""
try:
# Get all quotas for tenant
result = await self.db.execute(
select(ResourceQuota).where(
and_(ResourceQuota.tenant_id == tenant_id, ResourceQuota.is_active == True)
)
)
quotas = result.scalars().all()
usage_data = {}
for quota in quotas:
resource_type = ResourceType(quota.resource_type)
percentage_used = (quota.current_usage / quota.max_value) * 100 if quota.max_value > 0 else 0
usage_data[quota.resource_type] = ResourceUsageData(
resource_type=resource_type,
current_usage=quota.current_usage,
max_allowed=quota.max_value,
percentage_used=percentage_used,
cost_accrued=quota.current_usage * quota.cost_per_unit,
last_updated=quota.updated_at
)
return usage_data
except Exception as e:
logger.error(f"Failed to get resource usage for tenant {tenant_id}: {e}")
return {}
async def update_resource_usage(
self,
tenant_id: int,
resource_type: ResourceType,
usage_delta: float
) -> bool:
"""
Update resource usage for a tenant.
Args:
tenant_id: Tenant database ID
resource_type: Type of resource being used
usage_delta: Change in usage (positive for increase, negative for decrease)
Returns:
True if update successful
"""
try:
# Get resource quota
result = await self.db.execute(
select(ResourceQuota).where(
and_(
ResourceQuota.tenant_id == tenant_id,
ResourceQuota.resource_type == resource_type.value,
ResourceQuota.is_active == True
)
)
)
quota = result.scalar_one_or_none()
if not quota:
logger.warning(f"No quota found for {resource_type.value} for tenant {tenant_id}")
return False
# Calculate new usage
new_usage = max(0, quota.current_usage + usage_delta)
# Check if usage exceeds quota
if new_usage > quota.max_value:
logger.warning(
f"Resource usage would exceed quota for tenant {tenant_id}: "
f"{resource_type.value} {new_usage} > {quota.max_value}"
)
return False
# Update usage
quota.current_usage = new_usage
quota.updated_at = datetime.utcnow()
# Record usage history
usage_record = ResourceUsage(
tenant_id=tenant_id,
resource_type=resource_type.value,
usage_amount=usage_delta,
timestamp=datetime.utcnow(),
cost=usage_delta * quota.cost_per_unit
)
self.db.add(usage_record)
await self.db.commit()
# Check for alerts
await self._check_usage_alerts(tenant_id, quota)
return True
except Exception as e:
logger.error(f"Failed to update resource usage: {e}")
await self.db.rollback()
return False
async def _check_usage_alerts(self, tenant_id: int, quota: ResourceQuota) -> None:
"""Check if resource usage triggers alerts"""
try:
percentage_used = (quota.current_usage / quota.max_value) if quota.max_value > 0 else 0
alert_level = None
message = None
if percentage_used >= quota.critical_threshold:
alert_level = AlertLevel.CRITICAL
message = f"Critical: {quota.resource_type} usage at {percentage_used:.1f}%"
elif percentage_used >= quota.warning_threshold:
alert_level = AlertLevel.WARNING
message = f"Warning: {quota.resource_type} usage at {percentage_used:.1f}%"
if alert_level:
# Check if we already have a recent alert
recent_alert = await self.db.execute(
select(ResourceAlert).where(
and_(
ResourceAlert.tenant_id == tenant_id,
ResourceAlert.resource_type == quota.resource_type,
ResourceAlert.alert_level == alert_level.value,
ResourceAlert.created_at >= datetime.utcnow() - timedelta(hours=1)
)
)
)
if not recent_alert.scalar_one_or_none():
# Create new alert
alert = ResourceAlert(
tenant_id=tenant_id,
resource_type=quota.resource_type,
alert_level=alert_level.value,
message=message,
current_usage=quota.current_usage,
max_value=quota.max_value,
percentage_used=percentage_used
)
self.db.add(alert)
await self.db.commit()
logger.warning(f"Resource alert for tenant {tenant_id}: {message}")
except Exception as e:
logger.error(f"Failed to check usage alerts: {e}")
async def get_tenant_costs(self, tenant_id: int, start_date: datetime, end_date: datetime) -> Dict[str, Any]:
"""
Calculate costs for a tenant over a date range.
Args:
tenant_id: Tenant database ID
start_date: Start of cost calculation period
end_date: End of cost calculation period
Returns:
Cost breakdown by resource type
"""
try:
# Get usage records for the period
result = await self.db.execute(
select(ResourceUsage).where(
and_(
ResourceUsage.tenant_id == tenant_id,
ResourceUsage.timestamp >= start_date,
ResourceUsage.timestamp <= end_date
)
)
)
usage_records = result.scalars().all()
# Calculate costs by resource type
costs_by_type = {}
total_cost = 0.0
for record in usage_records:
if record.resource_type not in costs_by_type:
costs_by_type[record.resource_type] = {
"total_usage": 0.0,
"total_cost": 0.0,
"usage_events": 0
}
costs_by_type[record.resource_type]["total_usage"] += record.usage_amount
costs_by_type[record.resource_type]["total_cost"] += record.cost
costs_by_type[record.resource_type]["usage_events"] += 1
total_cost += record.cost
return {
"tenant_id": tenant_id,
"period_start": start_date.isoformat(),
"period_end": end_date.isoformat(),
"total_cost": round(total_cost, 4),
"costs_by_resource": costs_by_type,
"currency": "USD"
}
except Exception as e:
logger.error(f"Failed to calculate costs for tenant {tenant_id}: {e}")
return {}
async def scale_tenant_resources(
self,
tenant_id: int,
resource_type: ResourceType,
scale_factor: float
) -> bool:
"""
Scale tenant resources up or down.
Args:
tenant_id: Tenant database ID
resource_type: Type of resource to scale
scale_factor: Scaling factor (1.5 = 50% increase, 0.8 = 20% decrease)
Returns:
True if scaling successful
"""
try:
# Get current quota
result = await self.db.execute(
select(ResourceQuota).where(
and_(
ResourceQuota.tenant_id == tenant_id,
ResourceQuota.resource_type == resource_type.value,
ResourceQuota.is_active == True
)
)
)
quota = result.scalar_one_or_none()
if not quota:
logger.error(f"No quota found for {resource_type.value} for tenant {tenant_id}")
return False
# Calculate new limit
new_max_value = quota.max_value * scale_factor
# Ensure we don't scale below current usage
if new_max_value < quota.current_usage:
logger.warning(
f"Cannot scale {resource_type.value} below current usage: "
f"{new_max_value} < {quota.current_usage}"
)
return False
# Update quota
quota.max_value = new_max_value
quota.updated_at = datetime.utcnow()
await self.db.commit()
logger.info(
f"Scaled {resource_type.value} for tenant {tenant_id} by {scale_factor}x to {new_max_value}"
)
return True
except Exception as e:
logger.error(f"Failed to scale resources for tenant {tenant_id}: {e}")
await self.db.rollback()
return False
async def get_system_resource_overview(self) -> Dict[str, Any]:
"""
Get system-wide resource usage overview.
Returns:
System resource usage statistics
"""
try:
# Get aggregate usage by resource type
result = await self.db.execute(
select(
ResourceQuota.resource_type,
func.sum(ResourceQuota.current_usage).label('total_usage'),
func.sum(ResourceQuota.max_value).label('total_allocated'),
func.count(ResourceQuota.tenant_id).label('tenant_count')
).where(ResourceQuota.is_active == True)
.group_by(ResourceQuota.resource_type)
)
overview = {}
for row in result:
resource_type = row.resource_type
total_usage = float(row.total_usage or 0)
total_allocated = float(row.total_allocated or 0)
tenant_count = int(row.tenant_count or 0)
utilization = (total_usage / total_allocated) * 100 if total_allocated > 0 else 0
overview[resource_type] = {
"total_usage": total_usage,
"total_allocated": total_allocated,
"utilization_percentage": round(utilization, 2),
"tenant_count": tenant_count
}
return {
"timestamp": datetime.utcnow().isoformat(),
"resource_overview": overview,
"total_tenants": len(set([row.tenant_count for row in result]))
}
except Exception as e:
logger.error(f"Failed to get system resource overview: {e}")
return {}
async def get_resource_alerts(self, tenant_id: Optional[int] = None, hours: int = 24) -> List[Dict[str, Any]]:
"""
Get resource alerts for tenant(s).
Args:
tenant_id: Specific tenant ID (None for all tenants)
hours: Hours back to look for alerts
Returns:
List of alert dictionaries
"""
try:
query = select(ResourceAlert).where(
ResourceAlert.created_at >= datetime.utcnow() - timedelta(hours=hours)
)
if tenant_id:
query = query.where(ResourceAlert.tenant_id == tenant_id)
query = query.order_by(ResourceAlert.created_at.desc())
result = await self.db.execute(query)
alerts = result.scalars().all()
return [
{
"id": alert.id,
"tenant_id": alert.tenant_id,
"resource_type": alert.resource_type,
"alert_level": alert.alert_level,
"message": alert.message,
"current_usage": alert.current_usage,
"max_value": alert.max_value,
"percentage_used": alert.percentage_used,
"created_at": alert.created_at.isoformat()
}
for alert in alerts
]
except Exception as e:
logger.error(f"Failed to get resource alerts: {e}")
return []

View File

@@ -0,0 +1,821 @@
"""
Comprehensive Resource management service for all GT 2.0 resource families
Supports business logic and validation for:
- AI/ML Resources (LLMs, embeddings, image generation, function calling)
- RAG Engine Resources (vector databases, document processing, retrieval systems)
- Agentic Workflow Resources (multi-step AI workflows, agent frameworks)
- App Integration Resources (external tools, APIs, webhooks)
- External Web Services (Canvas LMS, CTFd, Guacamole, iframe-embedded services)
- AI Literacy & Cognitive Skills (educational games, puzzles, learning content)
"""
import asyncio
from typing import Dict, Any, List, Optional, Union
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func
from sqlalchemy.orm import selectinload
import logging
import json
import base64
from cryptography.fernet import Fernet
from app.core.config import get_settings
from app.models.ai_resource import AIResource
from app.models.tenant import Tenant, TenantResource
from app.models.usage import UsageRecord
from app.models.user_data import UserResourceData, UserPreferences, UserProgress, SessionData
from app.models.resource_schemas import validate_resource_config, get_config_schema
from app.services.groq_service import groq_service
# Use existing encryption implementation from GT 2.0
from cryptography.fernet import Fernet
import base64
logger = logging.getLogger(__name__)
class ResourceService:
"""Comprehensive service for managing all GT 2.0 resource families with HA and business logic"""
def __init__(self, db: AsyncSession):
self.db = db
async def create_resource(self, resource_data: Dict[str, Any]) -> AIResource:
"""Create a new resource with comprehensive validation for all resource families"""
# Validate required fields (model_name is now optional for non-AI resources)
required_fields = ["name", "resource_type", "provider"]
for field in required_fields:
if field not in resource_data:
raise ValueError(f"Missing required field: {field}")
# Validate resource type
valid_resource_types = [
"ai_ml", "rag_engine", "agentic_workflow",
"app_integration", "external_service", "ai_literacy"
]
if resource_data["resource_type"] not in valid_resource_types:
raise ValueError(f"Invalid resource_type. Must be one of: {valid_resource_types}")
# Validate and apply configuration based on resource type and subtype
resource_subtype = resource_data.get("resource_subtype")
if "configuration" in resource_data:
try:
validated_config = validate_resource_config(
resource_data["resource_type"],
resource_subtype or "default",
resource_data["configuration"]
)
resource_data["configuration"] = validated_config
except Exception as e:
logger.warning(f"Configuration validation failed: {e}. Using provided config as-is.")
# Apply resource-family-specific defaults
await self._apply_resource_defaults(resource_data)
# Validate specific requirements by resource family
await self._validate_resource_requirements(resource_data)
# Create resource
resource = AIResource(**resource_data)
self.db.add(resource)
await self.db.commit()
await self.db.refresh(resource)
logger.info(f"Created {resource.resource_type} resource: {resource.name} ({resource.provider})")
return resource
async def get_resource(self, resource_id: int) -> Optional[AIResource]:
"""Get resource by ID with relationships"""
result = await self.db.execute(
select(AIResource)
.options(selectinload(AIResource.tenant_resources))
.where(AIResource.id == resource_id)
)
return result.scalar_one_or_none()
async def get_resource_by_uuid(self, resource_uuid: str) -> Optional[AIResource]:
"""Get resource by UUID"""
result = await self.db.execute(
select(AIResource)
.where(AIResource.uuid == resource_uuid)
)
return result.scalar_one_or_none()
async def list_resources(
self,
provider: Optional[str] = None,
resource_type: Optional[str] = None,
is_active: Optional[bool] = None,
health_status: Optional[str] = None
) -> List[AIResource]:
"""List resources with filtering"""
query = select(AIResource).options(selectinload(AIResource.tenant_resources))
conditions = []
if provider:
conditions.append(AIResource.provider == provider)
if resource_type:
conditions.append(AIResource.resource_type == resource_type)
if is_active is not None:
conditions.append(AIResource.is_active == is_active)
if health_status:
conditions.append(AIResource.health_status == health_status)
if conditions:
query = query.where(and_(*conditions))
result = await self.db.execute(query.order_by(AIResource.priority.desc(), AIResource.created_at))
return result.scalars().all()
async def update_resource(self, resource_id: int, updates: Dict[str, Any]) -> Optional[AIResource]:
"""Update resource with validation"""
resource = await self.get_resource(resource_id)
if not resource:
return None
# Update fields
for key, value in updates.items():
if hasattr(resource, key):
setattr(resource, key, value)
resource.updated_at = datetime.utcnow()
await self.db.commit()
await self.db.refresh(resource)
logger.info(f"Updated resource {resource_id}: {list(updates.keys())}")
return resource
async def delete_resource(self, resource_id: int) -> bool:
"""Delete resource (soft delete by deactivating)"""
resource = await self.get_resource(resource_id)
if not resource:
return False
# Check if resource is in use by tenants
result = await self.db.execute(
select(TenantResource)
.where(and_(
TenantResource.resource_id == resource_id,
TenantResource.is_enabled == True
))
)
active_assignments = result.scalars().all()
if active_assignments:
raise ValueError(f"Cannot delete resource in use by {len(active_assignments)} tenants")
# Soft delete
resource.is_active = False
resource.health_status = "deleted"
resource.updated_at = datetime.utcnow()
await self.db.commit()
logger.info(f"Deleted resource {resource_id}")
return True
async def assign_resource_to_tenant(
self,
resource_id: int,
tenant_id: int,
usage_limits: Optional[Dict[str, Any]] = None
) -> TenantResource:
"""Assign resource to tenant with usage limits"""
# Validate resource exists and is active
resource = await self.get_resource(resource_id)
if not resource or not resource.is_active:
raise ValueError("Resource not found or inactive")
# Validate tenant exists
tenant_result = await self.db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = tenant_result.scalar_one_or_none()
if not tenant:
raise ValueError("Tenant not found")
# Check if assignment already exists
existing_result = await self.db.execute(
select(TenantResource)
.where(and_(
TenantResource.tenant_id == tenant_id,
TenantResource.resource_id == resource_id
))
)
existing = existing_result.scalar_one_or_none()
if existing:
# Update existing assignment
existing.is_enabled = True
existing.usage_limits = usage_limits or {}
existing.updated_at = datetime.utcnow()
await self.db.commit()
return existing
# Create new assignment
assignment = TenantResource(
tenant_id=tenant_id,
resource_id=resource_id,
usage_limits=usage_limits or {},
is_enabled=True
)
self.db.add(assignment)
await self.db.commit()
await self.db.refresh(assignment)
logger.info(f"Assigned resource {resource_id} to tenant {tenant_id}")
return assignment
async def unassign_resource_from_tenant(self, resource_id: int, tenant_id: int) -> bool:
"""Remove resource assignment from tenant"""
result = await self.db.execute(
select(TenantResource)
.where(and_(
TenantResource.tenant_id == tenant_id,
TenantResource.resource_id == resource_id
))
)
assignment = result.scalar_one_or_none()
if not assignment:
return False
assignment.is_enabled = False
assignment.updated_at = datetime.utcnow()
await self.db.commit()
logger.info(f"Unassigned resource {resource_id} from tenant {tenant_id}")
return True
async def get_tenant_resources(self, tenant_id: int) -> List[AIResource]:
"""Get all resources assigned to a tenant"""
result = await self.db.execute(
select(AIResource)
.join(TenantResource)
.where(and_(
TenantResource.tenant_id == tenant_id,
TenantResource.is_enabled == True,
AIResource.is_active == True
))
.order_by(AIResource.priority.desc())
)
return result.scalars().all()
async def health_check_all_resources(self) -> Dict[str, Any]:
"""Perform health checks on all active resources"""
resources = await self.list_resources(is_active=True)
results = {
"total_resources": len(resources),
"healthy": 0,
"unhealthy": 0,
"unknown": 0,
"details": []
}
# Run health checks concurrently
tasks = []
for resource in resources:
if resource.provider == "groq" and resource.api_key_encrypted:
# Decrypt API key for health check
try:
# Decrypt API key using tenant encryption key
api_key = await self._decrypt_api_key(resource.api_key_encrypted, resource.tenant_id)
task = self._health_check_resource(resource, api_key)
tasks.append(task)
except Exception as e:
logger.error(f"Failed to decrypt API key for resource {resource.id}: {e}")
resource.update_health_status("unhealthy")
if tasks:
health_results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(health_results):
resource = resources[i]
if isinstance(result, Exception):
logger.error(f"Health check failed for resource {resource.id}: {result}")
resource.update_health_status("unhealthy")
else:
# result is already updated in _health_check_resource
pass
# Count results
for resource in resources:
results["details"].append({
"id": resource.id,
"name": resource.name,
"provider": resource.provider,
"health_status": resource.health_status,
"last_check": resource.last_health_check.isoformat() if resource.last_health_check else None
})
if resource.health_status == "healthy":
results["healthy"] += 1
elif resource.health_status == "unhealthy":
results["unhealthy"] += 1
else:
results["unknown"] += 1
await self.db.commit() # Save health status updates
return results
async def _health_check_resource(self, resource: AIResource, api_key: str) -> bool:
"""Internal method to health check a single resource"""
try:
if resource.provider == "groq":
return await groq_service.health_check_resource(resource, api_key)
else:
# For other providers, implement specific health checks
logger.warning(f"No health check implementation for provider: {resource.provider}")
resource.update_health_status("unknown")
return False
except Exception as e:
logger.error(f"Health check failed for resource {resource.id}: {e}")
resource.update_health_status("unhealthy")
return False
async def get_resource_usage_stats(
self,
resource_id: int,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> Dict[str, Any]:
"""Get usage statistics for a resource"""
if not start_date:
start_date = datetime.utcnow() - timedelta(days=30)
if not end_date:
end_date = datetime.utcnow()
# Get usage records
result = await self.db.execute(
select(UsageRecord)
.where(and_(
UsageRecord.resource_id == resource_id,
UsageRecord.created_at >= start_date,
UsageRecord.created_at <= end_date
))
.order_by(UsageRecord.created_at.desc())
)
usage_records = result.scalars().all()
# Calculate statistics
total_requests = len(usage_records)
total_tokens = sum(record.tokens_used for record in usage_records)
total_cost_cents = sum(record.cost_cents for record in usage_records)
avg_tokens_per_request = total_tokens / total_requests if total_requests > 0 else 0
avg_cost_per_request = total_cost_cents / total_requests if total_requests > 0 else 0
# Group by day for trending
daily_stats = {}
for record in usage_records:
date_key = record.created_at.date().isoformat()
if date_key not in daily_stats:
daily_stats[date_key] = {
"requests": 0,
"tokens": 0,
"cost_cents": 0
}
daily_stats[date_key]["requests"] += 1
daily_stats[date_key]["tokens"] += record.tokens_used
daily_stats[date_key]["cost_cents"] += record.cost_cents
return {
"resource_id": resource_id,
"period": {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat()
},
"summary": {
"total_requests": total_requests,
"total_tokens": total_tokens,
"total_cost_dollars": total_cost_cents / 100,
"avg_tokens_per_request": round(avg_tokens_per_request, 2),
"avg_cost_per_request_cents": round(avg_cost_per_request, 2)
},
"daily_stats": daily_stats
}
async def get_tenant_usage_stats(
self,
tenant_id: int,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> Dict[str, Any]:
"""Get usage statistics for all resources used by a tenant"""
if not start_date:
start_date = datetime.utcnow() - timedelta(days=30)
if not end_date:
end_date = datetime.utcnow()
# Get usage records with resource information
result = await self.db.execute(
select(UsageRecord, AIResource)
.join(AIResource, UsageRecord.resource_id == AIResource.id)
.where(and_(
UsageRecord.tenant_id == tenant_id,
UsageRecord.created_at >= start_date,
UsageRecord.created_at <= end_date
))
.order_by(UsageRecord.created_at.desc())
)
records_with_resources = result.all()
# Calculate statistics by resource
resource_stats = {}
total_cost_cents = 0
total_requests = 0
for usage_record, ai_resource in records_with_resources:
resource_id = ai_resource.id
if resource_id not in resource_stats:
resource_stats[resource_id] = {
"resource_name": ai_resource.name,
"provider": ai_resource.provider,
"model_name": ai_resource.model_name,
"requests": 0,
"tokens": 0,
"cost_cents": 0
}
resource_stats[resource_id]["requests"] += 1
resource_stats[resource_id]["tokens"] += usage_record.tokens_used
resource_stats[resource_id]["cost_cents"] += usage_record.cost_cents
total_cost_cents += usage_record.cost_cents
total_requests += 1
return {
"tenant_id": tenant_id,
"period": {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat()
},
"summary": {
"total_requests": total_requests,
"total_cost_dollars": total_cost_cents / 100,
"resources_used": len(resource_stats)
},
"by_resource": resource_stats
}
# Resource-family-specific methods
async def _apply_resource_defaults(self, resource_data: Dict[str, Any]) -> None:
"""Apply defaults based on resource family and provider"""
resource_type = resource_data["resource_type"]
provider = resource_data["provider"]
if resource_type == "ai_ml" and provider == "groq":
# Apply Groq-specific defaults for AI/ML resources
groq_defaults = AIResource.get_groq_defaults()
for key, value in groq_defaults.items():
if key not in resource_data:
resource_data[key] = value
elif resource_type == "external_service":
# Apply defaults for external web services
if "sandbox_config" not in resource_data:
resource_data["sandbox_config"] = {
"permissions": ["allow-same-origin", "allow-scripts", "allow-forms"],
"csp_policy": "default-src 'self'",
"secure": True
}
if "personalization_mode" not in resource_data:
resource_data["personalization_mode"] = "user_scoped" # Most external services are user-specific
elif resource_type == "ai_literacy":
# Apply defaults for AI literacy resources
if "personalization_mode" not in resource_data:
resource_data["personalization_mode"] = "user_scoped" # Track individual progress
if "configuration" not in resource_data:
resource_data["configuration"] = {
"difficulty_adaptive": True,
"progress_tracking": True,
"explanation_mode": True
}
elif resource_type == "rag_engine":
# Apply defaults for RAG engines
if "personalization_mode" not in resource_data:
resource_data["personalization_mode"] = "shared" # RAG engines typically shared
if "configuration" not in resource_data:
resource_data["configuration"] = {
"chunk_size": 512,
"similarity_threshold": 0.7,
"max_results": 10
}
elif resource_type == "agentic_workflow":
# Apply defaults for agentic workflows
if "personalization_mode" not in resource_data:
resource_data["personalization_mode"] = "user_scoped" # Workflows are typically user-specific
if "configuration" not in resource_data:
resource_data["configuration"] = {
"max_iterations": 10,
"human_in_loop": True,
"retry_on_failure": True
}
elif resource_type == "app_integration":
# Apply defaults for app integrations
if "personalization_mode" not in resource_data:
resource_data["personalization_mode"] = "shared" # Most integrations are shared
if "configuration" not in resource_data:
resource_data["configuration"] = {
"timeout_seconds": 30,
"retry_attempts": 3,
"auth_method": "api_key"
}
# Set default personalization mode if not specified
if "personalization_mode" not in resource_data:
resource_data["personalization_mode"] = "shared"
async def _validate_resource_requirements(self, resource_data: Dict[str, Any]) -> None:
"""Validate resource-specific requirements"""
resource_type = resource_data["resource_type"]
resource_subtype = resource_data.get("resource_subtype")
if resource_type == "ai_ml":
# AI/ML resources must have model_name
if not resource_data.get("model_name"):
raise ValueError("AI/ML resources must specify model_name")
# Validate AI/ML subtypes
valid_ai_subtypes = ["llm", "embedding", "image_generation", "function_calling"]
if resource_subtype and resource_subtype not in valid_ai_subtypes:
raise ValueError(f"Invalid AI/ML subtype. Must be one of: {valid_ai_subtypes}")
elif resource_type == "external_service":
# External services must have iframe_url or primary_endpoint
if not resource_data.get("iframe_url") and not resource_data.get("primary_endpoint"):
raise ValueError("External service resources must specify iframe_url or primary_endpoint")
# Validate external service subtypes
valid_external_subtypes = ["lms", "cyber_range", "iframe", "custom"]
if resource_subtype and resource_subtype not in valid_external_subtypes:
raise ValueError(f"Invalid external service subtype. Must be one of: {valid_external_subtypes}")
elif resource_type == "ai_literacy":
# AI literacy resources must have appropriate subtype
valid_literacy_subtypes = ["strategic_game", "logic_puzzle", "philosophical_dilemma", "educational_content"]
if not resource_subtype or resource_subtype not in valid_literacy_subtypes:
raise ValueError(f"AI literacy resources must specify valid subtype: {valid_literacy_subtypes}")
elif resource_type == "rag_engine":
# RAG engines must have appropriate configuration
valid_rag_subtypes = ["vector_database", "document_processor", "retrieval_system"]
if resource_subtype and resource_subtype not in valid_rag_subtypes:
raise ValueError(f"Invalid RAG engine subtype. Must be one of: {valid_rag_subtypes}")
elif resource_type == "agentic_workflow":
# Agentic workflows must have appropriate configuration
valid_workflow_subtypes = ["workflow", "agent_framework", "multi_agent"]
if resource_subtype and resource_subtype not in valid_workflow_subtypes:
raise ValueError(f"Invalid agentic workflow subtype. Must be one of: {valid_workflow_subtypes}")
elif resource_type == "app_integration":
# App integrations must have endpoint or webhook configuration
if not resource_data.get("primary_endpoint") and not resource_data.get("configuration", {}).get("webhook_enabled"):
raise ValueError("App integration resources must specify primary_endpoint or enable webhooks")
valid_integration_subtypes = ["api", "webhook", "oauth_app", "custom"]
if resource_subtype and resource_subtype not in valid_integration_subtypes:
raise ValueError(f"Invalid app integration subtype. Must be one of: {valid_integration_subtypes}")
# User data separation methods
async def get_user_resource_data(
self,
user_id: int,
resource_id: int,
data_type: str,
session_id: Optional[str] = None
) -> Optional[UserResourceData]:
"""Get user-specific data for a resource"""
query = select(UserResourceData).where(and_(
UserResourceData.user_id == user_id,
UserResourceData.resource_id == resource_id,
UserResourceData.data_type == data_type
))
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def set_user_resource_data(
self,
user_id: int,
tenant_id: int,
resource_id: int,
data_type: str,
data_key: str,
data_value: Dict[str, Any],
session_id: Optional[str] = None,
expires_minutes: Optional[int] = None
) -> UserResourceData:
"""Set user-specific data for a resource"""
# Check if data already exists
existing = await self.get_user_resource_data(user_id, resource_id, data_type)
if existing:
# Update existing data
existing.data_key = data_key
existing.data_value = data_value
existing.accessed_at = datetime.utcnow()
if expires_minutes:
existing.expiry_date = datetime.utcnow() + timedelta(minutes=expires_minutes)
await self.db.commit()
await self.db.refresh(existing)
return existing
else:
# Create new data
expiry_date = None
if expires_minutes:
expiry_date = datetime.utcnow() + timedelta(minutes=expires_minutes)
user_data = UserResourceData(
user_id=user_id,
tenant_id=tenant_id,
resource_id=resource_id,
data_type=data_type,
data_key=data_key,
data_value=data_value,
expiry_date=expiry_date
)
self.db.add(user_data)
await self.db.commit()
await self.db.refresh(user_data)
logger.info(f"Created user data: user={user_id}, resource={resource_id}, type={data_type}")
return user_data
async def get_user_progress(self, user_id: int, resource_id: int) -> Optional[UserProgress]:
"""Get user progress for AI literacy resources"""
result = await self.db.execute(
select(UserProgress).where(and_(
UserProgress.user_id == user_id,
UserProgress.resource_id == resource_id
))
)
return result.scalar_one_or_none()
async def update_user_progress(
self,
user_id: int,
tenant_id: int,
resource_id: int,
skill_area: str,
progress_data: Dict[str, Any]
) -> UserProgress:
"""Update user progress for learning resources"""
existing = await self.get_user_progress(user_id, resource_id)
if existing:
# Update existing progress
for key, value in progress_data.items():
if hasattr(existing, key):
setattr(existing, key, value)
existing.last_activity = datetime.utcnow()
await self.db.commit()
await self.db.refresh(existing)
return existing
else:
# Create new progress record
progress = UserProgress(
user_id=user_id,
tenant_id=tenant_id,
resource_id=resource_id,
skill_area=skill_area,
**progress_data
)
self.db.add(progress)
await self.db.commit()
await self.db.refresh(progress)
logger.info(f"Created user progress: user={user_id}, resource={resource_id}, skill={skill_area}")
return progress
# Enhanced filtering and search
async def list_resources_by_family(
self,
resource_type: str,
resource_subtype: Optional[str] = None,
tenant_id: Optional[int] = None,
user_id: Optional[int] = None,
include_inactive: bool = False
) -> List[AIResource]:
"""List resources by resource family with optional filtering"""
query = select(AIResource).options(selectinload(AIResource.tenant_resources))
conditions = [AIResource.resource_type == resource_type]
if resource_subtype:
conditions.append(AIResource.resource_subtype == resource_subtype)
if not include_inactive:
conditions.append(AIResource.is_active == True)
if tenant_id:
# Filter to resources available to this tenant
query = query.join(TenantResource).where(and_(
TenantResource.tenant_id == tenant_id,
TenantResource.is_enabled == True
))
if conditions:
query = query.where(and_(*conditions))
result = await self.db.execute(
query.order_by(AIResource.priority.desc(), AIResource.created_at)
)
return result.scalars().all()
async def get_resource_families_summary(self, tenant_id: Optional[int] = None) -> Dict[str, Any]:
"""Get summary of all resource families"""
base_query = select(
AIResource.resource_type,
AIResource.resource_subtype,
func.count(AIResource.id).label('count'),
func.count(func.nullif(AIResource.health_status == 'healthy', False)).label('healthy_count')
).group_by(AIResource.resource_type, AIResource.resource_subtype)
if tenant_id:
base_query = base_query.join(TenantResource).where(and_(
TenantResource.tenant_id == tenant_id,
TenantResource.is_enabled == True,
AIResource.is_active == True
))
else:
base_query = base_query.where(AIResource.is_active == True)
result = await self.db.execute(base_query)
rows = result.all()
# Organize by resource family
families = {}
for row in rows:
family = row.resource_type
if family not in families:
families[family] = {
"total_resources": 0,
"healthy_resources": 0,
"subtypes": {}
}
subtype = row.resource_subtype or "default"
families[family]["total_resources"] += row.count
families[family]["healthy_resources"] += row.healthy_count or 0
families[family]["subtypes"][subtype] = {
"count": row.count,
"healthy_count": row.healthy_count or 0
}
return families
async def _decrypt_api_key(self, encrypted_api_key: str, tenant_id: str) -> str:
"""Decrypt API key using tenant-specific encryption key"""
try:
settings = get_settings()
# Generate tenant-specific encryption key from settings secret
tenant_key = base64.urlsafe_b64encode(
f"{settings.secret_key}:{tenant_id}".encode()[:32].ljust(32, b'\0')
)
cipher = Fernet(tenant_key)
# Decrypt the API key
decrypted_bytes = cipher.decrypt(encrypted_api_key.encode())
return decrypted_bytes.decode()
except Exception as e:
logger.error(f"Failed to decrypt API key for tenant {tenant_id}: {e}")
raise ValueError(f"API key decryption failed: {e}")
async def _encrypt_api_key(self, api_key: str, tenant_id: str) -> str:
"""Encrypt API key using tenant-specific encryption key"""
try:
settings = get_settings()
# Generate tenant-specific encryption key from settings secret
tenant_key = base64.urlsafe_b64encode(
f"{settings.secret_key}:{tenant_id}".encode()[:32].ljust(32, b'\0')
)
cipher = Fernet(tenant_key)
# Encrypt the API key
encrypted_bytes = cipher.encrypt(api_key.encode())
return encrypted_bytes.decode()
except Exception as e:
logger.error(f"Failed to encrypt API key for tenant {tenant_id}: {e}")
raise ValueError(f"API key encryption failed: {e}")

View File

@@ -0,0 +1,366 @@
"""
GT 2.0 Session Management Service
NIST SP 800-63B AAL2 Compliant Server-Side Session Management (Issue #264)
- Server-side session tracking is authoritative
- Idle timeout: 30 minutes (NIST AAL2 requirement)
- Absolute timeout: 12 hours (NIST AAL2 maximum)
- Warning threshold: 5 minutes before expiry
- Session tokens are SHA-256 hashed before storage
"""
from typing import Optional, Tuple, Dict, Any
from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session as DBSession
from sqlalchemy import and_
import secrets
import hashlib
import logging
from app.models.session import Session
logger = logging.getLogger(__name__)
class SessionService:
"""
Service for OWASP/NIST compliant session management.
Key features:
- Server-side session state is the single source of truth
- Session tokens hashed with SHA-256 (never stored in plaintext)
- Idle timeout tracked via last_activity_at
- Absolute timeout prevents indefinite session extension
- Warning signals sent when approaching expiry
"""
# Session timeout configuration (NIST SP 800-63B AAL2 Compliant)
IDLE_TIMEOUT_MINUTES = 30 # 30 minutes - NIST AAL2 requirement for inactivity timeout
ABSOLUTE_TIMEOUT_HOURS = 12 # 12 hours - NIST AAL2 maximum session duration
# Warning threshold: Show notice 30 minutes before absolute timeout
ABSOLUTE_WARNING_THRESHOLD_MINUTES = 30
def __init__(self, db: DBSession):
self.db = db
@staticmethod
def generate_session_token() -> str:
"""
Generate a cryptographically secure session token.
Uses secrets.token_urlsafe for CSPRNG (Cryptographically Secure
Pseudo-Random Number Generator). 32 bytes = 256 bits of entropy.
"""
return secrets.token_urlsafe(32)
@staticmethod
def hash_token(token: str) -> str:
"""
Hash session token with SHA-256 for secure storage.
OWASP: Never store session tokens in plaintext.
"""
return hashlib.sha256(token.encode('utf-8')).hexdigest()
def create_session(
self,
user_id: int,
tenant_id: Optional[int] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
app_type: str = 'control_panel'
) -> Tuple[str, datetime]:
"""
Create a new server-side session.
Args:
user_id: The authenticated user's ID
tenant_id: Optional tenant context
ip_address: Client IP for security auditing
user_agent: Client user agent for security auditing
app_type: 'control_panel' or 'tenant_app' to distinguish session source
Returns:
Tuple of (session_token, absolute_expires_at)
The token should be included in JWT claims.
"""
# Generate session token (this gets sent to client in JWT)
session_token = self.generate_session_token()
token_hash = self.hash_token(session_token)
# Calculate absolute expiration
now = datetime.now(timezone.utc)
absolute_expires_at = now + timedelta(hours=self.ABSOLUTE_TIMEOUT_HOURS)
# Create session record
session = Session(
user_id=user_id,
session_token_hash=token_hash,
absolute_expires_at=absolute_expires_at,
ip_address=ip_address,
user_agent=user_agent[:500] if user_agent and len(user_agent) > 500 else user_agent,
tenant_id=tenant_id,
is_active=True,
app_type=app_type
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
logger.info(f"Created session for user_id={user_id}, tenant_id={tenant_id}, app_type={app_type}, expires={absolute_expires_at}")
return session_token, absolute_expires_at
def validate_session(self, session_token: str) -> Tuple[bool, Optional[str], Optional[int], Optional[Dict[str, Any]]]:
"""
Validate a session and return status information.
This is the core validation method called on every authenticated request.
Args:
session_token: The plaintext session token from JWT
Returns:
Tuple of (is_valid, expiry_reason, seconds_until_idle_expiry, session_info)
- is_valid: Whether the session is currently valid
- expiry_reason: 'idle' or 'absolute' if expired, None if valid
- seconds_until_idle_expiry: Seconds until idle timeout (for warning)
- session_info: Dict with user_id, tenant_id if valid
"""
token_hash = self.hash_token(session_token)
# Find active session
session = self.db.query(Session).filter(
and_(
Session.session_token_hash == token_hash,
Session.is_active == True
)
).first()
if not session:
logger.debug(f"Session not found or inactive for token hash prefix: {token_hash[:8]}...")
return False, 'not_found', None, None
now = datetime.now(timezone.utc)
# Ensure session timestamps are timezone-aware for comparison
absolute_expires = session.absolute_expires_at
if absolute_expires.tzinfo is None:
absolute_expires = absolute_expires.replace(tzinfo=timezone.utc)
last_activity = session.last_activity_at
if last_activity.tzinfo is None:
last_activity = last_activity.replace(tzinfo=timezone.utc)
# Check absolute timeout first (cannot be extended)
if now >= absolute_expires:
self._revoke_session_internal(session, 'absolute_timeout')
logger.info(f"Session expired (absolute) for user_id={session.user_id}")
return False, 'absolute', None, {'user_id': session.user_id, 'tenant_id': session.tenant_id}
# Check idle timeout
idle_expires_at = last_activity + timedelta(minutes=self.IDLE_TIMEOUT_MINUTES)
if now >= idle_expires_at:
self._revoke_session_internal(session, 'idle_timeout')
logger.info(f"Session expired (idle) for user_id={session.user_id}")
return False, 'idle', None, {'user_id': session.user_id, 'tenant_id': session.tenant_id}
# Session is valid - calculate time until idle expiry
seconds_until_idle = int((idle_expires_at - now).total_seconds())
# Also check seconds until absolute expiry (use whichever is sooner)
seconds_until_absolute = int((absolute_expires - now).total_seconds())
seconds_remaining = min(seconds_until_idle, seconds_until_absolute)
return True, None, seconds_remaining, {
'user_id': session.user_id,
'tenant_id': session.tenant_id,
'session_id': str(session.id),
'absolute_seconds_remaining': seconds_until_absolute
}
def update_activity(self, session_token: str) -> bool:
"""
Update the last_activity_at timestamp for a session.
This should be called on every authenticated request to track idle time.
Args:
session_token: The plaintext session token from JWT
Returns:
True if session was updated, False if session not found/inactive
"""
token_hash = self.hash_token(session_token)
result = self.db.query(Session).filter(
and_(
Session.session_token_hash == token_hash,
Session.is_active == True
)
).update({
Session.last_activity_at: datetime.now(timezone.utc)
})
self.db.commit()
if result > 0:
logger.debug(f"Updated activity for session hash prefix: {token_hash[:8]}...")
return True
return False
def revoke_session(self, session_token: str, reason: str = 'logout') -> bool:
"""
Revoke a session (e.g., on logout).
Args:
session_token: The plaintext session token
reason: Revocation reason ('logout', 'admin_revoke', etc.)
Returns:
True if session was revoked, False if not found
"""
token_hash = self.hash_token(session_token)
session = self.db.query(Session).filter(
and_(
Session.session_token_hash == token_hash,
Session.is_active == True
)
).first()
if not session:
return False
self._revoke_session_internal(session, reason)
logger.info(f"Session revoked for user_id={session.user_id}, reason={reason}")
return True
def revoke_all_user_sessions(self, user_id: int, reason: str = 'password_change') -> int:
"""
Revoke all active sessions for a user.
This should be called on password change, account lockout, etc.
Args:
user_id: The user whose sessions to revoke
reason: Revocation reason
Returns:
Number of sessions revoked
"""
now = datetime.now(timezone.utc)
result = self.db.query(Session).filter(
and_(
Session.user_id == user_id,
Session.is_active == True
)
).update({
Session.is_active: False,
Session.revoked_at: now,
Session.ended_at: now, # Always set ended_at when session ends
Session.revoke_reason: reason
})
self.db.commit()
if result > 0:
logger.info(f"Revoked {result} sessions for user_id={user_id}, reason={reason}")
return result
def get_active_sessions_for_user(self, user_id: int) -> list:
"""
Get all active sessions for a user.
Useful for "active sessions" UI where users can see/revoke their sessions.
Args:
user_id: The user to query
Returns:
List of session dictionaries (without sensitive data)
"""
sessions = self.db.query(Session).filter(
and_(
Session.user_id == user_id,
Session.is_active == True
)
).all()
return [s.to_dict() for s in sessions]
def cleanup_expired_sessions(self) -> int:
"""
Clean up expired sessions (for scheduled maintenance).
This marks expired sessions as inactive rather than deleting them
to preserve audit trail.
Returns:
Number of sessions cleaned up
"""
now = datetime.now(timezone.utc)
idle_cutoff = now - timedelta(minutes=self.IDLE_TIMEOUT_MINUTES)
# Mark absolute-expired sessions
absolute_count = self.db.query(Session).filter(
and_(
Session.is_active == True,
Session.absolute_expires_at < now
)
).update({
Session.is_active: False,
Session.revoked_at: now,
Session.ended_at: now, # Always set ended_at when session ends
Session.revoke_reason: 'absolute_timeout'
})
# Mark idle-expired sessions
idle_count = self.db.query(Session).filter(
and_(
Session.is_active == True,
Session.last_activity_at < idle_cutoff
)
).update({
Session.is_active: False,
Session.revoked_at: now,
Session.ended_at: now, # Always set ended_at when session ends
Session.revoke_reason: 'idle_timeout'
})
self.db.commit()
total = absolute_count + idle_count
if total > 0:
logger.info(f"Cleaned up {total} expired sessions (absolute={absolute_count}, idle={idle_count})")
return total
def _revoke_session_internal(self, session: Session, reason: str) -> None:
"""Internal helper to revoke a session."""
now = datetime.now(timezone.utc)
session.is_active = False
session.revoked_at = now
session.ended_at = now # Always set ended_at when session ends
session.revoke_reason = reason
self.db.commit()
def should_show_warning(self, absolute_seconds_remaining: int) -> bool:
"""
Check if a warning should be shown to the user.
Warning is based on ABSOLUTE timeout (not idle), because:
- If browser is open, polling keeps idle timeout from expiring
- Absolute timeout is the only one that will actually log user out
- This gives users 30 minutes notice before forced re-authentication
Args:
absolute_seconds_remaining: Seconds until absolute session expiry
Returns:
True if warning should be shown (< 30 minutes until absolute timeout)
"""
return absolute_seconds_remaining <= (self.ABSOLUTE_WARNING_THRESHOLD_MINUTES * 60)

View File

@@ -0,0 +1,343 @@
"""
GT 2.0 Template Service
Handles applying tenant templates to existing tenants
"""
import logging
import os
import uuid
from typing import Dict, Any, List
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text
from sqlalchemy.dialects.postgresql import insert
from app.models.tenant_template import TenantTemplate
from app.models.tenant import Tenant
from app.models.tenant_model_config import TenantModelConfig
logger = logging.getLogger(__name__)
class TemplateService:
"""Service for applying tenant templates"""
def __init__(self):
tenant_password = os.environ["TENANT_POSTGRES_PASSWORD"]
self.tenant_db_url = f"postgresql://gt2_tenant_user:{tenant_password}@gentwo-tenant-postgres-primary:5432/gt2_tenants"
async def apply_template(
self,
template_id: int,
tenant_id: int,
control_panel_db: AsyncSession
) -> Dict[str, Any]:
"""
Apply a template to an existing tenant
Args:
template_id: ID of template to apply
tenant_id: ID of tenant to apply to
control_panel_db: Control panel database session
Returns:
Dict with applied resources summary
"""
try:
template = await control_panel_db.get(TenantTemplate, template_id)
if not template:
raise ValueError(f"Template {template_id} not found")
tenant = await control_panel_db.get(Tenant, tenant_id)
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
logger.info(f"Applying template '{template.name}' to tenant '{tenant.domain}'")
template_data = template.template_data
results = {
"models_added": 0,
"agents_added": 0,
"datasets_added": 0
}
results["models_added"] = await self._apply_model_configs(
template_data.get("model_configs", []),
tenant_id,
control_panel_db
)
tenant_schema = f"tenant_{tenant.domain.replace('-', '_').replace('.', '_')}"
results["agents_added"] = await self._apply_agents(
template_data.get("agents", []),
tenant_schema
)
results["datasets_added"] = await self._apply_datasets(
template_data.get("datasets", []),
tenant_schema
)
logger.info(f"Template applied successfully: {results}")
return results
except Exception as e:
logger.error(f"Failed to apply template: {e}")
raise
async def _apply_model_configs(
self,
model_configs: List[Dict],
tenant_id: int,
db: AsyncSession
) -> int:
"""Apply model configurations to control panel DB"""
count = 0
for config in model_configs:
stmt = insert(TenantModelConfig).values(
tenant_id=tenant_id,
model_id=config["model_id"],
is_enabled=config.get("is_enabled", True),
rate_limits=config.get("rate_limits", {}),
usage_constraints=config.get("usage_constraints", {}),
priority=config.get("priority", 5),
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
).on_conflict_do_update(
index_elements=['tenant_id', 'model_id'],
set_={
'is_enabled': config.get("is_enabled", True),
'rate_limits': config.get("rate_limits", {}),
'updated_at': datetime.utcnow()
}
)
await db.execute(stmt)
count += 1
await db.commit()
logger.info(f"Applied {count} model configs")
return count
async def _apply_agents(
self,
agents: List[Dict],
tenant_schema: str
) -> int:
"""Apply agents to tenant DB"""
from asyncpg import connect
count = 0
conn = await connect(self.tenant_db_url)
try:
for agent in agents:
result = await conn.fetchrow(f"""
SELECT id FROM {tenant_schema}.tenants LIMIT 1
""")
tenant_id = result['id'] if result else None
result = await conn.fetchrow(f"""
SELECT id FROM {tenant_schema}.users LIMIT 1
""")
created_by = result['id'] if result else None
if not tenant_id or not created_by:
logger.warning(f"No tenant or user found in {tenant_schema}, skipping agents")
break
agent_id = str(uuid.uuid4())
await conn.execute(f"""
INSERT INTO {tenant_schema}.agents (
id, name, description, system_prompt, tenant_id, created_by,
model, temperature, max_tokens, visibility, configuration,
is_active, access_group, agent_type, disclaimer, easy_prompts,
created_at, updated_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, NOW(), NOW()
)
ON CONFLICT (id) DO NOTHING
""",
agent_id,
agent.get("name"),
agent.get("description"),
agent.get("system_prompt"),
tenant_id,
created_by,
agent.get("model"),
agent.get("temperature"),
agent.get("max_tokens"),
agent.get("visibility", "individual"),
agent.get("configuration", {}),
True,
"individual",
agent.get("agent_type", "conversational"),
agent.get("disclaimer"),
agent.get("easy_prompts", [])
)
count += 1
logger.info(f"Applied {count} agents to {tenant_schema}")
finally:
await conn.close()
return count
async def _apply_datasets(
self,
datasets: List[Dict],
tenant_schema: str
) -> int:
"""Apply datasets to tenant DB"""
from asyncpg import connect
count = 0
conn = await connect(self.tenant_db_url)
try:
for dataset in datasets:
result = await conn.fetchrow(f"""
SELECT id FROM {tenant_schema}.tenants LIMIT 1
""")
tenant_id = result['id'] if result else None
result = await conn.fetchrow(f"""
SELECT id FROM {tenant_schema}.users LIMIT 1
""")
created_by = result['id'] if result else None
if not tenant_id or not created_by:
logger.warning(f"No tenant or user found in {tenant_schema}, skipping datasets")
break
dataset_id = str(uuid.uuid4())
collection_name = f"dataset_{dataset_id.replace('-', '_')}"
await conn.execute(f"""
INSERT INTO {tenant_schema}.datasets (
id, name, description, tenant_id, created_by, collection_name,
document_count, total_size_bytes, embedding_model, visibility,
metadata, is_active, access_group, search_method,
specialized_language, chunk_size, chunk_overlap,
created_at, updated_at
) VALUES (
$1, $2, $3, $4, $5, $6, 0, 0, $7, $8, $9, $10, $11, $12, $13, $14, $15, NOW(), NOW()
)
ON CONFLICT (id) DO NOTHING
""",
dataset_id,
dataset.get("name"),
dataset.get("description"),
tenant_id,
created_by,
collection_name,
dataset.get("embedding_model", "BAAI/bge-m3"),
dataset.get("visibility", "individual"),
dataset.get("metadata", {}),
True,
"individual",
dataset.get("search_method", "hybrid"),
dataset.get("specialized_language", False),
dataset.get("chunk_size", 512),
dataset.get("chunk_overlap", 128)
)
count += 1
logger.info(f"Applied {count} datasets to {tenant_schema}")
finally:
await conn.close()
return count
async def export_tenant_as_template(
self,
tenant_id: int,
template_name: str,
template_description: str,
control_panel_db: AsyncSession
) -> TenantTemplate:
"""Export existing tenant configuration as a new template"""
try:
tenant = await control_panel_db.get(Tenant, tenant_id)
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
logger.info(f"Exporting tenant '{tenant.domain}' as template '{template_name}'")
result = await control_panel_db.execute(
select(TenantModelConfig).where(TenantModelConfig.tenant_id == tenant_id)
)
model_configs = result.scalars().all()
model_config_data = [
{
"model_id": mc.model_id,
"is_enabled": mc.is_enabled,
"rate_limits": mc.rate_limits,
"usage_constraints": mc.usage_constraints,
"priority": mc.priority
}
for mc in model_configs
]
tenant_schema = f"tenant_{tenant.domain.replace('-', '_').replace('.', '_')}"
from asyncpg import connect
conn = await connect(self.tenant_db_url)
try:
query = f"""
SELECT name, description, system_prompt, model, temperature, max_tokens,
visibility, configuration, agent_type, disclaimer, easy_prompts
FROM {tenant_schema}.agents
WHERE is_active = true
"""
logger.info(f"Executing agents query: {query}")
agents_data = await conn.fetch(query)
logger.info(f"Found {len(agents_data)} agents")
agents = [dict(row) for row in agents_data]
datasets_data = await conn.fetch(f"""
SELECT name, description, embedding_model, visibility, metadata,
search_method, specialized_language, chunk_size, chunk_overlap
FROM {tenant_schema}.datasets
WHERE is_active = true
LIMIT 10
""")
datasets = [dict(row) for row in datasets_data]
finally:
await conn.close()
template_data = {
"model_configs": model_config_data,
"agents": agents,
"datasets": datasets
}
new_template = TenantTemplate(
name=template_name,
description=template_description,
template_data=template_data,
is_default=False,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
control_panel_db.add(new_template)
await control_panel_db.commit()
await control_panel_db.refresh(new_template)
logger.info(f"Template '{template_name}' created successfully with ID {new_template.id}")
return new_template
except Exception as e:
logger.error(f"Failed to export tenant as template: {e}")
await control_panel_db.rollback()
raise

View File

@@ -0,0 +1,397 @@
"""
GT 2.0 Tenant Provisioning Service
Implements automated tenant infrastructure provisioning following GT 2.0 principles:
- File-based isolation with OS-level permissions
- Perfect tenant separation
- Zero downtime deployment
- Self-contained security
"""
import os
import asyncio
import logging
# DuckDB removed - PostgreSQL + PGVector unified storage
import json
import subprocess
from pathlib import Path
from typing import Dict, Any, Optional
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from app.models.tenant import Tenant
from app.core.config import get_settings
from app.services.message_bus import message_bus
logger = logging.getLogger(__name__)
settings = get_settings()
class TenantProvisioningService:
"""
Service for automated tenant infrastructure provisioning.
Follows GT 2.0 PostgreSQL + PGVector architecture principles:
- PostgreSQL schema per tenant (MVCC concurrency)
- PGVector embeddings per tenant (replaces ChromaDB)
- Database-level tenant isolation with RLS
- Encrypted data at rest
"""
def __init__(self):
self.base_data_path = Path("/data")
self.message_bus = message_bus
async def provision_tenant(self, tenant_id: int, db: AsyncSession) -> bool:
"""
Complete tenant provisioning process.
Args:
tenant_id: Database ID of tenant to provision
db: Database session
Returns:
True if successful, False otherwise
"""
try:
# Get tenant details
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if not tenant:
logger.error(f"Tenant {tenant_id} not found")
return False
logger.info(f"Starting provisioning for tenant {tenant.domain}")
# Step 1: Create tenant directory structure
await self._create_directory_structure(tenant)
# Step 2: Initialize PostgreSQL schema
await self._initialize_database(tenant)
# Step 3: Setup PGVector extensions (handled by schema creation)
# Step 4: Create configuration files
await self._create_configuration_files(tenant)
# Step 5: Setup OS user (for production)
await self._setup_os_user(tenant)
# Step 6: Send provisioning message to tenant cluster
await self._notify_tenant_cluster(tenant)
# Step 7: Update tenant status
await self._update_tenant_status(tenant_id, "active", db)
logger.info(f"Tenant {tenant.domain} provisioned successfully")
return True
except Exception as e:
logger.error(f"Failed to provision tenant {tenant_id}: {e}")
await self._update_tenant_status(tenant_id, "failed", db)
return False
async def _create_directory_structure(self, tenant: Tenant) -> None:
"""Create tenant directory structure with proper permissions"""
tenant_path = self.base_data_path / tenant.domain
# Create main directories
directories = [
tenant_path,
tenant_path / "shared",
tenant_path / "shared" / "models",
tenant_path / "shared" / "configs",
tenant_path / "users",
tenant_path / "sessions",
tenant_path / "documents",
tenant_path / "vector_storage",
tenant_path / "backups"
]
for directory in directories:
directory.mkdir(parents=True, exist_ok=True, mode=0o700)
logger.info(f"Created directory structure for {tenant.domain}")
async def _initialize_database(self, tenant: Tenant) -> None:
"""Initialize PostgreSQL schema for tenant"""
schema_name = f"tenant_{tenant.domain.replace('-', '_').replace('.', '_')}"
# PostgreSQL schema creation is handled by the main database migration scripts
# Schema name follows pattern: tenant_{domain}
logger.info(f"PostgreSQL schema initialization for {tenant.domain} handled by migration scripts")
return True
async def _setup_vector_storage(self, tenant: Tenant) -> None:
"""Setup PGVector extensions for tenant (handled by PostgreSQL migration)"""
# PGVector extensions handled by PostgreSQL migration scripts
# Vector storage is now unified within PostgreSQL schema
logger.info(f"PGVector setup for {tenant.domain} handled by PostgreSQL migration scripts")
async def _create_configuration_files(self, tenant: Tenant) -> None:
"""Create tenant-specific configuration files"""
tenant_path = self.base_data_path / tenant.domain
config_path = tenant_path / "shared" / "configs"
# Main tenant configuration
tenant_config = {
"tenant_id": tenant.uuid,
"tenant_domain": tenant.domain,
"tenant_name": tenant.name,
"template": tenant.template,
"max_users": tenant.max_users,
"resource_limits": tenant.resource_limits,
"postgresql_schema": f"tenant_{tenant.domain.replace('-', '_').replace('.', '_')}",
"vector_storage_path": str(tenant_path / "vector_storage"),
"documents_path": str(tenant_path / "documents"),
"created_at": datetime.utcnow().isoformat(),
"encryption_enabled": True,
"backup_enabled": True
}
config_file = config_path / "tenant_config.json"
with open(config_file, 'w') as f:
json.dump(tenant_config, f, indent=2)
os.chmod(config_file, 0o600)
# Environment file for tenant backend
tenant_db_password = os.environ["TENANT_POSTGRES_PASSWORD"]
env_config = f"""
# GT 2.0 Tenant Configuration - {tenant.domain}
ENVIRONMENT=production
TENANT_ID={tenant.uuid}
TENANT_DOMAIN={tenant.domain}
DATABASE_URL=postgresql://gt2_tenant_user:{tenant_db_password}@tenant-pgbouncer:5432/gt2_tenants
POSTGRES_SCHEMA=tenant_{tenant.domain.replace('-', '_').replace('.', '_')}
DOCUMENTS_PATH={tenant_path}/documents
# Security
SECRET_KEY=will_be_replaced_with_vault_key
ENCRYPT_DATA=true
SECURE_DELETE=true
# Resource Limits
MAX_USERS={tenant.max_users}
MAX_STORAGE_GB={tenant.resource_limits.get('max_storage_gb', 100)}
MAX_API_CALLS_PER_HOUR={tenant.resource_limits.get('max_api_calls_per_hour', 1000)}
# Integration
CONTROL_PANEL_URL=http://control-panel-backend:8001
RESOURCE_CLUSTER_URL=http://resource-cluster:8004
"""
# Write tenant environment configuration file
# Security Note: This file contains tenant-specific configuration values (URLs, limits),
# not sensitive credentials like API keys or passwords. File permissions are set to 0o600
# (owner read/write only) for defense in depth. Actual secrets are stored securely in the
# database and accessed via the Control Panel API.
env_file = config_path / "tenant.env"
with open(env_file, 'w') as f:
f.write(env_config)
os.chmod(env_file, 0o600)
logger.info(f"Created configuration files for {tenant.domain}")
async def _setup_os_user(self, tenant: Tenant) -> None:
"""Create OS user for tenant (production only)"""
if settings.environment == "development":
logger.info(f"Skipping OS user creation in development for {tenant.domain}")
return
try:
# Create system user for tenant
username = f"gt-{tenant.domain}"
tenant_path = self.base_data_path / tenant.domain
# Check if user already exists
result = subprocess.run(
["id", username],
capture_output=True,
text=True
)
if result.returncode != 0:
# Create user
subprocess.run([
"useradd",
"--system",
"--home-dir", str(tenant_path),
"--shell", "/usr/sbin/nologin",
"--comment", f"GT 2.0 Tenant {tenant.domain}",
username
], check=True)
logger.info(f"Created OS user {username}")
# Set ownership
subprocess.run([
"chown", "-R", f"{username}:{username}", str(tenant_path)
], check=True)
logger.info(f"Set ownership for {tenant.domain}")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to setup OS user for {tenant.domain}: {e}")
# Don't fail the entire provisioning for this
async def _notify_tenant_cluster(self, tenant: Tenant) -> None:
"""Send provisioning message to tenant cluster via RabbitMQ"""
try:
message = {
"action": "tenant_provisioned",
"tenant_id": tenant.uuid,
"tenant_domain": tenant.domain,
"namespace": tenant.namespace,
"config_path": f"/data/{tenant.domain}/shared/configs/tenant_config.json",
"timestamp": datetime.utcnow().isoformat()
}
await self.message_bus.send_tenant_command(
command_type="tenant_provisioned",
tenant_namespace=tenant.namespace,
payload=message
)
logger.info(f"Sent provisioning notification for {tenant.domain}")
except Exception as e:
logger.error(f"Failed to notify tenant cluster for {tenant.domain}: {e}")
# Don't fail provisioning for this
async def _update_tenant_status(self, tenant_id: int, status: str, db: AsyncSession) -> None:
"""Update tenant status in database"""
try:
await db.execute(
update(Tenant)
.where(Tenant.id == tenant_id)
.values(
status=status,
updated_at=datetime.utcnow()
)
)
await db.commit()
except Exception as e:
logger.error(f"Failed to update tenant status: {e}")
async def deprovision_tenant(self, tenant_id: int, db: AsyncSession) -> bool:
"""
Safely deprovision tenant (archive data, don't delete).
Args:
tenant_id: Database ID of tenant to deprovision
db: Database session
Returns:
True if successful, False otherwise
"""
try:
# Get tenant details
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if not tenant:
logger.error(f"Tenant {tenant_id} not found")
return False
logger.info(f"Starting deprovisioning for tenant {tenant.domain}")
# Step 1: Create backup
await self._create_tenant_backup(tenant)
# Step 2: Notify tenant cluster to stop services
await self._notify_tenant_shutdown(tenant)
# Step 3: Archive data (don't delete)
await self._archive_tenant_data(tenant)
# Step 4: Update status
await self._update_tenant_status(tenant_id, "archived", db)
logger.info(f"Tenant {tenant.domain} deprovisioned successfully")
return True
except Exception as e:
logger.error(f"Failed to deprovision tenant {tenant_id}: {e}")
return False
async def _create_tenant_backup(self, tenant: Tenant) -> None:
"""Create complete backup of tenant data"""
tenant_path = self.base_data_path / tenant.domain
backup_path = tenant_path / "backups" / f"full_backup_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.tar.gz"
# Create compressed backup
subprocess.run([
"tar", "-czf", str(backup_path),
"-C", str(tenant_path.parent),
tenant.domain,
"--exclude", "backups"
], check=True)
logger.info(f"Created backup for {tenant.domain}: {backup_path}")
async def _notify_tenant_shutdown(self, tenant: Tenant) -> None:
"""Notify tenant cluster to shutdown services"""
try:
message = {
"action": "tenant_shutdown",
"tenant_id": tenant.uuid,
"tenant_domain": tenant.domain,
"timestamp": datetime.utcnow().isoformat()
}
await self.message_bus.send_tenant_command(
command_type="tenant_shutdown",
tenant_namespace=tenant.namespace,
payload=message
)
except Exception as e:
logger.error(f"Failed to notify tenant shutdown: {e}")
async def _archive_tenant_data(self, tenant: Tenant) -> None:
"""Archive tenant data (rename directory)"""
tenant_path = self.base_data_path / tenant.domain
archive_path = self.base_data_path / f"{tenant.domain}_archived_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
if tenant_path.exists():
tenant_path.rename(archive_path)
logger.info(f"Archived tenant data: {archive_path}")
# Background task function for FastAPI
async def deploy_tenant_infrastructure(tenant_id: int) -> None:
"""Background task to deploy tenant infrastructure"""
from app.core.database import get_db_session
provisioning_service = TenantProvisioningService()
async with get_db_session() as db:
success = await provisioning_service.provision_tenant(tenant_id, db)
if success:
logger.info(f"Tenant {tenant_id} provisioned successfully")
else:
logger.error(f"Failed to provision tenant {tenant_id}")
async def archive_tenant_infrastructure(tenant_id: int) -> None:
"""Background task to archive tenant infrastructure"""
from app.core.database import get_db_session
provisioning_service = TenantProvisioningService()
async with get_db_session() as db:
success = await provisioning_service.deprovision_tenant(tenant_id, db)
if success:
logger.info(f"Tenant {tenant_id} archived successfully")
else:
logger.error(f"Failed to archive tenant {tenant_id}")

View File

@@ -0,0 +1,525 @@
"""
Update Service - Manages system updates and version checking
"""
import os
import json
import asyncio
import httpx
from typing import Dict, Any, Optional, List
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, desc
from fastapi import HTTPException, status
import structlog
from app.models.system import SystemVersion, UpdateJob, UpdateStatus, BackupRecord
from app.services.backup_service import BackupService
logger = structlog.get_logger()
class UpdateService:
"""Service for checking and executing system updates"""
GITHUB_API_BASE = "https://api.github.com"
REPO_OWNER = "GT-Edge-AI-Internal"
REPO_NAME = "gt-ai-os-community"
DEPLOY_SCRIPT = "/app/scripts/deploy.sh"
ROLLBACK_SCRIPT = "/app/scripts/rollback.sh"
MIN_DISK_SPACE_GB = 5
def __init__(self, db: AsyncSession):
self.db = db
async def check_for_updates(self) -> Dict[str, Any]:
"""Check GitHub for available updates"""
try:
# Get current version
current_version = await self._get_current_version()
# Query GitHub releases API
url = f"{self.GITHUB_API_BASE}/repos/{self.REPO_OWNER}/{self.REPO_NAME}/releases/latest"
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
response = await client.get(url)
if response.status_code == 404:
logger.warning("No releases found in repository")
return {
"update_available": False,
"current_version": current_version,
"latest_version": None,
"release_notes": None,
"published_at": None,
"download_url": None,
"checked_at": datetime.utcnow().isoformat()
}
if response.status_code != 200:
logger.error(f"GitHub API error: {response.status_code}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to check for updates from GitHub"
)
release_data = response.json()
latest_version = release_data.get("tag_name", "").lstrip("v")
release_notes = release_data.get("body", "")
published_at = release_data.get("published_at")
update_available = self._is_newer_version(latest_version, current_version)
update_type = self._determine_update_type(latest_version, current_version) if update_available else None
return {
"update_available": update_available,
"available": update_available, # Alias for frontend compatibility
"current_version": current_version,
"latest_version": latest_version,
"update_type": update_type,
"release_notes": release_notes,
"published_at": published_at,
"released_at": published_at, # Alias for frontend compatibility
"download_url": release_data.get("html_url"),
"checked_at": datetime.utcnow().isoformat()
}
except httpx.RequestError as e:
logger.error(f"Network error checking for updates: {str(e)}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Network error while checking for updates"
)
except Exception as e:
logger.error(f"Error checking for updates: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to check for updates: {str(e)}"
)
async def validate_update(self, target_version: str) -> Dict[str, Any]:
"""Run pre-update validation checks"""
validation_results = {
"valid": True,
"checks": [],
"warnings": [],
"errors": []
}
# Check 1: Disk space
disk_check = await self._check_disk_space()
validation_results["checks"].append(disk_check)
if not disk_check["passed"]:
validation_results["valid"] = False
validation_results["errors"].append(disk_check["message"])
# Check 2: Container health
container_check = await self._check_container_health()
validation_results["checks"].append(container_check)
if not container_check["passed"]:
validation_results["valid"] = False
validation_results["errors"].append(container_check["message"])
# Check 3: Database connectivity
db_check = await self._check_database_connectivity()
validation_results["checks"].append(db_check)
if not db_check["passed"]:
validation_results["valid"] = False
validation_results["errors"].append(db_check["message"])
# Check 4: Recent backup exists
backup_check = await self._check_recent_backup()
validation_results["checks"].append(backup_check)
if not backup_check["passed"]:
validation_results["warnings"].append(backup_check["message"])
# Check 5: No running updates
running_update = await self._check_running_updates()
if running_update:
validation_results["valid"] = False
validation_results["errors"].append(
f"Update job {running_update} is already in progress"
)
return validation_results
async def execute_update(
self,
target_version: str,
create_backup: bool = True,
started_by: str = None
) -> str:
"""Execute system update"""
# Create update job
update_job = UpdateJob(
target_version=target_version,
status=UpdateStatus.pending,
started_by=started_by
)
update_job.add_log(f"Update to version {target_version} initiated", "info")
self.db.add(update_job)
await self.db.commit()
await self.db.refresh(update_job)
job_uuid = update_job.uuid
# Start update in background
asyncio.create_task(self._run_update_process(job_uuid, target_version, create_backup))
logger.info(f"Update job {job_uuid} created for version {target_version}")
return job_uuid
async def get_update_status(self, update_id: str) -> Dict[str, Any]:
"""Get current status of an update job"""
stmt = select(UpdateJob).where(UpdateJob.uuid == update_id)
result = await self.db.execute(stmt)
update_job = result.scalar_one_or_none()
if not update_job:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Update job {update_id} not found"
)
return update_job.to_dict()
async def rollback(self, update_id: str, reason: str = None) -> Dict[str, Any]:
"""Rollback a failed update"""
stmt = select(UpdateJob).where(UpdateJob.uuid == update_id)
result = await self.db.execute(stmt)
update_job = result.scalar_one_or_none()
if not update_job:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Update job {update_id} not found"
)
if update_job.status not in [UpdateStatus.failed, UpdateStatus.in_progress]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot rollback update in status: {update_job.status}"
)
update_job.rollback_reason = reason or "Manual rollback requested"
update_job.add_log(f"Rollback initiated: {update_job.rollback_reason}", "warning")
await self.db.commit()
# Execute rollback in background
asyncio.create_task(self._run_rollback_process(update_id))
return {"message": "Rollback initiated", "update_id": update_id}
async def _run_update_process(
self,
job_uuid: str,
target_version: str,
create_backup: bool
):
"""Background task to run update process"""
try:
# Reload job from database
stmt = select(UpdateJob).where(UpdateJob.uuid == job_uuid)
result = await self.db.execute(stmt)
update_job = result.scalar_one_or_none()
if not update_job:
logger.error(f"Update job {job_uuid} not found")
return
update_job.status = UpdateStatus.in_progress
await self.db.commit()
# Stage 1: Create pre-update backup
if create_backup:
update_job.current_stage = "creating_backup"
update_job.add_log("Creating pre-update backup", "info")
await self.db.commit()
backup_service = BackupService(self.db)
backup_result = await backup_service.create_backup(
backup_type="pre_update",
description=f"Pre-update backup before upgrading to {target_version}"
)
update_job.backup_id = backup_result["id"]
update_job.add_log(f"Backup created: {backup_result['uuid']}", "info")
await self.db.commit()
# Stage 2: Execute deploy script
update_job.current_stage = "executing_update"
update_job.add_log(f"Running deploy script for version {target_version}", "info")
await self.db.commit()
# Run deploy.sh script
process = await asyncio.create_subprocess_exec(
self.DEPLOY_SCRIPT,
target_version,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
# Success
update_job.status = UpdateStatus.completed
update_job.current_stage = "completed"
update_job.completed_at = datetime.utcnow()
update_job.add_log(f"Update to {target_version} completed successfully", "info")
# Record new version
await self._record_version(target_version, update_job.started_by)
else:
# Failure
update_job.status = UpdateStatus.failed
update_job.current_stage = "failed"
update_job.completed_at = datetime.utcnow()
error_msg = stderr.decode() if stderr else "Unknown error"
update_job.error_message = error_msg
update_job.add_log(f"Update failed: {error_msg}", "error")
await self.db.commit()
except Exception as e:
logger.error(f"Update process error: {str(e)}")
stmt = select(UpdateJob).where(UpdateJob.uuid == job_uuid)
result = await self.db.execute(stmt)
update_job = result.scalar_one_or_none()
if update_job:
update_job.status = UpdateStatus.failed
update_job.error_message = str(e)
update_job.completed_at = datetime.utcnow()
update_job.add_log(f"Update process exception: {str(e)}", "error")
await self.db.commit()
async def _run_rollback_process(self, job_uuid: str):
"""Background task to run rollback process"""
try:
stmt = select(UpdateJob).where(UpdateJob.uuid == job_uuid)
result = await self.db.execute(stmt)
update_job = result.scalar_one_or_none()
if not update_job:
logger.error(f"Update job {job_uuid} not found")
return
update_job.current_stage = "rolling_back"
update_job.add_log("Executing rollback script", "warning")
await self.db.commit()
# Run rollback script
process = await asyncio.create_subprocess_exec(
self.ROLLBACK_SCRIPT,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
update_job.status = UpdateStatus.rolled_back
update_job.current_stage = "rolled_back"
update_job.completed_at = datetime.utcnow()
update_job.add_log("Rollback completed successfully", "info")
else:
error_msg = stderr.decode() if stderr else "Unknown error"
update_job.add_log(f"Rollback failed: {error_msg}", "error")
await self.db.commit()
except Exception as e:
logger.error(f"Rollback process error: {str(e)}")
async def _get_current_version(self) -> str:
"""Get currently installed version"""
stmt = select(SystemVersion).where(
SystemVersion.is_current == True
).order_by(desc(SystemVersion.installed_at)).limit(1)
result = await self.db.execute(stmt)
current = result.scalar_one_or_none()
return current.version if current else "unknown"
async def _record_version(self, version: str, installed_by: str):
"""Record new system version"""
# Mark all versions as not current
stmt = select(SystemVersion).where(SystemVersion.is_current == True)
result = await self.db.execute(stmt)
old_versions = result.scalars().all()
for old_version in old_versions:
old_version.is_current = False
# Create new version record
new_version = SystemVersion(
version=version,
installed_by=installed_by,
is_current=True
)
self.db.add(new_version)
await self.db.commit()
def _is_newer_version(self, latest: str, current: str) -> bool:
"""Compare version strings"""
try:
latest_parts = [int(x) for x in latest.split(".")]
current_parts = [int(x) for x in current.split(".")]
# Pad shorter version with zeros
max_len = max(len(latest_parts), len(current_parts))
latest_parts += [0] * (max_len - len(latest_parts))
current_parts += [0] * (max_len - len(current_parts))
return latest_parts > current_parts
except (ValueError, AttributeError):
return False
def _determine_update_type(self, latest: str, current: str) -> str:
"""Determine if update is major, minor, or patch"""
try:
latest_parts = [int(x) for x in latest.split(".")]
current_parts = [int(x) for x in current.split(".")]
# Pad to at least 3 parts for comparison
while len(latest_parts) < 3:
latest_parts.append(0)
while len(current_parts) < 3:
current_parts.append(0)
if latest_parts[0] > current_parts[0]:
return "major"
elif latest_parts[1] > current_parts[1]:
return "minor"
else:
return "patch"
except (ValueError, IndexError, AttributeError):
return "patch"
async def _check_disk_space(self) -> Dict[str, Any]:
"""Check available disk space"""
try:
stat = os.statvfs("/")
free_gb = (stat.f_bavail * stat.f_frsize) / (1024 ** 3)
passed = free_gb >= self.MIN_DISK_SPACE_GB
return {
"name": "disk_space",
"passed": passed,
"message": f"Available disk space: {free_gb:.2f} GB (minimum: {self.MIN_DISK_SPACE_GB} GB)",
"details": {"free_gb": round(free_gb, 2)}
}
except Exception as e:
return {
"name": "disk_space",
"passed": False,
"message": f"Failed to check disk space: {str(e)}",
"details": {}
}
async def _check_container_health(self) -> Dict[str, Any]:
"""Check Docker container health"""
try:
# Run docker ps to check container status
process = await asyncio.create_subprocess_exec(
"docker", "ps", "--format", "{{.Names}}|{{.Status}}",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
return {
"name": "container_health",
"passed": False,
"message": "Failed to check container status",
"details": {"error": stderr.decode()}
}
containers = stdout.decode().strip().split("\n")
unhealthy = [c for c in containers if "unhealthy" in c.lower()]
return {
"name": "container_health",
"passed": len(unhealthy) == 0,
"message": f"Container health check: {len(containers)} running, {len(unhealthy)} unhealthy",
"details": {"total": len(containers), "unhealthy": len(unhealthy)}
}
except Exception as e:
return {
"name": "container_health",
"passed": False,
"message": f"Failed to check container health: {str(e)}",
"details": {}
}
async def _check_database_connectivity(self) -> Dict[str, Any]:
"""Check database connection"""
try:
await self.db.execute(select(1))
return {
"name": "database_connectivity",
"passed": True,
"message": "Database connection healthy",
"details": {}
}
except Exception as e:
return {
"name": "database_connectivity",
"passed": False,
"message": f"Database connection failed: {str(e)}",
"details": {}
}
async def _check_recent_backup(self) -> Dict[str, Any]:
"""Check if a recent backup exists"""
try:
from datetime import timedelta
from app.models.system import BackupRecord
one_day_ago = datetime.utcnow() - timedelta(days=1)
stmt = select(BackupRecord).where(
and_(
BackupRecord.created_at >= one_day_ago,
BackupRecord.is_valid == True
)
).order_by(desc(BackupRecord.created_at)).limit(1)
result = await self.db.execute(stmt)
recent_backup = result.scalar_one_or_none()
if recent_backup:
return {
"name": "recent_backup",
"passed": True,
"message": f"Recent backup found: {recent_backup.uuid}",
"details": {"backup_id": recent_backup.id, "created_at": recent_backup.created_at.isoformat()}
}
else:
return {
"name": "recent_backup",
"passed": False,
"message": "No backup found within last 24 hours",
"details": {}
}
except Exception as e:
return {
"name": "recent_backup",
"passed": False,
"message": f"Failed to check for recent backups: {str(e)}",
"details": {}
}
async def _check_running_updates(self) -> Optional[str]:
"""Check for running update jobs"""
stmt = select(UpdateJob.uuid).where(
UpdateJob.status == UpdateStatus.in_progress
).limit(1)
result = await self.db.execute(stmt)
running = result.scalar_one_or_none()
return running