Files
gt-ai-os-community/apps/resource-cluster/app/api/embeddings.py
HackWeasel 310491a557 GT AI OS Community v2.0.33 - Add NVIDIA NIM and Nemotron agents
- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2
- Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2
  - Made more general-purpose (flexible targets, expanded tools)
- Added nemotron-mini-agent.csv for fast local inference via Ollama
- Added nemotron-agent.csv for advanced reasoning via Ollama
- Added wiki page: Projects for NVIDIA NIMs and Nemotron
2025-12-12 17:47:14 -05:00

333 lines
11 KiB
Python

"""
Embedding Generation API Endpoints for GT 2.0 Resource Cluster
Provides OpenAI-compatible embedding API with:
- BGE-M3 model integration
- Capability-based authentication
- Rate limiting and quota management
- Batch processing support
- Stateless operation
GT 2.0 Architecture Principles:
- Perfect Tenant Isolation: Per-request capability validation
- Zero Downtime: Stateless design, no persistent state
- Self-Contained Security: JWT capability tokens
"""
from fastapi import APIRouter, HTTPException, Depends, Header, Request
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field
import logging
import asyncio
from datetime import datetime
from app.core.security import capability_validator, CapabilityToken
from app.api.auth import verify_capability
from app.services.embedding_service import get_embedding_service, EmbeddingService
from app.core.capability_auth import CapabilityError
router = APIRouter()
logger = logging.getLogger(__name__)
# OpenAI-compatible request/response models
class EmbeddingRequest(BaseModel):
"""OpenAI-compatible embedding request"""
input: List[str] = Field(..., description="List of texts to embed")
model: str = Field(default="BAAI/bge-m3", description="Embedding model name")
encoding_format: str = Field(default="float", description="Encoding format (float)")
dimensions: Optional[int] = Field(None, description="Number of dimensions (auto-detected)")
user: Optional[str] = Field(None, description="User identifier")
# BGE-M3 specific parameters
instruction: Optional[str] = Field(None, description="Instruction for query/document context")
normalize: bool = Field(True, description="Normalize embeddings to unit length")
class EmbeddingData(BaseModel):
"""Single embedding data object"""
object: str = "embedding"
embedding: List[float] = Field(..., description="Embedding vector")
index: int = Field(..., description="Index of the embedding in the input")
class EmbeddingUsage(BaseModel):
"""Token usage information"""
prompt_tokens: int = Field(..., description="Tokens in the input")
total_tokens: int = Field(..., description="Total tokens processed")
class EmbeddingResponse(BaseModel):
"""OpenAI-compatible embedding response"""
object: str = "list"
data: List[EmbeddingData] = Field(..., description="List of embedding objects")
model: str = Field(..., description="Model used for embeddings")
usage: EmbeddingUsage = Field(..., description="Token usage information")
# GT 2.0 specific metadata
gt2_metadata: Dict[str, Any] = Field(default_factory=dict, description="GT 2.0 processing metadata")
class EmbeddingModelInfo(BaseModel):
"""Embedding model information"""
model_name: str
dimensions: int
max_sequence_length: int
max_batch_size: int
supports_instruction: bool
normalization_default: bool
class ServiceHealthResponse(BaseModel):
"""Service health response"""
status: str
service: str
model: str
backend_ready: bool
last_request: Optional[str]
class BGE_M3_ConfigRequest(BaseModel):
"""BGE-M3 configuration update request"""
is_local_mode: bool = True
external_endpoint: Optional[str] = None
class BGE_M3_ConfigResponse(BaseModel):
"""BGE-M3 configuration response"""
is_local_mode: bool
current_endpoint: str
external_endpoint: Optional[str]
message: str
@router.post("/", response_model=EmbeddingResponse)
async def create_embeddings(
request: EmbeddingRequest,
token: CapabilityToken = Depends(verify_capability),
x_request_id: Optional[str] = Header(None)
) -> EmbeddingResponse:
"""
Generate embeddings for input texts using BGE-M3 model.
Compatible with OpenAI Embeddings API format.
Requires capability token with 'embeddings' permissions.
"""
try:
# Get embedding service
embedding_service = get_embedding_service()
# Generate embeddings
result = await embedding_service.generate_embeddings(
texts=request.input,
capability_token=token.token, # Pass raw token for verification
instruction=request.instruction,
request_id=x_request_id,
normalize=request.normalize
)
# Convert to OpenAI-compatible format
embedding_data = [
EmbeddingData(
embedding=embedding,
index=i
)
for i, embedding in enumerate(result.embeddings)
]
usage = EmbeddingUsage(
prompt_tokens=result.tokens_used,
total_tokens=result.tokens_used
)
response = EmbeddingResponse(
data=embedding_data,
model=result.model,
usage=usage,
gt2_metadata={
"request_id": result.request_id,
"tenant_id": result.tenant_id,
"processing_time_ms": result.processing_time_ms,
"dimensions": result.dimensions,
"created_at": result.created_at
}
)
logger.info(
f"Generated {len(result.embeddings)} embeddings "
f"for tenant {result.tenant_id} in {result.processing_time_ms}ms"
)
return response
except CapabilityError as e:
logger.warning(f"Capability error: {e}")
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
logger.warning(f"Invalid request: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/models", response_model=EmbeddingModelInfo)
async def get_model_info(
token: CapabilityToken = Depends(verify_capability)
) -> EmbeddingModelInfo:
"""
Get information about the embedding model.
Requires capability token with 'embeddings' permissions.
"""
try:
embedding_service = get_embedding_service()
model_info = await embedding_service.get_model_info()
return EmbeddingModelInfo(**model_info)
except CapabilityError as e:
logger.warning(f"Capability error: {e}")
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Error getting model info: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/stats")
async def get_service_stats(
token: CapabilityToken = Depends(verify_capability)
) -> Dict[str, Any]:
"""
Get embedding service statistics.
Requires capability token with 'admin' permissions.
"""
try:
embedding_service = get_embedding_service()
stats = await embedding_service.get_service_stats(token.token)
return stats
except CapabilityError as e:
logger.warning(f"Capability error: {e}")
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Error getting service stats: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/health", response_model=ServiceHealthResponse)
async def health_check() -> ServiceHealthResponse:
"""
Check embedding service health.
Public endpoint - no authentication required.
"""
try:
embedding_service = get_embedding_service()
health = await embedding_service.health_check()
return ServiceHealthResponse(**health)
except Exception as e:
logger.error(f"Health check failed: {e}")
raise HTTPException(status_code=500, detail="Service unhealthy")
@router.post("/config/bge-m3", response_model=BGE_M3_ConfigResponse)
async def update_bge_m3_config(
config_request: BGE_M3_ConfigRequest,
token: CapabilityToken = Depends(verify_capability)
) -> BGE_M3_ConfigResponse:
"""
Update BGE-M3 configuration for the embedding service.
This allows switching between local and external endpoints at runtime.
Requires capability token with 'admin' permissions.
"""
try:
# Verify admin permissions
if not token.payload.get("admin", False):
raise HTTPException(status_code=403, detail="Admin permissions required")
embedding_service = get_embedding_service()
# Update the embedding backend configuration
backend = embedding_service.backend
await backend.update_endpoint_config(
is_local_mode=config_request.is_local_mode,
external_endpoint=config_request.external_endpoint
)
logger.info(
f"BGE-M3 configuration updated by {token.payload.get('tenant_id', 'unknown')}: "
f"local_mode={config_request.is_local_mode}, "
f"external_endpoint={config_request.external_endpoint}"
)
return BGE_M3_ConfigResponse(
is_local_mode=config_request.is_local_mode,
current_endpoint=backend.embedding_endpoint,
external_endpoint=config_request.external_endpoint,
message=f"BGE-M3 configuration updated to {'local' if config_request.is_local_mode else 'external'} mode"
)
except CapabilityError as e:
logger.warning(f"Capability error: {e}")
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Error updating BGE-M3 config: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/config/bge-m3", response_model=BGE_M3_ConfigResponse)
async def get_bge_m3_config(
token: CapabilityToken = Depends(verify_capability)
) -> BGE_M3_ConfigResponse:
"""
Get current BGE-M3 configuration.
Requires capability token with 'embeddings' permissions.
"""
try:
embedding_service = get_embedding_service()
backend = embedding_service.backend
# Determine if currently in local mode
is_local_mode = "gentwo-vllm-embeddings" in backend.embedding_endpoint
return BGE_M3_ConfigResponse(
is_local_mode=is_local_mode,
current_endpoint=backend.embedding_endpoint,
external_endpoint=None, # We don't store this currently
message="Current BGE-M3 configuration"
)
except CapabilityError as e:
logger.warning(f"Capability error: {e}")
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Error getting BGE-M3 config: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# Legacy endpoint compatibility
@router.post("/embeddings", response_model=EmbeddingResponse)
async def create_embeddings_legacy(
request: EmbeddingRequest,
token: CapabilityToken = Depends(verify_capability),
x_request_id: Optional[str] = Header(None)
) -> EmbeddingResponse:
"""
Legacy endpoint for embedding generation.
Redirects to main embedding endpoint for compatibility.
"""
return await create_embeddings(request, token, x_request_id)