GT AI OS Community Edition v2.0.33
Security hardening release addressing CodeQL and Dependabot alerts: - Fix stack trace exposure in error responses - Add SSRF protection with DNS resolution checking - Implement proper URL hostname validation (replaces substring matching) - Add centralized path sanitization to prevent path traversal - Fix ReDoS vulnerability in email validation regex - Improve HTML sanitization in validation utilities - Fix capability wildcard matching in auth utilities - Update glob dependency to address CVE - Add CodeQL suppression comments for verified false positives 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
3
apps/resource-cluster/app/services/__init__.py
Normal file
3
apps/resource-cluster/app/services/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Service layer for Resource Cluster
|
||||
"""
|
||||
342
apps/resource-cluster/app/services/admin_model_config_service.py
Normal file
342
apps/resource-cluster/app/services/admin_model_config_service.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Admin Model Configuration Service for GT 2.0 Resource Cluster
|
||||
|
||||
This service fetches model configurations from the Admin Control Panel
|
||||
and provides them to the Resource Cluster for LLM routing and capabilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import httpx
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminModelConfig:
|
||||
"""Model configuration from admin cluster"""
|
||||
uuid: str # Database UUID - unique identifier for this model config
|
||||
model_id: str # Business identifier - the model name used in API calls
|
||||
name: str
|
||||
provider: str
|
||||
model_type: str
|
||||
endpoint: str
|
||||
api_key_name: Optional[str]
|
||||
context_window: Optional[int]
|
||||
max_tokens: Optional[int]
|
||||
capabilities: Dict[str, Any]
|
||||
cost_per_1k_input: float
|
||||
cost_per_1k_output: float
|
||||
is_active: bool
|
||||
tenant_restrictions: Dict[str, Any]
|
||||
required_capabilities: List[str]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for LLM Gateway"""
|
||||
return {
|
||||
"uuid": self.uuid,
|
||||
"model_id": self.model_id,
|
||||
"name": self.name,
|
||||
"provider": self.provider,
|
||||
"model_type": self.model_type,
|
||||
"endpoint": self.endpoint,
|
||||
"api_key_name": self.api_key_name,
|
||||
"context_window": self.context_window,
|
||||
"max_tokens": self.max_tokens,
|
||||
"capabilities": self.capabilities,
|
||||
"cost_per_1k_input": self.cost_per_1k_input,
|
||||
"cost_per_1k_output": self.cost_per_1k_output,
|
||||
"is_active": self.is_active,
|
||||
"tenant_restrictions": self.tenant_restrictions,
|
||||
"required_capabilities": self.required_capabilities
|
||||
}
|
||||
|
||||
|
||||
class AdminModelConfigService:
|
||||
"""Service for fetching model configurations from Admin Control Panel"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self._model_cache: Dict[str, AdminModelConfig] = {} # model_id -> config
|
||||
self._uuid_cache: Dict[str, AdminModelConfig] = {} # uuid -> config (for UUID-based lookups)
|
||||
self._tenant_model_cache: Dict[str, List[str]] = {} # tenant_id -> list of allowed model_ids
|
||||
self._last_sync: datetime = datetime.min
|
||||
self._sync_interval = timedelta(seconds=self.settings.config_sync_interval)
|
||||
self._sync_lock = asyncio.Lock()
|
||||
|
||||
async def get_model_config(self, model_id: str) -> Optional[AdminModelConfig]:
|
||||
"""Get configuration for a specific model by model_id string"""
|
||||
await self._ensure_fresh_cache()
|
||||
return self._model_cache.get(model_id)
|
||||
|
||||
async def get_model_by_uuid(self, uuid: str) -> Optional[AdminModelConfig]:
|
||||
"""Get configuration for a specific model by database UUID"""
|
||||
await self._ensure_fresh_cache()
|
||||
return self._uuid_cache.get(uuid)
|
||||
|
||||
async def get_all_models(self, active_only: bool = True) -> List[AdminModelConfig]:
|
||||
"""Get all model configurations"""
|
||||
await self._ensure_fresh_cache()
|
||||
models = list(self._model_cache.values())
|
||||
if active_only:
|
||||
models = [m for m in models if m.is_active]
|
||||
return models
|
||||
|
||||
async def get_tenant_models(self, tenant_id: str) -> List[AdminModelConfig]:
|
||||
"""Get models available to a specific tenant"""
|
||||
await self._ensure_fresh_cache()
|
||||
|
||||
# Get tenant's allowed model IDs - try multiple formats
|
||||
allowed_model_ids = self._get_tenant_model_ids(tenant_id)
|
||||
|
||||
# Return model configs for allowed models
|
||||
models = []
|
||||
for model_id in allowed_model_ids:
|
||||
if model_id in self._model_cache and self._model_cache[model_id].is_active:
|
||||
models.append(self._model_cache[model_id])
|
||||
|
||||
return models
|
||||
|
||||
async def check_tenant_access(self, tenant_id: str, model_id: str) -> bool:
|
||||
"""Check if a tenant has access to a specific model"""
|
||||
await self._ensure_fresh_cache()
|
||||
|
||||
# Check if model exists and is active
|
||||
model_config = self._model_cache.get(model_id)
|
||||
if not model_config or not model_config.is_active:
|
||||
return False
|
||||
|
||||
# Only use tenant-specific access (no global access)
|
||||
# This enforces proper tenant model assignments
|
||||
allowed_models = self._get_tenant_model_ids(tenant_id)
|
||||
return model_id in allowed_models
|
||||
|
||||
def _get_tenant_model_ids(self, tenant_id: str) -> List[str]:
|
||||
"""Get model IDs for tenant, handling multiple tenant ID formats"""
|
||||
# Try exact match first (e.g., "test-company")
|
||||
allowed_models = self._tenant_model_cache.get(tenant_id, [])
|
||||
|
||||
if not allowed_models:
|
||||
# Try converting "test-company" to "test" format
|
||||
if "-" in tenant_id:
|
||||
domain_format = tenant_id.split("-")[0]
|
||||
allowed_models = self._tenant_model_cache.get(domain_format, [])
|
||||
|
||||
# Try converting "test" to "test-company" format
|
||||
elif tenant_id + "-company" in self._tenant_model_cache:
|
||||
allowed_models = self._tenant_model_cache.get(tenant_id + "-company", [])
|
||||
|
||||
# Also try tenant_id as numeric string
|
||||
for key, models in self._tenant_model_cache.items():
|
||||
if key.isdigit() and tenant_id in key:
|
||||
allowed_models.extend(models)
|
||||
break
|
||||
|
||||
logger.debug(f"Tenant {tenant_id} has access to models: {allowed_models}")
|
||||
return allowed_models
|
||||
|
||||
async def get_groq_api_key(self, tenant_id: str = None) -> Optional[str]:
|
||||
"""
|
||||
Get Groq API key for a tenant from Control Panel database.
|
||||
|
||||
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
|
||||
API keys are managed in Control Panel and fetched via internal API.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant domain string (required for tenant requests)
|
||||
|
||||
Returns:
|
||||
Decrypted Groq API key
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key configured for tenant
|
||||
"""
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required to fetch Groq API key - no fallback to environment variables")
|
||||
|
||||
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
|
||||
|
||||
client = get_api_key_client()
|
||||
|
||||
try:
|
||||
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="groq")
|
||||
return key_info["api_key"]
|
||||
except APIKeyNotConfiguredError as e:
|
||||
logger.error(f"No Groq API key configured for tenant '{tenant_id}': {e}")
|
||||
raise ValueError(f"No Groq API key configured for tenant '{tenant_id}'. Please configure in Control Panel → API Keys.")
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Control Panel API error when fetching API key: {e}")
|
||||
raise ValueError(f"Unable to retrieve API key - Control Panel service unavailable: {e}")
|
||||
|
||||
async def _ensure_fresh_cache(self):
|
||||
"""Ensure model cache is fresh, sync if needed"""
|
||||
now = datetime.utcnow()
|
||||
if now - self._last_sync > self._sync_interval:
|
||||
async with self._sync_lock:
|
||||
# Double-check after acquiring lock
|
||||
now = datetime.utcnow()
|
||||
if now - self._last_sync <= self._sync_interval:
|
||||
return
|
||||
|
||||
await self._sync_from_admin()
|
||||
|
||||
async def _sync_from_admin(self):
|
||||
"""Sync model configurations from admin cluster"""
|
||||
try:
|
||||
# Use correct URL for containerized environment
|
||||
import os
|
||||
if os.path.exists('/.dockerenv'):
|
||||
admin_url = "http://control-panel-backend:8000"
|
||||
else:
|
||||
admin_url = self.settings.admin_cluster_url.rstrip('/')
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
# Fetch all model configurations
|
||||
models_response = await client.get(
|
||||
f"{admin_url}/api/v1/models/?active_only=true&include_stats=true"
|
||||
)
|
||||
|
||||
# Fetch tenant model assignments with proper authentication
|
||||
tenant_models_response = await client.get(
|
||||
f"{admin_url}/api/v1/tenant-models/tenants/all",
|
||||
headers={
|
||||
"Authorization": "Bearer admin-dev-token",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
if models_response.status_code == 200:
|
||||
models_data = models_response.json()
|
||||
if models_data and len(models_data) > 0:
|
||||
await self._update_model_cache(models_data)
|
||||
logger.info(f"Successfully synced {len(models_data)} models from admin cluster")
|
||||
|
||||
# Update tenant model assignments if available
|
||||
if tenant_models_response.status_code == 200:
|
||||
tenant_data = tenant_models_response.json()
|
||||
if tenant_data and len(tenant_data) > 0:
|
||||
await self._update_tenant_cache(tenant_data)
|
||||
logger.info(f"Successfully synced {len(tenant_data)} tenant model assignments")
|
||||
else:
|
||||
logger.warning("No tenant model assignments found")
|
||||
else:
|
||||
logger.error(f"Failed to fetch tenant assignments: {tenant_models_response.status_code}")
|
||||
# Log the actual error for debugging
|
||||
try:
|
||||
error_response = tenant_models_response.json()
|
||||
logger.error(f"Tenant assignments error: {error_response}")
|
||||
except:
|
||||
logger.error(f"Tenant assignments error text: {tenant_models_response.text}")
|
||||
|
||||
self._last_sync = datetime.utcnow()
|
||||
return
|
||||
else:
|
||||
logger.warning("Admin cluster returned empty model list")
|
||||
else:
|
||||
logger.warning(f"Failed to fetch models from admin cluster: {models_response.status_code}")
|
||||
|
||||
logger.info("No models configured in admin backend")
|
||||
self._last_sync = datetime.utcnow()
|
||||
logger.info(f"Loaded {len(self._model_cache)} models successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync from admin cluster: {e}")
|
||||
|
||||
# Log final state - no fallback models
|
||||
if not self._model_cache:
|
||||
logger.warning("No models available - admin backend has no models configured")
|
||||
|
||||
async def _update_model_cache(self, models_data: List[Dict[str, Any]]):
|
||||
"""Update model configuration cache"""
|
||||
new_cache = {}
|
||||
new_uuid_cache = {}
|
||||
|
||||
for model_data in models_data:
|
||||
try:
|
||||
specs = model_data.get("specifications", {})
|
||||
cost = model_data.get("cost", {})
|
||||
status = model_data.get("status", {})
|
||||
|
||||
# Get UUID from 'id' field in API response (Control Panel returns UUID as 'id')
|
||||
model_uuid = model_data.get("id", "")
|
||||
|
||||
model_config = AdminModelConfig(
|
||||
uuid=model_uuid,
|
||||
model_id=model_data["model_id"],
|
||||
name=model_data.get("name", model_data["model_id"]),
|
||||
provider=model_data["provider"],
|
||||
model_type=model_data["model_type"],
|
||||
endpoint=model_data.get("endpoint", ""),
|
||||
api_key_name=model_data.get("api_key_name"),
|
||||
context_window=specs.get("context_window"),
|
||||
max_tokens=specs.get("max_tokens"),
|
||||
capabilities=model_data.get("capabilities", {}),
|
||||
cost_per_1k_input=cost.get("per_1k_input", 0.0),
|
||||
cost_per_1k_output=cost.get("per_1k_output", 0.0),
|
||||
is_active=status.get("is_active", False),
|
||||
tenant_restrictions=model_data.get("tenant_restrictions", {"global_access": True}),
|
||||
required_capabilities=model_data.get("required_capabilities", [])
|
||||
)
|
||||
|
||||
new_cache[model_config.model_id] = model_config
|
||||
|
||||
# Also index by UUID for UUID-based lookups
|
||||
if model_uuid:
|
||||
new_uuid_cache[model_uuid] = model_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse model config {model_data.get('model_id', 'unknown')}: {e}")
|
||||
|
||||
self._model_cache = new_cache
|
||||
self._uuid_cache = new_uuid_cache
|
||||
|
||||
async def _update_tenant_cache(self, tenant_data: List[Dict[str, Any]]):
|
||||
"""Update tenant model access cache from tenant-models endpoint"""
|
||||
new_tenant_cache = {}
|
||||
|
||||
for assignment in tenant_data:
|
||||
try:
|
||||
# The tenant-models endpoint returns different format than the old endpoint
|
||||
tenant_domain = assignment.get("tenant_domain", "")
|
||||
model_id = assignment["model_id"]
|
||||
is_enabled = assignment.get("is_enabled", True)
|
||||
|
||||
if is_enabled and tenant_domain:
|
||||
if tenant_domain not in new_tenant_cache:
|
||||
new_tenant_cache[tenant_domain] = []
|
||||
new_tenant_cache[tenant_domain].append(model_id)
|
||||
|
||||
# Also add by tenant_id for backward compatibility
|
||||
tenant_id = str(assignment.get("tenant_id", ""))
|
||||
if tenant_id and tenant_id not in new_tenant_cache:
|
||||
new_tenant_cache[tenant_id] = []
|
||||
if tenant_id:
|
||||
new_tenant_cache[tenant_id].append(model_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse tenant assignment: {e}")
|
||||
|
||||
self._tenant_model_cache = new_tenant_cache
|
||||
logger.debug(f"Updated tenant cache: {self._tenant_model_cache}")
|
||||
|
||||
async def force_sync(self):
|
||||
"""Force immediate sync from admin cluster"""
|
||||
self._last_sync = datetime.min
|
||||
await self._ensure_fresh_cache()
|
||||
|
||||
|
||||
# Global instance
|
||||
_admin_model_service = None
|
||||
|
||||
def get_admin_model_service() -> AdminModelConfigService:
|
||||
"""Get singleton admin model service"""
|
||||
global _admin_model_service
|
||||
if _admin_model_service is None:
|
||||
_admin_model_service = AdminModelConfigService()
|
||||
return _admin_model_service
|
||||
931
apps/resource-cluster/app/services/agent_orchestrator.py
Normal file
931
apps/resource-cluster/app/services/agent_orchestrator.py
Normal file
@@ -0,0 +1,931 @@
|
||||
"""
|
||||
Agent Orchestration System for GT 2.0 Resource Cluster
|
||||
|
||||
Provides multi-agent workflow execution with:
|
||||
- Sequential, parallel, and conditional agent workflows
|
||||
- Inter-agent communication and memory management
|
||||
- Capability-based access control
|
||||
- Agent lifecycle management
|
||||
- Performance monitoring and metrics
|
||||
|
||||
GT 2.0 Architecture Principles:
|
||||
- Perfect Tenant Isolation: Agent sessions isolated per tenant
|
||||
- Zero Downtime: Stateless design, resumable workflows
|
||||
- Self-Contained Security: Capability-based agent permissions
|
||||
- No Complexity Addition: Simple orchestration patterns
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional, Union, Callable, Coroutine
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, asdict
|
||||
import traceback
|
||||
|
||||
from app.core.capability_auth import verify_capability_token, CapabilityError
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class AgentStatus(str, Enum):
|
||||
"""Agent execution status"""
|
||||
IDLE = "idle"
|
||||
RUNNING = "running"
|
||||
WAITING = "waiting"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class WorkflowType(str, Enum):
|
||||
"""Types of agent workflows"""
|
||||
SEQUENTIAL = "sequential"
|
||||
PARALLEL = "parallel"
|
||||
CONDITIONAL = "conditional"
|
||||
PIPELINE = "pipeline"
|
||||
MAP_REDUCE = "map_reduce"
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""Inter-agent message types"""
|
||||
DATA = "data"
|
||||
CONTROL = "control"
|
||||
ERROR = "error"
|
||||
HEARTBEAT = "heartbeat"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentDefinition:
|
||||
"""Definition of an agent"""
|
||||
agent_id: str
|
||||
agent_type: str
|
||||
name: str
|
||||
description: str
|
||||
capabilities_required: List[str]
|
||||
memory_limit_mb: int = 256
|
||||
timeout_seconds: int = 300
|
||||
retry_count: int = 3
|
||||
environment: Dict[str, Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMessage:
|
||||
"""Message between agents"""
|
||||
message_id: str
|
||||
from_agent: str
|
||||
to_agent: str
|
||||
message_type: MessageType
|
||||
content: Dict[str, Any]
|
||||
timestamp: str
|
||||
expires_at: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentState:
|
||||
"""Current state of an agent"""
|
||||
agent_id: str
|
||||
status: AgentStatus
|
||||
current_task: Optional[str]
|
||||
memory_usage_mb: int
|
||||
cpu_usage_percent: float
|
||||
started_at: str
|
||||
last_activity: str
|
||||
error_message: Optional[str] = None
|
||||
output_data: Dict[str, Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowExecution:
|
||||
"""Workflow execution instance"""
|
||||
workflow_id: str
|
||||
workflow_type: WorkflowType
|
||||
tenant_id: str
|
||||
created_by: str
|
||||
agents: List[AgentDefinition]
|
||||
workflow_config: Dict[str, Any]
|
||||
status: AgentStatus
|
||||
started_at: str
|
||||
completed_at: Optional[str] = None
|
||||
results: Dict[str, Any] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class AgentMemoryManager:
|
||||
"""Manages agent memory and state"""
|
||||
|
||||
def __init__(self):
|
||||
# In-memory storage (PostgreSQL used for persistent storage)
|
||||
self._agent_memory: Dict[str, Dict[str, Any]] = {}
|
||||
self._shared_memory: Dict[str, Dict[str, Any]] = {}
|
||||
self._message_queues: Dict[str, List[AgentMessage]] = {}
|
||||
|
||||
async def store_agent_memory(
|
||||
self,
|
||||
agent_id: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: Optional[int] = None
|
||||
) -> None:
|
||||
"""Store data in agent-specific memory"""
|
||||
if agent_id not in self._agent_memory:
|
||||
self._agent_memory[agent_id] = {}
|
||||
|
||||
self._agent_memory[agent_id][key] = {
|
||||
"value": value,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"expires_at": (
|
||||
datetime.utcnow() + timedelta(seconds=ttl_seconds)
|
||||
).isoformat() if ttl_seconds else None
|
||||
}
|
||||
|
||||
logger.debug(f"Stored memory for agent {agent_id}: {key}")
|
||||
|
||||
async def get_agent_memory(
|
||||
self,
|
||||
agent_id: str,
|
||||
key: str
|
||||
) -> Optional[Any]:
|
||||
"""Retrieve data from agent-specific memory"""
|
||||
if agent_id not in self._agent_memory:
|
||||
return None
|
||||
|
||||
memory_item = self._agent_memory[agent_id].get(key)
|
||||
if not memory_item:
|
||||
return None
|
||||
|
||||
# Check expiration
|
||||
if memory_item.get("expires_at"):
|
||||
expires_at = datetime.fromisoformat(memory_item["expires_at"])
|
||||
if datetime.utcnow() > expires_at:
|
||||
del self._agent_memory[agent_id][key]
|
||||
return None
|
||||
|
||||
return memory_item["value"]
|
||||
|
||||
async def store_shared_memory(
|
||||
self,
|
||||
tenant_id: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: Optional[int] = None
|
||||
) -> None:
|
||||
"""Store data in tenant-shared memory"""
|
||||
if tenant_id not in self._shared_memory:
|
||||
self._shared_memory[tenant_id] = {}
|
||||
|
||||
self._shared_memory[tenant_id][key] = {
|
||||
"value": value,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"expires_at": (
|
||||
datetime.utcnow() + timedelta(seconds=ttl_seconds)
|
||||
).isoformat() if ttl_seconds else None
|
||||
}
|
||||
|
||||
logger.debug(f"Stored shared memory for tenant {tenant_id}: {key}")
|
||||
|
||||
async def get_shared_memory(
|
||||
self,
|
||||
tenant_id: str,
|
||||
key: str
|
||||
) -> Optional[Any]:
|
||||
"""Retrieve data from tenant-shared memory"""
|
||||
if tenant_id not in self._shared_memory:
|
||||
return None
|
||||
|
||||
memory_item = self._shared_memory[tenant_id].get(key)
|
||||
if not memory_item:
|
||||
return None
|
||||
|
||||
# Check expiration
|
||||
if memory_item.get("expires_at"):
|
||||
expires_at = datetime.fromisoformat(memory_item["expires_at"])
|
||||
if datetime.utcnow() > expires_at:
|
||||
del self._shared_memory[tenant_id][key]
|
||||
return None
|
||||
|
||||
return memory_item["value"]
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: AgentMessage
|
||||
) -> None:
|
||||
"""Send message to agent queue"""
|
||||
if message.to_agent not in self._message_queues:
|
||||
self._message_queues[message.to_agent] = []
|
||||
|
||||
self._message_queues[message.to_agent].append(message)
|
||||
logger.debug(f"Message sent from {message.from_agent} to {message.to_agent}")
|
||||
|
||||
async def receive_messages(
|
||||
self,
|
||||
agent_id: str,
|
||||
message_type: Optional[MessageType] = None
|
||||
) -> List[AgentMessage]:
|
||||
"""Receive messages for agent"""
|
||||
if agent_id not in self._message_queues:
|
||||
return []
|
||||
|
||||
messages = self._message_queues[agent_id]
|
||||
|
||||
# Filter expired messages
|
||||
now = datetime.utcnow()
|
||||
messages = [
|
||||
msg for msg in messages
|
||||
if not msg.expires_at or datetime.fromisoformat(msg.expires_at) > now
|
||||
]
|
||||
|
||||
# Filter by message type if specified
|
||||
if message_type:
|
||||
messages = [msg for msg in messages if msg.message_type == message_type]
|
||||
|
||||
# Clear processed messages
|
||||
if message_type:
|
||||
self._message_queues[agent_id] = [
|
||||
msg for msg in self._message_queues[agent_id]
|
||||
if msg.message_type != message_type or
|
||||
(msg.expires_at and datetime.fromisoformat(msg.expires_at) <= now)
|
||||
]
|
||||
else:
|
||||
self._message_queues[agent_id] = []
|
||||
|
||||
return messages
|
||||
|
||||
async def cleanup_agent_memory(self, agent_id: str) -> None:
|
||||
"""Clean up memory for completed agent"""
|
||||
if agent_id in self._agent_memory:
|
||||
del self._agent_memory[agent_id]
|
||||
if agent_id in self._message_queues:
|
||||
del self._message_queues[agent_id]
|
||||
|
||||
logger.debug(f"Cleaned up memory for agent {agent_id}")
|
||||
|
||||
|
||||
class AgentOrchestrator:
|
||||
"""
|
||||
Main agent orchestration system for GT 2.0.
|
||||
|
||||
Manages agent lifecycle, workflows, communication, and resource allocation.
|
||||
All operations are tenant-isolated and capability-protected.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.memory_manager = AgentMemoryManager()
|
||||
self.active_workflows: Dict[str, WorkflowExecution] = {}
|
||||
self.agent_states: Dict[str, AgentState] = {}
|
||||
|
||||
# Built-in agent types
|
||||
self.agent_registry: Dict[str, Dict[str, Any]] = {
|
||||
"data_processor": {
|
||||
"description": "Processes and transforms data",
|
||||
"capabilities": ["data.read", "data.transform"],
|
||||
"memory_limit_mb": 512,
|
||||
"timeout_seconds": 300
|
||||
},
|
||||
"llm_agent": {
|
||||
"description": "Interacts with LLM services",
|
||||
"capabilities": ["llm.inference", "llm.chat"],
|
||||
"memory_limit_mb": 256,
|
||||
"timeout_seconds": 600
|
||||
},
|
||||
"embedding_agent": {
|
||||
"description": "Generates text embeddings",
|
||||
"capabilities": ["embeddings.generate"],
|
||||
"memory_limit_mb": 256,
|
||||
"timeout_seconds": 180
|
||||
},
|
||||
"rag_agent": {
|
||||
"description": "Performs retrieval-augmented generation",
|
||||
"capabilities": ["rag.search", "rag.generate"],
|
||||
"memory_limit_mb": 512,
|
||||
"timeout_seconds": 450
|
||||
},
|
||||
"integration_agent": {
|
||||
"description": "Connects to external services",
|
||||
"capabilities": ["integration.call", "integration.webhook"],
|
||||
"memory_limit_mb": 256,
|
||||
"timeout_seconds": 300
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Agent orchestrator initialized")
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
workflow_type: WorkflowType,
|
||||
agents: List[AgentDefinition],
|
||||
workflow_config: Dict[str, Any],
|
||||
capability_token: str,
|
||||
workflow_name: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a new agent workflow.
|
||||
|
||||
Args:
|
||||
workflow_type: Type of workflow to create
|
||||
agents: List of agents to include in workflow
|
||||
workflow_config: Configuration for the workflow
|
||||
capability_token: JWT token with workflow permissions
|
||||
workflow_name: Optional name for the workflow
|
||||
|
||||
Returns:
|
||||
Workflow ID
|
||||
"""
|
||||
# Verify capability token
|
||||
capability = await verify_capability_token(capability_token)
|
||||
tenant_id = capability.get("tenant_id")
|
||||
user_id = capability.get("sub")
|
||||
|
||||
# Check workflow permissions
|
||||
await self._verify_workflow_permissions(capability, workflow_type, agents)
|
||||
|
||||
# Generate workflow ID
|
||||
workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Create workflow execution
|
||||
workflow = WorkflowExecution(
|
||||
workflow_id=workflow_id,
|
||||
workflow_type=workflow_type,
|
||||
tenant_id=tenant_id,
|
||||
created_by=user_id,
|
||||
agents=agents,
|
||||
workflow_config=workflow_config,
|
||||
status=AgentStatus.IDLE,
|
||||
started_at=datetime.utcnow().isoformat()
|
||||
)
|
||||
|
||||
# Store workflow
|
||||
self.active_workflows[workflow_id] = workflow
|
||||
|
||||
logger.info(
|
||||
f"Created {workflow_type} workflow {workflow_id} "
|
||||
f"with {len(agents)} agents for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
return workflow_id
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
input_data: Dict[str, Any],
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute an agent workflow.
|
||||
|
||||
Args:
|
||||
workflow_id: ID of workflow to execute
|
||||
input_data: Input data for the workflow
|
||||
capability_token: JWT token with execution permissions
|
||||
|
||||
Returns:
|
||||
Workflow execution results
|
||||
"""
|
||||
# Verify capability token
|
||||
capability = await verify_capability_token(capability_token)
|
||||
tenant_id = capability.get("tenant_id")
|
||||
|
||||
# Get workflow
|
||||
workflow = self.active_workflows.get(workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {workflow_id} not found")
|
||||
|
||||
# Check tenant isolation
|
||||
if workflow.tenant_id != tenant_id:
|
||||
raise CapabilityError("Insufficient permissions for workflow")
|
||||
|
||||
# Check workflow permissions
|
||||
await self._verify_execution_permissions(capability, workflow)
|
||||
|
||||
try:
|
||||
# Update workflow status
|
||||
workflow.status = AgentStatus.RUNNING
|
||||
|
||||
# Execute based on workflow type
|
||||
if workflow.workflow_type == WorkflowType.SEQUENTIAL:
|
||||
results = await self._execute_sequential_workflow(
|
||||
workflow, input_data, capability_token
|
||||
)
|
||||
elif workflow.workflow_type == WorkflowType.PARALLEL:
|
||||
results = await self._execute_parallel_workflow(
|
||||
workflow, input_data, capability_token
|
||||
)
|
||||
elif workflow.workflow_type == WorkflowType.CONDITIONAL:
|
||||
results = await self._execute_conditional_workflow(
|
||||
workflow, input_data, capability_token
|
||||
)
|
||||
elif workflow.workflow_type == WorkflowType.PIPELINE:
|
||||
results = await self._execute_pipeline_workflow(
|
||||
workflow, input_data, capability_token
|
||||
)
|
||||
elif workflow.workflow_type == WorkflowType.MAP_REDUCE:
|
||||
results = await self._execute_map_reduce_workflow(
|
||||
workflow, input_data, capability_token
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported workflow type: {workflow.workflow_type}")
|
||||
|
||||
# Update workflow completion
|
||||
workflow.status = AgentStatus.COMPLETED
|
||||
workflow.completed_at = datetime.utcnow().isoformat()
|
||||
workflow.results = results
|
||||
|
||||
logger.info(f"Completed workflow {workflow_id} successfully")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
# Update workflow error status
|
||||
workflow.status = AgentStatus.FAILED
|
||||
workflow.completed_at = datetime.utcnow().isoformat()
|
||||
workflow.error_message = str(e)
|
||||
|
||||
logger.error(f"Workflow {workflow_id} failed: {e}")
|
||||
raise
|
||||
|
||||
async def get_workflow_status(
|
||||
self,
|
||||
workflow_id: str,
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get status of a workflow"""
|
||||
# Verify capability token
|
||||
capability = await verify_capability_token(capability_token)
|
||||
tenant_id = capability.get("tenant_id")
|
||||
|
||||
# Get workflow
|
||||
workflow = self.active_workflows.get(workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {workflow_id} not found")
|
||||
|
||||
# Check tenant isolation
|
||||
if workflow.tenant_id != tenant_id:
|
||||
raise CapabilityError("Insufficient permissions for workflow")
|
||||
|
||||
# Get agent states for this workflow
|
||||
agent_states = {
|
||||
agent.agent_id: asdict(self.agent_states.get(agent.agent_id))
|
||||
for agent in workflow.agents
|
||||
if agent.agent_id in self.agent_states
|
||||
}
|
||||
|
||||
return {
|
||||
"workflow": asdict(workflow),
|
||||
"agent_states": agent_states
|
||||
}
|
||||
|
||||
async def cancel_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
capability_token: str
|
||||
) -> None:
|
||||
"""Cancel a running workflow"""
|
||||
# Verify capability token
|
||||
capability = await verify_capability_token(capability_token)
|
||||
tenant_id = capability.get("tenant_id")
|
||||
|
||||
# Get workflow
|
||||
workflow = self.active_workflows.get(workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {workflow_id} not found")
|
||||
|
||||
# Check tenant isolation
|
||||
if workflow.tenant_id != tenant_id:
|
||||
raise CapabilityError("Insufficient permissions for workflow")
|
||||
|
||||
# Cancel workflow
|
||||
workflow.status = AgentStatus.CANCELLED
|
||||
workflow.completed_at = datetime.utcnow().isoformat()
|
||||
|
||||
# Cancel all agents in workflow
|
||||
for agent in workflow.agents:
|
||||
if agent.agent_id in self.agent_states:
|
||||
self.agent_states[agent.agent_id].status = AgentStatus.CANCELLED
|
||||
|
||||
logger.info(f"Cancelled workflow {workflow_id}")
|
||||
|
||||
async def _execute_sequential_workflow(
|
||||
self,
|
||||
workflow: WorkflowExecution,
|
||||
input_data: Dict[str, Any],
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute agents sequentially"""
|
||||
results = {}
|
||||
current_data = input_data
|
||||
|
||||
for agent in workflow.agents:
|
||||
agent_result = await self._execute_agent(
|
||||
agent, current_data, capability_token
|
||||
)
|
||||
results[agent.agent_id] = agent_result
|
||||
|
||||
# Pass output to next agent
|
||||
if "output" in agent_result:
|
||||
current_data = agent_result["output"]
|
||||
|
||||
return {
|
||||
"workflow_type": "sequential",
|
||||
"final_output": current_data,
|
||||
"agent_results": results
|
||||
}
|
||||
|
||||
async def _execute_parallel_workflow(
|
||||
self,
|
||||
workflow: WorkflowExecution,
|
||||
input_data: Dict[str, Any],
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute agents in parallel"""
|
||||
# Create tasks for all agents
|
||||
tasks = []
|
||||
for agent in workflow.agents:
|
||||
task = asyncio.create_task(
|
||||
self._execute_agent(agent, input_data, capability_token)
|
||||
)
|
||||
tasks.append((agent.agent_id, task))
|
||||
|
||||
# Wait for all tasks to complete
|
||||
results = {}
|
||||
for agent_id, task in tasks:
|
||||
try:
|
||||
results[agent_id] = await task
|
||||
except Exception as e:
|
||||
results[agent_id] = {"error": str(e)}
|
||||
|
||||
return {
|
||||
"workflow_type": "parallel",
|
||||
"agent_results": results
|
||||
}
|
||||
|
||||
async def _execute_conditional_workflow(
|
||||
self,
|
||||
workflow: WorkflowExecution,
|
||||
input_data: Dict[str, Any],
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute agents based on conditions"""
|
||||
results = {}
|
||||
condition_config = workflow.workflow_config.get("conditions", {})
|
||||
|
||||
for agent in workflow.agents:
|
||||
# Check if agent should execute based on conditions
|
||||
should_execute = await self._evaluate_condition(
|
||||
agent.agent_id, condition_config, input_data, results
|
||||
)
|
||||
|
||||
if should_execute:
|
||||
agent_result = await self._execute_agent(
|
||||
agent, input_data, capability_token
|
||||
)
|
||||
results[agent.agent_id] = agent_result
|
||||
else:
|
||||
results[agent.agent_id] = {"status": "skipped"}
|
||||
|
||||
return {
|
||||
"workflow_type": "conditional",
|
||||
"agent_results": results
|
||||
}
|
||||
|
||||
async def _execute_pipeline_workflow(
|
||||
self,
|
||||
workflow: WorkflowExecution,
|
||||
input_data: Dict[str, Any],
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute agents in pipeline with data transformation"""
|
||||
results = {}
|
||||
current_data = input_data
|
||||
|
||||
for i, agent in enumerate(workflow.agents):
|
||||
# Add pipeline metadata
|
||||
pipeline_data = {
|
||||
**current_data,
|
||||
"_pipeline_stage": i,
|
||||
"_pipeline_total": len(workflow.agents)
|
||||
}
|
||||
|
||||
agent_result = await self._execute_agent(
|
||||
agent, pipeline_data, capability_token
|
||||
)
|
||||
results[agent.agent_id] = agent_result
|
||||
|
||||
# Transform data for next stage
|
||||
if "transformed_output" in agent_result:
|
||||
current_data = agent_result["transformed_output"]
|
||||
elif "output" in agent_result:
|
||||
current_data = agent_result["output"]
|
||||
|
||||
return {
|
||||
"workflow_type": "pipeline",
|
||||
"final_output": current_data,
|
||||
"agent_results": results
|
||||
}
|
||||
|
||||
async def _execute_map_reduce_workflow(
|
||||
self,
|
||||
workflow: WorkflowExecution,
|
||||
input_data: Dict[str, Any],
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute map-reduce workflow"""
|
||||
# Separate map and reduce agents
|
||||
map_agents = [a for a in workflow.agents if a.agent_type.endswith("_mapper")]
|
||||
reduce_agents = [a for a in workflow.agents if a.agent_type.endswith("_reducer")]
|
||||
|
||||
# Execute map phase
|
||||
map_tasks = []
|
||||
input_chunks = input_data.get("chunks", [input_data])
|
||||
|
||||
for i, chunk in enumerate(input_chunks):
|
||||
for agent in map_agents:
|
||||
task = asyncio.create_task(
|
||||
self._execute_agent(agent, chunk, capability_token)
|
||||
)
|
||||
map_tasks.append((f"{agent.agent_id}_chunk_{i}", task))
|
||||
|
||||
# Collect map results
|
||||
map_results = {}
|
||||
for task_id, task in map_tasks:
|
||||
try:
|
||||
map_results[task_id] = await task
|
||||
except Exception as e:
|
||||
map_results[task_id] = {"error": str(e)}
|
||||
|
||||
# Execute reduce phase
|
||||
reduce_results = {}
|
||||
reduce_input = {"map_results": map_results}
|
||||
|
||||
for agent in reduce_agents:
|
||||
agent_result = await self._execute_agent(
|
||||
agent, reduce_input, capability_token
|
||||
)
|
||||
reduce_results[agent.agent_id] = agent_result
|
||||
|
||||
return {
|
||||
"workflow_type": "map_reduce",
|
||||
"map_results": map_results,
|
||||
"reduce_results": reduce_results
|
||||
}
|
||||
|
||||
async def _execute_agent(
|
||||
self,
|
||||
agent: AgentDefinition,
|
||||
input_data: Dict[str, Any],
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a single agent"""
|
||||
start_time = time.time()
|
||||
|
||||
# Create agent state
|
||||
agent_state = AgentState(
|
||||
agent_id=agent.agent_id,
|
||||
status=AgentStatus.RUNNING,
|
||||
current_task=f"Executing {agent.agent_type}",
|
||||
memory_usage_mb=0,
|
||||
cpu_usage_percent=0.0,
|
||||
started_at=datetime.utcnow().isoformat(),
|
||||
last_activity=datetime.utcnow().isoformat()
|
||||
)
|
||||
self.agent_states[agent.agent_id] = agent_state
|
||||
|
||||
try:
|
||||
# Simulate agent execution based on type
|
||||
if agent.agent_type == "data_processor":
|
||||
result = await self._execute_data_processor(agent, input_data)
|
||||
elif agent.agent_type == "llm_agent":
|
||||
result = await self._execute_llm_agent(agent, input_data, capability_token)
|
||||
elif agent.agent_type == "embedding_agent":
|
||||
result = await self._execute_embedding_agent(agent, input_data, capability_token)
|
||||
elif agent.agent_type == "rag_agent":
|
||||
result = await self._execute_rag_agent(agent, input_data, capability_token)
|
||||
elif agent.agent_type == "integration_agent":
|
||||
result = await self._execute_integration_agent(agent, input_data, capability_token)
|
||||
else:
|
||||
result = await self._execute_custom_agent(agent, input_data)
|
||||
|
||||
# Update agent state
|
||||
agent_state.status = AgentStatus.COMPLETED
|
||||
agent_state.output_data = result
|
||||
agent_state.last_activity = datetime.utcnow().isoformat()
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent.agent_id} completed in {processing_time:.2f}s"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"processing_time": processing_time,
|
||||
"output": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Update agent error state
|
||||
agent_state.status = AgentStatus.FAILED
|
||||
agent_state.error_message = str(e)
|
||||
agent_state.last_activity = datetime.utcnow().isoformat()
|
||||
|
||||
logger.error(f"Agent {agent.agent_id} failed: {e}")
|
||||
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"processing_time": time.time() - start_time
|
||||
}
|
||||
|
||||
# Agent execution implementations would go here...
|
||||
# For now, these are placeholder implementations
|
||||
|
||||
async def _execute_data_processor(
|
||||
self,
|
||||
agent: AgentDefinition,
|
||||
input_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute data processing agent"""
|
||||
await asyncio.sleep(0.1) # Simulate processing
|
||||
return {
|
||||
"processed_data": input_data,
|
||||
"processing_info": "Data processed successfully"
|
||||
}
|
||||
|
||||
async def _execute_llm_agent(
|
||||
self,
|
||||
agent: AgentDefinition,
|
||||
input_data: Dict[str, Any],
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute LLM agent"""
|
||||
await asyncio.sleep(0.2) # Simulate LLM call
|
||||
return {
|
||||
"llm_response": f"LLM processed: {input_data.get('prompt', 'No prompt provided')}",
|
||||
"model_used": "groq/llama-3-8b"
|
||||
}
|
||||
|
||||
async def _execute_embedding_agent(
|
||||
self,
|
||||
agent: AgentDefinition,
|
||||
input_data: Dict[str, Any],
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute embedding agent"""
|
||||
await asyncio.sleep(0.1) # Simulate embedding generation
|
||||
texts = input_data.get("texts", [""])
|
||||
return {
|
||||
"embeddings": [[0.1] * 1024 for _ in texts], # Mock embeddings
|
||||
"model_used": "BAAI/bge-m3"
|
||||
}
|
||||
|
||||
async def _execute_rag_agent(
|
||||
self,
|
||||
agent: AgentDefinition,
|
||||
input_data: Dict[str, Any],
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute RAG agent"""
|
||||
await asyncio.sleep(0.3) # Simulate RAG processing
|
||||
return {
|
||||
"rag_response": "RAG generated response",
|
||||
"retrieved_docs": ["doc1", "doc2"],
|
||||
"confidence_score": 0.85
|
||||
}
|
||||
|
||||
async def _execute_integration_agent(
|
||||
self,
|
||||
agent: AgentDefinition,
|
||||
input_data: Dict[str, Any],
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute integration agent"""
|
||||
await asyncio.sleep(0.1) # Simulate external API call
|
||||
return {
|
||||
"integration_result": "External API called successfully",
|
||||
"response_data": input_data
|
||||
}
|
||||
|
||||
async def _execute_custom_agent(
|
||||
self,
|
||||
agent: AgentDefinition,
|
||||
input_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute custom agent type"""
|
||||
await asyncio.sleep(0.1) # Simulate custom processing
|
||||
return {
|
||||
"custom_result": f"Custom agent {agent.agent_type} executed",
|
||||
"input_data": input_data
|
||||
}
|
||||
|
||||
async def _verify_workflow_permissions(
|
||||
self,
|
||||
capability: Dict[str, Any],
|
||||
workflow_type: WorkflowType,
|
||||
agents: List[AgentDefinition]
|
||||
) -> None:
|
||||
"""Verify workflow creation permissions"""
|
||||
capabilities = capability.get("capabilities", [])
|
||||
|
||||
# Check for workflow creation permission
|
||||
workflow_caps = [
|
||||
cap for cap in capabilities
|
||||
if cap.get("resource") == "workflows"
|
||||
]
|
||||
|
||||
if not workflow_caps:
|
||||
raise CapabilityError("No workflow permissions in capability token")
|
||||
|
||||
# Check specific workflow type permission
|
||||
workflow_cap = workflow_caps[0]
|
||||
actions = workflow_cap.get("actions", [])
|
||||
|
||||
if "create" not in actions:
|
||||
raise CapabilityError("No workflow creation permission")
|
||||
|
||||
# Check agent-specific permissions
|
||||
for agent in agents:
|
||||
for required_cap in agent.capabilities_required:
|
||||
if not any(
|
||||
cap.get("resource") == required_cap.split(".")[0]
|
||||
for cap in capabilities
|
||||
):
|
||||
raise CapabilityError(
|
||||
f"Missing capability for agent {agent.agent_id}: {required_cap}"
|
||||
)
|
||||
|
||||
async def _verify_execution_permissions(
|
||||
self,
|
||||
capability: Dict[str, Any],
|
||||
workflow: WorkflowExecution
|
||||
) -> None:
|
||||
"""Verify workflow execution permissions"""
|
||||
capabilities = capability.get("capabilities", [])
|
||||
|
||||
# Check for workflow execution permission
|
||||
workflow_caps = [
|
||||
cap for cap in capabilities
|
||||
if cap.get("resource") == "workflows"
|
||||
]
|
||||
|
||||
if not workflow_caps:
|
||||
raise CapabilityError("No workflow permissions in capability token")
|
||||
|
||||
workflow_cap = workflow_caps[0]
|
||||
actions = workflow_cap.get("actions", [])
|
||||
|
||||
if "execute" not in actions:
|
||||
raise CapabilityError("No workflow execution permission")
|
||||
|
||||
async def _evaluate_condition(
|
||||
self,
|
||||
agent_id: str,
|
||||
condition_config: Dict[str, Any],
|
||||
input_data: Dict[str, Any],
|
||||
results: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Evaluate condition for conditional workflow"""
|
||||
agent_condition = condition_config.get(agent_id, {})
|
||||
|
||||
if not agent_condition:
|
||||
return True # No condition means always execute
|
||||
|
||||
condition_type = agent_condition.get("type", "always")
|
||||
|
||||
if condition_type == "always":
|
||||
return True
|
||||
elif condition_type == "never":
|
||||
return False
|
||||
elif condition_type == "input_contains":
|
||||
key = agent_condition.get("key")
|
||||
value = agent_condition.get("value")
|
||||
return input_data.get(key) == value
|
||||
elif condition_type == "previous_success":
|
||||
previous_agent = agent_condition.get("previous_agent")
|
||||
return (
|
||||
previous_agent in results and
|
||||
results[previous_agent].get("status") == "completed"
|
||||
)
|
||||
elif condition_type == "previous_failure":
|
||||
previous_agent = agent_condition.get("previous_agent")
|
||||
return (
|
||||
previous_agent in results and
|
||||
results[previous_agent].get("status") == "failed"
|
||||
)
|
||||
|
||||
return True # Default to execute if condition not recognized
|
||||
|
||||
|
||||
# Global orchestrator instance
|
||||
_agent_orchestrator = None
|
||||
|
||||
|
||||
def get_agent_orchestrator() -> AgentOrchestrator:
|
||||
"""Get the global agent orchestrator instance"""
|
||||
global _agent_orchestrator
|
||||
if _agent_orchestrator is None:
|
||||
_agent_orchestrator = AgentOrchestrator()
|
||||
return _agent_orchestrator
|
||||
280
apps/resource-cluster/app/services/config_sync.py
Normal file
280
apps/resource-cluster/app/services/config_sync.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
GT 2.0 Configuration Sync Service
|
||||
|
||||
Syncs model configurations from admin cluster to resource cluster.
|
||||
Enables admin control panel to control AI model routing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.model_service import default_model_service
|
||||
from app.providers.external_provider import get_external_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class ConfigSyncService:
|
||||
"""Syncs model configurations from admin cluster"""
|
||||
|
||||
def __init__(self):
|
||||
# Force Docker service name for admin cluster communication in containerized environment
|
||||
if hasattr(settings, 'admin_cluster_url') and settings.admin_cluster_url:
|
||||
# Check if we're running in Docker (container environment)
|
||||
import os
|
||||
if os.path.exists('/.dockerenv'):
|
||||
self.admin_cluster_url = "http://control-panel-backend:8000"
|
||||
else:
|
||||
self.admin_cluster_url = settings.admin_cluster_url
|
||||
else:
|
||||
self.admin_cluster_url = "http://control-panel-backend:8000"
|
||||
self.sync_interval = settings.config_sync_interval or 60 # seconds
|
||||
# Use the default singleton model service instance
|
||||
self.model_service = default_model_service
|
||||
self.last_sync = 0
|
||||
self.sync_running = False
|
||||
|
||||
async def start_sync_loop(self):
|
||||
"""Start the configuration sync loop"""
|
||||
logger.info("Starting configuration sync loop")
|
||||
|
||||
while True:
|
||||
try:
|
||||
if not self.sync_running:
|
||||
await self.sync_configurations()
|
||||
await asyncio.sleep(self.sync_interval)
|
||||
except Exception as e:
|
||||
logger.error(f"Config sync loop error: {e}")
|
||||
await asyncio.sleep(30) # Wait 30s on error
|
||||
|
||||
async def sync_configurations(self):
|
||||
"""Sync model configurations from admin cluster"""
|
||||
if self.sync_running:
|
||||
return
|
||||
|
||||
self.sync_running = True
|
||||
|
||||
try:
|
||||
logger.debug("Syncing model configurations from admin cluster")
|
||||
|
||||
# Fetch all model configurations from admin cluster
|
||||
configs = await self._fetch_admin_configs()
|
||||
|
||||
if configs:
|
||||
# Update local model registry
|
||||
await self._update_local_registry(configs)
|
||||
|
||||
# Update provider configurations
|
||||
await self._update_provider_configs(configs)
|
||||
|
||||
self.last_sync = time.time()
|
||||
logger.info(f"Successfully synced {len(configs)} model configurations")
|
||||
else:
|
||||
logger.warning("No configurations received from admin cluster")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Configuration sync failed: {e}")
|
||||
finally:
|
||||
self.sync_running = False
|
||||
|
||||
async def _fetch_admin_configs(self) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Fetch model configurations from admin cluster"""
|
||||
try:
|
||||
logger.debug(f"Attempting to fetch configs from: {self.admin_cluster_url}/api/v1/models/configs/all")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
# Add authentication for admin cluster access
|
||||
headers = {
|
||||
"Authorization": "Bearer admin-cluster-sync-token",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
response = await client.get(
|
||||
f"{self.admin_cluster_url}/api/v1/models/configs/all",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
logger.debug(f"Admin cluster response: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
configs = data.get("configs", [])
|
||||
logger.debug(f"Successfully fetched {len(configs)} model configurations")
|
||||
return configs
|
||||
else:
|
||||
logger.warning(f"Admin cluster returned {response.status_code}: {response.text}")
|
||||
return None
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to connect to admin cluster: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching admin configs: {e}")
|
||||
return None
|
||||
|
||||
async def _update_local_registry(self, configs: List[Dict[str, Any]]):
|
||||
"""Update local model registry with admin configurations"""
|
||||
try:
|
||||
for config in configs:
|
||||
await self.model_service.register_or_update_model(
|
||||
model_id=config["model_id"],
|
||||
name=config["name"],
|
||||
version=config["version"],
|
||||
provider=config["provider"],
|
||||
model_type=config["model_type"],
|
||||
endpoint=config["endpoint"],
|
||||
api_key_name=config.get("api_key_name"),
|
||||
specifications=config.get("specifications", {}),
|
||||
capabilities=config.get("capabilities", {}),
|
||||
cost=config.get("cost", {}),
|
||||
description=config.get("description"),
|
||||
config=config.get("config", {}),
|
||||
status=config.get("status", {}),
|
||||
sync_timestamp=config.get("sync_timestamp")
|
||||
)
|
||||
|
||||
# Log BGE-M3 configuration details for debugging persistence
|
||||
if "bge-m3" in config["model_id"].lower():
|
||||
model_config = config.get("config", {})
|
||||
logger.info(
|
||||
f"Synced BGE-M3 configuration from database: "
|
||||
f"endpoint={config['endpoint']}, "
|
||||
f"is_local_mode={model_config.get('is_local_mode', True)}, "
|
||||
f"external_endpoint={model_config.get('external_endpoint', 'None')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update local registry: {e}")
|
||||
raise
|
||||
|
||||
async def _update_provider_configs(self, configs: List[Dict[str, Any]]):
|
||||
"""Update provider configurations based on admin settings"""
|
||||
try:
|
||||
# Group configs by provider
|
||||
provider_configs = {}
|
||||
for config in configs:
|
||||
provider = config["provider"]
|
||||
if provider not in provider_configs:
|
||||
provider_configs[provider] = []
|
||||
provider_configs[provider].append(config)
|
||||
|
||||
# Update each provider
|
||||
for provider, provider_models in provider_configs.items():
|
||||
await self._update_provider(provider, provider_models)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update provider configs: {e}")
|
||||
raise
|
||||
|
||||
async def _update_provider(self, provider: str, models: List[Dict[str, Any]]):
|
||||
"""Update specific provider configuration"""
|
||||
try:
|
||||
# Generic provider update - all providers are now supported automatically
|
||||
provider_models = [m for m in models if m["provider"] == provider]
|
||||
logger.debug(f"Updated {provider} provider with {len(provider_models)} models")
|
||||
|
||||
# Keep legacy support for specific providers if needed
|
||||
if provider == "groq":
|
||||
await self._update_groq_provider(models)
|
||||
elif provider == "external":
|
||||
await self._update_external_provider(models)
|
||||
elif provider == "openai":
|
||||
await self._update_openai_provider(models)
|
||||
elif provider == "anthropic":
|
||||
await self._update_anthropic_provider(models)
|
||||
elif provider == "vllm":
|
||||
await self._update_vllm_provider(models)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update {provider} provider: {e}")
|
||||
raise
|
||||
|
||||
async def _update_groq_provider(self, models: List[Dict[str, Any]]):
|
||||
"""Update Groq provider configuration"""
|
||||
# Update available Groq models
|
||||
groq_models = [m for m in models if m["provider"] == "groq"]
|
||||
logger.debug(f"Updated Groq provider with {len(groq_models)} models")
|
||||
|
||||
async def _update_external_provider(self, models: List[Dict[str, Any]]):
|
||||
"""Update external provider configuration (BGE-M3, etc.)"""
|
||||
external_models = [m for m in models if m["provider"] == "external"]
|
||||
|
||||
if external_models:
|
||||
external_provider = await get_external_provider()
|
||||
|
||||
for model in external_models:
|
||||
if "bge-m3" in model["model_id"].lower():
|
||||
# Update BGE-M3 endpoint configuration
|
||||
external_provider.update_model_endpoint(
|
||||
model["model_id"],
|
||||
model["endpoint"]
|
||||
)
|
||||
logger.debug(f"Updated BGE-M3 endpoint: {model['endpoint']}")
|
||||
|
||||
# Also refresh the embedding backend instance
|
||||
try:
|
||||
from app.core.backends import get_embedding_backend
|
||||
embedding_backend = get_embedding_backend()
|
||||
embedding_backend.refresh_endpoint_from_registry()
|
||||
logger.info(f"Refreshed embedding backend with new BGE-M3 endpoint from database")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh embedding backend: {e}")
|
||||
|
||||
logger.debug(f"Updated external provider with {len(external_models)} models")
|
||||
|
||||
async def _update_openai_provider(self, models: List[Dict[str, Any]]):
|
||||
"""Update OpenAI provider configuration"""
|
||||
openai_models = [m for m in models if m["provider"] == "openai"]
|
||||
logger.debug(f"Updated OpenAI provider with {len(openai_models)} models")
|
||||
|
||||
async def _update_anthropic_provider(self, models: List[Dict[str, Any]]):
|
||||
"""Update Anthropic provider configuration"""
|
||||
anthropic_models = [m for m in models if m["provider"] == "anthropic"]
|
||||
logger.debug(f"Updated Anthropic provider with {len(anthropic_models)} models")
|
||||
|
||||
async def _update_vllm_provider(self, models: List[Dict[str, Any]]):
|
||||
"""Update vLLM provider configuration (BGE-M3 embeddings, etc.)"""
|
||||
vllm_models = [m for m in models if m["provider"] == "vllm"]
|
||||
|
||||
for model in vllm_models:
|
||||
if model["model_type"] == "embedding":
|
||||
# This is an embedding model like BGE-M3
|
||||
logger.debug(f"Updated vLLM embedding model: {model['model_id']} -> {model['endpoint']}")
|
||||
else:
|
||||
logger.debug(f"Updated vLLM model: {model['model_id']} -> {model['endpoint']}")
|
||||
|
||||
logger.debug(f"Updated vLLM provider with {len(vllm_models)} models")
|
||||
|
||||
async def force_sync(self):
|
||||
"""Force immediate configuration sync"""
|
||||
logger.info("Force syncing configurations")
|
||||
await self.sync_configurations()
|
||||
|
||||
def get_sync_status(self) -> Dict[str, Any]:
|
||||
"""Get current sync status"""
|
||||
return {
|
||||
"last_sync": datetime.fromtimestamp(self.last_sync).isoformat() if self.last_sync else None,
|
||||
"sync_running": self.sync_running,
|
||||
"admin_cluster_url": self.admin_cluster_url,
|
||||
"sync_interval": self.sync_interval,
|
||||
"next_sync": datetime.fromtimestamp(self.last_sync + self.sync_interval).isoformat() if self.last_sync else None
|
||||
}
|
||||
|
||||
|
||||
# Global config sync service instance
|
||||
_config_sync_service = None
|
||||
|
||||
def get_config_sync_service() -> ConfigSyncService:
|
||||
"""Get configuration sync service instance"""
|
||||
global _config_sync_service
|
||||
if _config_sync_service is None:
|
||||
_config_sync_service = ConfigSyncService()
|
||||
return _config_sync_service
|
||||
101
apps/resource-cluster/app/services/consul_registry.py
Normal file
101
apps/resource-cluster/app/services/consul_registry.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Consul Service Registry
|
||||
|
||||
Handles service registration and discovery for the Resource Cluster.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
import consul
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class ConsulRegistry:
|
||||
"""Service registry using Consul"""
|
||||
|
||||
def __init__(self):
|
||||
self.consul = None
|
||||
try:
|
||||
self.consul = consul.Consul(
|
||||
host=settings.consul_host,
|
||||
port=settings.consul_port,
|
||||
token=settings.consul_token
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Consul not available: {e}")
|
||||
|
||||
async def register_service(
|
||||
self,
|
||||
name: str,
|
||||
service_id: str,
|
||||
address: str,
|
||||
port: int,
|
||||
tags: List[str] = None,
|
||||
check_interval: str = "10s"
|
||||
) -> bool:
|
||||
"""Register service with Consul"""
|
||||
|
||||
if not self.consul:
|
||||
logger.warning("Consul not available, skipping registration")
|
||||
return False
|
||||
|
||||
try:
|
||||
self.consul.agent.service.register(
|
||||
name=name,
|
||||
service_id=service_id,
|
||||
address=address,
|
||||
port=port,
|
||||
tags=tags or [],
|
||||
check=consul.Check.http(
|
||||
f"http://{address}:{port}/health",
|
||||
interval=check_interval
|
||||
)
|
||||
)
|
||||
logger.info(f"Registered service {service_id} with Consul")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register with Consul: {e}")
|
||||
return False
|
||||
|
||||
async def deregister_service(self, service_id: str) -> bool:
|
||||
"""Deregister service from Consul"""
|
||||
|
||||
if not self.consul:
|
||||
return False
|
||||
|
||||
try:
|
||||
self.consul.agent.service.deregister(service_id)
|
||||
logger.info(f"Deregistered service {service_id} from Consul")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deregister from Consul: {e}")
|
||||
return False
|
||||
|
||||
async def discover_service(self, service_name: str) -> List[Dict[str, Any]]:
|
||||
"""Discover service instances"""
|
||||
|
||||
if not self.consul:
|
||||
return []
|
||||
|
||||
try:
|
||||
_, services = self.consul.health.service(service_name, passing=True)
|
||||
|
||||
instances = []
|
||||
for service in services:
|
||||
instances.append({
|
||||
"id": service["Service"]["ID"],
|
||||
"address": service["Service"]["Address"],
|
||||
"port": service["Service"]["Port"],
|
||||
"tags": service["Service"]["Tags"]
|
||||
})
|
||||
|
||||
return instances
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to discover service: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,536 @@
|
||||
"""
|
||||
Enhanced Document Processing Pipeline with Dual-Engine Support
|
||||
|
||||
Implements the DocumentProcessingPipeline from CLAUDE.md with both native
|
||||
and Unstructured.io engine support, capability-based selection, and
|
||||
stateless processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import gc
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from app.core.backends.document_processor import (
|
||||
DocumentProcessorBackend,
|
||||
ChunkingStrategy
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingResult:
|
||||
"""Result of document processing"""
|
||||
chunks: List[Dict[str, str]]
|
||||
embeddings: Optional[List[List[float]]] # Optional embeddings
|
||||
metadata: Dict[str, Any]
|
||||
engine_used: str
|
||||
processing_time_ms: float
|
||||
token_count: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingOptions:
|
||||
"""Options for document processing"""
|
||||
engine_preference: str = "auto" # "native", "unstructured", "auto"
|
||||
chunking_strategy: str = "hybrid" # "fixed", "semantic", "hierarchical", "hybrid"
|
||||
chunk_size: int = 512 # tokens for BGE-M3
|
||||
chunk_overlap: int = 128 # overlap tokens
|
||||
generate_embeddings: bool = True
|
||||
extract_metadata: bool = True
|
||||
language_detection: bool = True
|
||||
ocr_enabled: bool = False # For scanned PDFs
|
||||
|
||||
|
||||
class UnstructuredAPIEngine:
|
||||
"""
|
||||
Mock Unstructured.io API engine for advanced document parsing.
|
||||
In production, this would call the actual Unstructured API.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None):
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url or "https://api.unstructured.io"
|
||||
self.supported_features = [
|
||||
"table_extraction",
|
||||
"image_extraction",
|
||||
"ocr",
|
||||
"language_detection",
|
||||
"metadata_extraction",
|
||||
"hierarchical_parsing"
|
||||
]
|
||||
|
||||
async def process(
|
||||
self,
|
||||
content: bytes,
|
||||
file_type: str,
|
||||
options: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process document using Unstructured API.
|
||||
|
||||
This is a mock implementation. In production:
|
||||
1. Send content to Unstructured API
|
||||
2. Handle rate limiting and retries
|
||||
3. Parse structured response
|
||||
"""
|
||||
# Mock processing delay
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Mock response structure
|
||||
return {
|
||||
"elements": [
|
||||
{
|
||||
"type": "Title",
|
||||
"text": "Document Title",
|
||||
"metadata": {"page_number": 1}
|
||||
},
|
||||
{
|
||||
"type": "NarrativeText",
|
||||
"text": "This is the main content of the document...",
|
||||
"metadata": {"page_number": 1}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"languages": ["en"],
|
||||
"page_count": 1,
|
||||
"has_tables": False,
|
||||
"has_images": False
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class NativeChunkingEngine:
|
||||
"""
|
||||
Native chunking engine using the existing DocumentProcessorBackend.
|
||||
Fast, lightweight, and suitable for most text documents.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.processor = DocumentProcessorBackend()
|
||||
|
||||
async def process(
|
||||
self,
|
||||
content: bytes,
|
||||
file_type: str,
|
||||
options: ProcessingOptions
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Process document using native chunking"""
|
||||
|
||||
strategy = ChunkingStrategy(
|
||||
strategy_type=options.chunking_strategy,
|
||||
chunk_size=options.chunk_size,
|
||||
chunk_overlap=options.chunk_overlap,
|
||||
preserve_paragraphs=True,
|
||||
preserve_sentences=True
|
||||
)
|
||||
|
||||
chunks = await self.processor.process_document(
|
||||
content=content,
|
||||
document_type=file_type,
|
||||
strategy=strategy,
|
||||
metadata={
|
||||
"processing_timestamp": datetime.utcnow().isoformat(),
|
||||
"engine": "native"
|
||||
}
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class DocumentProcessingPipeline:
|
||||
"""
|
||||
Dual-engine document processing pipeline with capability-based selection.
|
||||
|
||||
Features:
|
||||
- Native engine for fast, simple processing
|
||||
- Unstructured API for advanced features
|
||||
- Capability-based engine selection
|
||||
- Stateless processing with memory cleanup
|
||||
- Optional embedding generation
|
||||
"""
|
||||
|
||||
def __init__(self, resource_cluster_url: Optional[str] = None):
|
||||
self.resource_cluster_url = resource_cluster_url or "http://localhost:8004"
|
||||
self.native_engine = NativeChunkingEngine()
|
||||
self.unstructured_engine = None # Lazy initialization
|
||||
self.embedding_cache = {} # Cache for frequently used embeddings
|
||||
|
||||
logger.info("Document Processing Pipeline initialized")
|
||||
|
||||
def select_engine(
|
||||
self,
|
||||
filename: str,
|
||||
token_data: Dict[str, Any],
|
||||
options: ProcessingOptions
|
||||
) -> str:
|
||||
"""
|
||||
Select processing engine based on file type and capabilities.
|
||||
|
||||
Args:
|
||||
filename: Name of the file being processed
|
||||
token_data: Capability token data
|
||||
options: Processing options
|
||||
|
||||
Returns:
|
||||
Engine name: "native" or "unstructured"
|
||||
"""
|
||||
# Check if user has premium parsing capability
|
||||
has_premium = any(
|
||||
cap.get("resource") == "premium_parsing"
|
||||
for cap in token_data.get("capabilities", [])
|
||||
)
|
||||
|
||||
# Force native if no premium capability
|
||||
if not has_premium and options.engine_preference == "unstructured":
|
||||
logger.info("Premium parsing requested but not available, using native engine")
|
||||
return "native"
|
||||
|
||||
# Auto selection logic
|
||||
if options.engine_preference == "auto":
|
||||
# Use Unstructured for complex formats if available
|
||||
complex_formats = [".pdf", ".docx", ".pptx", ".xlsx"]
|
||||
needs_ocr = options.ocr_enabled
|
||||
needs_tables = filename.lower().endswith((".xlsx", ".csv"))
|
||||
|
||||
if has_premium and (
|
||||
any(filename.lower().endswith(fmt) for fmt in complex_formats) or
|
||||
needs_ocr or needs_tables
|
||||
):
|
||||
return "unstructured"
|
||||
else:
|
||||
return "native"
|
||||
|
||||
# Respect explicit preference if capability allows
|
||||
if options.engine_preference == "unstructured" and has_premium:
|
||||
return "unstructured"
|
||||
|
||||
return "native"
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
file: bytes,
|
||||
filename: str,
|
||||
token_data: Dict[str, Any],
|
||||
options: Optional[ProcessingOptions] = None
|
||||
) -> ProcessingResult:
|
||||
"""
|
||||
Process document with selected engine.
|
||||
|
||||
Args:
|
||||
file: Document content as bytes
|
||||
filename: Name of the file
|
||||
token_data: Capability token data
|
||||
options: Processing options
|
||||
|
||||
Returns:
|
||||
ProcessingResult with chunks, embeddings, and metadata
|
||||
"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# Use default options if not provided
|
||||
if options is None:
|
||||
options = ProcessingOptions()
|
||||
|
||||
# Determine file type
|
||||
file_type = self._get_file_extension(filename)
|
||||
|
||||
# Select engine based on capabilities
|
||||
engine = self.select_engine(filename, token_data, options)
|
||||
|
||||
# Process with selected engine
|
||||
if engine == "unstructured" and token_data.get("has_capability", {}).get("premium_parsing"):
|
||||
result = await self._process_with_unstructured(file, filename, token_data, options)
|
||||
else:
|
||||
result = await self._process_with_native(file, filename, token_data, options)
|
||||
|
||||
# Generate embeddings if requested
|
||||
embeddings = None
|
||||
if options.generate_embeddings:
|
||||
embeddings = await self._generate_embeddings(result.chunks, token_data)
|
||||
|
||||
# Calculate processing time
|
||||
processing_time = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
# Calculate token count
|
||||
token_count = sum(len(chunk["text"].split()) for chunk in result.chunks)
|
||||
|
||||
return ProcessingResult(
|
||||
chunks=result.chunks,
|
||||
embeddings=embeddings,
|
||||
metadata={
|
||||
"filename": filename,
|
||||
"file_type": file_type,
|
||||
"processing_timestamp": start_time.isoformat(),
|
||||
"chunk_count": len(result.chunks),
|
||||
"engine_used": engine,
|
||||
"options": {
|
||||
"chunking_strategy": options.chunking_strategy,
|
||||
"chunk_size": options.chunk_size,
|
||||
"chunk_overlap": options.chunk_overlap
|
||||
}
|
||||
},
|
||||
engine_used=engine,
|
||||
processing_time_ms=processing_time,
|
||||
token_count=token_count
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Ensure memory cleanup
|
||||
del file
|
||||
gc.collect()
|
||||
|
||||
async def _process_with_native(
|
||||
self,
|
||||
file: bytes,
|
||||
filename: str,
|
||||
token_data: Dict[str, Any],
|
||||
options: ProcessingOptions
|
||||
) -> ProcessingResult:
|
||||
"""Process document with native engine"""
|
||||
|
||||
file_type = self._get_file_extension(filename)
|
||||
chunks = await self.native_engine.process(file, file_type, options)
|
||||
|
||||
return ProcessingResult(
|
||||
chunks=chunks,
|
||||
embeddings=None,
|
||||
metadata={"engine": "native"},
|
||||
engine_used="native",
|
||||
processing_time_ms=0,
|
||||
token_count=0
|
||||
)
|
||||
|
||||
async def _process_with_unstructured(
|
||||
self,
|
||||
file: bytes,
|
||||
filename: str,
|
||||
token_data: Dict[str, Any],
|
||||
options: ProcessingOptions
|
||||
) -> ProcessingResult:
|
||||
"""Process document with Unstructured API"""
|
||||
|
||||
# Initialize Unstructured engine if needed
|
||||
if self.unstructured_engine is None:
|
||||
# Get API key from token constraints or environment
|
||||
api_key = token_data.get("constraints", {}).get("unstructured_api_key")
|
||||
self.unstructured_engine = UnstructuredAPIEngine(api_key=api_key)
|
||||
|
||||
file_type = self._get_file_extension(filename)
|
||||
|
||||
# Process with Unstructured
|
||||
unstructured_result = await self.unstructured_engine.process(
|
||||
content=file,
|
||||
file_type=file_type,
|
||||
options={
|
||||
"ocr": options.ocr_enabled,
|
||||
"extract_tables": True,
|
||||
"extract_images": False, # Don't extract images for security
|
||||
"languages": ["en", "es", "fr", "de", "zh"]
|
||||
}
|
||||
)
|
||||
|
||||
# Convert Unstructured elements to chunks
|
||||
chunks = []
|
||||
for element in unstructured_result.get("elements", []):
|
||||
chunk_text = element.get("text", "")
|
||||
if chunk_text.strip():
|
||||
chunks.append({
|
||||
"text": chunk_text,
|
||||
"metadata": {
|
||||
"element_type": element.get("type"),
|
||||
"page_number": element.get("metadata", {}).get("page_number"),
|
||||
"engine": "unstructured"
|
||||
}
|
||||
})
|
||||
|
||||
# Apply chunking strategy if chunks are too large
|
||||
final_chunks = await self._apply_chunking_to_elements(chunks, options)
|
||||
|
||||
return ProcessingResult(
|
||||
chunks=final_chunks,
|
||||
embeddings=None,
|
||||
metadata={
|
||||
"engine": "unstructured",
|
||||
"detected_languages": unstructured_result.get("metadata", {}).get("languages", []),
|
||||
"page_count": unstructured_result.get("metadata", {}).get("page_count", 0),
|
||||
"has_tables": unstructured_result.get("metadata", {}).get("has_tables", False),
|
||||
"has_images": unstructured_result.get("metadata", {}).get("has_images", False)
|
||||
},
|
||||
engine_used="unstructured",
|
||||
processing_time_ms=0,
|
||||
token_count=0
|
||||
)
|
||||
|
||||
async def _apply_chunking_to_elements(
|
||||
self,
|
||||
elements: List[Dict[str, Any]],
|
||||
options: ProcessingOptions
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Apply chunking strategy to Unstructured elements if needed"""
|
||||
|
||||
final_chunks = []
|
||||
|
||||
for element in elements:
|
||||
text = element["text"]
|
||||
|
||||
# Estimate token count (rough approximation)
|
||||
estimated_tokens = len(text.split()) * 1.3
|
||||
|
||||
# If element is small enough, keep as is
|
||||
if estimated_tokens <= options.chunk_size:
|
||||
final_chunks.append(element)
|
||||
else:
|
||||
# Split large elements using native chunking
|
||||
sub_chunks = await self._chunk_text(
|
||||
text,
|
||||
options.chunk_size,
|
||||
options.chunk_overlap
|
||||
)
|
||||
|
||||
for idx, sub_chunk in enumerate(sub_chunks):
|
||||
chunk_metadata = element["metadata"].copy()
|
||||
chunk_metadata["sub_chunk_index"] = idx
|
||||
chunk_metadata["parent_element_type"] = element["metadata"].get("element_type")
|
||||
|
||||
final_chunks.append({
|
||||
"text": sub_chunk,
|
||||
"metadata": chunk_metadata
|
||||
})
|
||||
|
||||
return final_chunks
|
||||
|
||||
async def _chunk_text(
|
||||
self,
|
||||
text: str,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int
|
||||
) -> List[str]:
|
||||
"""Simple text chunking for large elements"""
|
||||
|
||||
words = text.split()
|
||||
chunks = []
|
||||
|
||||
# Simple word-based chunking
|
||||
for i in range(0, len(words), chunk_size - chunk_overlap):
|
||||
chunk_words = words[i:i + chunk_size]
|
||||
chunks.append(" ".join(chunk_words))
|
||||
|
||||
return chunks
|
||||
|
||||
async def _generate_embeddings(
|
||||
self,
|
||||
chunks: List[Dict[str, Any]],
|
||||
token_data: Dict[str, Any]
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for chunks.
|
||||
|
||||
This is a mock implementation. In production, this would:
|
||||
1. Call the embedding service (BGE-M3 or similar)
|
||||
2. Handle batching for efficiency
|
||||
3. Apply caching for common chunks
|
||||
"""
|
||||
embeddings = []
|
||||
|
||||
for chunk in chunks:
|
||||
# Check cache first
|
||||
chunk_hash = hashlib.sha256(chunk["text"].encode()).hexdigest()
|
||||
|
||||
if chunk_hash in self.embedding_cache:
|
||||
embeddings.append(self.embedding_cache[chunk_hash])
|
||||
else:
|
||||
# Mock embedding generation
|
||||
# In production: call embedding API
|
||||
embedding = [0.1] * 768 # Mock 768-dim embedding (BGE-M3 size)
|
||||
embeddings.append(embedding)
|
||||
|
||||
# Cache for reuse (with size limit)
|
||||
if len(self.embedding_cache) < 1000:
|
||||
self.embedding_cache[chunk_hash] = embedding
|
||||
|
||||
return embeddings
|
||||
|
||||
def _get_file_extension(self, filename: str) -> str:
|
||||
"""Extract file extension from filename"""
|
||||
|
||||
parts = filename.lower().split(".")
|
||||
if len(parts) > 1:
|
||||
return f".{parts[-1]}"
|
||||
return ".txt" # Default to text
|
||||
|
||||
async def validate_document(
|
||||
self,
|
||||
file_size: int,
|
||||
filename: str,
|
||||
token_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate document before processing.
|
||||
|
||||
Args:
|
||||
file_size: Size of file in bytes
|
||||
filename: Name of the file
|
||||
token_data: Capability token data
|
||||
|
||||
Returns:
|
||||
Validation result with warnings and errors
|
||||
"""
|
||||
# Get size limits from token
|
||||
max_size = token_data.get("constraints", {}).get("max_file_size", 50 * 1024 * 1024)
|
||||
|
||||
validation = {
|
||||
"valid": True,
|
||||
"warnings": [],
|
||||
"errors": [],
|
||||
"recommendations": []
|
||||
}
|
||||
|
||||
# Check file size
|
||||
if file_size > max_size:
|
||||
validation["valid"] = False
|
||||
validation["errors"].append(f"File exceeds maximum size of {max_size / 1024 / 1024:.1f} MiB")
|
||||
elif file_size > 10 * 1024 * 1024:
|
||||
validation["warnings"].append("Large file may take longer to process")
|
||||
validation["recommendations"].append("Consider using streaming processing for better performance")
|
||||
|
||||
# Check file type
|
||||
file_type = self._get_file_extension(filename)
|
||||
supported_types = [".pdf", ".docx", ".txt", ".md", ".html", ".csv", ".xlsx", ".pptx"]
|
||||
|
||||
if file_type not in supported_types:
|
||||
validation["valid"] = False
|
||||
validation["errors"].append(f"Unsupported file type: {file_type}")
|
||||
validation["recommendations"].append(f"Supported types: {', '.join(supported_types)}")
|
||||
|
||||
# Check for special processing needs
|
||||
if file_type in [".xlsx", ".csv"]:
|
||||
validation["recommendations"].append("Table extraction will be applied automatically")
|
||||
|
||||
if file_type == ".pdf":
|
||||
validation["recommendations"].append("Enable OCR if document contains scanned images")
|
||||
|
||||
return validation
|
||||
|
||||
async def get_processing_stats(self) -> Dict[str, Any]:
|
||||
"""Get processing statistics"""
|
||||
|
||||
return {
|
||||
"engines_available": ["native", "unstructured"],
|
||||
"native_engine_status": "ready",
|
||||
"unstructured_engine_status": "ready" if self.unstructured_engine else "not_initialized",
|
||||
"embedding_cache_size": len(self.embedding_cache),
|
||||
"supported_formats": [".pdf", ".docx", ".txt", ".md", ".html", ".csv", ".xlsx", ".pptx"],
|
||||
"default_chunk_size": 512,
|
||||
"default_chunk_overlap": 128,
|
||||
"stateless": True
|
||||
}
|
||||
447
apps/resource-cluster/app/services/embedding_service.py
Normal file
447
apps/resource-cluster/app/services/embedding_service.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
Embedding Service for GT 2.0 Resource Cluster
|
||||
|
||||
Provides embedding generation with:
|
||||
- BGE-M3 model integration
|
||||
- Batch processing capabilities
|
||||
- Rate limiting and quota management
|
||||
- Capability-based authentication
|
||||
- Stateless operation (no data storage)
|
||||
|
||||
GT 2.0 Architecture Principles:
|
||||
- Perfect Tenant Isolation: Per-request capability validation
|
||||
- Zero Downtime: Stateless design, circuit breakers
|
||||
- Self-Contained Security: Capability-based auth
|
||||
- No Complexity Addition: Simple interface, no database
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, asdict
|
||||
import uuid
|
||||
|
||||
from app.core.backends.embedding_backend import EmbeddingBackend, EmbeddingRequest
|
||||
from app.core.capability_auth import verify_capability_token, CapabilityError
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResponse:
|
||||
"""Response structure for embedding generation"""
|
||||
request_id: str
|
||||
embeddings: List[List[float]]
|
||||
model: str
|
||||
dimensions: int
|
||||
tokens_used: int
|
||||
processing_time_ms: int
|
||||
tenant_id: str
|
||||
created_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingStats:
|
||||
"""Statistics for embedding requests"""
|
||||
total_requests: int = 0
|
||||
total_tokens_processed: int = 0
|
||||
total_processing_time_ms: int = 0
|
||||
average_processing_time_ms: float = 0.0
|
||||
last_request_at: Optional[str] = None
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""
|
||||
STATELESS embedding service for GT 2.0 Resource Cluster.
|
||||
|
||||
Key features:
|
||||
- BGE-M3 model for high-quality embeddings
|
||||
- Batch processing for efficiency
|
||||
- Rate limiting per capability token
|
||||
- Memory-conscious processing
|
||||
- No persistent storage
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.backend = EmbeddingBackend()
|
||||
self.stats = EmbeddingStats()
|
||||
|
||||
# Initialize BGE-M3 tokenizer for accurate token counting
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
|
||||
logger.info("Initialized BGE-M3 tokenizer for accurate token counting")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load BGE-M3 tokenizer, using word estimation: {e}")
|
||||
self.tokenizer = None
|
||||
|
||||
# Rate limiting settings (per capability token)
|
||||
self.rate_limits = {
|
||||
"requests_per_minute": 60,
|
||||
"tokens_per_minute": 50000,
|
||||
"max_batch_size": 32
|
||||
}
|
||||
|
||||
# Track requests for rate limiting (in-memory, temporary)
|
||||
self._request_tracker = {}
|
||||
|
||||
logger.info("STATELESS embedding service initialized")
|
||||
|
||||
async def generate_embeddings(
|
||||
self,
|
||||
texts: List[str],
|
||||
capability_token: str,
|
||||
instruction: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
normalize: bool = True
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Generate embeddings with capability-based authentication.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
capability_token: JWT token with embedding permissions
|
||||
instruction: Optional instruction for embedding context
|
||||
request_id: Optional request ID for tracking
|
||||
normalize: Whether to normalize embeddings
|
||||
|
||||
Returns:
|
||||
EmbeddingResponse with generated embeddings
|
||||
|
||||
Raises:
|
||||
CapabilityError: If token invalid or insufficient permissions
|
||||
ValueError: If request parameters invalid
|
||||
"""
|
||||
start_time = time.time()
|
||||
request_id = request_id or str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Verify capability token and extract permissions
|
||||
capability = await verify_capability_token(capability_token)
|
||||
tenant_id = capability.get("tenant_id")
|
||||
user_id = capability.get("sub") # Extract user ID from token
|
||||
|
||||
# Check embedding permissions
|
||||
await self._verify_embedding_permissions(capability, len(texts))
|
||||
|
||||
# Apply rate limiting
|
||||
await self._check_rate_limits(capability_token, len(texts))
|
||||
|
||||
# Validate input
|
||||
self._validate_embedding_request(texts)
|
||||
|
||||
# Generate embeddings via backend
|
||||
embeddings = await self.backend.generate_embeddings(
|
||||
texts=texts,
|
||||
instruction=instruction,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
# Calculate processing metrics
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
total_tokens = self._estimate_tokens(texts)
|
||||
|
||||
# Update statistics
|
||||
self._update_stats(total_tokens, processing_time_ms)
|
||||
|
||||
# Log embedding usage for billing (non-blocking)
|
||||
# Fire and forget - don't wait for completion
|
||||
asyncio.create_task(
|
||||
self._log_embedding_usage(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
tokens_used=total_tokens,
|
||||
embedding_count=len(embeddings),
|
||||
model=self.backend.model_name,
|
||||
request_id=request_id
|
||||
)
|
||||
)
|
||||
|
||||
# Create response
|
||||
response = EmbeddingResponse(
|
||||
request_id=request_id,
|
||||
embeddings=embeddings,
|
||||
model=self.backend.model_name,
|
||||
dimensions=self.backend.embedding_dimensions,
|
||||
tokens_used=total_tokens,
|
||||
processing_time_ms=processing_time_ms,
|
||||
tenant_id=tenant_id,
|
||||
created_at=datetime.utcnow().isoformat()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated {len(embeddings)} embeddings for tenant {tenant_id} "
|
||||
f"in {processing_time_ms}ms"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Always ensure cleanup
|
||||
if 'texts' in locals():
|
||||
del texts
|
||||
|
||||
async def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the embedding model"""
|
||||
return {
|
||||
"model_name": self.backend.model_name,
|
||||
"dimensions": self.backend.embedding_dimensions,
|
||||
"max_sequence_length": self.backend.max_sequence_length,
|
||||
"max_batch_size": self.backend.max_batch_size,
|
||||
"supports_instruction": True,
|
||||
"normalization_default": True
|
||||
}
|
||||
|
||||
async def get_service_stats(
|
||||
self,
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get service statistics (for admin users only).
|
||||
|
||||
Args:
|
||||
capability_token: JWT token with admin permissions
|
||||
|
||||
Returns:
|
||||
Service statistics
|
||||
"""
|
||||
# Verify admin permissions
|
||||
capability = await verify_capability_token(capability_token)
|
||||
if not self._has_admin_permissions(capability):
|
||||
raise CapabilityError("Admin permissions required")
|
||||
|
||||
return {
|
||||
"model_info": await self.get_model_info(),
|
||||
"statistics": asdict(self.stats),
|
||||
"rate_limits": self.rate_limits,
|
||||
"active_requests": len(self._request_tracker)
|
||||
}
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Check service health"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "embedding_service",
|
||||
"model": self.backend.model_name,
|
||||
"backend_ready": True,
|
||||
"last_request": self.stats.last_request_at
|
||||
}
|
||||
|
||||
async def _verify_embedding_permissions(
|
||||
self,
|
||||
capability: Dict[str, Any],
|
||||
text_count: int
|
||||
) -> None:
|
||||
"""Verify that capability token has embedding permissions"""
|
||||
|
||||
# Check for embedding capability
|
||||
capabilities = capability.get("capabilities", [])
|
||||
embedding_caps = [
|
||||
cap for cap in capabilities
|
||||
if cap.get("resource") == "embeddings"
|
||||
]
|
||||
|
||||
if not embedding_caps:
|
||||
raise CapabilityError("No embedding permissions in capability token")
|
||||
|
||||
# Check constraints
|
||||
embedding_cap = embedding_caps[0] # Use first embedding capability
|
||||
constraints = embedding_cap.get("constraints", {})
|
||||
|
||||
# Check batch size limit
|
||||
max_batch = constraints.get("max_batch_size", self.rate_limits["max_batch_size"])
|
||||
if text_count > max_batch:
|
||||
raise CapabilityError(f"Batch size {text_count} exceeds limit {max_batch}")
|
||||
|
||||
# Check rate limits
|
||||
rate_limit = constraints.get("rate_limit_per_minute", self.rate_limits["requests_per_minute"])
|
||||
token_limit = constraints.get("tokens_per_minute", self.rate_limits["tokens_per_minute"])
|
||||
|
||||
logger.debug(f"Embedding permissions verified: batch={text_count}, limits=({rate_limit}, {token_limit})")
|
||||
|
||||
async def _check_rate_limits(
|
||||
self,
|
||||
capability_token: str,
|
||||
text_count: int
|
||||
) -> None:
|
||||
"""Check rate limits for capability token"""
|
||||
|
||||
now = time.time()
|
||||
token_hash = hash(capability_token) % 10000 # Simple tracking key
|
||||
|
||||
# Clean old entries (older than 1 minute)
|
||||
cleanup_time = now - 60
|
||||
self._request_tracker = {
|
||||
k: v for k, v in self._request_tracker.items()
|
||||
if v.get("last_request", 0) > cleanup_time
|
||||
}
|
||||
|
||||
# Get or create tracker for this token
|
||||
if token_hash not in self._request_tracker:
|
||||
self._request_tracker[token_hash] = {
|
||||
"requests": 0,
|
||||
"tokens": 0,
|
||||
"last_request": now
|
||||
}
|
||||
|
||||
tracker = self._request_tracker[token_hash]
|
||||
|
||||
# Check request rate limit
|
||||
if tracker["requests"] >= self.rate_limits["requests_per_minute"]:
|
||||
raise CapabilityError("Rate limit exceeded: too many requests per minute")
|
||||
|
||||
# Estimate tokens and check token limit
|
||||
estimated_tokens = self._estimate_tokens([f"text_{i}" for i in range(text_count)])
|
||||
if tracker["tokens"] + estimated_tokens > self.rate_limits["tokens_per_minute"]:
|
||||
raise CapabilityError("Rate limit exceeded: too many tokens per minute")
|
||||
|
||||
# Update tracker
|
||||
tracker["requests"] += 1
|
||||
tracker["tokens"] += estimated_tokens
|
||||
tracker["last_request"] = now
|
||||
|
||||
def _validate_embedding_request(self, texts: List[str]) -> None:
|
||||
"""Validate embedding request parameters"""
|
||||
|
||||
if not texts:
|
||||
raise ValueError("No texts provided for embedding")
|
||||
|
||||
if not isinstance(texts, list):
|
||||
raise ValueError("Texts must be a list")
|
||||
|
||||
if len(texts) > self.backend.max_batch_size:
|
||||
raise ValueError(f"Batch size {len(texts)} exceeds maximum {self.backend.max_batch_size}")
|
||||
|
||||
# Check individual text lengths
|
||||
for i, text in enumerate(texts):
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Text at index {i} must be a string")
|
||||
|
||||
if len(text.strip()) == 0:
|
||||
raise ValueError(f"Text at index {i} is empty")
|
||||
|
||||
# Simple token estimation for length check
|
||||
estimated_tokens = len(text.split()) * 1.3 # Rough estimation
|
||||
if estimated_tokens > self.backend.max_sequence_length:
|
||||
raise ValueError(f"Text at index {i} exceeds maximum length")
|
||||
|
||||
def _estimate_tokens(self, texts: List[str]) -> int:
|
||||
"""
|
||||
Count tokens using actual BGE-M3 tokenizer.
|
||||
Falls back to word-count estimation if tokenizer unavailable.
|
||||
"""
|
||||
if self.tokenizer is not None:
|
||||
try:
|
||||
total_tokens = 0
|
||||
for text in texts:
|
||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
total_tokens += len(tokens)
|
||||
return total_tokens
|
||||
except Exception as e:
|
||||
logger.warning(f"Tokenizer error, falling back to estimation: {e}")
|
||||
|
||||
# Fallback: word count * 1.3 (rough estimation)
|
||||
total_words = sum(len(text.split()) for text in texts)
|
||||
return int(total_words * 1.3)
|
||||
|
||||
def _has_admin_permissions(self, capability: Dict[str, Any]) -> bool:
|
||||
"""Check if capability has admin permissions"""
|
||||
capabilities = capability.get("capabilities", [])
|
||||
return any(
|
||||
cap.get("resource") == "admin" and "stats" in cap.get("actions", [])
|
||||
for cap in capabilities
|
||||
)
|
||||
|
||||
def _update_stats(self, tokens_processed: int, processing_time_ms: int) -> None:
|
||||
"""Update service statistics"""
|
||||
self.stats.total_requests += 1
|
||||
self.stats.total_tokens_processed += tokens_processed
|
||||
self.stats.total_processing_time_ms += processing_time_ms
|
||||
self.stats.average_processing_time_ms = (
|
||||
self.stats.total_processing_time_ms / self.stats.total_requests
|
||||
)
|
||||
self.stats.last_request_at = datetime.utcnow().isoformat()
|
||||
|
||||
async def _log_embedding_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
tokens_used: int,
|
||||
embedding_count: int,
|
||||
model: str = "BAAI/bge-m3",
|
||||
request_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log embedding usage to control panel database for billing.
|
||||
|
||||
This method logs usage asynchronously and does not block the embedding response.
|
||||
Failures are logged as warnings but do not raise exceptions.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
user_id: User identifier (from capability token 'sub')
|
||||
tokens_used: Number of tokens processed
|
||||
embedding_count: Number of embeddings generated
|
||||
model: Embedding model name
|
||||
request_id: Optional request ID for tracking
|
||||
"""
|
||||
try:
|
||||
import asyncpg
|
||||
|
||||
# Calculate cost: BGE-M3 pricing ~$0.10 per million tokens
|
||||
cost_cents = (tokens_used / 1_000_000) * 0.10 * 100
|
||||
|
||||
# Connect to control panel database
|
||||
# Using environment variables from docker-compose
|
||||
db_password = os.getenv("CONTROL_PANEL_DB_PASSWORD")
|
||||
if not db_password:
|
||||
logger.warning("CONTROL_PANEL_DB_PASSWORD not set, skipping embedding usage logging")
|
||||
return
|
||||
|
||||
conn = await asyncpg.connect(
|
||||
host=os.getenv("CONTROL_PANEL_DB_HOST", "gentwo-controlpanel-postgres"),
|
||||
database=os.getenv("CONTROL_PANEL_DB_NAME", "gt2_admin"),
|
||||
user=os.getenv("CONTROL_PANEL_DB_USER", "postgres"),
|
||||
password=db_password,
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
try:
|
||||
# Insert usage log
|
||||
await conn.execute("""
|
||||
INSERT INTO public.embedding_usage_logs
|
||||
(tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""", tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
|
||||
|
||||
logger.info(
|
||||
f"Logged embedding usage: tenant={tenant_id}, user={user_id}, "
|
||||
f"tokens={tokens_used}, embeddings={embedding_count}, cost_cents={cost_cents:.4f}"
|
||||
)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
except Exception as e:
|
||||
# Log warning but don't fail the embedding request
|
||||
logger.warning(f"Failed to log embedding usage for tenant {tenant_id}: {e}")
|
||||
|
||||
|
||||
# Global service instance
|
||||
_embedding_service = None
|
||||
|
||||
|
||||
def get_embedding_service() -> EmbeddingService:
|
||||
"""Get the global embedding service instance"""
|
||||
global _embedding_service
|
||||
if _embedding_service is None:
|
||||
_embedding_service = EmbeddingService()
|
||||
return _embedding_service
|
||||
729
apps/resource-cluster/app/services/integration_proxy.py
Normal file
729
apps/resource-cluster/app/services/integration_proxy.py
Normal file
@@ -0,0 +1,729 @@
|
||||
"""
|
||||
Integration Proxy Service for GT 2.0
|
||||
|
||||
Secure proxy service for external integrations with capability-based access control,
|
||||
sandbox restrictions, and comprehensive audit logging. All external calls are routed
|
||||
through this service in the Resource Cluster for security and monitoring.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import httpx
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from app.core.security import verify_capability_token
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class IntegrationType(Enum):
|
||||
"""Types of external integrations"""
|
||||
COMMUNICATION = "communication" # Slack, Teams, Discord
|
||||
DEVELOPMENT = "development" # GitHub, GitLab, Jira
|
||||
PROJECT_MANAGEMENT = "project_management" # Asana, Monday.com
|
||||
DATABASE = "database" # PostgreSQL, MySQL, MongoDB
|
||||
CUSTOM_API = "custom_api" # Custom REST/GraphQL APIs
|
||||
WEBHOOK = "webhook" # Outbound webhook calls
|
||||
|
||||
|
||||
class SandboxLevel(Enum):
|
||||
"""Sandbox restriction levels"""
|
||||
NONE = "none" # No restrictions (trusted)
|
||||
BASIC = "basic" # Basic timeout and size limits
|
||||
RESTRICTED = "restricted" # Limited API calls and data access
|
||||
STRICT = "strict" # Maximum restrictions
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntegrationConfig:
|
||||
"""Configuration for external integration"""
|
||||
id: str
|
||||
name: str
|
||||
integration_type: IntegrationType
|
||||
base_url: str
|
||||
authentication_method: str # oauth2, api_key, basic_auth, certificate
|
||||
sandbox_level: SandboxLevel
|
||||
|
||||
# Authentication details (encrypted)
|
||||
auth_config: Dict[str, Any]
|
||||
|
||||
# Rate limits and constraints
|
||||
max_requests_per_hour: int = 1000
|
||||
max_response_size_bytes: int = 10 * 1024 * 1024 # 10MB
|
||||
timeout_seconds: int = 30
|
||||
|
||||
# Allowed operations
|
||||
allowed_methods: List[str] = None
|
||||
allowed_endpoints: List[str] = None
|
||||
blocked_endpoints: List[str] = None
|
||||
|
||||
# Network restrictions
|
||||
allowed_domains: List[str] = None
|
||||
|
||||
# Created metadata
|
||||
created_at: datetime = None
|
||||
created_by: str = ""
|
||||
is_active: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
if self.allowed_methods is None:
|
||||
self.allowed_methods = ["GET", "POST"]
|
||||
if self.allowed_endpoints is None:
|
||||
self.allowed_endpoints = []
|
||||
if self.blocked_endpoints is None:
|
||||
self.blocked_endpoints = []
|
||||
if self.allowed_domains is None:
|
||||
self.allowed_domains = []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage"""
|
||||
data = asdict(self)
|
||||
data["integration_type"] = self.integration_type.value
|
||||
data["sandbox_level"] = self.sandbox_level.value
|
||||
data["created_at"] = self.created_at.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "IntegrationConfig":
|
||||
"""Create from dictionary"""
|
||||
data["integration_type"] = IntegrationType(data["integration_type"])
|
||||
data["sandbox_level"] = SandboxLevel(data["sandbox_level"])
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProxyRequest:
|
||||
"""Request to proxy to external service"""
|
||||
integration_id: str
|
||||
method: str
|
||||
endpoint: str
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
params: Optional[Dict[str, str]] = None
|
||||
timeout_override: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.headers is None:
|
||||
self.headers = {}
|
||||
if self.data is None:
|
||||
self.data = {}
|
||||
if self.params is None:
|
||||
self.params = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProxyResponse:
|
||||
"""Response from proxied external service"""
|
||||
success: bool
|
||||
status_code: int
|
||||
data: Optional[Dict[str, Any]]
|
||||
headers: Dict[str, str]
|
||||
execution_time_ms: int
|
||||
sandbox_applied: bool
|
||||
restrictions_applied: List[str]
|
||||
error_message: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.headers is None:
|
||||
self.headers = {}
|
||||
if self.restrictions_applied is None:
|
||||
self.restrictions_applied = []
|
||||
|
||||
|
||||
class SandboxManager:
|
||||
"""Manages sandbox restrictions for external integrations"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_requests: Dict[str, datetime] = {}
|
||||
self.rate_limiters: Dict[str, List[datetime]] = {}
|
||||
|
||||
def apply_sandbox_restrictions(
|
||||
self,
|
||||
config: IntegrationConfig,
|
||||
request: ProxyRequest,
|
||||
capability_token: Dict[str, Any]
|
||||
) -> Tuple[ProxyRequest, List[str]]:
|
||||
"""Apply sandbox restrictions to request"""
|
||||
restrictions_applied = []
|
||||
|
||||
if config.sandbox_level == SandboxLevel.NONE:
|
||||
return request, restrictions_applied
|
||||
|
||||
# Apply timeout restrictions
|
||||
if config.sandbox_level in [SandboxLevel.BASIC, SandboxLevel.RESTRICTED, SandboxLevel.STRICT]:
|
||||
max_timeout = self._get_max_timeout(config.sandbox_level)
|
||||
if request.timeout_override is None or request.timeout_override > max_timeout:
|
||||
request.timeout_override = max_timeout
|
||||
restrictions_applied.append(f"timeout_limited_to_{max_timeout}s")
|
||||
|
||||
# Apply endpoint restrictions
|
||||
if config.sandbox_level in [SandboxLevel.RESTRICTED, SandboxLevel.STRICT]:
|
||||
# Check blocked endpoints first
|
||||
if request.endpoint in config.blocked_endpoints:
|
||||
raise PermissionError(f"Endpoint {request.endpoint} is blocked")
|
||||
|
||||
# Then check allowed endpoints if specified
|
||||
if config.allowed_endpoints and request.endpoint not in config.allowed_endpoints:
|
||||
raise PermissionError(f"Endpoint {request.endpoint} not allowed")
|
||||
|
||||
restrictions_applied.append("endpoint_validation")
|
||||
|
||||
# Apply method restrictions
|
||||
if config.sandbox_level == SandboxLevel.STRICT:
|
||||
allowed_methods = config.allowed_methods or ["GET", "POST"]
|
||||
if request.method not in allowed_methods:
|
||||
raise PermissionError(f"HTTP method {request.method} not allowed in strict mode")
|
||||
restrictions_applied.append("method_restricted")
|
||||
|
||||
# Apply data size restrictions
|
||||
if request.data:
|
||||
data_size = len(json.dumps(request.data).encode())
|
||||
max_size = self._get_max_data_size(config.sandbox_level)
|
||||
if data_size > max_size:
|
||||
raise ValueError(f"Request data size {data_size} exceeds limit {max_size}")
|
||||
restrictions_applied.append("data_size_validated")
|
||||
|
||||
# Apply capability-based restrictions
|
||||
constraints = capability_token.get("constraints", {})
|
||||
if "integration_timeout_seconds" in constraints:
|
||||
max_cap_timeout = constraints["integration_timeout_seconds"]
|
||||
if request.timeout_override > max_cap_timeout:
|
||||
request.timeout_override = max_cap_timeout
|
||||
restrictions_applied.append(f"capability_timeout_{max_cap_timeout}s")
|
||||
|
||||
return request, restrictions_applied
|
||||
|
||||
def _get_max_timeout(self, sandbox_level: SandboxLevel) -> int:
|
||||
"""Get maximum timeout for sandbox level"""
|
||||
timeouts = {
|
||||
SandboxLevel.BASIC: 60,
|
||||
SandboxLevel.RESTRICTED: 30,
|
||||
SandboxLevel.STRICT: 15
|
||||
}
|
||||
return timeouts.get(sandbox_level, 30)
|
||||
|
||||
def _get_max_data_size(self, sandbox_level: SandboxLevel) -> int:
|
||||
"""Get maximum data size for sandbox level"""
|
||||
sizes = {
|
||||
SandboxLevel.BASIC: 1024 * 1024, # 1MB
|
||||
SandboxLevel.RESTRICTED: 512 * 1024, # 512KB
|
||||
SandboxLevel.STRICT: 256 * 1024 # 256KB
|
||||
}
|
||||
return sizes.get(sandbox_level, 512 * 1024)
|
||||
|
||||
async def check_rate_limits(self, integration_id: str, config: IntegrationConfig) -> bool:
|
||||
"""Check if request is within rate limits"""
|
||||
now = datetime.utcnow()
|
||||
hour_ago = now - timedelta(hours=1)
|
||||
|
||||
# Initialize or clean rate limiter
|
||||
if integration_id not in self.rate_limiters:
|
||||
self.rate_limiters[integration_id] = []
|
||||
|
||||
# Remove old requests
|
||||
self.rate_limiters[integration_id] = [
|
||||
req_time for req_time in self.rate_limiters[integration_id]
|
||||
if req_time > hour_ago
|
||||
]
|
||||
|
||||
# Check rate limit
|
||||
if len(self.rate_limiters[integration_id]) >= config.max_requests_per_hour:
|
||||
return False
|
||||
|
||||
# Record this request
|
||||
self.rate_limiters[integration_id].append(now)
|
||||
return True
|
||||
|
||||
|
||||
class IntegrationProxyService:
|
||||
"""
|
||||
Integration Proxy Service for secure external API access.
|
||||
|
||||
Features:
|
||||
- Capability-based access control
|
||||
- Sandbox restrictions based on trust level
|
||||
- Rate limiting and usage tracking
|
||||
- Comprehensive audit logging
|
||||
- Response sanitization and size limits
|
||||
"""
|
||||
|
||||
def __init__(self, base_path: Optional[Path] = None):
|
||||
self.base_path = base_path or Path("/data/resource-cluster/integrations")
|
||||
self.configs_path = self.base_path / "configs"
|
||||
self.usage_path = self.base_path / "usage"
|
||||
self.audit_path = self.base_path / "audit"
|
||||
|
||||
self.sandbox_manager = SandboxManager()
|
||||
self.http_client = None
|
||||
|
||||
# Ensure directories exist with proper permissions
|
||||
self._ensure_directories()
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure storage directories exist with proper permissions"""
|
||||
for path in [self.configs_path, self.usage_path, self.audit_path]:
|
||||
path.mkdir(parents=True, exist_ok=True, mode=0o700)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_http_client(self):
|
||||
"""Get HTTP client with proper configuration"""
|
||||
if self.http_client is None:
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(60.0),
|
||||
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20)
|
||||
)
|
||||
try:
|
||||
yield self.http_client
|
||||
finally:
|
||||
# Client stays open for reuse
|
||||
pass
|
||||
|
||||
async def execute_integration(
|
||||
self,
|
||||
request: ProxyRequest,
|
||||
capability_token: str
|
||||
) -> ProxyResponse:
|
||||
"""Execute integration request with security and sandbox restrictions"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# Verify capability token
|
||||
token_obj = verify_capability_token(capability_token)
|
||||
if not token_obj:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Convert token object to dict for compatibility
|
||||
token_data = {
|
||||
"tenant_id": token_obj.tenant_id,
|
||||
"sub": token_obj.sub,
|
||||
"capabilities": [cap.dict() if hasattr(cap, 'dict') else cap for cap in token_obj.capabilities],
|
||||
"constraints": {}
|
||||
}
|
||||
|
||||
# Load integration configuration
|
||||
config = await self._load_integration_config(request.integration_id)
|
||||
if not config or not config.is_active:
|
||||
raise ValueError(f"Integration {request.integration_id} not found or inactive")
|
||||
|
||||
# Validate capability for this integration
|
||||
required_capability = f"integration:{request.integration_id}:{request.method.lower()}"
|
||||
if not self._has_capability(token_data, required_capability):
|
||||
raise PermissionError(f"Missing capability: {required_capability}")
|
||||
|
||||
# Check rate limits
|
||||
if not await self.sandbox_manager.check_rate_limits(request.integration_id, config):
|
||||
raise PermissionError("Rate limit exceeded")
|
||||
|
||||
# Apply sandbox restrictions
|
||||
sandboxed_request, restrictions = self.sandbox_manager.apply_sandbox_restrictions(
|
||||
config, request, token_data
|
||||
)
|
||||
|
||||
# Execute the request
|
||||
response = await self._execute_proxied_request(config, sandboxed_request)
|
||||
response.sandbox_applied = len(restrictions) > 0
|
||||
response.restrictions_applied = restrictions
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
response.execution_time_ms = int(execution_time)
|
||||
|
||||
# Log usage
|
||||
await self._log_usage(
|
||||
integration_id=request.integration_id,
|
||||
tenant_id=token_data.get("tenant_id"),
|
||||
user_id=token_data.get("sub"),
|
||||
method=request.method,
|
||||
endpoint=request.endpoint,
|
||||
success=response.success,
|
||||
execution_time_ms=response.execution_time_ms
|
||||
)
|
||||
|
||||
# Audit log
|
||||
await self._audit_log(
|
||||
action="integration_executed",
|
||||
integration_id=request.integration_id,
|
||||
user_id=token_data.get("sub"),
|
||||
details={
|
||||
"method": request.method,
|
||||
"endpoint": request.endpoint,
|
||||
"success": response.success,
|
||||
"restrictions_applied": restrictions
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Integration execution failed: {e}")
|
||||
|
||||
# Log error
|
||||
execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
await self._log_usage(
|
||||
integration_id=request.integration_id,
|
||||
tenant_id=token_data.get("tenant_id") if 'token_data' in locals() else "unknown",
|
||||
user_id=token_data.get("sub") if 'token_data' in locals() else "unknown",
|
||||
method=request.method,
|
||||
endpoint=request.endpoint,
|
||||
success=False,
|
||||
execution_time_ms=int(execution_time),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
return ProxyResponse(
|
||||
success=False,
|
||||
status_code=500,
|
||||
data=None,
|
||||
headers={},
|
||||
execution_time_ms=int(execution_time),
|
||||
sandbox_applied=False,
|
||||
restrictions_applied=[],
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
async def _execute_proxied_request(
|
||||
self,
|
||||
config: IntegrationConfig,
|
||||
request: ProxyRequest
|
||||
) -> ProxyResponse:
|
||||
"""Execute the actual HTTP request to external service"""
|
||||
|
||||
# Build URL
|
||||
if request.endpoint.startswith('http'):
|
||||
url = request.endpoint
|
||||
else:
|
||||
url = f"{config.base_url.rstrip('/')}/{request.endpoint.lstrip('/')}"
|
||||
|
||||
# Apply authentication
|
||||
headers = request.headers.copy()
|
||||
await self._apply_authentication(config, headers)
|
||||
|
||||
# Set timeout
|
||||
timeout = request.timeout_override or config.timeout_seconds
|
||||
|
||||
try:
|
||||
async with self.get_http_client() as client:
|
||||
# Execute request
|
||||
if request.method.upper() == "GET":
|
||||
response = await client.get(
|
||||
url,
|
||||
headers=headers,
|
||||
params=request.params,
|
||||
timeout=timeout
|
||||
)
|
||||
elif request.method.upper() == "POST":
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=request.data,
|
||||
params=request.params,
|
||||
timeout=timeout
|
||||
)
|
||||
elif request.method.upper() == "PUT":
|
||||
response = await client.put(
|
||||
url,
|
||||
headers=headers,
|
||||
json=request.data,
|
||||
params=request.params,
|
||||
timeout=timeout
|
||||
)
|
||||
elif request.method.upper() == "DELETE":
|
||||
response = await client.delete(
|
||||
url,
|
||||
headers=headers,
|
||||
params=request.params,
|
||||
timeout=timeout
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {request.method}")
|
||||
|
||||
# Check response size
|
||||
if len(response.content) > config.max_response_size_bytes:
|
||||
raise ValueError(f"Response size exceeds limit: {len(response.content)}")
|
||||
|
||||
# Parse response
|
||||
try:
|
||||
data = response.json() if response.content else {}
|
||||
except json.JSONDecodeError:
|
||||
data = {"raw_content": response.text}
|
||||
|
||||
return ProxyResponse(
|
||||
success=200 <= response.status_code < 300,
|
||||
status_code=response.status_code,
|
||||
data=data,
|
||||
headers=dict(response.headers),
|
||||
execution_time_ms=0, # Will be set by caller
|
||||
sandbox_applied=False # Will be set by caller
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return ProxyResponse(
|
||||
success=False,
|
||||
status_code=408,
|
||||
data=None,
|
||||
headers={},
|
||||
execution_time_ms=timeout * 1000,
|
||||
sandbox_applied=False,
|
||||
restrictions_applied=[],
|
||||
error_message="Request timeout"
|
||||
)
|
||||
except Exception as e:
|
||||
return ProxyResponse(
|
||||
success=False,
|
||||
status_code=500,
|
||||
data=None,
|
||||
headers={},
|
||||
execution_time_ms=0,
|
||||
sandbox_applied=False,
|
||||
restrictions_applied=[],
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
async def _apply_authentication(self, config: IntegrationConfig, headers: Dict[str, str]):
|
||||
"""Apply authentication to request headers"""
|
||||
auth_config = config.auth_config
|
||||
|
||||
if config.authentication_method == "api_key":
|
||||
api_key = auth_config.get("api_key")
|
||||
key_header = auth_config.get("key_header", "Authorization")
|
||||
key_prefix = auth_config.get("key_prefix", "Bearer")
|
||||
|
||||
if api_key:
|
||||
headers[key_header] = f"{key_prefix} {api_key}"
|
||||
|
||||
elif config.authentication_method == "basic_auth":
|
||||
username = auth_config.get("username")
|
||||
password = auth_config.get("password")
|
||||
|
||||
if username and password:
|
||||
import base64
|
||||
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
|
||||
elif config.authentication_method == "oauth2":
|
||||
access_token = auth_config.get("access_token")
|
||||
if access_token:
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
|
||||
# Add custom headers
|
||||
custom_headers = auth_config.get("custom_headers", {})
|
||||
headers.update(custom_headers)
|
||||
|
||||
def _has_capability(self, token_data: Dict[str, Any], required_capability: str) -> bool:
|
||||
"""Check if token has required capability"""
|
||||
capabilities = token_data.get("capabilities", [])
|
||||
|
||||
for capability in capabilities:
|
||||
if isinstance(capability, dict):
|
||||
resource = capability.get("resource", "")
|
||||
# Handle wildcard matching
|
||||
if resource == required_capability:
|
||||
return True
|
||||
if resource.endswith("*"):
|
||||
prefix = resource[:-1] # Remove the *
|
||||
if required_capability.startswith(prefix):
|
||||
return True
|
||||
elif isinstance(capability, str):
|
||||
# Handle wildcard matching for string capabilities
|
||||
if capability == required_capability:
|
||||
return True
|
||||
if capability.endswith("*"):
|
||||
prefix = capability[:-1] # Remove the *
|
||||
if required_capability.startswith(prefix):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _load_integration_config(self, integration_id: str) -> Optional[IntegrationConfig]:
|
||||
"""Load integration configuration from storage"""
|
||||
config_file = self.configs_path / f"{integration_id}.json"
|
||||
|
||||
if not config_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_file, "r") as f:
|
||||
data = json.load(f)
|
||||
return IntegrationConfig.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load integration config {integration_id}: {e}")
|
||||
return None
|
||||
|
||||
async def store_integration_config(self, config: IntegrationConfig) -> bool:
|
||||
"""Store integration configuration"""
|
||||
config_file = self.configs_path / f"{config.id}.json"
|
||||
|
||||
try:
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(config.to_dict(), f, indent=2)
|
||||
|
||||
# Set secure permissions
|
||||
config_file.chmod(0o600)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store integration config {config.id}: {e}")
|
||||
return False
|
||||
|
||||
async def _log_usage(
|
||||
self,
|
||||
integration_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
success: bool,
|
||||
execution_time_ms: int,
|
||||
error: Optional[str] = None
|
||||
):
|
||||
"""Log integration usage for analytics"""
|
||||
date_str = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
usage_file = self.usage_path / f"usage_{date_str}.jsonl"
|
||||
|
||||
usage_record = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"integration_id": integration_id,
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"method": method,
|
||||
"endpoint": endpoint,
|
||||
"success": success,
|
||||
"execution_time_ms": execution_time_ms,
|
||||
"error": error
|
||||
}
|
||||
|
||||
try:
|
||||
with open(usage_file, "a") as f:
|
||||
f.write(json.dumps(usage_record) + "\n")
|
||||
|
||||
# Set secure permissions on file
|
||||
usage_file.chmod(0o600)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log usage: {e}")
|
||||
|
||||
async def _audit_log(
|
||||
self,
|
||||
action: str,
|
||||
integration_id: str,
|
||||
user_id: str,
|
||||
details: Dict[str, Any]
|
||||
):
|
||||
"""Log audit trail for integration actions"""
|
||||
date_str = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
audit_file = self.audit_path / f"audit_{date_str}.jsonl"
|
||||
|
||||
audit_record = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"action": action,
|
||||
"integration_id": integration_id,
|
||||
"user_id": user_id,
|
||||
"details": details
|
||||
}
|
||||
|
||||
try:
|
||||
with open(audit_file, "a") as f:
|
||||
f.write(json.dumps(audit_record) + "\n")
|
||||
|
||||
# Set secure permissions on file
|
||||
audit_file.chmod(0o600)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit: {e}")
|
||||
|
||||
async def list_integrations(self, capability_token: str) -> List[IntegrationConfig]:
|
||||
"""List available integrations based on capabilities"""
|
||||
token_obj = verify_capability_token(capability_token)
|
||||
if not token_obj:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Convert token object to dict for compatibility
|
||||
token_data = {
|
||||
"tenant_id": token_obj.tenant_id,
|
||||
"sub": token_obj.sub,
|
||||
"capabilities": [cap.dict() if hasattr(cap, 'dict') else cap for cap in token_obj.capabilities],
|
||||
"constraints": {}
|
||||
}
|
||||
|
||||
integrations = []
|
||||
|
||||
for config_file in self.configs_path.glob("*.json"):
|
||||
try:
|
||||
with open(config_file, "r") as f:
|
||||
data = json.load(f)
|
||||
config = IntegrationConfig.from_dict(data)
|
||||
|
||||
# Check if user has capability for this integration
|
||||
required_capability = f"integration:{config.id}:*"
|
||||
if self._has_capability(token_data, required_capability):
|
||||
integrations.append(config)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load integration config {config_file}: {e}")
|
||||
|
||||
return integrations
|
||||
|
||||
async def get_integration_usage_analytics(
|
||||
self,
|
||||
integration_id: str,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get usage analytics for integration"""
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=days-1) # Include today in the range
|
||||
|
||||
total_requests = 0
|
||||
successful_requests = 0
|
||||
total_execution_time = 0
|
||||
error_count = 0
|
||||
|
||||
# Process usage logs
|
||||
for day_offset in range(days):
|
||||
date = start_date + timedelta(days=day_offset)
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
usage_file = self.usage_path / f"usage_{date_str}.jsonl"
|
||||
|
||||
if usage_file.exists():
|
||||
try:
|
||||
with open(usage_file, "r") as f:
|
||||
for line in f:
|
||||
record = json.loads(line.strip())
|
||||
if record["integration_id"] == integration_id:
|
||||
total_requests += 1
|
||||
if record["success"]:
|
||||
successful_requests += 1
|
||||
else:
|
||||
error_count += 1
|
||||
total_execution_time += record["execution_time_ms"]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process usage file {usage_file}: {e}")
|
||||
|
||||
return {
|
||||
"integration_id": integration_id,
|
||||
"total_requests": total_requests,
|
||||
"successful_requests": successful_requests,
|
||||
"error_count": error_count,
|
||||
"success_rate": successful_requests / total_requests if total_requests > 0 else 0,
|
||||
"avg_execution_time_ms": total_execution_time / total_requests if total_requests > 0 else 0,
|
||||
"date_range": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client and cleanup resources"""
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
self.http_client = None
|
||||
925
apps/resource-cluster/app/services/llm_gateway.py
Normal file
925
apps/resource-cluster/app/services/llm_gateway.py
Normal file
@@ -0,0 +1,925 @@
|
||||
"""
|
||||
LLM Gateway Service for GT 2.0 Resource Cluster
|
||||
|
||||
Provides unified access to LLM providers with:
|
||||
- Groq Cloud integration for fast inference
|
||||
- OpenAI API compatibility
|
||||
- Rate limiting and quota management
|
||||
- Capability-based authentication
|
||||
- Model routing and load balancing
|
||||
- Response streaming support
|
||||
|
||||
GT 2.0 Architecture Principles:
|
||||
- Stateless: No persistent connections or state
|
||||
- Zero downtime: Circuit breakers and failover
|
||||
- Self-contained: No external configuration dependencies
|
||||
- Capability-based: JWT token authorization
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator, Union
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from dataclasses import dataclass, asdict
|
||||
import uuid
|
||||
import httpx
|
||||
from enum import Enum
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
def is_provider_endpoint(endpoint_url: str, provider_domains: List[str]) -> bool:
|
||||
"""
|
||||
Safely check if URL belongs to a specific provider.
|
||||
|
||||
Uses proper URL parsing to prevent bypass via URLs like
|
||||
'evil.groq.com.attacker.com' or 'groq.com.evil.com'.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(endpoint_url)
|
||||
hostname = (parsed.hostname or "").lower()
|
||||
for domain in provider_domains:
|
||||
domain = domain.lower()
|
||||
# Match exact domain or subdomain (e.g., api.groq.com matches groq.com)
|
||||
if hostname == domain or hostname.endswith(f".{domain}"):
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
from app.core.capability_auth import verify_capability_token, CapabilityError
|
||||
from app.services.admin_model_config_service import get_admin_model_service, AdminModelConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class ModelProvider(str, Enum):
|
||||
"""Supported LLM providers"""
|
||||
GROQ = "groq"
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
NVIDIA = "nvidia"
|
||||
LOCAL = "local"
|
||||
|
||||
|
||||
class ModelCapability(str, Enum):
|
||||
"""Model capabilities for routing"""
|
||||
CHAT = "chat"
|
||||
COMPLETION = "completion"
|
||||
EMBEDDING = "embedding"
|
||||
FUNCTION_CALLING = "function_calling"
|
||||
VISION = "vision"
|
||||
CODE = "code"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Model configuration and capabilities"""
|
||||
model_id: str
|
||||
provider: ModelProvider
|
||||
capabilities: List[ModelCapability]
|
||||
max_tokens: int
|
||||
context_window: int
|
||||
cost_per_token: float
|
||||
rate_limit_rpm: int
|
||||
supports_streaming: bool
|
||||
supports_functions: bool
|
||||
is_available: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMRequest:
|
||||
"""Standardized LLM request format"""
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: bool = False
|
||||
functions: Optional[List[Dict[str, Any]]] = None
|
||||
function_call: Optional[Union[str, Dict[str, str]]] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API calls"""
|
||||
result = asdict(self)
|
||||
# Remove None values
|
||||
return {k: v for k, v in result.items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Standardized LLM response format"""
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
choices: List[Dict[str, Any]]
|
||||
usage: Dict[str, int]
|
||||
provider: str
|
||||
request_id: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API responses"""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
class LLMGateway:
|
||||
"""
|
||||
LLM Gateway with unified API and multi-provider support.
|
||||
|
||||
Provides OpenAI-compatible API while routing to optimal providers
|
||||
based on model capabilities, availability, and cost.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.http_client = httpx.AsyncClient(timeout=120.0)
|
||||
self.admin_service = get_admin_model_service()
|
||||
|
||||
# Rate limiting tracking
|
||||
self.rate_limits: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Provider health tracking
|
||||
self.provider_health: Dict[ModelProvider, bool] = {
|
||||
provider: True for provider in ModelProvider
|
||||
}
|
||||
|
||||
# Request statistics
|
||||
self.stats = {
|
||||
"total_requests": 0,
|
||||
"successful_requests": 0,
|
||||
"failed_requests": 0,
|
||||
"provider_usage": {provider.value: 0 for provider in ModelProvider},
|
||||
"model_usage": {},
|
||||
"average_latency": 0.0
|
||||
}
|
||||
|
||||
logger.info("LLM Gateway initialized with admin-configured models")
|
||||
|
||||
async def get_available_models(self, tenant_id: Optional[str] = None) -> List[AdminModelConfig]:
|
||||
"""Get available models, optionally filtered by tenant"""
|
||||
if tenant_id:
|
||||
return await self.admin_service.get_tenant_models(tenant_id)
|
||||
else:
|
||||
return await self.admin_service.get_all_models(active_only=True)
|
||||
|
||||
async def get_model_config(self, model_id: str, tenant_id: Optional[str] = None) -> Optional[AdminModelConfig]:
|
||||
"""Get configuration for a specific model"""
|
||||
config = await self.admin_service.get_model_config(model_id)
|
||||
|
||||
# Check tenant access if tenant_id provided
|
||||
if config and tenant_id:
|
||||
has_access = await self.admin_service.check_tenant_access(tenant_id, model_id)
|
||||
if not has_access:
|
||||
return None
|
||||
|
||||
return config
|
||||
|
||||
async def get_groq_api_key(self) -> Optional[str]:
|
||||
"""Get Groq API key from admin service"""
|
||||
return await self.admin_service.get_groq_api_key()
|
||||
|
||||
def _initialize_model_configs(self) -> Dict[str, ModelConfig]:
|
||||
"""Initialize supported model configurations"""
|
||||
models = {}
|
||||
|
||||
# Groq models (fast inference)
|
||||
groq_models = [
|
||||
ModelConfig(
|
||||
model_id="llama3-8b-8192",
|
||||
provider=ModelProvider.GROQ,
|
||||
capabilities=[ModelCapability.CHAT, ModelCapability.CODE],
|
||||
max_tokens=8192,
|
||||
context_window=8192,
|
||||
cost_per_token=0.00001,
|
||||
rate_limit_rpm=30,
|
||||
supports_streaming=True,
|
||||
supports_functions=False
|
||||
),
|
||||
ModelConfig(
|
||||
model_id="llama3-70b-8192",
|
||||
provider=ModelProvider.GROQ,
|
||||
capabilities=[ModelCapability.CHAT, ModelCapability.CODE],
|
||||
max_tokens=8192,
|
||||
context_window=8192,
|
||||
cost_per_token=0.00008,
|
||||
rate_limit_rpm=15,
|
||||
supports_streaming=True,
|
||||
supports_functions=False
|
||||
),
|
||||
ModelConfig(
|
||||
model_id="mixtral-8x7b-32768",
|
||||
provider=ModelProvider.GROQ,
|
||||
capabilities=[ModelCapability.CHAT, ModelCapability.CODE],
|
||||
max_tokens=32768,
|
||||
context_window=32768,
|
||||
cost_per_token=0.00005,
|
||||
rate_limit_rpm=20,
|
||||
supports_streaming=True,
|
||||
supports_functions=False
|
||||
),
|
||||
ModelConfig(
|
||||
model_id="gemma-7b-it",
|
||||
provider=ModelProvider.GROQ,
|
||||
capabilities=[ModelCapability.CHAT],
|
||||
max_tokens=8192,
|
||||
context_window=8192,
|
||||
cost_per_token=0.00001,
|
||||
rate_limit_rpm=30,
|
||||
supports_streaming=True,
|
||||
supports_functions=False
|
||||
)
|
||||
]
|
||||
|
||||
# OpenAI models (function calling, embeddings)
|
||||
openai_models = [
|
||||
ModelConfig(
|
||||
model_id="gpt-4-turbo-preview",
|
||||
provider=ModelProvider.OPENAI,
|
||||
capabilities=[ModelCapability.CHAT, ModelCapability.FUNCTION_CALLING, ModelCapability.VISION],
|
||||
max_tokens=4096,
|
||||
context_window=128000,
|
||||
cost_per_token=0.00003,
|
||||
rate_limit_rpm=10,
|
||||
supports_streaming=True,
|
||||
supports_functions=True
|
||||
),
|
||||
ModelConfig(
|
||||
model_id="gpt-3.5-turbo",
|
||||
provider=ModelProvider.OPENAI,
|
||||
capabilities=[ModelCapability.CHAT, ModelCapability.FUNCTION_CALLING],
|
||||
max_tokens=4096,
|
||||
context_window=16385,
|
||||
cost_per_token=0.000002,
|
||||
rate_limit_rpm=60,
|
||||
supports_streaming=True,
|
||||
supports_functions=True
|
||||
),
|
||||
ModelConfig(
|
||||
model_id="text-embedding-3-small",
|
||||
provider=ModelProvider.OPENAI,
|
||||
capabilities=[ModelCapability.EMBEDDING],
|
||||
max_tokens=8191,
|
||||
context_window=8191,
|
||||
cost_per_token=0.00000002,
|
||||
rate_limit_rpm=3000,
|
||||
supports_streaming=False,
|
||||
supports_functions=False
|
||||
)
|
||||
]
|
||||
|
||||
# Add all models to registry
|
||||
for model_list in [groq_models, openai_models]:
|
||||
for model in model_list:
|
||||
models[model.model_id] = model
|
||||
|
||||
return models
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: LLMRequest,
|
||||
capability_token: str,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> Union[LLMResponse, AsyncGenerator[str, None]]:
|
||||
"""
|
||||
Process chat completion request with capability validation.
|
||||
|
||||
Args:
|
||||
request: LLM request parameters
|
||||
capability_token: JWT capability token
|
||||
user_id: User identifier for rate limiting
|
||||
tenant_id: Tenant identifier for isolation
|
||||
|
||||
Returns:
|
||||
LLM response or streaming generator
|
||||
"""
|
||||
start_time = time.time()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Verify capabilities
|
||||
await self._verify_llm_capability(capability_token, request.model, user_id, tenant_id)
|
||||
|
||||
# Validate model availability
|
||||
model_config = self.models.get(request.model)
|
||||
if not model_config:
|
||||
raise ValueError(f"Model {request.model} not supported")
|
||||
|
||||
if not model_config.is_available:
|
||||
raise ValueError(f"Model {request.model} is currently unavailable")
|
||||
|
||||
# Check rate limits
|
||||
await self._check_rate_limits(user_id, model_config)
|
||||
|
||||
# Route to configured endpoint (generic routing for any provider)
|
||||
if hasattr(model_config, 'endpoint') and model_config.endpoint:
|
||||
result = await self._process_generic_request(request, request_id, model_config, tenant_id)
|
||||
elif model_config.provider == ModelProvider.GROQ:
|
||||
result = await self._process_groq_request(request, request_id, model_config, tenant_id)
|
||||
elif model_config.provider == ModelProvider.OPENAI:
|
||||
result = await self._process_openai_request(request, request_id, model_config)
|
||||
else:
|
||||
raise ValueError(f"Provider {model_config.provider} not implemented - ensure endpoint is configured")
|
||||
|
||||
# Update statistics
|
||||
latency = time.time() - start_time
|
||||
await self._update_stats(request.model, model_config.provider, latency, True)
|
||||
|
||||
logger.info(f"LLM request completed: {request_id} ({latency:.3f}s)")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
latency = time.time() - start_time
|
||||
await self._update_stats(request.model, ModelProvider.GROQ, latency, False)
|
||||
|
||||
logger.error(f"LLM request failed: {request_id} - {e}")
|
||||
raise
|
||||
|
||||
async def _verify_llm_capability(
|
||||
self,
|
||||
capability_token: str,
|
||||
model: str,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> None:
|
||||
"""Verify user has capability to use specific model"""
|
||||
try:
|
||||
payload = await verify_capability_token(capability_token)
|
||||
|
||||
# Check tenant match
|
||||
if payload.get("tenant_id") != tenant_id:
|
||||
raise CapabilityError("Tenant mismatch in capability token")
|
||||
|
||||
# Find LLM capability (match "llm" or "llm:provider" format)
|
||||
capabilities = payload.get("capabilities", [])
|
||||
llm_capability = None
|
||||
|
||||
for cap in capabilities:
|
||||
resource = cap.get("resource", "")
|
||||
if resource == "llm" or resource.startswith("llm:"):
|
||||
llm_capability = cap
|
||||
break
|
||||
|
||||
if not llm_capability:
|
||||
raise CapabilityError("No LLM capability found in token")
|
||||
|
||||
# Check model access
|
||||
allowed_models = llm_capability.get("constraints", {}).get("allowed_models", [])
|
||||
if allowed_models and model not in allowed_models:
|
||||
raise CapabilityError(f"Model {model} not allowed in capability")
|
||||
|
||||
# Check rate limits (per-minute window)
|
||||
max_requests_per_minute = llm_capability.get("constraints", {}).get("max_requests_per_minute")
|
||||
if max_requests_per_minute:
|
||||
await self._check_user_rate_limit(user_id, max_requests_per_minute)
|
||||
|
||||
except CapabilityError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise CapabilityError(f"Capability verification failed: {e}")
|
||||
|
||||
async def _check_rate_limits(self, user_id: str, model_config: ModelConfig) -> None:
|
||||
"""Check if user is within rate limits for model"""
|
||||
now = time.time()
|
||||
minute_ago = now - 60
|
||||
|
||||
# Initialize user rate limit tracking
|
||||
if user_id not in self.rate_limits:
|
||||
self.rate_limits[user_id] = {}
|
||||
|
||||
if model_config.model_id not in self.rate_limits[user_id]:
|
||||
self.rate_limits[user_id][model_config.model_id] = []
|
||||
|
||||
user_requests = self.rate_limits[user_id][model_config.model_id]
|
||||
|
||||
# Remove old requests
|
||||
user_requests[:] = [req_time for req_time in user_requests if req_time > minute_ago]
|
||||
|
||||
# Check limit
|
||||
if len(user_requests) >= model_config.rate_limit_rpm:
|
||||
raise ValueError(f"Rate limit exceeded for model {model_config.model_id}")
|
||||
|
||||
# Add current request
|
||||
user_requests.append(now)
|
||||
|
||||
async def _check_user_rate_limit(self, user_id: str, max_requests_per_minute: int) -> None:
|
||||
"""
|
||||
Check user's rate limit with per-minute enforcement window.
|
||||
|
||||
Enforces limits from Control Panel database (single source of truth).
|
||||
Time window: 60 seconds (not 1 hour).
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
max_requests_per_minute: Maximum requests allowed in 60-second window
|
||||
|
||||
Raises:
|
||||
ValueError: If rate limit exceeded
|
||||
"""
|
||||
now = time.time()
|
||||
minute_ago = now - 60 # 60-second window (was 3600 for hour)
|
||||
|
||||
if user_id not in self.rate_limits:
|
||||
self.rate_limits[user_id] = {}
|
||||
|
||||
if "total_requests" not in self.rate_limits[user_id]:
|
||||
self.rate_limits[user_id]["total_requests"] = []
|
||||
|
||||
total_requests = self.rate_limits[user_id]["total_requests"]
|
||||
|
||||
# Remove requests outside the 60-second window
|
||||
total_requests[:] = [req_time for req_time in total_requests if req_time > minute_ago]
|
||||
|
||||
# Check limit
|
||||
if len(total_requests) >= max_requests_per_minute:
|
||||
raise ValueError(
|
||||
f"Rate limit exceeded: {max_requests_per_minute} requests per minute. "
|
||||
f"Try again in {int(60 - (now - total_requests[0]))} seconds."
|
||||
)
|
||||
|
||||
# Add current request
|
||||
total_requests.append(now)
|
||||
|
||||
async def _process_groq_request(
|
||||
self,
|
||||
request: LLMRequest,
|
||||
request_id: str,
|
||||
model_config: ModelConfig,
|
||||
tenant_id: str
|
||||
) -> Union[LLMResponse, AsyncGenerator[str, None]]:
|
||||
"""
|
||||
Process request using Groq API with tenant-specific API key.
|
||||
|
||||
API keys are fetched from Control Panel database - NO environment variable fallback.
|
||||
"""
|
||||
try:
|
||||
# Get API key from Control Panel database (NO env fallback)
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
|
||||
# Prepare Groq API request
|
||||
groq_request = {
|
||||
"model": request.model,
|
||||
"messages": request.messages,
|
||||
"max_tokens": min(request.max_tokens or 1024, model_config.max_tokens),
|
||||
"temperature": request.temperature or 0.7,
|
||||
"top_p": request.top_p or 1.0,
|
||||
"stream": request.stream
|
||||
}
|
||||
|
||||
if request.stop:
|
||||
groq_request["stop"] = request.stop
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
if request.stream:
|
||||
return self._stream_groq_response(groq_request, headers, request_id)
|
||||
else:
|
||||
return await self._get_groq_response(groq_request, headers, request_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Groq API request failed: {e}")
|
||||
raise ValueError(f"Groq API error: {e}")
|
||||
|
||||
async def _get_tenant_api_key(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get API key for tenant from Control Panel database.
|
||||
|
||||
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
|
||||
"""
|
||||
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
|
||||
|
||||
client = get_api_key_client()
|
||||
|
||||
try:
|
||||
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="groq")
|
||||
return key_info["api_key"]
|
||||
except APIKeyNotConfiguredError as e:
|
||||
logger.error(f"No Groq API key for tenant '{tenant_id}': {e}")
|
||||
raise ValueError(f"No Groq API key configured for tenant '{tenant_id}'. Please configure in Control Panel → API Keys.")
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Control Panel error: {e}")
|
||||
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
|
||||
|
||||
async def _get_tenant_nvidia_api_key(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get NVIDIA NIM API key for tenant from Control Panel database.
|
||||
|
||||
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
|
||||
"""
|
||||
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
|
||||
|
||||
client = get_api_key_client()
|
||||
|
||||
try:
|
||||
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="nvidia")
|
||||
return key_info["api_key"]
|
||||
except APIKeyNotConfiguredError as e:
|
||||
logger.error(f"No NVIDIA API key for tenant '{tenant_id}': {e}")
|
||||
raise ValueError(f"No NVIDIA API key configured for tenant '{tenant_id}'. Please configure in Control Panel → API Keys.")
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Control Panel error: {e}")
|
||||
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
|
||||
|
||||
async def _get_groq_response(
|
||||
self,
|
||||
groq_request: Dict[str, Any],
|
||||
headers: Dict[str, str],
|
||||
request_id: str
|
||||
) -> LLMResponse:
|
||||
"""Get non-streaming response from Groq"""
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
"https://api.groq.com/openai/v1/chat/completions",
|
||||
json=groq_request,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Convert to standardized format
|
||||
return LLMResponse(
|
||||
id=data.get("id", request_id),
|
||||
object=data.get("object", "chat.completion"),
|
||||
created=data.get("created", int(time.time())),
|
||||
model=data.get("model", groq_request["model"]),
|
||||
choices=data.get("choices", []),
|
||||
usage=data.get("usage", {}),
|
||||
provider="groq",
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Groq API HTTP error: {e.response.status_code} - {e.response.text}")
|
||||
raise ValueError(f"Groq API error: {e.response.status_code}")
|
||||
except Exception as e:
|
||||
logger.error(f"Groq API error: {e}")
|
||||
raise ValueError(f"Groq API request failed: {e}")
|
||||
|
||||
async def _stream_groq_response(
|
||||
self,
|
||||
groq_request: Dict[str, Any],
|
||||
headers: Dict[str, str],
|
||||
request_id: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream response from Groq"""
|
||||
try:
|
||||
async with self.http_client.stream(
|
||||
"POST",
|
||||
"https://api.groq.com/openai/v1/chat/completions",
|
||||
json=groq_request,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
# Add provider and request_id to chunk
|
||||
data["provider"] = "groq"
|
||||
data["request_id"] = request_id
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Groq streaming error: {e.response.status_code}")
|
||||
yield f"data: {json.dumps({'error': f'Groq API error: {e.response.status_code}'})}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Groq streaming error: {e}")
|
||||
yield f"data: {json.dumps({'error': f'Streaming error: {e}'})}\n\n"
|
||||
|
||||
async def _process_generic_request(
|
||||
self,
|
||||
request: LLMRequest,
|
||||
request_id: str,
|
||||
model_config: Any,
|
||||
tenant_id: str
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Process request using generic endpoint (OpenAI-compatible).
|
||||
|
||||
For Groq endpoints, API keys are fetched from Control Panel database.
|
||||
"""
|
||||
try:
|
||||
# Build OpenAI-compatible request
|
||||
generic_request = {
|
||||
"model": request.model,
|
||||
"messages": request.messages,
|
||||
"temperature": request.temperature,
|
||||
"max_tokens": request.max_tokens,
|
||||
"stream": request.stream
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if hasattr(request, 'tools') and request.tools:
|
||||
generic_request["tools"] = request.tools
|
||||
if hasattr(request, 'tool_choice') and request.tool_choice:
|
||||
generic_request["tool_choice"] = request.tool_choice
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
endpoint_url = model_config.endpoint
|
||||
|
||||
# For Groq endpoints, use tenant-specific API key from Control Panel DB
|
||||
if is_provider_endpoint(endpoint_url, ["groq.com"]):
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
# For NVIDIA NIM endpoints, use tenant-specific API key from Control Panel DB
|
||||
elif is_provider_endpoint(endpoint_url, ["nvidia.com", "integrate.api.nvidia.com"]):
|
||||
api_key = await self._get_tenant_nvidia_api_key(tenant_id)
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
# For other endpoints, use model_config.api_key if configured
|
||||
elif hasattr(model_config, 'api_key') and model_config.api_key:
|
||||
headers["Authorization"] = f"Bearer {model_config.api_key}"
|
||||
|
||||
logger.info(f"Sending request to generic endpoint: {endpoint_url}")
|
||||
|
||||
if request.stream:
|
||||
return await self._stream_generic_response(generic_request, headers, endpoint_url, request_id, model_config)
|
||||
else:
|
||||
return await self._get_generic_response(generic_request, headers, endpoint_url, request_id, model_config)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generic request processing failed: {e}")
|
||||
raise ValueError(f"Generic inference failed: {e}")
|
||||
|
||||
async def _get_generic_response(
|
||||
self,
|
||||
generic_request: Dict[str, Any],
|
||||
headers: Dict[str, str],
|
||||
endpoint_url: str,
|
||||
request_id: str,
|
||||
model_config: Any
|
||||
) -> LLMResponse:
|
||||
"""Get non-streaming response from generic endpoint"""
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
endpoint_url,
|
||||
json=generic_request,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Convert to standardized format
|
||||
return LLMResponse(
|
||||
id=data.get("id", request_id),
|
||||
object=data.get("object", "chat.completion"),
|
||||
created=data.get("created", int(time.time())),
|
||||
model=data.get("model", generic_request["model"]),
|
||||
choices=data.get("choices", []),
|
||||
usage=data.get("usage", {}),
|
||||
provider=getattr(model_config, 'provider', 'generic'),
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Generic API HTTP error: {e.response.status_code} - {e.response.text}")
|
||||
raise ValueError(f"Generic API error: {e.response.status_code}")
|
||||
except Exception as e:
|
||||
logger.error(f"Generic response error: {e}")
|
||||
raise ValueError(f"Generic response processing failed: {e}")
|
||||
|
||||
async def _stream_generic_response(
|
||||
self,
|
||||
generic_request: Dict[str, Any],
|
||||
headers: Dict[str, str],
|
||||
endpoint_url: str,
|
||||
request_id: str,
|
||||
model_config: Any
|
||||
):
|
||||
"""Stream response from generic endpoint"""
|
||||
try:
|
||||
# For now, just do a non-streaming request and convert to streaming format
|
||||
# This can be enhanced to support actual streaming later
|
||||
response = await self._get_generic_response(generic_request, headers, endpoint_url, request_id, model_config)
|
||||
|
||||
# Convert to streaming format
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].get("message", {}).get("content", "")
|
||||
yield f"data: {json.dumps({'choices': [{'delta': {'content': content}}]})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generic streaming error: {e}")
|
||||
yield f"data: {json.dumps({'error': f'Streaming error: {e}'})}\n\n"
|
||||
|
||||
async def _process_openai_request(
|
||||
self,
|
||||
request: LLMRequest,
|
||||
request_id: str,
|
||||
model_config: ModelConfig
|
||||
) -> Union[LLMResponse, AsyncGenerator[str, None]]:
|
||||
"""Process request using OpenAI API"""
|
||||
try:
|
||||
# Prepare OpenAI API request
|
||||
openai_request = {
|
||||
"model": request.model,
|
||||
"messages": request.messages,
|
||||
"max_tokens": min(request.max_tokens or 1024, model_config.max_tokens),
|
||||
"temperature": request.temperature or 0.7,
|
||||
"top_p": request.top_p or 1.0,
|
||||
"stream": request.stream
|
||||
}
|
||||
|
||||
if request.stop:
|
||||
openai_request["stop"] = request.stop
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {settings.openai_api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
if request.stream:
|
||||
return self._stream_openai_response(openai_request, headers, request_id)
|
||||
else:
|
||||
return await self._get_openai_response(openai_request, headers, request_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API request failed: {e}")
|
||||
raise ValueError(f"OpenAI API error: {e}")
|
||||
|
||||
async def _get_openai_response(
|
||||
self,
|
||||
openai_request: Dict[str, Any],
|
||||
headers: Dict[str, str],
|
||||
request_id: str
|
||||
) -> LLMResponse:
|
||||
"""Get non-streaming response from OpenAI"""
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
json=openai_request,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Convert to standardized format
|
||||
return LLMResponse(
|
||||
id=data.get("id", request_id),
|
||||
object=data.get("object", "chat.completion"),
|
||||
created=data.get("created", int(time.time())),
|
||||
model=data.get("model", openai_request["model"]),
|
||||
choices=data.get("choices", []),
|
||||
usage=data.get("usage", {}),
|
||||
provider="openai",
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"OpenAI API HTTP error: {e.response.status_code} - {e.response.text}")
|
||||
raise ValueError(f"OpenAI API error: {e.response.status_code}")
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise ValueError(f"OpenAI API request failed: {e}")
|
||||
|
||||
async def _stream_openai_response(
|
||||
self,
|
||||
openai_request: Dict[str, Any],
|
||||
headers: Dict[str, str],
|
||||
request_id: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream response from OpenAI"""
|
||||
try:
|
||||
async with self.http_client.stream(
|
||||
"POST",
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
json=openai_request,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
# Add provider and request_id to chunk
|
||||
data["provider"] = "openai"
|
||||
data["request_id"] = request_id
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"OpenAI streaming error: {e.response.status_code}")
|
||||
yield f"data: {json.dumps({'error': f'OpenAI API error: {e.response.status_code}'})}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI streaming error: {e}")
|
||||
yield f"data: {json.dumps({'error': f'Streaming error: {e}'})}\n\n"
|
||||
|
||||
async def _update_stats(
|
||||
self,
|
||||
model: str,
|
||||
provider: ModelProvider,
|
||||
latency: float,
|
||||
success: bool
|
||||
) -> None:
|
||||
"""Update request statistics"""
|
||||
self.stats["total_requests"] += 1
|
||||
|
||||
if success:
|
||||
self.stats["successful_requests"] += 1
|
||||
else:
|
||||
self.stats["failed_requests"] += 1
|
||||
|
||||
self.stats["provider_usage"][provider.value] += 1
|
||||
|
||||
if model not in self.stats["model_usage"]:
|
||||
self.stats["model_usage"][model] = 0
|
||||
self.stats["model_usage"][model] += 1
|
||||
|
||||
# Update rolling average latency
|
||||
total_requests = self.stats["total_requests"]
|
||||
current_avg = self.stats["average_latency"]
|
||||
self.stats["average_latency"] = ((current_avg * (total_requests - 1)) + latency) / total_requests
|
||||
|
||||
async def get_available_models(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of available models with capabilities"""
|
||||
models = []
|
||||
|
||||
for model_id, config in self.models.items():
|
||||
if config.is_available:
|
||||
models.append({
|
||||
"id": model_id,
|
||||
"provider": config.provider.value,
|
||||
"capabilities": [cap.value for cap in config.capabilities],
|
||||
"max_tokens": config.max_tokens,
|
||||
"context_window": config.context_window,
|
||||
"supports_streaming": config.supports_streaming,
|
||||
"supports_functions": config.supports_functions
|
||||
})
|
||||
|
||||
return models
|
||||
|
||||
async def get_gateway_stats(self) -> Dict[str, Any]:
|
||||
"""Get gateway statistics"""
|
||||
return {
|
||||
**self.stats,
|
||||
"provider_health": {
|
||||
provider.value: health
|
||||
for provider, health in self.provider_health.items()
|
||||
},
|
||||
"active_models": len([m for m in self.models.values() if m.is_available]),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Health check for the LLM gateway"""
|
||||
healthy_providers = sum(1 for health in self.provider_health.values() if health)
|
||||
total_providers = len(self.provider_health)
|
||||
|
||||
return {
|
||||
"status": "healthy" if healthy_providers > 0 else "degraded",
|
||||
"providers_healthy": healthy_providers,
|
||||
"total_providers": total_providers,
|
||||
"available_models": len([m for m in self.models.values() if m.is_available]),
|
||||
"total_requests": self.stats["total_requests"],
|
||||
"success_rate": (
|
||||
self.stats["successful_requests"] / max(self.stats["total_requests"], 1)
|
||||
) * 100,
|
||||
"average_latency_ms": self.stats["average_latency"] * 1000
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client and cleanup resources"""
|
||||
await self.http_client.aclose()
|
||||
|
||||
|
||||
# Global gateway instance
|
||||
llm_gateway = LLMGateway()
|
||||
|
||||
|
||||
# Factory function for dependency injection
|
||||
def get_llm_gateway() -> LLMGateway:
|
||||
"""Get LLM gateway instance"""
|
||||
return llm_gateway
|
||||
599
apps/resource-cluster/app/services/mcp_rag_server.py
Normal file
599
apps/resource-cluster/app/services/mcp_rag_server.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""
|
||||
GT 2.0 MCP RAG Server
|
||||
|
||||
Provides RAG (Retrieval-Augmented Generation) capabilities as an MCP server.
|
||||
Agents can use this server to search datasets, query documents, and retrieve
|
||||
relevant context for user queries.
|
||||
|
||||
Tools provided:
|
||||
- search_datasets: Search across user's accessible datasets
|
||||
- query_documents: Query specific documents for relevant chunks
|
||||
- get_relevant_chunks: Get relevant text chunks based on similarity
|
||||
- list_user_datasets: List all datasets accessible to the user
|
||||
- get_dataset_info: Get detailed information about a dataset
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
import httpx
|
||||
import json
|
||||
|
||||
from app.core.security import verify_capability_token
|
||||
from app.services.mcp_server import MCPServerResource, MCPServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGSearchParams:
|
||||
"""Parameters for RAG search operations"""
|
||||
query: str
|
||||
dataset_ids: Optional[List[str]] = None
|
||||
search_method: str = "hybrid" # hybrid, vector, text
|
||||
max_results: int = 10
|
||||
similarity_threshold: float = 0.7
|
||||
include_metadata: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGSearchResult:
|
||||
"""Result from RAG search operation"""
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
dataset_id: str
|
||||
dataset_name: str
|
||||
document_name: str
|
||||
content: str
|
||||
similarity_score: float
|
||||
chunk_index: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class MCPRAGServer:
|
||||
"""
|
||||
MCP server for RAG operations in GT 2.0.
|
||||
|
||||
Provides secure, tenant-isolated access to document search capabilities
|
||||
through standardized MCP tool interfaces.
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_backend_url: str = "http://tenant-backend:8000"):
|
||||
self.tenant_backend_url = tenant_backend_url
|
||||
self.server_name = "rag_server"
|
||||
self.server_type = "rag"
|
||||
|
||||
# Define available tools (streamlined for simplicity)
|
||||
self.available_tools = [
|
||||
"search_datasets"
|
||||
]
|
||||
|
||||
# Tool schemas for MCP protocol (enhanced with flexible parameters)
|
||||
self.tool_schemas = {
|
||||
"search_datasets": {
|
||||
"name": "search_datasets",
|
||||
"description": "Search through datasets containing uploaded documents, PDFs, and files. Use when users ask about documentation, reference materials, checking files, looking up information, need data from uploaded content, want to know what's in the dataset, search our data, check if we have something, or look through files.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "What to search for in the datasets"
|
||||
},
|
||||
"dataset_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "(Optional) List of specific dataset IDs to search within"
|
||||
},
|
||||
"file_pattern": {
|
||||
"type": "string",
|
||||
"description": "(Optional) File pattern filter (e.g., '*.pdf', '*.txt')"
|
||||
},
|
||||
"search_all": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description": "(Optional) Search across all accessible datasets (ignores dataset_ids)"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"default": 10,
|
||||
"description": "(Optional) Number of results to return (default: 10)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async def handle_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
tenant_domain: str,
|
||||
user_id: str,
|
||||
agent_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle MCP tool call with tenant isolation and user context.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool being called
|
||||
parameters: Tool parameters from the LLM
|
||||
tenant_domain: Tenant domain for isolation
|
||||
user_id: User making the request
|
||||
|
||||
Returns:
|
||||
Tool execution result or error
|
||||
"""
|
||||
logger.info(f"🚀 MCP RAG Server: handle_tool_call called - tool={tool_name}, tenant={tenant_domain}, user={user_id}")
|
||||
logger.info(f"📝 MCP RAG Server: parameters={parameters}")
|
||||
try:
|
||||
# Validate tool exists
|
||||
if tool_name not in self.available_tools:
|
||||
return {
|
||||
"error": f"Unknown tool: {tool_name}",
|
||||
"tool_name": tool_name
|
||||
}
|
||||
|
||||
# Route to appropriate handler
|
||||
if tool_name == "search_datasets":
|
||||
return await self._search_datasets(parameters, tenant_domain, user_id, agent_context)
|
||||
else:
|
||||
return {
|
||||
"error": f"Tool handler not implemented: {tool_name}",
|
||||
"tool_name": tool_name
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling tool call {tool_name}: {e}")
|
||||
return {
|
||||
"error": f"Tool execution failed: {str(e)}",
|
||||
"tool_name": tool_name
|
||||
}
|
||||
|
||||
def _verify_user_access(self, user_id: str, tenant_domain: str) -> bool:
|
||||
"""Verify user has access to tenant resources (simplified check)"""
|
||||
# In a real system, this would query the database to verify
|
||||
# that the user has access to the tenant's resources
|
||||
# For now, we trust that the tenant backend has already verified this
|
||||
return bool(user_id and tenant_domain)
|
||||
|
||||
async def _search_datasets(
|
||||
self,
|
||||
parameters: Dict[str, Any],
|
||||
tenant_domain: str,
|
||||
user_id: str,
|
||||
agent_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Search across user's datasets"""
|
||||
logger.info(f"🔍 RAG Server: search_datasets called for user {user_id} in tenant {tenant_domain}")
|
||||
logger.info(f"📝 RAG Server: search parameters = {parameters}")
|
||||
logger.info(f"📝 RAG Server: parameter types: {[(k, type(v)) for k, v in parameters.items()]}")
|
||||
|
||||
try:
|
||||
query = parameters.get("query", "").strip()
|
||||
list_mode = parameters.get("list_mode", False)
|
||||
|
||||
# Handle list mode - list available datasets instead of searching
|
||||
if list_mode:
|
||||
logger.info(f"🔍 RAG Server: List mode activated - fetching available datasets")
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.tenant_backend_url}/api/v1/datasets/internal/list",
|
||||
headers={
|
||||
"X-Tenant-Domain": tenant_domain,
|
||||
"X-User-ID": user_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
datasets = response.json()
|
||||
logger.info(f"✅ RAG Server: Successfully listed {len(datasets)} datasets")
|
||||
return {
|
||||
"success": True,
|
||||
"datasets": datasets,
|
||||
"total_count": len(datasets),
|
||||
"list_mode": True
|
||||
}
|
||||
else:
|
||||
logger.error(f"❌ RAG Server: Failed to list datasets: {response.status_code} - {response.text}")
|
||||
return {"error": f"Failed to list datasets: {response.status_code}"}
|
||||
|
||||
# Normal search mode
|
||||
if not query:
|
||||
logger.error("❌ RAG Server: Query parameter is required")
|
||||
return {"error": "Query parameter is required"}
|
||||
|
||||
# Prepare search request with enhanced parameters
|
||||
dataset_ids = parameters.get("dataset_ids")
|
||||
file_pattern = parameters.get("file_pattern")
|
||||
search_all = parameters.get("search_all", False)
|
||||
|
||||
# Handle legacy dataset_id parameter (backwards compatibility)
|
||||
if dataset_ids is None and parameters.get("dataset_id"):
|
||||
dataset_ids = [parameters.get("dataset_id")]
|
||||
|
||||
# Ensure dataset_ids is properly formatted
|
||||
if dataset_ids is None:
|
||||
dataset_ids = []
|
||||
elif isinstance(dataset_ids, str):
|
||||
dataset_ids = [dataset_ids]
|
||||
|
||||
# If search_all is True, ignore dataset_ids filter
|
||||
if search_all:
|
||||
dataset_ids = []
|
||||
|
||||
# AGENT-AWARE: If no datasets specified, use agent's configured datasets
|
||||
if not dataset_ids and not search_all and agent_context:
|
||||
agent_dataset_ids = agent_context.get('selected_dataset_ids', [])
|
||||
if agent_dataset_ids:
|
||||
dataset_ids = agent_dataset_ids
|
||||
agent_name = agent_context.get('agent_name', 'Unknown')
|
||||
logger.info(f"✅ RAG Server: Using agent '{agent_name}' datasets: {dataset_ids}")
|
||||
else:
|
||||
logger.warning(f"⚠️ RAG Server: Agent context available but no datasets configured")
|
||||
elif not dataset_ids and not search_all:
|
||||
logger.warning(f"⚠️ RAG Server: No dataset_ids provided and no agent context available")
|
||||
|
||||
search_request = {
|
||||
"query": query,
|
||||
"search_type": parameters.get("search_method", "hybrid"),
|
||||
"max_results": parameters.get("max_results", 10), # No arbitrary cap
|
||||
"dataset_ids": dataset_ids,
|
||||
"min_similarity": 0.3
|
||||
}
|
||||
|
||||
# Add file_pattern if provided
|
||||
if file_pattern:
|
||||
search_request["file_pattern"] = file_pattern
|
||||
|
||||
logger.info(f"🎯 RAG Server: prepared search request = {search_request}")
|
||||
|
||||
# Call tenant backend search API
|
||||
logger.info(f"🌐 RAG Server: calling tenant backend at {self.tenant_backend_url}/api/v1/search/")
|
||||
logger.info(f"🌐 RAG Server: request headers: X-Tenant-Domain='{tenant_domain}', X-User-ID='{user_id}'")
|
||||
logger.info(f"🌐 RAG Server: request body: {search_request}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.tenant_backend_url}/api/v1/search/",
|
||||
json=search_request,
|
||||
headers={
|
||||
"X-Tenant-Domain": tenant_domain,
|
||||
"X-User-ID": user_id,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"📊 RAG Server: tenant backend response: {response.status_code}")
|
||||
if response.status_code != 200:
|
||||
logger.error(f"📊 RAG Server: error response body: {response.text}")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(f"✅ RAG Server: search successful, got {len(data.get('results', []))} results")
|
||||
|
||||
# Format results for MCP response
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
results.append({
|
||||
"chunk_id": result.get("chunk_id"),
|
||||
"document_id": result.get("document_id"),
|
||||
"dataset_id": result.get("dataset_id"),
|
||||
"content": result.get("text", ""),
|
||||
"similarity_score": result.get("hybrid_score", 0.0),
|
||||
"metadata": result.get("metadata", {})
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results_count": len(results),
|
||||
"results": results,
|
||||
"search_method": data.get("search_type", "hybrid")
|
||||
}
|
||||
else:
|
||||
error_text = response.text
|
||||
logger.error(f"❌ RAG Server: search failed: {response.status_code} - {error_text}")
|
||||
return {
|
||||
"error": f"Search failed: {response.status_code} - {error_text}",
|
||||
"query": query
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dataset search error: {e}")
|
||||
return {
|
||||
"error": f"Search operation failed: {str(e)}",
|
||||
"query": parameters.get("query", "")
|
||||
}
|
||||
|
||||
async def _query_documents(
|
||||
self,
|
||||
parameters: Dict[str, Any],
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Query specific documents for relevant chunks"""
|
||||
try:
|
||||
query = parameters.get("query", "").strip()
|
||||
document_ids = parameters.get("document_ids", [])
|
||||
|
||||
if not query or not document_ids:
|
||||
return {"error": "Both query and document_ids are required"}
|
||||
|
||||
# Use search API with document ID filter
|
||||
search_request = {
|
||||
"query": query,
|
||||
"search_type": "hybrid",
|
||||
"max_results": parameters.get("max_results", 5),
|
||||
"document_ids": document_ids
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.tenant_backend_url}/api/v1/search/documents",
|
||||
json=search_request,
|
||||
headers={
|
||||
"X-Tenant-Domain": tenant_domain,
|
||||
"X-User-ID": user_id,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"document_ids": document_ids,
|
||||
"results": data.get("results", [])
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"error": f"Document query failed: {response.status_code}",
|
||||
"query": query,
|
||||
"document_ids": document_ids
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"error": f"Document query failed: {str(e)}",
|
||||
"query": parameters.get("query", "")
|
||||
}
|
||||
|
||||
async def _list_user_datasets(
|
||||
self,
|
||||
parameters: Dict[str, Any],
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""List user's accessible datasets"""
|
||||
try:
|
||||
include_stats = parameters.get("include_stats", True)
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
params = {"include_stats": include_stats}
|
||||
response = await client.get(
|
||||
f"{self.tenant_backend_url}/api/v1/datasets",
|
||||
params=params,
|
||||
headers={
|
||||
"X-Tenant-Domain": tenant_domain,
|
||||
"X-User-ID": user_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
datasets = data.get("data", []) if isinstance(data, dict) else data
|
||||
|
||||
# Format for MCP response
|
||||
formatted_datasets = []
|
||||
for dataset in datasets:
|
||||
formatted_datasets.append({
|
||||
"id": dataset.get("id"),
|
||||
"name": dataset.get("name"),
|
||||
"description": dataset.get("description"),
|
||||
"document_count": dataset.get("document_count", 0),
|
||||
"chunk_count": dataset.get("chunk_count", 0),
|
||||
"created_at": dataset.get("created_at"),
|
||||
"access_group": dataset.get("access_group", "individual")
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"datasets": formatted_datasets,
|
||||
"total_count": len(formatted_datasets)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"error": f"Failed to list datasets: {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"error": f"Failed to list datasets: {str(e)}"
|
||||
}
|
||||
|
||||
async def _get_dataset_info(
|
||||
self,
|
||||
parameters: Dict[str, Any],
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get detailed information about a dataset"""
|
||||
try:
|
||||
dataset_id = parameters.get("dataset_id")
|
||||
if not dataset_id:
|
||||
return {"error": "dataset_id parameter is required"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.tenant_backend_url}/api/v1/datasets/{dataset_id}",
|
||||
headers={
|
||||
"X-Tenant-Domain": tenant_domain,
|
||||
"X-User-ID": user_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
dataset = data.get("data", data)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"dataset": {
|
||||
"id": dataset.get("id"),
|
||||
"name": dataset.get("name"),
|
||||
"description": dataset.get("description"),
|
||||
"document_count": dataset.get("document_count", 0),
|
||||
"chunk_count": dataset.get("chunk_count", 0),
|
||||
"vector_count": dataset.get("vector_count", 0),
|
||||
"storage_size_mb": dataset.get("storage_size_mb", 0),
|
||||
"created_at": dataset.get("created_at"),
|
||||
"updated_at": dataset.get("updated_at"),
|
||||
"access_group": dataset.get("access_group"),
|
||||
"tags": dataset.get("tags", [])
|
||||
}
|
||||
}
|
||||
elif response.status_code == 404:
|
||||
return {
|
||||
"error": f"Dataset not found: {dataset_id}"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"error": f"Failed to get dataset info: {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"error": f"Failed to get dataset info: {str(e)}"
|
||||
}
|
||||
|
||||
async def _get_user_agent_datasets(self, tenant_domain: str, user_id: str) -> List[str]:
|
||||
"""Auto-detect agent datasets for the current user"""
|
||||
try:
|
||||
# Get user's agents and their configured datasets
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.tenant_backend_url}/api/v1/agents",
|
||||
headers={
|
||||
"X-Tenant-Domain": tenant_domain,
|
||||
"X-User-ID": user_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
agents_data = response.json()
|
||||
agents = agents_data.get("data", []) if isinstance(agents_data, dict) else agents_data
|
||||
|
||||
# Collect all dataset IDs from all user's agents
|
||||
all_dataset_ids = set()
|
||||
for agent in agents:
|
||||
agent_dataset_ids = agent.get("selected_dataset_ids", [])
|
||||
if agent_dataset_ids:
|
||||
all_dataset_ids.update(agent_dataset_ids)
|
||||
logger.info(f"🔍 RAG Server: Agent {agent.get('name', 'unknown')} has datasets: {agent_dataset_ids}")
|
||||
|
||||
return list(all_dataset_ids)
|
||||
else:
|
||||
logger.warning(f"⚠️ RAG Server: Failed to get agents: {response.status_code}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ RAG Server: Error getting user agent datasets: {e}")
|
||||
return []
|
||||
|
||||
async def _get_relevant_chunks(
|
||||
self,
|
||||
parameters: Dict[str, Any],
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get most relevant chunks for a query"""
|
||||
try:
|
||||
query = parameters.get("query", "").strip()
|
||||
if not query:
|
||||
return {"error": "query parameter is required"}
|
||||
|
||||
chunk_count = min(parameters.get("chunk_count", 3), 10) # Cap at 10
|
||||
min_similarity = parameters.get("min_similarity", 0.6)
|
||||
dataset_ids = parameters.get("dataset_ids")
|
||||
|
||||
search_request = {
|
||||
"query": query,
|
||||
"search_type": "vector", # Use vector search for relevance
|
||||
"max_results": chunk_count,
|
||||
"min_similarity": min_similarity,
|
||||
"dataset_ids": dataset_ids
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.tenant_backend_url}/api/v1/search",
|
||||
json=search_request,
|
||||
headers={
|
||||
"X-Tenant-Domain": tenant_domain,
|
||||
"X-User-ID": user_id,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
chunks = []
|
||||
|
||||
for result in data.get("results", []):
|
||||
chunks.append({
|
||||
"chunk_id": result.get("chunk_id"),
|
||||
"document_id": result.get("document_id"),
|
||||
"dataset_id": result.get("dataset_id"),
|
||||
"content": result.get("text", ""),
|
||||
"similarity_score": result.get("vector_similarity", 0.0),
|
||||
"chunk_index": result.get("rank", 0),
|
||||
"metadata": result.get("metadata", {})
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"chunks": chunks,
|
||||
"chunk_count": len(chunks),
|
||||
"min_similarity": min_similarity
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"error": f"Chunk retrieval failed: {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"error": f"Failed to get relevant chunks: {str(e)}"
|
||||
}
|
||||
|
||||
def get_server_config(self) -> MCPServerConfig:
|
||||
"""Get MCP server configuration"""
|
||||
return MCPServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url="internal://mcp-rag-server",
|
||||
server_type=self.server_type,
|
||||
available_tools=self.available_tools,
|
||||
required_capabilities=["mcp:rag:*"],
|
||||
sandbox_mode=True,
|
||||
max_memory_mb=256,
|
||||
max_cpu_percent=25,
|
||||
timeout_seconds=30,
|
||||
network_isolation=False, # Needs to access tenant backend
|
||||
max_requests_per_minute=120,
|
||||
max_concurrent_requests=10
|
||||
)
|
||||
|
||||
def get_tool_schemas(self) -> Dict[str, Any]:
|
||||
"""Get MCP tool schemas for this server"""
|
||||
return self.tool_schemas
|
||||
|
||||
|
||||
# Global instance
|
||||
mcp_rag_server = MCPRAGServer()
|
||||
491
apps/resource-cluster/app/services/mcp_sandbox.py
Normal file
491
apps/resource-cluster/app/services/mcp_sandbox.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""
|
||||
MCP Sandbox Service for GT 2.0
|
||||
|
||||
Provides secure sandboxed execution environment for MCP servers.
|
||||
Implements resource isolation, monitoring, and security constraints.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import resource
|
||||
import signal
|
||||
import tempfile
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional, Callable, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import json
|
||||
import psutil
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SandboxConfig:
|
||||
"""Configuration for sandbox environment"""
|
||||
# Resource limits
|
||||
max_memory_mb: int = 512
|
||||
max_cpu_percent: int = 50
|
||||
max_disk_mb: int = 100
|
||||
timeout_seconds: int = 30
|
||||
|
||||
# Security settings
|
||||
network_isolation: bool = True
|
||||
readonly_filesystem: bool = False
|
||||
allowed_paths: list = None
|
||||
blocked_paths: list = None
|
||||
allowed_commands: list = None
|
||||
|
||||
# Process limits
|
||||
max_processes: int = 10
|
||||
max_open_files: int = 100
|
||||
max_threads: int = 20
|
||||
|
||||
def __post_init__(self):
|
||||
if self.allowed_paths is None:
|
||||
self.allowed_paths = ["/tmp", "/var/tmp"]
|
||||
if self.blocked_paths is None:
|
||||
self.blocked_paths = ["/etc", "/root", "/home", "/usr/bin", "/usr/sbin"]
|
||||
if self.allowed_commands is None:
|
||||
self.allowed_commands = ["ls", "cat", "grep", "find", "echo", "pwd"]
|
||||
|
||||
|
||||
class ProcessSandbox:
|
||||
"""
|
||||
Process-level sandbox for MCP tool execution
|
||||
Uses OS-level isolation and resource limits
|
||||
"""
|
||||
|
||||
def __init__(self, config: SandboxConfig):
|
||||
self.config = config
|
||||
self.process: Optional[asyncio.subprocess.Process] = None
|
||||
self.start_time: Optional[datetime] = None
|
||||
self.temp_dir: Optional[Path] = None
|
||||
self.resource_monitor_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter sandbox context"""
|
||||
await self.setup()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit sandbox context and cleanup"""
|
||||
await self.cleanup()
|
||||
|
||||
async def setup(self):
|
||||
"""Setup sandbox environment"""
|
||||
# Create temporary directory for sandbox
|
||||
self.temp_dir = Path(tempfile.mkdtemp(prefix="mcp_sandbox_"))
|
||||
os.chmod(self.temp_dir, 0o700) # Restrict access
|
||||
|
||||
# Set resource limits for child processes
|
||||
self._set_resource_limits()
|
||||
|
||||
# Start resource monitoring
|
||||
self.resource_monitor_task = asyncio.create_task(self._monitor_resources())
|
||||
|
||||
self.start_time = datetime.utcnow()
|
||||
logger.info(f"Sandbox setup complete: {self.temp_dir}")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup sandbox environment"""
|
||||
# Stop resource monitoring
|
||||
if self.resource_monitor_task:
|
||||
self.resource_monitor_task.cancel()
|
||||
try:
|
||||
await self.resource_monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Terminate process if still running
|
||||
if self.process and self.process.returncode is None:
|
||||
try:
|
||||
self.process.terminate()
|
||||
await asyncio.wait_for(self.process.wait(), timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
self.process.kill()
|
||||
await self.process.wait()
|
||||
|
||||
# Remove temporary directory
|
||||
if self.temp_dir and self.temp_dir.exists():
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
logger.info("Sandbox cleanup complete")
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
command: str,
|
||||
args: list = None,
|
||||
input_data: str = None,
|
||||
env: Dict[str, str] = None
|
||||
) -> Tuple[int, str, str]:
|
||||
"""
|
||||
Execute command in sandbox
|
||||
|
||||
Args:
|
||||
command: Command to execute
|
||||
args: Command arguments
|
||||
input_data: Input to send to process
|
||||
env: Environment variables
|
||||
|
||||
Returns:
|
||||
Tuple of (return_code, stdout, stderr)
|
||||
"""
|
||||
# Validate command
|
||||
if not self._validate_command(command):
|
||||
raise PermissionError(f"Command not allowed: {command}")
|
||||
|
||||
# Prepare environment
|
||||
sandbox_env = self._prepare_environment(env)
|
||||
|
||||
# Prepare command with arguments
|
||||
full_command = [command] + (args or [])
|
||||
|
||||
try:
|
||||
# Create process with resource limits
|
||||
self.process = await asyncio.create_subprocess_exec(
|
||||
*full_command,
|
||||
stdin=asyncio.subprocess.PIPE if input_data else None,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=str(self.temp_dir),
|
||||
env=sandbox_env,
|
||||
preexec_fn=self._set_process_limits if os.name == 'posix' else None
|
||||
)
|
||||
|
||||
# Execute with timeout
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
self.process.communicate(input=input_data.encode() if input_data else None),
|
||||
timeout=self.config.timeout_seconds
|
||||
)
|
||||
|
||||
return self.process.returncode, stdout.decode(), stderr.decode()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if self.process:
|
||||
self.process.kill()
|
||||
await self.process.wait()
|
||||
raise TimeoutError(f"Command exceeded {self.config.timeout_seconds}s timeout")
|
||||
except Exception as e:
|
||||
logger.error(f"Sandbox execution error: {e}")
|
||||
raise
|
||||
|
||||
async def execute_function(
|
||||
self,
|
||||
func: Callable,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Execute Python function in sandbox
|
||||
Uses multiprocessing for isolation
|
||||
"""
|
||||
import multiprocessing
|
||||
import pickle
|
||||
|
||||
# Create pipe for communication
|
||||
parent_conn, child_conn = multiprocessing.Pipe()
|
||||
|
||||
def sandbox_wrapper(conn, func, args, kwargs):
|
||||
"""Wrapper to execute function in child process"""
|
||||
try:
|
||||
# Apply resource limits
|
||||
self._set_process_limits()
|
||||
|
||||
# Execute function
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Send result back
|
||||
conn.send(("success", pickle.dumps(result)))
|
||||
except Exception as e:
|
||||
conn.send(("error", str(e)))
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# Create and start process
|
||||
process = multiprocessing.Process(
|
||||
target=sandbox_wrapper,
|
||||
args=(child_conn, func, args, kwargs)
|
||||
)
|
||||
process.start()
|
||||
|
||||
# Wait for result with timeout
|
||||
try:
|
||||
if parent_conn.poll(self.config.timeout_seconds):
|
||||
status, data = parent_conn.recv()
|
||||
if status == "success":
|
||||
return pickle.loads(data)
|
||||
else:
|
||||
raise RuntimeError(f"Sandbox function error: {data}")
|
||||
else:
|
||||
process.terminate()
|
||||
process.join(timeout=5)
|
||||
if process.is_alive():
|
||||
process.kill()
|
||||
raise TimeoutError(f"Function exceeded {self.config.timeout_seconds}s timeout")
|
||||
finally:
|
||||
parent_conn.close()
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join()
|
||||
|
||||
def _validate_command(self, command: str) -> bool:
|
||||
"""Validate if command is allowed"""
|
||||
# Check if command is in allowed list
|
||||
command_name = os.path.basename(command)
|
||||
if self.config.allowed_commands and command_name not in self.config.allowed_commands:
|
||||
return False
|
||||
|
||||
# Check for dangerous patterns
|
||||
dangerous_patterns = [
|
||||
"rm -rf",
|
||||
"dd if=",
|
||||
"mkfs",
|
||||
"format",
|
||||
">", # Redirect that could overwrite files
|
||||
"|", # Pipe that could chain commands
|
||||
";", # Command separator
|
||||
"&", # Background execution
|
||||
"`", # Command substitution
|
||||
"$(" # Command substitution
|
||||
]
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in command:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _prepare_environment(self, custom_env: Dict[str, str] = None) -> Dict[str, str]:
|
||||
"""Prepare sandboxed environment variables"""
|
||||
# Start with minimal environment
|
||||
sandbox_env = {
|
||||
"PATH": "/usr/local/bin:/usr/bin:/bin",
|
||||
"HOME": str(self.temp_dir),
|
||||
"TEMP": str(self.temp_dir),
|
||||
"TMP": str(self.temp_dir),
|
||||
"USER": "sandbox",
|
||||
"SHELL": "/bin/sh"
|
||||
}
|
||||
|
||||
# Add custom environment variables if provided
|
||||
if custom_env:
|
||||
# Filter out dangerous variables
|
||||
dangerous_vars = ["LD_PRELOAD", "LD_LIBRARY_PATH", "PYTHONPATH", "PATH"]
|
||||
for key, value in custom_env.items():
|
||||
if key not in dangerous_vars:
|
||||
sandbox_env[key] = value
|
||||
|
||||
return sandbox_env
|
||||
|
||||
def _set_resource_limits(self):
|
||||
"""Set resource limits for the process"""
|
||||
if os.name != 'posix':
|
||||
return # Resource limits only work on POSIX systems
|
||||
|
||||
# Memory limit
|
||||
memory_bytes = self.config.max_memory_mb * 1024 * 1024
|
||||
resource.setrlimit(resource.RLIMIT_AS, (memory_bytes, memory_bytes))
|
||||
|
||||
# CPU time limit
|
||||
resource.setrlimit(resource.RLIMIT_CPU, (self.config.timeout_seconds, self.config.timeout_seconds))
|
||||
|
||||
# File size limit
|
||||
file_size_bytes = self.config.max_disk_mb * 1024 * 1024
|
||||
resource.setrlimit(resource.RLIMIT_FSIZE, (file_size_bytes, file_size_bytes))
|
||||
|
||||
# Process limit
|
||||
resource.setrlimit(resource.RLIMIT_NPROC, (self.config.max_processes, self.config.max_processes))
|
||||
|
||||
# Open files limit
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (self.config.max_open_files, self.config.max_open_files))
|
||||
|
||||
def _set_process_limits(self):
|
||||
"""Set limits for child process (called in child context)"""
|
||||
if os.name != 'posix':
|
||||
return
|
||||
|
||||
# Drop privileges if running as root (shouldn't happen in production)
|
||||
if os.getuid() == 0:
|
||||
os.setuid(65534) # nobody user
|
||||
os.setgid(65534) # nogroup
|
||||
|
||||
# Set resource limits
|
||||
self._set_resource_limits()
|
||||
|
||||
# Set process group for easier cleanup
|
||||
os.setpgrp()
|
||||
|
||||
async def _monitor_resources(self):
|
||||
"""Monitor resource usage of sandboxed process"""
|
||||
while True:
|
||||
try:
|
||||
if self.process and self.process.returncode is None:
|
||||
# Get process info
|
||||
try:
|
||||
proc = psutil.Process(self.process.pid)
|
||||
|
||||
# Check CPU usage
|
||||
cpu_percent = proc.cpu_percent(interval=0.1)
|
||||
if cpu_percent > self.config.max_cpu_percent:
|
||||
logger.warning(f"Sandbox CPU usage high: {cpu_percent}%")
|
||||
# Could throttle or terminate if consistently high
|
||||
|
||||
# Check memory usage
|
||||
memory_info = proc.memory_info()
|
||||
memory_mb = memory_info.rss / (1024 * 1024)
|
||||
if memory_mb > self.config.max_memory_mb:
|
||||
logger.warning(f"Sandbox memory limit exceeded: {memory_mb}MB")
|
||||
self.process.terminate()
|
||||
break
|
||||
|
||||
# Check runtime
|
||||
if self.start_time:
|
||||
runtime = (datetime.utcnow() - self.start_time).total_seconds()
|
||||
if runtime > self.config.timeout_seconds:
|
||||
logger.warning(f"Sandbox timeout exceeded: {runtime}s")
|
||||
self.process.terminate()
|
||||
break
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass # Process ended or inaccessible
|
||||
|
||||
await asyncio.sleep(1) # Check every second
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Resource monitoring error: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
class ContainerSandbox:
|
||||
"""
|
||||
Container-based sandbox for stronger isolation
|
||||
Uses Docker or Podman for execution
|
||||
"""
|
||||
|
||||
def __init__(self, config: SandboxConfig):
|
||||
self.config = config
|
||||
self.container_id: Optional[str] = None
|
||||
self.container_runtime = self._detect_container_runtime()
|
||||
|
||||
def _detect_container_runtime(self) -> str:
|
||||
"""Detect available container runtime"""
|
||||
# Try Docker first
|
||||
if shutil.which("docker"):
|
||||
return "docker"
|
||||
# Try Podman as alternative
|
||||
elif shutil.which("podman"):
|
||||
return "podman"
|
||||
else:
|
||||
logger.warning("No container runtime found, falling back to process sandbox")
|
||||
return None
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_container(self, image: str = "alpine:latest"):
|
||||
"""Create and manage container lifecycle"""
|
||||
if not self.container_runtime:
|
||||
raise RuntimeError("No container runtime available")
|
||||
|
||||
try:
|
||||
# Create container with resource limits
|
||||
create_cmd = [
|
||||
self.container_runtime, "create",
|
||||
"--rm", # Auto-remove after stop
|
||||
f"--memory={self.config.max_memory_mb}m",
|
||||
f"--cpus={self.config.max_cpu_percent / 100}",
|
||||
"--network=none" if self.config.network_isolation else "--network=bridge",
|
||||
"--read-only" if self.config.readonly_filesystem else "",
|
||||
f"--tmpfs=/tmp:size={self.config.max_disk_mb}m",
|
||||
"--security-opt=no-new-privileges",
|
||||
"--cap-drop=ALL", # Drop all capabilities
|
||||
image,
|
||||
"sleep", "infinity" # Keep container running
|
||||
]
|
||||
|
||||
# Remove empty strings from command
|
||||
create_cmd = [arg for arg in create_cmd if arg]
|
||||
|
||||
# Create container
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*create_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"Failed to create container: {stderr.decode()}")
|
||||
|
||||
self.container_id = stdout.decode().strip()
|
||||
|
||||
# Start container
|
||||
start_cmd = [self.container_runtime, "start", self.container_id]
|
||||
proc = await asyncio.create_subprocess_exec(*start_cmd)
|
||||
await proc.wait()
|
||||
|
||||
logger.info(f"Container sandbox created: {self.container_id[:12]}")
|
||||
|
||||
yield self
|
||||
|
||||
finally:
|
||||
# Cleanup container
|
||||
if self.container_id:
|
||||
stop_cmd = [self.container_runtime, "stop", self.container_id]
|
||||
proc = await asyncio.create_subprocess_exec(*stop_cmd)
|
||||
await proc.wait()
|
||||
|
||||
logger.info(f"Container sandbox cleaned up: {self.container_id[:12]}")
|
||||
|
||||
async def execute(self, command: str, args: list = None) -> Tuple[int, str, str]:
|
||||
"""Execute command in container"""
|
||||
if not self.container_id:
|
||||
raise RuntimeError("Container not created")
|
||||
|
||||
exec_cmd = [
|
||||
self.container_runtime, "exec",
|
||||
self.container_id,
|
||||
command
|
||||
] + (args or [])
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*exec_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
proc.communicate(),
|
||||
timeout=self.config.timeout_seconds
|
||||
)
|
||||
return proc.returncode, stdout.decode(), stderr.decode()
|
||||
except asyncio.TimeoutError:
|
||||
# Kill process in container
|
||||
kill_cmd = [self.container_runtime, "exec", self.container_id, "kill", "-9", "-1"]
|
||||
await asyncio.create_subprocess_exec(*kill_cmd)
|
||||
raise TimeoutError(f"Command exceeded {self.config.timeout_seconds}s timeout")
|
||||
|
||||
|
||||
# Factory function to get appropriate sandbox
|
||||
def create_sandbox(config: SandboxConfig, prefer_container: bool = True) -> Any:
|
||||
"""
|
||||
Create appropriate sandbox based on availability and preference
|
||||
|
||||
Args:
|
||||
config: Sandbox configuration
|
||||
prefer_container: Prefer container over process sandbox
|
||||
|
||||
Returns:
|
||||
ProcessSandbox or ContainerSandbox instance
|
||||
"""
|
||||
if prefer_container and shutil.which("docker"):
|
||||
return ContainerSandbox(config)
|
||||
elif prefer_container and shutil.which("podman"):
|
||||
return ContainerSandbox(config)
|
||||
else:
|
||||
return ProcessSandbox(config)
|
||||
698
apps/resource-cluster/app/services/mcp_server.py
Normal file
698
apps/resource-cluster/app/services/mcp_server.py
Normal file
@@ -0,0 +1,698 @@
|
||||
"""
|
||||
MCP Server Resource Wrapper for GT 2.0
|
||||
|
||||
Encapsulates MCP (Model Context Protocol) servers as GT 2.0 resources.
|
||||
Provides security sandboxing and capability-based access control.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, AsyncIterator
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.models.access_group import AccessGroup, Resource
|
||||
from app.core.security import verify_capability_token
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServerStatus(str, Enum):
|
||||
"""MCP server operational status"""
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNHEALTHY = "unhealthy"
|
||||
STARTING = "starting"
|
||||
STOPPING = "stopping"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPServerConfig:
|
||||
"""Configuration for an MCP server instance"""
|
||||
server_name: str
|
||||
server_url: str
|
||||
server_type: str # filesystem, github, slack, etc.
|
||||
available_tools: List[str]
|
||||
required_capabilities: List[str]
|
||||
|
||||
# Security settings
|
||||
sandbox_mode: bool = True
|
||||
max_memory_mb: int = 512
|
||||
max_cpu_percent: int = 50
|
||||
timeout_seconds: int = 30
|
||||
network_isolation: bool = True
|
||||
|
||||
# Rate limiting
|
||||
max_requests_per_minute: int = 60
|
||||
max_concurrent_requests: int = 5
|
||||
|
||||
# Authentication
|
||||
auth_type: Optional[str] = None # none, api_key, oauth2
|
||||
auth_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class MCPServerResource(Resource):
|
||||
"""
|
||||
MCP server encapsulated as a GT 2.0 resource
|
||||
Inherits from Resource for access control
|
||||
"""
|
||||
|
||||
# MCP-specific configuration
|
||||
server_config: MCPServerConfig
|
||||
|
||||
# Runtime state
|
||||
status: MCPServerStatus = MCPServerStatus.STOPPED
|
||||
last_health_check: Optional[datetime] = None
|
||||
error_count: int = 0
|
||||
total_requests: int = 0
|
||||
|
||||
# Connection management
|
||||
connection_pool_size: int = 5
|
||||
active_connections: int = 0
|
||||
|
||||
def to_capability_requirement(self) -> str:
|
||||
"""Generate capability requirement string for this MCP server"""
|
||||
return f"mcp:{self.server_config.server_name}:*"
|
||||
|
||||
def validate_tool_access(self, tool_name: str, capability_token: Dict[str, Any]) -> bool:
|
||||
"""Check if capability token allows access to specific tool"""
|
||||
required_capability = f"mcp:{self.server_config.server_name}:{tool_name}"
|
||||
|
||||
capabilities = capability_token.get("capabilities", [])
|
||||
for cap in capabilities:
|
||||
resource = cap.get("resource", "")
|
||||
if resource == required_capability or resource == f"mcp:{self.server_config.server_name}:*":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class SecureMCPWrapper:
|
||||
"""
|
||||
Secure wrapper for MCP servers with GT 2.0 security integration
|
||||
Provides sandboxing, rate limiting, and capability-based access
|
||||
"""
|
||||
|
||||
def __init__(self, resource_cluster_url: str = "http://localhost:8004"):
|
||||
self.resource_cluster_url = resource_cluster_url
|
||||
self.mcp_resources: Dict[str, MCPServerResource] = {}
|
||||
self.rate_limiters: Dict[str, asyncio.Semaphore] = {}
|
||||
self.audit_log = []
|
||||
|
||||
async def register_mcp_server(
|
||||
self,
|
||||
server_config: MCPServerConfig,
|
||||
owner_id: str,
|
||||
tenant_domain: str,
|
||||
access_group: AccessGroup = AccessGroup.INDIVIDUAL
|
||||
) -> MCPServerResource:
|
||||
"""
|
||||
Register an MCP server as a GT 2.0 resource
|
||||
|
||||
Args:
|
||||
server_config: MCP server configuration
|
||||
owner_id: User who owns this MCP resource
|
||||
tenant_domain: Tenant domain
|
||||
access_group: Access control level
|
||||
|
||||
Returns:
|
||||
Registered MCP server resource
|
||||
"""
|
||||
# Create MCP resource
|
||||
resource = MCPServerResource(
|
||||
id=f"mcp-{server_config.server_name}-{datetime.utcnow().timestamp()}",
|
||||
name=f"MCP Server: {server_config.server_name}",
|
||||
resource_type="mcp_server",
|
||||
owner_id=owner_id,
|
||||
tenant_domain=tenant_domain,
|
||||
access_group=access_group,
|
||||
team_members=[],
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
metadata={
|
||||
"server_type": server_config.server_type,
|
||||
"tools_count": len(server_config.available_tools)
|
||||
},
|
||||
server_config=server_config
|
||||
)
|
||||
|
||||
# Initialize rate limiter
|
||||
self.rate_limiters[resource.id] = asyncio.Semaphore(
|
||||
server_config.max_concurrent_requests
|
||||
)
|
||||
|
||||
# Store resource
|
||||
self.mcp_resources[resource.id] = resource
|
||||
|
||||
# Start health monitoring
|
||||
asyncio.create_task(self._monitor_health(resource.id))
|
||||
|
||||
logger.info(f"Registered MCP server: {server_config.server_name} as resource {resource.id}")
|
||||
|
||||
return resource
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
mcp_resource_id: str,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
capability_token: str,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute an MCP tool with security constraints
|
||||
|
||||
Args:
|
||||
mcp_resource_id: MCP resource identifier
|
||||
tool_name: Tool to execute
|
||||
parameters: Tool parameters
|
||||
capability_token: JWT capability token
|
||||
user_id: User executing the tool
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
# Load MCP resource
|
||||
mcp_resource = self.mcp_resources.get(mcp_resource_id)
|
||||
if not mcp_resource:
|
||||
raise ValueError(f"MCP resource not found: {mcp_resource_id}")
|
||||
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Check tenant match
|
||||
if token_data.get("tenant_id") != mcp_resource.tenant_domain:
|
||||
raise PermissionError("Tenant mismatch")
|
||||
|
||||
# Validate tool access
|
||||
if not mcp_resource.validate_tool_access(tool_name, token_data):
|
||||
raise PermissionError(f"No capability for tool: {tool_name}")
|
||||
|
||||
# Check if tool exists
|
||||
if tool_name not in mcp_resource.server_config.available_tools:
|
||||
raise ValueError(f"Tool not available: {tool_name}")
|
||||
|
||||
# Apply rate limiting
|
||||
async with self.rate_limiters[mcp_resource_id]:
|
||||
try:
|
||||
# Execute tool with timeout and sandboxing
|
||||
result = await self._execute_tool_sandboxed(
|
||||
mcp_resource, tool_name, parameters, user_id
|
||||
)
|
||||
|
||||
# Update metrics
|
||||
mcp_resource.total_requests += 1
|
||||
|
||||
# Audit log
|
||||
self._log_tool_execution(
|
||||
mcp_resource_id, tool_name, user_id, "success", result
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Update error metrics
|
||||
mcp_resource.error_count += 1
|
||||
|
||||
# Audit log
|
||||
self._log_tool_execution(
|
||||
mcp_resource_id, tool_name, user_id, "error", str(e)
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
async def _execute_tool_sandboxed(
|
||||
self,
|
||||
mcp_resource: MCPServerResource,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute tool in sandboxed environment"""
|
||||
|
||||
# Create sandbox context
|
||||
sandbox_context = {
|
||||
"user_id": user_id,
|
||||
"tenant_domain": mcp_resource.tenant_domain,
|
||||
"resource_limits": {
|
||||
"max_memory_mb": mcp_resource.server_config.max_memory_mb,
|
||||
"max_cpu_percent": mcp_resource.server_config.max_cpu_percent,
|
||||
"timeout_seconds": mcp_resource.server_config.timeout_seconds
|
||||
},
|
||||
"network_isolation": mcp_resource.server_config.network_isolation
|
||||
}
|
||||
|
||||
# Execute based on server type
|
||||
if mcp_resource.server_config.server_type == "filesystem":
|
||||
return await self._execute_filesystem_tool(
|
||||
tool_name, parameters, sandbox_context
|
||||
)
|
||||
elif mcp_resource.server_config.server_type == "github":
|
||||
return await self._execute_github_tool(
|
||||
tool_name, parameters, sandbox_context
|
||||
)
|
||||
elif mcp_resource.server_config.server_type == "slack":
|
||||
return await self._execute_slack_tool(
|
||||
tool_name, parameters, sandbox_context
|
||||
)
|
||||
elif mcp_resource.server_config.server_type == "web":
|
||||
return await self._execute_web_tool(
|
||||
tool_name, parameters, sandbox_context
|
||||
)
|
||||
elif mcp_resource.server_config.server_type == "database":
|
||||
return await self._execute_database_tool(
|
||||
tool_name, parameters, sandbox_context
|
||||
)
|
||||
else:
|
||||
return await self._execute_custom_tool(
|
||||
mcp_resource, tool_name, parameters, sandbox_context
|
||||
)
|
||||
|
||||
async def _execute_filesystem_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
sandbox_context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute filesystem MCP tools"""
|
||||
|
||||
if tool_name == "read_file":
|
||||
# Simulate file reading with sandbox constraints
|
||||
file_path = parameters.get("path", "")
|
||||
|
||||
# Security validation
|
||||
if not self._validate_file_path(file_path, sandbox_context):
|
||||
raise PermissionError("Access denied to file path")
|
||||
|
||||
return {
|
||||
"tool": "read_file",
|
||||
"content": f"File content from {file_path}",
|
||||
"size_bytes": 1024,
|
||||
"mime_type": "text/plain"
|
||||
}
|
||||
|
||||
elif tool_name == "write_file":
|
||||
file_path = parameters.get("path", "")
|
||||
content = parameters.get("content", "")
|
||||
|
||||
# Security validation
|
||||
if not self._validate_file_path(file_path, sandbox_context):
|
||||
raise PermissionError("Access denied to file path")
|
||||
|
||||
if len(content) > 1024 * 1024: # 1MB limit
|
||||
raise ValueError("File content too large")
|
||||
|
||||
return {
|
||||
"tool": "write_file",
|
||||
"path": file_path,
|
||||
"bytes_written": len(content),
|
||||
"status": "success"
|
||||
}
|
||||
|
||||
elif tool_name == "list_directory":
|
||||
dir_path = parameters.get("path", "")
|
||||
|
||||
if not self._validate_file_path(dir_path, sandbox_context):
|
||||
raise PermissionError("Access denied to directory path")
|
||||
|
||||
return {
|
||||
"tool": "list_directory",
|
||||
"path": dir_path,
|
||||
"entries": ["file1.txt", "file2.txt", "subdir/"],
|
||||
"total_entries": 3
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown filesystem tool: {tool_name}")
|
||||
|
||||
async def _execute_github_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
sandbox_context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute GitHub MCP tools"""
|
||||
|
||||
if tool_name == "get_repository":
|
||||
repo_name = parameters.get("repository", "")
|
||||
|
||||
return {
|
||||
"tool": "get_repository",
|
||||
"repository": repo_name,
|
||||
"owner": "example",
|
||||
"description": "Example repository",
|
||||
"language": "Python",
|
||||
"stars": 123,
|
||||
"forks": 45
|
||||
}
|
||||
|
||||
elif tool_name == "create_issue":
|
||||
title = parameters.get("title", "")
|
||||
body = parameters.get("body", "")
|
||||
|
||||
return {
|
||||
"tool": "create_issue",
|
||||
"issue_number": 42,
|
||||
"title": title,
|
||||
"url": f"https://github.com/example/repo/issues/42",
|
||||
"status": "created"
|
||||
}
|
||||
|
||||
elif tool_name == "search_code":
|
||||
query = parameters.get("query", "")
|
||||
|
||||
return {
|
||||
"tool": "search_code",
|
||||
"query": query,
|
||||
"results": [
|
||||
{
|
||||
"file": "main.py",
|
||||
"line": 15,
|
||||
"content": f"# Code matching {query}"
|
||||
}
|
||||
],
|
||||
"total_results": 1
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown GitHub tool: {tool_name}")
|
||||
|
||||
async def _execute_slack_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
sandbox_context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute Slack MCP tools"""
|
||||
|
||||
if tool_name == "send_message":
|
||||
channel = parameters.get("channel", "")
|
||||
message = parameters.get("message", "")
|
||||
|
||||
return {
|
||||
"tool": "send_message",
|
||||
"channel": channel,
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"status": "sent"
|
||||
}
|
||||
|
||||
elif tool_name == "get_channel_history":
|
||||
channel = parameters.get("channel", "")
|
||||
limit = parameters.get("limit", 10)
|
||||
|
||||
return {
|
||||
"tool": "get_channel_history",
|
||||
"channel": channel,
|
||||
"messages": [
|
||||
{
|
||||
"user": "user1",
|
||||
"text": "Hello world!",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
] * min(limit, 10),
|
||||
"total_messages": limit
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown Slack tool: {tool_name}")
|
||||
|
||||
async def _execute_web_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
sandbox_context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute web MCP tools"""
|
||||
|
||||
if tool_name == "fetch_url":
|
||||
url = parameters.get("url", "")
|
||||
|
||||
# URL validation
|
||||
if not self._validate_url(url, sandbox_context):
|
||||
raise PermissionError("Access denied to URL")
|
||||
|
||||
return {
|
||||
"tool": "fetch_url",
|
||||
"url": url,
|
||||
"status_code": 200,
|
||||
"content": f"Content from {url}",
|
||||
"headers": {"content-type": "text/html"}
|
||||
}
|
||||
|
||||
elif tool_name == "submit_form":
|
||||
url = parameters.get("url", "")
|
||||
form_data = parameters.get("form_data", {})
|
||||
|
||||
if not self._validate_url(url, sandbox_context):
|
||||
raise PermissionError("Access denied to URL")
|
||||
|
||||
return {
|
||||
"tool": "submit_form",
|
||||
"url": url,
|
||||
"form_data": form_data,
|
||||
"status_code": 200,
|
||||
"response": "Form submitted successfully"
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown web tool: {tool_name}")
|
||||
|
||||
async def _execute_database_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
sandbox_context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute database MCP tools"""
|
||||
|
||||
if tool_name == "execute_query":
|
||||
query = parameters.get("query", "")
|
||||
|
||||
# Query validation
|
||||
if not self._validate_sql_query(query, sandbox_context):
|
||||
raise PermissionError("Query not allowed")
|
||||
|
||||
return {
|
||||
"tool": "execute_query",
|
||||
"query": query,
|
||||
"rows": [
|
||||
{"id": 1, "name": "Example"},
|
||||
{"id": 2, "name": "Data"}
|
||||
],
|
||||
"row_count": 2,
|
||||
"execution_time_ms": 15
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown database tool: {tool_name}")
|
||||
|
||||
async def _execute_custom_tool(
|
||||
self,
|
||||
mcp_resource: MCPServerResource,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
sandbox_context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute custom MCP tool via WebSocket transport"""
|
||||
|
||||
# This would connect to the actual MCP server via WebSocket
|
||||
# For now, simulate the execution
|
||||
|
||||
await asyncio.sleep(0.1) # Simulate network delay
|
||||
|
||||
return {
|
||||
"tool": tool_name,
|
||||
"parameters": parameters,
|
||||
"result": f"Custom tool {tool_name} executed successfully",
|
||||
"server_type": mcp_resource.server_config.server_type,
|
||||
"execution_time_ms": 100
|
||||
}
|
||||
|
||||
def _validate_file_path(
|
||||
self,
|
||||
file_path: str,
|
||||
sandbox_context: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Validate file path for security"""
|
||||
|
||||
# Basic path traversal prevention
|
||||
if ".." in file_path or file_path.startswith("/"):
|
||||
return False
|
||||
|
||||
# Check allowed extensions
|
||||
allowed_extensions = [".txt", ".md", ".json", ".py", ".js"]
|
||||
if not any(file_path.endswith(ext) for ext in allowed_extensions):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _validate_url(
|
||||
self,
|
||||
url: str,
|
||||
sandbox_context: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Validate URL for security"""
|
||||
|
||||
# Basic URL validation
|
||||
if not url.startswith(("http://", "https://")):
|
||||
return False
|
||||
|
||||
# Block internal/localhost URLs if network isolation enabled
|
||||
if sandbox_context.get("network_isolation", True):
|
||||
if any(domain in url for domain in ["localhost", "127.0.0.1", "10.", "192.168."]):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _validate_sql_query(
|
||||
self,
|
||||
query: str,
|
||||
sandbox_context: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Validate SQL query for security"""
|
||||
|
||||
# Block dangerous SQL operations
|
||||
dangerous_keywords = [
|
||||
"DROP", "DELETE", "UPDATE", "INSERT", "CREATE", "ALTER",
|
||||
"TRUNCATE", "EXEC", "EXECUTE", "xp_", "sp_"
|
||||
]
|
||||
|
||||
query_upper = query.upper()
|
||||
for keyword in dangerous_keywords:
|
||||
if keyword in query_upper:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _log_tool_execution(
|
||||
self,
|
||||
mcp_resource_id: str,
|
||||
tool_name: str,
|
||||
user_id: str,
|
||||
status: str,
|
||||
result: Any
|
||||
) -> None:
|
||||
"""Log tool execution for audit"""
|
||||
|
||||
log_entry = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"mcp_resource_id": mcp_resource_id,
|
||||
"tool_name": tool_name,
|
||||
"user_id": user_id,
|
||||
"status": status,
|
||||
"result_summary": str(result)[:200] if result else None
|
||||
}
|
||||
|
||||
self.audit_log.append(log_entry)
|
||||
|
||||
# Keep only last 1000 entries
|
||||
if len(self.audit_log) > 1000:
|
||||
self.audit_log = self.audit_log[-1000:]
|
||||
|
||||
async def _monitor_health(self, mcp_resource_id: str) -> None:
|
||||
"""Monitor MCP server health"""
|
||||
|
||||
while mcp_resource_id in self.mcp_resources:
|
||||
try:
|
||||
mcp_resource = self.mcp_resources[mcp_resource_id]
|
||||
|
||||
# Simulate health check
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
|
||||
# Update health status
|
||||
if mcp_resource.error_count > 10:
|
||||
mcp_resource.status = MCPServerStatus.DEGRADED
|
||||
elif mcp_resource.error_count > 50:
|
||||
mcp_resource.status = MCPServerStatus.UNHEALTHY
|
||||
else:
|
||||
mcp_resource.status = MCPServerStatus.HEALTHY
|
||||
|
||||
mcp_resource.last_health_check = datetime.utcnow()
|
||||
|
||||
logger.debug(f"Health check for MCP resource {mcp_resource_id}: {mcp_resource.status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed for MCP resource {mcp_resource_id}: {e}")
|
||||
|
||||
if mcp_resource_id in self.mcp_resources:
|
||||
self.mcp_resources[mcp_resource_id].status = MCPServerStatus.UNHEALTHY
|
||||
|
||||
async def get_resource_status(
|
||||
self,
|
||||
mcp_resource_id: str,
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get MCP resource status"""
|
||||
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Load MCP resource
|
||||
mcp_resource = self.mcp_resources.get(mcp_resource_id)
|
||||
if not mcp_resource:
|
||||
raise ValueError(f"MCP resource not found: {mcp_resource_id}")
|
||||
|
||||
# Check tenant match
|
||||
if token_data.get("tenant_id") != mcp_resource.tenant_domain:
|
||||
raise PermissionError("Tenant mismatch")
|
||||
|
||||
return {
|
||||
"resource_id": mcp_resource_id,
|
||||
"name": mcp_resource.name,
|
||||
"server_type": mcp_resource.server_config.server_type,
|
||||
"status": mcp_resource.status,
|
||||
"total_requests": mcp_resource.total_requests,
|
||||
"error_count": mcp_resource.error_count,
|
||||
"active_connections": mcp_resource.active_connections,
|
||||
"last_health_check": mcp_resource.last_health_check.isoformat() if mcp_resource.last_health_check else None,
|
||||
"available_tools": mcp_resource.server_config.available_tools
|
||||
}
|
||||
|
||||
async def list_mcp_resources(
|
||||
self,
|
||||
capability_token: str,
|
||||
tenant_domain: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List available MCP resources"""
|
||||
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
tenant_filter = tenant_domain or token_data.get("tenant_id")
|
||||
|
||||
resources = []
|
||||
for resource in self.mcp_resources.values():
|
||||
if resource.tenant_domain == tenant_filter:
|
||||
resources.append({
|
||||
"resource_id": resource.id,
|
||||
"name": resource.name,
|
||||
"server_type": resource.server_config.server_type,
|
||||
"status": resource.status,
|
||||
"tool_count": len(resource.server_config.available_tools),
|
||||
"created_at": resource.created_at.isoformat()
|
||||
})
|
||||
|
||||
return resources
|
||||
|
||||
|
||||
# Global MCP wrapper instance
|
||||
_mcp_wrapper = None
|
||||
|
||||
|
||||
def get_mcp_wrapper() -> SecureMCPWrapper:
|
||||
"""Get the global MCP wrapper instance"""
|
||||
global _mcp_wrapper
|
||||
if _mcp_wrapper is None:
|
||||
_mcp_wrapper = SecureMCPWrapper()
|
||||
return _mcp_wrapper
|
||||
296
apps/resource-cluster/app/services/model_router.py
Normal file
296
apps/resource-cluster/app/services/model_router.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
GT 2.0 Model Router
|
||||
|
||||
Routes inference requests to appropriate providers based on model registry.
|
||||
Integrates with provider factory for dynamic provider selection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, AsyncIterator
|
||||
from datetime import datetime
|
||||
|
||||
from app.services.model_service import get_model_service
|
||||
from app.providers import get_provider_factory
|
||||
from app.core.backends import get_backend
|
||||
from app.core.exceptions import ProviderError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
"""Routes model requests to appropriate providers"""
|
||||
|
||||
def __init__(self, tenant_id: Optional[str] = None):
|
||||
self.tenant_id = tenant_id
|
||||
# Use default model service for shared model registry (config sync writes to default)
|
||||
# Note: Tenant isolation is handled via capability tokens, not separate databases
|
||||
self.model_service = get_model_service(None)
|
||||
self.provider_factory = None
|
||||
self.backend_cache = {}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize model router"""
|
||||
try:
|
||||
self.provider_factory = await get_provider_factory()
|
||||
logger.info(f"Model router initialized for tenant: {self.tenant_id or 'default'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model router: {e}")
|
||||
raise
|
||||
|
||||
async def route_inference(
|
||||
self,
|
||||
model_id: str,
|
||||
prompt: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4000,
|
||||
stream: bool = False,
|
||||
user_id: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Route inference request to appropriate provider"""
|
||||
|
||||
# Get model configuration from registry
|
||||
model_config = await self.model_service.get_model(model_id)
|
||||
if not model_config:
|
||||
raise ProviderError(f"Model {model_id} not found in registry")
|
||||
|
||||
provider = model_config["provider"]
|
||||
|
||||
# Track model usage
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# Route to configured endpoint (generic routing for any provider)
|
||||
endpoint_url = model_config.get("endpoint")
|
||||
if not endpoint_url:
|
||||
raise ProviderError(f"No endpoint configured for model {model_id}")
|
||||
|
||||
result = await self._route_to_generic_endpoint(
|
||||
endpoint_url, model_id, prompt, messages, temperature, max_tokens, stream, user_id, tenant_id, tools, tool_choice, **kwargs
|
||||
)
|
||||
|
||||
# Calculate latency
|
||||
latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
# Track successful usage
|
||||
await self.model_service.track_model_usage(
|
||||
model_id, success=True, latency_ms=latency_ms
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Track failed usage
|
||||
latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
await self.model_service.track_model_usage(
|
||||
model_id, success=False, latency_ms=latency_ms
|
||||
)
|
||||
logger.error(f"Model routing failed for {model_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _route_to_groq(
|
||||
self,
|
||||
model_id: str,
|
||||
prompt: Optional[str],
|
||||
messages: Optional[list],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stream: bool,
|
||||
user_id: Optional[str],
|
||||
tenant_id: Optional[str],
|
||||
tools: Optional[list],
|
||||
tool_choice: Optional[str],
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Route request to Groq backend"""
|
||||
try:
|
||||
backend = get_backend("groq_proxy")
|
||||
if not backend:
|
||||
raise ProviderError("Groq backend not available")
|
||||
|
||||
if messages:
|
||||
return await backend.execute_inference_with_messages(
|
||||
messages=messages,
|
||||
model=model_id,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=stream,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice
|
||||
)
|
||||
else:
|
||||
return await backend.execute_inference(
|
||||
prompt=prompt,
|
||||
model=model_id,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=stream,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Groq routing failed: {e}")
|
||||
raise ProviderError(f"Groq inference failed: {e}")
|
||||
|
||||
async def _route_to_external(
|
||||
self,
|
||||
model_id: str,
|
||||
prompt: Optional[str],
|
||||
messages: Optional[list],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stream: bool,
|
||||
user_id: Optional[str],
|
||||
tenant_id: Optional[str],
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Route request to external provider"""
|
||||
try:
|
||||
if not self.provider_factory:
|
||||
await self.initialize()
|
||||
|
||||
external_provider = self.provider_factory.get_provider("external")
|
||||
if not external_provider:
|
||||
raise ProviderError("External provider not available")
|
||||
|
||||
# For embedding models
|
||||
if model_id == "bge-m3-embedding":
|
||||
# Convert prompt/messages to text list
|
||||
texts = []
|
||||
if messages:
|
||||
texts = [msg.get("content", "") for msg in messages if msg.get("content")]
|
||||
elif prompt:
|
||||
texts = [prompt]
|
||||
|
||||
return await external_provider.generate_embeddings(
|
||||
model_id=model_id,
|
||||
texts=texts
|
||||
)
|
||||
else:
|
||||
raise ProviderError(f"External model {model_id} not supported for inference")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"External routing failed: {e}")
|
||||
raise ProviderError(f"External inference failed: {e}")
|
||||
|
||||
async def _route_to_openai(
|
||||
self,
|
||||
model_id: str,
|
||||
prompt: Optional[str],
|
||||
messages: Optional[list],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stream: bool,
|
||||
user_id: Optional[str],
|
||||
tenant_id: Optional[str],
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Route request to OpenAI provider"""
|
||||
raise ProviderError("OpenAI provider not implemented - use Groq models instead")
|
||||
|
||||
async def _route_to_generic_endpoint(
|
||||
self,
|
||||
endpoint_url: str,
|
||||
model_id: str,
|
||||
prompt: Optional[str],
|
||||
messages: Optional[list],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stream: bool,
|
||||
user_id: Optional[str],
|
||||
tenant_id: Optional[str],
|
||||
tools: Optional[list],
|
||||
tool_choice: Optional[str],
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Route request to any configured endpoint using OpenAI-compatible API"""
|
||||
import httpx
|
||||
import time
|
||||
|
||||
try:
|
||||
# Build OpenAI-compatible request
|
||||
request_data = {
|
||||
"model": model_id,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# Use messages if provided, otherwise convert prompt to messages
|
||||
if messages:
|
||||
request_data["messages"] = messages
|
||||
elif prompt:
|
||||
request_data["messages"] = [{"role": "user", "content": prompt}]
|
||||
else:
|
||||
raise ProviderError("Either messages or prompt must be provided")
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
request_data["tools"] = tools
|
||||
if tool_choice:
|
||||
request_data["tool_choice"] = tool_choice
|
||||
|
||||
# Add any additional parameters
|
||||
request_data.update(kwargs)
|
||||
|
||||
logger.info(f"Routing request to endpoint: {endpoint_url}")
|
||||
logger.debug(f"Request data: {request_data}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"Endpoint {endpoint_url} returned {response.status_code}: {error_text}")
|
||||
raise ProviderError(f"Endpoint error: {response.status_code} - {error_text}")
|
||||
|
||||
result = response.json()
|
||||
logger.debug(f"Endpoint response: {result}")
|
||||
return result
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request to {endpoint_url} failed: {e}")
|
||||
raise ProviderError(f"Connection to endpoint failed: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Generic endpoint routing failed: {e}")
|
||||
raise ProviderError(f"Inference failed: {str(e)}")
|
||||
|
||||
async def list_available_models(self) -> list:
|
||||
"""List all available models from registry"""
|
||||
# Get all models (deployment status filtering available if needed)
|
||||
models = await self.model_service.list_models()
|
||||
return models
|
||||
|
||||
async def get_model_health(self, model_id: str) -> Dict[str, Any]:
|
||||
"""Check health of specific model"""
|
||||
return await self.model_service.check_model_health(model_id)
|
||||
|
||||
|
||||
# Global model router instances per tenant
|
||||
_model_routers = {}
|
||||
|
||||
|
||||
async def get_model_router(tenant_id: Optional[str] = None) -> ModelRouter:
|
||||
"""Get model router instance for tenant"""
|
||||
global _model_routers
|
||||
|
||||
cache_key = tenant_id or "default"
|
||||
|
||||
if cache_key not in _model_routers:
|
||||
router = ModelRouter(tenant_id)
|
||||
await router.initialize()
|
||||
_model_routers[cache_key] = router
|
||||
|
||||
return _model_routers[cache_key]
|
||||
720
apps/resource-cluster/app/services/model_service.py
Normal file
720
apps/resource-cluster/app/services/model_service.py
Normal file
@@ -0,0 +1,720 @@
|
||||
"""
|
||||
GT 2.0 Model Management Service - Stateless Version
|
||||
|
||||
Provides centralized model registry, versioning, deployment, and lifecycle management
|
||||
for all AI models across the Resource Cluster using in-memory storage.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class ModelService:
|
||||
"""Stateless model management service with in-memory registry"""
|
||||
|
||||
def __init__(self, tenant_id: Optional[str] = None):
|
||||
self.tenant_id = tenant_id
|
||||
self.settings = get_settings(tenant_id)
|
||||
|
||||
# In-memory model registry for stateless operation
|
||||
self.model_registry: Dict[str, Dict[str, Any]] = {}
|
||||
self.last_cache_update = 0
|
||||
self.cache_ttl = 300 # 5 minutes
|
||||
|
||||
# Performance tracking (in-memory)
|
||||
self.performance_metrics: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Initialize with default models synchronously
|
||||
self._initialize_default_models_sync()
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
name: str,
|
||||
version: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
description: str = "",
|
||||
capabilities: Dict[str, Any] = None,
|
||||
parameters: Dict[str, Any] = None,
|
||||
endpoint_url: str = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Register a new model in the in-memory registry"""
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Create or update model entry
|
||||
model_entry = {
|
||||
"id": model_id,
|
||||
"name": name,
|
||||
"version": version,
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"description": description,
|
||||
"capabilities": capabilities or {},
|
||||
"parameters": parameters or {},
|
||||
|
||||
# Performance metrics
|
||||
"max_tokens": kwargs.get("max_tokens", 4000),
|
||||
"context_window": kwargs.get("context_window", 4000),
|
||||
"cost_per_1k_tokens": kwargs.get("cost_per_1k_tokens", 0.0),
|
||||
"latency_p50_ms": kwargs.get("latency_p50_ms", 0.0),
|
||||
"latency_p95_ms": kwargs.get("latency_p95_ms", 0.0),
|
||||
|
||||
# Deployment status
|
||||
"deployment_status": kwargs.get("deployment_status", "available"),
|
||||
"health_status": kwargs.get("health_status", "unknown"),
|
||||
"last_health_check": kwargs.get("last_health_check"),
|
||||
|
||||
# Usage tracking
|
||||
"request_count": kwargs.get("request_count", 0),
|
||||
"error_count": kwargs.get("error_count", 0),
|
||||
"success_rate": kwargs.get("success_rate", 1.0),
|
||||
|
||||
# Lifecycle
|
||||
"created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat(),
|
||||
"retired_at": kwargs.get("retired_at"),
|
||||
|
||||
# Configuration
|
||||
"endpoint_url": endpoint_url,
|
||||
"api_key_required": kwargs.get("api_key_required", True),
|
||||
"rate_limits": kwargs.get("rate_limits", {})
|
||||
}
|
||||
|
||||
self.model_registry[model_id] = model_entry
|
||||
|
||||
logger.info(f"Registered model: {model_id} ({name} v{version})")
|
||||
return model_entry
|
||||
|
||||
async def get_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get model by ID"""
|
||||
return self.model_registry.get(model_id)
|
||||
|
||||
async def list_models(
|
||||
self,
|
||||
provider: str = None,
|
||||
model_type: str = None,
|
||||
deployment_status: str = None,
|
||||
health_status: str = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List models with optional filters"""
|
||||
|
||||
models = list(self.model_registry.values())
|
||||
|
||||
# Apply filters
|
||||
if provider:
|
||||
models = [m for m in models if m["provider"] == provider]
|
||||
if model_type:
|
||||
models = [m for m in models if m["model_type"] == model_type]
|
||||
if deployment_status:
|
||||
models = [m for m in models if m["deployment_status"] == deployment_status]
|
||||
if health_status:
|
||||
models = [m for m in models if m["health_status"] == health_status]
|
||||
|
||||
# Sort by created_at desc
|
||||
models.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
return models
|
||||
|
||||
async def update_model_status(
|
||||
self,
|
||||
model_id: str,
|
||||
deployment_status: str = None,
|
||||
health_status: str = None
|
||||
) -> bool:
|
||||
"""Update model deployment and health status"""
|
||||
|
||||
model = self.model_registry.get(model_id)
|
||||
if not model:
|
||||
return False
|
||||
|
||||
if deployment_status:
|
||||
model["deployment_status"] = deployment_status
|
||||
if health_status:
|
||||
model["health_status"] = health_status
|
||||
model["last_health_check"] = datetime.utcnow().isoformat()
|
||||
|
||||
model["updated_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
return True
|
||||
|
||||
async def track_model_usage(
|
||||
self,
|
||||
model_id: str,
|
||||
success: bool = True,
|
||||
latency_ms: float = None
|
||||
):
|
||||
"""Track model usage and performance metrics"""
|
||||
|
||||
model = self.model_registry.get(model_id)
|
||||
if not model:
|
||||
return
|
||||
|
||||
# Update usage counters
|
||||
model["request_count"] += 1
|
||||
if not success:
|
||||
model["error_count"] += 1
|
||||
|
||||
# Calculate success rate
|
||||
model["success_rate"] = (model["request_count"] - model["error_count"]) / model["request_count"]
|
||||
|
||||
# Update latency metrics (simple running average)
|
||||
if latency_ms is not None:
|
||||
if model["latency_p50_ms"] == 0:
|
||||
model["latency_p50_ms"] = latency_ms
|
||||
else:
|
||||
# Simple exponential moving average
|
||||
alpha = 0.1
|
||||
model["latency_p50_ms"] = alpha * latency_ms + (1 - alpha) * model["latency_p50_ms"]
|
||||
|
||||
# P95 approximation (conservative estimate)
|
||||
model["latency_p95_ms"] = max(model["latency_p95_ms"], latency_ms * 1.5)
|
||||
|
||||
model["updated_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
async def retire_model(self, model_id: str, reason: str = "") -> bool:
|
||||
"""Retire a model (mark as no longer available)"""
|
||||
|
||||
model = self.model_registry.get(model_id)
|
||||
if not model:
|
||||
return False
|
||||
|
||||
model["deployment_status"] = "retired"
|
||||
model["retired_at"] = datetime.utcnow().isoformat()
|
||||
model["updated_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
if reason:
|
||||
model["description"] += f"\n\nRetired: {reason}"
|
||||
|
||||
logger.info(f"Retired model: {model_id} - {reason}")
|
||||
return True
|
||||
|
||||
async def check_model_health(self, model_id: str) -> Dict[str, Any]:
|
||||
"""Check health of a specific model"""
|
||||
|
||||
model = await self.get_model(model_id)
|
||||
if not model:
|
||||
return {"healthy": False, "error": "Model not found"}
|
||||
|
||||
# Generic health check for any provider with endpoint
|
||||
if "endpoint" in model and model["endpoint"]:
|
||||
return await self._check_generic_model_health(model)
|
||||
elif model["provider"] == "groq":
|
||||
return await self._check_groq_model_health(model)
|
||||
elif model["provider"] == "openai":
|
||||
return await self._check_openai_model_health(model)
|
||||
elif model["provider"] == "local":
|
||||
return await self._check_local_model_health(model)
|
||||
else:
|
||||
return {"healthy": False, "error": f"No health check method for provider: {model['provider']}"}
|
||||
|
||||
async def _check_groq_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Health check for Groq models"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"https://api.groq.com/openai/v1/models",
|
||||
headers={"Authorization": f"Bearer {settings.groq_api_key}"},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
models = response.json()
|
||||
model_ids = [m["id"] for m in models.get("data", [])]
|
||||
is_available = model["id"] in model_ids
|
||||
|
||||
await self.update_model_status(
|
||||
model["id"],
|
||||
health_status="healthy" if is_available else "unhealthy"
|
||||
)
|
||||
|
||||
return {
|
||||
"healthy": is_available,
|
||||
"latency_ms": response.elapsed.total_seconds() * 1000,
|
||||
"available_models": len(model_ids)
|
||||
}
|
||||
else:
|
||||
await self.update_model_status(model["id"], health_status="unhealthy")
|
||||
return {"healthy": False, "error": f"API error: {response.status_code}"}
|
||||
|
||||
except Exception as e:
|
||||
await self.update_model_status(model["id"], health_status="unhealthy")
|
||||
return {"healthy": False, "error": str(e)}
|
||||
|
||||
async def _check_openai_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Health check for OpenAI models"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"https://api.openai.com/v1/models",
|
||||
headers={"Authorization": f"Bearer {settings.openai_api_key}"},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
models = response.json()
|
||||
model_ids = [m["id"] for m in models.get("data", [])]
|
||||
is_available = model["id"] in model_ids
|
||||
|
||||
await self.update_model_status(
|
||||
model["id"],
|
||||
health_status="healthy" if is_available else "unhealthy"
|
||||
)
|
||||
|
||||
return {
|
||||
"healthy": is_available,
|
||||
"latency_ms": response.elapsed.total_seconds() * 1000
|
||||
}
|
||||
else:
|
||||
await self.update_model_status(model["id"], health_status="unhealthy")
|
||||
return {"healthy": False, "error": f"API error: {response.status_code}"}
|
||||
|
||||
except Exception as e:
|
||||
await self.update_model_status(model["id"], health_status="unhealthy")
|
||||
return {"healthy": False, "error": str(e)}
|
||||
|
||||
async def _check_generic_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generic health check for any provider with configured endpoint"""
|
||||
try:
|
||||
endpoint_url = model.get("endpoint")
|
||||
if not endpoint_url:
|
||||
return {"healthy": False, "error": "No endpoint URL configured"}
|
||||
|
||||
# Try a simple health check by making a minimal request
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# For OpenAI-compatible endpoints, try a models list request
|
||||
try:
|
||||
# Try /v1/models endpoint first (common for OpenAI-compatible APIs)
|
||||
models_url = endpoint_url.replace("/chat/completions", "/models").replace("/v1/chat/completions", "/v1/models")
|
||||
response = await client.get(models_url)
|
||||
|
||||
if response.status_code == 200:
|
||||
await self.update_model_status(model["id"], health_status="healthy")
|
||||
return {
|
||||
"healthy": True,
|
||||
"provider": model.get("provider", "unknown"),
|
||||
"latency_ms": 0, # Could measure actual latency
|
||||
"last_check": datetime.utcnow().isoformat(),
|
||||
"details": "Endpoint responding to models request"
|
||||
}
|
||||
except:
|
||||
pass
|
||||
|
||||
# If models endpoint doesn't work, try a basic health endpoint
|
||||
try:
|
||||
health_url = endpoint_url.replace("/chat/completions", "/health").replace("/v1/chat/completions", "/health")
|
||||
response = await client.get(health_url)
|
||||
|
||||
if response.status_code == 200:
|
||||
await self.update_model_status(model["id"], health_status="healthy")
|
||||
return {
|
||||
"healthy": True,
|
||||
"provider": model.get("provider", "unknown"),
|
||||
"latency_ms": 0,
|
||||
"last_check": datetime.utcnow().isoformat(),
|
||||
"details": "Endpoint responding to health check"
|
||||
}
|
||||
except:
|
||||
pass
|
||||
|
||||
# If neither works, assume healthy if endpoint is reachable at all
|
||||
await self.update_model_status(model["id"], health_status="unknown")
|
||||
return {
|
||||
"healthy": True, # Assume healthy for generic endpoints
|
||||
"provider": model.get("provider", "unknown"),
|
||||
"latency_ms": 0,
|
||||
"last_check": datetime.utcnow().isoformat(),
|
||||
"details": "Generic endpoint - health check not available"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await self.update_model_status(model["id"], health_status="unhealthy")
|
||||
return {"healthy": False, "error": f"Health check failed: {str(e)}"}
|
||||
|
||||
async def _check_local_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Health check for local models"""
|
||||
try:
|
||||
endpoint_url = model.get("endpoint_url")
|
||||
if not endpoint_url:
|
||||
return {"healthy": False, "error": "No endpoint URL configured"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{endpoint_url}/health",
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
healthy = response.status_code == 200
|
||||
await self.update_model_status(
|
||||
model["id"],
|
||||
health_status="healthy" if healthy else "unhealthy"
|
||||
)
|
||||
|
||||
return {
|
||||
"healthy": healthy,
|
||||
"latency_ms": response.elapsed.total_seconds() * 1000
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await self.update_model_status(model["id"], health_status="unhealthy")
|
||||
return {"healthy": False, "error": str(e)}
|
||||
|
||||
async def bulk_health_check(self) -> Dict[str, Any]:
|
||||
"""Check health of all registered models"""
|
||||
|
||||
models = await self.list_models()
|
||||
health_results = {}
|
||||
|
||||
# Run health checks concurrently
|
||||
tasks = []
|
||||
for model in models:
|
||||
task = asyncio.create_task(self.check_model_health(model["id"]))
|
||||
tasks.append((model["id"], task))
|
||||
|
||||
for model_id, task in tasks:
|
||||
try:
|
||||
health_result = await task
|
||||
health_results[model_id] = health_result
|
||||
except Exception as e:
|
||||
health_results[model_id] = {"healthy": False, "error": str(e)}
|
||||
|
||||
# Calculate overall health statistics
|
||||
total_models = len(health_results)
|
||||
healthy_models = sum(1 for result in health_results.values() if result.get("healthy", False))
|
||||
|
||||
return {
|
||||
"total_models": total_models,
|
||||
"healthy_models": healthy_models,
|
||||
"unhealthy_models": total_models - healthy_models,
|
||||
"health_percentage": (healthy_models / total_models * 100) if total_models > 0 else 0,
|
||||
"individual_results": health_results
|
||||
}
|
||||
|
||||
async def get_model_analytics(
|
||||
self,
|
||||
model_id: str = None,
|
||||
timeframe_hours: int = 24
|
||||
) -> Dict[str, Any]:
|
||||
"""Get analytics for model usage and performance"""
|
||||
|
||||
models = await self.list_models()
|
||||
if model_id:
|
||||
models = [m for m in models if m["id"] == model_id]
|
||||
|
||||
analytics = {
|
||||
"total_models": len(models),
|
||||
"by_provider": {},
|
||||
"by_type": {},
|
||||
"performance_summary": {
|
||||
"avg_latency_p50": 0,
|
||||
"avg_success_rate": 0,
|
||||
"total_requests": 0,
|
||||
"total_errors": 0
|
||||
},
|
||||
"top_performers": [],
|
||||
"models": models
|
||||
}
|
||||
|
||||
total_latency = 0
|
||||
total_success_rate = 0
|
||||
total_requests = 0
|
||||
total_errors = 0
|
||||
|
||||
for model in models:
|
||||
# Provider statistics
|
||||
provider = model["provider"]
|
||||
if provider not in analytics["by_provider"]:
|
||||
analytics["by_provider"][provider] = {"count": 0, "requests": 0}
|
||||
analytics["by_provider"][provider]["count"] += 1
|
||||
analytics["by_provider"][provider]["requests"] += model["request_count"]
|
||||
|
||||
# Type statistics
|
||||
model_type = model["model_type"]
|
||||
if model_type not in analytics["by_type"]:
|
||||
analytics["by_type"][model_type] = {"count": 0, "requests": 0}
|
||||
analytics["by_type"][model_type]["count"] += 1
|
||||
analytics["by_type"][model_type]["requests"] += model["request_count"]
|
||||
|
||||
# Performance aggregation
|
||||
total_latency += model["latency_p50_ms"]
|
||||
total_success_rate += model["success_rate"]
|
||||
total_requests += model["request_count"]
|
||||
total_errors += model["error_count"]
|
||||
|
||||
# Calculate averages
|
||||
if len(models) > 0:
|
||||
analytics["performance_summary"]["avg_latency_p50"] = total_latency / len(models)
|
||||
analytics["performance_summary"]["avg_success_rate"] = total_success_rate / len(models)
|
||||
|
||||
analytics["performance_summary"]["total_requests"] = total_requests
|
||||
analytics["performance_summary"]["total_errors"] = total_errors
|
||||
|
||||
# Top performers (by success rate and low latency)
|
||||
analytics["top_performers"] = sorted(
|
||||
[m for m in models if m["request_count"] > 0],
|
||||
key=lambda x: (x["success_rate"], -x["latency_p50_ms"]),
|
||||
reverse=True
|
||||
)[:5]
|
||||
|
||||
return analytics
|
||||
|
||||
async def _initialize_default_models(self):
|
||||
"""Initialize registry with default models"""
|
||||
|
||||
# Groq models
|
||||
groq_models = [
|
||||
{
|
||||
"model_id": "llama-3.1-405b-reasoning",
|
||||
"name": "Llama 3.1 405B Reasoning",
|
||||
"version": "3.1",
|
||||
"provider": "groq",
|
||||
"model_type": "llm",
|
||||
"description": "Largest Llama model optimized for complex reasoning tasks",
|
||||
"max_tokens": 8000,
|
||||
"context_window": 32768,
|
||||
"cost_per_1k_tokens": 2.5,
|
||||
"capabilities": {"reasoning": True, "function_calling": True, "streaming": True}
|
||||
},
|
||||
{
|
||||
"model_id": "llama-3.1-70b-versatile",
|
||||
"name": "Llama 3.1 70B Versatile",
|
||||
"version": "3.1",
|
||||
"provider": "groq",
|
||||
"model_type": "llm",
|
||||
"description": "Balanced Llama model for general-purpose tasks",
|
||||
"max_tokens": 8000,
|
||||
"context_window": 32768,
|
||||
"cost_per_1k_tokens": 0.8,
|
||||
"capabilities": {"general": True, "function_calling": True, "streaming": True}
|
||||
},
|
||||
{
|
||||
"model_id": "llama-3.1-8b-instant",
|
||||
"name": "Llama 3.1 8B Instant",
|
||||
"version": "3.1",
|
||||
"provider": "groq",
|
||||
"model_type": "llm",
|
||||
"description": "Fast Llama model for quick responses",
|
||||
"max_tokens": 8000,
|
||||
"context_window": 32768,
|
||||
"cost_per_1k_tokens": 0.2,
|
||||
"capabilities": {"fast": True, "streaming": True}
|
||||
},
|
||||
{
|
||||
"model_id": "mixtral-8x7b-32768",
|
||||
"name": "Mixtral 8x7B",
|
||||
"version": "1.0",
|
||||
"provider": "groq",
|
||||
"model_type": "llm",
|
||||
"description": "Mixtral model for balanced performance",
|
||||
"max_tokens": 32768,
|
||||
"context_window": 32768,
|
||||
"cost_per_1k_tokens": 0.27,
|
||||
"capabilities": {"general": True, "streaming": True}
|
||||
}
|
||||
]
|
||||
|
||||
for model_config in groq_models:
|
||||
await self.register_model(**model_config)
|
||||
|
||||
logger.info("Initialized default model registry with in-memory storage")
|
||||
|
||||
def _initialize_default_models_sync(self):
|
||||
"""Initialize registry with default models synchronously"""
|
||||
|
||||
# Groq models
|
||||
groq_models = [
|
||||
{
|
||||
"model_id": "llama-3.1-405b-reasoning",
|
||||
"name": "Llama 3.1 405B Reasoning",
|
||||
"version": "3.1",
|
||||
"provider": "groq",
|
||||
"model_type": "llm",
|
||||
"description": "Largest Llama model optimized for complex reasoning tasks",
|
||||
"max_tokens": 8000,
|
||||
"context_window": 32768,
|
||||
"cost_per_1k_tokens": 2.5,
|
||||
"capabilities": {"reasoning": True, "function_calling": True, "streaming": True}
|
||||
},
|
||||
{
|
||||
"model_id": "llama-3.1-70b-versatile",
|
||||
"name": "Llama 3.1 70B Versatile",
|
||||
"version": "3.1",
|
||||
"provider": "groq",
|
||||
"model_type": "llm",
|
||||
"description": "Balanced Llama model for general-purpose tasks",
|
||||
"max_tokens": 8000,
|
||||
"context_window": 32768,
|
||||
"cost_per_1k_tokens": 0.8,
|
||||
"capabilities": {"general": True, "function_calling": True, "streaming": True}
|
||||
},
|
||||
{
|
||||
"model_id": "llama-3.1-8b-instant",
|
||||
"name": "Llama 3.1 8B Instant",
|
||||
"version": "3.1",
|
||||
"provider": "groq",
|
||||
"model_type": "llm",
|
||||
"description": "Fast Llama model for quick responses",
|
||||
"max_tokens": 8000,
|
||||
"context_window": 32768,
|
||||
"cost_per_1k_tokens": 0.2,
|
||||
"capabilities": {"fast": True, "streaming": True}
|
||||
},
|
||||
{
|
||||
"model_id": "mixtral-8x7b-32768",
|
||||
"name": "Mixtral 8x7B",
|
||||
"version": "1.0",
|
||||
"provider": "groq",
|
||||
"model_type": "llm",
|
||||
"description": "Mixtral model for balanced performance",
|
||||
"max_tokens": 32768,
|
||||
"context_window": 32768,
|
||||
"cost_per_1k_tokens": 0.27,
|
||||
"capabilities": {"general": True, "streaming": True}
|
||||
},
|
||||
{
|
||||
"model_id": "groq/compound",
|
||||
"name": "Groq Compound Model",
|
||||
"version": "1.0",
|
||||
"provider": "groq",
|
||||
"model_type": "llm",
|
||||
"description": "Groq compound AI model",
|
||||
"max_tokens": 8000,
|
||||
"context_window": 8000,
|
||||
"cost_per_1k_tokens": 0.5,
|
||||
"capabilities": {"general": True, "streaming": True}
|
||||
}
|
||||
]
|
||||
|
||||
for model_config in groq_models:
|
||||
now = datetime.utcnow()
|
||||
model_entry = {
|
||||
"id": model_config["model_id"],
|
||||
"name": model_config["name"],
|
||||
"version": model_config["version"],
|
||||
"provider": model_config["provider"],
|
||||
"model_type": model_config["model_type"],
|
||||
"description": model_config["description"],
|
||||
"capabilities": model_config["capabilities"],
|
||||
"parameters": {},
|
||||
|
||||
# Performance metrics
|
||||
"max_tokens": model_config["max_tokens"],
|
||||
"context_window": model_config["context_window"],
|
||||
"cost_per_1k_tokens": model_config["cost_per_1k_tokens"],
|
||||
"latency_p50_ms": 0.0,
|
||||
"latency_p95_ms": 0.0,
|
||||
|
||||
# Deployment status
|
||||
"deployment_status": "available",
|
||||
"health_status": "unknown",
|
||||
"last_health_check": None,
|
||||
|
||||
# Usage tracking
|
||||
"request_count": 0,
|
||||
"error_count": 0,
|
||||
"success_rate": 1.0,
|
||||
|
||||
# Lifecycle
|
||||
"created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat(),
|
||||
"retired_at": None,
|
||||
|
||||
# Configuration
|
||||
"endpoint_url": None,
|
||||
"api_key_required": True,
|
||||
"rate_limits": {}
|
||||
}
|
||||
|
||||
self.model_registry[model_config["model_id"]] = model_entry
|
||||
|
||||
logger.info("Initialized default model registry with in-memory storage (sync)")
|
||||
|
||||
async def register_or_update_model(
|
||||
self,
|
||||
model_id: str,
|
||||
name: str,
|
||||
version: str = "1.0",
|
||||
provider: str = "unknown",
|
||||
model_type: str = "llm",
|
||||
endpoint: str = "",
|
||||
api_key_name: str = None,
|
||||
specifications: Dict[str, Any] = None,
|
||||
capabilities: Dict[str, Any] = None,
|
||||
cost: Dict[str, Any] = None,
|
||||
description: str = "",
|
||||
config: Dict[str, Any] = None,
|
||||
status: Dict[str, Any] = None,
|
||||
sync_timestamp: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Register a new model or update existing one from admin cluster sync"""
|
||||
|
||||
specifications = specifications or {}
|
||||
capabilities = capabilities or {}
|
||||
cost = cost or {}
|
||||
config = config or {}
|
||||
status = status or {}
|
||||
|
||||
# Check if model exists
|
||||
existing_model = self.model_registry.get(model_id)
|
||||
|
||||
if existing_model:
|
||||
# Update existing model
|
||||
existing_model.update({
|
||||
"name": name,
|
||||
"version": version,
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"description": description,
|
||||
"capabilities": capabilities,
|
||||
"parameters": config,
|
||||
"endpoint_url": endpoint,
|
||||
"api_key_required": bool(api_key_name),
|
||||
"max_tokens": specifications.get("max_tokens", existing_model.get("max_tokens", 4000)),
|
||||
"context_window": specifications.get("context_window", existing_model.get("context_window", 4000)),
|
||||
"cost_per_1k_tokens": cost.get("per_1k_input", existing_model.get("cost_per_1k_tokens", 0.0)),
|
||||
"deployment_status": "deployed" if status.get("is_active", True) else "retired",
|
||||
"updated_at": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
if "bge-m3" in model_id.lower():
|
||||
logger.info(f"Updated BGE-M3 model: endpoint_url={endpoint}, parameters={config}")
|
||||
logger.debug(f"Updated model: {model_id}")
|
||||
return existing_model
|
||||
else:
|
||||
# Register new model
|
||||
return await self.register_model(
|
||||
model_id=model_id,
|
||||
name=name,
|
||||
version=version,
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
description=description,
|
||||
capabilities=capabilities,
|
||||
parameters=config,
|
||||
endpoint_url=endpoint,
|
||||
max_tokens=specifications.get("max_tokens", 4000),
|
||||
context_window=specifications.get("context_window", 4000),
|
||||
cost_per_1k_tokens=cost.get("per_1k_input", 0.0),
|
||||
api_key_required=bool(api_key_name)
|
||||
)
|
||||
|
||||
|
||||
def get_model_service(tenant_id: Optional[str] = None) -> ModelService:
|
||||
"""Get tenant-isolated model service instance"""
|
||||
return ModelService(tenant_id=tenant_id)
|
||||
|
||||
# Default model service for development/non-tenant operations
|
||||
default_model_service = get_model_service()
|
||||
931
apps/resource-cluster/app/services/service_manager.py
Normal file
931
apps/resource-cluster/app/services/service_manager.py
Normal file
@@ -0,0 +1,931 @@
|
||||
"""
|
||||
GT 2.0 Resource Cluster - Service Manager
|
||||
Orchestrates external web services (CTFd, Canvas LMS, Guacamole, JupyterHub)
|
||||
with perfect tenant isolation and security.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
try:
|
||||
import docker
|
||||
import kubernetes
|
||||
from kubernetes import client, config
|
||||
from kubernetes.client.rest import ApiException
|
||||
DOCKER_AVAILABLE = True
|
||||
KUBERNETES_AVAILABLE = True
|
||||
except ImportError:
|
||||
# For development containerization mode, these are optional
|
||||
docker = None
|
||||
kubernetes = None
|
||||
client = None
|
||||
config = None
|
||||
ApiException = Exception
|
||||
DOCKER_AVAILABLE = False
|
||||
KUBERNETES_AVAILABLE = False
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.security import verify_capability_token
|
||||
from app.utils.encryption import encrypt_data, decrypt_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ServiceInstance:
|
||||
"""Represents a deployed service instance"""
|
||||
instance_id: str
|
||||
tenant_id: str
|
||||
service_type: str # 'ctfd', 'canvas', 'guacamole', 'jupyter'
|
||||
status: str # 'starting', 'running', 'stopping', 'stopped', 'error'
|
||||
endpoint_url: str
|
||||
internal_port: int
|
||||
external_port: int
|
||||
namespace: str
|
||||
deployment_name: str
|
||||
service_name: str
|
||||
ingress_name: str
|
||||
sso_token: Optional[str] = None
|
||||
created_at: datetime = datetime.utcnow()
|
||||
last_heartbeat: datetime = datetime.utcnow()
|
||||
resource_usage: Dict[str, Any] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = asdict(self)
|
||||
data['created_at'] = self.created_at.isoformat()
|
||||
data['last_heartbeat'] = self.last_heartbeat.isoformat()
|
||||
return data
|
||||
|
||||
@dataclass
|
||||
class ServiceTemplate:
|
||||
"""Service deployment template configuration"""
|
||||
service_type: str
|
||||
image: str
|
||||
ports: Dict[str, int]
|
||||
environment: Dict[str, str]
|
||||
volumes: List[Dict[str, str]]
|
||||
resource_limits: Dict[str, str]
|
||||
security_context: Dict[str, Any]
|
||||
health_check: Dict[str, Any]
|
||||
sso_config: Dict[str, Any]
|
||||
|
||||
class ServiceManager:
|
||||
"""Manages external web service instances with Kubernetes orchestration"""
|
||||
|
||||
def __init__(self):
|
||||
# Initialize Docker client if available
|
||||
if DOCKER_AVAILABLE:
|
||||
try:
|
||||
self.docker_client = docker.from_env()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not initialize Docker client: {e}")
|
||||
self.docker_client = None
|
||||
else:
|
||||
self.docker_client = None
|
||||
|
||||
self.k8s_client = None
|
||||
self.active_instances: Dict[str, ServiceInstance] = {}
|
||||
self.service_templates: Dict[str, ServiceTemplate] = {}
|
||||
self.base_namespace = "gt-services"
|
||||
self.storage_path = Path("/tmp/resource-cluster/services")
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize Kubernetes client if available
|
||||
if KUBERNETES_AVAILABLE:
|
||||
try:
|
||||
config.load_incluster_config() # If running in cluster
|
||||
except:
|
||||
try:
|
||||
config.load_kube_config() # If running locally
|
||||
except:
|
||||
logger.warning("Could not load Kubernetes config - using mock mode")
|
||||
|
||||
self.k8s_client = client.ApiClient() if client else None
|
||||
else:
|
||||
logger.warning("Kubernetes not available - running in development containerization mode")
|
||||
self._initialize_service_templates()
|
||||
self._load_persistent_instances()
|
||||
|
||||
def _initialize_service_templates(self):
|
||||
"""Initialize service deployment templates"""
|
||||
|
||||
# CTFd Template
|
||||
self.service_templates['ctfd'] = ServiceTemplate(
|
||||
service_type='ctfd',
|
||||
image='ctfd/ctfd:3.6.0',
|
||||
ports={'http': 8000},
|
||||
environment={
|
||||
'SECRET_KEY': '${TENANT_SECRET_KEY}',
|
||||
'DATABASE_URL': 'sqlite:////data/ctfd.db',
|
||||
'DATABASE_CACHE_URL': 'postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres:5432/gt2_tenants',
|
||||
'UPLOAD_FOLDER': '/data/uploads',
|
||||
'LOG_FOLDER': '/data/logs',
|
||||
},
|
||||
volumes=[
|
||||
{'name': 'ctfd-data', 'mountPath': '/data', 'size': '5Gi'},
|
||||
{'name': 'ctfd-uploads', 'mountPath': '/uploads', 'size': '2Gi'}
|
||||
],
|
||||
resource_limits={
|
||||
'memory': '2Gi',
|
||||
'cpu': '1000m'
|
||||
},
|
||||
security_context={
|
||||
'runAsNonRoot': True,
|
||||
'runAsUser': 1000,
|
||||
'fsGroup': 1000,
|
||||
'readOnlyRootFilesystem': False
|
||||
},
|
||||
health_check={
|
||||
'path': '/health',
|
||||
'port': 8000,
|
||||
'initial_delay': 30,
|
||||
'period': 10
|
||||
},
|
||||
sso_config={
|
||||
'enabled': True,
|
||||
'provider': 'oauth2',
|
||||
'callback_path': '/auth/oauth/callback'
|
||||
}
|
||||
)
|
||||
|
||||
# Canvas LMS Template
|
||||
self.service_templates['canvas'] = ServiceTemplate(
|
||||
service_type='canvas',
|
||||
image='instructure/canvas-lms:stable',
|
||||
ports={'http': 3000},
|
||||
environment={
|
||||
'CANVAS_LMS_ADMIN_EMAIL': 'admin@${TENANT_DOMAIN}',
|
||||
'CANVAS_LMS_ADMIN_PASSWORD': '${CANVAS_ADMIN_PASSWORD}',
|
||||
'CANVAS_LMS_ACCOUNT_NAME': '${TENANT_NAME}',
|
||||
'CANVAS_LMS_STATS_COLLECTION': 'opt_out',
|
||||
'POSTGRES_PASSWORD': '${POSTGRES_PASSWORD}',
|
||||
'DATABASE_CACHE_URL': 'postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres:5432/gt2_tenants'
|
||||
},
|
||||
volumes=[
|
||||
{'name': 'canvas-data', 'mountPath': '/app/log', 'size': '10Gi'},
|
||||
{'name': 'canvas-files', 'mountPath': '/app/public/files', 'size': '20Gi'}
|
||||
],
|
||||
resource_limits={
|
||||
'memory': '4Gi',
|
||||
'cpu': '2000m'
|
||||
},
|
||||
security_context={
|
||||
'runAsNonRoot': True,
|
||||
'runAsUser': 1000,
|
||||
'fsGroup': 1000
|
||||
},
|
||||
health_check={
|
||||
'path': '/health_check',
|
||||
'port': 3000,
|
||||
'initial_delay': 60,
|
||||
'period': 15
|
||||
},
|
||||
sso_config={
|
||||
'enabled': True,
|
||||
'provider': 'saml',
|
||||
'metadata_url': '/auth/saml/metadata'
|
||||
}
|
||||
)
|
||||
|
||||
# Guacamole Template
|
||||
self.service_templates['guacamole'] = ServiceTemplate(
|
||||
service_type='guacamole',
|
||||
image='guacamole/guacamole:1.5.3',
|
||||
ports={'http': 8080},
|
||||
environment={
|
||||
'GUACD_HOSTNAME': 'guacd',
|
||||
'GUACD_PORT': '4822',
|
||||
'MYSQL_HOSTNAME': 'mysql',
|
||||
'MYSQL_PORT': '3306',
|
||||
'MYSQL_DATABASE': 'guacamole_db',
|
||||
'MYSQL_USER': 'guacamole_user',
|
||||
'MYSQL_PASSWORD': '${MYSQL_PASSWORD}',
|
||||
'GUAC_LOG_LEVEL': 'INFO'
|
||||
},
|
||||
volumes=[
|
||||
{'name': 'guacamole-data', 'mountPath': '/config', 'size': '1Gi'},
|
||||
{'name': 'guacamole-recordings', 'mountPath': '/recordings', 'size': '10Gi'}
|
||||
],
|
||||
resource_limits={
|
||||
'memory': '1Gi',
|
||||
'cpu': '500m'
|
||||
},
|
||||
security_context={
|
||||
'runAsNonRoot': True,
|
||||
'runAsUser': 1001,
|
||||
'fsGroup': 1001
|
||||
},
|
||||
health_check={
|
||||
'path': '/guacamole',
|
||||
'port': 8080,
|
||||
'initial_delay': 45,
|
||||
'period': 10
|
||||
},
|
||||
sso_config={
|
||||
'enabled': True,
|
||||
'provider': 'openid',
|
||||
'extension': 'guacamole-auth-openid'
|
||||
}
|
||||
)
|
||||
|
||||
# JupyterHub Template
|
||||
self.service_templates['jupyter'] = ServiceTemplate(
|
||||
service_type='jupyter',
|
||||
image='jupyterhub/jupyterhub:4.0',
|
||||
ports={'http': 8000},
|
||||
environment={
|
||||
'JUPYTERHUB_CRYPT_KEY': '${JUPYTERHUB_CRYPT_KEY}',
|
||||
'CONFIGPROXY_AUTH_TOKEN': '${CONFIGPROXY_AUTH_TOKEN}',
|
||||
'DOCKER_NETWORK_NAME': 'jupyterhub',
|
||||
'DOCKER_NOTEBOOK_IMAGE': 'jupyter/datascience-notebook:lab-4.0.7'
|
||||
},
|
||||
volumes=[
|
||||
{'name': 'jupyter-data', 'mountPath': '/srv/jupyterhub', 'size': '5Gi'},
|
||||
{'name': 'docker-socket', 'mountPath': '/var/run/docker.sock', 'hostPath': '/var/run/docker.sock'}
|
||||
],
|
||||
resource_limits={
|
||||
'memory': '2Gi',
|
||||
'cpu': '1000m'
|
||||
},
|
||||
security_context={
|
||||
'runAsNonRoot': False, # Needs Docker access
|
||||
'runAsUser': 0,
|
||||
'privileged': True
|
||||
},
|
||||
health_check={
|
||||
'path': '/hub/health',
|
||||
'port': 8000,
|
||||
'initial_delay': 30,
|
||||
'period': 15
|
||||
},
|
||||
sso_config={
|
||||
'enabled': True,
|
||||
'provider': 'oauth',
|
||||
'authenticator_class': 'oauthenticator.generic.GenericOAuthenticator'
|
||||
}
|
||||
)
|
||||
|
||||
async def create_service_instance(
|
||||
self,
|
||||
tenant_id: str,
|
||||
service_type: str,
|
||||
config_overrides: Dict[str, Any] = None
|
||||
) -> ServiceInstance:
|
||||
"""Create a new service instance for a tenant"""
|
||||
|
||||
if service_type not in self.service_templates:
|
||||
raise ValueError(f"Unsupported service type: {service_type}")
|
||||
|
||||
template = self.service_templates[service_type]
|
||||
instance_id = f"{service_type}-{tenant_id}-{uuid.uuid4().hex[:8]}"
|
||||
namespace = f"{self.base_namespace}-{tenant_id}"
|
||||
|
||||
# Generate unique ports
|
||||
external_port = await self._get_available_port()
|
||||
|
||||
# Create service instance object
|
||||
instance = ServiceInstance(
|
||||
instance_id=instance_id,
|
||||
tenant_id=tenant_id,
|
||||
service_type=service_type,
|
||||
status='starting',
|
||||
endpoint_url=f"https://{service_type}.{tenant_id}.gt2.com",
|
||||
internal_port=template.ports['http'],
|
||||
external_port=external_port,
|
||||
namespace=namespace,
|
||||
deployment_name=f"{service_type}-{instance_id}",
|
||||
service_name=f"{service_type}-service-{instance_id}",
|
||||
ingress_name=f"{service_type}-ingress-{instance_id}",
|
||||
resource_usage={'cpu': 0, 'memory': 0, 'storage': 0}
|
||||
)
|
||||
|
||||
try:
|
||||
# Create Kubernetes namespace if not exists
|
||||
await self._create_namespace(namespace, tenant_id)
|
||||
|
||||
# Deploy the service
|
||||
await self._deploy_service(instance, template, config_overrides)
|
||||
|
||||
# Generate SSO token
|
||||
instance.sso_token = await self._generate_sso_token(instance)
|
||||
|
||||
# Store instance
|
||||
self.active_instances[instance_id] = instance
|
||||
await self._persist_instance(instance)
|
||||
|
||||
logger.info(f"Created {service_type} instance {instance_id} for tenant {tenant_id}")
|
||||
return instance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create service instance: {e}")
|
||||
instance.status = 'error'
|
||||
raise
|
||||
|
||||
async def _create_namespace(self, namespace: str, tenant_id: str):
|
||||
"""Create Kubernetes namespace with proper labeling and network policies"""
|
||||
|
||||
if not self.k8s_client:
|
||||
logger.info(f"Mock: Created namespace {namespace}")
|
||||
return
|
||||
|
||||
v1 = client.CoreV1Api(self.k8s_client)
|
||||
|
||||
# Create namespace
|
||||
namespace_manifest = client.V1Namespace(
|
||||
metadata=client.V1ObjectMeta(
|
||||
name=namespace,
|
||||
labels={
|
||||
'gt.tenant-id': tenant_id,
|
||||
'gt.cluster': 'resource',
|
||||
'gt.isolation': 'tenant'
|
||||
},
|
||||
annotations={
|
||||
'gt.created-by': 'service-manager',
|
||||
'gt.creation-time': datetime.utcnow().isoformat()
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
v1.create_namespace(namespace_manifest)
|
||||
logger.info(f"Created namespace: {namespace}")
|
||||
except ApiException as e:
|
||||
if e.status == 409: # Already exists
|
||||
logger.info(f"Namespace {namespace} already exists")
|
||||
else:
|
||||
raise
|
||||
|
||||
# Apply network policy for tenant isolation
|
||||
await self._apply_network_policy(namespace, tenant_id)
|
||||
|
||||
async def _apply_network_policy(self, namespace: str, tenant_id: str):
|
||||
"""Apply network policy for tenant isolation"""
|
||||
|
||||
if not self.k8s_client:
|
||||
logger.info(f"Mock: Applied network policy to {namespace}")
|
||||
return
|
||||
|
||||
networking_v1 = client.NetworkingV1Api(self.k8s_client)
|
||||
|
||||
# Network policy that only allows:
|
||||
# 1. Intra-namespace communication
|
||||
# 2. Communication to system namespaces (DNS, etc.)
|
||||
# 3. Egress to external services (for updates, etc.)
|
||||
network_policy = client.V1NetworkPolicy(
|
||||
metadata=client.V1ObjectMeta(
|
||||
name=f"tenant-isolation-{tenant_id}",
|
||||
namespace=namespace,
|
||||
labels={'gt.tenant-id': tenant_id}
|
||||
),
|
||||
spec=client.V1NetworkPolicySpec(
|
||||
pod_selector=client.V1LabelSelector(), # All pods in namespace
|
||||
policy_types=['Ingress', 'Egress'],
|
||||
ingress=[
|
||||
# Allow ingress from same namespace
|
||||
client.V1NetworkPolicyIngressRule(
|
||||
from_=[client.V1NetworkPolicyPeer(
|
||||
namespace_selector=client.V1LabelSelector(
|
||||
match_labels={'name': namespace}
|
||||
)
|
||||
)]
|
||||
),
|
||||
# Allow ingress from ingress controller
|
||||
client.V1NetworkPolicyIngressRule(
|
||||
from_=[client.V1NetworkPolicyPeer(
|
||||
namespace_selector=client.V1LabelSelector(
|
||||
match_labels={'name': 'ingress-nginx'}
|
||||
)
|
||||
)]
|
||||
)
|
||||
],
|
||||
egress=[
|
||||
# Allow egress within namespace
|
||||
client.V1NetworkPolicyEgressRule(
|
||||
to=[client.V1NetworkPolicyPeer(
|
||||
namespace_selector=client.V1LabelSelector(
|
||||
match_labels={'name': namespace}
|
||||
)
|
||||
)]
|
||||
),
|
||||
# Allow DNS
|
||||
client.V1NetworkPolicyEgressRule(
|
||||
to=[client.V1NetworkPolicyPeer(
|
||||
namespace_selector=client.V1LabelSelector(
|
||||
match_labels={'name': 'kube-system'}
|
||||
)
|
||||
)],
|
||||
ports=[client.V1NetworkPolicyPort(port=53, protocol='UDP')]
|
||||
),
|
||||
# Allow external HTTPS (for updates, etc.)
|
||||
client.V1NetworkPolicyEgressRule(
|
||||
ports=[
|
||||
client.V1NetworkPolicyPort(port=443, protocol='TCP'),
|
||||
client.V1NetworkPolicyPort(port=80, protocol='TCP')
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
networking_v1.create_namespaced_network_policy(
|
||||
namespace=namespace,
|
||||
body=network_policy
|
||||
)
|
||||
logger.info(f"Applied network policy to namespace: {namespace}")
|
||||
except ApiException as e:
|
||||
if e.status == 409: # Already exists
|
||||
logger.info(f"Network policy already exists in {namespace}")
|
||||
else:
|
||||
logger.error(f"Failed to create network policy: {e}")
|
||||
raise
|
||||
|
||||
async def _deploy_service(
|
||||
self,
|
||||
instance: ServiceInstance,
|
||||
template: ServiceTemplate,
|
||||
config_overrides: Dict[str, Any] = None
|
||||
):
|
||||
"""Deploy service to Kubernetes cluster"""
|
||||
|
||||
if not self.k8s_client:
|
||||
logger.info(f"Mock: Deployed {template.service_type} service")
|
||||
instance.status = 'running'
|
||||
return
|
||||
|
||||
# Prepare environment variables with tenant-specific values
|
||||
environment = template.environment.copy()
|
||||
if config_overrides:
|
||||
environment.update(config_overrides.get('environment', {}))
|
||||
|
||||
# Substitute tenant-specific values
|
||||
env_vars = []
|
||||
for key, value in environment.items():
|
||||
substituted_value = value.replace('${TENANT_ID}', instance.tenant_id)
|
||||
substituted_value = substituted_value.replace('${TENANT_DOMAIN}', f"{instance.tenant_id}.gt2.com")
|
||||
env_vars.append(client.V1EnvVar(name=key, value=substituted_value))
|
||||
|
||||
# Create volumes
|
||||
volumes = []
|
||||
volume_mounts = []
|
||||
for vol_config in template.volumes:
|
||||
vol_name = f"{vol_config['name']}-{instance.instance_id}"
|
||||
volumes.append(client.V1Volume(
|
||||
name=vol_name,
|
||||
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
|
||||
claim_name=vol_name
|
||||
)
|
||||
))
|
||||
volume_mounts.append(client.V1VolumeMount(
|
||||
name=vol_name,
|
||||
mount_path=vol_config['mountPath']
|
||||
))
|
||||
|
||||
# Create PVCs first
|
||||
await self._create_persistent_volumes(instance, template)
|
||||
|
||||
# Create deployment
|
||||
deployment = client.V1Deployment(
|
||||
metadata=client.V1ObjectMeta(
|
||||
name=instance.deployment_name,
|
||||
namespace=instance.namespace,
|
||||
labels={
|
||||
'app': template.service_type,
|
||||
'instance': instance.instance_id,
|
||||
'gt.tenant-id': instance.tenant_id,
|
||||
'gt.service-type': template.service_type
|
||||
}
|
||||
),
|
||||
spec=client.V1DeploymentSpec(
|
||||
replicas=1,
|
||||
selector=client.V1LabelSelector(
|
||||
match_labels={'instance': instance.instance_id}
|
||||
),
|
||||
template=client.V1PodTemplateSpec(
|
||||
metadata=client.V1ObjectMeta(
|
||||
labels={
|
||||
'app': template.service_type,
|
||||
'instance': instance.instance_id,
|
||||
'gt.tenant-id': instance.tenant_id
|
||||
}
|
||||
),
|
||||
spec=client.V1PodSpec(
|
||||
containers=[client.V1Container(
|
||||
name=template.service_type,
|
||||
image=template.image,
|
||||
ports=[client.V1ContainerPort(
|
||||
container_port=template.ports['http']
|
||||
)],
|
||||
env=env_vars,
|
||||
volume_mounts=volume_mounts,
|
||||
resources=client.V1ResourceRequirements(
|
||||
limits=template.resource_limits,
|
||||
requests=template.resource_limits
|
||||
),
|
||||
security_context=client.V1SecurityContext(**template.security_context),
|
||||
liveness_probe=client.V1Probe(
|
||||
http_get=client.V1HTTPGetAction(
|
||||
path=template.health_check['path'],
|
||||
port=template.health_check['port']
|
||||
),
|
||||
initial_delay_seconds=template.health_check['initial_delay'],
|
||||
period_seconds=template.health_check['period']
|
||||
),
|
||||
readiness_probe=client.V1Probe(
|
||||
http_get=client.V1HTTPGetAction(
|
||||
path=template.health_check['path'],
|
||||
port=template.health_check['port']
|
||||
),
|
||||
initial_delay_seconds=10,
|
||||
period_seconds=5
|
||||
)
|
||||
)],
|
||||
volumes=volumes,
|
||||
security_context=client.V1PodSecurityContext(
|
||||
run_as_non_root=template.security_context.get('runAsNonRoot', True),
|
||||
fs_group=template.security_context.get('fsGroup', 1000)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Deploy to Kubernetes
|
||||
apps_v1 = client.AppsV1Api(self.k8s_client)
|
||||
apps_v1.create_namespaced_deployment(
|
||||
namespace=instance.namespace,
|
||||
body=deployment
|
||||
)
|
||||
|
||||
# Create service
|
||||
await self._create_service(instance, template)
|
||||
|
||||
# Create ingress
|
||||
await self._create_ingress(instance, template)
|
||||
|
||||
logger.info(f"Deployed {template.service_type} service: {instance.deployment_name}")
|
||||
|
||||
async def _create_persistent_volumes(self, instance: ServiceInstance, template: ServiceTemplate):
|
||||
"""Create persistent volume claims for the service"""
|
||||
|
||||
if not self.k8s_client:
|
||||
return
|
||||
|
||||
v1 = client.CoreV1Api(self.k8s_client)
|
||||
|
||||
for vol_config in template.volumes:
|
||||
if 'hostPath' in vol_config: # Skip host path volumes
|
||||
continue
|
||||
|
||||
pvc_name = f"{vol_config['name']}-{instance.instance_id}"
|
||||
|
||||
pvc = client.V1PersistentVolumeClaim(
|
||||
metadata=client.V1ObjectMeta(
|
||||
name=pvc_name,
|
||||
namespace=instance.namespace,
|
||||
labels={
|
||||
'app': template.service_type,
|
||||
'instance': instance.instance_id,
|
||||
'gt.tenant-id': instance.tenant_id
|
||||
}
|
||||
),
|
||||
spec=client.V1PersistentVolumeClaimSpec(
|
||||
access_modes=['ReadWriteOnce'],
|
||||
resources=client.V1ResourceRequirements(
|
||||
requests={'storage': vol_config['size']}
|
||||
),
|
||||
storage_class_name='fast-ssd' # Assuming SSD storage class
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
v1.create_namespaced_persistent_volume_claim(
|
||||
namespace=instance.namespace,
|
||||
body=pvc
|
||||
)
|
||||
logger.info(f"Created PVC: {pvc_name}")
|
||||
except ApiException as e:
|
||||
if e.status != 409: # Ignore if already exists
|
||||
raise
|
||||
|
||||
async def _create_service(self, instance: ServiceInstance, template: ServiceTemplate):
|
||||
"""Create Kubernetes service for the instance"""
|
||||
|
||||
if not self.k8s_client:
|
||||
return
|
||||
|
||||
v1 = client.CoreV1Api(self.k8s_client)
|
||||
|
||||
service = client.V1Service(
|
||||
metadata=client.V1ObjectMeta(
|
||||
name=instance.service_name,
|
||||
namespace=instance.namespace,
|
||||
labels={
|
||||
'app': template.service_type,
|
||||
'instance': instance.instance_id,
|
||||
'gt.tenant-id': instance.tenant_id
|
||||
}
|
||||
),
|
||||
spec=client.V1ServiceSpec(
|
||||
selector={'instance': instance.instance_id},
|
||||
ports=[client.V1ServicePort(
|
||||
port=80,
|
||||
target_port=template.ports['http'],
|
||||
protocol='TCP'
|
||||
)],
|
||||
type='ClusterIP'
|
||||
)
|
||||
)
|
||||
|
||||
v1.create_namespaced_service(
|
||||
namespace=instance.namespace,
|
||||
body=service
|
||||
)
|
||||
|
||||
logger.info(f"Created service: {instance.service_name}")
|
||||
|
||||
async def _create_ingress(self, instance: ServiceInstance, template: ServiceTemplate):
|
||||
"""Create ingress for external access with TLS"""
|
||||
|
||||
if not self.k8s_client:
|
||||
return
|
||||
|
||||
networking_v1 = client.NetworkingV1Api(self.k8s_client)
|
||||
|
||||
hostname = f"{template.service_type}.{instance.tenant_id}.gt2.com"
|
||||
|
||||
ingress = client.V1Ingress(
|
||||
metadata=client.V1ObjectMeta(
|
||||
name=instance.ingress_name,
|
||||
namespace=instance.namespace,
|
||||
labels={
|
||||
'app': template.service_type,
|
||||
'instance': instance.instance_id,
|
||||
'gt.tenant-id': instance.tenant_id
|
||||
},
|
||||
annotations={
|
||||
'kubernetes.io/ingress.class': 'nginx',
|
||||
'cert-manager.io/cluster-issuer': 'letsencrypt-prod',
|
||||
'nginx.ingress.kubernetes.io/ssl-redirect': 'true',
|
||||
'nginx.ingress.kubernetes.io/force-ssl-redirect': 'true',
|
||||
'nginx.ingress.kubernetes.io/auth-url': f'https://auth.{instance.tenant_id}.gt2.com/auth',
|
||||
'nginx.ingress.kubernetes.io/auth-signin': f'https://auth.{instance.tenant_id}.gt2.com/signin'
|
||||
}
|
||||
),
|
||||
spec=client.V1IngressSpec(
|
||||
tls=[client.V1IngressTLS(
|
||||
hosts=[hostname],
|
||||
secret_name=f"{template.service_type}-tls-{instance.instance_id}"
|
||||
)],
|
||||
rules=[client.V1IngressRule(
|
||||
host=hostname,
|
||||
http=client.V1HTTPIngressRuleValue(
|
||||
paths=[client.V1HTTPIngressPath(
|
||||
path='/',
|
||||
path_type='Prefix',
|
||||
backend=client.V1IngressBackend(
|
||||
service=client.V1IngressServiceBackend(
|
||||
name=instance.service_name,
|
||||
port=client.V1ServiceBackendPort(number=80)
|
||||
)
|
||||
)
|
||||
)]
|
||||
)
|
||||
)]
|
||||
)
|
||||
)
|
||||
|
||||
networking_v1.create_namespaced_ingress(
|
||||
namespace=instance.namespace,
|
||||
body=ingress
|
||||
)
|
||||
|
||||
logger.info(f"Created ingress: {instance.ingress_name} for {hostname}")
|
||||
|
||||
async def _get_available_port(self) -> int:
|
||||
"""Get next available port for service"""
|
||||
used_ports = {instance.external_port for instance in self.active_instances.values()}
|
||||
port = 30000 # Start from NodePort range
|
||||
while port in used_ports:
|
||||
port += 1
|
||||
return port
|
||||
|
||||
async def _generate_sso_token(self, instance: ServiceInstance) -> str:
|
||||
"""Generate SSO token for iframe embedding"""
|
||||
token_data = {
|
||||
'tenant_id': instance.tenant_id,
|
||||
'service_type': instance.service_type,
|
||||
'instance_id': instance.instance_id,
|
||||
'expires_at': (datetime.utcnow() + timedelta(hours=24)).isoformat(),
|
||||
'permissions': ['read', 'write', 'admin']
|
||||
}
|
||||
|
||||
# Encrypt the token data
|
||||
encrypted_token = encrypt_data(json.dumps(token_data))
|
||||
return encrypted_token.decode('utf-8')
|
||||
|
||||
async def get_service_instance(self, instance_id: str) -> Optional[ServiceInstance]:
|
||||
"""Get service instance by ID"""
|
||||
return self.active_instances.get(instance_id)
|
||||
|
||||
async def list_tenant_instances(self, tenant_id: str) -> List[ServiceInstance]:
|
||||
"""List all service instances for a tenant"""
|
||||
return [
|
||||
instance for instance in self.active_instances.values()
|
||||
if instance.tenant_id == tenant_id
|
||||
]
|
||||
|
||||
async def stop_service_instance(self, instance_id: str) -> bool:
|
||||
"""Stop a running service instance"""
|
||||
instance = self.active_instances.get(instance_id)
|
||||
if not instance:
|
||||
return False
|
||||
|
||||
try:
|
||||
instance.status = 'stopping'
|
||||
|
||||
if self.k8s_client:
|
||||
# Delete Kubernetes resources
|
||||
await self._cleanup_kubernetes_resources(instance)
|
||||
|
||||
instance.status = 'stopped'
|
||||
logger.info(f"Stopped service instance: {instance_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop instance {instance_id}: {e}")
|
||||
instance.status = 'error'
|
||||
return False
|
||||
|
||||
async def _cleanup_kubernetes_resources(self, instance: ServiceInstance):
|
||||
"""Clean up all Kubernetes resources for an instance"""
|
||||
|
||||
if not self.k8s_client:
|
||||
return
|
||||
|
||||
apps_v1 = client.AppsV1Api(self.k8s_client)
|
||||
v1 = client.CoreV1Api(self.k8s_client)
|
||||
networking_v1 = client.NetworkingV1Api(self.k8s_client)
|
||||
|
||||
try:
|
||||
# Delete deployment
|
||||
apps_v1.delete_namespaced_deployment(
|
||||
name=instance.deployment_name,
|
||||
namespace=instance.namespace,
|
||||
body=client.V1DeleteOptions()
|
||||
)
|
||||
|
||||
# Delete service
|
||||
v1.delete_namespaced_service(
|
||||
name=instance.service_name,
|
||||
namespace=instance.namespace,
|
||||
body=client.V1DeleteOptions()
|
||||
)
|
||||
|
||||
# Delete ingress
|
||||
networking_v1.delete_namespaced_ingress(
|
||||
name=instance.ingress_name,
|
||||
namespace=instance.namespace,
|
||||
body=client.V1DeleteOptions()
|
||||
)
|
||||
|
||||
# Delete PVCs (optional - may want to preserve data)
|
||||
# Note: In production, you might want to keep PVCs for data persistence
|
||||
|
||||
logger.info(f"Cleaned up Kubernetes resources for: {instance.instance_id}")
|
||||
|
||||
except ApiException as e:
|
||||
logger.error(f"Error cleaning up resources: {e}")
|
||||
raise
|
||||
|
||||
async def get_service_health(self, instance_id: str) -> Dict[str, Any]:
|
||||
"""Get health status of a service instance"""
|
||||
instance = self.active_instances.get(instance_id)
|
||||
if not instance:
|
||||
return {'status': 'not_found'}
|
||||
|
||||
if not self.k8s_client:
|
||||
return {
|
||||
'status': 'healthy',
|
||||
'instance_status': instance.status,
|
||||
'endpoint': instance.endpoint_url,
|
||||
'last_check': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Check Kubernetes pod status
|
||||
v1 = client.CoreV1Api(self.k8s_client)
|
||||
|
||||
try:
|
||||
pods = v1.list_namespaced_pod(
|
||||
namespace=instance.namespace,
|
||||
label_selector=f'instance={instance.instance_id}'
|
||||
)
|
||||
|
||||
if not pods.items:
|
||||
return {
|
||||
'status': 'no_pods',
|
||||
'instance_status': instance.status
|
||||
}
|
||||
|
||||
pod = pods.items[0]
|
||||
pod_status = 'unknown'
|
||||
|
||||
if pod.status.phase == 'Running':
|
||||
# Check container status
|
||||
if pod.status.container_statuses:
|
||||
container_status = pod.status.container_statuses[0]
|
||||
if container_status.ready:
|
||||
pod_status = 'healthy'
|
||||
else:
|
||||
pod_status = 'unhealthy'
|
||||
else:
|
||||
pod_status = 'starting'
|
||||
elif pod.status.phase == 'Pending':
|
||||
pod_status = 'starting'
|
||||
elif pod.status.phase == 'Failed':
|
||||
pod_status = 'failed'
|
||||
|
||||
# Update instance heartbeat
|
||||
instance.last_heartbeat = datetime.utcnow()
|
||||
|
||||
return {
|
||||
'status': pod_status,
|
||||
'instance_status': instance.status,
|
||||
'pod_phase': pod.status.phase,
|
||||
'endpoint': instance.endpoint_url,
|
||||
'last_check': datetime.utcnow().isoformat(),
|
||||
'restart_count': pod.status.container_statuses[0].restart_count if pod.status.container_statuses else 0
|
||||
}
|
||||
|
||||
except ApiException as e:
|
||||
logger.error(f"Failed to get health for {instance_id}: {e}")
|
||||
return {
|
||||
'status': 'error',
|
||||
'error': str(e),
|
||||
'instance_status': instance.status
|
||||
}
|
||||
|
||||
async def _persist_instance(self, instance: ServiceInstance):
|
||||
"""Persist instance data to disk"""
|
||||
instance_file = self.storage_path / f"{instance.instance_id}.json"
|
||||
|
||||
with open(instance_file, 'w') as f:
|
||||
json.dump(instance.to_dict(), f, indent=2)
|
||||
|
||||
def _load_persistent_instances(self):
|
||||
"""Load persistent instances from disk on startup"""
|
||||
if not self.storage_path.exists():
|
||||
return
|
||||
|
||||
for instance_file in self.storage_path.glob("*.json"):
|
||||
try:
|
||||
with open(instance_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Reconstruct instance object
|
||||
instance = ServiceInstance(
|
||||
instance_id=data['instance_id'],
|
||||
tenant_id=data['tenant_id'],
|
||||
service_type=data['service_type'],
|
||||
status=data['status'],
|
||||
endpoint_url=data['endpoint_url'],
|
||||
internal_port=data['internal_port'],
|
||||
external_port=data['external_port'],
|
||||
namespace=data['namespace'],
|
||||
deployment_name=data['deployment_name'],
|
||||
service_name=data['service_name'],
|
||||
ingress_name=data['ingress_name'],
|
||||
sso_token=data.get('sso_token'),
|
||||
created_at=datetime.fromisoformat(data['created_at']),
|
||||
last_heartbeat=datetime.fromisoformat(data['last_heartbeat']),
|
||||
resource_usage=data.get('resource_usage', {})
|
||||
)
|
||||
|
||||
self.active_instances[instance.instance_id] = instance
|
||||
logger.info(f"Loaded persistent instance: {instance.instance_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load instance from {instance_file}: {e}")
|
||||
|
||||
async def cleanup_orphaned_resources(self):
|
||||
"""Clean up orphaned Kubernetes resources"""
|
||||
if not self.k8s_client:
|
||||
return
|
||||
|
||||
logger.info("Starting cleanup of orphaned resources...")
|
||||
|
||||
# This would implement logic to find and clean up:
|
||||
# 1. Deployments without corresponding instances
|
||||
# 2. Services without deployments
|
||||
# 3. Unused PVCs
|
||||
# 4. Expired certificates
|
||||
|
||||
# Implementation would query Kubernetes for resources with GT labels
|
||||
# and cross-reference with active instances
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
Reference in New Issue
Block a user