GT AI OS Community v2.0.33 - Add NVIDIA NIM and Nemotron agents

- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2
- Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2
  - Made more general-purpose (flexible targets, expanded tools)
- Added nemotron-mini-agent.csv for fast local inference via Ollama
- Added nemotron-agent.csv for advanced reasoning via Ollama
- Added wiki page: Projects for NVIDIA NIMs and Nemotron
This commit is contained in:
HackWeasel
2025-12-12 17:47:14 -05:00
commit 310491a557
750 changed files with 232701 additions and 0 deletions

View File

@@ -0,0 +1,28 @@
# GT 2.0 Resource Cluster Environment Variables
# Environment
ENVIRONMENT=development
DEBUG=true
# Security
SECRET_KEY=your-secret-key-here-change-in-production
# External LLM Providers
GROQ_API_KEY=your-groq-api-key
OPENAI_API_KEY=your-openai-api-key
ANTHROPIC_API_KEY=your-anthropic-api-key
# Service Ports
SERVICE_PORT=8003
PROMETHEUS_PORT=9091
# Consul Service Discovery (optional)
CONSUL_HOST=localhost
CONSUL_PORT=8500
# Redis
REDIS_URL=redis://localhost:6379/1
# ChromaDB
CHROMADB_HOST=localhost
CHROMADB_PORT=8000

View File

@@ -0,0 +1,43 @@
# Resource Cluster Dockerfile
FROM python:3.11-slim
# Build arg for dev dependencies (default: false for production)
ARG INSTALL_DEV=false
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements (dev requirements may not exist in production builds)
COPY requirements.txt .
COPY requirements-dev.tx[t] ./
# Install Python dependencies
# Dev dependencies only installed when INSTALL_DEV=true
RUN pip install --no-cache-dir -r requirements.txt && \
if [ "$INSTALL_DEV" = "true" ] && [ -f requirements-dev.txt ]; then \
pip install --no-cache-dir -r requirements-dev.txt; \
fi
# Copy application code
COPY . .
# Create non-root user and data directory
RUN useradd -m -u 1000 appuser && \
mkdir -p /data/resource-cluster && \
chown -R appuser:appuser /app /data
USER appuser
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run the application with multiple workers for production
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]

View File

@@ -0,0 +1,38 @@
# Development Dockerfile for Resource Cluster
# This is separate from production Dockerfile
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
g++ \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements file
COPY requirements.txt .
# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Create a non-root user for development and set up data directory
RUN useradd -m -u 1000 devuser && \
mkdir -p /data/resource-cluster && \
chown -R devuser:devuser /app /data
USER devuser
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Development command (will be overridden by docker-compose)
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]

View File

@@ -0,0 +1,3 @@
"""
GT 2.0 Resource Cluster - Air-gapped resource management hub
"""

View File

@@ -0,0 +1,3 @@
"""
API endpoints for GT 2.0 Resource Cluster
"""

View File

@@ -0,0 +1,283 @@
"""
Agent orchestration API endpoints
Provides endpoints for:
- Individual agent execution by agent ID
- Agent execution status tracking
- workflows orchestration
- Capability-based authentication for all operations
"""
from fastapi import APIRouter, HTTPException, Depends, Path
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field
from datetime import datetime
import logging
import uuid
import asyncio
from app.core.security import capability_validator, CapabilityToken
from app.api.auth import verify_capability
router = APIRouter()
logger = logging.getLogger(__name__)
class AgentExecutionRequest(BaseModel):
"""Agent execution request for specific agent"""
input_data: Dict[str, Any] = Field(..., description="Input data for the agent")
parameters: Optional[Dict[str, Any]] = Field(default={}, description="Execution parameters")
timeout_seconds: Optional[int] = Field(default=300, description="Execution timeout")
priority: Optional[int] = Field(default=0, description="Execution priority")
class AgentExecutionResponse(BaseModel):
"""Agent execution response"""
execution_id: str = Field(..., description="Unique execution identifier")
agent_id: str = Field(..., description="Agent identifier")
status: str = Field(..., description="Execution status")
created_at: datetime = Field(..., description="Creation timestamp")
class AgentExecutionStatus(BaseModel):
"""Agent execution status"""
execution_id: str = Field(..., description="Execution identifier")
agent_id: str = Field(..., description="Agent identifier")
status: str = Field(..., description="Current status")
progress: Optional[float] = Field(default=None, description="Execution progress (0-100)")
result: Optional[Dict[str, Any]] = Field(default=None, description="Execution result if completed")
error: Optional[str] = Field(default=None, description="Error message if failed")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
completed_at: Optional[datetime] = Field(default=None, description="Completion timestamp")
# Global execution tracking
_active_executions: Dict[str, Dict[str, Any]] = {}
class AgentRequest(BaseModel):
"""Legacy agent execution request for backward compatibility"""
agent_type: str = Field(..., description="Type of agent to execute")
task: str = Field(..., description="Task for the agent")
context: Dict[str, Any] = Field(default={}, description="Additional context")
@router.post("/execute")
async def execute_agent(
request: AgentRequest,
token: CapabilityToken = Depends(verify_capability)
) -> Dict[str, Any]:
"""Execute an workflows"""
try:
from app.services.agent_orchestrator import AgentOrchestrator
# Initialize orchestrator
orchestrator = AgentOrchestrator()
# Create workflow based on request
workflow_config = {
"type": request.workflow_type or "sequential",
"agents": request.agents,
"input_data": request.input_data,
"configuration": request.configuration or {}
}
# Generate unique workflow ID
import uuid
workflow_id = f"workflow_{uuid.uuid4().hex[:8]}"
# Create and register workflow
workflow = await orchestrator.create_workflow(workflow_id, workflow_config)
# Execute the workflow
result = await orchestrator.execute_workflow(
workflow_id=workflow_id,
input_data=request.input_data,
capability_token=token.token
)
# codeql[py/stack-trace-exposure] returns workflow result dict, not error details
return {
"success": True,
"workflow_id": workflow_id,
"result": result,
"execution_time": result.get("execution_time", 0)
}
except ValueError as e:
logger.warning(f"Invalid agent request: {e}")
raise HTTPException(status_code=400, detail="Invalid request parameters")
except Exception as e:
logger.error(f"Agent execution failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Agent execution failed")
@router.post("/{agent_id}/execute", response_model=AgentExecutionResponse)
async def execute_agent_by_id(
agent_id: str = Path(..., description="Agent identifier"),
request: AgentExecutionRequest = ...,
token: CapabilityToken = Depends(verify_capability)
) -> AgentExecutionResponse:
"""Execute a specific agent by ID"""
try:
# Generate unique execution ID
execution_id = f"exec_{uuid.uuid4().hex[:12]}"
# Create execution record
execution_data = {
"execution_id": execution_id,
"agent_id": agent_id,
"status": "queued",
"input_data": request.input_data,
"parameters": request.parameters or {},
"timeout_seconds": request.timeout_seconds,
"priority": request.priority,
"created_at": datetime.utcnow(),
"updated_at": datetime.utcnow(),
"token": token.token
}
# Store execution
_active_executions[execution_id] = execution_data
# Start async execution
asyncio.create_task(_execute_agent_async(execution_id, agent_id, request, token))
logger.info(f"Started agent execution {execution_id} for agent {agent_id}")
return AgentExecutionResponse(
execution_id=execution_id,
agent_id=agent_id,
status="queued",
created_at=execution_data["created_at"]
)
except Exception as e:
logger.error(f"Failed to start agent execution: {e}")
raise HTTPException(status_code=500, detail="Failed to start agent execution")
@router.get("/executions/{execution_id}", response_model=AgentExecutionStatus)
async def get_execution_status(
execution_id: str = Path(..., description="Execution identifier"),
token: CapabilityToken = Depends(verify_capability)
) -> AgentExecutionStatus:
"""Get agent execution status"""
if execution_id not in _active_executions:
raise HTTPException(status_code=404, detail="Execution not found")
execution = _active_executions[execution_id]
return AgentExecutionStatus(
execution_id=execution_id,
agent_id=execution["agent_id"],
status=execution["status"],
progress=execution.get("progress"),
result=execution.get("result"),
error=execution.get("error"),
created_at=execution["created_at"],
updated_at=execution["updated_at"],
completed_at=execution.get("completed_at")
)
async def _execute_agent_async(execution_id: str, agent_id: str, request: AgentExecutionRequest, token: CapabilityToken):
"""Execute agent asynchronously"""
try:
# Update status to running
_active_executions[execution_id].update({
"status": "running",
"updated_at": datetime.utcnow(),
"progress": 0.0
})
# Simulate agent execution - replace with real agent orchestrator
await asyncio.sleep(0.5) # Initial setup
_active_executions[execution_id]["progress"] = 25.0
await asyncio.sleep(1.0) # Processing
_active_executions[execution_id]["progress"] = 50.0
await asyncio.sleep(1.0) # Generating result
_active_executions[execution_id]["progress"] = 75.0
# Simulate successful completion
result = {
"agent_id": agent_id,
"output": f"Agent {agent_id} completed successfully",
"processed_data": request.input_data,
"execution_time_seconds": 2.5,
"tokens_used": 150,
"cost": 0.002
}
# Update to completed
_active_executions[execution_id].update({
"status": "completed",
"progress": 100.0,
"result": result,
"updated_at": datetime.utcnow(),
"completed_at": datetime.utcnow()
})
logger.info(f"Agent execution {execution_id} completed successfully")
except asyncio.TimeoutError:
_active_executions[execution_id].update({
"status": "timeout",
"error": "Execution timeout",
"updated_at": datetime.utcnow(),
"completed_at": datetime.utcnow()
})
logger.error(f"Agent execution {execution_id} timed out")
except Exception as e:
_active_executions[execution_id].update({
"status": "failed",
"error": str(e),
"updated_at": datetime.utcnow(),
"completed_at": datetime.utcnow()
})
logger.error(f"Agent execution {execution_id} failed: {e}")
@router.get("/")
async def list_available_agents(
token: CapabilityToken = Depends(verify_capability)
) -> Dict[str, Any]:
"""List available agents for execution"""
# Return available agents - replace with real agent registry
available_agents = {
"coding_assistant": {
"id": "coding_assistant",
"name": "Coding Agent",
"description": "AI agent specialized in code generation and review",
"capabilities": ["code_generation", "code_review", "debugging"],
"status": "available"
},
"research_agent": {
"id": "research_agent",
"name": "Research Agent",
"description": "AI agent for information gathering and analysis",
"capabilities": ["web_search", "document_analysis", "summarization"],
"status": "available"
},
"data_analyst": {
"id": "data_analyst",
"name": "Data Analyst",
"description": "AI agent for data analysis and visualization",
"capabilities": ["data_processing", "visualization", "statistical_analysis"],
"status": "available"
}
}
return {
"agents": available_agents,
"total_count": len(available_agents),
"available_count": len([a for a in available_agents.values() if a["status"] == "available"])
}

View File

@@ -0,0 +1,20 @@
"""
Authentication utilities for API endpoints
"""
from fastapi import HTTPException, Header
from app.core.security import capability_validator, CapabilityToken
async def verify_capability(authorization: str = Header(None)) -> CapabilityToken:
"""Verify capability token from Authorization header"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing or invalid authorization header")
token_str = authorization.replace("Bearer ", "")
token = capability_validator.verify_capability_token(token_str)
if not token:
raise HTTPException(status_code=401, detail="Invalid capability token")
return token

View File

@@ -0,0 +1,333 @@
"""
Embedding Generation API Endpoints for GT 2.0 Resource Cluster
Provides OpenAI-compatible embedding API with:
- BGE-M3 model integration
- Capability-based authentication
- Rate limiting and quota management
- Batch processing support
- Stateless operation
GT 2.0 Architecture Principles:
- Perfect Tenant Isolation: Per-request capability validation
- Zero Downtime: Stateless design, no persistent state
- Self-Contained Security: JWT capability tokens
"""
from fastapi import APIRouter, HTTPException, Depends, Header, Request
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field
import logging
import asyncio
from datetime import datetime
from app.core.security import capability_validator, CapabilityToken
from app.api.auth import verify_capability
from app.services.embedding_service import get_embedding_service, EmbeddingService
from app.core.capability_auth import CapabilityError
router = APIRouter()
logger = logging.getLogger(__name__)
# OpenAI-compatible request/response models
class EmbeddingRequest(BaseModel):
"""OpenAI-compatible embedding request"""
input: List[str] = Field(..., description="List of texts to embed")
model: str = Field(default="BAAI/bge-m3", description="Embedding model name")
encoding_format: str = Field(default="float", description="Encoding format (float)")
dimensions: Optional[int] = Field(None, description="Number of dimensions (auto-detected)")
user: Optional[str] = Field(None, description="User identifier")
# BGE-M3 specific parameters
instruction: Optional[str] = Field(None, description="Instruction for query/document context")
normalize: bool = Field(True, description="Normalize embeddings to unit length")
class EmbeddingData(BaseModel):
"""Single embedding data object"""
object: str = "embedding"
embedding: List[float] = Field(..., description="Embedding vector")
index: int = Field(..., description="Index of the embedding in the input")
class EmbeddingUsage(BaseModel):
"""Token usage information"""
prompt_tokens: int = Field(..., description="Tokens in the input")
total_tokens: int = Field(..., description="Total tokens processed")
class EmbeddingResponse(BaseModel):
"""OpenAI-compatible embedding response"""
object: str = "list"
data: List[EmbeddingData] = Field(..., description="List of embedding objects")
model: str = Field(..., description="Model used for embeddings")
usage: EmbeddingUsage = Field(..., description="Token usage information")
# GT 2.0 specific metadata
gt2_metadata: Dict[str, Any] = Field(default_factory=dict, description="GT 2.0 processing metadata")
class EmbeddingModelInfo(BaseModel):
"""Embedding model information"""
model_name: str
dimensions: int
max_sequence_length: int
max_batch_size: int
supports_instruction: bool
normalization_default: bool
class ServiceHealthResponse(BaseModel):
"""Service health response"""
status: str
service: str
model: str
backend_ready: bool
last_request: Optional[str]
class BGE_M3_ConfigRequest(BaseModel):
"""BGE-M3 configuration update request"""
is_local_mode: bool = True
external_endpoint: Optional[str] = None
class BGE_M3_ConfigResponse(BaseModel):
"""BGE-M3 configuration response"""
is_local_mode: bool
current_endpoint: str
external_endpoint: Optional[str]
message: str
@router.post("/", response_model=EmbeddingResponse)
async def create_embeddings(
request: EmbeddingRequest,
token: CapabilityToken = Depends(verify_capability),
x_request_id: Optional[str] = Header(None)
) -> EmbeddingResponse:
"""
Generate embeddings for input texts using BGE-M3 model.
Compatible with OpenAI Embeddings API format.
Requires capability token with 'embeddings' permissions.
"""
try:
# Get embedding service
embedding_service = get_embedding_service()
# Generate embeddings
result = await embedding_service.generate_embeddings(
texts=request.input,
capability_token=token.token, # Pass raw token for verification
instruction=request.instruction,
request_id=x_request_id,
normalize=request.normalize
)
# Convert to OpenAI-compatible format
embedding_data = [
EmbeddingData(
embedding=embedding,
index=i
)
for i, embedding in enumerate(result.embeddings)
]
usage = EmbeddingUsage(
prompt_tokens=result.tokens_used,
total_tokens=result.tokens_used
)
response = EmbeddingResponse(
data=embedding_data,
model=result.model,
usage=usage,
gt2_metadata={
"request_id": result.request_id,
"tenant_id": result.tenant_id,
"processing_time_ms": result.processing_time_ms,
"dimensions": result.dimensions,
"created_at": result.created_at
}
)
logger.info(
f"Generated {len(result.embeddings)} embeddings "
f"for tenant {result.tenant_id} in {result.processing_time_ms}ms"
)
return response
except CapabilityError as e:
logger.warning(f"Capability error: {e}")
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
logger.warning(f"Invalid request: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/models", response_model=EmbeddingModelInfo)
async def get_model_info(
token: CapabilityToken = Depends(verify_capability)
) -> EmbeddingModelInfo:
"""
Get information about the embedding model.
Requires capability token with 'embeddings' permissions.
"""
try:
embedding_service = get_embedding_service()
model_info = await embedding_service.get_model_info()
return EmbeddingModelInfo(**model_info)
except CapabilityError as e:
logger.warning(f"Capability error: {e}")
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Error getting model info: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/stats")
async def get_service_stats(
token: CapabilityToken = Depends(verify_capability)
) -> Dict[str, Any]:
"""
Get embedding service statistics.
Requires capability token with 'admin' permissions.
"""
try:
embedding_service = get_embedding_service()
stats = await embedding_service.get_service_stats(token.token)
return stats
except CapabilityError as e:
logger.warning(f"Capability error: {e}")
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Error getting service stats: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/health", response_model=ServiceHealthResponse)
async def health_check() -> ServiceHealthResponse:
"""
Check embedding service health.
Public endpoint - no authentication required.
"""
try:
embedding_service = get_embedding_service()
health = await embedding_service.health_check()
return ServiceHealthResponse(**health)
except Exception as e:
logger.error(f"Health check failed: {e}")
raise HTTPException(status_code=500, detail="Service unhealthy")
@router.post("/config/bge-m3", response_model=BGE_M3_ConfigResponse)
async def update_bge_m3_config(
config_request: BGE_M3_ConfigRequest,
token: CapabilityToken = Depends(verify_capability)
) -> BGE_M3_ConfigResponse:
"""
Update BGE-M3 configuration for the embedding service.
This allows switching between local and external endpoints at runtime.
Requires capability token with 'admin' permissions.
"""
try:
# Verify admin permissions
if not token.payload.get("admin", False):
raise HTTPException(status_code=403, detail="Admin permissions required")
embedding_service = get_embedding_service()
# Update the embedding backend configuration
backend = embedding_service.backend
await backend.update_endpoint_config(
is_local_mode=config_request.is_local_mode,
external_endpoint=config_request.external_endpoint
)
logger.info(
f"BGE-M3 configuration updated by {token.payload.get('tenant_id', 'unknown')}: "
f"local_mode={config_request.is_local_mode}, "
f"external_endpoint={config_request.external_endpoint}"
)
return BGE_M3_ConfigResponse(
is_local_mode=config_request.is_local_mode,
current_endpoint=backend.embedding_endpoint,
external_endpoint=config_request.external_endpoint,
message=f"BGE-M3 configuration updated to {'local' if config_request.is_local_mode else 'external'} mode"
)
except CapabilityError as e:
logger.warning(f"Capability error: {e}")
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Error updating BGE-M3 config: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/config/bge-m3", response_model=BGE_M3_ConfigResponse)
async def get_bge_m3_config(
token: CapabilityToken = Depends(verify_capability)
) -> BGE_M3_ConfigResponse:
"""
Get current BGE-M3 configuration.
Requires capability token with 'embeddings' permissions.
"""
try:
embedding_service = get_embedding_service()
backend = embedding_service.backend
# Determine if currently in local mode
is_local_mode = "gentwo-vllm-embeddings" in backend.embedding_endpoint
return BGE_M3_ConfigResponse(
is_local_mode=is_local_mode,
current_endpoint=backend.embedding_endpoint,
external_endpoint=None, # We don't store this currently
message="Current BGE-M3 configuration"
)
except CapabilityError as e:
logger.warning(f"Capability error: {e}")
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Error getting BGE-M3 config: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# Legacy endpoint compatibility
@router.post("/embeddings", response_model=EmbeddingResponse)
async def create_embeddings_legacy(
request: EmbeddingRequest,
token: CapabilityToken = Depends(verify_capability),
x_request_id: Optional[str] = Header(None)
) -> EmbeddingResponse:
"""
Legacy endpoint for embedding generation.
Redirects to main embedding endpoint for compatibility.
"""
return await create_embeddings(request, token, x_request_id)

View File

@@ -0,0 +1,58 @@
"""
Health check endpoints for Resource Cluster
"""
from fastapi import APIRouter, HTTPException
from typing import Dict, Any
import logging
from app.core.backends import get_backend
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/")
async def health_check() -> Dict[str, Any]:
"""Basic health check"""
return {
"status": "healthy",
"service": "resource-cluster"
}
@router.get("/ready")
async def readiness_check() -> Dict[str, Any]:
"""Readiness check for Kubernetes"""
try:
# Check if critical backends are initialized
groq_backend = get_backend("groq_proxy")
return {
"status": "ready",
"backends": {
"groq_proxy": groq_backend is not None
}
}
except Exception as e:
logger.error(f"Readiness check failed: {e}")
raise HTTPException(status_code=503, detail="Service not ready")
@router.get("/backends")
async def backend_health() -> Dict[str, Any]:
"""Check health of all resource backends"""
health_status = {}
try:
# Check Groq backend
groq_backend = get_backend("groq_proxy")
groq_health = await groq_backend.check_health()
health_status["groq"] = groq_health
except Exception as e:
health_status["groq"] = {"error": str(e)}
return {
"status": "operational",
"backends": health_status
}

View File

@@ -0,0 +1,231 @@
"""
LLM Inference API endpoints
Provides capability-based access to LLM models with:
- Token validation and capability checking
- Multiple model support (Groq, OpenAI, Anthropic)
- Streaming and non-streaming responses
- Usage tracking and cost calculation
"""
from fastapi import APIRouter, HTTPException, Depends, Header, Request
from fastapi.responses import StreamingResponse
from typing import Dict, Any, Optional, List, Union
from pydantic import BaseModel, Field
import logging
from app.core.security import capability_validator, CapabilityToken
from app.core.backends import get_backend
from app.api.auth import verify_capability
from app.services.model_router import get_model_router
router = APIRouter()
logger = logging.getLogger(__name__)
class InferenceRequest(BaseModel):
"""LLM inference request supporting both prompt and messages format"""
prompt: Optional[str] = Field(default=None, description="Input prompt for the model")
messages: Optional[list] = Field(default=None, description="Conversation messages in OpenAI format")
model: str = Field(default="llama-3.1-70b-versatile", description="Model identifier")
temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
max_tokens: int = Field(default=4000, ge=1, le=32000, description="Maximum tokens to generate")
stream: bool = Field(default=False, description="Enable streaming response")
system_prompt: Optional[str] = Field(default=None, description="System prompt for context")
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="Available tools for function calling")
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(default=None, description="Tool choice strategy")
user_id: Optional[str] = Field(default=None, description="User identifier for tenant isolation")
tenant_id: Optional[str] = Field(default=None, description="Tenant identifier for isolation")
class InferenceResponse(BaseModel):
"""LLM inference response"""
content: str = Field(..., description="Generated text")
model: str = Field(..., description="Model used")
usage: Dict[str, Any] = Field(..., description="Token usage and cost information")
latency_ms: float = Field(..., description="Inference latency in milliseconds")
@router.post("/", response_model=InferenceResponse)
async def execute_inference(
request: InferenceRequest,
token: CapabilityToken = Depends(verify_capability)
) -> InferenceResponse:
"""Execute LLM inference with capability checking"""
# Validate request format
if not request.prompt and not request.messages:
raise HTTPException(
status_code=400,
detail="Either 'prompt' or 'messages' must be provided"
)
# Check if user has access to the requested model
resource = f"llm:{request.model.replace('-', '_')}"
if not capability_validator.check_resource_access(token, resource, "inference"):
# Try generic LLM access
if not capability_validator.check_resource_access(token, "llm:*", "inference"):
# Try groq specific access
if not capability_validator.check_resource_access(token, "llm:groq", "inference"):
raise HTTPException(
status_code=403,
detail=f"No capability for model: {request.model}"
)
# Get resource limits from token
limits = capability_validator.get_resource_limits(token, resource)
# Apply token limits
max_tokens = min(
request.max_tokens,
limits.get("max_tokens_per_request", request.max_tokens)
)
# Ensure tenant isolation
user_id = request.user_id or token.sub
tenant_id = request.tenant_id or token.tenant_id
try:
# Get model router for tenant
model_router = await get_model_router(tenant_id)
# Prepare prompt for routing
prompt = request.prompt
if request.system_prompt and prompt:
prompt = f"{request.system_prompt}\n\n{prompt}"
# Route inference request to appropriate provider
result = await model_router.route_inference(
model_id=request.model,
prompt=prompt,
messages=request.messages,
temperature=request.temperature,
max_tokens=max_tokens,
stream=False,
user_id=user_id,
tenant_id=tenant_id,
tools=request.tools,
tool_choice=request.tool_choice
)
return InferenceResponse(**result)
except Exception as e:
logger.error(f"Inference error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
@router.post("/stream")
async def stream_inference(
request: InferenceRequest,
token: CapabilityToken = Depends(verify_capability)
):
"""Stream LLM inference responses"""
# Validate request format
if not request.prompt and not request.messages:
raise HTTPException(
status_code=400,
detail="Either 'prompt' or 'messages' must be provided"
)
# Check streaming capability
resource = f"llm:{request.model.replace('-', '_')}"
if not capability_validator.check_resource_access(token, resource, "streaming"):
if not capability_validator.check_resource_access(token, "llm:*", "streaming"):
if not capability_validator.check_resource_access(token, "llm:groq", "streaming"):
raise HTTPException(
status_code=403,
detail="No streaming capability for this model"
)
# Ensure tenant isolation
user_id = request.user_id or token.sub
tenant_id = request.tenant_id or token.tenant_id
try:
# Get model router for tenant
model_router = await get_model_router(tenant_id)
# Prepare prompt for routing
prompt = request.prompt
if request.system_prompt and prompt:
prompt = f"{request.system_prompt}\n\n{prompt}"
# For now, fall back to groq backend for streaming (TODO: implement streaming in model router)
backend = get_backend("groq_proxy")
# Handle different request formats
if request.messages:
# Use messages format for streaming
async def generate():
async for chunk in backend._stream_inference_with_messages(
messages=request.messages,
model=request.model,
temperature=request.temperature,
max_tokens=request.max_tokens,
user_id=user_id,
tenant_id=tenant_id
):
yield f"data: {chunk}\n\n"
yield "data: [DONE]\n\n"
else:
# Use prompt format for streaming
async def generate():
async for chunk in backend._stream_inference(
messages=[{"role": "user", "content": prompt}],
model=request.model,
temperature=request.temperature,
max_tokens=request.max_tokens,
user_id=user_id,
tenant_id=tenant_id
):
yield f"data: {chunk}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # Disable nginx buffering
}
)
except Exception as e:
logger.error(f"Streaming inference error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")
@router.get("/models")
async def list_available_models(
token: CapabilityToken = Depends(verify_capability)
) -> Dict[str, Any]:
"""List available models based on user capabilities"""
try:
# Get model router for token's tenant
tenant_id = getattr(token, 'tenant_id', None)
model_router = await get_model_router(tenant_id)
# Get all available models from registry
all_models = await model_router.list_available_models()
# Filter based on user capabilities
accessible_models = []
for model in all_models:
resource = f"llm:{model['id'].replace('-', '_')}"
if capability_validator.check_resource_access(token, resource, "inference"):
accessible_models.append(model)
elif capability_validator.check_resource_access(token, "llm:*", "inference"):
accessible_models.append(model)
return {
"models": accessible_models,
"total": len(accessible_models)
}
except Exception as e:
logger.error(f"Error listing models: {e}")
raise HTTPException(status_code=500, detail="Failed to list models")

View File

@@ -0,0 +1,91 @@
"""
Internal API endpoints for service-to-service communication.
These endpoints are used by Control Panel to notify Resource Cluster
of configuration changes that require cache invalidation.
"""
from fastapi import APIRouter, Header, HTTPException, status
from typing import Optional
import logging
from app.core.config import get_settings
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/internal", tags=["Internal"])
settings = get_settings()
async def verify_service_auth(
x_service_auth: str = Header(None),
x_service_name: str = Header(None)
) -> bool:
"""Verify service-to-service authentication"""
if not x_service_auth or not x_service_name:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Service authentication required"
)
expected_token = settings.service_auth_token or "internal-service-token"
if x_service_auth != expected_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid service authentication"
)
allowed_services = ["control-panel-backend", "control-panel"]
if x_service_name not in allowed_services:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Service {x_service_name} not authorized"
)
return True
@router.post("/cache/api-keys/invalidate")
async def invalidate_api_key_cache(
tenant_domain: Optional[str] = None,
provider: Optional[str] = None,
x_service_auth: str = Header(None),
x_service_name: str = Header(None)
):
"""
Invalidate cached API keys.
Called by Control Panel when API keys are added, updated, or removed.
Args:
tenant_domain: If provided, only invalidate for this tenant
provider: If provided with tenant_domain, only invalidate this provider
"""
await verify_service_auth(x_service_auth, x_service_name)
from app.clients.api_key_client import get_api_key_client
client = get_api_key_client()
await client.invalidate_cache(tenant_domain=tenant_domain, provider=provider)
logger.info(
f"Cache invalidated: tenant={tenant_domain or 'all'}, provider={provider or 'all'}"
)
return {
"success": True,
"message": f"Cache invalidated for tenant={tenant_domain or 'all'}, provider={provider or 'all'}"
}
@router.get("/cache/api-keys/stats")
async def get_api_key_cache_stats(
x_service_auth: str = Header(None),
x_service_name: str = Header(None)
):
"""Get API key cache statistics for monitoring"""
await verify_service_auth(x_service_auth, x_service_name)
from app.clients.api_key_client import get_api_key_client
client = get_api_key_client()
return client.get_cache_stats()

View File

@@ -0,0 +1,366 @@
"""
LLM API endpoints for GT 2.0 Resource Cluster
Provides OpenAI-compatible API for LLM inference with:
- Multi-provider routing (Groq, OpenAI, Anthropic)
- Capability-based authentication
- Rate limiting and quota management
- Response streaming support
- Model availability management
GT 2.0 Security Features:
- JWT capability token authentication
- Tenant isolation in all operations
- No persistent state stored
- Stateless request processing
"""
import asyncio
import logging
from typing import Dict, Any, Optional
from datetime import datetime, timezone
from fastapi import APIRouter, HTTPException, Depends, Header, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel, Field
from app.core.capability_auth import verify_capability_token, get_current_capability
from app.services.llm_gateway import get_llm_gateway, LLMRequest, LLMGateway
logger = logging.getLogger(__name__)
router = APIRouter(tags=["llm"])
class ChatCompletionRequest(BaseModel):
"""OpenAI-compatible chat completion request"""
model: str = Field(..., description="Model to use for completion")
messages: list = Field(..., description="List of messages")
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Sampling temperature")
top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Nucleus sampling parameter")
frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="Frequency penalty")
presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="Presence penalty")
stop: Optional[list] = Field(None, description="Stop sequences")
stream: bool = Field(False, description="Whether to stream the response")
functions: Optional[list] = Field(None, description="Available functions for function calling")
function_call: Optional[Dict[str, Any]] = Field(None, description="Function call configuration")
user: Optional[str] = Field(None, description="User identifier for tracking")
class ModelListResponse(BaseModel):
"""Response for model list endpoint"""
object: str = "list"
data: list = Field(..., description="List of available models")
@router.post("/chat/completions")
async def create_chat_completion(
request: ChatCompletionRequest,
authorization: str = Header(..., description="Bearer token"),
capability_payload: Dict[str, Any] = Depends(get_current_capability),
gateway: LLMGateway = Depends(get_llm_gateway)
):
"""
Create a chat completion using the specified model.
Compatible with OpenAI API format for easy integration.
"""
try:
# Extract capability token from Authorization header
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
capability_token = authorization[7:] # Remove "Bearer " prefix
# Get user and tenant from capability payload
user_id = capability_payload.get("sub", "unknown")
tenant_id = capability_payload.get("tenant_id", "unknown")
# Create internal LLM request
llm_request = LLMRequest(
model=request.model,
messages=request.messages,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
stop=request.stop,
stream=request.stream,
functions=request.functions,
function_call=request.function_call,
user=request.user or user_id
)
# Process request through gateway
result = await gateway.chat_completion(
request=llm_request,
capability_token=capability_token,
user_id=user_id,
tenant_id=tenant_id
)
# Handle streaming vs non-streaming response
if request.stream:
# codeql[py/stack-trace-exposure] returns LLM response stream, not error details
return StreamingResponse(
result,
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/plain; charset=utf-8"
}
)
else:
return JSONResponse(content=result.to_dict())
except ValueError as e:
logger.warning(f"Invalid LLM request: {e}")
raise HTTPException(status_code=400, detail="Invalid request parameters")
except PermissionError as e:
logger.warning(f"Permission denied for LLM request: {e}")
raise HTTPException(status_code=403, detail="Permission denied")
except Exception as e:
logger.error(f"LLM request failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/models", response_model=ModelListResponse)
async def list_models(
capability_payload: Dict[str, Any] = Depends(get_current_capability),
gateway: LLMGateway = Depends(get_llm_gateway)
):
"""
List available models.
Returns models available to the user based on their capabilities.
"""
try:
# Get all available models
models = await gateway.get_available_models()
# Filter models based on user capabilities
user_capabilities = capability_payload.get("capabilities", [])
llm_capability = None
for cap in user_capabilities:
if cap.get("resource") == "llm":
llm_capability = cap
break
if llm_capability:
allowed_models = llm_capability.get("constraints", {}).get("allowed_models", [])
if allowed_models:
models = [model for model in models if model["id"] in allowed_models]
# Format response to match OpenAI API
formatted_models = []
for model in models:
formatted_models.append({
"id": model["id"],
"object": "model",
"created": int(datetime.now(timezone.utc).timestamp()),
"owned_by": f"gt2-{model['provider']}",
"permission": [],
"root": model["id"],
"parent": None,
"max_tokens": model["max_tokens"],
"context_window": model["context_window"],
"capabilities": model["capabilities"],
"supports_streaming": model["supports_streaming"],
"supports_functions": model["supports_functions"]
})
return ModelListResponse(data=formatted_models)
except Exception as e:
logger.error(f"Failed to list models: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve models")
@router.get("/models/{model_id}")
async def get_model(
model_id: str,
capability_payload: Dict[str, Any] = Depends(get_current_capability),
gateway: LLMGateway = Depends(get_llm_gateway)
):
"""
Get information about a specific model.
"""
try:
models = await gateway.get_available_models()
# Find the requested model
model = next((m for m in models if m["id"] == model_id), None)
if not model:
raise HTTPException(status_code=404, detail="Model not found")
# Check if user has access to this model
user_capabilities = capability_payload.get("capabilities", [])
llm_capability = None
for cap in user_capabilities:
if cap.get("resource") == "llm":
llm_capability = cap
break
if llm_capability:
allowed_models = llm_capability.get("constraints", {}).get("allowed_models", [])
if allowed_models and model_id not in allowed_models:
raise HTTPException(status_code=403, detail="Access to model not allowed")
# Format response
return {
"id": model["id"],
"object": "model",
"created": int(datetime.now(timezone.utc).timestamp()),
"owned_by": f"gt2-{model['provider']}",
"permission": [],
"root": model["id"],
"parent": None,
"max_tokens": model["max_tokens"],
"context_window": model["context_window"],
"capabilities": model["capabilities"],
"supports_streaming": model["supports_streaming"],
"supports_functions": model["supports_functions"]
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get model {model_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve model")
@router.get("/stats")
async def get_gateway_stats(
capability_payload: Dict[str, Any] = Depends(get_current_capability),
gateway: LLMGateway = Depends(get_llm_gateway)
):
"""
Get LLM gateway statistics.
Requires admin capability for detailed stats.
"""
try:
# Check if user has admin capabilities
user_capabilities = capability_payload.get("capabilities", [])
has_admin = any(
cap.get("resource") == "admin"
for cap in user_capabilities
)
stats = await gateway.get_gateway_stats()
if has_admin:
# Return full stats for admins
return stats
else:
# Return limited stats for regular users
return {
"total_requests": stats["total_requests"],
"success_rate": (
stats["successful_requests"] / max(stats["total_requests"], 1)
) * 100,
"available_models": len([
model for model in await gateway.get_available_models()
]),
"timestamp": stats["timestamp"]
}
except Exception as e:
logger.error(f"Failed to get gateway stats: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve statistics")
@router.get("/health")
async def health_check(
gateway: LLMGateway = Depends(get_llm_gateway)
):
"""
Health check endpoint for the LLM gateway.
Public endpoint for load balancer health checks.
"""
try:
health = await gateway.health_check()
if health["status"] == "healthy":
return JSONResponse(content=health, status_code=200)
else:
return JSONResponse(content=health, status_code=503)
except Exception as e:
logger.error(f"Health check failed: {e}")
return JSONResponse(
content={
"status": "error",
"error": "Health check failed",
"timestamp": datetime.now(timezone.utc).isoformat()
},
status_code=503
)
# Provider-specific endpoints for debugging and monitoring
@router.post("/providers/groq/test")
async def test_groq_connection(
capability_payload: Dict[str, Any] = Depends(get_current_capability),
gateway: LLMGateway = Depends(get_llm_gateway)
):
"""
Test connection to Groq API.
Requires admin capability.
"""
try:
# Check admin capability
user_capabilities = capability_payload.get("capabilities", [])
has_admin = any(
cap.get("resource") == "admin"
for cap in user_capabilities
)
if not has_admin:
raise HTTPException(status_code=403, detail="Admin capability required")
# Test simple request to Groq
test_request = LLMRequest(
model="llama3-8b-8192",
messages=[{"role": "user", "content": "Hello, this is a test."}],
max_tokens=10,
stream=False
)
# Use system capability token for testing
# TODO: Generate system token or use admin token
capability_token = "system-test-token"
user_id = "system-test"
tenant_id = "system"
result = await gateway._process_groq_request(
test_request,
"test-request-id",
gateway.models["llama3-8b-8192"]
)
return {
"status": "success",
"provider": "groq",
"response_received": bool(result),
"timestamp": datetime.now(timezone.utc).isoformat()
}
except Exception as e:
logger.error(f"Groq connection test failed: {e}")
return JSONResponse(
content={
"status": "error",
"provider": "groq",
"error": "Groq connection test failed",
"timestamp": datetime.now(timezone.utc).isoformat()
},
status_code=500
)

View File

@@ -0,0 +1,145 @@
"""
RAG (Retrieval-Augmented Generation) API endpoints
"""
from fastapi import APIRouter, HTTPException, Depends
from typing import Dict, Any, List
from pydantic import BaseModel, Field
import logging
from app.core.security import capability_validator, CapabilityToken
from app.api.auth import verify_capability
router = APIRouter()
logger = logging.getLogger(__name__)
class DocumentUploadRequest(BaseModel):
"""Document upload request"""
content: str = Field(..., description="Document content")
metadata: Dict[str, Any] = Field(default={}, description="Document metadata")
collection: str = Field(default="default", description="Collection name")
class SearchRequest(BaseModel):
"""Semantic search request"""
query: str = Field(..., description="Search query")
collection: str = Field(default="default", description="Collection to search")
top_k: int = Field(default=5, ge=1, le=100, description="Number of results")
@router.post("/upload")
async def upload_document(
request: DocumentUploadRequest,
token: CapabilityToken = Depends(verify_capability)
) -> Dict[str, Any]:
"""Upload document for RAG processing"""
try:
import uuid
import hashlib
# Generate document ID
doc_id = f"doc_{uuid.uuid4().hex[:8]}"
# Create content hash for deduplication
content_hash = hashlib.sha256(request.content.encode()).hexdigest()[:16]
# Process the document content
# In production, this would:
# 1. Split document into chunks
# 2. Generate embeddings using the embedding service
# 3. Store in ChromaDB collection
# For now, simulate document processing
word_count = len(request.content.split())
chunk_count = max(1, word_count // 200) # Simulate ~200 words per chunk
# Store metadata with content
document_data = {
"document_id": doc_id,
"content_hash": content_hash,
"content": request.content,
"metadata": request.metadata,
"collection": request.collection,
"tenant_id": token.tenant_id,
"user_id": token.user_id,
"word_count": word_count,
"chunk_count": chunk_count
}
# In production: Store in ChromaDB
# collection = chromadb_client.get_or_create_collection(request.collection)
# collection.add(documents=[request.content], ids=[doc_id], metadatas=[request.metadata])
logger.info(f"Document uploaded: {doc_id} ({word_count} words, {chunk_count} chunks)")
return {
"success": True,
"document_id": doc_id,
"content_hash": content_hash,
"collection": request.collection,
"word_count": word_count,
"chunk_count": chunk_count,
"message": "Document processed and stored for RAG retrieval"
}
except Exception as e:
logger.error(f"Document upload failed: {e}")
raise HTTPException(status_code=500, detail=f"Document upload failed: {str(e)}")
@router.post("/search")
async def semantic_search(
request: SearchRequest,
token: CapabilityToken = Depends(verify_capability)
) -> Dict[str, Any]:
"""Perform semantic search"""
try:
# In production, this would:
# 1. Generate embedding for the query using embedding service
# 2. Search ChromaDB collection for similar vectors
# 3. Return ranked results with metadata
# For now, simulate semantic search with keyword matching
import time
search_start = time.time()
# Simulate query processing
query_terms = request.query.lower().split()
# Mock search results
mock_results = [
{
"document_id": f"doc_result_{i}",
"content": f"Sample content matching '{request.query}' - result {i+1}",
"metadata": {
"source": f"document_{i+1}.txt",
"author": "System",
"created_at": "2025-01-01T00:00:00Z"
},
"similarity_score": 0.9 - (i * 0.1),
"chunk_id": f"chunk_{i+1}"
}
for i in range(min(request.top_k, 3)) # Return up to 3 mock results
]
search_time = time.time() - search_start
logger.info(f"Semantic search completed: query='{request.query}', results={len(mock_results)}, time={search_time:.3f}s")
return {
"success": True,
"query": request.query,
"collection": request.collection,
"results": mock_results,
"total_results": len(mock_results),
"search_time_ms": int(search_time * 1000),
"tenant_id": token.tenant_id,
"user_id": token.user_id
}
except Exception as e:
logger.error(f"Semantic search failed: {e}")
raise HTTPException(status_code=500, detail=f"Semantic search failed: {str(e)}")

View File

@@ -0,0 +1,125 @@
"""
Agent template library API endpoints
"""
from fastapi import APIRouter, HTTPException, Depends
from typing import Dict, Any, List
from pydantic import BaseModel, Field
import logging
from app.core.security import capability_validator, CapabilityToken
from app.api.auth import verify_capability
router = APIRouter()
logger = logging.getLogger(__name__)
class TemplateResponse(BaseModel):
"""Agent template response"""
template_id: str = Field(..., description="Template identifier")
name: str = Field(..., description="Template name")
description: str = Field(..., description="Template description")
category: str = Field(..., description="Template category")
configuration: Dict[str, Any] = Field(..., description="Template configuration")
@router.get("/", response_model=List[TemplateResponse])
async def list_templates(
token: CapabilityToken = Depends(verify_capability)
) -> List[TemplateResponse]:
"""List available agent templates"""
# Template library with predefined agent configurations
templates = [
TemplateResponse(
template_id="research_assistant",
name="Research & Analysis Agent",
description="Specialized in information synthesis and analysis",
category="research",
configuration={
"model": "llama-3.1-70b-versatile",
"temperature": 0.7,
"capabilities": ["llm:groq", "rag:semantic_search", "tools:web_search"]
}
),
TemplateResponse(
template_id="coding_assistant",
name="Software Development Agent",
description="Focused on code quality and best practices",
category="development",
configuration={
"model": "llama-3.1-70b-versatile",
"temperature": 0.3,
"capabilities": ["llm:groq", "tools:github_integration", "resources:documentation"]
}
)
]
return templates
@router.get("/{template_id}")
async def get_template(
template_id: str,
token: CapabilityToken = Depends(verify_capability)
) -> TemplateResponse:
"""Get specific agent template"""
try:
# Template library - in production this would be stored in database/filesystem
templates = {
"research_assistant": TemplateResponse(
template_id="research_assistant",
name="Research & Analysis Agent",
description="Specialized in information synthesis and analysis",
category="research",
configuration={
"model": "llama-3.1-70b-versatile",
"temperature": 0.7,
"capabilities": ["llm:groq", "rag:semantic_search", "tools:web_search"],
"system_prompt": "You are a research agent focused on thorough analysis and information synthesis.",
"max_tokens": 4000,
"tools": ["web_search", "document_analysis", "citation_formatter"]
}
),
"coding_assistant": TemplateResponse(
template_id="coding_assistant",
name="Software Development Agent",
description="Focused on code quality and best practices",
category="development",
configuration={
"model": "llama-3.1-70b-versatile",
"temperature": 0.3,
"capabilities": ["llm:groq", "tools:github_integration", "resources:documentation"],
"system_prompt": "You are a senior software engineer focused on code quality, best practices, and clean architecture.",
"max_tokens": 4000,
"tools": ["code_analysis", "github_integration", "documentation_generator"]
}
),
"creative_writing": TemplateResponse(
template_id="creative_writing",
name="Creative Writing Agent",
description="Specialized in creative content generation",
category="creative",
configuration={
"model": "llama-3.1-70b-versatile",
"temperature": 0.9,
"capabilities": ["llm:groq", "tools:style_guide"],
"system_prompt": "You are a creative writing agent focused on engaging, original content.",
"max_tokens": 4000,
"tools": ["style_analyzer", "plot_generator", "character_development"]
}
)
}
template = templates.get(template_id)
if not template:
raise HTTPException(status_code=404, detail=f"Template '{template_id}' not found")
return template
except HTTPException:
raise
except Exception as e:
logger.error(f"Template retrieval failed: {e}")
raise HTTPException(status_code=500, detail=f"Template retrieval failed: {str(e)}")

View File

@@ -0,0 +1,847 @@
"""
GT 2.0 Resource Cluster - AI Inference API (OpenAI Compatible Format)
IMPORTANT: This module maintains OpenAI API compatibility for AI model inference.
Other Resource Cluster endpoints use CB-REST standard.
"""
from typing import List, Optional, Dict, Any, Union
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from urllib.parse import urlparse
import logging
import json
import asyncio
import time
import uuid
logger = logging.getLogger(__name__)
def is_provider_endpoint(endpoint_url: str, provider_domains: List[str]) -> bool:
"""
Safely check if URL belongs to a specific provider.
Uses proper URL parsing to prevent bypass via URLs like
'evil.groq.com.attacker.com' or 'groq.com.evil.com'.
"""
try:
parsed = urlparse(endpoint_url)
hostname = (parsed.hostname or "").lower()
for domain in provider_domains:
domain = domain.lower()
# Match exact domain or subdomain (e.g., api.groq.com matches groq.com)
if hostname == domain or hostname.endswith(f".{domain}"):
return True
return False
except Exception:
return False
router = APIRouter(prefix="/ai", tags=["AI Inference"])
# OpenAI Compatible Request/Response Models
class ChatMessage(BaseModel):
role: str = Field(..., description="Message role: system, user, agent")
content: Optional[str] = Field(None, description="Message content")
name: Optional[str] = Field(None, description="Optional name for the message")
tool_calls: Optional[List[Dict[str, Any]]] = Field(None, description="Tool calls made by the agent")
tool_call_id: Optional[str] = Field(None, description="ID of the tool call this message is responding to")
class ChatCompletionRequest(BaseModel):
model: str = Field(..., description="Model identifier")
messages: List[ChatMessage] = Field(..., description="Chat messages")
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(None, ge=1, le=32000)
top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0)
n: Optional[int] = Field(1, ge=1, le=10)
stream: Optional[bool] = Field(False)
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
class ChatChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[str] = None
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
cost_cents: Optional[int] = Field(None, description="Total cost in cents")
class ModelUsageBreakdown(BaseModel):
"""Per-model token usage for Compound responses"""
model: str
prompt_tokens: int
completion_tokens: int
input_cost_dollars: Optional[float] = None
output_cost_dollars: Optional[float] = None
total_cost_dollars: Optional[float] = None
class ToolCostBreakdown(BaseModel):
"""Per-tool cost for Compound responses"""
tool: str
cost_dollars: float
class CostBreakdown(BaseModel):
"""Detailed cost breakdown for Compound models"""
models: List[ModelUsageBreakdown] = Field(default_factory=list)
tools: List[ToolCostBreakdown] = Field(default_factory=list)
total_cost_dollars: float = 0.0
total_cost_cents: int = 0
class UsageBreakdown(BaseModel):
"""Usage breakdown for Compound responses"""
models: List[Dict[str, Any]] = Field(default_factory=list)
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatChoice]
usage: Usage
system_fingerprint: Optional[str] = None
# Compound-specific fields (optional)
usage_breakdown: Optional[UsageBreakdown] = Field(None, description="Per-model usage for Compound models")
executed_tools: Optional[List[str]] = Field(None, description="Tools executed by Compound models")
cost_breakdown: Optional[CostBreakdown] = Field(None, description="Detailed cost breakdown for Compound models")
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]] = Field(..., description="Text to embed")
model: str = Field(..., description="Embedding model")
encoding_format: Optional[str] = Field("float", description="Encoding format")
user: Optional[str] = None
class EmbeddingData(BaseModel):
object: str = "embedding"
index: int
embedding: List[float]
class EmbeddingResponse(BaseModel):
object: str = "list"
data: List[EmbeddingData]
model: str
usage: Usage
class ImageGenerationRequest(BaseModel):
prompt: str = Field(..., description="Image description")
model: str = Field("dall-e-3", description="Image model")
n: Optional[int] = Field(1, ge=1, le=10)
size: Optional[str] = Field("1024x1024")
quality: Optional[str] = Field("standard")
style: Optional[str] = Field("vivid")
response_format: Optional[str] = Field("url")
user: Optional[str] = None
class ImageData(BaseModel):
url: Optional[str] = None
b64_json: Optional[str] = None
revised_prompt: Optional[str] = None
class ImageGenerationResponse(BaseModel):
created: int
data: List[ImageData]
# Import real LLM Gateway
from app.services.llm_gateway import LLMGateway
from app.services.admin_model_config_service import get_admin_model_service
# Initialize real LLM service
llm_gateway = LLMGateway()
admin_model_service = get_admin_model_service()
async def process_chat_completion(request: ChatCompletionRequest, tenant_id: str = None) -> ChatCompletionResponse:
"""Process chat completion using real LLM Gateway with admin configurations"""
try:
# Get model configuration from admin service
# First try by model_id string, then by UUID for new UUID-based selection
model_config = await admin_model_service.get_model_config(request.model)
if not model_config:
# Try looking up by UUID (frontend may send database UUID)
model_config = await admin_model_service.get_model_by_uuid(request.model)
if not model_config:
raise ValueError(f"Model {request.model} not found in admin configuration")
# Store the actual model_id for external API calls (in case request.model is a UUID)
actual_model_id = model_config.model_id
if not model_config.is_active:
raise ValueError(f"Model {actual_model_id} is not active")
# Tenant ID is required for API key lookup
if not tenant_id:
raise ValueError("Tenant ID is required for chat completions - no fallback to environment variables")
# Check tenant access - use actual model_id for access check
has_access = await admin_model_service.check_tenant_access(tenant_id, actual_model_id)
if not has_access:
raise ValueError(f"Tenant {tenant_id} does not have access to model {actual_model_id}")
# Get API key for the provider from Control Panel database (NO env fallback)
api_key = None
if model_config.provider == "groq":
api_key = await admin_model_service.get_groq_api_key(tenant_id=tenant_id)
# Route to configured endpoint (generic routing for any provider)
endpoint_url = getattr(model_config, 'endpoint', None)
if endpoint_url:
return await _call_generic_api(request, model_config, endpoint_url, tenant_id, actual_model_id)
elif model_config.provider == "groq":
return await _call_groq_api(request, model_config, api_key, actual_model_id)
else:
raise ValueError(f"Provider {model_config.provider} not implemented - no endpoint configured")
except Exception as e:
logger.error(f"Chat completion failed: {e}")
raise
async def _call_generic_api(request: ChatCompletionRequest, model_config, endpoint_url: str, tenant_id: str, actual_model_id: str = None) -> ChatCompletionResponse:
"""Call any OpenAI-compatible endpoint"""
# Use actual_model_id for external API calls (in case request.model is a UUID)
model_id_for_api = actual_model_id or model_config.model_id
import httpx
# Convert request to OpenAI format - translate GT 2.0 "agent" role to OpenAI "assistant" for external API compatibility
api_messages = []
for msg in request.messages:
# Translate GT 2.0 "agent" role to OpenAI-compatible "assistant" role for external APIs
external_role = "assistant" if msg.role == "agent" else msg.role
# Preserve all message fields including tool_call_id, tool_calls, etc.
api_msg = {
"role": external_role,
"content": msg.content
}
# Add tool_calls if present
if msg.tool_calls:
api_msg["tool_calls"] = msg.tool_calls
# Add tool_call_id if present (for tool response messages)
if msg.tool_call_id:
api_msg["tool_call_id"] = msg.tool_call_id
# Add name if present
if msg.name:
api_msg["name"] = msg.name
api_messages.append(api_msg)
api_request = {
"model": model_id_for_api, # Use actual model_id string, not UUID
"messages": api_messages,
"temperature": request.temperature,
"max_tokens": min(request.max_tokens or 1024, model_config.max_tokens),
"top_p": request.top_p,
"stream": False # Handle streaming separately
}
# Add tools if provided
if request.tools:
api_request["tools"] = request.tools
if request.tool_choice:
api_request["tool_choice"] = request.tool_choice
headers = {"Content-Type": "application/json"}
# Add API key based on endpoint - fetch from Control Panel DB (NO env fallback)
if is_provider_endpoint(endpoint_url, ["groq.com"]):
api_key = await admin_model_service.get_groq_api_key(tenant_id=tenant_id)
headers["Authorization"] = f"Bearer {api_key}"
elif is_provider_endpoint(endpoint_url, ["nvidia.com", "integrate.api.nvidia.com"]):
# Fetch NVIDIA API key from Control Panel
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="nvidia")
headers["Authorization"] = f"Bearer {key_info['api_key']}"
except APIKeyNotConfiguredError as e:
raise ValueError(f"NVIDIA API key not configured for tenant '{tenant_id}'. Please add your NVIDIA API key in the Control Panel.")
try:
async with httpx.AsyncClient() as client:
response = await client.post(
endpoint_url,
headers=headers,
json=api_request,
timeout=300.0 # 5 minutes - allows complex agent operations to complete
)
if response.status_code != 200:
raise ValueError(f"API error: {response.status_code} - {response.text}")
api_response = response.json()
except httpx.TimeoutException as e:
logger.error(f"API timeout after 300s for endpoint {endpoint_url}")
raise ValueError(f"API request timed out after 5 minutes - try reducing system prompt length or max_tokens")
except httpx.HTTPStatusError as e:
logger.error(f"API HTTP error: {e.response.status_code} - {e.response.text}")
raise ValueError(f"API HTTP error: {e.response.status_code}")
except Exception as e:
logger.error(f"API request failed: {type(e).__name__}: {e}")
raise ValueError(f"API request failed: {type(e).__name__}: {str(e)}")
# Convert API response to our format - translate OpenAI "assistant" back to GT 2.0 "agent"
choices = []
for choice in api_response["choices"]:
# Translate OpenAI-compatible "assistant" role back to GT 2.0 "agent" role
internal_role = "agent" if choice["message"]["role"] == "assistant" else choice["message"]["role"]
# Preserve all message fields from API response
message_data = {
"role": internal_role,
"content": choice["message"].get("content"),
}
# Add tool calls if present
if "tool_calls" in choice["message"]:
message_data["tool_calls"] = choice["message"]["tool_calls"]
# Add tool_call_id if present (for tool response messages)
if "tool_call_id" in choice["message"]:
message_data["tool_call_id"] = choice["message"]["tool_call_id"]
# Add name if present
if "name" in choice["message"]:
message_data["name"] = choice["message"]["name"]
choices.append(ChatChoice(
index=choice["index"],
message=ChatMessage(**message_data),
finish_reason=choice.get("finish_reason")
))
# Calculate cost_breakdown for Compound models
cost_breakdown = None
if "compound" in request.model.lower():
from app.core.backends.groq_proxy import GroqProxyBackend
proxy = GroqProxyBackend()
# Extract executed_tools from choices[0].message.executed_tools (Groq Compound format)
executed_tools_data = []
if "choices" in api_response and api_response["choices"]:
message = api_response["choices"][0].get("message", {})
raw_tools = message.get("executed_tools", [])
# Convert to format expected by _calculate_compound_cost: list of tool names/types
for tool in raw_tools:
if isinstance(tool, dict):
# Extract tool type (e.g., "search", "code_execution")
tool_type = tool.get("type", "search")
executed_tools_data.append(tool_type)
elif isinstance(tool, str):
executed_tools_data.append(tool)
if executed_tools_data:
logger.info(f"Compound executed_tools: {executed_tools_data}")
# Use actual per-model breakdown from usage_breakdown if available
usage_breakdown = api_response.get("usage_breakdown", {})
models_data = usage_breakdown.get("models", [])
if models_data:
logger.info(f"Compound using per-model breakdown: {len(models_data)} model calls")
cost_breakdown = proxy._calculate_compound_cost({
"usage_breakdown": {"models": models_data},
"executed_tools": executed_tools_data
})
else:
# Fallback: use aggregate tokens
usage = api_response.get("usage", {})
cost_breakdown = proxy._calculate_compound_cost({
"usage_breakdown": {
"models": [{
"model": api_response.get("model", request.model),
"usage": {
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0)
}
}]
},
"executed_tools": executed_tools_data
})
logger.info(f"Compound cost_breakdown (generic API): ${cost_breakdown.get('total_cost_dollars', 0):.6f}")
return ChatCompletionResponse(
id=api_response["id"],
created=api_response["created"],
model=api_response["model"],
choices=choices,
usage=Usage(
prompt_tokens=api_response["usage"]["prompt_tokens"],
completion_tokens=api_response["usage"]["completion_tokens"],
total_tokens=api_response["usage"]["total_tokens"]
),
cost_breakdown=cost_breakdown
)
async def _call_groq_api(request: ChatCompletionRequest, model_config, api_key: str, actual_model_id: str = None) -> ChatCompletionResponse:
"""Call Groq API directly"""
# Use actual_model_id for external API calls (in case request.model is a UUID)
model_id_for_api = actual_model_id or model_config.model_id
import httpx
# Convert request to Groq format - translate GT 2.0 "agent" role to OpenAI "assistant" for external API compatibility
groq_messages = []
for msg in request.messages:
# Translate GT 2.0 "agent" role to OpenAI-compatible "assistant" role for external APIs
external_role = "assistant" if msg.role == "agent" else msg.role
# Preserve all message fields including tool_call_id, tool_calls, etc.
groq_msg = {
"role": external_role,
"content": msg.content
}
# Add tool_calls if present
if msg.tool_calls:
groq_msg["tool_calls"] = msg.tool_calls
# Add tool_call_id if present (for tool response messages)
if msg.tool_call_id:
groq_msg["tool_call_id"] = msg.tool_call_id
# Add name if present
if msg.name:
groq_msg["name"] = msg.name
groq_messages.append(groq_msg)
groq_request = {
"model": model_id_for_api, # Use actual model_id string, not UUID
"messages": groq_messages,
"temperature": request.temperature,
"max_tokens": min(request.max_tokens or 1024, model_config.max_tokens),
"top_p": request.top_p,
"stream": False # Handle streaming separately
}
# Add tools if provided
if request.tools:
groq_request["tools"] = request.tools
if request.tool_choice:
groq_request["tool_choice"] = request.tool_choice
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.groq.com/openai/v1/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
json=groq_request,
timeout=300.0 # 5 minutes - allows complex agent operations to complete
)
if response.status_code != 200:
raise ValueError(f"Groq API error: {response.status_code} - {response.text}")
groq_response = response.json()
except httpx.TimeoutException as e:
logger.error(f"Groq API timeout after 300s for model {request.model}")
raise ValueError(f"Groq API request timed out after 5 minutes - try reducing system prompt length or max_tokens")
except httpx.HTTPStatusError as e:
logger.error(f"Groq API HTTP error: {e.response.status_code} - {e.response.text}")
raise ValueError(f"Groq API HTTP error: {e.response.status_code}")
except Exception as e:
logger.error(f"Groq API request failed: {type(e).__name__}: {e}")
raise ValueError(f"Groq API request failed: {type(e).__name__}: {str(e)}")
# Convert Groq response to our format - translate OpenAI "assistant" back to GT 2.0 "agent"
choices = []
for choice in groq_response["choices"]:
# Translate OpenAI-compatible "assistant" role back to GT 2.0 "agent" role
internal_role = "agent" if choice["message"]["role"] == "assistant" else choice["message"]["role"]
# Preserve all message fields from Groq response
message_data = {
"role": internal_role,
"content": choice["message"].get("content"),
}
# Add tool calls if present
if "tool_calls" in choice["message"]:
message_data["tool_calls"] = choice["message"]["tool_calls"]
# Add tool_call_id if present (for tool response messages)
if "tool_call_id" in choice["message"]:
message_data["tool_call_id"] = choice["message"]["tool_call_id"]
# Add name if present
if "name" in choice["message"]:
message_data["name"] = choice["message"]["name"]
choices.append(ChatChoice(
index=choice["index"],
message=ChatMessage(**message_data),
finish_reason=choice.get("finish_reason")
))
# Build response with Compound-specific fields if present
response_data = {
"id": groq_response["id"],
"created": groq_response["created"],
"model": groq_response["model"],
"choices": choices,
"usage": Usage(
prompt_tokens=groq_response["usage"]["prompt_tokens"],
completion_tokens=groq_response["usage"]["completion_tokens"],
total_tokens=groq_response["usage"]["total_tokens"]
)
}
# Extract Compound-specific fields if present (for accurate billing)
usage_breakdown_data = None
executed_tools_data = None
if "usage_breakdown" in groq_response.get("usage", {}):
usage_breakdown_data = groq_response["usage"]["usage_breakdown"]
response_data["usage_breakdown"] = UsageBreakdown(models=usage_breakdown_data)
logger.debug(f"Compound usage_breakdown: {usage_breakdown_data}")
# Check for executed_tools in the response (Compound models)
if "x_groq" in groq_response:
x_groq = groq_response["x_groq"]
if "usage" in x_groq and "executed_tools" in x_groq["usage"]:
executed_tools_data = x_groq["usage"]["executed_tools"]
response_data["executed_tools"] = executed_tools_data
logger.debug(f"Compound executed_tools: {executed_tools_data}")
# Calculate cost breakdown for Compound models using actual usage data
if usage_breakdown_data or executed_tools_data:
try:
from app.core.backends.groq_proxy import GroqProxyBackend
proxy = GroqProxyBackend()
cost_breakdown = proxy._calculate_compound_cost({
"usage_breakdown": {"models": usage_breakdown_data or []},
"executed_tools": executed_tools_data or []
})
response_data["cost_breakdown"] = CostBreakdown(
models=[ModelUsageBreakdown(**m) for m in cost_breakdown.get("models", [])],
tools=[ToolCostBreakdown(**t) for t in cost_breakdown.get("tools", [])],
total_cost_dollars=cost_breakdown.get("total_cost_dollars", 0.0),
total_cost_cents=cost_breakdown.get("total_cost_cents", 0)
)
logger.info(f"Compound cost_breakdown: ${cost_breakdown['total_cost_dollars']:.6f} ({cost_breakdown['total_cost_cents']} cents)")
except Exception as e:
logger.warning(f"Failed to calculate Compound cost breakdown: {e}")
# Fallback: If this is a Compound model and we don't have cost_breakdown yet,
# calculate it from standard token usage (Groq may not return detailed breakdown)
if "compound" in request.model.lower() and "cost_breakdown" not in response_data:
try:
from app.core.backends.groq_proxy import GroqProxyBackend
proxy = GroqProxyBackend()
# Build usage data from standard response tokens
# Match the structure expected by _calculate_compound_cost
usage = groq_response.get("usage", {})
cost_breakdown = proxy._calculate_compound_cost({
"usage_breakdown": {
"models": [{
"model": groq_response.get("model", request.model),
"usage": {
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0)
}
}]
},
"executed_tools": [] # No tool data available from standard response
})
response_data["cost_breakdown"] = CostBreakdown(
models=[ModelUsageBreakdown(**m) for m in cost_breakdown.get("models", [])],
tools=[],
total_cost_dollars=cost_breakdown.get("total_cost_dollars", 0.0),
total_cost_cents=cost_breakdown.get("total_cost_cents", 0)
)
logger.info(f"Compound cost_breakdown (from tokens): ${cost_breakdown['total_cost_dollars']:.6f} ({cost_breakdown['total_cost_cents']} cents)")
except Exception as e:
logger.warning(f"Failed to calculate Compound cost breakdown from tokens: {e}")
return ChatCompletionResponse(**response_data)
@router.post("/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(
request: ChatCompletionRequest,
http_request: Request
):
"""
OpenAI-compatible chat completions endpoint
This endpoint maintains full OpenAI API compatibility for seamless integration
with existing AI tools and libraries.
"""
try:
# Verify capability token from Authorization header
auth_header = http_request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
# Extract tenant ID from headers
tenant_id = http_request.headers.get("X-Tenant-ID")
# Handle streaming responses
if request.stream:
# codeql[py/stack-trace-exposure] returns LLM response stream, not error details
return StreamingResponse(
stream_chat_completion(request, tenant_id, auth_header),
media_type="text/plain"
)
# Regular response using real LLM Gateway
response = await process_chat_completion(request, tenant_id)
return response
except Exception as e:
logger.error(f"Chat completion error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/embeddings", response_model=EmbeddingResponse)
async def create_embeddings(
request: EmbeddingRequest,
http_request: Request
):
"""
OpenAI-compatible embeddings endpoint
Creates embeddings for the given input text(s).
"""
try:
# Verify capability token
auth_header = http_request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
# TODO: Implement embeddings via LLM Gateway (Day 3)
raise HTTPException(status_code=501, detail="Embeddings endpoint not yet implemented")
except Exception as e:
logger.error(f"Embedding creation error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/images/generations", response_model=ImageGenerationResponse)
async def create_image(
request: ImageGenerationRequest,
http_request: Request
):
"""
OpenAI-compatible image generation endpoint
Generates images from text prompts.
"""
try:
# Verify capability token
auth_header = http_request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
# Mock response (replace with actual image generation)
response = ImageGenerationResponse(
created=int(time.time()),
data=[
ImageData(
url=f"https://api.gt2.com/generated/{uuid.uuid4().hex}.png",
revised_prompt=request.prompt
)
for _ in range(request.n or 1)
]
)
return response
except Exception as e:
logger.error(f"Image generation error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/models")
async def list_models(http_request: Request):
"""
List available AI models (OpenAI compatible format)
"""
try:
# Verify capability token
auth_header = http_request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
models = {
"object": "list",
"data": [
{
"id": "gpt-4",
"object": "model",
"created": 1687882410,
"owned_by": "openai",
"permission": [],
"root": "gpt-4",
"parent": None
},
{
"id": "claude-3-sonnet",
"object": "model",
"created": 1687882410,
"owned_by": "anthropic",
"permission": [],
"root": "claude-3-sonnet",
"parent": None
},
{
"id": "llama-3.1-70b",
"object": "model",
"created": 1687882410,
"owned_by": "groq",
"permission": [],
"root": "llama-3.1-70b",
"parent": None
},
{
"id": "text-embedding-3-small",
"object": "model",
"created": 1687882410,
"owned_by": "openai",
"permission": [],
"root": "text-embedding-3-small",
"parent": None
}
]
}
return models
except Exception as e:
logger.error(f"List models error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
async def stream_chat_completion(request: ChatCompletionRequest, tenant_id: str, auth_header: str = None):
"""Stream chat completion responses using real AI providers"""
try:
from app.services.llm_gateway import LLMGateway, LLMRequest
gateway = LLMGateway()
# Create a unique request ID for this stream
response_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
created_time = int(time.time())
# Create LLM request with streaming enabled - translate GT 2.0 "agent" to OpenAI "assistant"
streaming_messages = []
for msg in request.messages:
# Translate GT 2.0 "agent" role to OpenAI-compatible "assistant" role for external APIs
external_role = "assistant" if msg.role == "agent" else msg.role
streaming_messages.append({"role": external_role, "content": msg.content})
llm_request = LLMRequest(
model=request.model,
messages=streaming_messages,
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
stream=True
)
# Extract real capability token from authorization header
capability_token = "dummy_capability_token"
user_id = "test_user"
if auth_header and auth_header.startswith("Bearer "):
capability_token = auth_header.replace("Bearer ", "")
# TODO: Extract user ID from token if possible
user_id = "test_user"
# Stream from the LLM Gateway
stream_generator = await gateway.chat_completion(
request=llm_request,
capability_token=capability_token,
user_id=user_id,
tenant_id=tenant_id
)
# Process streaming chunks
async for chunk_data in stream_generator:
# The chunk_data from Groq proxy should already be formatted
# Parse it if it's a string, or use directly if it's already a dict
if isinstance(chunk_data, str):
# Extract content from SSE format like "data: {content: 'text'}"
if chunk_data.startswith("data: "):
chunk_json = chunk_data[6:].strip()
if chunk_json and chunk_json != "[DONE]":
try:
chunk_dict = json.loads(chunk_json)
content = chunk_dict.get("content", "")
except json.JSONDecodeError:
content = ""
else:
content = ""
else:
content = chunk_data
else:
content = chunk_data.get("content", "")
if content:
# Format as OpenAI-compatible streaming chunk
stream_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": request.model,
"choices": [{
"index": 0,
"delta": {"content": content},
"finish_reason": None
}]
}
yield f"data: {json.dumps(stream_chunk)}\n\n"
# Send final chunk
final_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": request.model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {e}")
error_chunk = {
"error": {
"message": str(e),
"type": "server_error"
}
}
yield f"data: {json.dumps(error_chunk)}\n\n"

View File

@@ -0,0 +1,411 @@
"""
Integration Proxy API for GT 2.0
RESTful API for secure external service integration through the Resource Cluster.
Provides capability-based access control and sandbox restrictions.
"""
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from app.core.security import verify_capability_token
from app.services.integration_proxy import (
IntegrationProxyService, ProxyRequest, ProxyResponse, IntegrationConfig,
IntegrationType, SandboxLevel
)
router = APIRouter()
# Request/Response Models
class ExecuteIntegrationRequest(BaseModel):
"""Request to execute integration"""
integration_id: str = Field(..., description="Integration ID to execute")
method: str = Field(..., description="HTTP method (GET, POST, PUT, DELETE)")
endpoint: str = Field(..., description="Endpoint path or full URL")
headers: Optional[Dict[str, str]] = Field(None, description="Request headers")
data: Optional[Dict[str, Any]] = Field(None, description="Request data")
params: Optional[Dict[str, str]] = Field(None, description="Query parameters")
timeout_override: Optional[int] = Field(None, description="Override timeout in seconds")
class IntegrationExecutionResponse(BaseModel):
"""Response from integration execution"""
success: bool
status_code: int
data: Optional[Dict[str, Any]]
headers: Dict[str, str]
execution_time_ms: int
sandbox_applied: bool
restrictions_applied: List[str]
error_message: Optional[str]
class CreateIntegrationRequest(BaseModel):
"""Request to create integration configuration"""
name: str = Field(..., description="Human-readable integration name")
integration_type: str = Field(..., description="Type of integration")
base_url: str = Field(..., description="Base URL for the service")
authentication_method: str = Field(..., description="Authentication method")
auth_config: Dict[str, Any] = Field(..., description="Authentication configuration")
sandbox_level: str = Field("basic", description="Sandbox restriction level")
max_requests_per_hour: int = Field(1000, description="Rate limit per hour")
max_response_size_bytes: int = Field(10485760, description="Max response size (10MB default)")
timeout_seconds: int = Field(30, description="Request timeout")
allowed_methods: Optional[List[str]] = Field(None, description="Allowed HTTP methods")
allowed_endpoints: Optional[List[str]] = Field(None, description="Allowed endpoints")
blocked_endpoints: Optional[List[str]] = Field(None, description="Blocked endpoints")
allowed_domains: Optional[List[str]] = Field(None, description="Allowed domains")
class IntegrationConfigResponse(BaseModel):
"""Integration configuration response"""
id: str
name: str
integration_type: str
base_url: str
authentication_method: str
sandbox_level: str
max_requests_per_hour: int
max_response_size_bytes: int
timeout_seconds: int
allowed_methods: List[str]
allowed_endpoints: List[str]
blocked_endpoints: List[str]
allowed_domains: List[str]
is_active: bool
created_at: str
created_by: str
class IntegrationUsageResponse(BaseModel):
"""Integration usage analytics response"""
integration_id: str
total_requests: int
successful_requests: int
error_count: int
success_rate: float
avg_execution_time_ms: float
date_range: Dict[str, str]
# Dependency injection
async def get_integration_proxy_service() -> IntegrationProxyService:
"""Get integration proxy service"""
return IntegrationProxyService()
@router.post("/execute", response_model=IntegrationExecutionResponse)
async def execute_integration(
request: ExecuteIntegrationRequest,
authorization: str = Header(...),
proxy_service: IntegrationProxyService = Depends(get_integration_proxy_service)
):
"""
Execute external integration with capability-based access control.
- **integration_id**: ID of the configured integration
- **method**: HTTP method (GET, POST, PUT, DELETE)
- **endpoint**: API endpoint path or full URL
- **headers**: Optional request headers
- **data**: Optional request body data
- **params**: Optional query parameters
- **timeout_override**: Optional timeout override
"""
try:
# Create proxy request
proxy_request = ProxyRequest(
integration_id=request.integration_id,
method=request.method.upper(),
endpoint=request.endpoint,
headers=request.headers,
data=request.data,
params=request.params,
timeout_override=request.timeout_override
)
# Execute integration
response = await proxy_service.execute_integration(
request=proxy_request,
capability_token=authorization
)
return IntegrationExecutionResponse(
success=response.success,
status_code=response.status_code,
data=response.data,
headers=response.headers,
execution_time_ms=response.execution_time_ms,
sandbox_applied=response.sandbox_applied,
restrictions_applied=response.restrictions_applied,
error_message=response.error_message
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Integration execution failed: {str(e)}")
@router.get("", response_model=List[IntegrationConfigResponse])
async def list_integrations(
authorization: str = Header(...),
proxy_service: IntegrationProxyService = Depends(get_integration_proxy_service)
):
"""
List available integrations based on user capabilities.
Returns only integrations the user has permission to access.
"""
try:
integrations = await proxy_service.list_integrations(authorization)
return [
IntegrationConfigResponse(
id=config.id,
name=config.name,
integration_type=config.integration_type.value,
base_url=config.base_url,
authentication_method=config.authentication_method,
sandbox_level=config.sandbox_level.value,
max_requests_per_hour=config.max_requests_per_hour,
max_response_size_bytes=config.max_response_size_bytes,
timeout_seconds=config.timeout_seconds,
allowed_methods=config.allowed_methods,
allowed_endpoints=config.allowed_endpoints,
blocked_endpoints=config.blocked_endpoints,
allowed_domains=config.allowed_domains,
is_active=config.is_active,
created_at=config.created_at.isoformat(),
created_by=config.created_by
)
for config in integrations
]
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list integrations: {str(e)}")
@router.post("", response_model=IntegrationConfigResponse)
async def create_integration(
request: CreateIntegrationRequest,
authorization: str = Header(...),
proxy_service: IntegrationProxyService = Depends(get_integration_proxy_service)
):
"""
Create new integration configuration (admin only).
- **name**: Human-readable name for the integration
- **integration_type**: Type of integration (communication, development, etc.)
- **base_url**: Base URL for the external service
- **authentication_method**: oauth2, api_key, basic_auth, certificate
- **auth_config**: Authentication details (encrypted storage)
- **sandbox_level**: none, basic, restricted, strict
"""
try:
# Verify admin capability
token_data = await verify_capability_token(authorization)
if not token_data:
raise HTTPException(status_code=401, detail="Invalid capability token")
# Check admin permissions
if not any("admin" in str(cap) for cap in token_data.get("capabilities", [])):
raise HTTPException(status_code=403, detail="Admin capability required")
# Generate unique ID
import uuid
integration_id = str(uuid.uuid4())
# Create integration config
config = IntegrationConfig(
id=integration_id,
name=request.name,
integration_type=IntegrationType(request.integration_type.lower()),
base_url=request.base_url,
authentication_method=request.authentication_method,
auth_config=request.auth_config,
sandbox_level=SandboxLevel(request.sandbox_level.lower()),
max_requests_per_hour=request.max_requests_per_hour,
max_response_size_bytes=request.max_response_size_bytes,
timeout_seconds=request.timeout_seconds,
allowed_methods=request.allowed_methods or ["GET", "POST"],
allowed_endpoints=request.allowed_endpoints or [],
blocked_endpoints=request.blocked_endpoints or [],
allowed_domains=request.allowed_domains or [],
created_by=token_data.get("sub", "unknown")
)
# Store configuration
success = await proxy_service.store_integration_config(config)
if not success:
raise HTTPException(status_code=500, detail="Failed to store integration configuration")
return IntegrationConfigResponse(
id=config.id,
name=config.name,
integration_type=config.integration_type.value,
base_url=config.base_url,
authentication_method=config.authentication_method,
sandbox_level=config.sandbox_level.value,
max_requests_per_hour=config.max_requests_per_hour,
max_response_size_bytes=config.max_response_size_bytes,
timeout_seconds=config.timeout_seconds,
allowed_methods=config.allowed_methods,
allowed_endpoints=config.allowed_endpoints,
blocked_endpoints=config.blocked_endpoints,
allowed_domains=config.allowed_domains,
is_active=config.is_active,
created_at=config.created_at.isoformat(),
created_by=config.created_by
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to create integration: {str(e)}")
@router.get("/{integration_id}/usage", response_model=IntegrationUsageResponse)
async def get_integration_usage(
integration_id: str,
days: int = 30,
authorization: str = Header(...),
proxy_service: IntegrationProxyService = Depends(get_integration_proxy_service)
):
"""
Get usage analytics for specific integration.
- **days**: Number of days to analyze (default 30)
"""
try:
# Verify capability for this integration
token_data = await verify_capability_token(authorization)
if not token_data:
raise HTTPException(status_code=401, detail="Invalid capability token")
# Get usage analytics
usage = await proxy_service.get_integration_usage_analytics(integration_id, days)
return IntegrationUsageResponse(
integration_id=usage["integration_id"],
total_requests=usage["total_requests"],
successful_requests=usage["successful_requests"],
error_count=usage["error_count"],
success_rate=usage["success_rate"],
avg_execution_time_ms=usage["avg_execution_time_ms"],
date_range=usage["date_range"]
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get usage analytics: {str(e)}")
# Integration type and sandbox level catalogs
@router.get("/catalog/types")
async def get_integration_types():
"""Get available integration types for UI builders"""
return {
"integration_types": [
{
"value": "communication",
"label": "Communication",
"description": "Slack, Teams, Discord integration"
},
{
"value": "development",
"label": "Development",
"description": "GitHub, GitLab, Jira integration"
},
{
"value": "project_management",
"label": "Project Management",
"description": "Asana, Monday.com integration"
},
{
"value": "database",
"label": "Database",
"description": "PostgreSQL, MySQL, MongoDB connectors"
},
{
"value": "custom_api",
"label": "Custom API",
"description": "Custom REST/GraphQL APIs"
},
{
"value": "webhook",
"label": "Webhook",
"description": "Outbound webhook calls"
}
]
}
@router.get("/catalog/sandbox-levels")
async def get_sandbox_levels():
"""Get available sandbox levels for UI builders"""
return {
"sandbox_levels": [
{
"value": "none",
"label": "No Restrictions",
"description": "Trusted integrations with full access"
},
{
"value": "basic",
"label": "Basic Restrictions",
"description": "Basic timeout and size limits"
},
{
"value": "restricted",
"label": "Restricted Access",
"description": "Limited API calls and data access"
},
{
"value": "strict",
"label": "Maximum Security",
"description": "Strict restrictions and monitoring"
}
]
}
@router.get("/catalog/auth-methods")
async def get_authentication_methods():
"""Get available authentication methods for UI builders"""
return {
"auth_methods": [
{
"value": "api_key",
"label": "API Key",
"description": "Simple API key authentication",
"fields": ["api_key", "key_header", "key_prefix"]
},
{
"value": "basic_auth",
"label": "Basic Authentication",
"description": "Username and password authentication",
"fields": ["username", "password"]
},
{
"value": "oauth2",
"label": "OAuth 2.0",
"description": "OAuth 2.0 bearer token authentication",
"fields": ["access_token", "refresh_token", "client_id", "client_secret"]
},
{
"value": "certificate",
"label": "Certificate",
"description": "Client certificate authentication",
"fields": ["cert_path", "key_path", "ca_path"]
}
]
}

View File

@@ -0,0 +1,424 @@
"""
GT 2.0 MCP Tool Executor
Handles execution of MCP tools from agents. This is the main endpoint
that receives tool calls from the tenant backend and routes them to
the appropriate MCP servers with proper authentication and rate limiting.
"""
from typing import Dict, Any, List, Optional, Union
from fastapi import APIRouter, HTTPException, Header
from pydantic import BaseModel, Field
import logging
import asyncio
from datetime import datetime
# Removed: from app.core.security import verify_capability_token
from app.services.mcp_rag_server import mcp_rag_server
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/mcp", tags=["mcp_execution"])
# Request/Response Models
class MCPToolCall(BaseModel):
"""MCP tool call request"""
tool_name: str = Field(..., description="Name of the tool to execute")
server_name: str = Field(..., description="MCP server that provides the tool")
parameters: Dict[str, Any] = Field(..., description="Tool parameters")
class MCPToolResult(BaseModel):
"""MCP tool execution result"""
success: bool
tool_name: str
server_name: str
execution_time_ms: float
result: Dict[str, Any]
error: Optional[str] = None
timestamp: str
class MCPBatchRequest(BaseModel):
"""Request for executing multiple MCP tools"""
tool_calls: List[MCPToolCall] = Field(..., min_items=1, max_items=10)
class MCPBatchResponse(BaseModel):
"""Response for batch tool execution"""
results: List[MCPToolResult]
success_count: int
error_count: int
total_execution_time_ms: float
# Rate limiting (simple in-memory counter)
_rate_limits = {}
def check_rate_limit(user_id: str, server_name: str) -> bool:
"""Simple rate limiting check"""
# TODO: Implement proper rate limiting with Redis or similar
key = f"{user_id}:{server_name}"
current_time = datetime.now().timestamp()
if key not in _rate_limits:
_rate_limits[key] = []
# Remove old entries (older than 1 minute)
_rate_limits[key] = [t for t in _rate_limits[key] if current_time - t < 60]
# Check if under limit (60 requests per minute)
if len(_rate_limits[key]) >= 60:
return False
# Add current request
_rate_limits[key].append(current_time)
return True
@router.post("/tool", response_model=MCPToolResult)
async def execute_mcp_tool(
request: MCPToolCall,
x_tenant_domain: str = Header(..., description="Tenant domain for isolation"),
x_user_id: str = Header(..., description="User ID for authorization"),
agent_context: Optional[Dict[str, Any]] = None
):
"""
Execute a single MCP tool.
This is the main endpoint that agents use to execute MCP tools.
It handles rate limiting and routing to the appropriate MCP server.
User authentication is handled by the tenant backend before reaching here.
"""
start_time = datetime.now()
try:
# Validate required headers
if not x_user_id or not x_tenant_domain:
raise HTTPException(
status_code=400,
detail="Missing required authentication headers"
)
# Check rate limiting
if not check_rate_limit(x_user_id, request.server_name):
raise HTTPException(
status_code=429,
detail="Rate limit exceeded for MCP server"
)
# Route to appropriate MCP server (no capability token needed)
if request.server_name == "rag_server":
result = await mcp_rag_server.handle_tool_call(
tool_name=request.tool_name,
parameters=request.parameters,
tenant_domain=x_tenant_domain,
user_id=x_user_id,
agent_context=agent_context
)
else:
raise HTTPException(
status_code=404,
detail=f"Unknown MCP server: {request.server_name}"
)
# Calculate execution time
end_time = datetime.now()
execution_time = (end_time - start_time).total_seconds() * 1000
# Check if tool execution was successful
success = "error" not in result
error_message = result.get("error") if not success else None
logger.info(f"🔧 MCP Tool executed: {request.tool_name} ({execution_time:.2f}ms) - {'' if success else ''}")
return MCPToolResult(
success=success,
tool_name=request.tool_name,
server_name=request.server_name,
execution_time_ms=execution_time,
result=result,
error=error_message,
timestamp=end_time.isoformat()
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error executing MCP tool {request.tool_name}: {e}")
end_time = datetime.now()
execution_time = (end_time - start_time).total_seconds() * 1000
return MCPToolResult(
success=False,
tool_name=request.tool_name,
server_name=request.server_name,
execution_time_ms=execution_time,
result={},
error=f"Tool execution failed: {str(e)}",
timestamp=end_time.isoformat()
)
class MCPExecuteRequest(BaseModel):
"""Direct execution request format used by RAG orchestrator"""
server_id: str = Field(..., description="Server ID (rag_server)")
tool_name: str = Field(..., description="Tool name to execute")
parameters: Dict[str, Any] = Field(..., description="Tool parameters")
tenant_domain: str = Field(..., description="Tenant domain")
user_id: str = Field(..., description="User ID")
agent_context: Optional[Dict[str, Any]] = Field(None, description="Agent context with dataset info")
@router.post("/execute")
async def execute_mcp_direct(request: MCPExecuteRequest):
"""
Direct execution endpoint used by RAG orchestrator.
Simplified without capability tokens - uses user context for authorization.
"""
logger.info(f"🔧 Direct MCP execution request: server={request.server_id}, tool={request.tool_name}, tenant={request.tenant_domain}, user={request.user_id}")
logger.debug(f"📝 Tool parameters: {request.parameters}")
try:
# Map server_id to server_name
server_mapping = {
"rag_server": "rag_server"
}
server_name = server_mapping.get(request.server_id)
if not server_name:
logger.error(f"❌ Unknown server_id: {request.server_id}")
raise HTTPException(
status_code=400,
detail=f"Unknown server_id: {request.server_id}"
)
logger.info(f"🎯 Mapped server_id '{request.server_id}' → server_name '{server_name}'")
# Create simplified tool call request
tool_call = MCPToolCall(
tool_name=request.tool_name,
server_name=server_name,
parameters=request.parameters
)
# Execute the tool with agent context
result = await execute_mcp_tool(
request=tool_call,
x_tenant_domain=request.tenant_domain,
x_user_id=request.user_id,
agent_context=request.agent_context
)
# Return result in format expected by RAG orchestrator
if result.success:
return result.result
else:
return {
"success": False,
"error": result.error
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Direct MCP execution failed: {e}")
return {
"success": False,
"error": "MCP execution failed"
}
@router.post("/batch", response_model=MCPBatchResponse)
async def execute_mcp_batch(
request: MCPBatchRequest,
x_tenant_domain: str = Header(..., description="Tenant domain for isolation"),
x_user_id: str = Header(..., description="User ID for authorization")
):
"""
Execute multiple MCP tools in batch.
Useful for agents that need to call multiple tools simultaneously
for more efficient execution.
"""
batch_start_time = datetime.now()
try:
# Validate required headers
if not x_user_id or not x_tenant_domain:
raise HTTPException(
status_code=400,
detail="Missing required authentication headers"
)
# Execute all tool calls concurrently
tasks = []
for tool_call in request.tool_calls:
# Create individual tool call request
individual_request = MCPToolCall(
tool_name=tool_call.tool_name,
server_name=tool_call.server_name,
parameters=tool_call.parameters
)
# Create task for concurrent execution
task = execute_mcp_tool(
request=individual_request,
x_tenant_domain=x_tenant_domain,
x_user_id=x_user_id
)
tasks.append(task)
# Execute all tools concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
tool_results = []
success_count = 0
error_count = 0
for result in results:
if isinstance(result, Exception):
# Handle exceptions from individual tool calls
tool_results.append(MCPToolResult(
success=False,
tool_name="unknown",
server_name="unknown",
execution_time_ms=0,
result={},
error=str(result),
timestamp=datetime.now().isoformat()
))
error_count += 1
else:
tool_results.append(result)
if result.success:
success_count += 1
else:
error_count += 1
# Calculate total execution time
batch_end_time = datetime.now()
total_execution_time = (batch_end_time - batch_start_time).total_seconds() * 1000
return MCPBatchResponse(
results=tool_results,
success_count=success_count,
error_count=error_count,
total_execution_time_ms=total_execution_time
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error executing MCP batch: {e}")
raise HTTPException(status_code=500, detail=f"Batch execution failed: {str(e)}")
@router.post("/rag/{tool_name}")
async def execute_rag_tool(
tool_name: str,
parameters: Dict[str, Any],
x_tenant_domain: Optional[str] = Header(None),
x_user_id: Optional[str] = Header(None)
):
"""
Direct endpoint for executing RAG tools.
Convenience endpoint for common RAG operations without
needing to specify server name.
"""
# Create standard tool call request
tool_call = MCPToolCall(
tool_name=tool_name,
server_name="rag_server",
parameters=parameters
)
return await execute_mcp_tool(
request=tool_call,
x_tenant_domain=x_tenant_domain,
x_user_id=x_user_id
)
@router.post("/conversation/{tool_name}")
async def execute_conversation_tool(
tool_name: str,
parameters: Dict[str, Any],
x_tenant_domain: Optional[str] = Header(None),
x_user_id: Optional[str] = Header(None)
):
"""
Direct endpoint for executing conversation search tools.
Convenience endpoint for common conversation search operations
without needing to specify server name.
"""
# Create standard tool call request
tool_call = MCPToolCall(
tool_name=tool_name,
server_name="conversation_server",
parameters=parameters
)
return await execute_mcp_tool(
request=tool_call,
x_tenant_domain=x_tenant_domain,
x_user_id=x_user_id
)
@router.get("/status")
async def get_executor_status(
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID", description="Tenant ID for context")
):
"""
Get status of the MCP executor and connected servers.
Returns health information and statistics about MCP tool execution.
"""
try:
# Calculate basic statistics
total_requests = sum(len(requests) for requests in _rate_limits.values())
active_users = len(_rate_limits)
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"statistics": {
"total_requests_last_hour": total_requests, # Approximate
"active_users": active_users,
"available_servers": 2, # RAG and conversation servers
"total_tools": len(mcp_rag_server.available_tools) + len(mcp_conversation_server.available_tools)
},
"servers": {
"rag_server": {
"status": "healthy",
"tools_count": len(mcp_rag_server.available_tools),
"tools": mcp_rag_server.available_tools
},
"conversation_server": {
"status": "healthy",
"tools_count": len(mcp_conversation_server.available_tools),
"tools": mcp_conversation_server.available_tools
}
}
}
except Exception as e:
logger.error(f"Error getting executor status: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get status: {str(e)}")
# Health check endpoint
@router.get("/health")
async def health_check():
"""Simple health check endpoint"""
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"service": "mcp_executor"
}

View File

@@ -0,0 +1,238 @@
"""
GT 2.0 MCP Registry API
Manages registration and discovery of MCP servers in the resource cluster.
Provides endpoints for:
- Registering MCP servers
- Listing available MCP servers and tools
- Getting tool schemas
- Server health monitoring
"""
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel
import logging
from app.core.security import verify_capability_token
from app.services.mcp_server import SecureMCPWrapper, MCPServerConfig
from app.services.mcp_rag_server import mcp_rag_server
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/mcp", tags=["mcp"])
# Request/Response Models
class MCPServerInfo(BaseModel):
"""Information about an MCP server"""
server_name: str
server_type: str
available_tools: List[str]
status: str
description: str
required_capabilities: List[str]
class MCPToolSchema(BaseModel):
"""MCP tool schema information"""
name: str
description: str
parameters: Dict[str, Any]
server_name: str
class ListServersResponse(BaseModel):
"""Response for listing MCP servers"""
servers: List[MCPServerInfo]
total_count: int
class ListToolsResponse(BaseModel):
"""Response for listing MCP tools"""
tools: List[MCPToolSchema]
total_count: int
servers_count: int
# Global MCP wrapper instance
mcp_wrapper = SecureMCPWrapper()
@router.get("/servers", response_model=ListServersResponse)
async def list_mcp_servers(
knowledge_search_enabled: bool = Query(True, description="Whether dataset/knowledge search is enabled"),
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID", description="Tenant ID for context")
):
"""
List all available MCP servers and their status.
Returns information about registered MCP servers that the user
can access based on their capability tokens.
"""
try:
servers = []
if knowledge_search_enabled:
rag_config = mcp_rag_server.get_server_config()
servers.append(MCPServerInfo(
server_name=rag_config.server_name,
server_type=rag_config.server_type,
available_tools=rag_config.available_tools,
status="healthy",
description="Dataset and document search capabilities for RAG operations",
required_capabilities=rag_config.required_capabilities
))
return ListServersResponse(
servers=servers,
total_count=len(servers)
)
except Exception as e:
logger.error(f"Error listing MCP servers: {e}")
raise HTTPException(status_code=500, detail=f"Failed to list servers: {str(e)}")
@router.get("/tools", response_model=ListToolsResponse)
async def list_mcp_tools(
server_name: Optional[str] = Query(None, description="Filter by server name"),
knowledge_search_enabled: bool = Query(True, description="Whether dataset/knowledge search is enabled"),
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID", description="Tenant ID for context")
):
"""
List all available MCP tools across servers.
Can be filtered by server name to get tools for a specific server.
"""
try:
all_tools = []
servers_included = 0
if knowledge_search_enabled and (not server_name or server_name == "rag_server"):
rag_schemas = mcp_rag_server.get_tool_schemas()
for tool_name, schema in rag_schemas.items():
all_tools.append(MCPToolSchema(
name=tool_name,
description=schema.get("description", ""),
parameters=schema.get("parameters", {}),
server_name="rag_server"
))
servers_included += 1
return ListToolsResponse(
tools=all_tools,
total_count=len(all_tools),
servers_count=servers_included
)
except Exception as e:
logger.error(f"Error listing MCP tools: {e}")
raise HTTPException(status_code=500, detail=f"Failed to list tools: {str(e)}")
@router.get("/servers/{server_name}/tools")
async def get_server_tools(
server_name: str,
knowledge_search_enabled: bool = Query(True, description="Whether dataset/knowledge search is enabled"),
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID", description="Tenant ID for context")
):
"""Get tools and schemas for a specific MCP server"""
try:
if server_name == "rag_server":
if knowledge_search_enabled:
return {
"server_name": server_name,
"server_type": "rag",
"tools": mcp_rag_server.get_tool_schemas()
}
else:
return {
"server_name": server_name,
"server_type": "rag",
"tools": {}
}
else:
raise HTTPException(status_code=404, detail=f"MCP server not found: {server_name}")
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting server tools for {server_name}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get server tools: {str(e)}")
@router.get("/servers/{server_name}/health")
async def check_server_health(
server_name: str,
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID", description="Tenant ID for context")
):
"""Check health status of a specific MCP server"""
try:
if server_name == "rag_server":
return {
"server_name": server_name,
"status": "healthy",
"timestamp": "2025-01-15T12:00:00Z",
"response_time_ms": 5,
"tools_available": True
}
else:
raise HTTPException(status_code=404, detail=f"MCP server not found: {server_name}")
except HTTPException:
raise
except Exception as e:
logger.error(f"Error checking health for {server_name}: {e}")
raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}")
@router.get("/capabilities")
async def get_mcp_capabilities(
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID", description="Tenant ID for context")
):
"""
Get MCP capabilities summary for the current user.
Returns what MCP servers and tools the user has access to
based on their capability tokens.
"""
try:
capabilities = {
"user_id": "resource_cluster_user",
"tenant_domain": x_tenant_id or "default",
"available_servers": [
{
"server_name": "rag_server",
"server_type": "rag",
"tools_count": len(mcp_rag_server.available_tools),
"required_capability": "mcp:rag:*"
}
],
"total_tools": len(mcp_rag_server.available_tools),
"access_level": "full"
}
return capabilities
except Exception as e:
logger.error(f"Error getting MCP capabilities: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get capabilities: {str(e)}")
async def initialize_mcp_servers():
"""Initialize and register MCP servers"""
try:
logger.info("Initializing MCP servers...")
rag_config = mcp_rag_server.get_server_config()
logger.info(f"RAG server initialized with {len(rag_config.available_tools)} tools")
logger.info("All MCP servers initialized successfully")
except Exception as e:
logger.error(f"Error initializing MCP servers: {e}")
raise
# Export the initialization function
__all__ = ["router", "initialize_mcp_servers", "mcp_wrapper"]

View File

@@ -0,0 +1,460 @@
"""
Model Management API Endpoints - Simplified for Development
Provides REST API for model registry without capability checks for now.
"""
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, HTTPException, status, Query, Header
from pydantic import BaseModel, Field
from datetime import datetime
import logging
from app.services.model_service import default_model_service as model_service
from app.services.admin_model_config_service import AdminModelConfigService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/models", tags=["Model Management"])
# Initialize admin model config service
admin_model_service = AdminModelConfigService()
class ModelRegistrationRequest(BaseModel):
"""Request model for registering a new model"""
model_id: str = Field(..., description="Unique model identifier")
name: str = Field(..., description="Human-readable model name")
version: str = Field(..., description="Model version")
provider: str = Field(..., description="Model provider (groq, openai, local, etc.)")
model_type: str = Field(..., description="Model type (llm, embedding, image_gen, etc.)")
description: str = Field("", description="Model description")
capabilities: Optional[Dict[str, Any]] = Field(None, description="Model capabilities")
parameters: Optional[Dict[str, Any]] = Field(None, description="Model parameters")
endpoint_url: Optional[str] = Field(None, description="Model endpoint URL")
max_tokens: Optional[int] = Field(4000, description="Maximum tokens per request")
context_window: Optional[int] = Field(4000, description="Context window size")
cost_per_1k_tokens: Optional[float] = Field(0.0, description="Cost per 1000 tokens")
model_config = {"protected_namespaces": ()}
class ModelUpdateRequest(BaseModel):
"""Request model for updating model metadata"""
name: Optional[str] = None
description: Optional[str] = None
deployment_status: Optional[str] = None
health_status: Optional[str] = None
capabilities: Optional[Dict[str, Any]] = None
parameters: Optional[Dict[str, Any]] = None
class ModelUsageRequest(BaseModel):
"""Request model for tracking model usage"""
success: bool = Field(True, description="Whether the request was successful")
latency_ms: Optional[float] = Field(None, description="Request latency in milliseconds")
tokens_used: Optional[int] = Field(None, description="Number of tokens used")
@router.get("/", summary="List all models")
async def list_models(
provider: Optional[str] = Query(None, description="Filter by provider"),
model_type: Optional[str] = Query(None, description="Filter by model type"),
deployment_status: Optional[str] = Query(None, description="Filter by deployment status"),
health_status: Optional[str] = Query(None, description="Filter by health status"),
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID", description="Tenant ID for filtering accessible models")
) -> Dict[str, Any]:
"""List all registered models with optional filters"""
try:
# Get models from admin backend via sync service
# If tenant ID is provided, filter to only models accessible to that tenant
if x_tenant_id:
admin_models = await admin_model_service.get_tenant_models(x_tenant_id)
logger.info(f"Retrieved {len(admin_models)} tenant-specific models from admin backend for tenant {x_tenant_id}")
else:
admin_models = await admin_model_service.get_all_models(active_only=True)
logger.info(f"Retrieved {len(admin_models)} models from admin backend")
# Convert admin models to resource cluster format
models = []
for admin_model in admin_models:
model_dict = {
"id": admin_model.model_id, # model_id string for backwards compatibility
"uuid": admin_model.uuid, # Database UUID for unique identification
"name": admin_model.name,
"description": f"{admin_model.provider.title()} model with {admin_model.context_window or 'default'} context window",
"provider": admin_model.provider,
"model_type": admin_model.model_type,
"performance": {
"max_tokens": admin_model.max_tokens or 4096,
"context_window": admin_model.context_window or 4096,
"cost_per_1k_tokens": (admin_model.cost_per_1k_input + admin_model.cost_per_1k_output) / 2,
"latency_p50_ms": 150 # Default estimate, could be enhanced with real metrics
},
"status": {
"health": "healthy" if admin_model.is_active else "unhealthy",
"deployment": "available" if admin_model.is_active else "unavailable"
}
}
models.append(model_dict)
# If no models from admin, return empty list
if not models:
logger.warning("No models configured in admin backend")
models = []
# Apply filters if provided
filtered_models = models
if provider:
filtered_models = [m for m in filtered_models if m["provider"] == provider]
if model_type:
filtered_models = [m for m in filtered_models if m["model_type"] == model_type]
if deployment_status:
filtered_models = [m for m in filtered_models if m["status"]["deployment"] == deployment_status]
if health_status:
filtered_models = [m for m in filtered_models if m["status"]["health"] == health_status]
return {
"models": filtered_models,
"total": len(filtered_models),
"filters": {
"provider": provider,
"model_type": model_type,
"deployment_status": deployment_status,
"health_status": health_status
},
"last_updated": "2025-09-09T13:00:00Z"
}
except Exception as e:
logger.error(f"Error listing models: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to list models"
)
@router.post("/", status_code=status.HTTP_201_CREATED, summary="Register a new model")
async def register_model(
model_request: ModelRegistrationRequest
) -> Dict[str, Any]:
"""Register a new model in the registry"""
try:
model = await model_service.register_model(
model_id=model_request.model_id,
name=model_request.name,
version=model_request.version,
provider=model_request.provider,
model_type=model_request.model_type,
description=model_request.description,
capabilities=model_request.capabilities,
parameters=model_request.parameters,
endpoint_url=model_request.endpoint_url,
max_tokens=model_request.max_tokens,
context_window=model_request.context_window,
cost_per_1k_tokens=model_request.cost_per_1k_tokens
)
return {
"message": "Model registered successfully",
"model": model
}
except Exception as e:
logger.error(f"Error registering model {model_request.model_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to register model"
)
@router.get("/{model_id}", summary="Get model details")
async def get_model(
model_id: str,
) -> Dict[str, Any]:
"""Get detailed information about a specific model"""
try:
model = await model_service.get_model(model_id)
if not model:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found"
)
return {"model": model}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting model {model_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get model"
)
@router.put("/{model_id}", summary="Update model metadata")
async def update_model(
model_id: str,
update_request: ModelUpdateRequest,
) -> Dict[str, Any]:
"""Update model metadata and status"""
try:
# Check if model exists
model = await model_service.get_model(model_id)
if not model:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found"
)
# Update status fields
if update_request.deployment_status or update_request.health_status:
success = await model_service.update_model_status(
model_id,
deployment_status=update_request.deployment_status,
health_status=update_request.health_status
)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update model status"
)
# For other fields, we'd need to extend the model service
# This is a simplified implementation
updated_model = await model_service.get_model(model_id)
return {
"message": "Model updated successfully",
"model": updated_model
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error updating model {model_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update model"
)
@router.delete("/{model_id}", summary="Retire a model")
async def retire_model(
model_id: str,
reason: str = Query("", description="Reason for retirement"),
) -> Dict[str, Any]:
"""Retire a model (mark as no longer available)"""
try:
success = await model_service.retire_model(model_id, reason)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found"
)
return {
"message": f"Model {model_id} retired successfully",
"reason": reason
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retiring model {model_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retire model"
)
@router.post("/{model_id}/usage", summary="Track model usage")
async def track_model_usage(
model_id: str,
usage_request: ModelUsageRequest,
) -> Dict[str, Any]:
"""Track usage and performance metrics for a model"""
try:
await model_service.track_model_usage(
model_id,
success=usage_request.success,
latency_ms=usage_request.latency_ms
)
return {
"message": "Usage tracked successfully",
"model_id": model_id
}
except Exception as e:
logger.error(f"Error tracking usage for model {model_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
@router.get("/{model_id}/health", summary="Check model health")
async def check_model_health(
model_id: str,
) -> Dict[str, Any]:
"""Check the health status of a specific model"""
try:
health_result = await model_service.check_model_health(model_id)
# codeql[py/stack-trace-exposure] returns health status dict, not error details
return {
"model_id": model_id,
"health": health_result
}
except Exception as e:
logger.error(f"Error checking health for model {model_id}: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
@router.get("/health/bulk", summary="Bulk health check")
async def bulk_health_check(
) -> Dict[str, Any]:
"""Check health of all registered models"""
try:
health_results = await model_service.bulk_health_check()
return {
"health_check": health_results,
"timestamp": "2024-01-01T00:00:00Z" # Would use actual timestamp
}
except Exception as e:
logger.error(f"Error in bulk health check: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
@router.get("/analytics", summary="Get model analytics")
async def get_model_analytics(
model_id: Optional[str] = Query(None, description="Specific model ID"),
timeframe_hours: int = Query(24, description="Analytics timeframe in hours"),
) -> Dict[str, Any]:
"""Get analytics for model usage and performance"""
try:
analytics = await model_service.get_model_analytics(
model_id=model_id,
timeframe_hours=timeframe_hours
)
return {
"analytics": analytics,
"timeframe_hours": timeframe_hours,
"generated_at": "2024-01-01T00:00:00Z" # Would use actual timestamp
}
except Exception as e:
logger.error(f"Error getting analytics: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get analytics"
)
@router.post("/initialize", summary="Initialize default models")
async def initialize_default_models(
) -> Dict[str, Any]:
"""Initialize the registry with default models"""
try:
await model_service.initialize_default_models()
models = await model_service.list_models()
return {
"message": "Default models initialized successfully",
"total_models": len(models)
}
except Exception as e:
logger.error(f"Error initializing default models: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to initialize default models"
)
@router.get("/providers/available", summary="Get available providers")
async def get_available_providers(
) -> Dict[str, Any]:
"""Get list of available model providers"""
try:
models = await model_service.list_models()
providers = {}
for model in models:
provider = model["provider"]
if provider not in providers:
providers[provider] = {
"name": provider,
"model_count": 0,
"model_types": set(),
"status": "available"
}
providers[provider]["model_count"] += 1
providers[provider]["model_types"].add(model["model_type"])
# Convert sets to lists for JSON serialization
for provider_info in providers.values():
provider_info["model_types"] = list(provider_info["model_types"])
return {
"providers": list(providers.values()),
"total_providers": len(providers)
}
except Exception as e:
logger.error(f"Error getting available providers: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get available providers"
)
@router.post("/sync", summary="Force sync from admin cluster")
async def force_sync_models() -> Dict[str, Any]:
"""Force immediate sync of models from admin cluster"""
try:
await admin_model_service.force_sync()
models = await admin_model_service.get_all_models(active_only=True)
return {
"message": "Models synced successfully",
"models_count": len(models),
"sync_timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error forcing model sync: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to sync models"
)

View File

@@ -0,0 +1,358 @@
"""
RAG API endpoints for Resource Cluster
STATELESS processing of documents and embeddings.
All data is immediately returned to tenant - nothing is stored.
"""
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Body
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field
import logging
from app.core.backends.document_processor import DocumentProcessorBackend, ChunkingStrategy
from app.core.backends.embedding_backend import EmbeddingBackend
from app.core.security import verify_capability_token
logger = logging.getLogger(__name__)
router = APIRouter(tags=["rag"])
class ProcessDocumentRequest(BaseModel):
"""Request for document processing"""
document_type: str = Field(..., description="File type (.pdf, .docx, .txt, .md, .html)")
chunking_strategy: str = Field(default="hybrid", description="Chunking strategy")
chunk_size: int = Field(default=512, description="Target chunk size in tokens")
chunk_overlap: int = Field(default=128, description="Overlap between chunks")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Non-sensitive metadata")
class GenerateEmbeddingsRequest(BaseModel):
"""Request for embedding generation"""
texts: List[str] = Field(..., description="Texts to embed")
instruction: Optional[str] = Field(default=None, description="Optional instruction for embeddings")
class ProcessDocumentResponse(BaseModel):
"""Response from document processing"""
chunks: List[Dict[str, Any]] = Field(..., description="Document chunks with metadata")
chunk_count: int = Field(..., description="Number of chunks generated")
processing_time_ms: int = Field(..., description="Processing time in milliseconds")
class GenerateEmbeddingsResponse(BaseModel):
"""Response from embedding generation"""
embeddings: List[List[float]] = Field(..., description="Generated embeddings")
embedding_count: int = Field(..., description="Number of embeddings generated")
dimensions: int = Field(..., description="Embedding dimensions")
model: str = Field(..., description="Model used for embeddings")
# Initialize backends
document_processor = DocumentProcessorBackend()
embedding_backend = EmbeddingBackend()
@router.post("/process-document", response_model=ProcessDocumentResponse)
async def process_document(
file: UploadFile = File(...),
request: ProcessDocumentRequest = Depends(),
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> ProcessDocumentResponse:
"""
Process a document into chunks - STATELESS operation.
Security:
- No user data is stored
- Document processed in memory only
- Immediate response with chunks
- Memory cleared after processing
"""
import time
start_time = time.time()
try:
# Verify RAG capabilities
if "rag_processing" not in capabilities.get("resources", []):
raise HTTPException(
status_code=403,
detail="RAG processing capability not granted"
)
# Read file content (will be cleared from memory)
content = await file.read()
# Validate document
validation = await document_processor.validate_document(
content_size=len(content),
document_type=request.document_type
)
if not validation["valid"]:
raise HTTPException(
status_code=400,
detail=f"Document validation failed: {validation['errors']}"
)
# Create chunking strategy
strategy = ChunkingStrategy(
strategy_type=request.chunking_strategy,
chunk_size=request.chunk_size,
chunk_overlap=request.chunk_overlap
)
# Process document (stateless)
chunks = await document_processor.process_document(
content=content,
document_type=request.document_type,
strategy=strategy,
metadata={
"tenant_id": capabilities.get("tenant_id"),
"document_type": request.document_type,
"processing_timestamp": time.time()
}
)
# Clear content from memory
del content
processing_time = int((time.time() - start_time) * 1000)
logger.info(
f"Processed document into {len(chunks)} chunks for tenant "
f"{capabilities.get('tenant_id')} (STATELESS)"
)
return ProcessDocumentResponse(
chunks=chunks,
chunk_count=len(chunks),
processing_time_ms=processing_time
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing document: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate-embeddings", response_model=GenerateEmbeddingsResponse)
async def generate_embeddings(
request: GenerateEmbeddingsRequest,
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> GenerateEmbeddingsResponse:
"""
Generate embeddings for texts - STATELESS operation.
Security:
- No text content is stored
- Embeddings generated via GPU cluster
- Immediate response with vectors
- Memory cleared after generation
"""
try:
# Verify embedding capabilities
if "embedding_generation" not in capabilities.get("resources", []):
raise HTTPException(
status_code=403,
detail="Embedding generation capability not granted"
)
# Validate texts
validation = await embedding_backend.validate_texts(request.texts)
if not validation["valid"]:
raise HTTPException(
status_code=400,
detail=f"Text validation failed: {validation['errors']}"
)
# Generate embeddings (stateless)
embeddings = await embedding_backend.generate_embeddings(
texts=request.texts,
instruction=request.instruction,
tenant_id=capabilities.get("tenant_id"),
request_id=capabilities.get("request_id")
)
logger.info(
f"Generated {len(embeddings)} embeddings for tenant "
f"{capabilities.get('tenant_id')} (STATELESS)"
)
return GenerateEmbeddingsResponse(
embeddings=embeddings,
embedding_count=len(embeddings),
dimensions=embedding_backend.embedding_dimensions,
model=embedding_backend.model_name
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate-query-embeddings", response_model=GenerateEmbeddingsResponse)
async def generate_query_embeddings(
request: GenerateEmbeddingsRequest,
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> GenerateEmbeddingsResponse:
"""
Generate embeddings specifically for queries - STATELESS operation.
Uses BGE-M3 query instruction for better retrieval performance.
"""
try:
# Verify embedding capabilities
if "embedding_generation" not in capabilities.get("resources", []):
raise HTTPException(
status_code=403,
detail="Embedding generation capability not granted"
)
# Validate queries
validation = await embedding_backend.validate_texts(request.texts)
if not validation["valid"]:
raise HTTPException(
status_code=400,
detail=f"Query validation failed: {validation['errors']}"
)
# Generate query embeddings (stateless)
embeddings = await embedding_backend.generate_query_embeddings(
queries=request.texts,
tenant_id=capabilities.get("tenant_id"),
request_id=capabilities.get("request_id")
)
logger.info(
f"Generated {len(embeddings)} query embeddings for tenant "
f"{capabilities.get('tenant_id')} (STATELESS)"
)
return GenerateEmbeddingsResponse(
embeddings=embeddings,
embedding_count=len(embeddings),
dimensions=embedding_backend.embedding_dimensions,
model=embedding_backend.model_name
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error generating query embeddings: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate-document-embeddings", response_model=GenerateEmbeddingsResponse)
async def generate_document_embeddings(
request: GenerateEmbeddingsRequest,
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> GenerateEmbeddingsResponse:
"""
Generate embeddings specifically for documents - STATELESS operation.
Uses BGE-M3 document configuration for optimal indexing.
"""
try:
# Verify embedding capabilities
if "embedding_generation" not in capabilities.get("resources", []):
raise HTTPException(
status_code=403,
detail="Embedding generation capability not granted"
)
# Validate documents
validation = await embedding_backend.validate_texts(request.texts)
if not validation["valid"]:
raise HTTPException(
status_code=400,
detail=f"Document validation failed: {validation['errors']}"
)
# Generate document embeddings (stateless)
embeddings = await embedding_backend.generate_document_embeddings(
documents=request.texts,
tenant_id=capabilities.get("tenant_id"),
request_id=capabilities.get("request_id")
)
logger.info(
f"Generated {len(embeddings)} document embeddings for tenant "
f"{capabilities.get('tenant_id')} (STATELESS)"
)
return GenerateEmbeddingsResponse(
embeddings=embeddings,
embedding_count=len(embeddings),
dimensions=embedding_backend.embedding_dimensions,
model=embedding_backend.model_name
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error generating document embeddings: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/health")
async def health_check() -> Dict[str, Any]:
"""
Check RAG processing health - no user data exposed.
"""
try:
doc_health = await document_processor.check_health()
embed_health = await embedding_backend.check_health()
overall_status = "healthy"
if doc_health["status"] != "healthy" or embed_health["status"] != "healthy":
overall_status = "degraded"
# codeql[py/stack-trace-exposure] returns health status dict, not error details
return {
"status": overall_status,
"document_processor": doc_health,
"embedding_backend": embed_health,
"stateless": True,
"memory_management": "active"
}
except Exception as e:
logger.error(f"Health check failed: {e}")
return {
"status": "unhealthy",
"error": "Health check failed"
}
@router.get("/capabilities")
async def get_rag_capabilities() -> Dict[str, Any]:
"""
Get RAG processing capabilities - no sensitive data.
"""
return {
"document_processor": {
"supported_formats": document_processor.supported_formats,
"chunking_strategies": ["fixed", "semantic", "hierarchical", "hybrid"],
"default_chunk_size": document_processor.default_chunk_size,
"default_chunk_overlap": document_processor.default_chunk_overlap
},
"embedding_backend": {
"model": embedding_backend.model_name,
"dimensions": embedding_backend.embedding_dimensions,
"max_batch_size": embedding_backend.max_batch_size,
"max_sequence_length": embedding_backend.max_sequence_length
},
"security": {
"stateless_processing": True,
"memory_cleanup": True,
"data_encryption": True,
"tenant_isolation": True
}
}

View File

@@ -0,0 +1,404 @@
"""
GT 2.0 Resource Cluster - Resource Management API with CB-REST Standards
This module handles non-AI endpoints using CB-REST standard.
AI inference endpoints maintain OpenAI compatibility.
"""
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, Query, Request, BackgroundTasks
from pydantic import BaseModel, Field
import logging
import uuid
from datetime import datetime, timedelta
from app.core.api_standards import (
format_response,
format_error,
ErrorCode,
APIError
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/resources", tags=["Resource Management"])
# Request/Response Models
class HealthCheckRequest(BaseModel):
resource_id: str = Field(..., description="Resource identifier")
deep_check: bool = Field(False, description="Perform deep health check")
class RAGProcessRequest(BaseModel):
document_content: str = Field(..., description="Document content to process")
chunking_strategy: str = Field("semantic", description="Chunking strategy")
chunk_size: int = Field(1000, ge=100, le=10000)
chunk_overlap: int = Field(100, ge=0, le=500)
embedding_model: str = Field("text-embedding-3-small")
class SemanticSearchRequest(BaseModel):
query: str = Field(..., description="Search query")
collection_id: str = Field(..., description="Vector collection ID")
top_k: int = Field(10, ge=1, le=100)
relevance_threshold: float = Field(0.7, ge=0.0, le=1.0)
filters: Optional[Dict[str, Any]] = None
class AgentExecutionRequest(BaseModel):
agent_type: str = Field(..., description="Agent type")
task: Dict[str, Any] = Field(..., description="Task configuration")
timeout: int = Field(300, ge=10, le=3600, description="Timeout in seconds")
execution_context: Optional[Dict[str, Any]] = None
@router.get("/health/system")
async def system_health(request: Request):
"""
Get overall system health status
CB-REST Capability Required: health:system:read
"""
try:
health_status = {
"overall_health": "healthy",
"service_statuses": [
{"service": "ai_inference", "status": "healthy", "latency_ms": 45},
{"service": "rag_processing", "status": "healthy", "latency_ms": 120},
{"service": "vector_storage", "status": "healthy", "latency_ms": 30},
{"service": "agent_orchestration", "status": "healthy", "latency_ms": 85}
],
"resource_utilization": {
"cpu_percent": 42.5,
"memory_percent": 68.3,
"gpu_percent": 35.0,
"disk_percent": 55.2
},
"performance_metrics": {
"requests_per_second": 145,
"average_latency_ms": 95,
"error_rate_percent": 0.02,
"active_connections": 234
},
"timestamp": datetime.utcnow().isoformat()
}
return format_response(
data=health_status,
capability_used="health:system:read",
request_id=getattr(request.state, 'request_id', None)
)
except Exception as e:
logger.error(f"Failed to get system health: {e}")
return format_error(
code=ErrorCode.SYSTEM_ERROR,
message="Internal server error",
capability_used="health:system:read",
request_id=getattr(request.state, 'request_id', None)
)
@router.post("/health/check")
async def check_resource_health(
request: Request,
health_req: HealthCheckRequest,
background_tasks: BackgroundTasks
):
"""
Perform health check on a specific resource
CB-REST Capability Required: health:resource:check
"""
try:
# Mock health check result
health_result = {
"resource_id": health_req.resource_id,
"status": "healthy",
"latency_ms": 87,
"last_successful_request": datetime.utcnow().isoformat(),
"error_count_24h": 3,
"success_rate_24h": 99.97,
"details": {
"endpoint_reachable": True,
"authentication_valid": True,
"rate_limit_ok": True,
"response_time_acceptable": True
}
}
if health_req.deep_check:
health_result["deep_check_results"] = {
"model_loaded": True,
"memory_usage_mb": 2048,
"inference_test_passed": True,
"test_latency_ms": 145
}
return format_response(
data=health_result,
capability_used="health:resource:check",
request_id=getattr(request.state, 'request_id', None)
)
except Exception as e:
logger.error(f"Failed to check resource health: {e}")
return format_error(
code=ErrorCode.SYSTEM_ERROR,
message="Internal server error",
capability_used="health:resource:check",
request_id=getattr(request.state, 'request_id', None)
)
@router.post("/rag/process-document")
async def process_document(
request: Request,
rag_req: RAGProcessRequest,
background_tasks: BackgroundTasks
):
"""
Process document for RAG pipeline
CB-REST Capability Required: rag:document:process
"""
try:
processing_id = str(uuid.uuid4())
# Start async processing
background_tasks.add_task(
process_document_async,
processing_id,
rag_req
)
return format_response(
data={
"processing_id": processing_id,
"status": "processing",
"chunk_preview": [
{
"chunk_id": f"chunk_{i}",
"text": f"Sample chunk {i} from document...",
"metadata": {"position": i, "size": rag_req.chunk_size}
}
for i in range(3)
],
"estimated_completion": (datetime.utcnow() + timedelta(seconds=30)).isoformat()
},
capability_used="rag:document:process",
request_id=getattr(request.state, 'request_id', None)
)
except Exception as e:
logger.error(f"Failed to process document: {e}")
return format_error(
code=ErrorCode.SYSTEM_ERROR,
message="Internal server error",
capability_used="rag:document:process",
request_id=getattr(request.state, 'request_id', None)
)
@router.post("/rag/semantic-search")
async def semantic_search(
request: Request,
search_req: SemanticSearchRequest
):
"""
Perform semantic search in vector database
CB-REST Capability Required: rag:search:execute
"""
try:
# Mock search results
results = [
{
"document_id": f"doc_{i}",
"chunk_id": f"chunk_{i}",
"text": f"Relevant text snippet {i} matching query: {search_req.query[:50]}...",
"relevance_score": 0.95 - (i * 0.05),
"metadata": {
"source": f"document_{i}.pdf",
"page": i + 1,
"timestamp": datetime.utcnow().isoformat()
}
}
for i in range(min(search_req.top_k, 5))
]
return format_response(
data={
"results": results,
"query_embedding": [0.1] * 10, # Truncated for brevity
"search_metadata": {
"collection_id": search_req.collection_id,
"documents_searched": 1500,
"search_time_ms": 145,
"model_used": "text-embedding-3-small"
}
},
capability_used="rag:search:execute",
request_id=getattr(request.state, 'request_id', None)
)
except Exception as e:
logger.error(f"Failed to perform semantic search: {e}")
return format_error(
code=ErrorCode.SYSTEM_ERROR,
message="Internal server error",
capability_used="rag:search:execute",
request_id=getattr(request.state, 'request_id', None)
)
@router.post("/agents/execute")
async def execute_agent(
request: Request,
agent_req: AgentExecutionRequest,
background_tasks: BackgroundTasks
):
"""
Execute an agentic workflow
CB-REST Capability Required: agent:*:execute
"""
try:
execution_id = str(uuid.uuid4())
# Start async agent execution
background_tasks.add_task(
execute_agent_async,
execution_id,
agent_req
)
return format_response(
data={
"execution_id": execution_id,
"status": "queued",
"estimated_duration": agent_req.timeout // 2,
"resource_allocation": {
"cpu_cores": 2,
"memory_mb": 4096,
"gpu_allocation": 0.25
}
},
capability_used="agent:*:execute",
request_id=getattr(request.state, 'request_id', None)
)
except Exception as e:
logger.error(f"Failed to execute agent: {e}")
return format_error(
code=ErrorCode.SYSTEM_ERROR,
message="Internal server error",
capability_used="agent:*:execute",
request_id=getattr(request.state, 'request_id', None)
)
@router.get("/agents/{execution_id}/status")
async def get_agent_status(
request: Request,
execution_id: str
):
"""
Get agent execution status
CB-REST Capability Required: agent:{execution_id}:status
"""
try:
# Mock status
status = {
"execution_id": execution_id,
"status": "running",
"progress_percent": 65,
"current_task": {
"name": "data_analysis",
"status": "in_progress",
"started_at": datetime.utcnow().isoformat()
},
"memory_usage": {
"working_memory_mb": 512,
"context_size": 8192,
"tool_calls_made": 12
},
"performance_metrics": {
"steps_completed": 8,
"total_steps": 12,
"average_step_time_ms": 2500,
"errors_encountered": 0
}
}
return format_response(
data=status,
capability_used=f"agent:{execution_id}:status",
request_id=getattr(request.state, 'request_id', None)
)
except Exception as e:
logger.error(f"Failed to get agent status: {e}")
return format_error(
code=ErrorCode.SYSTEM_ERROR,
message="Internal server error",
capability_used=f"agent:{execution_id}:status",
request_id=getattr(request.state, 'request_id', None)
)
@router.post("/usage/record")
async def record_usage(
request: Request,
operation_type: str,
resource_id: str,
usage_metrics: Dict[str, Any]
):
"""
Record resource usage for billing and analytics
CB-REST Capability Required: usage:*:write
"""
try:
usage_record = {
"record_id": str(uuid.uuid4()),
"recorded": True,
"updated_quotas": {
"tokens_remaining": 950000,
"requests_remaining": 9500,
"cost_accumulated_cents": 125
},
"warnings": []
}
# Check for quota warnings
if usage_metrics.get("tokens_used", 0) > 10000:
usage_record["warnings"].append({
"type": "high_token_usage",
"message": "High token usage detected",
"threshold": 10000,
"actual": usage_metrics.get("tokens_used", 0)
})
return format_response(
data=usage_record,
capability_used="usage:*:write",
request_id=getattr(request.state, 'request_id', None)
)
except Exception as e:
logger.error(f"Failed to record usage: {e}")
return format_error(
code=ErrorCode.SYSTEM_ERROR,
message="Internal server error",
capability_used="usage:*:write",
request_id=getattr(request.state, 'request_id', None)
)
# Async helper functions
async def process_document_async(processing_id: str, rag_req: RAGProcessRequest):
"""Background task for document processing"""
# Implement actual document processing logic here
await asyncio.sleep(30) # Simulate processing
logger.info(f"Document processing completed: {processing_id}")
async def execute_agent_async(execution_id: str, agent_req: AgentExecutionRequest):
"""Background task for agent execution"""
# Implement actual agent execution logic here
await asyncio.sleep(agent_req.timeout // 2) # Simulate execution
logger.info(f"Agent execution completed: {execution_id}")

View File

@@ -0,0 +1,569 @@
"""
GT 2.0 Resource Cluster - External Services API
Orchestrate external web services with perfect tenant isolation
"""
from fastapi import APIRouter, HTTPException, Depends, Body
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field
import logging
from datetime import datetime
from app.core.security import verify_capability_token
from app.services.service_manager import ServiceManager, ServiceInstance
logger = logging.getLogger(__name__)
router = APIRouter(tags=["services"])
# Initialize service manager
service_manager = ServiceManager()
class CreateServiceRequest(BaseModel):
"""Request to create a new service instance"""
service_type: str = Field(..., description="Service type: ctfd, canvas, guacamole")
config_overrides: Optional[Dict[str, Any]] = Field(default=None, description="Custom configuration overrides")
class ServiceInstanceResponse(BaseModel):
"""Service instance details response"""
instance_id: str
tenant_id: str
service_type: str
status: str
endpoint_url: str
sso_token: Optional[str]
created_at: str
last_heartbeat: str
resource_usage: Dict[str, Any]
class ServiceHealthResponse(BaseModel):
"""Service health status response"""
status: str
instance_status: str
endpoint: str
last_check: str
pod_phase: Optional[str] = None
restart_count: Optional[int] = None
error: Optional[str] = None
class ServiceListResponse(BaseModel):
"""List of service instances response"""
instances: List[ServiceInstanceResponse]
total: int
class SSOTokenResponse(BaseModel):
"""SSO token generation response"""
token: str
expires_at: str
iframe_config: Dict[str, Any]
@router.post("/instances", response_model=ServiceInstanceResponse)
async def create_service_instance(
request: CreateServiceRequest,
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> ServiceInstanceResponse:
"""
Create a new external service instance for a tenant.
Supports:
- CTFd cybersecurity challenges platform
- Canvas LMS learning management system
- Guacamole remote desktop access
"""
try:
# Verify external services capability
if "external_services" not in capabilities.get("resources", []):
raise HTTPException(
status_code=403,
detail="External services capability not granted"
)
# Validate service type
supported_services = ["ctfd", "canvas", "guacamole"]
if request.service_type not in supported_services:
raise HTTPException(
status_code=400,
detail=f"Unsupported service type. Supported: {supported_services}"
)
# Extract tenant ID from capabilities
tenant_id = capabilities.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=400,
detail="Tenant ID not found in capabilities"
)
# Create service instance
instance = await service_manager.create_service_instance(
tenant_id=tenant_id,
service_type=request.service_type,
config_overrides=request.config_overrides
)
logger.info(
f"Created {request.service_type} instance {instance.instance_id} "
f"for tenant {tenant_id}"
)
return ServiceInstanceResponse(
instance_id=instance.instance_id,
tenant_id=instance.tenant_id,
service_type=instance.service_type,
status=instance.status,
endpoint_url=instance.endpoint_url,
sso_token=instance.sso_token,
created_at=instance.created_at.isoformat(),
last_heartbeat=instance.last_heartbeat.isoformat(),
resource_usage=instance.resource_usage or {}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to create service instance: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/instances/{instance_id}", response_model=ServiceInstanceResponse)
async def get_service_instance(
instance_id: str,
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> ServiceInstanceResponse:
"""Get details of a specific service instance"""
try:
# Verify external services capability
if "external_services" not in capabilities.get("resources", []):
raise HTTPException(
status_code=403,
detail="External services capability not granted"
)
instance = await service_manager.get_service_instance(instance_id)
if not instance:
raise HTTPException(
status_code=404,
detail=f"Service instance {instance_id} not found"
)
# Verify tenant access
tenant_id = capabilities.get("tenant_id")
if instance.tenant_id != tenant_id:
raise HTTPException(
status_code=403,
detail="Access denied to this service instance"
)
return ServiceInstanceResponse(
instance_id=instance.instance_id,
tenant_id=instance.tenant_id,
service_type=instance.service_type,
status=instance.status,
endpoint_url=instance.endpoint_url,
sso_token=instance.sso_token,
created_at=instance.created_at.isoformat(),
last_heartbeat=instance.last_heartbeat.isoformat(),
resource_usage=instance.resource_usage or {}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get service instance {instance_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/tenant/{tenant_id}", response_model=ServiceListResponse)
async def list_tenant_services(
tenant_id: str,
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> ServiceListResponse:
"""List all service instances for a tenant"""
try:
# Verify external services capability
if "external_services" not in capabilities.get("resources", []):
raise HTTPException(
status_code=403,
detail="External services capability not granted"
)
# Verify tenant access
if capabilities.get("tenant_id") != tenant_id:
raise HTTPException(
status_code=403,
detail="Access denied to this tenant's services"
)
instances = await service_manager.list_tenant_instances(tenant_id)
instance_responses = [
ServiceInstanceResponse(
instance_id=instance.instance_id,
tenant_id=instance.tenant_id,
service_type=instance.service_type,
status=instance.status,
endpoint_url=instance.endpoint_url,
sso_token=instance.sso_token,
created_at=instance.created_at.isoformat(),
last_heartbeat=instance.last_heartbeat.isoformat(),
resource_usage=instance.resource_usage or {}
)
for instance in instances
]
return ServiceListResponse(
instances=instance_responses,
total=len(instance_responses)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to list services for tenant {tenant_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/instances/{instance_id}")
async def stop_service_instance(
instance_id: str,
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> Dict[str, Any]:
"""Stop and remove a service instance"""
try:
# Verify external services capability
if "external_services" not in capabilities.get("resources", []):
raise HTTPException(
status_code=403,
detail="External services capability not granted"
)
instance = await service_manager.get_service_instance(instance_id)
if not instance:
raise HTTPException(
status_code=404,
detail=f"Service instance {instance_id} not found"
)
# Verify tenant access
tenant_id = capabilities.get("tenant_id")
if instance.tenant_id != tenant_id:
raise HTTPException(
status_code=403,
detail="Access denied to this service instance"
)
success = await service_manager.stop_service_instance(instance_id)
if not success:
raise HTTPException(
status_code=500,
detail=f"Failed to stop service instance {instance_id}"
)
logger.info(
f"Stopped {instance.service_type} instance {instance_id} "
f"for tenant {tenant_id}"
)
return {
"success": True,
"message": f"Service instance {instance_id} stopped successfully",
"stopped_at": datetime.utcnow().isoformat()
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to stop service instance {instance_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/health/{instance_id}", response_model=ServiceHealthResponse)
async def get_service_health(
instance_id: str,
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> ServiceHealthResponse:
"""Get health status of a service instance"""
try:
# Verify external services capability
if "external_services" not in capabilities.get("resources", []):
raise HTTPException(
status_code=403,
detail="External services capability not granted"
)
instance = await service_manager.get_service_instance(instance_id)
if not instance:
raise HTTPException(
status_code=404,
detail=f"Service instance {instance_id} not found"
)
# Verify tenant access
tenant_id = capabilities.get("tenant_id")
if instance.tenant_id != tenant_id:
raise HTTPException(
status_code=403,
detail="Access denied to this service instance"
)
health = await service_manager.get_service_health(instance_id)
return ServiceHealthResponse(
status=health.get("status", "unknown"),
instance_status=health.get("instance_status", "unknown"),
endpoint=health.get("endpoint", instance.endpoint_url),
last_check=health.get("last_check", datetime.utcnow().isoformat()),
pod_phase=health.get("pod_phase"),
restart_count=health.get("restart_count"),
error=health.get("error")
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get health for service instance {instance_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/sso-token/{instance_id}", response_model=SSOTokenResponse)
async def generate_sso_token(
instance_id: str,
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> SSOTokenResponse:
"""Generate SSO token for iframe embedding"""
try:
# Verify external services capability
if "external_services" not in capabilities.get("resources", []):
raise HTTPException(
status_code=403,
detail="External services capability not granted"
)
instance = await service_manager.get_service_instance(instance_id)
if not instance:
raise HTTPException(
status_code=404,
detail=f"Service instance {instance_id} not found"
)
# Verify tenant access
tenant_id = capabilities.get("tenant_id")
if instance.tenant_id != tenant_id:
raise HTTPException(
status_code=403,
detail="Access denied to this service instance"
)
# Generate new SSO token
sso_token = await service_manager._generate_sso_token(instance)
# Update instance with new token
instance.sso_token = sso_token
await service_manager._persist_instance(instance)
# Generate iframe configuration
iframe_config = {
"src": f"{instance.endpoint_url}?sso_token={sso_token}",
"sandbox": [
"allow-same-origin",
"allow-scripts",
"allow-forms",
"allow-popups",
"allow-modals"
],
"allow": "camera; microphone; clipboard-read; clipboard-write",
"referrerpolicy": "strict-origin-when-cross-origin",
"loading": "lazy"
}
# Set security policies based on service type
if instance.service_type == "guacamole":
iframe_config["sandbox"].extend([
"allow-pointer-lock",
"allow-fullscreen"
])
elif instance.service_type == "ctfd":
iframe_config["sandbox"].extend([
"allow-downloads",
"allow-top-navigation-by-user-activation"
])
expires_at = datetime.utcnow().isoformat() # Token expires in 24 hours
logger.info(
f"Generated SSO token for {instance.service_type} instance "
f"{instance_id} for tenant {tenant_id}"
)
return SSOTokenResponse(
token=sso_token,
expires_at=expires_at,
iframe_config=iframe_config
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to generate SSO token for {instance_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/templates")
async def get_service_templates(
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> Dict[str, Any]:
"""Get available service templates and their capabilities"""
try:
# Verify external services capability
if "external_services" not in capabilities.get("resources", []):
raise HTTPException(
status_code=403,
detail="External services capability not granted"
)
# Return sanitized template information (no sensitive config)
templates = {
"ctfd": {
"name": "CTFd Platform",
"description": "Cybersecurity capture-the-flag challenges and competitions",
"category": "cybersecurity",
"features": [
"Challenge creation and management",
"Team-based competitions",
"Scoring and leaderboards",
"User management and registration",
"Real-time updates and notifications"
],
"resource_requirements": {
"memory": "2Gi",
"cpu": "1000m",
"storage": "7Gi"
},
"estimated_startup_time": "2-3 minutes",
"ports": {"http": 8000},
"sso_supported": True
},
"canvas": {
"name": "Canvas LMS",
"description": "Learning management system for educational courses",
"category": "education",
"features": [
"Course creation and management",
"Assignment and grading system",
"Discussion forums and messaging",
"Grade book and analytics",
"Integration with external tools"
],
"resource_requirements": {
"memory": "4Gi",
"cpu": "2000m",
"storage": "30Gi"
},
"estimated_startup_time": "3-5 minutes",
"ports": {"http": 3000},
"sso_supported": True
},
"guacamole": {
"name": "Apache Guacamole",
"description": "Remote desktop access for cyber lab environments",
"category": "remote_access",
"features": [
"RDP, VNC, and SSH connections",
"Session recording and playback",
"Multi-user concurrent access",
"Connection sharing and collaboration",
"File transfer capabilities"
],
"resource_requirements": {
"memory": "1Gi",
"cpu": "500m",
"storage": "11Gi"
},
"estimated_startup_time": "2-4 minutes",
"ports": {"http": 8080},
"sso_supported": True
}
}
return {
"templates": templates,
"total": len(templates),
"categories": list(set(t["category"] for t in templates.values())),
"extensible": True,
"note": "Additional service templates can be added through the GT 2.0 extensibility framework"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get service templates: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/capabilities")
async def get_service_capabilities() -> Dict[str, Any]:
"""Get service management capabilities - no authentication required"""
return {
"service_orchestration": {
"platform": "kubernetes",
"isolation": "namespace_based",
"network_policies": True,
"resource_quotas": True,
"auto_scaling": False, # Fixed replicas for now
"health_monitoring": True,
"automatic_recovery": True
},
"supported_services": [
"ctfd",
"canvas",
"guacamole"
],
"security_features": {
"tenant_isolation": True,
"container_security": True,
"network_isolation": True,
"sso_integration": True,
"encrypted_storage": True,
"capability_based_auth": True
},
"resource_management": {
"cpu_limits": True,
"memory_limits": True,
"storage_quotas": True,
"persistent_volumes": True,
"automatic_cleanup": True
},
"deployment_features": {
"rolling_updates": True,
"health_checks": True,
"restart_policies": True,
"ingress_management": True,
"tls_termination": True,
"certificate_management": True
}
}
@router.post("/cleanup/orphaned")
async def cleanup_orphaned_resources(
capabilities: Dict[str, Any] = Depends(verify_capability_token)
) -> Dict[str, Any]:
"""Clean up orphaned Kubernetes resources"""
try:
# Verify admin capabilities (this is a dangerous operation)
if "admin" not in capabilities.get("user_type", ""):
raise HTTPException(
status_code=403,
detail="Admin privileges required for cleanup operations"
)
await service_manager.cleanup_orphaned_resources()
return {
"success": True,
"message": "Orphaned resource cleanup completed",
"cleanup_time": datetime.utcnow().isoformat()
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to cleanup orphaned resources: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1 @@
# API clients for external service communication

View File

@@ -0,0 +1,219 @@
"""
API Key Client for fetching tenant-specific API keys from Control Panel.
This client handles:
- Fetching decrypted API keys from Control Panel's internal API
- 5-minute in-memory caching to reduce database calls
- Service-to-service authentication
- NO FALLBACKS - per GT 2.0 principles
"""
import asyncio
import logging
import time
from typing import Dict, Any, Optional
from dataclasses import dataclass
import httpx
logger = logging.getLogger(__name__)
@dataclass
class CachedAPIKey:
"""Cached API key entry with expiration tracking"""
api_key: str
api_secret: Optional[str]
metadata: Dict[str, Any]
fetched_at: float
def is_expired(self, ttl_seconds: int = 300) -> bool:
"""Check if cache entry has expired (default 5 minutes)"""
return (time.time() - self.fetched_at) > ttl_seconds
class APIKeyNotConfiguredError(Exception):
"""Raised when no API key is configured for a tenant/provider"""
pass
class APIKeyClient:
"""
Client for fetching tenant API keys from Control Panel.
Features:
- 5-minute TTL cache for API keys
- Service-to-service authentication
- NO fallback to environment variables (per GT 2.0 NO FALLBACKS principle)
"""
CACHE_TTL_SECONDS = 300 # 5 minutes
def __init__(
self,
control_panel_url: str,
service_auth_token: str,
service_name: str = "resource-cluster"
):
self.control_panel_url = control_panel_url.rstrip('/')
self.service_auth_token = service_auth_token
self.service_name = service_name
# In-memory cache: key = "{tenant_domain}:{provider}"
self._cache: Dict[str, CachedAPIKey] = {}
self._cache_lock = asyncio.Lock()
def _get_headers(self) -> Dict[str, str]:
"""Get headers for service-to-service authentication"""
return {
"X-Service-Auth": self.service_auth_token,
"X-Service-Name": self.service_name,
"Content-Type": "application/json"
}
async def get_api_key(
self,
tenant_domain: str,
provider: str
) -> Dict[str, Any]:
"""
Get decrypted API key for a tenant and provider.
Args:
tenant_domain: Tenant domain string (e.g., "test-company")
provider: API provider name (e.g., "groq")
Returns:
Dict with 'api_key', 'api_secret' (optional), 'metadata'
Raises:
APIKeyNotConfiguredError: If API key not configured or disabled
RuntimeError: If Control Panel unreachable
"""
cache_key = f"{tenant_domain}:{provider}"
# Check cache first
async with self._cache_lock:
if cache_key in self._cache:
cached = self._cache[cache_key]
if not cached.is_expired(self.CACHE_TTL_SECONDS):
logger.debug(f"API key cache hit for {cache_key}")
return {
"api_key": cached.api_key,
"api_secret": cached.api_secret,
"metadata": cached.metadata
}
# Fetch from Control Panel
url = f"{self.control_panel_url}/internal/api-keys/{tenant_domain}/{provider}"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=self._get_headers())
if response.status_code == 404:
raise APIKeyNotConfiguredError(
f"No API key configured for provider '{provider}' "
f"for tenant '{tenant_domain}'. "
f"Please configure a {provider.upper()} API key in the Control Panel."
)
if response.status_code == 401:
raise RuntimeError("Service authentication failed - check SERVICE_AUTH_TOKEN")
if response.status_code == 403:
raise RuntimeError(f"Service '{self.service_name}' not authorized")
response.raise_for_status()
data = response.json()
# Update cache
async with self._cache_lock:
self._cache[cache_key] = CachedAPIKey(
api_key=data["api_key"],
api_secret=data.get("api_secret"),
metadata=data.get("metadata", {}),
fetched_at=time.time()
)
logger.info(f"Fetched API key for tenant '{tenant_domain}' provider '{provider}'")
return {
"api_key": data["api_key"],
"api_secret": data.get("api_secret"),
"metadata": data.get("metadata", {})
}
except httpx.HTTPStatusError as e:
logger.error(f"Control Panel API error: {e.response.status_code}")
if e.response.status_code == 404:
raise APIKeyNotConfiguredError(
f"No API key configured for provider '{provider}' "
f"for tenant '{tenant_domain}'"
)
raise RuntimeError(f"Control Panel API error: HTTP {e.response.status_code}")
except httpx.RequestError as e:
logger.error(f"Control Panel unreachable: {e}")
raise RuntimeError(f"Control Panel unreachable at {self.control_panel_url}")
async def invalidate_cache(
self,
tenant_domain: Optional[str] = None,
provider: Optional[str] = None
):
"""
Invalidate cached entries.
Args:
tenant_domain: If provided, only invalidate for this tenant
provider: If provided with tenant_domain, only invalidate this provider
"""
async with self._cache_lock:
if tenant_domain is None:
# Clear all
self._cache.clear()
logger.info("Cleared all API key caches")
elif provider:
# Clear specific tenant+provider
cache_key = f"{tenant_domain}:{provider}"
if cache_key in self._cache:
del self._cache[cache_key]
logger.info(f"Cleared cache for {cache_key}")
else:
# Clear all for tenant
keys_to_remove = [k for k in self._cache if k.startswith(f"{tenant_domain}:")]
for key in keys_to_remove:
del self._cache[key]
logger.info(f"Cleared cache for tenant: {tenant_domain}")
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics for monitoring"""
now = time.time()
valid_count = sum(
1 for k in self._cache.values()
if not k.is_expired(self.CACHE_TTL_SECONDS)
)
return {
"total_entries": len(self._cache),
"valid_entries": valid_count,
"cache_ttl_seconds": self.CACHE_TTL_SECONDS
}
# Singleton instance
_api_key_client: Optional[APIKeyClient] = None
def get_api_key_client() -> APIKeyClient:
"""Get or create the singleton API key client"""
global _api_key_client
if _api_key_client is None:
from app.core.config import get_settings
settings = get_settings()
_api_key_client = APIKeyClient(
control_panel_url=settings.control_panel_url,
service_auth_token=settings.service_auth_token,
service_name="resource-cluster"
)
return _api_key_client

View File

@@ -0,0 +1,3 @@
"""
Core utilities and configuration for Resource Cluster
"""

View File

@@ -0,0 +1,140 @@
"""
GT 2.0 Resource Cluster - API Standards Integration
This module integrates CB-REST standards for non-AI endpoints while
maintaining OpenAI compatibility for AI inference endpoints.
"""
import os
import sys
from pathlib import Path
# Add the api-standards package to the path
api_standards_path = Path(__file__).parent.parent.parent.parent.parent / "packages" / "api-standards" / "src"
if api_standards_path.exists():
sys.path.insert(0, str(api_standards_path))
# Import CB-REST standards
try:
from response import StandardResponse, format_response, format_error
from capability import (
init_capability_verifier,
verify_capability,
require_capability,
Capability,
CapabilityToken
)
from errors import ErrorCode, APIError, raise_api_error
from middleware import (
RequestCorrelationMiddleware,
CapabilityMiddleware,
TenantIsolationMiddleware,
RateLimitMiddleware
)
except ImportError as e:
# Fallback for development - create minimal implementations
print(f"Warning: Could not import api-standards package: {e}")
# Create minimal implementations for development
class StandardResponse:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def format_response(data, capability_used, request_id=None):
return {
"data": data,
"error": None,
"capability_used": capability_used,
"request_id": request_id or "dev-mode"
}
def format_error(code, message, capability_used="none", **kwargs):
return {
"data": None,
"error": {
"code": code,
"message": message,
**kwargs
},
"capability_used": capability_used,
"request_id": kwargs.get("request_id", "dev-mode")
}
class ErrorCode:
CAPABILITY_INSUFFICIENT = "CAPABILITY_INSUFFICIENT"
RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"
INVALID_REQUEST = "INVALID_REQUEST"
SYSTEM_ERROR = "SYSTEM_ERROR"
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"
class APIError(Exception):
def __init__(self, code, message, **kwargs):
self.code = code
self.message = message
self.kwargs = kwargs
super().__init__(message)
# Export all CB-REST components
__all__ = [
'StandardResponse',
'format_response',
'format_error',
'init_capability_verifier',
'verify_capability',
'require_capability',
'Capability',
'CapabilityToken',
'ErrorCode',
'APIError',
'raise_api_error',
'RequestCorrelationMiddleware',
'CapabilityMiddleware',
'TenantIsolationMiddleware',
'RateLimitMiddleware'
]
def setup_api_standards(app, secret_key: str):
"""
Setup API standards for the Resource Cluster
IMPORTANT: This only applies CB-REST to non-AI endpoints.
AI inference endpoints maintain OpenAI compatibility.
Args:
app: FastAPI application instance
secret_key: Secret key for JWT signing
"""
# Initialize capability verifier
if 'init_capability_verifier' in globals():
init_capability_verifier(secret_key)
# Add middleware in correct order
if 'RequestCorrelationMiddleware' in globals():
app.add_middleware(RequestCorrelationMiddleware)
if 'RateLimitMiddleware' in globals():
app.add_middleware(
RateLimitMiddleware,
requests_per_minute=1000 # Higher limit for resource cluster
)
# Note: No TenantIsolationMiddleware for Resource Cluster
# as it serves multiple tenants with capability-based access
if 'CapabilityMiddleware' in globals():
# Exclude AI inference endpoints from CB-REST middleware
# to maintain OpenAI compatibility
app.add_middleware(
CapabilityMiddleware,
exclude_paths=[
"/health",
"/ready",
"/metrics",
"/ai/chat/completions", # OpenAI compatible
"/ai/embeddings", # OpenAI compatible
"/ai/images/generations", # OpenAI compatible
"/ai/models" # OpenAI compatible
]
)

View File

@@ -0,0 +1,52 @@
"""
Resource backend implementations for GT 2.0
Provides unified interfaces for all resource types:
- LLM inference (Groq, OpenAI, Anthropic)
- Vector databases (PGVector)
- Document processing (Unstructured)
- External services (OAuth2, iframe)
- AI literacy resources
"""
from typing import Dict, Any
import logging
logger = logging.getLogger(__name__)
# Registry of available backends
BACKEND_REGISTRY: Dict[str, Any] = {}
def register_backend(name: str, backend_class):
"""Register a resource backend"""
BACKEND_REGISTRY[name] = backend_class
logger.info(f"Registered backend: {name}")
def get_backend(name: str):
"""Get a registered backend"""
if name not in BACKEND_REGISTRY:
raise ValueError(f"Backend not found: {name}")
return BACKEND_REGISTRY[name]
async def initialize_backends():
"""Initialize all resource backends"""
from app.core.backends.groq_proxy import GroqProxyBackend
from app.core.backends.nvidia_proxy import NvidiaProxyBackend
from app.core.backends.document_processor import DocumentProcessorBackend
from app.core.backends.embedding_backend import EmbeddingBackend
# Register backends
register_backend("groq_proxy", GroqProxyBackend())
register_backend("nvidia_proxy", NvidiaProxyBackend())
register_backend("document_processor", DocumentProcessorBackend())
register_backend("embedding", EmbeddingBackend())
logger.info("All resource backends initialized")
def get_embedding_backend():
"""Get the embedding backend instance"""
return get_backend("embedding")

View File

@@ -0,0 +1,322 @@
"""
Document Processing Backend
STATELESS document chunking and preprocessing for RAG operations.
All processing happens in memory - NO user data is ever stored.
"""
import logging
import io
import gc
from typing import Dict, Any, List, Optional, BinaryIO
from dataclasses import dataclass
import hashlib
# Document processing imports
import pypdf as PyPDF2
from docx import Document as DocxDocument
from bs4 import BeautifulSoup
from langchain_text_splitters import (
RecursiveCharacterTextSplitter,
TokenTextSplitter,
SentenceTransformersTokenTextSplitter
)
logger = logging.getLogger(__name__)
@dataclass
class ChunkingStrategy:
"""Configuration for document chunking"""
strategy_type: str # 'fixed', 'semantic', 'hierarchical', 'hybrid'
chunk_size: int # Target chunk size in tokens (optimized for BGE-M3: 512)
chunk_overlap: int # Overlap between chunks (typically 128 for BGE-M3)
separator_pattern: Optional[str] = None # Custom separator for splitting
preserve_paragraphs: bool = True
preserve_sentences: bool = True
class DocumentProcessorBackend:
"""
STATELESS document chunking and processing backend.
Security principles:
- NO persistence of user data
- All processing in memory only
- Immediate memory cleanup after processing
- No caching of user content
"""
def __init__(self):
self.supported_formats = [".pdf", ".docx", ".txt", ".md", ".html"]
# BGE-M3 optimal settings
self.default_chunk_size = 512 # tokens
self.default_chunk_overlap = 128 # tokens
self.model_name = "BAAI/bge-m3" # For tokenization
logger.info("STATELESS document processor backend initialized")
async def process_document(
self,
content: bytes,
document_type: str,
strategy: Optional[ChunkingStrategy] = None,
metadata: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Process document into chunks - STATELESS operation.
Args:
content: Document content as bytes (will be cleared from memory)
document_type: File type (.pdf, .docx, .txt, .md, .html)
strategy: Chunking strategy configuration
metadata: Optional metadata (will NOT include user content)
Returns:
List of chunks with metadata (immediately returned, not stored)
"""
try:
# Use default strategy if not provided
if strategy is None:
strategy = ChunkingStrategy(
strategy_type='hybrid',
chunk_size=self.default_chunk_size,
chunk_overlap=self.default_chunk_overlap
)
# Extract text based on document type (in memory)
text = await self._extract_text_from_bytes(content, document_type)
# Clear original content from memory
del content
gc.collect()
# Apply chunking strategy
if strategy.strategy_type == 'semantic':
chunks = await self._semantic_chunking(text, strategy)
elif strategy.strategy_type == 'hierarchical':
chunks = await self._hierarchical_chunking(text, strategy)
elif strategy.strategy_type == 'hybrid':
chunks = await self._hybrid_chunking(text, strategy)
else: # 'fixed'
chunks = await self._fixed_chunking(text, strategy)
# Clear text from memory
del text
gc.collect()
# Add metadata without storing content
processed_chunks = []
for idx, chunk in enumerate(chunks):
chunk_metadata = {
"chunk_index": idx,
"total_chunks": len(chunks),
"chunking_strategy": strategy.strategy_type,
"chunk_size_tokens": strategy.chunk_size,
# Generate hash for deduplication without storing content
"content_hash": hashlib.sha256(chunk.encode()).hexdigest()[:16]
}
# Add non-sensitive metadata if provided
if metadata:
# Filter out any potential sensitive data
safe_metadata = {
k: v for k, v in metadata.items()
if k in ['document_type', 'processing_timestamp', 'tenant_id']
}
chunk_metadata.update(safe_metadata)
processed_chunks.append({
"text": chunk,
"metadata": chunk_metadata
})
logger.info(f"Processed document into {len(processed_chunks)} chunks (STATELESS)")
# Return immediately - no storage
return processed_chunks
except Exception as e:
logger.error(f"Error processing document: {e}")
# Ensure memory is cleared even on error
gc.collect()
raise
finally:
# Always ensure memory cleanup
gc.collect()
async def _extract_text_from_bytes(
self,
content: bytes,
document_type: str
) -> str:
"""Extract text from document bytes - in memory only"""
try:
if document_type == ".pdf":
return await self._extract_pdf_text(io.BytesIO(content))
elif document_type == ".docx":
return await self._extract_docx_text(io.BytesIO(content))
elif document_type == ".html":
return await self._extract_html_text(content.decode('utf-8'))
elif document_type in [".txt", ".md"]:
return content.decode('utf-8')
else:
raise ValueError(f"Unsupported document type: {document_type}")
finally:
# Clear content from memory
del content
gc.collect()
async def _extract_pdf_text(self, file_stream: BinaryIO) -> str:
"""Extract text from PDF - in memory"""
text = ""
try:
pdf_reader = PyPDF2.PdfReader(file_stream)
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
text += page.extract_text() + "\n"
finally:
file_stream.close()
gc.collect()
return text
async def _extract_docx_text(self, file_stream: BinaryIO) -> str:
"""Extract text from DOCX - in memory"""
text = ""
try:
doc = DocxDocument(file_stream)
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
finally:
file_stream.close()
gc.collect()
return text
async def _extract_html_text(self, html_content: str) -> str:
"""Extract text from HTML - in memory"""
soup = BeautifulSoup(html_content, 'html.parser')
# Remove script and style elements
for script in soup(["script", "style"]):
script.decompose()
text = soup.get_text()
# Clean up whitespace
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n'.join(chunk for chunk in chunks if chunk)
return text
async def _semantic_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Semantic chunking using sentence boundaries"""
splitter = SentenceTransformersTokenTextSplitter(
model_name=self.model_name,
chunk_size=strategy.chunk_size,
chunk_overlap=strategy.chunk_overlap
)
return splitter.split_text(text)
async def _hierarchical_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Hierarchical chunking preserving document structure"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=strategy.chunk_size * 3, # Approximate token to char ratio
chunk_overlap=strategy.chunk_overlap * 3,
separators=["\n\n\n", "\n\n", "\n", ". ", " ", ""],
keep_separator=True
)
return splitter.split_text(text)
async def _hybrid_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Hybrid chunking combining semantic and structural boundaries"""
# First split by structure
structural_splitter = RecursiveCharacterTextSplitter(
chunk_size=strategy.chunk_size * 4,
chunk_overlap=0,
separators=["\n\n\n", "\n\n"],
keep_separator=True
)
structural_chunks = structural_splitter.split_text(text)
# Then apply semantic splitting to each structural chunk
final_chunks = []
token_splitter = TokenTextSplitter(
chunk_size=strategy.chunk_size,
chunk_overlap=strategy.chunk_overlap
)
for struct_chunk in structural_chunks:
semantic_chunks = token_splitter.split_text(struct_chunk)
final_chunks.extend(semantic_chunks)
return final_chunks
async def _fixed_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Fixed-size chunking with token boundaries"""
splitter = TokenTextSplitter(
chunk_size=strategy.chunk_size,
chunk_overlap=strategy.chunk_overlap
)
return splitter.split_text(text)
async def validate_document(
self,
content_size: int,
document_type: str
) -> Dict[str, Any]:
"""
Validate document before processing - no content stored.
Args:
content_size: Size of document in bytes
document_type: File extension
Returns:
Validation result with any warnings
"""
MAX_SIZE = 50 * 1024 * 1024 # 50MB max
validation = {
"valid": True,
"warnings": [],
"errors": []
}
# Check file size
if content_size > MAX_SIZE:
validation["valid"] = False
validation["errors"].append(f"File size exceeds maximum of 50MB")
elif content_size > 10 * 1024 * 1024: # Warning for files over 10MB
validation["warnings"].append("Large file may take longer to process")
# Check document type
if document_type not in self.supported_formats:
validation["valid"] = False
validation["errors"].append(f"Unsupported format: {document_type}")
return validation
async def check_health(self) -> Dict[str, Any]:
"""Check document processor health - no user data exposed"""
return {
"status": "healthy",
"supported_formats": self.supported_formats,
"default_chunk_size": self.default_chunk_size,
"default_chunk_overlap": self.default_chunk_overlap,
"model": self.model_name,
"stateless": True, # Confirm stateless operation
"memory_cleared": True # Confirm memory management
}

View File

@@ -0,0 +1,471 @@
"""
Embedding Model Backend
STATELESS embedding generation using BGE-M3 model hosted on GT's GPU clusters.
All embeddings are generated in real-time - NO user data is stored.
"""
import logging
import gc
import hashlib
import asyncio
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
# import numpy as np # Temporarily disabled for Docker build
import aiohttp
import json
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class EmbeddingRequest:
"""Request structure for embedding generation"""
texts: List[str]
model: str = "BAAI/bge-m3"
batch_size: int = 32
normalize: bool = True
instruction: Optional[str] = None # For instruction-based embeddings
class EmbeddingBackend:
"""
STATELESS embedding backend for BGE-M3 model.
Security principles:
- NO persistence of embeddings or text
- All processing via GT's internal GPU cluster
- Immediate memory cleanup after generation
- No caching of user content
- Request signing and verification
"""
def __init__(self):
self.model_name = "BAAI/bge-m3"
self.embedding_dimensions = 1024 # BGE-M3 dimensions
self.max_batch_size = 32
self.max_sequence_length = 8192 # BGE-M3 supports up to 8192 tokens
# Determine endpoint based on configuration
self.embedding_endpoint = self._get_embedding_endpoint()
# Timeout for embedding requests
self.request_timeout = 60 # seconds for model loading
logger.info(f"STATELESS embedding backend initialized for {self.model_name}")
logger.info(f"Using embedding endpoint: {self.embedding_endpoint}")
def _get_embedding_endpoint(self) -> str:
"""
Get the embedding endpoint based on configuration.
Priority:
1. Model registry from config sync (database-backed)
2. Environment variables (BGE_M3_LOCAL_MODE, BGE_M3_EXTERNAL_ENDPOINT)
3. Default local endpoint
"""
# Try to get configuration from model registry first (loaded from database)
try:
from app.services.model_service import default_model_service
import asyncio
# Use the default model service instance (singleton) used by config sync
model_service = default_model_service
# Try to get the model config synchronously (during initialization)
# The get_model method is async, so we need to handle this carefully
bge_m3_config = model_service.model_registry.get("BAAI/bge-m3")
if bge_m3_config:
# Model registry stores endpoint as 'endpoint_url' and config as 'parameters'
endpoint = bge_m3_config.get("endpoint_url")
config = bge_m3_config.get("parameters", {})
is_local_mode = config.get("is_local_mode", True)
external_endpoint = config.get("external_endpoint")
logger.info(f"Found BGE-M3 in registry: endpoint_url={endpoint}, is_local_mode={is_local_mode}, external_endpoint={external_endpoint}")
if endpoint:
logger.info(f"Using BGE-M3 endpoint from model registry (is_local_mode={is_local_mode}): {endpoint}")
return endpoint
else:
logger.warning(f"BGE-M3 found in registry but endpoint_url is None/empty. Full config: {bge_m3_config}")
else:
available_models = list(model_service.model_registry.keys())
logger.debug(f"BGE-M3 not found in model registry during init (expected on first startup). Available models: {available_models}")
except Exception as e:
logger.debug(f"Model registry not yet available during startup (will be populated after config sync): {e}")
# Fall back to Settings fields (environment variables or .env file)
is_local_mode = getattr(settings, 'bge_m3_local_mode', True)
external_endpoint = getattr(settings, 'bge_m3_external_endpoint', None)
if not is_local_mode and external_endpoint:
logger.info(f"Using external BGE-M3 endpoint from settings: {external_endpoint}")
return external_endpoint
# Default to local endpoint
local_endpoint = getattr(
settings,
'embedding_endpoint',
'http://gentwo-vllm-embeddings:8000/v1/embeddings'
)
logger.info(f"Using local BGE-M3 endpoint: {local_endpoint}")
return local_endpoint
async def update_endpoint_config(self, is_local_mode: bool, external_endpoint: str = None):
"""
Update the embedding endpoint configuration dynamically.
This allows switching between local and external endpoints without restart.
"""
if is_local_mode:
self.embedding_endpoint = getattr(
settings,
'embedding_endpoint',
'http://gentwo-vllm-embeddings:8000/v1/embeddings'
)
else:
if external_endpoint:
self.embedding_endpoint = external_endpoint
else:
raise ValueError("External endpoint must be provided when not in local mode")
logger.info(f"BGE-M3 endpoint updated to: {self.embedding_endpoint}")
logger.info(f"Mode: {'Local GT Edge' if is_local_mode else 'External API'}")
def refresh_endpoint_from_registry(self):
"""
Refresh the embedding endpoint from the model registry.
Called by config sync when BGE-M3 configuration changes.
"""
logger.info(f"Refreshing embedding endpoint - current: {self.embedding_endpoint}")
new_endpoint = self._get_embedding_endpoint()
if new_endpoint != self.embedding_endpoint:
logger.info(f"Refreshing BGE-M3 endpoint from {self.embedding_endpoint} to {new_endpoint}")
self.embedding_endpoint = new_endpoint
else:
logger.info(f"BGE-M3 endpoint unchanged: {self.embedding_endpoint}")
async def generate_embeddings(
self,
texts: List[str],
instruction: Optional[str] = None,
tenant_id: str = None,
request_id: str = None
) -> List[List[float]]:
"""
Generate embeddings for texts using BGE-M3 - STATELESS operation.
Args:
texts: List of texts to embed (will be cleared from memory)
instruction: Optional instruction for query vs document embeddings
tenant_id: Tenant ID for audit logging (not stored with data)
request_id: Request ID for tracing
Returns:
List of embedding vectors (immediately returned, not stored)
"""
try:
# Validate input
if not texts:
return []
if len(texts) > self.max_batch_size:
# Process in batches
return await self._batch_process_embeddings(
texts, instruction, tenant_id, request_id
)
# Prepare request
request_data = {
"model": self.model_name,
"input": texts,
"encoding_format": "float",
"dimensions": self.embedding_dimensions
}
# Add instruction if provided (for query vs document distinction)
if instruction:
request_data["instruction"] = instruction
# Add metadata for audit (not stored with embeddings)
metadata = {
"tenant_id": tenant_id,
"request_id": request_id,
"text_count": len(texts),
# Hash for deduplication without storing content
"content_hash": hashlib.sha256(
"".join(texts).encode()
).hexdigest()[:16]
}
# Call vLLM service - NO FALLBACKS
embeddings = await self._call_embedding_service(request_data, metadata)
# Clear texts from memory immediately
del texts
gc.collect()
# Validate response
if not embeddings or len(embeddings) == 0:
raise ValueError("No embeddings returned from service")
# Normalize if needed
if self._should_normalize():
embeddings = self._normalize_embeddings(embeddings)
logger.info(
f"Generated {len(embeddings)} embeddings (STATELESS) "
f"for tenant {tenant_id}"
)
# Return immediately - no storage
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
# Ensure memory is cleared even on error
gc.collect()
raise
finally:
# Always ensure memory cleanup
gc.collect()
async def _batch_process_embeddings(
self,
texts: List[str],
instruction: Optional[str],
tenant_id: str,
request_id: str
) -> List[List[float]]:
"""Process large text lists in batches using vLLM service"""
all_embeddings = []
for i in range(0, len(texts), self.max_batch_size):
batch = texts[i:i + self.max_batch_size]
# Prepare request for this batch
request_data = {
"model": self.model_name,
"input": batch,
"encoding_format": "float",
"dimensions": self.embedding_dimensions
}
if instruction:
request_data["instruction"] = instruction
metadata = {
"tenant_id": tenant_id,
"request_id": f"{request_id}_batch_{i}",
"text_count": len(batch),
"content_hash": hashlib.sha256(
"".join(batch).encode()
).hexdigest()[:16]
}
batch_embeddings = await self._call_embedding_service(request_data, metadata)
all_embeddings.extend(batch_embeddings)
# Clear batch from memory
del batch
gc.collect()
return all_embeddings
async def _call_embedding_service(
self,
request_data: Dict[str, Any],
metadata: Dict[str, Any]
) -> List[List[float]]:
"""Call internal GPU cluster embedding service"""
async with aiohttp.ClientSession() as session:
try:
# Add capability token for authentication
headers = {
"Content-Type": "application/json",
"X-Tenant-ID": metadata.get("tenant_id", ""),
"X-Request-ID": metadata.get("request_id", ""),
# Authorization will be added by Resource Cluster
}
async with session.post(
self.embedding_endpoint,
json=request_data,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self.request_timeout)
) as response:
if response.status != 200:
error_text = await response.text()
raise ValueError(
f"Embedding service error: {response.status} - {error_text}"
)
result = await response.json()
# Extract embeddings from response
if "data" in result:
embeddings = [item["embedding"] for item in result["data"]]
elif "embeddings" in result:
embeddings = result["embeddings"]
else:
raise ValueError("Invalid embedding service response format")
return embeddings
except asyncio.TimeoutError:
raise ValueError(f"Embedding service timeout after {self.request_timeout}s")
except Exception as e:
logger.error(f"Error calling embedding service: {e}")
raise
def _should_normalize(self) -> bool:
"""Check if embeddings should be normalized"""
# BGE-M3 embeddings are typically normalized for similarity search
return True
def _normalize_embeddings(
self,
embeddings: List[List[float]]
) -> List[List[float]]:
"""Normalize embedding vectors to unit length"""
normalized = []
for embedding in embeddings:
# Simple normalization without numpy (for now)
import math
# Calculate norm
norm = math.sqrt(sum(x * x for x in embedding))
if norm > 0:
normalized_vec = [x / norm for x in embedding]
else:
normalized_vec = embedding[:]
normalized.append(normalized_vec)
return normalized
async def generate_query_embeddings(
self,
queries: List[str],
tenant_id: str = None,
request_id: str = None
) -> List[List[float]]:
"""
Generate embeddings specifically for queries.
BGE-M3 can use different instructions for queries vs documents.
"""
# For BGE-M3, queries can use a specific instruction
instruction = "Represent this sentence for searching relevant passages: "
return await self.generate_embeddings(
queries, instruction, tenant_id, request_id
)
async def generate_document_embeddings(
self,
documents: List[str],
tenant_id: str = None,
request_id: str = None
) -> List[List[float]]:
"""
Generate embeddings specifically for documents.
BGE-M3 can use different instructions for documents vs queries.
"""
# For BGE-M3, documents typically don't need special instruction
return await self.generate_embeddings(
documents, None, tenant_id, request_id
)
async def validate_texts(
self,
texts: List[str]
) -> Dict[str, Any]:
"""
Validate texts before embedding - no content stored.
Args:
texts: List of texts to validate
Returns:
Validation result with any warnings
"""
validation = {
"valid": True,
"warnings": [],
"errors": [],
"stats": {
"total_texts": len(texts),
"max_length": 0,
"avg_length": 0
}
}
if not texts:
validation["valid"] = False
validation["errors"].append("No texts provided")
return validation
# Check text lengths
lengths = [len(text) for text in texts]
validation["stats"]["max_length"] = max(lengths)
validation["stats"]["avg_length"] = sum(lengths) // len(lengths)
# BGE-M3 max sequence length check (approximate)
max_chars = self.max_sequence_length * 4 # Rough char to token ratio
for i, length in enumerate(lengths):
if length > max_chars:
validation["warnings"].append(
f"Text {i} may exceed model's max sequence length"
)
elif length == 0:
validation["errors"].append(f"Text {i} is empty")
validation["valid"] = False
# Batch size check
if len(texts) > self.max_batch_size * 10:
validation["warnings"].append(
f"Large batch ({len(texts)} texts) will be processed in chunks"
)
return validation
async def check_health(self) -> Dict[str, Any]:
"""Check embedding backend health - no user data exposed"""
try:
# Test connection to vLLM service
test_text = ["Health check test"]
test_embeddings = await self.generate_embeddings(
test_text,
tenant_id="health_check",
request_id="health_check"
)
health_status = {
"status": "healthy",
"model": self.model_name,
"dimensions": self.embedding_dimensions,
"max_batch_size": self.max_batch_size,
"max_sequence_length": self.max_sequence_length,
"endpoint": self.embedding_endpoint,
"stateless": True,
"memory_cleared": True,
"vllm_service_connected": len(test_embeddings) > 0
}
except Exception as e:
health_status = {
"status": "unhealthy",
"error": str(e),
"model": self.model_name,
"endpoint": self.embedding_endpoint
}
return health_status

View File

@@ -0,0 +1,780 @@
"""
Groq Cloud LLM Proxy Backend
Provides high-availability LLM inference through Groq Cloud with:
- HAProxy load balancing across multiple endpoints
- Automatic failover handled by HAProxy
- Token usage tracking and cost calculation
- Streaming response support
- Circuit breaker pattern for enhanced reliability
"""
import asyncio
import json
import os
import time
from typing import Dict, Any, List, Optional, AsyncGenerator
from datetime import datetime
import httpx
try:
from groq import AsyncGroq
GROQ_AVAILABLE = True
except ImportError:
# Groq not available in development mode
AsyncGroq = None
GROQ_AVAILABLE = False
import logging
from app.core.config import get_settings, get_model_configs
from app.services.model_service import get_model_service
logger = logging.getLogger(__name__)
settings = get_settings()
# Groq Compound tool pricing (per request/execution)
# Source: https://groq.com/pricing (Dec 2, 2025)
COMPOUND_TOOL_PRICES = {
# Web Search variants
"search": 0.008, # API returns "search" for web search
"web_search": 0.008, # $8 per 1K = $0.008 per request (Advanced Search)
"advanced_search": 0.008, # $8 per 1K requests
"basic_search": 0.005, # $5 per 1K requests
# Other tools
"visit_website": 0.001, # $1 per 1K requests
"python": 0.00005, # API returns "python" for code execution
"code_interpreter": 0.00005, # Alternative API identifier
"code_execution": 0.00005, # Alias for backwards compatibility
"browser_automation": 0.00002, # $0.08/hr ≈ $0.00002 per execution
}
# Model pricing per million tokens (input/output)
# Source: https://groq.com/pricing (Dec 2, 2025)
GROQ_MODEL_PRICES = {
"llama-3.3-70b-versatile": {"input": 0.59, "output": 0.79},
"llama-3.1-8b-instant": {"input": 0.05, "output": 0.08},
"llama-4-maverick-17b-128e-instruct": {"input": 0.20, "output": 0.60},
"meta-llama/llama-4-maverick-17b-128e-instruct": {"input": 0.20, "output": 0.60},
"llama-4-scout-17b-16e-instruct": {"input": 0.11, "output": 0.34},
"meta-llama/llama-4-scout-17b-16e-instruct": {"input": 0.11, "output": 0.34},
"llama-guard-4-12b": {"input": 0.20, "output": 0.20},
"meta-llama/llama-guard-4-12b": {"input": 0.20, "output": 0.20},
"gpt-oss-120b": {"input": 0.15, "output": 0.60},
"openai/gpt-oss-120b": {"input": 0.15, "output": 0.60},
"gpt-oss-20b": {"input": 0.075, "output": 0.30},
"openai/gpt-oss-20b": {"input": 0.075, "output": 0.30},
"kimi-k2-instruct-0905": {"input": 1.00, "output": 3.00},
"moonshotai/kimi-k2-instruct-0905": {"input": 1.00, "output": 3.00},
"qwen3-32b": {"input": 0.29, "output": 0.59},
# Compound models - 50/50 blended pricing from underlying models
# compound: GPT-OSS-120B ($0.15/$0.60) + Llama 4 Scout ($0.11/$0.34) = $0.13/$0.47
"compound": {"input": 0.13, "output": 0.47},
"groq/compound": {"input": 0.13, "output": 0.47},
"compound-beta": {"input": 0.13, "output": 0.47},
# compound-mini: GPT-OSS-120B ($0.15/$0.60) + Llama 3.3 70B ($0.59/$0.79) = $0.37/$0.695
"compound-mini": {"input": 0.37, "output": 0.695},
"groq/compound-mini": {"input": 0.37, "output": 0.695},
"compound-mini-beta": {"input": 0.37, "output": 0.695},
}
class GroqProxyBackend:
"""LLM inference via Groq Cloud with HAProxy load balancing"""
def __init__(self):
self.settings = get_settings()
self.client = None
self.usage_metrics = {}
self.circuit_breaker_status = {}
self._initialize_client()
def _initialize_client(self):
"""Initialize Groq client to use HAProxy load balancer"""
if not GROQ_AVAILABLE:
logger.warning("Groq client not available - running in development mode")
return
if self.settings.groq_api_key:
# Use HAProxy load balancer instead of direct Groq API
haproxy_endpoint = self.settings.haproxy_groq_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local"
# Initialize client with HAProxy endpoint
self.client = AsyncGroq(
api_key=self.settings.groq_api_key,
base_url=haproxy_endpoint,
timeout=httpx.Timeout(30.0), # Increased timeout for load balancing
max_retries=1 # Let HAProxy handle retries
)
# Initialize circuit breaker
self.circuit_breaker_status = {
"state": "closed", # closed, open, half_open
"failure_count": 0,
"last_failure_time": None,
"failure_threshold": 5,
"recovery_timeout": 60 # seconds
}
logger.info(f"Initialized Groq client with HAProxy endpoint: {haproxy_endpoint}")
async def execute_inference(
self,
prompt: str,
model: str = "llama-3.1-70b-versatile",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None
) -> Dict[str, Any]:
"""Execute LLM inference with HAProxy load balancing and circuit breaker"""
# Check circuit breaker
if not await self._is_circuit_closed():
raise Exception("Circuit breaker is open - service temporarily unavailable")
# Validate model and get configuration
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
if not model_config:
# Try to get from model service registry
model_service = get_model_service(tenant_id)
model_info = await model_service.get_model(model)
if not model_info:
raise ValueError(f"Unsupported model: {model}")
model_config = {
"max_tokens": model_info["performance"]["max_tokens"],
"cost_per_1k_tokens": model_info["performance"]["cost_per_1k_tokens"],
"supports_streaming": model_info["capabilities"].get("streaming", False)
}
# Apply token limits
max_tokens = min(max_tokens, model_config["max_tokens"])
# Prepare messages
messages = [
{"role": "user", "content": prompt}
]
try:
# Get tenant-specific API key
if not tenant_id:
raise ValueError("tenant_id is required for Groq inference")
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
start_time = time.time()
if stream:
return await self._stream_inference(
messages, model, temperature, max_tokens, user_id, tenant_id, client
)
else:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=False
)
# Track successful usage
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
response.usage.total_tokens if response.usage else 0,
latency, model_config["cost_per_1k_tokens"]
)
# Track in model service
model_service = get_model_service(tenant_id)
await model_service.track_model_usage(
model_id=model,
success=True,
latency_ms=latency
)
# Reset circuit breaker on success
await self._record_success()
return {
"content": response.choices[0].message.content,
"model": model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0,
"cost_cents": self._calculate_cost(
response.usage.total_tokens if response.usage else 0,
model_config["cost_per_1k_tokens"]
)
},
"latency_ms": latency,
"load_balanced": True,
"haproxy_backend": "groq_general_backend"
}
except Exception as e:
logger.error(f"HAProxy Groq inference failed: {e}")
# Track failure in model service
await model_service.track_model_usage(
model_id=model,
success=False
)
# Record failure for circuit breaker
await self._record_failure()
# Re-raise the exception - no client-side fallback needed
# HAProxy handles all failover logic
raise Exception(f"Groq inference failed (via HAProxy): {str(e)}")
async def _stream_inference(
self,
messages: List[Dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
user_id: str,
tenant_id: str,
client: AsyncGroq = None
) -> AsyncGenerator[str, None]:
"""Stream LLM inference responses"""
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
start_time = time.time()
total_tokens = 0
try:
# Use provided client or get tenant-specific client
if not client:
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
stream = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
async for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
total_tokens += len(content.split()) # Approximate token count
# Yield SSE formatted data
yield f"data: {json.dumps({'content': content})}\n\n"
# Track usage after streaming completes
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
total_tokens, latency,
model_config["cost_per_1k_tokens"]
)
# Send completion signal
yield f"data: {json.dumps({'done': True})}\n\n"
except Exception as e:
logger.error(f"Streaming inference error: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
async def check_health(self) -> Dict[str, Any]:
"""Check health of HAProxy load balancer and circuit breaker status"""
try:
# Check HAProxy health via stats endpoint
haproxy_stats_url = self.settings.haproxy_stats_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local:8404/stats"
async with httpx.AsyncClient() as client:
response = await client.get(
haproxy_stats_url,
timeout=5.0,
auth=("admin", "gt2_haproxy_stats_password")
)
if response.status_code == 200:
# Parse HAProxy stats (simplified)
stats_healthy = "UP" in response.text
return {
"haproxy_load_balancer": {
"healthy": stats_healthy,
"stats_accessible": True,
"last_check": datetime.utcnow().isoformat()
},
"circuit_breaker": {
"state": self.circuit_breaker_status["state"],
"failure_count": self.circuit_breaker_status["failure_count"],
"last_failure": self.circuit_breaker_status["last_failure_time"].isoformat() if self.circuit_breaker_status["last_failure_time"] else None
},
"groq_endpoints": {
"managed_by": "haproxy",
"failover_handled_by": "haproxy"
}
}
else:
return {
"haproxy_load_balancer": {
"healthy": False,
"error": f"Stats endpoint returned {response.status_code}",
"last_check": datetime.utcnow().isoformat()
}
}
except Exception as e:
return {
"haproxy_load_balancer": {
"healthy": False,
"error": str(e),
"last_check": datetime.utcnow().isoformat()
},
"circuit_breaker": {
"state": self.circuit_breaker_status["state"],
"failure_count": self.circuit_breaker_status["failure_count"]
}
}
async def _is_circuit_closed(self) -> bool:
"""Check if circuit breaker allows requests"""
if self.circuit_breaker_status["state"] == "closed":
return True
if self.circuit_breaker_status["state"] == "open":
# Check if recovery timeout has passed
if self.circuit_breaker_status["last_failure_time"]:
time_since_failure = (datetime.utcnow() - self.circuit_breaker_status["last_failure_time"]).total_seconds()
if time_since_failure > self.circuit_breaker_status["recovery_timeout"]:
# Move to half-open state
self.circuit_breaker_status["state"] = "half_open"
logger.info("Circuit breaker moved to half-open state")
return True
return False
if self.circuit_breaker_status["state"] == "half_open":
# Allow limited requests in half-open state
return True
return False
async def _record_success(self):
"""Record successful request for circuit breaker"""
if self.circuit_breaker_status["state"] == "half_open":
# Success in half-open state closes the circuit
self.circuit_breaker_status["state"] = "closed"
self.circuit_breaker_status["failure_count"] = 0
logger.info("Circuit breaker closed after successful request")
# Reset failure count on any success
self.circuit_breaker_status["failure_count"] = 0
async def _record_failure(self):
"""Record failed request for circuit breaker"""
self.circuit_breaker_status["failure_count"] += 1
self.circuit_breaker_status["last_failure_time"] = datetime.utcnow()
if self.circuit_breaker_status["failure_count"] >= self.circuit_breaker_status["failure_threshold"]:
if self.circuit_breaker_status["state"] in ["closed", "half_open"]:
self.circuit_breaker_status["state"] = "open"
logger.warning(f"Circuit breaker opened after {self.circuit_breaker_status['failure_count']} failures")
async def _track_usage(
self,
user_id: str,
tenant_id: str,
model: str,
tokens: int,
latency: float,
cost_per_1k: float
):
"""Track usage metrics for billing and monitoring"""
# Create usage key
usage_key = f"{tenant_id}:{user_id}:{model}"
# Initialize metrics if not exists
if usage_key not in self.usage_metrics:
self.usage_metrics[usage_key] = {
"total_tokens": 0,
"total_requests": 0,
"total_cost_cents": 0,
"average_latency": 0
}
# Update metrics
metrics = self.usage_metrics[usage_key]
metrics["total_tokens"] += tokens
metrics["total_requests"] += 1
metrics["total_cost_cents"] += self._calculate_cost(tokens, cost_per_1k)
# Update average latency
prev_avg = metrics["average_latency"]
prev_count = metrics["total_requests"] - 1
metrics["average_latency"] = (prev_avg * prev_count + latency) / metrics["total_requests"]
# Log high-level metrics
if metrics["total_requests"] % 100 == 0:
logger.info(f"Usage milestone for {usage_key}: {metrics}")
def _calculate_cost(self, tokens: int, cost_per_1k: float) -> int:
"""Calculate cost in cents"""
return int((tokens / 1000) * cost_per_1k * 100)
def _calculate_compound_cost(self, response_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Calculate detailed cost breakdown for Groq Compound responses.
Compound API returns usage_breakdown with per-model token counts
and executed_tools list showing which tools were called.
Returns:
Dict with total cost in dollars and detailed breakdown
"""
total_cost = 0.0
breakdown = {"models": [], "tools": [], "total_cost_dollars": 0.0, "total_cost_cents": 0}
# Parse usage_breakdown for per-model token costs
usage_breakdown = response_data.get("usage_breakdown", {})
models_usage = usage_breakdown.get("models", [])
for model_usage in models_usage:
model_name = model_usage.get("model", "")
usage = model_usage.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
# Get model pricing (try multiple name formats)
model_prices = GROQ_MODEL_PRICES.get(model_name)
if not model_prices:
# Try without provider prefix
short_name = model_name.split("/")[-1] if "/" in model_name else model_name
model_prices = GROQ_MODEL_PRICES.get(short_name, {"input": 0.15, "output": 0.60})
# Calculate cost per million tokens
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
model_total = input_cost + output_cost
breakdown["models"].append({
"model": model_name,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"input_cost_dollars": round(input_cost, 6),
"output_cost_dollars": round(output_cost, 6),
"total_cost_dollars": round(model_total, 6)
})
total_cost += model_total
# Parse executed_tools for tool costs
executed_tools = response_data.get("executed_tools", [])
for tool in executed_tools:
# Handle both string and dict formats
tool_name = tool if isinstance(tool, str) else tool.get("name", "unknown")
tool_cost = COMPOUND_TOOL_PRICES.get(tool_name.lower(), 0.008) # Default to advanced search
breakdown["tools"].append({
"tool": tool_name,
"cost_dollars": round(tool_cost, 6)
})
total_cost += tool_cost
breakdown["total_cost_dollars"] = round(total_cost, 6)
breakdown["total_cost_cents"] = int(total_cost * 100)
return breakdown
def _is_compound_model(self, model: str) -> bool:
"""Check if model is a Groq Compound model"""
model_lower = model.lower()
return "compound" in model_lower or model_lower.startswith("groq/compound")
async def get_available_models(self) -> List[Dict[str, Any]]:
"""Get list of available Groq models with their configurations"""
models = []
model_configs = get_model_configs()
for model_id, config in model_configs.get("groq", {}).items():
models.append({
"id": model_id,
"name": model_id.replace("-", " ").title(),
"provider": "groq",
"max_tokens": config["max_tokens"],
"cost_per_1k_tokens": config["cost_per_1k_tokens"],
"supports_streaming": config["supports_streaming"],
"supports_function_calling": config["supports_function_calling"]
})
return models
async def execute_inference_with_messages(
self,
messages: List[Dict[str, str]],
model: str = "llama-3.1-70b-versatile",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None
) -> Dict[str, Any]:
"""Execute LLM inference using messages format (conversation style)"""
# Check circuit breaker
if not await self._is_circuit_closed():
raise Exception("Circuit breaker is open - service temporarily unavailable")
# Validate model and get configuration
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
if not model_config:
# Try to get from model service registry
model_service = get_model_service(tenant_id)
model_info = await model_service.get_model(model)
if not model_info:
raise ValueError(f"Unsupported model: {model}")
model_config = {
"max_tokens": model_info["performance"]["max_tokens"],
"cost_per_1k_tokens": model_info["performance"]["cost_per_1k_tokens"],
"supports_streaming": model_info["capabilities"].get("streaming", False)
}
# Apply token limits
max_tokens = min(max_tokens, model_config["max_tokens"])
try:
# Get tenant-specific API key
if not tenant_id:
raise ValueError("tenant_id is required for Groq inference")
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
start_time = time.time()
# Translate GT 2.0 "agent" role to OpenAI/Groq "assistant" for external API compatibility
# Use dictionary unpacking to preserve ALL fields including tool_call_id
external_messages = []
for msg in messages:
external_msg = {
**msg, # Preserve ALL fields including tool_call_id, tool_calls, etc.
"role": "assistant" if msg.get("role") == "agent" else msg.get("role")
}
external_messages.append(external_msg)
if stream:
return await self._stream_inference_with_messages(
external_messages, model, temperature, max_tokens, user_id, tenant_id, client
)
else:
# Prepare request parameters
request_params = {
"model": model,
"messages": external_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": False
}
# Add tools if provided
if tools:
request_params["tools"] = tools
if tool_choice:
request_params["tool_choice"] = tool_choice
# Debug: Log messages being sent to Groq
logger.info(f"🔧 Sending {len(external_messages)} messages to Groq API")
for i, msg in enumerate(external_messages):
if msg.get("role") == "tool":
logger.info(f"🔧 Groq Message {i}: role=tool, tool_call_id={msg.get('tool_call_id')}")
else:
logger.info(f"🔧 Groq Message {i}: role={msg.get('role')}, has_tool_calls={bool(msg.get('tool_calls'))}")
response = await client.chat.completions.create(**request_params)
# Track successful usage
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
response.usage.total_tokens if response.usage else 0,
latency, model_config["cost_per_1k_tokens"]
)
# Track in model service
model_service = get_model_service(tenant_id)
await model_service.track_model_usage(
model_id=model,
success=True,
latency_ms=latency
)
# Reset circuit breaker on success
await self._record_success()
# Build base response
result = {
"content": response.choices[0].message.content,
"model": model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0,
"cost_cents": self._calculate_cost(
response.usage.total_tokens if response.usage else 0,
model_config["cost_per_1k_tokens"]
)
},
"latency_ms": latency,
"load_balanced": True,
"haproxy_backend": "groq_general_backend"
}
# For Compound models, extract and calculate detailed cost breakdown
if self._is_compound_model(model):
# Convert response to dict for processing
response_dict = response.model_dump() if hasattr(response, 'model_dump') else {}
# Extract usage_breakdown and executed_tools if present
usage_breakdown = getattr(response, 'usage_breakdown', None)
executed_tools = getattr(response, 'executed_tools', None)
if usage_breakdown or executed_tools:
compound_data = {
"usage_breakdown": usage_breakdown if isinstance(usage_breakdown, dict) else {},
"executed_tools": executed_tools if isinstance(executed_tools, list) else []
}
# Calculate detailed cost breakdown
cost_breakdown = self._calculate_compound_cost(compound_data)
# Add compound-specific data to response
result["usage_breakdown"] = compound_data.get("usage_breakdown", {})
result["executed_tools"] = compound_data.get("executed_tools", [])
result["cost_breakdown"] = cost_breakdown
# Update cost_cents with accurate compound calculation
if cost_breakdown["total_cost_cents"] > 0:
result["usage"]["cost_cents"] = cost_breakdown["total_cost_cents"]
logger.info(f"Compound model cost breakdown: {cost_breakdown}")
return result
except Exception as e:
logger.error(f"HAProxy Groq inference with messages failed: {e}")
# Track failure in model service
await model_service.track_model_usage(
model_id=model,
success=False
)
# Record failure for circuit breaker
await self._record_failure()
# Re-raise the exception
raise Exception(f"Groq inference with messages failed (via HAProxy): {str(e)}")
async def _stream_inference_with_messages(
self,
messages: List[Dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
user_id: str,
tenant_id: str,
client: AsyncGroq = None
) -> AsyncGenerator[str, None]:
"""Stream LLM inference responses using messages format"""
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
start_time = time.time()
total_tokens = 0
try:
# Use provided client or get tenant-specific client
if not client:
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
stream = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
async for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
total_tokens += len(content.split()) # Approximate token count
# Yield just the content (SSE formatting handled by caller)
yield content
# Track usage after streaming completes
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
total_tokens, latency,
model_config["cost_per_1k_tokens"] if model_config else 0.0
)
except Exception as e:
logger.error(f"Streaming inference with messages error: {e}")
raise e
async def _get_tenant_api_key(self, tenant_id: str) -> str:
"""
Get API key for tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
API keys are managed in Control Panel and fetched via internal API.
Args:
tenant_id: Tenant domain string from X-Tenant-ID header
Returns:
Decrypted Groq API key
Raises:
ValueError: If no API key configured (results in HTTP 503 to client)
"""
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="groq")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No Groq API key for tenant '{tenant_id}': {e}")
raise ValueError(str(e))
except RuntimeError as e:
logger.error(f"Control Panel error: {e}")
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
def _get_client(self, api_key: str) -> AsyncGroq:
"""Get Groq client with specified API key"""
if not GROQ_AVAILABLE:
raise Exception("Groq client not available in development mode")
haproxy_endpoint = self.settings.haproxy_groq_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local"
return AsyncGroq(
api_key=api_key,
base_url=haproxy_endpoint,
timeout=httpx.Timeout(30.0),
max_retries=1
)

View File

@@ -0,0 +1,407 @@
"""
NVIDIA NIM LLM Proxy Backend
Provides LLM inference through NVIDIA NIM with:
- OpenAI-compatible API format (build.nvidia.com)
- Token usage tracking and cost calculation
- Streaming response support
- Circuit breaker pattern for enhanced reliability
"""
import json
import time
from typing import Dict, Any, List, Optional, AsyncGenerator
from datetime import datetime
import httpx
import logging
from app.core.config import get_settings
logger = logging.getLogger(__name__)
# NVIDIA NIM Model pricing per million tokens (input/output)
# Source: build.nvidia.com (Dec 2025 pricing estimates)
# Note: Actual pricing may vary - check build.nvidia.com for current rates
NVIDIA_MODEL_PRICES = {
# Llama Nemotron family
"nvidia/llama-3.1-nemotron-ultra-253b-v1": {"input": 2.0, "output": 6.0},
"nvidia/llama-3.1-nemotron-super-49b-v1": {"input": 0.5, "output": 1.5},
"nvidia/llama-3.1-nemotron-nano-8b-v1": {"input": 0.1, "output": 0.3},
# Standard Llama models via NIM
"meta/llama-3.1-8b-instruct": {"input": 0.1, "output": 0.3},
"meta/llama-3.1-70b-instruct": {"input": 0.5, "output": 1.0},
"meta/llama-3.1-405b-instruct": {"input": 2.0, "output": 6.0},
# Mistral models
"mistralai/mistral-7b-instruct-v0.3": {"input": 0.1, "output": 0.2},
"mistralai/mixtral-8x7b-instruct-v0.1": {"input": 0.3, "output": 0.6},
# Default fallback
"default": {"input": 0.5, "output": 1.5},
}
class NvidiaProxyBackend:
"""LLM inference via NVIDIA NIM with OpenAI-compatible API"""
def __init__(self):
self.settings = get_settings()
self.base_url = getattr(self.settings, 'nvidia_nim_endpoint', None) or "https://integrate.api.nvidia.com/v1"
self.usage_metrics = {}
self.circuit_breaker_status = {
"state": "closed", # closed, open, half_open
"failure_count": 0,
"last_failure_time": None,
"failure_threshold": 5,
"recovery_timeout": 60 # seconds
}
logger.info(f"Initialized NVIDIA NIM backend with endpoint: {self.base_url}")
async def _get_tenant_api_key(self, tenant_id: str) -> str:
"""
Get API key for tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
API keys are managed in Control Panel and fetched via internal API.
Args:
tenant_id: Tenant domain string from X-Tenant-ID header
Returns:
Decrypted NVIDIA API key
Raises:
ValueError: If no API key configured (results in HTTP 503 to client)
"""
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="nvidia")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No NVIDIA API key for tenant '{tenant_id}': {e}")
raise ValueError(str(e))
except RuntimeError as e:
logger.error(f"Control Panel error: {e}")
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
def _get_client(self, api_key: str) -> httpx.AsyncClient:
"""Get configured HTTP client for NVIDIA NIM API"""
return httpx.AsyncClient(
base_url=self.base_url,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
timeout=httpx.Timeout(120.0) # Longer timeout for large models
)
async def execute_inference(
self,
prompt: str,
model: str = "nvidia/llama-3.1-nemotron-super-49b-v1",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None
) -> Dict[str, Any]:
"""Execute LLM inference with simple prompt"""
messages = [{"role": "user", "content": prompt}]
return await self.execute_inference_with_messages(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
user_id=user_id,
tenant_id=tenant_id
)
async def execute_inference_with_messages(
self,
messages: List[Dict[str, str]],
model: str = "nvidia/llama-3.1-nemotron-super-49b-v1",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None
) -> Dict[str, Any]:
"""Execute LLM inference using messages format (conversation style)"""
# Check circuit breaker
if not await self._is_circuit_closed():
raise Exception("Circuit breaker is open - NVIDIA NIM service temporarily unavailable")
if not tenant_id:
raise ValueError("tenant_id is required for NVIDIA NIM inference")
try:
api_key = await self._get_tenant_api_key(tenant_id)
# Translate GT 2.0 "agent" role to OpenAI "assistant" for external API compatibility
external_messages = []
for msg in messages:
external_msg = {
**msg, # Preserve ALL fields including tool_call_id, tool_calls, etc.
"role": "assistant" if msg.get("role") == "agent" else msg.get("role")
}
external_messages.append(external_msg)
# Build request payload
request_data = {
"model": model,
"messages": external_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream
}
# Add tools if provided
if tools:
request_data["tools"] = tools
if tool_choice:
request_data["tool_choice"] = tool_choice
start_time = time.time()
async with self._get_client(api_key) as client:
if stream:
# Return generator for streaming
return self._stream_inference_with_messages(
client, request_data, user_id, tenant_id, model
)
# Non-streaming request
response = await client.post("/chat/completions", json=request_data)
response.raise_for_status()
data = response.json()
latency = (time.time() - start_time) * 1000
# Calculate cost
usage = data.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
cost_cents = int((input_cost + output_cost) * 100)
# Track usage
await self._track_usage(user_id, tenant_id, model, total_tokens, latency, cost_cents)
# Reset circuit breaker on success
await self._record_success()
# Build response
result = {
"content": data["choices"][0]["message"]["content"],
"model": model,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"cost_cents": cost_cents
},
"latency_ms": latency,
"provider": "nvidia"
}
# Include tool calls if present
message = data["choices"][0]["message"]
if message.get("tool_calls"):
result["tool_calls"] = message["tool_calls"]
return result
except httpx.HTTPStatusError as e:
logger.error(f"NVIDIA NIM API error: {e.response.status_code} - {e.response.text}")
await self._record_failure()
raise Exception(f"NVIDIA NIM inference failed: HTTP {e.response.status_code}")
except Exception as e:
logger.error(f"NVIDIA NIM inference failed: {e}")
await self._record_failure()
raise Exception(f"NVIDIA NIM inference failed: {str(e)}")
async def _stream_inference_with_messages(
self,
client: httpx.AsyncClient,
request_data: Dict[str, Any],
user_id: str,
tenant_id: str,
model: str
) -> AsyncGenerator[str, None]:
"""Stream LLM inference responses"""
start_time = time.time()
total_tokens = 0
try:
async with client.stream("POST", "/chat/completions", json=request_data) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
if chunk.get("choices") and chunk["choices"][0].get("delta", {}).get("content"):
content = chunk["choices"][0]["delta"]["content"]
total_tokens += len(content.split()) # Approximate
yield content
except json.JSONDecodeError:
continue
# Track usage after streaming completes
latency = (time.time() - start_time) * 1000
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
cost_cents = int((total_tokens / 1_000_000) * model_prices["output"] * 100)
await self._track_usage(user_id, tenant_id, model, total_tokens, latency, cost_cents)
await self._record_success()
except Exception as e:
logger.error(f"NVIDIA NIM streaming error: {e}")
await self._record_failure()
raise e
async def check_health(self) -> Dict[str, Any]:
"""Check health of NVIDIA NIM backend and circuit breaker status"""
return {
"nvidia_nim": {
"endpoint": self.base_url,
"status": "available" if self.circuit_breaker_status["state"] == "closed" else "degraded",
"last_check": datetime.utcnow().isoformat()
},
"circuit_breaker": {
"state": self.circuit_breaker_status["state"],
"failure_count": self.circuit_breaker_status["failure_count"],
"last_failure": self.circuit_breaker_status["last_failure_time"].isoformat()
if self.circuit_breaker_status["last_failure_time"] else None
}
}
async def _is_circuit_closed(self) -> bool:
"""Check if circuit breaker allows requests"""
if self.circuit_breaker_status["state"] == "closed":
return True
if self.circuit_breaker_status["state"] == "open":
# Check if recovery timeout has passed
if self.circuit_breaker_status["last_failure_time"]:
time_since_failure = (datetime.utcnow() - self.circuit_breaker_status["last_failure_time"]).total_seconds()
if time_since_failure > self.circuit_breaker_status["recovery_timeout"]:
# Move to half-open state
self.circuit_breaker_status["state"] = "half_open"
logger.info("NVIDIA NIM circuit breaker moved to half-open state")
return True
return False
if self.circuit_breaker_status["state"] == "half_open":
# Allow limited requests in half-open state
return True
return False
async def _record_success(self):
"""Record successful request for circuit breaker"""
if self.circuit_breaker_status["state"] == "half_open":
# Success in half-open state closes the circuit
self.circuit_breaker_status["state"] = "closed"
self.circuit_breaker_status["failure_count"] = 0
logger.info("NVIDIA NIM circuit breaker closed after successful request")
# Reset failure count on any success
self.circuit_breaker_status["failure_count"] = 0
async def _record_failure(self):
"""Record failed request for circuit breaker"""
self.circuit_breaker_status["failure_count"] += 1
self.circuit_breaker_status["last_failure_time"] = datetime.utcnow()
if self.circuit_breaker_status["failure_count"] >= self.circuit_breaker_status["failure_threshold"]:
if self.circuit_breaker_status["state"] in ["closed", "half_open"]:
self.circuit_breaker_status["state"] = "open"
logger.warning(f"NVIDIA NIM circuit breaker opened after {self.circuit_breaker_status['failure_count']} failures")
async def _track_usage(
self,
user_id: str,
tenant_id: str,
model: str,
tokens: int,
latency: float,
cost_cents: int
):
"""Track usage metrics for billing and monitoring"""
# Create usage key
usage_key = f"{tenant_id}:{user_id}:{model}"
# Initialize metrics if not exists
if usage_key not in self.usage_metrics:
self.usage_metrics[usage_key] = {
"total_tokens": 0,
"total_requests": 0,
"total_cost_cents": 0,
"average_latency": 0
}
# Update metrics
metrics = self.usage_metrics[usage_key]
metrics["total_tokens"] += tokens
metrics["total_requests"] += 1
metrics["total_cost_cents"] += cost_cents
# Update average latency
prev_avg = metrics["average_latency"]
prev_count = metrics["total_requests"] - 1
metrics["average_latency"] = (prev_avg * prev_count + latency) / metrics["total_requests"]
# Log high-level metrics periodically
if metrics["total_requests"] % 100 == 0:
logger.info(f"NVIDIA NIM usage milestone for {usage_key}: {metrics}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int, model: str) -> int:
"""Calculate cost in cents based on token usage"""
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
return int((input_cost + output_cost) * 100)
async def get_available_models(self) -> List[Dict[str, Any]]:
"""Get list of available NVIDIA NIM models with their configurations"""
models = []
for model_id, prices in NVIDIA_MODEL_PRICES.items():
if model_id == "default":
continue
models.append({
"id": model_id,
"name": model_id.split("/")[-1].replace("-", " ").title(),
"provider": "nvidia",
"max_tokens": 4096, # Default for most NIM models
"cost_per_1k_input": prices["input"],
"cost_per_1k_output": prices["output"],
"supports_streaming": True,
"supports_function_calling": True
})
return models

View File

@@ -0,0 +1,457 @@
"""
Capability-Based Authentication for GT 2.0 Resource Cluster
Implements JWT capability token verification with:
- Cryptographic signature validation
- Fine-grained resource permissions
- Rate limiting and constraints enforcement
- Tenant isolation validation
- Zero external dependencies
GT 2.0 Security Principles:
- Self-contained: No external auth services
- Stateless: All permissions in JWT token
- Cryptographic: RSA signature verification
- Isolated: Perfect tenant separation
"""
import jwt
import logging
from datetime import datetime, timezone
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from enum import Enum
from fastapi import HTTPException, Depends, Header
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class CapabilityError(Exception):
"""Capability authentication error"""
pass
class ResourceType(str, Enum):
"""Resource types in GT 2.0"""
LLM = "llm"
EMBEDDING = "embedding"
VECTOR_STORAGE = "vector_storage"
EXTERNAL_SERVICES = "external_services"
ADMIN = "admin"
class ActionType(str, Enum):
"""Action types for resources"""
READ = "read"
WRITE = "write"
EXECUTE = "execute"
ADMIN = "admin"
@dataclass
class Capability:
"""Individual capability definition"""
resource: ResourceType
actions: List[ActionType]
constraints: Dict[str, Any]
expires_at: Optional[datetime] = None
def allows_action(self, action: ActionType) -> bool:
"""Check if capability allows specific action"""
return action in self.actions
def is_expired(self) -> bool:
"""Check if capability is expired"""
if not self.expires_at:
return False
return datetime.now(timezone.utc) > self.expires_at
def check_constraint(self, constraint_name: str, value: Any) -> bool:
"""Check if value satisfies constraint"""
if constraint_name not in self.constraints:
return True # No constraint means allowed
constraint_value = self.constraints[constraint_name]
if constraint_name == "max_tokens":
return value <= constraint_value
elif constraint_name == "allowed_models":
return value in constraint_value
elif constraint_name == "max_requests_per_hour":
# This would be checked separately with rate limiting
return True
elif constraint_name == "allowed_tenants":
return value in constraint_value
return True
@dataclass
class CapabilityToken:
"""Parsed capability token"""
subject: str
tenant_id: str
capabilities: List[Capability]
issued_at: datetime
expires_at: datetime
issuer: str
token_version: str
def has_capability(self, resource: ResourceType, action: ActionType) -> bool:
"""Check if token has specific capability"""
for cap in self.capabilities:
if cap.resource == resource and cap.allows_action(action) and not cap.is_expired():
return True
return False
def get_capability(self, resource: ResourceType) -> Optional[Capability]:
"""Get capability for specific resource"""
for cap in self.capabilities:
if cap.resource == resource and not cap.is_expired():
return cap
return None
def is_expired(self) -> bool:
"""Check if entire token is expired"""
return datetime.now(timezone.utc) > self.expires_at
class CapabilityAuthenticator:
"""
Handles capability token verification and authorization.
Uses JWT tokens with embedded permissions for stateless authentication.
"""
def __init__(self):
self.settings = get_settings()
# In production, this would be loaded from secure storage
# For development, using the secret key
self.secret_key = self.settings.secret_key
self.algorithm = "HS256" # TODO: Upgrade to RS256 with public/private keys
logger.info("Capability authenticator initialized")
async def verify_token(self, token: str) -> CapabilityToken:
"""
Verify and parse capability token.
Args:
token: JWT capability token
Returns:
Parsed capability token
Raises:
CapabilityError: If token is invalid or expired
"""
try:
# Decode JWT token
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
audience="gt2-resource-cluster"
)
# Validate required fields
required_fields = ["sub", "tenant_id", "capabilities", "iat", "exp", "iss"]
for field in required_fields:
if field not in payload:
raise CapabilityError(f"Missing required field: {field}")
# Parse timestamps
issued_at = datetime.fromtimestamp(payload["iat"], tz=timezone.utc)
expires_at = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
# Check token expiration
if datetime.now(timezone.utc) > expires_at:
raise CapabilityError("Token has expired")
# Parse capabilities
capabilities = []
for cap_data in payload["capabilities"]:
try:
capability = Capability(
resource=ResourceType(cap_data["resource"]),
actions=[ActionType(action) for action in cap_data["actions"]],
constraints=cap_data.get("constraints", {}),
expires_at=datetime.fromtimestamp(
cap_data["expires_at"], tz=timezone.utc
) if cap_data.get("expires_at") else None
)
capabilities.append(capability)
except (KeyError, ValueError) as e:
logger.warning(f"Invalid capability in token: {e}")
# Skip invalid capabilities rather than rejecting entire token
continue
# Create capability token
capability_token = CapabilityToken(
subject=payload["sub"],
tenant_id=payload["tenant_id"],
capabilities=capabilities,
issued_at=issued_at,
expires_at=expires_at,
issuer=payload["iss"],
token_version=payload.get("token_version", "1.0")
)
logger.debug(f"Capability token verified for {capability_token.subject}")
return capability_token
except jwt.ExpiredSignatureError:
raise CapabilityError("Token has expired")
except jwt.InvalidTokenError as e:
raise CapabilityError(f"Invalid token: {e}")
except Exception as e:
logger.error(f"Token verification failed: {e}")
raise CapabilityError(f"Token verification failed: {e}")
async def check_resource_access(
self,
capability_token: CapabilityToken,
resource: ResourceType,
action: ActionType,
constraints: Optional[Dict[str, Any]] = None
) -> bool:
"""
Check if token allows access to resource with specific action.
Args:
capability_token: Verified capability token
resource: Resource type to access
action: Action to perform
constraints: Additional constraints to check
Returns:
True if access is allowed
Raises:
CapabilityError: If access is denied
"""
try:
# Check token expiration
if capability_token.is_expired():
raise CapabilityError("Token has expired")
# Find matching capability
capability = capability_token.get_capability(resource)
if not capability:
raise CapabilityError(f"No capability for resource: {resource}")
# Check action permission
if not capability.allows_action(action):
raise CapabilityError(f"Action {action} not allowed for resource {resource}")
# Check constraints if provided
if constraints:
for constraint_name, value in constraints.items():
if not capability.check_constraint(constraint_name, value):
raise CapabilityError(
f"Constraint violation: {constraint_name} = {value}"
)
return True
except CapabilityError:
raise
except Exception as e:
logger.error(f"Resource access check failed: {e}")
raise CapabilityError(f"Access check failed: {e}")
# Global authenticator instance
capability_authenticator = CapabilityAuthenticator()
async def verify_capability_token(token: str) -> Dict[str, Any]:
"""
Verify capability token and return payload.
Args:
token: JWT capability token
Returns:
Token payload as dictionary
Raises:
CapabilityError: If token is invalid
"""
capability_token = await capability_authenticator.verify_token(token)
return {
"sub": capability_token.subject,
"tenant_id": capability_token.tenant_id,
"capabilities": [
{
"resource": cap.resource.value,
"actions": [action.value for action in cap.actions],
"constraints": cap.constraints
}
for cap in capability_token.capabilities
],
"iat": capability_token.issued_at.timestamp(),
"exp": capability_token.expires_at.timestamp(),
"iss": capability_token.issuer,
"token_version": capability_token.token_version
}
async def get_current_capability(
authorization: str = Header(..., description="Bearer token")
) -> Dict[str, Any]:
"""
FastAPI dependency to get current capability from Authorization header.
Args:
authorization: Authorization header with Bearer token
Returns:
Capability payload
Raises:
HTTPException: If authentication fails
"""
try:
if not authorization.startswith("Bearer "):
raise HTTPException(
status_code=401,
detail="Invalid authorization header format"
)
token = authorization[7:] # Remove "Bearer " prefix
payload = await verify_capability_token(token)
return payload
except CapabilityError as e:
logger.warning(f"Capability authentication failed: {e}")
raise HTTPException(status_code=401, detail=str(e))
except Exception as e:
logger.error(f"Authentication error: {e}")
raise HTTPException(status_code=500, detail="Authentication error")
async def require_capability(
resource: ResourceType,
action: ActionType,
constraints: Optional[Dict[str, Any]] = None
):
"""
FastAPI dependency to require specific capability.
Args:
resource: Required resource type
action: Required action type
constraints: Additional constraints to check
Returns:
Dependency function
"""
async def _check_capability(
capability_payload: Dict[str, Any] = Depends(get_current_capability)
) -> Dict[str, Any]:
try:
# Reconstruct capability token from payload
capabilities = []
for cap_data in capability_payload["capabilities"]:
capability = Capability(
resource=ResourceType(cap_data["resource"]),
actions=[ActionType(action) for action in cap_data["actions"]],
constraints=cap_data["constraints"]
)
capabilities.append(capability)
capability_token = CapabilityToken(
subject=capability_payload["sub"],
tenant_id=capability_payload["tenant_id"],
capabilities=capabilities,
issued_at=datetime.fromtimestamp(capability_payload["iat"], tz=timezone.utc),
expires_at=datetime.fromtimestamp(capability_payload["exp"], tz=timezone.utc),
issuer=capability_payload["iss"],
token_version=capability_payload["token_version"]
)
# Check required capability
await capability_authenticator.check_resource_access(
capability_token=capability_token,
resource=resource,
action=action,
constraints=constraints
)
return capability_payload
except CapabilityError as e:
logger.warning(f"Capability check failed: {e}")
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Capability check error: {e}")
raise HTTPException(status_code=500, detail="Authorization error")
return _check_capability
# Convenience functions for common capability checks
async def require_llm_capability(
capability_payload: Dict[str, Any] = Depends(
require_capability(ResourceType.LLM, ActionType.EXECUTE)
)
) -> Dict[str, Any]:
"""Require LLM execution capability"""
return capability_payload
async def require_embedding_capability(
capability_payload: Dict[str, Any] = Depends(
require_capability(ResourceType.EMBEDDING, ActionType.EXECUTE)
)
) -> Dict[str, Any]:
"""Require embedding generation capability"""
return capability_payload
async def require_admin_capability(
capability_payload: Dict[str, Any] = Depends(
require_capability(ResourceType.ADMIN, ActionType.ADMIN)
)
) -> Dict[str, Any]:
"""Require admin capability"""
return capability_payload
async def verify_capability_token_dependency(
authorization: str = Header(..., description="Bearer token")
) -> Dict[str, Any]:
"""
FastAPI dependency for ChromaDB MCP API that verifies capability token.
Returns token payload with raw_token field for service layer use.
"""
try:
if not authorization.startswith("Bearer "):
raise HTTPException(
status_code=401,
detail="Invalid authorization header format"
)
token = authorization[7:] # Remove "Bearer " prefix
payload = await verify_capability_token(token)
# Add raw token for service layer
payload["raw_token"] = token
return payload
except CapabilityError as e:
logger.warning(f"Capability authentication failed: {e}")
raise HTTPException(status_code=401, detail=str(e))
except Exception as e:
logger.error(f"Authentication error: {e}")
raise HTTPException(status_code=500, detail="Authentication error")

View File

@@ -0,0 +1,293 @@
"""
GT 2.0 Resource Cluster Configuration
Central configuration for the air-gapped Resource Cluster that manages
all AI resources, document processing, and external service integrations.
"""
import os
from typing import List, Dict, Any, Optional
from pydantic_settings import BaseSettings
from pydantic import Field, validator
class Settings(BaseSettings):
"""Resource Cluster settings with environment variable support"""
# Environment
environment: str = Field(default="development", description="Runtime environment")
debug: bool = Field(default=False, description="Debug mode")
# Service Identity
cluster_name: str = Field(default="gt-resource-cluster", description="Cluster identifier")
service_port: int = Field(default=8003, description="Service port")
# Security
secret_key: str = Field(..., description="JWT signing key for capability tokens")
algorithm: str = Field(default="HS256", description="JWT algorithm")
capability_token_expire_minutes: int = Field(default=60, description="Capability token expiry")
# External LLM Providers (via HAProxy)
groq_api_key: Optional[str] = Field(default=None, description="Groq Cloud API key")
groq_endpoints: List[str] = Field(
default=["https://api.groq.com/openai/v1"],
description="Groq API endpoints for load balancing"
)
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key")
anthropic_api_key: Optional[str] = Field(default=None, description="Anthropic API key")
# NVIDIA NIM Configuration
nvidia_nim_endpoint: str = Field(
default="https://integrate.api.nvidia.com/v1",
description="NVIDIA NIM API endpoint (cloud or self-hosted)"
)
nvidia_nim_enabled: bool = Field(
default=True,
description="Enable NVIDIA NIM backend for GPU-accelerated inference"
)
# HAProxy Configuration
haproxy_groq_endpoint: str = Field(
default="http://haproxy-groq-lb-service.gt-resource.svc.cluster.local",
description="HAProxy load balancer endpoint for Groq API"
)
haproxy_stats_endpoint: str = Field(
default="http://haproxy-groq-lb-service.gt-resource.svc.cluster.local:8404/stats",
description="HAProxy statistics endpoint"
)
haproxy_admin_socket: str = Field(
default="/var/run/haproxy.sock",
description="HAProxy admin socket for runtime configuration"
)
haproxy_enabled: bool = Field(
default=True,
description="Enable HAProxy load balancing for external APIs"
)
# Control Panel Integration (for API key retrieval)
control_panel_url: str = Field(
default="http://control-panel-backend:8000",
description="Control Panel internal API URL for service-to-service calls"
)
service_auth_token: str = Field(
default="internal-service-token",
description="Service-to-service authentication token"
)
# Admin Cluster Configuration Sync
admin_cluster_url: str = Field(
default="http://localhost:8001",
description="Admin cluster URL for configuration sync"
)
config_sync_interval: int = Field(
default=10,
description="Configuration sync interval in seconds"
)
config_sync_enabled: bool = Field(
default=True,
description="Enable automatic configuration sync from admin cluster"
)
# Consul Service Discovery
consul_host: str = Field(default="localhost", description="Consul host")
consul_port: int = Field(default=8500, description="Consul port")
consul_token: Optional[str] = Field(default=None, description="Consul ACL token")
# Document Processing
chunking_engine_workers: int = Field(default=4, description="Parallel document processors")
max_document_size_mb: int = Field(default=50, description="Maximum document size")
supported_document_types: List[str] = Field(
default=[".pdf", ".docx", ".txt", ".md", ".html", ".pptx", ".xlsx", ".csv"],
description="Supported document formats"
)
# BGE-M3 Embedding Configuration
embedding_endpoint: str = Field(
default="http://gentwo-vllm-embeddings:8000/v1/embeddings",
description="Default embedding endpoint (local or external)"
)
bge_m3_local_mode: bool = Field(
default=True,
description="Use local BGE-M3 embedding service (True) or external endpoint (False)"
)
bge_m3_external_endpoint: Optional[str] = Field(
default=None,
description="External BGE-M3 embedding endpoint URL (when local_mode=False)"
)
# Vector Database (ChromaDB)
chromadb_host: str = Field(default="localhost", description="ChromaDB host")
chromadb_port: int = Field(default=8000, description="ChromaDB port")
chromadb_encryption_key: Optional[str] = Field(
default=None,
description="Encryption key for vector storage"
)
# Resource Limits
max_concurrent_inferences: int = Field(default=100, description="Max concurrent LLM calls")
max_tokens_per_request: int = Field(default=8000, description="Max tokens per LLM request")
rate_limit_requests_per_minute: int = Field(default=60, description="Global rate limit")
# Storage Paths
data_directory: str = Field(
default="/tmp/gt2-resource-cluster" if os.getenv("ENVIRONMENT") != "production" else "/data/resource-cluster",
description="Base data directory"
)
template_library_path: str = Field(
default="/tmp/gt2-resource-cluster/templates" if os.getenv("ENVIRONMENT") != "production" else "/data/resource-cluster/templates",
description="Agent template library"
)
models_cache_path: str = Field( # Renamed to avoid pydantic warning
default="/tmp/gt2-resource-cluster/models" if os.getenv("ENVIRONMENT") != "production" else "/data/resource-cluster/models",
description="Local model cache"
)
# Redis removed - Resource Cluster uses PostgreSQL for caching and rate limiting
# Monitoring
prometheus_enabled: bool = Field(default=True, description="Enable Prometheus metrics")
prometheus_port: int = Field(default=9091, description="Prometheus metrics port")
# CORS Configuration (for tenant backends)
cors_origins: List[str] = Field(
default=["http://localhost:8002", "https://*.gt2.com"],
description="Allowed CORS origins"
)
# Trusted Host Configuration
trusted_hosts: List[str] = Field(
default=["localhost", "*.gt2.com", "resource-cluster", "gentwo-resource-backend",
"gt2-resource-backend", "testserver", "127.0.0.1", "*"],
description="Allowed host headers for TrustedHostMiddleware"
)
# Feature Flags
enable_model_caching: bool = Field(default=True, description="Cache model responses")
enable_usage_tracking: bool = Field(default=True, description="Track resource usage")
enable_cost_calculation: bool = Field(default=True, description="Calculate usage costs")
@validator("data_directory")
def validate_data_directory(cls, v):
# Ensure directory exists with secure permissions
os.makedirs(v, exist_ok=True, mode=0o700)
return v
@validator("template_library_path")
def validate_template_library_path(cls, v):
os.makedirs(v, exist_ok=True, mode=0o700)
return v
@validator("models_cache_path")
def validate_models_cache_path(cls, v):
os.makedirs(v, exist_ok=True, mode=0o700)
return v
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore",
}
def get_settings(tenant_id: Optional[str] = None) -> Settings:
"""Get tenant-scoped application settings"""
# For development, use a simple cache without tenant isolation
if os.getenv("ENVIRONMENT") == "development":
return Settings()
# In production, settings should be tenant-scoped
# This prevents global state from affecting tenant isolation
if tenant_id:
# Create tenant-specific settings with proper isolation
settings = Settings()
# Add tenant-specific configurations here if needed
return settings
else:
# Default settings for non-tenant operations
return Settings()
def get_resource_families(tenant_id: Optional[str] = None) -> Dict[str, Any]:
"""Get tenant-scoped resource family definitions (from CLAUDE.md)"""
# Base resource families - can be extended per tenant in production
return {
"ai_ml": {
"name": "AI/ML Resources",
"subtypes": ["llm", "embedding", "image_generation", "function_calling"]
},
"rag_engine": {
"name": "RAG Engine Resources",
"subtypes": ["vector_db", "document_processor", "semantic_search", "retrieval"]
},
"agentic_workflow": {
"name": "Agentic Workflow Resources",
"subtypes": ["single_agent", "multi_agent", "orchestration", "memory"]
},
"app_integration": {
"name": "App Integration Resources",
"subtypes": ["oauth2", "webhook", "api_connector", "database_connector"]
},
"external_service": {
"name": "External Web Services",
"subtypes": ["iframe_embed", "sso_service", "remote_desktop", "learning_platform"]
},
"ai_literacy": {
"name": "AI Literacy & Cognitive Skills",
"subtypes": ["strategic_game", "logic_puzzle", "philosophical_dilemma", "educational_content"]
}
}
def get_model_configs(tenant_id: Optional[str] = None) -> Dict[str, Any]:
"""Get tenant-scoped model configurations for different providers"""
# Base model configurations - can be customized per tenant in production
return {
"groq": {
"llama-3.1-70b-versatile": {
"max_tokens": 8000,
"cost_per_1k_tokens": 0.59,
"supports_streaming": True,
"supports_function_calling": True
},
"llama-3.1-8b-instant": {
"max_tokens": 8000,
"cost_per_1k_tokens": 0.05,
"supports_streaming": True,
"supports_function_calling": True
},
"mixtral-8x7b-32768": {
"max_tokens": 32768,
"cost_per_1k_tokens": 0.27,
"supports_streaming": True,
"supports_function_calling": False
}
},
"openai": {
"gpt-4-turbo": {
"max_tokens": 128000,
"cost_per_1k_tokens": 10.0,
"supports_streaming": True,
"supports_function_calling": True
},
"gpt-3.5-turbo": {
"max_tokens": 16385,
"cost_per_1k_tokens": 0.5,
"supports_streaming": True,
"supports_function_calling": True
}
},
"anthropic": {
"claude-3-opus": {
"max_tokens": 200000,
"cost_per_1k_tokens": 15.0,
"supports_streaming": True,
"supports_function_calling": False
},
"claude-3-sonnet": {
"max_tokens": 200000,
"cost_per_1k_tokens": 3.0,
"supports_streaming": True,
"supports_function_calling": False
}
}
}

View File

@@ -0,0 +1,45 @@
"""
GT 2.0 Resource Cluster Exceptions
Custom exceptions for the resource cluster.
"""
class ResourceClusterError(Exception):
"""Base exception for resource cluster errors"""
pass
class ProviderError(ResourceClusterError):
"""Error from AI model provider"""
pass
class ModelNotFoundError(ResourceClusterError):
"""Requested model not found"""
pass
class CapabilityError(ResourceClusterError):
"""Capability token validation error"""
pass
class MCPError(ResourceClusterError):
"""MCP service error"""
pass
class DocumentProcessingError(ResourceClusterError):
"""Document processing error"""
pass
class RateLimitError(ResourceClusterError):
"""Rate limit exceeded"""
pass
class CircuitBreakerError(ProviderError):
"""Circuit breaker is open"""
pass

View File

@@ -0,0 +1,273 @@
"""
GT 2.0 Resource Cluster Security
Capability-based authentication and authorization for resource access.
Implements cryptographically signed JWT tokens with embedded capabilities.
"""
import hashlib
import json
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from app.core.config import get_settings
settings = get_settings()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class ResourceCapability(BaseModel):
"""Individual resource capability"""
resource: str # e.g., "llm:groq", "rag:semantic_search"
actions: List[str] # e.g., ["inference", "streaming"]
limits: Dict[str, Any] = {} # e.g., {"max_tokens": 4000, "requests_per_minute": 60}
constraints: Dict[str, Any] = {} # e.g., {"valid_until": "2024-12-31", "ip_restrictions": []}
class CapabilityToken(BaseModel):
"""Capability-based JWT token payload"""
sub: str # User or service identifier
tenant_id: str # Tenant identifier
capabilities: List[ResourceCapability] # Granted capabilities
capability_hash: str # SHA256 hash of capabilities for integrity
exp: Optional[datetime] = None # Expiration time
iat: Optional[datetime] = None # Issued at time
jti: Optional[str] = None # JWT ID for revocation
class CapabilityValidator:
"""Validates and enforces capability-based access control"""
def __init__(self):
self.settings = get_settings()
def create_capability_token(
self,
user_id: str,
tenant_id: str,
capabilities: List[Dict[str, Any]],
expires_delta: Optional[timedelta] = None
) -> str:
"""Create a cryptographically signed capability token"""
# Convert capabilities to ResourceCapability objects
capability_objects = [
ResourceCapability(**cap) for cap in capabilities
]
# Generate capability hash for integrity verification
capability_hash = self._generate_capability_hash(capability_objects)
# Set token expiration
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=self.settings.capability_token_expire_minutes)
# Create token payload
token_data = CapabilityToken(
sub=user_id,
tenant_id=tenant_id,
capabilities=[cap.dict() for cap in capability_objects],
capability_hash=capability_hash,
exp=expire,
iat=datetime.utcnow(),
jti=self._generate_jti()
)
# Encode JWT token
encoded_jwt = jwt.encode(
token_data.dict(),
self.settings.secret_key,
algorithm=self.settings.algorithm
)
return encoded_jwt
def verify_capability_token(self, token: str) -> Optional[CapabilityToken]:
"""Verify and decode a capability token"""
try:
# Decode JWT token
payload = jwt.decode(
token,
self.settings.secret_key,
algorithms=[self.settings.algorithm]
)
# Convert to CapabilityToken object
capability_token = CapabilityToken(**payload)
# Verify capability hash integrity
capability_objects = []
for cap in capability_token.capabilities:
if isinstance(cap, dict):
capability_objects.append(ResourceCapability(**cap))
else:
capability_objects.append(cap)
expected_hash = self._generate_capability_hash(capability_objects)
if capability_token.capability_hash != expected_hash:
raise ValueError("Capability hash mismatch - token may be tampered")
return capability_token
except (JWTError, ValueError) as e:
return None
def check_resource_access(
self,
token: CapabilityToken,
resource: str,
action: str,
context: Dict[str, Any] = {}
) -> bool:
"""Check if token grants access to specific resource and action"""
for capability in token.capabilities:
# Handle both dict and ResourceCapability object formats
if isinstance(capability, dict):
cap_resource = capability["resource"]
cap_actions = capability.get("actions", [])
cap_constraints = capability.get("constraints", {})
else:
cap_resource = capability.resource
cap_actions = capability.actions
cap_constraints = capability.constraints
# Check if capability matches resource
if self._matches_resource(cap_resource, resource):
# Check if action is allowed
if action in cap_actions:
# Check additional constraints
if self._check_constraints(cap_constraints, context):
return True
return False
def get_resource_limits(
self,
token: CapabilityToken,
resource: str
) -> Dict[str, Any]:
"""Get resource-specific limits from token"""
for capability in token.capabilities:
# Handle both dict and ResourceCapability object formats
if isinstance(capability, dict):
cap_resource = capability["resource"]
cap_limits = capability.get("limits", {})
else:
cap_resource = capability.resource
cap_limits = capability.limits
if self._matches_resource(cap_resource, resource):
return cap_limits
return {}
def _generate_capability_hash(self, capabilities: List[ResourceCapability]) -> str:
"""Generate SHA256 hash of capabilities for integrity verification"""
# Sort capabilities for consistent hashing
sorted_caps = sorted(
[cap.dict() for cap in capabilities],
key=lambda x: x["resource"]
)
# Create hash
cap_string = json.dumps(sorted_caps, sort_keys=True)
return hashlib.sha256(cap_string.encode()).hexdigest()
def _generate_jti(self) -> str:
"""Generate unique JWT ID"""
import uuid
return str(uuid.uuid4())
def _matches_resource(self, pattern: str, resource: str) -> bool:
"""Check if resource pattern matches requested resource"""
# Handle wildcards (e.g., "llm:*" matches "llm:groq")
if pattern.endswith(":*"):
prefix = pattern[:-2]
return resource.startswith(prefix + ":")
# Handle exact matches
return pattern == resource
def _check_constraints(self, constraints: Dict[str, Any], context: Dict[str, Any]) -> bool:
"""Check additional constraints like time validity and IP restrictions"""
# Check time validity
if "valid_until" in constraints:
valid_until = datetime.fromisoformat(constraints["valid_until"])
if datetime.utcnow() > valid_until:
return False
# Check IP restrictions
if "ip_restrictions" in constraints and "client_ip" in context:
allowed_ips = constraints["ip_restrictions"]
if allowed_ips and context["client_ip"] not in allowed_ips:
return False
# Check tenant restrictions
if "allowed_tenants" in constraints and "tenant_id" in context:
allowed_tenants = constraints["allowed_tenants"]
if allowed_tenants and context["tenant_id"] not in allowed_tenants:
return False
return True
# Global validator instance
capability_validator = CapabilityValidator()
def verify_capability_token(token: str) -> Optional[CapabilityToken]:
"""Standalone function for FastAPI dependency injection"""
return capability_validator.verify_capability_token(token)
def create_resource_capability(
resource_type: str,
resource_id: str,
actions: List[str],
limits: Dict[str, Any] = {},
constraints: Dict[str, Any] = {}
) -> Dict[str, Any]:
"""Helper function to create a resource capability"""
return {
"resource": f"{resource_type}:{resource_id}",
"actions": actions,
"limits": limits,
"constraints": constraints
}
def create_assistant_capabilities(assistant_config: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Create capabilities from agent configuration"""
capabilities = []
# Extract capabilities from agent config
for cap in assistant_config.get("capabilities", []):
capabilities.append(cap)
# Add default LLM capability if specified
if "primary_llm" in assistant_config.get("resource_preferences", {}):
llm_model = assistant_config["resource_preferences"]["primary_llm"]
capabilities.append(create_resource_capability(
"llm",
llm_model.replace(":", "_"),
["inference", "streaming"],
{
"max_tokens": assistant_config["resource_preferences"].get("max_tokens", 4000),
"temperature": assistant_config["resource_preferences"].get("temperature", 0.7)
}
))
return capabilities
# Global capability validator instance
capability_validator = CapabilityValidator()

View File

@@ -0,0 +1,234 @@
"""
GT 2.0 Resource Cluster - Main Application
Air-gapped resource management hub for AI/ML resources, RAG engines,
agentic workflows, app integrations, external services, and AI literacy.
"""
from contextlib import asynccontextmanager
from datetime import datetime
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse
from prometheus_client import make_asgi_app
import logging
from app.core.config import get_settings
from app.api import inference, embeddings, rag, agents, templates, health, internal
from app.api.v1 import services, models, ai_inference, mcp_registry, mcp_executor
from app.core.backends import initialize_backends
from app.services.consul_registry import ConsulRegistry
from app.services.config_sync import get_config_sync_service
from app.api.v1.mcp_registry import initialize_mcp_servers
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
settings = get_settings()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifecycle"""
# Startup
logger.info("Starting GT 2.0 Resource Cluster")
# Initialize resource backends
await initialize_backends()
# Initialize MCP servers (RAG and Conversation)
try:
await initialize_mcp_servers()
logger.info("MCP servers initialized")
except Exception as e:
logger.error(f"MCP server initialization failed: {e}")
# Start configuration sync from admin cluster
if settings.config_sync_enabled:
config_sync = get_config_sync_service()
# Perform initial sync before starting background loop
try:
await config_sync.sync_configurations()
logger.info("Initial configuration sync completed")
# Give config sync time to complete provider updates
import asyncio
await asyncio.sleep(0.5)
# Verify BGE-M3 model is loaded in registry before refreshing embedding backend
try:
from app.services.model_service import default_model_service
from app.core.backends import get_embedding_backend
# Retry logic to wait for BGE-M3 to appear in registry
max_retries = 3
retry_delay = 1.0 # seconds
bge_m3_found = False
for attempt in range(max_retries):
bge_m3_config = default_model_service.model_registry.get("BAAI/bge-m3")
if bge_m3_config:
endpoint = bge_m3_config.get("endpoint_url")
config = bge_m3_config.get("parameters", {})
is_local_mode = config.get("is_local_mode", True)
logger.info(f"BGE-M3 found in registry on attempt {attempt + 1}: endpoint={endpoint}, is_local_mode={is_local_mode}")
bge_m3_found = True
break
else:
logger.debug(f"BGE-M3 not yet in registry (attempt {attempt + 1}/{max_retries}), retrying...")
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
if not bge_m3_found:
logger.warning("BGE-M3 not found in registry after initial sync - will use defaults until next sync")
# Refresh embedding backend with database configuration
embedding_backend = get_embedding_backend()
embedding_backend.refresh_endpoint_from_registry()
logger.info(f"Embedding backend refreshed with database configuration: {embedding_backend.embedding_endpoint}")
except Exception as e:
logger.warning(f"Failed to refresh embedding backend on startup: {e}")
except Exception as e:
logger.warning(f"Initial configuration sync failed: {e}")
# Start sync loop in background
asyncio.create_task(config_sync.start_sync_loop())
logger.info("Started configuration sync from admin cluster")
# Register with Consul for service discovery
if settings.environment == "production":
consul = ConsulRegistry()
await consul.register_service(
name="resource-cluster",
service_id=f"resource-cluster-{settings.cluster_name}",
address="localhost",
port=settings.service_port,
tags=["ai", "resource", "cluster"],
check_interval="10s"
)
logger.info(f"Resource Cluster started on port {settings.service_port}")
yield
# Shutdown
logger.info("Shutting down Resource Cluster")
# Deregister from Consul
if settings.environment == "production":
await consul.deregister_service(f"resource-cluster-{settings.cluster_name}")
# Create FastAPI application
app = FastAPI(
title="GT 2.0 Resource Cluster",
description="Centralized AI resource management with high availability",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add trusted host middleware with configurable hosts
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=settings.trusted_hosts
)
# Include API routers
app.include_router(health.router, prefix="/health", tags=["health"])
app.include_router(inference.router, prefix="/api/v1/inference", tags=["inference"])
app.include_router(embeddings.router, prefix="/api/v1/embeddings", tags=["embeddings"])
app.include_router(rag.router, prefix="/api/v1/rag", tags=["rag"])
app.include_router(agents.router, prefix="/api/v1/agents", tags=["agents"])
app.include_router(templates.router, prefix="/api/v1/templates", tags=["templates"])
app.include_router(services.router, prefix="/api/v1/services", tags=["services"])
app.include_router(models.router, tags=["models"])
app.include_router(ai_inference.router, prefix="/api/v1", tags=["ai"]) # Add AI inference router
app.include_router(mcp_registry.router, prefix="/api/v1", tags=["mcp"])
app.include_router(mcp_executor.router, prefix="/api/v1", tags=["mcp"])
app.include_router(internal.router, tags=["internal"]) # Internal service-to-service APIs
# Mount Prometheus metrics endpoint
if settings.prometheus_enabled:
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
@app.get("/")
async def root():
"""Root endpoint"""
return {
"service": "GT 2.0 Resource Cluster",
"version": "1.0.0",
"status": "operational",
"environment": settings.environment,
"capabilities": {
"ai_ml": ["llm", "embeddings", "image_generation"],
"rag_engine": ["vector_search", "document_processing"],
"agentic_workflows": ["single_agent", "multi_agent"],
"app_integrations": ["oauth2", "webhooks"],
"external_services": ["ctfd", "canvas", "guacamole", "iframe_embed", "sso"],
"ai_literacy": ["games", "puzzles", "education"]
}
}
@app.get("/health")
async def health_check():
"""Docker health check endpoint (without trailing slash)"""
return {
"status": "healthy",
"service": "resource-cluster",
"timestamp": datetime.utcnow()
}
@app.get("/ready")
async def ready_check():
"""Kubernetes readiness probe endpoint"""
return {
"status": "ready",
"service": "resource-cluster",
"timestamp": datetime.utcnow(),
"health": "ok"
}
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Global exception handler"""
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"error": "Internal server error",
"message": str(exc) if settings.debug else "An error occurred processing your request"
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=settings.service_port,
reload=settings.debug,
log_level="info" if not settings.debug else "debug"
)

View File

@@ -0,0 +1 @@
# GT 2.0 Resource Cluster Models

View File

@@ -0,0 +1,68 @@
"""
Access Group Models for GT 2.0 Resource Cluster
Simplified models for resource access control.
These are lighter versions focused on MCP resource management.
"""
from datetime import datetime
from enum import Enum
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
class AccessGroup(str, Enum):
"""Resource access levels"""
INDIVIDUAL = "individual" # Private to owner
TEAM = "team" # Shared with specific users
ORGANIZATION = "organization" # Read-only for all tenant users
@dataclass
class Resource:
"""Base resource model for MCP services"""
id: str
name: str
resource_type: str
owner_id: str
tenant_domain: str
access_group: AccessGroup
team_members: List[str]
created_at: datetime
updated_at: datetime
metadata: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary representation"""
return {
"id": self.id,
"name": self.name,
"resource_type": self.resource_type,
"owner_id": self.owner_id,
"tenant_domain": self.tenant_domain,
"access_group": self.access_group.value,
"team_members": self.team_members,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"metadata": self.metadata
}
def can_access(self, user_id: str, tenant_domain: str) -> bool:
"""Check if user can access this resource"""
# Check tenant isolation
if self.tenant_domain != tenant_domain:
return False
# Owner always has access
if self.owner_id == user_id:
return True
# Check access group permissions
if self.access_group == AccessGroup.INDIVIDUAL:
return False
elif self.access_group == AccessGroup.TEAM:
return user_id in self.team_members
elif self.access_group == AccessGroup.ORGANIZATION:
return True # All tenant users have read access
return False

View File

@@ -0,0 +1,76 @@
"""
GT 2.0 Resource Cluster Providers
External AI model providers for the resource cluster.
"""
from typing import Dict, Any, Optional
import logging
from .external_provider import ExternalProvider
logger = logging.getLogger(__name__)
class ProviderFactory:
"""Factory for creating provider instances dynamically"""
def __init__(self):
self.providers = {}
self.initialized = False
async def initialize(self):
"""Initialize all providers"""
if self.initialized:
return
try:
# Initialize external provider (BGE-M3)
external_provider = ExternalProvider()
await external_provider.initialize()
self.providers["external"] = external_provider
logger.info("Provider factory initialized successfully")
self.initialized = True
except Exception as e:
logger.error(f"Failed to initialize provider factory: {e}")
raise
def get_provider(self, provider_name: str) -> Optional[Any]:
"""Get provider instance by name"""
return self.providers.get(provider_name)
def list_providers(self) -> Dict[str, Any]:
"""List all available providers"""
return {
name: {
"name": provider.name if hasattr(provider, "name") else name,
"status": "initialized" if provider else "error"
}
for name, provider in self.providers.items()
}
# Global provider factory instance
_provider_factory = None
async def get_provider_factory() -> ProviderFactory:
"""Get initialized provider factory"""
global _provider_factory
if _provider_factory is None:
_provider_factory = ProviderFactory()
await _provider_factory.initialize()
return _provider_factory
def get_external_provider():
"""Get external provider instance (synchronous)"""
global _provider_factory
if _provider_factory and "external" in _provider_factory.providers:
return _provider_factory.providers["external"]
return None
__all__ = ["ExternalProvider", "ProviderFactory", "get_provider_factory", "get_external_provider"]

View File

@@ -0,0 +1,306 @@
"""
GT 2.0 External Provider
Handles external AI services like BGE-M3 embedding model on GT Edge network.
Provides unified interface for external model access with health monitoring.
"""
import asyncio
import httpx
import json
import time
import logging
from typing import Dict, Any, List, Optional, Union
from datetime import datetime, timedelta
from app.core.config import get_settings
from app.core.exceptions import ProviderError
logger = logging.getLogger(__name__)
settings = get_settings()
class ExternalProvider:
"""Provider for external AI models and services"""
def __init__(self):
self.name = "external"
self.models = {}
self.health_status = {}
self.circuit_breaker = {}
self.retry_attempts = 3
self.timeout = 30.0
async def initialize(self):
"""Initialize external provider with default models"""
await self.register_bge_m3_model()
logger.info("External provider initialized")
async def register_bge_m3_model(self):
"""Register BGE-M3 embedding model on GT Edge network"""
model_config = {
"model_id": "bge-m3-embedding",
"name": "BGE-M3 Multilingual Embedding",
"version": "1.0",
"provider": "external",
"model_type": "embedding",
"endpoint": "http://10.0.0.100:8080", # GT Edge network default
"dimensions": 1024,
"max_input_tokens": 8192,
"cost_per_1k_tokens": 0.0, # Internal model, no cost
"description": "BGE-M3 multilingual embedding model on GT Edge network",
"capabilities": {
"languages": ["en", "zh", "fr", "de", "es", "ru", "ja", "ko"],
"max_sequence_length": 8192,
"output_dimensions": 1024,
"supports_retrieval": True,
"supports_clustering": True
}
}
self.models["bge-m3-embedding"] = model_config
await self._initialize_circuit_breaker("bge-m3-embedding")
logger.info("Registered BGE-M3 embedding model")
async def generate_embeddings(
self,
model_id: str,
texts: Union[str, List[str]],
**kwargs
) -> Dict[str, Any]:
"""Generate embeddings using external model"""
if model_id not in self.models:
raise ProviderError(f"Model {model_id} not found in external provider")
model_config = self.models[model_id]
if not await self._check_circuit_breaker(model_id):
raise ProviderError(f"Circuit breaker open for model {model_id}")
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
try:
start_time = time.time()
# Prepare request payload
payload = {
"model": model_id,
"input": texts,
"encoding_format": "float",
**kwargs
}
# Make request to external model
endpoint = model_config["endpoint"]
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{endpoint}/v1/embeddings",
json=payload,
headers={
"Content-Type": "application/json",
"User-Agent": "GT-2.0-Resource-Cluster/1.0"
}
)
response.raise_for_status()
result = response.json()
# Calculate metrics
latency_ms = (time.time() - start_time) * 1000
total_tokens = sum(len(text.split()) for text in texts)
# Update circuit breaker with success
await self._record_success(model_id, latency_ms)
# Format response
embeddings = []
for i, embedding_data in enumerate(result.get("data", [])):
embeddings.append({
"object": "embedding",
"index": i,
"embedding": embedding_data.get("embedding", [])
})
return {
"object": "list",
"data": embeddings,
"model": model_id,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens
},
"provider": "external",
"latency_ms": latency_ms,
"timestamp": datetime.utcnow().isoformat()
}
except httpx.RequestError as e:
await self._record_failure(model_id, str(e))
raise ProviderError(f"External model request failed: {e}")
except httpx.HTTPStatusError as e:
await self._record_failure(model_id, f"HTTP {e.response.status_code}")
raise ProviderError(f"External model returned error: {e.response.status_code}")
except Exception as e:
await self._record_failure(model_id, str(e))
raise ProviderError(f"External model error: {e}")
async def health_check(self, model_id: str = None) -> Dict[str, Any]:
"""Check health of external models"""
if model_id:
return await self._check_model_health(model_id)
# Check all models
health_results = {}
for mid in self.models.keys():
health_results[mid] = await self._check_model_health(mid)
# Calculate overall health
total_models = len(health_results)
healthy_models = sum(1 for h in health_results.values() if h.get("healthy", False))
return {
"provider": "external",
"overall_healthy": healthy_models == total_models,
"total_models": total_models,
"healthy_models": healthy_models,
"health_percentage": (healthy_models / total_models * 100) if total_models > 0 else 0,
"models": health_results,
"timestamp": datetime.utcnow().isoformat()
}
async def _check_model_health(self, model_id: str) -> Dict[str, Any]:
"""Check health of specific external model"""
if model_id not in self.models:
return {
"healthy": False,
"error": "Model not found",
"timestamp": datetime.utcnow().isoformat()
}
model_config = self.models[model_id]
try:
start_time = time.time()
# Health check endpoint
endpoint = model_config["endpoint"]
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(f"{endpoint}/health")
latency_ms = (time.time() - start_time) * 1000
if response.status_code == 200:
return {
"healthy": True,
"latency_ms": latency_ms,
"endpoint": endpoint,
"timestamp": datetime.utcnow().isoformat()
}
else:
return {
"healthy": False,
"error": f"HTTP {response.status_code}",
"latency_ms": latency_ms,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
return {
"healthy": False,
"error": str(e),
"timestamp": datetime.utcnow().isoformat()
}
async def _initialize_circuit_breaker(self, model_id: str):
"""Initialize circuit breaker for model"""
self.circuit_breaker[model_id] = {
"state": "closed", # closed, open, half_open
"failure_count": 0,
"success_count": 0,
"last_failure_time": 0,
"failure_threshold": 5,
"success_threshold": 3,
"timeout": 60 # seconds to wait before trying half_open
}
async def _check_circuit_breaker(self, model_id: str) -> bool:
"""Check if circuit breaker allows requests"""
cb = self.circuit_breaker.get(model_id, {})
if cb.get("state") == "closed":
return True
elif cb.get("state") == "open":
# Check if timeout has passed
if time.time() - cb.get("last_failure_time", 0) > cb.get("timeout", 60):
cb["state"] = "half_open"
cb["success_count"] = 0
return True
return False
elif cb.get("state") == "half_open":
return True
return False
async def _record_success(self, model_id: str, latency_ms: float):
"""Record successful request for circuit breaker"""
cb = self.circuit_breaker.get(model_id, {})
if cb.get("state") == "half_open":
cb["success_count"] += 1
if cb["success_count"] >= cb.get("success_threshold", 3):
cb["state"] = "closed"
cb["failure_count"] = 0
# Update health status
self.health_status[model_id] = {
"healthy": True,
"last_success": time.time(),
"latency_ms": latency_ms
}
async def _record_failure(self, model_id: str, error: str):
"""Record failed request for circuit breaker"""
cb = self.circuit_breaker.get(model_id, {})
cb["failure_count"] += 1
cb["last_failure_time"] = time.time()
if cb["failure_count"] >= cb.get("failure_threshold", 5):
cb["state"] = "open"
# Update health status
self.health_status[model_id] = {
"healthy": False,
"last_failure": time.time(),
"error": error
}
logger.warning(f"External model {model_id} failure: {error}")
def get_available_models(self) -> List[Dict[str, Any]]:
"""Get list of available external models"""
return list(self.models.values())
def update_model_endpoint(self, model_id: str, endpoint: str):
"""Update model endpoint (called from config sync)"""
if model_id in self.models:
old_endpoint = self.models[model_id]["endpoint"]
self.models[model_id]["endpoint"] = endpoint
logger.info(f"Updated {model_id} endpoint: {old_endpoint} -> {endpoint}")
else:
logger.warning(f"Attempted to update unknown model: {model_id}")
# Global external provider instance
_external_provider = None
async def get_external_provider() -> ExternalProvider:
"""Get external provider instance"""
global _external_provider
if _external_provider is None:
_external_provider = ExternalProvider()
await _external_provider.initialize()
return _external_provider

View File

@@ -0,0 +1,3 @@
"""
Service layer for Resource Cluster
"""

View File

@@ -0,0 +1,342 @@
"""
Admin Model Configuration Service for GT 2.0 Resource Cluster
This service fetches model configurations from the Admin Control Panel
and provides them to the Resource Cluster for LLM routing and capabilities.
"""
import asyncio
import logging
import httpx
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from dataclasses import dataclass
import json
from app.core.config import get_settings
logger = logging.getLogger(__name__)
@dataclass
class AdminModelConfig:
"""Model configuration from admin cluster"""
uuid: str # Database UUID - unique identifier for this model config
model_id: str # Business identifier - the model name used in API calls
name: str
provider: str
model_type: str
endpoint: str
api_key_name: Optional[str]
context_window: Optional[int]
max_tokens: Optional[int]
capabilities: Dict[str, Any]
cost_per_1k_input: float
cost_per_1k_output: float
is_active: bool
tenant_restrictions: Dict[str, Any]
required_capabilities: List[str]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for LLM Gateway"""
return {
"uuid": self.uuid,
"model_id": self.model_id,
"name": self.name,
"provider": self.provider,
"model_type": self.model_type,
"endpoint": self.endpoint,
"api_key_name": self.api_key_name,
"context_window": self.context_window,
"max_tokens": self.max_tokens,
"capabilities": self.capabilities,
"cost_per_1k_input": self.cost_per_1k_input,
"cost_per_1k_output": self.cost_per_1k_output,
"is_active": self.is_active,
"tenant_restrictions": self.tenant_restrictions,
"required_capabilities": self.required_capabilities
}
class AdminModelConfigService:
"""Service for fetching model configurations from Admin Control Panel"""
def __init__(self):
self.settings = get_settings()
self._model_cache: Dict[str, AdminModelConfig] = {} # model_id -> config
self._uuid_cache: Dict[str, AdminModelConfig] = {} # uuid -> config (for UUID-based lookups)
self._tenant_model_cache: Dict[str, List[str]] = {} # tenant_id -> list of allowed model_ids
self._last_sync: datetime = datetime.min
self._sync_interval = timedelta(seconds=self.settings.config_sync_interval)
self._sync_lock = asyncio.Lock()
async def get_model_config(self, model_id: str) -> Optional[AdminModelConfig]:
"""Get configuration for a specific model by model_id string"""
await self._ensure_fresh_cache()
return self._model_cache.get(model_id)
async def get_model_by_uuid(self, uuid: str) -> Optional[AdminModelConfig]:
"""Get configuration for a specific model by database UUID"""
await self._ensure_fresh_cache()
return self._uuid_cache.get(uuid)
async def get_all_models(self, active_only: bool = True) -> List[AdminModelConfig]:
"""Get all model configurations"""
await self._ensure_fresh_cache()
models = list(self._model_cache.values())
if active_only:
models = [m for m in models if m.is_active]
return models
async def get_tenant_models(self, tenant_id: str) -> List[AdminModelConfig]:
"""Get models available to a specific tenant"""
await self._ensure_fresh_cache()
# Get tenant's allowed model IDs - try multiple formats
allowed_model_ids = self._get_tenant_model_ids(tenant_id)
# Return model configs for allowed models
models = []
for model_id in allowed_model_ids:
if model_id in self._model_cache and self._model_cache[model_id].is_active:
models.append(self._model_cache[model_id])
return models
async def check_tenant_access(self, tenant_id: str, model_id: str) -> bool:
"""Check if a tenant has access to a specific model"""
await self._ensure_fresh_cache()
# Check if model exists and is active
model_config = self._model_cache.get(model_id)
if not model_config or not model_config.is_active:
return False
# Only use tenant-specific access (no global access)
# This enforces proper tenant model assignments
allowed_models = self._get_tenant_model_ids(tenant_id)
return model_id in allowed_models
def _get_tenant_model_ids(self, tenant_id: str) -> List[str]:
"""Get model IDs for tenant, handling multiple tenant ID formats"""
# Try exact match first (e.g., "test-company")
allowed_models = self._tenant_model_cache.get(tenant_id, [])
if not allowed_models:
# Try converting "test-company" to "test" format
if "-" in tenant_id:
domain_format = tenant_id.split("-")[0]
allowed_models = self._tenant_model_cache.get(domain_format, [])
# Try converting "test" to "test-company" format
elif tenant_id + "-company" in self._tenant_model_cache:
allowed_models = self._tenant_model_cache.get(tenant_id + "-company", [])
# Also try tenant_id as numeric string
for key, models in self._tenant_model_cache.items():
if key.isdigit() and tenant_id in key:
allowed_models.extend(models)
break
logger.debug(f"Tenant {tenant_id} has access to models: {allowed_models}")
return allowed_models
async def get_groq_api_key(self, tenant_id: str = None) -> Optional[str]:
"""
Get Groq API key for a tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
API keys are managed in Control Panel and fetched via internal API.
Args:
tenant_id: Tenant domain string (required for tenant requests)
Returns:
Decrypted Groq API key
Raises:
ValueError: If no API key configured for tenant
"""
if not tenant_id:
raise ValueError("tenant_id is required to fetch Groq API key - no fallback to environment variables")
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="groq")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No Groq API key configured for tenant '{tenant_id}': {e}")
raise ValueError(f"No Groq API key configured for tenant '{tenant_id}'. Please configure in Control Panel → API Keys.")
except RuntimeError as e:
logger.error(f"Control Panel API error when fetching API key: {e}")
raise ValueError(f"Unable to retrieve API key - Control Panel service unavailable: {e}")
async def _ensure_fresh_cache(self):
"""Ensure model cache is fresh, sync if needed"""
now = datetime.utcnow()
if now - self._last_sync > self._sync_interval:
async with self._sync_lock:
# Double-check after acquiring lock
now = datetime.utcnow()
if now - self._last_sync <= self._sync_interval:
return
await self._sync_from_admin()
async def _sync_from_admin(self):
"""Sync model configurations from admin cluster"""
try:
# Use correct URL for containerized environment
import os
if os.path.exists('/.dockerenv'):
admin_url = "http://control-panel-backend:8000"
else:
admin_url = self.settings.admin_cluster_url.rstrip('/')
async with httpx.AsyncClient(timeout=30.0) as client:
# Fetch all model configurations
models_response = await client.get(
f"{admin_url}/api/v1/models/?active_only=true&include_stats=true"
)
# Fetch tenant model assignments with proper authentication
tenant_models_response = await client.get(
f"{admin_url}/api/v1/tenant-models/tenants/all",
headers={
"Authorization": "Bearer admin-dev-token",
"Content-Type": "application/json"
}
)
if models_response.status_code == 200:
models_data = models_response.json()
if models_data and len(models_data) > 0:
await self._update_model_cache(models_data)
logger.info(f"Successfully synced {len(models_data)} models from admin cluster")
# Update tenant model assignments if available
if tenant_models_response.status_code == 200:
tenant_data = tenant_models_response.json()
if tenant_data and len(tenant_data) > 0:
await self._update_tenant_cache(tenant_data)
logger.info(f"Successfully synced {len(tenant_data)} tenant model assignments")
else:
logger.warning("No tenant model assignments found")
else:
logger.error(f"Failed to fetch tenant assignments: {tenant_models_response.status_code}")
# Log the actual error for debugging
try:
error_response = tenant_models_response.json()
logger.error(f"Tenant assignments error: {error_response}")
except:
logger.error(f"Tenant assignments error text: {tenant_models_response.text}")
self._last_sync = datetime.utcnow()
return
else:
logger.warning("Admin cluster returned empty model list")
else:
logger.warning(f"Failed to fetch models from admin cluster: {models_response.status_code}")
logger.info("No models configured in admin backend")
self._last_sync = datetime.utcnow()
logger.info(f"Loaded {len(self._model_cache)} models successfully")
except Exception as e:
logger.error(f"Failed to sync from admin cluster: {e}")
# Log final state - no fallback models
if not self._model_cache:
logger.warning("No models available - admin backend has no models configured")
async def _update_model_cache(self, models_data: List[Dict[str, Any]]):
"""Update model configuration cache"""
new_cache = {}
new_uuid_cache = {}
for model_data in models_data:
try:
specs = model_data.get("specifications", {})
cost = model_data.get("cost", {})
status = model_data.get("status", {})
# Get UUID from 'id' field in API response (Control Panel returns UUID as 'id')
model_uuid = model_data.get("id", "")
model_config = AdminModelConfig(
uuid=model_uuid,
model_id=model_data["model_id"],
name=model_data.get("name", model_data["model_id"]),
provider=model_data["provider"],
model_type=model_data["model_type"],
endpoint=model_data.get("endpoint", ""),
api_key_name=model_data.get("api_key_name"),
context_window=specs.get("context_window"),
max_tokens=specs.get("max_tokens"),
capabilities=model_data.get("capabilities", {}),
cost_per_1k_input=cost.get("per_1k_input", 0.0),
cost_per_1k_output=cost.get("per_1k_output", 0.0),
is_active=status.get("is_active", False),
tenant_restrictions=model_data.get("tenant_restrictions", {"global_access": True}),
required_capabilities=model_data.get("required_capabilities", [])
)
new_cache[model_config.model_id] = model_config
# Also index by UUID for UUID-based lookups
if model_uuid:
new_uuid_cache[model_uuid] = model_config
except Exception as e:
logger.error(f"Failed to parse model config {model_data.get('model_id', 'unknown')}: {e}")
self._model_cache = new_cache
self._uuid_cache = new_uuid_cache
async def _update_tenant_cache(self, tenant_data: List[Dict[str, Any]]):
"""Update tenant model access cache from tenant-models endpoint"""
new_tenant_cache = {}
for assignment in tenant_data:
try:
# The tenant-models endpoint returns different format than the old endpoint
tenant_domain = assignment.get("tenant_domain", "")
model_id = assignment["model_id"]
is_enabled = assignment.get("is_enabled", True)
if is_enabled and tenant_domain:
if tenant_domain not in new_tenant_cache:
new_tenant_cache[tenant_domain] = []
new_tenant_cache[tenant_domain].append(model_id)
# Also add by tenant_id for backward compatibility
tenant_id = str(assignment.get("tenant_id", ""))
if tenant_id and tenant_id not in new_tenant_cache:
new_tenant_cache[tenant_id] = []
if tenant_id:
new_tenant_cache[tenant_id].append(model_id)
except Exception as e:
logger.error(f"Failed to parse tenant assignment: {e}")
self._tenant_model_cache = new_tenant_cache
logger.debug(f"Updated tenant cache: {self._tenant_model_cache}")
async def force_sync(self):
"""Force immediate sync from admin cluster"""
self._last_sync = datetime.min
await self._ensure_fresh_cache()
# Global instance
_admin_model_service = None
def get_admin_model_service() -> AdminModelConfigService:
"""Get singleton admin model service"""
global _admin_model_service
if _admin_model_service is None:
_admin_model_service = AdminModelConfigService()
return _admin_model_service

View File

@@ -0,0 +1,931 @@
"""
Agent Orchestration System for GT 2.0 Resource Cluster
Provides multi-agent workflow execution with:
- Sequential, parallel, and conditional agent workflows
- Inter-agent communication and memory management
- Capability-based access control
- Agent lifecycle management
- Performance monitoring and metrics
GT 2.0 Architecture Principles:
- Perfect Tenant Isolation: Agent sessions isolated per tenant
- Zero Downtime: Stateless design, resumable workflows
- Self-Contained Security: Capability-based agent permissions
- No Complexity Addition: Simple orchestration patterns
"""
import asyncio
import logging
import json
import time
import uuid
from typing import Dict, Any, List, Optional, Union, Callable, Coroutine
from datetime import datetime, timedelta
from enum import Enum
from dataclasses import dataclass, asdict
import traceback
from app.core.capability_auth import verify_capability_token, CapabilityError
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class AgentStatus(str, Enum):
"""Agent execution status"""
IDLE = "idle"
RUNNING = "running"
WAITING = "waiting"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class WorkflowType(str, Enum):
"""Types of agent workflows"""
SEQUENTIAL = "sequential"
PARALLEL = "parallel"
CONDITIONAL = "conditional"
PIPELINE = "pipeline"
MAP_REDUCE = "map_reduce"
class MessageType(str, Enum):
"""Inter-agent message types"""
DATA = "data"
CONTROL = "control"
ERROR = "error"
HEARTBEAT = "heartbeat"
@dataclass
class AgentDefinition:
"""Definition of an agent"""
agent_id: str
agent_type: str
name: str
description: str
capabilities_required: List[str]
memory_limit_mb: int = 256
timeout_seconds: int = 300
retry_count: int = 3
environment: Dict[str, Any] = None
@dataclass
class AgentMessage:
"""Message between agents"""
message_id: str
from_agent: str
to_agent: str
message_type: MessageType
content: Dict[str, Any]
timestamp: str
expires_at: Optional[str] = None
@dataclass
class AgentState:
"""Current state of an agent"""
agent_id: str
status: AgentStatus
current_task: Optional[str]
memory_usage_mb: int
cpu_usage_percent: float
started_at: str
last_activity: str
error_message: Optional[str] = None
output_data: Dict[str, Any] = None
@dataclass
class WorkflowExecution:
"""Workflow execution instance"""
workflow_id: str
workflow_type: WorkflowType
tenant_id: str
created_by: str
agents: List[AgentDefinition]
workflow_config: Dict[str, Any]
status: AgentStatus
started_at: str
completed_at: Optional[str] = None
results: Dict[str, Any] = None
error_message: Optional[str] = None
class AgentMemoryManager:
"""Manages agent memory and state"""
def __init__(self):
# In-memory storage (PostgreSQL used for persistent storage)
self._agent_memory: Dict[str, Dict[str, Any]] = {}
self._shared_memory: Dict[str, Dict[str, Any]] = {}
self._message_queues: Dict[str, List[AgentMessage]] = {}
async def store_agent_memory(
self,
agent_id: str,
key: str,
value: Any,
ttl_seconds: Optional[int] = None
) -> None:
"""Store data in agent-specific memory"""
if agent_id not in self._agent_memory:
self._agent_memory[agent_id] = {}
self._agent_memory[agent_id][key] = {
"value": value,
"created_at": datetime.utcnow().isoformat(),
"expires_at": (
datetime.utcnow() + timedelta(seconds=ttl_seconds)
).isoformat() if ttl_seconds else None
}
logger.debug(f"Stored memory for agent {agent_id}: {key}")
async def get_agent_memory(
self,
agent_id: str,
key: str
) -> Optional[Any]:
"""Retrieve data from agent-specific memory"""
if agent_id not in self._agent_memory:
return None
memory_item = self._agent_memory[agent_id].get(key)
if not memory_item:
return None
# Check expiration
if memory_item.get("expires_at"):
expires_at = datetime.fromisoformat(memory_item["expires_at"])
if datetime.utcnow() > expires_at:
del self._agent_memory[agent_id][key]
return None
return memory_item["value"]
async def store_shared_memory(
self,
tenant_id: str,
key: str,
value: Any,
ttl_seconds: Optional[int] = None
) -> None:
"""Store data in tenant-shared memory"""
if tenant_id not in self._shared_memory:
self._shared_memory[tenant_id] = {}
self._shared_memory[tenant_id][key] = {
"value": value,
"created_at": datetime.utcnow().isoformat(),
"expires_at": (
datetime.utcnow() + timedelta(seconds=ttl_seconds)
).isoformat() if ttl_seconds else None
}
logger.debug(f"Stored shared memory for tenant {tenant_id}: {key}")
async def get_shared_memory(
self,
tenant_id: str,
key: str
) -> Optional[Any]:
"""Retrieve data from tenant-shared memory"""
if tenant_id not in self._shared_memory:
return None
memory_item = self._shared_memory[tenant_id].get(key)
if not memory_item:
return None
# Check expiration
if memory_item.get("expires_at"):
expires_at = datetime.fromisoformat(memory_item["expires_at"])
if datetime.utcnow() > expires_at:
del self._shared_memory[tenant_id][key]
return None
return memory_item["value"]
async def send_message(
self,
message: AgentMessage
) -> None:
"""Send message to agent queue"""
if message.to_agent not in self._message_queues:
self._message_queues[message.to_agent] = []
self._message_queues[message.to_agent].append(message)
logger.debug(f"Message sent from {message.from_agent} to {message.to_agent}")
async def receive_messages(
self,
agent_id: str,
message_type: Optional[MessageType] = None
) -> List[AgentMessage]:
"""Receive messages for agent"""
if agent_id not in self._message_queues:
return []
messages = self._message_queues[agent_id]
# Filter expired messages
now = datetime.utcnow()
messages = [
msg for msg in messages
if not msg.expires_at or datetime.fromisoformat(msg.expires_at) > now
]
# Filter by message type if specified
if message_type:
messages = [msg for msg in messages if msg.message_type == message_type]
# Clear processed messages
if message_type:
self._message_queues[agent_id] = [
msg for msg in self._message_queues[agent_id]
if msg.message_type != message_type or
(msg.expires_at and datetime.fromisoformat(msg.expires_at) <= now)
]
else:
self._message_queues[agent_id] = []
return messages
async def cleanup_agent_memory(self, agent_id: str) -> None:
"""Clean up memory for completed agent"""
if agent_id in self._agent_memory:
del self._agent_memory[agent_id]
if agent_id in self._message_queues:
del self._message_queues[agent_id]
logger.debug(f"Cleaned up memory for agent {agent_id}")
class AgentOrchestrator:
"""
Main agent orchestration system for GT 2.0.
Manages agent lifecycle, workflows, communication, and resource allocation.
All operations are tenant-isolated and capability-protected.
"""
def __init__(self):
self.memory_manager = AgentMemoryManager()
self.active_workflows: Dict[str, WorkflowExecution] = {}
self.agent_states: Dict[str, AgentState] = {}
# Built-in agent types
self.agent_registry: Dict[str, Dict[str, Any]] = {
"data_processor": {
"description": "Processes and transforms data",
"capabilities": ["data.read", "data.transform"],
"memory_limit_mb": 512,
"timeout_seconds": 300
},
"llm_agent": {
"description": "Interacts with LLM services",
"capabilities": ["llm.inference", "llm.chat"],
"memory_limit_mb": 256,
"timeout_seconds": 600
},
"embedding_agent": {
"description": "Generates text embeddings",
"capabilities": ["embeddings.generate"],
"memory_limit_mb": 256,
"timeout_seconds": 180
},
"rag_agent": {
"description": "Performs retrieval-augmented generation",
"capabilities": ["rag.search", "rag.generate"],
"memory_limit_mb": 512,
"timeout_seconds": 450
},
"integration_agent": {
"description": "Connects to external services",
"capabilities": ["integration.call", "integration.webhook"],
"memory_limit_mb": 256,
"timeout_seconds": 300
}
}
logger.info("Agent orchestrator initialized")
async def create_workflow(
self,
workflow_type: WorkflowType,
agents: List[AgentDefinition],
workflow_config: Dict[str, Any],
capability_token: str,
workflow_name: Optional[str] = None
) -> str:
"""
Create a new agent workflow.
Args:
workflow_type: Type of workflow to create
agents: List of agents to include in workflow
workflow_config: Configuration for the workflow
capability_token: JWT token with workflow permissions
workflow_name: Optional name for the workflow
Returns:
Workflow ID
"""
# Verify capability token
capability = await verify_capability_token(capability_token)
tenant_id = capability.get("tenant_id")
user_id = capability.get("sub")
# Check workflow permissions
await self._verify_workflow_permissions(capability, workflow_type, agents)
# Generate workflow ID
workflow_id = str(uuid.uuid4())
# Create workflow execution
workflow = WorkflowExecution(
workflow_id=workflow_id,
workflow_type=workflow_type,
tenant_id=tenant_id,
created_by=user_id,
agents=agents,
workflow_config=workflow_config,
status=AgentStatus.IDLE,
started_at=datetime.utcnow().isoformat()
)
# Store workflow
self.active_workflows[workflow_id] = workflow
logger.info(
f"Created {workflow_type} workflow {workflow_id} "
f"with {len(agents)} agents for tenant {tenant_id}"
)
return workflow_id
async def execute_workflow(
self,
workflow_id: str,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""
Execute an agent workflow.
Args:
workflow_id: ID of workflow to execute
input_data: Input data for the workflow
capability_token: JWT token with execution permissions
Returns:
Workflow execution results
"""
# Verify capability token
capability = await verify_capability_token(capability_token)
tenant_id = capability.get("tenant_id")
# Get workflow
workflow = self.active_workflows.get(workflow_id)
if not workflow:
raise ValueError(f"Workflow {workflow_id} not found")
# Check tenant isolation
if workflow.tenant_id != tenant_id:
raise CapabilityError("Insufficient permissions for workflow")
# Check workflow permissions
await self._verify_execution_permissions(capability, workflow)
try:
# Update workflow status
workflow.status = AgentStatus.RUNNING
# Execute based on workflow type
if workflow.workflow_type == WorkflowType.SEQUENTIAL:
results = await self._execute_sequential_workflow(
workflow, input_data, capability_token
)
elif workflow.workflow_type == WorkflowType.PARALLEL:
results = await self._execute_parallel_workflow(
workflow, input_data, capability_token
)
elif workflow.workflow_type == WorkflowType.CONDITIONAL:
results = await self._execute_conditional_workflow(
workflow, input_data, capability_token
)
elif workflow.workflow_type == WorkflowType.PIPELINE:
results = await self._execute_pipeline_workflow(
workflow, input_data, capability_token
)
elif workflow.workflow_type == WorkflowType.MAP_REDUCE:
results = await self._execute_map_reduce_workflow(
workflow, input_data, capability_token
)
else:
raise ValueError(f"Unsupported workflow type: {workflow.workflow_type}")
# Update workflow completion
workflow.status = AgentStatus.COMPLETED
workflow.completed_at = datetime.utcnow().isoformat()
workflow.results = results
logger.info(f"Completed workflow {workflow_id} successfully")
return results
except Exception as e:
# Update workflow error status
workflow.status = AgentStatus.FAILED
workflow.completed_at = datetime.utcnow().isoformat()
workflow.error_message = str(e)
logger.error(f"Workflow {workflow_id} failed: {e}")
raise
async def get_workflow_status(
self,
workflow_id: str,
capability_token: str
) -> Dict[str, Any]:
"""Get status of a workflow"""
# Verify capability token
capability = await verify_capability_token(capability_token)
tenant_id = capability.get("tenant_id")
# Get workflow
workflow = self.active_workflows.get(workflow_id)
if not workflow:
raise ValueError(f"Workflow {workflow_id} not found")
# Check tenant isolation
if workflow.tenant_id != tenant_id:
raise CapabilityError("Insufficient permissions for workflow")
# Get agent states for this workflow
agent_states = {
agent.agent_id: asdict(self.agent_states.get(agent.agent_id))
for agent in workflow.agents
if agent.agent_id in self.agent_states
}
return {
"workflow": asdict(workflow),
"agent_states": agent_states
}
async def cancel_workflow(
self,
workflow_id: str,
capability_token: str
) -> None:
"""Cancel a running workflow"""
# Verify capability token
capability = await verify_capability_token(capability_token)
tenant_id = capability.get("tenant_id")
# Get workflow
workflow = self.active_workflows.get(workflow_id)
if not workflow:
raise ValueError(f"Workflow {workflow_id} not found")
# Check tenant isolation
if workflow.tenant_id != tenant_id:
raise CapabilityError("Insufficient permissions for workflow")
# Cancel workflow
workflow.status = AgentStatus.CANCELLED
workflow.completed_at = datetime.utcnow().isoformat()
# Cancel all agents in workflow
for agent in workflow.agents:
if agent.agent_id in self.agent_states:
self.agent_states[agent.agent_id].status = AgentStatus.CANCELLED
logger.info(f"Cancelled workflow {workflow_id}")
async def _execute_sequential_workflow(
self,
workflow: WorkflowExecution,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute agents sequentially"""
results = {}
current_data = input_data
for agent in workflow.agents:
agent_result = await self._execute_agent(
agent, current_data, capability_token
)
results[agent.agent_id] = agent_result
# Pass output to next agent
if "output" in agent_result:
current_data = agent_result["output"]
return {
"workflow_type": "sequential",
"final_output": current_data,
"agent_results": results
}
async def _execute_parallel_workflow(
self,
workflow: WorkflowExecution,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute agents in parallel"""
# Create tasks for all agents
tasks = []
for agent in workflow.agents:
task = asyncio.create_task(
self._execute_agent(agent, input_data, capability_token)
)
tasks.append((agent.agent_id, task))
# Wait for all tasks to complete
results = {}
for agent_id, task in tasks:
try:
results[agent_id] = await task
except Exception as e:
results[agent_id] = {"error": str(e)}
return {
"workflow_type": "parallel",
"agent_results": results
}
async def _execute_conditional_workflow(
self,
workflow: WorkflowExecution,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute agents based on conditions"""
results = {}
condition_config = workflow.workflow_config.get("conditions", {})
for agent in workflow.agents:
# Check if agent should execute based on conditions
should_execute = await self._evaluate_condition(
agent.agent_id, condition_config, input_data, results
)
if should_execute:
agent_result = await self._execute_agent(
agent, input_data, capability_token
)
results[agent.agent_id] = agent_result
else:
results[agent.agent_id] = {"status": "skipped"}
return {
"workflow_type": "conditional",
"agent_results": results
}
async def _execute_pipeline_workflow(
self,
workflow: WorkflowExecution,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute agents in pipeline with data transformation"""
results = {}
current_data = input_data
for i, agent in enumerate(workflow.agents):
# Add pipeline metadata
pipeline_data = {
**current_data,
"_pipeline_stage": i,
"_pipeline_total": len(workflow.agents)
}
agent_result = await self._execute_agent(
agent, pipeline_data, capability_token
)
results[agent.agent_id] = agent_result
# Transform data for next stage
if "transformed_output" in agent_result:
current_data = agent_result["transformed_output"]
elif "output" in agent_result:
current_data = agent_result["output"]
return {
"workflow_type": "pipeline",
"final_output": current_data,
"agent_results": results
}
async def _execute_map_reduce_workflow(
self,
workflow: WorkflowExecution,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute map-reduce workflow"""
# Separate map and reduce agents
map_agents = [a for a in workflow.agents if a.agent_type.endswith("_mapper")]
reduce_agents = [a for a in workflow.agents if a.agent_type.endswith("_reducer")]
# Execute map phase
map_tasks = []
input_chunks = input_data.get("chunks", [input_data])
for i, chunk in enumerate(input_chunks):
for agent in map_agents:
task = asyncio.create_task(
self._execute_agent(agent, chunk, capability_token)
)
map_tasks.append((f"{agent.agent_id}_chunk_{i}", task))
# Collect map results
map_results = {}
for task_id, task in map_tasks:
try:
map_results[task_id] = await task
except Exception as e:
map_results[task_id] = {"error": str(e)}
# Execute reduce phase
reduce_results = {}
reduce_input = {"map_results": map_results}
for agent in reduce_agents:
agent_result = await self._execute_agent(
agent, reduce_input, capability_token
)
reduce_results[agent.agent_id] = agent_result
return {
"workflow_type": "map_reduce",
"map_results": map_results,
"reduce_results": reduce_results
}
async def _execute_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute a single agent"""
start_time = time.time()
# Create agent state
agent_state = AgentState(
agent_id=agent.agent_id,
status=AgentStatus.RUNNING,
current_task=f"Executing {agent.agent_type}",
memory_usage_mb=0,
cpu_usage_percent=0.0,
started_at=datetime.utcnow().isoformat(),
last_activity=datetime.utcnow().isoformat()
)
self.agent_states[agent.agent_id] = agent_state
try:
# Simulate agent execution based on type
if agent.agent_type == "data_processor":
result = await self._execute_data_processor(agent, input_data)
elif agent.agent_type == "llm_agent":
result = await self._execute_llm_agent(agent, input_data, capability_token)
elif agent.agent_type == "embedding_agent":
result = await self._execute_embedding_agent(agent, input_data, capability_token)
elif agent.agent_type == "rag_agent":
result = await self._execute_rag_agent(agent, input_data, capability_token)
elif agent.agent_type == "integration_agent":
result = await self._execute_integration_agent(agent, input_data, capability_token)
else:
result = await self._execute_custom_agent(agent, input_data)
# Update agent state
agent_state.status = AgentStatus.COMPLETED
agent_state.output_data = result
agent_state.last_activity = datetime.utcnow().isoformat()
processing_time = time.time() - start_time
logger.info(
f"Agent {agent.agent_id} completed in {processing_time:.2f}s"
)
return {
"status": "completed",
"processing_time": processing_time,
"output": result
}
except Exception as e:
# Update agent error state
agent_state.status = AgentStatus.FAILED
agent_state.error_message = str(e)
agent_state.last_activity = datetime.utcnow().isoformat()
logger.error(f"Agent {agent.agent_id} failed: {e}")
return {
"status": "failed",
"error": str(e),
"processing_time": time.time() - start_time
}
# Agent execution implementations would go here...
# For now, these are placeholder implementations
async def _execute_data_processor(
self,
agent: AgentDefinition,
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute data processing agent"""
await asyncio.sleep(0.1) # Simulate processing
return {
"processed_data": input_data,
"processing_info": "Data processed successfully"
}
async def _execute_llm_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute LLM agent"""
await asyncio.sleep(0.2) # Simulate LLM call
return {
"llm_response": f"LLM processed: {input_data.get('prompt', 'No prompt provided')}",
"model_used": "groq/llama-3-8b"
}
async def _execute_embedding_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute embedding agent"""
await asyncio.sleep(0.1) # Simulate embedding generation
texts = input_data.get("texts", [""])
return {
"embeddings": [[0.1] * 1024 for _ in texts], # Mock embeddings
"model_used": "BAAI/bge-m3"
}
async def _execute_rag_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute RAG agent"""
await asyncio.sleep(0.3) # Simulate RAG processing
return {
"rag_response": "RAG generated response",
"retrieved_docs": ["doc1", "doc2"],
"confidence_score": 0.85
}
async def _execute_integration_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any],
capability_token: str
) -> Dict[str, Any]:
"""Execute integration agent"""
await asyncio.sleep(0.1) # Simulate external API call
return {
"integration_result": "External API called successfully",
"response_data": input_data
}
async def _execute_custom_agent(
self,
agent: AgentDefinition,
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute custom agent type"""
await asyncio.sleep(0.1) # Simulate custom processing
return {
"custom_result": f"Custom agent {agent.agent_type} executed",
"input_data": input_data
}
async def _verify_workflow_permissions(
self,
capability: Dict[str, Any],
workflow_type: WorkflowType,
agents: List[AgentDefinition]
) -> None:
"""Verify workflow creation permissions"""
capabilities = capability.get("capabilities", [])
# Check for workflow creation permission
workflow_caps = [
cap for cap in capabilities
if cap.get("resource") == "workflows"
]
if not workflow_caps:
raise CapabilityError("No workflow permissions in capability token")
# Check specific workflow type permission
workflow_cap = workflow_caps[0]
actions = workflow_cap.get("actions", [])
if "create" not in actions:
raise CapabilityError("No workflow creation permission")
# Check agent-specific permissions
for agent in agents:
for required_cap in agent.capabilities_required:
if not any(
cap.get("resource") == required_cap.split(".")[0]
for cap in capabilities
):
raise CapabilityError(
f"Missing capability for agent {agent.agent_id}: {required_cap}"
)
async def _verify_execution_permissions(
self,
capability: Dict[str, Any],
workflow: WorkflowExecution
) -> None:
"""Verify workflow execution permissions"""
capabilities = capability.get("capabilities", [])
# Check for workflow execution permission
workflow_caps = [
cap for cap in capabilities
if cap.get("resource") == "workflows"
]
if not workflow_caps:
raise CapabilityError("No workflow permissions in capability token")
workflow_cap = workflow_caps[0]
actions = workflow_cap.get("actions", [])
if "execute" not in actions:
raise CapabilityError("No workflow execution permission")
async def _evaluate_condition(
self,
agent_id: str,
condition_config: Dict[str, Any],
input_data: Dict[str, Any],
results: Dict[str, Any]
) -> bool:
"""Evaluate condition for conditional workflow"""
agent_condition = condition_config.get(agent_id, {})
if not agent_condition:
return True # No condition means always execute
condition_type = agent_condition.get("type", "always")
if condition_type == "always":
return True
elif condition_type == "never":
return False
elif condition_type == "input_contains":
key = agent_condition.get("key")
value = agent_condition.get("value")
return input_data.get(key) == value
elif condition_type == "previous_success":
previous_agent = agent_condition.get("previous_agent")
return (
previous_agent in results and
results[previous_agent].get("status") == "completed"
)
elif condition_type == "previous_failure":
previous_agent = agent_condition.get("previous_agent")
return (
previous_agent in results and
results[previous_agent].get("status") == "failed"
)
return True # Default to execute if condition not recognized
# Global orchestrator instance
_agent_orchestrator = None
def get_agent_orchestrator() -> AgentOrchestrator:
"""Get the global agent orchestrator instance"""
global _agent_orchestrator
if _agent_orchestrator is None:
_agent_orchestrator = AgentOrchestrator()
return _agent_orchestrator

View File

@@ -0,0 +1,280 @@
"""
GT 2.0 Configuration Sync Service
Syncs model configurations from admin cluster to resource cluster.
Enables admin control panel to control AI model routing.
"""
import asyncio
import httpx
import json
import time
import logging
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from pathlib import Path
from app.core.config import get_settings
from app.services.model_service import default_model_service
from app.providers.external_provider import get_external_provider
logger = logging.getLogger(__name__)
settings = get_settings()
class ConfigSyncService:
"""Syncs model configurations from admin cluster"""
def __init__(self):
# Force Docker service name for admin cluster communication in containerized environment
if hasattr(settings, 'admin_cluster_url') and settings.admin_cluster_url:
# Check if we're running in Docker (container environment)
import os
if os.path.exists('/.dockerenv'):
self.admin_cluster_url = "http://control-panel-backend:8000"
else:
self.admin_cluster_url = settings.admin_cluster_url
else:
self.admin_cluster_url = "http://control-panel-backend:8000"
self.sync_interval = settings.config_sync_interval or 60 # seconds
# Use the default singleton model service instance
self.model_service = default_model_service
self.last_sync = 0
self.sync_running = False
async def start_sync_loop(self):
"""Start the configuration sync loop"""
logger.info("Starting configuration sync loop")
while True:
try:
if not self.sync_running:
await self.sync_configurations()
await asyncio.sleep(self.sync_interval)
except Exception as e:
logger.error(f"Config sync loop error: {e}")
await asyncio.sleep(30) # Wait 30s on error
async def sync_configurations(self):
"""Sync model configurations from admin cluster"""
if self.sync_running:
return
self.sync_running = True
try:
logger.debug("Syncing model configurations from admin cluster")
# Fetch all model configurations from admin cluster
configs = await self._fetch_admin_configs()
if configs:
# Update local model registry
await self._update_local_registry(configs)
# Update provider configurations
await self._update_provider_configs(configs)
self.last_sync = time.time()
logger.info(f"Successfully synced {len(configs)} model configurations")
else:
logger.warning("No configurations received from admin cluster")
except Exception as e:
logger.error(f"Configuration sync failed: {e}")
finally:
self.sync_running = False
async def _fetch_admin_configs(self) -> Optional[List[Dict[str, Any]]]:
"""Fetch model configurations from admin cluster"""
try:
logger.debug(f"Attempting to fetch configs from: {self.admin_cluster_url}/api/v1/models/configs/all")
async with httpx.AsyncClient(timeout=30.0) as client:
# Add authentication for admin cluster access
headers = {
"Authorization": "Bearer admin-cluster-sync-token",
"Content-Type": "application/json"
}
response = await client.get(
f"{self.admin_cluster_url}/api/v1/models/configs/all",
headers=headers
)
logger.debug(f"Admin cluster response: {response.status_code}")
if response.status_code == 200:
data = response.json()
configs = data.get("configs", [])
logger.debug(f"Successfully fetched {len(configs)} model configurations")
return configs
else:
logger.warning(f"Admin cluster returned {response.status_code}: {response.text}")
return None
except httpx.RequestError as e:
logger.error(f"Failed to connect to admin cluster: {e}")
return None
except Exception as e:
logger.error(f"Error fetching admin configs: {e}")
return None
async def _update_local_registry(self, configs: List[Dict[str, Any]]):
"""Update local model registry with admin configurations"""
try:
for config in configs:
await self.model_service.register_or_update_model(
model_id=config["model_id"],
name=config["name"],
version=config["version"],
provider=config["provider"],
model_type=config["model_type"],
endpoint=config["endpoint"],
api_key_name=config.get("api_key_name"),
specifications=config.get("specifications", {}),
capabilities=config.get("capabilities", {}),
cost=config.get("cost", {}),
description=config.get("description"),
config=config.get("config", {}),
status=config.get("status", {}),
sync_timestamp=config.get("sync_timestamp")
)
# Log BGE-M3 configuration details for debugging persistence
if "bge-m3" in config["model_id"].lower():
model_config = config.get("config", {})
logger.info(
f"Synced BGE-M3 configuration from database: "
f"endpoint={config['endpoint']}, "
f"is_local_mode={model_config.get('is_local_mode', True)}, "
f"external_endpoint={model_config.get('external_endpoint', 'None')}"
)
except Exception as e:
logger.error(f"Failed to update local registry: {e}")
raise
async def _update_provider_configs(self, configs: List[Dict[str, Any]]):
"""Update provider configurations based on admin settings"""
try:
# Group configs by provider
provider_configs = {}
for config in configs:
provider = config["provider"]
if provider not in provider_configs:
provider_configs[provider] = []
provider_configs[provider].append(config)
# Update each provider
for provider, provider_models in provider_configs.items():
await self._update_provider(provider, provider_models)
except Exception as e:
logger.error(f"Failed to update provider configs: {e}")
raise
async def _update_provider(self, provider: str, models: List[Dict[str, Any]]):
"""Update specific provider configuration"""
try:
# Generic provider update - all providers are now supported automatically
provider_models = [m for m in models if m["provider"] == provider]
logger.debug(f"Updated {provider} provider with {len(provider_models)} models")
# Keep legacy support for specific providers if needed
if provider == "groq":
await self._update_groq_provider(models)
elif provider == "external":
await self._update_external_provider(models)
elif provider == "openai":
await self._update_openai_provider(models)
elif provider == "anthropic":
await self._update_anthropic_provider(models)
elif provider == "vllm":
await self._update_vllm_provider(models)
except Exception as e:
logger.error(f"Failed to update {provider} provider: {e}")
raise
async def _update_groq_provider(self, models: List[Dict[str, Any]]):
"""Update Groq provider configuration"""
# Update available Groq models
groq_models = [m for m in models if m["provider"] == "groq"]
logger.debug(f"Updated Groq provider with {len(groq_models)} models")
async def _update_external_provider(self, models: List[Dict[str, Any]]):
"""Update external provider configuration (BGE-M3, etc.)"""
external_models = [m for m in models if m["provider"] == "external"]
if external_models:
external_provider = await get_external_provider()
for model in external_models:
if "bge-m3" in model["model_id"].lower():
# Update BGE-M3 endpoint configuration
external_provider.update_model_endpoint(
model["model_id"],
model["endpoint"]
)
logger.debug(f"Updated BGE-M3 endpoint: {model['endpoint']}")
# Also refresh the embedding backend instance
try:
from app.core.backends import get_embedding_backend
embedding_backend = get_embedding_backend()
embedding_backend.refresh_endpoint_from_registry()
logger.info(f"Refreshed embedding backend with new BGE-M3 endpoint from database")
except Exception as e:
logger.error(f"Failed to refresh embedding backend: {e}")
logger.debug(f"Updated external provider with {len(external_models)} models")
async def _update_openai_provider(self, models: List[Dict[str, Any]]):
"""Update OpenAI provider configuration"""
openai_models = [m for m in models if m["provider"] == "openai"]
logger.debug(f"Updated OpenAI provider with {len(openai_models)} models")
async def _update_anthropic_provider(self, models: List[Dict[str, Any]]):
"""Update Anthropic provider configuration"""
anthropic_models = [m for m in models if m["provider"] == "anthropic"]
logger.debug(f"Updated Anthropic provider with {len(anthropic_models)} models")
async def _update_vllm_provider(self, models: List[Dict[str, Any]]):
"""Update vLLM provider configuration (BGE-M3 embeddings, etc.)"""
vllm_models = [m for m in models if m["provider"] == "vllm"]
for model in vllm_models:
if model["model_type"] == "embedding":
# This is an embedding model like BGE-M3
logger.debug(f"Updated vLLM embedding model: {model['model_id']} -> {model['endpoint']}")
else:
logger.debug(f"Updated vLLM model: {model['model_id']} -> {model['endpoint']}")
logger.debug(f"Updated vLLM provider with {len(vllm_models)} models")
async def force_sync(self):
"""Force immediate configuration sync"""
logger.info("Force syncing configurations")
await self.sync_configurations()
def get_sync_status(self) -> Dict[str, Any]:
"""Get current sync status"""
return {
"last_sync": datetime.fromtimestamp(self.last_sync).isoformat() if self.last_sync else None,
"sync_running": self.sync_running,
"admin_cluster_url": self.admin_cluster_url,
"sync_interval": self.sync_interval,
"next_sync": datetime.fromtimestamp(self.last_sync + self.sync_interval).isoformat() if self.last_sync else None
}
# Global config sync service instance
_config_sync_service = None
def get_config_sync_service() -> ConfigSyncService:
"""Get configuration sync service instance"""
global _config_sync_service
if _config_sync_service is None:
_config_sync_service = ConfigSyncService()
return _config_sync_service

View File

@@ -0,0 +1,101 @@
"""
Consul Service Registry
Handles service registration and discovery for the Resource Cluster.
"""
import logging
from typing import Dict, Any, List, Optional
import consul
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class ConsulRegistry:
"""Service registry using Consul"""
def __init__(self):
self.consul = None
try:
self.consul = consul.Consul(
host=settings.consul_host,
port=settings.consul_port,
token=settings.consul_token
)
except Exception as e:
logger.warning(f"Consul not available: {e}")
async def register_service(
self,
name: str,
service_id: str,
address: str,
port: int,
tags: List[str] = None,
check_interval: str = "10s"
) -> bool:
"""Register service with Consul"""
if not self.consul:
logger.warning("Consul not available, skipping registration")
return False
try:
self.consul.agent.service.register(
name=name,
service_id=service_id,
address=address,
port=port,
tags=tags or [],
check=consul.Check.http(
f"http://{address}:{port}/health",
interval=check_interval
)
)
logger.info(f"Registered service {service_id} with Consul")
return True
except Exception as e:
logger.error(f"Failed to register with Consul: {e}")
return False
async def deregister_service(self, service_id: str) -> bool:
"""Deregister service from Consul"""
if not self.consul:
return False
try:
self.consul.agent.service.deregister(service_id)
logger.info(f"Deregistered service {service_id} from Consul")
return True
except Exception as e:
logger.error(f"Failed to deregister from Consul: {e}")
return False
async def discover_service(self, service_name: str) -> List[Dict[str, Any]]:
"""Discover service instances"""
if not self.consul:
return []
try:
_, services = self.consul.health.service(service_name, passing=True)
instances = []
for service in services:
instances.append({
"id": service["Service"]["ID"],
"address": service["Service"]["Address"],
"port": service["Service"]["Port"],
"tags": service["Service"]["Tags"]
})
return instances
except Exception as e:
logger.error(f"Failed to discover service: {e}")
return []

View File

@@ -0,0 +1,536 @@
"""
Enhanced Document Processing Pipeline with Dual-Engine Support
Implements the DocumentProcessingPipeline from CLAUDE.md with both native
and Unstructured.io engine support, capability-based selection, and
stateless processing.
"""
import logging
import asyncio
import gc
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime
import hashlib
import json
from app.core.backends.document_processor import (
DocumentProcessorBackend,
ChunkingStrategy
)
logger = logging.getLogger(__name__)
@dataclass
class ProcessingResult:
"""Result of document processing"""
chunks: List[Dict[str, str]]
embeddings: Optional[List[List[float]]] # Optional embeddings
metadata: Dict[str, Any]
engine_used: str
processing_time_ms: float
token_count: int
@dataclass
class ProcessingOptions:
"""Options for document processing"""
engine_preference: str = "auto" # "native", "unstructured", "auto"
chunking_strategy: str = "hybrid" # "fixed", "semantic", "hierarchical", "hybrid"
chunk_size: int = 512 # tokens for BGE-M3
chunk_overlap: int = 128 # overlap tokens
generate_embeddings: bool = True
extract_metadata: bool = True
language_detection: bool = True
ocr_enabled: bool = False # For scanned PDFs
class UnstructuredAPIEngine:
"""
Mock Unstructured.io API engine for advanced document parsing.
In production, this would call the actual Unstructured API.
"""
def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None):
self.api_key = api_key
self.api_url = api_url or "https://api.unstructured.io"
self.supported_features = [
"table_extraction",
"image_extraction",
"ocr",
"language_detection",
"metadata_extraction",
"hierarchical_parsing"
]
async def process(
self,
content: bytes,
file_type: str,
options: Dict[str, Any]
) -> Dict[str, Any]:
"""
Process document using Unstructured API.
This is a mock implementation. In production:
1. Send content to Unstructured API
2. Handle rate limiting and retries
3. Parse structured response
"""
# Mock processing delay
await asyncio.sleep(0.5)
# Mock response structure
return {
"elements": [
{
"type": "Title",
"text": "Document Title",
"metadata": {"page_number": 1}
},
{
"type": "NarrativeText",
"text": "This is the main content of the document...",
"metadata": {"page_number": 1}
}
],
"metadata": {
"languages": ["en"],
"page_count": 1,
"has_tables": False,
"has_images": False
}
}
class NativeChunkingEngine:
"""
Native chunking engine using the existing DocumentProcessorBackend.
Fast, lightweight, and suitable for most text documents.
"""
def __init__(self):
self.processor = DocumentProcessorBackend()
async def process(
self,
content: bytes,
file_type: str,
options: ProcessingOptions
) -> List[Dict[str, Any]]:
"""Process document using native chunking"""
strategy = ChunkingStrategy(
strategy_type=options.chunking_strategy,
chunk_size=options.chunk_size,
chunk_overlap=options.chunk_overlap,
preserve_paragraphs=True,
preserve_sentences=True
)
chunks = await self.processor.process_document(
content=content,
document_type=file_type,
strategy=strategy,
metadata={
"processing_timestamp": datetime.utcnow().isoformat(),
"engine": "native"
}
)
return chunks
class DocumentProcessingPipeline:
"""
Dual-engine document processing pipeline with capability-based selection.
Features:
- Native engine for fast, simple processing
- Unstructured API for advanced features
- Capability-based engine selection
- Stateless processing with memory cleanup
- Optional embedding generation
"""
def __init__(self, resource_cluster_url: Optional[str] = None):
self.resource_cluster_url = resource_cluster_url or "http://localhost:8004"
self.native_engine = NativeChunkingEngine()
self.unstructured_engine = None # Lazy initialization
self.embedding_cache = {} # Cache for frequently used embeddings
logger.info("Document Processing Pipeline initialized")
def select_engine(
self,
filename: str,
token_data: Dict[str, Any],
options: ProcessingOptions
) -> str:
"""
Select processing engine based on file type and capabilities.
Args:
filename: Name of the file being processed
token_data: Capability token data
options: Processing options
Returns:
Engine name: "native" or "unstructured"
"""
# Check if user has premium parsing capability
has_premium = any(
cap.get("resource") == "premium_parsing"
for cap in token_data.get("capabilities", [])
)
# Force native if no premium capability
if not has_premium and options.engine_preference == "unstructured":
logger.info("Premium parsing requested but not available, using native engine")
return "native"
# Auto selection logic
if options.engine_preference == "auto":
# Use Unstructured for complex formats if available
complex_formats = [".pdf", ".docx", ".pptx", ".xlsx"]
needs_ocr = options.ocr_enabled
needs_tables = filename.lower().endswith((".xlsx", ".csv"))
if has_premium and (
any(filename.lower().endswith(fmt) for fmt in complex_formats) or
needs_ocr or needs_tables
):
return "unstructured"
else:
return "native"
# Respect explicit preference if capability allows
if options.engine_preference == "unstructured" and has_premium:
return "unstructured"
return "native"
async def process_document(
self,
file: bytes,
filename: str,
token_data: Dict[str, Any],
options: Optional[ProcessingOptions] = None
) -> ProcessingResult:
"""
Process document with selected engine.
Args:
file: Document content as bytes
filename: Name of the file
token_data: Capability token data
options: Processing options
Returns:
ProcessingResult with chunks, embeddings, and metadata
"""
start_time = datetime.utcnow()
try:
# Use default options if not provided
if options is None:
options = ProcessingOptions()
# Determine file type
file_type = self._get_file_extension(filename)
# Select engine based on capabilities
engine = self.select_engine(filename, token_data, options)
# Process with selected engine
if engine == "unstructured" and token_data.get("has_capability", {}).get("premium_parsing"):
result = await self._process_with_unstructured(file, filename, token_data, options)
else:
result = await self._process_with_native(file, filename, token_data, options)
# Generate embeddings if requested
embeddings = None
if options.generate_embeddings:
embeddings = await self._generate_embeddings(result.chunks, token_data)
# Calculate processing time
processing_time = (datetime.utcnow() - start_time).total_seconds() * 1000
# Calculate token count
token_count = sum(len(chunk["text"].split()) for chunk in result.chunks)
return ProcessingResult(
chunks=result.chunks,
embeddings=embeddings,
metadata={
"filename": filename,
"file_type": file_type,
"processing_timestamp": start_time.isoformat(),
"chunk_count": len(result.chunks),
"engine_used": engine,
"options": {
"chunking_strategy": options.chunking_strategy,
"chunk_size": options.chunk_size,
"chunk_overlap": options.chunk_overlap
}
},
engine_used=engine,
processing_time_ms=processing_time,
token_count=token_count
)
except Exception as e:
logger.error(f"Error processing document: {e}")
raise
finally:
# Ensure memory cleanup
del file
gc.collect()
async def _process_with_native(
self,
file: bytes,
filename: str,
token_data: Dict[str, Any],
options: ProcessingOptions
) -> ProcessingResult:
"""Process document with native engine"""
file_type = self._get_file_extension(filename)
chunks = await self.native_engine.process(file, file_type, options)
return ProcessingResult(
chunks=chunks,
embeddings=None,
metadata={"engine": "native"},
engine_used="native",
processing_time_ms=0,
token_count=0
)
async def _process_with_unstructured(
self,
file: bytes,
filename: str,
token_data: Dict[str, Any],
options: ProcessingOptions
) -> ProcessingResult:
"""Process document with Unstructured API"""
# Initialize Unstructured engine if needed
if self.unstructured_engine is None:
# Get API key from token constraints or environment
api_key = token_data.get("constraints", {}).get("unstructured_api_key")
self.unstructured_engine = UnstructuredAPIEngine(api_key=api_key)
file_type = self._get_file_extension(filename)
# Process with Unstructured
unstructured_result = await self.unstructured_engine.process(
content=file,
file_type=file_type,
options={
"ocr": options.ocr_enabled,
"extract_tables": True,
"extract_images": False, # Don't extract images for security
"languages": ["en", "es", "fr", "de", "zh"]
}
)
# Convert Unstructured elements to chunks
chunks = []
for element in unstructured_result.get("elements", []):
chunk_text = element.get("text", "")
if chunk_text.strip():
chunks.append({
"text": chunk_text,
"metadata": {
"element_type": element.get("type"),
"page_number": element.get("metadata", {}).get("page_number"),
"engine": "unstructured"
}
})
# Apply chunking strategy if chunks are too large
final_chunks = await self._apply_chunking_to_elements(chunks, options)
return ProcessingResult(
chunks=final_chunks,
embeddings=None,
metadata={
"engine": "unstructured",
"detected_languages": unstructured_result.get("metadata", {}).get("languages", []),
"page_count": unstructured_result.get("metadata", {}).get("page_count", 0),
"has_tables": unstructured_result.get("metadata", {}).get("has_tables", False),
"has_images": unstructured_result.get("metadata", {}).get("has_images", False)
},
engine_used="unstructured",
processing_time_ms=0,
token_count=0
)
async def _apply_chunking_to_elements(
self,
elements: List[Dict[str, Any]],
options: ProcessingOptions
) -> List[Dict[str, Any]]:
"""Apply chunking strategy to Unstructured elements if needed"""
final_chunks = []
for element in elements:
text = element["text"]
# Estimate token count (rough approximation)
estimated_tokens = len(text.split()) * 1.3
# If element is small enough, keep as is
if estimated_tokens <= options.chunk_size:
final_chunks.append(element)
else:
# Split large elements using native chunking
sub_chunks = await self._chunk_text(
text,
options.chunk_size,
options.chunk_overlap
)
for idx, sub_chunk in enumerate(sub_chunks):
chunk_metadata = element["metadata"].copy()
chunk_metadata["sub_chunk_index"] = idx
chunk_metadata["parent_element_type"] = element["metadata"].get("element_type")
final_chunks.append({
"text": sub_chunk,
"metadata": chunk_metadata
})
return final_chunks
async def _chunk_text(
self,
text: str,
chunk_size: int,
chunk_overlap: int
) -> List[str]:
"""Simple text chunking for large elements"""
words = text.split()
chunks = []
# Simple word-based chunking
for i in range(0, len(words), chunk_size - chunk_overlap):
chunk_words = words[i:i + chunk_size]
chunks.append(" ".join(chunk_words))
return chunks
async def _generate_embeddings(
self,
chunks: List[Dict[str, Any]],
token_data: Dict[str, Any]
) -> List[List[float]]:
"""
Generate embeddings for chunks.
This is a mock implementation. In production, this would:
1. Call the embedding service (BGE-M3 or similar)
2. Handle batching for efficiency
3. Apply caching for common chunks
"""
embeddings = []
for chunk in chunks:
# Check cache first
chunk_hash = hashlib.sha256(chunk["text"].encode()).hexdigest()
if chunk_hash in self.embedding_cache:
embeddings.append(self.embedding_cache[chunk_hash])
else:
# Mock embedding generation
# In production: call embedding API
embedding = [0.1] * 768 # Mock 768-dim embedding (BGE-M3 size)
embeddings.append(embedding)
# Cache for reuse (with size limit)
if len(self.embedding_cache) < 1000:
self.embedding_cache[chunk_hash] = embedding
return embeddings
def _get_file_extension(self, filename: str) -> str:
"""Extract file extension from filename"""
parts = filename.lower().split(".")
if len(parts) > 1:
return f".{parts[-1]}"
return ".txt" # Default to text
async def validate_document(
self,
file_size: int,
filename: str,
token_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
Validate document before processing.
Args:
file_size: Size of file in bytes
filename: Name of the file
token_data: Capability token data
Returns:
Validation result with warnings and errors
"""
# Get size limits from token
max_size = token_data.get("constraints", {}).get("max_file_size", 50 * 1024 * 1024)
validation = {
"valid": True,
"warnings": [],
"errors": [],
"recommendations": []
}
# Check file size
if file_size > max_size:
validation["valid"] = False
validation["errors"].append(f"File exceeds maximum size of {max_size / 1024 / 1024:.1f} MiB")
elif file_size > 10 * 1024 * 1024:
validation["warnings"].append("Large file may take longer to process")
validation["recommendations"].append("Consider using streaming processing for better performance")
# Check file type
file_type = self._get_file_extension(filename)
supported_types = [".pdf", ".docx", ".txt", ".md", ".html", ".csv", ".xlsx", ".pptx"]
if file_type not in supported_types:
validation["valid"] = False
validation["errors"].append(f"Unsupported file type: {file_type}")
validation["recommendations"].append(f"Supported types: {', '.join(supported_types)}")
# Check for special processing needs
if file_type in [".xlsx", ".csv"]:
validation["recommendations"].append("Table extraction will be applied automatically")
if file_type == ".pdf":
validation["recommendations"].append("Enable OCR if document contains scanned images")
return validation
async def get_processing_stats(self) -> Dict[str, Any]:
"""Get processing statistics"""
return {
"engines_available": ["native", "unstructured"],
"native_engine_status": "ready",
"unstructured_engine_status": "ready" if self.unstructured_engine else "not_initialized",
"embedding_cache_size": len(self.embedding_cache),
"supported_formats": [".pdf", ".docx", ".txt", ".md", ".html", ".csv", ".xlsx", ".pptx"],
"default_chunk_size": 512,
"default_chunk_overlap": 128,
"stateless": True
}

View File

@@ -0,0 +1,447 @@
"""
Embedding Service for GT 2.0 Resource Cluster
Provides embedding generation with:
- BGE-M3 model integration
- Batch processing capabilities
- Rate limiting and quota management
- Capability-based authentication
- Stateless operation (no data storage)
GT 2.0 Architecture Principles:
- Perfect Tenant Isolation: Per-request capability validation
- Zero Downtime: Stateless design, circuit breakers
- Self-Contained Security: Capability-based auth
- No Complexity Addition: Simple interface, no database
"""
import asyncio
import logging
import time
import os
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
import uuid
from app.core.backends.embedding_backend import EmbeddingBackend, EmbeddingRequest
from app.core.capability_auth import verify_capability_token, CapabilityError
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class EmbeddingResponse:
"""Response structure for embedding generation"""
request_id: str
embeddings: List[List[float]]
model: str
dimensions: int
tokens_used: int
processing_time_ms: int
tenant_id: str
created_at: str
@dataclass
class EmbeddingStats:
"""Statistics for embedding requests"""
total_requests: int = 0
total_tokens_processed: int = 0
total_processing_time_ms: int = 0
average_processing_time_ms: float = 0.0
last_request_at: Optional[str] = None
class EmbeddingService:
"""
STATELESS embedding service for GT 2.0 Resource Cluster.
Key features:
- BGE-M3 model for high-quality embeddings
- Batch processing for efficiency
- Rate limiting per capability token
- Memory-conscious processing
- No persistent storage
"""
def __init__(self):
self.backend = EmbeddingBackend()
self.stats = EmbeddingStats()
# Initialize BGE-M3 tokenizer for accurate token counting
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
logger.info("Initialized BGE-M3 tokenizer for accurate token counting")
except Exception as e:
logger.warning(f"Failed to load BGE-M3 tokenizer, using word estimation: {e}")
self.tokenizer = None
# Rate limiting settings (per capability token)
self.rate_limits = {
"requests_per_minute": 60,
"tokens_per_minute": 50000,
"max_batch_size": 32
}
# Track requests for rate limiting (in-memory, temporary)
self._request_tracker = {}
logger.info("STATELESS embedding service initialized")
async def generate_embeddings(
self,
texts: List[str],
capability_token: str,
instruction: Optional[str] = None,
request_id: Optional[str] = None,
normalize: bool = True
) -> EmbeddingResponse:
"""
Generate embeddings with capability-based authentication.
Args:
texts: List of texts to embed
capability_token: JWT token with embedding permissions
instruction: Optional instruction for embedding context
request_id: Optional request ID for tracking
normalize: Whether to normalize embeddings
Returns:
EmbeddingResponse with generated embeddings
Raises:
CapabilityError: If token invalid or insufficient permissions
ValueError: If request parameters invalid
"""
start_time = time.time()
request_id = request_id or str(uuid.uuid4())
try:
# Verify capability token and extract permissions
capability = await verify_capability_token(capability_token)
tenant_id = capability.get("tenant_id")
user_id = capability.get("sub") # Extract user ID from token
# Check embedding permissions
await self._verify_embedding_permissions(capability, len(texts))
# Apply rate limiting
await self._check_rate_limits(capability_token, len(texts))
# Validate input
self._validate_embedding_request(texts)
# Generate embeddings via backend
embeddings = await self.backend.generate_embeddings(
texts=texts,
instruction=instruction,
tenant_id=tenant_id,
request_id=request_id
)
# Calculate processing metrics
processing_time_ms = int((time.time() - start_time) * 1000)
total_tokens = self._estimate_tokens(texts)
# Update statistics
self._update_stats(total_tokens, processing_time_ms)
# Log embedding usage for billing (non-blocking)
# Fire and forget - don't wait for completion
asyncio.create_task(
self._log_embedding_usage(
tenant_id=tenant_id,
user_id=user_id,
tokens_used=total_tokens,
embedding_count=len(embeddings),
model=self.backend.model_name,
request_id=request_id
)
)
# Create response
response = EmbeddingResponse(
request_id=request_id,
embeddings=embeddings,
model=self.backend.model_name,
dimensions=self.backend.embedding_dimensions,
tokens_used=total_tokens,
processing_time_ms=processing_time_ms,
tenant_id=tenant_id,
created_at=datetime.utcnow().isoformat()
)
logger.info(
f"Generated {len(embeddings)} embeddings for tenant {tenant_id} "
f"in {processing_time_ms}ms"
)
return response
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise
finally:
# Always ensure cleanup
if 'texts' in locals():
del texts
async def get_model_info(self) -> Dict[str, Any]:
"""Get information about the embedding model"""
return {
"model_name": self.backend.model_name,
"dimensions": self.backend.embedding_dimensions,
"max_sequence_length": self.backend.max_sequence_length,
"max_batch_size": self.backend.max_batch_size,
"supports_instruction": True,
"normalization_default": True
}
async def get_service_stats(
self,
capability_token: str
) -> Dict[str, Any]:
"""
Get service statistics (for admin users only).
Args:
capability_token: JWT token with admin permissions
Returns:
Service statistics
"""
# Verify admin permissions
capability = await verify_capability_token(capability_token)
if not self._has_admin_permissions(capability):
raise CapabilityError("Admin permissions required")
return {
"model_info": await self.get_model_info(),
"statistics": asdict(self.stats),
"rate_limits": self.rate_limits,
"active_requests": len(self._request_tracker)
}
async def health_check(self) -> Dict[str, Any]:
"""Check service health"""
return {
"status": "healthy",
"service": "embedding_service",
"model": self.backend.model_name,
"backend_ready": True,
"last_request": self.stats.last_request_at
}
async def _verify_embedding_permissions(
self,
capability: Dict[str, Any],
text_count: int
) -> None:
"""Verify that capability token has embedding permissions"""
# Check for embedding capability
capabilities = capability.get("capabilities", [])
embedding_caps = [
cap for cap in capabilities
if cap.get("resource") == "embeddings"
]
if not embedding_caps:
raise CapabilityError("No embedding permissions in capability token")
# Check constraints
embedding_cap = embedding_caps[0] # Use first embedding capability
constraints = embedding_cap.get("constraints", {})
# Check batch size limit
max_batch = constraints.get("max_batch_size", self.rate_limits["max_batch_size"])
if text_count > max_batch:
raise CapabilityError(f"Batch size {text_count} exceeds limit {max_batch}")
# Check rate limits
rate_limit = constraints.get("rate_limit_per_minute", self.rate_limits["requests_per_minute"])
token_limit = constraints.get("tokens_per_minute", self.rate_limits["tokens_per_minute"])
logger.debug(f"Embedding permissions verified: batch={text_count}, limits=({rate_limit}, {token_limit})")
async def _check_rate_limits(
self,
capability_token: str,
text_count: int
) -> None:
"""Check rate limits for capability token"""
now = time.time()
token_hash = hash(capability_token) % 10000 # Simple tracking key
# Clean old entries (older than 1 minute)
cleanup_time = now - 60
self._request_tracker = {
k: v for k, v in self._request_tracker.items()
if v.get("last_request", 0) > cleanup_time
}
# Get or create tracker for this token
if token_hash not in self._request_tracker:
self._request_tracker[token_hash] = {
"requests": 0,
"tokens": 0,
"last_request": now
}
tracker = self._request_tracker[token_hash]
# Check request rate limit
if tracker["requests"] >= self.rate_limits["requests_per_minute"]:
raise CapabilityError("Rate limit exceeded: too many requests per minute")
# Estimate tokens and check token limit
estimated_tokens = self._estimate_tokens([f"text_{i}" for i in range(text_count)])
if tracker["tokens"] + estimated_tokens > self.rate_limits["tokens_per_minute"]:
raise CapabilityError("Rate limit exceeded: too many tokens per minute")
# Update tracker
tracker["requests"] += 1
tracker["tokens"] += estimated_tokens
tracker["last_request"] = now
def _validate_embedding_request(self, texts: List[str]) -> None:
"""Validate embedding request parameters"""
if not texts:
raise ValueError("No texts provided for embedding")
if not isinstance(texts, list):
raise ValueError("Texts must be a list")
if len(texts) > self.backend.max_batch_size:
raise ValueError(f"Batch size {len(texts)} exceeds maximum {self.backend.max_batch_size}")
# Check individual text lengths
for i, text in enumerate(texts):
if not isinstance(text, str):
raise ValueError(f"Text at index {i} must be a string")
if len(text.strip()) == 0:
raise ValueError(f"Text at index {i} is empty")
# Simple token estimation for length check
estimated_tokens = len(text.split()) * 1.3 # Rough estimation
if estimated_tokens > self.backend.max_sequence_length:
raise ValueError(f"Text at index {i} exceeds maximum length")
def _estimate_tokens(self, texts: List[str]) -> int:
"""
Count tokens using actual BGE-M3 tokenizer.
Falls back to word-count estimation if tokenizer unavailable.
"""
if self.tokenizer is not None:
try:
total_tokens = 0
for text in texts:
tokens = self.tokenizer.encode(text, add_special_tokens=False)
total_tokens += len(tokens)
return total_tokens
except Exception as e:
logger.warning(f"Tokenizer error, falling back to estimation: {e}")
# Fallback: word count * 1.3 (rough estimation)
total_words = sum(len(text.split()) for text in texts)
return int(total_words * 1.3)
def _has_admin_permissions(self, capability: Dict[str, Any]) -> bool:
"""Check if capability has admin permissions"""
capabilities = capability.get("capabilities", [])
return any(
cap.get("resource") == "admin" and "stats" in cap.get("actions", [])
for cap in capabilities
)
def _update_stats(self, tokens_processed: int, processing_time_ms: int) -> None:
"""Update service statistics"""
self.stats.total_requests += 1
self.stats.total_tokens_processed += tokens_processed
self.stats.total_processing_time_ms += processing_time_ms
self.stats.average_processing_time_ms = (
self.stats.total_processing_time_ms / self.stats.total_requests
)
self.stats.last_request_at = datetime.utcnow().isoformat()
async def _log_embedding_usage(
self,
tenant_id: str,
user_id: str,
tokens_used: int,
embedding_count: int,
model: str = "BAAI/bge-m3",
request_id: Optional[str] = None
) -> None:
"""
Log embedding usage to control panel database for billing.
This method logs usage asynchronously and does not block the embedding response.
Failures are logged as warnings but do not raise exceptions.
Args:
tenant_id: Tenant identifier
user_id: User identifier (from capability token 'sub')
tokens_used: Number of tokens processed
embedding_count: Number of embeddings generated
model: Embedding model name
request_id: Optional request ID for tracking
"""
try:
import asyncpg
# Calculate cost: BGE-M3 pricing ~$0.10 per million tokens
cost_cents = (tokens_used / 1_000_000) * 0.10 * 100
# Connect to control panel database
# Using environment variables from docker-compose
db_password = os.getenv("CONTROL_PANEL_DB_PASSWORD")
if not db_password:
logger.warning("CONTROL_PANEL_DB_PASSWORD not set, skipping embedding usage logging")
return
conn = await asyncpg.connect(
host=os.getenv("CONTROL_PANEL_DB_HOST", "gentwo-controlpanel-postgres"),
database=os.getenv("CONTROL_PANEL_DB_NAME", "gt2_admin"),
user=os.getenv("CONTROL_PANEL_DB_USER", "postgres"),
password=db_password,
timeout=5.0
)
try:
# Insert usage log
await conn.execute("""
INSERT INTO public.embedding_usage_logs
(tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""", tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
logger.info(
f"Logged embedding usage: tenant={tenant_id}, user={user_id}, "
f"tokens={tokens_used}, embeddings={embedding_count}, cost_cents={cost_cents:.4f}"
)
finally:
await conn.close()
except Exception as e:
# Log warning but don't fail the embedding request
logger.warning(f"Failed to log embedding usage for tenant {tenant_id}: {e}")
# Global service instance
_embedding_service = None
def get_embedding_service() -> EmbeddingService:
"""Get the global embedding service instance"""
global _embedding_service
if _embedding_service is None:
_embedding_service = EmbeddingService()
return _embedding_service

View File

@@ -0,0 +1,729 @@
"""
Integration Proxy Service for GT 2.0
Secure proxy service for external integrations with capability-based access control,
sandbox restrictions, and comprehensive audit logging. All external calls are routed
through this service in the Resource Cluster for security and monitoring.
"""
import asyncio
import json
import httpx
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta
from pathlib import Path
from dataclasses import dataclass, asdict
from enum import Enum
import logging
from contextlib import asynccontextmanager
from app.core.security import verify_capability_token
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class IntegrationType(Enum):
"""Types of external integrations"""
COMMUNICATION = "communication" # Slack, Teams, Discord
DEVELOPMENT = "development" # GitHub, GitLab, Jira
PROJECT_MANAGEMENT = "project_management" # Asana, Monday.com
DATABASE = "database" # PostgreSQL, MySQL, MongoDB
CUSTOM_API = "custom_api" # Custom REST/GraphQL APIs
WEBHOOK = "webhook" # Outbound webhook calls
class SandboxLevel(Enum):
"""Sandbox restriction levels"""
NONE = "none" # No restrictions (trusted)
BASIC = "basic" # Basic timeout and size limits
RESTRICTED = "restricted" # Limited API calls and data access
STRICT = "strict" # Maximum restrictions
@dataclass
class IntegrationConfig:
"""Configuration for external integration"""
id: str
name: str
integration_type: IntegrationType
base_url: str
authentication_method: str # oauth2, api_key, basic_auth, certificate
sandbox_level: SandboxLevel
# Authentication details (encrypted)
auth_config: Dict[str, Any]
# Rate limits and constraints
max_requests_per_hour: int = 1000
max_response_size_bytes: int = 10 * 1024 * 1024 # 10MB
timeout_seconds: int = 30
# Allowed operations
allowed_methods: List[str] = None
allowed_endpoints: List[str] = None
blocked_endpoints: List[str] = None
# Network restrictions
allowed_domains: List[str] = None
# Created metadata
created_at: datetime = None
created_by: str = ""
is_active: bool = True
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
if self.allowed_methods is None:
self.allowed_methods = ["GET", "POST"]
if self.allowed_endpoints is None:
self.allowed_endpoints = []
if self.blocked_endpoints is None:
self.blocked_endpoints = []
if self.allowed_domains is None:
self.allowed_domains = []
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for storage"""
data = asdict(self)
data["integration_type"] = self.integration_type.value
data["sandbox_level"] = self.sandbox_level.value
data["created_at"] = self.created_at.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "IntegrationConfig":
"""Create from dictionary"""
data["integration_type"] = IntegrationType(data["integration_type"])
data["sandbox_level"] = SandboxLevel(data["sandbox_level"])
data["created_at"] = datetime.fromisoformat(data["created_at"])
return cls(**data)
@dataclass
class ProxyRequest:
"""Request to proxy to external service"""
integration_id: str
method: str
endpoint: str
headers: Optional[Dict[str, str]] = None
data: Optional[Dict[str, Any]] = None
params: Optional[Dict[str, str]] = None
timeout_override: Optional[int] = None
def __post_init__(self):
if self.headers is None:
self.headers = {}
if self.data is None:
self.data = {}
if self.params is None:
self.params = {}
@dataclass
class ProxyResponse:
"""Response from proxied external service"""
success: bool
status_code: int
data: Optional[Dict[str, Any]]
headers: Dict[str, str]
execution_time_ms: int
sandbox_applied: bool
restrictions_applied: List[str]
error_message: Optional[str] = None
def __post_init__(self):
if self.headers is None:
self.headers = {}
if self.restrictions_applied is None:
self.restrictions_applied = []
class SandboxManager:
"""Manages sandbox restrictions for external integrations"""
def __init__(self):
self.active_requests: Dict[str, datetime] = {}
self.rate_limiters: Dict[str, List[datetime]] = {}
def apply_sandbox_restrictions(
self,
config: IntegrationConfig,
request: ProxyRequest,
capability_token: Dict[str, Any]
) -> Tuple[ProxyRequest, List[str]]:
"""Apply sandbox restrictions to request"""
restrictions_applied = []
if config.sandbox_level == SandboxLevel.NONE:
return request, restrictions_applied
# Apply timeout restrictions
if config.sandbox_level in [SandboxLevel.BASIC, SandboxLevel.RESTRICTED, SandboxLevel.STRICT]:
max_timeout = self._get_max_timeout(config.sandbox_level)
if request.timeout_override is None or request.timeout_override > max_timeout:
request.timeout_override = max_timeout
restrictions_applied.append(f"timeout_limited_to_{max_timeout}s")
# Apply endpoint restrictions
if config.sandbox_level in [SandboxLevel.RESTRICTED, SandboxLevel.STRICT]:
# Check blocked endpoints first
if request.endpoint in config.blocked_endpoints:
raise PermissionError(f"Endpoint {request.endpoint} is blocked")
# Then check allowed endpoints if specified
if config.allowed_endpoints and request.endpoint not in config.allowed_endpoints:
raise PermissionError(f"Endpoint {request.endpoint} not allowed")
restrictions_applied.append("endpoint_validation")
# Apply method restrictions
if config.sandbox_level == SandboxLevel.STRICT:
allowed_methods = config.allowed_methods or ["GET", "POST"]
if request.method not in allowed_methods:
raise PermissionError(f"HTTP method {request.method} not allowed in strict mode")
restrictions_applied.append("method_restricted")
# Apply data size restrictions
if request.data:
data_size = len(json.dumps(request.data).encode())
max_size = self._get_max_data_size(config.sandbox_level)
if data_size > max_size:
raise ValueError(f"Request data size {data_size} exceeds limit {max_size}")
restrictions_applied.append("data_size_validated")
# Apply capability-based restrictions
constraints = capability_token.get("constraints", {})
if "integration_timeout_seconds" in constraints:
max_cap_timeout = constraints["integration_timeout_seconds"]
if request.timeout_override > max_cap_timeout:
request.timeout_override = max_cap_timeout
restrictions_applied.append(f"capability_timeout_{max_cap_timeout}s")
return request, restrictions_applied
def _get_max_timeout(self, sandbox_level: SandboxLevel) -> int:
"""Get maximum timeout for sandbox level"""
timeouts = {
SandboxLevel.BASIC: 60,
SandboxLevel.RESTRICTED: 30,
SandboxLevel.STRICT: 15
}
return timeouts.get(sandbox_level, 30)
def _get_max_data_size(self, sandbox_level: SandboxLevel) -> int:
"""Get maximum data size for sandbox level"""
sizes = {
SandboxLevel.BASIC: 1024 * 1024, # 1MB
SandboxLevel.RESTRICTED: 512 * 1024, # 512KB
SandboxLevel.STRICT: 256 * 1024 # 256KB
}
return sizes.get(sandbox_level, 512 * 1024)
async def check_rate_limits(self, integration_id: str, config: IntegrationConfig) -> bool:
"""Check if request is within rate limits"""
now = datetime.utcnow()
hour_ago = now - timedelta(hours=1)
# Initialize or clean rate limiter
if integration_id not in self.rate_limiters:
self.rate_limiters[integration_id] = []
# Remove old requests
self.rate_limiters[integration_id] = [
req_time for req_time in self.rate_limiters[integration_id]
if req_time > hour_ago
]
# Check rate limit
if len(self.rate_limiters[integration_id]) >= config.max_requests_per_hour:
return False
# Record this request
self.rate_limiters[integration_id].append(now)
return True
class IntegrationProxyService:
"""
Integration Proxy Service for secure external API access.
Features:
- Capability-based access control
- Sandbox restrictions based on trust level
- Rate limiting and usage tracking
- Comprehensive audit logging
- Response sanitization and size limits
"""
def __init__(self, base_path: Optional[Path] = None):
self.base_path = base_path or Path("/data/resource-cluster/integrations")
self.configs_path = self.base_path / "configs"
self.usage_path = self.base_path / "usage"
self.audit_path = self.base_path / "audit"
self.sandbox_manager = SandboxManager()
self.http_client = None
# Ensure directories exist with proper permissions
self._ensure_directories()
def _ensure_directories(self):
"""Ensure storage directories exist with proper permissions"""
for path in [self.configs_path, self.usage_path, self.audit_path]:
path.mkdir(parents=True, exist_ok=True, mode=0o700)
@asynccontextmanager
async def get_http_client(self):
"""Get HTTP client with proper configuration"""
if self.http_client is None:
self.http_client = httpx.AsyncClient(
timeout=httpx.Timeout(60.0),
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20)
)
try:
yield self.http_client
finally:
# Client stays open for reuse
pass
async def execute_integration(
self,
request: ProxyRequest,
capability_token: str
) -> ProxyResponse:
"""Execute integration request with security and sandbox restrictions"""
start_time = datetime.utcnow()
try:
# Verify capability token
token_obj = verify_capability_token(capability_token)
if not token_obj:
raise PermissionError("Invalid capability token")
# Convert token object to dict for compatibility
token_data = {
"tenant_id": token_obj.tenant_id,
"sub": token_obj.sub,
"capabilities": [cap.dict() if hasattr(cap, 'dict') else cap for cap in token_obj.capabilities],
"constraints": {}
}
# Load integration configuration
config = await self._load_integration_config(request.integration_id)
if not config or not config.is_active:
raise ValueError(f"Integration {request.integration_id} not found or inactive")
# Validate capability for this integration
required_capability = f"integration:{request.integration_id}:{request.method.lower()}"
if not self._has_capability(token_data, required_capability):
raise PermissionError(f"Missing capability: {required_capability}")
# Check rate limits
if not await self.sandbox_manager.check_rate_limits(request.integration_id, config):
raise PermissionError("Rate limit exceeded")
# Apply sandbox restrictions
sandboxed_request, restrictions = self.sandbox_manager.apply_sandbox_restrictions(
config, request, token_data
)
# Execute the request
response = await self._execute_proxied_request(config, sandboxed_request)
response.sandbox_applied = len(restrictions) > 0
response.restrictions_applied = restrictions
# Calculate execution time
execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000
response.execution_time_ms = int(execution_time)
# Log usage
await self._log_usage(
integration_id=request.integration_id,
tenant_id=token_data.get("tenant_id"),
user_id=token_data.get("sub"),
method=request.method,
endpoint=request.endpoint,
success=response.success,
execution_time_ms=response.execution_time_ms
)
# Audit log
await self._audit_log(
action="integration_executed",
integration_id=request.integration_id,
user_id=token_data.get("sub"),
details={
"method": request.method,
"endpoint": request.endpoint,
"success": response.success,
"restrictions_applied": restrictions
}
)
return response
except Exception as e:
logger.error(f"Integration execution failed: {e}")
# Log error
execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000
await self._log_usage(
integration_id=request.integration_id,
tenant_id=token_data.get("tenant_id") if 'token_data' in locals() else "unknown",
user_id=token_data.get("sub") if 'token_data' in locals() else "unknown",
method=request.method,
endpoint=request.endpoint,
success=False,
execution_time_ms=int(execution_time),
error=str(e)
)
return ProxyResponse(
success=False,
status_code=500,
data=None,
headers={},
execution_time_ms=int(execution_time),
sandbox_applied=False,
restrictions_applied=[],
error_message=str(e)
)
async def _execute_proxied_request(
self,
config: IntegrationConfig,
request: ProxyRequest
) -> ProxyResponse:
"""Execute the actual HTTP request to external service"""
# Build URL
if request.endpoint.startswith('http'):
url = request.endpoint
else:
url = f"{config.base_url.rstrip('/')}/{request.endpoint.lstrip('/')}"
# Apply authentication
headers = request.headers.copy()
await self._apply_authentication(config, headers)
# Set timeout
timeout = request.timeout_override or config.timeout_seconds
try:
async with self.get_http_client() as client:
# Execute request
if request.method.upper() == "GET":
response = await client.get(
url,
headers=headers,
params=request.params,
timeout=timeout
)
elif request.method.upper() == "POST":
response = await client.post(
url,
headers=headers,
json=request.data,
params=request.params,
timeout=timeout
)
elif request.method.upper() == "PUT":
response = await client.put(
url,
headers=headers,
json=request.data,
params=request.params,
timeout=timeout
)
elif request.method.upper() == "DELETE":
response = await client.delete(
url,
headers=headers,
params=request.params,
timeout=timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {request.method}")
# Check response size
if len(response.content) > config.max_response_size_bytes:
raise ValueError(f"Response size exceeds limit: {len(response.content)}")
# Parse response
try:
data = response.json() if response.content else {}
except json.JSONDecodeError:
data = {"raw_content": response.text}
return ProxyResponse(
success=200 <= response.status_code < 300,
status_code=response.status_code,
data=data,
headers=dict(response.headers),
execution_time_ms=0, # Will be set by caller
sandbox_applied=False # Will be set by caller
)
except httpx.TimeoutException:
return ProxyResponse(
success=False,
status_code=408,
data=None,
headers={},
execution_time_ms=timeout * 1000,
sandbox_applied=False,
restrictions_applied=[],
error_message="Request timeout"
)
except Exception as e:
return ProxyResponse(
success=False,
status_code=500,
data=None,
headers={},
execution_time_ms=0,
sandbox_applied=False,
restrictions_applied=[],
error_message=str(e)
)
async def _apply_authentication(self, config: IntegrationConfig, headers: Dict[str, str]):
"""Apply authentication to request headers"""
auth_config = config.auth_config
if config.authentication_method == "api_key":
api_key = auth_config.get("api_key")
key_header = auth_config.get("key_header", "Authorization")
key_prefix = auth_config.get("key_prefix", "Bearer")
if api_key:
headers[key_header] = f"{key_prefix} {api_key}"
elif config.authentication_method == "basic_auth":
username = auth_config.get("username")
password = auth_config.get("password")
if username and password:
import base64
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
headers["Authorization"] = f"Basic {credentials}"
elif config.authentication_method == "oauth2":
access_token = auth_config.get("access_token")
if access_token:
headers["Authorization"] = f"Bearer {access_token}"
# Add custom headers
custom_headers = auth_config.get("custom_headers", {})
headers.update(custom_headers)
def _has_capability(self, token_data: Dict[str, Any], required_capability: str) -> bool:
"""Check if token has required capability"""
capabilities = token_data.get("capabilities", [])
for capability in capabilities:
if isinstance(capability, dict):
resource = capability.get("resource", "")
# Handle wildcard matching
if resource == required_capability:
return True
if resource.endswith("*"):
prefix = resource[:-1] # Remove the *
if required_capability.startswith(prefix):
return True
elif isinstance(capability, str):
# Handle wildcard matching for string capabilities
if capability == required_capability:
return True
if capability.endswith("*"):
prefix = capability[:-1] # Remove the *
if required_capability.startswith(prefix):
return True
return False
async def _load_integration_config(self, integration_id: str) -> Optional[IntegrationConfig]:
"""Load integration configuration from storage"""
config_file = self.configs_path / f"{integration_id}.json"
if not config_file.exists():
return None
try:
with open(config_file, "r") as f:
data = json.load(f)
return IntegrationConfig.from_dict(data)
except Exception as e:
logger.error(f"Failed to load integration config {integration_id}: {e}")
return None
async def store_integration_config(self, config: IntegrationConfig) -> bool:
"""Store integration configuration"""
config_file = self.configs_path / f"{config.id}.json"
try:
with open(config_file, "w") as f:
json.dump(config.to_dict(), f, indent=2)
# Set secure permissions
config_file.chmod(0o600)
return True
except Exception as e:
logger.error(f"Failed to store integration config {config.id}: {e}")
return False
async def _log_usage(
self,
integration_id: str,
tenant_id: str,
user_id: str,
method: str,
endpoint: str,
success: bool,
execution_time_ms: int,
error: Optional[str] = None
):
"""Log integration usage for analytics"""
date_str = datetime.utcnow().strftime("%Y-%m-%d")
usage_file = self.usage_path / f"usage_{date_str}.jsonl"
usage_record = {
"timestamp": datetime.utcnow().isoformat(),
"integration_id": integration_id,
"tenant_id": tenant_id,
"user_id": user_id,
"method": method,
"endpoint": endpoint,
"success": success,
"execution_time_ms": execution_time_ms,
"error": error
}
try:
with open(usage_file, "a") as f:
f.write(json.dumps(usage_record) + "\n")
# Set secure permissions on file
usage_file.chmod(0o600)
except Exception as e:
logger.error(f"Failed to log usage: {e}")
async def _audit_log(
self,
action: str,
integration_id: str,
user_id: str,
details: Dict[str, Any]
):
"""Log audit trail for integration actions"""
date_str = datetime.utcnow().strftime("%Y-%m-%d")
audit_file = self.audit_path / f"audit_{date_str}.jsonl"
audit_record = {
"timestamp": datetime.utcnow().isoformat(),
"action": action,
"integration_id": integration_id,
"user_id": user_id,
"details": details
}
try:
with open(audit_file, "a") as f:
f.write(json.dumps(audit_record) + "\n")
# Set secure permissions on file
audit_file.chmod(0o600)
except Exception as e:
logger.error(f"Failed to log audit: {e}")
async def list_integrations(self, capability_token: str) -> List[IntegrationConfig]:
"""List available integrations based on capabilities"""
token_obj = verify_capability_token(capability_token)
if not token_obj:
raise PermissionError("Invalid capability token")
# Convert token object to dict for compatibility
token_data = {
"tenant_id": token_obj.tenant_id,
"sub": token_obj.sub,
"capabilities": [cap.dict() if hasattr(cap, 'dict') else cap for cap in token_obj.capabilities],
"constraints": {}
}
integrations = []
for config_file in self.configs_path.glob("*.json"):
try:
with open(config_file, "r") as f:
data = json.load(f)
config = IntegrationConfig.from_dict(data)
# Check if user has capability for this integration
required_capability = f"integration:{config.id}:*"
if self._has_capability(token_data, required_capability):
integrations.append(config)
except Exception as e:
logger.warning(f"Failed to load integration config {config_file}: {e}")
return integrations
async def get_integration_usage_analytics(
self,
integration_id: str,
days: int = 30
) -> Dict[str, Any]:
"""Get usage analytics for integration"""
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=days-1) # Include today in the range
total_requests = 0
successful_requests = 0
total_execution_time = 0
error_count = 0
# Process usage logs
for day_offset in range(days):
date = start_date + timedelta(days=day_offset)
date_str = date.strftime("%Y-%m-%d")
usage_file = self.usage_path / f"usage_{date_str}.jsonl"
if usage_file.exists():
try:
with open(usage_file, "r") as f:
for line in f:
record = json.loads(line.strip())
if record["integration_id"] == integration_id:
total_requests += 1
if record["success"]:
successful_requests += 1
else:
error_count += 1
total_execution_time += record["execution_time_ms"]
except Exception as e:
logger.warning(f"Failed to process usage file {usage_file}: {e}")
return {
"integration_id": integration_id,
"total_requests": total_requests,
"successful_requests": successful_requests,
"error_count": error_count,
"success_rate": successful_requests / total_requests if total_requests > 0 else 0,
"avg_execution_time_ms": total_execution_time / total_requests if total_requests > 0 else 0,
"date_range": {
"start": start_date.isoformat(),
"end": end_date.isoformat()
}
}
async def close(self):
"""Close HTTP client and cleanup resources"""
if self.http_client:
await self.http_client.aclose()
self.http_client = None

View File

@@ -0,0 +1,925 @@
"""
LLM Gateway Service for GT 2.0 Resource Cluster
Provides unified access to LLM providers with:
- Groq Cloud integration for fast inference
- OpenAI API compatibility
- Rate limiting and quota management
- Capability-based authentication
- Model routing and load balancing
- Response streaming support
GT 2.0 Architecture Principles:
- Stateless: No persistent connections or state
- Zero downtime: Circuit breakers and failover
- Self-contained: No external configuration dependencies
- Capability-based: JWT token authorization
"""
import asyncio
import logging
import json
import time
from typing import Dict, Any, List, Optional, AsyncGenerator, Union
from datetime import datetime, timedelta, timezone
from dataclasses import dataclass, asdict
import uuid
import httpx
from enum import Enum
from urllib.parse import urlparse
from app.core.config import get_settings
def is_provider_endpoint(endpoint_url: str, provider_domains: List[str]) -> bool:
"""
Safely check if URL belongs to a specific provider.
Uses proper URL parsing to prevent bypass via URLs like
'evil.groq.com.attacker.com' or 'groq.com.evil.com'.
"""
try:
parsed = urlparse(endpoint_url)
hostname = (parsed.hostname or "").lower()
for domain in provider_domains:
domain = domain.lower()
# Match exact domain or subdomain (e.g., api.groq.com matches groq.com)
if hostname == domain or hostname.endswith(f".{domain}"):
return True
return False
except Exception:
return False
from app.core.capability_auth import verify_capability_token, CapabilityError
from app.services.admin_model_config_service import get_admin_model_service, AdminModelConfig
logger = logging.getLogger(__name__)
settings = get_settings()
class ModelProvider(str, Enum):
"""Supported LLM providers"""
GROQ = "groq"
OPENAI = "openai"
ANTHROPIC = "anthropic"
NVIDIA = "nvidia"
LOCAL = "local"
class ModelCapability(str, Enum):
"""Model capabilities for routing"""
CHAT = "chat"
COMPLETION = "completion"
EMBEDDING = "embedding"
FUNCTION_CALLING = "function_calling"
VISION = "vision"
CODE = "code"
@dataclass
class ModelConfig:
"""Model configuration and capabilities"""
model_id: str
provider: ModelProvider
capabilities: List[ModelCapability]
max_tokens: int
context_window: int
cost_per_token: float
rate_limit_rpm: int
supports_streaming: bool
supports_functions: bool
is_available: bool = True
@dataclass
class LLMRequest:
"""Standardized LLM request format"""
model: str
messages: List[Dict[str, str]]
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
functions: Optional[List[Dict[str, Any]]] = None
function_call: Optional[Union[str, Dict[str, str]]] = None
user: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for API calls"""
result = asdict(self)
# Remove None values
return {k: v for k, v in result.items() if v is not None}
@dataclass
class LLMResponse:
"""Standardized LLM response format"""
id: str
object: str
created: int
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, int]
provider: str
request_id: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for API responses"""
return asdict(self)
class LLMGateway:
"""
LLM Gateway with unified API and multi-provider support.
Provides OpenAI-compatible API while routing to optimal providers
based on model capabilities, availability, and cost.
"""
def __init__(self):
self.settings = get_settings()
self.http_client = httpx.AsyncClient(timeout=120.0)
self.admin_service = get_admin_model_service()
# Rate limiting tracking
self.rate_limits: Dict[str, Dict[str, Any]] = {}
# Provider health tracking
self.provider_health: Dict[ModelProvider, bool] = {
provider: True for provider in ModelProvider
}
# Request statistics
self.stats = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"provider_usage": {provider.value: 0 for provider in ModelProvider},
"model_usage": {},
"average_latency": 0.0
}
logger.info("LLM Gateway initialized with admin-configured models")
async def get_available_models(self, tenant_id: Optional[str] = None) -> List[AdminModelConfig]:
"""Get available models, optionally filtered by tenant"""
if tenant_id:
return await self.admin_service.get_tenant_models(tenant_id)
else:
return await self.admin_service.get_all_models(active_only=True)
async def get_model_config(self, model_id: str, tenant_id: Optional[str] = None) -> Optional[AdminModelConfig]:
"""Get configuration for a specific model"""
config = await self.admin_service.get_model_config(model_id)
# Check tenant access if tenant_id provided
if config and tenant_id:
has_access = await self.admin_service.check_tenant_access(tenant_id, model_id)
if not has_access:
return None
return config
async def get_groq_api_key(self) -> Optional[str]:
"""Get Groq API key from admin service"""
return await self.admin_service.get_groq_api_key()
def _initialize_model_configs(self) -> Dict[str, ModelConfig]:
"""Initialize supported model configurations"""
models = {}
# Groq models (fast inference)
groq_models = [
ModelConfig(
model_id="llama3-8b-8192",
provider=ModelProvider.GROQ,
capabilities=[ModelCapability.CHAT, ModelCapability.CODE],
max_tokens=8192,
context_window=8192,
cost_per_token=0.00001,
rate_limit_rpm=30,
supports_streaming=True,
supports_functions=False
),
ModelConfig(
model_id="llama3-70b-8192",
provider=ModelProvider.GROQ,
capabilities=[ModelCapability.CHAT, ModelCapability.CODE],
max_tokens=8192,
context_window=8192,
cost_per_token=0.00008,
rate_limit_rpm=15,
supports_streaming=True,
supports_functions=False
),
ModelConfig(
model_id="mixtral-8x7b-32768",
provider=ModelProvider.GROQ,
capabilities=[ModelCapability.CHAT, ModelCapability.CODE],
max_tokens=32768,
context_window=32768,
cost_per_token=0.00005,
rate_limit_rpm=20,
supports_streaming=True,
supports_functions=False
),
ModelConfig(
model_id="gemma-7b-it",
provider=ModelProvider.GROQ,
capabilities=[ModelCapability.CHAT],
max_tokens=8192,
context_window=8192,
cost_per_token=0.00001,
rate_limit_rpm=30,
supports_streaming=True,
supports_functions=False
)
]
# OpenAI models (function calling, embeddings)
openai_models = [
ModelConfig(
model_id="gpt-4-turbo-preview",
provider=ModelProvider.OPENAI,
capabilities=[ModelCapability.CHAT, ModelCapability.FUNCTION_CALLING, ModelCapability.VISION],
max_tokens=4096,
context_window=128000,
cost_per_token=0.00003,
rate_limit_rpm=10,
supports_streaming=True,
supports_functions=True
),
ModelConfig(
model_id="gpt-3.5-turbo",
provider=ModelProvider.OPENAI,
capabilities=[ModelCapability.CHAT, ModelCapability.FUNCTION_CALLING],
max_tokens=4096,
context_window=16385,
cost_per_token=0.000002,
rate_limit_rpm=60,
supports_streaming=True,
supports_functions=True
),
ModelConfig(
model_id="text-embedding-3-small",
provider=ModelProvider.OPENAI,
capabilities=[ModelCapability.EMBEDDING],
max_tokens=8191,
context_window=8191,
cost_per_token=0.00000002,
rate_limit_rpm=3000,
supports_streaming=False,
supports_functions=False
)
]
# Add all models to registry
for model_list in [groq_models, openai_models]:
for model in model_list:
models[model.model_id] = model
return models
async def chat_completion(
self,
request: LLMRequest,
capability_token: str,
user_id: str,
tenant_id: str
) -> Union[LLMResponse, AsyncGenerator[str, None]]:
"""
Process chat completion request with capability validation.
Args:
request: LLM request parameters
capability_token: JWT capability token
user_id: User identifier for rate limiting
tenant_id: Tenant identifier for isolation
Returns:
LLM response or streaming generator
"""
start_time = time.time()
request_id = str(uuid.uuid4())
try:
# Verify capabilities
await self._verify_llm_capability(capability_token, request.model, user_id, tenant_id)
# Validate model availability
model_config = self.models.get(request.model)
if not model_config:
raise ValueError(f"Model {request.model} not supported")
if not model_config.is_available:
raise ValueError(f"Model {request.model} is currently unavailable")
# Check rate limits
await self._check_rate_limits(user_id, model_config)
# Route to configured endpoint (generic routing for any provider)
if hasattr(model_config, 'endpoint') and model_config.endpoint:
result = await self._process_generic_request(request, request_id, model_config, tenant_id)
elif model_config.provider == ModelProvider.GROQ:
result = await self._process_groq_request(request, request_id, model_config, tenant_id)
elif model_config.provider == ModelProvider.OPENAI:
result = await self._process_openai_request(request, request_id, model_config)
else:
raise ValueError(f"Provider {model_config.provider} not implemented - ensure endpoint is configured")
# Update statistics
latency = time.time() - start_time
await self._update_stats(request.model, model_config.provider, latency, True)
logger.info(f"LLM request completed: {request_id} ({latency:.3f}s)")
return result
except Exception as e:
latency = time.time() - start_time
await self._update_stats(request.model, ModelProvider.GROQ, latency, False)
logger.error(f"LLM request failed: {request_id} - {e}")
raise
async def _verify_llm_capability(
self,
capability_token: str,
model: str,
user_id: str,
tenant_id: str
) -> None:
"""Verify user has capability to use specific model"""
try:
payload = await verify_capability_token(capability_token)
# Check tenant match
if payload.get("tenant_id") != tenant_id:
raise CapabilityError("Tenant mismatch in capability token")
# Find LLM capability (match "llm" or "llm:provider" format)
capabilities = payload.get("capabilities", [])
llm_capability = None
for cap in capabilities:
resource = cap.get("resource", "")
if resource == "llm" or resource.startswith("llm:"):
llm_capability = cap
break
if not llm_capability:
raise CapabilityError("No LLM capability found in token")
# Check model access
allowed_models = llm_capability.get("constraints", {}).get("allowed_models", [])
if allowed_models and model not in allowed_models:
raise CapabilityError(f"Model {model} not allowed in capability")
# Check rate limits (per-minute window)
max_requests_per_minute = llm_capability.get("constraints", {}).get("max_requests_per_minute")
if max_requests_per_minute:
await self._check_user_rate_limit(user_id, max_requests_per_minute)
except CapabilityError:
raise
except Exception as e:
raise CapabilityError(f"Capability verification failed: {e}")
async def _check_rate_limits(self, user_id: str, model_config: ModelConfig) -> None:
"""Check if user is within rate limits for model"""
now = time.time()
minute_ago = now - 60
# Initialize user rate limit tracking
if user_id not in self.rate_limits:
self.rate_limits[user_id] = {}
if model_config.model_id not in self.rate_limits[user_id]:
self.rate_limits[user_id][model_config.model_id] = []
user_requests = self.rate_limits[user_id][model_config.model_id]
# Remove old requests
user_requests[:] = [req_time for req_time in user_requests if req_time > minute_ago]
# Check limit
if len(user_requests) >= model_config.rate_limit_rpm:
raise ValueError(f"Rate limit exceeded for model {model_config.model_id}")
# Add current request
user_requests.append(now)
async def _check_user_rate_limit(self, user_id: str, max_requests_per_minute: int) -> None:
"""
Check user's rate limit with per-minute enforcement window.
Enforces limits from Control Panel database (single source of truth).
Time window: 60 seconds (not 1 hour).
Args:
user_id: User identifier
max_requests_per_minute: Maximum requests allowed in 60-second window
Raises:
ValueError: If rate limit exceeded
"""
now = time.time()
minute_ago = now - 60 # 60-second window (was 3600 for hour)
if user_id not in self.rate_limits:
self.rate_limits[user_id] = {}
if "total_requests" not in self.rate_limits[user_id]:
self.rate_limits[user_id]["total_requests"] = []
total_requests = self.rate_limits[user_id]["total_requests"]
# Remove requests outside the 60-second window
total_requests[:] = [req_time for req_time in total_requests if req_time > minute_ago]
# Check limit
if len(total_requests) >= max_requests_per_minute:
raise ValueError(
f"Rate limit exceeded: {max_requests_per_minute} requests per minute. "
f"Try again in {int(60 - (now - total_requests[0]))} seconds."
)
# Add current request
total_requests.append(now)
async def _process_groq_request(
self,
request: LLMRequest,
request_id: str,
model_config: ModelConfig,
tenant_id: str
) -> Union[LLMResponse, AsyncGenerator[str, None]]:
"""
Process request using Groq API with tenant-specific API key.
API keys are fetched from Control Panel database - NO environment variable fallback.
"""
try:
# Get API key from Control Panel database (NO env fallback)
api_key = await self._get_tenant_api_key(tenant_id)
# Prepare Groq API request
groq_request = {
"model": request.model,
"messages": request.messages,
"max_tokens": min(request.max_tokens or 1024, model_config.max_tokens),
"temperature": request.temperature or 0.7,
"top_p": request.top_p or 1.0,
"stream": request.stream
}
if request.stop:
groq_request["stop"] = request.stop
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
if request.stream:
return self._stream_groq_response(groq_request, headers, request_id)
else:
return await self._get_groq_response(groq_request, headers, request_id)
except Exception as e:
logger.error(f"Groq API request failed: {e}")
raise ValueError(f"Groq API error: {e}")
async def _get_tenant_api_key(self, tenant_id: str) -> str:
"""
Get API key for tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
"""
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="groq")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No Groq API key for tenant '{tenant_id}': {e}")
raise ValueError(f"No Groq API key configured for tenant '{tenant_id}'. Please configure in Control Panel → API Keys.")
except RuntimeError as e:
logger.error(f"Control Panel error: {e}")
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
async def _get_tenant_nvidia_api_key(self, tenant_id: str) -> str:
"""
Get NVIDIA NIM API key for tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
"""
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="nvidia")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No NVIDIA API key for tenant '{tenant_id}': {e}")
raise ValueError(f"No NVIDIA API key configured for tenant '{tenant_id}'. Please configure in Control Panel → API Keys.")
except RuntimeError as e:
logger.error(f"Control Panel error: {e}")
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
async def _get_groq_response(
self,
groq_request: Dict[str, Any],
headers: Dict[str, str],
request_id: str
) -> LLMResponse:
"""Get non-streaming response from Groq"""
try:
response = await self.http_client.post(
"https://api.groq.com/openai/v1/chat/completions",
json=groq_request,
headers=headers
)
response.raise_for_status()
data = response.json()
# Convert to standardized format
return LLMResponse(
id=data.get("id", request_id),
object=data.get("object", "chat.completion"),
created=data.get("created", int(time.time())),
model=data.get("model", groq_request["model"]),
choices=data.get("choices", []),
usage=data.get("usage", {}),
provider="groq",
request_id=request_id
)
except httpx.HTTPStatusError as e:
logger.error(f"Groq API HTTP error: {e.response.status_code} - {e.response.text}")
raise ValueError(f"Groq API error: {e.response.status_code}")
except Exception as e:
logger.error(f"Groq API error: {e}")
raise ValueError(f"Groq API request failed: {e}")
async def _stream_groq_response(
self,
groq_request: Dict[str, Any],
headers: Dict[str, str],
request_id: str
) -> AsyncGenerator[str, None]:
"""Stream response from Groq"""
try:
async with self.http_client.stream(
"POST",
"https://api.groq.com/openai/v1/chat/completions",
json=groq_request,
headers=headers
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
# Add provider and request_id to chunk
data["provider"] = "groq"
data["request_id"] = request_id
yield f"data: {json.dumps(data)}\n\n"
except json.JSONDecodeError:
continue
yield "data: [DONE]\n\n"
except httpx.HTTPStatusError as e:
logger.error(f"Groq streaming error: {e.response.status_code}")
yield f"data: {json.dumps({'error': f'Groq API error: {e.response.status_code}'})}\n\n"
except Exception as e:
logger.error(f"Groq streaming error: {e}")
yield f"data: {json.dumps({'error': f'Streaming error: {e}'})}\n\n"
async def _process_generic_request(
self,
request: LLMRequest,
request_id: str,
model_config: Any,
tenant_id: str
) -> LLMResponse:
"""
Process request using generic endpoint (OpenAI-compatible).
For Groq endpoints, API keys are fetched from Control Panel database.
"""
try:
# Build OpenAI-compatible request
generic_request = {
"model": request.model,
"messages": request.messages,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"stream": request.stream
}
# Add optional parameters
if hasattr(request, 'tools') and request.tools:
generic_request["tools"] = request.tools
if hasattr(request, 'tool_choice') and request.tool_choice:
generic_request["tool_choice"] = request.tool_choice
headers = {"Content-Type": "application/json"}
endpoint_url = model_config.endpoint
# For Groq endpoints, use tenant-specific API key from Control Panel DB
if is_provider_endpoint(endpoint_url, ["groq.com"]):
api_key = await self._get_tenant_api_key(tenant_id)
headers["Authorization"] = f"Bearer {api_key}"
# For NVIDIA NIM endpoints, use tenant-specific API key from Control Panel DB
elif is_provider_endpoint(endpoint_url, ["nvidia.com", "integrate.api.nvidia.com"]):
api_key = await self._get_tenant_nvidia_api_key(tenant_id)
headers["Authorization"] = f"Bearer {api_key}"
# For other endpoints, use model_config.api_key if configured
elif hasattr(model_config, 'api_key') and model_config.api_key:
headers["Authorization"] = f"Bearer {model_config.api_key}"
logger.info(f"Sending request to generic endpoint: {endpoint_url}")
if request.stream:
return await self._stream_generic_response(generic_request, headers, endpoint_url, request_id, model_config)
else:
return await self._get_generic_response(generic_request, headers, endpoint_url, request_id, model_config)
except Exception as e:
logger.error(f"Generic request processing failed: {e}")
raise ValueError(f"Generic inference failed: {e}")
async def _get_generic_response(
self,
generic_request: Dict[str, Any],
headers: Dict[str, str],
endpoint_url: str,
request_id: str,
model_config: Any
) -> LLMResponse:
"""Get non-streaming response from generic endpoint"""
try:
response = await self.http_client.post(
endpoint_url,
json=generic_request,
headers=headers
)
response.raise_for_status()
data = response.json()
# Convert to standardized format
return LLMResponse(
id=data.get("id", request_id),
object=data.get("object", "chat.completion"),
created=data.get("created", int(time.time())),
model=data.get("model", generic_request["model"]),
choices=data.get("choices", []),
usage=data.get("usage", {}),
provider=getattr(model_config, 'provider', 'generic'),
request_id=request_id
)
except httpx.HTTPStatusError as e:
logger.error(f"Generic API HTTP error: {e.response.status_code} - {e.response.text}")
raise ValueError(f"Generic API error: {e.response.status_code}")
except Exception as e:
logger.error(f"Generic response error: {e}")
raise ValueError(f"Generic response processing failed: {e}")
async def _stream_generic_response(
self,
generic_request: Dict[str, Any],
headers: Dict[str, str],
endpoint_url: str,
request_id: str,
model_config: Any
):
"""Stream response from generic endpoint"""
try:
# For now, just do a non-streaming request and convert to streaming format
# This can be enhanced to support actual streaming later
response = await self._get_generic_response(generic_request, headers, endpoint_url, request_id, model_config)
# Convert to streaming format
if response.choices and len(response.choices) > 0:
content = response.choices[0].get("message", {}).get("content", "")
yield f"data: {json.dumps({'choices': [{'delta': {'content': content}}]})}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Generic streaming error: {e}")
yield f"data: {json.dumps({'error': f'Streaming error: {e}'})}\n\n"
async def _process_openai_request(
self,
request: LLMRequest,
request_id: str,
model_config: ModelConfig
) -> Union[LLMResponse, AsyncGenerator[str, None]]:
"""Process request using OpenAI API"""
try:
# Prepare OpenAI API request
openai_request = {
"model": request.model,
"messages": request.messages,
"max_tokens": min(request.max_tokens or 1024, model_config.max_tokens),
"temperature": request.temperature or 0.7,
"top_p": request.top_p or 1.0,
"stream": request.stream
}
if request.stop:
openai_request["stop"] = request.stop
headers = {
"Authorization": f"Bearer {settings.openai_api_key}",
"Content-Type": "application/json"
}
if request.stream:
return self._stream_openai_response(openai_request, headers, request_id)
else:
return await self._get_openai_response(openai_request, headers, request_id)
except Exception as e:
logger.error(f"OpenAI API request failed: {e}")
raise ValueError(f"OpenAI API error: {e}")
async def _get_openai_response(
self,
openai_request: Dict[str, Any],
headers: Dict[str, str],
request_id: str
) -> LLMResponse:
"""Get non-streaming response from OpenAI"""
try:
response = await self.http_client.post(
"https://api.openai.com/v1/chat/completions",
json=openai_request,
headers=headers
)
response.raise_for_status()
data = response.json()
# Convert to standardized format
return LLMResponse(
id=data.get("id", request_id),
object=data.get("object", "chat.completion"),
created=data.get("created", int(time.time())),
model=data.get("model", openai_request["model"]),
choices=data.get("choices", []),
usage=data.get("usage", {}),
provider="openai",
request_id=request_id
)
except httpx.HTTPStatusError as e:
logger.error(f"OpenAI API HTTP error: {e.response.status_code} - {e.response.text}")
raise ValueError(f"OpenAI API error: {e.response.status_code}")
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise ValueError(f"OpenAI API request failed: {e}")
async def _stream_openai_response(
self,
openai_request: Dict[str, Any],
headers: Dict[str, str],
request_id: str
) -> AsyncGenerator[str, None]:
"""Stream response from OpenAI"""
try:
async with self.http_client.stream(
"POST",
"https://api.openai.com/v1/chat/completions",
json=openai_request,
headers=headers
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
# Add provider and request_id to chunk
data["provider"] = "openai"
data["request_id"] = request_id
yield f"data: {json.dumps(data)}\n\n"
except json.JSONDecodeError:
continue
yield "data: [DONE]\n\n"
except httpx.HTTPStatusError as e:
logger.error(f"OpenAI streaming error: {e.response.status_code}")
yield f"data: {json.dumps({'error': f'OpenAI API error: {e.response.status_code}'})}\n\n"
except Exception as e:
logger.error(f"OpenAI streaming error: {e}")
yield f"data: {json.dumps({'error': f'Streaming error: {e}'})}\n\n"
async def _update_stats(
self,
model: str,
provider: ModelProvider,
latency: float,
success: bool
) -> None:
"""Update request statistics"""
self.stats["total_requests"] += 1
if success:
self.stats["successful_requests"] += 1
else:
self.stats["failed_requests"] += 1
self.stats["provider_usage"][provider.value] += 1
if model not in self.stats["model_usage"]:
self.stats["model_usage"][model] = 0
self.stats["model_usage"][model] += 1
# Update rolling average latency
total_requests = self.stats["total_requests"]
current_avg = self.stats["average_latency"]
self.stats["average_latency"] = ((current_avg * (total_requests - 1)) + latency) / total_requests
async def get_available_models(self) -> List[Dict[str, Any]]:
"""Get list of available models with capabilities"""
models = []
for model_id, config in self.models.items():
if config.is_available:
models.append({
"id": model_id,
"provider": config.provider.value,
"capabilities": [cap.value for cap in config.capabilities],
"max_tokens": config.max_tokens,
"context_window": config.context_window,
"supports_streaming": config.supports_streaming,
"supports_functions": config.supports_functions
})
return models
async def get_gateway_stats(self) -> Dict[str, Any]:
"""Get gateway statistics"""
return {
**self.stats,
"provider_health": {
provider.value: health
for provider, health in self.provider_health.items()
},
"active_models": len([m for m in self.models.values() if m.is_available]),
"timestamp": datetime.now(timezone.utc).isoformat()
}
async def health_check(self) -> Dict[str, Any]:
"""Health check for the LLM gateway"""
healthy_providers = sum(1 for health in self.provider_health.values() if health)
total_providers = len(self.provider_health)
return {
"status": "healthy" if healthy_providers > 0 else "degraded",
"providers_healthy": healthy_providers,
"total_providers": total_providers,
"available_models": len([m for m in self.models.values() if m.is_available]),
"total_requests": self.stats["total_requests"],
"success_rate": (
self.stats["successful_requests"] / max(self.stats["total_requests"], 1)
) * 100,
"average_latency_ms": self.stats["average_latency"] * 1000
}
async def close(self):
"""Close HTTP client and cleanup resources"""
await self.http_client.aclose()
# Global gateway instance
llm_gateway = LLMGateway()
# Factory function for dependency injection
def get_llm_gateway() -> LLMGateway:
"""Get LLM gateway instance"""
return llm_gateway

View File

@@ -0,0 +1,599 @@
"""
GT 2.0 MCP RAG Server
Provides RAG (Retrieval-Augmented Generation) capabilities as an MCP server.
Agents can use this server to search datasets, query documents, and retrieve
relevant context for user queries.
Tools provided:
- search_datasets: Search across user's accessible datasets
- query_documents: Query specific documents for relevant chunks
- get_relevant_chunks: Get relevant text chunks based on similarity
- list_user_datasets: List all datasets accessible to the user
- get_dataset_info: Get detailed information about a dataset
"""
import asyncio
import logging
from typing import Dict, Any, List, Optional, Union
from datetime import datetime
from dataclasses import dataclass
import httpx
import json
from app.core.security import verify_capability_token
from app.services.mcp_server import MCPServerResource, MCPServerConfig
logger = logging.getLogger(__name__)
@dataclass
class RAGSearchParams:
"""Parameters for RAG search operations"""
query: str
dataset_ids: Optional[List[str]] = None
search_method: str = "hybrid" # hybrid, vector, text
max_results: int = 10
similarity_threshold: float = 0.7
include_metadata: bool = True
@dataclass
class RAGSearchResult:
"""Result from RAG search operation"""
chunk_id: str
document_id: str
dataset_id: str
dataset_name: str
document_name: str
content: str
similarity_score: float
chunk_index: int
metadata: Dict[str, Any]
class MCPRAGServer:
"""
MCP server for RAG operations in GT 2.0.
Provides secure, tenant-isolated access to document search capabilities
through standardized MCP tool interfaces.
"""
def __init__(self, tenant_backend_url: str = "http://tenant-backend:8000"):
self.tenant_backend_url = tenant_backend_url
self.server_name = "rag_server"
self.server_type = "rag"
# Define available tools (streamlined for simplicity)
self.available_tools = [
"search_datasets"
]
# Tool schemas for MCP protocol (enhanced with flexible parameters)
self.tool_schemas = {
"search_datasets": {
"name": "search_datasets",
"description": "Search through datasets containing uploaded documents, PDFs, and files. Use when users ask about documentation, reference materials, checking files, looking up information, need data from uploaded content, want to know what's in the dataset, search our data, check if we have something, or look through files.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "What to search for in the datasets"
},
"dataset_ids": {
"type": "array",
"items": {"type": "string"},
"description": "(Optional) List of specific dataset IDs to search within"
},
"file_pattern": {
"type": "string",
"description": "(Optional) File pattern filter (e.g., '*.pdf', '*.txt')"
},
"search_all": {
"type": "boolean",
"default": False,
"description": "(Optional) Search across all accessible datasets (ignores dataset_ids)"
},
"max_results": {
"type": "integer",
"default": 10,
"description": "(Optional) Number of results to return (default: 10)"
}
},
"required": ["query"]
}
}
}
async def handle_tool_call(
self,
tool_name: str,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str,
agent_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Handle MCP tool call with tenant isolation and user context.
Args:
tool_name: Name of the tool being called
parameters: Tool parameters from the LLM
tenant_domain: Tenant domain for isolation
user_id: User making the request
Returns:
Tool execution result or error
"""
logger.info(f"🚀 MCP RAG Server: handle_tool_call called - tool={tool_name}, tenant={tenant_domain}, user={user_id}")
logger.info(f"📝 MCP RAG Server: parameters={parameters}")
try:
# Validate tool exists
if tool_name not in self.available_tools:
return {
"error": f"Unknown tool: {tool_name}",
"tool_name": tool_name
}
# Route to appropriate handler
if tool_name == "search_datasets":
return await self._search_datasets(parameters, tenant_domain, user_id, agent_context)
else:
return {
"error": f"Tool handler not implemented: {tool_name}",
"tool_name": tool_name
}
except Exception as e:
logger.error(f"Error handling tool call {tool_name}: {e}")
return {
"error": f"Tool execution failed: {str(e)}",
"tool_name": tool_name
}
def _verify_user_access(self, user_id: str, tenant_domain: str) -> bool:
"""Verify user has access to tenant resources (simplified check)"""
# In a real system, this would query the database to verify
# that the user has access to the tenant's resources
# For now, we trust that the tenant backend has already verified this
return bool(user_id and tenant_domain)
async def _search_datasets(
self,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str,
agent_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Search across user's datasets"""
logger.info(f"🔍 RAG Server: search_datasets called for user {user_id} in tenant {tenant_domain}")
logger.info(f"📝 RAG Server: search parameters = {parameters}")
logger.info(f"📝 RAG Server: parameter types: {[(k, type(v)) for k, v in parameters.items()]}")
try:
query = parameters.get("query", "").strip()
list_mode = parameters.get("list_mode", False)
# Handle list mode - list available datasets instead of searching
if list_mode:
logger.info(f"🔍 RAG Server: List mode activated - fetching available datasets")
async with httpx.AsyncClient(timeout=15.0) as client:
response = await client.get(
f"{self.tenant_backend_url}/api/v1/datasets/internal/list",
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id
}
)
if response.status_code == 200:
datasets = response.json()
logger.info(f"✅ RAG Server: Successfully listed {len(datasets)} datasets")
return {
"success": True,
"datasets": datasets,
"total_count": len(datasets),
"list_mode": True
}
else:
logger.error(f"❌ RAG Server: Failed to list datasets: {response.status_code} - {response.text}")
return {"error": f"Failed to list datasets: {response.status_code}"}
# Normal search mode
if not query:
logger.error("❌ RAG Server: Query parameter is required")
return {"error": "Query parameter is required"}
# Prepare search request with enhanced parameters
dataset_ids = parameters.get("dataset_ids")
file_pattern = parameters.get("file_pattern")
search_all = parameters.get("search_all", False)
# Handle legacy dataset_id parameter (backwards compatibility)
if dataset_ids is None and parameters.get("dataset_id"):
dataset_ids = [parameters.get("dataset_id")]
# Ensure dataset_ids is properly formatted
if dataset_ids is None:
dataset_ids = []
elif isinstance(dataset_ids, str):
dataset_ids = [dataset_ids]
# If search_all is True, ignore dataset_ids filter
if search_all:
dataset_ids = []
# AGENT-AWARE: If no datasets specified, use agent's configured datasets
if not dataset_ids and not search_all and agent_context:
agent_dataset_ids = agent_context.get('selected_dataset_ids', [])
if agent_dataset_ids:
dataset_ids = agent_dataset_ids
agent_name = agent_context.get('agent_name', 'Unknown')
logger.info(f"✅ RAG Server: Using agent '{agent_name}' datasets: {dataset_ids}")
else:
logger.warning(f"⚠️ RAG Server: Agent context available but no datasets configured")
elif not dataset_ids and not search_all:
logger.warning(f"⚠️ RAG Server: No dataset_ids provided and no agent context available")
search_request = {
"query": query,
"search_type": parameters.get("search_method", "hybrid"),
"max_results": parameters.get("max_results", 10), # No arbitrary cap
"dataset_ids": dataset_ids,
"min_similarity": 0.3
}
# Add file_pattern if provided
if file_pattern:
search_request["file_pattern"] = file_pattern
logger.info(f"🎯 RAG Server: prepared search request = {search_request}")
# Call tenant backend search API
logger.info(f"🌐 RAG Server: calling tenant backend at {self.tenant_backend_url}/api/v1/search/")
logger.info(f"🌐 RAG Server: request headers: X-Tenant-Domain='{tenant_domain}', X-User-ID='{user_id}'")
logger.info(f"🌐 RAG Server: request body: {search_request}")
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{self.tenant_backend_url}/api/v1/search/",
json=search_request,
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id,
"Content-Type": "application/json"
}
)
logger.info(f"📊 RAG Server: tenant backend response: {response.status_code}")
if response.status_code != 200:
logger.error(f"📊 RAG Server: error response body: {response.text}")
if response.status_code == 200:
data = response.json()
logger.info(f"✅ RAG Server: search successful, got {len(data.get('results', []))} results")
# Format results for MCP response
results = []
for result in data.get("results", []):
results.append({
"chunk_id": result.get("chunk_id"),
"document_id": result.get("document_id"),
"dataset_id": result.get("dataset_id"),
"content": result.get("text", ""),
"similarity_score": result.get("hybrid_score", 0.0),
"metadata": result.get("metadata", {})
})
return {
"success": True,
"query": query,
"results_count": len(results),
"results": results,
"search_method": data.get("search_type", "hybrid")
}
else:
error_text = response.text
logger.error(f"❌ RAG Server: search failed: {response.status_code} - {error_text}")
return {
"error": f"Search failed: {response.status_code} - {error_text}",
"query": query
}
except Exception as e:
logger.error(f"Dataset search error: {e}")
return {
"error": f"Search operation failed: {str(e)}",
"query": parameters.get("query", "")
}
async def _query_documents(
self,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str
) -> Dict[str, Any]:
"""Query specific documents for relevant chunks"""
try:
query = parameters.get("query", "").strip()
document_ids = parameters.get("document_ids", [])
if not query or not document_ids:
return {"error": "Both query and document_ids are required"}
# Use search API with document ID filter
search_request = {
"query": query,
"search_type": "hybrid",
"max_results": parameters.get("max_results", 5),
"document_ids": document_ids
}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{self.tenant_backend_url}/api/v1/search/documents",
json=search_request,
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id,
"Content-Type": "application/json"
}
)
if response.status_code == 200:
data = response.json()
return {
"success": True,
"query": query,
"document_ids": document_ids,
"results": data.get("results", [])
}
else:
return {
"error": f"Document query failed: {response.status_code}",
"query": query,
"document_ids": document_ids
}
except Exception as e:
return {
"error": f"Document query failed: {str(e)}",
"query": parameters.get("query", "")
}
async def _list_user_datasets(
self,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str
) -> Dict[str, Any]:
"""List user's accessible datasets"""
try:
include_stats = parameters.get("include_stats", True)
async with httpx.AsyncClient(timeout=15.0) as client:
params = {"include_stats": include_stats}
response = await client.get(
f"{self.tenant_backend_url}/api/v1/datasets",
params=params,
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id
}
)
if response.status_code == 200:
data = response.json()
datasets = data.get("data", []) if isinstance(data, dict) else data
# Format for MCP response
formatted_datasets = []
for dataset in datasets:
formatted_datasets.append({
"id": dataset.get("id"),
"name": dataset.get("name"),
"description": dataset.get("description"),
"document_count": dataset.get("document_count", 0),
"chunk_count": dataset.get("chunk_count", 0),
"created_at": dataset.get("created_at"),
"access_group": dataset.get("access_group", "individual")
})
return {
"success": True,
"datasets": formatted_datasets,
"total_count": len(formatted_datasets)
}
else:
return {
"error": f"Failed to list datasets: {response.status_code}"
}
except Exception as e:
return {
"error": f"Failed to list datasets: {str(e)}"
}
async def _get_dataset_info(
self,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str
) -> Dict[str, Any]:
"""Get detailed information about a dataset"""
try:
dataset_id = parameters.get("dataset_id")
if not dataset_id:
return {"error": "dataset_id parameter is required"}
async with httpx.AsyncClient(timeout=15.0) as client:
response = await client.get(
f"{self.tenant_backend_url}/api/v1/datasets/{dataset_id}",
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id
}
)
if response.status_code == 200:
data = response.json()
dataset = data.get("data", data)
return {
"success": True,
"dataset": {
"id": dataset.get("id"),
"name": dataset.get("name"),
"description": dataset.get("description"),
"document_count": dataset.get("document_count", 0),
"chunk_count": dataset.get("chunk_count", 0),
"vector_count": dataset.get("vector_count", 0),
"storage_size_mb": dataset.get("storage_size_mb", 0),
"created_at": dataset.get("created_at"),
"updated_at": dataset.get("updated_at"),
"access_group": dataset.get("access_group"),
"tags": dataset.get("tags", [])
}
}
elif response.status_code == 404:
return {
"error": f"Dataset not found: {dataset_id}"
}
else:
return {
"error": f"Failed to get dataset info: {response.status_code}"
}
except Exception as e:
return {
"error": f"Failed to get dataset info: {str(e)}"
}
async def _get_user_agent_datasets(self, tenant_domain: str, user_id: str) -> List[str]:
"""Auto-detect agent datasets for the current user"""
try:
# Get user's agents and their configured datasets
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
f"{self.tenant_backend_url}/api/v1/agents",
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id
}
)
if response.status_code == 200:
agents_data = response.json()
agents = agents_data.get("data", []) if isinstance(agents_data, dict) else agents_data
# Collect all dataset IDs from all user's agents
all_dataset_ids = set()
for agent in agents:
agent_dataset_ids = agent.get("selected_dataset_ids", [])
if agent_dataset_ids:
all_dataset_ids.update(agent_dataset_ids)
logger.info(f"🔍 RAG Server: Agent {agent.get('name', 'unknown')} has datasets: {agent_dataset_ids}")
return list(all_dataset_ids)
else:
logger.warning(f"⚠️ RAG Server: Failed to get agents: {response.status_code}")
return []
except Exception as e:
logger.error(f"❌ RAG Server: Error getting user agent datasets: {e}")
return []
async def _get_relevant_chunks(
self,
parameters: Dict[str, Any],
tenant_domain: str,
user_id: str
) -> Dict[str, Any]:
"""Get most relevant chunks for a query"""
try:
query = parameters.get("query", "").strip()
if not query:
return {"error": "query parameter is required"}
chunk_count = min(parameters.get("chunk_count", 3), 10) # Cap at 10
min_similarity = parameters.get("min_similarity", 0.6)
dataset_ids = parameters.get("dataset_ids")
search_request = {
"query": query,
"search_type": "vector", # Use vector search for relevance
"max_results": chunk_count,
"min_similarity": min_similarity,
"dataset_ids": dataset_ids
}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{self.tenant_backend_url}/api/v1/search",
json=search_request,
headers={
"X-Tenant-Domain": tenant_domain,
"X-User-ID": user_id,
"Content-Type": "application/json"
}
)
if response.status_code == 200:
data = response.json()
chunks = []
for result in data.get("results", []):
chunks.append({
"chunk_id": result.get("chunk_id"),
"document_id": result.get("document_id"),
"dataset_id": result.get("dataset_id"),
"content": result.get("text", ""),
"similarity_score": result.get("vector_similarity", 0.0),
"chunk_index": result.get("rank", 0),
"metadata": result.get("metadata", {})
})
return {
"success": True,
"query": query,
"chunks": chunks,
"chunk_count": len(chunks),
"min_similarity": min_similarity
}
else:
return {
"error": f"Chunk retrieval failed: {response.status_code}"
}
except Exception as e:
return {
"error": f"Failed to get relevant chunks: {str(e)}"
}
def get_server_config(self) -> MCPServerConfig:
"""Get MCP server configuration"""
return MCPServerConfig(
server_name=self.server_name,
server_url="internal://mcp-rag-server",
server_type=self.server_type,
available_tools=self.available_tools,
required_capabilities=["mcp:rag:*"],
sandbox_mode=True,
max_memory_mb=256,
max_cpu_percent=25,
timeout_seconds=30,
network_isolation=False, # Needs to access tenant backend
max_requests_per_minute=120,
max_concurrent_requests=10
)
def get_tool_schemas(self) -> Dict[str, Any]:
"""Get MCP tool schemas for this server"""
return self.tool_schemas
# Global instance
mcp_rag_server = MCPRAGServer()

View File

@@ -0,0 +1,491 @@
"""
MCP Sandbox Service for GT 2.0
Provides secure sandboxed execution environment for MCP servers.
Implements resource isolation, monitoring, and security constraints.
"""
import os
import asyncio
import resource
import signal
import tempfile
import shutil
from typing import Dict, Any, Optional, Callable, Tuple
from datetime import datetime, timedelta
from pathlib import Path
import logging
import json
import psutil
from contextlib import asynccontextmanager
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class SandboxConfig:
"""Configuration for sandbox environment"""
# Resource limits
max_memory_mb: int = 512
max_cpu_percent: int = 50
max_disk_mb: int = 100
timeout_seconds: int = 30
# Security settings
network_isolation: bool = True
readonly_filesystem: bool = False
allowed_paths: list = None
blocked_paths: list = None
allowed_commands: list = None
# Process limits
max_processes: int = 10
max_open_files: int = 100
max_threads: int = 20
def __post_init__(self):
if self.allowed_paths is None:
self.allowed_paths = ["/tmp", "/var/tmp"]
if self.blocked_paths is None:
self.blocked_paths = ["/etc", "/root", "/home", "/usr/bin", "/usr/sbin"]
if self.allowed_commands is None:
self.allowed_commands = ["ls", "cat", "grep", "find", "echo", "pwd"]
class ProcessSandbox:
"""
Process-level sandbox for MCP tool execution
Uses OS-level isolation and resource limits
"""
def __init__(self, config: SandboxConfig):
self.config = config
self.process: Optional[asyncio.subprocess.Process] = None
self.start_time: Optional[datetime] = None
self.temp_dir: Optional[Path] = None
self.resource_monitor_task: Optional[asyncio.Task] = None
async def __aenter__(self):
"""Enter sandbox context"""
await self.setup()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit sandbox context and cleanup"""
await self.cleanup()
async def setup(self):
"""Setup sandbox environment"""
# Create temporary directory for sandbox
self.temp_dir = Path(tempfile.mkdtemp(prefix="mcp_sandbox_"))
os.chmod(self.temp_dir, 0o700) # Restrict access
# Set resource limits for child processes
self._set_resource_limits()
# Start resource monitoring
self.resource_monitor_task = asyncio.create_task(self._monitor_resources())
self.start_time = datetime.utcnow()
logger.info(f"Sandbox setup complete: {self.temp_dir}")
async def cleanup(self):
"""Cleanup sandbox environment"""
# Stop resource monitoring
if self.resource_monitor_task:
self.resource_monitor_task.cancel()
try:
await self.resource_monitor_task
except asyncio.CancelledError:
pass
# Terminate process if still running
if self.process and self.process.returncode is None:
try:
self.process.terminate()
await asyncio.wait_for(self.process.wait(), timeout=5)
except asyncio.TimeoutError:
self.process.kill()
await self.process.wait()
# Remove temporary directory
if self.temp_dir and self.temp_dir.exists():
shutil.rmtree(self.temp_dir, ignore_errors=True)
logger.info("Sandbox cleanup complete")
async def execute(
self,
command: str,
args: list = None,
input_data: str = None,
env: Dict[str, str] = None
) -> Tuple[int, str, str]:
"""
Execute command in sandbox
Args:
command: Command to execute
args: Command arguments
input_data: Input to send to process
env: Environment variables
Returns:
Tuple of (return_code, stdout, stderr)
"""
# Validate command
if not self._validate_command(command):
raise PermissionError(f"Command not allowed: {command}")
# Prepare environment
sandbox_env = self._prepare_environment(env)
# Prepare command with arguments
full_command = [command] + (args or [])
try:
# Create process with resource limits
self.process = await asyncio.create_subprocess_exec(
*full_command,
stdin=asyncio.subprocess.PIPE if input_data else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(self.temp_dir),
env=sandbox_env,
preexec_fn=self._set_process_limits if os.name == 'posix' else None
)
# Execute with timeout
stdout, stderr = await asyncio.wait_for(
self.process.communicate(input=input_data.encode() if input_data else None),
timeout=self.config.timeout_seconds
)
return self.process.returncode, stdout.decode(), stderr.decode()
except asyncio.TimeoutError:
if self.process:
self.process.kill()
await self.process.wait()
raise TimeoutError(f"Command exceeded {self.config.timeout_seconds}s timeout")
except Exception as e:
logger.error(f"Sandbox execution error: {e}")
raise
async def execute_function(
self,
func: Callable,
*args,
**kwargs
) -> Any:
"""
Execute Python function in sandbox
Uses multiprocessing for isolation
"""
import multiprocessing
import pickle
# Create pipe for communication
parent_conn, child_conn = multiprocessing.Pipe()
def sandbox_wrapper(conn, func, args, kwargs):
"""Wrapper to execute function in child process"""
try:
# Apply resource limits
self._set_process_limits()
# Execute function
result = func(*args, **kwargs)
# Send result back
conn.send(("success", pickle.dumps(result)))
except Exception as e:
conn.send(("error", str(e)))
finally:
conn.close()
# Create and start process
process = multiprocessing.Process(
target=sandbox_wrapper,
args=(child_conn, func, args, kwargs)
)
process.start()
# Wait for result with timeout
try:
if parent_conn.poll(self.config.timeout_seconds):
status, data = parent_conn.recv()
if status == "success":
return pickle.loads(data)
else:
raise RuntimeError(f"Sandbox function error: {data}")
else:
process.terminate()
process.join(timeout=5)
if process.is_alive():
process.kill()
raise TimeoutError(f"Function exceeded {self.config.timeout_seconds}s timeout")
finally:
parent_conn.close()
if process.is_alive():
process.terminate()
process.join()
def _validate_command(self, command: str) -> bool:
"""Validate if command is allowed"""
# Check if command is in allowed list
command_name = os.path.basename(command)
if self.config.allowed_commands and command_name not in self.config.allowed_commands:
return False
# Check for dangerous patterns
dangerous_patterns = [
"rm -rf",
"dd if=",
"mkfs",
"format",
">", # Redirect that could overwrite files
"|", # Pipe that could chain commands
";", # Command separator
"&", # Background execution
"`", # Command substitution
"$(" # Command substitution
]
for pattern in dangerous_patterns:
if pattern in command:
return False
return True
def _prepare_environment(self, custom_env: Dict[str, str] = None) -> Dict[str, str]:
"""Prepare sandboxed environment variables"""
# Start with minimal environment
sandbox_env = {
"PATH": "/usr/local/bin:/usr/bin:/bin",
"HOME": str(self.temp_dir),
"TEMP": str(self.temp_dir),
"TMP": str(self.temp_dir),
"USER": "sandbox",
"SHELL": "/bin/sh"
}
# Add custom environment variables if provided
if custom_env:
# Filter out dangerous variables
dangerous_vars = ["LD_PRELOAD", "LD_LIBRARY_PATH", "PYTHONPATH", "PATH"]
for key, value in custom_env.items():
if key not in dangerous_vars:
sandbox_env[key] = value
return sandbox_env
def _set_resource_limits(self):
"""Set resource limits for the process"""
if os.name != 'posix':
return # Resource limits only work on POSIX systems
# Memory limit
memory_bytes = self.config.max_memory_mb * 1024 * 1024
resource.setrlimit(resource.RLIMIT_AS, (memory_bytes, memory_bytes))
# CPU time limit
resource.setrlimit(resource.RLIMIT_CPU, (self.config.timeout_seconds, self.config.timeout_seconds))
# File size limit
file_size_bytes = self.config.max_disk_mb * 1024 * 1024
resource.setrlimit(resource.RLIMIT_FSIZE, (file_size_bytes, file_size_bytes))
# Process limit
resource.setrlimit(resource.RLIMIT_NPROC, (self.config.max_processes, self.config.max_processes))
# Open files limit
resource.setrlimit(resource.RLIMIT_NOFILE, (self.config.max_open_files, self.config.max_open_files))
def _set_process_limits(self):
"""Set limits for child process (called in child context)"""
if os.name != 'posix':
return
# Drop privileges if running as root (shouldn't happen in production)
if os.getuid() == 0:
os.setuid(65534) # nobody user
os.setgid(65534) # nogroup
# Set resource limits
self._set_resource_limits()
# Set process group for easier cleanup
os.setpgrp()
async def _monitor_resources(self):
"""Monitor resource usage of sandboxed process"""
while True:
try:
if self.process and self.process.returncode is None:
# Get process info
try:
proc = psutil.Process(self.process.pid)
# Check CPU usage
cpu_percent = proc.cpu_percent(interval=0.1)
if cpu_percent > self.config.max_cpu_percent:
logger.warning(f"Sandbox CPU usage high: {cpu_percent}%")
# Could throttle or terminate if consistently high
# Check memory usage
memory_info = proc.memory_info()
memory_mb = memory_info.rss / (1024 * 1024)
if memory_mb > self.config.max_memory_mb:
logger.warning(f"Sandbox memory limit exceeded: {memory_mb}MB")
self.process.terminate()
break
# Check runtime
if self.start_time:
runtime = (datetime.utcnow() - self.start_time).total_seconds()
if runtime > self.config.timeout_seconds:
logger.warning(f"Sandbox timeout exceeded: {runtime}s")
self.process.terminate()
break
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass # Process ended or inaccessible
await asyncio.sleep(1) # Check every second
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Resource monitoring error: {e}")
await asyncio.sleep(1)
class ContainerSandbox:
"""
Container-based sandbox for stronger isolation
Uses Docker or Podman for execution
"""
def __init__(self, config: SandboxConfig):
self.config = config
self.container_id: Optional[str] = None
self.container_runtime = self._detect_container_runtime()
def _detect_container_runtime(self) -> str:
"""Detect available container runtime"""
# Try Docker first
if shutil.which("docker"):
return "docker"
# Try Podman as alternative
elif shutil.which("podman"):
return "podman"
else:
logger.warning("No container runtime found, falling back to process sandbox")
return None
@asynccontextmanager
async def create_container(self, image: str = "alpine:latest"):
"""Create and manage container lifecycle"""
if not self.container_runtime:
raise RuntimeError("No container runtime available")
try:
# Create container with resource limits
create_cmd = [
self.container_runtime, "create",
"--rm", # Auto-remove after stop
f"--memory={self.config.max_memory_mb}m",
f"--cpus={self.config.max_cpu_percent / 100}",
"--network=none" if self.config.network_isolation else "--network=bridge",
"--read-only" if self.config.readonly_filesystem else "",
f"--tmpfs=/tmp:size={self.config.max_disk_mb}m",
"--security-opt=no-new-privileges",
"--cap-drop=ALL", # Drop all capabilities
image,
"sleep", "infinity" # Keep container running
]
# Remove empty strings from command
create_cmd = [arg for arg in create_cmd if arg]
# Create container
proc = await asyncio.create_subprocess_exec(
*create_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"Failed to create container: {stderr.decode()}")
self.container_id = stdout.decode().strip()
# Start container
start_cmd = [self.container_runtime, "start", self.container_id]
proc = await asyncio.create_subprocess_exec(*start_cmd)
await proc.wait()
logger.info(f"Container sandbox created: {self.container_id[:12]}")
yield self
finally:
# Cleanup container
if self.container_id:
stop_cmd = [self.container_runtime, "stop", self.container_id]
proc = await asyncio.create_subprocess_exec(*stop_cmd)
await proc.wait()
logger.info(f"Container sandbox cleaned up: {self.container_id[:12]}")
async def execute(self, command: str, args: list = None) -> Tuple[int, str, str]:
"""Execute command in container"""
if not self.container_id:
raise RuntimeError("Container not created")
exec_cmd = [
self.container_runtime, "exec",
self.container_id,
command
] + (args or [])
proc = await asyncio.create_subprocess_exec(
*exec_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=self.config.timeout_seconds
)
return proc.returncode, stdout.decode(), stderr.decode()
except asyncio.TimeoutError:
# Kill process in container
kill_cmd = [self.container_runtime, "exec", self.container_id, "kill", "-9", "-1"]
await asyncio.create_subprocess_exec(*kill_cmd)
raise TimeoutError(f"Command exceeded {self.config.timeout_seconds}s timeout")
# Factory function to get appropriate sandbox
def create_sandbox(config: SandboxConfig, prefer_container: bool = True) -> Any:
"""
Create appropriate sandbox based on availability and preference
Args:
config: Sandbox configuration
prefer_container: Prefer container over process sandbox
Returns:
ProcessSandbox or ContainerSandbox instance
"""
if prefer_container and shutil.which("docker"):
return ContainerSandbox(config)
elif prefer_container and shutil.which("podman"):
return ContainerSandbox(config)
else:
return ProcessSandbox(config)

View File

@@ -0,0 +1,698 @@
"""
MCP Server Resource Wrapper for GT 2.0
Encapsulates MCP (Model Context Protocol) servers as GT 2.0 resources.
Provides security sandboxing and capability-based access control.
"""
from typing import List, Optional, Dict, Any, AsyncIterator
from datetime import datetime, timedelta
from enum import Enum
import asyncio
import logging
import json
from dataclasses import dataclass
from app.models.access_group import AccessGroup, Resource
from app.core.security import verify_capability_token
logger = logging.getLogger(__name__)
class MCPServerStatus(str, Enum):
"""MCP server operational status"""
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
STARTING = "starting"
STOPPING = "stopping"
STOPPED = "stopped"
@dataclass
class MCPServerConfig:
"""Configuration for an MCP server instance"""
server_name: str
server_url: str
server_type: str # filesystem, github, slack, etc.
available_tools: List[str]
required_capabilities: List[str]
# Security settings
sandbox_mode: bool = True
max_memory_mb: int = 512
max_cpu_percent: int = 50
timeout_seconds: int = 30
network_isolation: bool = True
# Rate limiting
max_requests_per_minute: int = 60
max_concurrent_requests: int = 5
# Authentication
auth_type: Optional[str] = None # none, api_key, oauth2
auth_config: Optional[Dict[str, Any]] = None
class MCPServerResource(Resource):
"""
MCP server encapsulated as a GT 2.0 resource
Inherits from Resource for access control
"""
# MCP-specific configuration
server_config: MCPServerConfig
# Runtime state
status: MCPServerStatus = MCPServerStatus.STOPPED
last_health_check: Optional[datetime] = None
error_count: int = 0
total_requests: int = 0
# Connection management
connection_pool_size: int = 5
active_connections: int = 0
def to_capability_requirement(self) -> str:
"""Generate capability requirement string for this MCP server"""
return f"mcp:{self.server_config.server_name}:*"
def validate_tool_access(self, tool_name: str, capability_token: Dict[str, Any]) -> bool:
"""Check if capability token allows access to specific tool"""
required_capability = f"mcp:{self.server_config.server_name}:{tool_name}"
capabilities = capability_token.get("capabilities", [])
for cap in capabilities:
resource = cap.get("resource", "")
if resource == required_capability or resource == f"mcp:{self.server_config.server_name}:*":
return True
return False
class SecureMCPWrapper:
"""
Secure wrapper for MCP servers with GT 2.0 security integration
Provides sandboxing, rate limiting, and capability-based access
"""
def __init__(self, resource_cluster_url: str = "http://localhost:8004"):
self.resource_cluster_url = resource_cluster_url
self.mcp_resources: Dict[str, MCPServerResource] = {}
self.rate_limiters: Dict[str, asyncio.Semaphore] = {}
self.audit_log = []
async def register_mcp_server(
self,
server_config: MCPServerConfig,
owner_id: str,
tenant_domain: str,
access_group: AccessGroup = AccessGroup.INDIVIDUAL
) -> MCPServerResource:
"""
Register an MCP server as a GT 2.0 resource
Args:
server_config: MCP server configuration
owner_id: User who owns this MCP resource
tenant_domain: Tenant domain
access_group: Access control level
Returns:
Registered MCP server resource
"""
# Create MCP resource
resource = MCPServerResource(
id=f"mcp-{server_config.server_name}-{datetime.utcnow().timestamp()}",
name=f"MCP Server: {server_config.server_name}",
resource_type="mcp_server",
owner_id=owner_id,
tenant_domain=tenant_domain,
access_group=access_group,
team_members=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
metadata={
"server_type": server_config.server_type,
"tools_count": len(server_config.available_tools)
},
server_config=server_config
)
# Initialize rate limiter
self.rate_limiters[resource.id] = asyncio.Semaphore(
server_config.max_concurrent_requests
)
# Store resource
self.mcp_resources[resource.id] = resource
# Start health monitoring
asyncio.create_task(self._monitor_health(resource.id))
logger.info(f"Registered MCP server: {server_config.server_name} as resource {resource.id}")
return resource
async def execute_tool(
self,
mcp_resource_id: str,
tool_name: str,
parameters: Dict[str, Any],
capability_token: str,
user_id: str
) -> Dict[str, Any]:
"""
Execute an MCP tool with security constraints
Args:
mcp_resource_id: MCP resource identifier
tool_name: Tool to execute
parameters: Tool parameters
capability_token: JWT capability token
user_id: User executing the tool
Returns:
Tool execution result
"""
# Load MCP resource
mcp_resource = self.mcp_resources.get(mcp_resource_id)
if not mcp_resource:
raise ValueError(f"MCP resource not found: {mcp_resource_id}")
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data:
raise PermissionError("Invalid capability token")
# Check tenant match
if token_data.get("tenant_id") != mcp_resource.tenant_domain:
raise PermissionError("Tenant mismatch")
# Validate tool access
if not mcp_resource.validate_tool_access(tool_name, token_data):
raise PermissionError(f"No capability for tool: {tool_name}")
# Check if tool exists
if tool_name not in mcp_resource.server_config.available_tools:
raise ValueError(f"Tool not available: {tool_name}")
# Apply rate limiting
async with self.rate_limiters[mcp_resource_id]:
try:
# Execute tool with timeout and sandboxing
result = await self._execute_tool_sandboxed(
mcp_resource, tool_name, parameters, user_id
)
# Update metrics
mcp_resource.total_requests += 1
# Audit log
self._log_tool_execution(
mcp_resource_id, tool_name, user_id, "success", result
)
return result
except Exception as e:
# Update error metrics
mcp_resource.error_count += 1
# Audit log
self._log_tool_execution(
mcp_resource_id, tool_name, user_id, "error", str(e)
)
raise
async def _execute_tool_sandboxed(
self,
mcp_resource: MCPServerResource,
tool_name: str,
parameters: Dict[str, Any],
user_id: str
) -> Dict[str, Any]:
"""Execute tool in sandboxed environment"""
# Create sandbox context
sandbox_context = {
"user_id": user_id,
"tenant_domain": mcp_resource.tenant_domain,
"resource_limits": {
"max_memory_mb": mcp_resource.server_config.max_memory_mb,
"max_cpu_percent": mcp_resource.server_config.max_cpu_percent,
"timeout_seconds": mcp_resource.server_config.timeout_seconds
},
"network_isolation": mcp_resource.server_config.network_isolation
}
# Execute based on server type
if mcp_resource.server_config.server_type == "filesystem":
return await self._execute_filesystem_tool(
tool_name, parameters, sandbox_context
)
elif mcp_resource.server_config.server_type == "github":
return await self._execute_github_tool(
tool_name, parameters, sandbox_context
)
elif mcp_resource.server_config.server_type == "slack":
return await self._execute_slack_tool(
tool_name, parameters, sandbox_context
)
elif mcp_resource.server_config.server_type == "web":
return await self._execute_web_tool(
tool_name, parameters, sandbox_context
)
elif mcp_resource.server_config.server_type == "database":
return await self._execute_database_tool(
tool_name, parameters, sandbox_context
)
else:
return await self._execute_custom_tool(
mcp_resource, tool_name, parameters, sandbox_context
)
async def _execute_filesystem_tool(
self,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute filesystem MCP tools"""
if tool_name == "read_file":
# Simulate file reading with sandbox constraints
file_path = parameters.get("path", "")
# Security validation
if not self._validate_file_path(file_path, sandbox_context):
raise PermissionError("Access denied to file path")
return {
"tool": "read_file",
"content": f"File content from {file_path}",
"size_bytes": 1024,
"mime_type": "text/plain"
}
elif tool_name == "write_file":
file_path = parameters.get("path", "")
content = parameters.get("content", "")
# Security validation
if not self._validate_file_path(file_path, sandbox_context):
raise PermissionError("Access denied to file path")
if len(content) > 1024 * 1024: # 1MB limit
raise ValueError("File content too large")
return {
"tool": "write_file",
"path": file_path,
"bytes_written": len(content),
"status": "success"
}
elif tool_name == "list_directory":
dir_path = parameters.get("path", "")
if not self._validate_file_path(dir_path, sandbox_context):
raise PermissionError("Access denied to directory path")
return {
"tool": "list_directory",
"path": dir_path,
"entries": ["file1.txt", "file2.txt", "subdir/"],
"total_entries": 3
}
else:
raise ValueError(f"Unknown filesystem tool: {tool_name}")
async def _execute_github_tool(
self,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute GitHub MCP tools"""
if tool_name == "get_repository":
repo_name = parameters.get("repository", "")
return {
"tool": "get_repository",
"repository": repo_name,
"owner": "example",
"description": "Example repository",
"language": "Python",
"stars": 123,
"forks": 45
}
elif tool_name == "create_issue":
title = parameters.get("title", "")
body = parameters.get("body", "")
return {
"tool": "create_issue",
"issue_number": 42,
"title": title,
"url": f"https://github.com/example/repo/issues/42",
"status": "created"
}
elif tool_name == "search_code":
query = parameters.get("query", "")
return {
"tool": "search_code",
"query": query,
"results": [
{
"file": "main.py",
"line": 15,
"content": f"# Code matching {query}"
}
],
"total_results": 1
}
else:
raise ValueError(f"Unknown GitHub tool: {tool_name}")
async def _execute_slack_tool(
self,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute Slack MCP tools"""
if tool_name == "send_message":
channel = parameters.get("channel", "")
message = parameters.get("message", "")
return {
"tool": "send_message",
"channel": channel,
"message": message,
"timestamp": datetime.utcnow().isoformat(),
"status": "sent"
}
elif tool_name == "get_channel_history":
channel = parameters.get("channel", "")
limit = parameters.get("limit", 10)
return {
"tool": "get_channel_history",
"channel": channel,
"messages": [
{
"user": "user1",
"text": "Hello world!",
"timestamp": datetime.utcnow().isoformat()
}
] * min(limit, 10),
"total_messages": limit
}
else:
raise ValueError(f"Unknown Slack tool: {tool_name}")
async def _execute_web_tool(
self,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute web MCP tools"""
if tool_name == "fetch_url":
url = parameters.get("url", "")
# URL validation
if not self._validate_url(url, sandbox_context):
raise PermissionError("Access denied to URL")
return {
"tool": "fetch_url",
"url": url,
"status_code": 200,
"content": f"Content from {url}",
"headers": {"content-type": "text/html"}
}
elif tool_name == "submit_form":
url = parameters.get("url", "")
form_data = parameters.get("form_data", {})
if not self._validate_url(url, sandbox_context):
raise PermissionError("Access denied to URL")
return {
"tool": "submit_form",
"url": url,
"form_data": form_data,
"status_code": 200,
"response": "Form submitted successfully"
}
else:
raise ValueError(f"Unknown web tool: {tool_name}")
async def _execute_database_tool(
self,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute database MCP tools"""
if tool_name == "execute_query":
query = parameters.get("query", "")
# Query validation
if not self._validate_sql_query(query, sandbox_context):
raise PermissionError("Query not allowed")
return {
"tool": "execute_query",
"query": query,
"rows": [
{"id": 1, "name": "Example"},
{"id": 2, "name": "Data"}
],
"row_count": 2,
"execution_time_ms": 15
}
else:
raise ValueError(f"Unknown database tool: {tool_name}")
async def _execute_custom_tool(
self,
mcp_resource: MCPServerResource,
tool_name: str,
parameters: Dict[str, Any],
sandbox_context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute custom MCP tool via WebSocket transport"""
# This would connect to the actual MCP server via WebSocket
# For now, simulate the execution
await asyncio.sleep(0.1) # Simulate network delay
return {
"tool": tool_name,
"parameters": parameters,
"result": f"Custom tool {tool_name} executed successfully",
"server_type": mcp_resource.server_config.server_type,
"execution_time_ms": 100
}
def _validate_file_path(
self,
file_path: str,
sandbox_context: Dict[str, Any]
) -> bool:
"""Validate file path for security"""
# Basic path traversal prevention
if ".." in file_path or file_path.startswith("/"):
return False
# Check allowed extensions
allowed_extensions = [".txt", ".md", ".json", ".py", ".js"]
if not any(file_path.endswith(ext) for ext in allowed_extensions):
return False
return True
def _validate_url(
self,
url: str,
sandbox_context: Dict[str, Any]
) -> bool:
"""Validate URL for security"""
# Basic URL validation
if not url.startswith(("http://", "https://")):
return False
# Block internal/localhost URLs if network isolation enabled
if sandbox_context.get("network_isolation", True):
if any(domain in url for domain in ["localhost", "127.0.0.1", "10.", "192.168."]):
return False
return True
def _validate_sql_query(
self,
query: str,
sandbox_context: Dict[str, Any]
) -> bool:
"""Validate SQL query for security"""
# Block dangerous SQL operations
dangerous_keywords = [
"DROP", "DELETE", "UPDATE", "INSERT", "CREATE", "ALTER",
"TRUNCATE", "EXEC", "EXECUTE", "xp_", "sp_"
]
query_upper = query.upper()
for keyword in dangerous_keywords:
if keyword in query_upper:
return False
return True
def _log_tool_execution(
self,
mcp_resource_id: str,
tool_name: str,
user_id: str,
status: str,
result: Any
) -> None:
"""Log tool execution for audit"""
log_entry = {
"timestamp": datetime.utcnow().isoformat(),
"mcp_resource_id": mcp_resource_id,
"tool_name": tool_name,
"user_id": user_id,
"status": status,
"result_summary": str(result)[:200] if result else None
}
self.audit_log.append(log_entry)
# Keep only last 1000 entries
if len(self.audit_log) > 1000:
self.audit_log = self.audit_log[-1000:]
async def _monitor_health(self, mcp_resource_id: str) -> None:
"""Monitor MCP server health"""
while mcp_resource_id in self.mcp_resources:
try:
mcp_resource = self.mcp_resources[mcp_resource_id]
# Simulate health check
await asyncio.sleep(30) # Check every 30 seconds
# Update health status
if mcp_resource.error_count > 10:
mcp_resource.status = MCPServerStatus.DEGRADED
elif mcp_resource.error_count > 50:
mcp_resource.status = MCPServerStatus.UNHEALTHY
else:
mcp_resource.status = MCPServerStatus.HEALTHY
mcp_resource.last_health_check = datetime.utcnow()
logger.debug(f"Health check for MCP resource {mcp_resource_id}: {mcp_resource.status}")
except Exception as e:
logger.error(f"Health check failed for MCP resource {mcp_resource_id}: {e}")
if mcp_resource_id in self.mcp_resources:
self.mcp_resources[mcp_resource_id].status = MCPServerStatus.UNHEALTHY
async def get_resource_status(
self,
mcp_resource_id: str,
capability_token: str
) -> Dict[str, Any]:
"""Get MCP resource status"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data:
raise PermissionError("Invalid capability token")
# Load MCP resource
mcp_resource = self.mcp_resources.get(mcp_resource_id)
if not mcp_resource:
raise ValueError(f"MCP resource not found: {mcp_resource_id}")
# Check tenant match
if token_data.get("tenant_id") != mcp_resource.tenant_domain:
raise PermissionError("Tenant mismatch")
return {
"resource_id": mcp_resource_id,
"name": mcp_resource.name,
"server_type": mcp_resource.server_config.server_type,
"status": mcp_resource.status,
"total_requests": mcp_resource.total_requests,
"error_count": mcp_resource.error_count,
"active_connections": mcp_resource.active_connections,
"last_health_check": mcp_resource.last_health_check.isoformat() if mcp_resource.last_health_check else None,
"available_tools": mcp_resource.server_config.available_tools
}
async def list_mcp_resources(
self,
capability_token: str,
tenant_domain: Optional[str] = None
) -> List[Dict[str, Any]]:
"""List available MCP resources"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data:
raise PermissionError("Invalid capability token")
tenant_filter = tenant_domain or token_data.get("tenant_id")
resources = []
for resource in self.mcp_resources.values():
if resource.tenant_domain == tenant_filter:
resources.append({
"resource_id": resource.id,
"name": resource.name,
"server_type": resource.server_config.server_type,
"status": resource.status,
"tool_count": len(resource.server_config.available_tools),
"created_at": resource.created_at.isoformat()
})
return resources
# Global MCP wrapper instance
_mcp_wrapper = None
def get_mcp_wrapper() -> SecureMCPWrapper:
"""Get the global MCP wrapper instance"""
global _mcp_wrapper
if _mcp_wrapper is None:
_mcp_wrapper = SecureMCPWrapper()
return _mcp_wrapper

View File

@@ -0,0 +1,296 @@
"""
GT 2.0 Model Router
Routes inference requests to appropriate providers based on model registry.
Integrates with provider factory for dynamic provider selection.
"""
import asyncio
import logging
from typing import Dict, Any, Optional, AsyncIterator
from datetime import datetime
from app.services.model_service import get_model_service
from app.providers import get_provider_factory
from app.core.backends import get_backend
from app.core.exceptions import ProviderError
logger = logging.getLogger(__name__)
class ModelRouter:
"""Routes model requests to appropriate providers"""
def __init__(self, tenant_id: Optional[str] = None):
self.tenant_id = tenant_id
# Use default model service for shared model registry (config sync writes to default)
# Note: Tenant isolation is handled via capability tokens, not separate databases
self.model_service = get_model_service(None)
self.provider_factory = None
self.backend_cache = {}
async def initialize(self):
"""Initialize model router"""
try:
self.provider_factory = await get_provider_factory()
logger.info(f"Model router initialized for tenant: {self.tenant_id or 'default'}")
except Exception as e:
logger.error(f"Failed to initialize model router: {e}")
raise
async def route_inference(
self,
model_id: str,
prompt: Optional[str] = None,
messages: Optional[list] = None,
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: Optional[str] = None,
tenant_id: Optional[str] = None,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""Route inference request to appropriate provider"""
# Get model configuration from registry
model_config = await self.model_service.get_model(model_id)
if not model_config:
raise ProviderError(f"Model {model_id} not found in registry")
provider = model_config["provider"]
# Track model usage
start_time = datetime.utcnow()
try:
# Route to configured endpoint (generic routing for any provider)
endpoint_url = model_config.get("endpoint")
if not endpoint_url:
raise ProviderError(f"No endpoint configured for model {model_id}")
result = await self._route_to_generic_endpoint(
endpoint_url, model_id, prompt, messages, temperature, max_tokens, stream, user_id, tenant_id, tools, tool_choice, **kwargs
)
# Calculate latency
latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
# Track successful usage
await self.model_service.track_model_usage(
model_id, success=True, latency_ms=latency_ms
)
return result
except Exception as e:
# Track failed usage
latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
await self.model_service.track_model_usage(
model_id, success=False, latency_ms=latency_ms
)
logger.error(f"Model routing failed for {model_id}: {e}")
raise
async def _route_to_groq(
self,
model_id: str,
prompt: Optional[str],
messages: Optional[list],
temperature: float,
max_tokens: int,
stream: bool,
user_id: Optional[str],
tenant_id: Optional[str],
tools: Optional[list],
tool_choice: Optional[str],
**kwargs
) -> Dict[str, Any]:
"""Route request to Groq backend"""
try:
backend = get_backend("groq_proxy")
if not backend:
raise ProviderError("Groq backend not available")
if messages:
return await backend.execute_inference_with_messages(
messages=messages,
model=model_id,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
user_id=user_id,
tenant_id=tenant_id,
tools=tools,
tool_choice=tool_choice
)
else:
return await backend.execute_inference(
prompt=prompt,
model=model_id,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
user_id=user_id,
tenant_id=tenant_id
)
except Exception as e:
logger.error(f"Groq routing failed: {e}")
raise ProviderError(f"Groq inference failed: {e}")
async def _route_to_external(
self,
model_id: str,
prompt: Optional[str],
messages: Optional[list],
temperature: float,
max_tokens: int,
stream: bool,
user_id: Optional[str],
tenant_id: Optional[str],
**kwargs
) -> Dict[str, Any]:
"""Route request to external provider"""
try:
if not self.provider_factory:
await self.initialize()
external_provider = self.provider_factory.get_provider("external")
if not external_provider:
raise ProviderError("External provider not available")
# For embedding models
if model_id == "bge-m3-embedding":
# Convert prompt/messages to text list
texts = []
if messages:
texts = [msg.get("content", "") for msg in messages if msg.get("content")]
elif prompt:
texts = [prompt]
return await external_provider.generate_embeddings(
model_id=model_id,
texts=texts
)
else:
raise ProviderError(f"External model {model_id} not supported for inference")
except Exception as e:
logger.error(f"External routing failed: {e}")
raise ProviderError(f"External inference failed: {e}")
async def _route_to_openai(
self,
model_id: str,
prompt: Optional[str],
messages: Optional[list],
temperature: float,
max_tokens: int,
stream: bool,
user_id: Optional[str],
tenant_id: Optional[str],
**kwargs
) -> Dict[str, Any]:
"""Route request to OpenAI provider"""
raise ProviderError("OpenAI provider not implemented - use Groq models instead")
async def _route_to_generic_endpoint(
self,
endpoint_url: str,
model_id: str,
prompt: Optional[str],
messages: Optional[list],
temperature: float,
max_tokens: int,
stream: bool,
user_id: Optional[str],
tenant_id: Optional[str],
tools: Optional[list],
tool_choice: Optional[str],
**kwargs
) -> Dict[str, Any]:
"""Route request to any configured endpoint using OpenAI-compatible API"""
import httpx
import time
try:
# Build OpenAI-compatible request
request_data = {
"model": model_id,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream
}
# Use messages if provided, otherwise convert prompt to messages
if messages:
request_data["messages"] = messages
elif prompt:
request_data["messages"] = [{"role": "user", "content": prompt}]
else:
raise ProviderError("Either messages or prompt must be provided")
# Add tools if provided
if tools:
request_data["tools"] = tools
if tool_choice:
request_data["tool_choice"] = tool_choice
# Add any additional parameters
request_data.update(kwargs)
logger.info(f"Routing request to endpoint: {endpoint_url}")
logger.debug(f"Request data: {request_data}")
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
endpoint_url,
json=request_data,
headers={"Content-Type": "application/json"}
)
if response.status_code != 200:
error_text = response.text
logger.error(f"Endpoint {endpoint_url} returned {response.status_code}: {error_text}")
raise ProviderError(f"Endpoint error: {response.status_code} - {error_text}")
result = response.json()
logger.debug(f"Endpoint response: {result}")
return result
except httpx.RequestError as e:
logger.error(f"Request to {endpoint_url} failed: {e}")
raise ProviderError(f"Connection to endpoint failed: {str(e)}")
except Exception as e:
logger.error(f"Generic endpoint routing failed: {e}")
raise ProviderError(f"Inference failed: {str(e)}")
async def list_available_models(self) -> list:
"""List all available models from registry"""
# Get all models (deployment status filtering available if needed)
models = await self.model_service.list_models()
return models
async def get_model_health(self, model_id: str) -> Dict[str, Any]:
"""Check health of specific model"""
return await self.model_service.check_model_health(model_id)
# Global model router instances per tenant
_model_routers = {}
async def get_model_router(tenant_id: Optional[str] = None) -> ModelRouter:
"""Get model router instance for tenant"""
global _model_routers
cache_key = tenant_id or "default"
if cache_key not in _model_routers:
router = ModelRouter(tenant_id)
await router.initialize()
_model_routers[cache_key] = router
return _model_routers[cache_key]

View File

@@ -0,0 +1,720 @@
"""
GT 2.0 Model Management Service - Stateless Version
Provides centralized model registry, versioning, deployment, and lifecycle management
for all AI models across the Resource Cluster using in-memory storage.
"""
import json
import time
import asyncio
from typing import Dict, Any, List, Optional, Union
from datetime import datetime, timedelta
from pathlib import Path
import hashlib
import httpx
import logging
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class ModelService:
"""Stateless model management service with in-memory registry"""
def __init__(self, tenant_id: Optional[str] = None):
self.tenant_id = tenant_id
self.settings = get_settings(tenant_id)
# In-memory model registry for stateless operation
self.model_registry: Dict[str, Dict[str, Any]] = {}
self.last_cache_update = 0
self.cache_ttl = 300 # 5 minutes
# Performance tracking (in-memory)
self.performance_metrics: Dict[str, Dict[str, Any]] = {}
# Initialize with default models synchronously
self._initialize_default_models_sync()
async def register_model(
self,
model_id: str,
name: str,
version: str,
provider: str,
model_type: str,
description: str = "",
capabilities: Dict[str, Any] = None,
parameters: Dict[str, Any] = None,
endpoint_url: str = None,
**kwargs
) -> Dict[str, Any]:
"""Register a new model in the in-memory registry"""
now = datetime.utcnow()
# Create or update model entry
model_entry = {
"id": model_id,
"name": name,
"version": version,
"provider": provider,
"model_type": model_type,
"description": description,
"capabilities": capabilities or {},
"parameters": parameters or {},
# Performance metrics
"max_tokens": kwargs.get("max_tokens", 4000),
"context_window": kwargs.get("context_window", 4000),
"cost_per_1k_tokens": kwargs.get("cost_per_1k_tokens", 0.0),
"latency_p50_ms": kwargs.get("latency_p50_ms", 0.0),
"latency_p95_ms": kwargs.get("latency_p95_ms", 0.0),
# Deployment status
"deployment_status": kwargs.get("deployment_status", "available"),
"health_status": kwargs.get("health_status", "unknown"),
"last_health_check": kwargs.get("last_health_check"),
# Usage tracking
"request_count": kwargs.get("request_count", 0),
"error_count": kwargs.get("error_count", 0),
"success_rate": kwargs.get("success_rate", 1.0),
# Lifecycle
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
"retired_at": kwargs.get("retired_at"),
# Configuration
"endpoint_url": endpoint_url,
"api_key_required": kwargs.get("api_key_required", True),
"rate_limits": kwargs.get("rate_limits", {})
}
self.model_registry[model_id] = model_entry
logger.info(f"Registered model: {model_id} ({name} v{version})")
return model_entry
async def get_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""Get model by ID"""
return self.model_registry.get(model_id)
async def list_models(
self,
provider: str = None,
model_type: str = None,
deployment_status: str = None,
health_status: str = None
) -> List[Dict[str, Any]]:
"""List models with optional filters"""
models = list(self.model_registry.values())
# Apply filters
if provider:
models = [m for m in models if m["provider"] == provider]
if model_type:
models = [m for m in models if m["model_type"] == model_type]
if deployment_status:
models = [m for m in models if m["deployment_status"] == deployment_status]
if health_status:
models = [m for m in models if m["health_status"] == health_status]
# Sort by created_at desc
models.sort(key=lambda x: x["created_at"], reverse=True)
return models
async def update_model_status(
self,
model_id: str,
deployment_status: str = None,
health_status: str = None
) -> bool:
"""Update model deployment and health status"""
model = self.model_registry.get(model_id)
if not model:
return False
if deployment_status:
model["deployment_status"] = deployment_status
if health_status:
model["health_status"] = health_status
model["last_health_check"] = datetime.utcnow().isoformat()
model["updated_at"] = datetime.utcnow().isoformat()
return True
async def track_model_usage(
self,
model_id: str,
success: bool = True,
latency_ms: float = None
):
"""Track model usage and performance metrics"""
model = self.model_registry.get(model_id)
if not model:
return
# Update usage counters
model["request_count"] += 1
if not success:
model["error_count"] += 1
# Calculate success rate
model["success_rate"] = (model["request_count"] - model["error_count"]) / model["request_count"]
# Update latency metrics (simple running average)
if latency_ms is not None:
if model["latency_p50_ms"] == 0:
model["latency_p50_ms"] = latency_ms
else:
# Simple exponential moving average
alpha = 0.1
model["latency_p50_ms"] = alpha * latency_ms + (1 - alpha) * model["latency_p50_ms"]
# P95 approximation (conservative estimate)
model["latency_p95_ms"] = max(model["latency_p95_ms"], latency_ms * 1.5)
model["updated_at"] = datetime.utcnow().isoformat()
async def retire_model(self, model_id: str, reason: str = "") -> bool:
"""Retire a model (mark as no longer available)"""
model = self.model_registry.get(model_id)
if not model:
return False
model["deployment_status"] = "retired"
model["retired_at"] = datetime.utcnow().isoformat()
model["updated_at"] = datetime.utcnow().isoformat()
if reason:
model["description"] += f"\n\nRetired: {reason}"
logger.info(f"Retired model: {model_id} - {reason}")
return True
async def check_model_health(self, model_id: str) -> Dict[str, Any]:
"""Check health of a specific model"""
model = await self.get_model(model_id)
if not model:
return {"healthy": False, "error": "Model not found"}
# Generic health check for any provider with endpoint
if "endpoint" in model and model["endpoint"]:
return await self._check_generic_model_health(model)
elif model["provider"] == "groq":
return await self._check_groq_model_health(model)
elif model["provider"] == "openai":
return await self._check_openai_model_health(model)
elif model["provider"] == "local":
return await self._check_local_model_health(model)
else:
return {"healthy": False, "error": f"No health check method for provider: {model['provider']}"}
async def _check_groq_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
"""Health check for Groq models"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.groq.com/openai/v1/models",
headers={"Authorization": f"Bearer {settings.groq_api_key}"},
timeout=10.0
)
if response.status_code == 200:
models = response.json()
model_ids = [m["id"] for m in models.get("data", [])]
is_available = model["id"] in model_ids
await self.update_model_status(
model["id"],
health_status="healthy" if is_available else "unhealthy"
)
return {
"healthy": is_available,
"latency_ms": response.elapsed.total_seconds() * 1000,
"available_models": len(model_ids)
}
else:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": f"API error: {response.status_code}"}
except Exception as e:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": str(e)}
async def _check_openai_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
"""Health check for OpenAI models"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {settings.openai_api_key}"},
timeout=10.0
)
if response.status_code == 200:
models = response.json()
model_ids = [m["id"] for m in models.get("data", [])]
is_available = model["id"] in model_ids
await self.update_model_status(
model["id"],
health_status="healthy" if is_available else "unhealthy"
)
return {
"healthy": is_available,
"latency_ms": response.elapsed.total_seconds() * 1000
}
else:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": f"API error: {response.status_code}"}
except Exception as e:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": str(e)}
async def _check_generic_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
"""Generic health check for any provider with configured endpoint"""
try:
endpoint_url = model.get("endpoint")
if not endpoint_url:
return {"healthy": False, "error": "No endpoint URL configured"}
# Try a simple health check by making a minimal request
async with httpx.AsyncClient(timeout=10.0) as client:
# For OpenAI-compatible endpoints, try a models list request
try:
# Try /v1/models endpoint first (common for OpenAI-compatible APIs)
models_url = endpoint_url.replace("/chat/completions", "/models").replace("/v1/chat/completions", "/v1/models")
response = await client.get(models_url)
if response.status_code == 200:
await self.update_model_status(model["id"], health_status="healthy")
return {
"healthy": True,
"provider": model.get("provider", "unknown"),
"latency_ms": 0, # Could measure actual latency
"last_check": datetime.utcnow().isoformat(),
"details": "Endpoint responding to models request"
}
except:
pass
# If models endpoint doesn't work, try a basic health endpoint
try:
health_url = endpoint_url.replace("/chat/completions", "/health").replace("/v1/chat/completions", "/health")
response = await client.get(health_url)
if response.status_code == 200:
await self.update_model_status(model["id"], health_status="healthy")
return {
"healthy": True,
"provider": model.get("provider", "unknown"),
"latency_ms": 0,
"last_check": datetime.utcnow().isoformat(),
"details": "Endpoint responding to health check"
}
except:
pass
# If neither works, assume healthy if endpoint is reachable at all
await self.update_model_status(model["id"], health_status="unknown")
return {
"healthy": True, # Assume healthy for generic endpoints
"provider": model.get("provider", "unknown"),
"latency_ms": 0,
"last_check": datetime.utcnow().isoformat(),
"details": "Generic endpoint - health check not available"
}
except Exception as e:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": f"Health check failed: {str(e)}"}
async def _check_local_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
"""Health check for local models"""
try:
endpoint_url = model.get("endpoint_url")
if not endpoint_url:
return {"healthy": False, "error": "No endpoint URL configured"}
async with httpx.AsyncClient() as client:
response = await client.get(
f"{endpoint_url}/health",
timeout=5.0
)
healthy = response.status_code == 200
await self.update_model_status(
model["id"],
health_status="healthy" if healthy else "unhealthy"
)
return {
"healthy": healthy,
"latency_ms": response.elapsed.total_seconds() * 1000
}
except Exception as e:
await self.update_model_status(model["id"], health_status="unhealthy")
return {"healthy": False, "error": str(e)}
async def bulk_health_check(self) -> Dict[str, Any]:
"""Check health of all registered models"""
models = await self.list_models()
health_results = {}
# Run health checks concurrently
tasks = []
for model in models:
task = asyncio.create_task(self.check_model_health(model["id"]))
tasks.append((model["id"], task))
for model_id, task in tasks:
try:
health_result = await task
health_results[model_id] = health_result
except Exception as e:
health_results[model_id] = {"healthy": False, "error": str(e)}
# Calculate overall health statistics
total_models = len(health_results)
healthy_models = sum(1 for result in health_results.values() if result.get("healthy", False))
return {
"total_models": total_models,
"healthy_models": healthy_models,
"unhealthy_models": total_models - healthy_models,
"health_percentage": (healthy_models / total_models * 100) if total_models > 0 else 0,
"individual_results": health_results
}
async def get_model_analytics(
self,
model_id: str = None,
timeframe_hours: int = 24
) -> Dict[str, Any]:
"""Get analytics for model usage and performance"""
models = await self.list_models()
if model_id:
models = [m for m in models if m["id"] == model_id]
analytics = {
"total_models": len(models),
"by_provider": {},
"by_type": {},
"performance_summary": {
"avg_latency_p50": 0,
"avg_success_rate": 0,
"total_requests": 0,
"total_errors": 0
},
"top_performers": [],
"models": models
}
total_latency = 0
total_success_rate = 0
total_requests = 0
total_errors = 0
for model in models:
# Provider statistics
provider = model["provider"]
if provider not in analytics["by_provider"]:
analytics["by_provider"][provider] = {"count": 0, "requests": 0}
analytics["by_provider"][provider]["count"] += 1
analytics["by_provider"][provider]["requests"] += model["request_count"]
# Type statistics
model_type = model["model_type"]
if model_type not in analytics["by_type"]:
analytics["by_type"][model_type] = {"count": 0, "requests": 0}
analytics["by_type"][model_type]["count"] += 1
analytics["by_type"][model_type]["requests"] += model["request_count"]
# Performance aggregation
total_latency += model["latency_p50_ms"]
total_success_rate += model["success_rate"]
total_requests += model["request_count"]
total_errors += model["error_count"]
# Calculate averages
if len(models) > 0:
analytics["performance_summary"]["avg_latency_p50"] = total_latency / len(models)
analytics["performance_summary"]["avg_success_rate"] = total_success_rate / len(models)
analytics["performance_summary"]["total_requests"] = total_requests
analytics["performance_summary"]["total_errors"] = total_errors
# Top performers (by success rate and low latency)
analytics["top_performers"] = sorted(
[m for m in models if m["request_count"] > 0],
key=lambda x: (x["success_rate"], -x["latency_p50_ms"]),
reverse=True
)[:5]
return analytics
async def _initialize_default_models(self):
"""Initialize registry with default models"""
# Groq models
groq_models = [
{
"model_id": "llama-3.1-405b-reasoning",
"name": "Llama 3.1 405B Reasoning",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Largest Llama model optimized for complex reasoning tasks",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 2.5,
"capabilities": {"reasoning": True, "function_calling": True, "streaming": True}
},
{
"model_id": "llama-3.1-70b-versatile",
"name": "Llama 3.1 70B Versatile",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Balanced Llama model for general-purpose tasks",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 0.8,
"capabilities": {"general": True, "function_calling": True, "streaming": True}
},
{
"model_id": "llama-3.1-8b-instant",
"name": "Llama 3.1 8B Instant",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Fast Llama model for quick responses",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 0.2,
"capabilities": {"fast": True, "streaming": True}
},
{
"model_id": "mixtral-8x7b-32768",
"name": "Mixtral 8x7B",
"version": "1.0",
"provider": "groq",
"model_type": "llm",
"description": "Mixtral model for balanced performance",
"max_tokens": 32768,
"context_window": 32768,
"cost_per_1k_tokens": 0.27,
"capabilities": {"general": True, "streaming": True}
}
]
for model_config in groq_models:
await self.register_model(**model_config)
logger.info("Initialized default model registry with in-memory storage")
def _initialize_default_models_sync(self):
"""Initialize registry with default models synchronously"""
# Groq models
groq_models = [
{
"model_id": "llama-3.1-405b-reasoning",
"name": "Llama 3.1 405B Reasoning",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Largest Llama model optimized for complex reasoning tasks",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 2.5,
"capabilities": {"reasoning": True, "function_calling": True, "streaming": True}
},
{
"model_id": "llama-3.1-70b-versatile",
"name": "Llama 3.1 70B Versatile",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Balanced Llama model for general-purpose tasks",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 0.8,
"capabilities": {"general": True, "function_calling": True, "streaming": True}
},
{
"model_id": "llama-3.1-8b-instant",
"name": "Llama 3.1 8B Instant",
"version": "3.1",
"provider": "groq",
"model_type": "llm",
"description": "Fast Llama model for quick responses",
"max_tokens": 8000,
"context_window": 32768,
"cost_per_1k_tokens": 0.2,
"capabilities": {"fast": True, "streaming": True}
},
{
"model_id": "mixtral-8x7b-32768",
"name": "Mixtral 8x7B",
"version": "1.0",
"provider": "groq",
"model_type": "llm",
"description": "Mixtral model for balanced performance",
"max_tokens": 32768,
"context_window": 32768,
"cost_per_1k_tokens": 0.27,
"capabilities": {"general": True, "streaming": True}
},
{
"model_id": "groq/compound",
"name": "Groq Compound Model",
"version": "1.0",
"provider": "groq",
"model_type": "llm",
"description": "Groq compound AI model",
"max_tokens": 8000,
"context_window": 8000,
"cost_per_1k_tokens": 0.5,
"capabilities": {"general": True, "streaming": True}
}
]
for model_config in groq_models:
now = datetime.utcnow()
model_entry = {
"id": model_config["model_id"],
"name": model_config["name"],
"version": model_config["version"],
"provider": model_config["provider"],
"model_type": model_config["model_type"],
"description": model_config["description"],
"capabilities": model_config["capabilities"],
"parameters": {},
# Performance metrics
"max_tokens": model_config["max_tokens"],
"context_window": model_config["context_window"],
"cost_per_1k_tokens": model_config["cost_per_1k_tokens"],
"latency_p50_ms": 0.0,
"latency_p95_ms": 0.0,
# Deployment status
"deployment_status": "available",
"health_status": "unknown",
"last_health_check": None,
# Usage tracking
"request_count": 0,
"error_count": 0,
"success_rate": 1.0,
# Lifecycle
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
"retired_at": None,
# Configuration
"endpoint_url": None,
"api_key_required": True,
"rate_limits": {}
}
self.model_registry[model_config["model_id"]] = model_entry
logger.info("Initialized default model registry with in-memory storage (sync)")
async def register_or_update_model(
self,
model_id: str,
name: str,
version: str = "1.0",
provider: str = "unknown",
model_type: str = "llm",
endpoint: str = "",
api_key_name: str = None,
specifications: Dict[str, Any] = None,
capabilities: Dict[str, Any] = None,
cost: Dict[str, Any] = None,
description: str = "",
config: Dict[str, Any] = None,
status: Dict[str, Any] = None,
sync_timestamp: str = None
) -> Dict[str, Any]:
"""Register a new model or update existing one from admin cluster sync"""
specifications = specifications or {}
capabilities = capabilities or {}
cost = cost or {}
config = config or {}
status = status or {}
# Check if model exists
existing_model = self.model_registry.get(model_id)
if existing_model:
# Update existing model
existing_model.update({
"name": name,
"version": version,
"provider": provider,
"model_type": model_type,
"description": description,
"capabilities": capabilities,
"parameters": config,
"endpoint_url": endpoint,
"api_key_required": bool(api_key_name),
"max_tokens": specifications.get("max_tokens", existing_model.get("max_tokens", 4000)),
"context_window": specifications.get("context_window", existing_model.get("context_window", 4000)),
"cost_per_1k_tokens": cost.get("per_1k_input", existing_model.get("cost_per_1k_tokens", 0.0)),
"deployment_status": "deployed" if status.get("is_active", True) else "retired",
"updated_at": datetime.utcnow().isoformat()
})
if "bge-m3" in model_id.lower():
logger.info(f"Updated BGE-M3 model: endpoint_url={endpoint}, parameters={config}")
logger.debug(f"Updated model: {model_id}")
return existing_model
else:
# Register new model
return await self.register_model(
model_id=model_id,
name=name,
version=version,
provider=provider,
model_type=model_type,
description=description,
capabilities=capabilities,
parameters=config,
endpoint_url=endpoint,
max_tokens=specifications.get("max_tokens", 4000),
context_window=specifications.get("context_window", 4000),
cost_per_1k_tokens=cost.get("per_1k_input", 0.0),
api_key_required=bool(api_key_name)
)
def get_model_service(tenant_id: Optional[str] = None) -> ModelService:
"""Get tenant-isolated model service instance"""
return ModelService(tenant_id=tenant_id)
# Default model service for development/non-tenant operations
default_model_service = get_model_service()

View File

@@ -0,0 +1,931 @@
"""
GT 2.0 Resource Cluster - Service Manager
Orchestrates external web services (CTFd, Canvas LMS, Guacamole, JupyterHub)
with perfect tenant isolation and security.
"""
import asyncio
import json
import logging
import subprocess
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass, asdict
from pathlib import Path
try:
import docker
import kubernetes
from kubernetes import client, config
from kubernetes.client.rest import ApiException
DOCKER_AVAILABLE = True
KUBERNETES_AVAILABLE = True
except ImportError:
# For development containerization mode, these are optional
docker = None
kubernetes = None
client = None
config = None
ApiException = Exception
DOCKER_AVAILABLE = False
KUBERNETES_AVAILABLE = False
from app.core.config import get_settings
from app.core.security import verify_capability_token
from app.utils.encryption import encrypt_data, decrypt_data
logger = logging.getLogger(__name__)
@dataclass
class ServiceInstance:
"""Represents a deployed service instance"""
instance_id: str
tenant_id: str
service_type: str # 'ctfd', 'canvas', 'guacamole', 'jupyter'
status: str # 'starting', 'running', 'stopping', 'stopped', 'error'
endpoint_url: str
internal_port: int
external_port: int
namespace: str
deployment_name: str
service_name: str
ingress_name: str
sso_token: Optional[str] = None
created_at: datetime = datetime.utcnow()
last_heartbeat: datetime = datetime.utcnow()
resource_usage: Dict[str, Any] = None
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
data['last_heartbeat'] = self.last_heartbeat.isoformat()
return data
@dataclass
class ServiceTemplate:
"""Service deployment template configuration"""
service_type: str
image: str
ports: Dict[str, int]
environment: Dict[str, str]
volumes: List[Dict[str, str]]
resource_limits: Dict[str, str]
security_context: Dict[str, Any]
health_check: Dict[str, Any]
sso_config: Dict[str, Any]
class ServiceManager:
"""Manages external web service instances with Kubernetes orchestration"""
def __init__(self):
# Initialize Docker client if available
if DOCKER_AVAILABLE:
try:
self.docker_client = docker.from_env()
except Exception as e:
logger.warning(f"Could not initialize Docker client: {e}")
self.docker_client = None
else:
self.docker_client = None
self.k8s_client = None
self.active_instances: Dict[str, ServiceInstance] = {}
self.service_templates: Dict[str, ServiceTemplate] = {}
self.base_namespace = "gt-services"
self.storage_path = Path("/tmp/resource-cluster/services")
self.storage_path.mkdir(parents=True, exist_ok=True)
# Initialize Kubernetes client if available
if KUBERNETES_AVAILABLE:
try:
config.load_incluster_config() # If running in cluster
except:
try:
config.load_kube_config() # If running locally
except:
logger.warning("Could not load Kubernetes config - using mock mode")
self.k8s_client = client.ApiClient() if client else None
else:
logger.warning("Kubernetes not available - running in development containerization mode")
self._initialize_service_templates()
self._load_persistent_instances()
def _initialize_service_templates(self):
"""Initialize service deployment templates"""
# CTFd Template
self.service_templates['ctfd'] = ServiceTemplate(
service_type='ctfd',
image='ctfd/ctfd:3.6.0',
ports={'http': 8000},
environment={
'SECRET_KEY': '${TENANT_SECRET_KEY}',
'DATABASE_URL': 'sqlite:////data/ctfd.db',
'DATABASE_CACHE_URL': 'postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres:5432/gt2_tenants',
'UPLOAD_FOLDER': '/data/uploads',
'LOG_FOLDER': '/data/logs',
},
volumes=[
{'name': 'ctfd-data', 'mountPath': '/data', 'size': '5Gi'},
{'name': 'ctfd-uploads', 'mountPath': '/uploads', 'size': '2Gi'}
],
resource_limits={
'memory': '2Gi',
'cpu': '1000m'
},
security_context={
'runAsNonRoot': True,
'runAsUser': 1000,
'fsGroup': 1000,
'readOnlyRootFilesystem': False
},
health_check={
'path': '/health',
'port': 8000,
'initial_delay': 30,
'period': 10
},
sso_config={
'enabled': True,
'provider': 'oauth2',
'callback_path': '/auth/oauth/callback'
}
)
# Canvas LMS Template
self.service_templates['canvas'] = ServiceTemplate(
service_type='canvas',
image='instructure/canvas-lms:stable',
ports={'http': 3000},
environment={
'CANVAS_LMS_ADMIN_EMAIL': 'admin@${TENANT_DOMAIN}',
'CANVAS_LMS_ADMIN_PASSWORD': '${CANVAS_ADMIN_PASSWORD}',
'CANVAS_LMS_ACCOUNT_NAME': '${TENANT_NAME}',
'CANVAS_LMS_STATS_COLLECTION': 'opt_out',
'POSTGRES_PASSWORD': '${POSTGRES_PASSWORD}',
'DATABASE_CACHE_URL': 'postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres:5432/gt2_tenants'
},
volumes=[
{'name': 'canvas-data', 'mountPath': '/app/log', 'size': '10Gi'},
{'name': 'canvas-files', 'mountPath': '/app/public/files', 'size': '20Gi'}
],
resource_limits={
'memory': '4Gi',
'cpu': '2000m'
},
security_context={
'runAsNonRoot': True,
'runAsUser': 1000,
'fsGroup': 1000
},
health_check={
'path': '/health_check',
'port': 3000,
'initial_delay': 60,
'period': 15
},
sso_config={
'enabled': True,
'provider': 'saml',
'metadata_url': '/auth/saml/metadata'
}
)
# Guacamole Template
self.service_templates['guacamole'] = ServiceTemplate(
service_type='guacamole',
image='guacamole/guacamole:1.5.3',
ports={'http': 8080},
environment={
'GUACD_HOSTNAME': 'guacd',
'GUACD_PORT': '4822',
'MYSQL_HOSTNAME': 'mysql',
'MYSQL_PORT': '3306',
'MYSQL_DATABASE': 'guacamole_db',
'MYSQL_USER': 'guacamole_user',
'MYSQL_PASSWORD': '${MYSQL_PASSWORD}',
'GUAC_LOG_LEVEL': 'INFO'
},
volumes=[
{'name': 'guacamole-data', 'mountPath': '/config', 'size': '1Gi'},
{'name': 'guacamole-recordings', 'mountPath': '/recordings', 'size': '10Gi'}
],
resource_limits={
'memory': '1Gi',
'cpu': '500m'
},
security_context={
'runAsNonRoot': True,
'runAsUser': 1001,
'fsGroup': 1001
},
health_check={
'path': '/guacamole',
'port': 8080,
'initial_delay': 45,
'period': 10
},
sso_config={
'enabled': True,
'provider': 'openid',
'extension': 'guacamole-auth-openid'
}
)
# JupyterHub Template
self.service_templates['jupyter'] = ServiceTemplate(
service_type='jupyter',
image='jupyterhub/jupyterhub:4.0',
ports={'http': 8000},
environment={
'JUPYTERHUB_CRYPT_KEY': '${JUPYTERHUB_CRYPT_KEY}',
'CONFIGPROXY_AUTH_TOKEN': '${CONFIGPROXY_AUTH_TOKEN}',
'DOCKER_NETWORK_NAME': 'jupyterhub',
'DOCKER_NOTEBOOK_IMAGE': 'jupyter/datascience-notebook:lab-4.0.7'
},
volumes=[
{'name': 'jupyter-data', 'mountPath': '/srv/jupyterhub', 'size': '5Gi'},
{'name': 'docker-socket', 'mountPath': '/var/run/docker.sock', 'hostPath': '/var/run/docker.sock'}
],
resource_limits={
'memory': '2Gi',
'cpu': '1000m'
},
security_context={
'runAsNonRoot': False, # Needs Docker access
'runAsUser': 0,
'privileged': True
},
health_check={
'path': '/hub/health',
'port': 8000,
'initial_delay': 30,
'period': 15
},
sso_config={
'enabled': True,
'provider': 'oauth',
'authenticator_class': 'oauthenticator.generic.GenericOAuthenticator'
}
)
async def create_service_instance(
self,
tenant_id: str,
service_type: str,
config_overrides: Dict[str, Any] = None
) -> ServiceInstance:
"""Create a new service instance for a tenant"""
if service_type not in self.service_templates:
raise ValueError(f"Unsupported service type: {service_type}")
template = self.service_templates[service_type]
instance_id = f"{service_type}-{tenant_id}-{uuid.uuid4().hex[:8]}"
namespace = f"{self.base_namespace}-{tenant_id}"
# Generate unique ports
external_port = await self._get_available_port()
# Create service instance object
instance = ServiceInstance(
instance_id=instance_id,
tenant_id=tenant_id,
service_type=service_type,
status='starting',
endpoint_url=f"https://{service_type}.{tenant_id}.gt2.com",
internal_port=template.ports['http'],
external_port=external_port,
namespace=namespace,
deployment_name=f"{service_type}-{instance_id}",
service_name=f"{service_type}-service-{instance_id}",
ingress_name=f"{service_type}-ingress-{instance_id}",
resource_usage={'cpu': 0, 'memory': 0, 'storage': 0}
)
try:
# Create Kubernetes namespace if not exists
await self._create_namespace(namespace, tenant_id)
# Deploy the service
await self._deploy_service(instance, template, config_overrides)
# Generate SSO token
instance.sso_token = await self._generate_sso_token(instance)
# Store instance
self.active_instances[instance_id] = instance
await self._persist_instance(instance)
logger.info(f"Created {service_type} instance {instance_id} for tenant {tenant_id}")
return instance
except Exception as e:
logger.error(f"Failed to create service instance: {e}")
instance.status = 'error'
raise
async def _create_namespace(self, namespace: str, tenant_id: str):
"""Create Kubernetes namespace with proper labeling and network policies"""
if not self.k8s_client:
logger.info(f"Mock: Created namespace {namespace}")
return
v1 = client.CoreV1Api(self.k8s_client)
# Create namespace
namespace_manifest = client.V1Namespace(
metadata=client.V1ObjectMeta(
name=namespace,
labels={
'gt.tenant-id': tenant_id,
'gt.cluster': 'resource',
'gt.isolation': 'tenant'
},
annotations={
'gt.created-by': 'service-manager',
'gt.creation-time': datetime.utcnow().isoformat()
}
)
)
try:
v1.create_namespace(namespace_manifest)
logger.info(f"Created namespace: {namespace}")
except ApiException as e:
if e.status == 409: # Already exists
logger.info(f"Namespace {namespace} already exists")
else:
raise
# Apply network policy for tenant isolation
await self._apply_network_policy(namespace, tenant_id)
async def _apply_network_policy(self, namespace: str, tenant_id: str):
"""Apply network policy for tenant isolation"""
if not self.k8s_client:
logger.info(f"Mock: Applied network policy to {namespace}")
return
networking_v1 = client.NetworkingV1Api(self.k8s_client)
# Network policy that only allows:
# 1. Intra-namespace communication
# 2. Communication to system namespaces (DNS, etc.)
# 3. Egress to external services (for updates, etc.)
network_policy = client.V1NetworkPolicy(
metadata=client.V1ObjectMeta(
name=f"tenant-isolation-{tenant_id}",
namespace=namespace,
labels={'gt.tenant-id': tenant_id}
),
spec=client.V1NetworkPolicySpec(
pod_selector=client.V1LabelSelector(), # All pods in namespace
policy_types=['Ingress', 'Egress'],
ingress=[
# Allow ingress from same namespace
client.V1NetworkPolicyIngressRule(
from_=[client.V1NetworkPolicyPeer(
namespace_selector=client.V1LabelSelector(
match_labels={'name': namespace}
)
)]
),
# Allow ingress from ingress controller
client.V1NetworkPolicyIngressRule(
from_=[client.V1NetworkPolicyPeer(
namespace_selector=client.V1LabelSelector(
match_labels={'name': 'ingress-nginx'}
)
)]
)
],
egress=[
# Allow egress within namespace
client.V1NetworkPolicyEgressRule(
to=[client.V1NetworkPolicyPeer(
namespace_selector=client.V1LabelSelector(
match_labels={'name': namespace}
)
)]
),
# Allow DNS
client.V1NetworkPolicyEgressRule(
to=[client.V1NetworkPolicyPeer(
namespace_selector=client.V1LabelSelector(
match_labels={'name': 'kube-system'}
)
)],
ports=[client.V1NetworkPolicyPort(port=53, protocol='UDP')]
),
# Allow external HTTPS (for updates, etc.)
client.V1NetworkPolicyEgressRule(
ports=[
client.V1NetworkPolicyPort(port=443, protocol='TCP'),
client.V1NetworkPolicyPort(port=80, protocol='TCP')
]
)
]
)
)
try:
networking_v1.create_namespaced_network_policy(
namespace=namespace,
body=network_policy
)
logger.info(f"Applied network policy to namespace: {namespace}")
except ApiException as e:
if e.status == 409: # Already exists
logger.info(f"Network policy already exists in {namespace}")
else:
logger.error(f"Failed to create network policy: {e}")
raise
async def _deploy_service(
self,
instance: ServiceInstance,
template: ServiceTemplate,
config_overrides: Dict[str, Any] = None
):
"""Deploy service to Kubernetes cluster"""
if not self.k8s_client:
logger.info(f"Mock: Deployed {template.service_type} service")
instance.status = 'running'
return
# Prepare environment variables with tenant-specific values
environment = template.environment.copy()
if config_overrides:
environment.update(config_overrides.get('environment', {}))
# Substitute tenant-specific values
env_vars = []
for key, value in environment.items():
substituted_value = value.replace('${TENANT_ID}', instance.tenant_id)
substituted_value = substituted_value.replace('${TENANT_DOMAIN}', f"{instance.tenant_id}.gt2.com")
env_vars.append(client.V1EnvVar(name=key, value=substituted_value))
# Create volumes
volumes = []
volume_mounts = []
for vol_config in template.volumes:
vol_name = f"{vol_config['name']}-{instance.instance_id}"
volumes.append(client.V1Volume(
name=vol_name,
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
claim_name=vol_name
)
))
volume_mounts.append(client.V1VolumeMount(
name=vol_name,
mount_path=vol_config['mountPath']
))
# Create PVCs first
await self._create_persistent_volumes(instance, template)
# Create deployment
deployment = client.V1Deployment(
metadata=client.V1ObjectMeta(
name=instance.deployment_name,
namespace=instance.namespace,
labels={
'app': template.service_type,
'instance': instance.instance_id,
'gt.tenant-id': instance.tenant_id,
'gt.service-type': template.service_type
}
),
spec=client.V1DeploymentSpec(
replicas=1,
selector=client.V1LabelSelector(
match_labels={'instance': instance.instance_id}
),
template=client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(
labels={
'app': template.service_type,
'instance': instance.instance_id,
'gt.tenant-id': instance.tenant_id
}
),
spec=client.V1PodSpec(
containers=[client.V1Container(
name=template.service_type,
image=template.image,
ports=[client.V1ContainerPort(
container_port=template.ports['http']
)],
env=env_vars,
volume_mounts=volume_mounts,
resources=client.V1ResourceRequirements(
limits=template.resource_limits,
requests=template.resource_limits
),
security_context=client.V1SecurityContext(**template.security_context),
liveness_probe=client.V1Probe(
http_get=client.V1HTTPGetAction(
path=template.health_check['path'],
port=template.health_check['port']
),
initial_delay_seconds=template.health_check['initial_delay'],
period_seconds=template.health_check['period']
),
readiness_probe=client.V1Probe(
http_get=client.V1HTTPGetAction(
path=template.health_check['path'],
port=template.health_check['port']
),
initial_delay_seconds=10,
period_seconds=5
)
)],
volumes=volumes,
security_context=client.V1PodSecurityContext(
run_as_non_root=template.security_context.get('runAsNonRoot', True),
fs_group=template.security_context.get('fsGroup', 1000)
)
)
)
)
)
# Deploy to Kubernetes
apps_v1 = client.AppsV1Api(self.k8s_client)
apps_v1.create_namespaced_deployment(
namespace=instance.namespace,
body=deployment
)
# Create service
await self._create_service(instance, template)
# Create ingress
await self._create_ingress(instance, template)
logger.info(f"Deployed {template.service_type} service: {instance.deployment_name}")
async def _create_persistent_volumes(self, instance: ServiceInstance, template: ServiceTemplate):
"""Create persistent volume claims for the service"""
if not self.k8s_client:
return
v1 = client.CoreV1Api(self.k8s_client)
for vol_config in template.volumes:
if 'hostPath' in vol_config: # Skip host path volumes
continue
pvc_name = f"{vol_config['name']}-{instance.instance_id}"
pvc = client.V1PersistentVolumeClaim(
metadata=client.V1ObjectMeta(
name=pvc_name,
namespace=instance.namespace,
labels={
'app': template.service_type,
'instance': instance.instance_id,
'gt.tenant-id': instance.tenant_id
}
),
spec=client.V1PersistentVolumeClaimSpec(
access_modes=['ReadWriteOnce'],
resources=client.V1ResourceRequirements(
requests={'storage': vol_config['size']}
),
storage_class_name='fast-ssd' # Assuming SSD storage class
)
)
try:
v1.create_namespaced_persistent_volume_claim(
namespace=instance.namespace,
body=pvc
)
logger.info(f"Created PVC: {pvc_name}")
except ApiException as e:
if e.status != 409: # Ignore if already exists
raise
async def _create_service(self, instance: ServiceInstance, template: ServiceTemplate):
"""Create Kubernetes service for the instance"""
if not self.k8s_client:
return
v1 = client.CoreV1Api(self.k8s_client)
service = client.V1Service(
metadata=client.V1ObjectMeta(
name=instance.service_name,
namespace=instance.namespace,
labels={
'app': template.service_type,
'instance': instance.instance_id,
'gt.tenant-id': instance.tenant_id
}
),
spec=client.V1ServiceSpec(
selector={'instance': instance.instance_id},
ports=[client.V1ServicePort(
port=80,
target_port=template.ports['http'],
protocol='TCP'
)],
type='ClusterIP'
)
)
v1.create_namespaced_service(
namespace=instance.namespace,
body=service
)
logger.info(f"Created service: {instance.service_name}")
async def _create_ingress(self, instance: ServiceInstance, template: ServiceTemplate):
"""Create ingress for external access with TLS"""
if not self.k8s_client:
return
networking_v1 = client.NetworkingV1Api(self.k8s_client)
hostname = f"{template.service_type}.{instance.tenant_id}.gt2.com"
ingress = client.V1Ingress(
metadata=client.V1ObjectMeta(
name=instance.ingress_name,
namespace=instance.namespace,
labels={
'app': template.service_type,
'instance': instance.instance_id,
'gt.tenant-id': instance.tenant_id
},
annotations={
'kubernetes.io/ingress.class': 'nginx',
'cert-manager.io/cluster-issuer': 'letsencrypt-prod',
'nginx.ingress.kubernetes.io/ssl-redirect': 'true',
'nginx.ingress.kubernetes.io/force-ssl-redirect': 'true',
'nginx.ingress.kubernetes.io/auth-url': f'https://auth.{instance.tenant_id}.gt2.com/auth',
'nginx.ingress.kubernetes.io/auth-signin': f'https://auth.{instance.tenant_id}.gt2.com/signin'
}
),
spec=client.V1IngressSpec(
tls=[client.V1IngressTLS(
hosts=[hostname],
secret_name=f"{template.service_type}-tls-{instance.instance_id}"
)],
rules=[client.V1IngressRule(
host=hostname,
http=client.V1HTTPIngressRuleValue(
paths=[client.V1HTTPIngressPath(
path='/',
path_type='Prefix',
backend=client.V1IngressBackend(
service=client.V1IngressServiceBackend(
name=instance.service_name,
port=client.V1ServiceBackendPort(number=80)
)
)
)]
)
)]
)
)
networking_v1.create_namespaced_ingress(
namespace=instance.namespace,
body=ingress
)
logger.info(f"Created ingress: {instance.ingress_name} for {hostname}")
async def _get_available_port(self) -> int:
"""Get next available port for service"""
used_ports = {instance.external_port for instance in self.active_instances.values()}
port = 30000 # Start from NodePort range
while port in used_ports:
port += 1
return port
async def _generate_sso_token(self, instance: ServiceInstance) -> str:
"""Generate SSO token for iframe embedding"""
token_data = {
'tenant_id': instance.tenant_id,
'service_type': instance.service_type,
'instance_id': instance.instance_id,
'expires_at': (datetime.utcnow() + timedelta(hours=24)).isoformat(),
'permissions': ['read', 'write', 'admin']
}
# Encrypt the token data
encrypted_token = encrypt_data(json.dumps(token_data))
return encrypted_token.decode('utf-8')
async def get_service_instance(self, instance_id: str) -> Optional[ServiceInstance]:
"""Get service instance by ID"""
return self.active_instances.get(instance_id)
async def list_tenant_instances(self, tenant_id: str) -> List[ServiceInstance]:
"""List all service instances for a tenant"""
return [
instance for instance in self.active_instances.values()
if instance.tenant_id == tenant_id
]
async def stop_service_instance(self, instance_id: str) -> bool:
"""Stop a running service instance"""
instance = self.active_instances.get(instance_id)
if not instance:
return False
try:
instance.status = 'stopping'
if self.k8s_client:
# Delete Kubernetes resources
await self._cleanup_kubernetes_resources(instance)
instance.status = 'stopped'
logger.info(f"Stopped service instance: {instance_id}")
return True
except Exception as e:
logger.error(f"Failed to stop instance {instance_id}: {e}")
instance.status = 'error'
return False
async def _cleanup_kubernetes_resources(self, instance: ServiceInstance):
"""Clean up all Kubernetes resources for an instance"""
if not self.k8s_client:
return
apps_v1 = client.AppsV1Api(self.k8s_client)
v1 = client.CoreV1Api(self.k8s_client)
networking_v1 = client.NetworkingV1Api(self.k8s_client)
try:
# Delete deployment
apps_v1.delete_namespaced_deployment(
name=instance.deployment_name,
namespace=instance.namespace,
body=client.V1DeleteOptions()
)
# Delete service
v1.delete_namespaced_service(
name=instance.service_name,
namespace=instance.namespace,
body=client.V1DeleteOptions()
)
# Delete ingress
networking_v1.delete_namespaced_ingress(
name=instance.ingress_name,
namespace=instance.namespace,
body=client.V1DeleteOptions()
)
# Delete PVCs (optional - may want to preserve data)
# Note: In production, you might want to keep PVCs for data persistence
logger.info(f"Cleaned up Kubernetes resources for: {instance.instance_id}")
except ApiException as e:
logger.error(f"Error cleaning up resources: {e}")
raise
async def get_service_health(self, instance_id: str) -> Dict[str, Any]:
"""Get health status of a service instance"""
instance = self.active_instances.get(instance_id)
if not instance:
return {'status': 'not_found'}
if not self.k8s_client:
return {
'status': 'healthy',
'instance_status': instance.status,
'endpoint': instance.endpoint_url,
'last_check': datetime.utcnow().isoformat()
}
# Check Kubernetes pod status
v1 = client.CoreV1Api(self.k8s_client)
try:
pods = v1.list_namespaced_pod(
namespace=instance.namespace,
label_selector=f'instance={instance.instance_id}'
)
if not pods.items:
return {
'status': 'no_pods',
'instance_status': instance.status
}
pod = pods.items[0]
pod_status = 'unknown'
if pod.status.phase == 'Running':
# Check container status
if pod.status.container_statuses:
container_status = pod.status.container_statuses[0]
if container_status.ready:
pod_status = 'healthy'
else:
pod_status = 'unhealthy'
else:
pod_status = 'starting'
elif pod.status.phase == 'Pending':
pod_status = 'starting'
elif pod.status.phase == 'Failed':
pod_status = 'failed'
# Update instance heartbeat
instance.last_heartbeat = datetime.utcnow()
return {
'status': pod_status,
'instance_status': instance.status,
'pod_phase': pod.status.phase,
'endpoint': instance.endpoint_url,
'last_check': datetime.utcnow().isoformat(),
'restart_count': pod.status.container_statuses[0].restart_count if pod.status.container_statuses else 0
}
except ApiException as e:
logger.error(f"Failed to get health for {instance_id}: {e}")
return {
'status': 'error',
'error': str(e),
'instance_status': instance.status
}
async def _persist_instance(self, instance: ServiceInstance):
"""Persist instance data to disk"""
instance_file = self.storage_path / f"{instance.instance_id}.json"
with open(instance_file, 'w') as f:
json.dump(instance.to_dict(), f, indent=2)
def _load_persistent_instances(self):
"""Load persistent instances from disk on startup"""
if not self.storage_path.exists():
return
for instance_file in self.storage_path.glob("*.json"):
try:
with open(instance_file, 'r') as f:
data = json.load(f)
# Reconstruct instance object
instance = ServiceInstance(
instance_id=data['instance_id'],
tenant_id=data['tenant_id'],
service_type=data['service_type'],
status=data['status'],
endpoint_url=data['endpoint_url'],
internal_port=data['internal_port'],
external_port=data['external_port'],
namespace=data['namespace'],
deployment_name=data['deployment_name'],
service_name=data['service_name'],
ingress_name=data['ingress_name'],
sso_token=data.get('sso_token'),
created_at=datetime.fromisoformat(data['created_at']),
last_heartbeat=datetime.fromisoformat(data['last_heartbeat']),
resource_usage=data.get('resource_usage', {})
)
self.active_instances[instance.instance_id] = instance
logger.info(f"Loaded persistent instance: {instance.instance_id}")
except Exception as e:
logger.error(f"Failed to load instance from {instance_file}: {e}")
async def cleanup_orphaned_resources(self):
"""Clean up orphaned Kubernetes resources"""
if not self.k8s_client:
return
logger.info("Starting cleanup of orphaned resources...")
# This would implement logic to find and clean up:
# 1. Deployments without corresponding instances
# 2. Services without deployments
# 3. Unused PVCs
# 4. Expired certificates
# Implementation would query Kubernetes for resources with GT labels
# and cross-reference with active instances
logger.info("Cleanup completed")

View File

@@ -0,0 +1,8 @@
"""
GT 2.0 Resource Cluster - Utilities Package
Common utilities for encryption, validation, and helper functions
"""
from .encryption import encrypt_data, decrypt_data
__all__ = ["encrypt_data", "decrypt_data"]

View File

@@ -0,0 +1,73 @@
"""
GT 2.0 Resource Cluster - Encryption Utilities
Secure data encryption for SSO tokens and sensitive data
"""
import base64
import os
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from typing import Union
import logging
logger = logging.getLogger(__name__)
class EncryptionManager:
"""Handles encryption and decryption of sensitive data"""
def __init__(self):
self._key = None
self._fernet = None
self._initialize_encryption()
def _initialize_encryption(self):
"""Initialize encryption key from environment or generate new one"""
# Get encryption key from environment or generate new one
key_material = os.environ.get("GT_ENCRYPTION_KEY", "default-dev-key-change-in-production")
# Derive a proper encryption key using PBKDF2
salt = b"GT2.0-Resource-Cluster-Salt" # Fixed salt for consistency
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(key_material.encode()))
self._key = key
self._fernet = Fernet(key)
logger.info("Encryption manager initialized")
def encrypt(self, data: Union[str, bytes]) -> bytes:
"""Encrypt data and return base64 encoded result"""
if isinstance(data, str):
data = data.encode('utf-8')
encrypted = self._fernet.encrypt(data)
return base64.urlsafe_b64encode(encrypted)
def decrypt(self, encrypted_data: Union[str, bytes]) -> str:
"""Decrypt base64 encoded data and return string"""
if isinstance(encrypted_data, str):
encrypted_data = encrypted_data.encode('utf-8')
# Decode from base64 first
decoded = base64.urlsafe_b64decode(encrypted_data)
# Decrypt
decrypted = self._fernet.decrypt(decoded)
return decrypted.decode('utf-8')
# Global encryption manager instance
_encryption_manager = EncryptionManager()
def encrypt_data(data: Union[str, bytes]) -> bytes:
"""Encrypt data using global encryption manager"""
return _encryption_manager.encrypt(data)
def decrypt_data(encrypted_data: Union[str, bytes]) -> str:
"""Decrypt data using global encryption manager"""
return _encryption_manager.decrypt(encrypted_data)

View File

@@ -0,0 +1,54 @@
# Python cache and build artifacts
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Testing
.pytest_cache/
.coverage
htmlcov/
.tox/
.hypothesis/
# Environment
.env
.env.*
.venv
env/
venv/
ENV/
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# Version control
.git/
.gitignore
# Documentation
README.md
*.md
# Logs
*.log

View File

@@ -0,0 +1,72 @@
[tool.poetry]
name = "gt2-resource-cluster"
version = "1.0.0"
description = "GT 2.0 Resource Cluster - Centralized AI resource management with HA support"
authors = ["GT Edge AI"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.11"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.pydocstyle]
convention = "google"
add-ignore = ["D100", "D104"] # Allow missing docstrings in __init__.py
match = "(?!test_).*\\.py" # Exclude test files
[tool.pytest.ini_options]
minversion = "7.0"
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = [
"--cov=app",
"--cov-report=html",
"--cov-report=term-missing",
"--cov-fail-under=80",
"--strict-markers",
"-v",
]
markers = [
"unit: Fast isolated tests (<100ms)",
"integration: Cross-service tests",
"slow: Long-running tests (>1s)",
"security: Security-focused tests",
]
asyncio_mode = "auto"
[tool.black]
line-length = 100
target-version = ['py311']
[tool.mypy]
python_version = "3.11"
ignore_missing_imports = true
strict_optional = true
[tool.coverage.run]
source = ["app"]
omit = [
"*/tests/*",
"*/migrations/*",
"*/venv/*",
"*/env/*",
]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise AssertionError",
"raise NotImplementedError",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]
[tool.bandit]
exclude_dirs = ["tests", "migrations", "venv", ".venv"]
skips = ["B101", "B601"] # B101=assert_used, B601=shell_injection (for subprocess)

View File

@@ -0,0 +1,34 @@
[tool:pytest]
minversion = 6.0
addopts =
-ra
--strict-markers
--strict-config
--cov=app
--cov-report=term-missing:skip-covered
--cov-report=html:htmlcov
--cov-report=xml
--cov-fail-under=80
-p no:warnings
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
markers =
slow: marks tests as slow
integration: marks tests as integration tests
unit: marks tests as unit tests
security: marks tests as security-focused
model_service: marks tests for model management service
groq_proxy: marks tests for groq proxy with haproxy
consul: marks tests for consul service registry
asyncio_mode = auto
env =
SECRET_KEY = test-secret-key-for-testing-only
JWT_SECRET = test-jwt-secret-for-testing-only
GROQ_API_KEY = test-groq-api-key
CONSUL_HOST = localhost
CONSUL_PORT = 8500
HAPROXY_GROQ_ENDPOINT = http://test-haproxy:8000
REDIS_URL = redis://localhost:6379/15
DEBUG = True

View File

@@ -0,0 +1,10 @@
# GT 2.0 Resource Cluster Development Dependencies
# Install with: pip install -r requirements-dev.txt
-r requirements.txt
# Testing
pytest==7.4.3
pytest-asyncio==0.21.1
pytest-cov==4.1.0
respx==0.20.2 # httpx mocking library

View File

@@ -0,0 +1,12 @@
# Testing dependencies for GT 2.0 Resource Cluster
pytest==7.4.3
pytest-asyncio==0.21.1
pytest-mock==3.12.0
pytest-cov==4.1.0
httpx==0.25.2
factory-boy==3.3.0
faker==20.1.0
freezegun==1.2.2
pytest-env==1.1.3
pytest-xdist==3.3.1
aiosqlite==0.19.0

View File

@@ -0,0 +1,52 @@
# GT 2.0 Resource Cluster Requirements (Production)
# FastAPI framework and dependencies
fastapi==0.121.2
uvicorn[standard]==0.38.0
python-multipart==0.0.20
# Async and networking
httpx==0.28.1
aiohttp==3.13.2
websockets==12.0
# Security and authentication
python-jose[cryptography]==3.4.0
passlib[bcrypt]==1.7.4
bcrypt==4.1.3
cryptography==44.0.1
PyJWT==2.10.1
# Database
sqlalchemy==2.0.44
asyncpg==0.29.0
# LLM and AI integrations
groq==0.34.1
openai==1.6.1
transformers>=4.35.0 # BGE-M3 tokenizer for accurate embedding token counting
# Document processing
pypdf==6.4.1
python-docx==1.1.0
markdown==3.5.1
beautifulsoup4==4.12.2
langchain-text-splitters==0.3.9
# Vector processing (numpy needed for transformers)
numpy==1.24.4
# Service discovery and load balancing
haproxy-stats==1.5
python-consul==1.1.0
# Monitoring and observability
prometheus-client==0.19.0
# Configuration and utilities
pydantic==2.12.4
pydantic-settings==2.1.0
python-dotenv==1.0.0
pyyaml==6.0.1
aiofiles==23.2.1