Files
gt-ai-os-community/apps/control-panel-backend/app/api/v1/models.py
HackWeasel 310491a557 GT AI OS Community v2.0.33 - Add NVIDIA NIM and Nemotron agents
- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2
- Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2
  - Made more general-purpose (flexible targets, expanded tools)
- Added nemotron-mini-agent.csv for fast local inference via Ollama
- Added nemotron-agent.csv for advanced reasoning via Ollama
- Added wiki page: Projects for NVIDIA NIMs and Nemotron
2025-12-12 17:47:14 -05:00

1096 lines
40 KiB
Python

"""
Model Management API for GT 2.0 Admin Control Panel
Provides RESTful endpoints for managing AI model configurations.
These endpoints enable the admin UI to configure models that sync
across all resource clusters.
"""
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel, Field
import logging
import re
import ipaddress
from urllib.parse import urlparse
from app.core.database import get_db
from app.services.model_management_service import get_model_management_service
from app.models.model_config import ModelConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/models", tags=["Model Management"])
def is_private_ip(url: str) -> bool:
"""
Check if URL points to a private/internal IP address.
SSRF Protection: Prevents requests to private networks (RFC1918),
localhost, loopback, and other reserved IP ranges.
Also resolves hostnames to check if they point to private IPs.
"""
import socket
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return True
# Check for localhost variants
if hostname in ('localhost', '127.0.0.1', '::1', '0.0.0.0', '0', 'localhost.localdomain'):
return True
# Check RFC1918 and other private ranges
try:
ip = ipaddress.ip_address(hostname)
return ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local
except ValueError:
# It's a hostname, not an IP - resolve it and check
# This prevents DNS rebinding attacks
try:
resolved_ips = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
for family, _, _, _, sockaddr in resolved_ips:
ip_str = sockaddr[0]
try:
ip = ipaddress.ip_address(ip_str)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return True
except ValueError:
continue
return False
except socket.gaierror:
# DNS resolution failed - block to be safe
return True
except Exception:
return True
# Request/Response Models
class ModelSpecifications(BaseModel):
context_window: Optional[int] = None
max_tokens: Optional[int] = None
dimensions: Optional[int] = None # For embedding models
class ModelCost(BaseModel):
per_million_input: float = 0.0
per_million_output: float = 0.0
class ModelStatus(BaseModel):
is_active: bool = True
is_compound: bool = False # Compound models use pass-through pricing from actual usage
class ModelCreateRequest(BaseModel):
model_id: str = Field(..., description="Unique model identifier")
name: str = Field(..., description="Human-readable model name")
version: str = Field(default="1.0", description="Model version")
provider: str = Field(..., description="Provider: groq, external, openai, anthropic")
model_type: str = Field(..., description="Type: llm, embedding, audio, tts, vision")
endpoint: str = Field(..., description="API endpoint URL")
api_key_name: Optional[str] = Field(None, description="Environment variable for API key")
specifications: Optional[ModelSpecifications] = None
capabilities: Dict[str, Any] = Field(default_factory=dict)
cost: Optional[ModelCost] = None
description: Optional[str] = None
config: Dict[str, Any] = Field(default_factory=dict)
status: Optional[ModelStatus] = None
model_config = {"protected_namespaces": ()}
class ModelUpdateRequest(BaseModel):
model_id: Optional[str] = None
name: Optional[str] = None
provider: Optional[str] = None
model_type: Optional[str] = None
endpoint: Optional[str] = None
api_key_name: Optional[str] = None
specifications: Optional[ModelSpecifications] = None
capabilities: Optional[Dict[str, Any]] = None
cost: Optional[ModelCost] = None
description: Optional[str] = None
config: Optional[Dict[str, Any]] = None
status: Optional[ModelStatus] = None
class EndpointUpdateRequest(BaseModel):
endpoint: str = Field(..., description="New endpoint URL")
class EndpointTestRequest(BaseModel):
"""Request body for testing an arbitrary endpoint URL"""
endpoint: str = Field(..., description="Endpoint URL to test")
provider: Optional[str] = Field(None, description="Optional provider hint for specialized testing")
class ModelResponse(BaseModel):
model_id: str
name: str
version: str
provider: str
model_type: str
endpoint: str
api_key_name: Optional[str]
specifications: Dict[str, Any]
capabilities: Dict[str, Any]
cost: Dict[str, float]
description: Optional[str]
config: Dict[str, Any]
status: Dict[str, Any]
usage: Dict[str, Any]
timestamps: Dict[str, str]
class HealthCheckResponse(BaseModel):
healthy: bool
status: Optional[str] = None # "healthy", "degraded", "unhealthy"
latency_ms: Optional[float] = None
error: Optional[str] = None
error_type: Optional[str] = None # connection_error, timeout, auth_failed, server_error
details: Optional[Dict[str, Any]] = None
rate_limit_remaining: Optional[int] = None
rate_limit_reset: Optional[str] = None
inference_validated: Optional[bool] = None # True if actual inference was tested
class BulkHealthResponse(BaseModel):
total_models: int
healthy_models: int
unhealthy_models: int
health_percentage: float
individual_results: Dict[str, Dict[str, Any]]
timestamp: str
@router.get("/", response_model=List[ModelResponse])
async def list_models(
provider: Optional[str] = Query(None, description="Filter by provider"),
model_type: Optional[str] = Query(None, description="Filter by model type"),
active_only: bool = Query(False, description="Show only active models"),
include_stats: bool = Query(False, description="Include real-time statistics"),
db: AsyncSession = Depends(get_db)
):
"""List all model configurations with real-time data"""
try:
service = get_model_management_service(db)
models = await service.list_models(
provider=provider,
model_type=model_type,
active_only=active_only
)
# Get bulk tenant stats if requested to avoid N+1 queries
tenant_stats = {}
if include_stats:
tenant_stats = await service.get_bulk_model_tenant_stats([m.model_id for m in models])
model_responses = []
for model in models:
model_dict = model.to_dict()
# Add real-time statistics if requested
if include_stats:
stats = tenant_stats.get(model.model_id, {"tenant_count": 0, "enabled_tenant_count": 0})
model_dict["tenant_count"] = stats["tenant_count"]
model_dict["enabled_tenant_count"] = stats["enabled_tenant_count"]
model_responses.append(ModelResponse(**model_dict))
return model_responses
except Exception as e:
logger.error(f"Failed to list models: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/", response_model=ModelResponse)
async def create_model(
model_request: ModelCreateRequest,
db: AsyncSession = Depends(get_db)
):
"""Create a new model configuration"""
try:
service = get_model_management_service(db)
# Convert request to dict
model_data = model_request.dict(exclude_none=True)
# Endpoint URL is preserved as provided by user
logger.debug(f"Model {model_data.get('model_id', 'unknown')} endpoint: {model_data.get('endpoint', 'not specified')}")
# Create model
model = await service.register_model(model_data)
# Auto-assign to all existing tenants
assigned_count = await service.auto_assign_model_to_all_tenants(model.model_id)
logger.info(f"Auto-assigned new model {model.model_id} to {assigned_count} tenants")
return ModelResponse(**model.to_dict())
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Failed to create model: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# Provider-specific health endpoints for industry-standard testing
PROVIDER_HEALTH_ENDPOINTS = {
'nvidia': '/v1/health/ready',
'vllm': '/health',
'bge_m3': '/health',
'ollama': '/api/tags',
'ollama-dgx-x86': '/api/tags',
'ollama-macos': '/api/tags',
}
# Latency threshold for degraded status (milliseconds)
LATENCY_DEGRADED_THRESHOLD = 2000
# Test endpoint for arbitrary URLs - MUST BE BEFORE path routes
@router.post("/test-endpoint", response_model=HealthCheckResponse)
async def test_endpoint_url(
request: EndpointTestRequest,
):
"""
Test if an arbitrary endpoint URL is accessible.
This endpoint is used when adding new models to test connectivity
before the model is registered in the system.
Returns health status with three possible states:
- healthy: Endpoint responding normally with acceptable latency
- degraded: Endpoint responding but with high latency (>2000ms)
- unhealthy: Endpoint not responding or returning errors
"""
import httpx
import time
try:
start_time = time.time()
# Validate URL format
parsed = urlparse(request.endpoint)
if not parsed.scheme or not parsed.netloc:
return HealthCheckResponse(
healthy=False,
status="unhealthy",
error="Invalid URL format - must include scheme (http/https) and host",
error_type="invalid_format"
)
# SSRF Protection: Block requests to private/internal IP addresses
if is_private_ip(request.endpoint):
return HealthCheckResponse(
healthy=False,
status="unhealthy",
error="Access to private/internal IP addresses is not allowed",
error_type="invalid_endpoint"
)
# Determine test URL based on provider
base_endpoint = request.endpoint.rstrip('/')
if request.provider and request.provider in PROVIDER_HEALTH_ENDPOINTS:
health_path = PROVIDER_HEALTH_ENDPOINTS[request.provider]
test_url = f"{base_endpoint}{health_path}"
else:
test_url = request.endpoint
async with httpx.AsyncClient(timeout=10.0, follow_redirects=False) as client:
try:
# codeql[py/full-ssrf] URL validated by is_private_ip() check at line 290
response = await client.get(test_url)
latency_ms = (time.time() - start_time) * 1000
# Extract rate limit headers if present
rate_limit_remaining = None
rate_limit_reset = None
if 'x-ratelimit-remaining' in response.headers:
try:
rate_limit_remaining = int(response.headers['x-ratelimit-remaining'])
except (ValueError, TypeError):
pass
if 'x-ratelimit-reset' in response.headers:
rate_limit_reset = response.headers['x-ratelimit-reset']
# Determine health status with degraded state
if response.status_code >= 500:
status = "unhealthy"
healthy = False
error = f"Server error: HTTP {response.status_code}"
error_type = "server_error"
elif response.status_code == 401:
# Auth error but endpoint is reachable
status = "healthy"
healthy = True
error = None
error_type = None
elif response.status_code == 403:
status = "healthy"
healthy = True
error = None
error_type = None
elif response.status_code == 429:
# Rate limited - endpoint works but currently limited
status = "degraded"
healthy = True
error = "Rate limit exceeded"
error_type = "rate_limited"
elif latency_ms > LATENCY_DEGRADED_THRESHOLD:
status = "degraded"
healthy = True
error = f"High latency detected ({latency_ms:.0f}ms > {LATENCY_DEGRADED_THRESHOLD}ms threshold)"
error_type = "high_latency"
else:
status = "healthy"
healthy = True
error = None
error_type = None
return HealthCheckResponse(
healthy=healthy,
status=status,
latency_ms=round(latency_ms, 2),
error=error,
error_type=error_type,
rate_limit_remaining=rate_limit_remaining,
rate_limit_reset=rate_limit_reset,
details={
"status_code": response.status_code,
"endpoint": request.endpoint,
"test_url": test_url,
"provider": request.provider
}
)
except httpx.ConnectError as e:
latency_ms = (time.time() - start_time) * 1000
return HealthCheckResponse(
healthy=False,
status="unhealthy",
latency_ms=round(latency_ms, 2),
error=f"Connection failed: Unable to connect to {parsed.netloc}",
error_type="connection_error",
details={"endpoint": request.endpoint, "test_url": test_url}
)
except httpx.TimeoutException:
return HealthCheckResponse(
healthy=False,
status="unhealthy",
latency_ms=10000, # Timeout is 10s
error="Connection timed out after 10 seconds",
error_type="timeout",
details={"endpoint": request.endpoint, "test_url": test_url}
)
except Exception as e:
logger.error(f"Failed to test endpoint {request.endpoint}: {e}")
return HealthCheckResponse(
healthy=False,
status="unhealthy",
error="An internal error occurred while testing the endpoint",
error_type="unknown",
details={"endpoint": request.endpoint}
)
# Configuration endpoints - MUST BE BEFORE path routes
@router.get("/configs/all")
async def get_all_model_configs(
db: AsyncSession = Depends(get_db)
):
"""Get all model configurations for resource cluster sync"""
try:
service = get_model_management_service(db)
models = await service.list_models(active_only=True)
configs = []
for model in models:
config = model.to_dict()
# Add resource cluster specific fields
config["sync_timestamp"] = config["timestamps"]["updated_at"]
configs.append(config)
return {
"configs": configs,
"total": len(configs),
"timestamp": configs[0]["sync_timestamp"] if configs else None
}
except Exception as e:
logger.error(f"Failed to get model configs: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/stats/overview")
async def get_models_overview_stats(
db: AsyncSession = Depends(get_db)
):
"""Get overview statistics for all models"""
try:
service = get_model_management_service(db)
# Get all models
all_models = await service.list_models()
active_models = await service.list_models(active_only=True)
# Calculate basic stats
total_models = len(all_models)
active_count = len(active_models)
# Group by provider
provider_stats = {}
health_stats = {"healthy": 0, "unhealthy": 0, "unknown": 0}
for model in active_models:
# Provider stats
provider = model.provider
if provider not in provider_stats:
provider_stats[provider] = 0
provider_stats[provider] += 1
# Health stats
health_status = model.health_status or "unknown"
if health_status in health_stats:
health_stats[health_status] += 1
# Calculate recent usage from actual model data
recent_requests = sum(getattr(model, 'request_count', 0) for model in active_models)
recent_cost = sum(
getattr(model, 'request_count', 0) *
(model.cost_per_million_input + model.cost_per_million_output) / 1000000 * 100
for model in active_models
)
return {
"total_models": total_models,
"active_models": active_count,
"inactive_models": total_models - active_count,
"providers": provider_stats,
"health": health_stats,
"recent_usage": {
"requests_24h": recent_requests,
"cost_24h": recent_cost,
"healthy_percentage": (health_stats["healthy"] / active_count * 100) if active_count > 0 else 0
}
}
except Exception as e:
logger.error(f"Failed to get models overview stats: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# ============================================================================
# TENANT RATE LIMITS ENDPOINTS - MUST be before /{model_id:path} catch-all
# ============================================================================
class TenantRateLimitUpdate(BaseModel):
"""Request model for updating tenant rate limits"""
requests_per_minute: Optional[int] = Field(None, ge=1, description="Max requests per minute")
max_tokens_per_request: Optional[int] = Field(None, ge=1, description="Max tokens per request")
concurrent_requests: Optional[int] = Field(None, ge=1, description="Max concurrent requests")
max_cost_per_hour: Optional[float] = Field(None, ge=0, description="Max cost per hour in dollars")
is_enabled: Optional[bool] = Field(None, description="Whether the model is enabled for this tenant")
@router.get("/tenant-rate-limits/{model_id:path}")
async def get_model_tenant_rate_limits(
model_id: str,
db: AsyncSession = Depends(get_db)
):
"""Get all tenant rate limit configurations for a specific model"""
try:
service = get_model_management_service(db)
# Verify model exists
model = await service.get_model(model_id)
if not model:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
configs = await service.get_model_tenant_rate_limits(model_id)
return {
"model_id": model_id,
"model_name": model.name,
"tenant_configs": configs
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get tenant rate limits for model {model_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.patch("/tenant-rate-limits/{model_id:path}/{tenant_id}")
async def update_model_tenant_rate_limit(
model_id: str,
tenant_id: int,
update_request: TenantRateLimitUpdate,
db: AsyncSession = Depends(get_db)
):
"""Update rate limits for a specific tenant-model configuration"""
try:
service = get_model_management_service(db)
# Build updates dict
updates = {}
# Handle rate limits
rate_limits = {}
if update_request.requests_per_minute is not None:
rate_limits["requests_per_minute"] = update_request.requests_per_minute
if update_request.max_tokens_per_request is not None:
rate_limits["max_tokens_per_request"] = update_request.max_tokens_per_request
if update_request.concurrent_requests is not None:
rate_limits["concurrent_requests"] = update_request.concurrent_requests
if update_request.max_cost_per_hour is not None:
rate_limits["max_cost_per_hour"] = update_request.max_cost_per_hour
if rate_limits:
updates["rate_limits"] = rate_limits
# Handle enabled status
if update_request.is_enabled is not None:
updates["is_enabled"] = update_request.is_enabled
if not updates:
raise HTTPException(status_code=400, detail="No updates provided")
# Update the configuration
config = await service.update_tenant_model_config(tenant_id, model_id, updates)
if not config:
raise HTTPException(
status_code=404,
detail=f"Configuration not found for model {model_id} and tenant {tenant_id}"
)
return {
"message": "Rate limits updated successfully",
"config": config.to_dict()
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update tenant rate limits: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# ============================================================================
# PATH PARAMETER ROUTES - These catch-all routes MUST be last
# ============================================================================
@router.get("/{model_id:path}", response_model=ModelResponse)
async def get_model(
model_id: str,
db: AsyncSession = Depends(get_db)
):
"""Get a specific model configuration"""
try:
service = get_model_management_service(db)
model = await service.get_model(model_id)
if not model:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
return ModelResponse(**model.to_dict())
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get model {model_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.patch("/{model_id:path}", response_model=ModelResponse)
async def update_model(
model_id: str,
update_request: ModelUpdateRequest,
db: AsyncSession = Depends(get_db)
):
"""Update a model configuration"""
try:
service = get_model_management_service(db)
# Convert request to dict, excluding None values
updates = {k: v for k, v in update_request.dict().items() if v is not None}
# Endpoint URL is preserved as provided by user
if 'endpoint' in updates and updates['endpoint']:
logger.debug(f"Model {model_id} endpoint update: {updates['endpoint']}")
model = await service.update_model(model_id, updates)
if not model:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
return ModelResponse(**model.to_dict())
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update model {model_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/{model_id:path}")
async def delete_model(
model_id: str,
db: AsyncSession = Depends(get_db)
):
"""Permanently delete a model configuration"""
try:
service = get_model_management_service(db)
success = await service.delete_model(model_id)
if not success:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
return {"message": f"Model {model_id} deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete model {model_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/{model_id:path}/test", response_model=HealthCheckResponse)
async def test_model_endpoint(
model_id: str,
db: AsyncSession = Depends(get_db)
):
"""Test if a model endpoint is accessible"""
try:
service = get_model_management_service(db)
result = await service.test_endpoint(model_id)
# Determine status based on healthy flag and latency
latency_ms = result.get("latency_ms")
healthy = result.get("healthy", False)
if not healthy:
status = "unhealthy"
elif latency_ms and latency_ms > LATENCY_DEGRADED_THRESHOLD:
status = "degraded"
else:
status = "healthy"
return HealthCheckResponse(
healthy=healthy,
status=status,
latency_ms=latency_ms,
error=result.get("error"),
error_type=result.get("error_type"),
details=result.get("details"),
rate_limit_remaining=result.get("rate_limit_remaining"),
rate_limit_reset=result.get("rate_limit_reset"),
inference_validated=result.get("inference_validated")
)
except Exception as e:
logger.error(f"Failed to test model {model_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.patch("/{model_id:path}/endpoint")
async def update_model_endpoint(
model_id: str,
request: EndpointUpdateRequest,
db: AsyncSession = Depends(get_db)
):
"""Update model endpoint URL"""
try:
service = get_model_management_service(db)
# Endpoint URL is preserved as provided by user
endpoint = request.endpoint
logger.debug(f"Model {model_id} endpoint update: {endpoint}")
success = await service.update_endpoint(model_id, endpoint)
if not success:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
return {"message": f"Endpoint updated for {model_id}"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update endpoint for {model_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/health/bulk", response_model=BulkHealthResponse)
async def bulk_health_check(
db: AsyncSession = Depends(get_db)
):
"""Check health of all active models"""
try:
service = get_model_management_service(db)
result = await service.bulk_health_check()
return BulkHealthResponse(**result)
except Exception as e:
logger.error(f"Failed to perform bulk health check: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/analytics/usage")
async def get_usage_analytics(
time_range: str = Query("24h", description="Time range: 1h, 24h, 7d, 30d"),
db: AsyncSession = Depends(get_db)
):
"""Get real usage analytics for models from database"""
try:
service = get_model_management_service(db)
# Get all models to analyze their usage
models = await service.list_models(active_only=True)
# Calculate real metrics from model data
total_requests = sum(getattr(model, 'request_count', 0) for model in models)
total_tokens = sum(getattr(model, 'request_count', 0) * 100 for model in models) # Estimate tokens
total_cost = sum(
getattr(model, 'request_count', 0) *
(model.cost_per_million_input + model.cost_per_million_output) / 1000000 * 100 # Estimate cost
for model in models
)
# Provider breakdown
provider_stats = {}
for model in models:
provider = model.provider
if provider not in provider_stats:
provider_stats[provider] = {
'provider': provider,
'requests': 0,
'tokens': 0,
'cost': 0.0
}
requests = getattr(model, 'request_count', 0)
tokens = requests * 100 # Estimate
cost = requests * (model.cost_per_million_input + model.cost_per_million_output) / 1000000 * 100
provider_stats[provider]['requests'] += requests
provider_stats[provider]['tokens'] += tokens
provider_stats[provider]['cost'] += cost
# Top models by usage
top_models = []
for model in models:
requests = getattr(model, 'request_count', 0)
tokens = requests * 100
cost = requests * (model.cost_per_million_input + model.cost_per_million_output) / 1000000 * 100
top_models.append({
'model': model.model_id,
'requests': requests,
'tokens': f'{tokens/1000:.1f}K' if tokens < 1000000 else f'{tokens/1000000:.1f}M',
'cost': cost,
'avg_latency': getattr(model, 'avg_latency_ms', 0) or 200, # Default estimate
'success_rate': getattr(model, 'success_rate', 100.0)
})
# Sort by requests descending
top_models.sort(key=lambda x: x['requests'], reverse=True)
# Mock hourly pattern based on time range
import datetime
now = datetime.datetime.now()
hourly_usage = []
if time_range == '1h':
for i in range(12): # Last 12 5-minute intervals
time_point = now - datetime.timedelta(minutes=i*5)
hourly_usage.append({
'hour': time_point.strftime('%H:%M'),
'requests': max(0, int(total_requests * (0.8 + 0.4 * (i % 3)) / 12)),
'tokens': max(0, int(total_tokens * (0.8 + 0.4 * (i % 3)) / 12))
})
else:
for i in range(24): # Last 24 hours
time_point = now - datetime.timedelta(hours=i)
hourly_usage.append({
'hour': time_point.strftime('%H:00'),
'requests': max(0, int(total_requests * (0.8 + 0.4 * (i % 6)) / 24)),
'tokens': max(0, int(total_tokens * (0.8 + 0.4 * (i % 6)) / 24))
})
hourly_usage.reverse() # Chronological order
return {
'summary': {
'total_requests': total_requests,
'total_tokens': total_tokens,
'total_cost': total_cost,
'active_tenants': 12 # This would come from tenant usage data
},
'usage_by_provider': list(provider_stats.values()),
'top_models': top_models[:10],
'hourly_usage': hourly_usage,
'time_range': time_range
}
except Exception as e:
logger.error(f"Failed to get usage analytics: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/initialize/defaults")
async def initialize_default_models(
db: AsyncSession = Depends(get_db)
):
"""Initialize default models (19 Groq + BGE-M3)"""
try:
service = get_model_management_service(db)
await service.initialize_default_models()
return {"message": "Default models initialized successfully"}
except Exception as e:
logger.error(f"Failed to initialize default models: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
class BGE_M3_ConfigRequest(BaseModel):
is_local_mode: bool = True
external_endpoint: Optional[str] = "http://10.0.1.50:8080"
@router.get("/bge-m3/config")
async def get_bge_m3_config(
db: AsyncSession = Depends(get_db)
):
"""Get current BGE-M3 configuration"""
try:
# For now, return a default configuration
# In a full implementation, this would be stored in the database
return {
"is_local_mode": True,
"external_endpoint": "http://10.0.1.50:8080",
"local_endpoint": "http://gentwo-vllm-embeddings:8000/v1/embeddings"
}
except Exception as e:
logger.error(f"Failed to get BGE-M3 config: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/bge-m3/status")
async def get_bge_m3_status(
db: AsyncSession = Depends(get_db)
):
"""Get BGE-M3 status across all services"""
import httpx
import asyncio
services_to_check = [
{"name": "Resource Cluster", "url": "http://localhost:8003/api/embeddings/config/bge-m3"},
{"name": "Tenant Backend", "url": "http://localhost:8002/api/embeddings/config/bge-m3"},
]
async def check_service(service: dict) -> dict:
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(service["url"])
if response.status_code == 200:
data = response.json()
return {
"service": service["name"],
"status": "online",
"config": data
}
else:
return {
"service": service["name"],
"status": "error",
"error": f"HTTP {response.status_code}"
}
except Exception as e:
return {
"service": service["name"],
"status": "offline",
"error": str(e)
}
# Check all services in parallel
tasks = [check_service(service) for service in services_to_check]
service_statuses = await asyncio.gather(*tasks, return_exceptions=True)
return {
"services": service_statuses,
"timestamp": "2025-01-21T12:00:00Z"
}
@router.post("/bge-m3/config")
async def update_bge_m3_config(
config_request: BGE_M3_ConfigRequest,
db: AsyncSession = Depends(get_db)
):
"""Update BGE-M3 configuration and sync to services"""
try:
# External endpoint URL is preserved as provided by user
if config_request.external_endpoint:
logger.debug(f"BGE-M3 external endpoint: {config_request.external_endpoint}")
logger.info(f"BGE-M3 config updated: local_mode={config_request.is_local_mode}, external_endpoint={config_request.external_endpoint}")
# STEP 1: Persist configuration to database
# Determine the endpoint and provider based on mode
if config_request.is_local_mode:
endpoint = "http://gentwo-vllm-embeddings:8000/v1/embeddings"
provider = "external" # Still external provider, just local endpoint
else:
endpoint = config_request.external_endpoint
provider = "external"
# Update or create BGE-M3 model configuration in database
from sqlalchemy import select
stmt = select(ModelConfig).where(ModelConfig.model_id == "BAAI/bge-m3")
result = await db.execute(stmt)
bge_m3_model = result.scalar_one_or_none()
if bge_m3_model:
# Update existing configuration
bge_m3_model.endpoint = endpoint
bge_m3_model.provider = provider
bge_m3_model.is_active = True
# Store mode in config JSON for reference
if not bge_m3_model.config:
bge_m3_model.config = {}
bge_m3_model.config["is_local_mode"] = config_request.is_local_mode
if config_request.external_endpoint:
bge_m3_model.config["external_endpoint"] = config_request.external_endpoint
logger.info(f"Updated BGE-M3 model config in database: endpoint={endpoint}")
else:
# Create new BGE-M3 configuration
bge_m3_model = ModelConfig(
model_id="BAAI/bge-m3",
name="BGE-M3 Embedding Model",
version="1.5",
provider=provider,
model_type="embedding",
endpoint=endpoint,
dimensions=1024,
config={
"is_local_mode": config_request.is_local_mode,
"external_endpoint": config_request.external_endpoint
},
is_active=True,
description="BGE-M3 embedding model for document processing and RAG"
)
db.add(bge_m3_model)
logger.info(f"Created BGE-M3 model config in database: endpoint={endpoint}")
await db.commit()
logger.info("BGE-M3 configuration persisted to database successfully")
# STEP 2: Sync configuration to resource cluster
sync_success = await _sync_bge_m3_config_to_services(config_request)
if sync_success:
return {
"message": "BGE-M3 configuration updated and synced successfully",
"config": {
"is_local_mode": config_request.is_local_mode,
"external_endpoint": config_request.external_endpoint,
"local_endpoint": "http://gentwo-vllm-embeddings:8000/v1/embeddings"
},
"sync_status": "success",
"database_persistence": "success"
}
else:
return {
"message": "BGE-M3 configuration updated but sync failed",
"config": {
"is_local_mode": config_request.is_local_mode,
"external_endpoint": config_request.external_endpoint,
"local_endpoint": "http://gentwo-vllm-embeddings:8000/v1/embeddings"
},
"sync_status": "failed",
"database_persistence": "success"
}
except Exception as e:
logger.error(f"Failed to update BGE-M3 config: {e}")
await db.rollback()
raise HTTPException(status_code=500, detail="Internal server error")
async def _sync_bge_m3_config_to_services(config: BGE_M3_ConfigRequest) -> bool:
"""Sync BGE-M3 configuration to all services"""
import httpx
import asyncio
import os
# Use Docker service names for inter-container communication
# Check if we're in Docker (presence of /.dockerenv) or development (localhost)
in_docker = os.path.exists('/.dockerenv')
if in_docker:
services_to_sync = [
"http://gentwo-resource-backend:8000", # Resource cluster (Docker service name)
"http://gentwo-tenant-backend:8000", # Tenant backend (Docker service name)
]
else:
services_to_sync = [
"http://localhost:8004", # Resource cluster (host port mapping)
"http://localhost:8002", # Tenant backend (host port mapping)
]
sync_results = []
async def sync_to_service(service_url: str) -> bool:
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# Determine the correct endpoint path for each service
if 'resource-backend' in service_url or 'localhost:8004' in service_url:
endpoint_path = "/api/v1/embeddings/config/bge-m3"
else: # Tenant backend
endpoint_path = "/api/embeddings/config/bge-m3"
response = await client.post(
f"{service_url}{endpoint_path}",
json={
"is_local_mode": config.is_local_mode,
"external_endpoint": config.external_endpoint
},
headers={
"Content-Type": "application/json",
# TODO: Add proper service-to-service authentication
}
)
if response.status_code == 200:
logger.info(f"Successfully synced BGE-M3 config to {service_url}")
return True
else:
logger.warning(f"Failed to sync BGE-M3 config to {service_url}: {response.status_code}")
return False
except Exception as e:
logger.error(f"Error syncing BGE-M3 config to {service_url}: {e}")
return False
# Sync to all services in parallel
tasks = [sync_to_service(url) for url in services_to_sync]
sync_results = await asyncio.gather(*tasks, return_exceptions=True)
# Return True if at least one sync succeeded
success_count = sum(1 for result in sync_results if result is True)
logger.info(f"BGE-M3 config sync: {success_count}/{len(services_to_sync)} services updated")
return success_count > 0