GT AI OS Community Edition v2.0.33

Security hardening release addressing CodeQL and Dependabot alerts:

- Fix stack trace exposure in error responses
- Add SSRF protection with DNS resolution checking
- Implement proper URL hostname validation (replaces substring matching)
- Add centralized path sanitization to prevent path traversal
- Fix ReDoS vulnerability in email validation regex
- Improve HTML sanitization in validation utilities
- Fix capability wildcard matching in auth utilities
- Update glob dependency to address CVE
- Add CodeQL suppression comments for verified false positives

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
HackWeasel
2025-12-12 17:04:45 -05:00
commit b9dfb86260
746 changed files with 232071 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
"""
Service layer for Resource Cluster
"""

View File

@@ -0,0 +1,342 @@
"""
Admin Model Configuration Service for GT 2.0 Resource Cluster
This service fetches model configurations from the Admin Control Panel
and provides them to the Resource Cluster for LLM routing and capabilities.
"""
import asyncio
import logging
import httpx
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from dataclasses import dataclass
import json
from app.core.config import get_settings
logger = logging.getLogger(__name__)
@dataclass
class AdminModelConfig:
"""Model configuration from admin cluster"""
uuid: str # Database UUID - unique identifier for this model config
model_id: str # Business identifier - the model name used in API calls
name: str
provider: str
model_type: str
endpoint: str
api_key_name: Optional[str]
context_window: Optional[int]
max_tokens: Optional[int]
capabilities: Dict[str, Any]
cost_per_1k_input: float
cost_per_1k_output: float
is_active: bool
tenant_restrictions: Dict[str, Any]
required_capabilities: List[str]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for LLM Gateway"""
return {
"uuid": self.uuid,
"model_id": self.model_id,
"name": self.name,
"provider": self.provider,
"model_type": self.model_type,
"endpoint": self.endpoint,
"api_key_name": self.api_key_name,
"context_window": self.context_window,
"max_tokens": self.max_tokens,
"capabilities": self.capabilities,
"cost_per_1k_input": self.cost_per_1k_input,
"cost_per_1k_output": self.cost_per_1k_output,
"is_active": self.is_active,
"tenant_restrictions": self.tenant_restrictions,
"required_capabilities": self.required_capabilities
}
class AdminModelConfigService:
"""Service for fetching model configurations from Admin Control Panel"""
def __init__(self):
self.settings = get_settings()
self._model_cache: Dict[str, AdminModelConfig] = {} # model_id -> config
self._uuid_cache: Dict[str, AdminModelConfig] = {} # uuid -> config (for UUID-based lookups)
self._tenant_model_cache: Dict[str, List[str]] = {} # tenant_id -> list of allowed model_ids
self._last_sync: datetime = datetime.min
self._sync_interval = timedelta(seconds=self.settings.config_sync_interval)
self._sync_lock = asyncio.Lock()
async def get_model_config(self, model_id: str) -> Optional[AdminModelConfig]:
"""Get configuration for a specific model by model_id string"""
await self._ensure_fresh_cache()
return self._model_cache.get(model_id)
async def get_model_by_uuid(self, uuid: str) -> Optional[AdminModelConfig]:
"""Get configuration for a specific model by database UUID"""
await self._ensure_fresh_cache()
return self._uuid_cache.get(uuid)
async def get_all_models(self, active_only: bool = True) -> List[AdminModelConfig]:
"""Get all model configurations"""
await self._ensure_fresh_cache()
models = list(self._model_cache.values())
if active_only:
models = [m for m in models if m.is_active]
return models
async def get_tenant_models(self, tenant_id: str) -> List[AdminModelConfig]:
"""Get models available to a specific tenant"""
await self._ensure_fresh_cache()
# Get tenant's allowed model IDs - try multiple formats
allowed_model_ids = self._get_tenant_model_ids(tenant_id)
# Return model configs for allowed models
models = []
for model_id in allowed_model_ids:
if model_id in self._model_cache and self._model_cache[model_id].is_active:
models.append(self._model_cache[model_id])
return models
async def check_tenant_access(self, tenant_id: str, model_id: str) -> bool:
"""Check if a tenant has access to a specific model"""
await self._ensure_fresh_cache()
# Check if model exists and is active
model_config = self._model_cache.get(model_id)
if not model_config or not model_config.is_active:
return False
# Only use tenant-specific access (no global access)
# This enforces proper tenant model assignments
allowed_models = self._get_tenant_model_ids(tenant_id)
return model_id in allowed_models
def _get_tenant_model_ids(self, tenant_id: str) -> List[str]:
"""Get model IDs for tenant, handling multiple tenant ID formats"""
# Try exact match first (e.g., "test-company")
allowed_models = self._tenant_model_cache.get(tenant_id, [])
if not allowed_models:
# Try converting "test-company" to "test" format
if "-" in tenant_id:
domain_format = tenant_id.split("-")[0]
allowed_models = self._tenant_model_cache.get(domain_format, [])
# Try converting "test" to "test-company" format
elif tenant_id + "-company" in self._tenant_model_cache:
allowed_models = self._tenant_model_cache.get(tenant_id + "-company", [])
# Also try tenant_id as numeric string
for key, models in self._tenant_model_cache.items():
if key.isdigit() and tenant_id in key:
allowed_models.extend(models)
break
logger.debug(f"Tenant {tenant_id} has access to models: {allowed_models}")
return allowed_models
async def get_groq_api_key(self, tenant_id: str = None) -> Optional[str]:
"""
Get Groq API key for a tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
API keys are managed in Control Panel and fetched via internal API.
Args:
tenant_id: Tenant domain string (required for tenant requests)
Returns:
Decrypted Groq API key
Raises:
ValueError: If no API key configured for tenant
"""
if not tenant_id:
raise ValueError("tenant_id is required to fetch Groq API key - no fallback to environment variables")
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="groq")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No Groq API key configured for tenant '{tenant_id}': {e}")
raise ValueError(f"No Groq API key configured for tenant '{tenant_id}'. Please configure in Control Panel → API Keys.")
except RuntimeError as e:
logger.error(f"Control Panel API error when fetching API key: {e}")
raise ValueError(f"Unable to retrieve API key - Control Panel service unavailable: {e}")
async def _ensure_fresh_cache(self):
"""Ensure model cache is fresh, sync if needed"""
now = datetime.utcnow()
if now - self._last_sync > self._sync_interval:
async with self._sync_lock:
# Double-check after acquiring lock
now = datetime.utcnow()
if now - self._last_sync <= self._sync_interval:
return
await self._sync_from_admin()
async def _sync_from_admin(self):
"""Sync model configurations from admin cluster"""
try:
# Use correct URL for containerized environment
import os
if os.path.exists('/.dockerenv'):
admin_url = "http://control-panel-backend:8000"
else:
admin_url = self.settings.admin_cluster_url.rstrip('/')
async with httpx.AsyncClient(timeout=30.0) as client:
# Fetch all model configurations
models_response = await client.get(
f"{admin_url}/api/v1/models/?active_only=true&include_stats=true"
)
# Fetch tenant model assignments with proper authentication
tenant_models_response = await client.get(
f"{admin_url}/api/v1/tenant-models/tenants/all",
headers={
"Authorization": "Bearer admin-dev-token",
"Content-Type": "application/json"
}
)
if models_response.status_code == 200:
models_data = models_response.json()
if models_data and len(models_data) > 0:
await self._update_model_cache(models_data)
logger.info(f"Successfully synced {len(models_data)} models from admin cluster")
# Update tenant model assignments if available
if tenant_models_response.status_code == 200:
tenant_data = tenant_models_response.json()
if tenant_data and len(tenant_data) > 0:
await self._update_tenant_cache(tenant_data)
logger.info(f"Successfully synced {len(tenant_data)} tenant model assignments")
else:
logger.warning("No tenant model assignments found")
else:
logger.error(f"Failed to fetch tenant assignments: {tenant_models_response.status_code}")
# Log the actual error for debugging
try:
error_response = tenant_models_response.json()
logger.error(f"Tenant assignments error: {error_response}")
except:
logger.error(f"Tenant assignments error text: {tenant_models_response.text}")
self._last_sync = datetime.utcnow()
return
else:
logger.warning("Admin cluster returned empty model list")
else:
logger.warning(f"Failed to fetch models from admin cluster: {models_response.status_code}")
logger.info("No models configured in admin backend")
self._last_sync = datetime.utcnow()
logger.info(f"Loaded {len(self._model_cache)} models successfully")
except Exception as e:
logger.error(f"Failed to sync from admin cluster: {e}")
# Log final state - no fallback models
if not self._model_cache:
logger.warning("No models available - admin backend has no models configured")
async def _update_model_cache(self, models_data: List[Dict[str, Any]]):
"""Update model configuration cache"""
new_cache = {}
new_uuid_cache = {}
for model_data in models_data:
try:
specs = model_data.get("specifications", {})
cost = model_data.get("cost", {})
status = model_data.get("status", {})
# Get UUID from 'id' field in API response (Control Panel returns UUID as 'id')
model_uuid = model_data.get("id", "")
model_config = AdminModelConfig(
uuid=model_uuid,
model_id=model_data["model_id"],
name=model_data.get("name", model_data["model_id"]),
provider=model_data["provider"],
model_type=model_data["model_type"],
endpoint=model_data.get("endpoint", ""),
api_key_name=model_data.get("api_key_name"),
context_window=specs.get("context_window"),
max_tokens=specs.get("max_tokens"),
capabilities=model_data.get("capabilities", {}),
cost_per_1k_input=cost.get("per_1k_input", 0.0),
cost_per_1k_output=cost.get("per_1k_output", 0.0),
is_active=status.get("is_active", False),
tenant_restrictions=model_data.get("tenant_restrictions", {"global_access": True}),
required_capabilities=model_data.get("required_capabilities", [])
)
new_cache[model_config.model_id] = model_config
# Also index by UUID for UUID-based lookups
if model_uuid:
new_uuid_cache[model_uuid] = model_config
except Exception as e:
logger.error(f"Failed to parse model config {model_data.get('model_id', 'unknown')}: {e}")
self._model_cache = new_cache
self._uuid_cache = new_uuid_cache
async def _update_tenant_cache(self, tenant_data: List[Dict[str, Any]]):
"""Update tenant model access cache from tenant-models endpoint"""
new_tenant_cache = {}
for assignment in tenant_data:
try:
# The tenant-models endpoint returns different format than the old endpoint
tenant_domain = assignment.get("tenant_domain", "")
model_id = assignment["model_id"]
is_enabled = assignment.get("is_enabled", True)
if is_enabled and tenant_domain:
if tenant_domain not in new_tenant_cache:
new_tenant_cache[tenant_domain] = []
new_tenant_cache[tenant_domain].append(model_id)
# Also add by tenant_id for backward compatibility
tenant_id = str(assignment.get("tenant_id", ""))
if tenant_id and tenant_id not in new_tenant_cache:
new_tenant_cache[tenant_id] = []
if tenant_id:
new_tenant_cache[tenant_id].append(model_id)
except Exception as e:
logger.error(f"Failed to parse tenant assignment: {e}")
self._tenant_model_cache = new_tenant_cache
logger.debug(f"Updated tenant cache: {self._tenant_model_cache}")
async def force_sync(self):
"""Force immediate sync from admin cluster"""
self._last_sync = datetime.min
await self._ensure_fresh_cache()
# Global instance
_admin_model_service = None
def get_admin_model_service() -> AdminModelConfigService:
"""Get singleton admin model service"""
global _admin_model_service
if _admin_model_service is None:
_admin_model_service = AdminModelConfigService()
return _admin_model_service

View File

@@ -0,0 +1,931 @@
"""
Agent Orchestration System for GT 2.0 Resource Cluster
Provides multi-agent workflow execution with:
- Sequential, parallel, and conditional agent workflows
- Inter-agent communication and memory management
- Capability-based access control
- Agent lifecycle management
- Performance monitoring and metrics
GT 2.0 Architecture Principles:
- Perfect Tenant Isolation: Agent sessions isolated per tenant
- Zero Downtime: Stateless design, resumable workflows
- Self-Contained Security: Capability-based agent permissions
- No Complexity Addition: Simple orchestration patterns
"""
import asyncio
import logging
import json
import time
import uuid
from typing import Dict, Any, List, Optional, Union, Callable, Coroutine
from datetime import datetime, timedelta
from enum import Enum
from dataclasses import dataclass, asdict
import traceback
from app.core.capability_auth import verify_capability_token, CapabilityError
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class AgentStatus(str, Enum):
"""Agent execution status"""
IDLE = "idle"
RUNNING = "running"
WAITING = "waiting"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class WorkflowType(str, Enum):
"""Types of agent workflows"""
SEQUENTIAL = "sequential"
PARALLEL = "parallel"
CONDITIONAL = "conditional"
PIPELINE = "pipeline"
MAP_REDUCE = "map_reduce"
class MessageType(str, Enum):
"""Inter-agent message types"""
DATA = "data"
CONTROL = "control"
ERROR = "error"
HEARTBEAT = "heartbeat"
@dataclass
class AgentDefinition:
"""Definition of an agent"""
agent_id: str
agent_type: str
name: str
description: str
capabilities_required: List[str]
memory_limit_mb: int = 256
timeout_seconds: int = 300
retry_count: int = 3
environment: Dict[str, Any] = None
@dataclass
class AgentMessage:
"""Message between agents"""
message_id: str
from_agent: str
to_agent: str
message_type: MessageType
content: Dict[str, Any]
timestamp: str
expires_at: Optional[str] = None
@dataclass
class AgentState:
"""Current state of an agent"""
agent_id: str
status: AgentStatus
current_task: Optional[str]
memory_usage_mb: int
cpu_usage_percent: float
started_at: str
last_activity: str
error_message: Optional[str] = None
output_data: Dict[str, Any] = None
@dataclass
class WorkflowExecution:
"""Workflow execution instance"""
workflow_id: str
workflow_type: WorkflowType
tenant_id: str
created_by: str
agents: List[AgentDefinition]
workflow_config: Dict[str, Any]
status: AgentStatus
started_at: str
completed_at: Optional[str] = None
results: Dict[str, Any] = None
error_message: Optional[str] = None
class AgentMemoryManager:
"""Manages agent memory and state"""
def __init__(self):
# In-memory storage (PostgreSQL used for persistent storage)
self._agent_memory: Dict[str, Dict[str, Any]] = {}
self._shared_memory: Dict[str, Dict[str, Any]] = {}
self._message_queues: Dict[str, List[AgentMessage]] = {}
async def store_agent_memory(
self,
agent_id: str,
key: str,
value: Any,
ttl_seconds: Optional[int] = None
) -> None:
"""Store data in agent-specific memory"""
if agent_id not in self._agent_memory:
self._agent_memory[agent_id] = {}
self._agent_memory[agent_id][key] = {
"value": value,
"created_at": datetime.utcnow().isoformat(),
"expires_at": (
datetime.utcnow() + timedelta(seconds=ttl_seconds)
).isoformat() if ttl_seconds else None
}
logger.debug(f"Stored memory for agent {agent_id}: {key}")
async def get_agent_memory(
self,
agent_id: str,
key: str
) -> Optional[Any]:
"""Retrieve data from agent-specific memory"""
if agent_id not in self._agent_memory:
return None
memory_item = self._agent_memory[agent_id].get(key)
if not memory_item:
return None
# Check expiration
if memory_item.get("expires_at"):
expires_at = datetime.fromisoformat(memory_item["expires_at"])
if datetime.utcnow() > expires_at:
del self._agent_memory[agent_id][key]
return None
return memory_item["value"]
async def store_shared_memory(
self,
tenant_id: str,
key: str,
value: Any,
ttl_seconds: Optional[int] = None
) -> None:
"""Store data in tenant-shared memory"""
if tenant_id not in self._shared_memory:
self._shared_memory[tenant_id] = {}
self._shared_memory[tenant_id][key] = {
"value": value,
"created_at": datetime.utcnow().isoformat(),
"expires_at": (
datetime.utcnow() + timedelta(seconds=ttl_seconds)
).isoformat() if ttl_seconds else None
}
logger.debug(f"Stored shared memory for tenant {tenant_id}: {key}")
async def get_shared_memory(
self,
tenant_id: str,
key: str
) -> Optional[Any]:
"""Retrieve data from tenant-shared memory"""
if tenant_id not in self._shared_memory:
return None
memory_item = self._shared_memory[tenant_id].get(key)
if not memory_item:
return None
# Check expiration
if memory_item.get("expires_at"):
expires_at = datetime.fromisoformat(memory_item["expires_at"])
if datetime.utcnow() > expires_at:
del self._shared_memory[tenant_id][key]
return None
return memory_item["value"]
async def send_message(
self,
message: AgentMessage
) -> None:
"""Send message to agent queue"""
if message.to_agent not in self._message_queues:
self._message_queues[message.to_agent] = []
self._message_queues[message.to_agent].append(message)
logger.debug(f"Message sent from {message.from_agent} to {message.to_agent}")
async def receive_messages(
self,
agent_id: str,
message_type: Optional[MessageType] = None
) -> List[AgentMessage]:
"""Receive messages for agent"""
if agent_id not in self._message_queues:
return []
messages = self._message_queues[agent_id]
# Filter expired messages
now = datetime.utcnow()
messages = [
msg for msg in messages
if not msg.expires_at or datetime.fromisoformat(msg.expires_at) > now
]
# Filter by message type if specified
if message_type:
messages = [msg for msg in messages if msg.message_type == message_type]
# Clear processed messages
if message_type:
self._message_queues[agent_id] = [
msg for msg in self._message_queues[agent_id]
if msg.message_type != message_type or
(msg.expires_at and datetime.fromisoformat(msg.expires_at) <= now)
]
else:
self._message_queues[agent_id] = []
return messages
async def cleanup_agent_memory(self, agent_id: str) -> None:
"""Clean up memory for completed agent"""
if agent_id in self._agent_memory:
del self._agent_memory[agent_id]
if agent_id in self._message_queues:
del self._message_queues[agent_id]
logger.debug(f"Cleaned up memory for agent {agent_id}")
class AgentOrchestrator:
"""
Main agent orchestration system for GT 2.0.
Manages agent lifecycle, workflows, communication, and resource allocation.
All operations are tenant-isolated and capability-protected.
"""
def __init__(self):
self.memory_manager = AgentMemoryManager()
self.active_workflows: Dict[str, WorkflowExecution] = {}
self.agent_states: Dict[str, AgentState] = {}
# Built-in agent types
self.agent_registry: Dict[str, Dict[str, Any]] = {
"data_processor": {
"description": "Processes and transforms data",
"capabilities": ["data.read", "data.transform"],
"memory_limit_mb": 512,
"timeout_seconds": 300
},
"llm_agent": {
"description": "Interacts with LLM services",
"capabilities": ["llm.inference", "llm.chat"],
"memory_limit_mb": 256,
"timeout_seconds": 600
},
"embedding_agent": {
"description": "Generates text embeddings",
"capabilities": ["embeddings.generate"],
"memory_limit_mb": 256,
"timeout_seconds": 180
},
"rag_agent": {
"description": "Performs retrieval-augmented generation",
"capabilities": ["rag.search", "rag.generate"],
"memory_limit_mb": 512,
"timeout_seconds": 450
},
"integration_agent": {
"description": "Connects to external services",
"capabilities": ["integration.call", "integration.webhook"],
"memory_limit_mb": 256,
"timeout_seconds": 300
}
}
logger.info("Agent orchestrator initialized")
async def create_workflow(
self,
workflow_type: WorkflowType,
agents: List[AgentDefinition],
workflow_config: Dict[str, Any],
capability_token: str,
workflow_name: Optional[str] = None
) -> str:
"""
Create a new agent workflow.
Args:
workflow_type: Type of workflow to create
agents: List of agents to include in workflow
workflow_config: Configuration for the workflow
capability_token: JWT token with workflow permissions
workflow_name: Optional name for the workflow
Returns:
Workflow ID
"""
# Verify capability token
capability = await verify_capability_token(capability_token)
tenant_id = capability.get("tenant_id")
user_id = capability.get("sub")
# Check workflow permissions
await self._verify_workflow_permissions(capability, workflow_type, agents)
# Generate workflow ID
workflow_id = str(uuid.uuid4())
# Create workflow execution
workflow = WorkflowExecution(
workflow_id=workflow_id,
workflow_type=workflow_type,
tenant_id=tenant_id,
created_by=user_id,
agents=agents,
workflow_config=workflow_config,
status=AgentStatus.IDLE,
started_at=datetime.utcnow().isoformat()
)
# Store workflow
self.active_workflows[workflow_id] = workflow
logger.info(
f"Created {workflow_type} workflow {workflow_id} "
f"with {len(agents)} agents for tenant {tenant_id}"
)
return workflow_id
async def execute_workflow(
self,
workflow_id: str,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""
Execute an agent workflow.
Args:
workflow_id: ID of workflow to execute
input_data: Input data for the workflow
capability_token: JWT token with execution permissions
Returns:
Workflow execution results
"""
# Verify capability token
capability = await verify_capability_token(capability_token)
tenant_id = capability.get("tenant_id")
# Get workflow
workflow = self.active_workflows.get(workflow_id)
if not workflow:
raise ValueError(f"Workflow {workflow_id} not found")
# Check tenant isolation
if workflow.tenant_id != tenant_id:
raise CapabilityError("Insufficient permissions for workflow")
# Check workflow permissions
await self._verify_execution_permissions(capability, workflow)
try:
# Update workflow status
workflow.status = AgentStatus.RUNNING
# Execute based on workflow type
if workflow.workflow_type == WorkflowType.SEQUENTIAL:
results = await self._execute_sequential_workflow(
workflow, input_data, capability_token
)
elif workflow.workflow_type == WorkflowType.PARALLEL:
results = await self._execute_parallel_workflow(
workflow, input_data, capability_token
)
elif workflow.workflow_type == WorkflowType.CONDITIONAL:
results = await self._execute_conditional_workflow(
workflow, input_data, capability_token
)
elif workflow.workflow_type == WorkflowType.PIPELINE:
results = await self._execute_pipeline_workflow(
workflow, input_data, capability_token
)
elif workflow.workflow_type == WorkflowType.MAP_REDUCE:
results = await self._execute_map_reduce_workflow(
workflow, input_data, capability_token
)
else:
raise ValueError(f"Unsupported workflow type: {workflow.workflow_type}")
# Update workflow completion
workflow.status = AgentStatus.COMPLETED
workflow.completed_at = datetime.utcnow().isoformat()
workflow.results = results
logger.info(f"Completed workflow {workflow_id} successfully")
return results
except Exception as e:
# Update workflow error status
workflow.status = AgentStatus.FAILED
workflow.completed_at = datetime.utcnow().isoformat()
workflow.error_message = str(e)
logger.error(f"Workflow {workflow_id} failed: {e}")
raise
async def get_workflow_status(
self,
workflow_id: str,
capability_token: str
) -> Dict[str, Any]:
"""Get status of a workflow"""
# Verify capability token
capability = await verify_capability_token(capability_token)
tenant_id = capability.get("tenant_id")
# Get workflow
workflow = self.active_workflows.get(workflow_id)
if not workflow:
raise ValueError(f"Workflow {workflow_id} not found")
# Check tenant isolation
if workflow.tenant_id != tenant_id:
raise CapabilityError("Insufficient permissions for workflow")
# Get agent states for this workflow
agent_states = {
agent.agent_id: asdict(self.agent_states.get(agent.agent_id))
for agent in workflow.agents
if agent.agent_id in self.agent_states
}
return {
"workflow": asdict(workflow),
"agent_states": agent_states
}
async def cancel_workflow(
self,
workflow_id: str,
capability_token: str
) -> None:
"""Cancel a running workflow"""
# Verify capability token
capability = await verify_capability_token(capability_token)
tenant_id = capability.get("tenant_id")
# Get workflow
workflow = self.active_workflows.get(workflow_id)
if not workflow:
raise ValueError(f"Workflow {workflow_id} not found")
# Check tenant isolation
if workflow.tenant_id != tenant_id:
raise CapabilityError("Insufficient permissions for workflow")
# Cancel workflow
workflow.status = AgentStatus.CANCELLED
workflow.completed_at = datetime.utcnow().isoformat()
# Cancel all agents in workflow
for agent in workflow.agents:
if agent.agent_id in self.agent_states:
self.agent_states[agent.agent_id].status = AgentStatus.CANCELLED
logger.info(f"Cancelled workflow {workflow_id}")
async def _execute_sequential_workflow(
self,
workflow: WorkflowExecution,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute agents sequentially"""
results = {}
current_data = input_data
for agent in workflow.agents:
agent_result = await self._execute_agent(
agent, current_data, capability_token
)
results[agent.agent_id] = agent_result
# Pass output to next agent
if "output" in agent_result:
current_data = agent_result["output"]
return {
"workflow_type": "sequential",
"final_output": current_data,
"agent_results": results
}
async def _execute_parallel_workflow(
self,
workflow: WorkflowExecution,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute agents in parallel"""
# Create tasks for all agents
tasks = []
for agent in workflow.agents:
task = asyncio.create_task(
self._execute_agent(agent, input_data, capability_token)
)
tasks.append((agent.agent_id, task))
# Wait for all tasks to complete
results = {}
for agent_id, task in tasks:
try:
results[agent_id] = await task
except Exception as e:
results[agent_id] = {"error": str(e)}
return {
"workflow_type": "parallel",
"agent_results": results
}
async def _execute_conditional_workflow(
self,
workflow: WorkflowExecution,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute agents based on conditions"""
results = {}
condition_config = workflow.workflow_config.get("conditions", {})
for agent in workflow.agents:
# Check if agent should execute based on conditions
should_execute = await self._evaluate_condition(
agent.agent_id, condition_config, input_data, results
)
if should_execute:
agent_result = await self._execute_agent(
agent, input_data, capability_token
)
results[agent.agent_id] = agent_result
else:
results[agent.agent_id] = {"status": "skipped"}
return {
"workflow_type": "conditional",
"agent_results": results
}
async def _execute_pipeline_workflow(
self,
workflow: WorkflowExecution,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute agents in pipeline with data transformation"""
results = {}
current_data = input_data
for i, agent in enumerate(workflow.agents):
# Add pipeline metadata
pipeline_data = {
**current_data,
"_pipeline_stage": i,
"_pipeline_total": len(workflow.agents)
}
agent_result = await self._execute_agent(
agent, pipeline_data, capability_token
)
results[agent.agent_id] = agent_result
# Transform data for next stage
if "transformed_output" in agent_result:
current_data = agent_result["transformed_output"]
elif "output" in agent_result:
current_data = agent_result["output"]
return {
"workflow_type": "pipeline",
"final_output": current_data,
"agent_results": results
}
async def _execute_map_reduce_workflow(
self,
workflow: WorkflowExecution,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute map-reduce workflow"""
# Separate map and reduce agents
map_agents = [a for a in workflow.agents if a.agent_type.endswith("_mapper")]
reduce_agents = [a for a in workflow.agents if a.agent_type.endswith("_reducer")]
# Execute map phase
map_tasks = []
input_chunks = input_data.get("chunks", [input_data])
for i, chunk in enumerate(input_chunks):
for agent in map_agents:
task = asyncio.create_task(
self._execute_agent(agent, chunk, capability_token)
)
map_tasks.append((f"{agent.agent_id}_chunk_{i}", task))
# Collect map results
map_results = {}
for task_id, task in map_tasks:
try:
map_results[task_id] = await task
except Exception as e:
map_results[task_id] = {"error": str(e)}
# Execute reduce phase
reduce_results = {}
reduce_input = {"map_results": map_results}
for agent in reduce_agents:
agent_result = await self._execute_agent(
agent, reduce_input, capability_token
)
reduce_results[agent.agent_id] = agent_result
return {
"workflow_type": "map_reduce",
"map_results": map_results,
"reduce_results": reduce_results
}
async def _execute_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute a single agent"""
start_time = time.time()
# Create agent state
agent_state = AgentState(
agent_id=agent.agent_id,
status=AgentStatus.RUNNING,
current_task=f"Executing {agent.agent_type}",
memory_usage_mb=0,
cpu_usage_percent=0.0,
started_at=datetime.utcnow().isoformat(),
last_activity=datetime.utcnow().isoformat()
)
self.agent_states[agent.agent_id] = agent_state
try:
# Simulate agent execution based on type
if agent.agent_type == "data_processor":
result = await self._execute_data_processor(agent, input_data)
elif agent.agent_type == "llm_agent":
result = await self._execute_llm_agent(agent, input_data, capability_token)
elif agent.agent_type == "embedding_agent":
result = await self._execute_embedding_agent(agent, input_data, capability_token)
elif agent.agent_type == "rag_agent":
result = await self._execute_rag_agent(agent, input_data, capability_token)
elif agent.agent_type == "integration_agent":
result = await self._execute_integration_agent(agent, input_data, capability_token)
else:
result = await self._execute_custom_agent(agent, input_data)
# Update agent state
agent_state.status = AgentStatus.COMPLETED
agent_state.output_data = result
agent_state.last_activity = datetime.utcnow().isoformat()
processing_time = time.time() - start_time
logger.info(
f"Agent {agent.agent_id} completed in {processing_time:.2f}s"
)
return {
"status": "completed",
"processing_time": processing_time,
"output": result
}
except Exception as e:
# Update agent error state
agent_state.status = AgentStatus.FAILED
agent_state.error_message = str(e)
agent_state.last_activity = datetime.utcnow().isoformat()
logger.error(f"Agent {agent.agent_id} failed: {e}")
return {
"status": "failed",
"error": str(e),
"processing_time": time.time() - start_time
}
# Agent execution implementations would go here...
# For now, these are placeholder implementations
async def _execute_data_processor(
self,
agent: AgentDefinition,
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute data processing agent"""
await asyncio.sleep(0.1) # Simulate processing
return {
"processed_data": input_data,
"processing_info": "Data processed successfully"
}
async def _execute_llm_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute LLM agent"""
await asyncio.sleep(0.2) # Simulate LLM call
return {
"llm_response": f"LLM processed: {input_data.get('prompt', 'No prompt provided')}",
"model_used": "groq/llama-3-8b"
}
async def _execute_embedding_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute embedding agent"""
await asyncio.sleep(0.1) # Simulate embedding generation
texts = input_data.get("texts", [""])
return {
"embeddings": [[0.1] * 1024 for _ in texts], # Mock embeddings
"model_used": "BAAI/bge-m3"
}
async def _execute_rag_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute RAG agent"""
await asyncio.sleep(0.3) # Simulate RAG processing
return {
"rag_response": "RAG generated response",
"retrieved_docs": ["doc1", "doc2"],
"confidence_score": 0.85
}
async def _execute_integration_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute integration agent"""
await asyncio.sleep(0.1) # Simulate external API call
return {
"integration_result": "External API called successfully",
"response_data": input_data
}
async def _execute_custom_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute custom agent type"""
await asyncio.sleep(0.1) # Simulate custom processing
return {
"custom_result": f"Custom agent {agent.agent_type} executed",
"input_data": input_data
}
async def _verify_workflow_permissions(
self,
capability: Dict[str, Any],
workflow_type: WorkflowType,
agents: List[AgentDefinition]
) -> None:
"""Verify workflow creation permissions"""
capabilities = capability.get("capabilities", [])
# Check for workflow creation permission
workflow_caps = [
cap for cap in capabilities
if cap.get("resource") == "workflows"
]
if not workflow_caps:
raise CapabilityError("No workflow permissions in capability token")
# Check specific workflow type permission
workflow_cap = workflow_caps[0]
actions = workflow_cap.get("actions", [])
if "create" not in actions:
raise CapabilityError("No workflow creation permission")
# Check agent-specific permissions
for agent in agents:
for required_cap in agent.capabilities_required:
if not any(
cap.get("resource") == required_cap.split(".")[0]
for cap in capabilities
):
raise CapabilityError(
f"Missing capability for agent {agent.agent_id}: {required_cap}"
)
async def _verify_execution_permissions(
self,
capability: Dict[str, Any],
workflow: WorkflowExecution
) -> None:
"""Verify workflow execution permissions"""
capabilities = capability.get("capabilities", [])
# Check for workflow execution permission
workflow_caps = [
cap for cap in capabilities
if cap.get("resource") == "workflows"
]
if not workflow_caps:
raise CapabilityError("No workflow permissions in capability token")
workflow_cap = workflow_caps[0]
actions = workflow_cap.get("actions", [])
if "execute" not in actions:
raise CapabilityError("No workflow execution permission")
async def _evaluate_condition(
self,
agent_id: str,
condition_config: Dict[str, Any],
input_data: Dict[str, Any],
results: Dict[str, Any]
) -> bool:
"""Evaluate condition for conditional workflow"""
agent_condition = condition_config.get(agent_id, {})
if not agent_condition:
return True # No condition means always execute
condition_type = agent_condition.get("type", "always")
if condition_type == "always":
return True
elif condition_type == "never":
return False
elif condition_type == "input_contains":
key = agent_condition.get("key")
value = agent_condition.get("value")
return input_data.get(key) == value
elif condition_type == "previous_success":
previous_agent = agent_condition.get("previous_agent")
return (
previous_agent in results and
results[previous_agent].get("status") == "completed"
)
elif condition_type == "previous_failure":
previous_agent = agent_condition.get("previous_agent")
return (
previous_agent in results and
results[previous_agent].get("status") == "failed"
)
return True # Default to execute if condition not recognized
# Global orchestrator instance
_agent_orchestrator = None
def get_agent_orchestrator() -> AgentOrchestrator:
"""Get the global agent orchestrator instance"""
global _agent_orchestrator
if _agent_orchestrator is None:
_agent_orchestrator = AgentOrchestrator()
return _agent_orchestrator

View File

@@ -0,0 +1,280 @@
"""
GT 2.0 Configuration Sync Service
Syncs model configurations from admin cluster to resource cluster.
Enables admin control panel to control AI model routing.
"""
import asyncio
import httpx
import json
import time
import logging
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from pathlib import Path
from app.core.config import get_settings
from app.services.model_service import default_model_service
from app.providers.external_provider import get_external_provider
logger = logging.getLogger(__name__)
settings = get_settings()
class ConfigSyncService:
"""Syncs model configurations from admin cluster"""
def __init__(self):
# Force Docker service name for admin cluster communication in containerized environment
if hasattr(settings, 'admin_cluster_url') and settings.admin_cluster_url:
# Check if we're running in Docker (container environment)
import os
if os.path.exists('/.dockerenv'):
self.admin_cluster_url = "http://control-panel-backend:8000"
else:
self.admin_cluster_url = settings.admin_cluster_url
else:
self.admin_cluster_url = "http://control-panel-backend:8000"
self.sync_interval = settings.config_sync_interval or 60 # seconds
# Use the default singleton model service instance
self.model_service = default_model_service
self.last_sync = 0
self.sync_running = False
async def start_sync_loop(self):
"""Start the configuration sync loop"""
logger.info("Starting configuration sync loop")
while True:
try:
if not self.sync_running:
await self.sync_configurations()
await asyncio.sleep(self.sync_interval)
except Exception as e:
logger.error(f"Config sync loop error: {e}")
await asyncio.sleep(30) # Wait 30s on error
async def sync_configurations(self):
"""Sync model configurations from admin cluster"""
if self.sync_running:
return
self.sync_running = True
try:
logger.debug("Syncing model configurations from admin cluster")
# Fetch all model configurations from admin cluster
configs = await self._fetch_admin_configs()
if configs:
# Update local model registry
await self._update_local_registry(configs)
# Update provider configurations
await self._update_provider_configs(configs)
self.last_sync = time.time()
logger.info(f"Successfully synced {len(configs)} model configurations")
else:
logger.warning("No configurations received from admin cluster")
except Exception as e:
logger.error(f"Configuration sync failed: {e}")
finally:
self.sync_running = False
async def _fetch_admin_configs(self) -> Optional[List[Dict[str, Any]]]:
"""Fetch model configurations from admin cluster"""
try:
logger.debug(f"Attempting to fetch configs from: {self.admin_cluster_url}/api/v1/models/configs/all")
async with httpx.AsyncClient(timeout=30.0) as client:
# Add authentication for admin cluster access
headers = {
"Authorization": "Bearer admin-cluster-sync-token",
"Content-Type": "application/json"
}
response = await client.get(
f"{self.admin_cluster_url}/api/v1/models/configs/all",
headers=headers
)
logger.debug(f"Admin cluster response: {response.status_code}")
if response.status_code == 200:
data = response.json()
configs = data.get("configs", [])
logger.debug(f"Successfully fetched {len(configs)} model configurations")
return configs
else:
logger.warning(f"Admin cluster returned {response.status_code}: {response.text}")
return None
except httpx.RequestError as e:
logger.error(f"Failed to connect to admin cluster: {e}")
return None
except Exception as e:
logger.error(f"Error fetching admin configs: {e}")
return None
async def _update_local_registry(self, configs: List[Dict[str, Any]]):
"""Update local model registry with admin configurations"""
try:
for config in configs:
await self.model_service.register_or_update_model(
model_id=config["model_id"],
name=config["name"],
version=config["version"],
provider=config["provider"],
model_type=config["model_type"],
endpoint=config["endpoint"],
api_key_name=config.get("api_key_name"),
specifications=config.get("specifications", {}),
capabilities=config.get("capabilities", {}),
cost=config.get("cost", {}),
description=config.get("description"),
config=config.get("config", {}),
status=config.get("status", {}),
sync_timestamp=config.get("sync_timestamp")
)
# Log BGE-M3 configuration details for debugging persistence
if "bge-m3" in config["model_id"].lower():
model_config = config.get("config", {})
logger.info(
f"Synced BGE-M3 configuration from database: "
f"endpoint={config['endpoint']}, "
f"is_local_mode={model_config.get('is_local_mode', True)}, "
f"external_endpoint={model_config.get('external_endpoint', 'None')}"
)
except Exception as e:
logger.error(f"Failed to update local registry: {e}")
raise
async def _update_provider_configs(self, configs: List[Dict[str, Any]]):
"""Update provider configurations based on admin settings"""
try:
# Group configs by provider
provider_configs = {}
for config in configs:
provider = config["provider"]
if provider not in provider_configs:
provider_configs[provider] = []
provider_configs[provider].append(config)
# Update each provider
for provider, provider_models in provider_configs.items():
await self._update_provider(provider, provider_models)
except Exception as e:
logger.error(f"Failed to update provider configs: {e}")
raise
async def _update_provider(self, provider: str, models: List[Dict[str, Any]]):
"""Update specific provider configuration"""
try:
# Generic provider update - all providers are now supported automatically
provider_models = [m for m in models if m["provider"] == provider]
logger.debug(f"Updated {provider} provider with {len(provider_models)} models")
# Keep legacy support for specific providers if needed
if provider == "groq":
await self._update_groq_provider(models)
elif provider == "external":
await self._update_external_provider(models)
elif provider == "openai":
await self._update_openai_provider(models)
elif provider == "anthropic":
await self._update_anthropic_provider(models)
elif provider == "vllm":
await self._update_vllm_provider(models)
except Exception as e:
logger.error(f"Failed to update {provider} provider: {e}")
raise
async def _update_groq_provider(self, models: List[Dict[str, Any]]):
"""Update Groq provider configuration"""
# Update available Groq models
groq_models = [m for m in models if m["provider"] == "groq"]
logger.debug(f"Updated Groq provider with {len(groq_models)} models")
async def _update_external_provider(self, models: List[Dict[str, Any]]):
"""Update external provider configuration (BGE-M3, etc.)"""
external_models = [m for m in models if m["provider"] == "external"]
if external_models:
external_provider = await get_external_provider()
for model in external_models:
if "bge-m3" in model["model_id"].lower():
# Update BGE-M3 endpoint configuration
external_provider.update_model_endpoint(
model["model_id"],
model["endpoint"]
)
logger.debug(f"Updated BGE-M3 endpoint: {model['endpoint']}")
# Also refresh the embedding backend instance
try:
from app.core.backends import get_embedding_backend
embedding_backend = get_embedding_backend()
embedding_backend.refresh_endpoint_from_registry()
logger.info(f"Refreshed embedding backend with new BGE-M3 endpoint from database")
except Exception as e:
logger.error(f"Failed to refresh embedding backend: {e}")
logger.debug(f"Updated external provider with {len(external_models)} models")
async def _update_openai_provider(self, models: List[Dict[str, Any]]):
"""Update OpenAI provider configuration"""
openai_models = [m for m in models if m["provider"] == "openai"]
logger.debug(f"Updated OpenAI provider with {len(openai_models)} models")
async def _update_anthropic_provider(self, models: List[Dict[str, Any]]):
"""Update Anthropic provider configuration"""
anthropic_models = [m for m in models if m["provider"] == "anthropic"]
logger.debug(f"Updated Anthropic provider with {len(anthropic_models)} models")
async def _update_vllm_provider(self, models: List[Dict[str, Any]]):
"""Update vLLM provider configuration (BGE-M3 embeddings, etc.)"""
vllm_models = [m for m in models if m["provider"] == "vllm"]
for model in vllm_models:
if model["model_type"] == "embedding":
# This is an embedding model like BGE-M3
logger.debug(f"Updated vLLM embedding model: {model['model_id']} -> {model['endpoint']}")
else:
logger.debug(f"Updated vLLM model: {model['model_id']} -> {model['endpoint']}")
logger.debug(f"Updated vLLM provider with {len(vllm_models)} models")
async def force_sync(self):
"""Force immediate configuration sync"""
logger.info("Force syncing configurations")
await self.sync_configurations()
def get_sync_status(self) -> Dict[str, Any]:
"""Get current sync status"""
return {
"last_sync": datetime.fromtimestamp(self.last_sync).isoformat() if self.last_sync else None,
"sync_running": self.sync_running,
"admin_cluster_url": self.admin_cluster_url,
"sync_interval": self.sync_interval,
"next_sync": datetime.fromtimestamp(self.last_sync + self.sync_interval).isoformat() if self.last_sync else None
}
# Global config sync service instance
_config_sync_service = None
def get_config_sync_service() -> ConfigSyncService:
"""Get configuration sync service instance"""
global _config_sync_service
if _config_sync_service is None:
_config_sync_service = ConfigSyncService()
return _config_sync_service

View File

@@ -0,0 +1,101 @@
"""
Consul Service Registry
Handles service registration and discovery for the Resource Cluster.
"""
import logging
from typing import Dict, Any, List, Optional
import consul
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class ConsulRegistry:
"""Service registry using Consul"""
def __init__(self):
self.consul = None
try:
self.consul = consul.Consul(
host=settings.consul_host,
port=settings.consul_port,
token=settings.consul_token
)
except Exception as e:
logger.warning(f"Consul not available: {e}")
async def register_service(
self,
name: str,
service_id: str,
address: str,
port: int,
tags: List[str] = None,
check_interval: str = "10s"
) -> bool:
"""Register service with Consul"""
if not self.consul:
logger.warning("Consul not available, skipping registration")
return False
try:
self.consul.agent.service.register(
name=name,
service_id=service_id,
address=address,
port=port,
tags=tags or [],
check=consul.Check.http(
f"http://{address}:{port}/health",
interval=check_interval
)
)
logger.info(f"Registered service {service_id} with Consul")
return True
except Exception as e:
logger.error(f"Failed to register with Consul: {e}")
return False
async def deregister_service(self, service_id: str) -> bool:
"""Deregister service from Consul"""
if not self.consul:
return False
try:
self.consul.agent.service.deregister(service_id)
logger.info(f"Deregistered service {service_id} from Consul")
return True
except Exception as e:
logger.error(f"Failed to deregister from Consul: {e}")
return False
async def discover_service(self, service_name: str) -> List[Dict[str, Any]]:
"""Discover service instances"""
if not self.consul:
return []
try:
_, services = self.consul.health.service(service_name, passing=True)
instances = []
for service in services:
instances.append({
"id": service["Service"]["ID"],
"address": service["Service"]["Address"],
"port": service["Service"]["Port"],
"tags": service["Service"]["Tags"]
})
return instances
except Exception as e:
logger.error(f"Failed to discover service: {e}")
return []

View File

@@ -0,0 +1,536 @@
"""
Enhanced Document Processing Pipeline with Dual-Engine Support
Implements the DocumentProcessingPipeline from CLAUDE.md with both native
and Unstructured.io engine support, capability-based selection, and
stateless processing.
"""
import logging
import asyncio
import gc
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime
import hashlib
import json
from app.core.backends.document_processor import (
DocumentProcessorBackend,
ChunkingStrategy
)
logger = logging.getLogger(__name__)
@dataclass
class ProcessingResult:
"""Result of document processing"""
chunks: List[Dict[str, str]]
embeddings: Optional[List[List[float]]] # Optional embeddings
metadata: Dict[str, Any]
engine_used: str
processing_time_ms: float
token_count: int
@dataclass
class ProcessingOptions:
"""Options for document processing"""
engine_preference: str = "auto" # "native", "unstructured", "auto"
chunking_strategy: str = "hybrid" # "fixed", "semantic", "hierarchical", "hybrid"
chunk_size: int = 512 # tokens for BGE-M3
chunk_overlap: int = 128 # overlap tokens
generate_embeddings: bool = True
extract_metadata: bool = True
language_detection: bool = True
ocr_enabled: bool = False # For scanned PDFs
class UnstructuredAPIEngine:
"""
Mock Unstructured.io API engine for advanced document parsing.
In production, this would call the actual Unstructured API.
"""
def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None):
self.api_key = api_key
self.api_url = api_url or "https://api.unstructured.io"
self.supported_features = [
"table_extraction",
"image_extraction",
"ocr",
"language_detection",
"metadata_extraction",
"hierarchical_parsing"
]
async def process(
self,
content: bytes,
file_type: str,
options: Dict[str, Any]
) -> Dict[str, Any]:
"""
Process document using Unstructured API.
This is a mock implementation. In production:
1. Send content to Unstructured API
2. Handle rate limiting and retries
3. Parse structured response
"""
# Mock processing delay
await asyncio.sleep(0.5)
# Mock response structure
return {
"elements": [
{
"type": "Title",
"text": "Document Title",
"metadata": {"page_number": 1}
},
{
"type": "NarrativeText",
"text": "This is the main content of the document...",
"metadata": {"page_number": 1}
}
],
"metadata": {
"languages": ["en"],
"page_count": 1,
"has_tables": False,
"has_images": False
}
}
class NativeChunkingEngine:
"""
Native chunking engine using the existing DocumentProcessorBackend.
Fast, lightweight, and suitable for most text documents.
"""
def __init__(self):
self.processor = DocumentProcessorBackend()
async def process(
self,
content: bytes,
file_type: str,
options: ProcessingOptions
) -> List[Dict[str, Any]]:
"""Process document using native chunking"""
strategy = ChunkingStrategy(
strategy_type=options.chunking_strategy,
chunk_size=options.chunk_size,
chunk_overlap=options.chunk_overlap,
preserve_paragraphs=True,
preserve_sentences=True
)
chunks = await self.processor.process_document(
content=content,
document_type=file_type,
strategy=strategy,
metadata={
"processing_timestamp": datetime.utcnow().isoformat(),
"engine": "native"
}
)
return chunks
class DocumentProcessingPipeline:
"""
Dual-engine document processing pipeline with capability-based selection.
Features:
- Native engine for fast, simple processing
- Unstructured API for advanced features
- Capability-based engine selection
- Stateless processing with memory cleanup
- Optional embedding generation
"""
def __init__(self, resource_cluster_url: Optional[str] = None):
self.resource_cluster_url = resource_cluster_url or "http://localhost:8004"
self.native_engine = NativeChunkingEngine()
self.unstructured_engine = None # Lazy initialization
self.embedding_cache = {} # Cache for frequently used embeddings
logger.info("Document Processing Pipeline initialized")
def select_engine(
self,
filename: str,
token_data: Dict[str, Any],
options: ProcessingOptions
) -> str:
"""
Select processing engine based on file type and capabilities.
Args:
filename: Name of the file being processed
token_data: Capability token data
options: Processing options
Returns:
Engine name: "native" or "unstructured"
"""
# Check if user has premium parsing capability
has_premium = any(
cap.get("resource") == "premium_parsing"
for cap in token_data.get("capabilities", [])
)
# Force native if no premium capability
if not has_premium and options.engine_preference == "unstructured":
logger.info("Premium parsing requested but not available, using native engine")
return "native"
# Auto selection logic
if options.engine_preference == "auto":
# Use Unstructured for complex formats if available
complex_formats = [".pdf", ".docx", ".pptx", ".xlsx"]
needs_ocr = options.ocr_enabled
needs_tables = filename.lower().endswith((".xlsx", ".csv"))
if has_premium and (
any(filename.lower().endswith(fmt) for fmt in complex_formats) or
needs_ocr or needs_tables
):
return "unstructured"
else:
return "native"
# Respect explicit preference if capability allows
if options.engine_preference == "unstructured" and has_premium:
return "unstructured"
return "native"
async def process_document(
self,
file: bytes,
filename: str,
token_data: Dict[str, Any],
options: Optional[ProcessingOptions] = None
) -> ProcessingResult:
"""
Process document with selected engine.
Args:
file: Document content as bytes
filename: Name of the file
token_data: Capability token data
options: Processing options
Returns:
ProcessingResult with chunks, embeddings, and metadata
"""
start_time = datetime.utcnow()
try:
# Use default options if not provided
if options is None:
options = ProcessingOptions()
# Determine file type
file_type = self._get_file_extension(filename)
# Select engine based on capabilities
engine = self.select_engine(filename, token_data, options)
# Process with selected engine
if engine == "unstructured" and token_data.get("has_capability", {}).get("premium_parsing"):
result = await self._process_with_unstructured(file, filename, token_data, options)
else:
result = await self._process_with_native(file, filename, token_data, options)
# Generate embeddings if requested
embeddings = None
if options.generate_embeddings:
embeddings = await self._generate_embeddings(result.chunks, token_data)
# Calculate processing time
processing_time = (datetime.utcnow() - start_time).total_seconds() * 1000
# Calculate token count
token_count = sum(len(chunk["text"].split()) for chunk in result.chunks)
return ProcessingResult(
chunks=result.chunks,
embeddings=embeddings,
metadata={
"filename": filename,
"file_type": file_type,
"processing_timestamp": start_time.isoformat(),
"chunk_count": len(result.chunks),
"engine_used": engine,
"options": {
"chunking_strategy": options.chunking_strategy,
"chunk_size": options.chunk_size,
"chunk_overlap": options.chunk_overlap
}
},
engine_used=engine,
processing_time_ms=processing_time,
token_count=token_count
)
except Exception as e:
logger.error(f"Error processing document: {e}")
raise
finally:
# Ensure memory cleanup
del file
gc.collect()
async def _process_with_native(
self,
file: bytes,
filename: str,
token_data: Dict[str, Any],
options: ProcessingOptions
) -> ProcessingResult:
"""Process document with native engine"""
file_type = self._get_file_extension(filename)
chunks = await self.native_engine.process(file, file_type, options)
return ProcessingResult(
chunks=chunks,
embeddings=None,
metadata={"engine": "native"},
engine_used="native",
processing_time_ms=0,
token_count=0
)
async def _process_with_unstructured(
self,
file: bytes,
filename: str,
token_data: Dict[str, Any],
options: ProcessingOptions
) -> ProcessingResult:
"""Process document with Unstructured API"""
# Initialize Unstructured engine if needed
if self.unstructured_engine is None:
# Get API key from token constraints or environment
api_key = token_data.get("constraints", {}).get("unstructured_api_key")
self.unstructured_engine = UnstructuredAPIEngine(api_key=api_key)
file_type = self._get_file_extension(filename)
# Process with Unstructured
unstructured_result = await self.unstructured_engine.process(
content=file,
file_type=file_type,
options={
"ocr": options.ocr_enabled,
"extract_tables": True,
"extract_images": False, # Don't extract images for security
"languages": ["en", "es", "fr", "de", "zh"]
}
)
# Convert Unstructured elements to chunks
chunks = []
for element in unstructured_result.get("elements", []):
chunk_text = element.get("text", "")
if chunk_text.strip():
chunks.append({
"text": chunk_text,
"metadata": {
"element_type": element.get("type"),
"page_number": element.get("metadata", {}).get("page_number"),
"engine": "unstructured"
}
})
# Apply chunking strategy if chunks are too large
final_chunks = await self._apply_chunking_to_elements(chunks, options)
return ProcessingResult(
chunks=final_chunks,
embeddings=None,
metadata={
"engine": "unstructured",
"detected_languages": unstructured_result.get("metadata", {}).get("languages", []),
"page_count": unstructured_result.get("metadata", {}).get("page_count", 0),
"has_tables": unstructured_result.get("metadata", {}).get("has_tables", False),
"has_images": unstructured_result.get("metadata", {}).get("has_images", False)
},
engine_used="unstructured",
processing_time_ms=0,
token_count=0
)
async def _apply_chunking_to_elements(
self,
elements: List[Dict[str, Any]],
options: ProcessingOptions
) -> List[Dict[str, Any]]:
"""Apply chunking strategy to Unstructured elements if needed"""
final_chunks = []
for element in elements:
text = element["text"]
# Estimate token count (rough approximation)
estimated_tokens = len(text.split()) * 1.3
# If element is small enough, keep as is
if estimated_tokens <= options.chunk_size:
final_chunks.append(element)
else:
# Split large elements using native chunking
sub_chunks = await self._chunk_text(
text,
options.chunk_size,
options.chunk_overlap
)
for idx, sub_chunk in enumerate(sub_chunks):
chunk_metadata = element["metadata"].copy()
chunk_metadata["sub_chunk_index"] = idx
chunk_metadata["parent_element_type"] = element["metadata"].get("element_type")
final_chunks.append({
"text": sub_chunk,
"metadata": chunk_metadata
})
return final_chunks
async def _chunk_text(
self,
text: str,
chunk_size: int,
chunk_overlap: int
) -> List[str]:
"""Simple text chunking for large elements"""
words = text.split()
chunks = []
# Simple word-based chunking
for i in range(0, len(words), chunk_size - chunk_overlap):
chunk_words = words[i:i + chunk_size]
chunks.append(" ".join(chunk_words))
return chunks
async def _generate_embeddings(
self,
chunks: List[Dict[str, Any]],
token_data: Dict[str, Any]
) -> List[List[float]]:
"""
Generate embeddings for chunks.
This is a mock implementation. In production, this would:
1. Call the embedding service (BGE-M3 or similar)
2. Handle batching for efficiency
3. Apply caching for common chunks
"""
embeddings = []
for chunk in chunks:
# Check cache first
chunk_hash = hashlib.sha256(chunk["text"].encode()).hexdigest()
if chunk_hash in self.embedding_cache:
embeddings.append(self.embedding_cache[chunk_hash])
else:
# Mock embedding generation
# In production: call embedding API
embedding = [0.1] * 768 # Mock 768-dim embedding (BGE-M3 size)
embeddings.append(embedding)
# Cache for reuse (with size limit)
if len(self.embedding_cache) < 1000:
self.embedding_cache[chunk_hash] = embedding
return embeddings
def _get_file_extension(self, filename: str) -> str:
"""Extract file extension from filename"""
parts = filename.lower().split(".")
if len(parts) > 1:
return f".{parts[-1]}"
return ".txt" # Default to text
async def validate_document(
self,
file_size: int,
filename: str,
token_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
Validate document before processing.
Args:
file_size: Size of file in bytes
filename: Name of the file
token_data: Capability token data
Returns:
Validation result with warnings and errors
"""
# Get size limits from token
max_size = token_data.get("constraints", {}).get("max_file_size", 50 * 1024 * 1024)
validation = {
"valid": True,
"warnings": [],
"errors": [],
"recommendations": []
}
# Check file size
if file_size > max_size:
validation["valid"] = False
validation["errors"].append(f"File exceeds maximum size of {max_size / 1024 / 1024:.1f} MiB")
elif file_size > 10 * 1024 * 1024:
validation["warnings"].append("Large file may take longer to process")
validation["recommendations"].append("Consider using streaming processing for better performance")
# Check file type
file_type = self._get_file_extension(filename)
supported_types = [".pdf", ".docx", ".txt", ".md", ".html", ".csv", ".xlsx", ".pptx"]
if file_type not in supported_types:
validation["valid"] = False
validation["errors"].append(f"Unsupported file type: {file_type}")
validation["recommendations"].append(f"Supported types: {', '.join(supported_types)}")
# Check for special processing needs
if file_type in [".xlsx", ".csv"]:
validation["recommendations"].append("Table extraction will be applied automatically")
if file_type == ".pdf":
validation["recommendations"].append("Enable OCR if document contains scanned images")
return validation
async def get_processing_stats(self) -> Dict[str, Any]:
"""Get processing statistics"""
return {
"engines_available": ["native", "unstructured"],
"native_engine_status": "ready",
"unstructured_engine_status": "ready" if self.unstructured_engine else "not_initialized",
"embedding_cache_size": len(self.embedding_cache),
"supported_formats": [".pdf", ".docx", ".txt", ".md", ".html", ".csv", ".xlsx", ".pptx"],
"default_chunk_size": 512,
"default_chunk_overlap": 128,
"stateless": True
}

View File

@@ -0,0 +1,447 @@
"""
Embedding Service for GT 2.0 Resource Cluster
Provides embedding generation with:
- BGE-M3 model integration
- Batch processing capabilities
- Rate limiting and quota management
- Capability-based authentication
- Stateless operation (no data storage)
GT 2.0 Architecture Principles:
- Perfect Tenant Isolation: Per-request capability validation
- Zero Downtime: Stateless design, circuit breakers
- Self-Contained Security: Capability-based auth
- No Complexity Addition: Simple interface, no database
"""
import asyncio
import logging
import time
import os
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
import uuid
from app.core.backends.embedding_backend import EmbeddingBackend, EmbeddingRequest
from app.core.capability_auth import verify_capability_token, CapabilityError
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class EmbeddingResponse:
"""Response structure for embedding generation"""
request_id: str
embeddings: List[List[float]]
model: str
dimensions: int
tokens_used: int
processing_time_ms: int
tenant_id: str
created_at: str
@dataclass
class EmbeddingStats:
"""Statistics for embedding requests"""
total_requests: int = 0
total_tokens_processed: int = 0
total_processing_time_ms: int = 0
average_processing_time_ms: float = 0.0
last_request_at: Optional[str] = None
class EmbeddingService:
"""
STATELESS embedding service for GT 2.0 Resource Cluster.
Key features:
- BGE-M3 model for high-quality embeddings
- Batch processing for efficiency
- Rate limiting per capability token
- Memory-conscious processing
- No persistent storage
"""
def __init__(self):
self.backend = EmbeddingBackend()
self.stats = EmbeddingStats()
# Initialize BGE-M3 tokenizer for accurate token counting
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
logger.info("Initialized BGE-M3 tokenizer for accurate token counting")
except Exception as e:
logger.warning(f"Failed to load BGE-M3 tokenizer, using word estimation: {e}")
self.tokenizer = None
# Rate limiting settings (per capability token)
self.rate_limits = {
"requests_per_minute": 60,
"tokens_per_minute": 50000,
"max_batch_size": 32
}
# Track requests for rate limiting (in-memory, temporary)
self._request_tracker = {}
logger.info("STATELESS embedding service initialized")
async def generate_embeddings(
self,
texts: List[str],
capability_token: str,
instruction: Optional[str] = None,
request_id: Optional[str] = None,
normalize: bool = True
) -> EmbeddingResponse:
"""
Generate embeddings with capability-based authentication.
Args:
texts: List of texts to embed
capability_token: JWT token with embedding permissions
instruction: Optional instruction for embedding context
request_id: Optional request ID for tracking
normalize: Whether to normalize embeddings
Returns:
EmbeddingResponse with generated embeddings
Raises:
CapabilityError: If token invalid or insufficient permissions
ValueError: If request parameters invalid
"""
start_time = time.time()
request_id = request_id or str(uuid.uuid4())
try:
# Verify capability token and extract permissions
capability = await verify_capability_token(capability_token)
tenant_id = capability.get("tenant_id")
user_id = capability.get("sub") # Extract user ID from token
# Check embedding permissions
await self._verify_embedding_permissions(capability, len(texts))
# Apply rate limiting
await self._check_rate_limits(capability_token, len(texts))
# Validate input
self._validate_embedding_request(texts)
# Generate embeddings via backend
embeddings = await self.backend.generate_embeddings(
texts=texts,
instruction=instruction,
tenant_id=tenant_id,
request_id=request_id
)
# Calculate processing metrics
processing_time_ms = int((time.time() - start_time) * 1000)
total_tokens = self._estimate_tokens(texts)
# Update statistics
self._update_stats(total_tokens, processing_time_ms)
# Log embedding usage for billing (non-blocking)
# Fire and forget - don't wait for completion
asyncio.create_task(
self._log_embedding_usage(
tenant_id=tenant_id,
user_id=user_id,
tokens_used=total_tokens,
embedding_count=len(embeddings),
model=self.backend.model_name,
request_id=request_id
)
)
# Create response
response = EmbeddingResponse(
request_id=request_id,
embeddings=embeddings,
model=self.backend.model_name,
dimensions=self.backend.embedding_dimensions,
tokens_used=total_tokens,
processing_time_ms=processing_time_ms,
tenant_id=tenant_id,
created_at=datetime.utcnow().isoformat()
)
logger.info(
f"Generated {len(embeddings)} embeddings for tenant {tenant_id} "
f"in {processing_time_ms}ms"
)
return response
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise
finally:
# Always ensure cleanup
if 'texts' in locals():
del texts
async def get_model_info(self) -> Dict[str, Any]:
"""Get information about the embedding model"""
return {
"model_name": self.backend.model_name,
"dimensions": self.backend.embedding_dimensions,
"max_sequence_length": self.backend.max_sequence_length,
"max_batch_size": self.backend.max_batch_size,
"supports_instruction": True,
"normalization_default": True
}
async def get_service_stats(
self,
capability_token: str
) -> Dict[str, Any]:
"""
Get service statistics (for admin users only).
Args:
capability_token: JWT token with admin permissions
Returns:
Service statistics
"""
# Verify admin permissions
capability = await verify_capability_token(capability_token)
if not self._has_admin_permissions(capability):
raise CapabilityError("Admin permissions required")
return {
"model_info": await self.get_model_info(),
"statistics": asdict(self.stats),
"rate_limits": self.rate_limits,
"active_requests": len(self._request_tracker)
}
async def health_check(self) -> Dict[str, Any]:
"""Check service health"""
return {
"status": "healthy",
"service": "embedding_service",
"model": self.backend.model_name,
"backend_ready": True,
"last_request": self.stats.last_request_at
}
async def _verify_embedding_permissions(
self,
capability: Dict[str, Any],
text_count: int
) -> None:
"""Verify that capability token has embedding permissions"""
# Check for embedding capability
capabilities = capability.get("capabilities", [])
embedding_caps = [
cap for cap in capabilities
if cap.get("resource") == "embeddings"
]
if not embedding_caps:
raise CapabilityError("No embedding permissions in capability token")
# Check constraints
embedding_cap = embedding_caps[0] # Use first embedding capability
constraints = embedding_cap.get("constraints", {})
# Check batch size limit
max_batch = constraints.get("max_batch_size", self.rate_limits["max_batch_size"])
if text_count > max_batch:
raise CapabilityError(f"Batch size {text_count} exceeds limit {max_batch}")
# Check rate limits
rate_limit = constraints.get("rate_limit_per_minute", self.rate_limits["requests_per_minute"])
token_limit = constraints.get("tokens_per_minute", self.rate_limits["tokens_per_minute"])
logger.debug(f"Embedding permissions verified: batch={text_count}, limits=({rate_limit}, {token_limit})")
async def _check_rate_limits(
self,
capability_token: str,
text_count: int
) -> None:
"""Check rate limits for capability token"""
now = time.time()
token_hash = hash(capability_token) % 10000 # Simple tracking key
# Clean old entries (older than 1 minute)
cleanup_time = now - 60
self._request_tracker = {
k: v for k, v in self._request_tracker.items()
if v.get("last_request", 0) > cleanup_time
}
# Get or create tracker for this token
if token_hash not in self._request_tracker:
self._request_tracker[token_hash] = {
"requests": 0,
"tokens": 0,
"last_request": now
}
tracker = self._request_tracker[token_hash]
# Check request rate limit
if tracker["requests"] >= self.rate_limits["requests_per_minute"]:
raise CapabilityError("Rate limit exceeded: too many requests per minute")
# Estimate tokens and check token limit
estimated_tokens = self._estimate_tokens([f"text_{i}" for i in range(text_count)])
if tracker["tokens"] + estimated_tokens > self.rate_limits["tokens_per_minute"]:
raise CapabilityError("Rate limit exceeded: too many tokens per minute")
# Update tracker
tracker["requests"] += 1
tracker["tokens"] += estimated_tokens
tracker["last_request"] = now
def _validate_embedding_request(self, texts: List[str]) -> None:
"""Validate embedding request parameters"""
if not texts:
raise ValueError("No texts provided for embedding")
if not isinstance(texts, list):
raise ValueError("Texts must be a list")
if len(texts) > self.backend.max_batch_size:
raise ValueError(f"Batch size {len(texts)} exceeds maximum {self.backend.max_batch_size}")
# Check individual text lengths
for i, text in enumerate(texts):
if not isinstance(text, str):
raise ValueError(f"Text at index {i} must be a string")
if len(text.strip()) == 0:
raise ValueError(f"Text at index {i} is empty")
# Simple token estimation for length check
estimated_tokens = len(text.split()) * 1.3 # Rough estimation
if estimated_tokens > self.backend.max_sequence_length:
raise ValueError(f"Text at index {i} exceeds maximum length")
def _estimate_tokens(self, texts: List[str]) -> int:
"""
Count tokens using actual BGE-M3 tokenizer.
Falls back to word-count estimation if tokenizer unavailable.
"""
if self.tokenizer is not None:
try:
total_tokens = 0
for text in texts:
tokens = self.tokenizer.encode(text, add_special_tokens=False)
total_tokens += len(tokens)
return total_tokens
except Exception as e:
logger.warning(f"Tokenizer error, falling back to estimation: {e}")
# Fallback: word count * 1.3 (rough estimation)
total_words = sum(len(text.split()) for text in texts)
return int(total_words * 1.3)
def _has_admin_permissions(self, capability: Dict[str, Any]) -> bool:
"""Check if capability has admin permissions"""
capabilities = capability.get("capabilities", [])
return any(
cap.get("resource") == "admin" and "stats" in cap.get("actions", [])
for cap in capabilities
)
def _update_stats(self, tokens_processed: int, processing_time_ms: int) -> None:
"""Update service statistics"""
self.stats.total_requests += 1
self.stats.total_tokens_processed += tokens_processed
self.stats.total_processing_time_ms += processing_time_ms
self.stats.average_processing_time_ms = (
self.stats.total_processing_time_ms / self.stats.total_requests
)
self.stats.last_request_at = datetime.utcnow().isoformat()
async def _log_embedding_usage(
self,
tenant_id: str,
user_id: str,
tokens_used: int,
embedding_count: int,
model: str = "BAAI/bge-m3",
request_id: Optional[str] = None
) -> None:
"""
Log embedding usage to control panel database for billing.
This method logs usage asynchronously and does not block the embedding response.
Failures are logged as warnings but do not raise exceptions.
Args:
tenant_id: Tenant identifier
user_id: User identifier (from capability token 'sub')
tokens_used: Number of tokens processed
embedding_count: Number of embeddings generated
model: Embedding model name
request_id: Optional request ID for tracking
"""
try:
import asyncpg
# Calculate cost: BGE-M3 pricing ~$0.10 per million tokens
cost_cents = (tokens_used / 1_000_000) * 0.10 * 100
# Connect to control panel database
# Using environment variables from docker-compose
db_password = os.getenv("CONTROL_PANEL_DB_PASSWORD")
if not db_password:
logger.warning("CONTROL_PANEL_DB_PASSWORD not set, skipping embedding usage logging")
return
conn = await asyncpg.connect(
host=os.getenv("CONTROL_PANEL_DB_HOST", "gentwo-controlpanel-postgres"),
database=os.getenv("CONTROL_PANEL_DB_NAME", "gt2_admin"),
user=os.getenv("CONTROL_PANEL_DB_USER", "postgres"),
password=db_password,
timeout=5.0
)
try:
# Insert usage log
await conn.execute("""
INSERT INTO public.embedding_usage_logs
(tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""", tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
logger.info(
f"Logged embedding usage: tenant={tenant_id}, user={user_id}, "
f"tokens={tokens_used}, embeddings={embedding_count}, cost_cents={cost_cents:.4f}"
)
finally:
await conn.close()
except Exception as e:
# Log warning but don't fail the embedding request
logger.warning(f"Failed to log embedding usage for tenant {tenant_id}: {e}")
# Global service instance
_embedding_service = None
def get_embedding_service() -> EmbeddingService:
"""Get the global embedding service instance"""
global _embedding_service
if _embedding_service is None:
_embedding_service = EmbeddingService()
return _embedding_service

View File

@@ -0,0 +1,729 @@
"""
Integration Proxy Service for GT 2.0
Secure proxy service for external integrations with capability-based access control,
sandbox restrictions, and comprehensive audit logging. All external calls are routed
through this service in the Resource Cluster for security and monitoring.
"""
import asyncio
import json
import httpx
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta
from pathlib import Path
from dataclasses import dataclass, asdict
from enum import Enum
import logging
from contextlib import asynccontextmanager
from app.core.security import verify_capability_token
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class IntegrationType(Enum):
"""Types of external integrations"""
COMMUNICATION = "communication" # Slack, Teams, Discord
DEVELOPMENT = "development" # GitHub, GitLab, Jira
PROJECT_MANAGEMENT = "project_management" # Asana, Monday.com
DATABASE = "database" # PostgreSQL, MySQL, MongoDB
CUSTOM_API = "custom_api" # Custom REST/GraphQL APIs
WEBHOOK = "webhook" # Outbound webhook calls
class SandboxLevel(Enum):
"""Sandbox restriction levels"""
NONE = "none" # No restrictions (trusted)
BASIC = "basic" # Basic timeout and size limits
RESTRICTED = "restricted" # Limited API calls and data access
STRICT = "strict" # Maximum restrictions
@dataclass
class IntegrationConfig:
"""Configuration for external integration"""
id: str
name: str
integration_type: IntegrationType
base_url: str
authentication_method: str # oauth2, api_key, basic_auth, certificate
sandbox_level: SandboxLevel
# Authentication details (encrypted)
auth_config: Dict[str, Any]
# Rate limits and constraints
max_requests_per_hour: int = 1000
max_response_size_bytes: int = 10 * 1024 * 1024 # 10MB
timeout_seconds: int = 30
# Allowed operations
allowed_methods: List[str] = None
allowed_endpoints: List[str] = None
blocked_endpoints: List[str] = None
# Network restrictions
allowed_domains: List[str] = None
# Created metadata
created_at: datetime = None
created_by: str = ""
is_active: bool = True
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
if self.allowed_methods is None:
self.allowed_methods = ["GET", "POST"]
if self.allowed_endpoints is None:
self.allowed_endpoints = []
if self.blocked_endpoints is None:
self.blocked_endpoints = []
if self.allowed_domains is None:
self.allowed_domains = []
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for storage"""
data = asdict(self)
data["integration_type"] = self.integration_type.value
data["sandbox_level"] = self.sandbox_level.value
data["created_at"] = self.created_at.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "IntegrationConfig":
"""Create from dictionary"""
data["integration_type"] = IntegrationType(data["integration_type"])
data["sandbox_level"] = SandboxLevel(data["sandbox_level"])
data["created_at"] = datetime.fromisoformat(data["created_at"])
return cls(**data)
@dataclass
class ProxyRequest:
"""Request to proxy to external service"""
integration_id: str
method: str
endpoint: str
headers: Optional[Dict[str, str]] = None
data: Optional[Dict[str, Any]] = None
params: Optional[Dict[str, str]] = None
timeout_override: Optional[int] = None
def __post_init__(self):
if self.headers is None:
self.headers = {}
if self.data is None:
self.data = {}
if self.params is None:
self.params = {}
@dataclass
class ProxyResponse:
"""Response from proxied external service"""
success: bool
status_code: int
data: Optional[Dict[str, Any]]
headers: Dict[str, str]
execution_time_ms: int
sandbox_applied: bool
restrictions_applied: List[str]
error_message: Optional[str] = None
def __post_init__(self):
if self.headers is None:
self.headers = {}
if self.restrictions_applied is None:
self.restrictions_applied = []
class SandboxManager:
"""Manages sandbox restrictions for external integrations"""
def __init__(self):
self.active_requests: Dict[str, datetime] = {}
self.rate_limiters: Dict[str, List[datetime]] = {}
def apply_sandbox_restrictions(
self,
config: IntegrationConfig,
request: ProxyRequest,
capability_token: Dict[str, Any]
) -> Tuple[ProxyRequest, List[str]]:
"""Apply sandbox restrictions to request"""
restrictions_applied = []
if config.sandbox_level == SandboxLevel.NONE:
return request, restrictions_applied
# Apply timeout restrictions
if config.sandbox_level in [SandboxLevel.BASIC, SandboxLevel.RESTRICTED, SandboxLevel.STRICT]:
max_timeout = self._get_max_timeout(config.sandbox_level)
if request.timeout_override is None or request.timeout_override > max_timeout:
request.timeout_override = max_timeout
restrictions_applied.append(f"timeout_limited_to_{max_timeout}s")
# Apply endpoint restrictions
if config.sandbox_level in [SandboxLevel.RESTRICTED, SandboxLevel.STRICT]:
# Check blocked endpoints first
if request.endpoint in config.blocked_endpoints:
raise PermissionError(f"Endpoint {request.endpoint} is blocked")
# Then check allowed endpoints if specified
if config.allowed_endpoints and request.endpoint not in config.allowed_endpoints:
raise PermissionError(f"Endpoint {request.endpoint} not allowed")
restrictions_applied.append("endpoint_validation")
# Apply method restrictions
if config.sandbox_level == SandboxLevel.STRICT:
allowed_methods = config.allowed_methods or ["GET", "POST"]
if request.method not in allowed_methods:
raise PermissionError(f"HTTP method {request.method} not allowed in strict mode")
restrictions_applied.append("method_restricted")
# Apply data size restrictions
if request.data:
data_size = len(json.dumps(request.data).encode())
max_size = self._get_max_data_size(config.sandbox_level)
if data_size > max_size:
raise ValueError(f"Request data size {data_size} exceeds limit {max_size}")
restrictions_applied.append("data_size_validated")
# Apply capability-based restrictions
constraints = capability_token.get("constraints", {})
if "integration_timeout_seconds" in constraints:
max_cap_timeout = constraints["integration_timeout_seconds"]
if request.timeout_override > max_cap_timeout:
request.timeout_override = max_cap_timeout
restrictions_applied.append(f"capability_timeout_{max_cap_timeout}s")
return request, restrictions_applied
def _get_max_timeout(self, sandbox_level: SandboxLevel) -> int:
"""Get maximum timeout for sandbox level"""
timeouts = {
SandboxLevel.BASIC: 60,
SandboxLevel.RESTRICTED: 30,
SandboxLevel.STRICT: 15
}
return timeouts.get(sandbox_level, 30)
def _get_max_data_size(self, sandbox_level: SandboxLevel) -> int:
"""Get maximum data size for sandbox level"""
sizes = {
SandboxLevel.BASIC: 1024 * 1024, # 1MB
SandboxLevel.RESTRICTED: 512 * 1024, # 512KB
SandboxLevel.STRICT: 256 * 1024 # 256KB
}
return sizes.get(sandbox_level, 512 * 1024)
async def check_rate_limits(self, integration_id: str, config: IntegrationConfig) -> bool:
"""Check if request is within rate limits"""
now = datetime.utcnow()
hour_ago = now - timedelta(hours=1)
# Initialize or clean rate limiter
if integration_id not in self.rate_limiters:
self.rate_limiters[integration_id] = []
# Remove old requests
self.rate_limiters[integration_id] = [
req_time for req_time in self.rate_limiters[integration_id]
if req_time > hour_ago
]
# Check rate limit
if len(self.rate_limiters[integration_id]) >= config.max_requests_per_hour:
return False
# Record this request
self.rate_limiters[integration_id].append(now)
return True
class IntegrationProxyService:
"""
Integration Proxy Service for secure external API access.
Features:
- Capability-based access control
- Sandbox restrictions based on trust level
- Rate limiting and usage tracking
- Comprehensive audit logging
- Response sanitization and size limits
"""
def __init__(self, base_path: Optional[Path] = None):
self.base_path = base_path or Path("/data/resource-cluster/integrations")
self.configs_path = self.base_path / "configs"
self.usage_path = self.base_path / "usage"
self.audit_path = self.base_path / "audit"
self.sandbox_manager = SandboxManager()
self.http_client = None
# Ensure directories exist with proper permissions
self._ensure_directories()
def _ensure_directories(self):
"""Ensure storage directories exist with proper permissions"""
for path in [self.configs_path, self.usage_path, self.audit_path]:
path.mkdir(parents=True, exist_ok=True, mode=0o700)
@asynccontextmanager
async def get_http_client(self):
"""Get HTTP client with proper configuration"""
if self.http_client is None:
self.http_client = httpx.AsyncClient(
timeout=httpx.Timeout(60.0),
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20)
)
try:
yield self.http_client
finally:
# Client stays open for reuse
pass
async def execute_integration(
self,
request: ProxyRequest,
capability_token: str
) -> ProxyResponse:
"""Execute integration request with security and sandbox restrictions"""
start_time = datetime.utcnow()
try:
# Verify capability token
token_obj = verify_capability_token(capability_token)
if not token_obj:
raise PermissionError("Invalid capability token")
# Convert token object to dict for compatibility
token_data = {
"tenant_id": token_obj.tenant_id,
"sub": token_obj.sub,
"capabilities": [cap.dict() if hasattr(cap, 'dict') else cap for cap in token_obj.capabilities],
"constraints": {}
}
# Load integration configuration
config = await self._load_integration_config(request.integration_id)
if not config or not config.is_active:
raise ValueError(f"Integration {request.integration_id} not found or inactive")
# Validate capability for this integration
required_capability = f"integration:{request.integration_id}:{request.method.lower()}"
if not self._has_capability(token_data, required_capability):
raise PermissionError(f"Missing capability: {required_capability}")
# Check rate limits
if not await self.sandbox_manager.check_rate_limits(request.integration_id, config):
raise PermissionError("Rate limit exceeded")
# Apply sandbox restrictions
sandboxed_request, restrictions = self.sandbox_manager.apply_sandbox_restrictions(
config, request, token_data
)
# Execute the request
response = await self._execute_proxied_request(config, sandboxed_request)
response.sandbox_applied = len(restrictions) > 0
response.restrictions_applied = restrictions
# Calculate execution time
execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000
response.execution_time_ms = int(execution_time)
# Log usage
await self._log_usage(
integration_id=request.integration_id,
tenant_id=token_data.get("tenant_id"),
user_id=token_data.get("sub"),
method=request.method,
endpoint=request.endpoint,
success=response.success,
execution_time_ms=response.execution_time_ms
)
# Audit log
await self._audit_log(
action="integration_executed",
integration_id=request.integration_id,
user_id=token_data.get("sub"),
details={
"method": request.method,
"endpoint": request.endpoint,
"success": response.success,
"restrictions_applied": restrictions
}
)
return response
except Exception as e:
logger.error(f"Integration execution failed: {e}")
# Log error
execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000
await self._log_usage(
integration_id=request.integration_id,
tenant_id=token_data.get("tenant_id") if 'token_data' in locals() else "unknown",
user_id=token_data.get("sub") if 'token_data' in locals() else "unknown",
method=request.method,
endpoint=request.endpoint,
success=False,
execution_time_ms=int(execution_time),
error=str(e)
)
return ProxyResponse(
success=False,
status_code=500,
data=None,
headers={},
execution_time_ms=int(execution_time),
sandbox_applied=False,
restrictions_applied=[],
error_message=str(e)
)
async def _execute_proxied_request(
self,
config: IntegrationConfig,
request: ProxyRequest
) -> ProxyResponse:
"""Execute the actual HTTP request to external service"""
# Build URL
if request.endpoint.startswith('http'):
url = request.endpoint
else:
url = f"{config.base_url.rstrip('/')}/{request.endpoint.lstrip('/')}"
# Apply authentication
headers = request.headers.copy()
await self._apply_authentication(config, headers)
# Set timeout
timeout = request.timeout_override or config.timeout_seconds
try:
async with self.get_http_client() as client:
# Execute request
if request.method.upper() == "GET":
response = await client.get(
url,
headers=headers,
params=request.params,
timeout=timeout
)
elif request.method.upper() == "POST":
response = await client.post(
url,
headers=headers,
json=request.data,
params=request.params,
timeout=timeout
)
elif request.method.upper() == "PUT":
response = await client.put(
url,
headers=headers,
json=request.data,
params=request.params,
timeout=timeout
)
elif request.method.upper() == "DELETE":
response = await client.delete(
url,
headers=headers,
params=request.params,
timeout=timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {request.method}")
# Check response size
if len(response.content) > config.max_response_size_bytes:
raise ValueError(f"Response size exceeds limit: {len(response.content)}")
# Parse response
try:
data = response.json() if response.content else {}
except json.JSONDecodeError:
data = {"raw_content": response.text}
return ProxyResponse(
success=200 <= response.status_code < 300,
status_code=response.status_code,
data=data,
headers=dict(response.headers),
execution_time_ms=0, # Will be set by caller
sandbox_applied=False # Will be set by caller
)
except httpx.TimeoutException:
return ProxyResponse(
success=False,
status_code=408,
data=None,
headers={},
execution_time_ms=timeout * 1000,
sandbox_applied=False,
restrictions_applied=[],
error_message="Request timeout"
)
except Exception as e:
return ProxyResponse(
success=False,
status_code=500,
data=None,
headers={},
execution_time_ms=0,
sandbox_applied=False,
restrictions_applied=[],
error_message=str(e)
)
async def _apply_authentication(self, config: IntegrationConfig, headers: Dict[str, str]):
"""Apply authentication to request headers"""
auth_config = config.auth_config
if config.authentication_method == "api_key":
api_key = auth_config.get("api_key")
key_header = auth_config.get("key_header", "Authorization")
key_prefix = auth_config.get("key_prefix", "Bearer")
if api_key:
headers[key_header] = f"{key_prefix} {api_key}"
elif config.authentication_method == "basic_auth":
username = auth_config.get("username")
password = auth_config.get("password")
if username and password:
import base64
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
headers["Authorization"] = f"Basic {credentials}"
elif config.authentication_method == "oauth2":
access_token = auth_config.get("access_token")
if access_token:
headers["Authorization"] = f"Bearer {access_token}"
# Add custom headers
custom_headers = auth_config.get("custom_headers", {})
headers.update(custom_headers)
def _has_capability(self, token_data: Dict[str, Any], required_capability: str) -> bool:
"""Check if token has required capability"""
capabilities = token_data.get("capabilities", [])
for capability in capabilities:
if isinstance(capability, dict):
resource = capability.get("resource", "")
# Handle wildcard matching
if resource == required_capability:
return True
if resource.endswith("*"):
prefix = resource[:-1] # Remove the *
if required_capability.startswith(prefix):
return True
elif isinstance(capability, str):
# Handle wildcard matching for string capabilities
if capability == required_capability:
return True
if capability.endswith("*"):
prefix = capability[:-1] # Remove the *
if required_capability.startswith(prefix):
return True
return False
async def _load_integration_config(self, integration_id: str) -> Optional[IntegrationConfig]:
"""Load integration configuration from storage"""
config_file = self.configs_path / f"{integration_id}.json"
if not config_file.exists():
return None
try:
with open(config_file, "r") as f:
data = json.load(f)
return IntegrationConfig.from_dict(data)
except Exception as e:
logger.error(f"Failed to load integration config {integration_id}: {e}")
return None
async def store_integration_config(self, config: IntegrationConfig) -> bool:
"""Store integration configuration"""
config_file = self.configs_path / f"{config.id}.json"
try:
with open(config_file, "w") as f:
json.dump(config.to_dict(), f, indent=2)
# Set secure permissions
config_file.chmod(0o600)
return True
except Exception as e:
logger.error(f"Failed to store integration config {config.id}: {e}")
return False
async def _log_usage(
self,
integration_id: str,
tenant_id: str,
user_id: str,
method: str,
endpoint: str,
success: bool,
execution_time_ms: int,
error: Optional[str] = None
):
"""Log integration usage for analytics"""
date_str = datetime.utcnow().strftime("%Y-%m-%d")
usage_file = self.usage_path / f"usage_{date_str}.jsonl"
usage_record = {
"timestamp": datetime.utcnow().isoformat(),
"integration_id": integration_id,
"tenant_id": tenant_id,
"user_id": user_id,
"method": method,
"endpoint": endpoint,
"success": success,
"execution_time_ms": execution_time_ms,
"error": error
}
try:
with open(usage_file, "a") as f:
f.write(json.dumps(usage_record) + "\n")
# Set secure permissions on file
usage_file.chmod(0o600)
except Exception as e:
logger.error(f"Failed to log usage: {e}")
async def _audit_log(
self,
action: str,
integration_id: str,
user_id: str,
details: Dict[str, Any]
):
"""Log audit trail for integration actions"""
date_str = datetime.utcnow().strftime("%Y-%m-%d")
audit_file = self.audit_path / f"audit_{date_str}.jsonl"
audit_record = {
"timestamp": datetime.utcnow().isoformat(),
"action": action,
"integration_id": integration_id,
"user_id": user_id,
"details": details
}
try:
with open(audit_file, "a") as f:
f.write(json.dumps(audit_record) + "\n")
# Set secure permissions on file
audit_file.chmod(0o600)
except Exception as e:
logger.error(f"Failed to log audit: {e}")
async def list_integrations(self, capability_token: str) -> List[IntegrationConfig]:
"""List available integrations based on capabilities"""
token_obj = verify_capability_token(capability_token)
if not token_obj:
raise PermissionError("Invalid capability token")
# Convert token object to dict for compatibility
token_data = {
"tenant_id": token_obj.tenant_id,
"sub": token_obj.sub,
"capabilities": [cap.dict() if hasattr(cap, 'dict') else cap for cap in token_obj.capabilities],
"constraints": {}
}
integrations = []
for config_file in self.configs_path.glob("*.json"):
try:
with open(config_file, "r") as f:
data = json.load(f)
config = IntegrationConfig.from_dict(data)
# Check if user has capability for this integration
required_capability = f"integration:{config.id}:*"
if self._has_capability(token_data, required_capability):
integrations.append(config)
except Exception as e:
logger.warning(f"Failed to load integration config {config_file}: {e}")
return integrations
async def get_integration_usage_analytics(
self,
integration_id: str,
days: int = 30
) -> Dict[str, Any]:
"""Get usage analytics for integration"""
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=days-1) # Include today in the range
total_requests = 0
successful_requests = 0
total_execution_time = 0
error_count = 0
# Process usage logs
for day_offset in range(days):
date = start_date + timedelta(days=day_offset)
date_str = date.strftime("%Y-%m-%d")
usage_file = self.usage_path / f"usage_{date_str}.jsonl"
if usage_file.exists():
try:
with open(usage_file, "r") as f:
for line in f:
record = json.loads(line.strip())
if record["integration_id"] == integration_id:
total_requests += 1
if record["success"]:
successful_requests += 1
else:
error_count += 1
total_execution_time += record["execution_time_ms"]
except Exception as e:
logger.warning(f"Failed to process usage file {usage_file}: {e}")
return {
"integration_id": integration_id,
"total_requests": total_requests,
"successful_requests": successful_requests,
"error_count": error_count,
"success_rate": successful_requests / total_requests if total_requests > 0 else 0,
"avg_execution_time_ms": total_execution_time / total_requests if total_requests > 0 else 0,
"date_range": {
"start": start_date.isoformat(),
"end": end_date.isoformat()
}
}
async def close(self):
"""Close HTTP client and cleanup resources"""
if self.http_client:
await self.http_client.aclose()
self.http_client = None

View File

@@ -0,0 +1,925 @@
"""
LLM Gateway Service for GT 2.0 Resource Cluster
Provides unified access to LLM providers with:
- Groq Cloud integration for fast inference
- OpenAI API compatibility
- Rate limiting and quota management
- Capability-based authentication
- Model routing and load balancing
- Response streaming support
GT 2.0 Architecture Principles:
- Stateless: No persistent connections or state
- Zero downtime: Circuit breakers and failover
- Self-contained: No external configuration dependencies
- Capability-based: JWT token authorization
"""
import asyncio
import logging
import json
import time
from typing import Dict, Any, List, Optional, AsyncGenerator, Union
from datetime import datetime, timedelta, timezone
from dataclasses import dataclass, asdict
import uuid
import httpx
from enum import Enum
from urllib.parse import urlparse
from app.core.config import get_settings
def is_provider_endpoint(endpoint_url: str, provider_domains: List[str]) -> bool:
"""
Safely check if URL belongs to a specific provider.
Uses proper URL parsing to prevent bypass via URLs like
'evil.groq.com.attacker.com' or 'groq.com.evil.com'.
"""
try:
parsed = urlparse(endpoint_url)
hostname = (parsed.hostname or "").lower()
for domain in provider_domains:
domain = domain.lower()
# Match exact domain or subdomain (e.g., api.groq.com matches groq.com)
if hostname == domain or hostname.endswith(f".{domain}"):
return True
return False
except Exception:
return False
from app.core.capability_auth import verify_capability_token, CapabilityError
from app.services.admin_model_config_service import get_admin_model_service, AdminModelConfig
logger = logging.getLogger(__name__)
settings = get_settings()
class ModelProvider(str, Enum):
"""Supported LLM providers"""
GROQ = "groq"
OPENAI = "openai"
ANTHROPIC = "anthropic"
NVIDIA = "nvidia"
LOCAL = "local"
class ModelCapability(str, Enum):
"""Model capabilities for routing"""
CHAT = "chat"
COMPLETION = "completion"
EMBEDDING = "embedding"
FUNCTION_CALLING = "function_calling"
VISION = "vision"
CODE = "code"
@dataclass
class ModelConfig:
"""Model configuration and capabilities"""
model_id: str
provider: ModelProvider
capabilities: List[ModelCapability]
max_tokens: int
context_window: int
cost_per_token: float
rate_limit_rpm: int
supports_streaming: bool
supports_functions: bool
is_available: bool = True
@dataclass
class LLMRequest:
"""Standardized LLM request format"""
model: str
messages: List[Dict[str, str]]
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
functions: Optional[List[Dict[str, Any]]] = None
function_call: Optional[Union[str, Dict[str, str]]] = None
user: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for API calls"""
result = asdict(self)
# Remove None values
return {k: v for k, v in result.items() if v is not None}
@dataclass
class LLMResponse:
"""Standardized LLM response format"""
id: str
object: str
created: int
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, int]
provider: str
request_id: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for API responses"""
return asdict(self)
class LLMGateway:
"""
LLM Gateway with unified API and multi-provider support.
Provides OpenAI-compatible API while routing to optimal providers
based on model capabilities, availability, and cost.
"""
def __init__(self):
self.settings = get_settings()
self.http_client = httpx.AsyncClient(timeout=120.0)
self.admin_service = get_admin_model_service()
# Rate limiting tracking
self.rate_limits: Dict[str, Dict[str, Any]] = {}
# Provider health tracking
self.provider_health: Dict[ModelProvider, bool] = {
provider: True for provider in ModelProvider
}
# Request statistics
self.stats = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"provider_usage": {provider.value: 0 for provider in ModelProvider},
"model_usage": {},
"average_latency": 0.0
}
logger.info("LLM Gateway initialized with admin-configured models")
async def get_available_models(self, tenant_id: Optional[str] = None) -> List[AdminModelConfig]:
"""Get available models, optionally filtered by tenant"""
if tenant_id:
return await self.admin_service.get_tenant_models(tenant_id)
else:
return await self.admin_service.get_all_models(active_only=True)
async def get_model_config(self, model_id: str, tenant_id: Optional[str] = None) -> Optional[AdminModelConfig]:
"""Get configuration for a specific model"""
config = await self.admin_service.get_model_config(model_id)
# Check tenant access if tenant_id provided
if config and tenant_id:
has_access = await self.admin_service.check_tenant_access(tenant_id, model_id)
if not has_access:
return None
return config
async def get_groq_api_key(self) -> Optional[str]:
"""Get Groq API key from admin service"""
return await self.admin_service.get_groq_api_key()
def _initialize_model_configs(self) -> Dict[str, ModelConfig]:
"""Initialize supported model configurations"""
models = {}
# Groq models (fast inference)
groq_models = [
ModelConfig(
model_id="llama3-8b-8192",
provider=ModelProvider.GROQ,
capabilities=[ModelCapability.CHAT, ModelCapability.CODE],
max_tokens=8192,
context_window=8192,
cost_per_token=0.00001,
rate_limit_rpm=30,
supports_streaming=True,
supports_functions=False
),
ModelConfig(
model_id="llama3-70b-8192",
provider=ModelProvider.GROQ,
capabilities=[ModelCapability.CHAT, ModelCapability.CODE],
max_tokens=8192,
context_window=8192,
cost_per_token=0.00008,
rate_limit_rpm=15,
supports_streaming=True,
supports_functions=False
),
ModelConfig(
model_id="mixtral-8x7b-32768",
provider=ModelProvider.GROQ,
capabilities=[ModelCapability.CHAT, ModelCapability.CODE],
max_tokens=32768,
context_window=32768,
cost_per_token=0.00005,
rate_limit_rpm=20,
supports_streaming=True,
supports_functions=False
),
ModelConfig(
model_id="gemma-7b-it",
provider=ModelProvider.GROQ,
capabilities=[ModelCapability.CHAT],
max_tokens=8192,
context_window=8192,
cost_per_token=0.00001,
rate_limit_rpm=30,
supports_streaming=True,
supports_functions=False
)
]
# OpenAI models (function calling, embeddings)
openai_models = [
ModelConfig(
model_id="gpt-4-turbo-preview",
provider=ModelProvider.OPENAI,
capabilities=[ModelCapability.CHAT, ModelCapability.FUNCTION_CALLING, ModelCapability.VISION],
max_tokens=4096,
context_window=128000,
cost_per_token=0.00003,
rate_limit_rpm=10,
supports_streaming=True,
supports_functions=True
),
ModelConfig(
model_id="gpt-3.5-turbo",
provider=ModelProvider.OPENAI,
capabilities=[ModelCapability.CHAT, ModelCapability.FUNCTION_CALLING],
max_tokens=4096,
context_window=16385,
cost_per_token=0.000002,
rate_limit_rpm=60,
supports_streaming=True,
supports_functions=True
),
ModelConfig(
model_id="text-embedding-3-small",
provider=ModelProvider.OPENAI,
capabilities=[ModelCapability.EMBEDDING],
max_tokens=8191,
context_window=8191,
cost_per_token=0.00000002,
rate_limit_rpm=3000,
supports_streaming=False,
supports_functions=False
)
]
# Add all models to registry
for model_list in [groq_models, openai_models]:
for model in model_list:
models[model.model_id] = model
return models
async def chat_completion(
self,
request: LLMRequest,
capability_token: str,
user_id: str,
tenant_id: str
) -> Union[LLMResponse, AsyncGenerator[str, None]]:
"""
Process chat completion request with capability validation.
Args:
request: LLM request parameters
capability_token: JWT capability token
user_id: User identifier for rate limiting
tenant_id: Tenant identifier for isolation
Returns:
LLM response or streaming generator
"""
start_time = time.time()
request_id = str(uuid.uuid4())
try:
# Verify capabilities
await self._verify_llm_capability(capability_token, request.model, user_id, tenant_id)
# Validate model availability
model_config = self.models.get(request.model)
if not model_config:
raise ValueError(f"Model {request.model} not supported")
if not model_config.is_available:
raise ValueError(f"Model {request.model} is currently unavailable")
# Check rate limits
await self._check_rate_limits(user_id, model_config)
# Route to configured endpoint (generic routing for any provider)
if hasattr(model_config, 'endpoint') and model_config.endpoint:
result = await self._process_generic_request(request, request_id, model_config, tenant_id)
elif model_config.provider == ModelProvider.GROQ:
result = await self._process_groq_request(request, request_id, model_config, tenant_id)
elif model_config.provider == ModelProvider.OPENAI:
result = await self._process_openai_request(request, request_id, model_config)
else:
raise ValueError(f"Provider {model_config.provider} not implemented - ensure endpoint is configured")
# Update statistics
latency = time.time() - start_time
await self._update_stats(request.model, model_config.provider, latency, True)
logger.info(f"LLM request completed: {request_id} ({latency:.3f}s)")
return result
except Exception as e:
latency = time.time() - start_time
await self._update_stats(request.model, ModelProvider.GROQ, latency, False)
logger.error(f"LLM request failed: {request_id} - {e}")
raise
async def _verify_llm_capability(
self,
capability_token: str,
model: str,
user_id: str,
tenant_id: str
) -> None:
"""Verify user has capability to use specific model"""
try:
payload = await verify_capability_token(capability_token)
# Check tenant match
if payload.get("tenant_id") != tenant_id:
raise CapabilityError("Tenant mismatch in capability token")
# Find LLM capability (match "llm" or "llm:provider" format)
capabilities = payload.get("capabilities", [])
llm_capability = None
for cap in capabilities:
resource = cap.get("resource", "")
if resource == "llm" or resource.startswith("llm:"):
llm_capability = cap
break
if not llm_capability:
raise CapabilityError("No LLM capability found in token")
# Check model access
allowed_models = llm_capability.get("constraints", {}).get("allowed_models", [])
if allowed_models and model not in allowed_models:
raise CapabilityError(f"Model {model} not allowed in capability")
# Check rate limits (per-minute window)
max_requests_per_minute = llm_capability.get("constraints", {}).get("max_requests_per_minute")
if max_requests_per_minute:
await self._check_user_rate_limit(user_id, max_requests_per_minute)
except CapabilityError:
raise
except Exception as e:
raise CapabilityError(f"Capability verification failed: {e}")
async def _check_rate_limits(self, user_id: str, model_config: ModelConfig) -> None:
"""Check if user is within rate limits for model"""
now = time.time()
minute_ago = now - 60
# Initialize user rate limit tracking
if user_id not in self.rate_limits:
self.rate_limits[user_id] = {}
if model_config.model_id not in self.rate_limits[user_id]:
self.rate_limits[user_id][model_config.model_id] = []
user_requests = self.rate_limits[user_id][model_config.model_id]
# Remove old requests
user_requests[:] = [req_time for req_time in user_requests if req_time > minute_ago]
# Check limit
if len(user_requests) >= model_config.rate_limit_rpm:
raise ValueError(f"Rate limit exceeded for model {model_config.model_id}")
# Add current request
user_requests.append(now)
async def _check_user_rate_limit(self, user_id: str, max_requests_per_minute: int) -> None:
"""
Check user's rate limit with per-minute enforcement window.
Enforces limits from Control Panel database (single source of truth).
Time window: 60 seconds (not 1 hour).
Args:
user_id: User identifier
max_requests_per_minute: Maximum requests allowed in 60-second window
Raises:
ValueError: If rate limit exceeded
"""
now = time.time()
minute_ago = now - 60 # 60-second window (was 3600 for hour)
if user_id not in self.rate_limits:
self.rate_limits[user_id] = {}
if "total_requests" not in self.rate_limits[user_id]:
self.rate_limits[user_id]["total_requests"] = []
total_requests = self.rate_limits[user_id]["total_requests"]
# Remove requests outside the 60-second window
total_requests[:] = [req_time for req_time in total_requests if req_time > minute_ago]
# Check limit
if len(total_requests) >= max_requests_per_minute:
raise ValueError(
f"Rate limit exceeded: {max_requests_per_minute} requests per minute. "
f"Try again in {int(60 - (now - total_requests[0]))} seconds."
)
# Add current request
total_requests.append(now)
async def _process_groq_request(
self,
request: LLMRequest,
request_id: str,
model_config: ModelConfig,
tenant_id: str
) -> Union[LLMResponse, AsyncGenerator[str, None]]:
"""
Process request using Groq API with tenant-specific API key.
API keys are fetched from Control Panel database - NO environment variable fallback.
"""
try:
# Get API key from Control Panel database (NO env fallback)
api_key = await self._get_tenant_api_key(tenant_id)
# Prepare Groq API request
groq_request = {
"model": request.model,
"messages": request.messages,
"max_tokens": min(request.max_tokens or 1024, model_config.max_tokens),
"temperature": request.temperature or 0.7,
"top_p": request.top_p or 1.0,
"stream": request.stream
}
if request.stop:
groq_request["stop"] = request.stop
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
if request.stream:
return self._stream_groq_response(groq_request, headers, request_id)
else:
return await self._get_groq_response(groq_request, headers, request_id)
except Exception as e:
logger.error(f"Groq API request failed: {e}")
raise ValueError(f"Groq API error: {e}")
async def _get_tenant_api_key(self, tenant_id: str) -> str:
"""
Get API key for tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
"""
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="groq")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No Groq API key for tenant '{tenant_id}': {e}")
raise ValueError(f"No Groq API key configured for tenant '{tenant_id}'. Please configure in Control Panel → API Keys.")
except RuntimeError as e:
logger.error(f"Control Panel error: {e}")
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
async def _get_tenant_nvidia_api_key(self, tenant_id: str) -> str:
"""
Get NVIDIA NIM API key for tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
"""
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="nvidia")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No NVIDIA API key for tenant '{tenant_id}': {e}")
raise ValueError(f"No NVIDIA API key configured for tenant '{tenant_id}'. Please configure in Control Panel → API Keys.")
except RuntimeError as e:
logger.error(f"Control Panel error: {e}")
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
async def _get_groq_response(
self,
groq_request: Dict[str, Any],
headers: Dict[str, str],
request_id: str
) -> LLMResponse:
"""Get non-streaming response from Groq"""
try:
response = await self.http_client.post(
"https://api.groq.com/openai/v1/chat/completions",
json=groq_request,
headers=headers
)
response.raise_for_status()
data = response.json()
# Convert to standardized format
return LLMResponse(
id=data.get("id", request_id),
object=data.get("object", "chat.completion"),
created=data.get("created", int(time.time())),
model=data.get("model", groq_request["model"]),
choices=data.get("choices", []),
usage=data.get("usage", {}),
provider="groq",
request_id=request_id
)
except httpx.HTTPStatusError as e:
logger.error(f"Groq API HTTP error: {e.response.status_code} - {e.response.text}")
raise ValueError(f"Groq API error: {e.response.status_code}")
except Exception as e:
logger.error(f"Groq API error: {e}")
raise ValueError(f"Groq API request failed: {e}")
async def _stream_groq_response(
self,
groq_request: Dict[str, Any],
headers: Dict[str, str],
request_id: str
) -> AsyncGenerator[str, None]:
"""Stream response from Groq"""
try:
async with self.http_client.stream(
"POST",
"https://api.groq.com/openai/v1/chat/completions",
json=groq_request,
headers=headers
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
# Add provider and request_id to chunk
data["provider"] = "groq"
data["request_id"] = request_id
yield f"data: {json.dumps(data)}\n\n"
except json.JSONDecodeError:
continue
yield "data: [DONE]\n\n"
except httpx.HTTPStatusError as e:
logger.error(f"Groq streaming error: {e.response.status_code}")
yield f"data: {json.dumps({'error': f'Groq API error: {e.response.status_code}'})}\n\n"
except Exception as e:
logger.error(f"Groq streaming error: {e}")
yield f"data: {json.dumps({'error': f'Streaming error: {e}'})}\n\n"
async def _process_generic_request(
self,
request: LLMRequest,
request_id: str,
model_config: Any,
tenant_id: str
) -> LLMResponse:
"""
Process request using generic endpoint (OpenAI-compatible).
For Groq endpoints, API keys are fetched from Control Panel database.
"""
try:
# Build OpenAI-compatible request
generic_request = {
"model": request.model,
"messages": request.messages,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"stream": request.stream
}
# Add optional parameters
if hasattr(request, 'tools') and request.tools:
generic_request["tools"] = request.tools
if hasattr(request, 'tool_choice') and request.tool_choice:
generic_request["tool_choice"] = request.tool_choice
headers = {"Content-Type": "application/json"}
endpoint_url = model_config.endpoint
# For Groq endpoints, use tenant-specific API key from Control Panel DB
if is_provider_endpoint(endpoint_url, ["groq.com"]):
api_key = await self._get_tenant_api_key(tenant_id)
headers["Authorization"] = f"Bearer {api_key}"
# For NVIDIA NIM endpoints, use tenant-specific API key from Control Panel DB
elif is_provider_endpoint(endpoint_url, ["nvidia.com", "integrate.api.nvidia.com"]):
api_key = await self._get_tenant_nvidia_api_key(tenant_id)
headers["Authorization"] = f"Bearer {api_key}"
# For other endpoints, use model_config.api_key if configured
elif hasattr(model_config, 'api_key') and model_config.api_key:
headers["Authorization"] = f"Bearer {model_config.api_key}"
logger.info(f"Sending request to generic endpoint: {endpoint_url}")
if request.stream:
return await self._stream_generic_response(generic_request, headers, endpoint_url, request_id, model_config)
else:
return await self._get_generic_response(generic_request, headers, endpoint_url, request_id, model_config)
except Exception as e:
logger.error(f"Generic request processing failed: {e}")
raise ValueError(f"Generic inference failed: {e}")
async def _get_generic_response(
self,
generic_request: Dict[str, Any],
headers: Dict[str, str],
endpoint_url: str,
request_id: str,
model_config: Any
) -> LLMResponse:
"""Get non-streaming response from generic endpoint"""
try:
response = await self.http_client.post(
endpoint_url,
json=generic_request,
headers=headers
)
response.raise_for_status()
data = response.json()
# Convert to standardized format
return LLMResponse(
id=data.get("id", request_id),
object=data.get("object", "chat.completion"),
created=data.get("created", int(time.time())),
model=data.get("model", generic_request["model"]),
choices=data.get("choices", []),
usage=data.get("usage", {}),
provider=getattr(model_config, 'provider', 'generic'),
request_id=request_id
)
except httpx.HTTPStatusError as e:
logger.error(f"Generic API HTTP error: {e.response.status_code} - {e.response.text}")
raise ValueError(f"Generic API error: {e.response.status_code}")
except Exception as e:
logger.error(f"Generic response error: {e}")
raise ValueError(f"Generic response processing failed: {e}")
async def _stream_generic_response(
self,
generic_request: Dict[str, Any],
headers: Dict[str, str],
endpoint_url: str,
request_id: str,
model_config: Any
):
"""Stream response from generic endpoint"""
try:
# For now, just do a non-streaming request and convert to streaming format
# This can be enhanced to support actual streaming later
response = await self._get_generic_response(generic_request, headers, endpoint_url, request_id, model_config)
# Convert to streaming format
if response.choices and len(response.choices) > 0:
content = response.choices[0].get("message", {}).get("content", "")
yield f"data: {json.dumps({'choices': [{'delta': {'content': content}}]})}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Generic streaming error: {e}")
yield f"data: {json.dumps({'error': f'Streaming error: {e}'})}\n\n"
async def _process_openai_request(
self,
request: LLMRequest,
request_id: str,
model_config: ModelConfig
) -> Union[LLMResponse, AsyncGenerator[str, None]]:
"""Process request using OpenAI API"""
try:
# Prepare OpenAI API request
openai_request = {
"model": request.model,
"messages": request.messages,
"max_tokens": min(request.max_tokens or 1024, model_config.max_tokens),
"temperature": request.temperature or 0.7,
"top_p": request.top_p or 1.0,
"stream": request.stream
}
if request.stop:
openai_request["stop"] = request.stop
headers = {
"Authorization": f"Bearer {settings.openai_api_key}",
"Content-Type": "application/json"
}
if request.stream:
return self._stream_openai_response(openai_request, headers, request_id)
else:
return await self._get_openai_response(openai_request, headers, request_id)
except Exception as e:
logger.error(f"OpenAI API request failed: {e}")
raise ValueError(f"OpenAI API error: {e}")
async def _get_openai_response(
self,
openai_request: Dict[str, Any],
headers: Dict[str, str],
request_id: str
) -> LLMResponse:
"""Get non-streaming response from OpenAI"""
try:
response = await self.http_client.post(
"https://api.openai.com/v1/chat/completions",
json=openai_request,
headers=headers
)
response.raise_for_status()
data = response.json()
# Convert to standardized format
return LLMResponse(
id=data.get("id", request_id),
object=data.get("object", "chat.completion"),
created=data.get("created", int(time.time())),
model=data.get("model", openai_request["model"]),
choices=data.get("choices", []),
usage=data.get("usage", {}),
provider="openai",
request_id=request_id
)
except httpx.HTTPStatusError as e:
logger.error(f"OpenAI API HTTP error: {e.response.status_code} - {e.response.text}")
raise ValueError(f"OpenAI API error: {e.response.status_code}")
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise ValueError(f"OpenAI API request failed: {e}")
async def _stream_openai_response(
self,
openai_request: Dict[str, Any],
headers: Dict[str, str],
request_id: str
) -> AsyncGenerator[str, None]:
"""Stream response from OpenAI"""
try:
async with self.http_client.stream(
"POST",
"https://api.openai.com/v1/chat/completions",
json=openai_request,
headers=headers
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
# Add provider and request_id to chunk
data["provider"] = "openai"
data["request_id"] = request_id
yield f"data: {json.dumps(data)}\n\n"
except json.JSONDecodeError:
continue
yield "data: [DONE]\n\n"
except httpx.HTTPStatusError as e:
logger.error(f"OpenAI streaming error: {e.response.status_code}")
yield f"data: {json.dumps({'error': f'OpenAI API error: {e.response.status_code}'})}\n\n"
except Exception as e:
logger.error(f"OpenAI streaming error: {e}")
yield f"data: {json.dumps({'error': f'Streaming error: {e}'})}\n\n"
async def _update_stats(
self,
model: str,
provider: ModelProvider,
latency: float,
success: bool
) -> None:
"""Update request statistics"""
self.stats["total_requests"] += 1
if success:
self.stats["successful_requests"] += 1
else:
self.stats["failed_requests"] += 1
self.stats["provider_usage"][provider.value] += 1
if model not in self.stats["model_usage"]:
self.stats["model_usage"][model] = 0
self.stats["model_usage"][model] += 1
# Update rolling average latency
total_requests = self.stats["total_requests"]
current_avg = self.stats["average_latency"]
self.stats["average_latency"] = ((current_avg * (total_requests - 1)) + latency) / total_requests
async def get_available_models(self) -> List[Dict[str, Any]]:
"""Get list of available models with capabilities"""
models = []
for model_id, config in self.models.items():
if config.is_available:
models.append({
"id": model_id,
"provider": config.provider.value,
"capabilities": [cap.value for cap in config.capabilities],
"max_tokens": config.max_tokens,
"context_window": config.context_window,
"supports_streaming": config.supports_streaming,
"supports_functions": config.supports_functions
})
return models
async def get_gateway_stats(self) -> Dict[str, Any]:
"""Get gateway statistics"""
return {
**self.stats,
"provider_health": {
provider.value: health
for provider, health in self.provider_health.items()
},
"active_models": len([m for m in self.models.values() if m.is_available]),
"timestamp": datetime.now(timezone.utc).isoformat()
}
async def health_check(self) -> Dict[str, Any]:
"""Health check for the LLM gateway"""
healthy_providers = sum(1 for health in self.provider_health.values() if health)
total_providers = len(self.provider_health)
return {
"status": "healthy" if healthy_providers > 0 else "degraded",
"providers_healthy": healthy_providers,
"total_providers": total_providers,
"available_models": len([m for m in self.models.values() if m.is_available]),
"total_requests": self.stats["total_requests"],
"success_rate": (
self.stats["successful_requests"] / max(self.stats["total_requests"], 1)
) * 100,
"average_latency_ms": self.stats["average_latency"] * 1000
}
async def close(self):
"""Close HTTP client and cleanup resources"""
await self.http_client.aclose()
# Global gateway instance
llm_gateway = LLMGateway()
# Factory function for dependency injection
def get_llm_gateway() -> LLMGateway:
"""Get LLM gateway instance"""
return llm_gateway

View File

@@ -0,0 +1,599 @@
"""
GT 2.0 MCP RAG Server
Provides RAG (Retrieval-Augmented Generation) capabilities as an MCP server.
Agents can use this server to search datasets, query documents, and retrieve
relevant context for user queries.
Tools provided:
- search_datasets: Search across user's accessible datasets
- query_documents: Query specific documents for relevant chunks
- get_relevant_chunks: Get relevant text chunks based on similarity
- list_user_datasets: List all datasets accessible to the user
- get_dataset_info: Get detailed information about a dataset
"""
import asyncio
import logging
from typing import Dict, Any, List, Optional, Union
from datetime import datetime
from dataclasses import dataclass
import httpx
import json
from app.core.security import verify_capability_token
from app.services.mcp_server import MCPServerResource, MCPServerConfig
logger = logging.getLogger(__name__)
@dataclass
class RAGSearchParams:
"""Parameters for RAG search operations"""
query: str
dataset_ids: Optional[List[str]] = None
search_method: str = "hybrid" # hybrid, vector, text
max_results: int = 10
similarity_threshold: float = 0.7
include_metadata: bool = True
@dataclass
class RAGSearchResult:
"""Result from RAG search operation"""
chunk_id: str
document_id: str
dataset_id: str
dataset_name: str
document_name: str
content: str
similarity_score: float
chunk_index: int
metadata: Dict[str, Any]
class MCPRAGServer:
"""
MCP server for RAG operations in GT 2.0.
Provides secure, tenant-isolated access to document search capabilities
through standardized MCP tool interfaces.
"""
def __init__(self, tenant_backend_url: str = "http://tenant-backend:8000"):
self.tenant_backend_url = tenant_backend_url
self.server_name = "rag_server"
self.server_type = "rag"
# Define available tools (streamlined for simplicity)
self.available_tools = [
"search_datasets"
]
# Tool schemas for MCP protocol (enhanced with flexible parameters)
self.tool_schemas = {
"search_datasets": {
"name": "search_datasets",
"description": "Search through datasets containing uploaded documents, PDFs, and files. Use when users ask about documentation, reference materials, checking files, looking up information, need data from uploaded content, want to know what's in the dataset, search our data, check if we have something, or look through files.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "What to search for in the datasets"
},
"dataset_ids": {
"type": "array",
"items": {"type": "string"},
"description": "(Optional) List of specific dataset IDs to search within"
},
"file_pattern": {
"type": "string",
"description": "(Optional) File pattern filter (e.g., '*.pdf', '*.txt')"
},
"search_all": {
"type": "boolean",
"default": False,
"description": "(Optional) Search across all accessible datasets (ignores dataset_ids)"
},
"max_results": {
"type": "integer",
"default": 10,
"description": "(Optional) Number of results to return (default: 10)"
}
},
"required": ["query"]
}
}
}
async def handle_tool_call(
self,
tool_name: str,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str,
agent_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Handle MCP tool call with tenant isolation and user context.
Args:
tool_name: Name of the tool being called
parameters: Tool parameters from the LLM
tenant_domain: Tenant domain for isolation
user_id: User making the request
Returns:
Tool execution result or error
"""
logger.info(f"🚀 MCP RAG Server: handle_tool_call called - tool={tool_name}, tenant={tenant_domain}, user={user_id}")
logger.info(f"📝 MCP RAG Server: parameters={parameters}")
try:
# Validate tool exists
if tool_name not in self.available_tools:
return {
"error": f"Unknown tool: {tool_name}",
"tool_name": tool_name
}
# Route to appropriate handler
if tool_name == "search_datasets":
return await self._search_datasets(parameters, tenant_domain, user_id, agent_context)
else:
return {
"error": f"Tool handler not implemented: {tool_name}",
"tool_name": tool_name
}
except Exception as e:
logger.error(f"Error handling tool call {tool_name}: {e}")
return {
"error": f"Tool execution failed: {str(e)}",
"tool_name": tool_name
}
def _verify_user_access(self, user_id: str, tenant_domain: str) -> bool:
"""Verify user has access to tenant resources (simplified check)"""
# In a real system, this would query the database to verify
# that the user has access to the tenant's resources
# For now, we trust that the tenant backend has already verified this
return bool(user_id and tenant_domain)
async def _search_datasets(
self,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str,
agent_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Search across user's datasets"""
logger.info(f"🔍 RAG Server: search_datasets called for user {user_id} in tenant {tenant_domain}")
logger.info(f"📝 RAG Server: search parameters = {parameters}")
logger.info(f"📝 RAG Server: parameter types: {[(k, type(v)) for k, v in parameters.items()]}")
try:
query = parameters.get("query", "").strip()
list_mode = parameters.get("list_mode", False)
# Handle list mode - list available datasets instead of searching
if list_mode:
logger.info(f"🔍 RAG Server: List mode activated - fetching available datasets")
async with httpx.AsyncClient(timeout=15.0) as client:
response = await client.get(
f"{self.tenant_backend_url}/api/v1/datasets/internal/list",
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id
}
)
if response.status_code == 200:
datasets = response.json()
logger.info(f"✅ RAG Server: Successfully listed {len(datasets)} datasets")
return {
"success": True,
"datasets": datasets,
"total_count": len(datasets),
"list_mode": True
}
else:
logger.error(f"❌ RAG Server: Failed to list datasets: {response.status_code} - {response.text}")
return {"error": f"Failed to list datasets: {response.status_code}"}
# Normal search mode
if not query:
logger.error("❌ RAG Server: Query parameter is required")
return {"error": "Query parameter is required"}
# Prepare search request with enhanced parameters
dataset_ids = parameters.get("dataset_ids")
file_pattern = parameters.get("file_pattern")
search_all = parameters.get("search_all", False)
# Handle legacy dataset_id parameter (backwards compatibility)
if dataset_ids is None and parameters.get("dataset_id"):
dataset_ids = [parameters.get("dataset_id")]
# Ensure dataset_ids is properly formatted
if dataset_ids is None:
dataset_ids = []
elif isinstance(dataset_ids, str):
dataset_ids = [dataset_ids]
# If search_all is True, ignore dataset_ids filter
if search_all:
dataset_ids = []
# AGENT-AWARE: If no datasets specified, use agent's configured datasets
if not dataset_ids and not search_all and agent_context:
agent_dataset_ids = agent_context.get('selected_dataset_ids', [])
if agent_dataset_ids:
dataset_ids = agent_dataset_ids
agent_name = agent_context.get('agent_name', 'Unknown')
logger.info(f"✅ RAG Server: Using agent '{agent_name}' datasets: {dataset_ids}")
else:
logger.warning(f"⚠️ RAG Server: Agent context available but no datasets configured")
elif not dataset_ids and not search_all:
logger.warning(f"⚠️ RAG Server: No dataset_ids provided and no agent context available")
search_request = {
"query": query,
"search_type": parameters.get("search_method", "hybrid"),
"max_results": parameters.get("max_results", 10), # No arbitrary cap
"dataset_ids": dataset_ids,
"min_similarity": 0.3
}
# Add file_pattern if provided
if file_pattern:
search_request["file_pattern"] = file_pattern
logger.info(f"🎯 RAG Server: prepared search request = {search_request}")
# Call tenant backend search API
logger.info(f"🌐 RAG Server: calling tenant backend at {self.tenant_backend_url}/api/v1/search/")
logger.info(f"🌐 RAG Server: request headers: X-Tenant-Domain='{tenant_domain}', X-User-ID='{user_id}'")
logger.info(f"🌐 RAG Server: request body: {search_request}")
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{self.tenant_backend_url}/api/v1/search/",
json=search_request,
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id,
"Content-Type": "application/json"
}
)
logger.info(f"📊 RAG Server: tenant backend response: {response.status_code}")
if response.status_code != 200:
logger.error(f"📊 RAG Server: error response body: {response.text}")
if response.status_code == 200:
data = response.json()
logger.info(f"✅ RAG Server: search successful, got {len(data.get('results', []))} results")
# Format results for MCP response
results = []
for result in data.get("results", []):
results.append({
"chunk_id": result.get("chunk_id"),
"document_id": result.get("document_id"),
"dataset_id": result.get("dataset_id"),
"content": result.get("text", ""),
"similarity_score": result.get("hybrid_score", 0.0),
"metadata": result.get("metadata", {})
})
return {
"success": True,
"query": query,
"results_count": len(results),
"results": results,
"search_method": data.get("search_type", "hybrid")
}
else:
error_text = response.text
logger.error(f"❌ RAG Server: search failed: {response.status_code} - {error_text}")
return {
"error": f"Search failed: {response.status_code} - {error_text}",
"query": query
}
except Exception as e:
logger.error(f"Dataset search error: {e}")
return {
"error": f"Search operation failed: {str(e)}",
"query": parameters.get("query", "")
}
async def _query_documents(
self,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str
) -> Dict[str, Any]:
"""Query specific documents for relevant chunks"""
try:
query = parameters.get("query", "").strip()
document_ids = parameters.get("document_ids", [])
if not query or not document_ids:
return {"error": "Both query and document_ids are required"}
# Use search API with document ID filter
search_request = {
"query": query,
"search_type": "hybrid",
"max_results": parameters.get("max_results", 5),
"document_ids": document_ids
}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{self.tenant_backend_url}/api/v1/search/documents",
json=search_request,
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id,
"Content-Type": "application/json"
}
)
if response.status_code == 200:
data = response.json()
return {
"success": True,
"query": query,
"document_ids": document_ids,
"results": data.get("results", [])
}
else:
return {
"error": f"Document query failed: {response.status_code}",
"query": query,
"document_ids": document_ids
}
except Exception as e:
return {
"error": f"Document query failed: {str(e)}",
"query": parameters.get("query", "")
}
async def _list_user_datasets(
self,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str
) -> Dict[str, Any]:
"""List user's accessible datasets"""
try:
include_stats = parameters.get("include_stats", True)
async with httpx.AsyncClient(timeout=15.0) as client:
params = {"include_stats": include_stats}
response = await client.get(
f"{self.tenant_backend_url}/api/v1/datasets",
params=params,
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id
}
)
if response.status_code == 200:
data = response.json()
datasets = data.get("data", []) if isinstance(data, dict) else data
# Format for MCP response
formatted_datasets = []
for dataset in datasets:
formatted_datasets.append({
"id": dataset.get("id"),
"name": dataset.get("name"),
"description": dataset.get("description"),
"document_count": dataset.get("document_count", 0),
"chunk_count": dataset.get("chunk_count", 0),
"created_at": dataset.get("created_at"),
"access_group": dataset.get("access_group", "individual")
})
return {
"success": True,
"datasets": formatted_datasets,
"total_count": len(formatted_datasets)
}
else:
return {
"error": f"Failed to list datasets: {response.status_code}"
}
except Exception as e:
return {
"error": f"Failed to list datasets: {str(e)}"
}
async def _get_dataset_info(
self,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str
) -> Dict[str, Any]:
"""Get detailed information about a dataset"""
try:
dataset_id = parameters.get("dataset_id")
if not dataset_id:
return {"error": "dataset_id parameter is required"}
async with httpx.AsyncClient(timeout=15.0) as client:
response = await client.get(
f"{self.tenant_backend_url}/api/v1/datasets/{dataset_id}",
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id
}
)
if response.status_code == 200:
data = response.json()
dataset = data.get("data", data)
return {
"success": True,
"dataset": {
"id": dataset.get("id"),
"name": dataset.get("name"),
"description": dataset.get("description"),
"document_count": dataset.get("document_count", 0),
"chunk_count": dataset.get("chunk_count", 0),
"vector_count": dataset.get("vector_count", 0),
"storage_size_mb": dataset.get("storage_size_mb", 0),
"created_at": dataset.get("created_at"),
"updated_at": dataset.get("updated_at"),
"access_group": dataset.get("access_group"),
"tags": dataset.get("tags", [])
}
}
elif response.status_code == 404:
return {
"error": f"Dataset not found: {dataset_id}"
}
else:
return {
"error": f"Failed to get dataset info: {response.status_code}"
}
except Exception as e:
return {
"error": f"Failed to get dataset info: {str(e)}"
}
async def _get_user_agent_datasets(self, tenant_domain: str, user_id: str) -> List[str]:
"""Auto-detect agent datasets for the current user"""
try:
# Get user's agents and their configured datasets
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
f"{self.tenant_backend_url}/api/v1/agents",
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id
}
)
if response.status_code == 200:
agents_data = response.json()
agents = agents_data.get("data", []) if isinstance(agents_data, dict) else agents_data
# Collect all dataset IDs from all user's agents
all_dataset_ids = set()
for agent in agents:
agent_dataset_ids = agent.get("selected_dataset_ids", [])
if agent_dataset_ids:
all_dataset_ids.update(agent_dataset_ids)
logger.info(f"🔍 RAG Server: Agent {agent.get('name', 'unknown')} has datasets: {agent_dataset_ids}")
return list(all_dataset_ids)
else:
logger.warning(f"⚠️ RAG Server: Failed to get agents: {response.status_code}")
return []
except Exception as e:
logger.error(f"❌ RAG Server: Error getting user agent datasets: {e}")
return []
async def _get_relevant_chunks(
self,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str
) -> Dict[str, Any]:
"""Get most relevant chunks for a query"""
try:
query = parameters.get("query", "").strip()
if not query:
return {"error": "query parameter is required"}
chunk_count = min(parameters.get("chunk_count", 3), 10) # Cap at 10
min_similarity = parameters.get("min_similarity", 0.6)
dataset_ids = parameters.get("dataset_ids")
search_request = {
"query": query,
"search_type": "vector", # Use vector search for relevance
"max_results": chunk_count,
"min_similarity": min_similarity,
"dataset_ids": dataset_ids
}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{self.tenant_backend_url}/api/v1/search",
json=search_request,
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id,
"Content-Type": "application/json"
}
)
if response.status_code == 200:
data = response.json()
chunks = []
for result in data.get("results", []):
chunks.append({
"chunk_id": result.get("chunk_id"),
"document_id": result.get("document_id"),
"dataset_id": result.get("dataset_id"),
"content": result.get("text", ""),
"similarity_score": result.get("vector_similarity", 0.0),
"chunk_index": result.get("rank", 0),
"metadata": result.get("metadata", {})
})
return {
"success": True,
"query": query,
"chunks": chunks,
"chunk_count": len(chunks),
"min_similarity": min_similarity
}
else:
return {
"error": f"Chunk retrieval failed: {response.status_code}"
}
except Exception as e:
return {
"error": f"Failed to get relevant chunks: {str(e)}"
}
def get_server_config(self) -> MCPServerConfig:
"""Get MCP server configuration"""
return MCPServerConfig(
server_name=self.server_name,
server_url="internal://mcp-rag-server",
server_type=self.server_type,
available_tools=self.available_tools,
required_capabilities=["mcp:rag:*"],
sandbox_mode=True,
max_memory_mb=256,
max_cpu_percent=25,
timeout_seconds=30,
network_isolation=False, # Needs to access tenant backend
max_requests_per_minute=120,
max_concurrent_requests=10
)
def get_tool_schemas(self) -> Dict[str, Any]:
"""Get MCP tool schemas for this server"""
return self.tool_schemas
# Global instance
mcp_rag_server = MCPRAGServer()

View File

@@ -0,0 +1,491 @@
"""
MCP Sandbox Service for GT 2.0
Provides secure sandboxed execution environment for MCP servers.
Implements resource isolation, monitoring, and security constraints.
"""
import os
import asyncio
import resource
import signal
import tempfile
import shutil
from typing import Dict, Any, Optional, Callable, Tuple
from datetime import datetime, timedelta
from pathlib import Path
import logging
import json
import psutil
from contextlib import asynccontextmanager
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class SandboxConfig:
"""Configuration for sandbox environment"""
# Resource limits
max_memory_mb: int = 512
max_cpu_percent: int = 50
max_disk_mb: int = 100
timeout_seconds: int = 30
# Security settings
network_isolation: bool = True
readonly_filesystem: bool = False
allowed_paths: list = None
blocked_paths: list = None
allowed_commands: list = None
# Process limits
max_processes: int = 10
max_open_files: int = 100
max_threads: int = 20
def __post_init__(self):
if self.allowed_paths is None:
self.allowed_paths = ["/tmp", "/var/tmp"]
if self.blocked_paths is None:
self.blocked_paths = ["/etc", "/root", "/home", "/usr/bin", "/usr/sbin"]
if self.allowed_commands is None:
self.allowed_commands = ["ls", "cat", "grep", "find", "echo", "pwd"]
class ProcessSandbox:
"""
Process-level sandbox for MCP tool execution
Uses OS-level isolation and resource limits
"""
def __init__(self, config: SandboxConfig):
self.config = config
self.process: Optional[asyncio.subprocess.Process] = None
self.start_time: Optional[datetime] = None
self.temp_dir: Optional[Path] = None
self.resource_monitor_task: Optional[asyncio.Task] = None
async def __aenter__(self):
"""Enter sandbox context"""
await self.setup()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit sandbox context and cleanup"""
await self.cleanup()
async def setup(self):
"""Setup sandbox environment"""
# Create temporary directory for sandbox
self.temp_dir = Path(tempfile.mkdtemp(prefix="mcp_sandbox_"))
os.chmod(self.temp_dir, 0o700) # Restrict access
# Set resource limits for child processes
self._set_resource_limits()
# Start resource monitoring
self.resource_monitor_task = asyncio.create_task(self._monitor_resources())
self.start_time = datetime.utcnow()
logger.info(f"Sandbox setup complete: {self.temp_dir}")
async def cleanup(self):
"""Cleanup sandbox environment"""
# Stop resource monitoring
if self.resource_monitor_task:
self.resource_monitor_task.cancel()
try:
await self.resource_monitor_task
except asyncio.CancelledError:
pass
# Terminate process if still running
if self.process and self.process.returncode is None:
try:
self.process.terminate()
await asyncio.wait_for(self.process.wait(), timeout=5)
except asyncio.TimeoutError:
self.process.kill()
await self.process.wait()
# Remove temporary directory
if self.temp_dir and self.temp_dir.exists():
shutil.rmtree(self.temp_dir, ignore_errors=True)
logger.info("Sandbox cleanup complete")
async def execute(
self,
command: str,
args: list = None,
input_data: str = None,
env: Dict[str, str] = None
) -> Tuple[int, str, str]:
"""
Execute command in sandbox
Args:
command: Command to execute
args: Command arguments
input_data: Input to send to process
env: Environment variables
Returns:
Tuple of (return_code, stdout, stderr)
"""
# Validate command
if not self._validate_command(command):
raise PermissionError(f"Command not allowed: {command}")
# Prepare environment
sandbox_env = self._prepare_environment(env)
# Prepare command with arguments
full_command = [command] + (args or [])
try:
# Create process with resource limits
self.process = await asyncio.create_subprocess_exec(
*full_command,
stdin=asyncio.subprocess.PIPE if input_data else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(self.temp_dir),
env=sandbox_env,
preexec_fn=self._set_process_limits if os.name == 'posix' else None
)
# Execute with timeout
stdout, stderr = await asyncio.wait_for(
self.process.communicate(input=input_data.encode() if input_data else None),
timeout=self.config.timeout_seconds
)
return self.process.returncode, stdout.decode(), stderr.decode()
except asyncio.TimeoutError:
if self.process:
self.process.kill()
await self.process.wait()
raise TimeoutError(f"Command exceeded {self.config.timeout_seconds}s timeout")
except Exception as e:
logger.error(f"Sandbox execution error: {e}")
raise
async def execute_function(
self,
func: Callable,
*args,
**kwargs
) -> Any:
"""
Execute Python function in sandbox
Uses multiprocessing for isolation
"""
import multiprocessing
import pickle
# Create pipe for communication
parent_conn, child_conn = multiprocessing.Pipe()
def sandbox_wrapper(conn, func, args, kwargs):
"""Wrapper to execute function in child process"""
try:
# Apply resource limits
self._set_process_limits()
# Execute function
result = func(*args, **kwargs)
# Send result back
conn.send(("success", pickle.dumps(result)))
except Exception as e:
conn.send(("error", str(e)))
finally:
conn.close()
# Create and start process
process = multiprocessing.Process(
target=sandbox_wrapper,
args=(child_conn, func, args, kwargs)
)
process.start()
# Wait for result with timeout
try:
if parent_conn.poll(self.config.timeout_seconds):
status, data = parent_conn.recv()
if status == "success":
return pickle.loads(data)
else:
raise RuntimeError(f"Sandbox function error: {data}")
else:
process.terminate()
process.join(timeout=5)
if process.is_alive():
process.kill()
raise TimeoutError(f"Function exceeded {self.config.timeout_seconds}s timeout")
finally:
parent_conn.close()
if process.is_alive():
process.terminate()
process.join()
def _validate_command(self, command: str) -> bool:
"""Validate if command is allowed"""
# Check if command is in allowed list
command_name = os.path.basename(command)
if self.config.allowed_commands and command_name not in self.config.allowed_commands:
return False
# Check for dangerous patterns
dangerous_patterns = [
"rm -rf",
"dd if=",
"mkfs",
"format",
">", # Redirect that could overwrite files
"|", # Pipe that could chain commands
";", # Command separator
"&", # Background execution
"`", # Command substitution
"$(" # Command substitution
]
for pattern in dangerous_patterns:
if pattern in command:
return False
return True
def _prepare_environment(self, custom_env: Dict[str, str] = None) -> Dict[str, str]:
"""Prepare sandboxed environment variables"""
# Start with minimal environment
sandbox_env = {
"PATH": "/usr/local/bin:/usr/bin:/bin",
"HOME": str(self.temp_dir),
"TEMP": str(self.temp_dir),
"TMP": str(self.temp_dir),
"USER": "sandbox",
"SHELL": "/bin/sh"
}
# Add custom environment variables if provided
if custom_env:
# Filter out dangerous variables
dangerous_vars = ["LD_PRELOAD", "LD_LIBRARY_PATH", "PYTHONPATH", "PATH"]
for key, value in custom_env.items():
if key not in dangerous_vars:
sandbox_env[key] = value
return sandbox_env
def _set_resource_limits(self):
"""Set resource limits for the process"""
if os.name != 'posix':
return # Resource limits only work on POSIX systems
# Memory limit
memory_bytes = self.config.max_memory_mb * 1024 * 1024
resource.setrlimit(resource.RLIMIT_AS, (memory_bytes, memory_bytes))
# CPU time limit
resource.setrlimit(resource.RLIMIT_CPU, (self.config.timeout_seconds, self.config.timeout_seconds))
# File size limit
file_size_bytes = self.config.max_disk_mb * 1024 * 1024
resource.setrlimit(resource.RLIMIT_FSIZE, (file_size_bytes, file_size_bytes))
# Process limit
resource.setrlimit(resource.RLIMIT_NPROC, (self.config.max_processes, self.config.max_processes))
# Open files limit
resource.setrlimit(resource.RLIMIT_NOFILE, (self.config.max_open_files, self.config.max_open_files))
def _set_process_limits(self):
"""Set limits for child process (called in child context)"""
if os.name != 'posix':
return
# Drop privileges if running as root (shouldn't happen in production)
if os.getuid() == 0:
os.setuid(65534) # nobody user
os.setgid(65534) # nogroup
# Set resource limits
self._set_resource_limits()
# Set process group for easier cleanup
os.setpgrp()
async def _monitor_resources(self):
"""Monitor resource usage of sandboxed process"""
while True:
try:
if self.process and self.process.returncode is None:
# Get process info
try:
proc = psutil.Process(self.process.pid)
# Check CPU usage
cpu_percent = proc.cpu_percent(interval=0.1)
if cpu_percent > self.config.max_cpu_percent:
logger.warning(f"Sandbox CPU usage high: {cpu_percent}%")
# Could throttle or terminate if consistently high
# Check memory usage
memory_info = proc.memory_info()
memory_mb = memory_info.rss / (1024 * 1024)
if memory_mb > self.config.max_memory_mb:
logger.warning(f"Sandbox memory limit exceeded: {memory_mb}MB")
self.process.terminate()
break
# Check runtime
if self.start_time:
runtime = (datetime.utcnow() - self.start_time).total_seconds()
if runtime > self.config.timeout_seconds:
logger.warning(f"Sandbox timeout exceeded: {runtime}s")
self.process.terminate()
break
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass # Process ended or inaccessible
await asyncio.sleep(1) # Check every second
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Resource monitoring error: {e}")
await asyncio.sleep(1)
class ContainerSandbox:
"""
Container-based sandbox for stronger isolation
Uses Docker or Podman for execution
"""
def __init__(self, config: SandboxConfig):
self.config = config
self.container_id: Optional[str] = None
self.container_runtime = self._detect_container_runtime()
def _detect_container_runtime(self) -> str:
"""Detect available container runtime"""
# Try Docker first
if shutil.which("docker"):
return "docker"
# Try Podman as alternative
elif shutil.which("podman"):
return "podman"
else:
logger.warning("No container runtime found, falling back to process sandbox")
return None
@asynccontextmanager
async def create_container(self, image: str = "alpine:latest"):
"""Create and manage container lifecycle"""
if not self.container_runtime:
raise RuntimeError("No container runtime available")
try:
# Create container with resource limits
create_cmd = [
self.container_runtime, "create",
"--rm", # Auto-remove after stop
f"--memory={self.config.max_memory_mb}m",
f"--cpus={self.config.max_cpu_percent / 100}",
"--network=none" if self.config.network_isolation else "--network=bridge",
"--read-only" if self.config.readonly_filesystem else "",
f"--tmpfs=/tmp:size={self.config.max_disk_mb}m",
"--security-opt=no-new-privileges",
"--cap-drop=ALL", # Drop all capabilities
image,
"sleep", "infinity" # Keep container running
]
# Remove empty strings from command
create_cmd = [arg for arg in create_cmd if arg]
# Create container
proc = await asyncio.create_subprocess_exec(
*create_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"Failed to create container: {stderr.decode()}")
self.container_id = stdout.decode().strip()
# Start container
start_cmd = [self.container_runtime, "start", self.container_id]
proc = await asyncio.create_subprocess_exec(*start_cmd)
await proc.wait()
logger.info(f"Container sandbox created: {self.container_id[:12]}")
yield self
finally:
# Cleanup container
if self.container_id:
stop_cmd = [self.container_runtime, "stop", self.container_id]
proc = await asyncio.create_subprocess_exec(*stop_cmd)
await proc.wait()
logger.info(f"Container sandbox cleaned up: {self.container_id[:12]}")
async def execute(self, command: str, args: list = None) -> Tuple[int, str, str]:
"""Execute command in container"""
if not self.container_id:
raise RuntimeError("Container not created")
exec_cmd = [
self.container_runtime, "exec",
self.container_id,
command
] + (args or [])
proc = await asyncio.create_subprocess_exec(
*exec_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=self.config.timeout_seconds
)
return proc.returncode, stdout.decode(), stderr.decode()
except asyncio.TimeoutError:
# Kill process in container
kill_cmd = [self.container_runtime, "exec", self.container_id, "kill", "-9", "-1"]
await asyncio.create_subprocess_exec(*kill_cmd)
raise TimeoutError(f"Command exceeded {self.config.timeout_seconds}s timeout")
# Factory function to get appropriate sandbox
def create_sandbox(config: SandboxConfig, prefer_container: bool = True) -> Any:
"""
Create appropriate sandbox based on availability and preference
Args:
config: Sandbox configuration
prefer_container: Prefer container over process sandbox
Returns:
ProcessSandbox or ContainerSandbox instance
"""
if prefer_container and shutil.which("docker"):
return ContainerSandbox(config)
elif prefer_container and shutil.which("podman"):
return ContainerSandbox(config)
else:
return ProcessSandbox(config)

View File

@@ -0,0 +1,698 @@
"""
MCP Server Resource Wrapper for GT 2.0
Encapsulates MCP (Model Context Protocol) servers as GT 2.0 resources.
Provides security sandboxing and capability-based access control.
"""
from typing import List, Optional, Dict, Any, AsyncIterator
from datetime import datetime, timedelta
from enum import Enum
import asyncio
import logging
import json
from dataclasses import dataclass
from app.models.access_group import AccessGroup, Resource
from app.core.security import verify_capability_token
logger = logging.getLogger(__name__)
class MCPServerStatus(str, Enum):
"""MCP server operational status"""
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
STARTING = "starting"
STOPPING = "stopping"
STOPPED = "stopped"
@dataclass
class MCPServerConfig:
"""Configuration for an MCP server instance"""
server_name: str
server_url: str
server_type: str # filesystem, github, slack, etc.
available_tools: List[str]
required_capabilities: List[str]
# Security settings
sandbox_mode: bool = True
max_memory_mb: int = 512
max_cpu_percent: int = 50
timeout_seconds: int = 30
network_isolation: bool = True
# Rate limiting
max_requests_per_minute: int = 60
max_concurrent_requests: int = 5
# Authentication
auth_type: Optional[str] = None # none, api_key, oauth2
auth_config: Optional[Dict[str, Any]] = None
class MCPServerResource(Resource):
"""
MCP server encapsulated as a GT 2.0 resource
Inherits from Resource for access control
"""
# MCP-specific configuration
server_config: MCPServerConfig
# Runtime state
status: MCPServerStatus = MCPServerStatus.STOPPED
last_health_check: Optional[datetime] = None
error_count: int = 0
total_requests: int = 0
# Connection management
connection_pool_size: int = 5
active_connections: int = 0
def to_capability_requirement(self) -> str:
"""Generate capability requirement string for this MCP server"""
return f"mcp:{self.server_config.server_name}:*"
def validate_tool_access(self, tool_name: str, capability_token: Dict[str, Any]) -> bool:
"""Check if capability token allows access to specific tool"""
required_capability = f"mcp:{self.server_config.server_name}:{tool_name}"
capabilities = capability_token.get("capabilities", [])
for cap in capabilities:
resource = cap.get("resource", "")
if resource == required_capability or resource == f"mcp:{self.server_config.server_name}:*":
return True
return False
class SecureMCPWrapper:
"""
Secure wrapper for MCP servers with GT 2.0 security integration
Provides sandboxing, rate limiting, and capability-based access
"""
def __init__(self, resource_cluster_url: str = "http://localhost:8004"):
self.resource_cluster_url = resource_cluster_url
self.mcp_resources: Dict[str, MCPServerResource] = {}
self.rate_limiters: Dict[str, asyncio.Semaphore] = {}
self.audit_log = []
async def register_mcp_server(
self,
server_config: MCPServerConfig,
owner_id: str,
tenant_domain: str,
access_group: AccessGroup = AccessGroup.INDIVIDUAL
) -> MCPServerResource:
"""
Register an MCP server as a GT 2.0 resource
Args:
server_config: MCP server configuration
owner_id: User who owns this MCP resource
tenant_domain: Tenant domain
access_group: Access control level
Returns:
Registered MCP server resource
"""
# Create MCP resource
resource = MCPServerResource(
id=f"mcp-{server_config.server_name}-{datetime.utcnow().timestamp()}",
name=f"MCP Server: {server_config.server_name}",
resource_type="mcp_server",
owner_id=owner_id,
tenant_domain=tenant_domain,
access_group=access_group,
team_members=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
metadata={
"server_type": server_config.server_type,
"tools_count": len(server_config.available_tools)
},
server_config=server_config
)
# Initialize rate limiter
self.rate_limiters[resource.id] = asyncio.Semaphore(
server_config.max_concurrent_requests
)
# Store resource
self.mcp_resources[resource.id] = resource
# Start health monitoring
asyncio.create_task(self._monitor_health(resource.id))
logger.info(f"Registered MCP server: {server_config.server_name} as resource {resource.id}")
return resource
async def execute_tool(
self,
mcp_resource_id: str,
tool_name: str,
parameters: Dict[str, Any],
capability_token: str,
user_id: str
) -> Dict[str, Any]:
"""
Execute an MCP tool with security constraints
Args:
mcp_resource_id: MCP resource identifier
tool_name: Tool to execute
parameters: Tool parameters
capability_token: JWT capability token
user_id: User executing the tool
Returns:
Tool execution result
"""
# Load MCP resource
mcp_resource = self.mcp_resources.get(mcp_resource_id)
if not mcp_resource:
raise ValueError(f"MCP resource not found: {mcp_resource_id}")
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data:
raise PermissionError("Invalid capability token")
# Check tenant match
if token_data.get("tenant_id") != mcp_resource.tenant_domain:
raise PermissionError("Tenant mismatch")
# Validate tool access
if not mcp_resource.validate_tool_access(tool_name, token_data):
raise PermissionError(f"No capability for tool: {tool_name}")
# Check if tool exists
if tool_name not in mcp_resource.server_config.available_tools:
raise ValueError(f"Tool not available: {tool_name}")
# Apply rate limiting
async with self.rate_limiters[mcp_resource_id]:
try:
# Execute tool with timeout and sandboxing
result = await self._execute_tool_sandboxed(
mcp_resource, tool_name, parameters, user_id
)
# Update metrics
mcp_resource.total_requests += 1
# Audit log
self._log_tool_execution(
mcp_resource_id, tool_name, user_id, "success", result
)
return result
except Exception as e:
# Update error metrics
mcp_resource.error_count += 1
# Audit log
self._log_tool_execution(
mcp_resource_id, tool_name, user_id, "error", str(e)
)
raise
async def _execute_tool_sandboxed(
self,
mcp_resource: MCPServerResource,
tool_name: str,
parameters: Dict[str, Any],
user_id: str
) -> Dict[str, Any]:
"""Execute tool in sandboxed environment"""
# Create sandbox context
sandbox_context = {
"user_id": user_id,
"tenant_domain": mcp_resource.tenant_domain,
"resource_limits": {
"max_memory_mb": mcp_resource.server_config.max_memory_mb,
"max_cpu_percent": mcp_resource.server_config.max_cpu_percent,
"timeout_seconds": mcp_resource.server_config.timeout_seconds
},
"network_isolation": mcp_resource.server_config.network_isolation
}
# Execute based on server type
if mcp_resource.server_config.server_type == "filesystem":
return await self._execute_filesystem_tool(
tool_name, parameters, sandbox_context
)
elif mcp_resource.server_config.server_type == "github":
return await self._execute_github_tool(
tool_name, parameters, sandbox_context
)
elif mcp_resource.server_config.server_type == "slack":
return await self._execute_slack_tool(
tool_name, parameters, sandbox_context
)
elif mcp_resource.server_config.server_type == "web":
return await self._execute_web_tool(
tool_name, parameters, sandbox_context
)
elif mcp_resource.server_config.server_type == "database":
return await self._execute_database_tool(
tool_name, parameters, sandbox_context
)
else:
return await self._execute_custom_tool(
mcp_resource, tool_name, parameters, sandbox_context
)
async def _execute_filesystem_tool(
self,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute filesystem MCP tools"""
if tool_name == "read_file":
# Simulate file reading with sandbox constraints
file_path = parameters.get("path", "")
# Security validation
if not self._validate_file_path(file_path, sandbox_context):
raise PermissionError("Access denied to file path")
return {
"tool": "read_file",
"content": f"File content from {file_path}",
"size_bytes": 1024,
"mime_type": "text/plain"
}
elif tool_name == "write_file":
file_path = parameters.get("path", "")
content = parameters.get("content", "")
# Security validation
if not self._validate_file_path(file_path, sandbox_context):
raise PermissionError("Access denied to file path")
if len(content) > 1024 * 1024: # 1MB limit
raise ValueError("File content too large")
return {
"tool": "write_file",
"path": file_path,
"bytes_written": len(content),
"status": "success"
}
elif tool_name == "list_directory":
dir_path = parameters.get("path", "")
if not self._validate_file_path(dir_path, sandbox_context):
raise PermissionError("Access denied to directory path")
return {
"tool": "list_directory",
"path": dir_path,
"entries": ["file1.txt", "file2.txt", "subdir/"],
"total_entries": 3
}
else:
raise ValueError(f"Unknown filesystem tool: {tool_name}")
async def _execute_github_tool(
self,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute GitHub MCP tools"""
if tool_name == "get_repository":
repo_name = parameters.get("repository", "")
return {
"tool": "get_repository",
"repository": repo_name,
"owner": "example",
"description": "Example repository",
"language": "Python",
"stars": 123,
"forks": 45
}
elif tool_name == "create_issue":
title = parameters.get("title", "")
body = parameters.get("body", "")
return {
"tool": "create_issue",
"issue_number": 42,
"title": title,
"url": f"https://github.com/example/repo/issues/42",
"status": "created"
}
elif tool_name == "search_code":
query = parameters.get("query", "")
return {
"tool": "search_code",
"query": query,
"results": [
{
"file": "main.py",
"line": 15,
"content": f"# Code matching {query}"
}
],
"total_results": 1
}
else:
raise ValueError(f"Unknown GitHub tool: {tool_name}")
async def _execute_slack_tool(
self,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute Slack MCP tools"""
if tool_name == "send_message":
channel = parameters.get("channel", "")
message = parameters.get("message", "")
return {
"tool": "send_message",
"channel": channel,
"message": message,
"timestamp": datetime.utcnow().isoformat(),
"status": "sent"
}
elif tool_name == "get_channel_history":
channel = parameters.get("channel", "")
limit = parameters.get("limit", 10)
return {
"tool": "get_channel_history",
"channel": channel,
"messages": [
{
"user": "user1",
"text": "Hello world!",
"timestamp": datetime.utcnow().isoformat()
}
] * min(limit, 10),
"total_messages": limit
}
else:
raise ValueError(f"Unknown Slack tool: {tool_name}")
async def _execute_web_tool(
self,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute web MCP tools"""
if tool_name == "fetch_url":
url = parameters.get("url", "")
# URL validation
if not self._validate_url(url, sandbox_context):
raise PermissionError("Access denied to URL")
return {
"tool": "fetch_url",
"url": url,
"status_code": 200,
"content": f"Content from {url}",
"headers": {"content-type": "text/html"}
}
elif tool_name == "submit_form":
url = parameters.get("url", "")
form_data = parameters.get("form_data", {})
if not self._validate_url(url, sandbox_context):
raise PermissionError("Access denied to URL")
return {
"tool": "submit_form",
"url": url,
"form_data": form_data,
"status_code": 200,
"response": "Form submitted successfully"
}
else:
raise ValueError(f"Unknown web tool: {tool_name}")
async def _execute_database_tool(
self,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute database MCP tools"""
if tool_name == "execute_query":
query = parameters.get("query", "")
# Query validation
if not self._validate_sql_query(query, sandbox_context):
raise PermissionError("Query not allowed")
return {
"tool": "execute_query",
"query": query,
"rows": [
{"id": 1, "name": "Example"},
{"id": 2, "name": "Data"}
],
"row_count": 2,
"execution_time_ms": 15
}
else:
raise ValueError(f"Unknown database tool: {tool_name}")
async def _execute_custom_tool(
self,
mcp_resource: MCPServerResource,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute custom MCP tool via WebSocket transport"""
# This would connect to the actual MCP server via WebSocket
# For now, simulate the execution
await asyncio.sleep(0.1) # Simulate network delay
return {
"tool": tool_name,
"parameters": parameters,
"result": f"Custom tool {tool_name} executed successfully",
"server_type": mcp_resource.server_config.server_type,
"execution_time_ms": 100
}
def _validate_file_path(
self,
file_path: str,
sandbox_context: Dict[str, Any]
) -> bool:
"""Validate file path for security"""
# Basic path traversal prevention
if ".." in file_path or file_path.startswith("/"):
return False
# Check allowed extensions
allowed_extensions = [".txt", ".md", ".json", ".py", ".js"]
if not any(file_path.endswith(ext) for ext in allowed_extensions):
return False
return True
def _validate_url(
self,
url: str,
sandbox_context: Dict[str, Any]
) -> bool:
"""Validate URL for security"""
# Basic URL validation
if not url.startswith(("http://", "https://")):
return False
# Block internal/localhost URLs if network isolation enabled
if sandbox_context.get("network_isolation", True):
if any(domain in url for domain in ["localhost", "127.0.0.1", "10.", "192.168."]):
return False
return True
def _validate_sql_query(
self,
query: str,
sandbox_context: Dict[str, Any]
) -> bool:
"""Validate SQL query for security"""
# Block dangerous SQL operations
dangerous_keywords = [
"DROP", "DELETE", "UPDATE", "INSERT", "CREATE", "ALTER",
"TRUNCATE", "EXEC", "EXECUTE", "xp_", "sp_"
]
query_upper = query.upper()
for keyword in dangerous_keywords:
if keyword in query_upper:
return False
return True
def _log_tool_execution(
self,
mcp_resource_id: str,
tool_name: str,
user_id: str,
status: str,
result: Any
) -> None:
"""Log tool execution for audit"""
log_entry = {
"timestamp": datetime.utcnow().isoformat(),
"mcp_resource_id": mcp_resource_id,
"tool_name": tool_name,
"user_id": user_id,
"status": status,
"result_summary": str(result)[:200] if result else None
}
self.audit_log.append(log_entry)
# Keep only last 1000 entries
if len(self.audit_log) > 1000:
self.audit_log = self.audit_log[-1000:]
async def _monitor_health(self, mcp_resource_id: str) -> None:
"""Monitor MCP server health"""
while mcp_resource_id in self.mcp_resources:
try:
mcp_resource = self.mcp_resources[mcp_resource_id]
# Simulate health check
await asyncio.sleep(30) # Check every 30 seconds
# Update health status
if mcp_resource.error_count > 10:
mcp_resource.status = MCPServerStatus.DEGRADED
elif mcp_resource.error_count > 50:
mcp_resource.status = MCPServerStatus.UNHEALTHY
else:
mcp_resource.status = MCPServerStatus.HEALTHY
mcp_resource.last_health_check = datetime.utcnow()
logger.debug(f"Health check for MCP resource {mcp_resource_id}: {mcp_resource.status}")
except Exception as e:
logger.error(f"Health check failed for MCP resource {mcp_resource_id}: {e}")
if mcp_resource_id in self.mcp_resources:
self.mcp_resources[mcp_resource_id].status = MCPServerStatus.UNHEALTHY
async def get_resource_status(
self,
mcp_resource_id: str,
capability_token: str
) -> Dict[str, Any]:
"""Get MCP resource status"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data:
raise PermissionError("Invalid capability token")
# Load MCP resource
mcp_resource = self.mcp_resources.get(mcp_resource_id)
if not mcp_resource:
raise ValueError(f"MCP resource not found: {mcp_resource_id}")
# Check tenant match
if token_data.get("tenant_id") != mcp_resource.tenant_domain:
raise PermissionError("Tenant mismatch")
return {
"resource_id": mcp_resource_id,
"name": mcp_resource.name,
"server_type": mcp_resource.server_config.server_type,
"status": mcp_resource.status,
"total_requests": mcp_resource.total_requests,
"error_count": mcp_resource.error_count,
"active_connections": mcp_resource.active_connections,
"last_health_check": mcp_resource.last_health_check.isoformat() if mcp_resource.last_health_check else None,
"available_tools": mcp_resource.server_config.available_tools
}
async def list_mcp_resources(
self,
capability_token: str,
tenant_domain: Optional[str] = None
) -> List[Dict[str, Any]]:
"""List available MCP resources"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data:
raise PermissionError("Invalid capability token")
tenant_filter = tenant_domain or token_data.get("tenant_id")
resources = []
for resource in self.mcp_resources.values():
if resource.tenant_domain == tenant_filter:
resources.append({
"resource_id": resource.id,
"name": resource.name,
"server_type": resource.server_config.server_type,
"status": resource.status,
"tool_count": len(resource.server_config.available_tools),
"created_at": resource.created_at.isoformat()
})
return resources
# Global MCP wrapper instance
_mcp_wrapper = None
def get_mcp_wrapper() -> SecureMCPWrapper:
"""Get the global MCP wrapper instance"""
global _mcp_wrapper
if _mcp_wrapper is None:
_mcp_wrapper = SecureMCPWrapper()
return _mcp_wrapper

View File

@@ -0,0 +1,296 @@
"""
GT 2.0 Model Router
Routes inference requests to appropriate providers based on model registry.
Integrates with provider factory for dynamic provider selection.
"""
import asyncio
import logging
from typing import Dict, Any, Optional, AsyncIterator
from datetime import datetime
from app.services.model_service import get_model_service
from app.providers import get_provider_factory
from app.core.backends import get_backend
from app.core.exceptions import ProviderError
logger = logging.getLogger(__name__)
class ModelRouter:
"""Routes model requests to appropriate providers"""
def __init__(self, tenant_id: Optional[str] = None):
self.tenant_id = tenant_id
# Use default model service for shared model registry (config sync writes to default)
# Note: Tenant isolation is handled via capability tokens, not separate databases
self.model_service = get_model_service(None)
self.provider_factory = None
self.backend_cache = {}
async def initialize(self):
"""Initialize model router"""
try:
self.provider_factory = await get_provider_factory()
logger.info(f"Model router initialized for tenant: {self.tenant_id or 'default'}")
except Exception as e:
logger.error(f"Failed to initialize model router: {e}")
raise
async def route_inference(
self,
model_id: str,
prompt: Optional[str] = None,
messages: Optional[list] = None,
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: Optional[str] = None,
tenant_id: Optional[str] = None,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""Route inference request to appropriate provider"""
# Get model configuration from registry
model_config = await self.model_service.get_model(model_id)
if not model_config:
raise ProviderError(f"Model {model_id} not found in registry")
provider = model_config["provider"]
# Track model usage
start_time = datetime.utcnow()
try:
# Route to configured endpoint (generic routing for any provider)
endpoint_url = model_config.get("endpoint")
if not endpoint_url:
raise ProviderError(f"No endpoint configured for model {model_id}")
result = await self._route_to_generic_endpoint(
endpoint_url, model_id, prompt, messages, temperature, max_tokens, stream, user_id, tenant_id, tools, tool_choice, **kwargs
)
# Calculate latency
latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
# Track successful usage
await self.model_service.track_model_usage(
model_id, success=True, latency_ms=latency_ms
)
return result
except Exception as e:
# Track failed usage
latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
await self.model_service.track_model_usage(
model_id, success=False, latency_ms=latency_ms
)
logger.error(f"Model routing failed for {model_id}: {e}")
raise
async def _route_to_groq(
self,
model_id: str,
prompt: Optional[str],
messages: Optional[list],
temperature: float,
max_tokens: int,
stream: bool,
user_id: Optional[str],
tenant_id: Optional[str],
tools: Optional[list],
tool_choice: Optional[str],
**kwargs
) -> Dict[str, Any]:
"""Route request to Groq backend"""
try:
backend = get_backend("groq_proxy")
if not backend:
raise ProviderError("Groq backend not available")
if messages:
return await backend.execute_inference_with_messages(
messages=messages,
model=model_id,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
user_id=user_id,
tenant_id=tenant_id,
tools=tools,
tool_choice=tool_choice
)
else:
return await backend.execute_inference(
prompt=prompt,
model=model_id,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
user_id=user_id,
tenant_id=tenant_id
)
except Exception as e:
logger.error(f"Groq routing failed: {e}")
raise ProviderError(f"Groq inference failed: {e}")
async def _route_to_external(
self,
model_id: str,
prompt: Optional[str],
messages: Optional[list],
temperature: float,
max_tokens: int,
stream: bool,
user_id: Optional[str],
tenant_id: Optional[str],
**kwargs
) -> Dict[str, Any]:
"""Route request to external provider"""
try:
if not self.provider_factory:
await self.initialize()
external_provider = self.provider_factory.get_provider("external")
if not external_provider:
raise ProviderError("External provider not available")
# For embedding models
if model_id == "bge-m3-embedding":
# Convert prompt/messages to text list
texts = []
if messages:
texts = [msg.get("content", "") for msg in messages if msg.get("content")]
elif prompt:
texts = [prompt]
return await external_provider.generate_embeddings(
model_id=model_id,
texts=texts
)
else:
raise ProviderError(f"External model {model_id} not supported for inference")
except Exception as e:
logger.error(f"External routing failed: {e}")
raise ProviderError(f"External inference failed: {e}")
async def _route_to_openai(
self,
model_id: str,
prompt: Optional[str],
messages: Optional[list],
temperature: float,
max_tokens: int,
stream: bool,
user_id: Optional[str],
tenant_id: Optional[str],
**kwargs
) -> Dict[str, Any]:
"""Route request to OpenAI provider"""
raise ProviderError("OpenAI provider not implemented - use Groq models instead")
async def _route_to_generic_endpoint(
self,
endpoint_url: str,
model_id: str,
prompt: Optional[str],
messages: Optional[list],
temperature: float,
max_tokens: int,
stream: bool,
user_id: Optional[str],
tenant_id: Optional[str],
tools: Optional[list],
tool_choice: Optional[str],
**kwargs
) -> Dict[str, Any]:
"""Route request to any configured endpoint using OpenAI-compatible API"""
import httpx
import time
try:
# Build OpenAI-compatible request
request_data = {
"model": model_id,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream
}
# Use messages if provided, otherwise convert prompt to messages
if messages:
request_data["messages"] = messages
elif prompt:
request_data["messages"] = [{"role": "user", "content": prompt}]
else:
raise ProviderError("Either messages or prompt must be provided")
# Add tools if provided
if tools:
request_data["tools"] = tools
if tool_choice:
request_data["tool_choice"] = tool_choice
# Add any additional parameters
request_data.update(kwargs)
logger.info(f"Routing request to endpoint: {endpoint_url}")
logger.debug(f"Request data: {request_data}")
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
endpoint_url,
json=request_data,
headers={"Content-Type": "application/json"}
)
if response.status_code != 200:
error_text = response.text
logger.error(f"Endpoint {endpoint_url} returned {response.status_code}: {error_text}")
raise ProviderError(f"Endpoint error: {response.status_code} - {error_text}")
result = response.json()
logger.debug(f"Endpoint response: {result}")
return result
except httpx.RequestError as e:
logger.error(f"Request to {endpoint_url} failed: {e}")
raise ProviderError(f"Connection to endpoint failed: {str(e)}")
except Exception as e:
logger.error(f"Generic endpoint routing failed: {e}")
raise ProviderError(f"Inference failed: {str(e)}")
async def list_available_models(self) -> list:
"""List all available models from registry"""
# Get all models (deployment status filtering available if needed)
models = await self.model_service.list_models()
return models
async def get_model_health(self, model_id: str) -> Dict[str, Any]:
"""Check health of specific model"""
return await self.model_service.check_model_health(model_id)
# Global model router instances per tenant
_model_routers = {}
async def get_model_router(tenant_id: Optional[str] = None) -> ModelRouter:
"""Get model router instance for tenant"""
global _model_routers
cache_key = tenant_id or "default"
if cache_key not in _model_routers:
router = ModelRouter(tenant_id)
await router.initialize()
_model_routers[cache_key] = router
return _model_routers[cache_key]

View File

@@ -0,0 +1,720 @@
"""
GT 2.0 Model Management Service - Stateless Version
Provides centralized model registry, versioning, deployment, and lifecycle management
for all AI models across the Resource Cluster using in-memory storage.
"""
import json
import time
import asyncio
from typing import Dict, Any, List, Optional, Union
from datetime import datetime, timedelta
from pathlib import Path
import hashlib
import httpx
import logging
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class ModelService:
"""Stateless model management service with in-memory registry"""
def __init__(self, tenant_id: Optional[str] = None):
self.tenant_id = tenant_id
self.settings = get_settings(tenant_id)
# In-memory model registry for stateless operation
self.model_registry: Dict[str, Dict[str, Any]] = {}
self.last_cache_update = 0
self.cache_ttl = 300 # 5 minutes
# Performance tracking (in-memory)
self.performance_metrics: Dict[str, Dict[str, Any]] = {}
# Initialize with default models synchronously
self._initialize_default_models_sync()
async def register_model(
self,
model_id: str,
name: str,
version: str,
provider: str,
model_type: str,
description: str = "",
capabilities: Dict[str, Any] = None,
parameters: Dict[str, Any] = None,
endpoint_url: str = None,
**kwargs
) -> Dict[str, Any]:
"""Register a new model in the in-memory registry"""
now = datetime.utcnow()
# Create or update model entry
model_entry = {
"id": model_id,
"name": name,
"version": version,
"provider": provider,
"model_type": model_type,
"description": description,
"capabilities": capabilities or {},
"parameters": parameters or {},
# Performance metrics
"max_tokens": kwargs.get("max_tokens", 4000),
"context_window": kwargs.get("context_window", 4000),
"cost_per_1k_tokens": kwargs.get("cost_per_1k_tokens", 0.0),
"latency_p50_ms": kwargs.get("latency_p50_ms", 0.0),
"latency_p95_ms": kwargs.get("latency_p95_ms", 0.0),
# Deployment status
"deployment_status": kwargs.get("deployment_status", "available"),
"health_status": kwargs.get("health_status", "unknown"),
"last_health_check": kwargs.get("last_health_check"),
# Usage tracking
"request_count": kwargs.get("request_count", 0),
"error_count": kwargs.get("error_count", 0),
"success_rate": kwargs.get("success_rate", 1.0),
# Lifecycle
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
"retired_at": kwargs.get("retired_at"),
# Configuration
"endpoint_url": endpoint_url,
"api_key_required": kwargs.get("api_key_required", True),
"rate_limits": kwargs.get("rate_limits", {})
}
self.model_registry[model_id] = model_entry
logger.info(f"Registered model: {model_id} ({name} v{version})")
return model_entry
async def get_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""Get model by ID"""
return self.model_registry.get(model_id)
async def list_models(
self,
provider: str = None,
model_type: str = None,
deployment_status: str = None,
health_status: str = None
) -> List[Dict[str, Any]]:
"""List models with optional filters"""
models = list(self.model_registry.values())
# Apply filters
if provider:
models = [m for m in models if m["provider"] == provider]
if model_type:
models = [m for m in models if m["model_type"] == model_type]
if deployment_status:
models = [m for m in models if m["deployment_status"] == deployment_status]
if health_status:
models = [m for m in models if m["health_status"] == health_status]
# Sort by created_at desc
models.sort(key=lambda x: x["created_at"], reverse=True)
return models
async def update_model_status(
self,
model_id: str,
deployment_status: str = None,
health_status: str = None
) -> bool:
"""Update model deployment and health status"""
model = self.model_registry.get(model_id)
if not model:
return False
if deployment_status:
model["deployment_status"] = deployment_status
if health_status:
model["health_status"] = health_status
model["last_health_check"] = datetime.utcnow().isoformat()
model["updated_at"] = datetime.utcnow().isoformat()
return True
async def track_model_usage(
self,
model_id: str,
success: bool = True,
latency_ms: float = None
):
"""Track model usage and performance metrics"""
model = self.model_registry.get(model_id)
if not model:
return
# Update usage counters
model["request_count"] += 1
if not success:
model["error_count"] += 1
# Calculate success rate
model["success_rate"] = (model["request_count"] - model["error_count"]) / model["request_count"]
# Update latency metrics (simple running average)
if latency_ms is not None:
if model["latency_p50_ms"] == 0:
model["latency_p50_ms"] = latency_ms
else:
# Simple exponential moving average
alpha = 0.1
model["latency_p50_ms"] = alpha * latency_ms + (1 - alpha) * model["latency_p50_ms"]
# P95 approximation (conservative estimate)
model["latency_p95_ms"] = max(model["latency_p95_ms"], latency_ms * 1.5)
model["updated_at"] = datetime.utcnow().isoformat()
async def retire_model(self, model_id: str, reason: str = "") -> bool:
"""Retire a model (mark as no longer available)"""
model = self.model_registry.get(model_id)
if not model:
return False
model["deployment_status"] = "retired"
model["retired_at"] = datetime.utcnow().isoformat()
model["updated_at"] = datetime.utcnow().isoformat()
if reason:
model["description"] += f"\n\nRetired: {reason}"
logger.info(f"Retired model: {model_id} - {reason}")
return True
async def check_model_health(self, model_id: str) -> Dict[str, Any]:
"""Check health of a specific model"""
model = await self.get_model(model_id)
if not model:
return {"healthy": False, "error": "Model not found"}
# Generic health check for any provider with endpoint
if "endpoint" in model and model["endpoint"]:
return await self._check_generic_model_health(model)
elif model["provider"] == "groq":
return await self._check_groq_model_health(model)
elif model["provider"] == "openai":
return await self._check_openai_model_health(model)
elif model["provider"] == "local":
return await self._check_local_model_health(model)
else:
return {"healthy": False, "error": f"No health check method for provider: {model['provider']}"}
async def _check_groq_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
"""Health check for Groq models"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.groq.com/openai/v1/models",
headers={"Authorization": f"Bearer {settings.groq_api_key}"},
timeout=10.0
)
if response.status_code == 200:
models = response.json()
model_ids = [m["id"] for m in models.get("data", [])]
is_available = model["id"] in model_ids
await self.update_model_status(
model["id"],
health_status="healthy" if is_available else "unhealthy"
)
return {
"healthy": is_available,
"latency_ms": response.elapsed.total_seconds() * 1000,
"available_models": len(model_ids)
}
else:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": f"API error: {response.status_code}"}
except Exception as e:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": str(e)}
async def _check_openai_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
"""Health check for OpenAI models"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {settings.openai_api_key}"},
timeout=10.0
)
if response.status_code == 200:
models = response.json()
model_ids = [m["id"] for m in models.get("data", [])]
is_available = model["id"] in model_ids
await self.update_model_status(
model["id"],
health_status="healthy" if is_available else "unhealthy"
)
return {
"healthy": is_available,
"latency_ms": response.elapsed.total_seconds() * 1000
}
else:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": f"API error: {response.status_code}"}
except Exception as e:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": str(e)}
async def _check_generic_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
"""Generic health check for any provider with configured endpoint"""
try:
endpoint_url = model.get("endpoint")
if not endpoint_url:
return {"healthy": False, "error": "No endpoint URL configured"}
# Try a simple health check by making a minimal request
async with httpx.AsyncClient(timeout=10.0) as client:
# For OpenAI-compatible endpoints, try a models list request
try:
# Try /v1/models endpoint first (common for OpenAI-compatible APIs)
models_url = endpoint_url.replace("/chat/completions", "/models").replace("/v1/chat/completions", "/v1/models")
response = await client.get(models_url)
if response.status_code == 200:
await self.update_model_status(model["id"], health_status="healthy")
return {
"healthy": True,
"provider": model.get("provider", "unknown"),
"latency_ms": 0, # Could measure actual latency
"last_check": datetime.utcnow().isoformat(),
"details": "Endpoint responding to models request"
}
except:
pass
# If models endpoint doesn't work, try a basic health endpoint
try:
health_url = endpoint_url.replace("/chat/completions", "/health").replace("/v1/chat/completions", "/health")
response = await client.get(health_url)
if response.status_code == 200:
await self.update_model_status(model["id"], health_status="healthy")
return {
"healthy": True,
"provider": model.get("provider", "unknown"),
"latency_ms": 0,
"last_check": datetime.utcnow().isoformat(),
"details": "Endpoint responding to health check"
}
except:
pass
# If neither works, assume healthy if endpoint is reachable at all
await self.update_model_status(model["id"], health_status="unknown")
return {
"healthy": True, # Assume healthy for generic endpoints
"provider": model.get("provider", "unknown"),
"latency_ms": 0,
"last_check": datetime.utcnow().isoformat(),
"details": "Generic endpoint - health check not available"
}
except Exception as e:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": f"Health check failed: {str(e)}"}
async def _check_local_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
"""Health check for local models"""
try:
endpoint_url = model.get("endpoint_url")
if not endpoint_url:
return {"healthy": False, "error": "No endpoint URL configured"}
async with httpx.AsyncClient() as client:
response = await client.get(
f"{endpoint_url}/health",
timeout=5.0
)
healthy = response.status_code == 200
await self.update_model_status(
model["id"],
health_status="healthy" if healthy else "unhealthy"
)
return {
"healthy": healthy,
"latency_ms": response.elapsed.total_seconds() * 1000
}
except Exception as e:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": str(e)}
async def bulk_health_check(self) -> Dict[str, Any]:
"""Check health of all registered models"""
models = await self.list_models()
health_results = {}
# Run health checks concurrently
tasks = []
for model in models:
task = asyncio.create_task(self.check_model_health(model["id"]))
tasks.append((model["id"], task))
for model_id, task in tasks:
try:
health_result = await task
health_results[model_id] = health_result
except Exception as e:
health_results[model_id] = {"healthy": False, "error": str(e)}
# Calculate overall health statistics
total_models = len(health_results)
healthy_models = sum(1 for result in health_results.values() if result.get("healthy", False))
return {
"total_models": total_models,
"healthy_models": healthy_models,
"unhealthy_models": total_models - healthy_models,
"health_percentage": (healthy_models / total_models * 100) if total_models > 0 else 0,
"individual_results": health_results
}
async def get_model_analytics(
self,
model_id: str = None,
timeframe_hours: int = 24
) -> Dict[str, Any]:
"""Get analytics for model usage and performance"""
models = await self.list_models()
if model_id:
models = [m for m in models if m["id"] == model_id]
analytics = {
"total_models": len(models),
"by_provider": {},
"by_type": {},
"performance_summary": {
"avg_latency_p50": 0,
"avg_success_rate": 0,
"total_requests": 0,
"total_errors": 0
},
"top_performers": [],
"models": models
}
total_latency = 0
total_success_rate = 0
total_requests = 0
total_errors = 0
for model in models:
# Provider statistics
provider = model["provider"]
if provider not in analytics["by_provider"]:
analytics["by_provider"][provider] = {"count": 0, "requests": 0}
analytics["by_provider"][provider]["count"] += 1
analytics["by_provider"][provider]["requests"] += model["request_count"]
# Type statistics
model_type = model["model_type"]
if model_type not in analytics["by_type"]:
analytics["by_type"][model_type] = {"count": 0, "requests": 0}
analytics["by_type"][model_type]["count"] += 1
analytics["by_type"][model_type]["requests"] += model["request_count"]
# Performance aggregation
total_latency += model["latency_p50_ms"]
total_success_rate += model["success_rate"]
total_requests += model["request_count"]
total_errors += model["error_count"]
# Calculate averages
if len(models) > 0:
analytics["performance_summary"]["avg_latency_p50"] = total_latency / len(models)
analytics["performance_summary"]["avg_success_rate"] = total_success_rate / len(models)
analytics["performance_summary"]["total_requests"] = total_requests
analytics["performance_summary"]["total_errors"] = total_errors
# Top performers (by success rate and low latency)
analytics["top_performers"] = sorted(
[m for m in models if m["request_count"] > 0],
key=lambda x: (x["success_rate"], -x["latency_p50_ms"]),
reverse=True
)[:5]
return analytics
async def _initialize_default_models(self):
"""Initialize registry with default models"""
# Groq models
groq_models = [
{
"model_id": "llama-3.1-405b-reasoning",
"name": "Llama 3.1 405B Reasoning",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Largest Llama model optimized for complex reasoning tasks",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 2.5,
"capabilities": {"reasoning": True, "function_calling": True, "streaming": True}
},
{
"model_id": "llama-3.1-70b-versatile",
"name": "Llama 3.1 70B Versatile",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Balanced Llama model for general-purpose tasks",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 0.8,
"capabilities": {"general": True, "function_calling": True, "streaming": True}
},
{
"model_id": "llama-3.1-8b-instant",
"name": "Llama 3.1 8B Instant",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Fast Llama model for quick responses",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 0.2,
"capabilities": {"fast": True, "streaming": True}
},
{
"model_id": "mixtral-8x7b-32768",
"name": "Mixtral 8x7B",
"version": "1.0",
"provider": "groq",
"model_type": "llm",
"description": "Mixtral model for balanced performance",
"max_tokens": 32768,
"context_window": 32768,
"cost_per_1k_tokens": 0.27,
"capabilities": {"general": True, "streaming": True}
}
]
for model_config in groq_models:
await self.register_model(**model_config)
logger.info("Initialized default model registry with in-memory storage")
def _initialize_default_models_sync(self):
"""Initialize registry with default models synchronously"""
# Groq models
groq_models = [
{
"model_id": "llama-3.1-405b-reasoning",
"name": "Llama 3.1 405B Reasoning",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Largest Llama model optimized for complex reasoning tasks",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 2.5,
"capabilities": {"reasoning": True, "function_calling": True, "streaming": True}
},
{
"model_id": "llama-3.1-70b-versatile",
"name": "Llama 3.1 70B Versatile",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Balanced Llama model for general-purpose tasks",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 0.8,
"capabilities": {"general": True, "function_calling": True, "streaming": True}
},
{
"model_id": "llama-3.1-8b-instant",
"name": "Llama 3.1 8B Instant",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Fast Llama model for quick responses",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 0.2,
"capabilities": {"fast": True, "streaming": True}
},
{
"model_id": "mixtral-8x7b-32768",
"name": "Mixtral 8x7B",
"version": "1.0",
"provider": "groq",
"model_type": "llm",
"description": "Mixtral model for balanced performance",
"max_tokens": 32768,
"context_window": 32768,
"cost_per_1k_tokens": 0.27,
"capabilities": {"general": True, "streaming": True}
},
{
"model_id": "groq/compound",
"name": "Groq Compound Model",
"version": "1.0",
"provider": "groq",
"model_type": "llm",
"description": "Groq compound AI model",
"max_tokens": 8000,
"context_window": 8000,
"cost_per_1k_tokens": 0.5,
"capabilities": {"general": True, "streaming": True}
}
]
for model_config in groq_models:
now = datetime.utcnow()
model_entry = {
"id": model_config["model_id"],
"name": model_config["name"],
"version": model_config["version"],
"provider": model_config["provider"],
"model_type": model_config["model_type"],
"description": model_config["description"],
"capabilities": model_config["capabilities"],
"parameters": {},
# Performance metrics
"max_tokens": model_config["max_tokens"],
"context_window": model_config["context_window"],
"cost_per_1k_tokens": model_config["cost_per_1k_tokens"],
"latency_p50_ms": 0.0,
"latency_p95_ms": 0.0,
# Deployment status
"deployment_status": "available",
"health_status": "unknown",
"last_health_check": None,
# Usage tracking
"request_count": 0,
"error_count": 0,
"success_rate": 1.0,
# Lifecycle
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
"retired_at": None,
# Configuration
"endpoint_url": None,
"api_key_required": True,
"rate_limits": {}
}
self.model_registry[model_config["model_id"]] = model_entry
logger.info("Initialized default model registry with in-memory storage (sync)")
async def register_or_update_model(
self,
model_id: str,
name: str,
version: str = "1.0",
provider: str = "unknown",
model_type: str = "llm",
endpoint: str = "",
api_key_name: str = None,
specifications: Dict[str, Any] = None,
capabilities: Dict[str, Any] = None,
cost: Dict[str, Any] = None,
description: str = "",
config: Dict[str, Any] = None,
status: Dict[str, Any] = None,
sync_timestamp: str = None
) -> Dict[str, Any]:
"""Register a new model or update existing one from admin cluster sync"""
specifications = specifications or {}
capabilities = capabilities or {}
cost = cost or {}
config = config or {}
status = status or {}
# Check if model exists
existing_model = self.model_registry.get(model_id)
if existing_model:
# Update existing model
existing_model.update({
"name": name,
"version": version,
"provider": provider,
"model_type": model_type,
"description": description,
"capabilities": capabilities,
"parameters": config,
"endpoint_url": endpoint,
"api_key_required": bool(api_key_name),
"max_tokens": specifications.get("max_tokens", existing_model.get("max_tokens", 4000)),
"context_window": specifications.get("context_window", existing_model.get("context_window", 4000)),
"cost_per_1k_tokens": cost.get("per_1k_input", existing_model.get("cost_per_1k_tokens", 0.0)),
"deployment_status": "deployed" if status.get("is_active", True) else "retired",
"updated_at": datetime.utcnow().isoformat()
})
if "bge-m3" in model_id.lower():
logger.info(f"Updated BGE-M3 model: endpoint_url={endpoint}, parameters={config}")
logger.debug(f"Updated model: {model_id}")
return existing_model
else:
# Register new model
return await self.register_model(
model_id=model_id,
name=name,
version=version,
provider=provider,
model_type=model_type,
description=description,
capabilities=capabilities,
parameters=config,
endpoint_url=endpoint,
max_tokens=specifications.get("max_tokens", 4000),
context_window=specifications.get("context_window", 4000),
cost_per_1k_tokens=cost.get("per_1k_input", 0.0),
api_key_required=bool(api_key_name)
)
def get_model_service(tenant_id: Optional[str] = None) -> ModelService:
"""Get tenant-isolated model service instance"""
return ModelService(tenant_id=tenant_id)
# Default model service for development/non-tenant operations
default_model_service = get_model_service()

View File

@@ -0,0 +1,931 @@
"""
GT 2.0 Resource Cluster - Service Manager
Orchestrates external web services (CTFd, Canvas LMS, Guacamole, JupyterHub)
with perfect tenant isolation and security.
"""
import asyncio
import json
import logging
import subprocess
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass, asdict
from pathlib import Path
try:
import docker
import kubernetes
from kubernetes import client, config
from kubernetes.client.rest import ApiException
DOCKER_AVAILABLE = True
KUBERNETES_AVAILABLE = True
except ImportError:
# For development containerization mode, these are optional
docker = None
kubernetes = None
client = None
config = None
ApiException = Exception
DOCKER_AVAILABLE = False
KUBERNETES_AVAILABLE = False
from app.core.config import get_settings
from app.core.security import verify_capability_token
from app.utils.encryption import encrypt_data, decrypt_data
logger = logging.getLogger(__name__)
@dataclass
class ServiceInstance:
"""Represents a deployed service instance"""
instance_id: str
tenant_id: str
service_type: str # 'ctfd', 'canvas', 'guacamole', 'jupyter'
status: str # 'starting', 'running', 'stopping', 'stopped', 'error'
endpoint_url: str
internal_port: int
external_port: int
namespace: str
deployment_name: str
service_name: str
ingress_name: str
sso_token: Optional[str] = None
created_at: datetime = datetime.utcnow()
last_heartbeat: datetime = datetime.utcnow()
resource_usage: Dict[str, Any] = None
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
data['last_heartbeat'] = self.last_heartbeat.isoformat()
return data
@dataclass
class ServiceTemplate:
"""Service deployment template configuration"""
service_type: str
image: str
ports: Dict[str, int]
environment: Dict[str, str]
volumes: List[Dict[str, str]]
resource_limits: Dict[str, str]
security_context: Dict[str, Any]
health_check: Dict[str, Any]
sso_config: Dict[str, Any]
class ServiceManager:
"""Manages external web service instances with Kubernetes orchestration"""
def __init__(self):
# Initialize Docker client if available
if DOCKER_AVAILABLE:
try:
self.docker_client = docker.from_env()
except Exception as e:
logger.warning(f"Could not initialize Docker client: {e}")
self.docker_client = None
else:
self.docker_client = None
self.k8s_client = None
self.active_instances: Dict[str, ServiceInstance] = {}
self.service_templates: Dict[str, ServiceTemplate] = {}
self.base_namespace = "gt-services"
self.storage_path = Path("/tmp/resource-cluster/services")
self.storage_path.mkdir(parents=True, exist_ok=True)
# Initialize Kubernetes client if available
if KUBERNETES_AVAILABLE:
try:
config.load_incluster_config() # If running in cluster
except:
try:
config.load_kube_config() # If running locally
except:
logger.warning("Could not load Kubernetes config - using mock mode")
self.k8s_client = client.ApiClient() if client else None
else:
logger.warning("Kubernetes not available - running in development containerization mode")
self._initialize_service_templates()
self._load_persistent_instances()
def _initialize_service_templates(self):
"""Initialize service deployment templates"""
# CTFd Template
self.service_templates['ctfd'] = ServiceTemplate(
service_type='ctfd',
image='ctfd/ctfd:3.6.0',
ports={'http': 8000},
environment={
'SECRET_KEY': '${TENANT_SECRET_KEY}',
'DATABASE_URL': 'sqlite:////data/ctfd.db',
'DATABASE_CACHE_URL': 'postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres:5432/gt2_tenants',
'UPLOAD_FOLDER': '/data/uploads',
'LOG_FOLDER': '/data/logs',
},
volumes=[
{'name': 'ctfd-data', 'mountPath': '/data', 'size': '5Gi'},
{'name': 'ctfd-uploads', 'mountPath': '/uploads', 'size': '2Gi'}
],
resource_limits={
'memory': '2Gi',
'cpu': '1000m'
},
security_context={
'runAsNonRoot': True,
'runAsUser': 1000,
'fsGroup': 1000,
'readOnlyRootFilesystem': False
},
health_check={
'path': '/health',
'port': 8000,
'initial_delay': 30,
'period': 10
},
sso_config={
'enabled': True,
'provider': 'oauth2',
'callback_path': '/auth/oauth/callback'
}
)
# Canvas LMS Template
self.service_templates['canvas'] = ServiceTemplate(
service_type='canvas',
image='instructure/canvas-lms:stable',
ports={'http': 3000},
environment={
'CANVAS_LMS_ADMIN_EMAIL': 'admin@${TENANT_DOMAIN}',
'CANVAS_LMS_ADMIN_PASSWORD': '${CANVAS_ADMIN_PASSWORD}',
'CANVAS_LMS_ACCOUNT_NAME': '${TENANT_NAME}',
'CANVAS_LMS_STATS_COLLECTION': 'opt_out',
'POSTGRES_PASSWORD': '${POSTGRES_PASSWORD}',
'DATABASE_CACHE_URL': 'postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres:5432/gt2_tenants'
},
volumes=[
{'name': 'canvas-data', 'mountPath': '/app/log', 'size': '10Gi'},
{'name': 'canvas-files', 'mountPath': '/app/public/files', 'size': '20Gi'}
],
resource_limits={
'memory': '4Gi',
'cpu': '2000m'
},
security_context={
'runAsNonRoot': True,
'runAsUser': 1000,
'fsGroup': 1000
},
health_check={
'path': '/health_check',
'port': 3000,
'initial_delay': 60,
'period': 15
},
sso_config={
'enabled': True,
'provider': 'saml',
'metadata_url': '/auth/saml/metadata'
}
)
# Guacamole Template
self.service_templates['guacamole'] = ServiceTemplate(
service_type='guacamole',
image='guacamole/guacamole:1.5.3',
ports={'http': 8080},
environment={
'GUACD_HOSTNAME': 'guacd',
'GUACD_PORT': '4822',
'MYSQL_HOSTNAME': 'mysql',
'MYSQL_PORT': '3306',
'MYSQL_DATABASE': 'guacamole_db',
'MYSQL_USER': 'guacamole_user',
'MYSQL_PASSWORD': '${MYSQL_PASSWORD}',
'GUAC_LOG_LEVEL': 'INFO'
},
volumes=[
{'name': 'guacamole-data', 'mountPath': '/config', 'size': '1Gi'},
{'name': 'guacamole-recordings', 'mountPath': '/recordings', 'size': '10Gi'}
],
resource_limits={
'memory': '1Gi',
'cpu': '500m'
},
security_context={
'runAsNonRoot': True,
'runAsUser': 1001,
'fsGroup': 1001
},
health_check={
'path': '/guacamole',
'port': 8080,
'initial_delay': 45,
'period': 10
},
sso_config={
'enabled': True,
'provider': 'openid',
'extension': 'guacamole-auth-openid'
}
)
# JupyterHub Template
self.service_templates['jupyter'] = ServiceTemplate(
service_type='jupyter',
image='jupyterhub/jupyterhub:4.0',
ports={'http': 8000},
environment={
'JUPYTERHUB_CRYPT_KEY': '${JUPYTERHUB_CRYPT_KEY}',
'CONFIGPROXY_AUTH_TOKEN': '${CONFIGPROXY_AUTH_TOKEN}',
'DOCKER_NETWORK_NAME': 'jupyterhub',
'DOCKER_NOTEBOOK_IMAGE': 'jupyter/datascience-notebook:lab-4.0.7'
},
volumes=[
{'name': 'jupyter-data', 'mountPath': '/srv/jupyterhub', 'size': '5Gi'},
{'name': 'docker-socket', 'mountPath': '/var/run/docker.sock', 'hostPath': '/var/run/docker.sock'}
],
resource_limits={
'memory': '2Gi',
'cpu': '1000m'
},
security_context={
'runAsNonRoot': False, # Needs Docker access
'runAsUser': 0,
'privileged': True
},
health_check={
'path': '/hub/health',
'port': 8000,
'initial_delay': 30,
'period': 15
},
sso_config={
'enabled': True,
'provider': 'oauth',
'authenticator_class': 'oauthenticator.generic.GenericOAuthenticator'
}
)
async def create_service_instance(
self,
tenant_id: str,
service_type: str,
config_overrides: Dict[str, Any] = None
) -> ServiceInstance:
"""Create a new service instance for a tenant"""
if service_type not in self.service_templates:
raise ValueError(f"Unsupported service type: {service_type}")
template = self.service_templates[service_type]
instance_id = f"{service_type}-{tenant_id}-{uuid.uuid4().hex[:8]}"
namespace = f"{self.base_namespace}-{tenant_id}"
# Generate unique ports
external_port = await self._get_available_port()
# Create service instance object
instance = ServiceInstance(
instance_id=instance_id,
tenant_id=tenant_id,
service_type=service_type,
status='starting',
endpoint_url=f"https://{service_type}.{tenant_id}.gt2.com",
internal_port=template.ports['http'],
external_port=external_port,
namespace=namespace,
deployment_name=f"{service_type}-{instance_id}",
service_name=f"{service_type}-service-{instance_id}",
ingress_name=f"{service_type}-ingress-{instance_id}",
resource_usage={'cpu': 0, 'memory': 0, 'storage': 0}
)
try:
# Create Kubernetes namespace if not exists
await self._create_namespace(namespace, tenant_id)
# Deploy the service
await self._deploy_service(instance, template, config_overrides)
# Generate SSO token
instance.sso_token = await self._generate_sso_token(instance)
# Store instance
self.active_instances[instance_id] = instance
await self._persist_instance(instance)
logger.info(f"Created {service_type} instance {instance_id} for tenant {tenant_id}")
return instance
except Exception as e:
logger.error(f"Failed to create service instance: {e}")
instance.status = 'error'
raise
async def _create_namespace(self, namespace: str, tenant_id: str):
"""Create Kubernetes namespace with proper labeling and network policies"""
if not self.k8s_client:
logger.info(f"Mock: Created namespace {namespace}")
return
v1 = client.CoreV1Api(self.k8s_client)
# Create namespace
namespace_manifest = client.V1Namespace(
metadata=client.V1ObjectMeta(
name=namespace,
labels={
'gt.tenant-id': tenant_id,
'gt.cluster': 'resource',
'gt.isolation': 'tenant'
},
annotations={
'gt.created-by': 'service-manager',
'gt.creation-time': datetime.utcnow().isoformat()
}
)
)
try:
v1.create_namespace(namespace_manifest)
logger.info(f"Created namespace: {namespace}")
except ApiException as e:
if e.status == 409: # Already exists
logger.info(f"Namespace {namespace} already exists")
else:
raise
# Apply network policy for tenant isolation
await self._apply_network_policy(namespace, tenant_id)
async def _apply_network_policy(self, namespace: str, tenant_id: str):
"""Apply network policy for tenant isolation"""
if not self.k8s_client:
logger.info(f"Mock: Applied network policy to {namespace}")
return
networking_v1 = client.NetworkingV1Api(self.k8s_client)
# Network policy that only allows:
# 1. Intra-namespace communication
# 2. Communication to system namespaces (DNS, etc.)
# 3. Egress to external services (for updates, etc.)
network_policy = client.V1NetworkPolicy(
metadata=client.V1ObjectMeta(
name=f"tenant-isolation-{tenant_id}",
namespace=namespace,
labels={'gt.tenant-id': tenant_id}
),
spec=client.V1NetworkPolicySpec(
pod_selector=client.V1LabelSelector(), # All pods in namespace
policy_types=['Ingress', 'Egress'],
ingress=[
# Allow ingress from same namespace
client.V1NetworkPolicyIngressRule(
from_=[client.V1NetworkPolicyPeer(
namespace_selector=client.V1LabelSelector(
match_labels={'name': namespace}
)
)]
),
# Allow ingress from ingress controller
client.V1NetworkPolicyIngressRule(
from_=[client.V1NetworkPolicyPeer(
namespace_selector=client.V1LabelSelector(
match_labels={'name': 'ingress-nginx'}
)
)]
)
],
egress=[
# Allow egress within namespace
client.V1NetworkPolicyEgressRule(
to=[client.V1NetworkPolicyPeer(
namespace_selector=client.V1LabelSelector(
match_labels={'name': namespace}
)
)]
),
# Allow DNS
client.V1NetworkPolicyEgressRule(
to=[client.V1NetworkPolicyPeer(
namespace_selector=client.V1LabelSelector(
match_labels={'name': 'kube-system'}
)
)],
ports=[client.V1NetworkPolicyPort(port=53, protocol='UDP')]
),
# Allow external HTTPS (for updates, etc.)
client.V1NetworkPolicyEgressRule(
ports=[
client.V1NetworkPolicyPort(port=443, protocol='TCP'),
client.V1NetworkPolicyPort(port=80, protocol='TCP')
]
)
]
)
)
try:
networking_v1.create_namespaced_network_policy(
namespace=namespace,
body=network_policy
)
logger.info(f"Applied network policy to namespace: {namespace}")
except ApiException as e:
if e.status == 409: # Already exists
logger.info(f"Network policy already exists in {namespace}")
else:
logger.error(f"Failed to create network policy: {e}")
raise
async def _deploy_service(
self,
instance: ServiceInstance,
template: ServiceTemplate,
config_overrides: Dict[str, Any] = None
):
"""Deploy service to Kubernetes cluster"""
if not self.k8s_client:
logger.info(f"Mock: Deployed {template.service_type} service")
instance.status = 'running'
return
# Prepare environment variables with tenant-specific values
environment = template.environment.copy()
if config_overrides:
environment.update(config_overrides.get('environment', {}))
# Substitute tenant-specific values
env_vars = []
for key, value in environment.items():
substituted_value = value.replace('${TENANT_ID}', instance.tenant_id)
substituted_value = substituted_value.replace('${TENANT_DOMAIN}', f"{instance.tenant_id}.gt2.com")
env_vars.append(client.V1EnvVar(name=key, value=substituted_value))
# Create volumes
volumes = []
volume_mounts = []
for vol_config in template.volumes:
vol_name = f"{vol_config['name']}-{instance.instance_id}"
volumes.append(client.V1Volume(
name=vol_name,
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
claim_name=vol_name
)
))
volume_mounts.append(client.V1VolumeMount(
name=vol_name,
mount_path=vol_config['mountPath']
))
# Create PVCs first
await self._create_persistent_volumes(instance, template)
# Create deployment
deployment = client.V1Deployment(
metadata=client.V1ObjectMeta(
name=instance.deployment_name,
namespace=instance.namespace,
labels={
'app': template.service_type,
'instance': instance.instance_id,
'gt.tenant-id': instance.tenant_id,
'gt.service-type': template.service_type
}
),
spec=client.V1DeploymentSpec(
replicas=1,
selector=client.V1LabelSelector(
match_labels={'instance': instance.instance_id}
),
template=client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(
labels={
'app': template.service_type,
'instance': instance.instance_id,
'gt.tenant-id': instance.tenant_id
}
),
spec=client.V1PodSpec(
containers=[client.V1Container(
name=template.service_type,
image=template.image,
ports=[client.V1ContainerPort(
container_port=template.ports['http']
)],
env=env_vars,
volume_mounts=volume_mounts,
resources=client.V1ResourceRequirements(
limits=template.resource_limits,
requests=template.resource_limits
),
security_context=client.V1SecurityContext(**template.security_context),
liveness_probe=client.V1Probe(
http_get=client.V1HTTPGetAction(
path=template.health_check['path'],
port=template.health_check['port']
),
initial_delay_seconds=template.health_check['initial_delay'],
period_seconds=template.health_check['period']
),
readiness_probe=client.V1Probe(
http_get=client.V1HTTPGetAction(
path=template.health_check['path'],
port=template.health_check['port']
),
initial_delay_seconds=10,
period_seconds=5
)
)],
volumes=volumes,
security_context=client.V1PodSecurityContext(
run_as_non_root=template.security_context.get('runAsNonRoot', True),
fs_group=template.security_context.get('fsGroup', 1000)
)
)
)
)
)
# Deploy to Kubernetes
apps_v1 = client.AppsV1Api(self.k8s_client)
apps_v1.create_namespaced_deployment(
namespace=instance.namespace,
body=deployment
)
# Create service
await self._create_service(instance, template)
# Create ingress
await self._create_ingress(instance, template)
logger.info(f"Deployed {template.service_type} service: {instance.deployment_name}")
async def _create_persistent_volumes(self, instance: ServiceInstance, template: ServiceTemplate):
"""Create persistent volume claims for the service"""
if not self.k8s_client:
return
v1 = client.CoreV1Api(self.k8s_client)
for vol_config in template.volumes:
if 'hostPath' in vol_config: # Skip host path volumes
continue
pvc_name = f"{vol_config['name']}-{instance.instance_id}"
pvc = client.V1PersistentVolumeClaim(
metadata=client.V1ObjectMeta(
name=pvc_name,
namespace=instance.namespace,
labels={
'app': template.service_type,
'instance': instance.instance_id,
'gt.tenant-id': instance.tenant_id
}
),
spec=client.V1PersistentVolumeClaimSpec(
access_modes=['ReadWriteOnce'],
resources=client.V1ResourceRequirements(
requests={'storage': vol_config['size']}
),
storage_class_name='fast-ssd' # Assuming SSD storage class
)
)
try:
v1.create_namespaced_persistent_volume_claim(
namespace=instance.namespace,
body=pvc
)
logger.info(f"Created PVC: {pvc_name}")
except ApiException as e:
if e.status != 409: # Ignore if already exists
raise
async def _create_service(self, instance: ServiceInstance, template: ServiceTemplate):
"""Create Kubernetes service for the instance"""
if not self.k8s_client:
return
v1 = client.CoreV1Api(self.k8s_client)
service = client.V1Service(
metadata=client.V1ObjectMeta(
name=instance.service_name,
namespace=instance.namespace,
labels={
'app': template.service_type,
'instance': instance.instance_id,
'gt.tenant-id': instance.tenant_id
}
),
spec=client.V1ServiceSpec(
selector={'instance': instance.instance_id},
ports=[client.V1ServicePort(
port=80,
target_port=template.ports['http'],
protocol='TCP'
)],
type='ClusterIP'
)
)
v1.create_namespaced_service(
namespace=instance.namespace,
body=service
)
logger.info(f"Created service: {instance.service_name}")
async def _create_ingress(self, instance: ServiceInstance, template: ServiceTemplate):
"""Create ingress for external access with TLS"""
if not self.k8s_client:
return
networking_v1 = client.NetworkingV1Api(self.k8s_client)
hostname = f"{template.service_type}.{instance.tenant_id}.gt2.com"
ingress = client.V1Ingress(
metadata=client.V1ObjectMeta(
name=instance.ingress_name,
namespace=instance.namespace,
labels={
'app': template.service_type,
'instance': instance.instance_id,
'gt.tenant-id': instance.tenant_id
},
annotations={
'kubernetes.io/ingress.class': 'nginx',
'cert-manager.io/cluster-issuer': 'letsencrypt-prod',
'nginx.ingress.kubernetes.io/ssl-redirect': 'true',
'nginx.ingress.kubernetes.io/force-ssl-redirect': 'true',
'nginx.ingress.kubernetes.io/auth-url': f'https://auth.{instance.tenant_id}.gt2.com/auth',
'nginx.ingress.kubernetes.io/auth-signin': f'https://auth.{instance.tenant_id}.gt2.com/signin'
}
),
spec=client.V1IngressSpec(
tls=[client.V1IngressTLS(
hosts=[hostname],
secret_name=f"{template.service_type}-tls-{instance.instance_id}"
)],
rules=[client.V1IngressRule(
host=hostname,
http=client.V1HTTPIngressRuleValue(
paths=[client.V1HTTPIngressPath(
path='/',
path_type='Prefix',
backend=client.V1IngressBackend(
service=client.V1IngressServiceBackend(
name=instance.service_name,
port=client.V1ServiceBackendPort(number=80)
)
)
)]
)
)]
)
)
networking_v1.create_namespaced_ingress(
namespace=instance.namespace,
body=ingress
)
logger.info(f"Created ingress: {instance.ingress_name} for {hostname}")
async def _get_available_port(self) -> int:
"""Get next available port for service"""
used_ports = {instance.external_port for instance in self.active_instances.values()}
port = 30000 # Start from NodePort range
while port in used_ports:
port += 1
return port
async def _generate_sso_token(self, instance: ServiceInstance) -> str:
"""Generate SSO token for iframe embedding"""
token_data = {
'tenant_id': instance.tenant_id,
'service_type': instance.service_type,
'instance_id': instance.instance_id,
'expires_at': (datetime.utcnow() + timedelta(hours=24)).isoformat(),
'permissions': ['read', 'write', 'admin']
}
# Encrypt the token data
encrypted_token = encrypt_data(json.dumps(token_data))
return encrypted_token.decode('utf-8')
async def get_service_instance(self, instance_id: str) -> Optional[ServiceInstance]:
"""Get service instance by ID"""
return self.active_instances.get(instance_id)
async def list_tenant_instances(self, tenant_id: str) -> List[ServiceInstance]:
"""List all service instances for a tenant"""
return [
instance for instance in self.active_instances.values()
if instance.tenant_id == tenant_id
]
async def stop_service_instance(self, instance_id: str) -> bool:
"""Stop a running service instance"""
instance = self.active_instances.get(instance_id)
if not instance:
return False
try:
instance.status = 'stopping'
if self.k8s_client:
# Delete Kubernetes resources
await self._cleanup_kubernetes_resources(instance)
instance.status = 'stopped'
logger.info(f"Stopped service instance: {instance_id}")
return True
except Exception as e:
logger.error(f"Failed to stop instance {instance_id}: {e}")
instance.status = 'error'
return False
async def _cleanup_kubernetes_resources(self, instance: ServiceInstance):
"""Clean up all Kubernetes resources for an instance"""
if not self.k8s_client:
return
apps_v1 = client.AppsV1Api(self.k8s_client)
v1 = client.CoreV1Api(self.k8s_client)
networking_v1 = client.NetworkingV1Api(self.k8s_client)
try:
# Delete deployment
apps_v1.delete_namespaced_deployment(
name=instance.deployment_name,
namespace=instance.namespace,
body=client.V1DeleteOptions()
)
# Delete service
v1.delete_namespaced_service(
name=instance.service_name,
namespace=instance.namespace,
body=client.V1DeleteOptions()
)
# Delete ingress
networking_v1.delete_namespaced_ingress(
name=instance.ingress_name,
namespace=instance.namespace,
body=client.V1DeleteOptions()
)
# Delete PVCs (optional - may want to preserve data)
# Note: In production, you might want to keep PVCs for data persistence
logger.info(f"Cleaned up Kubernetes resources for: {instance.instance_id}")
except ApiException as e:
logger.error(f"Error cleaning up resources: {e}")
raise
async def get_service_health(self, instance_id: str) -> Dict[str, Any]:
"""Get health status of a service instance"""
instance = self.active_instances.get(instance_id)
if not instance:
return {'status': 'not_found'}
if not self.k8s_client:
return {
'status': 'healthy',
'instance_status': instance.status,
'endpoint': instance.endpoint_url,
'last_check': datetime.utcnow().isoformat()
}
# Check Kubernetes pod status
v1 = client.CoreV1Api(self.k8s_client)
try:
pods = v1.list_namespaced_pod(
namespace=instance.namespace,
label_selector=f'instance={instance.instance_id}'
)
if not pods.items:
return {
'status': 'no_pods',
'instance_status': instance.status
}
pod = pods.items[0]
pod_status = 'unknown'
if pod.status.phase == 'Running':
# Check container status
if pod.status.container_statuses:
container_status = pod.status.container_statuses[0]
if container_status.ready:
pod_status = 'healthy'
else:
pod_status = 'unhealthy'
else:
pod_status = 'starting'
elif pod.status.phase == 'Pending':
pod_status = 'starting'
elif pod.status.phase == 'Failed':
pod_status = 'failed'
# Update instance heartbeat
instance.last_heartbeat = datetime.utcnow()
return {
'status': pod_status,
'instance_status': instance.status,
'pod_phase': pod.status.phase,
'endpoint': instance.endpoint_url,
'last_check': datetime.utcnow().isoformat(),
'restart_count': pod.status.container_statuses[0].restart_count if pod.status.container_statuses else 0
}
except ApiException as e:
logger.error(f"Failed to get health for {instance_id}: {e}")
return {
'status': 'error',
'error': str(e),
'instance_status': instance.status
}
async def _persist_instance(self, instance: ServiceInstance):
"""Persist instance data to disk"""
instance_file = self.storage_path / f"{instance.instance_id}.json"
with open(instance_file, 'w') as f:
json.dump(instance.to_dict(), f, indent=2)
def _load_persistent_instances(self):
"""Load persistent instances from disk on startup"""
if not self.storage_path.exists():
return
for instance_file in self.storage_path.glob("*.json"):
try:
with open(instance_file, 'r') as f:
data = json.load(f)
# Reconstruct instance object
instance = ServiceInstance(
instance_id=data['instance_id'],
tenant_id=data['tenant_id'],
service_type=data['service_type'],
status=data['status'],
endpoint_url=data['endpoint_url'],
internal_port=data['internal_port'],
external_port=data['external_port'],
namespace=data['namespace'],
deployment_name=data['deployment_name'],
service_name=data['service_name'],
ingress_name=data['ingress_name'],
sso_token=data.get('sso_token'),
created_at=datetime.fromisoformat(data['created_at']),
last_heartbeat=datetime.fromisoformat(data['last_heartbeat']),
resource_usage=data.get('resource_usage', {})
)
self.active_instances[instance.instance_id] = instance
logger.info(f"Loaded persistent instance: {instance.instance_id}")
except Exception as e:
logger.error(f"Failed to load instance from {instance_file}: {e}")
async def cleanup_orphaned_resources(self):
"""Clean up orphaned Kubernetes resources"""
if not self.k8s_client:
return
logger.info("Starting cleanup of orphaned resources...")
# This would implement logic to find and clean up:
# 1. Deployments without corresponding instances
# 2. Services without deployments
# 3. Unused PVCs
# 4. Expired certificates
# Implementation would query Kubernetes for resources with GT labels
# and cross-reference with active instances
logger.info("Cleanup completed")