Security hardening release addressing CodeQL and Dependabot alerts: - Fix stack trace exposure in error responses - Add SSRF protection with DNS resolution checking - Implement proper URL hostname validation (replaces substring matching) - Add centralized path sanitization to prevent path traversal - Fix ReDoS vulnerability in email validation regex - Improve HTML sanitization in validation utilities - Fix capability wildcard matching in auth utilities - Update glob dependency to address CVE - Add CodeQL suppression comments for verified false positives 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
447 lines
16 KiB
Python
447 lines
16 KiB
Python
"""
|
|
Embedding Service for GT 2.0 Resource Cluster
|
|
|
|
Provides embedding generation with:
|
|
- BGE-M3 model integration
|
|
- Batch processing capabilities
|
|
- Rate limiting and quota management
|
|
- Capability-based authentication
|
|
- Stateless operation (no data storage)
|
|
|
|
GT 2.0 Architecture Principles:
|
|
- Perfect Tenant Isolation: Per-request capability validation
|
|
- Zero Downtime: Stateless design, circuit breakers
|
|
- Self-Contained Security: Capability-based auth
|
|
- No Complexity Addition: Simple interface, no database
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import os
|
|
from typing import Dict, Any, List, Optional
|
|
from datetime import datetime, timedelta
|
|
from dataclasses import dataclass, asdict
|
|
import uuid
|
|
|
|
from app.core.backends.embedding_backend import EmbeddingBackend, EmbeddingRequest
|
|
from app.core.capability_auth import verify_capability_token, CapabilityError
|
|
from app.core.config import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
settings = get_settings()
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingResponse:
|
|
"""Response structure for embedding generation"""
|
|
request_id: str
|
|
embeddings: List[List[float]]
|
|
model: str
|
|
dimensions: int
|
|
tokens_used: int
|
|
processing_time_ms: int
|
|
tenant_id: str
|
|
created_at: str
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingStats:
|
|
"""Statistics for embedding requests"""
|
|
total_requests: int = 0
|
|
total_tokens_processed: int = 0
|
|
total_processing_time_ms: int = 0
|
|
average_processing_time_ms: float = 0.0
|
|
last_request_at: Optional[str] = None
|
|
|
|
|
|
class EmbeddingService:
|
|
"""
|
|
STATELESS embedding service for GT 2.0 Resource Cluster.
|
|
|
|
Key features:
|
|
- BGE-M3 model for high-quality embeddings
|
|
- Batch processing for efficiency
|
|
- Rate limiting per capability token
|
|
- Memory-conscious processing
|
|
- No persistent storage
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.backend = EmbeddingBackend()
|
|
self.stats = EmbeddingStats()
|
|
|
|
# Initialize BGE-M3 tokenizer for accurate token counting
|
|
try:
|
|
from transformers import AutoTokenizer
|
|
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
|
|
logger.info("Initialized BGE-M3 tokenizer for accurate token counting")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load BGE-M3 tokenizer, using word estimation: {e}")
|
|
self.tokenizer = None
|
|
|
|
# Rate limiting settings (per capability token)
|
|
self.rate_limits = {
|
|
"requests_per_minute": 60,
|
|
"tokens_per_minute": 50000,
|
|
"max_batch_size": 32
|
|
}
|
|
|
|
# Track requests for rate limiting (in-memory, temporary)
|
|
self._request_tracker = {}
|
|
|
|
logger.info("STATELESS embedding service initialized")
|
|
|
|
async def generate_embeddings(
|
|
self,
|
|
texts: List[str],
|
|
capability_token: str,
|
|
instruction: Optional[str] = None,
|
|
request_id: Optional[str] = None,
|
|
normalize: bool = True
|
|
) -> EmbeddingResponse:
|
|
"""
|
|
Generate embeddings with capability-based authentication.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
capability_token: JWT token with embedding permissions
|
|
instruction: Optional instruction for embedding context
|
|
request_id: Optional request ID for tracking
|
|
normalize: Whether to normalize embeddings
|
|
|
|
Returns:
|
|
EmbeddingResponse with generated embeddings
|
|
|
|
Raises:
|
|
CapabilityError: If token invalid or insufficient permissions
|
|
ValueError: If request parameters invalid
|
|
"""
|
|
start_time = time.time()
|
|
request_id = request_id or str(uuid.uuid4())
|
|
|
|
try:
|
|
# Verify capability token and extract permissions
|
|
capability = await verify_capability_token(capability_token)
|
|
tenant_id = capability.get("tenant_id")
|
|
user_id = capability.get("sub") # Extract user ID from token
|
|
|
|
# Check embedding permissions
|
|
await self._verify_embedding_permissions(capability, len(texts))
|
|
|
|
# Apply rate limiting
|
|
await self._check_rate_limits(capability_token, len(texts))
|
|
|
|
# Validate input
|
|
self._validate_embedding_request(texts)
|
|
|
|
# Generate embeddings via backend
|
|
embeddings = await self.backend.generate_embeddings(
|
|
texts=texts,
|
|
instruction=instruction,
|
|
tenant_id=tenant_id,
|
|
request_id=request_id
|
|
)
|
|
|
|
# Calculate processing metrics
|
|
processing_time_ms = int((time.time() - start_time) * 1000)
|
|
total_tokens = self._estimate_tokens(texts)
|
|
|
|
# Update statistics
|
|
self._update_stats(total_tokens, processing_time_ms)
|
|
|
|
# Log embedding usage for billing (non-blocking)
|
|
# Fire and forget - don't wait for completion
|
|
asyncio.create_task(
|
|
self._log_embedding_usage(
|
|
tenant_id=tenant_id,
|
|
user_id=user_id,
|
|
tokens_used=total_tokens,
|
|
embedding_count=len(embeddings),
|
|
model=self.backend.model_name,
|
|
request_id=request_id
|
|
)
|
|
)
|
|
|
|
# Create response
|
|
response = EmbeddingResponse(
|
|
request_id=request_id,
|
|
embeddings=embeddings,
|
|
model=self.backend.model_name,
|
|
dimensions=self.backend.embedding_dimensions,
|
|
tokens_used=total_tokens,
|
|
processing_time_ms=processing_time_ms,
|
|
tenant_id=tenant_id,
|
|
created_at=datetime.utcnow().isoformat()
|
|
)
|
|
|
|
logger.info(
|
|
f"Generated {len(embeddings)} embeddings for tenant {tenant_id} "
|
|
f"in {processing_time_ms}ms"
|
|
)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating embeddings: {e}")
|
|
raise
|
|
finally:
|
|
# Always ensure cleanup
|
|
if 'texts' in locals():
|
|
del texts
|
|
|
|
async def get_model_info(self) -> Dict[str, Any]:
|
|
"""Get information about the embedding model"""
|
|
return {
|
|
"model_name": self.backend.model_name,
|
|
"dimensions": self.backend.embedding_dimensions,
|
|
"max_sequence_length": self.backend.max_sequence_length,
|
|
"max_batch_size": self.backend.max_batch_size,
|
|
"supports_instruction": True,
|
|
"normalization_default": True
|
|
}
|
|
|
|
async def get_service_stats(
|
|
self,
|
|
capability_token: str
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Get service statistics (for admin users only).
|
|
|
|
Args:
|
|
capability_token: JWT token with admin permissions
|
|
|
|
Returns:
|
|
Service statistics
|
|
"""
|
|
# Verify admin permissions
|
|
capability = await verify_capability_token(capability_token)
|
|
if not self._has_admin_permissions(capability):
|
|
raise CapabilityError("Admin permissions required")
|
|
|
|
return {
|
|
"model_info": await self.get_model_info(),
|
|
"statistics": asdict(self.stats),
|
|
"rate_limits": self.rate_limits,
|
|
"active_requests": len(self._request_tracker)
|
|
}
|
|
|
|
async def health_check(self) -> Dict[str, Any]:
|
|
"""Check service health"""
|
|
return {
|
|
"status": "healthy",
|
|
"service": "embedding_service",
|
|
"model": self.backend.model_name,
|
|
"backend_ready": True,
|
|
"last_request": self.stats.last_request_at
|
|
}
|
|
|
|
async def _verify_embedding_permissions(
|
|
self,
|
|
capability: Dict[str, Any],
|
|
text_count: int
|
|
) -> None:
|
|
"""Verify that capability token has embedding permissions"""
|
|
|
|
# Check for embedding capability
|
|
capabilities = capability.get("capabilities", [])
|
|
embedding_caps = [
|
|
cap for cap in capabilities
|
|
if cap.get("resource") == "embeddings"
|
|
]
|
|
|
|
if not embedding_caps:
|
|
raise CapabilityError("No embedding permissions in capability token")
|
|
|
|
# Check constraints
|
|
embedding_cap = embedding_caps[0] # Use first embedding capability
|
|
constraints = embedding_cap.get("constraints", {})
|
|
|
|
# Check batch size limit
|
|
max_batch = constraints.get("max_batch_size", self.rate_limits["max_batch_size"])
|
|
if text_count > max_batch:
|
|
raise CapabilityError(f"Batch size {text_count} exceeds limit {max_batch}")
|
|
|
|
# Check rate limits
|
|
rate_limit = constraints.get("rate_limit_per_minute", self.rate_limits["requests_per_minute"])
|
|
token_limit = constraints.get("tokens_per_minute", self.rate_limits["tokens_per_minute"])
|
|
|
|
logger.debug(f"Embedding permissions verified: batch={text_count}, limits=({rate_limit}, {token_limit})")
|
|
|
|
async def _check_rate_limits(
|
|
self,
|
|
capability_token: str,
|
|
text_count: int
|
|
) -> None:
|
|
"""Check rate limits for capability token"""
|
|
|
|
now = time.time()
|
|
token_hash = hash(capability_token) % 10000 # Simple tracking key
|
|
|
|
# Clean old entries (older than 1 minute)
|
|
cleanup_time = now - 60
|
|
self._request_tracker = {
|
|
k: v for k, v in self._request_tracker.items()
|
|
if v.get("last_request", 0) > cleanup_time
|
|
}
|
|
|
|
# Get or create tracker for this token
|
|
if token_hash not in self._request_tracker:
|
|
self._request_tracker[token_hash] = {
|
|
"requests": 0,
|
|
"tokens": 0,
|
|
"last_request": now
|
|
}
|
|
|
|
tracker = self._request_tracker[token_hash]
|
|
|
|
# Check request rate limit
|
|
if tracker["requests"] >= self.rate_limits["requests_per_minute"]:
|
|
raise CapabilityError("Rate limit exceeded: too many requests per minute")
|
|
|
|
# Estimate tokens and check token limit
|
|
estimated_tokens = self._estimate_tokens([f"text_{i}" for i in range(text_count)])
|
|
if tracker["tokens"] + estimated_tokens > self.rate_limits["tokens_per_minute"]:
|
|
raise CapabilityError("Rate limit exceeded: too many tokens per minute")
|
|
|
|
# Update tracker
|
|
tracker["requests"] += 1
|
|
tracker["tokens"] += estimated_tokens
|
|
tracker["last_request"] = now
|
|
|
|
def _validate_embedding_request(self, texts: List[str]) -> None:
|
|
"""Validate embedding request parameters"""
|
|
|
|
if not texts:
|
|
raise ValueError("No texts provided for embedding")
|
|
|
|
if not isinstance(texts, list):
|
|
raise ValueError("Texts must be a list")
|
|
|
|
if len(texts) > self.backend.max_batch_size:
|
|
raise ValueError(f"Batch size {len(texts)} exceeds maximum {self.backend.max_batch_size}")
|
|
|
|
# Check individual text lengths
|
|
for i, text in enumerate(texts):
|
|
if not isinstance(text, str):
|
|
raise ValueError(f"Text at index {i} must be a string")
|
|
|
|
if len(text.strip()) == 0:
|
|
raise ValueError(f"Text at index {i} is empty")
|
|
|
|
# Simple token estimation for length check
|
|
estimated_tokens = len(text.split()) * 1.3 # Rough estimation
|
|
if estimated_tokens > self.backend.max_sequence_length:
|
|
raise ValueError(f"Text at index {i} exceeds maximum length")
|
|
|
|
def _estimate_tokens(self, texts: List[str]) -> int:
|
|
"""
|
|
Count tokens using actual BGE-M3 tokenizer.
|
|
Falls back to word-count estimation if tokenizer unavailable.
|
|
"""
|
|
if self.tokenizer is not None:
|
|
try:
|
|
total_tokens = 0
|
|
for text in texts:
|
|
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
|
total_tokens += len(tokens)
|
|
return total_tokens
|
|
except Exception as e:
|
|
logger.warning(f"Tokenizer error, falling back to estimation: {e}")
|
|
|
|
# Fallback: word count * 1.3 (rough estimation)
|
|
total_words = sum(len(text.split()) for text in texts)
|
|
return int(total_words * 1.3)
|
|
|
|
def _has_admin_permissions(self, capability: Dict[str, Any]) -> bool:
|
|
"""Check if capability has admin permissions"""
|
|
capabilities = capability.get("capabilities", [])
|
|
return any(
|
|
cap.get("resource") == "admin" and "stats" in cap.get("actions", [])
|
|
for cap in capabilities
|
|
)
|
|
|
|
def _update_stats(self, tokens_processed: int, processing_time_ms: int) -> None:
|
|
"""Update service statistics"""
|
|
self.stats.total_requests += 1
|
|
self.stats.total_tokens_processed += tokens_processed
|
|
self.stats.total_processing_time_ms += processing_time_ms
|
|
self.stats.average_processing_time_ms = (
|
|
self.stats.total_processing_time_ms / self.stats.total_requests
|
|
)
|
|
self.stats.last_request_at = datetime.utcnow().isoformat()
|
|
|
|
async def _log_embedding_usage(
|
|
self,
|
|
tenant_id: str,
|
|
user_id: str,
|
|
tokens_used: int,
|
|
embedding_count: int,
|
|
model: str = "BAAI/bge-m3",
|
|
request_id: Optional[str] = None
|
|
) -> None:
|
|
"""
|
|
Log embedding usage to control panel database for billing.
|
|
|
|
This method logs usage asynchronously and does not block the embedding response.
|
|
Failures are logged as warnings but do not raise exceptions.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
user_id: User identifier (from capability token 'sub')
|
|
tokens_used: Number of tokens processed
|
|
embedding_count: Number of embeddings generated
|
|
model: Embedding model name
|
|
request_id: Optional request ID for tracking
|
|
"""
|
|
try:
|
|
import asyncpg
|
|
|
|
# Calculate cost: BGE-M3 pricing ~$0.10 per million tokens
|
|
cost_cents = (tokens_used / 1_000_000) * 0.10 * 100
|
|
|
|
# Connect to control panel database
|
|
# Using environment variables from docker-compose
|
|
db_password = os.getenv("CONTROL_PANEL_DB_PASSWORD")
|
|
if not db_password:
|
|
logger.warning("CONTROL_PANEL_DB_PASSWORD not set, skipping embedding usage logging")
|
|
return
|
|
|
|
conn = await asyncpg.connect(
|
|
host=os.getenv("CONTROL_PANEL_DB_HOST", "gentwo-controlpanel-postgres"),
|
|
database=os.getenv("CONTROL_PANEL_DB_NAME", "gt2_admin"),
|
|
user=os.getenv("CONTROL_PANEL_DB_USER", "postgres"),
|
|
password=db_password,
|
|
timeout=5.0
|
|
)
|
|
|
|
try:
|
|
# Insert usage log
|
|
await conn.execute("""
|
|
INSERT INTO public.embedding_usage_logs
|
|
(tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
""", tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
|
|
|
|
logger.info(
|
|
f"Logged embedding usage: tenant={tenant_id}, user={user_id}, "
|
|
f"tokens={tokens_used}, embeddings={embedding_count}, cost_cents={cost_cents:.4f}"
|
|
)
|
|
finally:
|
|
await conn.close()
|
|
|
|
except Exception as e:
|
|
# Log warning but don't fail the embedding request
|
|
logger.warning(f"Failed to log embedding usage for tenant {tenant_id}: {e}")
|
|
|
|
|
|
# Global service instance
|
|
_embedding_service = None
|
|
|
|
|
|
def get_embedding_service() -> EmbeddingService:
|
|
"""Get the global embedding service instance"""
|
|
global _embedding_service
|
|
if _embedding_service is None:
|
|
_embedding_service = EmbeddingService()
|
|
return _embedding_service |