GT AI OS Community v2.0.33 - Add NVIDIA NIM and Nemotron agents

- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2
- Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2
  - Made more general-purpose (flexible targets, expanded tools)
- Added nemotron-mini-agent.csv for fast local inference via Ollama
- Added nemotron-agent.csv for advanced reasoning via Ollama
- Added wiki page: Projects for NVIDIA NIMs and Nemotron
This commit is contained in:
HackWeasel
2025-12-12 17:47:14 -05:00
commit 310491a557
750 changed files with 232701 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
"""
Core utilities and configuration for Resource Cluster
"""

View File

@@ -0,0 +1,140 @@
"""
GT 2.0 Resource Cluster - API Standards Integration
This module integrates CB-REST standards for non-AI endpoints while
maintaining OpenAI compatibility for AI inference endpoints.
"""
import os
import sys
from pathlib import Path
# Add the api-standards package to the path
api_standards_path = Path(__file__).parent.parent.parent.parent.parent / "packages" / "api-standards" / "src"
if api_standards_path.exists():
sys.path.insert(0, str(api_standards_path))
# Import CB-REST standards
try:
from response import StandardResponse, format_response, format_error
from capability import (
init_capability_verifier,
verify_capability,
require_capability,
Capability,
CapabilityToken
)
from errors import ErrorCode, APIError, raise_api_error
from middleware import (
RequestCorrelationMiddleware,
CapabilityMiddleware,
TenantIsolationMiddleware,
RateLimitMiddleware
)
except ImportError as e:
# Fallback for development - create minimal implementations
print(f"Warning: Could not import api-standards package: {e}")
# Create minimal implementations for development
class StandardResponse:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def format_response(data, capability_used, request_id=None):
return {
"data": data,
"error": None,
"capability_used": capability_used,
"request_id": request_id or "dev-mode"
}
def format_error(code, message, capability_used="none", **kwargs):
return {
"data": None,
"error": {
"code": code,
"message": message,
**kwargs
},
"capability_used": capability_used,
"request_id": kwargs.get("request_id", "dev-mode")
}
class ErrorCode:
CAPABILITY_INSUFFICIENT = "CAPABILITY_INSUFFICIENT"
RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"
INVALID_REQUEST = "INVALID_REQUEST"
SYSTEM_ERROR = "SYSTEM_ERROR"
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"
class APIError(Exception):
def __init__(self, code, message, **kwargs):
self.code = code
self.message = message
self.kwargs = kwargs
super().__init__(message)
# Export all CB-REST components
__all__ = [
'StandardResponse',
'format_response',
'format_error',
'init_capability_verifier',
'verify_capability',
'require_capability',
'Capability',
'CapabilityToken',
'ErrorCode',
'APIError',
'raise_api_error',
'RequestCorrelationMiddleware',
'CapabilityMiddleware',
'TenantIsolationMiddleware',
'RateLimitMiddleware'
]
def setup_api_standards(app, secret_key: str):
"""
Setup API standards for the Resource Cluster
IMPORTANT: This only applies CB-REST to non-AI endpoints.
AI inference endpoints maintain OpenAI compatibility.
Args:
app: FastAPI application instance
secret_key: Secret key for JWT signing
"""
# Initialize capability verifier
if 'init_capability_verifier' in globals():
init_capability_verifier(secret_key)
# Add middleware in correct order
if 'RequestCorrelationMiddleware' in globals():
app.add_middleware(RequestCorrelationMiddleware)
if 'RateLimitMiddleware' in globals():
app.add_middleware(
RateLimitMiddleware,
requests_per_minute=1000 # Higher limit for resource cluster
)
# Note: No TenantIsolationMiddleware for Resource Cluster
# as it serves multiple tenants with capability-based access
if 'CapabilityMiddleware' in globals():
# Exclude AI inference endpoints from CB-REST middleware
# to maintain OpenAI compatibility
app.add_middleware(
CapabilityMiddleware,
exclude_paths=[
"/health",
"/ready",
"/metrics",
"/ai/chat/completions", # OpenAI compatible
"/ai/embeddings", # OpenAI compatible
"/ai/images/generations", # OpenAI compatible
"/ai/models" # OpenAI compatible
]
)

View File

@@ -0,0 +1,52 @@
"""
Resource backend implementations for GT 2.0
Provides unified interfaces for all resource types:
- LLM inference (Groq, OpenAI, Anthropic)
- Vector databases (PGVector)
- Document processing (Unstructured)
- External services (OAuth2, iframe)
- AI literacy resources
"""
from typing import Dict, Any
import logging
logger = logging.getLogger(__name__)
# Registry of available backends
BACKEND_REGISTRY: Dict[str, Any] = {}
def register_backend(name: str, backend_class):
"""Register a resource backend"""
BACKEND_REGISTRY[name] = backend_class
logger.info(f"Registered backend: {name}")
def get_backend(name: str):
"""Get a registered backend"""
if name not in BACKEND_REGISTRY:
raise ValueError(f"Backend not found: {name}")
return BACKEND_REGISTRY[name]
async def initialize_backends():
"""Initialize all resource backends"""
from app.core.backends.groq_proxy import GroqProxyBackend
from app.core.backends.nvidia_proxy import NvidiaProxyBackend
from app.core.backends.document_processor import DocumentProcessorBackend
from app.core.backends.embedding_backend import EmbeddingBackend
# Register backends
register_backend("groq_proxy", GroqProxyBackend())
register_backend("nvidia_proxy", NvidiaProxyBackend())
register_backend("document_processor", DocumentProcessorBackend())
register_backend("embedding", EmbeddingBackend())
logger.info("All resource backends initialized")
def get_embedding_backend():
"""Get the embedding backend instance"""
return get_backend("embedding")

View File

@@ -0,0 +1,322 @@
"""
Document Processing Backend
STATELESS document chunking and preprocessing for RAG operations.
All processing happens in memory - NO user data is ever stored.
"""
import logging
import io
import gc
from typing import Dict, Any, List, Optional, BinaryIO
from dataclasses import dataclass
import hashlib
# Document processing imports
import pypdf as PyPDF2
from docx import Document as DocxDocument
from bs4 import BeautifulSoup
from langchain_text_splitters import (
RecursiveCharacterTextSplitter,
TokenTextSplitter,
SentenceTransformersTokenTextSplitter
)
logger = logging.getLogger(__name__)
@dataclass
class ChunkingStrategy:
"""Configuration for document chunking"""
strategy_type: str # 'fixed', 'semantic', 'hierarchical', 'hybrid'
chunk_size: int # Target chunk size in tokens (optimized for BGE-M3: 512)
chunk_overlap: int # Overlap between chunks (typically 128 for BGE-M3)
separator_pattern: Optional[str] = None # Custom separator for splitting
preserve_paragraphs: bool = True
preserve_sentences: bool = True
class DocumentProcessorBackend:
"""
STATELESS document chunking and processing backend.
Security principles:
- NO persistence of user data
- All processing in memory only
- Immediate memory cleanup after processing
- No caching of user content
"""
def __init__(self):
self.supported_formats = [".pdf", ".docx", ".txt", ".md", ".html"]
# BGE-M3 optimal settings
self.default_chunk_size = 512 # tokens
self.default_chunk_overlap = 128 # tokens
self.model_name = "BAAI/bge-m3" # For tokenization
logger.info("STATELESS document processor backend initialized")
async def process_document(
self,
content: bytes,
document_type: str,
strategy: Optional[ChunkingStrategy] = None,
metadata: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Process document into chunks - STATELESS operation.
Args:
content: Document content as bytes (will be cleared from memory)
document_type: File type (.pdf, .docx, .txt, .md, .html)
strategy: Chunking strategy configuration
metadata: Optional metadata (will NOT include user content)
Returns:
List of chunks with metadata (immediately returned, not stored)
"""
try:
# Use default strategy if not provided
if strategy is None:
strategy = ChunkingStrategy(
strategy_type='hybrid',
chunk_size=self.default_chunk_size,
chunk_overlap=self.default_chunk_overlap
)
# Extract text based on document type (in memory)
text = await self._extract_text_from_bytes(content, document_type)
# Clear original content from memory
del content
gc.collect()
# Apply chunking strategy
if strategy.strategy_type == 'semantic':
chunks = await self._semantic_chunking(text, strategy)
elif strategy.strategy_type == 'hierarchical':
chunks = await self._hierarchical_chunking(text, strategy)
elif strategy.strategy_type == 'hybrid':
chunks = await self._hybrid_chunking(text, strategy)
else: # 'fixed'
chunks = await self._fixed_chunking(text, strategy)
# Clear text from memory
del text
gc.collect()
# Add metadata without storing content
processed_chunks = []
for idx, chunk in enumerate(chunks):
chunk_metadata = {
"chunk_index": idx,
"total_chunks": len(chunks),
"chunking_strategy": strategy.strategy_type,
"chunk_size_tokens": strategy.chunk_size,
# Generate hash for deduplication without storing content
"content_hash": hashlib.sha256(chunk.encode()).hexdigest()[:16]
}
# Add non-sensitive metadata if provided
if metadata:
# Filter out any potential sensitive data
safe_metadata = {
k: v for k, v in metadata.items()
if k in ['document_type', 'processing_timestamp', 'tenant_id']
}
chunk_metadata.update(safe_metadata)
processed_chunks.append({
"text": chunk,
"metadata": chunk_metadata
})
logger.info(f"Processed document into {len(processed_chunks)} chunks (STATELESS)")
# Return immediately - no storage
return processed_chunks
except Exception as e:
logger.error(f"Error processing document: {e}")
# Ensure memory is cleared even on error
gc.collect()
raise
finally:
# Always ensure memory cleanup
gc.collect()
async def _extract_text_from_bytes(
self,
content: bytes,
document_type: str
) -> str:
"""Extract text from document bytes - in memory only"""
try:
if document_type == ".pdf":
return await self._extract_pdf_text(io.BytesIO(content))
elif document_type == ".docx":
return await self._extract_docx_text(io.BytesIO(content))
elif document_type == ".html":
return await self._extract_html_text(content.decode('utf-8'))
elif document_type in [".txt", ".md"]:
return content.decode('utf-8')
else:
raise ValueError(f"Unsupported document type: {document_type}")
finally:
# Clear content from memory
del content
gc.collect()
async def _extract_pdf_text(self, file_stream: BinaryIO) -> str:
"""Extract text from PDF - in memory"""
text = ""
try:
pdf_reader = PyPDF2.PdfReader(file_stream)
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
text += page.extract_text() + "\n"
finally:
file_stream.close()
gc.collect()
return text
async def _extract_docx_text(self, file_stream: BinaryIO) -> str:
"""Extract text from DOCX - in memory"""
text = ""
try:
doc = DocxDocument(file_stream)
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
finally:
file_stream.close()
gc.collect()
return text
async def _extract_html_text(self, html_content: str) -> str:
"""Extract text from HTML - in memory"""
soup = BeautifulSoup(html_content, 'html.parser')
# Remove script and style elements
for script in soup(["script", "style"]):
script.decompose()
text = soup.get_text()
# Clean up whitespace
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n'.join(chunk for chunk in chunks if chunk)
return text
async def _semantic_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Semantic chunking using sentence boundaries"""
splitter = SentenceTransformersTokenTextSplitter(
model_name=self.model_name,
chunk_size=strategy.chunk_size,
chunk_overlap=strategy.chunk_overlap
)
return splitter.split_text(text)
async def _hierarchical_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Hierarchical chunking preserving document structure"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=strategy.chunk_size * 3, # Approximate token to char ratio
chunk_overlap=strategy.chunk_overlap * 3,
separators=["\n\n\n", "\n\n", "\n", ". ", " ", ""],
keep_separator=True
)
return splitter.split_text(text)
async def _hybrid_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Hybrid chunking combining semantic and structural boundaries"""
# First split by structure
structural_splitter = RecursiveCharacterTextSplitter(
chunk_size=strategy.chunk_size * 4,
chunk_overlap=0,
separators=["\n\n\n", "\n\n"],
keep_separator=True
)
structural_chunks = structural_splitter.split_text(text)
# Then apply semantic splitting to each structural chunk
final_chunks = []
token_splitter = TokenTextSplitter(
chunk_size=strategy.chunk_size,
chunk_overlap=strategy.chunk_overlap
)
for struct_chunk in structural_chunks:
semantic_chunks = token_splitter.split_text(struct_chunk)
final_chunks.extend(semantic_chunks)
return final_chunks
async def _fixed_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Fixed-size chunking with token boundaries"""
splitter = TokenTextSplitter(
chunk_size=strategy.chunk_size,
chunk_overlap=strategy.chunk_overlap
)
return splitter.split_text(text)
async def validate_document(
self,
content_size: int,
document_type: str
) -> Dict[str, Any]:
"""
Validate document before processing - no content stored.
Args:
content_size: Size of document in bytes
document_type: File extension
Returns:
Validation result with any warnings
"""
MAX_SIZE = 50 * 1024 * 1024 # 50MB max
validation = {
"valid": True,
"warnings": [],
"errors": []
}
# Check file size
if content_size > MAX_SIZE:
validation["valid"] = False
validation["errors"].append(f"File size exceeds maximum of 50MB")
elif content_size > 10 * 1024 * 1024: # Warning for files over 10MB
validation["warnings"].append("Large file may take longer to process")
# Check document type
if document_type not in self.supported_formats:
validation["valid"] = False
validation["errors"].append(f"Unsupported format: {document_type}")
return validation
async def check_health(self) -> Dict[str, Any]:
"""Check document processor health - no user data exposed"""
return {
"status": "healthy",
"supported_formats": self.supported_formats,
"default_chunk_size": self.default_chunk_size,
"default_chunk_overlap": self.default_chunk_overlap,
"model": self.model_name,
"stateless": True, # Confirm stateless operation
"memory_cleared": True # Confirm memory management
}

View File

@@ -0,0 +1,471 @@
"""
Embedding Model Backend
STATELESS embedding generation using BGE-M3 model hosted on GT's GPU clusters.
All embeddings are generated in real-time - NO user data is stored.
"""
import logging
import gc
import hashlib
import asyncio
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
# import numpy as np # Temporarily disabled for Docker build
import aiohttp
import json
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class EmbeddingRequest:
"""Request structure for embedding generation"""
texts: List[str]
model: str = "BAAI/bge-m3"
batch_size: int = 32
normalize: bool = True
instruction: Optional[str] = None # For instruction-based embeddings
class EmbeddingBackend:
"""
STATELESS embedding backend for BGE-M3 model.
Security principles:
- NO persistence of embeddings or text
- All processing via GT's internal GPU cluster
- Immediate memory cleanup after generation
- No caching of user content
- Request signing and verification
"""
def __init__(self):
self.model_name = "BAAI/bge-m3"
self.embedding_dimensions = 1024 # BGE-M3 dimensions
self.max_batch_size = 32
self.max_sequence_length = 8192 # BGE-M3 supports up to 8192 tokens
# Determine endpoint based on configuration
self.embedding_endpoint = self._get_embedding_endpoint()
# Timeout for embedding requests
self.request_timeout = 60 # seconds for model loading
logger.info(f"STATELESS embedding backend initialized for {self.model_name}")
logger.info(f"Using embedding endpoint: {self.embedding_endpoint}")
def _get_embedding_endpoint(self) -> str:
"""
Get the embedding endpoint based on configuration.
Priority:
1. Model registry from config sync (database-backed)
2. Environment variables (BGE_M3_LOCAL_MODE, BGE_M3_EXTERNAL_ENDPOINT)
3. Default local endpoint
"""
# Try to get configuration from model registry first (loaded from database)
try:
from app.services.model_service import default_model_service
import asyncio
# Use the default model service instance (singleton) used by config sync
model_service = default_model_service
# Try to get the model config synchronously (during initialization)
# The get_model method is async, so we need to handle this carefully
bge_m3_config = model_service.model_registry.get("BAAI/bge-m3")
if bge_m3_config:
# Model registry stores endpoint as 'endpoint_url' and config as 'parameters'
endpoint = bge_m3_config.get("endpoint_url")
config = bge_m3_config.get("parameters", {})
is_local_mode = config.get("is_local_mode", True)
external_endpoint = config.get("external_endpoint")
logger.info(f"Found BGE-M3 in registry: endpoint_url={endpoint}, is_local_mode={is_local_mode}, external_endpoint={external_endpoint}")
if endpoint:
logger.info(f"Using BGE-M3 endpoint from model registry (is_local_mode={is_local_mode}): {endpoint}")
return endpoint
else:
logger.warning(f"BGE-M3 found in registry but endpoint_url is None/empty. Full config: {bge_m3_config}")
else:
available_models = list(model_service.model_registry.keys())
logger.debug(f"BGE-M3 not found in model registry during init (expected on first startup). Available models: {available_models}")
except Exception as e:
logger.debug(f"Model registry not yet available during startup (will be populated after config sync): {e}")
# Fall back to Settings fields (environment variables or .env file)
is_local_mode = getattr(settings, 'bge_m3_local_mode', True)
external_endpoint = getattr(settings, 'bge_m3_external_endpoint', None)
if not is_local_mode and external_endpoint:
logger.info(f"Using external BGE-M3 endpoint from settings: {external_endpoint}")
return external_endpoint
# Default to local endpoint
local_endpoint = getattr(
settings,
'embedding_endpoint',
'http://gentwo-vllm-embeddings:8000/v1/embeddings'
)
logger.info(f"Using local BGE-M3 endpoint: {local_endpoint}")
return local_endpoint
async def update_endpoint_config(self, is_local_mode: bool, external_endpoint: str = None):
"""
Update the embedding endpoint configuration dynamically.
This allows switching between local and external endpoints without restart.
"""
if is_local_mode:
self.embedding_endpoint = getattr(
settings,
'embedding_endpoint',
'http://gentwo-vllm-embeddings:8000/v1/embeddings'
)
else:
if external_endpoint:
self.embedding_endpoint = external_endpoint
else:
raise ValueError("External endpoint must be provided when not in local mode")
logger.info(f"BGE-M3 endpoint updated to: {self.embedding_endpoint}")
logger.info(f"Mode: {'Local GT Edge' if is_local_mode else 'External API'}")
def refresh_endpoint_from_registry(self):
"""
Refresh the embedding endpoint from the model registry.
Called by config sync when BGE-M3 configuration changes.
"""
logger.info(f"Refreshing embedding endpoint - current: {self.embedding_endpoint}")
new_endpoint = self._get_embedding_endpoint()
if new_endpoint != self.embedding_endpoint:
logger.info(f"Refreshing BGE-M3 endpoint from {self.embedding_endpoint} to {new_endpoint}")
self.embedding_endpoint = new_endpoint
else:
logger.info(f"BGE-M3 endpoint unchanged: {self.embedding_endpoint}")
async def generate_embeddings(
self,
texts: List[str],
instruction: Optional[str] = None,
tenant_id: str = None,
request_id: str = None
) -> List[List[float]]:
"""
Generate embeddings for texts using BGE-M3 - STATELESS operation.
Args:
texts: List of texts to embed (will be cleared from memory)
instruction: Optional instruction for query vs document embeddings
tenant_id: Tenant ID for audit logging (not stored with data)
request_id: Request ID for tracing
Returns:
List of embedding vectors (immediately returned, not stored)
"""
try:
# Validate input
if not texts:
return []
if len(texts) > self.max_batch_size:
# Process in batches
return await self._batch_process_embeddings(
texts, instruction, tenant_id, request_id
)
# Prepare request
request_data = {
"model": self.model_name,
"input": texts,
"encoding_format": "float",
"dimensions": self.embedding_dimensions
}
# Add instruction if provided (for query vs document distinction)
if instruction:
request_data["instruction"] = instruction
# Add metadata for audit (not stored with embeddings)
metadata = {
"tenant_id": tenant_id,
"request_id": request_id,
"text_count": len(texts),
# Hash for deduplication without storing content
"content_hash": hashlib.sha256(
"".join(texts).encode()
).hexdigest()[:16]
}
# Call vLLM service - NO FALLBACKS
embeddings = await self._call_embedding_service(request_data, metadata)
# Clear texts from memory immediately
del texts
gc.collect()
# Validate response
if not embeddings or len(embeddings) == 0:
raise ValueError("No embeddings returned from service")
# Normalize if needed
if self._should_normalize():
embeddings = self._normalize_embeddings(embeddings)
logger.info(
f"Generated {len(embeddings)} embeddings (STATELESS) "
f"for tenant {tenant_id}"
)
# Return immediately - no storage
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
# Ensure memory is cleared even on error
gc.collect()
raise
finally:
# Always ensure memory cleanup
gc.collect()
async def _batch_process_embeddings(
self,
texts: List[str],
instruction: Optional[str],
tenant_id: str,
request_id: str
) -> List[List[float]]:
"""Process large text lists in batches using vLLM service"""
all_embeddings = []
for i in range(0, len(texts), self.max_batch_size):
batch = texts[i:i + self.max_batch_size]
# Prepare request for this batch
request_data = {
"model": self.model_name,
"input": batch,
"encoding_format": "float",
"dimensions": self.embedding_dimensions
}
if instruction:
request_data["instruction"] = instruction
metadata = {
"tenant_id": tenant_id,
"request_id": f"{request_id}_batch_{i}",
"text_count": len(batch),
"content_hash": hashlib.sha256(
"".join(batch).encode()
).hexdigest()[:16]
}
batch_embeddings = await self._call_embedding_service(request_data, metadata)
all_embeddings.extend(batch_embeddings)
# Clear batch from memory
del batch
gc.collect()
return all_embeddings
async def _call_embedding_service(
self,
request_data: Dict[str, Any],
metadata: Dict[str, Any]
) -> List[List[float]]:
"""Call internal GPU cluster embedding service"""
async with aiohttp.ClientSession() as session:
try:
# Add capability token for authentication
headers = {
"Content-Type": "application/json",
"X-Tenant-ID": metadata.get("tenant_id", ""),
"X-Request-ID": metadata.get("request_id", ""),
# Authorization will be added by Resource Cluster
}
async with session.post(
self.embedding_endpoint,
json=request_data,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self.request_timeout)
) as response:
if response.status != 200:
error_text = await response.text()
raise ValueError(
f"Embedding service error: {response.status} - {error_text}"
)
result = await response.json()
# Extract embeddings from response
if "data" in result:
embeddings = [item["embedding"] for item in result["data"]]
elif "embeddings" in result:
embeddings = result["embeddings"]
else:
raise ValueError("Invalid embedding service response format")
return embeddings
except asyncio.TimeoutError:
raise ValueError(f"Embedding service timeout after {self.request_timeout}s")
except Exception as e:
logger.error(f"Error calling embedding service: {e}")
raise
def _should_normalize(self) -> bool:
"""Check if embeddings should be normalized"""
# BGE-M3 embeddings are typically normalized for similarity search
return True
def _normalize_embeddings(
self,
embeddings: List[List[float]]
) -> List[List[float]]:
"""Normalize embedding vectors to unit length"""
normalized = []
for embedding in embeddings:
# Simple normalization without numpy (for now)
import math
# Calculate norm
norm = math.sqrt(sum(x * x for x in embedding))
if norm > 0:
normalized_vec = [x / norm for x in embedding]
else:
normalized_vec = embedding[:]
normalized.append(normalized_vec)
return normalized
async def generate_query_embeddings(
self,
queries: List[str],
tenant_id: str = None,
request_id: str = None
) -> List[List[float]]:
"""
Generate embeddings specifically for queries.
BGE-M3 can use different instructions for queries vs documents.
"""
# For BGE-M3, queries can use a specific instruction
instruction = "Represent this sentence for searching relevant passages: "
return await self.generate_embeddings(
queries, instruction, tenant_id, request_id
)
async def generate_document_embeddings(
self,
documents: List[str],
tenant_id: str = None,
request_id: str = None
) -> List[List[float]]:
"""
Generate embeddings specifically for documents.
BGE-M3 can use different instructions for documents vs queries.
"""
# For BGE-M3, documents typically don't need special instruction
return await self.generate_embeddings(
documents, None, tenant_id, request_id
)
async def validate_texts(
self,
texts: List[str]
) -> Dict[str, Any]:
"""
Validate texts before embedding - no content stored.
Args:
texts: List of texts to validate
Returns:
Validation result with any warnings
"""
validation = {
"valid": True,
"warnings": [],
"errors": [],
"stats": {
"total_texts": len(texts),
"max_length": 0,
"avg_length": 0
}
}
if not texts:
validation["valid"] = False
validation["errors"].append("No texts provided")
return validation
# Check text lengths
lengths = [len(text) for text in texts]
validation["stats"]["max_length"] = max(lengths)
validation["stats"]["avg_length"] = sum(lengths) // len(lengths)
# BGE-M3 max sequence length check (approximate)
max_chars = self.max_sequence_length * 4 # Rough char to token ratio
for i, length in enumerate(lengths):
if length > max_chars:
validation["warnings"].append(
f"Text {i} may exceed model's max sequence length"
)
elif length == 0:
validation["errors"].append(f"Text {i} is empty")
validation["valid"] = False
# Batch size check
if len(texts) > self.max_batch_size * 10:
validation["warnings"].append(
f"Large batch ({len(texts)} texts) will be processed in chunks"
)
return validation
async def check_health(self) -> Dict[str, Any]:
"""Check embedding backend health - no user data exposed"""
try:
# Test connection to vLLM service
test_text = ["Health check test"]
test_embeddings = await self.generate_embeddings(
test_text,
tenant_id="health_check",
request_id="health_check"
)
health_status = {
"status": "healthy",
"model": self.model_name,
"dimensions": self.embedding_dimensions,
"max_batch_size": self.max_batch_size,
"max_sequence_length": self.max_sequence_length,
"endpoint": self.embedding_endpoint,
"stateless": True,
"memory_cleared": True,
"vllm_service_connected": len(test_embeddings) > 0
}
except Exception as e:
health_status = {
"status": "unhealthy",
"error": str(e),
"model": self.model_name,
"endpoint": self.embedding_endpoint
}
return health_status

View File

@@ -0,0 +1,780 @@
"""
Groq Cloud LLM Proxy Backend
Provides high-availability LLM inference through Groq Cloud with:
- HAProxy load balancing across multiple endpoints
- Automatic failover handled by HAProxy
- Token usage tracking and cost calculation
- Streaming response support
- Circuit breaker pattern for enhanced reliability
"""
import asyncio
import json
import os
import time
from typing import Dict, Any, List, Optional, AsyncGenerator
from datetime import datetime
import httpx
try:
from groq import AsyncGroq
GROQ_AVAILABLE = True
except ImportError:
# Groq not available in development mode
AsyncGroq = None
GROQ_AVAILABLE = False
import logging
from app.core.config import get_settings, get_model_configs
from app.services.model_service import get_model_service
logger = logging.getLogger(__name__)
settings = get_settings()
# Groq Compound tool pricing (per request/execution)
# Source: https://groq.com/pricing (Dec 2, 2025)
COMPOUND_TOOL_PRICES = {
# Web Search variants
"search": 0.008, # API returns "search" for web search
"web_search": 0.008, # $8 per 1K = $0.008 per request (Advanced Search)
"advanced_search": 0.008, # $8 per 1K requests
"basic_search": 0.005, # $5 per 1K requests
# Other tools
"visit_website": 0.001, # $1 per 1K requests
"python": 0.00005, # API returns "python" for code execution
"code_interpreter": 0.00005, # Alternative API identifier
"code_execution": 0.00005, # Alias for backwards compatibility
"browser_automation": 0.00002, # $0.08/hr ≈ $0.00002 per execution
}
# Model pricing per million tokens (input/output)
# Source: https://groq.com/pricing (Dec 2, 2025)
GROQ_MODEL_PRICES = {
"llama-3.3-70b-versatile": {"input": 0.59, "output": 0.79},
"llama-3.1-8b-instant": {"input": 0.05, "output": 0.08},
"llama-4-maverick-17b-128e-instruct": {"input": 0.20, "output": 0.60},
"meta-llama/llama-4-maverick-17b-128e-instruct": {"input": 0.20, "output": 0.60},
"llama-4-scout-17b-16e-instruct": {"input": 0.11, "output": 0.34},
"meta-llama/llama-4-scout-17b-16e-instruct": {"input": 0.11, "output": 0.34},
"llama-guard-4-12b": {"input": 0.20, "output": 0.20},
"meta-llama/llama-guard-4-12b": {"input": 0.20, "output": 0.20},
"gpt-oss-120b": {"input": 0.15, "output": 0.60},
"openai/gpt-oss-120b": {"input": 0.15, "output": 0.60},
"gpt-oss-20b": {"input": 0.075, "output": 0.30},
"openai/gpt-oss-20b": {"input": 0.075, "output": 0.30},
"kimi-k2-instruct-0905": {"input": 1.00, "output": 3.00},
"moonshotai/kimi-k2-instruct-0905": {"input": 1.00, "output": 3.00},
"qwen3-32b": {"input": 0.29, "output": 0.59},
# Compound models - 50/50 blended pricing from underlying models
# compound: GPT-OSS-120B ($0.15/$0.60) + Llama 4 Scout ($0.11/$0.34) = $0.13/$0.47
"compound": {"input": 0.13, "output": 0.47},
"groq/compound": {"input": 0.13, "output": 0.47},
"compound-beta": {"input": 0.13, "output": 0.47},
# compound-mini: GPT-OSS-120B ($0.15/$0.60) + Llama 3.3 70B ($0.59/$0.79) = $0.37/$0.695
"compound-mini": {"input": 0.37, "output": 0.695},
"groq/compound-mini": {"input": 0.37, "output": 0.695},
"compound-mini-beta": {"input": 0.37, "output": 0.695},
}
class GroqProxyBackend:
"""LLM inference via Groq Cloud with HAProxy load balancing"""
def __init__(self):
self.settings = get_settings()
self.client = None
self.usage_metrics = {}
self.circuit_breaker_status = {}
self._initialize_client()
def _initialize_client(self):
"""Initialize Groq client to use HAProxy load balancer"""
if not GROQ_AVAILABLE:
logger.warning("Groq client not available - running in development mode")
return
if self.settings.groq_api_key:
# Use HAProxy load balancer instead of direct Groq API
haproxy_endpoint = self.settings.haproxy_groq_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local"
# Initialize client with HAProxy endpoint
self.client = AsyncGroq(
api_key=self.settings.groq_api_key,
base_url=haproxy_endpoint,
timeout=httpx.Timeout(30.0), # Increased timeout for load balancing
max_retries=1 # Let HAProxy handle retries
)
# Initialize circuit breaker
self.circuit_breaker_status = {
"state": "closed", # closed, open, half_open
"failure_count": 0,
"last_failure_time": None,
"failure_threshold": 5,
"recovery_timeout": 60 # seconds
}
logger.info(f"Initialized Groq client with HAProxy endpoint: {haproxy_endpoint}")
async def execute_inference(
self,
prompt: str,
model: str = "llama-3.1-70b-versatile",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None
) -> Dict[str, Any]:
"""Execute LLM inference with HAProxy load balancing and circuit breaker"""
# Check circuit breaker
if not await self._is_circuit_closed():
raise Exception("Circuit breaker is open - service temporarily unavailable")
# Validate model and get configuration
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
if not model_config:
# Try to get from model service registry
model_service = get_model_service(tenant_id)
model_info = await model_service.get_model(model)
if not model_info:
raise ValueError(f"Unsupported model: {model}")
model_config = {
"max_tokens": model_info["performance"]["max_tokens"],
"cost_per_1k_tokens": model_info["performance"]["cost_per_1k_tokens"],
"supports_streaming": model_info["capabilities"].get("streaming", False)
}
# Apply token limits
max_tokens = min(max_tokens, model_config["max_tokens"])
# Prepare messages
messages = [
{"role": "user", "content": prompt}
]
try:
# Get tenant-specific API key
if not tenant_id:
raise ValueError("tenant_id is required for Groq inference")
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
start_time = time.time()
if stream:
return await self._stream_inference(
messages, model, temperature, max_tokens, user_id, tenant_id, client
)
else:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=False
)
# Track successful usage
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
response.usage.total_tokens if response.usage else 0,
latency, model_config["cost_per_1k_tokens"]
)
# Track in model service
model_service = get_model_service(tenant_id)
await model_service.track_model_usage(
model_id=model,
success=True,
latency_ms=latency
)
# Reset circuit breaker on success
await self._record_success()
return {
"content": response.choices[0].message.content,
"model": model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0,
"cost_cents": self._calculate_cost(
response.usage.total_tokens if response.usage else 0,
model_config["cost_per_1k_tokens"]
)
},
"latency_ms": latency,
"load_balanced": True,
"haproxy_backend": "groq_general_backend"
}
except Exception as e:
logger.error(f"HAProxy Groq inference failed: {e}")
# Track failure in model service
await model_service.track_model_usage(
model_id=model,
success=False
)
# Record failure for circuit breaker
await self._record_failure()
# Re-raise the exception - no client-side fallback needed
# HAProxy handles all failover logic
raise Exception(f"Groq inference failed (via HAProxy): {str(e)}")
async def _stream_inference(
self,
messages: List[Dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
user_id: str,
tenant_id: str,
client: AsyncGroq = None
) -> AsyncGenerator[str, None]:
"""Stream LLM inference responses"""
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
start_time = time.time()
total_tokens = 0
try:
# Use provided client or get tenant-specific client
if not client:
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
stream = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
async for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
total_tokens += len(content.split()) # Approximate token count
# Yield SSE formatted data
yield f"data: {json.dumps({'content': content})}\n\n"
# Track usage after streaming completes
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
total_tokens, latency,
model_config["cost_per_1k_tokens"]
)
# Send completion signal
yield f"data: {json.dumps({'done': True})}\n\n"
except Exception as e:
logger.error(f"Streaming inference error: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
async def check_health(self) -> Dict[str, Any]:
"""Check health of HAProxy load balancer and circuit breaker status"""
try:
# Check HAProxy health via stats endpoint
haproxy_stats_url = self.settings.haproxy_stats_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local:8404/stats"
async with httpx.AsyncClient() as client:
response = await client.get(
haproxy_stats_url,
timeout=5.0,
auth=("admin", "gt2_haproxy_stats_password")
)
if response.status_code == 200:
# Parse HAProxy stats (simplified)
stats_healthy = "UP" in response.text
return {
"haproxy_load_balancer": {
"healthy": stats_healthy,
"stats_accessible": True,
"last_check": datetime.utcnow().isoformat()
},
"circuit_breaker": {
"state": self.circuit_breaker_status["state"],
"failure_count": self.circuit_breaker_status["failure_count"],
"last_failure": self.circuit_breaker_status["last_failure_time"].isoformat() if self.circuit_breaker_status["last_failure_time"] else None
},
"groq_endpoints": {
"managed_by": "haproxy",
"failover_handled_by": "haproxy"
}
}
else:
return {
"haproxy_load_balancer": {
"healthy": False,
"error": f"Stats endpoint returned {response.status_code}",
"last_check": datetime.utcnow().isoformat()
}
}
except Exception as e:
return {
"haproxy_load_balancer": {
"healthy": False,
"error": str(e),
"last_check": datetime.utcnow().isoformat()
},
"circuit_breaker": {
"state": self.circuit_breaker_status["state"],
"failure_count": self.circuit_breaker_status["failure_count"]
}
}
async def _is_circuit_closed(self) -> bool:
"""Check if circuit breaker allows requests"""
if self.circuit_breaker_status["state"] == "closed":
return True
if self.circuit_breaker_status["state"] == "open":
# Check if recovery timeout has passed
if self.circuit_breaker_status["last_failure_time"]:
time_since_failure = (datetime.utcnow() - self.circuit_breaker_status["last_failure_time"]).total_seconds()
if time_since_failure > self.circuit_breaker_status["recovery_timeout"]:
# Move to half-open state
self.circuit_breaker_status["state"] = "half_open"
logger.info("Circuit breaker moved to half-open state")
return True
return False
if self.circuit_breaker_status["state"] == "half_open":
# Allow limited requests in half-open state
return True
return False
async def _record_success(self):
"""Record successful request for circuit breaker"""
if self.circuit_breaker_status["state"] == "half_open":
# Success in half-open state closes the circuit
self.circuit_breaker_status["state"] = "closed"
self.circuit_breaker_status["failure_count"] = 0
logger.info("Circuit breaker closed after successful request")
# Reset failure count on any success
self.circuit_breaker_status["failure_count"] = 0
async def _record_failure(self):
"""Record failed request for circuit breaker"""
self.circuit_breaker_status["failure_count"] += 1
self.circuit_breaker_status["last_failure_time"] = datetime.utcnow()
if self.circuit_breaker_status["failure_count"] >= self.circuit_breaker_status["failure_threshold"]:
if self.circuit_breaker_status["state"] in ["closed", "half_open"]:
self.circuit_breaker_status["state"] = "open"
logger.warning(f"Circuit breaker opened after {self.circuit_breaker_status['failure_count']} failures")
async def _track_usage(
self,
user_id: str,
tenant_id: str,
model: str,
tokens: int,
latency: float,
cost_per_1k: float
):
"""Track usage metrics for billing and monitoring"""
# Create usage key
usage_key = f"{tenant_id}:{user_id}:{model}"
# Initialize metrics if not exists
if usage_key not in self.usage_metrics:
self.usage_metrics[usage_key] = {
"total_tokens": 0,
"total_requests": 0,
"total_cost_cents": 0,
"average_latency": 0
}
# Update metrics
metrics = self.usage_metrics[usage_key]
metrics["total_tokens"] += tokens
metrics["total_requests"] += 1
metrics["total_cost_cents"] += self._calculate_cost(tokens, cost_per_1k)
# Update average latency
prev_avg = metrics["average_latency"]
prev_count = metrics["total_requests"] - 1
metrics["average_latency"] = (prev_avg * prev_count + latency) / metrics["total_requests"]
# Log high-level metrics
if metrics["total_requests"] % 100 == 0:
logger.info(f"Usage milestone for {usage_key}: {metrics}")
def _calculate_cost(self, tokens: int, cost_per_1k: float) -> int:
"""Calculate cost in cents"""
return int((tokens / 1000) * cost_per_1k * 100)
def _calculate_compound_cost(self, response_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Calculate detailed cost breakdown for Groq Compound responses.
Compound API returns usage_breakdown with per-model token counts
and executed_tools list showing which tools were called.
Returns:
Dict with total cost in dollars and detailed breakdown
"""
total_cost = 0.0
breakdown = {"models": [], "tools": [], "total_cost_dollars": 0.0, "total_cost_cents": 0}
# Parse usage_breakdown for per-model token costs
usage_breakdown = response_data.get("usage_breakdown", {})
models_usage = usage_breakdown.get("models", [])
for model_usage in models_usage:
model_name = model_usage.get("model", "")
usage = model_usage.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
# Get model pricing (try multiple name formats)
model_prices = GROQ_MODEL_PRICES.get(model_name)
if not model_prices:
# Try without provider prefix
short_name = model_name.split("/")[-1] if "/" in model_name else model_name
model_prices = GROQ_MODEL_PRICES.get(short_name, {"input": 0.15, "output": 0.60})
# Calculate cost per million tokens
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
model_total = input_cost + output_cost
breakdown["models"].append({
"model": model_name,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"input_cost_dollars": round(input_cost, 6),
"output_cost_dollars": round(output_cost, 6),
"total_cost_dollars": round(model_total, 6)
})
total_cost += model_total
# Parse executed_tools for tool costs
executed_tools = response_data.get("executed_tools", [])
for tool in executed_tools:
# Handle both string and dict formats
tool_name = tool if isinstance(tool, str) else tool.get("name", "unknown")
tool_cost = COMPOUND_TOOL_PRICES.get(tool_name.lower(), 0.008) # Default to advanced search
breakdown["tools"].append({
"tool": tool_name,
"cost_dollars": round(tool_cost, 6)
})
total_cost += tool_cost
breakdown["total_cost_dollars"] = round(total_cost, 6)
breakdown["total_cost_cents"] = int(total_cost * 100)
return breakdown
def _is_compound_model(self, model: str) -> bool:
"""Check if model is a Groq Compound model"""
model_lower = model.lower()
return "compound" in model_lower or model_lower.startswith("groq/compound")
async def get_available_models(self) -> List[Dict[str, Any]]:
"""Get list of available Groq models with their configurations"""
models = []
model_configs = get_model_configs()
for model_id, config in model_configs.get("groq", {}).items():
models.append({
"id": model_id,
"name": model_id.replace("-", " ").title(),
"provider": "groq",
"max_tokens": config["max_tokens"],
"cost_per_1k_tokens": config["cost_per_1k_tokens"],
"supports_streaming": config["supports_streaming"],
"supports_function_calling": config["supports_function_calling"]
})
return models
async def execute_inference_with_messages(
self,
messages: List[Dict[str, str]],
model: str = "llama-3.1-70b-versatile",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None
) -> Dict[str, Any]:
"""Execute LLM inference using messages format (conversation style)"""
# Check circuit breaker
if not await self._is_circuit_closed():
raise Exception("Circuit breaker is open - service temporarily unavailable")
# Validate model and get configuration
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
if not model_config:
# Try to get from model service registry
model_service = get_model_service(tenant_id)
model_info = await model_service.get_model(model)
if not model_info:
raise ValueError(f"Unsupported model: {model}")
model_config = {
"max_tokens": model_info["performance"]["max_tokens"],
"cost_per_1k_tokens": model_info["performance"]["cost_per_1k_tokens"],
"supports_streaming": model_info["capabilities"].get("streaming", False)
}
# Apply token limits
max_tokens = min(max_tokens, model_config["max_tokens"])
try:
# Get tenant-specific API key
if not tenant_id:
raise ValueError("tenant_id is required for Groq inference")
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
start_time = time.time()
# Translate GT 2.0 "agent" role to OpenAI/Groq "assistant" for external API compatibility
# Use dictionary unpacking to preserve ALL fields including tool_call_id
external_messages = []
for msg in messages:
external_msg = {
**msg, # Preserve ALL fields including tool_call_id, tool_calls, etc.
"role": "assistant" if msg.get("role") == "agent" else msg.get("role")
}
external_messages.append(external_msg)
if stream:
return await self._stream_inference_with_messages(
external_messages, model, temperature, max_tokens, user_id, tenant_id, client
)
else:
# Prepare request parameters
request_params = {
"model": model,
"messages": external_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": False
}
# Add tools if provided
if tools:
request_params["tools"] = tools
if tool_choice:
request_params["tool_choice"] = tool_choice
# Debug: Log messages being sent to Groq
logger.info(f"🔧 Sending {len(external_messages)} messages to Groq API")
for i, msg in enumerate(external_messages):
if msg.get("role") == "tool":
logger.info(f"🔧 Groq Message {i}: role=tool, tool_call_id={msg.get('tool_call_id')}")
else:
logger.info(f"🔧 Groq Message {i}: role={msg.get('role')}, has_tool_calls={bool(msg.get('tool_calls'))}")
response = await client.chat.completions.create(**request_params)
# Track successful usage
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
response.usage.total_tokens if response.usage else 0,
latency, model_config["cost_per_1k_tokens"]
)
# Track in model service
model_service = get_model_service(tenant_id)
await model_service.track_model_usage(
model_id=model,
success=True,
latency_ms=latency
)
# Reset circuit breaker on success
await self._record_success()
# Build base response
result = {
"content": response.choices[0].message.content,
"model": model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0,
"cost_cents": self._calculate_cost(
response.usage.total_tokens if response.usage else 0,
model_config["cost_per_1k_tokens"]
)
},
"latency_ms": latency,
"load_balanced": True,
"haproxy_backend": "groq_general_backend"
}
# For Compound models, extract and calculate detailed cost breakdown
if self._is_compound_model(model):
# Convert response to dict for processing
response_dict = response.model_dump() if hasattr(response, 'model_dump') else {}
# Extract usage_breakdown and executed_tools if present
usage_breakdown = getattr(response, 'usage_breakdown', None)
executed_tools = getattr(response, 'executed_tools', None)
if usage_breakdown or executed_tools:
compound_data = {
"usage_breakdown": usage_breakdown if isinstance(usage_breakdown, dict) else {},
"executed_tools": executed_tools if isinstance(executed_tools, list) else []
}
# Calculate detailed cost breakdown
cost_breakdown = self._calculate_compound_cost(compound_data)
# Add compound-specific data to response
result["usage_breakdown"] = compound_data.get("usage_breakdown", {})
result["executed_tools"] = compound_data.get("executed_tools", [])
result["cost_breakdown"] = cost_breakdown
# Update cost_cents with accurate compound calculation
if cost_breakdown["total_cost_cents"] > 0:
result["usage"]["cost_cents"] = cost_breakdown["total_cost_cents"]
logger.info(f"Compound model cost breakdown: {cost_breakdown}")
return result
except Exception as e:
logger.error(f"HAProxy Groq inference with messages failed: {e}")
# Track failure in model service
await model_service.track_model_usage(
model_id=model,
success=False
)
# Record failure for circuit breaker
await self._record_failure()
# Re-raise the exception
raise Exception(f"Groq inference with messages failed (via HAProxy): {str(e)}")
async def _stream_inference_with_messages(
self,
messages: List[Dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
user_id: str,
tenant_id: str,
client: AsyncGroq = None
) -> AsyncGenerator[str, None]:
"""Stream LLM inference responses using messages format"""
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
start_time = time.time()
total_tokens = 0
try:
# Use provided client or get tenant-specific client
if not client:
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
stream = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
async for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
total_tokens += len(content.split()) # Approximate token count
# Yield just the content (SSE formatting handled by caller)
yield content
# Track usage after streaming completes
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
total_tokens, latency,
model_config["cost_per_1k_tokens"] if model_config else 0.0
)
except Exception as e:
logger.error(f"Streaming inference with messages error: {e}")
raise e
async def _get_tenant_api_key(self, tenant_id: str) -> str:
"""
Get API key for tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
API keys are managed in Control Panel and fetched via internal API.
Args:
tenant_id: Tenant domain string from X-Tenant-ID header
Returns:
Decrypted Groq API key
Raises:
ValueError: If no API key configured (results in HTTP 503 to client)
"""
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="groq")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No Groq API key for tenant '{tenant_id}': {e}")
raise ValueError(str(e))
except RuntimeError as e:
logger.error(f"Control Panel error: {e}")
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
def _get_client(self, api_key: str) -> AsyncGroq:
"""Get Groq client with specified API key"""
if not GROQ_AVAILABLE:
raise Exception("Groq client not available in development mode")
haproxy_endpoint = self.settings.haproxy_groq_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local"
return AsyncGroq(
api_key=api_key,
base_url=haproxy_endpoint,
timeout=httpx.Timeout(30.0),
max_retries=1
)

View File

@@ -0,0 +1,407 @@
"""
NVIDIA NIM LLM Proxy Backend
Provides LLM inference through NVIDIA NIM with:
- OpenAI-compatible API format (build.nvidia.com)
- Token usage tracking and cost calculation
- Streaming response support
- Circuit breaker pattern for enhanced reliability
"""
import json
import time
from typing import Dict, Any, List, Optional, AsyncGenerator
from datetime import datetime
import httpx
import logging
from app.core.config import get_settings
logger = logging.getLogger(__name__)
# NVIDIA NIM Model pricing per million tokens (input/output)
# Source: build.nvidia.com (Dec 2025 pricing estimates)
# Note: Actual pricing may vary - check build.nvidia.com for current rates
NVIDIA_MODEL_PRICES = {
# Llama Nemotron family
"nvidia/llama-3.1-nemotron-ultra-253b-v1": {"input": 2.0, "output": 6.0},
"nvidia/llama-3.1-nemotron-super-49b-v1": {"input": 0.5, "output": 1.5},
"nvidia/llama-3.1-nemotron-nano-8b-v1": {"input": 0.1, "output": 0.3},
# Standard Llama models via NIM
"meta/llama-3.1-8b-instruct": {"input": 0.1, "output": 0.3},
"meta/llama-3.1-70b-instruct": {"input": 0.5, "output": 1.0},
"meta/llama-3.1-405b-instruct": {"input": 2.0, "output": 6.0},
# Mistral models
"mistralai/mistral-7b-instruct-v0.3": {"input": 0.1, "output": 0.2},
"mistralai/mixtral-8x7b-instruct-v0.1": {"input": 0.3, "output": 0.6},
# Default fallback
"default": {"input": 0.5, "output": 1.5},
}
class NvidiaProxyBackend:
"""LLM inference via NVIDIA NIM with OpenAI-compatible API"""
def __init__(self):
self.settings = get_settings()
self.base_url = getattr(self.settings, 'nvidia_nim_endpoint', None) or "https://integrate.api.nvidia.com/v1"
self.usage_metrics = {}
self.circuit_breaker_status = {
"state": "closed", # closed, open, half_open
"failure_count": 0,
"last_failure_time": None,
"failure_threshold": 5,
"recovery_timeout": 60 # seconds
}
logger.info(f"Initialized NVIDIA NIM backend with endpoint: {self.base_url}")
async def _get_tenant_api_key(self, tenant_id: str) -> str:
"""
Get API key for tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
API keys are managed in Control Panel and fetched via internal API.
Args:
tenant_id: Tenant domain string from X-Tenant-ID header
Returns:
Decrypted NVIDIA API key
Raises:
ValueError: If no API key configured (results in HTTP 503 to client)
"""
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="nvidia")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No NVIDIA API key for tenant '{tenant_id}': {e}")
raise ValueError(str(e))
except RuntimeError as e:
logger.error(f"Control Panel error: {e}")
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
def _get_client(self, api_key: str) -> httpx.AsyncClient:
"""Get configured HTTP client for NVIDIA NIM API"""
return httpx.AsyncClient(
base_url=self.base_url,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
timeout=httpx.Timeout(120.0) # Longer timeout for large models
)
async def execute_inference(
self,
prompt: str,
model: str = "nvidia/llama-3.1-nemotron-super-49b-v1",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None
) -> Dict[str, Any]:
"""Execute LLM inference with simple prompt"""
messages = [{"role": "user", "content": prompt}]
return await self.execute_inference_with_messages(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
user_id=user_id,
tenant_id=tenant_id
)
async def execute_inference_with_messages(
self,
messages: List[Dict[str, str]],
model: str = "nvidia/llama-3.1-nemotron-super-49b-v1",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None
) -> Dict[str, Any]:
"""Execute LLM inference using messages format (conversation style)"""
# Check circuit breaker
if not await self._is_circuit_closed():
raise Exception("Circuit breaker is open - NVIDIA NIM service temporarily unavailable")
if not tenant_id:
raise ValueError("tenant_id is required for NVIDIA NIM inference")
try:
api_key = await self._get_tenant_api_key(tenant_id)
# Translate GT 2.0 "agent" role to OpenAI "assistant" for external API compatibility
external_messages = []
for msg in messages:
external_msg = {
**msg, # Preserve ALL fields including tool_call_id, tool_calls, etc.
"role": "assistant" if msg.get("role") == "agent" else msg.get("role")
}
external_messages.append(external_msg)
# Build request payload
request_data = {
"model": model,
"messages": external_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream
}
# Add tools if provided
if tools:
request_data["tools"] = tools
if tool_choice:
request_data["tool_choice"] = tool_choice
start_time = time.time()
async with self._get_client(api_key) as client:
if stream:
# Return generator for streaming
return self._stream_inference_with_messages(
client, request_data, user_id, tenant_id, model
)
# Non-streaming request
response = await client.post("/chat/completions", json=request_data)
response.raise_for_status()
data = response.json()
latency = (time.time() - start_time) * 1000
# Calculate cost
usage = data.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
cost_cents = int((input_cost + output_cost) * 100)
# Track usage
await self._track_usage(user_id, tenant_id, model, total_tokens, latency, cost_cents)
# Reset circuit breaker on success
await self._record_success()
# Build response
result = {
"content": data["choices"][0]["message"]["content"],
"model": model,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"cost_cents": cost_cents
},
"latency_ms": latency,
"provider": "nvidia"
}
# Include tool calls if present
message = data["choices"][0]["message"]
if message.get("tool_calls"):
result["tool_calls"] = message["tool_calls"]
return result
except httpx.HTTPStatusError as e:
logger.error(f"NVIDIA NIM API error: {e.response.status_code} - {e.response.text}")
await self._record_failure()
raise Exception(f"NVIDIA NIM inference failed: HTTP {e.response.status_code}")
except Exception as e:
logger.error(f"NVIDIA NIM inference failed: {e}")
await self._record_failure()
raise Exception(f"NVIDIA NIM inference failed: {str(e)}")
async def _stream_inference_with_messages(
self,
client: httpx.AsyncClient,
request_data: Dict[str, Any],
user_id: str,
tenant_id: str,
model: str
) -> AsyncGenerator[str, None]:
"""Stream LLM inference responses"""
start_time = time.time()
total_tokens = 0
try:
async with client.stream("POST", "/chat/completions", json=request_data) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
if chunk.get("choices") and chunk["choices"][0].get("delta", {}).get("content"):
content = chunk["choices"][0]["delta"]["content"]
total_tokens += len(content.split()) # Approximate
yield content
except json.JSONDecodeError:
continue
# Track usage after streaming completes
latency = (time.time() - start_time) * 1000
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
cost_cents = int((total_tokens / 1_000_000) * model_prices["output"] * 100)
await self._track_usage(user_id, tenant_id, model, total_tokens, latency, cost_cents)
await self._record_success()
except Exception as e:
logger.error(f"NVIDIA NIM streaming error: {e}")
await self._record_failure()
raise e
async def check_health(self) -> Dict[str, Any]:
"""Check health of NVIDIA NIM backend and circuit breaker status"""
return {
"nvidia_nim": {
"endpoint": self.base_url,
"status": "available" if self.circuit_breaker_status["state"] == "closed" else "degraded",
"last_check": datetime.utcnow().isoformat()
},
"circuit_breaker": {
"state": self.circuit_breaker_status["state"],
"failure_count": self.circuit_breaker_status["failure_count"],
"last_failure": self.circuit_breaker_status["last_failure_time"].isoformat()
if self.circuit_breaker_status["last_failure_time"] else None
}
}
async def _is_circuit_closed(self) -> bool:
"""Check if circuit breaker allows requests"""
if self.circuit_breaker_status["state"] == "closed":
return True
if self.circuit_breaker_status["state"] == "open":
# Check if recovery timeout has passed
if self.circuit_breaker_status["last_failure_time"]:
time_since_failure = (datetime.utcnow() - self.circuit_breaker_status["last_failure_time"]).total_seconds()
if time_since_failure > self.circuit_breaker_status["recovery_timeout"]:
# Move to half-open state
self.circuit_breaker_status["state"] = "half_open"
logger.info("NVIDIA NIM circuit breaker moved to half-open state")
return True
return False
if self.circuit_breaker_status["state"] == "half_open":
# Allow limited requests in half-open state
return True
return False
async def _record_success(self):
"""Record successful request for circuit breaker"""
if self.circuit_breaker_status["state"] == "half_open":
# Success in half-open state closes the circuit
self.circuit_breaker_status["state"] = "closed"
self.circuit_breaker_status["failure_count"] = 0
logger.info("NVIDIA NIM circuit breaker closed after successful request")
# Reset failure count on any success
self.circuit_breaker_status["failure_count"] = 0
async def _record_failure(self):
"""Record failed request for circuit breaker"""
self.circuit_breaker_status["failure_count"] += 1
self.circuit_breaker_status["last_failure_time"] = datetime.utcnow()
if self.circuit_breaker_status["failure_count"] >= self.circuit_breaker_status["failure_threshold"]:
if self.circuit_breaker_status["state"] in ["closed", "half_open"]:
self.circuit_breaker_status["state"] = "open"
logger.warning(f"NVIDIA NIM circuit breaker opened after {self.circuit_breaker_status['failure_count']} failures")
async def _track_usage(
self,
user_id: str,
tenant_id: str,
model: str,
tokens: int,
latency: float,
cost_cents: int
):
"""Track usage metrics for billing and monitoring"""
# Create usage key
usage_key = f"{tenant_id}:{user_id}:{model}"
# Initialize metrics if not exists
if usage_key not in self.usage_metrics:
self.usage_metrics[usage_key] = {
"total_tokens": 0,
"total_requests": 0,
"total_cost_cents": 0,
"average_latency": 0
}
# Update metrics
metrics = self.usage_metrics[usage_key]
metrics["total_tokens"] += tokens
metrics["total_requests"] += 1
metrics["total_cost_cents"] += cost_cents
# Update average latency
prev_avg = metrics["average_latency"]
prev_count = metrics["total_requests"] - 1
metrics["average_latency"] = (prev_avg * prev_count + latency) / metrics["total_requests"]
# Log high-level metrics periodically
if metrics["total_requests"] % 100 == 0:
logger.info(f"NVIDIA NIM usage milestone for {usage_key}: {metrics}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int, model: str) -> int:
"""Calculate cost in cents based on token usage"""
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
return int((input_cost + output_cost) * 100)
async def get_available_models(self) -> List[Dict[str, Any]]:
"""Get list of available NVIDIA NIM models with their configurations"""
models = []
for model_id, prices in NVIDIA_MODEL_PRICES.items():
if model_id == "default":
continue
models.append({
"id": model_id,
"name": model_id.split("/")[-1].replace("-", " ").title(),
"provider": "nvidia",
"max_tokens": 4096, # Default for most NIM models
"cost_per_1k_input": prices["input"],
"cost_per_1k_output": prices["output"],
"supports_streaming": True,
"supports_function_calling": True
})
return models

View File

@@ -0,0 +1,457 @@
"""
Capability-Based Authentication for GT 2.0 Resource Cluster
Implements JWT capability token verification with:
- Cryptographic signature validation
- Fine-grained resource permissions
- Rate limiting and constraints enforcement
- Tenant isolation validation
- Zero external dependencies
GT 2.0 Security Principles:
- Self-contained: No external auth services
- Stateless: All permissions in JWT token
- Cryptographic: RSA signature verification
- Isolated: Perfect tenant separation
"""
import jwt
import logging
from datetime import datetime, timezone
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from enum import Enum
from fastapi import HTTPException, Depends, Header
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class CapabilityError(Exception):
"""Capability authentication error"""
pass
class ResourceType(str, Enum):
"""Resource types in GT 2.0"""
LLM = "llm"
EMBEDDING = "embedding"
VECTOR_STORAGE = "vector_storage"
EXTERNAL_SERVICES = "external_services"
ADMIN = "admin"
class ActionType(str, Enum):
"""Action types for resources"""
READ = "read"
WRITE = "write"
EXECUTE = "execute"
ADMIN = "admin"
@dataclass
class Capability:
"""Individual capability definition"""
resource: ResourceType
actions: List[ActionType]
constraints: Dict[str, Any]
expires_at: Optional[datetime] = None
def allows_action(self, action: ActionType) -> bool:
"""Check if capability allows specific action"""
return action in self.actions
def is_expired(self) -> bool:
"""Check if capability is expired"""
if not self.expires_at:
return False
return datetime.now(timezone.utc) > self.expires_at
def check_constraint(self, constraint_name: str, value: Any) -> bool:
"""Check if value satisfies constraint"""
if constraint_name not in self.constraints:
return True # No constraint means allowed
constraint_value = self.constraints[constraint_name]
if constraint_name == "max_tokens":
return value <= constraint_value
elif constraint_name == "allowed_models":
return value in constraint_value
elif constraint_name == "max_requests_per_hour":
# This would be checked separately with rate limiting
return True
elif constraint_name == "allowed_tenants":
return value in constraint_value
return True
@dataclass
class CapabilityToken:
"""Parsed capability token"""
subject: str
tenant_id: str
capabilities: List[Capability]
issued_at: datetime
expires_at: datetime
issuer: str
token_version: str
def has_capability(self, resource: ResourceType, action: ActionType) -> bool:
"""Check if token has specific capability"""
for cap in self.capabilities:
if cap.resource == resource and cap.allows_action(action) and not cap.is_expired():
return True
return False
def get_capability(self, resource: ResourceType) -> Optional[Capability]:
"""Get capability for specific resource"""
for cap in self.capabilities:
if cap.resource == resource and not cap.is_expired():
return cap
return None
def is_expired(self) -> bool:
"""Check if entire token is expired"""
return datetime.now(timezone.utc) > self.expires_at
class CapabilityAuthenticator:
"""
Handles capability token verification and authorization.
Uses JWT tokens with embedded permissions for stateless authentication.
"""
def __init__(self):
self.settings = get_settings()
# In production, this would be loaded from secure storage
# For development, using the secret key
self.secret_key = self.settings.secret_key
self.algorithm = "HS256" # TODO: Upgrade to RS256 with public/private keys
logger.info("Capability authenticator initialized")
async def verify_token(self, token: str) -> CapabilityToken:
"""
Verify and parse capability token.
Args:
token: JWT capability token
Returns:
Parsed capability token
Raises:
CapabilityError: If token is invalid or expired
"""
try:
# Decode JWT token
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
audience="gt2-resource-cluster"
)
# Validate required fields
required_fields = ["sub", "tenant_id", "capabilities", "iat", "exp", "iss"]
for field in required_fields:
if field not in payload:
raise CapabilityError(f"Missing required field: {field}")
# Parse timestamps
issued_at = datetime.fromtimestamp(payload["iat"], tz=timezone.utc)
expires_at = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
# Check token expiration
if datetime.now(timezone.utc) > expires_at:
raise CapabilityError("Token has expired")
# Parse capabilities
capabilities = []
for cap_data in payload["capabilities"]:
try:
capability = Capability(
resource=ResourceType(cap_data["resource"]),
actions=[ActionType(action) for action in cap_data["actions"]],
constraints=cap_data.get("constraints", {}),
expires_at=datetime.fromtimestamp(
cap_data["expires_at"], tz=timezone.utc
) if cap_data.get("expires_at") else None
)
capabilities.append(capability)
except (KeyError, ValueError) as e:
logger.warning(f"Invalid capability in token: {e}")
# Skip invalid capabilities rather than rejecting entire token
continue
# Create capability token
capability_token = CapabilityToken(
subject=payload["sub"],
tenant_id=payload["tenant_id"],
capabilities=capabilities,
issued_at=issued_at,
expires_at=expires_at,
issuer=payload["iss"],
token_version=payload.get("token_version", "1.0")
)
logger.debug(f"Capability token verified for {capability_token.subject}")
return capability_token
except jwt.ExpiredSignatureError:
raise CapabilityError("Token has expired")
except jwt.InvalidTokenError as e:
raise CapabilityError(f"Invalid token: {e}")
except Exception as e:
logger.error(f"Token verification failed: {e}")
raise CapabilityError(f"Token verification failed: {e}")
async def check_resource_access(
self,
capability_token: CapabilityToken,
resource: ResourceType,
action: ActionType,
constraints: Optional[Dict[str, Any]] = None
) -> bool:
"""
Check if token allows access to resource with specific action.
Args:
capability_token: Verified capability token
resource: Resource type to access
action: Action to perform
constraints: Additional constraints to check
Returns:
True if access is allowed
Raises:
CapabilityError: If access is denied
"""
try:
# Check token expiration
if capability_token.is_expired():
raise CapabilityError("Token has expired")
# Find matching capability
capability = capability_token.get_capability(resource)
if not capability:
raise CapabilityError(f"No capability for resource: {resource}")
# Check action permission
if not capability.allows_action(action):
raise CapabilityError(f"Action {action} not allowed for resource {resource}")
# Check constraints if provided
if constraints:
for constraint_name, value in constraints.items():
if not capability.check_constraint(constraint_name, value):
raise CapabilityError(
f"Constraint violation: {constraint_name} = {value}"
)
return True
except CapabilityError:
raise
except Exception as e:
logger.error(f"Resource access check failed: {e}")
raise CapabilityError(f"Access check failed: {e}")
# Global authenticator instance
capability_authenticator = CapabilityAuthenticator()
async def verify_capability_token(token: str) -> Dict[str, Any]:
"""
Verify capability token and return payload.
Args:
token: JWT capability token
Returns:
Token payload as dictionary
Raises:
CapabilityError: If token is invalid
"""
capability_token = await capability_authenticator.verify_token(token)
return {
"sub": capability_token.subject,
"tenant_id": capability_token.tenant_id,
"capabilities": [
{
"resource": cap.resource.value,
"actions": [action.value for action in cap.actions],
"constraints": cap.constraints
}
for cap in capability_token.capabilities
],
"iat": capability_token.issued_at.timestamp(),
"exp": capability_token.expires_at.timestamp(),
"iss": capability_token.issuer,
"token_version": capability_token.token_version
}
async def get_current_capability(
authorization: str = Header(..., description="Bearer token")
) -> Dict[str, Any]:
"""
FastAPI dependency to get current capability from Authorization header.
Args:
authorization: Authorization header with Bearer token
Returns:
Capability payload
Raises:
HTTPException: If authentication fails
"""
try:
if not authorization.startswith("Bearer "):
raise HTTPException(
status_code=401,
detail="Invalid authorization header format"
)
token = authorization[7:] # Remove "Bearer " prefix
payload = await verify_capability_token(token)
return payload
except CapabilityError as e:
logger.warning(f"Capability authentication failed: {e}")
raise HTTPException(status_code=401, detail=str(e))
except Exception as e:
logger.error(f"Authentication error: {e}")
raise HTTPException(status_code=500, detail="Authentication error")
async def require_capability(
resource: ResourceType,
action: ActionType,
constraints: Optional[Dict[str, Any]] = None
):
"""
FastAPI dependency to require specific capability.
Args:
resource: Required resource type
action: Required action type
constraints: Additional constraints to check
Returns:
Dependency function
"""
async def _check_capability(
capability_payload: Dict[str, Any] = Depends(get_current_capability)
) -> Dict[str, Any]:
try:
# Reconstruct capability token from payload
capabilities = []
for cap_data in capability_payload["capabilities"]:
capability = Capability(
resource=ResourceType(cap_data["resource"]),
actions=[ActionType(action) for action in cap_data["actions"]],
constraints=cap_data["constraints"]
)
capabilities.append(capability)
capability_token = CapabilityToken(
subject=capability_payload["sub"],
tenant_id=capability_payload["tenant_id"],
capabilities=capabilities,
issued_at=datetime.fromtimestamp(capability_payload["iat"], tz=timezone.utc),
expires_at=datetime.fromtimestamp(capability_payload["exp"], tz=timezone.utc),
issuer=capability_payload["iss"],
token_version=capability_payload["token_version"]
)
# Check required capability
await capability_authenticator.check_resource_access(
capability_token=capability_token,
resource=resource,
action=action,
constraints=constraints
)
return capability_payload
except CapabilityError as e:
logger.warning(f"Capability check failed: {e}")
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Capability check error: {e}")
raise HTTPException(status_code=500, detail="Authorization error")
return _check_capability
# Convenience functions for common capability checks
async def require_llm_capability(
capability_payload: Dict[str, Any] = Depends(
require_capability(ResourceType.LLM, ActionType.EXECUTE)
)
) -> Dict[str, Any]:
"""Require LLM execution capability"""
return capability_payload
async def require_embedding_capability(
capability_payload: Dict[str, Any] = Depends(
require_capability(ResourceType.EMBEDDING, ActionType.EXECUTE)
)
) -> Dict[str, Any]:
"""Require embedding generation capability"""
return capability_payload
async def require_admin_capability(
capability_payload: Dict[str, Any] = Depends(
require_capability(ResourceType.ADMIN, ActionType.ADMIN)
)
) -> Dict[str, Any]:
"""Require admin capability"""
return capability_payload
async def verify_capability_token_dependency(
authorization: str = Header(..., description="Bearer token")
) -> Dict[str, Any]:
"""
FastAPI dependency for ChromaDB MCP API that verifies capability token.
Returns token payload with raw_token field for service layer use.
"""
try:
if not authorization.startswith("Bearer "):
raise HTTPException(
status_code=401,
detail="Invalid authorization header format"
)
token = authorization[7:] # Remove "Bearer " prefix
payload = await verify_capability_token(token)
# Add raw token for service layer
payload["raw_token"] = token
return payload
except CapabilityError as e:
logger.warning(f"Capability authentication failed: {e}")
raise HTTPException(status_code=401, detail=str(e))
except Exception as e:
logger.error(f"Authentication error: {e}")
raise HTTPException(status_code=500, detail="Authentication error")

View File

@@ -0,0 +1,293 @@
"""
GT 2.0 Resource Cluster Configuration
Central configuration for the air-gapped Resource Cluster that manages
all AI resources, document processing, and external service integrations.
"""
import os
from typing import List, Dict, Any, Optional
from pydantic_settings import BaseSettings
from pydantic import Field, validator
class Settings(BaseSettings):
"""Resource Cluster settings with environment variable support"""
# Environment
environment: str = Field(default="development", description="Runtime environment")
debug: bool = Field(default=False, description="Debug mode")
# Service Identity
cluster_name: str = Field(default="gt-resource-cluster", description="Cluster identifier")
service_port: int = Field(default=8003, description="Service port")
# Security
secret_key: str = Field(..., description="JWT signing key for capability tokens")
algorithm: str = Field(default="HS256", description="JWT algorithm")
capability_token_expire_minutes: int = Field(default=60, description="Capability token expiry")
# External LLM Providers (via HAProxy)
groq_api_key: Optional[str] = Field(default=None, description="Groq Cloud API key")
groq_endpoints: List[str] = Field(
default=["https://api.groq.com/openai/v1"],
description="Groq API endpoints for load balancing"
)
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key")
anthropic_api_key: Optional[str] = Field(default=None, description="Anthropic API key")
# NVIDIA NIM Configuration
nvidia_nim_endpoint: str = Field(
default="https://integrate.api.nvidia.com/v1",
description="NVIDIA NIM API endpoint (cloud or self-hosted)"
)
nvidia_nim_enabled: bool = Field(
default=True,
description="Enable NVIDIA NIM backend for GPU-accelerated inference"
)
# HAProxy Configuration
haproxy_groq_endpoint: str = Field(
default="http://haproxy-groq-lb-service.gt-resource.svc.cluster.local",
description="HAProxy load balancer endpoint for Groq API"
)
haproxy_stats_endpoint: str = Field(
default="http://haproxy-groq-lb-service.gt-resource.svc.cluster.local:8404/stats",
description="HAProxy statistics endpoint"
)
haproxy_admin_socket: str = Field(
default="/var/run/haproxy.sock",
description="HAProxy admin socket for runtime configuration"
)
haproxy_enabled: bool = Field(
default=True,
description="Enable HAProxy load balancing for external APIs"
)
# Control Panel Integration (for API key retrieval)
control_panel_url: str = Field(
default="http://control-panel-backend:8000",
description="Control Panel internal API URL for service-to-service calls"
)
service_auth_token: str = Field(
default="internal-service-token",
description="Service-to-service authentication token"
)
# Admin Cluster Configuration Sync
admin_cluster_url: str = Field(
default="http://localhost:8001",
description="Admin cluster URL for configuration sync"
)
config_sync_interval: int = Field(
default=10,
description="Configuration sync interval in seconds"
)
config_sync_enabled: bool = Field(
default=True,
description="Enable automatic configuration sync from admin cluster"
)
# Consul Service Discovery
consul_host: str = Field(default="localhost", description="Consul host")
consul_port: int = Field(default=8500, description="Consul port")
consul_token: Optional[str] = Field(default=None, description="Consul ACL token")
# Document Processing
chunking_engine_workers: int = Field(default=4, description="Parallel document processors")
max_document_size_mb: int = Field(default=50, description="Maximum document size")
supported_document_types: List[str] = Field(
default=[".pdf", ".docx", ".txt", ".md", ".html", ".pptx", ".xlsx", ".csv"],
description="Supported document formats"
)
# BGE-M3 Embedding Configuration
embedding_endpoint: str = Field(
default="http://gentwo-vllm-embeddings:8000/v1/embeddings",
description="Default embedding endpoint (local or external)"
)
bge_m3_local_mode: bool = Field(
default=True,
description="Use local BGE-M3 embedding service (True) or external endpoint (False)"
)
bge_m3_external_endpoint: Optional[str] = Field(
default=None,
description="External BGE-M3 embedding endpoint URL (when local_mode=False)"
)
# Vector Database (ChromaDB)
chromadb_host: str = Field(default="localhost", description="ChromaDB host")
chromadb_port: int = Field(default=8000, description="ChromaDB port")
chromadb_encryption_key: Optional[str] = Field(
default=None,
description="Encryption key for vector storage"
)
# Resource Limits
max_concurrent_inferences: int = Field(default=100, description="Max concurrent LLM calls")
max_tokens_per_request: int = Field(default=8000, description="Max tokens per LLM request")
rate_limit_requests_per_minute: int = Field(default=60, description="Global rate limit")
# Storage Paths
data_directory: str = Field(
default="/tmp/gt2-resource-cluster" if os.getenv("ENVIRONMENT") != "production" else "/data/resource-cluster",
description="Base data directory"
)
template_library_path: str = Field(
default="/tmp/gt2-resource-cluster/templates" if os.getenv("ENVIRONMENT") != "production" else "/data/resource-cluster/templates",
description="Agent template library"
)
models_cache_path: str = Field( # Renamed to avoid pydantic warning
default="/tmp/gt2-resource-cluster/models" if os.getenv("ENVIRONMENT") != "production" else "/data/resource-cluster/models",
description="Local model cache"
)
# Redis removed - Resource Cluster uses PostgreSQL for caching and rate limiting
# Monitoring
prometheus_enabled: bool = Field(default=True, description="Enable Prometheus metrics")
prometheus_port: int = Field(default=9091, description="Prometheus metrics port")
# CORS Configuration (for tenant backends)
cors_origins: List[str] = Field(
default=["http://localhost:8002", "https://*.gt2.com"],
description="Allowed CORS origins"
)
# Trusted Host Configuration
trusted_hosts: List[str] = Field(
default=["localhost", "*.gt2.com", "resource-cluster", "gentwo-resource-backend",
"gt2-resource-backend", "testserver", "127.0.0.1", "*"],
description="Allowed host headers for TrustedHostMiddleware"
)
# Feature Flags
enable_model_caching: bool = Field(default=True, description="Cache model responses")
enable_usage_tracking: bool = Field(default=True, description="Track resource usage")
enable_cost_calculation: bool = Field(default=True, description="Calculate usage costs")
@validator("data_directory")
def validate_data_directory(cls, v):
# Ensure directory exists with secure permissions
os.makedirs(v, exist_ok=True, mode=0o700)
return v
@validator("template_library_path")
def validate_template_library_path(cls, v):
os.makedirs(v, exist_ok=True, mode=0o700)
return v
@validator("models_cache_path")
def validate_models_cache_path(cls, v):
os.makedirs(v, exist_ok=True, mode=0o700)
return v
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore",
}
def get_settings(tenant_id: Optional[str] = None) -> Settings:
"""Get tenant-scoped application settings"""
# For development, use a simple cache without tenant isolation
if os.getenv("ENVIRONMENT") == "development":
return Settings()
# In production, settings should be tenant-scoped
# This prevents global state from affecting tenant isolation
if tenant_id:
# Create tenant-specific settings with proper isolation
settings = Settings()
# Add tenant-specific configurations here if needed
return settings
else:
# Default settings for non-tenant operations
return Settings()
def get_resource_families(tenant_id: Optional[str] = None) -> Dict[str, Any]:
"""Get tenant-scoped resource family definitions (from CLAUDE.md)"""
# Base resource families - can be extended per tenant in production
return {
"ai_ml": {
"name": "AI/ML Resources",
"subtypes": ["llm", "embedding", "image_generation", "function_calling"]
},
"rag_engine": {
"name": "RAG Engine Resources",
"subtypes": ["vector_db", "document_processor", "semantic_search", "retrieval"]
},
"agentic_workflow": {
"name": "Agentic Workflow Resources",
"subtypes": ["single_agent", "multi_agent", "orchestration", "memory"]
},
"app_integration": {
"name": "App Integration Resources",
"subtypes": ["oauth2", "webhook", "api_connector", "database_connector"]
},
"external_service": {
"name": "External Web Services",
"subtypes": ["iframe_embed", "sso_service", "remote_desktop", "learning_platform"]
},
"ai_literacy": {
"name": "AI Literacy & Cognitive Skills",
"subtypes": ["strategic_game", "logic_puzzle", "philosophical_dilemma", "educational_content"]
}
}
def get_model_configs(tenant_id: Optional[str] = None) -> Dict[str, Any]:
"""Get tenant-scoped model configurations for different providers"""
# Base model configurations - can be customized per tenant in production
return {
"groq": {
"llama-3.1-70b-versatile": {
"max_tokens": 8000,
"cost_per_1k_tokens": 0.59,
"supports_streaming": True,
"supports_function_calling": True
},
"llama-3.1-8b-instant": {
"max_tokens": 8000,
"cost_per_1k_tokens": 0.05,
"supports_streaming": True,
"supports_function_calling": True
},
"mixtral-8x7b-32768": {
"max_tokens": 32768,
"cost_per_1k_tokens": 0.27,
"supports_streaming": True,
"supports_function_calling": False
}
},
"openai": {
"gpt-4-turbo": {
"max_tokens": 128000,
"cost_per_1k_tokens": 10.0,
"supports_streaming": True,
"supports_function_calling": True
},
"gpt-3.5-turbo": {
"max_tokens": 16385,
"cost_per_1k_tokens": 0.5,
"supports_streaming": True,
"supports_function_calling": True
}
},
"anthropic": {
"claude-3-opus": {
"max_tokens": 200000,
"cost_per_1k_tokens": 15.0,
"supports_streaming": True,
"supports_function_calling": False
},
"claude-3-sonnet": {
"max_tokens": 200000,
"cost_per_1k_tokens": 3.0,
"supports_streaming": True,
"supports_function_calling": False
}
}
}

View File

@@ -0,0 +1,45 @@
"""
GT 2.0 Resource Cluster Exceptions
Custom exceptions for the resource cluster.
"""
class ResourceClusterError(Exception):
"""Base exception for resource cluster errors"""
pass
class ProviderError(ResourceClusterError):
"""Error from AI model provider"""
pass
class ModelNotFoundError(ResourceClusterError):
"""Requested model not found"""
pass
class CapabilityError(ResourceClusterError):
"""Capability token validation error"""
pass
class MCPError(ResourceClusterError):
"""MCP service error"""
pass
class DocumentProcessingError(ResourceClusterError):
"""Document processing error"""
pass
class RateLimitError(ResourceClusterError):
"""Rate limit exceeded"""
pass
class CircuitBreakerError(ProviderError):
"""Circuit breaker is open"""
pass

View File

@@ -0,0 +1,273 @@
"""
GT 2.0 Resource Cluster Security
Capability-based authentication and authorization for resource access.
Implements cryptographically signed JWT tokens with embedded capabilities.
"""
import hashlib
import json
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from app.core.config import get_settings
settings = get_settings()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class ResourceCapability(BaseModel):
"""Individual resource capability"""
resource: str # e.g., "llm:groq", "rag:semantic_search"
actions: List[str] # e.g., ["inference", "streaming"]
limits: Dict[str, Any] = {} # e.g., {"max_tokens": 4000, "requests_per_minute": 60}
constraints: Dict[str, Any] = {} # e.g., {"valid_until": "2024-12-31", "ip_restrictions": []}
class CapabilityToken(BaseModel):
"""Capability-based JWT token payload"""
sub: str # User or service identifier
tenant_id: str # Tenant identifier
capabilities: List[ResourceCapability] # Granted capabilities
capability_hash: str # SHA256 hash of capabilities for integrity
exp: Optional[datetime] = None # Expiration time
iat: Optional[datetime] = None # Issued at time
jti: Optional[str] = None # JWT ID for revocation
class CapabilityValidator:
"""Validates and enforces capability-based access control"""
def __init__(self):
self.settings = get_settings()
def create_capability_token(
self,
user_id: str,
tenant_id: str,
capabilities: List[Dict[str, Any]],
expires_delta: Optional[timedelta] = None
) -> str:
"""Create a cryptographically signed capability token"""
# Convert capabilities to ResourceCapability objects
capability_objects = [
ResourceCapability(**cap) for cap in capabilities
]
# Generate capability hash for integrity verification
capability_hash = self._generate_capability_hash(capability_objects)
# Set token expiration
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=self.settings.capability_token_expire_minutes)
# Create token payload
token_data = CapabilityToken(
sub=user_id,
tenant_id=tenant_id,
capabilities=[cap.dict() for cap in capability_objects],
capability_hash=capability_hash,
exp=expire,
iat=datetime.utcnow(),
jti=self._generate_jti()
)
# Encode JWT token
encoded_jwt = jwt.encode(
token_data.dict(),
self.settings.secret_key,
algorithm=self.settings.algorithm
)
return encoded_jwt
def verify_capability_token(self, token: str) -> Optional[CapabilityToken]:
"""Verify and decode a capability token"""
try:
# Decode JWT token
payload = jwt.decode(
token,
self.settings.secret_key,
algorithms=[self.settings.algorithm]
)
# Convert to CapabilityToken object
capability_token = CapabilityToken(**payload)
# Verify capability hash integrity
capability_objects = []
for cap in capability_token.capabilities:
if isinstance(cap, dict):
capability_objects.append(ResourceCapability(**cap))
else:
capability_objects.append(cap)
expected_hash = self._generate_capability_hash(capability_objects)
if capability_token.capability_hash != expected_hash:
raise ValueError("Capability hash mismatch - token may be tampered")
return capability_token
except (JWTError, ValueError) as e:
return None
def check_resource_access(
self,
token: CapabilityToken,
resource: str,
action: str,
context: Dict[str, Any] = {}
) -> bool:
"""Check if token grants access to specific resource and action"""
for capability in token.capabilities:
# Handle both dict and ResourceCapability object formats
if isinstance(capability, dict):
cap_resource = capability["resource"]
cap_actions = capability.get("actions", [])
cap_constraints = capability.get("constraints", {})
else:
cap_resource = capability.resource
cap_actions = capability.actions
cap_constraints = capability.constraints
# Check if capability matches resource
if self._matches_resource(cap_resource, resource):
# Check if action is allowed
if action in cap_actions:
# Check additional constraints
if self._check_constraints(cap_constraints, context):
return True
return False
def get_resource_limits(
self,
token: CapabilityToken,
resource: str
) -> Dict[str, Any]:
"""Get resource-specific limits from token"""
for capability in token.capabilities:
# Handle both dict and ResourceCapability object formats
if isinstance(capability, dict):
cap_resource = capability["resource"]
cap_limits = capability.get("limits", {})
else:
cap_resource = capability.resource
cap_limits = capability.limits
if self._matches_resource(cap_resource, resource):
return cap_limits
return {}
def _generate_capability_hash(self, capabilities: List[ResourceCapability]) -> str:
"""Generate SHA256 hash of capabilities for integrity verification"""
# Sort capabilities for consistent hashing
sorted_caps = sorted(
[cap.dict() for cap in capabilities],
key=lambda x: x["resource"]
)
# Create hash
cap_string = json.dumps(sorted_caps, sort_keys=True)
return hashlib.sha256(cap_string.encode()).hexdigest()
def _generate_jti(self) -> str:
"""Generate unique JWT ID"""
import uuid
return str(uuid.uuid4())
def _matches_resource(self, pattern: str, resource: str) -> bool:
"""Check if resource pattern matches requested resource"""
# Handle wildcards (e.g., "llm:*" matches "llm:groq")
if pattern.endswith(":*"):
prefix = pattern[:-2]
return resource.startswith(prefix + ":")
# Handle exact matches
return pattern == resource
def _check_constraints(self, constraints: Dict[str, Any], context: Dict[str, Any]) -> bool:
"""Check additional constraints like time validity and IP restrictions"""
# Check time validity
if "valid_until" in constraints:
valid_until = datetime.fromisoformat(constraints["valid_until"])
if datetime.utcnow() > valid_until:
return False
# Check IP restrictions
if "ip_restrictions" in constraints and "client_ip" in context:
allowed_ips = constraints["ip_restrictions"]
if allowed_ips and context["client_ip"] not in allowed_ips:
return False
# Check tenant restrictions
if "allowed_tenants" in constraints and "tenant_id" in context:
allowed_tenants = constraints["allowed_tenants"]
if allowed_tenants and context["tenant_id"] not in allowed_tenants:
return False
return True
# Global validator instance
capability_validator = CapabilityValidator()
def verify_capability_token(token: str) -> Optional[CapabilityToken]:
"""Standalone function for FastAPI dependency injection"""
return capability_validator.verify_capability_token(token)
def create_resource_capability(
resource_type: str,
resource_id: str,
actions: List[str],
limits: Dict[str, Any] = {},
constraints: Dict[str, Any] = {}
) -> Dict[str, Any]:
"""Helper function to create a resource capability"""
return {
"resource": f"{resource_type}:{resource_id}",
"actions": actions,
"limits": limits,
"constraints": constraints
}
def create_assistant_capabilities(assistant_config: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Create capabilities from agent configuration"""
capabilities = []
# Extract capabilities from agent config
for cap in assistant_config.get("capabilities", []):
capabilities.append(cap)
# Add default LLM capability if specified
if "primary_llm" in assistant_config.get("resource_preferences", {}):
llm_model = assistant_config["resource_preferences"]["primary_llm"]
capabilities.append(create_resource_capability(
"llm",
llm_model.replace(":", "_"),
["inference", "streaming"],
{
"max_tokens": assistant_config["resource_preferences"].get("max_tokens", 4000),
"temperature": assistant_config["resource_preferences"].get("temperature", 0.7)
}
))
return capabilities
# Global capability validator instance
capability_validator = CapabilityValidator()