Security hardening release addressing CodeQL and Dependabot alerts: - Fix stack trace exposure in error responses - Add SSRF protection with DNS resolution checking - Implement proper URL hostname validation (replaces substring matching) - Add centralized path sanitization to prevent path traversal - Fix ReDoS vulnerability in email validation regex - Improve HTML sanitization in validation utilities - Fix capability wildcard matching in auth utilities - Update glob dependency to address CVE - Add CodeQL suppression comments for verified false positives 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
231 lines
9.0 KiB
Python
231 lines
9.0 KiB
Python
"""
|
|
LLM Inference API endpoints
|
|
|
|
Provides capability-based access to LLM models with:
|
|
- Token validation and capability checking
|
|
- Multiple model support (Groq, OpenAI, Anthropic)
|
|
- Streaming and non-streaming responses
|
|
- Usage tracking and cost calculation
|
|
"""
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends, Header, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from typing import Dict, Any, Optional, List, Union
|
|
from pydantic import BaseModel, Field
|
|
import logging
|
|
|
|
from app.core.security import capability_validator, CapabilityToken
|
|
from app.core.backends import get_backend
|
|
from app.api.auth import verify_capability
|
|
from app.services.model_router import get_model_router
|
|
|
|
router = APIRouter()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class InferenceRequest(BaseModel):
|
|
"""LLM inference request supporting both prompt and messages format"""
|
|
prompt: Optional[str] = Field(default=None, description="Input prompt for the model")
|
|
messages: Optional[list] = Field(default=None, description="Conversation messages in OpenAI format")
|
|
model: str = Field(default="llama-3.1-70b-versatile", description="Model identifier")
|
|
temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
|
|
max_tokens: int = Field(default=4000, ge=1, le=32000, description="Maximum tokens to generate")
|
|
stream: bool = Field(default=False, description="Enable streaming response")
|
|
system_prompt: Optional[str] = Field(default=None, description="System prompt for context")
|
|
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="Available tools for function calling")
|
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(default=None, description="Tool choice strategy")
|
|
user_id: Optional[str] = Field(default=None, description="User identifier for tenant isolation")
|
|
tenant_id: Optional[str] = Field(default=None, description="Tenant identifier for isolation")
|
|
|
|
|
|
class InferenceResponse(BaseModel):
|
|
"""LLM inference response"""
|
|
content: str = Field(..., description="Generated text")
|
|
model: str = Field(..., description="Model used")
|
|
usage: Dict[str, Any] = Field(..., description="Token usage and cost information")
|
|
latency_ms: float = Field(..., description="Inference latency in milliseconds")
|
|
|
|
|
|
@router.post("/", response_model=InferenceResponse)
|
|
async def execute_inference(
|
|
request: InferenceRequest,
|
|
token: CapabilityToken = Depends(verify_capability)
|
|
) -> InferenceResponse:
|
|
"""Execute LLM inference with capability checking"""
|
|
|
|
# Validate request format
|
|
if not request.prompt and not request.messages:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Either 'prompt' or 'messages' must be provided"
|
|
)
|
|
|
|
# Check if user has access to the requested model
|
|
resource = f"llm:{request.model.replace('-', '_')}"
|
|
if not capability_validator.check_resource_access(token, resource, "inference"):
|
|
# Try generic LLM access
|
|
if not capability_validator.check_resource_access(token, "llm:*", "inference"):
|
|
# Try groq specific access
|
|
if not capability_validator.check_resource_access(token, "llm:groq", "inference"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"No capability for model: {request.model}"
|
|
)
|
|
|
|
# Get resource limits from token
|
|
limits = capability_validator.get_resource_limits(token, resource)
|
|
|
|
# Apply token limits
|
|
max_tokens = min(
|
|
request.max_tokens,
|
|
limits.get("max_tokens_per_request", request.max_tokens)
|
|
)
|
|
|
|
# Ensure tenant isolation
|
|
user_id = request.user_id or token.sub
|
|
tenant_id = request.tenant_id or token.tenant_id
|
|
|
|
try:
|
|
# Get model router for tenant
|
|
model_router = await get_model_router(tenant_id)
|
|
|
|
# Prepare prompt for routing
|
|
prompt = request.prompt
|
|
if request.system_prompt and prompt:
|
|
prompt = f"{request.system_prompt}\n\n{prompt}"
|
|
|
|
# Route inference request to appropriate provider
|
|
result = await model_router.route_inference(
|
|
model_id=request.model,
|
|
prompt=prompt,
|
|
messages=request.messages,
|
|
temperature=request.temperature,
|
|
max_tokens=max_tokens,
|
|
stream=False,
|
|
user_id=user_id,
|
|
tenant_id=tenant_id,
|
|
tools=request.tools,
|
|
tool_choice=request.tool_choice
|
|
)
|
|
|
|
return InferenceResponse(**result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Inference error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
|
|
|
|
|
|
@router.post("/stream")
|
|
async def stream_inference(
|
|
request: InferenceRequest,
|
|
token: CapabilityToken = Depends(verify_capability)
|
|
):
|
|
"""Stream LLM inference responses"""
|
|
|
|
# Validate request format
|
|
if not request.prompt and not request.messages:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Either 'prompt' or 'messages' must be provided"
|
|
)
|
|
|
|
# Check streaming capability
|
|
resource = f"llm:{request.model.replace('-', '_')}"
|
|
if not capability_validator.check_resource_access(token, resource, "streaming"):
|
|
if not capability_validator.check_resource_access(token, "llm:*", "streaming"):
|
|
if not capability_validator.check_resource_access(token, "llm:groq", "streaming"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="No streaming capability for this model"
|
|
)
|
|
|
|
# Ensure tenant isolation
|
|
user_id = request.user_id or token.sub
|
|
tenant_id = request.tenant_id or token.tenant_id
|
|
|
|
try:
|
|
# Get model router for tenant
|
|
model_router = await get_model_router(tenant_id)
|
|
|
|
# Prepare prompt for routing
|
|
prompt = request.prompt
|
|
if request.system_prompt and prompt:
|
|
prompt = f"{request.system_prompt}\n\n{prompt}"
|
|
|
|
# For now, fall back to groq backend for streaming (TODO: implement streaming in model router)
|
|
backend = get_backend("groq_proxy")
|
|
|
|
# Handle different request formats
|
|
if request.messages:
|
|
# Use messages format for streaming
|
|
async def generate():
|
|
async for chunk in backend._stream_inference_with_messages(
|
|
messages=request.messages,
|
|
model=request.model,
|
|
temperature=request.temperature,
|
|
max_tokens=request.max_tokens,
|
|
user_id=user_id,
|
|
tenant_id=tenant_id
|
|
):
|
|
yield f"data: {chunk}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
else:
|
|
# Use prompt format for streaming
|
|
async def generate():
|
|
async for chunk in backend._stream_inference(
|
|
messages=[{"role": "user", "content": prompt}],
|
|
model=request.model,
|
|
temperature=request.temperature,
|
|
max_tokens=request.max_tokens,
|
|
user_id=user_id,
|
|
tenant_id=tenant_id
|
|
):
|
|
yield f"data: {chunk}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no" # Disable nginx buffering
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Streaming inference error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")
|
|
|
|
|
|
@router.get("/models")
|
|
async def list_available_models(
|
|
token: CapabilityToken = Depends(verify_capability)
|
|
) -> Dict[str, Any]:
|
|
"""List available models based on user capabilities"""
|
|
|
|
try:
|
|
# Get model router for token's tenant
|
|
tenant_id = getattr(token, 'tenant_id', None)
|
|
model_router = await get_model_router(tenant_id)
|
|
|
|
# Get all available models from registry
|
|
all_models = await model_router.list_available_models()
|
|
|
|
# Filter based on user capabilities
|
|
accessible_models = []
|
|
for model in all_models:
|
|
resource = f"llm:{model['id'].replace('-', '_')}"
|
|
if capability_validator.check_resource_access(token, resource, "inference"):
|
|
accessible_models.append(model)
|
|
elif capability_validator.check_resource_access(token, "llm:*", "inference"):
|
|
accessible_models.append(model)
|
|
|
|
return {
|
|
"models": accessible_models,
|
|
"total": len(accessible_models)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error listing models: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to list models") |