GT AI OS Community Edition v2.0.33
Security hardening release addressing CodeQL and Dependabot alerts: - Fix stack trace exposure in error responses - Add SSRF protection with DNS resolution checking - Implement proper URL hostname validation (replaces substring matching) - Add centralized path sanitization to prevent path traversal - Fix ReDoS vulnerability in email validation regex - Improve HTML sanitization in validation utilities - Fix capability wildcard matching in auth utilities - Update glob dependency to address CVE - Add CodeQL suppression comments for verified false positives 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
296
apps/resource-cluster/app/services/model_router.py
Normal file
296
apps/resource-cluster/app/services/model_router.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
GT 2.0 Model Router
|
||||
|
||||
Routes inference requests to appropriate providers based on model registry.
|
||||
Integrates with provider factory for dynamic provider selection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, AsyncIterator
|
||||
from datetime import datetime
|
||||
|
||||
from app.services.model_service import get_model_service
|
||||
from app.providers import get_provider_factory
|
||||
from app.core.backends import get_backend
|
||||
from app.core.exceptions import ProviderError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
"""Routes model requests to appropriate providers"""
|
||||
|
||||
def __init__(self, tenant_id: Optional[str] = None):
|
||||
self.tenant_id = tenant_id
|
||||
# Use default model service for shared model registry (config sync writes to default)
|
||||
# Note: Tenant isolation is handled via capability tokens, not separate databases
|
||||
self.model_service = get_model_service(None)
|
||||
self.provider_factory = None
|
||||
self.backend_cache = {}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize model router"""
|
||||
try:
|
||||
self.provider_factory = await get_provider_factory()
|
||||
logger.info(f"Model router initialized for tenant: {self.tenant_id or 'default'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model router: {e}")
|
||||
raise
|
||||
|
||||
async def route_inference(
|
||||
self,
|
||||
model_id: str,
|
||||
prompt: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4000,
|
||||
stream: bool = False,
|
||||
user_id: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Route inference request to appropriate provider"""
|
||||
|
||||
# Get model configuration from registry
|
||||
model_config = await self.model_service.get_model(model_id)
|
||||
if not model_config:
|
||||
raise ProviderError(f"Model {model_id} not found in registry")
|
||||
|
||||
provider = model_config["provider"]
|
||||
|
||||
# Track model usage
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# Route to configured endpoint (generic routing for any provider)
|
||||
endpoint_url = model_config.get("endpoint")
|
||||
if not endpoint_url:
|
||||
raise ProviderError(f"No endpoint configured for model {model_id}")
|
||||
|
||||
result = await self._route_to_generic_endpoint(
|
||||
endpoint_url, model_id, prompt, messages, temperature, max_tokens, stream, user_id, tenant_id, tools, tool_choice, **kwargs
|
||||
)
|
||||
|
||||
# Calculate latency
|
||||
latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
# Track successful usage
|
||||
await self.model_service.track_model_usage(
|
||||
model_id, success=True, latency_ms=latency_ms
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Track failed usage
|
||||
latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
await self.model_service.track_model_usage(
|
||||
model_id, success=False, latency_ms=latency_ms
|
||||
)
|
||||
logger.error(f"Model routing failed for {model_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _route_to_groq(
|
||||
self,
|
||||
model_id: str,
|
||||
prompt: Optional[str],
|
||||
messages: Optional[list],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stream: bool,
|
||||
user_id: Optional[str],
|
||||
tenant_id: Optional[str],
|
||||
tools: Optional[list],
|
||||
tool_choice: Optional[str],
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Route request to Groq backend"""
|
||||
try:
|
||||
backend = get_backend("groq_proxy")
|
||||
if not backend:
|
||||
raise ProviderError("Groq backend not available")
|
||||
|
||||
if messages:
|
||||
return await backend.execute_inference_with_messages(
|
||||
messages=messages,
|
||||
model=model_id,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=stream,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice
|
||||
)
|
||||
else:
|
||||
return await backend.execute_inference(
|
||||
prompt=prompt,
|
||||
model=model_id,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=stream,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Groq routing failed: {e}")
|
||||
raise ProviderError(f"Groq inference failed: {e}")
|
||||
|
||||
async def _route_to_external(
|
||||
self,
|
||||
model_id: str,
|
||||
prompt: Optional[str],
|
||||
messages: Optional[list],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stream: bool,
|
||||
user_id: Optional[str],
|
||||
tenant_id: Optional[str],
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Route request to external provider"""
|
||||
try:
|
||||
if not self.provider_factory:
|
||||
await self.initialize()
|
||||
|
||||
external_provider = self.provider_factory.get_provider("external")
|
||||
if not external_provider:
|
||||
raise ProviderError("External provider not available")
|
||||
|
||||
# For embedding models
|
||||
if model_id == "bge-m3-embedding":
|
||||
# Convert prompt/messages to text list
|
||||
texts = []
|
||||
if messages:
|
||||
texts = [msg.get("content", "") for msg in messages if msg.get("content")]
|
||||
elif prompt:
|
||||
texts = [prompt]
|
||||
|
||||
return await external_provider.generate_embeddings(
|
||||
model_id=model_id,
|
||||
texts=texts
|
||||
)
|
||||
else:
|
||||
raise ProviderError(f"External model {model_id} not supported for inference")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"External routing failed: {e}")
|
||||
raise ProviderError(f"External inference failed: {e}")
|
||||
|
||||
async def _route_to_openai(
|
||||
self,
|
||||
model_id: str,
|
||||
prompt: Optional[str],
|
||||
messages: Optional[list],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stream: bool,
|
||||
user_id: Optional[str],
|
||||
tenant_id: Optional[str],
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Route request to OpenAI provider"""
|
||||
raise ProviderError("OpenAI provider not implemented - use Groq models instead")
|
||||
|
||||
async def _route_to_generic_endpoint(
|
||||
self,
|
||||
endpoint_url: str,
|
||||
model_id: str,
|
||||
prompt: Optional[str],
|
||||
messages: Optional[list],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stream: bool,
|
||||
user_id: Optional[str],
|
||||
tenant_id: Optional[str],
|
||||
tools: Optional[list],
|
||||
tool_choice: Optional[str],
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Route request to any configured endpoint using OpenAI-compatible API"""
|
||||
import httpx
|
||||
import time
|
||||
|
||||
try:
|
||||
# Build OpenAI-compatible request
|
||||
request_data = {
|
||||
"model": model_id,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# Use messages if provided, otherwise convert prompt to messages
|
||||
if messages:
|
||||
request_data["messages"] = messages
|
||||
elif prompt:
|
||||
request_data["messages"] = [{"role": "user", "content": prompt}]
|
||||
else:
|
||||
raise ProviderError("Either messages or prompt must be provided")
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
request_data["tools"] = tools
|
||||
if tool_choice:
|
||||
request_data["tool_choice"] = tool_choice
|
||||
|
||||
# Add any additional parameters
|
||||
request_data.update(kwargs)
|
||||
|
||||
logger.info(f"Routing request to endpoint: {endpoint_url}")
|
||||
logger.debug(f"Request data: {request_data}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"Endpoint {endpoint_url} returned {response.status_code}: {error_text}")
|
||||
raise ProviderError(f"Endpoint error: {response.status_code} - {error_text}")
|
||||
|
||||
result = response.json()
|
||||
logger.debug(f"Endpoint response: {result}")
|
||||
return result
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request to {endpoint_url} failed: {e}")
|
||||
raise ProviderError(f"Connection to endpoint failed: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Generic endpoint routing failed: {e}")
|
||||
raise ProviderError(f"Inference failed: {str(e)}")
|
||||
|
||||
async def list_available_models(self) -> list:
|
||||
"""List all available models from registry"""
|
||||
# Get all models (deployment status filtering available if needed)
|
||||
models = await self.model_service.list_models()
|
||||
return models
|
||||
|
||||
async def get_model_health(self, model_id: str) -> Dict[str, Any]:
|
||||
"""Check health of specific model"""
|
||||
return await self.model_service.check_model_health(model_id)
|
||||
|
||||
|
||||
# Global model router instances per tenant
|
||||
_model_routers = {}
|
||||
|
||||
|
||||
async def get_model_router(tenant_id: Optional[str] = None) -> ModelRouter:
|
||||
"""Get model router instance for tenant"""
|
||||
global _model_routers
|
||||
|
||||
cache_key = tenant_id or "default"
|
||||
|
||||
if cache_key not in _model_routers:
|
||||
router = ModelRouter(tenant_id)
|
||||
await router.initialize()
|
||||
_model_routers[cache_key] = router
|
||||
|
||||
return _model_routers[cache_key]
|
||||
Reference in New Issue
Block a user