GT AI OS Community v2.0.33 - Add NVIDIA NIM and Nemotron agents

- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2
- Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2
  - Made more general-purpose (flexible targets, expanded tools)
- Added nemotron-mini-agent.csv for fast local inference via Ollama
- Added nemotron-agent.csv for advanced reasoning via Ollama
- Added wiki page: Projects for NVIDIA NIMs and Nemotron
This commit is contained in:
HackWeasel
2025-12-12 17:47:14 -05:00
commit 310491a557
750 changed files with 232701 additions and 0 deletions

View File

@@ -0,0 +1,52 @@
"""
Resource backend implementations for GT 2.0
Provides unified interfaces for all resource types:
- LLM inference (Groq, OpenAI, Anthropic)
- Vector databases (PGVector)
- Document processing (Unstructured)
- External services (OAuth2, iframe)
- AI literacy resources
"""
from typing import Dict, Any
import logging
logger = logging.getLogger(__name__)
# Registry of available backends
BACKEND_REGISTRY: Dict[str, Any] = {}
def register_backend(name: str, backend_class):
"""Register a resource backend"""
BACKEND_REGISTRY[name] = backend_class
logger.info(f"Registered backend: {name}")
def get_backend(name: str):
"""Get a registered backend"""
if name not in BACKEND_REGISTRY:
raise ValueError(f"Backend not found: {name}")
return BACKEND_REGISTRY[name]
async def initialize_backends():
"""Initialize all resource backends"""
from app.core.backends.groq_proxy import GroqProxyBackend
from app.core.backends.nvidia_proxy import NvidiaProxyBackend
from app.core.backends.document_processor import DocumentProcessorBackend
from app.core.backends.embedding_backend import EmbeddingBackend
# Register backends
register_backend("groq_proxy", GroqProxyBackend())
register_backend("nvidia_proxy", NvidiaProxyBackend())
register_backend("document_processor", DocumentProcessorBackend())
register_backend("embedding", EmbeddingBackend())
logger.info("All resource backends initialized")
def get_embedding_backend():
"""Get the embedding backend instance"""
return get_backend("embedding")

View File

@@ -0,0 +1,322 @@
"""
Document Processing Backend
STATELESS document chunking and preprocessing for RAG operations.
All processing happens in memory - NO user data is ever stored.
"""
import logging
import io
import gc
from typing import Dict, Any, List, Optional, BinaryIO
from dataclasses import dataclass
import hashlib
# Document processing imports
import pypdf as PyPDF2
from docx import Document as DocxDocument
from bs4 import BeautifulSoup
from langchain_text_splitters import (
RecursiveCharacterTextSplitter,
TokenTextSplitter,
SentenceTransformersTokenTextSplitter
)
logger = logging.getLogger(__name__)
@dataclass
class ChunkingStrategy:
"""Configuration for document chunking"""
strategy_type: str # 'fixed', 'semantic', 'hierarchical', 'hybrid'
chunk_size: int # Target chunk size in tokens (optimized for BGE-M3: 512)
chunk_overlap: int # Overlap between chunks (typically 128 for BGE-M3)
separator_pattern: Optional[str] = None # Custom separator for splitting
preserve_paragraphs: bool = True
preserve_sentences: bool = True
class DocumentProcessorBackend:
"""
STATELESS document chunking and processing backend.
Security principles:
- NO persistence of user data
- All processing in memory only
- Immediate memory cleanup after processing
- No caching of user content
"""
def __init__(self):
self.supported_formats = [".pdf", ".docx", ".txt", ".md", ".html"]
# BGE-M3 optimal settings
self.default_chunk_size = 512 # tokens
self.default_chunk_overlap = 128 # tokens
self.model_name = "BAAI/bge-m3" # For tokenization
logger.info("STATELESS document processor backend initialized")
async def process_document(
self,
content: bytes,
document_type: str,
strategy: Optional[ChunkingStrategy] = None,
metadata: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Process document into chunks - STATELESS operation.
Args:
content: Document content as bytes (will be cleared from memory)
document_type: File type (.pdf, .docx, .txt, .md, .html)
strategy: Chunking strategy configuration
metadata: Optional metadata (will NOT include user content)
Returns:
List of chunks with metadata (immediately returned, not stored)
"""
try:
# Use default strategy if not provided
if strategy is None:
strategy = ChunkingStrategy(
strategy_type='hybrid',
chunk_size=self.default_chunk_size,
chunk_overlap=self.default_chunk_overlap
)
# Extract text based on document type (in memory)
text = await self._extract_text_from_bytes(content, document_type)
# Clear original content from memory
del content
gc.collect()
# Apply chunking strategy
if strategy.strategy_type == 'semantic':
chunks = await self._semantic_chunking(text, strategy)
elif strategy.strategy_type == 'hierarchical':
chunks = await self._hierarchical_chunking(text, strategy)
elif strategy.strategy_type == 'hybrid':
chunks = await self._hybrid_chunking(text, strategy)
else: # 'fixed'
chunks = await self._fixed_chunking(text, strategy)
# Clear text from memory
del text
gc.collect()
# Add metadata without storing content
processed_chunks = []
for idx, chunk in enumerate(chunks):
chunk_metadata = {
"chunk_index": idx,
"total_chunks": len(chunks),
"chunking_strategy": strategy.strategy_type,
"chunk_size_tokens": strategy.chunk_size,
# Generate hash for deduplication without storing content
"content_hash": hashlib.sha256(chunk.encode()).hexdigest()[:16]
}
# Add non-sensitive metadata if provided
if metadata:
# Filter out any potential sensitive data
safe_metadata = {
k: v for k, v in metadata.items()
if k in ['document_type', 'processing_timestamp', 'tenant_id']
}
chunk_metadata.update(safe_metadata)
processed_chunks.append({
"text": chunk,
"metadata": chunk_metadata
})
logger.info(f"Processed document into {len(processed_chunks)} chunks (STATELESS)")
# Return immediately - no storage
return processed_chunks
except Exception as e:
logger.error(f"Error processing document: {e}")
# Ensure memory is cleared even on error
gc.collect()
raise
finally:
# Always ensure memory cleanup
gc.collect()
async def _extract_text_from_bytes(
self,
content: bytes,
document_type: str
) -> str:
"""Extract text from document bytes - in memory only"""
try:
if document_type == ".pdf":
return await self._extract_pdf_text(io.BytesIO(content))
elif document_type == ".docx":
return await self._extract_docx_text(io.BytesIO(content))
elif document_type == ".html":
return await self._extract_html_text(content.decode('utf-8'))
elif document_type in [".txt", ".md"]:
return content.decode('utf-8')
else:
raise ValueError(f"Unsupported document type: {document_type}")
finally:
# Clear content from memory
del content
gc.collect()
async def _extract_pdf_text(self, file_stream: BinaryIO) -> str:
"""Extract text from PDF - in memory"""
text = ""
try:
pdf_reader = PyPDF2.PdfReader(file_stream)
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
text += page.extract_text() + "\n"
finally:
file_stream.close()
gc.collect()
return text
async def _extract_docx_text(self, file_stream: BinaryIO) -> str:
"""Extract text from DOCX - in memory"""
text = ""
try:
doc = DocxDocument(file_stream)
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
finally:
file_stream.close()
gc.collect()
return text
async def _extract_html_text(self, html_content: str) -> str:
"""Extract text from HTML - in memory"""
soup = BeautifulSoup(html_content, 'html.parser')
# Remove script and style elements
for script in soup(["script", "style"]):
script.decompose()
text = soup.get_text()
# Clean up whitespace
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n'.join(chunk for chunk in chunks if chunk)
return text
async def _semantic_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Semantic chunking using sentence boundaries"""
splitter = SentenceTransformersTokenTextSplitter(
model_name=self.model_name,
chunk_size=strategy.chunk_size,
chunk_overlap=strategy.chunk_overlap
)
return splitter.split_text(text)
async def _hierarchical_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Hierarchical chunking preserving document structure"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=strategy.chunk_size * 3, # Approximate token to char ratio
chunk_overlap=strategy.chunk_overlap * 3,
separators=["\n\n\n", "\n\n", "\n", ". ", " ", ""],
keep_separator=True
)
return splitter.split_text(text)
async def _hybrid_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Hybrid chunking combining semantic and structural boundaries"""
# First split by structure
structural_splitter = RecursiveCharacterTextSplitter(
chunk_size=strategy.chunk_size * 4,
chunk_overlap=0,
separators=["\n\n\n", "\n\n"],
keep_separator=True
)
structural_chunks = structural_splitter.split_text(text)
# Then apply semantic splitting to each structural chunk
final_chunks = []
token_splitter = TokenTextSplitter(
chunk_size=strategy.chunk_size,
chunk_overlap=strategy.chunk_overlap
)
for struct_chunk in structural_chunks:
semantic_chunks = token_splitter.split_text(struct_chunk)
final_chunks.extend(semantic_chunks)
return final_chunks
async def _fixed_chunking(
self,
text: str,
strategy: ChunkingStrategy
) -> List[str]:
"""Fixed-size chunking with token boundaries"""
splitter = TokenTextSplitter(
chunk_size=strategy.chunk_size,
chunk_overlap=strategy.chunk_overlap
)
return splitter.split_text(text)
async def validate_document(
self,
content_size: int,
document_type: str
) -> Dict[str, Any]:
"""
Validate document before processing - no content stored.
Args:
content_size: Size of document in bytes
document_type: File extension
Returns:
Validation result with any warnings
"""
MAX_SIZE = 50 * 1024 * 1024 # 50MB max
validation = {
"valid": True,
"warnings": [],
"errors": []
}
# Check file size
if content_size > MAX_SIZE:
validation["valid"] = False
validation["errors"].append(f"File size exceeds maximum of 50MB")
elif content_size > 10 * 1024 * 1024: # Warning for files over 10MB
validation["warnings"].append("Large file may take longer to process")
# Check document type
if document_type not in self.supported_formats:
validation["valid"] = False
validation["errors"].append(f"Unsupported format: {document_type}")
return validation
async def check_health(self) -> Dict[str, Any]:
"""Check document processor health - no user data exposed"""
return {
"status": "healthy",
"supported_formats": self.supported_formats,
"default_chunk_size": self.default_chunk_size,
"default_chunk_overlap": self.default_chunk_overlap,
"model": self.model_name,
"stateless": True, # Confirm stateless operation
"memory_cleared": True # Confirm memory management
}

View File

@@ -0,0 +1,471 @@
"""
Embedding Model Backend
STATELESS embedding generation using BGE-M3 model hosted on GT's GPU clusters.
All embeddings are generated in real-time - NO user data is stored.
"""
import logging
import gc
import hashlib
import asyncio
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
# import numpy as np # Temporarily disabled for Docker build
import aiohttp
import json
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class EmbeddingRequest:
"""Request structure for embedding generation"""
texts: List[str]
model: str = "BAAI/bge-m3"
batch_size: int = 32
normalize: bool = True
instruction: Optional[str] = None # For instruction-based embeddings
class EmbeddingBackend:
"""
STATELESS embedding backend for BGE-M3 model.
Security principles:
- NO persistence of embeddings or text
- All processing via GT's internal GPU cluster
- Immediate memory cleanup after generation
- No caching of user content
- Request signing and verification
"""
def __init__(self):
self.model_name = "BAAI/bge-m3"
self.embedding_dimensions = 1024 # BGE-M3 dimensions
self.max_batch_size = 32
self.max_sequence_length = 8192 # BGE-M3 supports up to 8192 tokens
# Determine endpoint based on configuration
self.embedding_endpoint = self._get_embedding_endpoint()
# Timeout for embedding requests
self.request_timeout = 60 # seconds for model loading
logger.info(f"STATELESS embedding backend initialized for {self.model_name}")
logger.info(f"Using embedding endpoint: {self.embedding_endpoint}")
def _get_embedding_endpoint(self) -> str:
"""
Get the embedding endpoint based on configuration.
Priority:
1. Model registry from config sync (database-backed)
2. Environment variables (BGE_M3_LOCAL_MODE, BGE_M3_EXTERNAL_ENDPOINT)
3. Default local endpoint
"""
# Try to get configuration from model registry first (loaded from database)
try:
from app.services.model_service import default_model_service
import asyncio
# Use the default model service instance (singleton) used by config sync
model_service = default_model_service
# Try to get the model config synchronously (during initialization)
# The get_model method is async, so we need to handle this carefully
bge_m3_config = model_service.model_registry.get("BAAI/bge-m3")
if bge_m3_config:
# Model registry stores endpoint as 'endpoint_url' and config as 'parameters'
endpoint = bge_m3_config.get("endpoint_url")
config = bge_m3_config.get("parameters", {})
is_local_mode = config.get("is_local_mode", True)
external_endpoint = config.get("external_endpoint")
logger.info(f"Found BGE-M3 in registry: endpoint_url={endpoint}, is_local_mode={is_local_mode}, external_endpoint={external_endpoint}")
if endpoint:
logger.info(f"Using BGE-M3 endpoint from model registry (is_local_mode={is_local_mode}): {endpoint}")
return endpoint
else:
logger.warning(f"BGE-M3 found in registry but endpoint_url is None/empty. Full config: {bge_m3_config}")
else:
available_models = list(model_service.model_registry.keys())
logger.debug(f"BGE-M3 not found in model registry during init (expected on first startup). Available models: {available_models}")
except Exception as e:
logger.debug(f"Model registry not yet available during startup (will be populated after config sync): {e}")
# Fall back to Settings fields (environment variables or .env file)
is_local_mode = getattr(settings, 'bge_m3_local_mode', True)
external_endpoint = getattr(settings, 'bge_m3_external_endpoint', None)
if not is_local_mode and external_endpoint:
logger.info(f"Using external BGE-M3 endpoint from settings: {external_endpoint}")
return external_endpoint
# Default to local endpoint
local_endpoint = getattr(
settings,
'embedding_endpoint',
'http://gentwo-vllm-embeddings:8000/v1/embeddings'
)
logger.info(f"Using local BGE-M3 endpoint: {local_endpoint}")
return local_endpoint
async def update_endpoint_config(self, is_local_mode: bool, external_endpoint: str = None):
"""
Update the embedding endpoint configuration dynamically.
This allows switching between local and external endpoints without restart.
"""
if is_local_mode:
self.embedding_endpoint = getattr(
settings,
'embedding_endpoint',
'http://gentwo-vllm-embeddings:8000/v1/embeddings'
)
else:
if external_endpoint:
self.embedding_endpoint = external_endpoint
else:
raise ValueError("External endpoint must be provided when not in local mode")
logger.info(f"BGE-M3 endpoint updated to: {self.embedding_endpoint}")
logger.info(f"Mode: {'Local GT Edge' if is_local_mode else 'External API'}")
def refresh_endpoint_from_registry(self):
"""
Refresh the embedding endpoint from the model registry.
Called by config sync when BGE-M3 configuration changes.
"""
logger.info(f"Refreshing embedding endpoint - current: {self.embedding_endpoint}")
new_endpoint = self._get_embedding_endpoint()
if new_endpoint != self.embedding_endpoint:
logger.info(f"Refreshing BGE-M3 endpoint from {self.embedding_endpoint} to {new_endpoint}")
self.embedding_endpoint = new_endpoint
else:
logger.info(f"BGE-M3 endpoint unchanged: {self.embedding_endpoint}")
async def generate_embeddings(
self,
texts: List[str],
instruction: Optional[str] = None,
tenant_id: str = None,
request_id: str = None
) -> List[List[float]]:
"""
Generate embeddings for texts using BGE-M3 - STATELESS operation.
Args:
texts: List of texts to embed (will be cleared from memory)
instruction: Optional instruction for query vs document embeddings
tenant_id: Tenant ID for audit logging (not stored with data)
request_id: Request ID for tracing
Returns:
List of embedding vectors (immediately returned, not stored)
"""
try:
# Validate input
if not texts:
return []
if len(texts) > self.max_batch_size:
# Process in batches
return await self._batch_process_embeddings(
texts, instruction, tenant_id, request_id
)
# Prepare request
request_data = {
"model": self.model_name,
"input": texts,
"encoding_format": "float",
"dimensions": self.embedding_dimensions
}
# Add instruction if provided (for query vs document distinction)
if instruction:
request_data["instruction"] = instruction
# Add metadata for audit (not stored with embeddings)
metadata = {
"tenant_id": tenant_id,
"request_id": request_id,
"text_count": len(texts),
# Hash for deduplication without storing content
"content_hash": hashlib.sha256(
"".join(texts).encode()
).hexdigest()[:16]
}
# Call vLLM service - NO FALLBACKS
embeddings = await self._call_embedding_service(request_data, metadata)
# Clear texts from memory immediately
del texts
gc.collect()
# Validate response
if not embeddings or len(embeddings) == 0:
raise ValueError("No embeddings returned from service")
# Normalize if needed
if self._should_normalize():
embeddings = self._normalize_embeddings(embeddings)
logger.info(
f"Generated {len(embeddings)} embeddings (STATELESS) "
f"for tenant {tenant_id}"
)
# Return immediately - no storage
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
# Ensure memory is cleared even on error
gc.collect()
raise
finally:
# Always ensure memory cleanup
gc.collect()
async def _batch_process_embeddings(
self,
texts: List[str],
instruction: Optional[str],
tenant_id: str,
request_id: str
) -> List[List[float]]:
"""Process large text lists in batches using vLLM service"""
all_embeddings = []
for i in range(0, len(texts), self.max_batch_size):
batch = texts[i:i + self.max_batch_size]
# Prepare request for this batch
request_data = {
"model": self.model_name,
"input": batch,
"encoding_format": "float",
"dimensions": self.embedding_dimensions
}
if instruction:
request_data["instruction"] = instruction
metadata = {
"tenant_id": tenant_id,
"request_id": f"{request_id}_batch_{i}",
"text_count": len(batch),
"content_hash": hashlib.sha256(
"".join(batch).encode()
).hexdigest()[:16]
}
batch_embeddings = await self._call_embedding_service(request_data, metadata)
all_embeddings.extend(batch_embeddings)
# Clear batch from memory
del batch
gc.collect()
return all_embeddings
async def _call_embedding_service(
self,
request_data: Dict[str, Any],
metadata: Dict[str, Any]
) -> List[List[float]]:
"""Call internal GPU cluster embedding service"""
async with aiohttp.ClientSession() as session:
try:
# Add capability token for authentication
headers = {
"Content-Type": "application/json",
"X-Tenant-ID": metadata.get("tenant_id", ""),
"X-Request-ID": metadata.get("request_id", ""),
# Authorization will be added by Resource Cluster
}
async with session.post(
self.embedding_endpoint,
json=request_data,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self.request_timeout)
) as response:
if response.status != 200:
error_text = await response.text()
raise ValueError(
f"Embedding service error: {response.status} - {error_text}"
)
result = await response.json()
# Extract embeddings from response
if "data" in result:
embeddings = [item["embedding"] for item in result["data"]]
elif "embeddings" in result:
embeddings = result["embeddings"]
else:
raise ValueError("Invalid embedding service response format")
return embeddings
except asyncio.TimeoutError:
raise ValueError(f"Embedding service timeout after {self.request_timeout}s")
except Exception as e:
logger.error(f"Error calling embedding service: {e}")
raise
def _should_normalize(self) -> bool:
"""Check if embeddings should be normalized"""
# BGE-M3 embeddings are typically normalized for similarity search
return True
def _normalize_embeddings(
self,
embeddings: List[List[float]]
) -> List[List[float]]:
"""Normalize embedding vectors to unit length"""
normalized = []
for embedding in embeddings:
# Simple normalization without numpy (for now)
import math
# Calculate norm
norm = math.sqrt(sum(x * x for x in embedding))
if norm > 0:
normalized_vec = [x / norm for x in embedding]
else:
normalized_vec = embedding[:]
normalized.append(normalized_vec)
return normalized
async def generate_query_embeddings(
self,
queries: List[str],
tenant_id: str = None,
request_id: str = None
) -> List[List[float]]:
"""
Generate embeddings specifically for queries.
BGE-M3 can use different instructions for queries vs documents.
"""
# For BGE-M3, queries can use a specific instruction
instruction = "Represent this sentence for searching relevant passages: "
return await self.generate_embeddings(
queries, instruction, tenant_id, request_id
)
async def generate_document_embeddings(
self,
documents: List[str],
tenant_id: str = None,
request_id: str = None
) -> List[List[float]]:
"""
Generate embeddings specifically for documents.
BGE-M3 can use different instructions for documents vs queries.
"""
# For BGE-M3, documents typically don't need special instruction
return await self.generate_embeddings(
documents, None, tenant_id, request_id
)
async def validate_texts(
self,
texts: List[str]
) -> Dict[str, Any]:
"""
Validate texts before embedding - no content stored.
Args:
texts: List of texts to validate
Returns:
Validation result with any warnings
"""
validation = {
"valid": True,
"warnings": [],
"errors": [],
"stats": {
"total_texts": len(texts),
"max_length": 0,
"avg_length": 0
}
}
if not texts:
validation["valid"] = False
validation["errors"].append("No texts provided")
return validation
# Check text lengths
lengths = [len(text) for text in texts]
validation["stats"]["max_length"] = max(lengths)
validation["stats"]["avg_length"] = sum(lengths) // len(lengths)
# BGE-M3 max sequence length check (approximate)
max_chars = self.max_sequence_length * 4 # Rough char to token ratio
for i, length in enumerate(lengths):
if length > max_chars:
validation["warnings"].append(
f"Text {i} may exceed model's max sequence length"
)
elif length == 0:
validation["errors"].append(f"Text {i} is empty")
validation["valid"] = False
# Batch size check
if len(texts) > self.max_batch_size * 10:
validation["warnings"].append(
f"Large batch ({len(texts)} texts) will be processed in chunks"
)
return validation
async def check_health(self) -> Dict[str, Any]:
"""Check embedding backend health - no user data exposed"""
try:
# Test connection to vLLM service
test_text = ["Health check test"]
test_embeddings = await self.generate_embeddings(
test_text,
tenant_id="health_check",
request_id="health_check"
)
health_status = {
"status": "healthy",
"model": self.model_name,
"dimensions": self.embedding_dimensions,
"max_batch_size": self.max_batch_size,
"max_sequence_length": self.max_sequence_length,
"endpoint": self.embedding_endpoint,
"stateless": True,
"memory_cleared": True,
"vllm_service_connected": len(test_embeddings) > 0
}
except Exception as e:
health_status = {
"status": "unhealthy",
"error": str(e),
"model": self.model_name,
"endpoint": self.embedding_endpoint
}
return health_status

View File

@@ -0,0 +1,780 @@
"""
Groq Cloud LLM Proxy Backend
Provides high-availability LLM inference through Groq Cloud with:
- HAProxy load balancing across multiple endpoints
- Automatic failover handled by HAProxy
- Token usage tracking and cost calculation
- Streaming response support
- Circuit breaker pattern for enhanced reliability
"""
import asyncio
import json
import os
import time
from typing import Dict, Any, List, Optional, AsyncGenerator
from datetime import datetime
import httpx
try:
from groq import AsyncGroq
GROQ_AVAILABLE = True
except ImportError:
# Groq not available in development mode
AsyncGroq = None
GROQ_AVAILABLE = False
import logging
from app.core.config import get_settings, get_model_configs
from app.services.model_service import get_model_service
logger = logging.getLogger(__name__)
settings = get_settings()
# Groq Compound tool pricing (per request/execution)
# Source: https://groq.com/pricing (Dec 2, 2025)
COMPOUND_TOOL_PRICES = {
# Web Search variants
"search": 0.008, # API returns "search" for web search
"web_search": 0.008, # $8 per 1K = $0.008 per request (Advanced Search)
"advanced_search": 0.008, # $8 per 1K requests
"basic_search": 0.005, # $5 per 1K requests
# Other tools
"visit_website": 0.001, # $1 per 1K requests
"python": 0.00005, # API returns "python" for code execution
"code_interpreter": 0.00005, # Alternative API identifier
"code_execution": 0.00005, # Alias for backwards compatibility
"browser_automation": 0.00002, # $0.08/hr ≈ $0.00002 per execution
}
# Model pricing per million tokens (input/output)
# Source: https://groq.com/pricing (Dec 2, 2025)
GROQ_MODEL_PRICES = {
"llama-3.3-70b-versatile": {"input": 0.59, "output": 0.79},
"llama-3.1-8b-instant": {"input": 0.05, "output": 0.08},
"llama-4-maverick-17b-128e-instruct": {"input": 0.20, "output": 0.60},
"meta-llama/llama-4-maverick-17b-128e-instruct": {"input": 0.20, "output": 0.60},
"llama-4-scout-17b-16e-instruct": {"input": 0.11, "output": 0.34},
"meta-llama/llama-4-scout-17b-16e-instruct": {"input": 0.11, "output": 0.34},
"llama-guard-4-12b": {"input": 0.20, "output": 0.20},
"meta-llama/llama-guard-4-12b": {"input": 0.20, "output": 0.20},
"gpt-oss-120b": {"input": 0.15, "output": 0.60},
"openai/gpt-oss-120b": {"input": 0.15, "output": 0.60},
"gpt-oss-20b": {"input": 0.075, "output": 0.30},
"openai/gpt-oss-20b": {"input": 0.075, "output": 0.30},
"kimi-k2-instruct-0905": {"input": 1.00, "output": 3.00},
"moonshotai/kimi-k2-instruct-0905": {"input": 1.00, "output": 3.00},
"qwen3-32b": {"input": 0.29, "output": 0.59},
# Compound models - 50/50 blended pricing from underlying models
# compound: GPT-OSS-120B ($0.15/$0.60) + Llama 4 Scout ($0.11/$0.34) = $0.13/$0.47
"compound": {"input": 0.13, "output": 0.47},
"groq/compound": {"input": 0.13, "output": 0.47},
"compound-beta": {"input": 0.13, "output": 0.47},
# compound-mini: GPT-OSS-120B ($0.15/$0.60) + Llama 3.3 70B ($0.59/$0.79) = $0.37/$0.695
"compound-mini": {"input": 0.37, "output": 0.695},
"groq/compound-mini": {"input": 0.37, "output": 0.695},
"compound-mini-beta": {"input": 0.37, "output": 0.695},
}
class GroqProxyBackend:
"""LLM inference via Groq Cloud with HAProxy load balancing"""
def __init__(self):
self.settings = get_settings()
self.client = None
self.usage_metrics = {}
self.circuit_breaker_status = {}
self._initialize_client()
def _initialize_client(self):
"""Initialize Groq client to use HAProxy load balancer"""
if not GROQ_AVAILABLE:
logger.warning("Groq client not available - running in development mode")
return
if self.settings.groq_api_key:
# Use HAProxy load balancer instead of direct Groq API
haproxy_endpoint = self.settings.haproxy_groq_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local"
# Initialize client with HAProxy endpoint
self.client = AsyncGroq(
api_key=self.settings.groq_api_key,
base_url=haproxy_endpoint,
timeout=httpx.Timeout(30.0), # Increased timeout for load balancing
max_retries=1 # Let HAProxy handle retries
)
# Initialize circuit breaker
self.circuit_breaker_status = {
"state": "closed", # closed, open, half_open
"failure_count": 0,
"last_failure_time": None,
"failure_threshold": 5,
"recovery_timeout": 60 # seconds
}
logger.info(f"Initialized Groq client with HAProxy endpoint: {haproxy_endpoint}")
async def execute_inference(
self,
prompt: str,
model: str = "llama-3.1-70b-versatile",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None
) -> Dict[str, Any]:
"""Execute LLM inference with HAProxy load balancing and circuit breaker"""
# Check circuit breaker
if not await self._is_circuit_closed():
raise Exception("Circuit breaker is open - service temporarily unavailable")
# Validate model and get configuration
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
if not model_config:
# Try to get from model service registry
model_service = get_model_service(tenant_id)
model_info = await model_service.get_model(model)
if not model_info:
raise ValueError(f"Unsupported model: {model}")
model_config = {
"max_tokens": model_info["performance"]["max_tokens"],
"cost_per_1k_tokens": model_info["performance"]["cost_per_1k_tokens"],
"supports_streaming": model_info["capabilities"].get("streaming", False)
}
# Apply token limits
max_tokens = min(max_tokens, model_config["max_tokens"])
# Prepare messages
messages = [
{"role": "user", "content": prompt}
]
try:
# Get tenant-specific API key
if not tenant_id:
raise ValueError("tenant_id is required for Groq inference")
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
start_time = time.time()
if stream:
return await self._stream_inference(
messages, model, temperature, max_tokens, user_id, tenant_id, client
)
else:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=False
)
# Track successful usage
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
response.usage.total_tokens if response.usage else 0,
latency, model_config["cost_per_1k_tokens"]
)
# Track in model service
model_service = get_model_service(tenant_id)
await model_service.track_model_usage(
model_id=model,
success=True,
latency_ms=latency
)
# Reset circuit breaker on success
await self._record_success()
return {
"content": response.choices[0].message.content,
"model": model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0,
"cost_cents": self._calculate_cost(
response.usage.total_tokens if response.usage else 0,
model_config["cost_per_1k_tokens"]
)
},
"latency_ms": latency,
"load_balanced": True,
"haproxy_backend": "groq_general_backend"
}
except Exception as e:
logger.error(f"HAProxy Groq inference failed: {e}")
# Track failure in model service
await model_service.track_model_usage(
model_id=model,
success=False
)
# Record failure for circuit breaker
await self._record_failure()
# Re-raise the exception - no client-side fallback needed
# HAProxy handles all failover logic
raise Exception(f"Groq inference failed (via HAProxy): {str(e)}")
async def _stream_inference(
self,
messages: List[Dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
user_id: str,
tenant_id: str,
client: AsyncGroq = None
) -> AsyncGenerator[str, None]:
"""Stream LLM inference responses"""
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
start_time = time.time()
total_tokens = 0
try:
# Use provided client or get tenant-specific client
if not client:
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
stream = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
async for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
total_tokens += len(content.split()) # Approximate token count
# Yield SSE formatted data
yield f"data: {json.dumps({'content': content})}\n\n"
# Track usage after streaming completes
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
total_tokens, latency,
model_config["cost_per_1k_tokens"]
)
# Send completion signal
yield f"data: {json.dumps({'done': True})}\n\n"
except Exception as e:
logger.error(f"Streaming inference error: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
async def check_health(self) -> Dict[str, Any]:
"""Check health of HAProxy load balancer and circuit breaker status"""
try:
# Check HAProxy health via stats endpoint
haproxy_stats_url = self.settings.haproxy_stats_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local:8404/stats"
async with httpx.AsyncClient() as client:
response = await client.get(
haproxy_stats_url,
timeout=5.0,
auth=("admin", "gt2_haproxy_stats_password")
)
if response.status_code == 200:
# Parse HAProxy stats (simplified)
stats_healthy = "UP" in response.text
return {
"haproxy_load_balancer": {
"healthy": stats_healthy,
"stats_accessible": True,
"last_check": datetime.utcnow().isoformat()
},
"circuit_breaker": {
"state": self.circuit_breaker_status["state"],
"failure_count": self.circuit_breaker_status["failure_count"],
"last_failure": self.circuit_breaker_status["last_failure_time"].isoformat() if self.circuit_breaker_status["last_failure_time"] else None
},
"groq_endpoints": {
"managed_by": "haproxy",
"failover_handled_by": "haproxy"
}
}
else:
return {
"haproxy_load_balancer": {
"healthy": False,
"error": f"Stats endpoint returned {response.status_code}",
"last_check": datetime.utcnow().isoformat()
}
}
except Exception as e:
return {
"haproxy_load_balancer": {
"healthy": False,
"error": str(e),
"last_check": datetime.utcnow().isoformat()
},
"circuit_breaker": {
"state": self.circuit_breaker_status["state"],
"failure_count": self.circuit_breaker_status["failure_count"]
}
}
async def _is_circuit_closed(self) -> bool:
"""Check if circuit breaker allows requests"""
if self.circuit_breaker_status["state"] == "closed":
return True
if self.circuit_breaker_status["state"] == "open":
# Check if recovery timeout has passed
if self.circuit_breaker_status["last_failure_time"]:
time_since_failure = (datetime.utcnow() - self.circuit_breaker_status["last_failure_time"]).total_seconds()
if time_since_failure > self.circuit_breaker_status["recovery_timeout"]:
# Move to half-open state
self.circuit_breaker_status["state"] = "half_open"
logger.info("Circuit breaker moved to half-open state")
return True
return False
if self.circuit_breaker_status["state"] == "half_open":
# Allow limited requests in half-open state
return True
return False
async def _record_success(self):
"""Record successful request for circuit breaker"""
if self.circuit_breaker_status["state"] == "half_open":
# Success in half-open state closes the circuit
self.circuit_breaker_status["state"] = "closed"
self.circuit_breaker_status["failure_count"] = 0
logger.info("Circuit breaker closed after successful request")
# Reset failure count on any success
self.circuit_breaker_status["failure_count"] = 0
async def _record_failure(self):
"""Record failed request for circuit breaker"""
self.circuit_breaker_status["failure_count"] += 1
self.circuit_breaker_status["last_failure_time"] = datetime.utcnow()
if self.circuit_breaker_status["failure_count"] >= self.circuit_breaker_status["failure_threshold"]:
if self.circuit_breaker_status["state"] in ["closed", "half_open"]:
self.circuit_breaker_status["state"] = "open"
logger.warning(f"Circuit breaker opened after {self.circuit_breaker_status['failure_count']} failures")
async def _track_usage(
self,
user_id: str,
tenant_id: str,
model: str,
tokens: int,
latency: float,
cost_per_1k: float
):
"""Track usage metrics for billing and monitoring"""
# Create usage key
usage_key = f"{tenant_id}:{user_id}:{model}"
# Initialize metrics if not exists
if usage_key not in self.usage_metrics:
self.usage_metrics[usage_key] = {
"total_tokens": 0,
"total_requests": 0,
"total_cost_cents": 0,
"average_latency": 0
}
# Update metrics
metrics = self.usage_metrics[usage_key]
metrics["total_tokens"] += tokens
metrics["total_requests"] += 1
metrics["total_cost_cents"] += self._calculate_cost(tokens, cost_per_1k)
# Update average latency
prev_avg = metrics["average_latency"]
prev_count = metrics["total_requests"] - 1
metrics["average_latency"] = (prev_avg * prev_count + latency) / metrics["total_requests"]
# Log high-level metrics
if metrics["total_requests"] % 100 == 0:
logger.info(f"Usage milestone for {usage_key}: {metrics}")
def _calculate_cost(self, tokens: int, cost_per_1k: float) -> int:
"""Calculate cost in cents"""
return int((tokens / 1000) * cost_per_1k * 100)
def _calculate_compound_cost(self, response_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Calculate detailed cost breakdown for Groq Compound responses.
Compound API returns usage_breakdown with per-model token counts
and executed_tools list showing which tools were called.
Returns:
Dict with total cost in dollars and detailed breakdown
"""
total_cost = 0.0
breakdown = {"models": [], "tools": [], "total_cost_dollars": 0.0, "total_cost_cents": 0}
# Parse usage_breakdown for per-model token costs
usage_breakdown = response_data.get("usage_breakdown", {})
models_usage = usage_breakdown.get("models", [])
for model_usage in models_usage:
model_name = model_usage.get("model", "")
usage = model_usage.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
# Get model pricing (try multiple name formats)
model_prices = GROQ_MODEL_PRICES.get(model_name)
if not model_prices:
# Try without provider prefix
short_name = model_name.split("/")[-1] if "/" in model_name else model_name
model_prices = GROQ_MODEL_PRICES.get(short_name, {"input": 0.15, "output": 0.60})
# Calculate cost per million tokens
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
model_total = input_cost + output_cost
breakdown["models"].append({
"model": model_name,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"input_cost_dollars": round(input_cost, 6),
"output_cost_dollars": round(output_cost, 6),
"total_cost_dollars": round(model_total, 6)
})
total_cost += model_total
# Parse executed_tools for tool costs
executed_tools = response_data.get("executed_tools", [])
for tool in executed_tools:
# Handle both string and dict formats
tool_name = tool if isinstance(tool, str) else tool.get("name", "unknown")
tool_cost = COMPOUND_TOOL_PRICES.get(tool_name.lower(), 0.008) # Default to advanced search
breakdown["tools"].append({
"tool": tool_name,
"cost_dollars": round(tool_cost, 6)
})
total_cost += tool_cost
breakdown["total_cost_dollars"] = round(total_cost, 6)
breakdown["total_cost_cents"] = int(total_cost * 100)
return breakdown
def _is_compound_model(self, model: str) -> bool:
"""Check if model is a Groq Compound model"""
model_lower = model.lower()
return "compound" in model_lower or model_lower.startswith("groq/compound")
async def get_available_models(self) -> List[Dict[str, Any]]:
"""Get list of available Groq models with their configurations"""
models = []
model_configs = get_model_configs()
for model_id, config in model_configs.get("groq", {}).items():
models.append({
"id": model_id,
"name": model_id.replace("-", " ").title(),
"provider": "groq",
"max_tokens": config["max_tokens"],
"cost_per_1k_tokens": config["cost_per_1k_tokens"],
"supports_streaming": config["supports_streaming"],
"supports_function_calling": config["supports_function_calling"]
})
return models
async def execute_inference_with_messages(
self,
messages: List[Dict[str, str]],
model: str = "llama-3.1-70b-versatile",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None
) -> Dict[str, Any]:
"""Execute LLM inference using messages format (conversation style)"""
# Check circuit breaker
if not await self._is_circuit_closed():
raise Exception("Circuit breaker is open - service temporarily unavailable")
# Validate model and get configuration
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
if not model_config:
# Try to get from model service registry
model_service = get_model_service(tenant_id)
model_info = await model_service.get_model(model)
if not model_info:
raise ValueError(f"Unsupported model: {model}")
model_config = {
"max_tokens": model_info["performance"]["max_tokens"],
"cost_per_1k_tokens": model_info["performance"]["cost_per_1k_tokens"],
"supports_streaming": model_info["capabilities"].get("streaming", False)
}
# Apply token limits
max_tokens = min(max_tokens, model_config["max_tokens"])
try:
# Get tenant-specific API key
if not tenant_id:
raise ValueError("tenant_id is required for Groq inference")
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
start_time = time.time()
# Translate GT 2.0 "agent" role to OpenAI/Groq "assistant" for external API compatibility
# Use dictionary unpacking to preserve ALL fields including tool_call_id
external_messages = []
for msg in messages:
external_msg = {
**msg, # Preserve ALL fields including tool_call_id, tool_calls, etc.
"role": "assistant" if msg.get("role") == "agent" else msg.get("role")
}
external_messages.append(external_msg)
if stream:
return await self._stream_inference_with_messages(
external_messages, model, temperature, max_tokens, user_id, tenant_id, client
)
else:
# Prepare request parameters
request_params = {
"model": model,
"messages": external_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": False
}
# Add tools if provided
if tools:
request_params["tools"] = tools
if tool_choice:
request_params["tool_choice"] = tool_choice
# Debug: Log messages being sent to Groq
logger.info(f"🔧 Sending {len(external_messages)} messages to Groq API")
for i, msg in enumerate(external_messages):
if msg.get("role") == "tool":
logger.info(f"🔧 Groq Message {i}: role=tool, tool_call_id={msg.get('tool_call_id')}")
else:
logger.info(f"🔧 Groq Message {i}: role={msg.get('role')}, has_tool_calls={bool(msg.get('tool_calls'))}")
response = await client.chat.completions.create(**request_params)
# Track successful usage
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
response.usage.total_tokens if response.usage else 0,
latency, model_config["cost_per_1k_tokens"]
)
# Track in model service
model_service = get_model_service(tenant_id)
await model_service.track_model_usage(
model_id=model,
success=True,
latency_ms=latency
)
# Reset circuit breaker on success
await self._record_success()
# Build base response
result = {
"content": response.choices[0].message.content,
"model": model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0,
"cost_cents": self._calculate_cost(
response.usage.total_tokens if response.usage else 0,
model_config["cost_per_1k_tokens"]
)
},
"latency_ms": latency,
"load_balanced": True,
"haproxy_backend": "groq_general_backend"
}
# For Compound models, extract and calculate detailed cost breakdown
if self._is_compound_model(model):
# Convert response to dict for processing
response_dict = response.model_dump() if hasattr(response, 'model_dump') else {}
# Extract usage_breakdown and executed_tools if present
usage_breakdown = getattr(response, 'usage_breakdown', None)
executed_tools = getattr(response, 'executed_tools', None)
if usage_breakdown or executed_tools:
compound_data = {
"usage_breakdown": usage_breakdown if isinstance(usage_breakdown, dict) else {},
"executed_tools": executed_tools if isinstance(executed_tools, list) else []
}
# Calculate detailed cost breakdown
cost_breakdown = self._calculate_compound_cost(compound_data)
# Add compound-specific data to response
result["usage_breakdown"] = compound_data.get("usage_breakdown", {})
result["executed_tools"] = compound_data.get("executed_tools", [])
result["cost_breakdown"] = cost_breakdown
# Update cost_cents with accurate compound calculation
if cost_breakdown["total_cost_cents"] > 0:
result["usage"]["cost_cents"] = cost_breakdown["total_cost_cents"]
logger.info(f"Compound model cost breakdown: {cost_breakdown}")
return result
except Exception as e:
logger.error(f"HAProxy Groq inference with messages failed: {e}")
# Track failure in model service
await model_service.track_model_usage(
model_id=model,
success=False
)
# Record failure for circuit breaker
await self._record_failure()
# Re-raise the exception
raise Exception(f"Groq inference with messages failed (via HAProxy): {str(e)}")
async def _stream_inference_with_messages(
self,
messages: List[Dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
user_id: str,
tenant_id: str,
client: AsyncGroq = None
) -> AsyncGenerator[str, None]:
"""Stream LLM inference responses using messages format"""
model_configs = get_model_configs(tenant_id)
model_config = model_configs.get("groq", {}).get(model)
start_time = time.time()
total_tokens = 0
try:
# Use provided client or get tenant-specific client
if not client:
api_key = await self._get_tenant_api_key(tenant_id)
client = self._get_client(api_key)
stream = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
async for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
total_tokens += len(content.split()) # Approximate token count
# Yield just the content (SSE formatting handled by caller)
yield content
# Track usage after streaming completes
latency = (time.time() - start_time) * 1000
await self._track_usage(
user_id, tenant_id, model,
total_tokens, latency,
model_config["cost_per_1k_tokens"] if model_config else 0.0
)
except Exception as e:
logger.error(f"Streaming inference with messages error: {e}")
raise e
async def _get_tenant_api_key(self, tenant_id: str) -> str:
"""
Get API key for tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
API keys are managed in Control Panel and fetched via internal API.
Args:
tenant_id: Tenant domain string from X-Tenant-ID header
Returns:
Decrypted Groq API key
Raises:
ValueError: If no API key configured (results in HTTP 503 to client)
"""
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="groq")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No Groq API key for tenant '{tenant_id}': {e}")
raise ValueError(str(e))
except RuntimeError as e:
logger.error(f"Control Panel error: {e}")
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
def _get_client(self, api_key: str) -> AsyncGroq:
"""Get Groq client with specified API key"""
if not GROQ_AVAILABLE:
raise Exception("Groq client not available in development mode")
haproxy_endpoint = self.settings.haproxy_groq_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local"
return AsyncGroq(
api_key=api_key,
base_url=haproxy_endpoint,
timeout=httpx.Timeout(30.0),
max_retries=1
)

View File

@@ -0,0 +1,407 @@
"""
NVIDIA NIM LLM Proxy Backend
Provides LLM inference through NVIDIA NIM with:
- OpenAI-compatible API format (build.nvidia.com)
- Token usage tracking and cost calculation
- Streaming response support
- Circuit breaker pattern for enhanced reliability
"""
import json
import time
from typing import Dict, Any, List, Optional, AsyncGenerator
from datetime import datetime
import httpx
import logging
from app.core.config import get_settings
logger = logging.getLogger(__name__)
# NVIDIA NIM Model pricing per million tokens (input/output)
# Source: build.nvidia.com (Dec 2025 pricing estimates)
# Note: Actual pricing may vary - check build.nvidia.com for current rates
NVIDIA_MODEL_PRICES = {
# Llama Nemotron family
"nvidia/llama-3.1-nemotron-ultra-253b-v1": {"input": 2.0, "output": 6.0},
"nvidia/llama-3.1-nemotron-super-49b-v1": {"input": 0.5, "output": 1.5},
"nvidia/llama-3.1-nemotron-nano-8b-v1": {"input": 0.1, "output": 0.3},
# Standard Llama models via NIM
"meta/llama-3.1-8b-instruct": {"input": 0.1, "output": 0.3},
"meta/llama-3.1-70b-instruct": {"input": 0.5, "output": 1.0},
"meta/llama-3.1-405b-instruct": {"input": 2.0, "output": 6.0},
# Mistral models
"mistralai/mistral-7b-instruct-v0.3": {"input": 0.1, "output": 0.2},
"mistralai/mixtral-8x7b-instruct-v0.1": {"input": 0.3, "output": 0.6},
# Default fallback
"default": {"input": 0.5, "output": 1.5},
}
class NvidiaProxyBackend:
"""LLM inference via NVIDIA NIM with OpenAI-compatible API"""
def __init__(self):
self.settings = get_settings()
self.base_url = getattr(self.settings, 'nvidia_nim_endpoint', None) or "https://integrate.api.nvidia.com/v1"
self.usage_metrics = {}
self.circuit_breaker_status = {
"state": "closed", # closed, open, half_open
"failure_count": 0,
"last_failure_time": None,
"failure_threshold": 5,
"recovery_timeout": 60 # seconds
}
logger.info(f"Initialized NVIDIA NIM backend with endpoint: {self.base_url}")
async def _get_tenant_api_key(self, tenant_id: str) -> str:
"""
Get API key for tenant from Control Panel database.
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
API keys are managed in Control Panel and fetched via internal API.
Args:
tenant_id: Tenant domain string from X-Tenant-ID header
Returns:
Decrypted NVIDIA API key
Raises:
ValueError: If no API key configured (results in HTTP 503 to client)
"""
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
client = get_api_key_client()
try:
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="nvidia")
return key_info["api_key"]
except APIKeyNotConfiguredError as e:
logger.error(f"No NVIDIA API key for tenant '{tenant_id}': {e}")
raise ValueError(str(e))
except RuntimeError as e:
logger.error(f"Control Panel error: {e}")
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
def _get_client(self, api_key: str) -> httpx.AsyncClient:
"""Get configured HTTP client for NVIDIA NIM API"""
return httpx.AsyncClient(
base_url=self.base_url,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
timeout=httpx.Timeout(120.0) # Longer timeout for large models
)
async def execute_inference(
self,
prompt: str,
model: str = "nvidia/llama-3.1-nemotron-super-49b-v1",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None
) -> Dict[str, Any]:
"""Execute LLM inference with simple prompt"""
messages = [{"role": "user", "content": prompt}]
return await self.execute_inference_with_messages(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
user_id=user_id,
tenant_id=tenant_id
)
async def execute_inference_with_messages(
self,
messages: List[Dict[str, str]],
model: str = "nvidia/llama-3.1-nemotron-super-49b-v1",
temperature: float = 0.7,
max_tokens: int = 4000,
stream: bool = False,
user_id: str = None,
tenant_id: str = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None
) -> Dict[str, Any]:
"""Execute LLM inference using messages format (conversation style)"""
# Check circuit breaker
if not await self._is_circuit_closed():
raise Exception("Circuit breaker is open - NVIDIA NIM service temporarily unavailable")
if not tenant_id:
raise ValueError("tenant_id is required for NVIDIA NIM inference")
try:
api_key = await self._get_tenant_api_key(tenant_id)
# Translate GT 2.0 "agent" role to OpenAI "assistant" for external API compatibility
external_messages = []
for msg in messages:
external_msg = {
**msg, # Preserve ALL fields including tool_call_id, tool_calls, etc.
"role": "assistant" if msg.get("role") == "agent" else msg.get("role")
}
external_messages.append(external_msg)
# Build request payload
request_data = {
"model": model,
"messages": external_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream
}
# Add tools if provided
if tools:
request_data["tools"] = tools
if tool_choice:
request_data["tool_choice"] = tool_choice
start_time = time.time()
async with self._get_client(api_key) as client:
if stream:
# Return generator for streaming
return self._stream_inference_with_messages(
client, request_data, user_id, tenant_id, model
)
# Non-streaming request
response = await client.post("/chat/completions", json=request_data)
response.raise_for_status()
data = response.json()
latency = (time.time() - start_time) * 1000
# Calculate cost
usage = data.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
cost_cents = int((input_cost + output_cost) * 100)
# Track usage
await self._track_usage(user_id, tenant_id, model, total_tokens, latency, cost_cents)
# Reset circuit breaker on success
await self._record_success()
# Build response
result = {
"content": data["choices"][0]["message"]["content"],
"model": model,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"cost_cents": cost_cents
},
"latency_ms": latency,
"provider": "nvidia"
}
# Include tool calls if present
message = data["choices"][0]["message"]
if message.get("tool_calls"):
result["tool_calls"] = message["tool_calls"]
return result
except httpx.HTTPStatusError as e:
logger.error(f"NVIDIA NIM API error: {e.response.status_code} - {e.response.text}")
await self._record_failure()
raise Exception(f"NVIDIA NIM inference failed: HTTP {e.response.status_code}")
except Exception as e:
logger.error(f"NVIDIA NIM inference failed: {e}")
await self._record_failure()
raise Exception(f"NVIDIA NIM inference failed: {str(e)}")
async def _stream_inference_with_messages(
self,
client: httpx.AsyncClient,
request_data: Dict[str, Any],
user_id: str,
tenant_id: str,
model: str
) -> AsyncGenerator[str, None]:
"""Stream LLM inference responses"""
start_time = time.time()
total_tokens = 0
try:
async with client.stream("POST", "/chat/completions", json=request_data) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
if chunk.get("choices") and chunk["choices"][0].get("delta", {}).get("content"):
content = chunk["choices"][0]["delta"]["content"]
total_tokens += len(content.split()) # Approximate
yield content
except json.JSONDecodeError:
continue
# Track usage after streaming completes
latency = (time.time() - start_time) * 1000
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
cost_cents = int((total_tokens / 1_000_000) * model_prices["output"] * 100)
await self._track_usage(user_id, tenant_id, model, total_tokens, latency, cost_cents)
await self._record_success()
except Exception as e:
logger.error(f"NVIDIA NIM streaming error: {e}")
await self._record_failure()
raise e
async def check_health(self) -> Dict[str, Any]:
"""Check health of NVIDIA NIM backend and circuit breaker status"""
return {
"nvidia_nim": {
"endpoint": self.base_url,
"status": "available" if self.circuit_breaker_status["state"] == "closed" else "degraded",
"last_check": datetime.utcnow().isoformat()
},
"circuit_breaker": {
"state": self.circuit_breaker_status["state"],
"failure_count": self.circuit_breaker_status["failure_count"],
"last_failure": self.circuit_breaker_status["last_failure_time"].isoformat()
if self.circuit_breaker_status["last_failure_time"] else None
}
}
async def _is_circuit_closed(self) -> bool:
"""Check if circuit breaker allows requests"""
if self.circuit_breaker_status["state"] == "closed":
return True
if self.circuit_breaker_status["state"] == "open":
# Check if recovery timeout has passed
if self.circuit_breaker_status["last_failure_time"]:
time_since_failure = (datetime.utcnow() - self.circuit_breaker_status["last_failure_time"]).total_seconds()
if time_since_failure > self.circuit_breaker_status["recovery_timeout"]:
# Move to half-open state
self.circuit_breaker_status["state"] = "half_open"
logger.info("NVIDIA NIM circuit breaker moved to half-open state")
return True
return False
if self.circuit_breaker_status["state"] == "half_open":
# Allow limited requests in half-open state
return True
return False
async def _record_success(self):
"""Record successful request for circuit breaker"""
if self.circuit_breaker_status["state"] == "half_open":
# Success in half-open state closes the circuit
self.circuit_breaker_status["state"] = "closed"
self.circuit_breaker_status["failure_count"] = 0
logger.info("NVIDIA NIM circuit breaker closed after successful request")
# Reset failure count on any success
self.circuit_breaker_status["failure_count"] = 0
async def _record_failure(self):
"""Record failed request for circuit breaker"""
self.circuit_breaker_status["failure_count"] += 1
self.circuit_breaker_status["last_failure_time"] = datetime.utcnow()
if self.circuit_breaker_status["failure_count"] >= self.circuit_breaker_status["failure_threshold"]:
if self.circuit_breaker_status["state"] in ["closed", "half_open"]:
self.circuit_breaker_status["state"] = "open"
logger.warning(f"NVIDIA NIM circuit breaker opened after {self.circuit_breaker_status['failure_count']} failures")
async def _track_usage(
self,
user_id: str,
tenant_id: str,
model: str,
tokens: int,
latency: float,
cost_cents: int
):
"""Track usage metrics for billing and monitoring"""
# Create usage key
usage_key = f"{tenant_id}:{user_id}:{model}"
# Initialize metrics if not exists
if usage_key not in self.usage_metrics:
self.usage_metrics[usage_key] = {
"total_tokens": 0,
"total_requests": 0,
"total_cost_cents": 0,
"average_latency": 0
}
# Update metrics
metrics = self.usage_metrics[usage_key]
metrics["total_tokens"] += tokens
metrics["total_requests"] += 1
metrics["total_cost_cents"] += cost_cents
# Update average latency
prev_avg = metrics["average_latency"]
prev_count = metrics["total_requests"] - 1
metrics["average_latency"] = (prev_avg * prev_count + latency) / metrics["total_requests"]
# Log high-level metrics periodically
if metrics["total_requests"] % 100 == 0:
logger.info(f"NVIDIA NIM usage milestone for {usage_key}: {metrics}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int, model: str) -> int:
"""Calculate cost in cents based on token usage"""
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
return int((input_cost + output_cost) * 100)
async def get_available_models(self) -> List[Dict[str, Any]]:
"""Get list of available NVIDIA NIM models with their configurations"""
models = []
for model_id, prices in NVIDIA_MODEL_PRICES.items():
if model_id == "default":
continue
models.append({
"id": model_id,
"name": model_id.split("/")[-1].replace("-", " ").title(),
"provider": "nvidia",
"max_tokens": 4096, # Default for most NIM models
"cost_per_1k_input": prices["input"],
"cost_per_1k_output": prices["output"],
"supports_streaming": True,
"supports_function_calling": True
})
return models