GT AI OS Community v2.0.33 - Add NVIDIA NIM and Nemotron agents
- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2 - Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2 - Made more general-purpose (flexible targets, expanded tools) - Added nemotron-mini-agent.csv for fast local inference via Ollama - Added nemotron-agent.csv for advanced reasoning via Ollama - Added wiki page: Projects for NVIDIA NIMs and Nemotron
This commit is contained in:
52
apps/resource-cluster/app/core/backends/__init__.py
Normal file
52
apps/resource-cluster/app/core/backends/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Resource backend implementations for GT 2.0
|
||||
|
||||
Provides unified interfaces for all resource types:
|
||||
- LLM inference (Groq, OpenAI, Anthropic)
|
||||
- Vector databases (PGVector)
|
||||
- Document processing (Unstructured)
|
||||
- External services (OAuth2, iframe)
|
||||
- AI literacy resources
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Registry of available backends
|
||||
BACKEND_REGISTRY: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def register_backend(name: str, backend_class):
|
||||
"""Register a resource backend"""
|
||||
BACKEND_REGISTRY[name] = backend_class
|
||||
logger.info(f"Registered backend: {name}")
|
||||
|
||||
|
||||
def get_backend(name: str):
|
||||
"""Get a registered backend"""
|
||||
if name not in BACKEND_REGISTRY:
|
||||
raise ValueError(f"Backend not found: {name}")
|
||||
return BACKEND_REGISTRY[name]
|
||||
|
||||
|
||||
async def initialize_backends():
|
||||
"""Initialize all resource backends"""
|
||||
from app.core.backends.groq_proxy import GroqProxyBackend
|
||||
from app.core.backends.nvidia_proxy import NvidiaProxyBackend
|
||||
from app.core.backends.document_processor import DocumentProcessorBackend
|
||||
from app.core.backends.embedding_backend import EmbeddingBackend
|
||||
|
||||
# Register backends
|
||||
register_backend("groq_proxy", GroqProxyBackend())
|
||||
register_backend("nvidia_proxy", NvidiaProxyBackend())
|
||||
register_backend("document_processor", DocumentProcessorBackend())
|
||||
register_backend("embedding", EmbeddingBackend())
|
||||
|
||||
logger.info("All resource backends initialized")
|
||||
|
||||
|
||||
def get_embedding_backend():
|
||||
"""Get the embedding backend instance"""
|
||||
return get_backend("embedding")
|
||||
322
apps/resource-cluster/app/core/backends/document_processor.py
Normal file
322
apps/resource-cluster/app/core/backends/document_processor.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Document Processing Backend
|
||||
|
||||
STATELESS document chunking and preprocessing for RAG operations.
|
||||
All processing happens in memory - NO user data is ever stored.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import io
|
||||
import gc
|
||||
from typing import Dict, Any, List, Optional, BinaryIO
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
|
||||
# Document processing imports
|
||||
import pypdf as PyPDF2
|
||||
from docx import Document as DocxDocument
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain_text_splitters import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
TokenTextSplitter,
|
||||
SentenceTransformersTokenTextSplitter
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkingStrategy:
|
||||
"""Configuration for document chunking"""
|
||||
strategy_type: str # 'fixed', 'semantic', 'hierarchical', 'hybrid'
|
||||
chunk_size: int # Target chunk size in tokens (optimized for BGE-M3: 512)
|
||||
chunk_overlap: int # Overlap between chunks (typically 128 for BGE-M3)
|
||||
separator_pattern: Optional[str] = None # Custom separator for splitting
|
||||
preserve_paragraphs: bool = True
|
||||
preserve_sentences: bool = True
|
||||
|
||||
|
||||
class DocumentProcessorBackend:
|
||||
"""
|
||||
STATELESS document chunking and processing backend.
|
||||
|
||||
Security principles:
|
||||
- NO persistence of user data
|
||||
- All processing in memory only
|
||||
- Immediate memory cleanup after processing
|
||||
- No caching of user content
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.supported_formats = [".pdf", ".docx", ".txt", ".md", ".html"]
|
||||
# BGE-M3 optimal settings
|
||||
self.default_chunk_size = 512 # tokens
|
||||
self.default_chunk_overlap = 128 # tokens
|
||||
self.model_name = "BAAI/bge-m3" # For tokenization
|
||||
logger.info("STATELESS document processor backend initialized")
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
content: bytes,
|
||||
document_type: str,
|
||||
strategy: Optional[ChunkingStrategy] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process document into chunks - STATELESS operation.
|
||||
|
||||
Args:
|
||||
content: Document content as bytes (will be cleared from memory)
|
||||
document_type: File type (.pdf, .docx, .txt, .md, .html)
|
||||
strategy: Chunking strategy configuration
|
||||
metadata: Optional metadata (will NOT include user content)
|
||||
|
||||
Returns:
|
||||
List of chunks with metadata (immediately returned, not stored)
|
||||
"""
|
||||
try:
|
||||
# Use default strategy if not provided
|
||||
if strategy is None:
|
||||
strategy = ChunkingStrategy(
|
||||
strategy_type='hybrid',
|
||||
chunk_size=self.default_chunk_size,
|
||||
chunk_overlap=self.default_chunk_overlap
|
||||
)
|
||||
|
||||
# Extract text based on document type (in memory)
|
||||
text = await self._extract_text_from_bytes(content, document_type)
|
||||
|
||||
# Clear original content from memory
|
||||
del content
|
||||
gc.collect()
|
||||
|
||||
# Apply chunking strategy
|
||||
if strategy.strategy_type == 'semantic':
|
||||
chunks = await self._semantic_chunking(text, strategy)
|
||||
elif strategy.strategy_type == 'hierarchical':
|
||||
chunks = await self._hierarchical_chunking(text, strategy)
|
||||
elif strategy.strategy_type == 'hybrid':
|
||||
chunks = await self._hybrid_chunking(text, strategy)
|
||||
else: # 'fixed'
|
||||
chunks = await self._fixed_chunking(text, strategy)
|
||||
|
||||
# Clear text from memory
|
||||
del text
|
||||
gc.collect()
|
||||
|
||||
# Add metadata without storing content
|
||||
processed_chunks = []
|
||||
for idx, chunk in enumerate(chunks):
|
||||
chunk_metadata = {
|
||||
"chunk_index": idx,
|
||||
"total_chunks": len(chunks),
|
||||
"chunking_strategy": strategy.strategy_type,
|
||||
"chunk_size_tokens": strategy.chunk_size,
|
||||
# Generate hash for deduplication without storing content
|
||||
"content_hash": hashlib.sha256(chunk.encode()).hexdigest()[:16]
|
||||
}
|
||||
|
||||
# Add non-sensitive metadata if provided
|
||||
if metadata:
|
||||
# Filter out any potential sensitive data
|
||||
safe_metadata = {
|
||||
k: v for k, v in metadata.items()
|
||||
if k in ['document_type', 'processing_timestamp', 'tenant_id']
|
||||
}
|
||||
chunk_metadata.update(safe_metadata)
|
||||
|
||||
processed_chunks.append({
|
||||
"text": chunk,
|
||||
"metadata": chunk_metadata
|
||||
})
|
||||
|
||||
logger.info(f"Processed document into {len(processed_chunks)} chunks (STATELESS)")
|
||||
|
||||
# Return immediately - no storage
|
||||
return processed_chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document: {e}")
|
||||
# Ensure memory is cleared even on error
|
||||
gc.collect()
|
||||
raise
|
||||
finally:
|
||||
# Always ensure memory cleanup
|
||||
gc.collect()
|
||||
|
||||
async def _extract_text_from_bytes(
|
||||
self,
|
||||
content: bytes,
|
||||
document_type: str
|
||||
) -> str:
|
||||
"""Extract text from document bytes - in memory only"""
|
||||
|
||||
try:
|
||||
if document_type == ".pdf":
|
||||
return await self._extract_pdf_text(io.BytesIO(content))
|
||||
elif document_type == ".docx":
|
||||
return await self._extract_docx_text(io.BytesIO(content))
|
||||
elif document_type == ".html":
|
||||
return await self._extract_html_text(content.decode('utf-8'))
|
||||
elif document_type in [".txt", ".md"]:
|
||||
return content.decode('utf-8')
|
||||
else:
|
||||
raise ValueError(f"Unsupported document type: {document_type}")
|
||||
finally:
|
||||
# Clear content from memory
|
||||
del content
|
||||
gc.collect()
|
||||
|
||||
async def _extract_pdf_text(self, file_stream: BinaryIO) -> str:
|
||||
"""Extract text from PDF - in memory"""
|
||||
text = ""
|
||||
try:
|
||||
pdf_reader = PyPDF2.PdfReader(file_stream)
|
||||
for page_num in range(len(pdf_reader.pages)):
|
||||
page = pdf_reader.pages[page_num]
|
||||
text += page.extract_text() + "\n"
|
||||
finally:
|
||||
file_stream.close()
|
||||
gc.collect()
|
||||
return text
|
||||
|
||||
async def _extract_docx_text(self, file_stream: BinaryIO) -> str:
|
||||
"""Extract text from DOCX - in memory"""
|
||||
text = ""
|
||||
try:
|
||||
doc = DocxDocument(file_stream)
|
||||
for paragraph in doc.paragraphs:
|
||||
text += paragraph.text + "\n"
|
||||
finally:
|
||||
file_stream.close()
|
||||
gc.collect()
|
||||
return text
|
||||
|
||||
async def _extract_html_text(self, html_content: str) -> str:
|
||||
"""Extract text from HTML - in memory"""
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
# Remove script and style elements
|
||||
for script in soup(["script", "style"]):
|
||||
script.decompose()
|
||||
text = soup.get_text()
|
||||
# Clean up whitespace
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = '\n'.join(chunk for chunk in chunks if chunk)
|
||||
return text
|
||||
|
||||
async def _semantic_chunking(
|
||||
self,
|
||||
text: str,
|
||||
strategy: ChunkingStrategy
|
||||
) -> List[str]:
|
||||
"""Semantic chunking using sentence boundaries"""
|
||||
splitter = SentenceTransformersTokenTextSplitter(
|
||||
model_name=self.model_name,
|
||||
chunk_size=strategy.chunk_size,
|
||||
chunk_overlap=strategy.chunk_overlap
|
||||
)
|
||||
return splitter.split_text(text)
|
||||
|
||||
async def _hierarchical_chunking(
|
||||
self,
|
||||
text: str,
|
||||
strategy: ChunkingStrategy
|
||||
) -> List[str]:
|
||||
"""Hierarchical chunking preserving document structure"""
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=strategy.chunk_size * 3, # Approximate token to char ratio
|
||||
chunk_overlap=strategy.chunk_overlap * 3,
|
||||
separators=["\n\n\n", "\n\n", "\n", ". ", " ", ""],
|
||||
keep_separator=True
|
||||
)
|
||||
return splitter.split_text(text)
|
||||
|
||||
async def _hybrid_chunking(
|
||||
self,
|
||||
text: str,
|
||||
strategy: ChunkingStrategy
|
||||
) -> List[str]:
|
||||
"""Hybrid chunking combining semantic and structural boundaries"""
|
||||
# First split by structure
|
||||
structural_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=strategy.chunk_size * 4,
|
||||
chunk_overlap=0,
|
||||
separators=["\n\n\n", "\n\n"],
|
||||
keep_separator=True
|
||||
)
|
||||
structural_chunks = structural_splitter.split_text(text)
|
||||
|
||||
# Then apply semantic splitting to each structural chunk
|
||||
final_chunks = []
|
||||
token_splitter = TokenTextSplitter(
|
||||
chunk_size=strategy.chunk_size,
|
||||
chunk_overlap=strategy.chunk_overlap
|
||||
)
|
||||
|
||||
for struct_chunk in structural_chunks:
|
||||
semantic_chunks = token_splitter.split_text(struct_chunk)
|
||||
final_chunks.extend(semantic_chunks)
|
||||
|
||||
return final_chunks
|
||||
|
||||
async def _fixed_chunking(
|
||||
self,
|
||||
text: str,
|
||||
strategy: ChunkingStrategy
|
||||
) -> List[str]:
|
||||
"""Fixed-size chunking with token boundaries"""
|
||||
splitter = TokenTextSplitter(
|
||||
chunk_size=strategy.chunk_size,
|
||||
chunk_overlap=strategy.chunk_overlap
|
||||
)
|
||||
return splitter.split_text(text)
|
||||
|
||||
async def validate_document(
|
||||
self,
|
||||
content_size: int,
|
||||
document_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate document before processing - no content stored.
|
||||
|
||||
Args:
|
||||
content_size: Size of document in bytes
|
||||
document_type: File extension
|
||||
|
||||
Returns:
|
||||
Validation result with any warnings
|
||||
"""
|
||||
MAX_SIZE = 50 * 1024 * 1024 # 50MB max
|
||||
|
||||
validation = {
|
||||
"valid": True,
|
||||
"warnings": [],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
# Check file size
|
||||
if content_size > MAX_SIZE:
|
||||
validation["valid"] = False
|
||||
validation["errors"].append(f"File size exceeds maximum of 50MB")
|
||||
elif content_size > 10 * 1024 * 1024: # Warning for files over 10MB
|
||||
validation["warnings"].append("Large file may take longer to process")
|
||||
|
||||
# Check document type
|
||||
if document_type not in self.supported_formats:
|
||||
validation["valid"] = False
|
||||
validation["errors"].append(f"Unsupported format: {document_type}")
|
||||
|
||||
return validation
|
||||
|
||||
async def check_health(self) -> Dict[str, Any]:
|
||||
"""Check document processor health - no user data exposed"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"supported_formats": self.supported_formats,
|
||||
"default_chunk_size": self.default_chunk_size,
|
||||
"default_chunk_overlap": self.default_chunk_overlap,
|
||||
"model": self.model_name,
|
||||
"stateless": True, # Confirm stateless operation
|
||||
"memory_cleared": True # Confirm memory management
|
||||
}
|
||||
471
apps/resource-cluster/app/core/backends/embedding_backend.py
Normal file
471
apps/resource-cluster/app/core/backends/embedding_backend.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""
|
||||
Embedding Model Backend
|
||||
|
||||
STATELESS embedding generation using BGE-M3 model hosted on GT's GPU clusters.
|
||||
All embeddings are generated in real-time - NO user data is stored.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import gc
|
||||
import hashlib
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
# import numpy as np # Temporarily disabled for Docker build
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingRequest:
|
||||
"""Request structure for embedding generation"""
|
||||
texts: List[str]
|
||||
model: str = "BAAI/bge-m3"
|
||||
batch_size: int = 32
|
||||
normalize: bool = True
|
||||
instruction: Optional[str] = None # For instruction-based embeddings
|
||||
|
||||
|
||||
class EmbeddingBackend:
|
||||
"""
|
||||
STATELESS embedding backend for BGE-M3 model.
|
||||
|
||||
Security principles:
|
||||
- NO persistence of embeddings or text
|
||||
- All processing via GT's internal GPU cluster
|
||||
- Immediate memory cleanup after generation
|
||||
- No caching of user content
|
||||
- Request signing and verification
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_name = "BAAI/bge-m3"
|
||||
self.embedding_dimensions = 1024 # BGE-M3 dimensions
|
||||
self.max_batch_size = 32
|
||||
self.max_sequence_length = 8192 # BGE-M3 supports up to 8192 tokens
|
||||
|
||||
# Determine endpoint based on configuration
|
||||
self.embedding_endpoint = self._get_embedding_endpoint()
|
||||
|
||||
# Timeout for embedding requests
|
||||
self.request_timeout = 60 # seconds for model loading
|
||||
|
||||
logger.info(f"STATELESS embedding backend initialized for {self.model_name}")
|
||||
logger.info(f"Using embedding endpoint: {self.embedding_endpoint}")
|
||||
|
||||
def _get_embedding_endpoint(self) -> str:
|
||||
"""
|
||||
Get the embedding endpoint based on configuration.
|
||||
Priority:
|
||||
1. Model registry from config sync (database-backed)
|
||||
2. Environment variables (BGE_M3_LOCAL_MODE, BGE_M3_EXTERNAL_ENDPOINT)
|
||||
3. Default local endpoint
|
||||
"""
|
||||
# Try to get configuration from model registry first (loaded from database)
|
||||
try:
|
||||
from app.services.model_service import default_model_service
|
||||
import asyncio
|
||||
|
||||
# Use the default model service instance (singleton) used by config sync
|
||||
model_service = default_model_service
|
||||
|
||||
# Try to get the model config synchronously (during initialization)
|
||||
# The get_model method is async, so we need to handle this carefully
|
||||
bge_m3_config = model_service.model_registry.get("BAAI/bge-m3")
|
||||
|
||||
if bge_m3_config:
|
||||
# Model registry stores endpoint as 'endpoint_url' and config as 'parameters'
|
||||
endpoint = bge_m3_config.get("endpoint_url")
|
||||
config = bge_m3_config.get("parameters", {})
|
||||
is_local_mode = config.get("is_local_mode", True)
|
||||
external_endpoint = config.get("external_endpoint")
|
||||
|
||||
logger.info(f"Found BGE-M3 in registry: endpoint_url={endpoint}, is_local_mode={is_local_mode}, external_endpoint={external_endpoint}")
|
||||
|
||||
if endpoint:
|
||||
logger.info(f"Using BGE-M3 endpoint from model registry (is_local_mode={is_local_mode}): {endpoint}")
|
||||
return endpoint
|
||||
else:
|
||||
logger.warning(f"BGE-M3 found in registry but endpoint_url is None/empty. Full config: {bge_m3_config}")
|
||||
else:
|
||||
available_models = list(model_service.model_registry.keys())
|
||||
logger.debug(f"BGE-M3 not found in model registry during init (expected on first startup). Available models: {available_models}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Model registry not yet available during startup (will be populated after config sync): {e}")
|
||||
|
||||
# Fall back to Settings fields (environment variables or .env file)
|
||||
is_local_mode = getattr(settings, 'bge_m3_local_mode', True)
|
||||
external_endpoint = getattr(settings, 'bge_m3_external_endpoint', None)
|
||||
|
||||
if not is_local_mode and external_endpoint:
|
||||
logger.info(f"Using external BGE-M3 endpoint from settings: {external_endpoint}")
|
||||
return external_endpoint
|
||||
|
||||
# Default to local endpoint
|
||||
local_endpoint = getattr(
|
||||
settings,
|
||||
'embedding_endpoint',
|
||||
'http://gentwo-vllm-embeddings:8000/v1/embeddings'
|
||||
)
|
||||
logger.info(f"Using local BGE-M3 endpoint: {local_endpoint}")
|
||||
return local_endpoint
|
||||
|
||||
async def update_endpoint_config(self, is_local_mode: bool, external_endpoint: str = None):
|
||||
"""
|
||||
Update the embedding endpoint configuration dynamically.
|
||||
This allows switching between local and external endpoints without restart.
|
||||
"""
|
||||
if is_local_mode:
|
||||
self.embedding_endpoint = getattr(
|
||||
settings,
|
||||
'embedding_endpoint',
|
||||
'http://gentwo-vllm-embeddings:8000/v1/embeddings'
|
||||
)
|
||||
else:
|
||||
if external_endpoint:
|
||||
self.embedding_endpoint = external_endpoint
|
||||
else:
|
||||
raise ValueError("External endpoint must be provided when not in local mode")
|
||||
|
||||
logger.info(f"BGE-M3 endpoint updated to: {self.embedding_endpoint}")
|
||||
logger.info(f"Mode: {'Local GT Edge' if is_local_mode else 'External API'}")
|
||||
|
||||
def refresh_endpoint_from_registry(self):
|
||||
"""
|
||||
Refresh the embedding endpoint from the model registry.
|
||||
Called by config sync when BGE-M3 configuration changes.
|
||||
"""
|
||||
logger.info(f"Refreshing embedding endpoint - current: {self.embedding_endpoint}")
|
||||
new_endpoint = self._get_embedding_endpoint()
|
||||
if new_endpoint != self.embedding_endpoint:
|
||||
logger.info(f"Refreshing BGE-M3 endpoint from {self.embedding_endpoint} to {new_endpoint}")
|
||||
self.embedding_endpoint = new_endpoint
|
||||
else:
|
||||
logger.info(f"BGE-M3 endpoint unchanged: {self.embedding_endpoint}")
|
||||
|
||||
async def generate_embeddings(
|
||||
self,
|
||||
texts: List[str],
|
||||
instruction: Optional[str] = None,
|
||||
tenant_id: str = None,
|
||||
request_id: str = None
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for texts using BGE-M3 - STATELESS operation.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed (will be cleared from memory)
|
||||
instruction: Optional instruction for query vs document embeddings
|
||||
tenant_id: Tenant ID for audit logging (not stored with data)
|
||||
request_id: Request ID for tracing
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (immediately returned, not stored)
|
||||
"""
|
||||
try:
|
||||
# Validate input
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
if len(texts) > self.max_batch_size:
|
||||
# Process in batches
|
||||
return await self._batch_process_embeddings(
|
||||
texts, instruction, tenant_id, request_id
|
||||
)
|
||||
|
||||
# Prepare request
|
||||
request_data = {
|
||||
"model": self.model_name,
|
||||
"input": texts,
|
||||
"encoding_format": "float",
|
||||
"dimensions": self.embedding_dimensions
|
||||
}
|
||||
|
||||
# Add instruction if provided (for query vs document distinction)
|
||||
if instruction:
|
||||
request_data["instruction"] = instruction
|
||||
|
||||
# Add metadata for audit (not stored with embeddings)
|
||||
metadata = {
|
||||
"tenant_id": tenant_id,
|
||||
"request_id": request_id,
|
||||
"text_count": len(texts),
|
||||
# Hash for deduplication without storing content
|
||||
"content_hash": hashlib.sha256(
|
||||
"".join(texts).encode()
|
||||
).hexdigest()[:16]
|
||||
}
|
||||
|
||||
# Call vLLM service - NO FALLBACKS
|
||||
embeddings = await self._call_embedding_service(request_data, metadata)
|
||||
|
||||
# Clear texts from memory immediately
|
||||
del texts
|
||||
gc.collect()
|
||||
|
||||
# Validate response
|
||||
if not embeddings or len(embeddings) == 0:
|
||||
raise ValueError("No embeddings returned from service")
|
||||
|
||||
# Normalize if needed
|
||||
if self._should_normalize():
|
||||
embeddings = self._normalize_embeddings(embeddings)
|
||||
|
||||
logger.info(
|
||||
f"Generated {len(embeddings)} embeddings (STATELESS) "
|
||||
f"for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
# Return immediately - no storage
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings: {e}")
|
||||
# Ensure memory is cleared even on error
|
||||
gc.collect()
|
||||
raise
|
||||
finally:
|
||||
# Always ensure memory cleanup
|
||||
gc.collect()
|
||||
|
||||
async def _batch_process_embeddings(
|
||||
self,
|
||||
texts: List[str],
|
||||
instruction: Optional[str],
|
||||
tenant_id: str,
|
||||
request_id: str
|
||||
) -> List[List[float]]:
|
||||
"""Process large text lists in batches using vLLM service"""
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), self.max_batch_size):
|
||||
batch = texts[i:i + self.max_batch_size]
|
||||
|
||||
# Prepare request for this batch
|
||||
request_data = {
|
||||
"model": self.model_name,
|
||||
"input": batch,
|
||||
"encoding_format": "float",
|
||||
"dimensions": self.embedding_dimensions
|
||||
}
|
||||
|
||||
if instruction:
|
||||
request_data["instruction"] = instruction
|
||||
|
||||
metadata = {
|
||||
"tenant_id": tenant_id,
|
||||
"request_id": f"{request_id}_batch_{i}",
|
||||
"text_count": len(batch),
|
||||
"content_hash": hashlib.sha256(
|
||||
"".join(batch).encode()
|
||||
).hexdigest()[:16]
|
||||
}
|
||||
|
||||
batch_embeddings = await self._call_embedding_service(request_data, metadata)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Clear batch from memory
|
||||
del batch
|
||||
gc.collect()
|
||||
|
||||
return all_embeddings
|
||||
|
||||
|
||||
async def _call_embedding_service(
|
||||
self,
|
||||
request_data: Dict[str, Any],
|
||||
metadata: Dict[str, Any]
|
||||
) -> List[List[float]]:
|
||||
"""Call internal GPU cluster embedding service"""
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
# Add capability token for authentication
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Tenant-ID": metadata.get("tenant_id", ""),
|
||||
"X-Request-ID": metadata.get("request_id", ""),
|
||||
# Authorization will be added by Resource Cluster
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
self.embedding_endpoint,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.request_timeout)
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise ValueError(
|
||||
f"Embedding service error: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
|
||||
# Extract embeddings from response
|
||||
if "data" in result:
|
||||
embeddings = [item["embedding"] for item in result["data"]]
|
||||
elif "embeddings" in result:
|
||||
embeddings = result["embeddings"]
|
||||
else:
|
||||
raise ValueError("Invalid embedding service response format")
|
||||
|
||||
return embeddings
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise ValueError(f"Embedding service timeout after {self.request_timeout}s")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling embedding service: {e}")
|
||||
raise
|
||||
|
||||
def _should_normalize(self) -> bool:
|
||||
"""Check if embeddings should be normalized"""
|
||||
# BGE-M3 embeddings are typically normalized for similarity search
|
||||
return True
|
||||
|
||||
def _normalize_embeddings(
|
||||
self,
|
||||
embeddings: List[List[float]]
|
||||
) -> List[List[float]]:
|
||||
"""Normalize embedding vectors to unit length"""
|
||||
normalized = []
|
||||
|
||||
for embedding in embeddings:
|
||||
# Simple normalization without numpy (for now)
|
||||
import math
|
||||
|
||||
# Calculate norm
|
||||
norm = math.sqrt(sum(x * x for x in embedding))
|
||||
|
||||
if norm > 0:
|
||||
normalized_vec = [x / norm for x in embedding]
|
||||
else:
|
||||
normalized_vec = embedding[:]
|
||||
|
||||
normalized.append(normalized_vec)
|
||||
|
||||
return normalized
|
||||
|
||||
async def generate_query_embeddings(
|
||||
self,
|
||||
queries: List[str],
|
||||
tenant_id: str = None,
|
||||
request_id: str = None
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings specifically for queries.
|
||||
BGE-M3 can use different instructions for queries vs documents.
|
||||
"""
|
||||
# For BGE-M3, queries can use a specific instruction
|
||||
instruction = "Represent this sentence for searching relevant passages: "
|
||||
return await self.generate_embeddings(
|
||||
queries, instruction, tenant_id, request_id
|
||||
)
|
||||
|
||||
async def generate_document_embeddings(
|
||||
self,
|
||||
documents: List[str],
|
||||
tenant_id: str = None,
|
||||
request_id: str = None
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings specifically for documents.
|
||||
BGE-M3 can use different instructions for documents vs queries.
|
||||
"""
|
||||
# For BGE-M3, documents typically don't need special instruction
|
||||
return await self.generate_embeddings(
|
||||
documents, None, tenant_id, request_id
|
||||
)
|
||||
|
||||
async def validate_texts(
|
||||
self,
|
||||
texts: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate texts before embedding - no content stored.
|
||||
|
||||
Args:
|
||||
texts: List of texts to validate
|
||||
|
||||
Returns:
|
||||
Validation result with any warnings
|
||||
"""
|
||||
validation = {
|
||||
"valid": True,
|
||||
"warnings": [],
|
||||
"errors": [],
|
||||
"stats": {
|
||||
"total_texts": len(texts),
|
||||
"max_length": 0,
|
||||
"avg_length": 0
|
||||
}
|
||||
}
|
||||
|
||||
if not texts:
|
||||
validation["valid"] = False
|
||||
validation["errors"].append("No texts provided")
|
||||
return validation
|
||||
|
||||
# Check text lengths
|
||||
lengths = [len(text) for text in texts]
|
||||
validation["stats"]["max_length"] = max(lengths)
|
||||
validation["stats"]["avg_length"] = sum(lengths) // len(lengths)
|
||||
|
||||
# BGE-M3 max sequence length check (approximate)
|
||||
max_chars = self.max_sequence_length * 4 # Rough char to token ratio
|
||||
|
||||
for i, length in enumerate(lengths):
|
||||
if length > max_chars:
|
||||
validation["warnings"].append(
|
||||
f"Text {i} may exceed model's max sequence length"
|
||||
)
|
||||
elif length == 0:
|
||||
validation["errors"].append(f"Text {i} is empty")
|
||||
validation["valid"] = False
|
||||
|
||||
# Batch size check
|
||||
if len(texts) > self.max_batch_size * 10:
|
||||
validation["warnings"].append(
|
||||
f"Large batch ({len(texts)} texts) will be processed in chunks"
|
||||
)
|
||||
|
||||
return validation
|
||||
|
||||
async def check_health(self) -> Dict[str, Any]:
|
||||
"""Check embedding backend health - no user data exposed"""
|
||||
try:
|
||||
# Test connection to vLLM service
|
||||
test_text = ["Health check test"]
|
||||
test_embeddings = await self.generate_embeddings(
|
||||
test_text,
|
||||
tenant_id="health_check",
|
||||
request_id="health_check"
|
||||
)
|
||||
|
||||
health_status = {
|
||||
"status": "healthy",
|
||||
"model": self.model_name,
|
||||
"dimensions": self.embedding_dimensions,
|
||||
"max_batch_size": self.max_batch_size,
|
||||
"max_sequence_length": self.max_sequence_length,
|
||||
"endpoint": self.embedding_endpoint,
|
||||
"stateless": True,
|
||||
"memory_cleared": True,
|
||||
"vllm_service_connected": len(test_embeddings) > 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
health_status = {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"model": self.model_name,
|
||||
"endpoint": self.embedding_endpoint
|
||||
}
|
||||
|
||||
return health_status
|
||||
780
apps/resource-cluster/app/core/backends/groq_proxy.py
Normal file
780
apps/resource-cluster/app/core/backends/groq_proxy.py
Normal file
@@ -0,0 +1,780 @@
|
||||
"""
|
||||
Groq Cloud LLM Proxy Backend
|
||||
|
||||
Provides high-availability LLM inference through Groq Cloud with:
|
||||
- HAProxy load balancing across multiple endpoints
|
||||
- Automatic failover handled by HAProxy
|
||||
- Token usage tracking and cost calculation
|
||||
- Streaming response support
|
||||
- Circuit breaker pattern for enhanced reliability
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
try:
|
||||
from groq import AsyncGroq
|
||||
GROQ_AVAILABLE = True
|
||||
except ImportError:
|
||||
# Groq not available in development mode
|
||||
AsyncGroq = None
|
||||
GROQ_AVAILABLE = False
|
||||
import logging
|
||||
|
||||
from app.core.config import get_settings, get_model_configs
|
||||
from app.services.model_service import get_model_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
# Groq Compound tool pricing (per request/execution)
|
||||
# Source: https://groq.com/pricing (Dec 2, 2025)
|
||||
COMPOUND_TOOL_PRICES = {
|
||||
# Web Search variants
|
||||
"search": 0.008, # API returns "search" for web search
|
||||
"web_search": 0.008, # $8 per 1K = $0.008 per request (Advanced Search)
|
||||
"advanced_search": 0.008, # $8 per 1K requests
|
||||
"basic_search": 0.005, # $5 per 1K requests
|
||||
# Other tools
|
||||
"visit_website": 0.001, # $1 per 1K requests
|
||||
"python": 0.00005, # API returns "python" for code execution
|
||||
"code_interpreter": 0.00005, # Alternative API identifier
|
||||
"code_execution": 0.00005, # Alias for backwards compatibility
|
||||
"browser_automation": 0.00002, # $0.08/hr ≈ $0.00002 per execution
|
||||
}
|
||||
|
||||
# Model pricing per million tokens (input/output)
|
||||
# Source: https://groq.com/pricing (Dec 2, 2025)
|
||||
GROQ_MODEL_PRICES = {
|
||||
"llama-3.3-70b-versatile": {"input": 0.59, "output": 0.79},
|
||||
"llama-3.1-8b-instant": {"input": 0.05, "output": 0.08},
|
||||
"llama-4-maverick-17b-128e-instruct": {"input": 0.20, "output": 0.60},
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct": {"input": 0.20, "output": 0.60},
|
||||
"llama-4-scout-17b-16e-instruct": {"input": 0.11, "output": 0.34},
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct": {"input": 0.11, "output": 0.34},
|
||||
"llama-guard-4-12b": {"input": 0.20, "output": 0.20},
|
||||
"meta-llama/llama-guard-4-12b": {"input": 0.20, "output": 0.20},
|
||||
"gpt-oss-120b": {"input": 0.15, "output": 0.60},
|
||||
"openai/gpt-oss-120b": {"input": 0.15, "output": 0.60},
|
||||
"gpt-oss-20b": {"input": 0.075, "output": 0.30},
|
||||
"openai/gpt-oss-20b": {"input": 0.075, "output": 0.30},
|
||||
"kimi-k2-instruct-0905": {"input": 1.00, "output": 3.00},
|
||||
"moonshotai/kimi-k2-instruct-0905": {"input": 1.00, "output": 3.00},
|
||||
"qwen3-32b": {"input": 0.29, "output": 0.59},
|
||||
# Compound models - 50/50 blended pricing from underlying models
|
||||
# compound: GPT-OSS-120B ($0.15/$0.60) + Llama 4 Scout ($0.11/$0.34) = $0.13/$0.47
|
||||
"compound": {"input": 0.13, "output": 0.47},
|
||||
"groq/compound": {"input": 0.13, "output": 0.47},
|
||||
"compound-beta": {"input": 0.13, "output": 0.47},
|
||||
# compound-mini: GPT-OSS-120B ($0.15/$0.60) + Llama 3.3 70B ($0.59/$0.79) = $0.37/$0.695
|
||||
"compound-mini": {"input": 0.37, "output": 0.695},
|
||||
"groq/compound-mini": {"input": 0.37, "output": 0.695},
|
||||
"compound-mini-beta": {"input": 0.37, "output": 0.695},
|
||||
}
|
||||
|
||||
|
||||
class GroqProxyBackend:
|
||||
"""LLM inference via Groq Cloud with HAProxy load balancing"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.client = None
|
||||
self.usage_metrics = {}
|
||||
self.circuit_breaker_status = {}
|
||||
self._initialize_client()
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize Groq client to use HAProxy load balancer"""
|
||||
if not GROQ_AVAILABLE:
|
||||
logger.warning("Groq client not available - running in development mode")
|
||||
return
|
||||
|
||||
if self.settings.groq_api_key:
|
||||
# Use HAProxy load balancer instead of direct Groq API
|
||||
haproxy_endpoint = self.settings.haproxy_groq_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local"
|
||||
|
||||
# Initialize client with HAProxy endpoint
|
||||
self.client = AsyncGroq(
|
||||
api_key=self.settings.groq_api_key,
|
||||
base_url=haproxy_endpoint,
|
||||
timeout=httpx.Timeout(30.0), # Increased timeout for load balancing
|
||||
max_retries=1 # Let HAProxy handle retries
|
||||
)
|
||||
|
||||
# Initialize circuit breaker
|
||||
self.circuit_breaker_status = {
|
||||
"state": "closed", # closed, open, half_open
|
||||
"failure_count": 0,
|
||||
"last_failure_time": None,
|
||||
"failure_threshold": 5,
|
||||
"recovery_timeout": 60 # seconds
|
||||
}
|
||||
|
||||
logger.info(f"Initialized Groq client with HAProxy endpoint: {haproxy_endpoint}")
|
||||
|
||||
async def execute_inference(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str = "llama-3.1-70b-versatile",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4000,
|
||||
stream: bool = False,
|
||||
user_id: str = None,
|
||||
tenant_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute LLM inference with HAProxy load balancing and circuit breaker"""
|
||||
|
||||
# Check circuit breaker
|
||||
if not await self._is_circuit_closed():
|
||||
raise Exception("Circuit breaker is open - service temporarily unavailable")
|
||||
|
||||
# Validate model and get configuration
|
||||
model_configs = get_model_configs(tenant_id)
|
||||
model_config = model_configs.get("groq", {}).get(model)
|
||||
if not model_config:
|
||||
# Try to get from model service registry
|
||||
model_service = get_model_service(tenant_id)
|
||||
model_info = await model_service.get_model(model)
|
||||
if not model_info:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
model_config = {
|
||||
"max_tokens": model_info["performance"]["max_tokens"],
|
||||
"cost_per_1k_tokens": model_info["performance"]["cost_per_1k_tokens"],
|
||||
"supports_streaming": model_info["capabilities"].get("streaming", False)
|
||||
}
|
||||
|
||||
# Apply token limits
|
||||
max_tokens = min(max_tokens, model_config["max_tokens"])
|
||||
|
||||
# Prepare messages
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
# Get tenant-specific API key
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required for Groq inference")
|
||||
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
client = self._get_client(api_key)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if stream:
|
||||
return await self._stream_inference(
|
||||
messages, model, temperature, max_tokens, user_id, tenant_id, client
|
||||
)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False
|
||||
)
|
||||
|
||||
# Track successful usage
|
||||
latency = (time.time() - start_time) * 1000
|
||||
await self._track_usage(
|
||||
user_id, tenant_id, model,
|
||||
response.usage.total_tokens if response.usage else 0,
|
||||
latency, model_config["cost_per_1k_tokens"]
|
||||
)
|
||||
|
||||
# Track in model service
|
||||
model_service = get_model_service(tenant_id)
|
||||
await model_service.track_model_usage(
|
||||
model_id=model,
|
||||
success=True,
|
||||
latency_ms=latency
|
||||
)
|
||||
|
||||
# Reset circuit breaker on success
|
||||
await self._record_success()
|
||||
|
||||
return {
|
||||
"content": response.choices[0].message.content,
|
||||
"model": model,
|
||||
"usage": {
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage else 0,
|
||||
"cost_cents": self._calculate_cost(
|
||||
response.usage.total_tokens if response.usage else 0,
|
||||
model_config["cost_per_1k_tokens"]
|
||||
)
|
||||
},
|
||||
"latency_ms": latency,
|
||||
"load_balanced": True,
|
||||
"haproxy_backend": "groq_general_backend"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HAProxy Groq inference failed: {e}")
|
||||
|
||||
# Track failure in model service
|
||||
await model_service.track_model_usage(
|
||||
model_id=model,
|
||||
success=False
|
||||
)
|
||||
|
||||
# Record failure for circuit breaker
|
||||
await self._record_failure()
|
||||
|
||||
# Re-raise the exception - no client-side fallback needed
|
||||
# HAProxy handles all failover logic
|
||||
raise Exception(f"Groq inference failed (via HAProxy): {str(e)}")
|
||||
|
||||
async def _stream_inference(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
client: AsyncGroq = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream LLM inference responses"""
|
||||
|
||||
model_configs = get_model_configs(tenant_id)
|
||||
model_config = model_configs.get("groq", {}).get(model)
|
||||
start_time = time.time()
|
||||
total_tokens = 0
|
||||
|
||||
try:
|
||||
# Use provided client or get tenant-specific client
|
||||
if not client:
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
client = self._get_client(api_key)
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
total_tokens += len(content.split()) # Approximate token count
|
||||
|
||||
# Yield SSE formatted data
|
||||
yield f"data: {json.dumps({'content': content})}\n\n"
|
||||
|
||||
# Track usage after streaming completes
|
||||
latency = (time.time() - start_time) * 1000
|
||||
await self._track_usage(
|
||||
user_id, tenant_id, model,
|
||||
total_tokens, latency,
|
||||
model_config["cost_per_1k_tokens"]
|
||||
)
|
||||
|
||||
# Send completion signal
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming inference error: {e}")
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
|
||||
async def check_health(self) -> Dict[str, Any]:
|
||||
"""Check health of HAProxy load balancer and circuit breaker status"""
|
||||
|
||||
try:
|
||||
# Check HAProxy health via stats endpoint
|
||||
haproxy_stats_url = self.settings.haproxy_stats_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local:8404/stats"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
haproxy_stats_url,
|
||||
timeout=5.0,
|
||||
auth=("admin", "gt2_haproxy_stats_password")
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Parse HAProxy stats (simplified)
|
||||
stats_healthy = "UP" in response.text
|
||||
|
||||
return {
|
||||
"haproxy_load_balancer": {
|
||||
"healthy": stats_healthy,
|
||||
"stats_accessible": True,
|
||||
"last_check": datetime.utcnow().isoformat()
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": self.circuit_breaker_status["state"],
|
||||
"failure_count": self.circuit_breaker_status["failure_count"],
|
||||
"last_failure": self.circuit_breaker_status["last_failure_time"].isoformat() if self.circuit_breaker_status["last_failure_time"] else None
|
||||
},
|
||||
"groq_endpoints": {
|
||||
"managed_by": "haproxy",
|
||||
"failover_handled_by": "haproxy"
|
||||
}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"haproxy_load_balancer": {
|
||||
"healthy": False,
|
||||
"error": f"Stats endpoint returned {response.status_code}",
|
||||
"last_check": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"haproxy_load_balancer": {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"last_check": datetime.utcnow().isoformat()
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": self.circuit_breaker_status["state"],
|
||||
"failure_count": self.circuit_breaker_status["failure_count"]
|
||||
}
|
||||
}
|
||||
|
||||
async def _is_circuit_closed(self) -> bool:
|
||||
"""Check if circuit breaker allows requests"""
|
||||
|
||||
if self.circuit_breaker_status["state"] == "closed":
|
||||
return True
|
||||
|
||||
if self.circuit_breaker_status["state"] == "open":
|
||||
# Check if recovery timeout has passed
|
||||
if self.circuit_breaker_status["last_failure_time"]:
|
||||
time_since_failure = (datetime.utcnow() - self.circuit_breaker_status["last_failure_time"]).total_seconds()
|
||||
if time_since_failure > self.circuit_breaker_status["recovery_timeout"]:
|
||||
# Move to half-open state
|
||||
self.circuit_breaker_status["state"] = "half_open"
|
||||
logger.info("Circuit breaker moved to half-open state")
|
||||
return True
|
||||
return False
|
||||
|
||||
if self.circuit_breaker_status["state"] == "half_open":
|
||||
# Allow limited requests in half-open state
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _record_success(self):
|
||||
"""Record successful request for circuit breaker"""
|
||||
|
||||
if self.circuit_breaker_status["state"] == "half_open":
|
||||
# Success in half-open state closes the circuit
|
||||
self.circuit_breaker_status["state"] = "closed"
|
||||
self.circuit_breaker_status["failure_count"] = 0
|
||||
logger.info("Circuit breaker closed after successful request")
|
||||
|
||||
# Reset failure count on any success
|
||||
self.circuit_breaker_status["failure_count"] = 0
|
||||
|
||||
async def _record_failure(self):
|
||||
"""Record failed request for circuit breaker"""
|
||||
|
||||
self.circuit_breaker_status["failure_count"] += 1
|
||||
self.circuit_breaker_status["last_failure_time"] = datetime.utcnow()
|
||||
|
||||
if self.circuit_breaker_status["failure_count"] >= self.circuit_breaker_status["failure_threshold"]:
|
||||
if self.circuit_breaker_status["state"] in ["closed", "half_open"]:
|
||||
self.circuit_breaker_status["state"] = "open"
|
||||
logger.warning(f"Circuit breaker opened after {self.circuit_breaker_status['failure_count']} failures")
|
||||
|
||||
async def _track_usage(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
model: str,
|
||||
tokens: int,
|
||||
latency: float,
|
||||
cost_per_1k: float
|
||||
):
|
||||
"""Track usage metrics for billing and monitoring"""
|
||||
|
||||
# Create usage key
|
||||
usage_key = f"{tenant_id}:{user_id}:{model}"
|
||||
|
||||
# Initialize metrics if not exists
|
||||
if usage_key not in self.usage_metrics:
|
||||
self.usage_metrics[usage_key] = {
|
||||
"total_tokens": 0,
|
||||
"total_requests": 0,
|
||||
"total_cost_cents": 0,
|
||||
"average_latency": 0
|
||||
}
|
||||
|
||||
# Update metrics
|
||||
metrics = self.usage_metrics[usage_key]
|
||||
metrics["total_tokens"] += tokens
|
||||
metrics["total_requests"] += 1
|
||||
metrics["total_cost_cents"] += self._calculate_cost(tokens, cost_per_1k)
|
||||
|
||||
# Update average latency
|
||||
prev_avg = metrics["average_latency"]
|
||||
prev_count = metrics["total_requests"] - 1
|
||||
metrics["average_latency"] = (prev_avg * prev_count + latency) / metrics["total_requests"]
|
||||
|
||||
# Log high-level metrics
|
||||
if metrics["total_requests"] % 100 == 0:
|
||||
logger.info(f"Usage milestone for {usage_key}: {metrics}")
|
||||
|
||||
def _calculate_cost(self, tokens: int, cost_per_1k: float) -> int:
|
||||
"""Calculate cost in cents"""
|
||||
return int((tokens / 1000) * cost_per_1k * 100)
|
||||
|
||||
def _calculate_compound_cost(self, response_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate detailed cost breakdown for Groq Compound responses.
|
||||
|
||||
Compound API returns usage_breakdown with per-model token counts
|
||||
and executed_tools list showing which tools were called.
|
||||
|
||||
Returns:
|
||||
Dict with total cost in dollars and detailed breakdown
|
||||
"""
|
||||
total_cost = 0.0
|
||||
breakdown = {"models": [], "tools": [], "total_cost_dollars": 0.0, "total_cost_cents": 0}
|
||||
|
||||
# Parse usage_breakdown for per-model token costs
|
||||
usage_breakdown = response_data.get("usage_breakdown", {})
|
||||
models_usage = usage_breakdown.get("models", [])
|
||||
|
||||
for model_usage in models_usage:
|
||||
model_name = model_usage.get("model", "")
|
||||
usage = model_usage.get("usage", {})
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
# Get model pricing (try multiple name formats)
|
||||
model_prices = GROQ_MODEL_PRICES.get(model_name)
|
||||
if not model_prices:
|
||||
# Try without provider prefix
|
||||
short_name = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
model_prices = GROQ_MODEL_PRICES.get(short_name, {"input": 0.15, "output": 0.60})
|
||||
|
||||
# Calculate cost per million tokens
|
||||
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
|
||||
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
|
||||
model_total = input_cost + output_cost
|
||||
|
||||
breakdown["models"].append({
|
||||
"model": model_name,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"input_cost_dollars": round(input_cost, 6),
|
||||
"output_cost_dollars": round(output_cost, 6),
|
||||
"total_cost_dollars": round(model_total, 6)
|
||||
})
|
||||
total_cost += model_total
|
||||
|
||||
# Parse executed_tools for tool costs
|
||||
executed_tools = response_data.get("executed_tools", [])
|
||||
|
||||
for tool in executed_tools:
|
||||
# Handle both string and dict formats
|
||||
tool_name = tool if isinstance(tool, str) else tool.get("name", "unknown")
|
||||
tool_cost = COMPOUND_TOOL_PRICES.get(tool_name.lower(), 0.008) # Default to advanced search
|
||||
|
||||
breakdown["tools"].append({
|
||||
"tool": tool_name,
|
||||
"cost_dollars": round(tool_cost, 6)
|
||||
})
|
||||
total_cost += tool_cost
|
||||
|
||||
breakdown["total_cost_dollars"] = round(total_cost, 6)
|
||||
breakdown["total_cost_cents"] = int(total_cost * 100)
|
||||
|
||||
return breakdown
|
||||
|
||||
def _is_compound_model(self, model: str) -> bool:
|
||||
"""Check if model is a Groq Compound model"""
|
||||
model_lower = model.lower()
|
||||
return "compound" in model_lower or model_lower.startswith("groq/compound")
|
||||
|
||||
async def get_available_models(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of available Groq models with their configurations"""
|
||||
models = []
|
||||
|
||||
model_configs = get_model_configs()
|
||||
for model_id, config in model_configs.get("groq", {}).items():
|
||||
models.append({
|
||||
"id": model_id,
|
||||
"name": model_id.replace("-", " ").title(),
|
||||
"provider": "groq",
|
||||
"max_tokens": config["max_tokens"],
|
||||
"cost_per_1k_tokens": config["cost_per_1k_tokens"],
|
||||
"supports_streaming": config["supports_streaming"],
|
||||
"supports_function_calling": config["supports_function_calling"]
|
||||
})
|
||||
|
||||
return models
|
||||
|
||||
async def execute_inference_with_messages(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str = "llama-3.1-70b-versatile",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4000,
|
||||
stream: bool = False,
|
||||
user_id: str = None,
|
||||
tenant_id: str = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute LLM inference using messages format (conversation style)"""
|
||||
|
||||
# Check circuit breaker
|
||||
if not await self._is_circuit_closed():
|
||||
raise Exception("Circuit breaker is open - service temporarily unavailable")
|
||||
|
||||
# Validate model and get configuration
|
||||
model_configs = get_model_configs(tenant_id)
|
||||
model_config = model_configs.get("groq", {}).get(model)
|
||||
if not model_config:
|
||||
# Try to get from model service registry
|
||||
model_service = get_model_service(tenant_id)
|
||||
model_info = await model_service.get_model(model)
|
||||
if not model_info:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
model_config = {
|
||||
"max_tokens": model_info["performance"]["max_tokens"],
|
||||
"cost_per_1k_tokens": model_info["performance"]["cost_per_1k_tokens"],
|
||||
"supports_streaming": model_info["capabilities"].get("streaming", False)
|
||||
}
|
||||
|
||||
# Apply token limits
|
||||
max_tokens = min(max_tokens, model_config["max_tokens"])
|
||||
|
||||
try:
|
||||
# Get tenant-specific API key
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required for Groq inference")
|
||||
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
client = self._get_client(api_key)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Translate GT 2.0 "agent" role to OpenAI/Groq "assistant" for external API compatibility
|
||||
# Use dictionary unpacking to preserve ALL fields including tool_call_id
|
||||
external_messages = []
|
||||
for msg in messages:
|
||||
external_msg = {
|
||||
**msg, # Preserve ALL fields including tool_call_id, tool_calls, etc.
|
||||
"role": "assistant" if msg.get("role") == "agent" else msg.get("role")
|
||||
}
|
||||
external_messages.append(external_msg)
|
||||
|
||||
if stream:
|
||||
return await self._stream_inference_with_messages(
|
||||
external_messages, model, temperature, max_tokens, user_id, tenant_id, client
|
||||
)
|
||||
else:
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": external_messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
if tool_choice:
|
||||
request_params["tool_choice"] = tool_choice
|
||||
|
||||
# Debug: Log messages being sent to Groq
|
||||
logger.info(f"🔧 Sending {len(external_messages)} messages to Groq API")
|
||||
for i, msg in enumerate(external_messages):
|
||||
if msg.get("role") == "tool":
|
||||
logger.info(f"🔧 Groq Message {i}: role=tool, tool_call_id={msg.get('tool_call_id')}")
|
||||
else:
|
||||
logger.info(f"🔧 Groq Message {i}: role={msg.get('role')}, has_tool_calls={bool(msg.get('tool_calls'))}")
|
||||
|
||||
response = await client.chat.completions.create(**request_params)
|
||||
|
||||
# Track successful usage
|
||||
latency = (time.time() - start_time) * 1000
|
||||
await self._track_usage(
|
||||
user_id, tenant_id, model,
|
||||
response.usage.total_tokens if response.usage else 0,
|
||||
latency, model_config["cost_per_1k_tokens"]
|
||||
)
|
||||
|
||||
# Track in model service
|
||||
model_service = get_model_service(tenant_id)
|
||||
await model_service.track_model_usage(
|
||||
model_id=model,
|
||||
success=True,
|
||||
latency_ms=latency
|
||||
)
|
||||
|
||||
# Reset circuit breaker on success
|
||||
await self._record_success()
|
||||
|
||||
# Build base response
|
||||
result = {
|
||||
"content": response.choices[0].message.content,
|
||||
"model": model,
|
||||
"usage": {
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage else 0,
|
||||
"cost_cents": self._calculate_cost(
|
||||
response.usage.total_tokens if response.usage else 0,
|
||||
model_config["cost_per_1k_tokens"]
|
||||
)
|
||||
},
|
||||
"latency_ms": latency,
|
||||
"load_balanced": True,
|
||||
"haproxy_backend": "groq_general_backend"
|
||||
}
|
||||
|
||||
# For Compound models, extract and calculate detailed cost breakdown
|
||||
if self._is_compound_model(model):
|
||||
# Convert response to dict for processing
|
||||
response_dict = response.model_dump() if hasattr(response, 'model_dump') else {}
|
||||
|
||||
# Extract usage_breakdown and executed_tools if present
|
||||
usage_breakdown = getattr(response, 'usage_breakdown', None)
|
||||
executed_tools = getattr(response, 'executed_tools', None)
|
||||
|
||||
if usage_breakdown or executed_tools:
|
||||
compound_data = {
|
||||
"usage_breakdown": usage_breakdown if isinstance(usage_breakdown, dict) else {},
|
||||
"executed_tools": executed_tools if isinstance(executed_tools, list) else []
|
||||
}
|
||||
|
||||
# Calculate detailed cost breakdown
|
||||
cost_breakdown = self._calculate_compound_cost(compound_data)
|
||||
|
||||
# Add compound-specific data to response
|
||||
result["usage_breakdown"] = compound_data.get("usage_breakdown", {})
|
||||
result["executed_tools"] = compound_data.get("executed_tools", [])
|
||||
result["cost_breakdown"] = cost_breakdown
|
||||
|
||||
# Update cost_cents with accurate compound calculation
|
||||
if cost_breakdown["total_cost_cents"] > 0:
|
||||
result["usage"]["cost_cents"] = cost_breakdown["total_cost_cents"]
|
||||
|
||||
logger.info(f"Compound model cost breakdown: {cost_breakdown}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HAProxy Groq inference with messages failed: {e}")
|
||||
|
||||
# Track failure in model service
|
||||
await model_service.track_model_usage(
|
||||
model_id=model,
|
||||
success=False
|
||||
)
|
||||
|
||||
# Record failure for circuit breaker
|
||||
await self._record_failure()
|
||||
|
||||
# Re-raise the exception
|
||||
raise Exception(f"Groq inference with messages failed (via HAProxy): {str(e)}")
|
||||
|
||||
async def _stream_inference_with_messages(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
client: AsyncGroq = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream LLM inference responses using messages format"""
|
||||
|
||||
model_configs = get_model_configs(tenant_id)
|
||||
model_config = model_configs.get("groq", {}).get(model)
|
||||
start_time = time.time()
|
||||
total_tokens = 0
|
||||
|
||||
try:
|
||||
# Use provided client or get tenant-specific client
|
||||
if not client:
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
client = self._get_client(api_key)
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
total_tokens += len(content.split()) # Approximate token count
|
||||
|
||||
# Yield just the content (SSE formatting handled by caller)
|
||||
yield content
|
||||
|
||||
# Track usage after streaming completes
|
||||
latency = (time.time() - start_time) * 1000
|
||||
await self._track_usage(
|
||||
user_id, tenant_id, model,
|
||||
total_tokens, latency,
|
||||
model_config["cost_per_1k_tokens"] if model_config else 0.0
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming inference with messages error: {e}")
|
||||
raise e
|
||||
|
||||
async def _get_tenant_api_key(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get API key for tenant from Control Panel database.
|
||||
|
||||
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
|
||||
API keys are managed in Control Panel and fetched via internal API.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant domain string from X-Tenant-ID header
|
||||
|
||||
Returns:
|
||||
Decrypted Groq API key
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key configured (results in HTTP 503 to client)
|
||||
"""
|
||||
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
|
||||
|
||||
client = get_api_key_client()
|
||||
|
||||
try:
|
||||
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="groq")
|
||||
return key_info["api_key"]
|
||||
except APIKeyNotConfiguredError as e:
|
||||
logger.error(f"No Groq API key for tenant '{tenant_id}': {e}")
|
||||
raise ValueError(str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Control Panel error: {e}")
|
||||
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
|
||||
|
||||
def _get_client(self, api_key: str) -> AsyncGroq:
|
||||
"""Get Groq client with specified API key"""
|
||||
if not GROQ_AVAILABLE:
|
||||
raise Exception("Groq client not available in development mode")
|
||||
|
||||
haproxy_endpoint = self.settings.haproxy_groq_endpoint or "http://haproxy-groq-lb-service.gt-resource.svc.cluster.local"
|
||||
|
||||
return AsyncGroq(
|
||||
api_key=api_key,
|
||||
base_url=haproxy_endpoint,
|
||||
timeout=httpx.Timeout(30.0),
|
||||
max_retries=1
|
||||
)
|
||||
407
apps/resource-cluster/app/core/backends/nvidia_proxy.py
Normal file
407
apps/resource-cluster/app/core/backends/nvidia_proxy.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
NVIDIA NIM LLM Proxy Backend
|
||||
|
||||
Provides LLM inference through NVIDIA NIM with:
|
||||
- OpenAI-compatible API format (build.nvidia.com)
|
||||
- Token usage tracking and cost calculation
|
||||
- Streaming response support
|
||||
- Circuit breaker pattern for enhanced reliability
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NVIDIA NIM Model pricing per million tokens (input/output)
|
||||
# Source: build.nvidia.com (Dec 2025 pricing estimates)
|
||||
# Note: Actual pricing may vary - check build.nvidia.com for current rates
|
||||
NVIDIA_MODEL_PRICES = {
|
||||
# Llama Nemotron family
|
||||
"nvidia/llama-3.1-nemotron-ultra-253b-v1": {"input": 2.0, "output": 6.0},
|
||||
"nvidia/llama-3.1-nemotron-super-49b-v1": {"input": 0.5, "output": 1.5},
|
||||
"nvidia/llama-3.1-nemotron-nano-8b-v1": {"input": 0.1, "output": 0.3},
|
||||
# Standard Llama models via NIM
|
||||
"meta/llama-3.1-8b-instruct": {"input": 0.1, "output": 0.3},
|
||||
"meta/llama-3.1-70b-instruct": {"input": 0.5, "output": 1.0},
|
||||
"meta/llama-3.1-405b-instruct": {"input": 2.0, "output": 6.0},
|
||||
# Mistral models
|
||||
"mistralai/mistral-7b-instruct-v0.3": {"input": 0.1, "output": 0.2},
|
||||
"mistralai/mixtral-8x7b-instruct-v0.1": {"input": 0.3, "output": 0.6},
|
||||
# Default fallback
|
||||
"default": {"input": 0.5, "output": 1.5},
|
||||
}
|
||||
|
||||
|
||||
class NvidiaProxyBackend:
|
||||
"""LLM inference via NVIDIA NIM with OpenAI-compatible API"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.base_url = getattr(self.settings, 'nvidia_nim_endpoint', None) or "https://integrate.api.nvidia.com/v1"
|
||||
self.usage_metrics = {}
|
||||
self.circuit_breaker_status = {
|
||||
"state": "closed", # closed, open, half_open
|
||||
"failure_count": 0,
|
||||
"last_failure_time": None,
|
||||
"failure_threshold": 5,
|
||||
"recovery_timeout": 60 # seconds
|
||||
}
|
||||
logger.info(f"Initialized NVIDIA NIM backend with endpoint: {self.base_url}")
|
||||
|
||||
async def _get_tenant_api_key(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get API key for tenant from Control Panel database.
|
||||
|
||||
NO environment variable fallback - per GT 2.0 NO FALLBACKS principle.
|
||||
API keys are managed in Control Panel and fetched via internal API.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant domain string from X-Tenant-ID header
|
||||
|
||||
Returns:
|
||||
Decrypted NVIDIA API key
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key configured (results in HTTP 503 to client)
|
||||
"""
|
||||
from app.clients.api_key_client import get_api_key_client, APIKeyNotConfiguredError
|
||||
|
||||
client = get_api_key_client()
|
||||
|
||||
try:
|
||||
key_info = await client.get_api_key(tenant_domain=tenant_id, provider="nvidia")
|
||||
return key_info["api_key"]
|
||||
except APIKeyNotConfiguredError as e:
|
||||
logger.error(f"No NVIDIA API key for tenant '{tenant_id}': {e}")
|
||||
raise ValueError(str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Control Panel error: {e}")
|
||||
raise ValueError(f"Unable to retrieve API key - service unavailable: {e}")
|
||||
|
||||
def _get_client(self, api_key: str) -> httpx.AsyncClient:
|
||||
"""Get configured HTTP client for NVIDIA NIM API"""
|
||||
return httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
timeout=httpx.Timeout(120.0) # Longer timeout for large models
|
||||
)
|
||||
|
||||
async def execute_inference(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str = "nvidia/llama-3.1-nemotron-super-49b-v1",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4000,
|
||||
stream: bool = False,
|
||||
user_id: str = None,
|
||||
tenant_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute LLM inference with simple prompt"""
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
return await self.execute_inference_with_messages(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=stream,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
async def execute_inference_with_messages(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str = "nvidia/llama-3.1-nemotron-super-49b-v1",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4000,
|
||||
stream: bool = False,
|
||||
user_id: str = None,
|
||||
tenant_id: str = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute LLM inference using messages format (conversation style)"""
|
||||
|
||||
# Check circuit breaker
|
||||
if not await self._is_circuit_closed():
|
||||
raise Exception("Circuit breaker is open - NVIDIA NIM service temporarily unavailable")
|
||||
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required for NVIDIA NIM inference")
|
||||
|
||||
try:
|
||||
api_key = await self._get_tenant_api_key(tenant_id)
|
||||
|
||||
# Translate GT 2.0 "agent" role to OpenAI "assistant" for external API compatibility
|
||||
external_messages = []
|
||||
for msg in messages:
|
||||
external_msg = {
|
||||
**msg, # Preserve ALL fields including tool_call_id, tool_calls, etc.
|
||||
"role": "assistant" if msg.get("role") == "agent" else msg.get("role")
|
||||
}
|
||||
external_messages.append(external_msg)
|
||||
|
||||
# Build request payload
|
||||
request_data = {
|
||||
"model": model,
|
||||
"messages": external_messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
request_data["tools"] = tools
|
||||
if tool_choice:
|
||||
request_data["tool_choice"] = tool_choice
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
async with self._get_client(api_key) as client:
|
||||
if stream:
|
||||
# Return generator for streaming
|
||||
return self._stream_inference_with_messages(
|
||||
client, request_data, user_id, tenant_id, model
|
||||
)
|
||||
|
||||
# Non-streaming request
|
||||
response = await client.post("/chat/completions", json=request_data)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
latency = (time.time() - start_time) * 1000
|
||||
|
||||
# Calculate cost
|
||||
usage = data.get("usage", {})
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)
|
||||
|
||||
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
|
||||
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
|
||||
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
|
||||
cost_cents = int((input_cost + output_cost) * 100)
|
||||
|
||||
# Track usage
|
||||
await self._track_usage(user_id, tenant_id, model, total_tokens, latency, cost_cents)
|
||||
|
||||
# Reset circuit breaker on success
|
||||
await self._record_success()
|
||||
|
||||
# Build response
|
||||
result = {
|
||||
"content": data["choices"][0]["message"]["content"],
|
||||
"model": model,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"cost_cents": cost_cents
|
||||
},
|
||||
"latency_ms": latency,
|
||||
"provider": "nvidia"
|
||||
}
|
||||
|
||||
# Include tool calls if present
|
||||
message = data["choices"][0]["message"]
|
||||
if message.get("tool_calls"):
|
||||
result["tool_calls"] = message["tool_calls"]
|
||||
|
||||
return result
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"NVIDIA NIM API error: {e.response.status_code} - {e.response.text}")
|
||||
await self._record_failure()
|
||||
raise Exception(f"NVIDIA NIM inference failed: HTTP {e.response.status_code}")
|
||||
except Exception as e:
|
||||
logger.error(f"NVIDIA NIM inference failed: {e}")
|
||||
await self._record_failure()
|
||||
raise Exception(f"NVIDIA NIM inference failed: {str(e)}")
|
||||
|
||||
async def _stream_inference_with_messages(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
request_data: Dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
model: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream LLM inference responses"""
|
||||
|
||||
start_time = time.time()
|
||||
total_tokens = 0
|
||||
|
||||
try:
|
||||
async with client.stream("POST", "/chat/completions", json=request_data) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
if chunk.get("choices") and chunk["choices"][0].get("delta", {}).get("content"):
|
||||
content = chunk["choices"][0]["delta"]["content"]
|
||||
total_tokens += len(content.split()) # Approximate
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Track usage after streaming completes
|
||||
latency = (time.time() - start_time) * 1000
|
||||
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
|
||||
cost_cents = int((total_tokens / 1_000_000) * model_prices["output"] * 100)
|
||||
await self._track_usage(user_id, tenant_id, model, total_tokens, latency, cost_cents)
|
||||
|
||||
await self._record_success()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"NVIDIA NIM streaming error: {e}")
|
||||
await self._record_failure()
|
||||
raise e
|
||||
|
||||
async def check_health(self) -> Dict[str, Any]:
|
||||
"""Check health of NVIDIA NIM backend and circuit breaker status"""
|
||||
|
||||
return {
|
||||
"nvidia_nim": {
|
||||
"endpoint": self.base_url,
|
||||
"status": "available" if self.circuit_breaker_status["state"] == "closed" else "degraded",
|
||||
"last_check": datetime.utcnow().isoformat()
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": self.circuit_breaker_status["state"],
|
||||
"failure_count": self.circuit_breaker_status["failure_count"],
|
||||
"last_failure": self.circuit_breaker_status["last_failure_time"].isoformat()
|
||||
if self.circuit_breaker_status["last_failure_time"] else None
|
||||
}
|
||||
}
|
||||
|
||||
async def _is_circuit_closed(self) -> bool:
|
||||
"""Check if circuit breaker allows requests"""
|
||||
|
||||
if self.circuit_breaker_status["state"] == "closed":
|
||||
return True
|
||||
|
||||
if self.circuit_breaker_status["state"] == "open":
|
||||
# Check if recovery timeout has passed
|
||||
if self.circuit_breaker_status["last_failure_time"]:
|
||||
time_since_failure = (datetime.utcnow() - self.circuit_breaker_status["last_failure_time"]).total_seconds()
|
||||
if time_since_failure > self.circuit_breaker_status["recovery_timeout"]:
|
||||
# Move to half-open state
|
||||
self.circuit_breaker_status["state"] = "half_open"
|
||||
logger.info("NVIDIA NIM circuit breaker moved to half-open state")
|
||||
return True
|
||||
return False
|
||||
|
||||
if self.circuit_breaker_status["state"] == "half_open":
|
||||
# Allow limited requests in half-open state
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _record_success(self):
|
||||
"""Record successful request for circuit breaker"""
|
||||
|
||||
if self.circuit_breaker_status["state"] == "half_open":
|
||||
# Success in half-open state closes the circuit
|
||||
self.circuit_breaker_status["state"] = "closed"
|
||||
self.circuit_breaker_status["failure_count"] = 0
|
||||
logger.info("NVIDIA NIM circuit breaker closed after successful request")
|
||||
|
||||
# Reset failure count on any success
|
||||
self.circuit_breaker_status["failure_count"] = 0
|
||||
|
||||
async def _record_failure(self):
|
||||
"""Record failed request for circuit breaker"""
|
||||
|
||||
self.circuit_breaker_status["failure_count"] += 1
|
||||
self.circuit_breaker_status["last_failure_time"] = datetime.utcnow()
|
||||
|
||||
if self.circuit_breaker_status["failure_count"] >= self.circuit_breaker_status["failure_threshold"]:
|
||||
if self.circuit_breaker_status["state"] in ["closed", "half_open"]:
|
||||
self.circuit_breaker_status["state"] = "open"
|
||||
logger.warning(f"NVIDIA NIM circuit breaker opened after {self.circuit_breaker_status['failure_count']} failures")
|
||||
|
||||
async def _track_usage(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
model: str,
|
||||
tokens: int,
|
||||
latency: float,
|
||||
cost_cents: int
|
||||
):
|
||||
"""Track usage metrics for billing and monitoring"""
|
||||
|
||||
# Create usage key
|
||||
usage_key = f"{tenant_id}:{user_id}:{model}"
|
||||
|
||||
# Initialize metrics if not exists
|
||||
if usage_key not in self.usage_metrics:
|
||||
self.usage_metrics[usage_key] = {
|
||||
"total_tokens": 0,
|
||||
"total_requests": 0,
|
||||
"total_cost_cents": 0,
|
||||
"average_latency": 0
|
||||
}
|
||||
|
||||
# Update metrics
|
||||
metrics = self.usage_metrics[usage_key]
|
||||
metrics["total_tokens"] += tokens
|
||||
metrics["total_requests"] += 1
|
||||
metrics["total_cost_cents"] += cost_cents
|
||||
|
||||
# Update average latency
|
||||
prev_avg = metrics["average_latency"]
|
||||
prev_count = metrics["total_requests"] - 1
|
||||
metrics["average_latency"] = (prev_avg * prev_count + latency) / metrics["total_requests"]
|
||||
|
||||
# Log high-level metrics periodically
|
||||
if metrics["total_requests"] % 100 == 0:
|
||||
logger.info(f"NVIDIA NIM usage milestone for {usage_key}: {metrics}")
|
||||
|
||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int, model: str) -> int:
|
||||
"""Calculate cost in cents based on token usage"""
|
||||
model_prices = NVIDIA_MODEL_PRICES.get(model, NVIDIA_MODEL_PRICES["default"])
|
||||
input_cost = (prompt_tokens / 1_000_000) * model_prices["input"]
|
||||
output_cost = (completion_tokens / 1_000_000) * model_prices["output"]
|
||||
return int((input_cost + output_cost) * 100)
|
||||
|
||||
async def get_available_models(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of available NVIDIA NIM models with their configurations"""
|
||||
models = []
|
||||
|
||||
for model_id, prices in NVIDIA_MODEL_PRICES.items():
|
||||
if model_id == "default":
|
||||
continue
|
||||
|
||||
models.append({
|
||||
"id": model_id,
|
||||
"name": model_id.split("/")[-1].replace("-", " ").title(),
|
||||
"provider": "nvidia",
|
||||
"max_tokens": 4096, # Default for most NIM models
|
||||
"cost_per_1k_input": prices["input"],
|
||||
"cost_per_1k_output": prices["output"],
|
||||
"supports_streaming": True,
|
||||
"supports_function_calling": True
|
||||
})
|
||||
|
||||
return models
|
||||
Reference in New Issue
Block a user