GT AI OS Community Edition v2.0.33
Security hardening release addressing CodeQL and Dependabot alerts: - Fix stack trace exposure in error responses - Add SSRF protection with DNS resolution checking - Implement proper URL hostname validation (replaces substring matching) - Add centralized path sanitization to prevent path traversal - Fix ReDoS vulnerability in email validation regex - Improve HTML sanitization in validation utilities - Fix capability wildcard matching in auth utilities - Update glob dependency to address CVE - Add CodeQL suppression comments for verified false positives 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
447
apps/resource-cluster/app/services/embedding_service.py
Normal file
447
apps/resource-cluster/app/services/embedding_service.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
Embedding Service for GT 2.0 Resource Cluster
|
||||
|
||||
Provides embedding generation with:
|
||||
- BGE-M3 model integration
|
||||
- Batch processing capabilities
|
||||
- Rate limiting and quota management
|
||||
- Capability-based authentication
|
||||
- Stateless operation (no data storage)
|
||||
|
||||
GT 2.0 Architecture Principles:
|
||||
- Perfect Tenant Isolation: Per-request capability validation
|
||||
- Zero Downtime: Stateless design, circuit breakers
|
||||
- Self-Contained Security: Capability-based auth
|
||||
- No Complexity Addition: Simple interface, no database
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, asdict
|
||||
import uuid
|
||||
|
||||
from app.core.backends.embedding_backend import EmbeddingBackend, EmbeddingRequest
|
||||
from app.core.capability_auth import verify_capability_token, CapabilityError
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResponse:
|
||||
"""Response structure for embedding generation"""
|
||||
request_id: str
|
||||
embeddings: List[List[float]]
|
||||
model: str
|
||||
dimensions: int
|
||||
tokens_used: int
|
||||
processing_time_ms: int
|
||||
tenant_id: str
|
||||
created_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingStats:
|
||||
"""Statistics for embedding requests"""
|
||||
total_requests: int = 0
|
||||
total_tokens_processed: int = 0
|
||||
total_processing_time_ms: int = 0
|
||||
average_processing_time_ms: float = 0.0
|
||||
last_request_at: Optional[str] = None
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""
|
||||
STATELESS embedding service for GT 2.0 Resource Cluster.
|
||||
|
||||
Key features:
|
||||
- BGE-M3 model for high-quality embeddings
|
||||
- Batch processing for efficiency
|
||||
- Rate limiting per capability token
|
||||
- Memory-conscious processing
|
||||
- No persistent storage
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.backend = EmbeddingBackend()
|
||||
self.stats = EmbeddingStats()
|
||||
|
||||
# Initialize BGE-M3 tokenizer for accurate token counting
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
|
||||
logger.info("Initialized BGE-M3 tokenizer for accurate token counting")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load BGE-M3 tokenizer, using word estimation: {e}")
|
||||
self.tokenizer = None
|
||||
|
||||
# Rate limiting settings (per capability token)
|
||||
self.rate_limits = {
|
||||
"requests_per_minute": 60,
|
||||
"tokens_per_minute": 50000,
|
||||
"max_batch_size": 32
|
||||
}
|
||||
|
||||
# Track requests for rate limiting (in-memory, temporary)
|
||||
self._request_tracker = {}
|
||||
|
||||
logger.info("STATELESS embedding service initialized")
|
||||
|
||||
async def generate_embeddings(
|
||||
self,
|
||||
texts: List[str],
|
||||
capability_token: str,
|
||||
instruction: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
normalize: bool = True
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Generate embeddings with capability-based authentication.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
capability_token: JWT token with embedding permissions
|
||||
instruction: Optional instruction for embedding context
|
||||
request_id: Optional request ID for tracking
|
||||
normalize: Whether to normalize embeddings
|
||||
|
||||
Returns:
|
||||
EmbeddingResponse with generated embeddings
|
||||
|
||||
Raises:
|
||||
CapabilityError: If token invalid or insufficient permissions
|
||||
ValueError: If request parameters invalid
|
||||
"""
|
||||
start_time = time.time()
|
||||
request_id = request_id or str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Verify capability token and extract permissions
|
||||
capability = await verify_capability_token(capability_token)
|
||||
tenant_id = capability.get("tenant_id")
|
||||
user_id = capability.get("sub") # Extract user ID from token
|
||||
|
||||
# Check embedding permissions
|
||||
await self._verify_embedding_permissions(capability, len(texts))
|
||||
|
||||
# Apply rate limiting
|
||||
await self._check_rate_limits(capability_token, len(texts))
|
||||
|
||||
# Validate input
|
||||
self._validate_embedding_request(texts)
|
||||
|
||||
# Generate embeddings via backend
|
||||
embeddings = await self.backend.generate_embeddings(
|
||||
texts=texts,
|
||||
instruction=instruction,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
# Calculate processing metrics
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
total_tokens = self._estimate_tokens(texts)
|
||||
|
||||
# Update statistics
|
||||
self._update_stats(total_tokens, processing_time_ms)
|
||||
|
||||
# Log embedding usage for billing (non-blocking)
|
||||
# Fire and forget - don't wait for completion
|
||||
asyncio.create_task(
|
||||
self._log_embedding_usage(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
tokens_used=total_tokens,
|
||||
embedding_count=len(embeddings),
|
||||
model=self.backend.model_name,
|
||||
request_id=request_id
|
||||
)
|
||||
)
|
||||
|
||||
# Create response
|
||||
response = EmbeddingResponse(
|
||||
request_id=request_id,
|
||||
embeddings=embeddings,
|
||||
model=self.backend.model_name,
|
||||
dimensions=self.backend.embedding_dimensions,
|
||||
tokens_used=total_tokens,
|
||||
processing_time_ms=processing_time_ms,
|
||||
tenant_id=tenant_id,
|
||||
created_at=datetime.utcnow().isoformat()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated {len(embeddings)} embeddings for tenant {tenant_id} "
|
||||
f"in {processing_time_ms}ms"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Always ensure cleanup
|
||||
if 'texts' in locals():
|
||||
del texts
|
||||
|
||||
async def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the embedding model"""
|
||||
return {
|
||||
"model_name": self.backend.model_name,
|
||||
"dimensions": self.backend.embedding_dimensions,
|
||||
"max_sequence_length": self.backend.max_sequence_length,
|
||||
"max_batch_size": self.backend.max_batch_size,
|
||||
"supports_instruction": True,
|
||||
"normalization_default": True
|
||||
}
|
||||
|
||||
async def get_service_stats(
|
||||
self,
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get service statistics (for admin users only).
|
||||
|
||||
Args:
|
||||
capability_token: JWT token with admin permissions
|
||||
|
||||
Returns:
|
||||
Service statistics
|
||||
"""
|
||||
# Verify admin permissions
|
||||
capability = await verify_capability_token(capability_token)
|
||||
if not self._has_admin_permissions(capability):
|
||||
raise CapabilityError("Admin permissions required")
|
||||
|
||||
return {
|
||||
"model_info": await self.get_model_info(),
|
||||
"statistics": asdict(self.stats),
|
||||
"rate_limits": self.rate_limits,
|
||||
"active_requests": len(self._request_tracker)
|
||||
}
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Check service health"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "embedding_service",
|
||||
"model": self.backend.model_name,
|
||||
"backend_ready": True,
|
||||
"last_request": self.stats.last_request_at
|
||||
}
|
||||
|
||||
async def _verify_embedding_permissions(
|
||||
self,
|
||||
capability: Dict[str, Any],
|
||||
text_count: int
|
||||
) -> None:
|
||||
"""Verify that capability token has embedding permissions"""
|
||||
|
||||
# Check for embedding capability
|
||||
capabilities = capability.get("capabilities", [])
|
||||
embedding_caps = [
|
||||
cap for cap in capabilities
|
||||
if cap.get("resource") == "embeddings"
|
||||
]
|
||||
|
||||
if not embedding_caps:
|
||||
raise CapabilityError("No embedding permissions in capability token")
|
||||
|
||||
# Check constraints
|
||||
embedding_cap = embedding_caps[0] # Use first embedding capability
|
||||
constraints = embedding_cap.get("constraints", {})
|
||||
|
||||
# Check batch size limit
|
||||
max_batch = constraints.get("max_batch_size", self.rate_limits["max_batch_size"])
|
||||
if text_count > max_batch:
|
||||
raise CapabilityError(f"Batch size {text_count} exceeds limit {max_batch}")
|
||||
|
||||
# Check rate limits
|
||||
rate_limit = constraints.get("rate_limit_per_minute", self.rate_limits["requests_per_minute"])
|
||||
token_limit = constraints.get("tokens_per_minute", self.rate_limits["tokens_per_minute"])
|
||||
|
||||
logger.debug(f"Embedding permissions verified: batch={text_count}, limits=({rate_limit}, {token_limit})")
|
||||
|
||||
async def _check_rate_limits(
|
||||
self,
|
||||
capability_token: str,
|
||||
text_count: int
|
||||
) -> None:
|
||||
"""Check rate limits for capability token"""
|
||||
|
||||
now = time.time()
|
||||
token_hash = hash(capability_token) % 10000 # Simple tracking key
|
||||
|
||||
# Clean old entries (older than 1 minute)
|
||||
cleanup_time = now - 60
|
||||
self._request_tracker = {
|
||||
k: v for k, v in self._request_tracker.items()
|
||||
if v.get("last_request", 0) > cleanup_time
|
||||
}
|
||||
|
||||
# Get or create tracker for this token
|
||||
if token_hash not in self._request_tracker:
|
||||
self._request_tracker[token_hash] = {
|
||||
"requests": 0,
|
||||
"tokens": 0,
|
||||
"last_request": now
|
||||
}
|
||||
|
||||
tracker = self._request_tracker[token_hash]
|
||||
|
||||
# Check request rate limit
|
||||
if tracker["requests"] >= self.rate_limits["requests_per_minute"]:
|
||||
raise CapabilityError("Rate limit exceeded: too many requests per minute")
|
||||
|
||||
# Estimate tokens and check token limit
|
||||
estimated_tokens = self._estimate_tokens([f"text_{i}" for i in range(text_count)])
|
||||
if tracker["tokens"] + estimated_tokens > self.rate_limits["tokens_per_minute"]:
|
||||
raise CapabilityError("Rate limit exceeded: too many tokens per minute")
|
||||
|
||||
# Update tracker
|
||||
tracker["requests"] += 1
|
||||
tracker["tokens"] += estimated_tokens
|
||||
tracker["last_request"] = now
|
||||
|
||||
def _validate_embedding_request(self, texts: List[str]) -> None:
|
||||
"""Validate embedding request parameters"""
|
||||
|
||||
if not texts:
|
||||
raise ValueError("No texts provided for embedding")
|
||||
|
||||
if not isinstance(texts, list):
|
||||
raise ValueError("Texts must be a list")
|
||||
|
||||
if len(texts) > self.backend.max_batch_size:
|
||||
raise ValueError(f"Batch size {len(texts)} exceeds maximum {self.backend.max_batch_size}")
|
||||
|
||||
# Check individual text lengths
|
||||
for i, text in enumerate(texts):
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Text at index {i} must be a string")
|
||||
|
||||
if len(text.strip()) == 0:
|
||||
raise ValueError(f"Text at index {i} is empty")
|
||||
|
||||
# Simple token estimation for length check
|
||||
estimated_tokens = len(text.split()) * 1.3 # Rough estimation
|
||||
if estimated_tokens > self.backend.max_sequence_length:
|
||||
raise ValueError(f"Text at index {i} exceeds maximum length")
|
||||
|
||||
def _estimate_tokens(self, texts: List[str]) -> int:
|
||||
"""
|
||||
Count tokens using actual BGE-M3 tokenizer.
|
||||
Falls back to word-count estimation if tokenizer unavailable.
|
||||
"""
|
||||
if self.tokenizer is not None:
|
||||
try:
|
||||
total_tokens = 0
|
||||
for text in texts:
|
||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
total_tokens += len(tokens)
|
||||
return total_tokens
|
||||
except Exception as e:
|
||||
logger.warning(f"Tokenizer error, falling back to estimation: {e}")
|
||||
|
||||
# Fallback: word count * 1.3 (rough estimation)
|
||||
total_words = sum(len(text.split()) for text in texts)
|
||||
return int(total_words * 1.3)
|
||||
|
||||
def _has_admin_permissions(self, capability: Dict[str, Any]) -> bool:
|
||||
"""Check if capability has admin permissions"""
|
||||
capabilities = capability.get("capabilities", [])
|
||||
return any(
|
||||
cap.get("resource") == "admin" and "stats" in cap.get("actions", [])
|
||||
for cap in capabilities
|
||||
)
|
||||
|
||||
def _update_stats(self, tokens_processed: int, processing_time_ms: int) -> None:
|
||||
"""Update service statistics"""
|
||||
self.stats.total_requests += 1
|
||||
self.stats.total_tokens_processed += tokens_processed
|
||||
self.stats.total_processing_time_ms += processing_time_ms
|
||||
self.stats.average_processing_time_ms = (
|
||||
self.stats.total_processing_time_ms / self.stats.total_requests
|
||||
)
|
||||
self.stats.last_request_at = datetime.utcnow().isoformat()
|
||||
|
||||
async def _log_embedding_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
tokens_used: int,
|
||||
embedding_count: int,
|
||||
model: str = "BAAI/bge-m3",
|
||||
request_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log embedding usage to control panel database for billing.
|
||||
|
||||
This method logs usage asynchronously and does not block the embedding response.
|
||||
Failures are logged as warnings but do not raise exceptions.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
user_id: User identifier (from capability token 'sub')
|
||||
tokens_used: Number of tokens processed
|
||||
embedding_count: Number of embeddings generated
|
||||
model: Embedding model name
|
||||
request_id: Optional request ID for tracking
|
||||
"""
|
||||
try:
|
||||
import asyncpg
|
||||
|
||||
# Calculate cost: BGE-M3 pricing ~$0.10 per million tokens
|
||||
cost_cents = (tokens_used / 1_000_000) * 0.10 * 100
|
||||
|
||||
# Connect to control panel database
|
||||
# Using environment variables from docker-compose
|
||||
db_password = os.getenv("CONTROL_PANEL_DB_PASSWORD")
|
||||
if not db_password:
|
||||
logger.warning("CONTROL_PANEL_DB_PASSWORD not set, skipping embedding usage logging")
|
||||
return
|
||||
|
||||
conn = await asyncpg.connect(
|
||||
host=os.getenv("CONTROL_PANEL_DB_HOST", "gentwo-controlpanel-postgres"),
|
||||
database=os.getenv("CONTROL_PANEL_DB_NAME", "gt2_admin"),
|
||||
user=os.getenv("CONTROL_PANEL_DB_USER", "postgres"),
|
||||
password=db_password,
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
try:
|
||||
# Insert usage log
|
||||
await conn.execute("""
|
||||
INSERT INTO public.embedding_usage_logs
|
||||
(tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""", tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
|
||||
|
||||
logger.info(
|
||||
f"Logged embedding usage: tenant={tenant_id}, user={user_id}, "
|
||||
f"tokens={tokens_used}, embeddings={embedding_count}, cost_cents={cost_cents:.4f}"
|
||||
)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
except Exception as e:
|
||||
# Log warning but don't fail the embedding request
|
||||
logger.warning(f"Failed to log embedding usage for tenant {tenant_id}: {e}")
|
||||
|
||||
|
||||
# Global service instance
|
||||
_embedding_service = None
|
||||
|
||||
|
||||
def get_embedding_service() -> EmbeddingService:
|
||||
"""Get the global embedding service instance"""
|
||||
global _embedding_service
|
||||
if _embedding_service is None:
|
||||
_embedding_service = EmbeddingService()
|
||||
return _embedding_service
|
||||
Reference in New Issue
Block a user