Files
gt-ai-os-community/apps/resource-cluster/app/api/v1/mcp_executor.py
HackWeasel 310491a557 GT AI OS Community v2.0.33 - Add NVIDIA NIM and Nemotron agents
- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2
- Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2
  - Made more general-purpose (flexible targets, expanded tools)
- Added nemotron-mini-agent.csv for fast local inference via Ollama
- Added nemotron-agent.csv for advanced reasoning via Ollama
- Added wiki page: Projects for NVIDIA NIMs and Nemotron
2025-12-12 17:47:14 -05:00

424 lines
14 KiB
Python

"""
GT 2.0 MCP Tool Executor
Handles execution of MCP tools from agents. This is the main endpoint
that receives tool calls from the tenant backend and routes them to
the appropriate MCP servers with proper authentication and rate limiting.
"""
from typing import Dict, Any, List, Optional, Union
from fastapi import APIRouter, HTTPException, Header
from pydantic import BaseModel, Field
import logging
import asyncio
from datetime import datetime
# Removed: from app.core.security import verify_capability_token
from app.services.mcp_rag_server import mcp_rag_server
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/mcp", tags=["mcp_execution"])
# Request/Response Models
class MCPToolCall(BaseModel):
"""MCP tool call request"""
tool_name: str = Field(..., description="Name of the tool to execute")
server_name: str = Field(..., description="MCP server that provides the tool")
parameters: Dict[str, Any] = Field(..., description="Tool parameters")
class MCPToolResult(BaseModel):
"""MCP tool execution result"""
success: bool
tool_name: str
server_name: str
execution_time_ms: float
result: Dict[str, Any]
error: Optional[str] = None
timestamp: str
class MCPBatchRequest(BaseModel):
"""Request for executing multiple MCP tools"""
tool_calls: List[MCPToolCall] = Field(..., min_items=1, max_items=10)
class MCPBatchResponse(BaseModel):
"""Response for batch tool execution"""
results: List[MCPToolResult]
success_count: int
error_count: int
total_execution_time_ms: float
# Rate limiting (simple in-memory counter)
_rate_limits = {}
def check_rate_limit(user_id: str, server_name: str) -> bool:
"""Simple rate limiting check"""
# TODO: Implement proper rate limiting with Redis or similar
key = f"{user_id}:{server_name}"
current_time = datetime.now().timestamp()
if key not in _rate_limits:
_rate_limits[key] = []
# Remove old entries (older than 1 minute)
_rate_limits[key] = [t for t in _rate_limits[key] if current_time - t < 60]
# Check if under limit (60 requests per minute)
if len(_rate_limits[key]) >= 60:
return False
# Add current request
_rate_limits[key].append(current_time)
return True
@router.post("/tool", response_model=MCPToolResult)
async def execute_mcp_tool(
request: MCPToolCall,
x_tenant_domain: str = Header(..., description="Tenant domain for isolation"),
x_user_id: str = Header(..., description="User ID for authorization"),
agent_context: Optional[Dict[str, Any]] = None
):
"""
Execute a single MCP tool.
This is the main endpoint that agents use to execute MCP tools.
It handles rate limiting and routing to the appropriate MCP server.
User authentication is handled by the tenant backend before reaching here.
"""
start_time = datetime.now()
try:
# Validate required headers
if not x_user_id or not x_tenant_domain:
raise HTTPException(
status_code=400,
detail="Missing required authentication headers"
)
# Check rate limiting
if not check_rate_limit(x_user_id, request.server_name):
raise HTTPException(
status_code=429,
detail="Rate limit exceeded for MCP server"
)
# Route to appropriate MCP server (no capability token needed)
if request.server_name == "rag_server":
result = await mcp_rag_server.handle_tool_call(
tool_name=request.tool_name,
parameters=request.parameters,
tenant_domain=x_tenant_domain,
user_id=x_user_id,
agent_context=agent_context
)
else:
raise HTTPException(
status_code=404,
detail=f"Unknown MCP server: {request.server_name}"
)
# Calculate execution time
end_time = datetime.now()
execution_time = (end_time - start_time).total_seconds() * 1000
# Check if tool execution was successful
success = "error" not in result
error_message = result.get("error") if not success else None
logger.info(f"🔧 MCP Tool executed: {request.tool_name} ({execution_time:.2f}ms) - {'' if success else ''}")
return MCPToolResult(
success=success,
tool_name=request.tool_name,
server_name=request.server_name,
execution_time_ms=execution_time,
result=result,
error=error_message,
timestamp=end_time.isoformat()
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error executing MCP tool {request.tool_name}: {e}")
end_time = datetime.now()
execution_time = (end_time - start_time).total_seconds() * 1000
return MCPToolResult(
success=False,
tool_name=request.tool_name,
server_name=request.server_name,
execution_time_ms=execution_time,
result={},
error=f"Tool execution failed: {str(e)}",
timestamp=end_time.isoformat()
)
class MCPExecuteRequest(BaseModel):
"""Direct execution request format used by RAG orchestrator"""
server_id: str = Field(..., description="Server ID (rag_server)")
tool_name: str = Field(..., description="Tool name to execute")
parameters: Dict[str, Any] = Field(..., description="Tool parameters")
tenant_domain: str = Field(..., description="Tenant domain")
user_id: str = Field(..., description="User ID")
agent_context: Optional[Dict[str, Any]] = Field(None, description="Agent context with dataset info")
@router.post("/execute")
async def execute_mcp_direct(request: MCPExecuteRequest):
"""
Direct execution endpoint used by RAG orchestrator.
Simplified without capability tokens - uses user context for authorization.
"""
logger.info(f"🔧 Direct MCP execution request: server={request.server_id}, tool={request.tool_name}, tenant={request.tenant_domain}, user={request.user_id}")
logger.debug(f"📝 Tool parameters: {request.parameters}")
try:
# Map server_id to server_name
server_mapping = {
"rag_server": "rag_server"
}
server_name = server_mapping.get(request.server_id)
if not server_name:
logger.error(f"❌ Unknown server_id: {request.server_id}")
raise HTTPException(
status_code=400,
detail=f"Unknown server_id: {request.server_id}"
)
logger.info(f"🎯 Mapped server_id '{request.server_id}' → server_name '{server_name}'")
# Create simplified tool call request
tool_call = MCPToolCall(
tool_name=request.tool_name,
server_name=server_name,
parameters=request.parameters
)
# Execute the tool with agent context
result = await execute_mcp_tool(
request=tool_call,
x_tenant_domain=request.tenant_domain,
x_user_id=request.user_id,
agent_context=request.agent_context
)
# Return result in format expected by RAG orchestrator
if result.success:
return result.result
else:
return {
"success": False,
"error": result.error
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Direct MCP execution failed: {e}")
return {
"success": False,
"error": "MCP execution failed"
}
@router.post("/batch", response_model=MCPBatchResponse)
async def execute_mcp_batch(
request: MCPBatchRequest,
x_tenant_domain: str = Header(..., description="Tenant domain for isolation"),
x_user_id: str = Header(..., description="User ID for authorization")
):
"""
Execute multiple MCP tools in batch.
Useful for agents that need to call multiple tools simultaneously
for more efficient execution.
"""
batch_start_time = datetime.now()
try:
# Validate required headers
if not x_user_id or not x_tenant_domain:
raise HTTPException(
status_code=400,
detail="Missing required authentication headers"
)
# Execute all tool calls concurrently
tasks = []
for tool_call in request.tool_calls:
# Create individual tool call request
individual_request = MCPToolCall(
tool_name=tool_call.tool_name,
server_name=tool_call.server_name,
parameters=tool_call.parameters
)
# Create task for concurrent execution
task = execute_mcp_tool(
request=individual_request,
x_tenant_domain=x_tenant_domain,
x_user_id=x_user_id
)
tasks.append(task)
# Execute all tools concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
tool_results = []
success_count = 0
error_count = 0
for result in results:
if isinstance(result, Exception):
# Handle exceptions from individual tool calls
tool_results.append(MCPToolResult(
success=False,
tool_name="unknown",
server_name="unknown",
execution_time_ms=0,
result={},
error=str(result),
timestamp=datetime.now().isoformat()
))
error_count += 1
else:
tool_results.append(result)
if result.success:
success_count += 1
else:
error_count += 1
# Calculate total execution time
batch_end_time = datetime.now()
total_execution_time = (batch_end_time - batch_start_time).total_seconds() * 1000
return MCPBatchResponse(
results=tool_results,
success_count=success_count,
error_count=error_count,
total_execution_time_ms=total_execution_time
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error executing MCP batch: {e}")
raise HTTPException(status_code=500, detail=f"Batch execution failed: {str(e)}")
@router.post("/rag/{tool_name}")
async def execute_rag_tool(
tool_name: str,
parameters: Dict[str, Any],
x_tenant_domain: Optional[str] = Header(None),
x_user_id: Optional[str] = Header(None)
):
"""
Direct endpoint for executing RAG tools.
Convenience endpoint for common RAG operations without
needing to specify server name.
"""
# Create standard tool call request
tool_call = MCPToolCall(
tool_name=tool_name,
server_name="rag_server",
parameters=parameters
)
return await execute_mcp_tool(
request=tool_call,
x_tenant_domain=x_tenant_domain,
x_user_id=x_user_id
)
@router.post("/conversation/{tool_name}")
async def execute_conversation_tool(
tool_name: str,
parameters: Dict[str, Any],
x_tenant_domain: Optional[str] = Header(None),
x_user_id: Optional[str] = Header(None)
):
"""
Direct endpoint for executing conversation search tools.
Convenience endpoint for common conversation search operations
without needing to specify server name.
"""
# Create standard tool call request
tool_call = MCPToolCall(
tool_name=tool_name,
server_name="conversation_server",
parameters=parameters
)
return await execute_mcp_tool(
request=tool_call,
x_tenant_domain=x_tenant_domain,
x_user_id=x_user_id
)
@router.get("/status")
async def get_executor_status(
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID", description="Tenant ID for context")
):
"""
Get status of the MCP executor and connected servers.
Returns health information and statistics about MCP tool execution.
"""
try:
# Calculate basic statistics
total_requests = sum(len(requests) for requests in _rate_limits.values())
active_users = len(_rate_limits)
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"statistics": {
"total_requests_last_hour": total_requests, # Approximate
"active_users": active_users,
"available_servers": 2, # RAG and conversation servers
"total_tools": len(mcp_rag_server.available_tools) + len(mcp_conversation_server.available_tools)
},
"servers": {
"rag_server": {
"status": "healthy",
"tools_count": len(mcp_rag_server.available_tools),
"tools": mcp_rag_server.available_tools
},
"conversation_server": {
"status": "healthy",
"tools_count": len(mcp_conversation_server.available_tools),
"tools": mcp_conversation_server.available_tools
}
}
}
except Exception as e:
logger.error(f"Error getting executor status: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get status: {str(e)}")
# Health check endpoint
@router.get("/health")
async def health_check():
"""Simple health check endpoint"""
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"service": "mcp_executor"
}