- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2 - Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2 - Made more general-purpose (flexible targets, expanded tools) - Added nemotron-mini-agent.csv for fast local inference via Ollama - Added nemotron-agent.csv for advanced reasoning via Ollama - Added wiki page: Projects for NVIDIA NIMs and Nemotron
461 lines
17 KiB
Python
461 lines
17 KiB
Python
"""
|
|
API Key Management Service for tenant-specific external API keys
|
|
"""
|
|
import os
|
|
import json
|
|
from typing import Dict, Any, Optional, List
|
|
from datetime import datetime
|
|
from cryptography.fernet import Fernet
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.orm.attributes import flag_modified
|
|
|
|
from app.models.tenant import Tenant
|
|
from app.models.audit import AuditLog
|
|
from app.core.config import settings
|
|
|
|
|
|
class APIKeyService:
|
|
"""Service for managing tenant-specific API keys"""
|
|
|
|
# Supported API key providers - NVIDIA, Groq, and Backblaze
|
|
SUPPORTED_PROVIDERS = {
|
|
'nvidia': {
|
|
'name': 'NVIDIA NIM',
|
|
'description': 'GPU-accelerated inference on DGX Cloud via build.nvidia.com',
|
|
'required_format': 'nvapi-*',
|
|
'test_endpoint': 'https://integrate.api.nvidia.com/v1/models'
|
|
},
|
|
'groq': {
|
|
'name': 'Groq Cloud LLM',
|
|
'description': 'High-performance LLM inference',
|
|
'required_format': 'gsk_*',
|
|
'test_endpoint': 'https://api.groq.com/openai/v1/models'
|
|
},
|
|
'backblaze': {
|
|
'name': 'Backblaze B2',
|
|
'description': 'S3-compatible backup storage',
|
|
'required_format': None, # Key ID and Application Key
|
|
'test_endpoint': None
|
|
}
|
|
}
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
# Use environment variable or generate a key for encryption
|
|
encryption_key = os.getenv('API_KEY_ENCRYPTION_KEY')
|
|
if not encryption_key:
|
|
# In production, this should be stored securely
|
|
encryption_key = Fernet.generate_key().decode()
|
|
os.environ['API_KEY_ENCRYPTION_KEY'] = encryption_key
|
|
self.cipher = Fernet(encryption_key.encode() if isinstance(encryption_key, str) else encryption_key)
|
|
|
|
async def set_api_key(
|
|
self,
|
|
tenant_id: int,
|
|
provider: str,
|
|
api_key: str,
|
|
api_secret: Optional[str] = None,
|
|
enabled: bool = True,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""Set or update an API key for a tenant"""
|
|
|
|
if provider not in self.SUPPORTED_PROVIDERS:
|
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
|
|
# Validate key format if required
|
|
provider_info = self.SUPPORTED_PROVIDERS[provider]
|
|
if provider_info['required_format'] and not api_key.startswith(provider_info['required_format'].replace('*', '')):
|
|
raise ValueError(f"Invalid API key format for {provider}")
|
|
|
|
# Get tenant
|
|
result = await self.db.execute(
|
|
select(Tenant).where(Tenant.id == tenant_id)
|
|
)
|
|
tenant = result.scalar_one_or_none()
|
|
if not tenant:
|
|
raise ValueError(f"Tenant {tenant_id} not found")
|
|
|
|
# Encrypt API key
|
|
encrypted_key = self.cipher.encrypt(api_key.encode()).decode()
|
|
encrypted_secret = None
|
|
if api_secret:
|
|
encrypted_secret = self.cipher.encrypt(api_secret.encode()).decode()
|
|
|
|
# Update tenant's API keys
|
|
api_keys = tenant.api_keys or {}
|
|
api_keys[provider] = {
|
|
'key': encrypted_key,
|
|
'secret': encrypted_secret,
|
|
'enabled': enabled,
|
|
'metadata': metadata or {},
|
|
'updated_at': datetime.utcnow().isoformat(),
|
|
'updated_by': 'admin' # Should come from auth context
|
|
}
|
|
|
|
tenant.api_keys = api_keys
|
|
flag_modified(tenant, "api_keys")
|
|
await self.db.commit()
|
|
|
|
# Log the action
|
|
audit_log = AuditLog(
|
|
tenant_id=tenant_id,
|
|
action='api_key_updated',
|
|
resource_type='api_key',
|
|
resource_id=provider,
|
|
details={'provider': provider, 'enabled': enabled}
|
|
)
|
|
self.db.add(audit_log)
|
|
await self.db.commit()
|
|
|
|
# Invalidate Resource Cluster cache so it picks up the new key
|
|
await self._invalidate_resource_cluster_cache(tenant.domain, provider)
|
|
|
|
return {
|
|
'tenant_id': tenant_id,
|
|
'provider': provider,
|
|
'enabled': enabled,
|
|
'updated_at': api_keys[provider]['updated_at']
|
|
}
|
|
|
|
async def get_api_keys(self, tenant_id: int) -> Dict[str, Any]:
|
|
"""Get all API keys for a tenant (without decryption)"""
|
|
|
|
result = await self.db.execute(
|
|
select(Tenant).where(Tenant.id == tenant_id)
|
|
)
|
|
tenant = result.scalar_one_or_none()
|
|
if not tenant:
|
|
raise ValueError(f"Tenant {tenant_id} not found")
|
|
|
|
api_keys = tenant.api_keys or {}
|
|
|
|
# Return key status without actual keys
|
|
return {
|
|
provider: {
|
|
'configured': True,
|
|
'enabled': info.get('enabled', False),
|
|
'updated_at': info.get('updated_at'),
|
|
'metadata': info.get('metadata', {})
|
|
}
|
|
for provider, info in api_keys.items()
|
|
}
|
|
|
|
async def get_decrypted_key(
|
|
self,
|
|
tenant_id: int,
|
|
provider: str,
|
|
require_enabled: bool = True
|
|
) -> Dict[str, Any]:
|
|
"""Get decrypted API key for a specific provider"""
|
|
|
|
result = await self.db.execute(
|
|
select(Tenant).where(Tenant.id == tenant_id)
|
|
)
|
|
tenant = result.scalar_one_or_none()
|
|
if not tenant:
|
|
raise ValueError(f"Tenant {tenant_id} not found")
|
|
|
|
api_keys = tenant.api_keys or {}
|
|
if provider not in api_keys:
|
|
raise ValueError(f"API key for {provider} not configured for tenant {tenant_id}")
|
|
|
|
key_info = api_keys[provider]
|
|
if require_enabled and not key_info.get('enabled', False):
|
|
raise ValueError(f"API key for {provider} is disabled for tenant {tenant_id}")
|
|
|
|
# Decrypt the key
|
|
decrypted_key = self.cipher.decrypt(key_info['key'].encode()).decode()
|
|
decrypted_secret = None
|
|
if key_info.get('secret'):
|
|
decrypted_secret = self.cipher.decrypt(key_info['secret'].encode()).decode()
|
|
|
|
return {
|
|
'provider': provider,
|
|
'api_key': decrypted_key,
|
|
'api_secret': decrypted_secret,
|
|
'metadata': key_info.get('metadata', {}),
|
|
'enabled': key_info.get('enabled', False)
|
|
}
|
|
|
|
async def disable_api_key(self, tenant_id: int, provider: str) -> bool:
|
|
"""Disable an API key without removing it"""
|
|
|
|
result = await self.db.execute(
|
|
select(Tenant).where(Tenant.id == tenant_id)
|
|
)
|
|
tenant = result.scalar_one_or_none()
|
|
if not tenant:
|
|
raise ValueError(f"Tenant {tenant_id} not found")
|
|
|
|
api_keys = tenant.api_keys or {}
|
|
if provider not in api_keys:
|
|
raise ValueError(f"API key for {provider} not configured")
|
|
|
|
api_keys[provider]['enabled'] = False
|
|
api_keys[provider]['updated_at'] = datetime.utcnow().isoformat()
|
|
|
|
tenant.api_keys = api_keys
|
|
flag_modified(tenant, "api_keys")
|
|
await self.db.commit()
|
|
|
|
# Log the action
|
|
audit_log = AuditLog(
|
|
tenant_id=tenant_id,
|
|
action='api_key_disabled',
|
|
resource_type='api_key',
|
|
resource_id=provider,
|
|
details={'provider': provider}
|
|
)
|
|
self.db.add(audit_log)
|
|
await self.db.commit()
|
|
|
|
# Invalidate Resource Cluster cache
|
|
await self._invalidate_resource_cluster_cache(tenant.domain, provider)
|
|
|
|
return True
|
|
|
|
async def remove_api_key(self, tenant_id: int, provider: str) -> bool:
|
|
"""Completely remove an API key"""
|
|
|
|
result = await self.db.execute(
|
|
select(Tenant).where(Tenant.id == tenant_id)
|
|
)
|
|
tenant = result.scalar_one_or_none()
|
|
if not tenant:
|
|
raise ValueError(f"Tenant {tenant_id} not found")
|
|
|
|
api_keys = tenant.api_keys or {}
|
|
if provider in api_keys:
|
|
del api_keys[provider]
|
|
tenant.api_keys = api_keys
|
|
flag_modified(tenant, "api_keys")
|
|
await self.db.commit()
|
|
|
|
# Log the action
|
|
audit_log = AuditLog(
|
|
tenant_id=tenant_id,
|
|
action='api_key_removed',
|
|
resource_type='api_key',
|
|
resource_id=provider,
|
|
details={'provider': provider}
|
|
)
|
|
self.db.add(audit_log)
|
|
await self.db.commit()
|
|
|
|
# Invalidate Resource Cluster cache
|
|
await self._invalidate_resource_cluster_cache(tenant.domain, provider)
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
async def test_api_key(self, tenant_id: int, provider: str) -> Dict[str, Any]:
|
|
"""Test if an API key is valid by making a test request with detailed error mapping"""
|
|
|
|
import httpx
|
|
|
|
# Get decrypted key
|
|
key_info = await self.get_decrypted_key(tenant_id, provider)
|
|
provider_info = self.SUPPORTED_PROVIDERS[provider]
|
|
|
|
if not provider_info.get('test_endpoint'):
|
|
return {
|
|
'provider': provider,
|
|
'testable': False,
|
|
'valid': False,
|
|
'message': 'No test endpoint available for this provider',
|
|
'error_type': 'not_testable'
|
|
}
|
|
|
|
# Validate key format before making request
|
|
api_key = key_info['api_key']
|
|
if provider == 'nvidia' and not api_key.startswith('nvapi-'):
|
|
return {
|
|
'provider': provider,
|
|
'valid': False,
|
|
'message': 'Invalid key format (should start with nvapi-)',
|
|
'error_type': 'invalid_format'
|
|
}
|
|
if provider == 'groq' and not api_key.startswith('gsk_'):
|
|
return {
|
|
'provider': provider,
|
|
'valid': False,
|
|
'message': 'Invalid key format (should start with gsk_)',
|
|
'error_type': 'invalid_format'
|
|
}
|
|
|
|
# Build authorization headers based on provider
|
|
headers = self._get_auth_headers(provider, api_key)
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
provider_info['test_endpoint'],
|
|
headers=headers,
|
|
timeout=10.0
|
|
)
|
|
|
|
# Extract rate limit headers
|
|
rate_limit_remaining = None
|
|
rate_limit_reset = None
|
|
if 'x-ratelimit-remaining' in response.headers:
|
|
try:
|
|
rate_limit_remaining = int(response.headers['x-ratelimit-remaining'])
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if 'x-ratelimit-reset' in response.headers:
|
|
rate_limit_reset = response.headers['x-ratelimit-reset']
|
|
|
|
# Count available models if response is successful
|
|
models_available = None
|
|
if response.status_code == 200:
|
|
try:
|
|
data = response.json()
|
|
if 'data' in data and isinstance(data['data'], list):
|
|
models_available = len(data['data'])
|
|
except Exception:
|
|
pass
|
|
|
|
# Detailed error mapping
|
|
if response.status_code == 200:
|
|
return {
|
|
'provider': provider,
|
|
'valid': True,
|
|
'message': 'API key is valid',
|
|
'status_code': response.status_code,
|
|
'rate_limit_remaining': rate_limit_remaining,
|
|
'rate_limit_reset': rate_limit_reset,
|
|
'models_available': models_available
|
|
}
|
|
elif response.status_code == 401:
|
|
return {
|
|
'provider': provider,
|
|
'valid': False,
|
|
'message': 'Invalid or expired API key',
|
|
'status_code': response.status_code,
|
|
'error_type': 'auth_failed',
|
|
'rate_limit_remaining': rate_limit_remaining,
|
|
'rate_limit_reset': rate_limit_reset
|
|
}
|
|
elif response.status_code == 403:
|
|
return {
|
|
'provider': provider,
|
|
'valid': False,
|
|
'message': 'Insufficient permissions for this API key',
|
|
'status_code': response.status_code,
|
|
'error_type': 'insufficient_permissions',
|
|
'rate_limit_remaining': rate_limit_remaining,
|
|
'rate_limit_reset': rate_limit_reset
|
|
}
|
|
elif response.status_code == 429:
|
|
return {
|
|
'provider': provider,
|
|
'valid': True, # Key is valid, just rate limited
|
|
'message': 'Rate limit exceeded - key is valid but currently limited',
|
|
'status_code': response.status_code,
|
|
'error_type': 'rate_limited',
|
|
'rate_limit_remaining': rate_limit_remaining,
|
|
'rate_limit_reset': rate_limit_reset
|
|
}
|
|
else:
|
|
return {
|
|
'provider': provider,
|
|
'valid': False,
|
|
'message': f'Test failed with HTTP {response.status_code}',
|
|
'status_code': response.status_code,
|
|
'error_type': 'server_error' if response.status_code >= 500 else 'unknown',
|
|
'rate_limit_remaining': rate_limit_remaining,
|
|
'rate_limit_reset': rate_limit_reset
|
|
}
|
|
|
|
except httpx.ConnectError:
|
|
return {
|
|
'provider': provider,
|
|
'valid': False,
|
|
'message': f"Connection failed: Unable to reach {provider_info['test_endpoint']}",
|
|
'error_type': 'connection_error'
|
|
}
|
|
except httpx.TimeoutException:
|
|
return {
|
|
'provider': provider,
|
|
'valid': False,
|
|
'message': 'Connection timed out after 10 seconds',
|
|
'error_type': 'timeout'
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
'provider': provider,
|
|
'valid': False,
|
|
'error': str(e),
|
|
'message': f"Test failed: {str(e)}",
|
|
'error_type': 'unknown'
|
|
}
|
|
|
|
def _get_auth_headers(self, provider: str, api_key: str) -> Dict[str, str]:
|
|
"""Build authorization headers based on provider"""
|
|
if provider in ('nvidia', 'groq', 'openai', 'cohere', 'huggingface'):
|
|
return {'Authorization': f"Bearer {api_key}"}
|
|
elif provider == 'anthropic':
|
|
return {'x-api-key': api_key}
|
|
else:
|
|
return {'Authorization': f"Bearer {api_key}"}
|
|
|
|
async def get_api_key_usage(self, tenant_id: int, provider: str) -> Dict[str, Any]:
|
|
"""Get usage statistics for an API key"""
|
|
|
|
# This would query usage records for the specific provider
|
|
# For now, return mock data
|
|
return {
|
|
'provider': provider,
|
|
'tenant_id': tenant_id,
|
|
'usage': {
|
|
'requests_today': 1234,
|
|
'tokens_today': 456789,
|
|
'cost_today_cents': 234,
|
|
'requests_month': 45678,
|
|
'tokens_month': 12345678,
|
|
'cost_month_cents': 8901
|
|
}
|
|
}
|
|
|
|
async def _invalidate_resource_cluster_cache(
|
|
self,
|
|
tenant_domain: str,
|
|
provider: str
|
|
) -> None:
|
|
"""
|
|
Notify Resource Cluster to invalidate its API key cache.
|
|
|
|
This is called after API keys are modified, disabled, or removed
|
|
to ensure the Resource Cluster doesn't use stale cached keys.
|
|
|
|
Non-critical: If this fails, the cache will expire naturally after TTL.
|
|
"""
|
|
try:
|
|
from app.clients.resource_cluster_client import get_resource_cluster_client
|
|
|
|
client = get_resource_cluster_client()
|
|
await client.invalidate_api_key_cache(
|
|
tenant_domain=tenant_domain,
|
|
provider=provider
|
|
)
|
|
except Exception as e:
|
|
# Log but don't fail the main operation
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logger.warning(f"Failed to invalidate Resource Cluster cache (non-critical): {e}")
|
|
|
|
@classmethod
|
|
def get_supported_providers(cls) -> List[Dict[str, Any]]:
|
|
"""Get list of supported API key providers"""
|
|
return [
|
|
{
|
|
'id': provider_id,
|
|
'name': info['name'],
|
|
'description': info['description'],
|
|
'requires_secret': provider_id == 'backblaze'
|
|
}
|
|
for provider_id, info in cls.SUPPORTED_PROVIDERS.items()
|
|
] |