- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2 - Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2 - Made more general-purpose (flexible targets, expanded tools) - Added nemotron-mini-agent.csv for fast local inference via Ollama - Added nemotron-agent.csv for advanced reasoning via Ollama - Added wiki page: Projects for NVIDIA NIMs and Nemotron
424 lines
14 KiB
Python
424 lines
14 KiB
Python
"""
|
|
GT 2.0 MCP Tool Executor
|
|
|
|
Handles execution of MCP tools from agents. This is the main endpoint
|
|
that receives tool calls from the tenant backend and routes them to
|
|
the appropriate MCP servers with proper authentication and rate limiting.
|
|
"""
|
|
|
|
from typing import Dict, Any, List, Optional, Union
|
|
from fastapi import APIRouter, HTTPException, Header
|
|
from pydantic import BaseModel, Field
|
|
import logging
|
|
import asyncio
|
|
from datetime import datetime
|
|
|
|
# Removed: from app.core.security import verify_capability_token
|
|
from app.services.mcp_rag_server import mcp_rag_server
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/mcp", tags=["mcp_execution"])
|
|
|
|
|
|
# Request/Response Models
|
|
class MCPToolCall(BaseModel):
|
|
"""MCP tool call request"""
|
|
tool_name: str = Field(..., description="Name of the tool to execute")
|
|
server_name: str = Field(..., description="MCP server that provides the tool")
|
|
parameters: Dict[str, Any] = Field(..., description="Tool parameters")
|
|
|
|
|
|
class MCPToolResult(BaseModel):
|
|
"""MCP tool execution result"""
|
|
success: bool
|
|
tool_name: str
|
|
server_name: str
|
|
execution_time_ms: float
|
|
result: Dict[str, Any]
|
|
error: Optional[str] = None
|
|
timestamp: str
|
|
|
|
|
|
class MCPBatchRequest(BaseModel):
|
|
"""Request for executing multiple MCP tools"""
|
|
tool_calls: List[MCPToolCall] = Field(..., min_items=1, max_items=10)
|
|
|
|
|
|
class MCPBatchResponse(BaseModel):
|
|
"""Response for batch tool execution"""
|
|
results: List[MCPToolResult]
|
|
success_count: int
|
|
error_count: int
|
|
total_execution_time_ms: float
|
|
|
|
|
|
# Rate limiting (simple in-memory counter)
|
|
_rate_limits = {}
|
|
|
|
|
|
def check_rate_limit(user_id: str, server_name: str) -> bool:
|
|
"""Simple rate limiting check"""
|
|
# TODO: Implement proper rate limiting with Redis or similar
|
|
key = f"{user_id}:{server_name}"
|
|
current_time = datetime.now().timestamp()
|
|
|
|
if key not in _rate_limits:
|
|
_rate_limits[key] = []
|
|
|
|
# Remove old entries (older than 1 minute)
|
|
_rate_limits[key] = [t for t in _rate_limits[key] if current_time - t < 60]
|
|
|
|
# Check if under limit (60 requests per minute)
|
|
if len(_rate_limits[key]) >= 60:
|
|
return False
|
|
|
|
# Add current request
|
|
_rate_limits[key].append(current_time)
|
|
return True
|
|
|
|
|
|
@router.post("/tool", response_model=MCPToolResult)
|
|
async def execute_mcp_tool(
|
|
request: MCPToolCall,
|
|
x_tenant_domain: str = Header(..., description="Tenant domain for isolation"),
|
|
x_user_id: str = Header(..., description="User ID for authorization"),
|
|
agent_context: Optional[Dict[str, Any]] = None
|
|
):
|
|
"""
|
|
Execute a single MCP tool.
|
|
|
|
This is the main endpoint that agents use to execute MCP tools.
|
|
It handles rate limiting and routing to the appropriate MCP server.
|
|
User authentication is handled by the tenant backend before reaching here.
|
|
"""
|
|
start_time = datetime.now()
|
|
|
|
try:
|
|
# Validate required headers
|
|
if not x_user_id or not x_tenant_domain:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Missing required authentication headers"
|
|
)
|
|
|
|
# Check rate limiting
|
|
if not check_rate_limit(x_user_id, request.server_name):
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail="Rate limit exceeded for MCP server"
|
|
)
|
|
|
|
# Route to appropriate MCP server (no capability token needed)
|
|
if request.server_name == "rag_server":
|
|
result = await mcp_rag_server.handle_tool_call(
|
|
tool_name=request.tool_name,
|
|
parameters=request.parameters,
|
|
tenant_domain=x_tenant_domain,
|
|
user_id=x_user_id,
|
|
agent_context=agent_context
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Unknown MCP server: {request.server_name}"
|
|
)
|
|
|
|
# Calculate execution time
|
|
end_time = datetime.now()
|
|
execution_time = (end_time - start_time).total_seconds() * 1000
|
|
|
|
# Check if tool execution was successful
|
|
success = "error" not in result
|
|
error_message = result.get("error") if not success else None
|
|
|
|
logger.info(f"🔧 MCP Tool executed: {request.tool_name} ({execution_time:.2f}ms) - {'✅' if success else '❌'}")
|
|
|
|
return MCPToolResult(
|
|
success=success,
|
|
tool_name=request.tool_name,
|
|
server_name=request.server_name,
|
|
execution_time_ms=execution_time,
|
|
result=result,
|
|
error=error_message,
|
|
timestamp=end_time.isoformat()
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error executing MCP tool {request.tool_name}: {e}")
|
|
|
|
end_time = datetime.now()
|
|
execution_time = (end_time - start_time).total_seconds() * 1000
|
|
|
|
return MCPToolResult(
|
|
success=False,
|
|
tool_name=request.tool_name,
|
|
server_name=request.server_name,
|
|
execution_time_ms=execution_time,
|
|
result={},
|
|
error=f"Tool execution failed: {str(e)}",
|
|
timestamp=end_time.isoformat()
|
|
)
|
|
|
|
|
|
class MCPExecuteRequest(BaseModel):
|
|
"""Direct execution request format used by RAG orchestrator"""
|
|
server_id: str = Field(..., description="Server ID (rag_server)")
|
|
tool_name: str = Field(..., description="Tool name to execute")
|
|
parameters: Dict[str, Any] = Field(..., description="Tool parameters")
|
|
tenant_domain: str = Field(..., description="Tenant domain")
|
|
user_id: str = Field(..., description="User ID")
|
|
agent_context: Optional[Dict[str, Any]] = Field(None, description="Agent context with dataset info")
|
|
|
|
|
|
@router.post("/execute")
|
|
async def execute_mcp_direct(request: MCPExecuteRequest):
|
|
"""
|
|
Direct execution endpoint used by RAG orchestrator.
|
|
Simplified without capability tokens - uses user context for authorization.
|
|
"""
|
|
logger.info(f"🔧 Direct MCP execution request: server={request.server_id}, tool={request.tool_name}, tenant={request.tenant_domain}, user={request.user_id}")
|
|
logger.debug(f"📝 Tool parameters: {request.parameters}")
|
|
|
|
try:
|
|
# Map server_id to server_name
|
|
server_mapping = {
|
|
"rag_server": "rag_server"
|
|
}
|
|
|
|
server_name = server_mapping.get(request.server_id)
|
|
if not server_name:
|
|
logger.error(f"❌ Unknown server_id: {request.server_id}")
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unknown server_id: {request.server_id}"
|
|
)
|
|
|
|
logger.info(f"🎯 Mapped server_id '{request.server_id}' → server_name '{server_name}'")
|
|
|
|
# Create simplified tool call request
|
|
tool_call = MCPToolCall(
|
|
tool_name=request.tool_name,
|
|
server_name=server_name,
|
|
parameters=request.parameters
|
|
)
|
|
|
|
# Execute the tool with agent context
|
|
result = await execute_mcp_tool(
|
|
request=tool_call,
|
|
x_tenant_domain=request.tenant_domain,
|
|
x_user_id=request.user_id,
|
|
agent_context=request.agent_context
|
|
)
|
|
|
|
# Return result in format expected by RAG orchestrator
|
|
if result.success:
|
|
return result.result
|
|
else:
|
|
return {
|
|
"success": False,
|
|
"error": result.error
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Direct MCP execution failed: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": "MCP execution failed"
|
|
}
|
|
|
|
|
|
@router.post("/batch", response_model=MCPBatchResponse)
|
|
async def execute_mcp_batch(
|
|
request: MCPBatchRequest,
|
|
x_tenant_domain: str = Header(..., description="Tenant domain for isolation"),
|
|
x_user_id: str = Header(..., description="User ID for authorization")
|
|
):
|
|
"""
|
|
Execute multiple MCP tools in batch.
|
|
|
|
Useful for agents that need to call multiple tools simultaneously
|
|
for more efficient execution.
|
|
"""
|
|
batch_start_time = datetime.now()
|
|
|
|
try:
|
|
# Validate required headers
|
|
if not x_user_id or not x_tenant_domain:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Missing required authentication headers"
|
|
)
|
|
|
|
# Execute all tool calls concurrently
|
|
tasks = []
|
|
for tool_call in request.tool_calls:
|
|
# Create individual tool call request
|
|
individual_request = MCPToolCall(
|
|
tool_name=tool_call.tool_name,
|
|
server_name=tool_call.server_name,
|
|
parameters=tool_call.parameters
|
|
)
|
|
|
|
# Create task for concurrent execution
|
|
task = execute_mcp_tool(
|
|
request=individual_request,
|
|
x_tenant_domain=x_tenant_domain,
|
|
x_user_id=x_user_id
|
|
)
|
|
tasks.append(task)
|
|
|
|
# Execute all tools concurrently
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Process results
|
|
tool_results = []
|
|
success_count = 0
|
|
error_count = 0
|
|
|
|
for result in results:
|
|
if isinstance(result, Exception):
|
|
# Handle exceptions from individual tool calls
|
|
tool_results.append(MCPToolResult(
|
|
success=False,
|
|
tool_name="unknown",
|
|
server_name="unknown",
|
|
execution_time_ms=0,
|
|
result={},
|
|
error=str(result),
|
|
timestamp=datetime.now().isoformat()
|
|
))
|
|
error_count += 1
|
|
else:
|
|
tool_results.append(result)
|
|
if result.success:
|
|
success_count += 1
|
|
else:
|
|
error_count += 1
|
|
|
|
# Calculate total execution time
|
|
batch_end_time = datetime.now()
|
|
total_execution_time = (batch_end_time - batch_start_time).total_seconds() * 1000
|
|
|
|
return MCPBatchResponse(
|
|
results=tool_results,
|
|
success_count=success_count,
|
|
error_count=error_count,
|
|
total_execution_time_ms=total_execution_time
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error executing MCP batch: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Batch execution failed: {str(e)}")
|
|
|
|
|
|
@router.post("/rag/{tool_name}")
|
|
async def execute_rag_tool(
|
|
tool_name: str,
|
|
parameters: Dict[str, Any],
|
|
x_tenant_domain: Optional[str] = Header(None),
|
|
x_user_id: Optional[str] = Header(None)
|
|
):
|
|
"""
|
|
Direct endpoint for executing RAG tools.
|
|
|
|
Convenience endpoint for common RAG operations without
|
|
needing to specify server name.
|
|
"""
|
|
# Create standard tool call request
|
|
tool_call = MCPToolCall(
|
|
tool_name=tool_name,
|
|
server_name="rag_server",
|
|
parameters=parameters
|
|
)
|
|
|
|
return await execute_mcp_tool(
|
|
request=tool_call,
|
|
x_tenant_domain=x_tenant_domain,
|
|
x_user_id=x_user_id
|
|
)
|
|
|
|
|
|
@router.post("/conversation/{tool_name}")
|
|
async def execute_conversation_tool(
|
|
tool_name: str,
|
|
parameters: Dict[str, Any],
|
|
x_tenant_domain: Optional[str] = Header(None),
|
|
x_user_id: Optional[str] = Header(None)
|
|
):
|
|
"""
|
|
Direct endpoint for executing conversation search tools.
|
|
|
|
Convenience endpoint for common conversation search operations
|
|
without needing to specify server name.
|
|
"""
|
|
# Create standard tool call request
|
|
tool_call = MCPToolCall(
|
|
tool_name=tool_name,
|
|
server_name="conversation_server",
|
|
parameters=parameters
|
|
)
|
|
|
|
return await execute_mcp_tool(
|
|
request=tool_call,
|
|
x_tenant_domain=x_tenant_domain,
|
|
x_user_id=x_user_id
|
|
)
|
|
|
|
|
|
@router.get("/status")
|
|
async def get_executor_status(
|
|
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID", description="Tenant ID for context")
|
|
):
|
|
"""
|
|
Get status of the MCP executor and connected servers.
|
|
|
|
Returns health information and statistics about MCP tool execution.
|
|
"""
|
|
try:
|
|
# Calculate basic statistics
|
|
total_requests = sum(len(requests) for requests in _rate_limits.values())
|
|
active_users = len(_rate_limits)
|
|
|
|
return {
|
|
"status": "healthy",
|
|
"timestamp": datetime.now().isoformat(),
|
|
"statistics": {
|
|
"total_requests_last_hour": total_requests, # Approximate
|
|
"active_users": active_users,
|
|
"available_servers": 2, # RAG and conversation servers
|
|
"total_tools": len(mcp_rag_server.available_tools) + len(mcp_conversation_server.available_tools)
|
|
},
|
|
"servers": {
|
|
"rag_server": {
|
|
"status": "healthy",
|
|
"tools_count": len(mcp_rag_server.available_tools),
|
|
"tools": mcp_rag_server.available_tools
|
|
},
|
|
"conversation_server": {
|
|
"status": "healthy",
|
|
"tools_count": len(mcp_conversation_server.available_tools),
|
|
"tools": mcp_conversation_server.available_tools
|
|
}
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting executor status: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to get status: {str(e)}")
|
|
|
|
|
|
# Health check endpoint
|
|
@router.get("/health")
|
|
async def health_check():
|
|
"""Simple health check endpoint"""
|
|
return {
|
|
"status": "healthy",
|
|
"timestamp": datetime.now().isoformat(),
|
|
"service": "mcp_executor"
|
|
} |