Security hardening release addressing CodeQL and Dependabot alerts: - Fix stack trace exposure in error responses - Add SSRF protection with DNS resolution checking - Implement proper URL hostname validation (replaces substring matching) - Add centralized path sanitization to prevent path traversal - Fix ReDoS vulnerability in email validation regex - Improve HTML sanitization in validation utilities - Fix capability wildcard matching in auth utilities - Update glob dependency to address CVE - Add CodeQL suppression comments for verified false positives 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
720 lines
28 KiB
Python
720 lines
28 KiB
Python
"""
|
|
GT 2.0 Model Management Service - Stateless Version
|
|
|
|
Provides centralized model registry, versioning, deployment, and lifecycle management
|
|
for all AI models across the Resource Cluster using in-memory storage.
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
import asyncio
|
|
from typing import Dict, Any, List, Optional, Union
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
import hashlib
|
|
import httpx
|
|
import logging
|
|
|
|
from app.core.config import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
settings = get_settings()
|
|
|
|
|
|
class ModelService:
|
|
"""Stateless model management service with in-memory registry"""
|
|
|
|
def __init__(self, tenant_id: Optional[str] = None):
|
|
self.tenant_id = tenant_id
|
|
self.settings = get_settings(tenant_id)
|
|
|
|
# In-memory model registry for stateless operation
|
|
self.model_registry: Dict[str, Dict[str, Any]] = {}
|
|
self.last_cache_update = 0
|
|
self.cache_ttl = 300 # 5 minutes
|
|
|
|
# Performance tracking (in-memory)
|
|
self.performance_metrics: Dict[str, Dict[str, Any]] = {}
|
|
|
|
# Initialize with default models synchronously
|
|
self._initialize_default_models_sync()
|
|
|
|
async def register_model(
|
|
self,
|
|
model_id: str,
|
|
name: str,
|
|
version: str,
|
|
provider: str,
|
|
model_type: str,
|
|
description: str = "",
|
|
capabilities: Dict[str, Any] = None,
|
|
parameters: Dict[str, Any] = None,
|
|
endpoint_url: str = None,
|
|
**kwargs
|
|
) -> Dict[str, Any]:
|
|
"""Register a new model in the in-memory registry"""
|
|
|
|
now = datetime.utcnow()
|
|
|
|
# Create or update model entry
|
|
model_entry = {
|
|
"id": model_id,
|
|
"name": name,
|
|
"version": version,
|
|
"provider": provider,
|
|
"model_type": model_type,
|
|
"description": description,
|
|
"capabilities": capabilities or {},
|
|
"parameters": parameters or {},
|
|
|
|
# Performance metrics
|
|
"max_tokens": kwargs.get("max_tokens", 4000),
|
|
"context_window": kwargs.get("context_window", 4000),
|
|
"cost_per_1k_tokens": kwargs.get("cost_per_1k_tokens", 0.0),
|
|
"latency_p50_ms": kwargs.get("latency_p50_ms", 0.0),
|
|
"latency_p95_ms": kwargs.get("latency_p95_ms", 0.0),
|
|
|
|
# Deployment status
|
|
"deployment_status": kwargs.get("deployment_status", "available"),
|
|
"health_status": kwargs.get("health_status", "unknown"),
|
|
"last_health_check": kwargs.get("last_health_check"),
|
|
|
|
# Usage tracking
|
|
"request_count": kwargs.get("request_count", 0),
|
|
"error_count": kwargs.get("error_count", 0),
|
|
"success_rate": kwargs.get("success_rate", 1.0),
|
|
|
|
# Lifecycle
|
|
"created_at": now.isoformat(),
|
|
"updated_at": now.isoformat(),
|
|
"retired_at": kwargs.get("retired_at"),
|
|
|
|
# Configuration
|
|
"endpoint_url": endpoint_url,
|
|
"api_key_required": kwargs.get("api_key_required", True),
|
|
"rate_limits": kwargs.get("rate_limits", {})
|
|
}
|
|
|
|
self.model_registry[model_id] = model_entry
|
|
|
|
logger.info(f"Registered model: {model_id} ({name} v{version})")
|
|
return model_entry
|
|
|
|
async def get_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get model by ID"""
|
|
return self.model_registry.get(model_id)
|
|
|
|
async def list_models(
|
|
self,
|
|
provider: str = None,
|
|
model_type: str = None,
|
|
deployment_status: str = None,
|
|
health_status: str = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""List models with optional filters"""
|
|
|
|
models = list(self.model_registry.values())
|
|
|
|
# Apply filters
|
|
if provider:
|
|
models = [m for m in models if m["provider"] == provider]
|
|
if model_type:
|
|
models = [m for m in models if m["model_type"] == model_type]
|
|
if deployment_status:
|
|
models = [m for m in models if m["deployment_status"] == deployment_status]
|
|
if health_status:
|
|
models = [m for m in models if m["health_status"] == health_status]
|
|
|
|
# Sort by created_at desc
|
|
models.sort(key=lambda x: x["created_at"], reverse=True)
|
|
return models
|
|
|
|
async def update_model_status(
|
|
self,
|
|
model_id: str,
|
|
deployment_status: str = None,
|
|
health_status: str = None
|
|
) -> bool:
|
|
"""Update model deployment and health status"""
|
|
|
|
model = self.model_registry.get(model_id)
|
|
if not model:
|
|
return False
|
|
|
|
if deployment_status:
|
|
model["deployment_status"] = deployment_status
|
|
if health_status:
|
|
model["health_status"] = health_status
|
|
model["last_health_check"] = datetime.utcnow().isoformat()
|
|
|
|
model["updated_at"] = datetime.utcnow().isoformat()
|
|
|
|
return True
|
|
|
|
async def track_model_usage(
|
|
self,
|
|
model_id: str,
|
|
success: bool = True,
|
|
latency_ms: float = None
|
|
):
|
|
"""Track model usage and performance metrics"""
|
|
|
|
model = self.model_registry.get(model_id)
|
|
if not model:
|
|
return
|
|
|
|
# Update usage counters
|
|
model["request_count"] += 1
|
|
if not success:
|
|
model["error_count"] += 1
|
|
|
|
# Calculate success rate
|
|
model["success_rate"] = (model["request_count"] - model["error_count"]) / model["request_count"]
|
|
|
|
# Update latency metrics (simple running average)
|
|
if latency_ms is not None:
|
|
if model["latency_p50_ms"] == 0:
|
|
model["latency_p50_ms"] = latency_ms
|
|
else:
|
|
# Simple exponential moving average
|
|
alpha = 0.1
|
|
model["latency_p50_ms"] = alpha * latency_ms + (1 - alpha) * model["latency_p50_ms"]
|
|
|
|
# P95 approximation (conservative estimate)
|
|
model["latency_p95_ms"] = max(model["latency_p95_ms"], latency_ms * 1.5)
|
|
|
|
model["updated_at"] = datetime.utcnow().isoformat()
|
|
|
|
async def retire_model(self, model_id: str, reason: str = "") -> bool:
|
|
"""Retire a model (mark as no longer available)"""
|
|
|
|
model = self.model_registry.get(model_id)
|
|
if not model:
|
|
return False
|
|
|
|
model["deployment_status"] = "retired"
|
|
model["retired_at"] = datetime.utcnow().isoformat()
|
|
model["updated_at"] = datetime.utcnow().isoformat()
|
|
|
|
if reason:
|
|
model["description"] += f"\n\nRetired: {reason}"
|
|
|
|
logger.info(f"Retired model: {model_id} - {reason}")
|
|
return True
|
|
|
|
async def check_model_health(self, model_id: str) -> Dict[str, Any]:
|
|
"""Check health of a specific model"""
|
|
|
|
model = await self.get_model(model_id)
|
|
if not model:
|
|
return {"healthy": False, "error": "Model not found"}
|
|
|
|
# Generic health check for any provider with endpoint
|
|
if "endpoint" in model and model["endpoint"]:
|
|
return await self._check_generic_model_health(model)
|
|
elif model["provider"] == "groq":
|
|
return await self._check_groq_model_health(model)
|
|
elif model["provider"] == "openai":
|
|
return await self._check_openai_model_health(model)
|
|
elif model["provider"] == "local":
|
|
return await self._check_local_model_health(model)
|
|
else:
|
|
return {"healthy": False, "error": f"No health check method for provider: {model['provider']}"}
|
|
|
|
async def _check_groq_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Health check for Groq models"""
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"https://api.groq.com/openai/v1/models",
|
|
headers={"Authorization": f"Bearer {settings.groq_api_key}"},
|
|
timeout=10.0
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
models = response.json()
|
|
model_ids = [m["id"] for m in models.get("data", [])]
|
|
is_available = model["id"] in model_ids
|
|
|
|
await self.update_model_status(
|
|
model["id"],
|
|
health_status="healthy" if is_available else "unhealthy"
|
|
)
|
|
|
|
return {
|
|
"healthy": is_available,
|
|
"latency_ms": response.elapsed.total_seconds() * 1000,
|
|
"available_models": len(model_ids)
|
|
}
|
|
else:
|
|
await self.update_model_status(model["id"], health_status="unhealthy")
|
|
return {"healthy": False, "error": f"API error: {response.status_code}"}
|
|
|
|
except Exception as e:
|
|
await self.update_model_status(model["id"], health_status="unhealthy")
|
|
return {"healthy": False, "error": str(e)}
|
|
|
|
async def _check_openai_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Health check for OpenAI models"""
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"https://api.openai.com/v1/models",
|
|
headers={"Authorization": f"Bearer {settings.openai_api_key}"},
|
|
timeout=10.0
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
models = response.json()
|
|
model_ids = [m["id"] for m in models.get("data", [])]
|
|
is_available = model["id"] in model_ids
|
|
|
|
await self.update_model_status(
|
|
model["id"],
|
|
health_status="healthy" if is_available else "unhealthy"
|
|
)
|
|
|
|
return {
|
|
"healthy": is_available,
|
|
"latency_ms": response.elapsed.total_seconds() * 1000
|
|
}
|
|
else:
|
|
await self.update_model_status(model["id"], health_status="unhealthy")
|
|
return {"healthy": False, "error": f"API error: {response.status_code}"}
|
|
|
|
except Exception as e:
|
|
await self.update_model_status(model["id"], health_status="unhealthy")
|
|
return {"healthy": False, "error": str(e)}
|
|
|
|
async def _check_generic_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Generic health check for any provider with configured endpoint"""
|
|
try:
|
|
endpoint_url = model.get("endpoint")
|
|
if not endpoint_url:
|
|
return {"healthy": False, "error": "No endpoint URL configured"}
|
|
|
|
# Try a simple health check by making a minimal request
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
# For OpenAI-compatible endpoints, try a models list request
|
|
try:
|
|
# Try /v1/models endpoint first (common for OpenAI-compatible APIs)
|
|
models_url = endpoint_url.replace("/chat/completions", "/models").replace("/v1/chat/completions", "/v1/models")
|
|
response = await client.get(models_url)
|
|
|
|
if response.status_code == 200:
|
|
await self.update_model_status(model["id"], health_status="healthy")
|
|
return {
|
|
"healthy": True,
|
|
"provider": model.get("provider", "unknown"),
|
|
"latency_ms": 0, # Could measure actual latency
|
|
"last_check": datetime.utcnow().isoformat(),
|
|
"details": "Endpoint responding to models request"
|
|
}
|
|
except:
|
|
pass
|
|
|
|
# If models endpoint doesn't work, try a basic health endpoint
|
|
try:
|
|
health_url = endpoint_url.replace("/chat/completions", "/health").replace("/v1/chat/completions", "/health")
|
|
response = await client.get(health_url)
|
|
|
|
if response.status_code == 200:
|
|
await self.update_model_status(model["id"], health_status="healthy")
|
|
return {
|
|
"healthy": True,
|
|
"provider": model.get("provider", "unknown"),
|
|
"latency_ms": 0,
|
|
"last_check": datetime.utcnow().isoformat(),
|
|
"details": "Endpoint responding to health check"
|
|
}
|
|
except:
|
|
pass
|
|
|
|
# If neither works, assume healthy if endpoint is reachable at all
|
|
await self.update_model_status(model["id"], health_status="unknown")
|
|
return {
|
|
"healthy": True, # Assume healthy for generic endpoints
|
|
"provider": model.get("provider", "unknown"),
|
|
"latency_ms": 0,
|
|
"last_check": datetime.utcnow().isoformat(),
|
|
"details": "Generic endpoint - health check not available"
|
|
}
|
|
|
|
except Exception as e:
|
|
await self.update_model_status(model["id"], health_status="unhealthy")
|
|
return {"healthy": False, "error": f"Health check failed: {str(e)}"}
|
|
|
|
async def _check_local_model_health(self, model: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Health check for local models"""
|
|
try:
|
|
endpoint_url = model.get("endpoint_url")
|
|
if not endpoint_url:
|
|
return {"healthy": False, "error": "No endpoint URL configured"}
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"{endpoint_url}/health",
|
|
timeout=5.0
|
|
)
|
|
|
|
healthy = response.status_code == 200
|
|
await self.update_model_status(
|
|
model["id"],
|
|
health_status="healthy" if healthy else "unhealthy"
|
|
)
|
|
|
|
return {
|
|
"healthy": healthy,
|
|
"latency_ms": response.elapsed.total_seconds() * 1000
|
|
}
|
|
|
|
except Exception as e:
|
|
await self.update_model_status(model["id"], health_status="unhealthy")
|
|
return {"healthy": False, "error": str(e)}
|
|
|
|
async def bulk_health_check(self) -> Dict[str, Any]:
|
|
"""Check health of all registered models"""
|
|
|
|
models = await self.list_models()
|
|
health_results = {}
|
|
|
|
# Run health checks concurrently
|
|
tasks = []
|
|
for model in models:
|
|
task = asyncio.create_task(self.check_model_health(model["id"]))
|
|
tasks.append((model["id"], task))
|
|
|
|
for model_id, task in tasks:
|
|
try:
|
|
health_result = await task
|
|
health_results[model_id] = health_result
|
|
except Exception as e:
|
|
health_results[model_id] = {"healthy": False, "error": str(e)}
|
|
|
|
# Calculate overall health statistics
|
|
total_models = len(health_results)
|
|
healthy_models = sum(1 for result in health_results.values() if result.get("healthy", False))
|
|
|
|
return {
|
|
"total_models": total_models,
|
|
"healthy_models": healthy_models,
|
|
"unhealthy_models": total_models - healthy_models,
|
|
"health_percentage": (healthy_models / total_models * 100) if total_models > 0 else 0,
|
|
"individual_results": health_results
|
|
}
|
|
|
|
async def get_model_analytics(
|
|
self,
|
|
model_id: str = None,
|
|
timeframe_hours: int = 24
|
|
) -> Dict[str, Any]:
|
|
"""Get analytics for model usage and performance"""
|
|
|
|
models = await self.list_models()
|
|
if model_id:
|
|
models = [m for m in models if m["id"] == model_id]
|
|
|
|
analytics = {
|
|
"total_models": len(models),
|
|
"by_provider": {},
|
|
"by_type": {},
|
|
"performance_summary": {
|
|
"avg_latency_p50": 0,
|
|
"avg_success_rate": 0,
|
|
"total_requests": 0,
|
|
"total_errors": 0
|
|
},
|
|
"top_performers": [],
|
|
"models": models
|
|
}
|
|
|
|
total_latency = 0
|
|
total_success_rate = 0
|
|
total_requests = 0
|
|
total_errors = 0
|
|
|
|
for model in models:
|
|
# Provider statistics
|
|
provider = model["provider"]
|
|
if provider not in analytics["by_provider"]:
|
|
analytics["by_provider"][provider] = {"count": 0, "requests": 0}
|
|
analytics["by_provider"][provider]["count"] += 1
|
|
analytics["by_provider"][provider]["requests"] += model["request_count"]
|
|
|
|
# Type statistics
|
|
model_type = model["model_type"]
|
|
if model_type not in analytics["by_type"]:
|
|
analytics["by_type"][model_type] = {"count": 0, "requests": 0}
|
|
analytics["by_type"][model_type]["count"] += 1
|
|
analytics["by_type"][model_type]["requests"] += model["request_count"]
|
|
|
|
# Performance aggregation
|
|
total_latency += model["latency_p50_ms"]
|
|
total_success_rate += model["success_rate"]
|
|
total_requests += model["request_count"]
|
|
total_errors += model["error_count"]
|
|
|
|
# Calculate averages
|
|
if len(models) > 0:
|
|
analytics["performance_summary"]["avg_latency_p50"] = total_latency / len(models)
|
|
analytics["performance_summary"]["avg_success_rate"] = total_success_rate / len(models)
|
|
|
|
analytics["performance_summary"]["total_requests"] = total_requests
|
|
analytics["performance_summary"]["total_errors"] = total_errors
|
|
|
|
# Top performers (by success rate and low latency)
|
|
analytics["top_performers"] = sorted(
|
|
[m for m in models if m["request_count"] > 0],
|
|
key=lambda x: (x["success_rate"], -x["latency_p50_ms"]),
|
|
reverse=True
|
|
)[:5]
|
|
|
|
return analytics
|
|
|
|
async def _initialize_default_models(self):
|
|
"""Initialize registry with default models"""
|
|
|
|
# Groq models
|
|
groq_models = [
|
|
{
|
|
"model_id": "llama-3.1-405b-reasoning",
|
|
"name": "Llama 3.1 405B Reasoning",
|
|
"version": "3.1",
|
|
"provider": "groq",
|
|
"model_type": "llm",
|
|
"description": "Largest Llama model optimized for complex reasoning tasks",
|
|
"max_tokens": 8000,
|
|
"context_window": 32768,
|
|
"cost_per_1k_tokens": 2.5,
|
|
"capabilities": {"reasoning": True, "function_calling": True, "streaming": True}
|
|
},
|
|
{
|
|
"model_id": "llama-3.1-70b-versatile",
|
|
"name": "Llama 3.1 70B Versatile",
|
|
"version": "3.1",
|
|
"provider": "groq",
|
|
"model_type": "llm",
|
|
"description": "Balanced Llama model for general-purpose tasks",
|
|
"max_tokens": 8000,
|
|
"context_window": 32768,
|
|
"cost_per_1k_tokens": 0.8,
|
|
"capabilities": {"general": True, "function_calling": True, "streaming": True}
|
|
},
|
|
{
|
|
"model_id": "llama-3.1-8b-instant",
|
|
"name": "Llama 3.1 8B Instant",
|
|
"version": "3.1",
|
|
"provider": "groq",
|
|
"model_type": "llm",
|
|
"description": "Fast Llama model for quick responses",
|
|
"max_tokens": 8000,
|
|
"context_window": 32768,
|
|
"cost_per_1k_tokens": 0.2,
|
|
"capabilities": {"fast": True, "streaming": True}
|
|
},
|
|
{
|
|
"model_id": "mixtral-8x7b-32768",
|
|
"name": "Mixtral 8x7B",
|
|
"version": "1.0",
|
|
"provider": "groq",
|
|
"model_type": "llm",
|
|
"description": "Mixtral model for balanced performance",
|
|
"max_tokens": 32768,
|
|
"context_window": 32768,
|
|
"cost_per_1k_tokens": 0.27,
|
|
"capabilities": {"general": True, "streaming": True}
|
|
}
|
|
]
|
|
|
|
for model_config in groq_models:
|
|
await self.register_model(**model_config)
|
|
|
|
logger.info("Initialized default model registry with in-memory storage")
|
|
|
|
def _initialize_default_models_sync(self):
|
|
"""Initialize registry with default models synchronously"""
|
|
|
|
# Groq models
|
|
groq_models = [
|
|
{
|
|
"model_id": "llama-3.1-405b-reasoning",
|
|
"name": "Llama 3.1 405B Reasoning",
|
|
"version": "3.1",
|
|
"provider": "groq",
|
|
"model_type": "llm",
|
|
"description": "Largest Llama model optimized for complex reasoning tasks",
|
|
"max_tokens": 8000,
|
|
"context_window": 32768,
|
|
"cost_per_1k_tokens": 2.5,
|
|
"capabilities": {"reasoning": True, "function_calling": True, "streaming": True}
|
|
},
|
|
{
|
|
"model_id": "llama-3.1-70b-versatile",
|
|
"name": "Llama 3.1 70B Versatile",
|
|
"version": "3.1",
|
|
"provider": "groq",
|
|
"model_type": "llm",
|
|
"description": "Balanced Llama model for general-purpose tasks",
|
|
"max_tokens": 8000,
|
|
"context_window": 32768,
|
|
"cost_per_1k_tokens": 0.8,
|
|
"capabilities": {"general": True, "function_calling": True, "streaming": True}
|
|
},
|
|
{
|
|
"model_id": "llama-3.1-8b-instant",
|
|
"name": "Llama 3.1 8B Instant",
|
|
"version": "3.1",
|
|
"provider": "groq",
|
|
"model_type": "llm",
|
|
"description": "Fast Llama model for quick responses",
|
|
"max_tokens": 8000,
|
|
"context_window": 32768,
|
|
"cost_per_1k_tokens": 0.2,
|
|
"capabilities": {"fast": True, "streaming": True}
|
|
},
|
|
{
|
|
"model_id": "mixtral-8x7b-32768",
|
|
"name": "Mixtral 8x7B",
|
|
"version": "1.0",
|
|
"provider": "groq",
|
|
"model_type": "llm",
|
|
"description": "Mixtral model for balanced performance",
|
|
"max_tokens": 32768,
|
|
"context_window": 32768,
|
|
"cost_per_1k_tokens": 0.27,
|
|
"capabilities": {"general": True, "streaming": True}
|
|
},
|
|
{
|
|
"model_id": "groq/compound",
|
|
"name": "Groq Compound Model",
|
|
"version": "1.0",
|
|
"provider": "groq",
|
|
"model_type": "llm",
|
|
"description": "Groq compound AI model",
|
|
"max_tokens": 8000,
|
|
"context_window": 8000,
|
|
"cost_per_1k_tokens": 0.5,
|
|
"capabilities": {"general": True, "streaming": True}
|
|
}
|
|
]
|
|
|
|
for model_config in groq_models:
|
|
now = datetime.utcnow()
|
|
model_entry = {
|
|
"id": model_config["model_id"],
|
|
"name": model_config["name"],
|
|
"version": model_config["version"],
|
|
"provider": model_config["provider"],
|
|
"model_type": model_config["model_type"],
|
|
"description": model_config["description"],
|
|
"capabilities": model_config["capabilities"],
|
|
"parameters": {},
|
|
|
|
# Performance metrics
|
|
"max_tokens": model_config["max_tokens"],
|
|
"context_window": model_config["context_window"],
|
|
"cost_per_1k_tokens": model_config["cost_per_1k_tokens"],
|
|
"latency_p50_ms": 0.0,
|
|
"latency_p95_ms": 0.0,
|
|
|
|
# Deployment status
|
|
"deployment_status": "available",
|
|
"health_status": "unknown",
|
|
"last_health_check": None,
|
|
|
|
# Usage tracking
|
|
"request_count": 0,
|
|
"error_count": 0,
|
|
"success_rate": 1.0,
|
|
|
|
# Lifecycle
|
|
"created_at": now.isoformat(),
|
|
"updated_at": now.isoformat(),
|
|
"retired_at": None,
|
|
|
|
# Configuration
|
|
"endpoint_url": None,
|
|
"api_key_required": True,
|
|
"rate_limits": {}
|
|
}
|
|
|
|
self.model_registry[model_config["model_id"]] = model_entry
|
|
|
|
logger.info("Initialized default model registry with in-memory storage (sync)")
|
|
|
|
async def register_or_update_model(
|
|
self,
|
|
model_id: str,
|
|
name: str,
|
|
version: str = "1.0",
|
|
provider: str = "unknown",
|
|
model_type: str = "llm",
|
|
endpoint: str = "",
|
|
api_key_name: str = None,
|
|
specifications: Dict[str, Any] = None,
|
|
capabilities: Dict[str, Any] = None,
|
|
cost: Dict[str, Any] = None,
|
|
description: str = "",
|
|
config: Dict[str, Any] = None,
|
|
status: Dict[str, Any] = None,
|
|
sync_timestamp: str = None
|
|
) -> Dict[str, Any]:
|
|
"""Register a new model or update existing one from admin cluster sync"""
|
|
|
|
specifications = specifications or {}
|
|
capabilities = capabilities or {}
|
|
cost = cost or {}
|
|
config = config or {}
|
|
status = status or {}
|
|
|
|
# Check if model exists
|
|
existing_model = self.model_registry.get(model_id)
|
|
|
|
if existing_model:
|
|
# Update existing model
|
|
existing_model.update({
|
|
"name": name,
|
|
"version": version,
|
|
"provider": provider,
|
|
"model_type": model_type,
|
|
"description": description,
|
|
"capabilities": capabilities,
|
|
"parameters": config,
|
|
"endpoint_url": endpoint,
|
|
"api_key_required": bool(api_key_name),
|
|
"max_tokens": specifications.get("max_tokens", existing_model.get("max_tokens", 4000)),
|
|
"context_window": specifications.get("context_window", existing_model.get("context_window", 4000)),
|
|
"cost_per_1k_tokens": cost.get("per_1k_input", existing_model.get("cost_per_1k_tokens", 0.0)),
|
|
"deployment_status": "deployed" if status.get("is_active", True) else "retired",
|
|
"updated_at": datetime.utcnow().isoformat()
|
|
})
|
|
|
|
if "bge-m3" in model_id.lower():
|
|
logger.info(f"Updated BGE-M3 model: endpoint_url={endpoint}, parameters={config}")
|
|
logger.debug(f"Updated model: {model_id}")
|
|
return existing_model
|
|
else:
|
|
# Register new model
|
|
return await self.register_model(
|
|
model_id=model_id,
|
|
name=name,
|
|
version=version,
|
|
provider=provider,
|
|
model_type=model_type,
|
|
description=description,
|
|
capabilities=capabilities,
|
|
parameters=config,
|
|
endpoint_url=endpoint,
|
|
max_tokens=specifications.get("max_tokens", 4000),
|
|
context_window=specifications.get("context_window", 4000),
|
|
cost_per_1k_tokens=cost.get("per_1k_input", 0.0),
|
|
api_key_required=bool(api_key_name)
|
|
)
|
|
|
|
|
|
def get_model_service(tenant_id: Optional[str] = None) -> ModelService:
|
|
"""Get tenant-isolated model service instance"""
|
|
return ModelService(tenant_id=tenant_id)
|
|
|
|
# Default model service for development/non-tenant operations
|
|
default_model_service = get_model_service() |