GT AI OS Community Edition v2.0.33
Security hardening release addressing CodeQL and Dependabot alerts: - Fix stack trace exposure in error responses - Add SSRF protection with DNS resolution checking - Implement proper URL hostname validation (replaces substring matching) - Add centralized path sanitization to prevent path traversal - Fix ReDoS vulnerability in email validation regex - Improve HTML sanitization in validation utilities - Fix capability wildcard matching in auth utilities - Update glob dependency to address CVE - Add CodeQL suppression comments for verified false positives 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
3
apps/resource-cluster/app/core/__init__.py
Normal file
3
apps/resource-cluster/app/core/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Core utilities and configuration for Resource Cluster
|
||||
"""
|
||||
140
apps/resource-cluster/app/core/api_standards.py
Normal file
140
apps/resource-cluster/app/core/api_standards.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
GT 2.0 Resource Cluster - API Standards Integration
|
||||
|
||||
This module integrates CB-REST standards for non-AI endpoints while
|
||||
maintaining OpenAI compatibility for AI inference endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the api-standards package to the path
|
||||
api_standards_path = Path(__file__).parent.parent.parent.parent.parent / "packages" / "api-standards" / "src"
|
||||
if api_standards_path.exists():
|
||||
sys.path.insert(0, str(api_standards_path))
|
||||
|
||||
# Import CB-REST standards
|
||||
try:
|
||||
from response import StandardResponse, format_response, format_error
|
||||
from capability import (
|
||||
init_capability_verifier,
|
||||
verify_capability,
|
||||
require_capability,
|
||||
Capability,
|
||||
CapabilityToken
|
||||
)
|
||||
from errors import ErrorCode, APIError, raise_api_error
|
||||
from middleware import (
|
||||
RequestCorrelationMiddleware,
|
||||
CapabilityMiddleware,
|
||||
TenantIsolationMiddleware,
|
||||
RateLimitMiddleware
|
||||
)
|
||||
except ImportError as e:
|
||||
# Fallback for development - create minimal implementations
|
||||
print(f"Warning: Could not import api-standards package: {e}")
|
||||
|
||||
# Create minimal implementations for development
|
||||
class StandardResponse:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def format_response(data, capability_used, request_id=None):
|
||||
return {
|
||||
"data": data,
|
||||
"error": None,
|
||||
"capability_used": capability_used,
|
||||
"request_id": request_id or "dev-mode"
|
||||
}
|
||||
|
||||
def format_error(code, message, capability_used="none", **kwargs):
|
||||
return {
|
||||
"data": None,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message,
|
||||
**kwargs
|
||||
},
|
||||
"capability_used": capability_used,
|
||||
"request_id": kwargs.get("request_id", "dev-mode")
|
||||
}
|
||||
|
||||
class ErrorCode:
|
||||
CAPABILITY_INSUFFICIENT = "CAPABILITY_INSUFFICIENT"
|
||||
RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"
|
||||
INVALID_REQUEST = "INVALID_REQUEST"
|
||||
SYSTEM_ERROR = "SYSTEM_ERROR"
|
||||
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"
|
||||
|
||||
class APIError(Exception):
|
||||
def __init__(self, code, message, **kwargs):
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.kwargs = kwargs
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# Export all CB-REST components
|
||||
__all__ = [
|
||||
'StandardResponse',
|
||||
'format_response',
|
||||
'format_error',
|
||||
'init_capability_verifier',
|
||||
'verify_capability',
|
||||
'require_capability',
|
||||
'Capability',
|
||||
'CapabilityToken',
|
||||
'ErrorCode',
|
||||
'APIError',
|
||||
'raise_api_error',
|
||||
'RequestCorrelationMiddleware',
|
||||
'CapabilityMiddleware',
|
||||
'TenantIsolationMiddleware',
|
||||
'RateLimitMiddleware'
|
||||
]
|
||||
|
||||
|
||||
def setup_api_standards(app, secret_key: str):
|
||||
"""
|
||||
Setup API standards for the Resource Cluster
|
||||
|
||||
IMPORTANT: This only applies CB-REST to non-AI endpoints.
|
||||
AI inference endpoints maintain OpenAI compatibility.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
secret_key: Secret key for JWT signing
|
||||
"""
|
||||
# Initialize capability verifier
|
||||
if 'init_capability_verifier' in globals():
|
||||
init_capability_verifier(secret_key)
|
||||
|
||||
# Add middleware in correct order
|
||||
if 'RequestCorrelationMiddleware' in globals():
|
||||
app.add_middleware(RequestCorrelationMiddleware)
|
||||
|
||||
if 'RateLimitMiddleware' in globals():
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
requests_per_minute=1000 # Higher limit for resource cluster
|
||||
)
|
||||
|
||||
# Note: No TenantIsolationMiddleware for Resource Cluster
|
||||
# as it serves multiple tenants with capability-based access
|
||||
|
||||
if 'CapabilityMiddleware' in globals():
|
||||
# Exclude AI inference endpoints from CB-REST middleware
|
||||
# to maintain OpenAI compatibility
|
||||
app.add_middleware(
|
||||
CapabilityMiddleware,
|
||||
exclude_paths=[
|
||||
"/health",
|
||||
"/ready",
|
||||
"/metrics",
|
||||
"/ai/chat/completions", # OpenAI compatible
|
||||
"/ai/embeddings", # OpenAI compatible
|
||||
"/ai/images/generations", # OpenAI compatible
|
||||
"/ai/models" # OpenAI compatible
|
||||
]
|
||||
)
|
||||
52
apps/resource-cluster/app/core/backends/__init__.py
Normal file
52
apps/resource-cluster/app/core/backends/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Resource backend implementations for GT 2.0
|
||||
|
||||
Provides unified interfaces for all resource types:
|
||||
- LLM inference (Groq, OpenAI, Anthropic)
|
||||
- Vector databases (PGVector)
|
||||
- Document processing (Unstructured)
|
||||
- External services (OAuth2, iframe)
|
||||
- AI literacy resources
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Registry of available backends
|
||||
BACKEND_REGISTRY: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def register_backend(name: str, backend_class):
|
||||
"""Register a resource backend"""
|
||||
BACKEND_REGISTRY[name] = backend_class
|
||||
logger.info(f"Registered backend: {name}")
|
||||
|
||||
|
||||
def get_backend(name: str):
|
||||
"""Get a registered backend"""
|
||||
if name not in BACKEND_REGISTRY:
|
||||
raise ValueError(f"Backend not found: {name}")
|
||||
return BACKEND_REGISTRY[name]
|
||||
|
||||
|
||||
async def initialize_backends():
|
||||
"""Initialize all resource backends"""
|
||||
from app.core.backends.groq_proxy import GroqProxyBackend
|
||||
from app.core.backends.nvidia_proxy import NvidiaProxyBackend
|
||||
from app.core.backends.document_processor import DocumentProcessorBackend
|
||||
from app.core.backends.embedding_backend import EmbeddingBackend
|
||||
|
||||
# Register backends
|
||||
register_backend("groq_proxy", GroqProxyBackend())
|
||||
register_backend("nvidia_proxy", NvidiaProxyBackend())
|
||||
register_backend("document_processor", DocumentProcessorBackend())
|
||||
register_backend("embedding", EmbeddingBackend())
|
||||
|
||||
logger.info("All resource backends initialized")
|
||||
|
||||
|
||||
def get_embedding_backend():
|
||||
"""Get the embedding backend instance"""
|
||||
return get_backend("embedding")
|
||||
322
apps/resource-cluster/app/core/backends/document_processor.py
Normal file
322
apps/resource-cluster/app/core/backends/document_processor.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Document Processing Backend
|
||||
|
||||
STATELESS document chunking and preprocessing for RAG operations.
|
||||
All processing happens in memory - NO user data is ever stored.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import io
|
||||
import gc
|
||||
from typing import Dict, Any, List, Optional, BinaryIO
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
|
||||
# Document processing imports
|
||||
import pypdf as PyPDF2
|
||||
from docx import Document as DocxDocument
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain_text_splitters import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
TokenTextSplitter,
|
||||
SentenceTransformersTokenTextSplitter
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkingStrategy:
|
||||
"""Configuration for document chunking"""
|
||||
strategy_type: str # 'fixed', 'semantic', 'hierarchical', 'hybrid'
|
||||
chunk_size: int # Target chunk size in tokens (optimized for BGE-M3: 512)
|
||||
chunk_overlap: int # Overlap between chunks (typically 128 for BGE-M3)
|
||||
separator_pattern: Optional[str] = None # Custom separator for splitting
|
||||
preserve_paragraphs: bool = True
|
||||
preserve_sentences: bool = True
|
||||
|
||||
|
||||
class DocumentProcessorBackend:
|
||||
"""
|
||||
STATELESS document chunking and processing backend.
|
||||
|
||||
Security principles:
|
||||
- NO persistence of user data
|
||||
- All processing in memory only
|
||||
- Immediate memory cleanup after processing
|
||||
- No caching of user content
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.supported_formats = [".pdf", ".docx", ".txt", ".md", ".html"]
|
||||
# BGE-M3 optimal settings
|
||||
self.default_chunk_size = 512 # tokens
|
||||
self.default_chunk_overlap = 128 # tokens
|
||||
self.model_name = "BAAI/bge-m3" # For tokenization
|
||||
logger.info("STATELESS document processor backend initialized")
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
content: bytes,
|
||||
document_type: str,
|
||||
strategy: Optional[ChunkingStrategy] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process document into chunks - STATELESS operation.
|
||||
|
||||
Args:
|
||||
content: Document content as bytes (will be cleared from memory)
|
||||
document_type: File type (.pdf, .docx, .txt, .md, .html)
|
||||
strategy: Chunking strategy configuration
|
||||
metadata: Optional metadata (will NOT include user content)
|
||||
|
||||
Returns:
|
||||
List of chunks with metadata (immediately returned, not stored)
|
||||
"""
|
||||
try:
|
||||
# Use default strategy if not provided
|
||||
if strategy is None:
|
||||
strategy = ChunkingStrategy(
|
||||
strategy_type='hybrid',
|
||||
chunk_size=self.default_chunk_size,
|
||||
chunk_overlap=self.default_chunk_overlap
|
||||
)
|
||||
|
||||
# Extract text based on document type (in memory)
|
||||
text = await self._extract_text_from_bytes(content, document_type)
|
||||
|
||||
# Clear original content from memory
|
||||
del content
|
||||
gc.collect()
|
||||
|
||||
# Apply chunking strategy
|
||||
if strategy.strategy_type == 'semantic':
|
||||
chunks = await self._semantic_chunking(text, strategy)
|
||||
elif strategy.strategy_type == 'hierarchical':
|
||||
chunks = await self._hierarchical_chunking(text, strategy)
|
||||
elif strategy.strategy_type == 'hybrid':
|
||||
chunks = await self._hybrid_chunking(text, strategy)
|
||||
else: # 'fixed'
|
||||
chunks = await self._fixed_chunking(text, strategy)
|
||||
|
||||
# Clear text from memory
|
||||
del text
|
||||
gc.collect()
|
||||
|
||||
# Add metadata without storing content
|
||||
processed_chunks = []
|
||||
for idx, chunk in enumerate(chunks):
|
||||
chunk_metadata = {
|
||||
"chunk_index": idx,
|
||||
"total_chunks": len(chunks),
|
||||
"chunking_strategy": strategy.strategy_type,
|
||||
"chunk_size_tokens": strategy.chunk_size,
|
||||
# Generate hash for deduplication without storing content
|
||||
"content_hash": hashlib.sha256(chunk.encode()).hexdigest()[:16]
|
||||
}
|
||||
|
||||
# Add non-sensitive metadata if provided
|
||||
if metadata:
|
||||
# Filter out any potential sensitive data
|
||||
safe_metadata = {
|
||||
k: v for k, v in metadata.items()
|
||||
if k in ['document_type', 'processing_timestamp', 'tenant_id']
|
||||
}
|
||||
chunk_metadata.update(safe_metadata)
|
||||
|
||||
processed_chunks.append({
|
||||
"text": chunk,
|
||||
"metadata": chunk_metadata
|
||||
})
|
||||
|
||||
logger.info(f"Processed document into {len(processed_chunks)} chunks (STATELESS)")
|
||||
|
||||
# Return immediately - no storage
|
||||
return processed_chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document: {e}")
|
||||
# Ensure memory is cleared even on error
|
||||
gc.collect()
|
||||
raise
|
||||
finally:
|
||||
# Always ensure memory cleanup
|
||||
gc.collect()
|
||||
|
||||
async def _extract_text_from_bytes(
|
||||
self,
|
||||
content: bytes,
|
||||
document_type: str
|
||||
) -> str:
|
||||
"""Extract text from document bytes - in memory only"""
|
||||
|
||||
try:
|
||||
if document_type == ".pdf":
|
||||
return await self._extract_pdf_text(io.BytesIO(content))
|
||||
elif document_type == ".docx":
|
||||
return await self._extract_docx_text(io.BytesIO(content))
|
||||
elif document_type == ".html":
|
||||
return await self._extract_html_text(content.decode('utf-8'))
|
||||
elif document_type in [".txt", ".md"]:
|
||||
return content.decode('utf-8')
|
||||
else:
|
||||
raise ValueError(f"Unsupported document type: {document_type}")
|
||||
finally:
|
||||
# Clear content from memory
|
||||
del content
|
||||
gc.collect()
|
||||
|
||||
async def _extract_pdf_text(self, file_stream: BinaryIO) -> str:
|
||||
"""Extract text from PDF - in memory"""
|
||||
text = ""
|
||||
try:
|
||||
pdf_reader = PyPDF2.PdfReader(file_stream)
|
||||
for page_num in range(len(pdf_reader.pages)):
|
||||
page = pdf_reader.pages[page_num]
|
||||
text += page.extract_text() + "\n"
|
||||
finally:
|
||||
file_stream.close()
|
||||
gc.collect()
|
||||
return text
|
||||
|
||||
async def _extract_docx_text(self, file_stream: BinaryIO) -> str:
|
||||
"""Extract text from DOCX - in memory"""
|
||||
text = ""
|
||||
try:
|
||||
doc = DocxDocument(file_stream)
|
||||
for paragraph in doc.paragraphs:
|
||||
text += paragraph.text + "\n"
|
||||
finally:
|
||||
file_stream.close()
|
||||
gc.collect()
|
||||
return text
|
||||
|
||||
async def _extract_html_text(self, html_content: str) -> str:
|
||||
"""Extract text from HTML - in memory"""
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
# Remove script and style elements
|
||||
for script in soup(["script", "style"]):
|
||||
script.decompose()
|
||||
text = soup.get_text()
|
||||
# Clean up whitespace
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = '\n'.join(chunk for chunk in chunks if chunk)
|
||||
return text
|
||||
|
||||
async def _semantic_chunking(
|
||||
self,
|
||||
text: str,
|
||||
strategy: ChunkingStrategy
|
||||
) -> List[str]:
|
||||
"""Semantic chunking using sentence boundaries"""
|
||||
splitter = SentenceTransformersTokenTextSplitter(
|
||||
model_name=self.model_name,
|
||||
chunk_size=strategy.chunk_size,
|
||||
chunk_overlap=strategy.chunk_overlap
|
||||
)
|
||||
return splitter.split_text(text)
|
||||
|
||||
async def _hierarchical_chunking(
|
||||
self,
|
||||
text: str,
|
||||
strategy: ChunkingStrategy
|
||||
) -> List[str]:
|
||||
"""Hierarchical chunking preserving document structure"""
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=strategy.chunk_size * 3, # Approximate token to char ratio
|
||||
chunk_overlap=strategy.chunk_overlap * 3,
|
||||
separators=["\n\n\n", "\n\n", "\n", ". ", " ", ""],
|
||||
keep_separator=True
|
||||
)
|
||||
return splitter.split_text(text)
|
||||
|
||||
async def _hybrid_chunking(
|
||||
self,
|
||||
text: str,
|
||||
strategy: ChunkingStrategy
|
||||
) -> List[str]:
|
||||
"""Hybrid chunking combining semantic and structural boundaries"""
|
||||
# First split by structure
|
||||
structural_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=strategy.chunk_size * 4,
|
||||
chunk_overlap=0,
|
||||
separators=["\n\n\n", "\n\n"],
|
||||
keep_separator=True
|
||||
)
|
||||
structural_chunks = structural_splitter.split_text(text)
|
||||
|
||||
# Then apply semantic splitting to each structural chunk
|
||||
final_chunks = []
|
||||
token_splitter = TokenTextSplitter(
|
||||
chunk_size=strategy.chunk_size,
|
||||
chunk_overlap=strategy.chunk_overlap
|
||||
)
|
||||
|
||||
for struct_chunk in structural_chunks:
|
||||
semantic_chunks = token_splitter.split_text(struct_chunk)
|
||||
final_chunks.extend(semantic_chunks)
|
||||
|
||||
return final_chunks
|
||||
|
||||
async def _fixed_chunking(
|
||||
self,
|
||||
text: str,
|
||||
strategy: ChunkingStrategy
|
||||
) -> List[str]:
|
||||
"""Fixed-size chunking with token boundaries"""
|
||||
splitter = TokenTextSplitter(
|
||||
chunk_size=strategy.chunk_size,
|
||||
chunk_overlap=strategy.chunk_overlap
|
||||
)
|
||||
return splitter.split_text(text)
|
||||
|
||||
async def validate_document(
|
||||
self,
|
||||
content_size: int,
|
||||
document_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate document before processing - no content stored.
|
||||
|
||||
Args:
|
||||
content_size: Size of document in bytes
|
||||
document_type: File extension
|
||||
|
||||
Returns:
|
||||
Validation result with any warnings
|
||||
"""
|
||||
MAX_SIZE = 50 * 1024 * 1024 # 50MB max
|
||||
|
||||
validation = {
|
||||
"valid": True,
|
||||
"warnings": [],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
# Check file size
|
||||
if content_size > MAX_SIZE:
|
||||
validation["valid"] = False
|
||||
validation["errors"].append(f"File size exceeds maximum of 50MB")
|
||||
elif content_size > 10 * 1024 * 1024: # Warning for files over 10MB
|
||||
validation["warnings"].append("Large file may take longer to process")
|
||||
|
||||
# Check document type
|
||||
if document_type not in self.supported_formats:
|
||||
validation["valid"] = False
|
||||
validation["errors"].append(f"Unsupported format: {document_type}")
|
||||
|
||||
return validation
|
||||
|
||||
async def check_health(self) -> Dict[str, Any]:
|
||||
"""Check document processor health - no user data exposed"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"supported_formats": self.supported_formats,
|
||||
"default_chunk_size": self.default_chunk_size,
|
||||
"default_chunk_overlap": self.default_chunk_overlap,
|
||||
"model": self.model_name,
|
||||
"stateless": True, # Confirm stateless operation
|
||||
"memory_cleared": True # Confirm memory management
|
||||
}
|
||||
471
apps/resource-cluster/app/core/backends/embedding_backend.py
Normal file
471
apps/resource-cluster/app/core/backends/embedding_backend.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""
|
||||
Embedding Model Backend
|
||||
|
||||
STATELESS embedding generation using BGE-M3 model hosted on GT's GPU clusters.
|
||||
All embeddings are generated in real-time - NO user data is stored.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import gc
|
||||
import hashlib
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
# import numpy as np # Temporarily disabled for Docker build
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingRequest:
|
||||
"""Request structure for embedding generation"""
|
||||
texts: List[str]
|
||||
model: str = "BAAI/bge-m3"
|
||||
batch_size: int = 32
|
||||
normalize: bool = True
|
||||
instruction: Optional[str] = None # For instruction-based embeddings
|
||||
|
||||
|
||||
class EmbeddingBackend:
|
||||
"""
|
||||
STATELESS embedding backend for BGE-M3 model.
|
||||
|
||||
Security principles:
|
||||
- NO persistence of embeddings or text
|
||||
- All processing via GT's internal GPU cluster
|
||||
- Immediate memory cleanup after generation
|
||||
- No caching of user content
|
||||
- Request signing and verification
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_name = "BAAI/bge-m3"
|
||||
self.embedding_dimensions = 1024 # BGE-M3 dimensions
|
||||
self.max_batch_size = 32
|
||||
self.max_sequence_length = 8192 # BGE-M3 supports up to 8192 tokens
|
||||
|
||||
# Determine endpoint based on configuration
|
||||
self.embedding_endpoint = self._get_embedding_endpoint()
|
||||
|
||||
# Timeout for embedding requests
|
||||
self.request_timeout = 60 # seconds for model loading
|
||||
|
||||
logger.info(f"STATELESS embedding backend initialized for {self.model_name}")
|
||||
logger.info(f"Using embedding endpoint: {self.embedding_endpoint}")
|
||||
|
||||
def _get_embedding_endpoint(self) -> str:
|
||||
"""
|
||||
Get the embedding endpoint based on configuration.
|
||||
Priority:
|
||||
1. Model registry from config sync (database-backed)
|
||||
2. Environment variables (BGE_M3_LOCAL_MODE, BGE_M3_EXTERNAL_ENDPOINT)
|
||||
3. Default local endpoint
|
||||
"""
|
||||
# Try to get configuration from model registry first (loaded from database)
|
||||
try:
|
||||
from app.services.model_service import default_model_service
|
||||
import asyncio
|
||||
|
||||
# Use the default model service instance (singleton) used by config sync
|
||||
model_service = default_model_service
|
||||
|
||||
# Try to get the model config synchronously (during initialization)
|
||||
# The get_model method is async, so we need to handle this carefully
|
||||
bge_m3_config = model_service.model_registry.get("BAAI/bge-m3")
|
||||
|
||||
if bge_m3_config:
|
||||
# Model registry stores endpoint as 'endpoint_url' and config as 'parameters'
|
||||
endpoint = bge_m3_config.get("endpoint_url")
|
||||
config = bge_m3_config.get("parameters", {})
|
||||
is_local_mode = config.get("is_local_mode", True)
|
||||
external_endpoint = config.get("external_endpoint")
|
||||
|
||||
logger.info(f"Found BGE-M3 in registry: endpoint_url={endpoint}, is_local_mode={is_local_mode}, external_endpoint={external_endpoint}")
|
||||
|
||||
if endpoint:
|
||||
logger.info(f"Using BGE-M3 endpoint from model registry (is_local_mode={is_local_mode}): {endpoint}")
|
||||
return endpoint
|
||||
else:
|
||||
logger.warning(f"BGE-M3 found in registry but endpoint_url is None/empty. Full config: {bge_m3_config}")
|
||||
else:
|
||||
available_models = list(model_service.model_registry.keys())
|
||||
logger.debug(f"BGE-M3 not found in model registry during init (expected on first startup). Available models: {available_models}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Model registry not yet available during startup (will be populated after config sync): {e}")
|
||||
|
||||
# Fall back to Settings fields (environment variables or .env file)
|
||||
is_local_mode = getattr(settings, 'bge_m3_local_mode', True)
|
||||
external_endpoint = getattr(settings, 'bge_m3_external_endpoint', None)
|
||||
|
||||
if not is_local_mode and external_endpoint:
|
||||
logger.info(f"Using external BGE-M3 endpoint from settings: {external_endpoint}")
|
||||
return external_endpoint
|
||||
|
||||
# Default to local endpoint
|
||||
local_endpoint = getattr(
|
||||
settings,
|
||||
'embedding_endpoint',
|
||||
'http://gentwo-vllm-embeddings:8000/v1/embeddings'
|
||||
)
|
||||
logger.info(f"Using local BGE-M3 endpoint: {local_endpoint}")
|
||||
return local_endpoint
|
||||
|
||||
async def update_endpoint_config(self, is_local_mode: bool, external_endpoint: str = None):
|
||||
"""
|
||||
Update the embedding endpoint configuration dynamically.
|
||||
This allows switching between local and external endpoints without restart.
|
||||
"""
|
||||
if is_local_mode:
|
||||
self.embedding_endpoint = getattr(
|
||||
settings,
|
||||
'embedding_endpoint',
|
||||
'http://gentwo-vllm-embeddings:8000/v1/embeddings'
|
||||
)
|
||||
else:
|
||||
if external_endpoint:
|
||||
self.embedding_endpoint = external_endpoint
|
||||
else:
|
||||
raise ValueError("External endpoint must be provided when not in local mode")
|
||||
|
||||
logger.info(f"BGE-M3 endpoint updated to: {self.embedding_endpoint}")
|
||||
logger.info(f"Mode: {'Local GT Edge' if is_local_mode else 'External API'}")
|
||||
|
||||
def refresh_endpoint_from_registry(self):
|
||||
"""
|
||||
Refresh the embedding endpoint from the model registry.
|
||||
Called by config sync when BGE-M3 configuration changes.
|
||||
"""
|
||||
logger.info(f"Refreshing embedding endpoint - current: {self.embedding_endpoint}")
|
||||
new_endpoint = self._get_embedding_endpoint()
|
||||
if new_endpoint != self.embedding_endpoint:
|
||||
logger.info(f"Refreshing BGE-M3 endpoint from {self.embedding_endpoint} to {new_endpoint}")
|
||||
self.embedding_endpoint = new_endpoint
|
||||
else:
|
||||
logger.info(f"BGE-M3 endpoint unchanged: {self.embedding_endpoint}")
|
||||
|
||||
async def generate_embeddings(
|
||||
self,
|
||||
texts: List[str],
|
||||
instruction: Optional[str] = None,
|
||||
tenant_id: str = None,
|
||||
request_id: str = None
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for texts using BGE-M3 - STATELESS operation.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed (will be cleared from memory)
|
||||
instruction: Optional instruction for query vs document embeddings
|
||||
tenant_id: Tenant ID for audit logging (not stored with data)
|
||||
request_id: Request ID for tracing
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (immediately returned, not stored)
|
||||
"""
|
||||
try:
|
||||
# Validate input
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
if len(texts) > self.max_batch_size:
|
||||
# Process in batches
|
||||
return await self._batch_process_embeddings(
|
||||
texts, instruction, tenant_id, request_id
|
||||
)
|
||||
|
||||
# Prepare request
|
||||
request_data = {
|
||||
"model": self.model_name,
|
||||
"input": texts,
|
||||
"encoding_format": "float",
|
||||
"dimensions": self.embedding_dimensions
|
||||
}
|
||||
|
||||
# Add instruction if provided (for query vs document distinction)
|
||||
if instruction:
|
||||
request_data["instruction"] = instruction
|
||||
|
||||
# Add metadata for audit (not stored with embeddings)
|
||||
metadata = {
|
||||
"tenant_id": tenant_id,
|
||||
"request_id": request_id,
|
||||
"text_count": len(texts),
|
||||
# Hash for deduplication without storing content
|
||||
"content_hash": hashlib.sha256(
|
||||
"".join(texts).encode()
|
||||
).hexdigest()[:16]
|
||||
}
|
||||
|
||||
# Call vLLM service - NO FALLBACKS
|
||||
embeddings = await self._call_embedding_service(request_data, metadata)
|
||||
|
||||
# Clear texts from memory immediately
|
||||
del texts
|
||||
gc.collect()
|
||||
|
||||
# Validate response
|
||||
if not embeddings or len(embeddings) == 0:
|
||||
raise ValueError("No embeddings returned from service")
|
||||
|
||||
# Normalize if needed
|
||||
if self._should_normalize():
|
||||
embeddings = self._normalize_embeddings(embeddings)
|
||||
|
||||
logger.info(
|
||||
f"Generated {len(embeddings)} embeddings (STATELESS) "
|
||||
f"for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
# Return immediately - no storage
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings: {e}")
|
||||
# Ensure memory is cleared even on error
|
||||
gc.collect()
|
||||
raise
|
||||
finally:
|
||||
# Always ensure memory cleanup
|
||||
gc.collect()
|
||||
|
||||
async def _batch_process_embeddings(
|
||||
self,
|
||||
texts: List[str],
|
||||
instruction: Optional[str],
|
||||
tenant_id: str,
|
||||
request_id: str
|
||||
) -> List[List[float]]:
|
||||
"""Process large text lists in batches using vLLM service"""
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), self.max_batch_size):
|
||||
batch = texts[i:i + self.max_batch_size]
|
||||
|
||||
# Prepare request for this batch
|
||||
request_data = {
|
||||
"model": self.model_name,
|
||||
"input": batch,
|
||||
"encoding_format": "float",
|
||||
"dimensions": self.embedding_dimensions
|
||||
}
|
||||
|
||||
if instruction:
|
||||
request_data["instruction"] = instruction
|
||||
|
||||
metadata = {
|
||||
"tenant_id": tenant_id,
|
||||
"request_id": f"{request_id}_batch_{i}",
|
||||
"text_count": len(batch),
|
||||
"content_hash": hashlib.sha256(
|
||||
"".join(batch).encode()
|
||||
).hexdigest()[:16]
|
||||
}
|
||||
|
||||
batch_embeddings = await self._call_embedding_service(request_data, metadata)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Clear batch from memory
|
||||
del batch
|
||||
gc.collect()
|
||||
|
||||
return all_embeddings
|
||||
|
||||
|
||||
async def _call_embedding_service(
|
||||
self,
|
||||
request_data: Dict[str, Any],
|
||||
metadata: Dict[str, Any]
|
||||
) -> List[List[float]]:
|
||||
"""Call internal GPU cluster embedding service"""
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
# Add capability token for authentication
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Tenant-ID": metadata.get("tenant_id", ""),
|
||||
"X-Request-ID": metadata.get("request_id", ""),
|
||||
# Authorization will be added by Resource Cluster
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
self.embedding_endpoint,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.request_timeout)
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise ValueError(
|
||||
f"Embedding service error: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
|
||||
# Extract embeddings from response
|
||||
if "data" in result:
|
||||
embeddings = [item["embedding"] for item in result["data"]]
|
||||
elif "embeddings" in result:
|
||||
embeddings = result["embeddings"]
|
||||
else:
|
||||
raise ValueError("Invalid embedding service response format")
|
||||
|
||||
return embeddings
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise ValueError(f"Embedding service timeout after {self.request_timeout}s")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling embedding service: {e}")
|
||||
raise
|
||||
|
||||
def _should_normalize(self) -> bool:
|
||||
"""Check if embeddings should be normalized"""
|
||||
# BGE-M3 embeddings are typically normalized for similarity search
|
||||
return True
|
||||
|
||||
def _normalize_embeddings(
|
||||
self,
|
||||
embeddings: List[List[float]]
|
||||
) -> List[List[float]]:
|
||||
"""Normalize embedding vectors to unit length"""
|
||||
normalized = []
|
||||
|
||||
for embedding in embeddings:
|
||||
# Simple normalization without numpy (for now)
|
||||
import math
|
||||
|
||||
# Calculate norm
|
||||
norm = math.sqrt(sum(x * x for x in embedding))
|
||||
|
||||
if norm > 0:
|
||||
normalized_vec = [x / norm for x in embedding]
|
||||
else:
|
||||
normalized_vec = embedding[:]
|
||||
|
||||
normalized.append(normalized_vec)
|
||||
|
||||
return normalized
|
||||
|
||||
async def generate_query_embeddings(
|
||||
self,
|
||||
queries: List[str],
|
||||
tenant_id: str = None,
|
||||
request_id: str = None
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings specifically for queries.
|
||||
BGE-M3 can use different instructions for queries vs documents.
|
||||
"""
|
||||
# For BGE-M3, queries can use a specific instruction
|
||||
instruction = "Represent this sentence for searching relevant passages: "
|
||||
return await self.generate_embeddings(
|
||||
queries, instruction, tenant_id, request_id
|
||||
)
|
||||
|
||||
async def generate_document_embeddings(
|
||||
self,
|
||||
documents: List[str],
|
||||
tenant_id: str = None,
|
||||
request_id: str = None
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings specifically for documents.
|
||||
BGE-M3 can use different instructions for documents vs queries.
|
||||
"""
|
||||
# For BGE-M3, documents typically don't need special instruction
|
||||
return await self.generate_embeddings(
|
||||
documents, None, tenant_id, request_id
|
||||
)
|
||||
|
||||
async def validate_texts(
|
||||
self,
|
||||
texts: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate texts before embedding - no content stored.
|
||||
|
||||
Args:
|
||||
texts: List of texts to validate
|
||||
|
||||
Returns:
|
||||
Validation result with any warnings
|
||||
"""
|
||||
validation = {
|
||||
"valid": True,
|
||||
"warnings": [],
|
||||
"errors": [],
|
||||
"stats": {
|
||||
"total_texts": len(texts),
|
||||
"max_length": 0,
|
||||
"avg_length": 0
|
||||
}
|
||||
}
|
||||
|
||||
if not texts:
|
||||
validation["valid"] = False
|
||||
validation["errors"].append("No texts provided")
|
||||
return validation
|
||||
|
||||
# Check text lengths
|
||||
lengths = [len(text) for text in texts]
|
||||
validation["stats"]["max_length"] = max(lengths)
|
||||
validation["stats"]["avg_length"] = sum(lengths) // len(lengths)
|
||||
|
||||
# BGE-M3 max sequence length check (approximate)
|
||||
max_chars = self.max_sequence_length * 4 # Rough char to token ratio
|
||||
|
||||
for i, length in enumerate(lengths):
|
||||
if length > max_chars:
|
||||
validation["warnings"].append(
|
||||
f"Text {i} may exceed model's max sequence length"
|
||||
)
|
||||
elif length == 0:
|
||||
validation["errors"].append(f"Text {i} is empty")
|
||||
validation["valid"] = False
|
||||
|
||||
# Batch size check
|
||||
if len(texts) > self.max_batch_size * 10:
|
||||
validation["warnings"].append(
|
||||
f"Large batch ({len(texts)} texts) will be processed in chunks"
|
||||
)
|
||||
|
||||
return validation
|
||||
|
||||
async def check_health(self) -> Dict[str, Any]:
|
||||
"""Check embedding backend health - no user data exposed"""
|
||||
try:
|
||||
# Test connection to vLLM service
|
||||
test_text = ["Health check test"]
|
||||
test_embeddings = await self.generate_embeddings(
|
||||
test_text,
|
||||
tenant_id="health_check",
|
||||
request_id="health_check"
|
||||
)
|
||||
|
||||
health_status = {
|
||||
"status": "healthy",
|
||||
"model": self.model_name,
|
||||
"dimensions": self.embedding_dimensions,
|
||||
"max_batch_size": self.max_batch_size,
|
||||
"max_sequence_length": self.max_sequence_length,
|
||||
"endpoint": self.embedding_endpoint,
|
||||
"stateless": True,
|
||||
"memory_cleared": True,
|
||||
"vllm_service_connected": len(test_embeddings) > 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
health_status = {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"model": self.model_name,
|
||||
"endpoint": self.embedding_endpoint
|
||||
}
|
||||
|
||||
return health_status
|
||||
780
apps/resource-cluster/app/core/backends/groq_proxy.py
Normal file
780
apps/resource-cluster/app/core/backends/groq_proxy.py
Normal file
@@ -0,0 +1,780 @@
|
||||
"""
|
||||
Groq Cloud LLM Proxy Backend
|
||||
|
||||
Provides high-availability LLM inference through Groq Cloud with:
|
||||
- HAProxy load balancing across multiple endpoints
|
||||
- Automatic failover handled by HAProxy
|
||||
- Token usage tracking and cost calculation
|
||||
- Streaming response support
|
||||
- Circuit breaker pattern for enhanced reliability
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
try:
|
||||
from groq import AsyncGroq
|
||||
GROQ_AVAILABLE = True
|
||||
except ImportError:
|
||||
# Groq not available in development mode
|
||||
AsyncGroq = None
|
||||
GROQ_AVAILABLE = False
|
||||
import logging
|
||||
|
||||
from app.core.config import get_settings, get_model_configs
|
||||
from app.services.model_service import get_model_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
# Groq Compound tool pricing (per request/execution)
|
||||
# Source: https://groq.com/pricing (Dec 2, 2025)
|
||||
COMPOUND_TOOL_PRICES = {
|
||||
# Web Search variants
|
||||
"search": 0.008, # API returns "search" for web search
|
||||
"web_search": 0.008, # $8 per 1K = $0.008 per request (Advanced Search)
|
||||
"advanced_search": 0.008, # $8 per 1K requests
|
||||
"basic_search": 0.005, # $5 per 1K requests
|
||||
# Other tools
|
||||
"visit_website": 0.001, # $1 per 1K requests
|
||||
"python": 0.00005, # API returns "python" for code execution
|
||||
"code_interpreter": 0.00005, # Alternative API identifier
|
||||
"code_execution": 0.00005, # Alias for backwards compatibility
|
||||
"browser_automation": 0.00002, # $0.08/hr ≈ $0.00002 per execution
|
||||
}
|
||||
|
||||
# Model pricing per million tokens (input/output)
|
||||
# Source: https://groq.com/pricing (Dec 2, 2025)
|
||||
GROQ_MODEL_PRICES = {
|
||||
"llama-3.3-70b-versatile": {"input": 0.59, "output": 0.79},
|
||||
"llama-3.1-8b-instant": {"input": 0.05, "output": 0.08},
|
||||
"llama-4-maverick-17b-128e-instruct": {"input": 0.20, "output": 0.60},
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct": {"input": 0.20, "output": 0.60},
|
||||
"llama-4-scout-17b-16e-instruct": {"input": 0.11, "output": 0.34},
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct": {"input": 0.11, "output": 0.34},
|
||||
"llama-guard-4-12b": {"input": 0.20, "output": 0.20},
|
||||
"meta-llama/llama-guard-4-12b": {"input": 0.20, "output": 0.20},
|
||||
"gpt-oss-120b": {"input": 0.15, "output": 0.60},
|
||||
"openai/gpt-oss-120b": {"input": 0.15, "output": 0.60},
|
||||
"gpt-oss-20b": {"input": 0.075, "output": 0.30},
|
||||
"openai/gpt-oss-20b": {"input": 0.075, "output": 0.30},
|
||||
"kimi-k2-instruct-0905": {"input": 1.00, "output": 3.00},
|
||||
"moonshotai/kimi-k2-instruct-0905": {"input": 1.00, "output": 3.00},
|
||||
"qwen3-32b": {"input": 0.29, "output": 0.59},
|
||||
# Compound models - 50/50 blended pricing from underlying models
|
||||
# compound: GPT-OSS-120B ($0.15/$0.60) + Llama 4 Scout ($0.11/$0.34) = $0.13/$0.47
|
||||
"compound": {"input": 0.13, "output": 0.47},
|
||||
"groq/compound": {"input": 0.13, "output": 0.47},
|
||||
"compound-beta": {"input": 0.13, "output": 0.47},
|
||||
# compound-mini: GPT-OSS-120B ($0.15/$0.60) + Llama 3.3 70B ($0.59/$0.79) = $0.37/$0.695
|
||||
"compound-mini": {"input": 0.37, "output": 0.695},
|
||||
"groq/compound-mini": {"input": 0.37, "output": 0.695},
|
||||
"compound-mini-beta": {"input": 0.37, "output": 0.695},
|
||||
}
|
||||
|
||||
|
||||
class GroqProxyBackend:
|
||||
"""LLM inference via Groq Cloud with HAProxy load balancing"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.client = None
|
||||
self.usage_metrics = {}
|
||||
self.circuit_breaker_status = {}
|
||||
self._initialize_client()
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize Groq client to use HAProxy load balancer"""
|
||||
if not GROQ_AVAILABLE:
|
||||
logger.warning("Groq client not available - running in development mode")
|
||||
return
|
||||
|
||||
if self.settings.groq_api_key:
|
||||
# Use HAProxy load balancer instead of direct Groq API
|
||||
haproxy_endpoint = self.settings.haproxy_groq_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local"
|
||||
|
||||
# Initialize client with HAProxy endpoint
|
||||
self.client = AsyncGroq(
|
||||
api_key=self.settings.groq_api_key,
|
||||
base_url=haproxy_endpoint,
|
||||
timeout=httpx.Timeout(30.0), # Increased timeout for load balancing
|
||||
max_retries=1 # Let HAProxy handle retries
|
||||
)
|
||||
|
||||
# Initialize circuit breaker
|
||||
self.circuit_breaker_status = {
|
||||
"state": "closed", # closed, open, half_open
|
||||
"failure_count": 0,
|
||||
"last_failure_time": None,
|
||||
"failure_threshold": 5,
|
||||
"recovery_timeout": 60 # seconds
|
||||
}
|
||||
|
||||
logger.info(f"Initialized Groq client with HAProxy endpoint: {haproxy_endpoint}")
|
||||
|
||||
async def execute_inference(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str = "llama-3.1-70b-versatile",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4000,
|
||||
stream: bool = False,
|
||||
user_id: str = None,
|
||||
tenant_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute LLM inference with HAProxy load balancing and circuit breaker"""
|
||||
|
||||
# Check circuit breaker
|
||||
if not await self._is_circuit_closed():
|
||||
raise Exception("Circuit breaker is open - service temporarily unavailable")
|
||||
|
||||
# Validate model and get configuration
|
||||
model_configs = get_model_configs(tenant_id)
|
||||
model_config = model_configs.get("groq", {}).get(model)
|
||||
if not model_config:
|
||||
# Try to get from model service registry
|
||||
model_service = get_model_service(tenant_id)
|
||||
model_info = await model_service.get_model(model)
|
||||
if not model_info:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
model_config = {
|
||||
"max_tokens": model_info["performance"]["max_tokens"],
|
||||
"cost_per_1k_tokens": model_info["performance"]["cost_per_1k_tokens"],
|
||||
"supports_streaming": model_info["capabilities"].get("streaming", False)
|
||||
}
|
||||
|
||||
# Apply token limits
|
||||
max_tokens = min(max_tokens, model_config["max_tokens"])
|
||||
|
||||
# Prepare messages
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
# Get tenant-specific API key
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required for Groq inference")
|
||||
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
client = self._get_client(api_key)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if stream:
|
||||
return await self._stream_inference(
|
||||
messages, model, temperature, max_tokens, user_id, tenant_id, client
|
||||
)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False
|
||||
)
|
||||
|
||||
# Track successful usage
|
||||
latency = (time.time() - start_time) * 1000
|
||||
await self._track_usage(
|
||||
user_id, tenant_id, model,
|
||||
response.usage.total_tokens if response.usage else 0,
|
||||
latency, model_config["cost_per_1k_tokens"]
|
||||
)
|
||||
|
||||
# Track in model service
|
||||
model_service = get_model_service(tenant_id)
|
||||
await model_service.track_model_usage(
|
||||
model_id=model,
|
||||
success=True,
|
||||
latency_ms=latency
|
||||
)
|
||||
|
||||
# Reset circuit breaker on success
|
||||
await self._record_success()
|
||||
|
||||
return {
|
||||
"content": response.choices[0].message.content,
|
||||
"model": model,
|
||||
"usage": {
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage else 0,
|
||||
"cost_cents": self._calculate_cost(
|
||||
response.usage.total_tokens if response.usage else 0,
|
||||
model_config["cost_per_1k_tokens"]
|
||||
)
|
||||
},
|
||||
"latency_ms": latency,
|
||||
"load_balanced": True,
|
||||
"haproxy_backend": "groq_general_backend"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HAProxy Groq inference failed: {e}")
|
||||
|
||||
# Track failure in model service
|
||||
await model_service.track_model_usage(
|
||||
model_id=model,
|
||||
success=False
|
||||
)
|
||||
|
||||
# Record failure for circuit breaker
|
||||
await self._record_failure()
|
||||
|
||||
# Re-raise the exception - no client-side fallback needed
|
||||
# HAProxy handles all failover logic
|
||||
raise Exception(f"Groq inference failed (via HAProxy): {str(e)}")
|
||||
|
||||
async def _stream_inference(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
client: AsyncGroq = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream LLM inference responses"""
|
||||
|
||||
model_configs = get_model_configs(tenant_id)
|
||||
model_config = model_configs.get("groq", {}).get(model)
|
||||
start_time = time.time()
|
||||
total_tokens = 0
|
||||
|
||||
try:
|
||||
# Use provided client or get tenant-specific client
|
||||
if not client:
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
client = self._get_client(api_key)
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
total_tokens += len(content.split()) # Approximate token count
|
||||
|
||||
# Yield SSE formatted data
|
||||
yield f"data: {json.dumps({'content': content})}\n\n"
|
||||
|
||||
# Track usage after streaming completes
|
||||
latency = (time.time() - start_time) * 1000
|
||||
await self._track_usage(
|
||||
user_id, tenant_id, model,
|
||||
total_tokens, latency,
|
||||
model_config["cost_per_1k_tokens"]
|
||||
)
|
||||
|
||||
# Send completion signal
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming inference error: {e}")
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
|
||||
async def check_health(self) -> Dict[str, Any]:
|
||||
"""Check health of HAProxy load balancer and circuit breaker status"""
|
||||
|
||||
try:
|
||||
# Check HAProxy health via stats endpoint
|
||||
haproxy_stats_url = self.settings.haproxy_stats_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local:8404/stats"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
haproxy_stats_url,
|
||||
timeout=5.0,
|
||||
auth=("admin", "gt2_haproxy_stats_password")
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Parse HAProxy stats (simplified)
|
||||
stats_healthy = "UP" in response.text
|
||||
|
||||
return {
|
||||
"haproxy_load_balancer": {
|
||||
"healthy": stats_healthy,
|
||||
"stats_accessible": True,
|
||||
"last_check": datetime.utcnow().isoformat()
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": self.circuit_breaker_status["state"],
|
||||
"failure_count": self.circuit_breaker_status["failure_count"],
|
||||
"last_failure": self.circuit_breaker_status["last_failure_time"].isoformat() if self.circuit_breaker_status["last_failure_time"] else None
|
||||
},
|
||||
"groq_endpoints": {
|
||||
"managed_by": "haproxy",
|
||||
"failover_handled_by": "haproxy"
|
||||
}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"haproxy_load_balancer": {
|
||||
"healthy": False,
|
||||
"error": f"Stats endpoint returned {response.status_code}",
|
||||
"last_check": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"haproxy_load_balancer": {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"last_check": datetime.utcnow().isoformat()
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": self.circuit_breaker_status["state"],
|
||||
"failure_count": self.circuit_breaker_status["failure_count"]
|
||||
}
|
||||
}
|
||||
|
||||
async def _is_circuit_closed(self) -> bool:
|
||||
"""Check if circuit breaker allows requests"""
|
||||
|
||||
if self.circuit_breaker_status["state"] == "closed":
|
||||
return True
|
||||
|
||||
if self.circuit_breaker_status["state"] == "open":
|
||||
# Check if recovery timeout has passed
|
||||
if self.circuit_breaker_status["last_failure_time"]:
|
||||
time_since_failure = (datetime.utcnow() - self.circuit_breaker_status["last_failure_time"]).total_seconds()
|
||||
if time_since_failure > self.circuit_breaker_status["recovery_timeout"]:
|
||||
# Move to half-open state
|
||||
self.circuit_breaker_status["state"] = "half_open"
|
||||
logger.info("Circuit breaker moved to half-open state")
|
||||
return True
|
||||
return False
|
||||
|
||||
if self.circuit_breaker_status["state"] == "half_open":
|
||||
# Allow limited requests in half-open state
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _record_success(self):
|
||||
"""Record successful request for circuit breaker"""
|
||||
|
||||
if self.circuit_breaker_status["state"] == "half_open":
|
||||
# Success in half-open state closes the circuit
|
||||
self.circuit_breaker_status["state"] = "closed"
|
||||
self.circuit_breaker_status["failure_count"] = 0
|
||||
logger.info("Circuit breaker closed after successful request")
|
||||
|
||||
# Reset failure count on any success
|
||||
self.circuit_breaker_status["failure_count"] = 0
|
||||
|
||||
async def _record_failure(self):
|
||||
"""Record failed request for circuit breaker"""
|
||||
|
||||
self.circuit_breaker_status["failure_count"] += 1
|
||||
self.circuit_breaker_status["last_failure_time"] = datetime.utcnow()
|
||||
|
||||
if self.circuit_breaker_status["failure_count"] >= self.circuit_breaker_status["failure_threshold"]:
|
||||
if self.circuit_breaker_status["state"] in ["closed", "half_open"]:
|
||||
self.circuit_breaker_status["state"] = "open"
|
||||
logger.warning(f"Circuit breaker opened after {self.circuit_breaker_status['failure_count']} failures")
|
||||
|
||||
async def _track_usage(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
model: str,
|
||||
tokens: int,
|
||||
latency: float,
|
||||
cost_per_1k: float
|
||||
):
|
||||
"""Track usage metrics for billing and monitoring"""
|
||||
|
||||
# Create usage key
|
||||
usage_key = f"{tenant_id}:{user_id}:{model}"
|
||||
|
||||
# Initialize metrics if not exists
|
||||
if usage_key not in self.usage_metrics:
|
||||
self.usage_metrics[usage_key] = {
|
||||
"total_tokens": 0,
|
||||
"total_requests": 0,
|
||||
"total_cost_cents": 0,
|
||||
"average_latency": 0
|
||||
}
|
||||
|
||||
# Update metrics
|
||||
metrics = self.usage_metrics[usage_key]
|
||||
metrics["total_tokens"] += tokens
|
||||
metrics["total_requests"] += 1
|
||||
metrics["total_cost_cents"] += self._calculate_cost(tokens, cost_per_1k)
|
||||
|
||||
# Update average latency
|
||||
prev_avg = metrics["average_latency"]
|
||||
prev_count = metrics["total_requests"] - 1
|
||||
metrics["average_latency"] = (prev_avg * prev_count + latency) / metrics["total_requests"]
|
||||
|
||||
# Log high-level metrics
|
||||
if metrics["total_requests"] % 100 == 0:
|
||||
logger.info(f"Usage milestone for {usage_key}: {metrics}")
|
||||
|
||||
def _calculate_cost(self, tokens: int, cost_per_1k: float) -> int:
|
||||
"""Calculate cost in cents"""
|
||||
return int((tokens / 1000) * cost_per_1k * 100)
|
||||
|
||||
def _calculate_compound_cost(self, response_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate detailed cost breakdown for Groq Compound responses.
|
||||
|
||||
Compound API returns usage_breakdown with per-model token counts
|
||||
and executed_tools list showing which tools were called.
|
||||
|
||||
Returns:
|
||||
Dict with total cost in dollars and detailed breakdown
|
||||
"""
|
||||
total_cost = 0.0
|
||||
breakdown = {"models": [], "tools": [], "total_cost_dollars": 0.0, "total_cost_cents": 0}
|
||||
|
||||
# Parse usage_breakdown for per-model token costs
|
||||
usage_breakdown = response_data.get("usage_breakdown", {})
|
||||
models_usage = usage_breakdown.get("models", [])
|
||||
|
||||
for model_usage in models_usage:
|
||||
model_name = model_usage.get("model", "")
|
||||
usage = model_usage.get("usage", {})
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
# Get model pricing (try multiple name formats)
|
||||
model_prices = GROQ_MODEL_PRICES.get(model_name)
|
||||
if not model_prices:
|
||||
# Try without provider prefix
|
||||
short_name = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
model_prices = GROQ_MODEL_PRICES.get(short_name, {"input": 0.15, "output": 0.60})
|
||||
|
||||
# Calculate cost per million tokens
|
||||
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
|
||||
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
|
||||
model_total = input_cost + output_cost
|
||||
|
||||
breakdown["models"].append({
|
||||
"model": model_name,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"input_cost_dollars": round(input_cost, 6),
|
||||
"output_cost_dollars": round(output_cost, 6),
|
||||
"total_cost_dollars": round(model_total, 6)
|
||||
})
|
||||
total_cost += model_total
|
||||
|
||||
# Parse executed_tools for tool costs
|
||||
executed_tools = response_data.get("executed_tools", [])
|
||||
|
||||
for tool in executed_tools:
|
||||
# Handle both string and dict formats
|
||||
tool_name = tool if isinstance(tool, str) else tool.get("name", "unknown")
|
||||
tool_cost = COMPOUND_TOOL_PRICES.get(tool_name.lower(), 0.008) # Default to advanced search
|
||||
|
||||
breakdown["tools"].append({
|
||||
"tool": tool_name,
|
||||
"cost_dollars": round(tool_cost, 6)
|
||||
})
|
||||
total_cost += tool_cost
|
||||
|
||||
breakdown["total_cost_dollars"] = round(total_cost, 6)
|
||||
breakdown["total_cost_cents"] = int(total_cost * 100)
|
||||
|
||||
return breakdown
|
||||
|
||||
def _is_compound_model(self, model: str) -> bool:
|
||||
"""Check if model is a Groq Compound model"""
|
||||
model_lower = model.lower()
|
||||
return "compound" in model_lower or model_lower.startswith("groq/compound")
|
||||
|
||||
async def get_available_models(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of available Groq models with their configurations"""
|
||||
models = []
|
||||
|
||||
model_configs = get_model_configs()
|
||||
for model_id, config in model_configs.get("groq", {}).items():
|
||||
models.append({
|
||||
"id": model_id,
|
||||
"name": model_id.replace("-", " ").title(),
|
||||
"provider": "groq",
|
||||
"max_tokens": config["max_tokens"],
|
||||
"cost_per_1k_tokens": config["cost_per_1k_tokens"],
|
||||
"supports_streaming": config["supports_streaming"],
|
||||
"supports_function_calling": config["supports_function_calling"]
|
||||
})
|
||||
|
||||
return models
|
||||
|
||||
async def execute_inference_with_messages(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str = "llama-3.1-70b-versatile",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4000,
|
||||
stream: bool = False,
|
||||
user_id: str = None,
|
||||
tenant_id: str = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute LLM inference using messages format (conversation style)"""
|
||||
|
||||
# Check circuit breaker
|
||||
if not await self._is_circuit_closed():
|
||||
raise Exception("Circuit breaker is open - service temporarily unavailable")
|
||||
|
||||
# Validate model and get configuration
|
||||
model_configs = get_model_configs(tenant_id)
|
||||
model_config = model_configs.get("groq", {}).get(model)
|
||||
if not model_config:
|
||||
# Try to get from model service registry
|
||||
model_service = get_model_service(tenant_id)
|
||||
model_info = await model_service.get_model(model)
|
||||
if not model_info:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
model_config = {
|
||||
"max_tokens": model_info["performance"]["max_tokens"],
|
||||
"cost_per_1k_tokens": model_info["performance"]["cost_per_1k_tokens"],
|
||||
"supports_streaming": model_info["capabilities"].get("streaming", False)
|
||||
}
|
||||
|
||||
# Apply token limits
|
||||
max_tokens = min(max_tokens, model_config["max_tokens"])
|
||||
|
||||
try:
|
||||
# Get tenant-specific API key
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required for Groq inference")
|
||||
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
client = self._get_client(api_key)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Translate GT 2.0 "agent" role to OpenAI/Groq "assistant" for external API compatibility
|
||||
# Use dictionary unpacking to preserve ALL fields including tool_call_id
|
||||
external_messages = []
|
||||
for msg in messages:
|
||||
external_msg = {
|
||||
**msg, # Preserve ALL fields including tool_call_id, tool_calls, etc.
|
||||
"role": "assistant" if msg.get("role") == "agent" else msg.get("role")
|
||||
}
|
||||
external_messages.append(external_msg)
|
||||
|
||||
if stream:
|
||||
return await self._stream_inference_with_messages(
|
||||
external_messages, model, temperature, max_tokens, user_id, tenant_id, client
|
||||
)
|
||||
else:
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": external_messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
if tool_choice:
|
||||
request_params["tool_choice"] = tool_choice
|
||||
|
||||
# Debug: Log messages being sent to Groq
|
||||
logger.info(f"🔧 Sending {len(external_messages)} messages to Groq API")
|
||||
for i, msg in enumerate(external_messages):
|
||||
if msg.get("role") == "tool":
|
||||
logger.info(f"🔧 Groq Message {i}: role=tool, tool_call_id={msg.get('tool_call_id')}")
|
||||
else:
|
||||
logger.info(f"🔧 Groq Message {i}: role={msg.get('role')}, has_tool_calls={bool(msg.get('tool_calls'))}")
|
||||
|
||||
response = await client.chat.completions.create(**request_params)
|
||||
|
||||
# Track successful usage
|
||||
latency = (time.time() - start_time) * 1000
|
||||
await self._track_usage(
|
||||
user_id, tenant_id, model,
|
||||
response.usage.total_tokens if response.usage else 0,
|
||||
latency, model_config["cost_per_1k_tokens"]
|
||||
)
|
||||
|
||||
# Track in model service
|
||||
model_service = get_model_service(tenant_id)
|
||||
await model_service.track_model_usage(
|
||||
model_id=model,
|
||||
success=True,
|
||||
latency_ms=latency
|
||||
)
|
||||
|
||||
# Reset circuit breaker on success
|
||||
await self._record_success()
|
||||
|
||||
# Build base response
|
||||
result = {
|
||||
"content": response.choices[0].message.content,
|
||||
"model": model,
|
||||
"usage": {
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage else 0,
|
||||
"cost_cents": self._calculate_cost(
|
||||
response.usage.total_tokens if response.usage else 0,
|
||||
model_config["cost_per_1k_tokens"]
|
||||
)
|
||||
},
|
||||
"latency_ms": latency,
|
||||
"load_balanced": True,
|
||||
"haproxy_backend": "groq_general_backend"
|
||||
}
|
||||
|
||||
# For Compound models, extract and calculate detailed cost breakdown
|
||||
if self._is_compound_model(model):
|
||||
# Convert response to dict for processing
|
||||
response_dict = response.model_dump() if hasattr(response, 'model_dump') else {}
|
||||
|
||||
# Extract usage_breakdown and executed_tools if present
|
||||
usage_breakdown = getattr(response, 'usage_breakdown', None)
|
||||
executed_tools = getattr(response, 'executed_tools', None)
|
||||
|
||||
if usage_breakdown or executed_tools:
|
||||
compound_data = {
|
||||
"usage_breakdown": usage_breakdown if isinstance(usage_breakdown, dict) else {},
|
||||
"executed_tools": executed_tools if isinstance(executed_tools, list) else []
|
||||
}
|
||||
|
||||
# Calculate detailed cost breakdown
|
||||
cost_breakdown = self._calculate_compound_cost(compound_data)
|
||||
|
||||
# Add compound-specific data to response
|
||||
result["usage_breakdown"] = compound_data.get("usage_breakdown", {})
|
||||
result["executed_tools"] = compound_data.get("executed_tools", [])
|
||||
result["cost_breakdown"] = cost_breakdown
|
||||
|
||||
# Update cost_cents with accurate compound calculation
|
||||
if cost_breakdown["total_cost_cents"] > 0:
|
||||
result["usage"]["cost_cents"] = cost_breakdown["total_cost_cents"]
|
||||
|
||||
logger.info(f"Compound model cost breakdown: {cost_breakdown}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HAProxy Groq inference with messages failed: {e}")
|
||||
|
||||
# Track failure in model service
|
||||
await model_service.track_model_usage(
|
||||
model_id=model,
|
||||
success=False
|
||||
)
|
||||
|
||||
# Record failure for circuit breaker
|
||||
await self._record_failure()
|
||||
|
||||
# Re-raise the exception
|
||||
raise Exception(f"Groq inference with messages failed (via HAProxy): {str(e)}")
|
||||
|
||||
async def _stream_inference_with_messages(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
client: AsyncGroq = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream LLM inference responses using messages format"""
|
||||
|
||||
model_configs = get_model_configs(tenant_id)
|
||||
model_config = model_configs.get("groq", {}).get(model)
|
||||
start_time = time.time()
|
||||
total_tokens = 0
|
||||
|
||||
try:
|
||||
# Use provided client or get tenant-specific client
|
||||
if not client:
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
client = self._get_client(api_key)
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
total_tokens += len(content.split()) # Approximate token count
|
||||
|
||||
# Yield just the content (SSE formatting handled by caller)
|
||||
yield content
|
||||
|
||||
# Track usage after streaming completes
|
||||
latency = (time.time() - start_time) * 1000
|
||||
await self._track_usage(
|
||||
user_id, tenant_id, model,
|
||||
total_tokens, latency,
|
||||
model_config["cost_per_1k_tokens"] if model_config else 0.0
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming inference with messages error: {e}")
|
||||
raise e
|
||||
|
||||
async def _get_tenant_api_key(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get API key for tenant from Control Panel database.
|
||||
|
||||
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
|
||||
API keys are managed in Control Panel and fetched via internal API.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant domain string from X-Tenant-ID header
|
||||
|
||||
Returns:
|
||||
Decrypted Groq API key
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key configured (results in HTTP 503 to client)
|
||||
"""
|
||||
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
|
||||
|
||||
client = get_api_key_client()
|
||||
|
||||
try:
|
||||
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="groq")
|
||||
return key_info["api_key"]
|
||||
except APIKeyNotConfiguredError as e:
|
||||
logger.error(f"No Groq API key for tenant '{tenant_id}': {e}")
|
||||
raise ValueError(str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Control Panel error: {e}")
|
||||
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
|
||||
|
||||
def _get_client(self, api_key: str) -> AsyncGroq:
|
||||
"""Get Groq client with specified API key"""
|
||||
if not GROQ_AVAILABLE:
|
||||
raise Exception("Groq client not available in development mode")
|
||||
|
||||
haproxy_endpoint = self.settings.haproxy_groq_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local"
|
||||
|
||||
return AsyncGroq(
|
||||
api_key=api_key,
|
||||
base_url=haproxy_endpoint,
|
||||
timeout=httpx.Timeout(30.0),
|
||||
max_retries=1
|
||||
)
|
||||
407
apps/resource-cluster/app/core/backends/nvidia_proxy.py
Normal file
407
apps/resource-cluster/app/core/backends/nvidia_proxy.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
NVIDIA NIM LLM Proxy Backend
|
||||
|
||||
Provides LLM inference through NVIDIA NIM with:
|
||||
- OpenAI-compatible API format (build.nvidia.com)
|
||||
- Token usage tracking and cost calculation
|
||||
- Streaming response support
|
||||
- Circuit breaker pattern for enhanced reliability
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NVIDIA NIM Model pricing per million tokens (input/output)
|
||||
# Source: build.nvidia.com (Dec 2025 pricing estimates)
|
||||
# Note: Actual pricing may vary - check build.nvidia.com for current rates
|
||||
NVIDIA_MODEL_PRICES = {
|
||||
# Llama Nemotron family
|
||||
"nvidia/llama-3.1-nemotron-ultra-253b-v1": {"input": 2.0, "output": 6.0},
|
||||
"nvidia/llama-3.1-nemotron-super-49b-v1": {"input": 0.5, "output": 1.5},
|
||||
"nvidia/llama-3.1-nemotron-nano-8b-v1": {"input": 0.1, "output": 0.3},
|
||||
# Standard Llama models via NIM
|
||||
"meta/llama-3.1-8b-instruct": {"input": 0.1, "output": 0.3},
|
||||
"meta/llama-3.1-70b-instruct": {"input": 0.5, "output": 1.0},
|
||||
"meta/llama-3.1-405b-instruct": {"input": 2.0, "output": 6.0},
|
||||
# Mistral models
|
||||
"mistralai/mistral-7b-instruct-v0.3": {"input": 0.1, "output": 0.2},
|
||||
"mistralai/mixtral-8x7b-instruct-v0.1": {"input": 0.3, "output": 0.6},
|
||||
# Default fallback
|
||||
"default": {"input": 0.5, "output": 1.5},
|
||||
}
|
||||
|
||||
|
||||
class NvidiaProxyBackend:
|
||||
"""LLM inference via NVIDIA NIM with OpenAI-compatible API"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.base_url = getattr(self.settings, 'nvidia_nim_endpoint', None) or "https://integrate.api.nvidia.com/v1"
|
||||
self.usage_metrics = {}
|
||||
self.circuit_breaker_status = {
|
||||
"state": "closed", # closed, open, half_open
|
||||
"failure_count": 0,
|
||||
"last_failure_time": None,
|
||||
"failure_threshold": 5,
|
||||
"recovery_timeout": 60 # seconds
|
||||
}
|
||||
logger.info(f"Initialized NVIDIA NIM backend with endpoint: {self.base_url}")
|
||||
|
||||
async def _get_tenant_api_key(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get API key for tenant from Control Panel database.
|
||||
|
||||
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
|
||||
API keys are managed in Control Panel and fetched via internal API.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant domain string from X-Tenant-ID header
|
||||
|
||||
Returns:
|
||||
Decrypted NVIDIA API key
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key configured (results in HTTP 503 to client)
|
||||
"""
|
||||
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
|
||||
|
||||
client = get_api_key_client()
|
||||
|
||||
try:
|
||||
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="nvidia")
|
||||
return key_info["api_key"]
|
||||
except APIKeyNotConfiguredError as e:
|
||||
logger.error(f"No NVIDIA API key for tenant '{tenant_id}': {e}")
|
||||
raise ValueError(str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Control Panel error: {e}")
|
||||
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
|
||||
|
||||
def _get_client(self, api_key: str) -> httpx.AsyncClient:
|
||||
"""Get configured HTTP client for NVIDIA NIM API"""
|
||||
return httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
timeout=httpx.Timeout(120.0) # Longer timeout for large models
|
||||
)
|
||||
|
||||
async def execute_inference(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str = "nvidia/llama-3.1-nemotron-super-49b-v1",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4000,
|
||||
stream: bool = False,
|
||||
user_id: str = None,
|
||||
tenant_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute LLM inference with simple prompt"""
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
return await self.execute_inference_with_messages(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=stream,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
async def execute_inference_with_messages(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str = "nvidia/llama-3.1-nemotron-super-49b-v1",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4000,
|
||||
stream: bool = False,
|
||||
user_id: str = None,
|
||||
tenant_id: str = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute LLM inference using messages format (conversation style)"""
|
||||
|
||||
# Check circuit breaker
|
||||
if not await self._is_circuit_closed():
|
||||
raise Exception("Circuit breaker is open - NVIDIA NIM service temporarily unavailable")
|
||||
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required for NVIDIA NIM inference")
|
||||
|
||||
try:
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
|
||||
# Translate GT 2.0 "agent" role to OpenAI "assistant" for external API compatibility
|
||||
external_messages = []
|
||||
for msg in messages:
|
||||
external_msg = {
|
||||
**msg, # Preserve ALL fields including tool_call_id, tool_calls, etc.
|
||||
"role": "assistant" if msg.get("role") == "agent" else msg.get("role")
|
||||
}
|
||||
external_messages.append(external_msg)
|
||||
|
||||
# Build request payload
|
||||
request_data = {
|
||||
"model": model,
|
||||
"messages": external_messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
request_data["tools"] = tools
|
||||
if tool_choice:
|
||||
request_data["tool_choice"] = tool_choice
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
async with self._get_client(api_key) as client:
|
||||
if stream:
|
||||
# Return generator for streaming
|
||||
return self._stream_inference_with_messages(
|
||||
client, request_data, user_id, tenant_id, model
|
||||
)
|
||||
|
||||
# Non-streaming request
|
||||
response = await client.post("/chat/completions", json=request_data)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
latency = (time.time() - start_time) * 1000
|
||||
|
||||
# Calculate cost
|
||||
usage = data.get("usage", {})
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)
|
||||
|
||||
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
|
||||
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
|
||||
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
|
||||
cost_cents = int((input_cost + output_cost) * 100)
|
||||
|
||||
# Track usage
|
||||
await self._track_usage(user_id, tenant_id, model, total_tokens, latency, cost_cents)
|
||||
|
||||
# Reset circuit breaker on success
|
||||
await self._record_success()
|
||||
|
||||
# Build response
|
||||
result = {
|
||||
"content": data["choices"][0]["message"]["content"],
|
||||
"model": model,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"cost_cents": cost_cents
|
||||
},
|
||||
"latency_ms": latency,
|
||||
"provider": "nvidia"
|
||||
}
|
||||
|
||||
# Include tool calls if present
|
||||
message = data["choices"][0]["message"]
|
||||
if message.get("tool_calls"):
|
||||
result["tool_calls"] = message["tool_calls"]
|
||||
|
||||
return result
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"NVIDIA NIM API error: {e.response.status_code} - {e.response.text}")
|
||||
await self._record_failure()
|
||||
raise Exception(f"NVIDIA NIM inference failed: HTTP {e.response.status_code}")
|
||||
except Exception as e:
|
||||
logger.error(f"NVIDIA NIM inference failed: {e}")
|
||||
await self._record_failure()
|
||||
raise Exception(f"NVIDIA NIM inference failed: {str(e)}")
|
||||
|
||||
async def _stream_inference_with_messages(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
request_data: Dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
model: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream LLM inference responses"""
|
||||
|
||||
start_time = time.time()
|
||||
total_tokens = 0
|
||||
|
||||
try:
|
||||
async with client.stream("POST", "/chat/completions", json=request_data) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
if chunk.get("choices") and chunk["choices"][0].get("delta", {}).get("content"):
|
||||
content = chunk["choices"][0]["delta"]["content"]
|
||||
total_tokens += len(content.split()) # Approximate
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Track usage after streaming completes
|
||||
latency = (time.time() - start_time) * 1000
|
||||
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
|
||||
cost_cents = int((total_tokens / 1_000_000) * model_prices["output"] * 100)
|
||||
await self._track_usage(user_id, tenant_id, model, total_tokens, latency, cost_cents)
|
||||
|
||||
await self._record_success()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"NVIDIA NIM streaming error: {e}")
|
||||
await self._record_failure()
|
||||
raise e
|
||||
|
||||
async def check_health(self) -> Dict[str, Any]:
|
||||
"""Check health of NVIDIA NIM backend and circuit breaker status"""
|
||||
|
||||
return {
|
||||
"nvidia_nim": {
|
||||
"endpoint": self.base_url,
|
||||
"status": "available" if self.circuit_breaker_status["state"] == "closed" else "degraded",
|
||||
"last_check": datetime.utcnow().isoformat()
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": self.circuit_breaker_status["state"],
|
||||
"failure_count": self.circuit_breaker_status["failure_count"],
|
||||
"last_failure": self.circuit_breaker_status["last_failure_time"].isoformat()
|
||||
if self.circuit_breaker_status["last_failure_time"] else None
|
||||
}
|
||||
}
|
||||
|
||||
async def _is_circuit_closed(self) -> bool:
|
||||
"""Check if circuit breaker allows requests"""
|
||||
|
||||
if self.circuit_breaker_status["state"] == "closed":
|
||||
return True
|
||||
|
||||
if self.circuit_breaker_status["state"] == "open":
|
||||
# Check if recovery timeout has passed
|
||||
if self.circuit_breaker_status["last_failure_time"]:
|
||||
time_since_failure = (datetime.utcnow() - self.circuit_breaker_status["last_failure_time"]).total_seconds()
|
||||
if time_since_failure > self.circuit_breaker_status["recovery_timeout"]:
|
||||
# Move to half-open state
|
||||
self.circuit_breaker_status["state"] = "half_open"
|
||||
logger.info("NVIDIA NIM circuit breaker moved to half-open state")
|
||||
return True
|
||||
return False
|
||||
|
||||
if self.circuit_breaker_status["state"] == "half_open":
|
||||
# Allow limited requests in half-open state
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _record_success(self):
|
||||
"""Record successful request for circuit breaker"""
|
||||
|
||||
if self.circuit_breaker_status["state"] == "half_open":
|
||||
# Success in half-open state closes the circuit
|
||||
self.circuit_breaker_status["state"] = "closed"
|
||||
self.circuit_breaker_status["failure_count"] = 0
|
||||
logger.info("NVIDIA NIM circuit breaker closed after successful request")
|
||||
|
||||
# Reset failure count on any success
|
||||
self.circuit_breaker_status["failure_count"] = 0
|
||||
|
||||
async def _record_failure(self):
|
||||
"""Record failed request for circuit breaker"""
|
||||
|
||||
self.circuit_breaker_status["failure_count"] += 1
|
||||
self.circuit_breaker_status["last_failure_time"] = datetime.utcnow()
|
||||
|
||||
if self.circuit_breaker_status["failure_count"] >= self.circuit_breaker_status["failure_threshold"]:
|
||||
if self.circuit_breaker_status["state"] in ["closed", "half_open"]:
|
||||
self.circuit_breaker_status["state"] = "open"
|
||||
logger.warning(f"NVIDIA NIM circuit breaker opened after {self.circuit_breaker_status['failure_count']} failures")
|
||||
|
||||
async def _track_usage(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
model: str,
|
||||
tokens: int,
|
||||
latency: float,
|
||||
cost_cents: int
|
||||
):
|
||||
"""Track usage metrics for billing and monitoring"""
|
||||
|
||||
# Create usage key
|
||||
usage_key = f"{tenant_id}:{user_id}:{model}"
|
||||
|
||||
# Initialize metrics if not exists
|
||||
if usage_key not in self.usage_metrics:
|
||||
self.usage_metrics[usage_key] = {
|
||||
"total_tokens": 0,
|
||||
"total_requests": 0,
|
||||
"total_cost_cents": 0,
|
||||
"average_latency": 0
|
||||
}
|
||||
|
||||
# Update metrics
|
||||
metrics = self.usage_metrics[usage_key]
|
||||
metrics["total_tokens"] += tokens
|
||||
metrics["total_requests"] += 1
|
||||
metrics["total_cost_cents"] += cost_cents
|
||||
|
||||
# Update average latency
|
||||
prev_avg = metrics["average_latency"]
|
||||
prev_count = metrics["total_requests"] - 1
|
||||
metrics["average_latency"] = (prev_avg * prev_count + latency) / metrics["total_requests"]
|
||||
|
||||
# Log high-level metrics periodically
|
||||
if metrics["total_requests"] % 100 == 0:
|
||||
logger.info(f"NVIDIA NIM usage milestone for {usage_key}: {metrics}")
|
||||
|
||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int, model: str) -> int:
|
||||
"""Calculate cost in cents based on token usage"""
|
||||
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
|
||||
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
|
||||
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
|
||||
return int((input_cost + output_cost) * 100)
|
||||
|
||||
async def get_available_models(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of available NVIDIA NIM models with their configurations"""
|
||||
models = []
|
||||
|
||||
for model_id, prices in NVIDIA_MODEL_PRICES.items():
|
||||
if model_id == "default":
|
||||
continue
|
||||
|
||||
models.append({
|
||||
"id": model_id,
|
||||
"name": model_id.split("/")[-1].replace("-", " ").title(),
|
||||
"provider": "nvidia",
|
||||
"max_tokens": 4096, # Default for most NIM models
|
||||
"cost_per_1k_input": prices["input"],
|
||||
"cost_per_1k_output": prices["output"],
|
||||
"supports_streaming": True,
|
||||
"supports_function_calling": True
|
||||
})
|
||||
|
||||
return models
|
||||
457
apps/resource-cluster/app/core/capability_auth.py
Normal file
457
apps/resource-cluster/app/core/capability_auth.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
Capability-Based Authentication for GT 2.0 Resource Cluster
|
||||
|
||||
Implements JWT capability token verification with:
|
||||
- Cryptographic signature validation
|
||||
- Fine-grained resource permissions
|
||||
- Rate limiting and constraints enforcement
|
||||
- Tenant isolation validation
|
||||
- Zero external dependencies
|
||||
|
||||
GT 2.0 Security Principles:
|
||||
- Self-contained: No external auth services
|
||||
- Stateless: All permissions in JWT token
|
||||
- Cryptographic: RSA signature verification
|
||||
- Isolated: Perfect tenant separation
|
||||
"""
|
||||
|
||||
import jwt
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import HTTPException, Depends, Header
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class CapabilityError(Exception):
|
||||
"""Capability authentication error"""
|
||||
pass
|
||||
|
||||
|
||||
class ResourceType(str, Enum):
|
||||
"""Resource types in GT 2.0"""
|
||||
LLM = "llm"
|
||||
EMBEDDING = "embedding"
|
||||
VECTOR_STORAGE = "vector_storage"
|
||||
EXTERNAL_SERVICES = "external_services"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
"""Action types for resources"""
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXECUTE = "execute"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Capability:
|
||||
"""Individual capability definition"""
|
||||
resource: ResourceType
|
||||
actions: List[ActionType]
|
||||
constraints: Dict[str, Any]
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
def allows_action(self, action: ActionType) -> bool:
|
||||
"""Check if capability allows specific action"""
|
||||
return action in self.actions
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if capability is expired"""
|
||||
if not self.expires_at:
|
||||
return False
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
def check_constraint(self, constraint_name: str, value: Any) -> bool:
|
||||
"""Check if value satisfies constraint"""
|
||||
if constraint_name not in self.constraints:
|
||||
return True # No constraint means allowed
|
||||
|
||||
constraint_value = self.constraints[constraint_name]
|
||||
|
||||
if constraint_name == "max_tokens":
|
||||
return value <= constraint_value
|
||||
elif constraint_name == "allowed_models":
|
||||
return value in constraint_value
|
||||
elif constraint_name == "max_requests_per_hour":
|
||||
# This would be checked separately with rate limiting
|
||||
return True
|
||||
elif constraint_name == "allowed_tenants":
|
||||
return value in constraint_value
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapabilityToken:
|
||||
"""Parsed capability token"""
|
||||
subject: str
|
||||
tenant_id: str
|
||||
capabilities: List[Capability]
|
||||
issued_at: datetime
|
||||
expires_at: datetime
|
||||
issuer: str
|
||||
token_version: str
|
||||
|
||||
def has_capability(self, resource: ResourceType, action: ActionType) -> bool:
|
||||
"""Check if token has specific capability"""
|
||||
for cap in self.capabilities:
|
||||
if cap.resource == resource and cap.allows_action(action) and not cap.is_expired():
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_capability(self, resource: ResourceType) -> Optional[Capability]:
|
||||
"""Get capability for specific resource"""
|
||||
for cap in self.capabilities:
|
||||
if cap.resource == resource and not cap.is_expired():
|
||||
return cap
|
||||
return None
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if entire token is expired"""
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
|
||||
class CapabilityAuthenticator:
|
||||
"""
|
||||
Handles capability token verification and authorization.
|
||||
|
||||
Uses JWT tokens with embedded permissions for stateless authentication.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
|
||||
# In production, this would be loaded from secure storage
|
||||
# For development, using the secret key
|
||||
self.secret_key = self.settings.secret_key
|
||||
self.algorithm = "HS256" # TODO: Upgrade to RS256 with public/private keys
|
||||
|
||||
logger.info("Capability authenticator initialized")
|
||||
|
||||
async def verify_token(self, token: str) -> CapabilityToken:
|
||||
"""
|
||||
Verify and parse capability token.
|
||||
|
||||
Args:
|
||||
token: JWT capability token
|
||||
|
||||
Returns:
|
||||
Parsed capability token
|
||||
|
||||
Raises:
|
||||
CapabilityError: If token is invalid or expired
|
||||
"""
|
||||
try:
|
||||
# Decode JWT token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.secret_key,
|
||||
algorithms=[self.algorithm],
|
||||
audience="gt2-resource-cluster"
|
||||
)
|
||||
|
||||
# Validate required fields
|
||||
required_fields = ["sub", "tenant_id", "capabilities", "iat", "exp", "iss"]
|
||||
for field in required_fields:
|
||||
if field not in payload:
|
||||
raise CapabilityError(f"Missing required field: {field}")
|
||||
|
||||
# Parse timestamps
|
||||
issued_at = datetime.fromtimestamp(payload["iat"], tz=timezone.utc)
|
||||
expires_at = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
|
||||
# Check token expiration
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise CapabilityError("Token has expired")
|
||||
|
||||
# Parse capabilities
|
||||
capabilities = []
|
||||
for cap_data in payload["capabilities"]:
|
||||
try:
|
||||
capability = Capability(
|
||||
resource=ResourceType(cap_data["resource"]),
|
||||
actions=[ActionType(action) for action in cap_data["actions"]],
|
||||
constraints=cap_data.get("constraints", {}),
|
||||
expires_at=datetime.fromtimestamp(
|
||||
cap_data["expires_at"], tz=timezone.utc
|
||||
) if cap_data.get("expires_at") else None
|
||||
)
|
||||
capabilities.append(capability)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning(f"Invalid capability in token: {e}")
|
||||
# Skip invalid capabilities rather than rejecting entire token
|
||||
continue
|
||||
|
||||
# Create capability token
|
||||
capability_token = CapabilityToken(
|
||||
subject=payload["sub"],
|
||||
tenant_id=payload["tenant_id"],
|
||||
capabilities=capabilities,
|
||||
issued_at=issued_at,
|
||||
expires_at=expires_at,
|
||||
issuer=payload["iss"],
|
||||
token_version=payload.get("token_version", "1.0")
|
||||
)
|
||||
|
||||
logger.debug(f"Capability token verified for {capability_token.subject}")
|
||||
return capability_token
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise CapabilityError("Token has expired")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise CapabilityError(f"Invalid token: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification failed: {e}")
|
||||
raise CapabilityError(f"Token verification failed: {e}")
|
||||
|
||||
async def check_resource_access(
|
||||
self,
|
||||
capability_token: CapabilityToken,
|
||||
resource: ResourceType,
|
||||
action: ActionType,
|
||||
constraints: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if token allows access to resource with specific action.
|
||||
|
||||
Args:
|
||||
capability_token: Verified capability token
|
||||
resource: Resource type to access
|
||||
action: Action to perform
|
||||
constraints: Additional constraints to check
|
||||
|
||||
Returns:
|
||||
True if access is allowed
|
||||
|
||||
Raises:
|
||||
CapabilityError: If access is denied
|
||||
"""
|
||||
try:
|
||||
# Check token expiration
|
||||
if capability_token.is_expired():
|
||||
raise CapabilityError("Token has expired")
|
||||
|
||||
# Find matching capability
|
||||
capability = capability_token.get_capability(resource)
|
||||
if not capability:
|
||||
raise CapabilityError(f"No capability for resource: {resource}")
|
||||
|
||||
# Check action permission
|
||||
if not capability.allows_action(action):
|
||||
raise CapabilityError(f"Action {action} not allowed for resource {resource}")
|
||||
|
||||
# Check constraints if provided
|
||||
if constraints:
|
||||
for constraint_name, value in constraints.items():
|
||||
if not capability.check_constraint(constraint_name, value):
|
||||
raise CapabilityError(
|
||||
f"Constraint violation: {constraint_name} = {value}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except CapabilityError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Resource access check failed: {e}")
|
||||
raise CapabilityError(f"Access check failed: {e}")
|
||||
|
||||
|
||||
# Global authenticator instance
|
||||
capability_authenticator = CapabilityAuthenticator()
|
||||
|
||||
|
||||
async def verify_capability_token(token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify capability token and return payload.
|
||||
|
||||
Args:
|
||||
token: JWT capability token
|
||||
|
||||
Returns:
|
||||
Token payload as dictionary
|
||||
|
||||
Raises:
|
||||
CapabilityError: If token is invalid
|
||||
"""
|
||||
capability_token = await capability_authenticator.verify_token(token)
|
||||
|
||||
return {
|
||||
"sub": capability_token.subject,
|
||||
"tenant_id": capability_token.tenant_id,
|
||||
"capabilities": [
|
||||
{
|
||||
"resource": cap.resource.value,
|
||||
"actions": [action.value for action in cap.actions],
|
||||
"constraints": cap.constraints
|
||||
}
|
||||
for cap in capability_token.capabilities
|
||||
],
|
||||
"iat": capability_token.issued_at.timestamp(),
|
||||
"exp": capability_token.expires_at.timestamp(),
|
||||
"iss": capability_token.issuer,
|
||||
"token_version": capability_token.token_version
|
||||
}
|
||||
|
||||
|
||||
async def get_current_capability(
|
||||
authorization: str = Header(..., description="Bearer token")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
FastAPI dependency to get current capability from Authorization header.
|
||||
|
||||
Args:
|
||||
authorization: Authorization header with Bearer token
|
||||
|
||||
Returns:
|
||||
Capability payload
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
try:
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid authorization header format"
|
||||
)
|
||||
|
||||
token = authorization[7:] # Remove "Bearer " prefix
|
||||
payload = await verify_capability_token(token)
|
||||
|
||||
return payload
|
||||
|
||||
except CapabilityError as e:
|
||||
logger.warning(f"Capability authentication failed: {e}")
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Authentication error")
|
||||
|
||||
|
||||
async def require_capability(
|
||||
resource: ResourceType,
|
||||
action: ActionType,
|
||||
constraints: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
FastAPI dependency to require specific capability.
|
||||
|
||||
Args:
|
||||
resource: Required resource type
|
||||
action: Required action type
|
||||
constraints: Additional constraints to check
|
||||
|
||||
Returns:
|
||||
Dependency function
|
||||
"""
|
||||
async def _check_capability(
|
||||
capability_payload: Dict[str, Any] = Depends(get_current_capability)
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
# Reconstruct capability token from payload
|
||||
capabilities = []
|
||||
for cap_data in capability_payload["capabilities"]:
|
||||
capability = Capability(
|
||||
resource=ResourceType(cap_data["resource"]),
|
||||
actions=[ActionType(action) for action in cap_data["actions"]],
|
||||
constraints=cap_data["constraints"]
|
||||
)
|
||||
capabilities.append(capability)
|
||||
|
||||
capability_token = CapabilityToken(
|
||||
subject=capability_payload["sub"],
|
||||
tenant_id=capability_payload["tenant_id"],
|
||||
capabilities=capabilities,
|
||||
issued_at=datetime.fromtimestamp(capability_payload["iat"], tz=timezone.utc),
|
||||
expires_at=datetime.fromtimestamp(capability_payload["exp"], tz=timezone.utc),
|
||||
issuer=capability_payload["iss"],
|
||||
token_version=capability_payload["token_version"]
|
||||
)
|
||||
|
||||
# Check required capability
|
||||
await capability_authenticator.check_resource_access(
|
||||
capability_token=capability_token,
|
||||
resource=resource,
|
||||
action=action,
|
||||
constraints=constraints
|
||||
)
|
||||
|
||||
return capability_payload
|
||||
|
||||
except CapabilityError as e:
|
||||
logger.warning(f"Capability check failed: {e}")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Capability check error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Authorization error")
|
||||
|
||||
return _check_capability
|
||||
|
||||
|
||||
# Convenience functions for common capability checks
|
||||
|
||||
async def require_llm_capability(
|
||||
capability_payload: Dict[str, Any] = Depends(
|
||||
require_capability(ResourceType.LLM, ActionType.EXECUTE)
|
||||
)
|
||||
) -> Dict[str, Any]:
|
||||
"""Require LLM execution capability"""
|
||||
return capability_payload
|
||||
|
||||
|
||||
async def require_embedding_capability(
|
||||
capability_payload: Dict[str, Any] = Depends(
|
||||
require_capability(ResourceType.EMBEDDING, ActionType.EXECUTE)
|
||||
)
|
||||
) -> Dict[str, Any]:
|
||||
"""Require embedding generation capability"""
|
||||
return capability_payload
|
||||
|
||||
|
||||
async def require_admin_capability(
|
||||
capability_payload: Dict[str, Any] = Depends(
|
||||
require_capability(ResourceType.ADMIN, ActionType.ADMIN)
|
||||
)
|
||||
) -> Dict[str, Any]:
|
||||
"""Require admin capability"""
|
||||
return capability_payload
|
||||
|
||||
|
||||
async def verify_capability_token_dependency(
|
||||
authorization: str = Header(..., description="Bearer token")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
FastAPI dependency for ChromaDB MCP API that verifies capability token.
|
||||
|
||||
Returns token payload with raw_token field for service layer use.
|
||||
"""
|
||||
try:
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid authorization header format"
|
||||
)
|
||||
|
||||
token = authorization[7:] # Remove "Bearer " prefix
|
||||
payload = await verify_capability_token(token)
|
||||
|
||||
# Add raw token for service layer
|
||||
payload["raw_token"] = token
|
||||
|
||||
return payload
|
||||
|
||||
except CapabilityError as e:
|
||||
logger.warning(f"Capability authentication failed: {e}")
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Authentication error")
|
||||
293
apps/resource-cluster/app/core/config.py
Normal file
293
apps/resource-cluster/app/core/config.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
GT 2.0 Resource Cluster Configuration
|
||||
|
||||
Central configuration for the air-gapped Resource Cluster that manages
|
||||
all AI resources, document processing, and external service integrations.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field, validator
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Resource Cluster settings with environment variable support"""
|
||||
|
||||
# Environment
|
||||
environment: str = Field(default="development", description="Runtime environment")
|
||||
debug: bool = Field(default=False, description="Debug mode")
|
||||
|
||||
# Service Identity
|
||||
cluster_name: str = Field(default="gt-resource-cluster", description="Cluster identifier")
|
||||
service_port: int = Field(default=8003, description="Service port")
|
||||
|
||||
# Security
|
||||
secret_key: str = Field(..., description="JWT signing key for capability tokens")
|
||||
algorithm: str = Field(default="HS256", description="JWT algorithm")
|
||||
capability_token_expire_minutes: int = Field(default=60, description="Capability token expiry")
|
||||
|
||||
# External LLM Providers (via HAProxy)
|
||||
groq_api_key: Optional[str] = Field(default=None, description="Groq Cloud API key")
|
||||
groq_endpoints: List[str] = Field(
|
||||
default=["https://api.groq.com/openai/v1"],
|
||||
description="Groq API endpoints for load balancing"
|
||||
)
|
||||
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key")
|
||||
anthropic_api_key: Optional[str] = Field(default=None, description="Anthropic API key")
|
||||
|
||||
# NVIDIA NIM Configuration
|
||||
nvidia_nim_endpoint: str = Field(
|
||||
default="https://integrate.api.nvidia.com/v1",
|
||||
description="NVIDIA NIM API endpoint (cloud or self-hosted)"
|
||||
)
|
||||
nvidia_nim_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable NVIDIA NIM backend for GPU-accelerated inference"
|
||||
)
|
||||
|
||||
# HAProxy Configuration
|
||||
haproxy_groq_endpoint: str = Field(
|
||||
default="http://haproxy-groq-lb-service.gt-resource.svc.cluster.local",
|
||||
description="HAProxy load balancer endpoint for Groq API"
|
||||
)
|
||||
haproxy_stats_endpoint: str = Field(
|
||||
default="http://haproxy-groq-lb-service.gt-resource.svc.cluster.local:8404/stats",
|
||||
description="HAProxy statistics endpoint"
|
||||
)
|
||||
haproxy_admin_socket: str = Field(
|
||||
default="/var/run/haproxy.sock",
|
||||
description="HAProxy admin socket for runtime configuration"
|
||||
)
|
||||
haproxy_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable HAProxy load balancing for external APIs"
|
||||
)
|
||||
|
||||
# Control Panel Integration (for API key retrieval)
|
||||
control_panel_url: str = Field(
|
||||
default="http://control-panel-backend:8000",
|
||||
description="Control Panel internal API URL for service-to-service calls"
|
||||
)
|
||||
service_auth_token: str = Field(
|
||||
default="internal-service-token",
|
||||
description="Service-to-service authentication token"
|
||||
)
|
||||
|
||||
# Admin Cluster Configuration Sync
|
||||
admin_cluster_url: str = Field(
|
||||
default="http://localhost:8001",
|
||||
description="Admin cluster URL for configuration sync"
|
||||
)
|
||||
config_sync_interval: int = Field(
|
||||
default=10,
|
||||
description="Configuration sync interval in seconds"
|
||||
)
|
||||
config_sync_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable automatic configuration sync from admin cluster"
|
||||
)
|
||||
|
||||
# Consul Service Discovery
|
||||
consul_host: str = Field(default="localhost", description="Consul host")
|
||||
consul_port: int = Field(default=8500, description="Consul port")
|
||||
consul_token: Optional[str] = Field(default=None, description="Consul ACL token")
|
||||
|
||||
# Document Processing
|
||||
chunking_engine_workers: int = Field(default=4, description="Parallel document processors")
|
||||
max_document_size_mb: int = Field(default=50, description="Maximum document size")
|
||||
supported_document_types: List[str] = Field(
|
||||
default=[".pdf", ".docx", ".txt", ".md", ".html", ".pptx", ".xlsx", ".csv"],
|
||||
description="Supported document formats"
|
||||
)
|
||||
|
||||
# BGE-M3 Embedding Configuration
|
||||
embedding_endpoint: str = Field(
|
||||
default="http://gentwo-vllm-embeddings:8000/v1/embeddings",
|
||||
description="Default embedding endpoint (local or external)"
|
||||
)
|
||||
bge_m3_local_mode: bool = Field(
|
||||
default=True,
|
||||
description="Use local BGE-M3 embedding service (True) or external endpoint (False)"
|
||||
)
|
||||
bge_m3_external_endpoint: Optional[str] = Field(
|
||||
default=None,
|
||||
description="External BGE-M3 embedding endpoint URL (when local_mode=False)"
|
||||
)
|
||||
|
||||
# Vector Database (ChromaDB)
|
||||
chromadb_host: str = Field(default="localhost", description="ChromaDB host")
|
||||
chromadb_port: int = Field(default=8000, description="ChromaDB port")
|
||||
chromadb_encryption_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Encryption key for vector storage"
|
||||
)
|
||||
|
||||
# Resource Limits
|
||||
max_concurrent_inferences: int = Field(default=100, description="Max concurrent LLM calls")
|
||||
max_tokens_per_request: int = Field(default=8000, description="Max tokens per LLM request")
|
||||
rate_limit_requests_per_minute: int = Field(default=60, description="Global rate limit")
|
||||
|
||||
# Storage Paths
|
||||
data_directory: str = Field(
|
||||
default="/tmp/gt2-resource-cluster" if os.getenv("ENVIRONMENT") != "production" else "/data/resource-cluster",
|
||||
description="Base data directory"
|
||||
)
|
||||
template_library_path: str = Field(
|
||||
default="/tmp/gt2-resource-cluster/templates" if os.getenv("ENVIRONMENT") != "production" else "/data/resource-cluster/templates",
|
||||
description="Agent template library"
|
||||
)
|
||||
models_cache_path: str = Field( # Renamed to avoid pydantic warning
|
||||
default="/tmp/gt2-resource-cluster/models" if os.getenv("ENVIRONMENT") != "production" else "/data/resource-cluster/models",
|
||||
description="Local model cache"
|
||||
)
|
||||
|
||||
# Redis removed - Resource Cluster uses PostgreSQL for caching and rate limiting
|
||||
|
||||
# Monitoring
|
||||
prometheus_enabled: bool = Field(default=True, description="Enable Prometheus metrics")
|
||||
prometheus_port: int = Field(default=9091, description="Prometheus metrics port")
|
||||
|
||||
# CORS Configuration (for tenant backends)
|
||||
cors_origins: List[str] = Field(
|
||||
default=["http://localhost:8002", "https://*.gt2.com"],
|
||||
description="Allowed CORS origins"
|
||||
)
|
||||
|
||||
# Trusted Host Configuration
|
||||
trusted_hosts: List[str] = Field(
|
||||
default=["localhost", "*.gt2.com", "resource-cluster", "gentwo-resource-backend",
|
||||
"gt2-resource-backend", "testserver", "127.0.0.1", "*"],
|
||||
description="Allowed host headers for TrustedHostMiddleware"
|
||||
)
|
||||
|
||||
# Feature Flags
|
||||
enable_model_caching: bool = Field(default=True, description="Cache model responses")
|
||||
enable_usage_tracking: bool = Field(default=True, description="Track resource usage")
|
||||
enable_cost_calculation: bool = Field(default=True, description="Calculate usage costs")
|
||||
|
||||
@validator("data_directory")
|
||||
def validate_data_directory(cls, v):
|
||||
# Ensure directory exists with secure permissions
|
||||
os.makedirs(v, exist_ok=True, mode=0o700)
|
||||
return v
|
||||
|
||||
@validator("template_library_path")
|
||||
def validate_template_library_path(cls, v):
|
||||
os.makedirs(v, exist_ok=True, mode=0o700)
|
||||
return v
|
||||
|
||||
@validator("models_cache_path")
|
||||
def validate_models_cache_path(cls, v):
|
||||
os.makedirs(v, exist_ok=True, mode=0o700)
|
||||
return v
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
def get_settings(tenant_id: Optional[str] = None) -> Settings:
|
||||
"""Get tenant-scoped application settings"""
|
||||
# For development, use a simple cache without tenant isolation
|
||||
if os.getenv("ENVIRONMENT") == "development":
|
||||
return Settings()
|
||||
|
||||
# In production, settings should be tenant-scoped
|
||||
# This prevents global state from affecting tenant isolation
|
||||
if tenant_id:
|
||||
# Create tenant-specific settings with proper isolation
|
||||
settings = Settings()
|
||||
# Add tenant-specific configurations here if needed
|
||||
return settings
|
||||
else:
|
||||
# Default settings for non-tenant operations
|
||||
return Settings()
|
||||
|
||||
|
||||
def get_resource_families(tenant_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get tenant-scoped resource family definitions (from CLAUDE.md)"""
|
||||
# Base resource families - can be extended per tenant in production
|
||||
return {
|
||||
"ai_ml": {
|
||||
"name": "AI/ML Resources",
|
||||
"subtypes": ["llm", "embedding", "image_generation", "function_calling"]
|
||||
},
|
||||
"rag_engine": {
|
||||
"name": "RAG Engine Resources",
|
||||
"subtypes": ["vector_db", "document_processor", "semantic_search", "retrieval"]
|
||||
},
|
||||
"agentic_workflow": {
|
||||
"name": "Agentic Workflow Resources",
|
||||
"subtypes": ["single_agent", "multi_agent", "orchestration", "memory"]
|
||||
},
|
||||
"app_integration": {
|
||||
"name": "App Integration Resources",
|
||||
"subtypes": ["oauth2", "webhook", "api_connector", "database_connector"]
|
||||
},
|
||||
"external_service": {
|
||||
"name": "External Web Services",
|
||||
"subtypes": ["iframe_embed", "sso_service", "remote_desktop", "learning_platform"]
|
||||
},
|
||||
"ai_literacy": {
|
||||
"name": "AI Literacy & Cognitive Skills",
|
||||
"subtypes": ["strategic_game", "logic_puzzle", "philosophical_dilemma", "educational_content"]
|
||||
}
|
||||
}
|
||||
|
||||
def get_model_configs(tenant_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get tenant-scoped model configurations for different providers"""
|
||||
# Base model configurations - can be customized per tenant in production
|
||||
return {
|
||||
"groq": {
|
||||
"llama-3.1-70b-versatile": {
|
||||
"max_tokens": 8000,
|
||||
"cost_per_1k_tokens": 0.59,
|
||||
"supports_streaming": True,
|
||||
"supports_function_calling": True
|
||||
},
|
||||
"llama-3.1-8b-instant": {
|
||||
"max_tokens": 8000,
|
||||
"cost_per_1k_tokens": 0.05,
|
||||
"supports_streaming": True,
|
||||
"supports_function_calling": True
|
||||
},
|
||||
"mixtral-8x7b-32768": {
|
||||
"max_tokens": 32768,
|
||||
"cost_per_1k_tokens": 0.27,
|
||||
"supports_streaming": True,
|
||||
"supports_function_calling": False
|
||||
}
|
||||
},
|
||||
"openai": {
|
||||
"gpt-4-turbo": {
|
||||
"max_tokens": 128000,
|
||||
"cost_per_1k_tokens": 10.0,
|
||||
"supports_streaming": True,
|
||||
"supports_function_calling": True
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
"max_tokens": 16385,
|
||||
"cost_per_1k_tokens": 0.5,
|
||||
"supports_streaming": True,
|
||||
"supports_function_calling": True
|
||||
}
|
||||
},
|
||||
"anthropic": {
|
||||
"claude-3-opus": {
|
||||
"max_tokens": 200000,
|
||||
"cost_per_1k_tokens": 15.0,
|
||||
"supports_streaming": True,
|
||||
"supports_function_calling": False
|
||||
},
|
||||
"claude-3-sonnet": {
|
||||
"max_tokens": 200000,
|
||||
"cost_per_1k_tokens": 3.0,
|
||||
"supports_streaming": True,
|
||||
"supports_function_calling": False
|
||||
}
|
||||
}
|
||||
}
|
||||
45
apps/resource-cluster/app/core/exceptions.py
Normal file
45
apps/resource-cluster/app/core/exceptions.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
GT 2.0 Resource Cluster Exceptions
|
||||
|
||||
Custom exceptions for the resource cluster.
|
||||
"""
|
||||
|
||||
|
||||
class ResourceClusterError(Exception):
|
||||
"""Base exception for resource cluster errors"""
|
||||
pass
|
||||
|
||||
|
||||
class ProviderError(ResourceClusterError):
|
||||
"""Error from AI model provider"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelNotFoundError(ResourceClusterError):
|
||||
"""Requested model not found"""
|
||||
pass
|
||||
|
||||
|
||||
class CapabilityError(ResourceClusterError):
|
||||
"""Capability token validation error"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPError(ResourceClusterError):
|
||||
"""MCP service error"""
|
||||
pass
|
||||
|
||||
|
||||
class DocumentProcessingError(ResourceClusterError):
|
||||
"""Document processing error"""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(ResourceClusterError):
|
||||
"""Rate limit exceeded"""
|
||||
pass
|
||||
|
||||
|
||||
class CircuitBreakerError(ProviderError):
|
||||
"""Circuit breaker is open"""
|
||||
pass
|
||||
273
apps/resource-cluster/app/core/security.py
Normal file
273
apps/resource-cluster/app/core/security.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
GT 2.0 Resource Cluster Security
|
||||
|
||||
Capability-based authentication and authorization for resource access.
|
||||
Implements cryptographically signed JWT tokens with embedded capabilities.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
class ResourceCapability(BaseModel):
|
||||
"""Individual resource capability"""
|
||||
resource: str # e.g., "llm:groq", "rag:semantic_search"
|
||||
actions: List[str] # e.g., ["inference", "streaming"]
|
||||
limits: Dict[str, Any] = {} # e.g., {"max_tokens": 4000, "requests_per_minute": 60}
|
||||
constraints: Dict[str, Any] = {} # e.g., {"valid_until": "2024-12-31", "ip_restrictions": []}
|
||||
|
||||
|
||||
class CapabilityToken(BaseModel):
|
||||
"""Capability-based JWT token payload"""
|
||||
sub: str # User or service identifier
|
||||
tenant_id: str # Tenant identifier
|
||||
capabilities: List[ResourceCapability] # Granted capabilities
|
||||
capability_hash: str # SHA256 hash of capabilities for integrity
|
||||
exp: Optional[datetime] = None # Expiration time
|
||||
iat: Optional[datetime] = None # Issued at time
|
||||
jti: Optional[str] = None # JWT ID for revocation
|
||||
|
||||
|
||||
class CapabilityValidator:
|
||||
"""Validates and enforces capability-based access control"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
|
||||
def create_capability_token(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
capabilities: List[Dict[str, Any]],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""Create a cryptographically signed capability token"""
|
||||
|
||||
# Convert capabilities to ResourceCapability objects
|
||||
capability_objects = [
|
||||
ResourceCapability(**cap) for cap in capabilities
|
||||
]
|
||||
|
||||
# Generate capability hash for integrity verification
|
||||
capability_hash = self._generate_capability_hash(capability_objects)
|
||||
|
||||
# Set token expiration
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=self.settings.capability_token_expire_minutes)
|
||||
|
||||
# Create token payload
|
||||
token_data = CapabilityToken(
|
||||
sub=user_id,
|
||||
tenant_id=tenant_id,
|
||||
capabilities=[cap.dict() for cap in capability_objects],
|
||||
capability_hash=capability_hash,
|
||||
exp=expire,
|
||||
iat=datetime.utcnow(),
|
||||
jti=self._generate_jti()
|
||||
)
|
||||
|
||||
# Encode JWT token
|
||||
encoded_jwt = jwt.encode(
|
||||
token_data.dict(),
|
||||
self.settings.secret_key,
|
||||
algorithm=self.settings.algorithm
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
def verify_capability_token(self, token: str) -> Optional[CapabilityToken]:
|
||||
"""Verify and decode a capability token"""
|
||||
try:
|
||||
# Decode JWT token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.settings.secret_key,
|
||||
algorithms=[self.settings.algorithm]
|
||||
)
|
||||
|
||||
# Convert to CapabilityToken object
|
||||
capability_token = CapabilityToken(**payload)
|
||||
|
||||
# Verify capability hash integrity
|
||||
capability_objects = []
|
||||
for cap in capability_token.capabilities:
|
||||
if isinstance(cap, dict):
|
||||
capability_objects.append(ResourceCapability(**cap))
|
||||
else:
|
||||
capability_objects.append(cap)
|
||||
|
||||
expected_hash = self._generate_capability_hash(capability_objects)
|
||||
|
||||
if capability_token.capability_hash != expected_hash:
|
||||
raise ValueError("Capability hash mismatch - token may be tampered")
|
||||
|
||||
return capability_token
|
||||
|
||||
except (JWTError, ValueError) as e:
|
||||
return None
|
||||
|
||||
def check_resource_access(
|
||||
self,
|
||||
token: CapabilityToken,
|
||||
resource: str,
|
||||
action: str,
|
||||
context: Dict[str, Any] = {}
|
||||
) -> bool:
|
||||
"""Check if token grants access to specific resource and action"""
|
||||
|
||||
for capability in token.capabilities:
|
||||
# Handle both dict and ResourceCapability object formats
|
||||
if isinstance(capability, dict):
|
||||
cap_resource = capability["resource"]
|
||||
cap_actions = capability.get("actions", [])
|
||||
cap_constraints = capability.get("constraints", {})
|
||||
else:
|
||||
cap_resource = capability.resource
|
||||
cap_actions = capability.actions
|
||||
cap_constraints = capability.constraints
|
||||
|
||||
# Check if capability matches resource
|
||||
if self._matches_resource(cap_resource, resource):
|
||||
# Check if action is allowed
|
||||
if action in cap_actions:
|
||||
# Check additional constraints
|
||||
if self._check_constraints(cap_constraints, context):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_resource_limits(
|
||||
self,
|
||||
token: CapabilityToken,
|
||||
resource: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get resource-specific limits from token"""
|
||||
|
||||
for capability in token.capabilities:
|
||||
# Handle both dict and ResourceCapability object formats
|
||||
if isinstance(capability, dict):
|
||||
cap_resource = capability["resource"]
|
||||
cap_limits = capability.get("limits", {})
|
||||
else:
|
||||
cap_resource = capability.resource
|
||||
cap_limits = capability.limits
|
||||
|
||||
if self._matches_resource(cap_resource, resource):
|
||||
return cap_limits
|
||||
|
||||
return {}
|
||||
|
||||
def _generate_capability_hash(self, capabilities: List[ResourceCapability]) -> str:
|
||||
"""Generate SHA256 hash of capabilities for integrity verification"""
|
||||
# Sort capabilities for consistent hashing
|
||||
sorted_caps = sorted(
|
||||
[cap.dict() for cap in capabilities],
|
||||
key=lambda x: x["resource"]
|
||||
)
|
||||
|
||||
# Create hash
|
||||
cap_string = json.dumps(sorted_caps, sort_keys=True)
|
||||
return hashlib.sha256(cap_string.encode()).hexdigest()
|
||||
|
||||
def _generate_jti(self) -> str:
|
||||
"""Generate unique JWT ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _matches_resource(self, pattern: str, resource: str) -> bool:
|
||||
"""Check if resource pattern matches requested resource"""
|
||||
# Handle wildcards (e.g., "llm:*" matches "llm:groq")
|
||||
if pattern.endswith(":*"):
|
||||
prefix = pattern[:-2]
|
||||
return resource.startswith(prefix + ":")
|
||||
|
||||
# Handle exact matches
|
||||
return pattern == resource
|
||||
|
||||
def _check_constraints(self, constraints: Dict[str, Any], context: Dict[str, Any]) -> bool:
|
||||
"""Check additional constraints like time validity and IP restrictions"""
|
||||
|
||||
# Check time validity
|
||||
if "valid_until" in constraints:
|
||||
valid_until = datetime.fromisoformat(constraints["valid_until"])
|
||||
if datetime.utcnow() > valid_until:
|
||||
return False
|
||||
|
||||
# Check IP restrictions
|
||||
if "ip_restrictions" in constraints and "client_ip" in context:
|
||||
allowed_ips = constraints["ip_restrictions"]
|
||||
if allowed_ips and context["client_ip"] not in allowed_ips:
|
||||
return False
|
||||
|
||||
# Check tenant restrictions
|
||||
if "allowed_tenants" in constraints and "tenant_id" in context:
|
||||
allowed_tenants = constraints["allowed_tenants"]
|
||||
if allowed_tenants and context["tenant_id"] not in allowed_tenants:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Global validator instance
|
||||
capability_validator = CapabilityValidator()
|
||||
|
||||
|
||||
def verify_capability_token(token: str) -> Optional[CapabilityToken]:
|
||||
"""Standalone function for FastAPI dependency injection"""
|
||||
return capability_validator.verify_capability_token(token)
|
||||
|
||||
|
||||
def create_resource_capability(
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
actions: List[str],
|
||||
limits: Dict[str, Any] = {},
|
||||
constraints: Dict[str, Any] = {}
|
||||
) -> Dict[str, Any]:
|
||||
"""Helper function to create a resource capability"""
|
||||
return {
|
||||
"resource": f"{resource_type}:{resource_id}",
|
||||
"actions": actions,
|
||||
"limits": limits,
|
||||
"constraints": constraints
|
||||
}
|
||||
|
||||
|
||||
def create_assistant_capabilities(assistant_config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Create capabilities from agent configuration"""
|
||||
capabilities = []
|
||||
|
||||
# Extract capabilities from agent config
|
||||
for cap in assistant_config.get("capabilities", []):
|
||||
capabilities.append(cap)
|
||||
|
||||
# Add default LLM capability if specified
|
||||
if "primary_llm" in assistant_config.get("resource_preferences", {}):
|
||||
llm_model = assistant_config["resource_preferences"]["primary_llm"]
|
||||
capabilities.append(create_resource_capability(
|
||||
"llm",
|
||||
llm_model.replace(":", "_"),
|
||||
["inference", "streaming"],
|
||||
{
|
||||
"max_tokens": assistant_config["resource_preferences"].get("max_tokens", 4000),
|
||||
"temperature": assistant_config["resource_preferences"].get("temperature", 0.7)
|
||||
}
|
||||
))
|
||||
|
||||
return capabilities
|
||||
|
||||
|
||||
# Global capability validator instance
|
||||
capability_validator = CapabilityValidator()
|
||||
Reference in New Issue
Block a user