GT AI OS Community Edition v2.0.33
Security hardening release addressing CodeQL and Dependabot alerts: - Fix stack trace exposure in error responses - Add SSRF protection with DNS resolution checking - Implement proper URL hostname validation (replaces substring matching) - Add centralized path sanitization to prevent path traversal - Fix ReDoS vulnerability in email validation regex - Improve HTML sanitization in validation utilities - Fix capability wildcard matching in auth utilities - Update glob dependency to address CVE - Add CodeQL suppression comments for verified false positives 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
5
apps/tenant-backend/app/api/__init__.py
Normal file
5
apps/tenant-backend/app/api/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend API Module
|
||||
|
||||
FastAPI routers and endpoints for tenant-specific functionality.
|
||||
"""
|
||||
757
apps/tenant-backend/app/api/auth.py
Normal file
757
apps/tenant-backend/app/api/auth.py
Normal file
@@ -0,0 +1,757 @@
|
||||
"""
|
||||
Authentication API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Handles JWT authentication via Control Panel Backend.
|
||||
No mocks - following GT 2.0 philosophy of building on real foundations.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from fastapi import APIRouter, HTTPException, Depends, Header, Request, Response, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from pydantic import BaseModel, EmailStr
|
||||
import jwt
|
||||
import logging
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.security import create_capability_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["authentication"])
|
||||
security = HTTPBearer(auto_error=False)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
# Development authentication function will be defined after class definitions
|
||||
|
||||
|
||||
# Pydantic models
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
user: dict
|
||||
tenant: Optional[dict] = None
|
||||
|
||||
|
||||
class TFASetupResponse(BaseModel):
|
||||
"""Response when TFA is enforced but not yet configured
|
||||
Session data (QR code, temp token) stored server-side in HTTP-only cookie"""
|
||||
requires_tfa: bool = True
|
||||
tfa_configured: bool = False
|
||||
|
||||
|
||||
class TFAVerificationResponse(BaseModel):
|
||||
"""Response when TFA is configured and verification is required
|
||||
Session data (temp token) stored server-side in HTTP-only cookie"""
|
||||
requires_tfa: bool = True
|
||||
tfa_configured: bool = True
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
id: int
|
||||
email: str
|
||||
full_name: str
|
||||
user_type: str
|
||||
tenant_id: Optional[int]
|
||||
capabilities: list
|
||||
is_active: bool
|
||||
|
||||
|
||||
# No development authentication function - violates No Mocks principle
|
||||
# All authentication MUST go through Control Panel Backend
|
||||
|
||||
|
||||
async def get_tenant_user_uuid_by_email(email: str) -> Optional[str]:
|
||||
"""
|
||||
Query tenant database to get user UUID by email.
|
||||
This maps Control Panel users to tenant-specific UUIDs for resource access.
|
||||
"""
|
||||
try:
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
|
||||
client = await get_postgresql_client()
|
||||
if not client or not client._initialized:
|
||||
logger.warning("PostgreSQL client not initialized, cannot query tenant user")
|
||||
return None
|
||||
|
||||
# Query tenant schema for user by email
|
||||
query = f"""
|
||||
SELECT id FROM {client.schema_name}.users
|
||||
WHERE email = $1
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
async with client._pool.acquire() as conn:
|
||||
user_uuid = await conn.fetchval(query, email)
|
||||
|
||||
if user_uuid:
|
||||
logger.info(f"Found tenant user UUID {user_uuid} for email {email}")
|
||||
return str(user_uuid)
|
||||
else:
|
||||
logger.warning(f"No tenant user found for email {email}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying tenant user by email {email}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def verify_token_with_control_panel(token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify JWT token with Control Panel Backend.
|
||||
This ensures consistency across the entire GT 2.0 platform.
|
||||
No fallbacks - if Control Panel is unavailable, authentication fails.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/verify-token",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get("success") and data.get("data", {}).get("valid"):
|
||||
return data["data"]
|
||||
|
||||
return {"valid": False, "error": "Invalid token"}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Control Panel unavailable for token verification: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service unavailable"
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
authorization: HTTPAuthorizationCredentials = Depends(security)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract current user from JWT token.
|
||||
Validates with Control Panel for consistency.
|
||||
No fallbacks - authentication is required.
|
||||
"""
|
||||
if not authorization or not authorization.credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = authorization.credentials
|
||||
validation = await verify_token_with_control_panel(token)
|
||||
|
||||
if not validation.get("valid"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=validation.get("error", "Invalid authentication token"),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user_data = validation.get("user", {})
|
||||
user_type = user_data.get("user_type", "")
|
||||
|
||||
# For super_admin users, allow access to any tenant backend
|
||||
# They will assume the tenant context of the backend they're accessing
|
||||
if user_type == "super_admin":
|
||||
logger.info(f"Super admin user {user_data.get('email')} accessing tenant backend {settings.tenant_id}")
|
||||
|
||||
# Override user data with tenant backend context and admin capabilities
|
||||
user_data.update({
|
||||
'tenant_id': settings.tenant_id,
|
||||
'tenant_domain': settings.tenant_domain,
|
||||
'tenant_name': f'Tenant {settings.tenant_domain}',
|
||||
'tenant_role': 'super_admin',
|
||||
'capabilities': [
|
||||
{'resource': '*', 'actions': ['*'], 'constraints': {}}
|
||||
]
|
||||
})
|
||||
# Tenant ID validation removed - any authenticated Control Panel user can access any tenant
|
||||
|
||||
return user_data
|
||||
|
||||
|
||||
@router.post("/auth/login", response_model=Union[LoginResponse, TFASetupResponse, TFAVerificationResponse])
|
||||
async def login(
|
||||
login_data: LoginRequest,
|
||||
request: Request,
|
||||
response: Response
|
||||
):
|
||||
"""
|
||||
Authenticate user via Control Panel Backend.
|
||||
This ensures a single source of truth for user authentication.
|
||||
No fallbacks - if Control Panel is unavailable, login fails.
|
||||
"""
|
||||
try:
|
||||
# Forward authentication request to Control Panel
|
||||
async with httpx.AsyncClient() as client:
|
||||
cp_response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/login",
|
||||
json={
|
||||
"email": login_data.email,
|
||||
"password": login_data.password
|
||||
},
|
||||
headers={
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown"),
|
||||
"X-App-Type": "tenant_app" # Distinguish from control_panel sessions
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if cp_response.status_code == 200:
|
||||
data = cp_response.json()
|
||||
|
||||
# Forward Set-Cookie headers from Control Panel to client
|
||||
if "set-cookie" in cp_response.headers:
|
||||
response.headers["set-cookie"] = cp_response.headers["set-cookie"]
|
||||
|
||||
# Check if this is a TFA response (setup or verification)
|
||||
if data.get("requires_tfa"):
|
||||
logger.info(
|
||||
f"TFA required for user {data.get('user_email')}: "
|
||||
f"configured={data.get('tfa_configured')}"
|
||||
)
|
||||
|
||||
# Return TFA response directly without modification
|
||||
if data.get("tfa_configured"):
|
||||
# TFA verification required
|
||||
return TFAVerificationResponse(**data)
|
||||
else:
|
||||
# TFA setup required
|
||||
return TFASetupResponse(**data)
|
||||
|
||||
# Handle normal login response (no TFA required)
|
||||
user = data.get("user", {})
|
||||
user_type = user.get("user_type", "")
|
||||
|
||||
# For super_admin users, allow access to any tenant backend
|
||||
# They will assume the tenant context of the backend they're accessing
|
||||
if user_type == "super_admin":
|
||||
logger.info(f"Super admin user {user.get('email')} accessing tenant backend {settings.tenant_id}")
|
||||
# Admin users can access any tenant backend - no tenant validation needed
|
||||
# Tenant ID validation removed - any authenticated Control Panel user can access any tenant
|
||||
|
||||
logger.info(
|
||||
f"User login successful: {user.get('email')} (ID: {user.get('id')})"
|
||||
)
|
||||
|
||||
# Use the original Control Panel JWT token - do not replace with tenant UUID token
|
||||
# UUID mapping will be handled at the service level when accessing tenant resources
|
||||
access_token = data["access_token"]
|
||||
logger.info(f"Using original Control Panel JWT for user {user.get('email')}")
|
||||
|
||||
# Get user's role from tenant database for frontend
|
||||
from app.core.permissions import get_user_role
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
tenant_role = await get_user_role(pg_client, user.get('email'), settings.tenant_domain)
|
||||
|
||||
# Add tenant role to user object for frontend
|
||||
user['role'] = tenant_role
|
||||
logger.info(f"Added tenant role '{tenant_role}' to user {user.get('email')}")
|
||||
|
||||
# Create tenant context for frontend
|
||||
tenant_info = {
|
||||
"id": settings.tenant_id,
|
||||
"domain": settings.tenant_domain,
|
||||
"name": f"Tenant {settings.tenant_domain}"
|
||||
}
|
||||
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
expires_in=data.get("expires_in", 86400),
|
||||
user=user,
|
||||
tenant=tenant_info
|
||||
)
|
||||
|
||||
elif response.status_code == 401:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Control Panel login failed: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Authentication service error"
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Control Panel connection failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Login error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Login failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/refresh")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Refresh authentication token.
|
||||
For now, returns the same token (Control Panel tokens have 24hr expiry).
|
||||
"""
|
||||
# Get current token
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# In a production system, we'd generate a new token here
|
||||
# For now, return the existing token since Control Panel tokens last 24 hours
|
||||
return {
|
||||
"access_token": token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 86400, # 24 hours
|
||||
"user": current_user
|
||||
}
|
||||
|
||||
|
||||
@router.post("/auth/logout")
|
||||
async def logout(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Logout user (forward to Control Panel for audit logging).
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
if token:
|
||||
# Forward logout to Control Panel for audit logging
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/auth/logout",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(
|
||||
f"User logout successful: {current_user.get('email')} (ID: {current_user.get('id')})"
|
||||
)
|
||||
|
||||
# Always return success for logout
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Logout error: {str(e)}")
|
||||
# Always return success for logout
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
|
||||
|
||||
@router.get("/auth/me")
|
||||
async def get_current_user_info(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""
|
||||
Get current user information.
|
||||
"""
|
||||
return {
|
||||
"success": True,
|
||||
"data": current_user
|
||||
}
|
||||
|
||||
|
||||
@router.get("/auth/verify")
|
||||
async def verify_token(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Verify if token is valid.
|
||||
"""
|
||||
return {
|
||||
"success": True,
|
||||
"valid": True,
|
||||
"user": current_user
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PASSWORD RESET PROXY ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/auth/request-password-reset")
|
||||
async def request_password_reset_proxy(
|
||||
data: dict,
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
Proxy password reset request to Control Panel Backend.
|
||||
Forwards client IP for rate limiting.
|
||||
"""
|
||||
try:
|
||||
# Get client IP for rate limiting
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/auth/request-password-reset",
|
||||
json=data,
|
||||
headers={
|
||||
"X-Forwarded-For": client_ip,
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to forward password reset request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Password reset service unavailable"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Password reset request error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Password reset request failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/reset-password")
|
||||
async def reset_password_proxy(data: dict):
|
||||
"""
|
||||
Proxy password reset to Control Panel Backend.
|
||||
"""
|
||||
try:
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/auth/reset-password",
|
||||
json=data,
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Return response with original status code
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response.json().get("detail", "Password reset failed")
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to forward password reset: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Password reset service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Password reset error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Password reset failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/auth/verify-reset-token")
|
||||
async def verify_reset_token_proxy(token: str):
|
||||
"""
|
||||
Proxy token verification to Control Panel Backend.
|
||||
"""
|
||||
try:
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/auth/verify-reset-token",
|
||||
params={"token": token},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to verify reset token: {str(e)}")
|
||||
return {"valid": False, "error": "Token verification service unavailable"}
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification error: {str(e)}")
|
||||
return {"valid": False, "error": "Token verification failed"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TFA PROXY ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/auth/tfa/session-data")
|
||||
async def get_tfa_session_data(request: Request, response: Response):
|
||||
"""
|
||||
Proxy TFA session data request to Control Panel Backend.
|
||||
Forwards cookies for session validation.
|
||||
"""
|
||||
try:
|
||||
# Get cookies from request
|
||||
cookie_header = request.headers.get("cookie", "")
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
cp_response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/session-data",
|
||||
headers={
|
||||
"Cookie": cookie_header,
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if cp_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=cp_response.status_code,
|
||||
detail=cp_response.json().get("detail", "Failed to get TFA session data")
|
||||
)
|
||||
|
||||
return cp_response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to get TFA session data: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/auth/tfa/session-qr-code")
|
||||
async def get_tfa_session_qr_code(request: Request):
|
||||
"""
|
||||
Proxy TFA QR code blob request to Control Panel Backend.
|
||||
Forwards cookies for session validation.
|
||||
Returns PNG image blob (never exposes TOTP secret to JavaScript).
|
||||
"""
|
||||
try:
|
||||
# Get cookies from request
|
||||
cookie_header = request.headers.get("cookie", "")
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
cp_response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/session-qr-code",
|
||||
headers={
|
||||
"Cookie": cookie_header,
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if cp_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=cp_response.status_code,
|
||||
detail="Failed to get TFA QR code"
|
||||
)
|
||||
|
||||
# Return raw PNG bytes with image/png content type
|
||||
from fastapi.responses import Response
|
||||
return Response(
|
||||
content=cp_response.content,
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Cache-Control": "no-store, no-cache, must-revalidate",
|
||||
"Pragma": "no-cache",
|
||||
"Expires": "0"
|
||||
}
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to get TFA QR code: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TFA session data error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get TFA session data"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/tfa/verify-login")
|
||||
async def verify_tfa_login_proxy(
|
||||
data: dict,
|
||||
request: Request,
|
||||
response: Response
|
||||
):
|
||||
"""
|
||||
Proxy TFA verification to Control Panel Backend.
|
||||
Forwards cookies for session validation.
|
||||
"""
|
||||
try:
|
||||
# Get cookies from request
|
||||
cookie_header = request.headers.get("cookie", "")
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
cp_response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/verify-login",
|
||||
json=data,
|
||||
headers={
|
||||
"Cookie": cookie_header,
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Forward Set-Cookie headers (cookie deletion after verification)
|
||||
if "set-cookie" in cp_response.headers:
|
||||
response.headers["set-cookie"] = cp_response.headers["set-cookie"]
|
||||
|
||||
if cp_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=cp_response.status_code,
|
||||
detail=cp_response.json().get("detail", "TFA verification failed")
|
||||
)
|
||||
|
||||
return cp_response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to verify TFA: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TFA verification error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="TFA verification failed"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SESSION STATUS ENDPOINT (Issue #264)
|
||||
# ============================================================================
|
||||
|
||||
class SessionStatusResponse(BaseModel):
|
||||
"""Response for session status check"""
|
||||
is_valid: bool
|
||||
seconds_remaining: int # Seconds until idle timeout
|
||||
show_warning: bool # True if < 5 minutes remaining
|
||||
absolute_seconds_remaining: Optional[int] = None # Seconds until absolute timeout
|
||||
|
||||
|
||||
@router.get("/auth/session/status", response_model=SessionStatusResponse)
|
||||
async def get_session_status(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get current session status for frontend session monitoring.
|
||||
|
||||
Proxies request to Control Panel Backend which is the authoritative
|
||||
source for session state. This endpoint replaces the complex react-idle-timer
|
||||
approach with a simple polling mechanism.
|
||||
|
||||
Frontend calls this every 60 seconds to check session health.
|
||||
|
||||
Returns:
|
||||
- is_valid: Whether session is currently valid
|
||||
- seconds_remaining: Seconds until idle timeout (30 min from last activity)
|
||||
- show_warning: True if warning should be shown (< 5 min remaining)
|
||||
- absolute_seconds_remaining: Seconds until absolute timeout (8 hours from login)
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No token provided"
|
||||
)
|
||||
|
||||
# Forward to Control Panel Backend (authoritative session source)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/session/status",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=5.0 # Short timeout for health check
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return SessionStatusResponse(
|
||||
is_valid=data.get("is_valid", True),
|
||||
seconds_remaining=data.get("seconds_remaining", 1800),
|
||||
show_warning=data.get("show_warning", False),
|
||||
absolute_seconds_remaining=data.get("absolute_seconds_remaining")
|
||||
)
|
||||
|
||||
elif response.status_code == 401:
|
||||
# Session expired - return invalid status
|
||||
return SessionStatusResponse(
|
||||
is_valid=False,
|
||||
seconds_remaining=0,
|
||||
show_warning=False,
|
||||
absolute_seconds_remaining=None
|
||||
)
|
||||
|
||||
else:
|
||||
# Unexpected response - return safe defaults
|
||||
logger.warning(
|
||||
f"Unexpected session status response: {response.status_code}"
|
||||
)
|
||||
return SessionStatusResponse(
|
||||
is_valid=True,
|
||||
seconds_remaining=1800, # 30 minutes default
|
||||
show_warning=False,
|
||||
absolute_seconds_remaining=None
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
# Control Panel unavailable - FAIL CLOSED for security
|
||||
logger.error(f"Session status check failed - Control Panel unavailable: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Session validation service unavailable"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Session status proxy error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Session status check failed"
|
||||
)
|
||||
144
apps/tenant-backend/app/api/conversations.py
Normal file
144
apps/tenant-backend/app/api/conversations.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Conversation API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Handles conversation management for AI chat sessions.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
router = APIRouter(tags=["conversations"])
|
||||
|
||||
|
||||
class CreateConversationRequest(BaseModel):
|
||||
"""Request model for creating a conversation"""
|
||||
agent_id: str = Field(..., description="Agent ID to use for conversation")
|
||||
title: Optional[str] = Field(None, description="Optional conversation title")
|
||||
|
||||
|
||||
class MessageRequest(BaseModel):
|
||||
"""Request model for sending a message"""
|
||||
content: str = Field(..., description="Message content")
|
||||
stream: bool = Field(default=False, description="Stream the response")
|
||||
|
||||
|
||||
async def get_conversation_service(db: AsyncSession = Depends(get_db)) -> ConversationService:
|
||||
"""Get conversation service instance"""
|
||||
return ConversationService(db)
|
||||
|
||||
|
||||
@router.get("/conversations")
|
||||
async def list_conversations(
|
||||
agent_id: Optional[str] = Query(None, description="Filter by agent ID"),
|
||||
limit: int = Query(20, ge=1, le=100, description="Number of results"),
|
||||
offset: int = Query(0, ge=0, description="Pagination offset"),
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""List user's conversations"""
|
||||
try:
|
||||
# Extract email from user object
|
||||
user_email = current_user.get("email") if isinstance(current_user, dict) else current_user
|
||||
result = await service.list_conversations(
|
||||
user_identifier=user_email,
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
return JSONResponse(status_code=200, content=result)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/conversations")
|
||||
async def create_conversation(
|
||||
request: CreateConversationRequest,
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""Create new conversation"""
|
||||
try:
|
||||
# Extract email from user object
|
||||
user_email = current_user.get("email") if isinstance(current_user, dict) else current_user
|
||||
result = await service.create_conversation(
|
||||
agent_id=request.agent_id,
|
||||
title=request.title,
|
||||
user_identifier=user_email
|
||||
)
|
||||
return JSONResponse(status_code=201, content=result)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}")
|
||||
async def get_conversation(
|
||||
conversation_id: int,
|
||||
include_messages: bool = Query(False, description="Include messages in response"),
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""Get conversation details"""
|
||||
try:
|
||||
result = await service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user,
|
||||
include_messages=include_messages
|
||||
)
|
||||
return JSONResponse(status_code=200, content=result)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}")
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
request: dict,
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""Update a conversation title"""
|
||||
try:
|
||||
title = request.get("title")
|
||||
if not title:
|
||||
raise ValueError("Title is required")
|
||||
|
||||
await service.update_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user,
|
||||
title=title
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"message": "Conversation updated successfully"})
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}")
|
||||
async def delete_conversation(
|
||||
conversation_id: int,
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""Delete a conversation"""
|
||||
try:
|
||||
await service.delete_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"message": "Conversation deleted successfully"})
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
404
apps/tenant-backend/app/api/documents.py
Normal file
404
apps/tenant-backend/app/api/documents.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""
|
||||
Document API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Handles document upload and management using PostgreSQL file service with perfect tenant isolation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Form, Query
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.core.path_security import sanitize_filename
|
||||
from app.services.postgresql_file_service import PostgreSQLFileService
|
||||
from app.services.document_processor import DocumentProcessor, get_document_processor
|
||||
from app.api.auth import get_tenant_user_uuid_by_email
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["documents"])
|
||||
|
||||
|
||||
@router.get("/documents")
|
||||
async def list_documents(
|
||||
status: Optional[str] = Query(None, description="Filter by processing status"),
|
||||
dataset_id: Optional[str] = Query(None, description="Filter by dataset ID"),
|
||||
offset: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""List user's documents with optional filtering using PostgreSQL file service"""
|
||||
try:
|
||||
# Get tenant user UUID from Control Panel user
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
tenant_user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
if not tenant_user_uuid:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_email} not found in tenant system")
|
||||
|
||||
# Get PostgreSQL file service
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=current_user.get('tenant_domain', 'test'),
|
||||
user_id=tenant_user_uuid
|
||||
)
|
||||
|
||||
# List files (documents) with optional dataset filter
|
||||
files = await file_service.list_files(
|
||||
category="documents",
|
||||
dataset_id=dataset_id,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
# Get chunk counts and document status for these documents
|
||||
document_ids = [file_info["id"] for file_info in files]
|
||||
chunk_counts = {}
|
||||
document_status = {}
|
||||
if document_ids:
|
||||
from app.core.database import get_postgresql_client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get chunk counts
|
||||
chunk_query = """
|
||||
SELECT document_id, COUNT(*) as chunk_count
|
||||
FROM document_chunks
|
||||
WHERE document_id = ANY($1)
|
||||
GROUP BY document_id
|
||||
"""
|
||||
chunk_results = await pg_client.execute_query(chunk_query, document_ids)
|
||||
chunk_counts = {str(row["document_id"]): row["chunk_count"] for row in chunk_results}
|
||||
|
||||
# Get document processing status and progress
|
||||
status_query = """
|
||||
SELECT id, processing_status, chunk_count, chunks_processed,
|
||||
total_chunks_expected, processing_progress, processing_stage,
|
||||
error_message, created_at, updated_at
|
||||
FROM documents
|
||||
WHERE id = ANY($1)
|
||||
"""
|
||||
status_results = await pg_client.execute_query(status_query, document_ids)
|
||||
document_status = {str(row["id"]): row for row in status_results}
|
||||
|
||||
# Convert to expected document format
|
||||
documents = []
|
||||
for file_info in files:
|
||||
doc_id = str(file_info["id"])
|
||||
chunk_count = chunk_counts.get(doc_id, 0)
|
||||
status_info = document_status.get(doc_id, {})
|
||||
|
||||
documents.append({
|
||||
"id": file_info["id"],
|
||||
"filename": file_info["filename"],
|
||||
"original_filename": file_info["original_filename"],
|
||||
"file_type": file_info["content_type"],
|
||||
"file_size_bytes": file_info["file_size"],
|
||||
"dataset_id": file_info.get("dataset_id"),
|
||||
"processing_status": status_info.get("processing_status", "completed"),
|
||||
"chunk_count": chunk_count,
|
||||
"chunks_processed": status_info.get("chunks_processed", chunk_count),
|
||||
"total_chunks_expected": status_info.get("total_chunks_expected", chunk_count),
|
||||
"processing_progress": status_info.get("processing_progress", 100 if chunk_count > 0 else 0),
|
||||
"processing_stage": status_info.get("processing_stage", "Completed" if chunk_count > 0 else "Pending"),
|
||||
"error_message": status_info.get("error_message"),
|
||||
"vector_count": chunk_count, # Each chunk gets one vector
|
||||
"created_at": file_info["created_at"],
|
||||
"processed_at": status_info.get("updated_at", file_info["created_at"])
|
||||
})
|
||||
|
||||
# Apply status filter if provided
|
||||
if status:
|
||||
documents = [doc for doc in documents if doc["processing_status"] == status]
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list documents: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/documents")
|
||||
async def upload_document(
|
||||
file: UploadFile = File(...),
|
||||
dataset_id: Optional[str] = Form(None),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Upload new document using PostgreSQL file service"""
|
||||
try:
|
||||
# Validate file
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
|
||||
# Get file extension
|
||||
import pathlib
|
||||
file_extension = pathlib.Path(file.filename).suffix.lower()
|
||||
|
||||
# Validate file type
|
||||
allowed_extensions = ['.pdf', '.docx', '.txt', '.md', '.html', '.csv', '.json']
|
||||
if file_extension not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}"
|
||||
)
|
||||
|
||||
# Get tenant user UUID from Control Panel user
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
tenant_user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
if not tenant_user_uuid:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_email} not found in tenant system")
|
||||
|
||||
# Get PostgreSQL file service
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=current_user.get('tenant_domain', 'test'),
|
||||
user_id=tenant_user_uuid
|
||||
)
|
||||
|
||||
# Store file
|
||||
result = await file_service.store_file(
|
||||
file=file,
|
||||
dataset_id=dataset_id,
|
||||
category="documents"
|
||||
)
|
||||
|
||||
# Start document processing if dataset_id is provided
|
||||
if dataset_id:
|
||||
try:
|
||||
# Get document processor with tenant domain
|
||||
tenant_domain = current_user.get('tenant_domain', 'test')
|
||||
processor = await get_document_processor(tenant_domain=tenant_domain)
|
||||
|
||||
# Process document for RAG (async)
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# Create temporary file for processing
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
|
||||
# Write file content to temp file
|
||||
file.file.seek(0) # Reset file pointer
|
||||
temp_file.write(await file.read())
|
||||
temp_file.flush()
|
||||
|
||||
# Process document using existing document ID
|
||||
try:
|
||||
processed_doc = await processor.process_file(
|
||||
file_path=Path(temp_file.name),
|
||||
dataset_id=dataset_id,
|
||||
user_id=tenant_user_uuid,
|
||||
original_filename=file.filename,
|
||||
document_id=result["id"] # Use existing document ID
|
||||
)
|
||||
|
||||
processing_status = "completed"
|
||||
chunk_count = getattr(processed_doc, 'chunk_count', 0)
|
||||
|
||||
except Exception as proc_error:
|
||||
logger.error(f"Document processing failed: {proc_error}")
|
||||
processing_status = "failed"
|
||||
chunk_count = 0
|
||||
finally:
|
||||
# Clean up temp file
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
except Exception as proc_error:
|
||||
logger.error(f"Failed to initiate document processing: {proc_error}")
|
||||
processing_status = "pending"
|
||||
chunk_count = 0
|
||||
else:
|
||||
processing_status = "completed"
|
||||
chunk_count = 0
|
||||
|
||||
# Return in expected format
|
||||
return {
|
||||
"id": result["id"],
|
||||
"filename": result["filename"],
|
||||
"original_filename": file.filename,
|
||||
"file_type": result["content_type"],
|
||||
"file_size_bytes": result["file_size"],
|
||||
"processing_status": processing_status,
|
||||
"chunk_count": chunk_count,
|
||||
"vector_count": chunk_count, # Each chunk gets one vector
|
||||
"created_at": result["upload_timestamp"],
|
||||
"processed_at": result["upload_timestamp"]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload document: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/documents/{document_id}/process")
|
||||
async def process_document(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Process a document for RAG pipeline (text extraction, chunking, embedding generation)"""
|
||||
try:
|
||||
# Get tenant user UUID from Control Panel user
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
tenant_user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
if not tenant_user_uuid:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_email} not found in tenant system")
|
||||
|
||||
tenant_domain = current_user.get('tenant_domain', 'test')
|
||||
|
||||
# Get PostgreSQL file service to verify document exists
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=tenant_user_uuid
|
||||
)
|
||||
|
||||
# Get file info to verify ownership and get file metadata
|
||||
file_info = await file_service.get_file_info(document_id)
|
||||
if not file_info:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Get document processor with tenant domain
|
||||
processor = await get_document_processor(tenant_domain=tenant_domain)
|
||||
|
||||
# Get file extension for temp file
|
||||
import pathlib
|
||||
original_filename = file_info.get("original_filename", file_info.get("filename", "unknown"))
|
||||
# Sanitize the filename to prevent path injection
|
||||
safe_filename = sanitize_filename(original_filename)
|
||||
file_extension = pathlib.Path(safe_filename).suffix.lower() if safe_filename else ".tmp"
|
||||
|
||||
# Create temporary file from database content for processing
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# codeql[py/path-injection] file_extension derived from sanitize_filename() at line 273
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
|
||||
# Stream file content from database to temp file
|
||||
async for chunk in file_service.get_file(document_id):
|
||||
temp_file.write(chunk)
|
||||
temp_file.flush()
|
||||
|
||||
# Process document
|
||||
try:
|
||||
processed_doc = await processor.process_file(
|
||||
file_path=Path(temp_file.name),
|
||||
dataset_id=file_info.get("dataset_id"),
|
||||
user_id=tenant_user_uuid,
|
||||
original_filename=original_filename
|
||||
)
|
||||
|
||||
processing_status = "completed"
|
||||
chunk_count = getattr(processed_doc, 'chunk_count', 0)
|
||||
|
||||
except Exception as proc_error:
|
||||
logger.error(f"Document processing failed for {document_id}: {proc_error}")
|
||||
processing_status = "failed"
|
||||
chunk_count = 0
|
||||
finally:
|
||||
# Clean up temp file
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"processing_status": processing_status,
|
||||
"message": "Document processed successfully" if processing_status == "completed" else f"Processing failed: {processing_status}",
|
||||
"chunk_count": chunk_count,
|
||||
"processed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process document {document_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}/status")
|
||||
async def get_document_processing_status(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get processing status of a document"""
|
||||
try:
|
||||
# Get document processor to check status
|
||||
tenant_domain = current_user.get('tenant_domain', 'test')
|
||||
processor = await get_document_processor(tenant_domain=tenant_domain)
|
||||
status = await processor.get_processing_status(document_id)
|
||||
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"processing_status": status.get("status", "unknown"),
|
||||
"error_message": status.get("error_message"),
|
||||
"chunk_count": status.get("chunk_count", 0),
|
||||
"chunks_processed": status.get("chunks_processed", 0),
|
||||
"total_chunks_expected": status.get("total_chunks_expected", 0),
|
||||
"processing_progress": status.get("processing_progress", 0),
|
||||
"processing_stage": status.get("processing_stage", "")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get processing status for {document_id}: {e}", exc_info=True)
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"processing_status": "unknown",
|
||||
"error_message": "Unable to retrieve processing status",
|
||||
"chunk_count": 0,
|
||||
"chunks_processed": 0,
|
||||
"total_chunks_expected": 0,
|
||||
"processing_progress": 0,
|
||||
"processing_stage": "Error"
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/documents/{document_id}")
|
||||
async def delete_document(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a document and its associated data"""
|
||||
try:
|
||||
# Get tenant user UUID from Control Panel user
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
tenant_user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
if not tenant_user_uuid:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_email} not found in tenant system")
|
||||
|
||||
# Get PostgreSQL file service
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=current_user.get('tenant_domain', 'test'),
|
||||
user_id=tenant_user_uuid
|
||||
)
|
||||
|
||||
# Verify document exists and user has permission to delete it
|
||||
file_info = await file_service.get_file_info(document_id)
|
||||
if not file_info:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Delete the document
|
||||
success = await file_service.delete_file(document_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete document")
|
||||
|
||||
return {
|
||||
"message": "Document deleted successfully",
|
||||
"document_id": document_id
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document {document_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Additional endpoints can be added here as needed for RAG processing
|
||||
99
apps/tenant-backend/app/api/embeddings.py
Normal file
99
apps/tenant-backend/app/api/embeddings.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
BGE-M3 Embedding Configuration API for Tenant Backend
|
||||
|
||||
Provides endpoint to update embedding configuration at runtime.
|
||||
This allows the tenant backend to switch between local and external embedding endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
import os
|
||||
|
||||
from app.services.embedding_client import get_embedding_client
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BGE_M3_ConfigRequest(BaseModel):
|
||||
"""BGE-M3 configuration update request"""
|
||||
is_local_mode: bool = True
|
||||
external_endpoint: Optional[str] = None
|
||||
|
||||
|
||||
class BGE_M3_ConfigResponse(BaseModel):
|
||||
"""BGE-M3 configuration response"""
|
||||
is_local_mode: bool
|
||||
current_endpoint: str
|
||||
message: str
|
||||
|
||||
|
||||
@router.post("/config/bge-m3", response_model=BGE_M3_ConfigResponse)
|
||||
async def update_bge_m3_config(
|
||||
config_request: BGE_M3_ConfigRequest
|
||||
) -> BGE_M3_ConfigResponse:
|
||||
"""
|
||||
Update BGE-M3 configuration for the tenant backend.
|
||||
|
||||
This allows switching between local and external endpoints at runtime.
|
||||
No authentication required for service-to-service calls.
|
||||
"""
|
||||
try:
|
||||
# Get the global embedding client
|
||||
embedding_client = get_embedding_client()
|
||||
|
||||
# Determine new endpoint
|
||||
if config_request.is_local_mode:
|
||||
new_endpoint = os.getenv('EMBEDDING_ENDPOINT', 'http://host.docker.internal:8005')
|
||||
else:
|
||||
if not config_request.external_endpoint:
|
||||
raise HTTPException(status_code=400, detail="External endpoint required when not in local mode")
|
||||
new_endpoint = config_request.external_endpoint
|
||||
|
||||
# Update the client endpoint
|
||||
embedding_client.update_endpoint(new_endpoint)
|
||||
|
||||
# Update environment variables for future client instances
|
||||
os.environ['BGE_M3_LOCAL_MODE'] = str(config_request.is_local_mode).lower()
|
||||
if config_request.external_endpoint:
|
||||
os.environ['BGE_M3_EXTERNAL_ENDPOINT'] = config_request.external_endpoint
|
||||
|
||||
logger.info(
|
||||
f"BGE-M3 configuration updated: "
|
||||
f"local_mode={config_request.is_local_mode}, "
|
||||
f"endpoint={new_endpoint}"
|
||||
)
|
||||
|
||||
return BGE_M3_ConfigResponse(
|
||||
is_local_mode=config_request.is_local_mode,
|
||||
current_endpoint=new_endpoint,
|
||||
message=f"BGE-M3 configuration updated to {'local' if config_request.is_local_mode else 'external'} mode"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating BGE-M3 config: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/config/bge-m3", response_model=BGE_M3_ConfigResponse)
|
||||
async def get_bge_m3_config() -> BGE_M3_ConfigResponse:
|
||||
"""
|
||||
Get current BGE-M3 configuration.
|
||||
"""
|
||||
try:
|
||||
embedding_client = get_embedding_client()
|
||||
|
||||
# Determine if currently in local mode
|
||||
is_local_mode = os.getenv('BGE_M3_LOCAL_MODE', 'true').lower() == 'true'
|
||||
|
||||
return BGE_M3_ConfigResponse(
|
||||
is_local_mode=is_local_mode,
|
||||
current_endpoint=embedding_client.base_url,
|
||||
message="Current BGE-M3 configuration"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting BGE-M3 config: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
358
apps/tenant-backend/app/api/events.py
Normal file
358
apps/tenant-backend/app/api/events.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
Event Automation API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Manages event subscriptions, triggers, and automation workflows
|
||||
with perfect tenant isolation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db_session
|
||||
from app.core.security import get_current_user_email, get_tenant_info
|
||||
from app.services.event_service import EventService, EventType, ActionType, EventActionConfig
|
||||
from app.schemas.event import (
|
||||
EventSubscriptionCreate, EventSubscriptionResponse, EventActionCreate,
|
||||
EventResponse, EventStatistics, ScheduledTaskResponse
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["events"])
|
||||
|
||||
|
||||
@router.post("/subscriptions", response_model=EventSubscriptionResponse)
|
||||
async def create_event_subscription(
|
||||
subscription: EventSubscriptionCreate,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Create a new event subscription"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
|
||||
# Convert actions
|
||||
actions = []
|
||||
for action_data in subscription.actions:
|
||||
action_config = EventActionConfig(
|
||||
action_type=ActionType(action_data.action_type),
|
||||
config=action_data.config,
|
||||
delay_seconds=action_data.delay_seconds,
|
||||
retry_count=action_data.retry_count,
|
||||
retry_delay=action_data.retry_delay,
|
||||
condition=action_data.condition
|
||||
)
|
||||
actions.append(action_config)
|
||||
|
||||
subscription_id = await event_service.create_subscription(
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
event_type=EventType(subscription.event_type),
|
||||
actions=actions,
|
||||
name=subscription.name,
|
||||
description=subscription.description
|
||||
)
|
||||
|
||||
# Get created subscription
|
||||
subscriptions = await event_service.get_user_subscriptions(
|
||||
current_user, tenant_info["tenant_id"]
|
||||
)
|
||||
created_subscription = next(
|
||||
(s for s in subscriptions if s.id == subscription_id), None
|
||||
)
|
||||
|
||||
if not created_subscription:
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve created subscription")
|
||||
|
||||
return EventSubscriptionResponse.from_orm(created_subscription)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create event subscription: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/subscriptions", response_model=List[EventSubscriptionResponse])
|
||||
async def list_event_subscriptions(
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""List user's event subscriptions"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
subscriptions = await event_service.get_user_subscriptions(
|
||||
current_user, tenant_info["tenant_id"]
|
||||
)
|
||||
|
||||
return [EventSubscriptionResponse.from_orm(sub) for sub in subscriptions]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list event subscriptions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/subscriptions/{subscription_id}/status")
|
||||
async def update_subscription_status(
|
||||
subscription_id: str,
|
||||
is_active: bool,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Update event subscription status"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
success = await event_service.update_subscription_status(
|
||||
subscription_id, current_user, is_active
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Subscription not found")
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": f"Subscription {'activated' if is_active else 'deactivated'} successfully"
|
||||
})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update subscription status: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/subscriptions/{subscription_id}")
|
||||
async def delete_event_subscription(
|
||||
subscription_id: str,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Delete event subscription"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
success = await event_service.delete_subscription(subscription_id, current_user)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Subscription not found")
|
||||
|
||||
return JSONResponse(content={"message": "Subscription deleted successfully"})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete subscription: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/emit")
|
||||
async def emit_event(
|
||||
event_type: str,
|
||||
data: Dict[str, Any],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Manually emit an event"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
|
||||
# Validate event type
|
||||
try:
|
||||
event_type_enum = EventType(event_type)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid event type: {event_type}")
|
||||
|
||||
event_id = await event_service.emit_event(
|
||||
event_type=event_type_enum,
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
data=data,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
return JSONResponse(content={
|
||||
"event_id": event_id,
|
||||
"message": "Event emitted successfully"
|
||||
})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to emit event: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/history", response_model=List[EventResponse])
|
||||
async def get_event_history(
|
||||
event_types: Optional[List[str]] = Query(None, description="Filter by event types"),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Get event history for user"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
|
||||
# Convert event types if provided
|
||||
event_type_enums = None
|
||||
if event_types:
|
||||
try:
|
||||
event_type_enums = [EventType(et) for et in event_types]
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid event type: {e}")
|
||||
|
||||
events = await event_service.get_event_history(
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
event_types=event_type_enums,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return [EventResponse.from_orm(event) for event in events]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get event history: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/statistics", response_model=EventStatistics)
|
||||
async def get_event_statistics(
|
||||
days: int = Query(30, ge=1, le=365),
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Get event statistics for user"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
stats = await event_service.get_event_statistics(
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
days=days
|
||||
)
|
||||
|
||||
return EventStatistics(**stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get event statistics: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/types")
|
||||
async def get_available_event_types():
|
||||
"""Get available event types and actions"""
|
||||
return JSONResponse(content={
|
||||
"event_types": [
|
||||
{"value": et.value, "description": et.value.replace("_", " ").title()}
|
||||
for et in EventType
|
||||
],
|
||||
"action_types": [
|
||||
{"value": at.value, "description": at.value.replace("_", " ").title()}
|
||||
for at in ActionType
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
# Document automation endpoints
|
||||
@router.post("/documents/{document_id}/auto-process")
|
||||
async def trigger_document_processing(
|
||||
document_id: int,
|
||||
chunking_strategy: Optional[str] = Query("hybrid", description="Chunking strategy"),
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Trigger automated document processing"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
|
||||
event_id = await event_service.emit_event(
|
||||
event_type=EventType.DOCUMENT_UPLOADED,
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
data={
|
||||
"document_id": document_id,
|
||||
"filename": f"document_{document_id}",
|
||||
"chunking_strategy": chunking_strategy,
|
||||
"manual_trigger": True
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(content={
|
||||
"event_id": event_id,
|
||||
"message": "Document processing automation triggered"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger document processing: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Conversation automation endpoints
|
||||
@router.post("/conversations/{conversation_id}/auto-analyze")
|
||||
async def trigger_conversation_analysis(
|
||||
conversation_id: int,
|
||||
analysis_type: str = Query("sentiment", description="Type of analysis"),
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Trigger automated conversation analysis"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
|
||||
event_id = await event_service.emit_event(
|
||||
event_type=EventType.CONVERSATION_STARTED,
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
data={
|
||||
"conversation_id": conversation_id,
|
||||
"analysis_type": analysis_type,
|
||||
"manual_trigger": True
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(content={
|
||||
"event_id": event_id,
|
||||
"message": "Conversation analysis automation triggered"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger conversation analysis: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Default subscriptions endpoint
|
||||
@router.post("/setup-defaults")
|
||||
async def setup_default_subscriptions(
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Setup default event subscriptions for user"""
|
||||
try:
|
||||
from app.services.event_service import setup_default_subscriptions
|
||||
|
||||
event_service = EventService(db)
|
||||
await setup_default_subscriptions(
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
event_service=event_service
|
||||
)
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "Default event subscriptions created successfully"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup default subscriptions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
3
apps/tenant-backend/app/api/internal/__init__.py
Normal file
3
apps/tenant-backend/app/api/internal/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Internal API endpoints for service-to-service communication
|
||||
"""
|
||||
137
apps/tenant-backend/app/api/messages.py
Normal file
137
apps/tenant-backend/app/api/messages.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Message API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Handles message management within conversations.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
import logging
|
||||
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.core.security import get_current_user
|
||||
from app.core.user_resolver import resolve_user_uuid
|
||||
|
||||
router = APIRouter(tags=["messages"])
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
"""Request model for sending a message"""
|
||||
content: str = Field(..., description="Message content")
|
||||
stream: bool = Field(default=False, description="Stream the response")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_conversation_service(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ConversationService:
|
||||
"""Get properly initialized conversation service"""
|
||||
tenant_domain, user_email, user_id = await resolve_user_uuid(current_user)
|
||||
return ConversationService(tenant_domain=tenant_domain, user_id=user_id)
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}/messages")
|
||||
async def list_messages(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""List messages in conversation"""
|
||||
try:
|
||||
# Get properly initialized service
|
||||
service = await get_conversation_service(current_user)
|
||||
|
||||
# Get conversation with messages
|
||||
result = await service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=service.user_id,
|
||||
include_messages=True
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"messages": result.get("messages", [])}
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/conversations/{conversation_id}/messages")
|
||||
async def send_message(
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
stream: bool = False,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""Send a message and get AI response"""
|
||||
try:
|
||||
# Get properly initialized service
|
||||
service = await get_conversation_service(current_user)
|
||||
|
||||
# Send message
|
||||
result = await service.send_message(
|
||||
conversation_id=conversation_id,
|
||||
content=content,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
# Generate title after first message
|
||||
if result.get("is_first_message", False):
|
||||
try:
|
||||
await service.auto_generate_conversation_title(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=service.user_id
|
||||
)
|
||||
logger.info(f"✅ Title generated for conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate title: {e}")
|
||||
# Don't fail the request if title generation fails
|
||||
|
||||
return JSONResponse(status_code=201, content=result)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}/messages/stream")
|
||||
async def stream_message(
|
||||
conversation_id: int,
|
||||
content: str,
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""Stream AI response for a message"""
|
||||
|
||||
async def generate():
|
||||
"""Generate SSE stream"""
|
||||
try:
|
||||
async for chunk in service.stream_message_response(
|
||||
conversation_id=conversation_id,
|
||||
message_content=content,
|
||||
user_identifier=current_user
|
||||
):
|
||||
# Format as Server-Sent Event
|
||||
yield f"data: {json.dumps({'content': chunk})}\n\n"
|
||||
|
||||
# Send completion signal
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Message streaming failed: {e}", exc_info=True)
|
||||
yield f"data: {json.dumps({'error': 'An internal error occurred. Please try again.'})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
411
apps/tenant-backend/app/api/v1/access_groups.py
Normal file
411
apps/tenant-backend/app/api/v1/access_groups.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Access Groups API for GT 2.0 Tenant Backend
|
||||
|
||||
RESTful API endpoints for managing resource access groups and permissions.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db_session
|
||||
from app.core.security import get_current_user, verify_capability_token
|
||||
from app.models.access_group import (
|
||||
AccessGroup, ResourceCreate, ResourceUpdate, ResourceResponse,
|
||||
AccessGroupModel
|
||||
)
|
||||
from app.services.access_controller import AccessController, AccessControlMiddleware
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/access-groups",
|
||||
tags=["access-groups"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
async def get_access_controller(
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain")
|
||||
) -> AccessController:
|
||||
"""Dependency to get access controller for tenant"""
|
||||
return AccessController(x_tenant_domain)
|
||||
|
||||
|
||||
async def get_middleware(
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain")
|
||||
) -> AccessControlMiddleware:
|
||||
"""Dependency to get access control middleware"""
|
||||
return AccessControlMiddleware(x_tenant_domain)
|
||||
|
||||
|
||||
@router.post("/resources", response_model=ResourceResponse)
|
||||
async def create_resource(
|
||||
resource: ResourceCreate,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Create a new resource with access control.
|
||||
|
||||
- **name**: Resource name
|
||||
- **resource_type**: Type of resource (agent, dataset, etc.)
|
||||
- **access_group**: Access level (individual, team, organization)
|
||||
- **team_members**: List of user IDs for team access
|
||||
- **metadata**: Resource-specific metadata
|
||||
"""
|
||||
try:
|
||||
# Extract bearer token
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
|
||||
# Create resource
|
||||
created_resource = await controller.create_resource(
|
||||
user_id=user.id,
|
||||
resource_data=resource,
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
return ResourceResponse(
|
||||
id=created_resource.id,
|
||||
name=created_resource.name,
|
||||
resource_type=created_resource.resource_type,
|
||||
owner_id=created_resource.owner_id,
|
||||
tenant_domain=created_resource.tenant_domain,
|
||||
access_group=created_resource.access_group,
|
||||
team_members=created_resource.team_members,
|
||||
created_at=created_resource.created_at,
|
||||
updated_at=created_resource.updated_at,
|
||||
metadata=created_resource.metadata,
|
||||
file_path=created_resource.file_path
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create resource: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/resources/{resource_id}/access", response_model=ResourceResponse)
|
||||
async def update_resource_access(
|
||||
resource_id: str,
|
||||
access_update: AccessGroupModel,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
middleware: AccessControlMiddleware = Depends(get_middleware),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Update resource access group.
|
||||
|
||||
Only the resource owner can change access settings.
|
||||
|
||||
- **access_group**: New access level
|
||||
- **team_members**: Updated team members list (for team access)
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Verify permission
|
||||
await middleware.verify_request(
|
||||
user_id=user.id,
|
||||
resource_id=resource_id,
|
||||
action="share",
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
# Update access
|
||||
updated_resource = await controller.update_resource_access(
|
||||
user_id=user.id,
|
||||
resource_id=resource_id,
|
||||
new_access_group=access_update.access_group,
|
||||
team_members=access_update.team_members
|
||||
)
|
||||
|
||||
return ResourceResponse(
|
||||
id=updated_resource.id,
|
||||
name=updated_resource.name,
|
||||
resource_type=updated_resource.resource_type,
|
||||
owner_id=updated_resource.owner_id,
|
||||
tenant_domain=updated_resource.tenant_domain,
|
||||
access_group=updated_resource.access_group,
|
||||
team_members=updated_resource.team_members,
|
||||
created_at=updated_resource.created_at,
|
||||
updated_at=updated_resource.updated_at,
|
||||
metadata=updated_resource.metadata,
|
||||
file_path=updated_resource.file_path
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update access: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/resources", response_model=List[ResourceResponse])
|
||||
async def list_accessible_resources(
|
||||
resource_type: Optional[str] = Query(None, description="Filter by resource type"),
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
List all resources accessible to the current user.
|
||||
|
||||
Returns resources based on access groups:
|
||||
- Individual: Only owned resources
|
||||
- Team: Resources shared with user's teams
|
||||
- Organization: All organization-wide resources
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
|
||||
# Get accessible resources
|
||||
resources = await controller.list_accessible_resources(
|
||||
user_id=user.id,
|
||||
resource_type=resource_type
|
||||
)
|
||||
|
||||
return [
|
||||
ResourceResponse(
|
||||
id=r.id,
|
||||
name=r.name,
|
||||
resource_type=r.resource_type,
|
||||
owner_id=r.owner_id,
|
||||
tenant_domain=r.tenant_domain,
|
||||
access_group=r.access_group,
|
||||
team_members=r.team_members,
|
||||
created_at=r.created_at,
|
||||
updated_at=r.updated_at,
|
||||
metadata=r.metadata,
|
||||
file_path=r.file_path
|
||||
)
|
||||
for r in resources
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list resources: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/resources/{resource_id}", response_model=ResourceResponse)
|
||||
async def get_resource(
|
||||
resource_id: str,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
middleware: AccessControlMiddleware = Depends(get_middleware),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Get a specific resource if user has access.
|
||||
|
||||
Checks read permission based on access group.
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Verify permission
|
||||
await middleware.verify_request(
|
||||
user_id=user.id,
|
||||
resource_id=resource_id,
|
||||
action="read",
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
# Load resource
|
||||
resource = await controller._load_resource(resource_id)
|
||||
|
||||
return ResourceResponse(
|
||||
id=resource.id,
|
||||
name=resource.name,
|
||||
resource_type=resource.resource_type,
|
||||
owner_id=resource.owner_id,
|
||||
tenant_domain=resource.tenant_domain,
|
||||
access_group=resource.access_group,
|
||||
team_members=resource.team_members,
|
||||
created_at=resource.created_at,
|
||||
updated_at=resource.updated_at,
|
||||
metadata=resource.metadata,
|
||||
file_path=resource.file_path
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get resource: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/resources/{resource_id}")
|
||||
async def delete_resource(
|
||||
resource_id: str,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
middleware: AccessControlMiddleware = Depends(get_middleware),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Delete a resource.
|
||||
|
||||
Only the resource owner can delete.
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Verify permission
|
||||
await middleware.verify_request(
|
||||
user_id=user.id,
|
||||
resource_id=resource_id,
|
||||
action="delete",
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
# Delete resource (implementation needed)
|
||||
# await controller.delete_resource(resource_id)
|
||||
|
||||
return {"status": "deleted", "resource_id": resource_id}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete resource: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_resource_stats(
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Get resource statistics for the current user.
|
||||
|
||||
Returns counts by type and access group.
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
|
||||
# Get stats
|
||||
stats = await controller.get_resource_stats(user_id=user.id)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/resources/{resource_id}/team-members/{user_id}")
|
||||
async def add_team_member(
|
||||
resource_id: str,
|
||||
user_id: str,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
middleware: AccessControlMiddleware = Depends(get_middleware),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Add a user to team access for a resource.
|
||||
|
||||
Only the resource owner can add team members.
|
||||
Resource must have team access group.
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
current_user = await get_current_user(authorization, db)
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Verify permission
|
||||
await middleware.verify_request(
|
||||
user_id=current_user.id,
|
||||
resource_id=resource_id,
|
||||
action="share",
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
# Load resource
|
||||
resource = await controller._load_resource(resource_id)
|
||||
|
||||
# Check if team access
|
||||
if resource.access_group != AccessGroup.TEAM:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Resource must have team access to add members"
|
||||
)
|
||||
|
||||
# Add team member
|
||||
resource.add_team_member(user_id)
|
||||
|
||||
# Save changes (implementation needed)
|
||||
# await controller.save_resource(resource)
|
||||
|
||||
return {"status": "added", "user_id": user_id, "resource_id": resource_id}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to add team member: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/resources/{resource_id}/team-members/{user_id}")
|
||||
async def remove_team_member(
|
||||
resource_id: str,
|
||||
user_id: str,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
middleware: AccessControlMiddleware = Depends(get_middleware),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Remove a user from team access for a resource.
|
||||
|
||||
Only the resource owner can remove team members.
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
current_user = await get_current_user(authorization, db)
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Verify permission
|
||||
await middleware.verify_request(
|
||||
user_id=current_user.id,
|
||||
resource_id=resource_id,
|
||||
action="share",
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
# Load resource
|
||||
resource = await controller._load_resource(resource_id)
|
||||
|
||||
# Remove team member
|
||||
resource.remove_team_member(user_id)
|
||||
|
||||
# Save changes (implementation needed)
|
||||
# await controller.save_resource(resource)
|
||||
|
||||
return {"status": "removed", "user_id": user_id, "resource_id": resource_id}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to remove team member: {str(e)}")
|
||||
1465
apps/tenant-backend/app/api/v1/agents.py
Normal file
1465
apps/tenant-backend/app/api/v1/agents.py
Normal file
File diff suppressed because it is too large
Load Diff
582
apps/tenant-backend/app/api/v1/api_keys.py
Normal file
582
apps/tenant-backend/app/api/v1/api_keys.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
Enhanced API Keys Management API for GT 2.0
|
||||
|
||||
RESTful API for advanced API key management with capability-based permissions,
|
||||
configurable constraints, and comprehensive audit logging.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, Depends, Header, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.security import get_current_user, verify_capability_token
|
||||
from app.services.enhanced_api_keys import (
|
||||
EnhancedAPIKeyService, APIKeyConfig, APIKeyStatus, APIKeyScope, SharingPermission
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
"""Request to create a new API key"""
|
||||
name: str = Field(..., description="Human-readable name for the key")
|
||||
description: Optional[str] = Field(None, description="Description of the key's purpose")
|
||||
capabilities: List[str] = Field(..., description="List of capability strings")
|
||||
scope: str = Field("user", description="Key scope: user, tenant, admin")
|
||||
expires_in_days: int = Field(90, description="Expiration time in days")
|
||||
rate_limit_per_hour: Optional[int] = Field(None, description="Custom rate limit per hour")
|
||||
daily_quota: Optional[int] = Field(None, description="Custom daily quota")
|
||||
cost_limit_cents: Optional[int] = Field(None, description="Custom cost limit in cents")
|
||||
allowed_endpoints: Optional[List[str]] = Field(None, description="Allowed endpoints")
|
||||
blocked_endpoints: Optional[List[str]] = Field(None, description="Blocked endpoints")
|
||||
allowed_ips: Optional[List[str]] = Field(None, description="Allowed IP addresses")
|
||||
tenant_constraints: Optional[Dict[str, Any]] = Field(None, description="Custom tenant constraints")
|
||||
|
||||
|
||||
class APIKeyResponse(BaseModel):
|
||||
"""API key configuration response"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
owner_id: str
|
||||
scope: str
|
||||
capabilities: List[str]
|
||||
rate_limit_per_hour: int
|
||||
daily_quota: int
|
||||
cost_limit_cents: int
|
||||
max_tokens_per_request: int
|
||||
allowed_endpoints: List[str]
|
||||
blocked_endpoints: List[str]
|
||||
allowed_ips: List[str]
|
||||
status: str
|
||||
created_at: datetime
|
||||
expires_at: Optional[datetime]
|
||||
last_rotated: Optional[datetime]
|
||||
usage: Dict[str, Any]
|
||||
|
||||
|
||||
class CreateAPIKeyResponse(BaseModel):
|
||||
"""Response when creating a new API key"""
|
||||
api_key: APIKeyResponse
|
||||
raw_key: str = Field(..., description="The actual API key (only shown once)")
|
||||
warning: str = Field(..., description="Security warning about key storage")
|
||||
|
||||
|
||||
class RotateAPIKeyResponse(BaseModel):
|
||||
"""Response when rotating an API key"""
|
||||
api_key: APIKeyResponse
|
||||
new_raw_key: str = Field(..., description="The new API key (only shown once)")
|
||||
warning: str = Field(..., description="Security warning about updating systems")
|
||||
|
||||
|
||||
class APIKeyUsageResponse(BaseModel):
|
||||
"""API key usage analytics response"""
|
||||
total_requests: int
|
||||
total_errors: int
|
||||
avg_requests_per_day: float
|
||||
rate_limit_hits: int
|
||||
keys_analyzed: int
|
||||
date_range: Dict[str, str]
|
||||
most_used_endpoints: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class ValidateAPIKeyRequest(BaseModel):
|
||||
"""Request to validate an API key"""
|
||||
api_key: str = Field(..., description="Raw API key to validate")
|
||||
endpoint: Optional[str] = Field(None, description="Endpoint being accessed")
|
||||
client_ip: Optional[str] = Field(None, description="Client IP address")
|
||||
|
||||
|
||||
class ValidateAPIKeyResponse(BaseModel):
|
||||
"""API key validation response"""
|
||||
valid: bool
|
||||
error_message: Optional[str]
|
||||
capability_token: Optional[str]
|
||||
rate_limit_remaining: Optional[int]
|
||||
quota_remaining: Optional[int]
|
||||
|
||||
|
||||
# Dependency injection
|
||||
async def get_api_key_service(
|
||||
authorization: str = Header(...),
|
||||
current_user: str = Depends(get_current_user)
|
||||
) -> EnhancedAPIKeyService:
|
||||
"""Get enhanced API key service"""
|
||||
# Extract tenant from token (mock implementation)
|
||||
tenant_domain = "customer1.com" # Would extract from JWT
|
||||
|
||||
# Use tenant-specific signing key
|
||||
signing_key = f"signing_key_for_{tenant_domain}"
|
||||
|
||||
return EnhancedAPIKeyService(tenant_domain, signing_key)
|
||||
|
||||
|
||||
@router.post("", response_model=CreateAPIKeyResponse)
|
||||
async def create_api_key(
|
||||
request: CreateAPIKeyRequest,
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Create a new API key with specified capabilities and constraints.
|
||||
|
||||
- **name**: Human-readable name for the key
|
||||
- **capabilities**: List of capability strings (e.g., ["llm:gpt-4", "rag:search"])
|
||||
- **scope**: user, tenant, or admin level
|
||||
- **expires_in_days**: Expiration time (default 90 days)
|
||||
- **rate_limit_per_hour**: Custom rate limit (optional)
|
||||
- **allowed_endpoints**: Restrict to specific endpoints (optional)
|
||||
- **tenant_constraints**: Custom constraints for the key (optional)
|
||||
"""
|
||||
try:
|
||||
# Convert scope string to enum
|
||||
scope = APIKeyScope(request.scope.lower())
|
||||
|
||||
# Build constraints from request
|
||||
constraints = request.tenant_constraints or {}
|
||||
|
||||
# Apply custom limits if provided
|
||||
if request.rate_limit_per_hour:
|
||||
constraints["rate_limit_per_hour"] = request.rate_limit_per_hour
|
||||
if request.daily_quota:
|
||||
constraints["daily_quota"] = request.daily_quota
|
||||
if request.cost_limit_cents:
|
||||
constraints["cost_limit_cents"] = request.cost_limit_cents
|
||||
|
||||
# Create API key
|
||||
api_key, raw_key = await api_key_service.create_api_key(
|
||||
name=request.name,
|
||||
owner_id=current_user,
|
||||
capabilities=request.capabilities,
|
||||
scope=scope,
|
||||
expires_in_days=request.expires_in_days,
|
||||
constraints=constraints,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
# Apply custom settings if provided
|
||||
if request.allowed_endpoints:
|
||||
api_key.allowed_endpoints = request.allowed_endpoints
|
||||
if request.blocked_endpoints:
|
||||
api_key.blocked_endpoints = request.blocked_endpoints
|
||||
if request.allowed_ips:
|
||||
api_key.allowed_ips = request.allowed_ips
|
||||
if request.description:
|
||||
api_key.description = request.description
|
||||
|
||||
# Store updated key
|
||||
await api_key_service._store_api_key(api_key)
|
||||
|
||||
return CreateAPIKeyResponse(
|
||||
api_key=APIKeyResponse(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
description=api_key.description,
|
||||
owner_id=api_key.owner_id,
|
||||
scope=api_key.scope.value,
|
||||
capabilities=api_key.capabilities,
|
||||
rate_limit_per_hour=api_key.rate_limit_per_hour,
|
||||
daily_quota=api_key.daily_quota,
|
||||
cost_limit_cents=api_key.cost_limit_cents,
|
||||
max_tokens_per_request=api_key.max_tokens_per_request,
|
||||
allowed_endpoints=api_key.allowed_endpoints,
|
||||
blocked_endpoints=api_key.blocked_endpoints,
|
||||
allowed_ips=api_key.allowed_ips,
|
||||
status=api_key.status.value,
|
||||
created_at=api_key.created_at,
|
||||
expires_at=api_key.expires_at,
|
||||
last_rotated=api_key.last_rotated,
|
||||
usage=api_key.usage.to_dict()
|
||||
),
|
||||
raw_key=raw_key,
|
||||
warning="Store this API key securely. It will not be shown again."
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create API key: {str(e)}")
|
||||
|
||||
|
||||
@router.get("", response_model=List[APIKeyResponse])
|
||||
async def list_api_keys(
|
||||
include_usage: bool = Query(True, description="Include usage statistics"),
|
||||
status: Optional[str] = Query(None, description="Filter by status"),
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List API keys for the current user.
|
||||
|
||||
- **include_usage**: Include detailed usage statistics
|
||||
- **status**: Filter by key status (active, suspended, expired, revoked)
|
||||
"""
|
||||
try:
|
||||
api_keys = await api_key_service.list_user_api_keys(
|
||||
owner_id=current_user,
|
||||
capability_token=authorization,
|
||||
include_usage=include_usage
|
||||
)
|
||||
|
||||
# Filter by status if provided
|
||||
if status:
|
||||
try:
|
||||
status_filter = APIKeyStatus(status.lower())
|
||||
api_keys = [key for key in api_keys if key.status == status_filter]
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid status: {status}")
|
||||
|
||||
return [
|
||||
APIKeyResponse(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
description=key.description,
|
||||
owner_id=key.owner_id,
|
||||
scope=key.scope.value,
|
||||
capabilities=key.capabilities,
|
||||
rate_limit_per_hour=key.rate_limit_per_hour,
|
||||
daily_quota=key.daily_quota,
|
||||
cost_limit_cents=key.cost_limit_cents,
|
||||
max_tokens_per_request=key.max_tokens_per_request,
|
||||
allowed_endpoints=key.allowed_endpoints,
|
||||
blocked_endpoints=key.blocked_endpoints,
|
||||
allowed_ips=key.allowed_ips,
|
||||
status=key.status.value,
|
||||
created_at=key.created_at,
|
||||
expires_at=key.expires_at,
|
||||
last_rotated=key.last_rotated,
|
||||
usage=key.usage.to_dict()
|
||||
)
|
||||
for key in api_keys
|
||||
]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list API keys: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{key_id}", response_model=APIKeyResponse)
|
||||
async def get_api_key(
|
||||
key_id: str,
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get detailed information about a specific API key.
|
||||
"""
|
||||
try:
|
||||
# Get user's keys and find the requested one
|
||||
user_keys = await api_key_service.list_user_api_keys(
|
||||
owner_id=current_user,
|
||||
capability_token=authorization,
|
||||
include_usage=True
|
||||
)
|
||||
|
||||
api_key = next((key for key in user_keys if key.id == key_id), None)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
|
||||
return APIKeyResponse(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
description=api_key.description,
|
||||
owner_id=api_key.owner_id,
|
||||
scope=api_key.scope.value,
|
||||
capabilities=api_key.capabilities,
|
||||
rate_limit_per_hour=api_key.rate_limit_per_hour,
|
||||
daily_quota=api_key.daily_quota,
|
||||
cost_limit_cents=api_key.cost_limit_cents,
|
||||
max_tokens_per_request=api_key.max_tokens_per_request,
|
||||
allowed_endpoints=api_key.allowed_endpoints,
|
||||
blocked_endpoints=api_key.blocked_endpoints,
|
||||
allowed_ips=api_key.allowed_ips,
|
||||
status=api_key.status.value,
|
||||
created_at=api_key.created_at,
|
||||
expires_at=api_key.expires_at,
|
||||
last_rotated=api_key.last_rotated,
|
||||
usage=api_key.usage.to_dict()
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get API key: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{key_id}/rotate", response_model=RotateAPIKeyResponse)
|
||||
async def rotate_api_key(
|
||||
key_id: str,
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Rotate API key (generate new key value).
|
||||
|
||||
The old key will be invalidated and a new key will be generated.
|
||||
"""
|
||||
try:
|
||||
api_key, new_raw_key = await api_key_service.rotate_api_key(
|
||||
key_id=key_id,
|
||||
owner_id=current_user,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
return RotateAPIKeyResponse(
|
||||
api_key=APIKeyResponse(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
description=api_key.description,
|
||||
owner_id=api_key.owner_id,
|
||||
scope=api_key.scope.value,
|
||||
capabilities=api_key.capabilities,
|
||||
rate_limit_per_hour=api_key.rate_limit_per_hour,
|
||||
daily_quota=api_key.daily_quota,
|
||||
cost_limit_cents=api_key.cost_limit_cents,
|
||||
max_tokens_per_request=api_key.max_tokens_per_request,
|
||||
allowed_endpoints=api_key.allowed_endpoints,
|
||||
blocked_endpoints=api_key.blocked_endpoints,
|
||||
allowed_ips=api_key.allowed_ips,
|
||||
status=api_key.status.value,
|
||||
created_at=api_key.created_at,
|
||||
expires_at=api_key.expires_at,
|
||||
last_rotated=api_key.last_rotated,
|
||||
usage=api_key.usage.to_dict()
|
||||
),
|
||||
new_raw_key=new_raw_key,
|
||||
warning="Update all systems using this API key with the new value."
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to rotate API key: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{key_id}/revoke")
|
||||
async def revoke_api_key(
|
||||
key_id: str,
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Revoke API key (mark as revoked and disable access).
|
||||
|
||||
Revoked keys cannot be restored and will immediately stop working.
|
||||
"""
|
||||
try:
|
||||
success = await api_key_service.revoke_api_key(
|
||||
key_id=key_id,
|
||||
owner_id=current_user,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"API key {key_id} has been revoked",
|
||||
"key_id": key_id
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to revoke API key: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/validate", response_model=ValidateAPIKeyResponse)
|
||||
async def validate_api_key(
|
||||
request: ValidateAPIKeyRequest,
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service)
|
||||
):
|
||||
"""
|
||||
Validate an API key and get capability token.
|
||||
|
||||
This endpoint is used by other services to validate API keys
|
||||
and generate capability tokens for resource access.
|
||||
"""
|
||||
try:
|
||||
valid, api_key, error_message = await api_key_service.validate_api_key(
|
||||
raw_key=request.api_key,
|
||||
endpoint=request.endpoint or "",
|
||||
client_ip=request.client_ip or "",
|
||||
user_agent=""
|
||||
)
|
||||
|
||||
response = ValidateAPIKeyResponse(
|
||||
valid=valid,
|
||||
error_message=error_message,
|
||||
capability_token=None,
|
||||
rate_limit_remaining=None,
|
||||
quota_remaining=None
|
||||
)
|
||||
|
||||
if valid and api_key:
|
||||
# Generate capability token
|
||||
capability_token = await api_key_service.generate_capability_token(api_key)
|
||||
response.capability_token = capability_token
|
||||
|
||||
# Add rate limit and quota info
|
||||
response.rate_limit_remaining = max(0, api_key.rate_limit_per_hour - api_key.usage.requests_count)
|
||||
response.quota_remaining = max(0, api_key.daily_quota - api_key.usage.requests_count)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to validate API key: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{key_id}/usage", response_model=APIKeyUsageResponse)
|
||||
async def get_api_key_usage(
|
||||
key_id: str,
|
||||
days: int = Query(30, description="Number of days to analyze"),
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get usage analytics for a specific API key.
|
||||
|
||||
- **days**: Number of days to analyze (default 30)
|
||||
"""
|
||||
try:
|
||||
analytics = await api_key_service.get_usage_analytics(
|
||||
owner_id=current_user,
|
||||
key_id=key_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return APIKeyUsageResponse(
|
||||
total_requests=analytics["total_requests"],
|
||||
total_errors=analytics["total_errors"],
|
||||
avg_requests_per_day=analytics["avg_requests_per_day"],
|
||||
rate_limit_hits=analytics["rate_limit_hits"],
|
||||
keys_analyzed=analytics["keys_analyzed"],
|
||||
date_range=analytics["date_range"],
|
||||
most_used_endpoints=analytics.get("most_used_endpoints", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get usage analytics: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/analytics/summary", response_model=APIKeyUsageResponse)
|
||||
async def get_usage_summary(
|
||||
days: int = Query(30, description="Number of days to analyze"),
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get usage analytics summary for all user's API keys.
|
||||
|
||||
- **days**: Number of days to analyze (default 30)
|
||||
"""
|
||||
try:
|
||||
analytics = await api_key_service.get_usage_analytics(
|
||||
owner_id=current_user,
|
||||
key_id=None, # All keys
|
||||
days=days
|
||||
)
|
||||
|
||||
return APIKeyUsageResponse(
|
||||
total_requests=analytics["total_requests"],
|
||||
total_errors=analytics["total_errors"],
|
||||
avg_requests_per_day=analytics["avg_requests_per_day"],
|
||||
rate_limit_hits=analytics["rate_limit_hits"],
|
||||
keys_analyzed=analytics["keys_analyzed"],
|
||||
date_range=analytics["date_range"],
|
||||
most_used_endpoints=analytics.get("most_used_endpoints", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get usage summary: {str(e)}")
|
||||
|
||||
|
||||
# Capability and scope catalogs for UI builders
|
||||
@router.get("/catalog/capabilities")
|
||||
async def get_capability_catalog():
|
||||
"""Get available capabilities for UI builders"""
|
||||
return {
|
||||
"capabilities": [
|
||||
# AI/ML Resources
|
||||
{"value": "llm:gpt-4", "label": "GPT-4 Language Model", "category": "AI/ML"},
|
||||
{"value": "llm:claude-sonnet", "label": "Claude Sonnet", "category": "AI/ML"},
|
||||
{"value": "llm:groq", "label": "Groq Models", "category": "AI/ML"},
|
||||
{"value": "embedding:openai", "label": "OpenAI Embeddings", "category": "AI/ML"},
|
||||
{"value": "image:dall-e", "label": "DALL-E Image Generation", "category": "AI/ML"},
|
||||
|
||||
# RAG & Knowledge
|
||||
{"value": "rag:search", "label": "RAG Search", "category": "Knowledge"},
|
||||
{"value": "rag:upload", "label": "Document Upload", "category": "Knowledge"},
|
||||
{"value": "rag:dataset_management", "label": "Dataset Management", "category": "Knowledge"},
|
||||
|
||||
# Automation
|
||||
{"value": "automation:create", "label": "Create Automations", "category": "Automation"},
|
||||
{"value": "automation:execute", "label": "Execute Automations", "category": "Automation"},
|
||||
{"value": "automation:api_calls", "label": "API Call Actions", "category": "Automation"},
|
||||
{"value": "automation:webhooks", "label": "Webhook Actions", "category": "Automation"},
|
||||
|
||||
# External Services
|
||||
{"value": "external:github", "label": "GitHub Integration", "category": "External"},
|
||||
{"value": "external:slack", "label": "Slack Integration", "category": "External"},
|
||||
|
||||
# Administrative
|
||||
{"value": "admin:user_management", "label": "User Management", "category": "Admin"},
|
||||
{"value": "admin:tenant_settings", "label": "Tenant Settings", "category": "Admin"}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/catalog/scopes")
|
||||
async def get_scope_catalog():
|
||||
"""Get available scopes for UI builders"""
|
||||
return {
|
||||
"scopes": [
|
||||
{
|
||||
"value": "user",
|
||||
"label": "User Scope",
|
||||
"description": "Access to user-specific operations and data",
|
||||
"default_limits": {
|
||||
"rate_limit_per_hour": 1000,
|
||||
"daily_quota": 10000,
|
||||
"cost_limit_cents": 1000
|
||||
}
|
||||
},
|
||||
{
|
||||
"value": "tenant",
|
||||
"label": "Tenant Scope",
|
||||
"description": "Access to tenant-wide operations and data",
|
||||
"default_limits": {
|
||||
"rate_limit_per_hour": 5000,
|
||||
"daily_quota": 50000,
|
||||
"cost_limit_cents": 5000
|
||||
}
|
||||
},
|
||||
{
|
||||
"value": "admin",
|
||||
"label": "Admin Scope",
|
||||
"description": "Administrative access with elevated privileges",
|
||||
"default_limits": {
|
||||
"rate_limit_per_hour": 10000,
|
||||
"daily_quota": 100000,
|
||||
"cost_limit_cents": 10000
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
860
apps/tenant-backend/app/api/v1/auth.py
Normal file
860
apps/tenant-backend/app/api/v1/auth.py
Normal file
@@ -0,0 +1,860 @@
|
||||
"""
|
||||
Authentication endpoints for Tenant Backend
|
||||
|
||||
Real authentication that connects to Control Panel Backend.
|
||||
No mocks - following GT 2.0 philosophy of building on real foundations.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from pydantic import BaseModel, EmailStr
|
||||
import jwt
|
||||
import structlog
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.security import create_capability_token
|
||||
from app.core.database import get_postgresql_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
security = HTTPBearer()
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
# Auth logging helper function (Issue #152)
|
||||
async def log_auth_event(
|
||||
event_type: str,
|
||||
email: str,
|
||||
user_id: str = None,
|
||||
success: bool = True,
|
||||
failure_reason: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
tenant_domain: str = None
|
||||
):
|
||||
"""
|
||||
Log authentication events to auth_logs table for security monitoring.
|
||||
|
||||
Args:
|
||||
event_type: 'login', 'logout', or 'failed_login'
|
||||
email: User's email address
|
||||
user_id: User ID (optional for failed logins)
|
||||
success: Whether the auth attempt succeeded
|
||||
failure_reason: Reason for failure (if applicable)
|
||||
ip_address: IP address of the request
|
||||
user_agent: User agent string from request
|
||||
tenant_domain: Tenant domain for the auth event
|
||||
"""
|
||||
try:
|
||||
if not tenant_domain:
|
||||
tenant_domain = settings.tenant_domain or "test-company"
|
||||
|
||||
schema_name = f"tenant_{tenant_domain.replace('-', '_')}"
|
||||
|
||||
client = await get_postgresql_client()
|
||||
query = f"""
|
||||
INSERT INTO {schema_name}.auth_logs (
|
||||
user_id,
|
||||
email,
|
||||
event_type,
|
||||
success,
|
||||
failure_reason,
|
||||
ip_address,
|
||||
user_agent,
|
||||
tenant_domain
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
"""
|
||||
|
||||
await client.execute_command(
|
||||
query,
|
||||
user_id or "unknown",
|
||||
email,
|
||||
event_type,
|
||||
success,
|
||||
failure_reason,
|
||||
ip_address,
|
||||
user_agent,
|
||||
tenant_domain
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Auth event logged",
|
||||
event_type=event_type,
|
||||
email=email,
|
||||
success=success,
|
||||
tenant=tenant_domain
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't fail the authentication if logging fails
|
||||
logger.error(f"Failed to log auth event: {e}", event_type=event_type, email=email)
|
||||
|
||||
|
||||
# Pydantic models
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
user: dict
|
||||
|
||||
|
||||
class TokenValidation(BaseModel):
|
||||
valid: bool
|
||||
user: Optional[dict] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
id: int
|
||||
email: str
|
||||
full_name: str
|
||||
user_type: str
|
||||
tenant_id: Optional[int]
|
||||
capabilities: list
|
||||
is_active: bool
|
||||
|
||||
|
||||
# Helper functions
|
||||
async def development_login(login_data: LoginRequest) -> LoginResponse:
|
||||
"""
|
||||
Development authentication that creates a valid JWT token
|
||||
for testing purposes when Control Panel is unavailable.
|
||||
"""
|
||||
# Simple development authentication - check for test credentials
|
||||
test_users = {
|
||||
"gtadmin@test.com": {"password": "password", "role": "admin"},
|
||||
"admin@test.com": {"password": "password", "role": "admin"},
|
||||
"test@example.com": {"password": "password", "role": "developer"}
|
||||
}
|
||||
|
||||
user_info = test_users.get(str(login_data.email).lower())
|
||||
if not user_info or login_data.password != user_info["password"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password"
|
||||
)
|
||||
|
||||
# Create capability token using GT 2.0 security
|
||||
# NIST compliant: 30 minutes idle timeout (Issue #242)
|
||||
token = create_capability_token(
|
||||
user_id=str(login_data.email),
|
||||
tenant_id="test-company",
|
||||
capabilities=[
|
||||
{"resource": "agents", "actions": ["read", "write", "delete"]},
|
||||
{"resource": "datasets", "actions": ["read", "write", "delete"]},
|
||||
{"resource": "conversations", "actions": ["read", "write", "delete"]}
|
||||
],
|
||||
expires_hours=0.5 # 30 minutes (NIST compliant)
|
||||
)
|
||||
|
||||
user_data = {
|
||||
"id": 1,
|
||||
"email": str(login_data.email),
|
||||
"full_name": "Test User",
|
||||
"role": user_info["role"],
|
||||
"tenant_id": "test-company",
|
||||
"tenant_domain": "test-company",
|
||||
"is_active": True
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Development login successful",
|
||||
email=login_data.email,
|
||||
tenant_id="test-company"
|
||||
)
|
||||
|
||||
return LoginResponse(
|
||||
access_token=token,
|
||||
expires_in=1800, # 30 minutes (NIST compliant)
|
||||
user=user_data
|
||||
)
|
||||
async def verify_token_with_control_panel(token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify JWT token with Control Panel Backend.
|
||||
This ensures consistency across the entire GT 2.0 platform.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/auth/verify-token",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get("success") and data.get("data", {}).get("valid"):
|
||||
return data["data"]
|
||||
|
||||
return {"valid": False, "error": "Invalid token"}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to verify token with control panel", error=str(e))
|
||||
# Fallback to local verification if control panel is unreachable
|
||||
try:
|
||||
# Use same RSA keys as token creation for consistency
|
||||
from app.core.security import verify_capability_token
|
||||
payload = verify_capability_token(token)
|
||||
if payload:
|
||||
return {"valid": True, "user": payload}
|
||||
else:
|
||||
return {"valid": False, "error": "Invalid token"}
|
||||
except Exception as e:
|
||||
logger.error(f"Local token verification failed: {e}")
|
||||
return {"valid": False, "error": "Token verification failed"}
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security)
|
||||
) -> UserInfo:
|
||||
"""
|
||||
Get current user from JWT token.
|
||||
Validates with Control Panel for consistency.
|
||||
"""
|
||||
token = credentials.credentials
|
||||
validation = await verify_token_with_control_panel(token)
|
||||
|
||||
if not validation.get("valid"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=validation.get("error", "Invalid authentication token"),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user_data = validation.get("user", {})
|
||||
|
||||
# Ensure user belongs to this tenant
|
||||
if settings.tenant_id and str(user_data.get("tenant_id")) != str(settings.tenant_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User does not belong to this tenant"
|
||||
)
|
||||
|
||||
return UserInfo(
|
||||
id=user_data.get("id"),
|
||||
email=user_data.get("email"),
|
||||
full_name=user_data.get("full_name"),
|
||||
user_type=user_data.get("user_type"),
|
||||
tenant_id=user_data.get("tenant_id"),
|
||||
capabilities=user_data.get("capabilities", []),
|
||||
is_active=user_data.get("is_active", True)
|
||||
)
|
||||
|
||||
|
||||
# API endpoints
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(
|
||||
login_data: LoginRequest,
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
Authenticate user via Control Panel Backend.
|
||||
For development, falls back to test authentication.
|
||||
"""
|
||||
logger.warning(f"Login attempt for {login_data.email}")
|
||||
logger.warning(f"Settings environment: {settings.environment}")
|
||||
try:
|
||||
# Try Control Panel first
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/login",
|
||||
json={
|
||||
"email": login_data.email,
|
||||
"password": login_data.password
|
||||
},
|
||||
headers={
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown"),
|
||||
"X-App-Type": "tenant_app" # Distinguish from control_panel sessions
|
||||
},
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
|
||||
# Verify user belongs to this tenant
|
||||
user = data.get("user", {})
|
||||
if settings.tenant_id and str(user.get("tenant_id")) != str(settings.tenant_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User does not belong to this tenant"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"User login successful via Control Panel",
|
||||
user_id=user.get("id"),
|
||||
email=user.get("email"),
|
||||
tenant_id=user.get("tenant_id")
|
||||
)
|
||||
|
||||
# Log successful login (Issue #152)
|
||||
await log_auth_event(
|
||||
event_type="login",
|
||||
email=user.get("email"),
|
||||
user_id=str(user.get("id")),
|
||||
success=True,
|
||||
ip_address=request.client.host if request.client else "unknown",
|
||||
user_agent=request.headers.get("user-agent", "unknown"),
|
||||
tenant_domain=settings.tenant_domain
|
||||
)
|
||||
|
||||
return LoginResponse(
|
||||
access_token=data["access_token"],
|
||||
expires_in=data.get("expires_in", 86400),
|
||||
user=user
|
||||
)
|
||||
else:
|
||||
# Control Panel returned non-200, fall back to development auth
|
||||
logger.warning(f"Control Panel returned {response.status_code}, using development auth")
|
||||
logger.warning(f"Environment is: {settings.environment}")
|
||||
if settings.environment == "development":
|
||||
logger.warning("Calling development_login fallback")
|
||||
return await development_login(login_data)
|
||||
else:
|
||||
logger.warning("Not in development mode, raising 401")
|
||||
|
||||
# Log failed login attempt (Issue #152)
|
||||
await log_auth_event(
|
||||
event_type="failed_login",
|
||||
email=login_data.email,
|
||||
success=False,
|
||||
failure_reason="Invalid credentials",
|
||||
ip_address=request.client.host if request.client else "unknown",
|
||||
user_agent=request.headers.get("user-agent", "unknown"),
|
||||
tenant_domain=settings.tenant_domain
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password"
|
||||
)
|
||||
|
||||
except (httpx.RequestError, httpx.TimeoutException) as e:
|
||||
logger.warning("Control Panel unavailable, using development auth", error=str(e))
|
||||
|
||||
# Development fallback authentication
|
||||
if settings.environment == "development":
|
||||
return await development_login(login_data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service unavailable"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Login error", error=str(e))
|
||||
|
||||
# Development fallback for any other errors
|
||||
if settings.environment == "development":
|
||||
logger.warning("Falling back to development auth due to error")
|
||||
return await development_login(login_data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Login failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Logout user (forward to Control Panel for audit logging).
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# Forward logout to Control Panel for audit logging
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/auth/logout",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(
|
||||
"User logout successful",
|
||||
user_id=current_user.id,
|
||||
email=current_user.email
|
||||
)
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
else:
|
||||
# Log locally even if Control Panel fails
|
||||
logger.warning(
|
||||
"Control Panel logout failed, but logging out locally",
|
||||
user_id=current_user.id,
|
||||
status_code=response.status_code
|
||||
)
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Logout error", error=str(e), user_id=current_user.id)
|
||||
# Always return success for logout
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_current_user_info(current_user: UserInfo = Depends(get_current_user)):
|
||||
"""
|
||||
Get current user information.
|
||||
"""
|
||||
return {
|
||||
"success": True,
|
||||
"data": current_user.dict()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/verify")
|
||||
async def verify_token(
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Verify if token is valid.
|
||||
"""
|
||||
return {
|
||||
"success": True,
|
||||
"valid": True,
|
||||
"user": current_user.dict()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Refresh authentication token via Control Panel Backend.
|
||||
Proxies the refresh request to maintain consistent token lifecycle.
|
||||
"""
|
||||
try:
|
||||
# Get current token
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No token provided"
|
||||
)
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
# Note: Control Panel auth endpoints are at /api/v1/* (not /api/v1/auth/*)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/refresh",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Handle successful refresh
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(
|
||||
"Token refresh successful",
|
||||
user_id=current_user.id,
|
||||
email=current_user.email
|
||||
)
|
||||
return {
|
||||
"access_token": data.get("access_token", token),
|
||||
"token_type": "bearer",
|
||||
"expires_in": data.get("expires_in", 86400),
|
||||
"user": current_user.dict()
|
||||
}
|
||||
|
||||
# Handle refresh failure (expired or invalid token)
|
||||
elif response.status_code == 401:
|
||||
logger.warning(
|
||||
"Token refresh failed - token expired or invalid",
|
||||
user_id=current_user.id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token refresh failed - please login again"
|
||||
)
|
||||
|
||||
# Handle other errors
|
||||
else:
|
||||
logger.error(
|
||||
"Token refresh unexpected response",
|
||||
status_code=response.status_code,
|
||||
user_id=current_user.id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail="Token refresh failed"
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to forward token refresh", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Token refresh service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Token refresh proxy error", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Token refresh failed"
|
||||
)
|
||||
|
||||
|
||||
# Two-Factor Authentication Proxy Endpoints
|
||||
@router.post("/tfa/enable")
|
||||
async def enable_tfa_proxy(
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Proxy TFA enable request to Control Panel Backend.
|
||||
User-initiated from settings page.
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/enable",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Return response with original status code
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response.json().get("detail", "TFA enable failed")
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to forward TFA enable request", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("TFA enable proxy error", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="TFA enable failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tfa/verify-setup")
|
||||
async def verify_setup_tfa_proxy(
|
||||
data: dict,
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Proxy TFA setup verification to Control Panel Backend.
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/verify-setup",
|
||||
json=data,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Return response with original status code
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response.json().get("detail", "TFA verification failed")
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to forward TFA verify setup", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("TFA verify setup proxy error", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="TFA verification failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tfa/disable")
|
||||
async def disable_tfa_proxy(
|
||||
data: dict,
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Proxy TFA disable request to Control Panel Backend.
|
||||
Requires password confirmation.
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/disable",
|
||||
json=data,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Return response with original status code
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response.json().get("detail", "TFA disable failed")
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to forward TFA disable request", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("TFA disable proxy error", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="TFA disable failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tfa/verify-login")
|
||||
async def verify_login_tfa_proxy(data: dict, request: Request):
|
||||
"""
|
||||
Proxy TFA login verification to Control Panel Backend.
|
||||
Called after password verification with temp token + 6-digit code.
|
||||
"""
|
||||
try:
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/verify-login",
|
||||
json=data,
|
||||
headers={
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Return response with original status code
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response.json().get("detail", "TFA verification failed")
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to forward TFA verify login", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("TFA verify login proxy error", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="TFA verification failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tfa/status")
|
||||
async def get_tfa_status_proxy(
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Proxy TFA status request to Control Panel Backend.
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/status",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to get TFA status", error=str(e))
|
||||
return {"tfa_enabled": False, "tfa_required": False, "tfa_status": "disabled"}
|
||||
except Exception as e:
|
||||
logger.error("TFA status proxy error", error=str(e))
|
||||
return {"tfa_enabled": False, "tfa_required": False, "tfa_status": "disabled"}
|
||||
|
||||
|
||||
class SessionStatusResponse(BaseModel):
|
||||
"""Response for session status check"""
|
||||
is_valid: bool
|
||||
seconds_remaining: int # Seconds until idle timeout
|
||||
show_warning: bool # True if < 5 minutes remaining
|
||||
absolute_seconds_remaining: Optional[int] = None # Seconds until absolute timeout
|
||||
|
||||
|
||||
@router.get("/session/status", response_model=SessionStatusResponse)
|
||||
async def get_session_status(
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get current session status for frontend session monitoring.
|
||||
|
||||
Proxies request to Control Panel Backend which is the authoritative
|
||||
source for session state. This endpoint replaces the complex react-idle-timer
|
||||
approach with a simple polling mechanism.
|
||||
|
||||
Frontend calls this every 60 seconds to check session health.
|
||||
|
||||
Returns:
|
||||
- is_valid: Whether session is currently valid
|
||||
- seconds_remaining: Seconds until idle timeout (4 hours from last activity)
|
||||
- show_warning: True if warning should be shown (< 5 min remaining)
|
||||
- absolute_seconds_remaining: Seconds until absolute timeout (8 hours from login)
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No token provided"
|
||||
)
|
||||
|
||||
# Forward to Control Panel Backend (authoritative session source)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/session/status",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0 # Increased timeout for proxy/Cloudflare scenarios
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return SessionStatusResponse(
|
||||
is_valid=data.get("is_valid", True),
|
||||
seconds_remaining=data.get("seconds_remaining", 14400), # 4 hours default
|
||||
show_warning=data.get("show_warning", False),
|
||||
absolute_seconds_remaining=data.get("absolute_seconds_remaining")
|
||||
)
|
||||
|
||||
elif response.status_code == 401:
|
||||
# Session expired - return invalid status
|
||||
return SessionStatusResponse(
|
||||
is_valid=False,
|
||||
seconds_remaining=0,
|
||||
show_warning=False,
|
||||
absolute_seconds_remaining=None
|
||||
)
|
||||
|
||||
else:
|
||||
# Unexpected response - return safe defaults
|
||||
logger.warning(
|
||||
"Unexpected session status response",
|
||||
status_code=response.status_code,
|
||||
user_id=current_user.id
|
||||
)
|
||||
return SessionStatusResponse(
|
||||
is_valid=True,
|
||||
seconds_remaining=14400, # 4 hours default (matches IDLE_TIMEOUT_MINUTES)
|
||||
show_warning=False,
|
||||
absolute_seconds_remaining=None
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
# Control Panel unavailable - FAIL CLOSED for security
|
||||
# Return session invalid to force re-authentication
|
||||
logger.error(
|
||||
"Session status check failed - Control Panel unavailable",
|
||||
error=str(e),
|
||||
user_id=current_user.id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Session validation service unavailable"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Session status proxy error", error=str(e), user_id=current_user.id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Session status check failed"
|
||||
)
|
||||
219
apps/tenant-backend/app/api/v1/auth_logs.py
Normal file
219
apps/tenant-backend/app/api/v1/auth_logs.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Authentication Logs API Endpoints
|
||||
Issue: #152
|
||||
|
||||
Provides endpoints for querying authentication event logs including:
|
||||
- User logins
|
||||
- User logouts
|
||||
- Failed login attempts
|
||||
|
||||
Used by observability dashboard for security monitoring and audit trails.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.core.database import get_postgresql_client
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/auth-logs")
|
||||
async def get_auth_logs(
|
||||
event_type: Optional[str] = Query(None, description="Filter by event type: login, logout, failed_login"),
|
||||
start_date: Optional[datetime] = Query(None, description="Start date for filtering (ISO format)"),
|
||||
end_date: Optional[datetime] = Query(None, description="End date for filtering (ISO format)"),
|
||||
user_email: Optional[str] = Query(None, description="Filter by user email"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Maximum number of records to return"),
|
||||
offset: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get authentication logs with optional filtering.
|
||||
|
||||
Returns paginated list of authentication events including logins, logouts,
|
||||
and failed login attempts.
|
||||
"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test-company")
|
||||
schema_name = f"tenant_{tenant_domain.replace('-', '_')}"
|
||||
|
||||
# Build query conditions
|
||||
conditions = []
|
||||
params = []
|
||||
param_counter = 1
|
||||
|
||||
if event_type:
|
||||
if event_type not in ['login', 'logout', 'failed_login']:
|
||||
raise HTTPException(status_code=400, detail="Invalid event_type. Must be: login, logout, or failed_login")
|
||||
conditions.append(f"event_type = ${param_counter}")
|
||||
params.append(event_type)
|
||||
param_counter += 1
|
||||
|
||||
if start_date:
|
||||
conditions.append(f"created_at >= ${param_counter}")
|
||||
params.append(start_date)
|
||||
param_counter += 1
|
||||
|
||||
if end_date:
|
||||
conditions.append(f"created_at <= ${param_counter}")
|
||||
params.append(end_date)
|
||||
param_counter += 1
|
||||
|
||||
if user_email:
|
||||
conditions.append(f"email = ${param_counter}")
|
||||
params.append(user_email)
|
||||
param_counter += 1
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
|
||||
# Get total count
|
||||
client = await get_postgresql_client()
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as total
|
||||
FROM {schema_name}.auth_logs
|
||||
{where_clause}
|
||||
"""
|
||||
count_result = await client.fetch_one(count_query, *params)
|
||||
total_count = count_result['total'] if count_result else 0
|
||||
|
||||
# Get paginated results
|
||||
query = f"""
|
||||
SELECT
|
||||
id,
|
||||
user_id,
|
||||
email,
|
||||
event_type,
|
||||
success,
|
||||
failure_reason,
|
||||
ip_address,
|
||||
user_agent,
|
||||
tenant_domain,
|
||||
created_at,
|
||||
metadata
|
||||
FROM {schema_name}.auth_logs
|
||||
{where_clause}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ${param_counter} OFFSET ${param_counter + 1}
|
||||
"""
|
||||
params.extend([limit, offset])
|
||||
|
||||
logs = await client.fetch_all(query, *params)
|
||||
|
||||
# Format results
|
||||
formatted_logs = []
|
||||
for log in logs:
|
||||
formatted_logs.append({
|
||||
"id": str(log['id']),
|
||||
"user_id": log['user_id'],
|
||||
"email": log['email'],
|
||||
"event_type": log['event_type'],
|
||||
"success": log['success'],
|
||||
"failure_reason": log['failure_reason'],
|
||||
"ip_address": log['ip_address'],
|
||||
"user_agent": log['user_agent'],
|
||||
"tenant_domain": log['tenant_domain'],
|
||||
"created_at": log['created_at'].isoformat() if log['created_at'] else None,
|
||||
"metadata": log['metadata']
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Retrieved authentication logs",
|
||||
tenant=tenant_domain,
|
||||
count=len(formatted_logs),
|
||||
filters={"event_type": event_type, "user_email": user_email}
|
||||
)
|
||||
|
||||
return {
|
||||
"logs": formatted_logs,
|
||||
"pagination": {
|
||||
"total": total_count,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + limit) < total_count
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve auth logs: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/auth-logs/summary")
|
||||
async def get_auth_logs_summary(
|
||||
days: int = Query(7, ge=1, le=90, description="Number of days to summarize"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get authentication log summary statistics.
|
||||
|
||||
Returns aggregated counts of login events by type for the specified time period.
|
||||
"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test-company")
|
||||
schema_name = f"tenant_{tenant_domain.replace('-', '_')}"
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
client = await get_postgresql_client()
|
||||
query = f"""
|
||||
SELECT
|
||||
event_type,
|
||||
success,
|
||||
COUNT(*) as count
|
||||
FROM {schema_name}.auth_logs
|
||||
WHERE created_at >= $1 AND created_at <= $2
|
||||
GROUP BY event_type, success
|
||||
ORDER BY event_type, success
|
||||
"""
|
||||
|
||||
results = await client.fetch_all(query, start_date, end_date)
|
||||
|
||||
# Format summary
|
||||
summary = {
|
||||
"period_days": days,
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"events": {
|
||||
"successful_logins": 0,
|
||||
"failed_logins": 0,
|
||||
"logouts": 0,
|
||||
"total": 0
|
||||
}
|
||||
}
|
||||
|
||||
for row in results:
|
||||
event_type = row['event_type']
|
||||
success = row['success']
|
||||
count = row['count']
|
||||
|
||||
if event_type == 'login' and success:
|
||||
summary['events']['successful_logins'] = count
|
||||
elif event_type == 'failed_login' or (event_type == 'login' and not success):
|
||||
summary['events']['failed_logins'] += count
|
||||
elif event_type == 'logout':
|
||||
summary['events']['logouts'] = count
|
||||
|
||||
summary['events']['total'] += count
|
||||
|
||||
logger.info(
|
||||
"Retrieved auth logs summary",
|
||||
tenant=tenant_domain,
|
||||
days=days,
|
||||
total_events=summary['events']['total']
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve auth logs summary: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
539
apps/tenant-backend/app/api/v1/automation.py
Normal file
539
apps/tenant-backend/app/api/v1/automation.py
Normal file
@@ -0,0 +1,539 @@
|
||||
"""
|
||||
Automation Management API
|
||||
|
||||
REST endpoints for creating, managing, and monitoring automations
|
||||
with capability-based access control.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.services.event_bus import TenantEventBus, TriggerType, EVENT_CATALOG
|
||||
from app.services.automation_executor import AutomationChainExecutor
|
||||
from app.core.dependencies import get_current_user, get_tenant_domain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/automation", tags=["automation"])
|
||||
|
||||
|
||||
class AutomationCreate(BaseModel):
|
||||
"""Create automation request"""
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
trigger_type: str = Field(..., regex="^(cron|webhook|event|manual)$")
|
||||
trigger_config: Dict[str, Any] = Field(default_factory=dict)
|
||||
conditions: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
actions: List[Dict[str, Any]] = Field(..., min_items=1)
|
||||
is_active: bool = True
|
||||
max_retries: int = Field(default=3, ge=0, le=5)
|
||||
timeout_seconds: int = Field(default=300, ge=1, le=3600)
|
||||
|
||||
|
||||
class AutomationUpdate(BaseModel):
|
||||
"""Update automation request"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
trigger_config: Optional[Dict[str, Any]] = None
|
||||
conditions: Optional[List[Dict[str, Any]]] = None
|
||||
actions: Optional[List[Dict[str, Any]]] = None
|
||||
is_active: Optional[bool] = None
|
||||
max_retries: Optional[int] = Field(None, ge=0, le=5)
|
||||
timeout_seconds: Optional[int] = Field(None, ge=1, le=3600)
|
||||
|
||||
|
||||
class AutomationResponse(BaseModel):
|
||||
"""Automation response"""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
owner_id: str
|
||||
trigger_type: str
|
||||
trigger_config: Dict[str, Any]
|
||||
conditions: List[Dict[str, Any]]
|
||||
actions: List[Dict[str, Any]]
|
||||
is_active: bool
|
||||
max_retries: int
|
||||
timeout_seconds: int
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class TriggerAutomationRequest(BaseModel):
|
||||
"""Manual trigger request"""
|
||||
event_data: Dict[str, Any] = Field(default_factory=dict)
|
||||
variables: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@router.get("/catalog/events")
|
||||
async def get_event_catalog():
|
||||
"""Get available event types and their required fields"""
|
||||
return {
|
||||
"events": EVENT_CATALOG,
|
||||
"trigger_types": [trigger.value for trigger in TriggerType]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/catalog/actions")
|
||||
async def get_action_catalog():
|
||||
"""Get available action types and their configurations"""
|
||||
return {
|
||||
"actions": {
|
||||
"webhook": {
|
||||
"description": "Send HTTP request to external endpoint",
|
||||
"required_fields": ["url"],
|
||||
"optional_fields": ["method", "headers", "body"],
|
||||
"example": {
|
||||
"type": "webhook",
|
||||
"url": "https://api.example.com/notify",
|
||||
"method": "POST",
|
||||
"headers": {"Content-Type": "application/json"},
|
||||
"body": {"message": "Automation triggered"}
|
||||
}
|
||||
},
|
||||
"email": {
|
||||
"description": "Send email notification",
|
||||
"required_fields": ["to", "subject"],
|
||||
"optional_fields": ["body", "template"],
|
||||
"example": {
|
||||
"type": "email",
|
||||
"to": "user@example.com",
|
||||
"subject": "Automation Alert",
|
||||
"body": "Your automation has completed"
|
||||
}
|
||||
},
|
||||
"log": {
|
||||
"description": "Write to application logs",
|
||||
"required_fields": ["message"],
|
||||
"optional_fields": ["level"],
|
||||
"example": {
|
||||
"type": "log",
|
||||
"message": "Document processed successfully",
|
||||
"level": "info"
|
||||
}
|
||||
},
|
||||
"api_call": {
|
||||
"description": "Call internal or external API",
|
||||
"required_fields": ["endpoint"],
|
||||
"optional_fields": ["method", "headers", "body"],
|
||||
"example": {
|
||||
"type": "api_call",
|
||||
"endpoint": "/api/v1/documents/process",
|
||||
"method": "POST",
|
||||
"body": {"document_id": "${document_id}"}
|
||||
}
|
||||
},
|
||||
"data_transform": {
|
||||
"description": "Transform data between steps",
|
||||
"required_fields": ["transform_type", "source", "target"],
|
||||
"optional_fields": ["path", "mapping"],
|
||||
"example": {
|
||||
"type": "data_transform",
|
||||
"transform_type": "extract",
|
||||
"source": "api_response",
|
||||
"target": "document_id",
|
||||
"path": "data.document.id"
|
||||
}
|
||||
},
|
||||
"conditional": {
|
||||
"description": "Execute actions based on conditions",
|
||||
"required_fields": ["condition", "then"],
|
||||
"optional_fields": ["else"],
|
||||
"example": {
|
||||
"type": "conditional",
|
||||
"condition": {"left": "$status", "operator": "equals", "right": "success"},
|
||||
"then": [{"type": "log", "message": "Success"}],
|
||||
"else": [{"type": "log", "message": "Failed"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("", response_model=AutomationResponse)
|
||||
async def create_automation(
|
||||
automation: AutomationCreate,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Create a new automation"""
|
||||
try:
|
||||
# Initialize event bus
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Convert trigger type to enum
|
||||
trigger_type = TriggerType(automation.trigger_type)
|
||||
|
||||
# Create automation
|
||||
created_automation = await event_bus.create_automation(
|
||||
name=automation.name,
|
||||
owner_id=current_user,
|
||||
trigger_type=trigger_type,
|
||||
trigger_config=automation.trigger_config,
|
||||
actions=automation.actions,
|
||||
conditions=automation.conditions
|
||||
)
|
||||
|
||||
# Set additional properties
|
||||
created_automation.max_retries = automation.max_retries
|
||||
created_automation.timeout_seconds = automation.timeout_seconds
|
||||
created_automation.is_active = automation.is_active
|
||||
|
||||
# Log creation
|
||||
await event_bus.emit_event(
|
||||
event_type="automation.created",
|
||||
user_id=current_user,
|
||||
data={
|
||||
"automation_id": created_automation.id,
|
||||
"name": created_automation.name,
|
||||
"trigger_type": trigger_type.value
|
||||
}
|
||||
)
|
||||
|
||||
return AutomationResponse(
|
||||
id=created_automation.id,
|
||||
name=created_automation.name,
|
||||
description=automation.description,
|
||||
owner_id=created_automation.owner_id,
|
||||
trigger_type=trigger_type.value,
|
||||
trigger_config=created_automation.trigger_config,
|
||||
conditions=created_automation.conditions,
|
||||
actions=created_automation.actions,
|
||||
is_active=created_automation.is_active,
|
||||
max_retries=created_automation.max_retries,
|
||||
timeout_seconds=created_automation.timeout_seconds,
|
||||
created_at=created_automation.created_at.isoformat(),
|
||||
updated_at=created_automation.updated_at.isoformat()
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating automation: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create automation")
|
||||
|
||||
|
||||
@router.get("", response_model=List[AutomationResponse])
|
||||
async def list_automations(
|
||||
owner_only: bool = True,
|
||||
active_only: bool = False,
|
||||
trigger_type: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""List automations with optional filtering"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Get automations
|
||||
owner_filter = current_user if owner_only else None
|
||||
automations = await event_bus.list_automations(owner_id=owner_filter)
|
||||
|
||||
# Apply filters
|
||||
filtered = []
|
||||
for automation in automations:
|
||||
if active_only and not automation.is_active:
|
||||
continue
|
||||
|
||||
if trigger_type and automation.trigger_type.value != trigger_type:
|
||||
continue
|
||||
|
||||
filtered.append(AutomationResponse(
|
||||
id=automation.id,
|
||||
name=automation.name,
|
||||
description="", # Not stored in current model
|
||||
owner_id=automation.owner_id,
|
||||
trigger_type=automation.trigger_type.value,
|
||||
trigger_config=automation.trigger_config,
|
||||
conditions=automation.conditions,
|
||||
actions=automation.actions,
|
||||
is_active=automation.is_active,
|
||||
max_retries=automation.max_retries,
|
||||
timeout_seconds=automation.timeout_seconds,
|
||||
created_at=automation.created_at.isoformat(),
|
||||
updated_at=automation.updated_at.isoformat()
|
||||
))
|
||||
|
||||
# Apply limit
|
||||
return filtered[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing automations: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list automations")
|
||||
|
||||
|
||||
@router.get("/{automation_id}", response_model=AutomationResponse)
|
||||
async def get_automation(
|
||||
automation_id: str,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Get automation by ID"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
automation = await event_bus.get_automation(automation_id)
|
||||
|
||||
if not automation:
|
||||
raise HTTPException(status_code=404, detail="Automation not found")
|
||||
|
||||
# Check ownership
|
||||
if automation.owner_id != current_user:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
return AutomationResponse(
|
||||
id=automation.id,
|
||||
name=automation.name,
|
||||
description="",
|
||||
owner_id=automation.owner_id,
|
||||
trigger_type=automation.trigger_type.value,
|
||||
trigger_config=automation.trigger_config,
|
||||
conditions=automation.conditions,
|
||||
actions=automation.actions,
|
||||
is_active=automation.is_active,
|
||||
max_retries=automation.max_retries,
|
||||
timeout_seconds=automation.timeout_seconds,
|
||||
created_at=automation.created_at.isoformat(),
|
||||
updated_at=automation.updated_at.isoformat()
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting automation: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get automation")
|
||||
|
||||
|
||||
@router.delete("/{automation_id}")
|
||||
async def delete_automation(
|
||||
automation_id: str,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Delete automation"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Check if automation exists and user owns it
|
||||
automation = await event_bus.get_automation(automation_id)
|
||||
if not automation:
|
||||
raise HTTPException(status_code=404, detail="Automation not found")
|
||||
|
||||
if automation.owner_id != current_user:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Delete automation
|
||||
success = await event_bus.delete_automation(automation_id, current_user)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete automation")
|
||||
|
||||
# Log deletion
|
||||
await event_bus.emit_event(
|
||||
event_type="automation.deleted",
|
||||
user_id=current_user,
|
||||
data={
|
||||
"automation_id": automation_id,
|
||||
"name": automation.name
|
||||
}
|
||||
)
|
||||
|
||||
return {"status": "deleted", "automation_id": automation_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting automation: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete automation")
|
||||
|
||||
|
||||
@router.post("/{automation_id}/trigger")
|
||||
async def trigger_automation(
|
||||
automation_id: str,
|
||||
trigger_request: TriggerAutomationRequest,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Manually trigger an automation"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Get automation
|
||||
automation = await event_bus.get_automation(automation_id)
|
||||
if not automation:
|
||||
raise HTTPException(status_code=404, detail="Automation not found")
|
||||
|
||||
# Check ownership
|
||||
if automation.owner_id != current_user:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Check if automation supports manual triggering
|
||||
if automation.trigger_type != TriggerType.MANUAL:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Automation does not support manual triggering"
|
||||
)
|
||||
|
||||
# Create manual trigger event
|
||||
await event_bus.emit_event(
|
||||
event_type="automation.manual_trigger",
|
||||
user_id=current_user,
|
||||
data={
|
||||
"automation_id": automation_id,
|
||||
"trigger_data": trigger_request.event_data,
|
||||
"variables": trigger_request.variables
|
||||
},
|
||||
metadata={
|
||||
"trigger_type": TriggerType.MANUAL.value
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "triggered",
|
||||
"automation_id": automation_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering automation: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to trigger automation")
|
||||
|
||||
|
||||
@router.get("/{automation_id}/executions")
|
||||
async def get_execution_history(
|
||||
automation_id: str,
|
||||
limit: int = 10,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Get execution history for automation"""
|
||||
try:
|
||||
# Check automation ownership first
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
automation = await event_bus.get_automation(automation_id)
|
||||
|
||||
if not automation:
|
||||
raise HTTPException(status_code=404, detail="Automation not found")
|
||||
|
||||
if automation.owner_id != current_user:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Get execution history
|
||||
executor = AutomationChainExecutor(tenant_domain, event_bus)
|
||||
executions = await executor.get_execution_history(automation_id, limit)
|
||||
|
||||
return {
|
||||
"automation_id": automation_id,
|
||||
"executions": executions,
|
||||
"total": len(executions)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting execution history: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get execution history")
|
||||
|
||||
|
||||
@router.post("/{automation_id}/test")
|
||||
async def test_automation(
|
||||
automation_id: str,
|
||||
test_data: Dict[str, Any] = {},
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Test automation with sample data"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Get automation
|
||||
automation = await event_bus.get_automation(automation_id)
|
||||
if not automation:
|
||||
raise HTTPException(status_code=404, detail="Automation not found")
|
||||
|
||||
# Check ownership
|
||||
if automation.owner_id != current_user:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Create test event
|
||||
test_event_type = automation.trigger_config.get("event_types", ["test.event"])[0]
|
||||
|
||||
await event_bus.emit_event(
|
||||
event_type=test_event_type,
|
||||
user_id=current_user,
|
||||
data={
|
||||
"test": True,
|
||||
"automation_id": automation_id,
|
||||
**test_data
|
||||
},
|
||||
metadata={
|
||||
"test_mode": True,
|
||||
"test_timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "test_triggered",
|
||||
"automation_id": automation_id,
|
||||
"test_event": test_event_type,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing automation: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to test automation")
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_automation_stats(
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Get automation statistics for current user"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Get user's automations
|
||||
automations = await event_bus.list_automations(owner_id=current_user)
|
||||
|
||||
# Calculate stats
|
||||
total = len(automations)
|
||||
active = sum(1 for a in automations if a.is_active)
|
||||
by_trigger_type = {}
|
||||
|
||||
for automation in automations:
|
||||
trigger = automation.trigger_type.value
|
||||
by_trigger_type[trigger] = by_trigger_type.get(trigger, 0) + 1
|
||||
|
||||
# Get recent events
|
||||
recent_events = await event_bus.get_event_history(
|
||||
user_id=current_user,
|
||||
limit=10
|
||||
)
|
||||
|
||||
return {
|
||||
"total_automations": total,
|
||||
"active_automations": active,
|
||||
"inactive_automations": total - active,
|
||||
"by_trigger_type": by_trigger_type,
|
||||
"recent_events": [
|
||||
{
|
||||
"type": event.type,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"data": event.data
|
||||
}
|
||||
for event in recent_events
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting automation stats: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get automation stats")
|
||||
174
apps/tenant-backend/app/api/v1/categories.py
Normal file
174
apps/tenant-backend/app/api/v1/categories.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Category API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Provides tenant-scoped agent category management with CRUD operations.
|
||||
Supports Issue #215 requirements for editable/deletable categories.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
import logging
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.services.category_service import CategoryService
|
||||
from app.schemas.category import (
|
||||
CategoryCreate,
|
||||
CategoryUpdate,
|
||||
CategoryResponse,
|
||||
CategoryListResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/categories", tags=["categories"])
|
||||
|
||||
|
||||
async def get_category_service(current_user: Dict[str, Any]) -> CategoryService:
|
||||
"""Helper function to create CategoryService with proper context"""
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
# Get user ID from token or lookup
|
||||
user_id = current_user.get('sub', current_user.get('user_id', user_email))
|
||||
tenant_domain = current_user.get('tenant_domain', 'test-company')
|
||||
|
||||
return CategoryService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=str(user_id),
|
||||
user_email=user_email
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=CategoryListResponse)
|
||||
async def list_categories(
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all categories for the tenant.
|
||||
|
||||
Returns all active categories with permission flags indicating
|
||||
whether the current user can edit/delete each category.
|
||||
"""
|
||||
try:
|
||||
service = await get_category_service(current_user)
|
||||
categories = await service.get_all_categories()
|
||||
|
||||
return CategoryListResponse(
|
||||
categories=[CategoryResponse(**cat) for cat in categories],
|
||||
total=len(categories)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing categories: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{category_id}", response_model=CategoryResponse)
|
||||
async def get_category(
|
||||
category_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get a specific category by ID.
|
||||
"""
|
||||
try:
|
||||
service = await get_category_service(current_user)
|
||||
category = await service.get_category_by_id(category_id)
|
||||
|
||||
if not category:
|
||||
raise HTTPException(status_code=404, detail="Category not found")
|
||||
|
||||
return CategoryResponse(**category)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting category {category_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("", response_model=CategoryResponse, status_code=201)
|
||||
async def create_category(
|
||||
data: CategoryCreate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Create a new custom category.
|
||||
|
||||
The creating user becomes the owner and can edit/delete the category.
|
||||
All users in the tenant can use the category for their agents.
|
||||
"""
|
||||
try:
|
||||
service = await get_category_service(current_user)
|
||||
category = await service.create_category(
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
icon=data.icon
|
||||
)
|
||||
|
||||
return CategoryResponse(**category)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating category: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{category_id}", response_model=CategoryResponse)
|
||||
async def update_category(
|
||||
category_id: str,
|
||||
data: CategoryUpdate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update a category.
|
||||
|
||||
Requires permission: admin/developer role OR be the category creator.
|
||||
"""
|
||||
try:
|
||||
service = await get_category_service(current_user)
|
||||
category = await service.update_category(
|
||||
category_id=category_id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
icon=data.icon
|
||||
)
|
||||
|
||||
return CategoryResponse(**category)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating category {category_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{category_id}")
|
||||
async def delete_category(
|
||||
category_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Delete a category (soft delete).
|
||||
|
||||
Requires permission: admin/developer role OR be the category creator.
|
||||
Note: Agents using this category will retain their category value,
|
||||
but the category will no longer appear in selection lists.
|
||||
"""
|
||||
try:
|
||||
service = await get_category_service(current_user)
|
||||
await service.delete_category(category_id)
|
||||
|
||||
return {"message": "Category deleted successfully"}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting category {category_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
1563
apps/tenant-backend/app/api/v1/chat.py
Normal file
1563
apps/tenant-backend/app/api/v1/chat.py
Normal file
File diff suppressed because it is too large
Load Diff
631
apps/tenant-backend/app/api/v1/conversations.py
Normal file
631
apps/tenant-backend/app/api/v1/conversations.py
Normal file
@@ -0,0 +1,631 @@
|
||||
"""
|
||||
Conversation API endpoints for GT 2.0 Tenant Backend - PostgreSQL Migration
|
||||
|
||||
Basic conversation endpoints during PostgreSQL migration.
|
||||
Full functionality will be restored as migration completes.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import List, Optional, Dict, Any
|
||||
import logging
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.websocket.manager import websocket_manager
|
||||
|
||||
# TEMPORARY: Basic response schemas during migration
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from fastapi import File, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
class ConversationResponse(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
agent_id: Optional[str]
|
||||
agent_name: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
last_message_at: Optional[datetime] = None
|
||||
message_count: int = 0
|
||||
token_count: int = 0
|
||||
is_archived: bool = False
|
||||
unread_count: int = 0
|
||||
|
||||
class ConversationListResponse(BaseModel):
|
||||
conversations: List[ConversationResponse]
|
||||
total: int
|
||||
|
||||
# Message creation model
|
||||
class MessageCreate(BaseModel):
|
||||
"""Request body for creating a message in a conversation"""
|
||||
role: str = Field(..., description="Message role: user, assistant, agent, or system")
|
||||
content: str = Field(..., description="Message content (supports any length)")
|
||||
model_used: Optional[str] = Field(None, description="Model used to generate the message")
|
||||
token_count: int = Field(0, ge=0, description="Token count for the message")
|
||||
metadata: Optional[Dict] = Field(None, description="Additional message metadata")
|
||||
attachments: Optional[List] = Field(None, description="Message attachments")
|
||||
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/conversations", tags=["conversations"])
|
||||
|
||||
|
||||
@router.get("", response_model=ConversationListResponse)
|
||||
async def list_conversations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
agent_id: Optional[str] = Query(None, description="Filter by agent"),
|
||||
search: Optional[str] = Query(None, description="Search in conversation titles"),
|
||||
time_filter: str = Query("all", description="Filter by time: 'today', 'week', 'month', 'all'"),
|
||||
limit: int = Query(20, ge=1),
|
||||
offset: int = Query(0, ge=0)
|
||||
) -> ConversationListResponse:
|
||||
"""List user's conversations using PostgreSQL with server-side filtering"""
|
||||
try:
|
||||
# Extract tenant domain from user context
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
# Get conversations from PostgreSQL with filters
|
||||
result = await service.list_conversations(
|
||||
user_identifier=current_user["email"],
|
||||
agent_id=agent_id,
|
||||
search=search,
|
||||
time_filter=time_filter,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
conversations = [
|
||||
ConversationResponse(
|
||||
id=conv["id"],
|
||||
title=conv["title"],
|
||||
agent_id=conv["agent_id"],
|
||||
agent_name=conv.get("agent_name"),
|
||||
created_at=conv["created_at"],
|
||||
updated_at=conv.get("updated_at"),
|
||||
last_message_at=conv.get("last_message_at"),
|
||||
message_count=conv.get("message_count", 0),
|
||||
token_count=conv.get("token_count", 0),
|
||||
is_archived=conv.get("is_archived", False),
|
||||
unread_count=conv.get("unread_count", 0)
|
||||
)
|
||||
for conv in result["conversations"]
|
||||
]
|
||||
|
||||
return ConversationListResponse(
|
||||
conversations=conversations,
|
||||
total=result["total"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list conversations: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("", response_model=Dict[str, Any])
|
||||
async def create_conversation(
|
||||
agent_id: str,
|
||||
title: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new conversation with an agent"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
result = await service.create_conversation(
|
||||
agent_id=agent_id,
|
||||
title=title,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create conversation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{conversation_id}", response_model=Dict[str, Any])
|
||||
async def get_conversation(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get a specific conversation with details"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
result = await service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/messages", response_model=List[Dict[str, Any]])
|
||||
async def get_conversation_messages(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0)
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get messages for a conversation"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
messages = await service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"],
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages for conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{conversation_id}/messages", response_model=Dict[str, Any])
|
||||
async def add_message(
|
||||
conversation_id: str,
|
||||
message: MessageCreate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a message to a conversation (supports messages of any length)"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
result = await service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role=message.role,
|
||||
content=message.content,
|
||||
user_identifier=current_user["email"],
|
||||
model_used=message.model_used,
|
||||
token_count=message.token_count,
|
||||
metadata=message.metadata,
|
||||
attachments=message.attachments
|
||||
)
|
||||
|
||||
# Broadcast message creation via WebSocket for unread tracking and sidebar updates
|
||||
try:
|
||||
# Get updated conversation details (message count, timestamp)
|
||||
conv_details = await service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
# Import the Socket.IO broadcast function
|
||||
from app.websocket.manager import broadcast_conversation_update
|
||||
|
||||
await broadcast_conversation_update(
|
||||
conversation_id=conversation_id,
|
||||
event='conversation:message_added',
|
||||
data={
|
||||
'conversation_id': conversation_id,
|
||||
'message_id': result.get('id'),
|
||||
'sender_id': current_user.get('id'),
|
||||
'role': message.role,
|
||||
'content': message.content[:100], # First 100 chars for preview
|
||||
'message_count': conv_details.get('message_count', conv_details.get('total_messages', 0)),
|
||||
'last_message_at': result.get('created_at'),
|
||||
'title': conv_details.get('title', 'New Conversation')
|
||||
}
|
||||
)
|
||||
except Exception as ws_error:
|
||||
logger.warning(f"Failed to broadcast message via WebSocket: {ws_error}")
|
||||
# Don't fail the request if WebSocket broadcast fails
|
||||
|
||||
# Check if we should generate a title after this message
|
||||
if message.role == "agent" or message.role == "assistant":
|
||||
# This is an AI response - check if it's after the first user message
|
||||
try:
|
||||
# Get all messages in conversation
|
||||
messages = await service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
# Count user and agent messages
|
||||
user_messages = [m for m in messages if m["role"] == "user"]
|
||||
agent_messages = [m for m in messages if m["role"] in ["agent", "assistant"]]
|
||||
|
||||
# Generate title if this is the first exchange (1 user + 1 agent message)
|
||||
if len(user_messages) == 1 and len(agent_messages) == 1:
|
||||
logger.info(f"🎯 First exchange complete, generating title for conversation {conversation_id}")
|
||||
try:
|
||||
await service.auto_generate_conversation_title(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
logger.info(f"✅ Title generated for conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate title for conversation {conversation_id}: {e}")
|
||||
# Don't fail the request if title generation fails
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for title generation: {e}")
|
||||
# Don't fail the request if title check fails
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add message to conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{conversation_id}")
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
title: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Update a conversation (e.g., rename title)"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
success = await service.update_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"],
|
||||
title=title
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found or access denied")
|
||||
|
||||
return {"message": "Conversation updated successfully", "conversation_id": conversation_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{conversation_id}")
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, str]:
|
||||
"""Delete a conversation (soft delete)"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
success = await service.delete_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found or access denied")
|
||||
|
||||
return {"message": "Conversation deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/recent", response_model=ConversationListResponse)
|
||||
async def get_recent_conversations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
days_back: int = Query(7, ge=1, le=30, description="Number of days back"),
|
||||
max_conversations: int = Query(10, ge=1, le=25, description="Maximum conversations"),
|
||||
include_archived: bool = Query(False, description="Include deleted conversations")
|
||||
):
|
||||
"""
|
||||
Get recent conversation summaries.
|
||||
|
||||
Used by MCP conversation server for recent activity context.
|
||||
"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
# Get recent conversations
|
||||
result = await service.list_conversations(
|
||||
user_identifier=current_user["email"],
|
||||
limit=max_conversations,
|
||||
offset=0,
|
||||
order_by="updated_at",
|
||||
order_direction="desc"
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
conversations = [
|
||||
ConversationResponse(
|
||||
id=conv["id"],
|
||||
title=conv["title"],
|
||||
agent_id=conv["agent_id"],
|
||||
created_at=conv["created_at"]
|
||||
)
|
||||
for conv in result["conversations"]
|
||||
]
|
||||
|
||||
return ConversationListResponse(
|
||||
conversations=conversations,
|
||||
total=len(conversations)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get recent conversations: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Dataset management models
|
||||
class AddDatasetsRequest(BaseModel):
|
||||
"""Request to add datasets to a conversation"""
|
||||
dataset_ids: List[str] = Field(..., min_items=1, description="Dataset IDs to add to conversation")
|
||||
|
||||
class DatasetOperationResponse(BaseModel):
|
||||
"""Response for dataset operations"""
|
||||
success: bool
|
||||
message: str
|
||||
conversation_id: str
|
||||
dataset_count: int
|
||||
|
||||
# Conversation file models
|
||||
class ConversationFileResponse(BaseModel):
|
||||
"""Response for conversation file operations"""
|
||||
id: str
|
||||
filename: str
|
||||
original_filename: str
|
||||
content_type: str
|
||||
file_size_bytes: int
|
||||
processing_status: str
|
||||
uploaded_at: datetime
|
||||
processed_at: Optional[datetime] = None
|
||||
|
||||
class ConversationFileListResponse(BaseModel):
|
||||
"""Response for listing conversation files"""
|
||||
conversation_id: str
|
||||
files: List[ConversationFileResponse]
|
||||
total_files: int
|
||||
|
||||
|
||||
@router.post("/{conversation_id}/datasets", response_model=DatasetOperationResponse)
|
||||
async def add_datasets_to_conversation(
|
||||
conversation_id: str,
|
||||
request: AddDatasetsRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> DatasetOperationResponse:
|
||||
"""Add datasets to a conversation for context awareness"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
# Add datasets to conversation
|
||||
success = await service.add_datasets_to_conversation(
|
||||
conversation_id=conversation_id,
|
||||
dataset_ids=request.dataset_ids,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Failed to add datasets to conversation. Check conversation exists and you have access."
|
||||
)
|
||||
|
||||
logger.info(f"Added {len(request.dataset_ids)} datasets to conversation {conversation_id}")
|
||||
|
||||
return DatasetOperationResponse(
|
||||
success=True,
|
||||
message=f"Successfully added {len(request.dataset_ids)} dataset(s) to conversation",
|
||||
conversation_id=conversation_id,
|
||||
dataset_count=len(request.dataset_ids)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add datasets to conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/datasets")
|
||||
async def get_conversation_datasets(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get datasets associated with a conversation"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
dataset_ids = await service.get_conversation_datasets(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
return {
|
||||
"conversation_id": conversation_id,
|
||||
"dataset_ids": dataset_ids,
|
||||
"dataset_count": len(dataset_ids)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get datasets for conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Conversation File Management Endpoints
|
||||
|
||||
@router.post("/{conversation_id}/files", response_model=List[ConversationFileResponse])
|
||||
async def upload_conversation_files(
|
||||
conversation_id: str,
|
||||
files: List[UploadFile] = File(...),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> List[ConversationFileResponse]:
|
||||
"""Upload files directly to conversation (replaces dataset-based uploads)"""
|
||||
try:
|
||||
from app.services.conversation_file_service import get_conversation_file_service
|
||||
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = get_conversation_file_service(tenant_domain, current_user["email"])
|
||||
|
||||
# Upload files to conversation
|
||||
uploaded_files = await service.upload_files(
|
||||
conversation_id=conversation_id,
|
||||
files=files,
|
||||
user_id=current_user["email"]
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
file_responses = []
|
||||
for file_data in uploaded_files:
|
||||
file_responses.append(ConversationFileResponse(
|
||||
id=file_data["id"],
|
||||
filename=file_data["filename"],
|
||||
original_filename=file_data["original_filename"],
|
||||
content_type=file_data["content_type"],
|
||||
file_size_bytes=file_data["file_size_bytes"],
|
||||
processing_status=file_data["processing_status"],
|
||||
uploaded_at=file_data["uploaded_at"],
|
||||
processed_at=file_data.get("processed_at")
|
||||
))
|
||||
|
||||
logger.info(f"Uploaded {len(uploaded_files)} files to conversation {conversation_id}")
|
||||
return file_responses
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload files to conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/files", response_model=ConversationFileListResponse)
|
||||
async def list_conversation_files(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ConversationFileListResponse:
|
||||
"""List files attached to conversation"""
|
||||
try:
|
||||
from app.services.conversation_file_service import get_conversation_file_service
|
||||
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = get_conversation_file_service(tenant_domain, current_user["email"])
|
||||
|
||||
files = await service.list_files(conversation_id)
|
||||
|
||||
# Convert to response format
|
||||
file_responses = []
|
||||
for file_data in files:
|
||||
file_responses.append(ConversationFileResponse(
|
||||
id=str(file_data["id"]), # Convert UUID to string
|
||||
filename=file_data["filename"],
|
||||
original_filename=file_data["original_filename"],
|
||||
content_type=file_data["content_type"],
|
||||
file_size_bytes=file_data["file_size_bytes"],
|
||||
processing_status=file_data["processing_status"],
|
||||
uploaded_at=file_data["uploaded_at"],
|
||||
processed_at=file_data.get("processed_at")
|
||||
))
|
||||
|
||||
return ConversationFileListResponse(
|
||||
conversation_id=conversation_id,
|
||||
files=file_responses,
|
||||
total_files=len(file_responses)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list files for conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{conversation_id}/files/{file_id}")
|
||||
async def delete_conversation_file(
|
||||
conversation_id: str,
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, str]:
|
||||
"""Delete specific file from conversation"""
|
||||
try:
|
||||
from app.services.conversation_file_service import get_conversation_file_service
|
||||
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = get_conversation_file_service(tenant_domain, current_user["email"])
|
||||
|
||||
success = await service.delete_file(
|
||||
conversation_id=conversation_id,
|
||||
file_id=file_id,
|
||||
user_id=current_user["email"]
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="File not found or access denied")
|
||||
|
||||
return {"message": "File deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file {file_id} from conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/files/{file_id}/download")
|
||||
async def download_conversation_file(
|
||||
conversation_id: str,
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Download specific conversation file"""
|
||||
try:
|
||||
from app.services.conversation_file_service import get_conversation_file_service
|
||||
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = get_conversation_file_service(tenant_domain, current_user["email"])
|
||||
|
||||
# Get file record for metadata
|
||||
file_record = await service._get_file_record(file_id)
|
||||
if not file_record or file_record['conversation_id'] != conversation_id:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Get file content
|
||||
content = await service.get_file_content(file_id, current_user["email"])
|
||||
if not content:
|
||||
raise HTTPException(status_code=404, detail="File content not found")
|
||||
|
||||
# Return file as streaming response
|
||||
def iter_content():
|
||||
yield content
|
||||
|
||||
return StreamingResponse(
|
||||
iter_content(),
|
||||
media_type=file_record['content_type'],
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=\"{file_record['original_filename']}\"",
|
||||
"Content-Length": str(file_record['file_size_bytes'])
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download file {file_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
425
apps/tenant-backend/app/api/v1/dataset_sharing.py
Normal file
425
apps/tenant-backend/app/api/v1/dataset_sharing.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Dataset Sharing API for GT 2.0
|
||||
|
||||
RESTful API for hierarchical dataset sharing with capability-based access control.
|
||||
Enables secure collaboration while maintaining perfect tenant isolation.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import APIRouter, HTTPException, Depends, Header, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.security import get_current_user, verify_capability_token
|
||||
from app.services.dataset_sharing import (
|
||||
DatasetSharingService, DatasetShare, DatasetInfo, SharingPermission
|
||||
)
|
||||
from app.services.access_controller import AccessController
|
||||
from app.models.access_group import AccessGroup
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class ShareDatasetRequest(BaseModel):
|
||||
"""Request to share a dataset"""
|
||||
dataset_id: str = Field(..., description="Dataset ID to share")
|
||||
access_group: str = Field(..., description="Access group: individual, team, organization")
|
||||
team_members: Optional[List[str]] = Field(None, description="Team members for team sharing")
|
||||
team_permissions: Optional[Dict[str, str]] = Field(None, description="Permissions for team members")
|
||||
expires_days: Optional[int] = Field(None, description="Expiration in days")
|
||||
|
||||
|
||||
class UpdatePermissionRequest(BaseModel):
|
||||
"""Request to update team member permissions"""
|
||||
user_id: str = Field(..., description="User ID to update")
|
||||
permission: str = Field(..., description="Permission: read, write, admin")
|
||||
|
||||
|
||||
class DatasetShareResponse(BaseModel):
|
||||
"""Dataset sharing configuration response"""
|
||||
id: str
|
||||
dataset_id: str
|
||||
owner_id: str
|
||||
access_group: str
|
||||
team_members: List[str]
|
||||
team_permissions: Dict[str, str]
|
||||
shared_at: datetime
|
||||
expires_at: Optional[datetime]
|
||||
is_active: bool
|
||||
|
||||
|
||||
class DatasetInfoResponse(BaseModel):
|
||||
"""Dataset information response"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
owner_id: str
|
||||
document_count: int
|
||||
size_bytes: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
tags: List[str]
|
||||
|
||||
|
||||
class SharingStatsResponse(BaseModel):
|
||||
"""Sharing statistics response"""
|
||||
owned_datasets: int
|
||||
shared_with_me: int
|
||||
sharing_breakdown: Dict[str, int]
|
||||
total_team_members: int
|
||||
expired_shares: int
|
||||
|
||||
|
||||
# Dependency injection
|
||||
async def get_dataset_sharing_service(
|
||||
authorization: str = Header(...),
|
||||
current_user: str = Depends(get_current_user)
|
||||
) -> DatasetSharingService:
|
||||
"""Get dataset sharing service with access controller"""
|
||||
# Extract tenant from token (mock implementation)
|
||||
tenant_domain = "customer1.com" # Would extract from JWT
|
||||
|
||||
access_controller = AccessController(tenant_domain)
|
||||
return DatasetSharingService(tenant_domain, access_controller)
|
||||
|
||||
|
||||
@router.post("/share", response_model=DatasetShareResponse)
|
||||
async def share_dataset(
|
||||
request: ShareDatasetRequest,
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Share a dataset with specified access group.
|
||||
|
||||
- **dataset_id**: Dataset to share
|
||||
- **access_group**: individual, team, or organization
|
||||
- **team_members**: Required for team sharing
|
||||
- **team_permissions**: Optional custom permissions
|
||||
- **expires_days**: Optional expiration period
|
||||
"""
|
||||
try:
|
||||
# Convert access group string to enum
|
||||
access_group = AccessGroup(request.access_group.lower())
|
||||
|
||||
# Convert permission strings to enums
|
||||
team_permissions = {}
|
||||
if request.team_permissions:
|
||||
for user, perm in request.team_permissions.items():
|
||||
team_permissions[user] = SharingPermission(perm.lower())
|
||||
|
||||
# Calculate expiration
|
||||
expires_at = None
|
||||
if request.expires_days:
|
||||
expires_at = datetime.utcnow() + timedelta(days=request.expires_days)
|
||||
|
||||
# Share dataset
|
||||
share = await sharing_service.share_dataset(
|
||||
dataset_id=request.dataset_id,
|
||||
owner_id=current_user,
|
||||
access_group=access_group,
|
||||
team_members=request.team_members,
|
||||
team_permissions=team_permissions,
|
||||
expires_at=expires_at,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
return DatasetShareResponse(
|
||||
id=share.id,
|
||||
dataset_id=share.dataset_id,
|
||||
owner_id=share.owner_id,
|
||||
access_group=share.access_group.value,
|
||||
team_members=share.team_members,
|
||||
team_permissions={k: v.value for k, v in share.team_permissions.items()},
|
||||
shared_at=share.shared_at,
|
||||
expires_at=share.expires_at,
|
||||
is_active=share.is_active
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to share dataset: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{dataset_id}", response_model=DatasetShareResponse)
|
||||
async def get_dataset_sharing(
|
||||
dataset_id: str,
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get sharing configuration for a dataset.
|
||||
|
||||
Returns sharing details if user has access to view them.
|
||||
"""
|
||||
try:
|
||||
share = await sharing_service.get_dataset_sharing(
|
||||
dataset_id=dataset_id,
|
||||
user_id=current_user,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
if not share:
|
||||
raise HTTPException(status_code=404, detail="Dataset sharing not found or access denied")
|
||||
|
||||
return DatasetShareResponse(
|
||||
id=share.id,
|
||||
dataset_id=share.dataset_id,
|
||||
owner_id=share.owner_id,
|
||||
access_group=share.access_group.value,
|
||||
team_members=share.team_members,
|
||||
team_permissions={k: v.value for k, v in share.team_permissions.items()},
|
||||
shared_at=share.shared_at,
|
||||
expires_at=share.expires_at,
|
||||
is_active=share.is_active
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get dataset sharing: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{dataset_id}/access-check")
|
||||
async def check_dataset_access(
|
||||
dataset_id: str,
|
||||
permission: str = Query("read", description="Permission to check: read, write, admin"),
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Check if user has specified permission on dataset.
|
||||
|
||||
- **permission**: read, write, or admin
|
||||
"""
|
||||
try:
|
||||
# Convert permission string to enum
|
||||
required_permission = SharingPermission(permission.lower())
|
||||
|
||||
allowed, reason = await sharing_service.check_dataset_access(
|
||||
dataset_id=dataset_id,
|
||||
user_id=current_user,
|
||||
permission=required_permission
|
||||
)
|
||||
|
||||
return {
|
||||
"allowed": allowed,
|
||||
"reason": reason,
|
||||
"permission": permission,
|
||||
"user_id": current_user,
|
||||
"dataset_id": dataset_id
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid permission: {str(e)}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to check access: {str(e)}")
|
||||
|
||||
|
||||
@router.get("", response_model=List[DatasetInfoResponse])
|
||||
async def list_accessible_datasets(
|
||||
include_owned: bool = Query(True, description="Include user's own datasets"),
|
||||
include_shared: bool = Query(True, description="Include datasets shared with user"),
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List datasets accessible to user.
|
||||
|
||||
- **include_owned**: Include user's own datasets
|
||||
- **include_shared**: Include datasets shared with user
|
||||
"""
|
||||
try:
|
||||
datasets = await sharing_service.list_accessible_datasets(
|
||||
user_id=current_user,
|
||||
capability_token=authorization,
|
||||
include_owned=include_owned,
|
||||
include_shared=include_shared
|
||||
)
|
||||
|
||||
return [
|
||||
DatasetInfoResponse(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
owner_id=dataset.owner_id,
|
||||
document_count=dataset.document_count,
|
||||
size_bytes=dataset.size_bytes,
|
||||
created_at=dataset.created_at,
|
||||
updated_at=dataset.updated_at,
|
||||
tags=dataset.tags
|
||||
)
|
||||
for dataset in datasets
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list datasets: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{dataset_id}/revoke")
|
||||
async def revoke_dataset_sharing(
|
||||
dataset_id: str,
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Revoke dataset sharing (make it private).
|
||||
|
||||
Only the dataset owner can revoke sharing.
|
||||
"""
|
||||
try:
|
||||
success = await sharing_service.revoke_dataset_sharing(
|
||||
dataset_id=dataset_id,
|
||||
owner_id=current_user,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to revoke sharing")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Dataset {dataset_id} sharing revoked",
|
||||
"dataset_id": dataset_id
|
||||
}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to revoke sharing: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/{dataset_id}/permissions")
|
||||
async def update_team_permissions(
|
||||
dataset_id: str,
|
||||
request: UpdatePermissionRequest,
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update team member permissions for a dataset.
|
||||
|
||||
Only the dataset owner can update permissions.
|
||||
"""
|
||||
try:
|
||||
# Convert permission string to enum
|
||||
permission = SharingPermission(request.permission.lower())
|
||||
|
||||
success = await sharing_service.update_team_permissions(
|
||||
dataset_id=dataset_id,
|
||||
owner_id=current_user,
|
||||
user_id=request.user_id,
|
||||
permission=permission,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to update permissions")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Updated {request.user_id} permission to {request.permission}",
|
||||
"dataset_id": dataset_id,
|
||||
"user_id": request.user_id,
|
||||
"permission": request.permission
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid permission: {str(e)}")
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update permissions: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/stats/sharing", response_model=SharingStatsResponse)
|
||||
async def get_sharing_statistics(
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get sharing statistics for current user.
|
||||
|
||||
Returns counts of owned, shared datasets and sharing breakdown.
|
||||
"""
|
||||
try:
|
||||
stats = await sharing_service.get_sharing_statistics(
|
||||
user_id=current_user,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
# Convert AccessGroup enum keys to strings
|
||||
sharing_breakdown = {
|
||||
group.value: count for group, count in stats["sharing_breakdown"].items()
|
||||
}
|
||||
|
||||
return SharingStatsResponse(
|
||||
owned_datasets=stats["owned_datasets"],
|
||||
shared_with_me=stats["shared_with_me"],
|
||||
sharing_breakdown=sharing_breakdown,
|
||||
total_team_members=stats["total_team_members"],
|
||||
expired_shares=stats["expired_shares"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get statistics: {str(e)}")
|
||||
|
||||
|
||||
# Action and permission catalogs for UI builders
|
||||
@router.get("/catalog/permissions")
|
||||
async def get_permission_catalog():
|
||||
"""Get available sharing permissions for UI builders"""
|
||||
return {
|
||||
"permissions": [
|
||||
{
|
||||
"value": "read",
|
||||
"label": "Read Only",
|
||||
"description": "Can view and search dataset"
|
||||
},
|
||||
{
|
||||
"value": "write",
|
||||
"label": "Read & Write",
|
||||
"description": "Can view, search, and add documents"
|
||||
},
|
||||
{
|
||||
"value": "admin",
|
||||
"label": "Administrator",
|
||||
"description": "Can modify sharing settings"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/catalog/access-groups")
|
||||
async def get_access_group_catalog():
|
||||
"""Get available access groups for UI builders"""
|
||||
return {
|
||||
"access_groups": [
|
||||
{
|
||||
"value": "individual",
|
||||
"label": "Private",
|
||||
"description": "Only accessible to owner"
|
||||
},
|
||||
{
|
||||
"value": "team",
|
||||
"label": "Team",
|
||||
"description": "Shared with specific team members"
|
||||
},
|
||||
{
|
||||
"value": "organization",
|
||||
"label": "Organization",
|
||||
"description": "Read-only access for all tenant users"
|
||||
}
|
||||
]
|
||||
}
|
||||
1024
apps/tenant-backend/app/api/v1/datasets.py
Normal file
1024
apps/tenant-backend/app/api/v1/datasets.py
Normal file
File diff suppressed because it is too large
Load Diff
256
apps/tenant-backend/app/api/v1/documents.py
Normal file
256
apps/tenant-backend/app/api/v1/documents.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
GT 2.0 Documents API - Wrapper for Files API
|
||||
|
||||
Provides document-centric interface that wraps the underlying files API.
|
||||
This maintains the document abstraction for the frontend while leveraging
|
||||
the existing file storage infrastructure.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, UploadFile, File, Form
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.api.v1.files import (
|
||||
get_file_info,
|
||||
download_file,
|
||||
delete_file,
|
||||
list_files,
|
||||
get_document_summary as get_file_summary,
|
||||
upload_file
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
||||
|
||||
|
||||
@router.post("", status_code=201)
|
||||
@router.post("/", status_code=201) # Support both with and without trailing slash
|
||||
async def upload_document(
|
||||
file: UploadFile = File(...),
|
||||
dataset_id: Optional[str] = Form(None, description="Associate with dataset"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Upload document (proxy to files API) - accepts dataset_id from FormData"""
|
||||
try:
|
||||
logger.info(f"Document upload requested - file: {file.filename}, dataset_id: {dataset_id}")
|
||||
# Proxy to files upload endpoint with "documents" category
|
||||
return await upload_file(file, dataset_id, "documents", current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Document upload failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{document_id}")
|
||||
async def get_document(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get document details (proxy to files API)"""
|
||||
try:
|
||||
# Proxy to files API - documents are stored as files
|
||||
return await get_file_info(document_id, current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get document {document_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{document_id}/summary")
|
||||
async def get_document_summary(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get AI-generated summary for a document (proxy to files API)"""
|
||||
try:
|
||||
# Proxy to files summary endpoint
|
||||
# codeql[py/stack-trace-exposure] proxies to files API, returns summary dict
|
||||
return await get_file_summary(document_id, current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Document summary generation failed for {document_id}: {e}", exc_info=True)
|
||||
# Return a fallback response
|
||||
return {
|
||||
"summary": "Summary generation is currently unavailable. Please try again later.",
|
||||
"key_topics": [],
|
||||
"document_type": "unknown",
|
||||
"language": "en",
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_documents(
|
||||
dataset_id: Optional[str] = Query(None, description="Filter by dataset"),
|
||||
status: Optional[str] = Query(None, description="Filter by processing status"),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""List documents with optional filtering (proxy to files API)"""
|
||||
try:
|
||||
# Map documents request to files API
|
||||
# Documents are files in the "documents" category
|
||||
result = await list_files(
|
||||
dataset_id=dataset_id,
|
||||
category="documents",
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
# Extract just the files array from the response object
|
||||
# The list_files endpoint returns {files: [...], total: N, limit: N, offset: N}
|
||||
# But frontend expects just the array
|
||||
if isinstance(result, dict) and 'files' in result:
|
||||
return result['files']
|
||||
elif isinstance(result, list):
|
||||
return result
|
||||
else:
|
||||
logger.warning(f"Unexpected response format from list_files: {type(result)}")
|
||||
return []
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list documents: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/{document_id}")
|
||||
async def delete_document(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Delete document and its metadata (proxy to files API)"""
|
||||
try:
|
||||
# Proxy to files delete endpoint
|
||||
return await delete_file(document_id, current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document {document_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{document_id}/download")
|
||||
async def download_document(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Download document file (proxy to files API)"""
|
||||
try:
|
||||
# Proxy to files download endpoint
|
||||
return await download_file(document_id, current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download document {document_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/{document_id}/process")
|
||||
async def process_document(
|
||||
document_id: str,
|
||||
chunking_strategy: Optional[str] = Query("hybrid", description="Chunking strategy"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Trigger document processing (chunking and embedding generation)"""
|
||||
try:
|
||||
from app.services.document_processor import get_document_processor
|
||||
from app.core.user_resolver import resolve_user_uuid
|
||||
|
||||
logger.info(f"Manual processing requested for document {document_id}")
|
||||
|
||||
# Get user info
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get document processor
|
||||
processor = await get_document_processor(tenant_domain=tenant_domain)
|
||||
|
||||
# Get document info to verify it exists and get metadata
|
||||
from app.services.postgresql_file_service import PostgreSQLFileService
|
||||
file_service = PostgreSQLFileService(tenant_domain=tenant_domain, user_id=user_uuid)
|
||||
|
||||
try:
|
||||
doc_info = await file_service.get_file_info(document_id)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Trigger processing using the file service's processing method
|
||||
await file_service._process_document_from_database(
|
||||
processor=processor,
|
||||
document_id=document_id,
|
||||
dataset_id=doc_info.get("dataset_id"),
|
||||
user_uuid=user_uuid,
|
||||
filename=doc_info["original_filename"]
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Document processing started",
|
||||
"document_id": document_id,
|
||||
"chunking_strategy": chunking_strategy
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process document {document_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/processing-status")
|
||||
async def get_processing_status(
|
||||
request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get processing status for multiple documents"""
|
||||
try:
|
||||
from app.services.document_processor import get_document_processor
|
||||
from app.core.user_resolver import resolve_user_uuid
|
||||
|
||||
# Get user info
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get document IDs from request
|
||||
document_ids = request.get("document_ids", [])
|
||||
if not document_ids:
|
||||
raise HTTPException(status_code=400, detail="document_ids required")
|
||||
|
||||
# Get processor instance
|
||||
processor = await get_document_processor(tenant_domain=tenant_domain)
|
||||
|
||||
# Get status for each document
|
||||
status_results = {}
|
||||
for doc_id in document_ids:
|
||||
try:
|
||||
status_info = await processor.get_processing_status(doc_id)
|
||||
status_results[doc_id] = {
|
||||
"status": status_info["status"],
|
||||
"error_message": status_info["error_message"],
|
||||
"progress": status_info["processing_progress"],
|
||||
"stage": status_info["processing_stage"],
|
||||
"chunks_processed": status_info["chunks_processed"],
|
||||
"total_chunks_expected": status_info["total_chunks_expected"]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get status for doc {doc_id}: {e}", exc_info=True)
|
||||
status_results[doc_id] = {
|
||||
"status": "error",
|
||||
"error_message": "Failed to get processing status",
|
||||
"progress": 0,
|
||||
"stage": "unknown"
|
||||
}
|
||||
|
||||
return status_results
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get processing status: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
513
apps/tenant-backend/app/api/v1/external_services.py
Normal file
513
apps/tenant-backend/app/api/v1/external_services.py
Normal file
@@ -0,0 +1,513 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend - External Services API
|
||||
Manage external web service instances with Resource Cluster integration
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from app.api.auth import get_current_user
|
||||
from app.core.database import get_db_session
|
||||
from app.services.external_service import ExternalServiceManager
|
||||
from app.core.capability_client import CapabilityClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["external_services"])
|
||||
|
||||
class CreateServiceRequest(BaseModel):
|
||||
"""Request to create external service"""
|
||||
service_type: str = Field(..., description="Service type: ctfd, canvas, guacamole")
|
||||
service_name: str = Field(..., description="Human-readable service name")
|
||||
description: Optional[str] = Field(None, description="Service description")
|
||||
config_overrides: Optional[Dict[str, Any]] = Field(None, description="Custom configuration")
|
||||
template_id: Optional[str] = Field(None, description="Template to use as base")
|
||||
|
||||
class ShareServiceRequest(BaseModel):
|
||||
"""Request to share service with other users"""
|
||||
share_with_emails: List[str] = Field(..., description="Email addresses to share with")
|
||||
access_level: str = Field(default="read", description="Access level: read, write")
|
||||
|
||||
class ServiceResponse(BaseModel):
|
||||
"""Service instance response"""
|
||||
id: str
|
||||
service_type: str
|
||||
service_name: str
|
||||
description: Optional[str]
|
||||
endpoint_url: str
|
||||
status: str
|
||||
health_status: str
|
||||
created_by: str
|
||||
allowed_users: List[str]
|
||||
access_level: str
|
||||
created_at: str
|
||||
last_accessed: Optional[str]
|
||||
|
||||
class ServiceListResponse(BaseModel):
|
||||
"""List of services response"""
|
||||
services: List[ServiceResponse]
|
||||
total: int
|
||||
|
||||
class EmbedConfigResponse(BaseModel):
|
||||
"""Iframe embed configuration response"""
|
||||
iframe_url: str
|
||||
sandbox_attributes: List[str]
|
||||
security_policies: Dict[str, Any]
|
||||
sso_token: str
|
||||
expires_at: str
|
||||
|
||||
class ServiceAnalyticsResponse(BaseModel):
|
||||
"""Service analytics response"""
|
||||
instance_id: str
|
||||
service_type: str
|
||||
service_name: str
|
||||
analytics_period_days: int
|
||||
total_sessions: int
|
||||
total_time_hours: float
|
||||
unique_users: int
|
||||
average_session_duration_minutes: float
|
||||
daily_usage: Dict[str, Any]
|
||||
uptime_percentage: float
|
||||
|
||||
@router.post("/create", response_model=ServiceResponse)
|
||||
async def create_external_service(
|
||||
request: CreateServiceRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> ServiceResponse:
|
||||
"""Create a new external service instance"""
|
||||
try:
|
||||
# Initialize service manager
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
# Get capability token for Resource Cluster calls
|
||||
capability_client = CapabilityClient()
|
||||
capability_token = await capability_client.generate_capability_token(
|
||||
user_email=current_user['email'],
|
||||
tenant_id=current_user['tenant_id'],
|
||||
resources=['external_services'],
|
||||
expires_hours=24
|
||||
)
|
||||
service_manager.set_capability_token(capability_token)
|
||||
|
||||
# Create service instance
|
||||
instance = await service_manager.create_service_instance(
|
||||
service_type=request.service_type,
|
||||
service_name=request.service_name,
|
||||
user_email=current_user['email'],
|
||||
config_overrides=request.config_overrides,
|
||||
template_id=request.template_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created {request.service_type} service '{request.service_name}' "
|
||||
f"for user {current_user['email']}"
|
||||
)
|
||||
|
||||
return ServiceResponse(
|
||||
id=instance.id,
|
||||
service_type=instance.service_type,
|
||||
service_name=instance.service_name,
|
||||
description=instance.description,
|
||||
endpoint_url=instance.endpoint_url,
|
||||
status=instance.status,
|
||||
health_status=instance.health_status,
|
||||
created_by=instance.created_by,
|
||||
allowed_users=instance.allowed_users,
|
||||
access_level=instance.access_level,
|
||||
created_at=instance.created_at.isoformat(),
|
||||
last_accessed=instance.last_accessed.isoformat() if instance.last_accessed else None
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Invalid request: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid request parameters")
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Resource cluster error: {e}")
|
||||
raise HTTPException(status_code=502, detail="Resource cluster unavailable")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create external service: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/list", response_model=ServiceListResponse)
|
||||
async def list_external_services(
|
||||
service_type: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> ServiceListResponse:
|
||||
"""List external services accessible to the user"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
instances = await service_manager.list_user_services(
|
||||
user_email=current_user['email'],
|
||||
service_type=service_type,
|
||||
status=status
|
||||
)
|
||||
|
||||
services = [
|
||||
ServiceResponse(
|
||||
id=instance.id,
|
||||
service_type=instance.service_type,
|
||||
service_name=instance.service_name,
|
||||
description=instance.description,
|
||||
endpoint_url=instance.endpoint_url,
|
||||
status=instance.status,
|
||||
health_status=instance.health_status,
|
||||
created_by=instance.created_by,
|
||||
allowed_users=instance.allowed_users,
|
||||
access_level=instance.access_level,
|
||||
created_at=instance.created_at.isoformat(),
|
||||
last_accessed=instance.last_accessed.isoformat() if instance.last_accessed else None
|
||||
)
|
||||
for instance in instances
|
||||
]
|
||||
|
||||
return ServiceListResponse(
|
||||
services=services,
|
||||
total=len(services)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list external services: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/{instance_id}", response_model=ServiceResponse)
|
||||
async def get_external_service(
|
||||
instance_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> ServiceResponse:
|
||||
"""Get specific external service details"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
instance = await service_manager.get_service_instance(
|
||||
instance_id=instance_id,
|
||||
user_email=current_user['email']
|
||||
)
|
||||
|
||||
if not instance:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Service instance not found or access denied"
|
||||
)
|
||||
|
||||
return ServiceResponse(
|
||||
id=instance.id,
|
||||
service_type=instance.service_type,
|
||||
service_name=instance.service_name,
|
||||
description=instance.description,
|
||||
endpoint_url=instance.endpoint_url,
|
||||
status=instance.status,
|
||||
health_status=instance.health_status,
|
||||
created_by=instance.created_by,
|
||||
allowed_users=instance.allowed_users,
|
||||
access_level=instance.access_level,
|
||||
created_at=instance.created_at.isoformat(),
|
||||
last_accessed=instance.last_accessed.isoformat() if instance.last_accessed else None
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get external service {instance_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.delete("/{instance_id}")
|
||||
async def stop_external_service(
|
||||
instance_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> Dict[str, Any]:
|
||||
"""Stop external service instance"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
# Get capability token for Resource Cluster calls
|
||||
capability_client = CapabilityClient()
|
||||
capability_token = await capability_client.generate_capability_token(
|
||||
user_email=current_user['email'],
|
||||
tenant_id=current_user['tenant_id'],
|
||||
resources=['external_services'],
|
||||
expires_hours=1
|
||||
)
|
||||
service_manager.set_capability_token(capability_token)
|
||||
|
||||
success = await service_manager.stop_service_instance(
|
||||
instance_id=instance_id,
|
||||
user_email=current_user['email']
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to stop service instance"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Service instance {instance_id} stopped successfully",
|
||||
"stopped_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Service not found: {e}")
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop external service {instance_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/{instance_id}/health")
|
||||
async def get_service_health(
|
||||
instance_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get service health status"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
# Get capability token for Resource Cluster calls
|
||||
capability_client = CapabilityClient()
|
||||
capability_token = await capability_client.generate_capability_token(
|
||||
user_email=current_user['email'],
|
||||
tenant_id=current_user['tenant_id'],
|
||||
resources=['external_services'],
|
||||
expires_hours=1
|
||||
)
|
||||
service_manager.set_capability_token(capability_token)
|
||||
|
||||
health = await service_manager.get_service_health(
|
||||
instance_id=instance_id,
|
||||
user_email=current_user['email']
|
||||
)
|
||||
|
||||
# codeql[py/stack-trace-exposure] returns health status dict, not error details
|
||||
return health
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Service not found: {e}")
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get service health for {instance_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.post("/{instance_id}/embed-config", response_model=EmbedConfigResponse)
|
||||
async def get_embed_config(
|
||||
instance_id: str,
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> EmbedConfigResponse:
|
||||
"""Get iframe embed configuration with SSO token"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
# Get capability token for Resource Cluster calls
|
||||
capability_client = CapabilityClient()
|
||||
capability_token = await capability_client.generate_capability_token(
|
||||
user_email=current_user['email'],
|
||||
tenant_id=current_user['tenant_id'],
|
||||
resources=['external_services'],
|
||||
expires_hours=24
|
||||
)
|
||||
service_manager.set_capability_token(capability_token)
|
||||
|
||||
# Generate SSO token and get embed config
|
||||
sso_data = await service_manager.generate_sso_token(
|
||||
instance_id=instance_id,
|
||||
user_email=current_user['email']
|
||||
)
|
||||
|
||||
# Log access event
|
||||
await service_manager.log_service_access(
|
||||
service_instance_id=instance_id,
|
||||
service_type="unknown", # Will be filled by service lookup
|
||||
user_email=current_user['email'],
|
||||
access_type="embed_access",
|
||||
session_id=f"embed_{datetime.utcnow().timestamp()}",
|
||||
ip_address=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
referer=request.headers.get("referer")
|
||||
)
|
||||
|
||||
return EmbedConfigResponse(
|
||||
iframe_url=sso_data['iframe_config']['src'],
|
||||
sandbox_attributes=sso_data['iframe_config']['sandbox'],
|
||||
security_policies={
|
||||
'allow': sso_data['iframe_config']['allow'],
|
||||
'referrerpolicy': sso_data['iframe_config']['referrerpolicy'],
|
||||
'loading': sso_data['iframe_config']['loading']
|
||||
},
|
||||
sso_token=sso_data['token'],
|
||||
expires_at=sso_data['expires_at']
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Service not found: {e}")
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Resource cluster error: {e}")
|
||||
raise HTTPException(status_code=502, detail="Resource cluster unavailable")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embed config for {instance_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/{instance_id}/analytics", response_model=ServiceAnalyticsResponse)
|
||||
async def get_service_analytics(
|
||||
instance_id: str,
|
||||
days: int = 30,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> ServiceAnalyticsResponse:
|
||||
"""Get service usage analytics"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
analytics = await service_manager.get_service_analytics(
|
||||
instance_id=instance_id,
|
||||
user_email=current_user['email'],
|
||||
days=days
|
||||
)
|
||||
|
||||
return ServiceAnalyticsResponse(**analytics)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Service not found: {e}")
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get analytics for {instance_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.post("/{instance_id}/share")
|
||||
async def share_service(
|
||||
instance_id: str,
|
||||
request: ShareServiceRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> Dict[str, Any]:
|
||||
"""Share service instance with other users"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
success = await service_manager.share_service_instance(
|
||||
instance_id=instance_id,
|
||||
owner_email=current_user['email'],
|
||||
share_with_emails=request.share_with_emails,
|
||||
access_level=request.access_level
|
||||
)
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"shared_with": request.share_with_emails,
|
||||
"access_level": request.access_level,
|
||||
"shared_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Service not found: {e}")
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to share service {instance_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/templates/list")
|
||||
async def list_service_templates(
|
||||
service_type: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
db=Depends(get_db_session)
|
||||
) -> Dict[str, Any]:
|
||||
"""List available service templates"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
templates = await service_manager.list_service_templates(
|
||||
service_type=service_type,
|
||||
category=category,
|
||||
public_only=True
|
||||
)
|
||||
|
||||
return {
|
||||
"templates": [template.to_dict() for template in templates],
|
||||
"total": len(templates)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list service templates: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/types/supported")
|
||||
async def get_supported_service_types() -> Dict[str, Any]:
|
||||
"""Get supported external service types and their capabilities"""
|
||||
return {
|
||||
"supported_types": [
|
||||
{
|
||||
"type": "ctfd",
|
||||
"name": "CTFd Platform",
|
||||
"description": "Cybersecurity capture-the-flag challenges and competitions",
|
||||
"category": "cybersecurity",
|
||||
"features": [
|
||||
"Challenge creation and management",
|
||||
"Team-based competitions",
|
||||
"Scoring and leaderboards",
|
||||
"User registration and management",
|
||||
"Real-time notifications"
|
||||
],
|
||||
"resource_requirements": {
|
||||
"cpu": "1000m",
|
||||
"memory": "2Gi",
|
||||
"storage": "7Gi"
|
||||
},
|
||||
"estimated_startup_time": "2-3 minutes",
|
||||
"sso_supported": True
|
||||
},
|
||||
{
|
||||
"type": "canvas",
|
||||
"name": "Canvas LMS",
|
||||
"description": "Learning management system for educational courses",
|
||||
"category": "education",
|
||||
"features": [
|
||||
"Course creation and management",
|
||||
"Assignment and grading system",
|
||||
"Discussion forums and messaging",
|
||||
"Grade book and analytics",
|
||||
"External tool integrations"
|
||||
],
|
||||
"resource_requirements": {
|
||||
"cpu": "2000m",
|
||||
"memory": "4Gi",
|
||||
"storage": "30Gi"
|
||||
},
|
||||
"estimated_startup_time": "3-5 minutes",
|
||||
"sso_supported": True
|
||||
},
|
||||
{
|
||||
"type": "guacamole",
|
||||
"name": "Apache Guacamole",
|
||||
"description": "Remote desktop access for cyber lab environments",
|
||||
"category": "remote_access",
|
||||
"features": [
|
||||
"RDP, VNC, and SSH connections",
|
||||
"Session recording and playback",
|
||||
"Multi-user concurrent access",
|
||||
"Connection sharing and collaboration",
|
||||
"File transfer capabilities"
|
||||
],
|
||||
"resource_requirements": {
|
||||
"cpu": "500m",
|
||||
"memory": "1Gi",
|
||||
"storage": "11Gi"
|
||||
},
|
||||
"estimated_startup_time": "2-4 minutes",
|
||||
"sso_supported": True
|
||||
}
|
||||
],
|
||||
"total_types": 3,
|
||||
"categories": ["cybersecurity", "education", "remote_access"],
|
||||
"extensible": True
|
||||
}
|
||||
342
apps/tenant-backend/app/api/v1/files.py
Normal file
342
apps/tenant-backend/app/api/v1/files.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
GT 2.0 Files API - PostgreSQL File Storage
|
||||
|
||||
Provides file upload, download, and management using PostgreSQL unified storage.
|
||||
Replaces MinIO integration with PostgreSQL 3-tier storage strategy.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Query, Form
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.core.user_resolver import resolve_user_uuid
|
||||
from app.core.response_filter import ResponseFilter
|
||||
from app.core.permissions import get_user_role, is_effective_owner
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.services.postgresql_file_service import PostgreSQLFileService
|
||||
from app.services.document_summarizer import DocumentSummarizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/files", tags=["files"])
|
||||
|
||||
|
||||
@router.post("/upload", status_code=201)
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
dataset_id: Optional[str] = Form(None, description="Associate with dataset"),
|
||||
category: str = Form("documents", description="File category"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Upload file using PostgreSQL storage"""
|
||||
try:
|
||||
logger.info(f"File upload started: {file.filename}, size: {file.size if hasattr(file, 'size') else 'unknown'}")
|
||||
logger.info(f"Current user: {current_user}")
|
||||
logger.info(f"Dataset ID: {dataset_id}, Category: {category}")
|
||||
|
||||
if not file.filename:
|
||||
logger.error("No filename provided in upload request")
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain = current_user.get('tenant_domain', 'test-company')
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
logger.info(f"Creating file service for tenant: {tenant_domain}, user: {user_email} (UUID: {user_uuid})")
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
# Store file
|
||||
logger.info(f"Storing file: {file.filename}")
|
||||
result = await file_service.store_file(
|
||||
file=file,
|
||||
dataset_id=dataset_id,
|
||||
category=category
|
||||
)
|
||||
|
||||
logger.info(f"File uploaded successfully: {file.filename} -> {result['id']}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"File upload failed for {file.filename if file and file.filename else 'unknown'}: {e}", exc_info=True)
|
||||
logger.error(f"Exception type: {type(e).__name__}")
|
||||
logger.error(f"Current user context: {current_user}")
|
||||
raise HTTPException(status_code=500, detail="Failed to upload file")
|
||||
|
||||
|
||||
@router.get("/{file_id}")
|
||||
async def download_file(
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Download file by ID with streaming support"""
|
||||
try:
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
# Get file info first
|
||||
file_info = await file_service.get_file_info(file_id)
|
||||
|
||||
# Stream file content
|
||||
file_stream = file_service.get_file(file_id)
|
||||
|
||||
return StreamingResponse(
|
||||
file_stream,
|
||||
media_type=file_info['content_type'],
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=\"{file_info['original_filename']}\"",
|
||||
"Content-Length": str(file_info['file_size'])
|
||||
}
|
||||
)
|
||||
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
except Exception as e:
|
||||
logger.error(f"File download failed for {file_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{file_id}/info")
|
||||
async def get_file_info(
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get file metadata"""
|
||||
try:
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
file_info = await file_service.get_file_info(file_id)
|
||||
|
||||
# Apply security filtering using effective ownership
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.core.permissions import get_user_role, is_effective_owner
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
is_owner = is_effective_owner(file_info.get("user_id"), user_uuid, user_role)
|
||||
|
||||
filtered_info = ResponseFilter.filter_file_response(
|
||||
file_info,
|
||||
is_owner=is_owner
|
||||
)
|
||||
|
||||
return filtered_info
|
||||
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Get file info failed for {file_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_files(
|
||||
dataset_id: Optional[str] = Query(None, description="Filter by dataset"),
|
||||
category: str = Query("documents", description="Filter by category"),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""List user files with filtering"""
|
||||
try:
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
files = await file_service.list_files(
|
||||
dataset_id=dataset_id,
|
||||
category=category,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
# Apply security filtering to file list using effective ownership
|
||||
filtered_files = []
|
||||
for file_info in files:
|
||||
is_owner = is_effective_owner(file_info.get("user_id"), user_uuid, user_role)
|
||||
filtered_file = ResponseFilter.filter_file_response(
|
||||
file_info,
|
||||
is_owner=is_owner
|
||||
)
|
||||
filtered_files.append(filtered_file)
|
||||
|
||||
return {
|
||||
"files": filtered_files,
|
||||
"total": len(filtered_files),
|
||||
"limit": limit,
|
||||
"offset": offset
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"List files failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/{file_id}")
|
||||
async def delete_file(
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Delete file and its metadata"""
|
||||
try:
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
success = await file_service.delete_file(file_id)
|
||||
|
||||
if success:
|
||||
return {"message": "File deleted successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="File not found or delete failed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Delete file failed for {file_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/cleanup")
|
||||
async def cleanup_orphaned_files(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Clean up orphaned files (admin operation)"""
|
||||
try:
|
||||
# Only allow admin users to run cleanup
|
||||
user_roles = current_user.get('roles', [])
|
||||
if 'admin' not in user_roles:
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
cleanup_count = await file_service.cleanup_orphaned_files()
|
||||
|
||||
return {
|
||||
"message": f"Cleaned up {cleanup_count} orphaned files",
|
||||
"count": cleanup_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{file_id}/summary")
|
||||
async def get_document_summary(
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get AI-generated summary for a document"""
|
||||
try:
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get file service to retrieve document content
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid
|
||||
)
|
||||
|
||||
# Get file info
|
||||
file_info = await file_service.get_file_info(file_id)
|
||||
|
||||
# Initialize summarizer
|
||||
summarizer = DocumentSummarizer()
|
||||
|
||||
# Get file content (for text files)
|
||||
# Note: This assumes text content is available
|
||||
# In production, you'd need to extract text from PDFs, etc.
|
||||
file_stream = file_service.get_file(file_id)
|
||||
content = ""
|
||||
async for chunk in file_stream:
|
||||
content += chunk.decode('utf-8', errors='ignore')
|
||||
|
||||
# Generate summary
|
||||
summary_result = await summarizer.generate_document_summary(
|
||||
document_id=file_id,
|
||||
content=content[:summarizer.max_content_length], # Truncate if too long
|
||||
filename=file_info['original_filename'],
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# codeql[py/stack-trace-exposure] returns document summary dict, not error details
|
||||
return {
|
||||
"summary": summary_result.get("summary", "No summary available"),
|
||||
"key_topics": summary_result.get("key_topics", []),
|
||||
"document_type": summary_result.get("document_type"),
|
||||
"language": summary_result.get("language", "en"),
|
||||
"metadata": summary_result.get("metadata", {})
|
||||
}
|
||||
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Document summary generation failed for {file_id}: {e}", exc_info=True)
|
||||
# Return a fallback response instead of failing completely
|
||||
return {
|
||||
"summary": "Summary generation is currently unavailable. Please try again later.",
|
||||
"key_topics": [],
|
||||
"document_type": "unknown",
|
||||
"language": "en",
|
||||
"metadata": {}
|
||||
}
|
||||
520
apps/tenant-backend/app/api/v1/games.py
Normal file
520
apps/tenant-backend/app/api/v1/games.py
Normal file
@@ -0,0 +1,520 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from app.services.game_service import GameService, PuzzleService, PhilosophicalDialogueService
|
||||
from app.models.game import GameSession, PuzzleSession, PhilosophicalDialogue, LearningAnalytics
|
||||
|
||||
router = APIRouter(prefix="/games", tags=["AI Literacy & Games"])
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class GameConfigRequest(BaseModel):
|
||||
game_type: str = Field(..., description="Type of game: chess, go")
|
||||
difficulty: str = Field(default="intermediate", description="Difficulty level")
|
||||
name: Optional[str] = Field(None, description="Custom game name")
|
||||
ai_personality: Optional[str] = Field(default="teaching", description="AI opponent personality")
|
||||
time_control: Optional[str] = Field(None, description="Time control settings")
|
||||
|
||||
|
||||
class GameMoveRequest(BaseModel):
|
||||
move_data: Dict[str, Any] = Field(..., description="Move data specific to game type")
|
||||
request_analysis: Optional[bool] = Field(default=False, description="Request move analysis")
|
||||
|
||||
|
||||
class PuzzleConfigRequest(BaseModel):
|
||||
puzzle_type: str = Field(..., description="Type of puzzle: lateral_thinking, logical_deduction, etc.")
|
||||
difficulty: Optional[int] = Field(None, description="Difficulty level 1-10", ge=1, le=10)
|
||||
category: Optional[str] = Field(None, description="Puzzle category")
|
||||
|
||||
|
||||
class PuzzleSolutionRequest(BaseModel):
|
||||
solution: Dict[str, Any] = Field(..., description="Puzzle solution attempt")
|
||||
reasoning: Optional[str] = Field(None, description="User's reasoning explanation")
|
||||
|
||||
|
||||
class HintRequest(BaseModel):
|
||||
hint_level: int = Field(default=1, description="Hint level 1-3", ge=1, le=3)
|
||||
|
||||
|
||||
class DilemmaConfigRequest(BaseModel):
|
||||
dilemma_type: str = Field(..., description="Type of dilemma: ethical_frameworks, game_theory, ai_consciousness")
|
||||
topic: str = Field(..., description="Specific topic within the dilemma type")
|
||||
complexity: Optional[str] = Field(default="intermediate", description="Complexity level")
|
||||
|
||||
|
||||
class DilemmaResponseRequest(BaseModel):
|
||||
response: str = Field(..., description="User's response to the dilemma")
|
||||
framework: Optional[str] = Field(None, description="Ethical framework being applied")
|
||||
|
||||
|
||||
class DilemmaFinalPositionRequest(BaseModel):
|
||||
final_position: Dict[str, Any] = Field(..., description="User's final position on the dilemma")
|
||||
key_insights: Optional[List[str]] = Field(default=[], description="Key insights gained")
|
||||
|
||||
|
||||
# Strategic Games Endpoints
|
||||
@router.get("/available")
|
||||
async def get_available_games(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get available games and user progress"""
|
||||
service = GameService(db)
|
||||
return await service.get_available_games(current_user["user_id"])
|
||||
|
||||
|
||||
@router.post("/start")
|
||||
async def start_game_session(
|
||||
config: GameConfigRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Start a new game session"""
|
||||
service = GameService(db)
|
||||
|
||||
try:
|
||||
session = await service.start_game_session(
|
||||
user_id=current_user["user_id"],
|
||||
game_type=config.game_type,
|
||||
config=config.dict()
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session.id,
|
||||
"game_type": session.game_type,
|
||||
"difficulty": session.difficulty_level,
|
||||
"initial_state": session.current_state,
|
||||
"ai_config": session.ai_opponent_config,
|
||||
"started_at": session.started_at.isoformat()
|
||||
}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/session/{session_id}")
|
||||
async def get_game_session(
|
||||
session_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get current game session details"""
|
||||
service = GameService(db)
|
||||
|
||||
try:
|
||||
analysis = await service.get_game_analysis(session_id, current_user["user_id"])
|
||||
return analysis
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/session/{session_id}/move")
|
||||
async def make_move(
|
||||
session_id: str,
|
||||
move: GameMoveRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Make a move in the game"""
|
||||
service = GameService(db)
|
||||
|
||||
try:
|
||||
result = await service.make_move(
|
||||
session_id=session_id,
|
||||
user_id=current_user["user_id"],
|
||||
move_data=move.move_data
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/session/{session_id}/analysis")
|
||||
async def get_game_analysis(
|
||||
session_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get detailed game analysis"""
|
||||
service = GameService(db)
|
||||
|
||||
try:
|
||||
analysis = await service.get_game_analysis(session_id, current_user["user_id"])
|
||||
return analysis
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_game_history(
|
||||
game_type: Optional[str] = Query(None, description="Filter by game type"),
|
||||
limit: int = Query(20, description="Number of games to return", ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get user's game history"""
|
||||
service = GameService(db)
|
||||
return await service.get_user_game_history(
|
||||
user_id=current_user["user_id"],
|
||||
game_type=game_type,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
|
||||
# Logic Puzzles Endpoints
|
||||
@router.get("/puzzles/available")
|
||||
async def get_available_puzzles(
|
||||
category: Optional[str] = Query(None, description="Filter by puzzle category"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get available puzzle categories and difficulty recommendations"""
|
||||
service = PuzzleService(db)
|
||||
return await service.get_available_puzzles(current_user["user_id"], category)
|
||||
|
||||
|
||||
@router.post("/puzzles/start")
|
||||
async def start_puzzle_session(
|
||||
config: PuzzleConfigRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Start a new puzzle session"""
|
||||
service = PuzzleService(db)
|
||||
|
||||
try:
|
||||
session = await service.start_puzzle_session(
|
||||
user_id=current_user["user_id"],
|
||||
puzzle_type=config.puzzle_type,
|
||||
difficulty=config.difficulty
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session.id,
|
||||
"puzzle_type": session.puzzle_type,
|
||||
"difficulty": session.difficulty_rating,
|
||||
"puzzle": session.puzzle_definition,
|
||||
"estimated_time": session.estimated_time_minutes,
|
||||
"started_at": session.started_at.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/puzzles/{session_id}/solve")
|
||||
async def submit_puzzle_solution(
|
||||
session_id: str,
|
||||
solution: PuzzleSolutionRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Submit a solution for the puzzle"""
|
||||
service = PuzzleService(db)
|
||||
|
||||
try:
|
||||
result = await service.submit_solution(
|
||||
session_id=session_id,
|
||||
user_id=current_user["user_id"],
|
||||
solution=solution.solution,
|
||||
reasoning=solution.reasoning
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/puzzles/{session_id}/hint")
|
||||
async def get_puzzle_hint(
|
||||
session_id: str,
|
||||
hint_request: HintRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get a hint for the current puzzle"""
|
||||
service = PuzzleService(db)
|
||||
|
||||
try:
|
||||
hint = await service.get_hint(
|
||||
session_id=session_id,
|
||||
user_id=current_user["user_id"],
|
||||
hint_level=hint_request.hint_level
|
||||
)
|
||||
return hint
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
# Philosophical Dilemmas Endpoints
|
||||
@router.get("/dilemmas/available")
|
||||
async def get_available_dilemmas(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get available philosophical dilemmas"""
|
||||
service = PhilosophicalDialogueService(db)
|
||||
return await service.get_available_dilemmas(current_user["user_id"])
|
||||
|
||||
|
||||
@router.post("/dilemmas/start")
|
||||
async def start_philosophical_dialogue(
|
||||
config: DilemmaConfigRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Start a new philosophical dialogue session"""
|
||||
service = PhilosophicalDialogueService(db)
|
||||
|
||||
try:
|
||||
dialogue = await service.start_dialogue_session(
|
||||
user_id=current_user["user_id"],
|
||||
dilemma_type=config.dilemma_type,
|
||||
topic=config.topic
|
||||
)
|
||||
|
||||
return {
|
||||
"dialogue_id": dialogue.id,
|
||||
"dilemma_title": dialogue.dilemma_title,
|
||||
"scenario": dialogue.scenario_description,
|
||||
"framework_options": dialogue.framework_options,
|
||||
"complexity": dialogue.complexity_level,
|
||||
"estimated_time": dialogue.estimated_discussion_time,
|
||||
"started_at": dialogue.started_at.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/dilemmas/{dialogue_id}/respond")
|
||||
async def submit_dilemma_response(
|
||||
dialogue_id: str,
|
||||
response: DilemmaResponseRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Submit a response to the philosophical dilemma"""
|
||||
service = PhilosophicalDialogueService(db)
|
||||
|
||||
try:
|
||||
result = await service.submit_response(
|
||||
dialogue_id=dialogue_id,
|
||||
user_id=current_user["user_id"],
|
||||
response=response.response,
|
||||
framework=response.framework
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/dilemmas/{dialogue_id}/conclude")
|
||||
async def conclude_philosophical_dialogue(
|
||||
dialogue_id: str,
|
||||
final_position: DilemmaFinalPositionRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Conclude the philosophical dialogue with final assessment"""
|
||||
service = PhilosophicalDialogueService(db)
|
||||
|
||||
try:
|
||||
result = await service.conclude_dialogue(
|
||||
dialogue_id=dialogue_id,
|
||||
user_id=current_user["user_id"],
|
||||
final_position=final_position.final_position
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/dilemmas/{dialogue_id}")
|
||||
async def get_dialogue_session(
|
||||
dialogue_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get philosophical dialogue session details"""
|
||||
query = """
|
||||
SELECT * FROM philosophical_dialogues
|
||||
WHERE id = :dialogue_id AND user_id = :user_id
|
||||
"""
|
||||
|
||||
result = await db.execute(query, {
|
||||
"dialogue_id": dialogue_id,
|
||||
"user_id": current_user["user_id"]
|
||||
})
|
||||
dialogue = result.fetchone()
|
||||
|
||||
if not dialogue:
|
||||
raise HTTPException(status_code=404, detail="Dialogue session not found")
|
||||
|
||||
return {
|
||||
"dialogue_id": dialogue["id"],
|
||||
"dilemma_title": dialogue["dilemma_title"],
|
||||
"scenario": dialogue["scenario_description"],
|
||||
"dialogue_history": dialogue["dialogue_history"],
|
||||
"frameworks_explored": dialogue["frameworks_explored"],
|
||||
"status": dialogue["dialogue_status"],
|
||||
"exchange_count": dialogue["exchange_count"],
|
||||
"started_at": dialogue["started_at"],
|
||||
"last_exchange_at": dialogue["last_exchange_at"]
|
||||
}
|
||||
|
||||
|
||||
# Learning Analytics Endpoints
|
||||
@router.get("/analytics/progress")
|
||||
async def get_learning_progress(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get comprehensive learning progress and analytics"""
|
||||
service = GameService(db)
|
||||
analytics = await service.get_or_create_analytics(current_user["user_id"])
|
||||
|
||||
return {
|
||||
"overall_progress": {
|
||||
"total_sessions": analytics.total_sessions,
|
||||
"total_time_minutes": analytics.total_time_minutes,
|
||||
"current_streak": analytics.current_streak_days,
|
||||
"longest_streak": analytics.longest_streak_days
|
||||
},
|
||||
"game_ratings": {
|
||||
"chess": analytics.chess_rating,
|
||||
"go": analytics.go_rating,
|
||||
"puzzle_level": analytics.puzzle_solving_level,
|
||||
"philosophical_depth": analytics.philosophical_depth_level
|
||||
},
|
||||
"cognitive_skills": {
|
||||
"strategic_thinking": analytics.strategic_thinking_score,
|
||||
"logical_reasoning": analytics.logical_reasoning_score,
|
||||
"creative_problem_solving": analytics.creative_problem_solving_score,
|
||||
"ethical_reasoning": analytics.ethical_reasoning_score,
|
||||
"pattern_recognition": analytics.pattern_recognition_score,
|
||||
"metacognitive_awareness": analytics.metacognitive_awareness_score
|
||||
},
|
||||
"thinking_style": {
|
||||
"system1_reliance": analytics.system1_reliance_average,
|
||||
"system2_engagement": analytics.system2_engagement_average,
|
||||
"intuition_accuracy": analytics.intuition_accuracy_score,
|
||||
"reflection_frequency": analytics.reflection_frequency_score
|
||||
},
|
||||
"ai_collaboration": {
|
||||
"dependency_index": analytics.ai_dependency_index,
|
||||
"prompt_engineering": analytics.prompt_engineering_skill,
|
||||
"output_evaluation": analytics.ai_output_evaluation_skill,
|
||||
"collaborative_solving": analytics.collaborative_problem_solving
|
||||
},
|
||||
"achievements": analytics.achievement_badges,
|
||||
"recommendations": analytics.recommended_activities,
|
||||
"last_activity": analytics.last_activity_date.isoformat() if analytics.last_activity_date else None
|
||||
}
|
||||
|
||||
|
||||
@router.get("/analytics/trends")
|
||||
async def get_learning_trends(
|
||||
timeframe: str = Query("30d", description="Timeframe: 7d, 30d, 90d, 1y"),
|
||||
skill_area: Optional[str] = Query(None, description="Specific skill area to analyze"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get learning trends and performance over time"""
|
||||
service = GameService(db)
|
||||
analytics = await service.get_or_create_analytics(current_user["user_id"])
|
||||
|
||||
# This would typically involve more complex analytics queries
|
||||
# For now, return structured trend data
|
||||
return {
|
||||
"timeframe": timeframe,
|
||||
"skill_progression": analytics.skill_progression_data,
|
||||
"performance_trends": {
|
||||
"chess_rating_history": [{"date": "2024-01-01", "rating": 1200}], # Mock data
|
||||
"puzzle_completion_rate": [{"week": 1, "rate": 0.75}],
|
||||
"session_frequency": [{"week": 1, "sessions": 3}]
|
||||
},
|
||||
"comparative_metrics": {
|
||||
"peer_comparison": "Above average",
|
||||
"improvement_rate": "15% this month",
|
||||
"consistency_score": 0.85
|
||||
},
|
||||
"insights": [
|
||||
"Strong improvement in strategic thinking",
|
||||
"Puzzle-solving speed has increased 20%",
|
||||
"Consider more challenging philosophical dilemmas"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/analytics/recommendations")
|
||||
async def get_learning_recommendations(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get personalized learning recommendations"""
|
||||
service = GameService(db)
|
||||
analytics = await service.get_or_create_analytics(current_user["user_id"])
|
||||
|
||||
return {
|
||||
"next_activities": [
|
||||
{
|
||||
"type": "chess",
|
||||
"difficulty": "advanced",
|
||||
"reason": "Your tactical skills have improved significantly",
|
||||
"estimated_time": 30
|
||||
},
|
||||
{
|
||||
"type": "logical_deduction",
|
||||
"difficulty": 6,
|
||||
"reason": "Ready for more complex reasoning challenges",
|
||||
"estimated_time": 20
|
||||
},
|
||||
{
|
||||
"type": "ai_consciousness",
|
||||
"topic": "chinese_room",
|
||||
"reason": "Explore deeper philosophical concepts",
|
||||
"estimated_time": 25
|
||||
}
|
||||
],
|
||||
"skill_focus_areas": [
|
||||
"Synthesis of multiple ethical frameworks",
|
||||
"Advanced strategic planning in Go",
|
||||
"Metacognitive awareness development"
|
||||
],
|
||||
"adaptive_settings": {
|
||||
"chess_difficulty": "advanced",
|
||||
"puzzle_difficulty": min(analytics.puzzle_solving_level + 1, 10),
|
||||
"philosophical_complexity": "advanced" if analytics.philosophical_depth_level > 6 else "intermediate"
|
||||
},
|
||||
"learning_goals": analytics.learning_goals or [
|
||||
"Improve system 2 thinking engagement",
|
||||
"Develop better AI collaboration skills",
|
||||
"Master ethical framework application"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.post("/analytics/reflection")
|
||||
async def submit_learning_reflection(
|
||||
reflection_data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Submit learning reflection and self-assessment"""
|
||||
service = GameService(db)
|
||||
analytics = await service.get_or_create_analytics(current_user["user_id"])
|
||||
|
||||
# Process reflection data and update analytics
|
||||
# This would involve sophisticated analysis of user self-reflection
|
||||
|
||||
return {
|
||||
"reflection_recorded": True,
|
||||
"metacognitive_feedback": "Your self-awareness of thinking patterns is improving",
|
||||
"updated_recommendations": [
|
||||
"Continue exploring areas where intuition conflicts with analysis",
|
||||
"Practice explaining your reasoning process more explicitly"
|
||||
],
|
||||
"insights_gained": reflection_data.get("insights", [])
|
||||
}
|
||||
172
apps/tenant-backend/app/api/v1/models.py
Normal file
172
apps/tenant-backend/app/api/v1/models.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Tenant Models API - Interface to Resource Cluster Model Management
|
||||
|
||||
Provides tenant-scoped access to available AI models from the Resource Cluster.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import APIRouter, HTTPException, status, Depends
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.core.config import get_settings
|
||||
from app.core.cache import get_cache
|
||||
from app.services.resource_cluster_client import ResourceClusterClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
cache = get_cache()
|
||||
|
||||
router = APIRouter(prefix="/api/v1/models", tags=["Models"])
|
||||
|
||||
|
||||
@router.get("/", summary="List available models for tenant")
|
||||
async def list_available_models(
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get list of AI models available to the current tenant"""
|
||||
|
||||
try:
|
||||
# Get tenant domain from current user
|
||||
tenant_domain = current_user.get("tenant_domain", "default")
|
||||
|
||||
# Check cache first (5-minute TTL)
|
||||
cache_key = f"models_list_{tenant_domain}"
|
||||
cached_models = cache.get(cache_key, ttl=300)
|
||||
if cached_models:
|
||||
logger.debug(f"Returning cached model list for tenant {tenant_domain}")
|
||||
return {**cached_models, "cached": True}
|
||||
|
||||
# Call Resource Cluster models API - use Docker service name if in container
|
||||
import os
|
||||
if os.path.exists('/.dockerenv'):
|
||||
resource_cluster_url = "http://resource-cluster:8000"
|
||||
else:
|
||||
resource_cluster_url = settings.resource_cluster_url
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{resource_cluster_url}/api/v1/models/",
|
||||
headers={
|
||||
"X-Tenant-Domain": tenant_domain
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
models_data = response.json()
|
||||
models = models_data.get("models", [])
|
||||
|
||||
# Filter models by health and deployment status
|
||||
available_models = [
|
||||
{
|
||||
"value": model["id"], # model_id string for backwards compatibility
|
||||
"uuid": model.get("uuid"), # Database UUID for unique identification
|
||||
"label": model["name"],
|
||||
"description": model["description"],
|
||||
"provider": model["provider"],
|
||||
"model_type": model["model_type"],
|
||||
"max_tokens": model["performance"]["max_tokens"],
|
||||
"context_window": model["performance"]["context_window"],
|
||||
"cost_per_1k_tokens": model["performance"]["cost_per_1k_tokens"],
|
||||
"latency_p50_ms": model["performance"]["latency_p50_ms"],
|
||||
"health_status": model["status"]["health"],
|
||||
"deployment_status": model["status"]["deployment"]
|
||||
}
|
||||
for model in models
|
||||
if (model["status"]["deployment"] == "available" and
|
||||
model["status"]["health"] in ["healthy", "unknown"] and
|
||||
model["model_type"] != "embedding")
|
||||
]
|
||||
|
||||
# Sort by provider preference (NVIDIA first, then Groq) and then by performance
|
||||
provider_order = {"nvidia": 0, "groq": 1}
|
||||
available_models.sort(key=lambda x: (
|
||||
provider_order.get(x["provider"], 99), # NVIDIA first, then Groq
|
||||
x["latency_p50_ms"] or 999 # Lower latency first
|
||||
))
|
||||
|
||||
result = {
|
||||
"models": available_models,
|
||||
"total": len(available_models),
|
||||
"tenant_domain": tenant_domain,
|
||||
"last_updated": models_data.get("last_updated"),
|
||||
"cached": False
|
||||
}
|
||||
|
||||
# Cache the result for 5 minutes
|
||||
cache.set(cache_key, result)
|
||||
logger.debug(f"Cached model list for tenant {tenant_domain}")
|
||||
|
||||
return result
|
||||
|
||||
else:
|
||||
# Resource Cluster unavailable - return empty list
|
||||
logger.warning(f"Resource Cluster unavailable (HTTP {response.status_code})")
|
||||
return {
|
||||
"models": [],
|
||||
"total": 0,
|
||||
"tenant_domain": tenant_domain,
|
||||
"message": "No models available - resource cluster unavailable"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching models from Resource Cluster: {e}")
|
||||
# Return empty list in case of error
|
||||
return {
|
||||
"models": [],
|
||||
"total": 0,
|
||||
"tenant_domain": current_user.get("tenant_domain", "default"),
|
||||
"message": "No models available - service error"
|
||||
}
|
||||
|
||||
|
||||
|
||||
@router.get("/{model_id}", summary="Get model details")
|
||||
async def get_model_details(
|
||||
model_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get detailed information about a specific model"""
|
||||
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "default")
|
||||
|
||||
# Call Resource Cluster for model details - use Docker service name if in container
|
||||
import os
|
||||
if os.path.exists('/.dockerenv'):
|
||||
resource_cluster_url = "http://resource-cluster:8000"
|
||||
else:
|
||||
resource_cluster_url = settings.resource_cluster_url
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{resource_cluster_url}/api/v1/models/{model_id}",
|
||||
headers={
|
||||
"X-Tenant-Domain": tenant_domain
|
||||
},
|
||||
timeout=15.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code == 404:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model {model_id} not found"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Resource Cluster unavailable"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model {model_id} details: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Failed to get model details"
|
||||
)
|
||||
2908
apps/tenant-backend/app/api/v1/observability.py
Normal file
2908
apps/tenant-backend/app/api/v1/observability.py
Normal file
File diff suppressed because it is too large
Load Diff
238
apps/tenant-backend/app/api/v1/optics.py
Normal file
238
apps/tenant-backend/app/api/v1/optics.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Optics Cost Tracking API Endpoints
|
||||
|
||||
Provides cost visibility for inference and storage usage.
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.api.v1.observability import get_current_user, get_user_role
|
||||
|
||||
from app.services.optics_service import (
|
||||
fetch_optics_settings,
|
||||
get_optics_cost_summary,
|
||||
STORAGE_COST_PER_MB_CENTS
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/optics", tags=["Optics Cost Tracking"])
|
||||
|
||||
|
||||
# Response models
|
||||
class OpticsSettingsResponse(BaseModel):
|
||||
enabled: bool
|
||||
storage_cost_per_mb_cents: float
|
||||
show_to_admins_only: bool = True
|
||||
|
||||
|
||||
class ModelCostBreakdown(BaseModel):
|
||||
model_id: str
|
||||
model_name: str
|
||||
tokens: int
|
||||
conversations: int
|
||||
messages: int
|
||||
cost_cents: float
|
||||
cost_display: str
|
||||
percentage: float
|
||||
|
||||
|
||||
class UserCostBreakdown(BaseModel):
|
||||
user_id: str
|
||||
email: str
|
||||
tokens: int
|
||||
cost_cents: float
|
||||
cost_display: str
|
||||
percentage: float
|
||||
|
||||
|
||||
class OpticsCostResponse(BaseModel):
|
||||
enabled: bool
|
||||
inference_cost_cents: float
|
||||
storage_cost_cents: float
|
||||
total_cost_cents: float
|
||||
inference_cost_display: str
|
||||
storage_cost_display: str
|
||||
total_cost_display: str
|
||||
total_tokens: int
|
||||
total_storage_mb: float
|
||||
document_count: int
|
||||
dataset_count: int
|
||||
by_model: List[ModelCostBreakdown]
|
||||
by_user: Optional[List[UserCostBreakdown]] = None
|
||||
period_start: str
|
||||
period_end: str
|
||||
|
||||
|
||||
@router.get("/settings", response_model=OpticsSettingsResponse)
|
||||
async def get_optics_settings(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Check if Optics is enabled for the current tenant.
|
||||
|
||||
This endpoint is used by the frontend to determine whether
|
||||
to show the Optics tab in the observability dashboard.
|
||||
"""
|
||||
tenant_domain = current_user.get("tenant_domain", "test-company")
|
||||
|
||||
try:
|
||||
settings = await fetch_optics_settings(tenant_domain)
|
||||
|
||||
return OpticsSettingsResponse(
|
||||
enabled=settings.get("enabled", False),
|
||||
storage_cost_per_mb_cents=settings.get("storage_cost_per_mb_cents", STORAGE_COST_PER_MB_CENTS),
|
||||
show_to_admins_only=True # Only admins can see user breakdown
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching optics settings: {str(e)}")
|
||||
return OpticsSettingsResponse(
|
||||
enabled=False,
|
||||
storage_cost_per_mb_cents=STORAGE_COST_PER_MB_CENTS,
|
||||
show_to_admins_only=True
|
||||
)
|
||||
|
||||
|
||||
@router.get("/costs", response_model=OpticsCostResponse)
|
||||
async def get_optics_costs(
|
||||
days: Optional[int] = Query(30, ge=1, le=365, description="Number of days to look back"),
|
||||
start_date: Optional[str] = Query(None, description="Custom start date (ISO format)"),
|
||||
end_date: Optional[str] = Query(None, description="Custom end date (ISO format)"),
|
||||
user_id: Optional[str] = Query(None, description="Filter by user ID (admin only)"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get Optics cost breakdown for the current tenant.
|
||||
|
||||
Returns inference costs calculated from token usage and model pricing,
|
||||
plus storage costs at the configured rate (default 4 cents/MB).
|
||||
|
||||
- **days**: Number of days to look back (default 30)
|
||||
- **start_date**: Custom start date (overrides days)
|
||||
- **end_date**: Custom end date
|
||||
- **user_id**: Filter by specific user (admin only)
|
||||
"""
|
||||
tenant_domain = current_user.get("tenant_domain", "test-company")
|
||||
|
||||
# Check if Optics is enabled
|
||||
settings = await fetch_optics_settings(tenant_domain)
|
||||
if not settings.get("enabled", False):
|
||||
return OpticsCostResponse(
|
||||
enabled=False,
|
||||
inference_cost_cents=0,
|
||||
storage_cost_cents=0,
|
||||
total_cost_cents=0,
|
||||
inference_cost_display="$0.00",
|
||||
storage_cost_display="$0.00",
|
||||
total_cost_display="$0.00",
|
||||
total_tokens=0,
|
||||
total_storage_mb=0,
|
||||
document_count=0,
|
||||
dataset_count=0,
|
||||
by_model=[],
|
||||
by_user=None,
|
||||
period_start=datetime.utcnow().isoformat(),
|
||||
period_end=datetime.utcnow().isoformat()
|
||||
)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get user role for permission checks
|
||||
user_email = current_user.get("email", "")
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
is_admin = user_role in ["admin", "developer"]
|
||||
|
||||
# Handle user filter - only admins can filter by user
|
||||
filter_user_id = None
|
||||
if user_id:
|
||||
if not is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only admins can filter by user"
|
||||
)
|
||||
filter_user_id = user_id
|
||||
elif not is_admin:
|
||||
# Non-admins can only see their own data
|
||||
# Get user UUID from email
|
||||
user_query = f"""
|
||||
SELECT id FROM tenant_{tenant_domain.replace('-', '_')}.users
|
||||
WHERE email = $1 LIMIT 1
|
||||
"""
|
||||
user_result = await pg_client.execute_query(user_query, user_email)
|
||||
if user_result:
|
||||
filter_user_id = str(user_result[0]["id"])
|
||||
|
||||
# Calculate date range
|
||||
date_end = datetime.utcnow()
|
||||
date_start = date_end - timedelta(days=days)
|
||||
|
||||
if start_date:
|
||||
try:
|
||||
date_start = datetime.fromisoformat(start_date.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid start_date format. Use ISO format."
|
||||
)
|
||||
|
||||
if end_date:
|
||||
try:
|
||||
date_end = datetime.fromisoformat(end_date.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid end_date format. Use ISO format."
|
||||
)
|
||||
|
||||
try:
|
||||
cost_summary = await get_optics_cost_summary(
|
||||
pg_client=pg_client,
|
||||
tenant_domain=tenant_domain,
|
||||
date_start=date_start,
|
||||
date_end=date_end,
|
||||
user_id=filter_user_id,
|
||||
include_user_breakdown=is_admin and not filter_user_id # Only include breakdown for platform view
|
||||
)
|
||||
|
||||
# Convert to response model
|
||||
by_model = [
|
||||
ModelCostBreakdown(**item)
|
||||
for item in cost_summary.get("by_model", [])
|
||||
]
|
||||
|
||||
by_user = None
|
||||
if cost_summary.get("by_user"):
|
||||
by_user = [
|
||||
UserCostBreakdown(**item)
|
||||
for item in cost_summary["by_user"]
|
||||
]
|
||||
|
||||
return OpticsCostResponse(
|
||||
enabled=True,
|
||||
inference_cost_cents=cost_summary["inference_cost_cents"],
|
||||
storage_cost_cents=cost_summary["storage_cost_cents"],
|
||||
total_cost_cents=cost_summary["total_cost_cents"],
|
||||
inference_cost_display=cost_summary["inference_cost_display"],
|
||||
storage_cost_display=cost_summary["storage_cost_display"],
|
||||
total_cost_display=cost_summary["total_cost_display"],
|
||||
total_tokens=cost_summary["total_tokens"],
|
||||
total_storage_mb=cost_summary["total_storage_mb"],
|
||||
document_count=cost_summary["document_count"],
|
||||
dataset_count=cost_summary["dataset_count"],
|
||||
by_model=by_model,
|
||||
by_user=by_user,
|
||||
period_start=cost_summary["period_start"],
|
||||
period_end=cost_summary["period_end"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating optics costs: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to calculate costs"
|
||||
)
|
||||
535
apps/tenant-backend/app/api/v1/rag_visualization.py
Normal file
535
apps/tenant-backend/app/api/v1/rag_visualization.py
Normal file
@@ -0,0 +1,535 @@
|
||||
"""
|
||||
RAG Network Visualization API for GT 2.0
|
||||
Provides force-directed graph data and semantic relationships
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.core.database import get_db_session
|
||||
from app.core.security import get_current_user
|
||||
from app.services.rag_service import RAGService
|
||||
import random
|
||||
import math
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/rag/visualization", tags=["rag-visualization"])
|
||||
|
||||
class NetworkNode:
|
||||
"""Represents a node in the knowledge network"""
|
||||
def __init__(self, id: str, label: str, type: str, metadata: Dict[str, Any]):
|
||||
self.id = id
|
||||
self.label = label
|
||||
self.type = type # document, chunk, concept, query
|
||||
self.metadata = metadata
|
||||
self.x = random.uniform(-100, 100)
|
||||
self.y = random.uniform(-100, 100)
|
||||
self.size = self._calculate_size(type, metadata)
|
||||
self.color = self._get_color_by_type(type)
|
||||
self.importance = metadata.get('importance', 0.5)
|
||||
|
||||
def _calculate_size(self, type: str, metadata: Dict[str, Any]) -> int:
|
||||
"""Calculate node size based on type and metadata"""
|
||||
base_sizes = {
|
||||
"document": 20,
|
||||
"chunk": 10,
|
||||
"concept": 15,
|
||||
"query": 25
|
||||
}
|
||||
|
||||
base_size = base_sizes.get(type, 10)
|
||||
|
||||
# Adjust based on importance/usage
|
||||
if 'usage_count' in metadata:
|
||||
usage_multiplier = min(2.0, 1.0 + metadata['usage_count'] / 10.0)
|
||||
base_size = int(base_size * usage_multiplier)
|
||||
|
||||
if 'relevance_score' in metadata:
|
||||
relevance_multiplier = 0.5 + metadata['relevance_score'] * 0.5
|
||||
base_size = int(base_size * relevance_multiplier)
|
||||
|
||||
return max(5, min(50, base_size))
|
||||
|
||||
def _get_color_by_type(self, type: str) -> str:
|
||||
"""Get color based on node type"""
|
||||
colors = {
|
||||
"document": "#00d084", # GT Green
|
||||
"chunk": "#677489", # GT Gray
|
||||
"concept": "#4a5568", # Darker gray
|
||||
"query": "#ff6b6b", # Red for queries
|
||||
"dataset": "#4ade80", # Light green
|
||||
"user": "#3b82f6" # Blue
|
||||
}
|
||||
return colors.get(type, "#9aa5b1")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"label": self.label,
|
||||
"type": self.type,
|
||||
"x": float(self.x),
|
||||
"y": float(self.y),
|
||||
"size": self.size,
|
||||
"color": self.color,
|
||||
"importance": self.importance,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
class NetworkEdge:
|
||||
"""Represents an edge in the knowledge network"""
|
||||
def __init__(self, source: str, target: str, weight: float, edge_type: str = "semantic"):
|
||||
self.source = source
|
||||
self.target = target
|
||||
self.weight = max(0.0, min(1.0, weight)) # Clamp to 0-1
|
||||
self.type = edge_type
|
||||
self.color = self._get_edge_color(weight)
|
||||
self.width = self._get_edge_width(weight)
|
||||
self.animated = weight > 0.7
|
||||
|
||||
def _get_edge_color(self, weight: float) -> str:
|
||||
"""Get edge color based on weight"""
|
||||
if weight > 0.8:
|
||||
return "#00d084" # Strong connection - GT green
|
||||
elif weight > 0.6:
|
||||
return "#4ade80" # Medium-strong - light green
|
||||
elif weight > 0.4:
|
||||
return "#9aa5b1" # Medium - gray
|
||||
else:
|
||||
return "#d1d9e0" # Weak - light gray
|
||||
|
||||
def _get_edge_width(self, weight: float) -> float:
|
||||
"""Get edge width based on weight"""
|
||||
return 1.0 + (weight * 4.0) # 1-5px width
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
return {
|
||||
"source": self.source,
|
||||
"target": self.target,
|
||||
"weight": self.weight,
|
||||
"type": self.type,
|
||||
"color": self.color,
|
||||
"width": self.width,
|
||||
"animated": self.animated
|
||||
}
|
||||
|
||||
@router.get("/network/{dataset_id}")
|
||||
async def get_knowledge_network(
|
||||
dataset_id: str,
|
||||
max_nodes: int = Query(default=100, le=500),
|
||||
min_similarity: float = Query(default=0.3, ge=0, le=1),
|
||||
include_concepts: bool = Query(default=True),
|
||||
layout_algorithm: str = Query(default="force", description="force, circular, hierarchical"),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get force-directed graph data for a RAG dataset
|
||||
Returns nodes (documents/chunks/concepts) and edges (semantic relationships)
|
||||
"""
|
||||
try:
|
||||
# rag_service = RAGService(db)
|
||||
user_id = current_user["sub"]
|
||||
tenant_id = current_user.get("tenant_id", "default")
|
||||
|
||||
# TODO: Verify dataset ownership and access permissions
|
||||
# dataset = await rag_service._get_user_dataset(user_id, dataset_id)
|
||||
# if not dataset:
|
||||
# raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
# For now, generate mock data that represents the expected structure
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
# Generate mock document nodes
|
||||
doc_count = min(max_nodes // 3, 20)
|
||||
for i in range(doc_count):
|
||||
doc_node = NetworkNode(
|
||||
id=f"doc_{i}",
|
||||
label=f"Document {i+1}",
|
||||
type="document",
|
||||
metadata={
|
||||
"filename": f"document_{i+1}.pdf",
|
||||
"size_bytes": random.randint(1000, 100000),
|
||||
"chunk_count": random.randint(5, 50),
|
||||
"upload_date": datetime.utcnow().isoformat(),
|
||||
"usage_count": random.randint(0, 100),
|
||||
"relevance_score": random.uniform(0.3, 1.0)
|
||||
}
|
||||
)
|
||||
nodes.append(doc_node.to_dict())
|
||||
|
||||
# Generate chunks for this document
|
||||
chunk_count = min(5, (max_nodes - len(nodes)) // (doc_count - i))
|
||||
for j in range(chunk_count):
|
||||
chunk_node = NetworkNode(
|
||||
id=f"chunk_{i}_{j}",
|
||||
label=f"Chunk {j+1}",
|
||||
type="chunk",
|
||||
metadata={
|
||||
"document_id": f"doc_{i}",
|
||||
"chunk_index": j,
|
||||
"token_count": random.randint(50, 500),
|
||||
"semantic_density": random.uniform(0.2, 0.9)
|
||||
}
|
||||
)
|
||||
nodes.append(chunk_node.to_dict())
|
||||
|
||||
# Connect chunk to document
|
||||
edge = NetworkEdge(
|
||||
source=f"doc_{i}",
|
||||
target=f"chunk_{i}_{j}",
|
||||
weight=1.0,
|
||||
edge_type="contains"
|
||||
)
|
||||
edges.append(edge.to_dict())
|
||||
|
||||
# Generate concept nodes if requested
|
||||
if include_concepts:
|
||||
concept_count = min(10, max_nodes - len(nodes))
|
||||
concepts = ["AI", "Machine Learning", "Neural Networks", "Data Science",
|
||||
"Python", "JavaScript", "API", "Database", "Security", "Cloud"]
|
||||
|
||||
for i in range(concept_count):
|
||||
if i < len(concepts):
|
||||
concept_node = NetworkNode(
|
||||
id=f"concept_{i}",
|
||||
label=concepts[i],
|
||||
type="concept",
|
||||
metadata={
|
||||
"frequency": random.randint(1, 50),
|
||||
"co_occurrence_score": random.uniform(0.1, 0.8),
|
||||
"domain": "technology"
|
||||
}
|
||||
)
|
||||
nodes.append(concept_node.to_dict())
|
||||
|
||||
# Generate semantic relationships between chunks
|
||||
chunk_nodes = [n for n in nodes if n["type"] == "chunk"]
|
||||
relationship_count = 0
|
||||
max_relationships = min(len(chunk_nodes) * 2, 100)
|
||||
|
||||
for i, node1 in enumerate(chunk_nodes):
|
||||
if relationship_count >= max_relationships:
|
||||
break
|
||||
|
||||
# Connect to a few other chunks based on semantic similarity
|
||||
connection_count = random.randint(1, 4)
|
||||
|
||||
for _ in range(connection_count):
|
||||
if relationship_count >= max_relationships:
|
||||
break
|
||||
|
||||
# Select random target (avoid self-connection)
|
||||
target_idx = random.randint(0, len(chunk_nodes) - 1)
|
||||
if target_idx == i:
|
||||
continue
|
||||
|
||||
node2 = chunk_nodes[target_idx]
|
||||
|
||||
# Generate similarity score (higher for nodes with related content)
|
||||
similarity = random.uniform(0.2, 0.9)
|
||||
|
||||
# Apply minimum similarity filter
|
||||
if similarity >= min_similarity:
|
||||
edge = NetworkEdge(
|
||||
source=node1["id"],
|
||||
target=node2["id"],
|
||||
weight=similarity,
|
||||
edge_type="semantic_similarity"
|
||||
)
|
||||
edges.append(edge.to_dict())
|
||||
relationship_count += 1
|
||||
|
||||
# Connect concepts to relevant chunks
|
||||
concept_nodes = [n for n in nodes if n["type"] == "concept"]
|
||||
for concept in concept_nodes:
|
||||
# Connect to 2-5 relevant chunks
|
||||
connection_count = random.randint(2, 6)
|
||||
connected_chunks = random.sample(
|
||||
range(len(chunk_nodes)),
|
||||
k=min(connection_count, len(chunk_nodes))
|
||||
)
|
||||
|
||||
for chunk_idx in connected_chunks:
|
||||
chunk = chunk_nodes[chunk_idx]
|
||||
relevance = random.uniform(0.4, 0.9)
|
||||
|
||||
edge = NetworkEdge(
|
||||
source=concept["id"],
|
||||
target=chunk["id"],
|
||||
weight=relevance,
|
||||
edge_type="concept_relevance"
|
||||
)
|
||||
edges.append(edge.to_dict())
|
||||
|
||||
# Apply layout algorithm positioning
|
||||
if layout_algorithm == "circular":
|
||||
_apply_circular_layout(nodes)
|
||||
elif layout_algorithm == "hierarchical":
|
||||
_apply_hierarchical_layout(nodes)
|
||||
# force layout uses random positioning (already applied)
|
||||
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"metadata": {
|
||||
"dataset_name": f"Dataset {dataset_id}",
|
||||
"total_nodes": len(nodes),
|
||||
"total_edges": len(edges),
|
||||
"node_types": {
|
||||
node_type: len([n for n in nodes if n["type"] == node_type])
|
||||
for node_type in set(n["type"] for n in nodes)
|
||||
},
|
||||
"edge_types": {
|
||||
edge_type: len([e for e in edges if e["type"] == edge_type])
|
||||
for edge_type in set(e["type"] for e in edges)
|
||||
},
|
||||
"min_similarity": min_similarity,
|
||||
"layout_algorithm": layout_algorithm,
|
||||
"generated_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating knowledge network: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/search/visual")
|
||||
async def visual_search(
|
||||
query: str,
|
||||
dataset_ids: Optional[List[str]] = Query(default=None),
|
||||
top_k: int = Query(default=10, le=50),
|
||||
include_network: bool = Query(default=True),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform semantic search with visualization metadata
|
||||
Returns search results with connection paths and confidence scores
|
||||
"""
|
||||
try:
|
||||
# rag_service = RAGService(db)
|
||||
user_id = current_user["sub"]
|
||||
tenant_id = current_user.get("tenant_id", "default")
|
||||
|
||||
# TODO: Perform actual search using RAG service
|
||||
# results = await rag_service.search_documents(
|
||||
# user_id=user_id,
|
||||
# tenant_id=tenant_id,
|
||||
# query=query,
|
||||
# dataset_ids=dataset_ids,
|
||||
# top_k=top_k
|
||||
# )
|
||||
|
||||
# Generate mock search results for now
|
||||
results = []
|
||||
for i in range(top_k):
|
||||
similarity = random.uniform(0.5, 0.95) # Mock similarity scores
|
||||
results.append({
|
||||
"id": f"result_{i}",
|
||||
"document_id": f"doc_{random.randint(0, 9)}",
|
||||
"chunk_id": f"chunk_{i}",
|
||||
"content": f"Mock search result {i+1} for query: {query}",
|
||||
"metadata": {
|
||||
"filename": f"document_{i}.pdf",
|
||||
"page_number": random.randint(1, 100),
|
||||
"chunk_index": i
|
||||
},
|
||||
"similarity": similarity,
|
||||
"dataset_id": dataset_ids[0] if dataset_ids else "default"
|
||||
})
|
||||
|
||||
# Enhance results with visualization data
|
||||
visual_results = []
|
||||
for i, result in enumerate(results):
|
||||
visual_result = {
|
||||
**result,
|
||||
"visual_metadata": {
|
||||
"position": i + 1,
|
||||
"relevance_score": result["similarity"],
|
||||
"confidence_level": _calculate_confidence(result["similarity"]),
|
||||
"connection_strength": result["similarity"],
|
||||
"highlight_color": _get_relevance_color(result["similarity"]),
|
||||
"path_to_query": _generate_path(query, result),
|
||||
"semantic_distance": 1.0 - result["similarity"],
|
||||
"cluster_id": f"cluster_{random.randint(0, 2)}"
|
||||
}
|
||||
}
|
||||
visual_results.append(visual_result)
|
||||
|
||||
response = {
|
||||
"query": query,
|
||||
"results": visual_results,
|
||||
"total_results": len(visual_results),
|
||||
"search_metadata": {
|
||||
"execution_time_ms": random.randint(50, 200),
|
||||
"datasets_searched": dataset_ids or ["all"],
|
||||
"semantic_method": "embedding_similarity",
|
||||
"reranking_applied": True
|
||||
}
|
||||
}
|
||||
|
||||
# Add network visualization if requested
|
||||
if include_network:
|
||||
network_data = _generate_search_network(query, visual_results)
|
||||
response["network_visualization"] = network_data
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in visual search: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/datasets/{dataset_id}/stats")
|
||||
async def get_dataset_stats(
|
||||
dataset_id: str,
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get statistical information about a dataset for visualization"""
|
||||
try:
|
||||
# TODO: Implement actual dataset statistics
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"document_count": random.randint(10, 100),
|
||||
"chunk_count": random.randint(100, 1000),
|
||||
"total_tokens": random.randint(10000, 100000),
|
||||
"concept_count": random.randint(20, 200),
|
||||
"average_document_similarity": random.uniform(0.3, 0.8),
|
||||
"semantic_clusters": random.randint(3, 10),
|
||||
"most_connected_documents": [
|
||||
{"id": f"doc_{i}", "connections": random.randint(5, 50)}
|
||||
for i in range(5)
|
||||
],
|
||||
"topic_distribution": {
|
||||
f"topic_{i}": random.uniform(0.05, 0.3)
|
||||
for i in range(5)
|
||||
},
|
||||
"last_updated": datetime.utcnow().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting dataset stats: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Helper functions
|
||||
def _calculate_confidence(similarity: float) -> str:
|
||||
"""Calculate confidence level from similarity score"""
|
||||
if similarity > 0.9:
|
||||
return "very_high"
|
||||
elif similarity > 0.75:
|
||||
return "high"
|
||||
elif similarity > 0.6:
|
||||
return "medium"
|
||||
elif similarity > 0.4:
|
||||
return "low"
|
||||
else:
|
||||
return "very_low"
|
||||
|
||||
def _get_relevance_color(similarity: float) -> str:
|
||||
"""Get color based on relevance score"""
|
||||
if similarity > 0.8:
|
||||
return "#00d084" # GT Green
|
||||
elif similarity > 0.6:
|
||||
return "#4ade80" # Light green
|
||||
elif similarity > 0.4:
|
||||
return "#fbbf24" # Yellow
|
||||
else:
|
||||
return "#ef4444" # Red
|
||||
|
||||
def _generate_path(query: str, result: Dict[str, Any]) -> List[str]:
|
||||
"""Generate conceptual path from query to result"""
|
||||
return [
|
||||
"query",
|
||||
f"dataset_{result.get('dataset_id', 'unknown')}",
|
||||
f"document_{result.get('document_id', 'unknown')}",
|
||||
f"chunk_{result.get('chunk_id', 'unknown')}"
|
||||
]
|
||||
|
||||
def _generate_search_network(query: str, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Generate network visualization for search results"""
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
# Add query node
|
||||
query_node = NetworkNode(
|
||||
id="query",
|
||||
label=query[:50] + "..." if len(query) > 50 else query,
|
||||
type="query",
|
||||
metadata={"original_query": query}
|
||||
)
|
||||
nodes.append(query_node.to_dict())
|
||||
|
||||
# Add result nodes and connections
|
||||
for i, result in enumerate(results):
|
||||
result_node = NetworkNode(
|
||||
id=f"result_{i}",
|
||||
label=result.get("content", "")[:30] + "...",
|
||||
type="chunk",
|
||||
metadata={
|
||||
"similarity": result["similarity"],
|
||||
"position": i + 1,
|
||||
"document_id": result.get("document_id")
|
||||
}
|
||||
)
|
||||
nodes.append(result_node.to_dict())
|
||||
|
||||
# Connect query to result
|
||||
edge = NetworkEdge(
|
||||
source="query",
|
||||
target=f"result_{i}",
|
||||
weight=result["similarity"],
|
||||
edge_type="search_result"
|
||||
)
|
||||
edges.append(edge.to_dict())
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"center_node_id": "query",
|
||||
"layout": "radial"
|
||||
}
|
||||
|
||||
def _apply_circular_layout(nodes: List[Dict[str, Any]]) -> None:
|
||||
"""Apply circular layout to nodes"""
|
||||
if not nodes:
|
||||
return
|
||||
|
||||
radius = 150
|
||||
angle_step = 2 * math.pi / len(nodes)
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
angle = i * angle_step
|
||||
node["x"] = radius * math.cos(angle)
|
||||
node["y"] = radius * math.sin(angle)
|
||||
|
||||
def _apply_hierarchical_layout(nodes: List[Dict[str, Any]]) -> None:
|
||||
"""Apply hierarchical layout to nodes"""
|
||||
if not nodes:
|
||||
return
|
||||
|
||||
# Group by type
|
||||
node_types = {}
|
||||
for node in nodes:
|
||||
node_type = node["type"]
|
||||
if node_type not in node_types:
|
||||
node_types[node_type] = []
|
||||
node_types[node_type].append(node)
|
||||
|
||||
# Position by type levels
|
||||
y_positions = {"document": -100, "concept": 0, "chunk": 100, "query": -150}
|
||||
|
||||
for node_type, type_nodes in node_types.items():
|
||||
y_pos = y_positions.get(node_type, 0)
|
||||
x_step = 300 / max(1, len(type_nodes) - 1) if len(type_nodes) > 1 else 0
|
||||
|
||||
for i, node in enumerate(type_nodes):
|
||||
node["y"] = y_pos
|
||||
node["x"] = -150 + (i * x_step) if len(type_nodes) > 1 else 0
|
||||
697
apps/tenant-backend/app/api/v1/search.py
Normal file
697
apps/tenant-backend/app/api/v1/search.py
Normal file
@@ -0,0 +1,697 @@
|
||||
"""
|
||||
Search API for GT 2.0 Tenant Backend
|
||||
|
||||
Provides hybrid vector + text search capabilities using PGVector.
|
||||
Supports semantic similarity search, full-text search, and hybrid ranking.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Header
|
||||
from pydantic import BaseModel, Field
|
||||
import logging
|
||||
import httpx
|
||||
import time
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.services.pgvector_search_service import (
|
||||
PGVectorSearchService,
|
||||
HybridSearchResult,
|
||||
SearchConfig,
|
||||
get_pgvector_search_service
|
||||
)
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/search", tags=["search"])
|
||||
|
||||
|
||||
async def get_user_context(
|
||||
x_tenant_domain: Optional[str] = Header(None),
|
||||
x_user_id: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get user context from headers (for internal services).
|
||||
Handles email, UUID, and numeric ID formats for user identification.
|
||||
"""
|
||||
logger.info(f"🔍 GET_USER_CONTEXT: x_tenant_domain='{x_tenant_domain}', x_user_id='{x_user_id}'")
|
||||
|
||||
# Validate required headers
|
||||
if not x_tenant_domain or not x_user_id:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing required headers: X-Tenant-Domain and X-User-ID"
|
||||
)
|
||||
|
||||
# Validate and clean inputs
|
||||
x_tenant_domain = x_tenant_domain.strip() if x_tenant_domain else None
|
||||
x_user_id = x_user_id.strip() if x_user_id else None
|
||||
|
||||
if not x_tenant_domain or not x_user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid empty headers: X-Tenant-Domain and X-User-ID cannot be empty"
|
||||
)
|
||||
|
||||
logger.info(f"🔍 GET_USER_CONTEXT: Processing user_id='{x_user_id}' for tenant='{x_tenant_domain}'")
|
||||
|
||||
# Use ensure_user_uuid to handle all user ID formats (email, UUID, numeric)
|
||||
from app.core.user_resolver import ensure_user_uuid
|
||||
try:
|
||||
resolved_uuid = await ensure_user_uuid(x_user_id, x_tenant_domain)
|
||||
logger.info(f"🔍 GET_USER_CONTEXT: Resolved user_id '{x_user_id}' to UUID '{resolved_uuid}'")
|
||||
|
||||
# Determine original email if input was UUID
|
||||
user_email = x_user_id if "@" in x_user_id else None
|
||||
|
||||
# If we don't have an email, try to get it from the database
|
||||
if not user_email:
|
||||
try:
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
tenant_schema = f"tenant_{x_tenant_domain.replace('.', '_').replace('-', '_')}"
|
||||
user_row = await conn.fetchrow(
|
||||
f"SELECT email FROM {tenant_schema}.users WHERE id = $1",
|
||||
resolved_uuid
|
||||
)
|
||||
if user_row:
|
||||
user_email = user_row['email']
|
||||
else:
|
||||
# Fallback to UUID as email for backward compatibility
|
||||
user_email = resolved_uuid
|
||||
logger.warning(f"Could not find email for UUID {resolved_uuid}, using UUID as email")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to lookup email for UUID {resolved_uuid}: {e}")
|
||||
user_email = resolved_uuid
|
||||
|
||||
context = {
|
||||
"tenant_domain": x_tenant_domain,
|
||||
"id": resolved_uuid,
|
||||
"sub": resolved_uuid,
|
||||
"email": user_email,
|
||||
"user_type": "internal_service"
|
||||
}
|
||||
logger.info(f"🔍 GET_USER_CONTEXT: Returning context: {context}")
|
||||
return context
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"🔍 GET_USER_CONTEXT ERROR: Failed to resolve user_id '{x_user_id}': {e}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid user identifier '{x_user_id}': {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"🔍 GET_USER_CONTEXT ERROR: Unexpected error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal error processing user context"
|
||||
)
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Request for hybrid search"""
|
||||
query: str = Field(..., min_length=1, max_length=1000, description="Search query")
|
||||
dataset_ids: Optional[List[str]] = Field(None, description="Optional dataset IDs to search within")
|
||||
search_type: str = Field("hybrid", description="Search type: hybrid, vector, text")
|
||||
max_results: int = Field(10, ge=1, le=200, description="Maximum results to return")
|
||||
|
||||
# Advanced search parameters
|
||||
vector_weight: Optional[float] = Field(0.7, ge=0.0, le=1.0, description="Weight for vector similarity")
|
||||
text_weight: Optional[float] = Field(0.3, ge=0.0, le=1.0, description="Weight for text relevance")
|
||||
min_similarity: Optional[float] = Field(0.3, ge=0.0, le=1.0, description="Minimum similarity threshold")
|
||||
rerank_results: Optional[bool] = Field(True, description="Apply result re-ranking")
|
||||
|
||||
|
||||
class SimilarChunksRequest(BaseModel):
|
||||
"""Request for finding similar chunks"""
|
||||
chunk_id: str = Field(..., description="Reference chunk ID")
|
||||
similarity_threshold: float = Field(0.5, ge=0.0, le=1.0, description="Minimum similarity")
|
||||
max_results: int = Field(5, ge=1, le=20, description="Maximum results")
|
||||
exclude_same_document: bool = Field(True, description="Exclude chunks from same document")
|
||||
|
||||
|
||||
class SearchResultResponse(BaseModel):
|
||||
"""Individual search result"""
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
dataset_id: Optional[str]
|
||||
text: str
|
||||
metadata: Dict[str, Any]
|
||||
vector_similarity: float
|
||||
text_relevance: float
|
||||
hybrid_score: float
|
||||
rank: int
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
"""Search response with results and metadata"""
|
||||
query: str
|
||||
search_type: str
|
||||
total_results: int
|
||||
results: List[SearchResultResponse]
|
||||
search_time_ms: float
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class ConversationSearchRequest(BaseModel):
|
||||
"""Request for searching conversation history"""
|
||||
query: str = Field(..., min_length=1, max_length=500, description="Search query")
|
||||
days_back: Optional[int] = Field(30, ge=1, le=365, description="Number of days back to search")
|
||||
max_results: Optional[int] = Field(5, ge=1, le=200, description="Maximum results to return")
|
||||
agent_filter: Optional[List[str]] = Field(None, description="Filter by agent names/IDs")
|
||||
include_user_messages: Optional[bool] = Field(True, description="Include user messages in results")
|
||||
|
||||
|
||||
class DocumentChunksResponse(BaseModel):
|
||||
"""Response for document chunks"""
|
||||
document_id: str
|
||||
total_chunks: int
|
||||
chunks: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# Search Endpoints
|
||||
|
||||
@router.post("/", response_model=SearchResponse)
|
||||
async def search_documents(
|
||||
request: SearchRequest,
|
||||
user_context: Dict[str, Any] = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Perform hybrid search across user's documents and datasets.
|
||||
|
||||
Combines vector similarity search with full-text search for optimal results.
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"🔍 SEARCH_API START: request={request.dict()}")
|
||||
logger.info(f"🔍 SEARCH_API: user_context={user_context}")
|
||||
|
||||
# Extract tenant info
|
||||
tenant_domain = user_context.get("tenant_domain", "test")
|
||||
user_id = user_context.get("id", user_context.get("sub", None))
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing user ID in authentication context"
|
||||
)
|
||||
logger.info(f"🔍 SEARCH_API: extracted tenant_domain='{tenant_domain}', user_id='{user_id}'")
|
||||
|
||||
# Validate weights sum to 1.0 for hybrid search
|
||||
if request.search_type == "hybrid":
|
||||
total_weight = (request.vector_weight or 0.7) + (request.text_weight or 0.3)
|
||||
if abs(total_weight - 1.0) > 0.01:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Vector weight and text weight must sum to 1.0"
|
||||
)
|
||||
|
||||
# Initialize search service
|
||||
logger.info(f"🔍 SEARCH_API: Initializing search service with tenant_id='{tenant_domain}', user_id='{user_id}'")
|
||||
search_service = get_pgvector_search_service(
|
||||
tenant_id=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Configure search parameters
|
||||
config = SearchConfig(
|
||||
vector_weight=request.vector_weight or 0.7,
|
||||
text_weight=request.text_weight or 0.3,
|
||||
min_vector_similarity=request.min_similarity or 0.3,
|
||||
min_text_relevance=0.01, # Fix: Use appropriate ts_rank_cd threshold
|
||||
max_results=request.max_results,
|
||||
rerank_results=request.rerank_results or True
|
||||
)
|
||||
logger.info(f"🔍 SEARCH_API: Search config created: {config.__dict__}")
|
||||
|
||||
# Execute search based on type
|
||||
results = []
|
||||
logger.info(f"🔍 SEARCH_API: Executing {request.search_type} search")
|
||||
if request.search_type == "hybrid":
|
||||
logger.info(f"🔍 SEARCH_API: Calling hybrid_search with query='{request.query}', user_id='{user_id}', dataset_ids={request.dataset_ids}")
|
||||
results = await search_service.hybrid_search(
|
||||
query=request.query,
|
||||
user_id=user_id,
|
||||
dataset_ids=request.dataset_ids,
|
||||
config=config,
|
||||
limit=request.max_results
|
||||
)
|
||||
logger.info(f"🔍 SEARCH_API: Hybrid search returned {len(results)} results")
|
||||
elif request.search_type == "vector":
|
||||
# Generate query embedding first
|
||||
query_embedding = await search_service._generate_query_embedding(
|
||||
request.query,
|
||||
user_id
|
||||
)
|
||||
results = await search_service.vector_similarity_search(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
dataset_ids=request.dataset_ids,
|
||||
similarity_threshold=config.min_vector_similarity,
|
||||
limit=request.max_results
|
||||
)
|
||||
elif request.search_type == "text":
|
||||
results = await search_service.full_text_search(
|
||||
query=request.query,
|
||||
user_id=user_id,
|
||||
dataset_ids=request.dataset_ids,
|
||||
limit=request.max_results
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid search_type. Must be 'hybrid', 'vector', or 'text'"
|
||||
)
|
||||
|
||||
# Calculate search time
|
||||
search_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Convert results to response format
|
||||
result_responses = [
|
||||
SearchResultResponse(
|
||||
chunk_id=str(result.chunk_id),
|
||||
document_id=str(result.document_id),
|
||||
dataset_id=str(result.dataset_id) if result.dataset_id else None,
|
||||
text=result.text,
|
||||
metadata=result.metadata if isinstance(result.metadata, dict) else {},
|
||||
vector_similarity=result.vector_similarity,
|
||||
text_relevance=result.text_relevance,
|
||||
hybrid_score=result.hybrid_score,
|
||||
rank=result.rank
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Search completed: query='{request.query}', "
|
||||
f"type={request.search_type}, results={len(results)}, "
|
||||
f"time={search_time_ms:.1f}ms"
|
||||
)
|
||||
|
||||
return SearchResponse(
|
||||
query=request.query,
|
||||
search_type=request.search_type,
|
||||
total_results=len(results),
|
||||
results=result_responses,
|
||||
search_time_ms=search_time_ms,
|
||||
config={
|
||||
"vector_weight": config.vector_weight,
|
||||
"text_weight": config.text_weight,
|
||||
"min_similarity": config.min_vector_similarity,
|
||||
"rerank_enabled": config.rerank_results
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}/chunks", response_model=DocumentChunksResponse)
|
||||
async def get_document_chunks(
|
||||
document_id: str,
|
||||
include_embeddings: bool = Query(False, description="Include embedding vectors"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all chunks for a specific document.
|
||||
|
||||
Useful for understanding document structure and chunk boundaries.
|
||||
"""
|
||||
try:
|
||||
# Extract tenant info
|
||||
tenant_domain = user_context.get("tenant_domain", "test")
|
||||
user_id = user_context.get("id", user_context.get("sub", None))
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing user ID in authentication context"
|
||||
)
|
||||
|
||||
search_service = get_pgvector_search_service(
|
||||
tenant_id=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
chunks = await search_service.get_document_chunks(
|
||||
document_id=document_id,
|
||||
user_id=user_id,
|
||||
include_embeddings=include_embeddings
|
||||
)
|
||||
|
||||
return DocumentChunksResponse(
|
||||
document_id=document_id,
|
||||
total_chunks=len(chunks),
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get document chunks: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/similar-chunks", response_model=SearchResponse)
|
||||
async def find_similar_chunks(
|
||||
request: SimilarChunksRequest,
|
||||
user_context: Dict[str, Any] = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Find chunks similar to a reference chunk.
|
||||
|
||||
Useful for exploring related content and building context.
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Extract tenant info
|
||||
tenant_domain = user_context.get("tenant_domain", "test")
|
||||
user_id = user_context.get("id", user_context.get("sub", None))
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing user ID in authentication context"
|
||||
)
|
||||
|
||||
search_service = get_pgvector_search_service(
|
||||
tenant_id=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
results = await search_service.search_similar_chunks(
|
||||
chunk_id=request.chunk_id,
|
||||
user_id=user_id,
|
||||
similarity_threshold=request.similarity_threshold,
|
||||
limit=request.max_results,
|
||||
exclude_same_document=request.exclude_same_document
|
||||
)
|
||||
|
||||
search_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Convert results to response format
|
||||
result_responses = [
|
||||
SearchResultResponse(
|
||||
chunk_id=str(result.chunk_id),
|
||||
document_id=str(result.document_id),
|
||||
dataset_id=str(result.dataset_id) if result.dataset_id else None,
|
||||
text=result.text,
|
||||
metadata=result.metadata if isinstance(result.metadata, dict) else {},
|
||||
vector_similarity=result.vector_similarity,
|
||||
text_relevance=result.text_relevance,
|
||||
hybrid_score=result.hybrid_score,
|
||||
rank=result.rank
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
logger.info(f"Similar chunks search: found {len(results)} results in {search_time_ms:.1f}ms")
|
||||
|
||||
return SearchResponse(
|
||||
query=f"Similar to chunk {request.chunk_id}",
|
||||
search_type="vector_similarity",
|
||||
total_results=len(results),
|
||||
results=result_responses,
|
||||
search_time_ms=search_time_ms,
|
||||
config={
|
||||
"similarity_threshold": request.similarity_threshold,
|
||||
"exclude_same_document": request.exclude_same_document
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Similar chunks search failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
class DocumentSearchRequest(BaseModel):
|
||||
"""Request for searching specific documents"""
|
||||
query: str = Field(..., min_length=1, max_length=1000, description="Search query")
|
||||
document_ids: List[str] = Field(..., description="Document IDs to search within")
|
||||
search_type: str = Field("hybrid", description="Search type: hybrid, vector, text")
|
||||
max_results: int = Field(5, ge=1, le=20, description="Maximum results per document")
|
||||
min_similarity: Optional[float] = Field(0.6, ge=0.0, le=1.0, description="Minimum similarity threshold")
|
||||
|
||||
|
||||
@router.post("/documents", response_model=SearchResponse)
|
||||
async def search_documents_specific(
|
||||
request: DocumentSearchRequest,
|
||||
user_context: Dict[str, Any] = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Search within specific documents for relevant chunks.
|
||||
|
||||
Used by MCP RAG server for document-specific queries.
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Extract tenant info
|
||||
tenant_domain = user_context.get("tenant_domain", "test")
|
||||
user_id = user_context.get("id", user_context.get("sub", None))
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing user ID in authentication context"
|
||||
)
|
||||
|
||||
# Initialize search service
|
||||
search_service = get_pgvector_search_service(
|
||||
tenant_id=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Configure search for documents
|
||||
config = SearchConfig(
|
||||
vector_weight=0.7,
|
||||
text_weight=0.3,
|
||||
min_vector_similarity=request.min_similarity or 0.6,
|
||||
min_text_relevance=0.1,
|
||||
max_results=request.max_results,
|
||||
rerank_results=True
|
||||
)
|
||||
|
||||
# Execute search with document filter
|
||||
# First resolve dataset IDs from document IDs to satisfy security constraints
|
||||
dataset_ids = await search_service.get_dataset_ids_from_documents(request.document_ids, user_id)
|
||||
if not dataset_ids:
|
||||
logger.warning(f"No dataset IDs found for documents: {request.document_ids}")
|
||||
return SearchResponse(
|
||||
query=request.query,
|
||||
search_type=request.search_type,
|
||||
total_results=0,
|
||||
results=[],
|
||||
search_time_ms=0.0,
|
||||
config={}
|
||||
)
|
||||
|
||||
results = []
|
||||
if request.search_type == "hybrid":
|
||||
results = await search_service.hybrid_search(
|
||||
query=request.query,
|
||||
user_id=user_id,
|
||||
dataset_ids=dataset_ids, # Use resolved dataset IDs
|
||||
config=config,
|
||||
limit=request.max_results * len(request.document_ids) # Allow more results for filtering
|
||||
)
|
||||
# Filter results to only include specified documents
|
||||
results = [r for r in results if r.document_id in request.document_ids][:request.max_results]
|
||||
|
||||
elif request.search_type == "vector":
|
||||
query_embedding = await search_service._generate_query_embedding(
|
||||
request.query,
|
||||
user_id
|
||||
)
|
||||
results = await search_service.vector_similarity_search(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
dataset_ids=dataset_ids, # Use resolved dataset IDs
|
||||
similarity_threshold=config.min_vector_similarity,
|
||||
limit=request.max_results * len(request.document_ids)
|
||||
)
|
||||
# Filter by document IDs
|
||||
results = [r for r in results if r.document_id in request.document_ids][:request.max_results]
|
||||
|
||||
elif request.search_type == "text":
|
||||
results = await search_service.full_text_search(
|
||||
query=request.query,
|
||||
user_id=user_id,
|
||||
dataset_ids=dataset_ids, # Use resolved dataset IDs
|
||||
limit=request.max_results * len(request.document_ids)
|
||||
)
|
||||
# Filter by document IDs
|
||||
results = [r for r in results if r.document_id in request.document_ids][:request.max_results]
|
||||
|
||||
search_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Convert results to response format
|
||||
result_responses = [
|
||||
SearchResultResponse(
|
||||
chunk_id=str(result.chunk_id),
|
||||
document_id=str(result.document_id),
|
||||
dataset_id=str(result.dataset_id) if result.dataset_id else None,
|
||||
text=result.text,
|
||||
metadata=result.metadata if isinstance(result.metadata, dict) else {},
|
||||
vector_similarity=result.vector_similarity,
|
||||
text_relevance=result.text_relevance,
|
||||
hybrid_score=result.hybrid_score,
|
||||
rank=result.rank
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Document search completed: query='{request.query}', "
|
||||
f"documents={len(request.document_ids)}, results={len(results)}, "
|
||||
f"time={search_time_ms:.1f}ms"
|
||||
)
|
||||
|
||||
return SearchResponse(
|
||||
query=request.query,
|
||||
search_type=request.search_type,
|
||||
total_results=len(results),
|
||||
results=result_responses,
|
||||
search_time_ms=search_time_ms,
|
||||
config={
|
||||
"document_ids": request.document_ids,
|
||||
"min_similarity": request.min_similarity,
|
||||
"max_results_per_document": request.max_results
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Document search failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def search_health_check(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Health check for search service functionality.
|
||||
|
||||
Verifies that PGVector extension and search capabilities are working.
|
||||
"""
|
||||
try:
|
||||
# Extract tenant info
|
||||
tenant_domain = user_context.get("tenant_domain", "test")
|
||||
user_id = user_context.get("id", user_context.get("sub", None))
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing user ID in authentication context"
|
||||
)
|
||||
|
||||
search_service = get_pgvector_search_service(
|
||||
tenant_id=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Test basic connectivity and PGVector functionality
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
# Test PGVector extension
|
||||
result = await conn.fetchval("SELECT 1 + 1")
|
||||
if result != 2:
|
||||
raise Exception("Basic database connectivity failed")
|
||||
|
||||
# Test PGVector extension (this will fail if extension is not installed)
|
||||
await conn.fetchval("SELECT '[1,2,3]'::vector <-> '[1,2,4]'::vector")
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"tenant_id": tenant_domain,
|
||||
"pgvector_available": True,
|
||||
"search_service": "operational"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search health check failed: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": "Search service health check failed",
|
||||
"tenant_id": tenant_domain,
|
||||
"pgvector_available": False,
|
||||
"search_service": "error"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/conversations")
|
||||
async def search_conversations(
|
||||
request: ConversationSearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_user_context),
|
||||
x_tenant_domain: Optional[str] = Header(None, alias="X-Tenant-Domain"),
|
||||
x_user_id: Optional[str] = Header(None, alias="X-User-ID")
|
||||
):
|
||||
"""
|
||||
Search through conversation history using MCP conversation server.
|
||||
|
||||
Used by both external clients and internal MCP tools for conversation search.
|
||||
"""
|
||||
try:
|
||||
# Use same user resolution pattern as document search
|
||||
tenant_domain = current_user.get("tenant_domain") or x_tenant_domain
|
||||
user_id = current_user.get("id") or current_user.get("sub") or x_user_id
|
||||
|
||||
if not tenant_domain or not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing tenant_domain or user_id in request context"
|
||||
)
|
||||
|
||||
logger.info(f"🔍 Conversation search: query='{request.query}', user={user_id}, tenant={tenant_domain}")
|
||||
|
||||
# Get resource cluster URL
|
||||
settings = get_settings()
|
||||
mcp_base_url = getattr(settings, 'resource_cluster_url', 'http://gentwo-resource-backend:8000')
|
||||
|
||||
# Build request payload for MCP execution
|
||||
request_payload = {
|
||||
"server_id": "conversation_server",
|
||||
"tool_name": "search_conversations",
|
||||
"parameters": {
|
||||
"query": request.query,
|
||||
"days_back": request.days_back or 30,
|
||||
"max_results": request.max_results or 5,
|
||||
"agent_filter": request.agent_filter,
|
||||
"include_user_messages": request.include_user_messages
|
||||
},
|
||||
"tenant_domain": tenant_domain,
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
logger.info(f"🌐 Making MCP request to: {mcp_base_url}/api/v1/mcp/execute")
|
||||
|
||||
response = await client.post(
|
||||
f"{mcp_base_url}/api/v1/mcp/execute",
|
||||
json=request_payload
|
||||
)
|
||||
|
||||
execution_time_ms = (time.time() - start_time) * 1000
|
||||
logger.info(f"📊 MCP response: {response.status_code} ({execution_time_ms:.1f}ms)")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
logger.info(f"✅ Conversation search successful ({execution_time_ms:.1f}ms)")
|
||||
return result
|
||||
else:
|
||||
error_text = response.text
|
||||
error_msg = f"MCP conversation search failed: {response.status_code} - {error_text}"
|
||||
logger.error(f"❌ {error_msg}")
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Conversation search endpoint error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
758
apps/tenant-backend/app/api/v1/teams.py
Normal file
758
apps/tenant-backend/app/api/v1/teams.py
Normal file
@@ -0,0 +1,758 @@
|
||||
"""
|
||||
Teams API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Provides team collaboration management with two-tier permissions:
|
||||
- Tier 1 (Team-level): 'read' or 'share' set by team owner
|
||||
- Tier 2 (Resource-level): 'read' or 'edit' set by resource sharer
|
||||
|
||||
Follows GT 2.0 principles:
|
||||
- Perfect tenant isolation
|
||||
- Admin bypass for tenant admins
|
||||
- Fail-fast error handling
|
||||
- PostgreSQL-first storage
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Dict, Any, List
|
||||
import logging
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.services.team_service import TeamService
|
||||
from app.api.auth import get_tenant_user_uuid_by_email
|
||||
from app.models.collaboration_team import (
|
||||
TeamCreate,
|
||||
TeamUpdate,
|
||||
Team,
|
||||
TeamListResponse,
|
||||
TeamResponse,
|
||||
TeamWithMembers,
|
||||
TeamWithMembersResponse,
|
||||
AddMemberRequest,
|
||||
UpdateMemberPermissionRequest,
|
||||
MemberListResponse,
|
||||
MemberResponse,
|
||||
ShareResourceRequest,
|
||||
SharedResourcesResponse,
|
||||
SharedResource,
|
||||
TeamInvitation,
|
||||
InvitationListResponse,
|
||||
ObservableRequest,
|
||||
ObservableRequestListResponse,
|
||||
TeamActivityMetrics,
|
||||
TeamActivityResponse,
|
||||
ErrorResponse
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/teams", tags=["teams"])
|
||||
|
||||
|
||||
async def get_team_service_for_user(current_user: Dict[str, Any]) -> TeamService:
|
||||
"""Helper function to create TeamService with proper tenant UUID mapping"""
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
tenant_user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
if not tenant_user_uuid:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_email} not found in tenant system")
|
||||
|
||||
return TeamService(
|
||||
tenant_domain=current_user.get('tenant_domain', 'test-company'),
|
||||
user_id=tenant_user_uuid,
|
||||
user_email=user_email
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TEAM CRUD ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("", response_model=TeamListResponse)
|
||||
async def list_teams(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List all teams where the current user is owner or member.
|
||||
Returns teams with member counts and permission flags.
|
||||
|
||||
Permission flags:
|
||||
- is_owner: User created this team
|
||||
- can_manage: User can manage team (owner or admin/developer)
|
||||
"""
|
||||
logger.info(f"Listing teams for user {current_user['sub']}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
teams = await service.get_user_teams()
|
||||
|
||||
return TeamListResponse(
|
||||
data=teams,
|
||||
total=len(teams)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing teams: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("", response_model=TeamResponse, status_code=201)
|
||||
async def create_team(
|
||||
team_data: TeamCreate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Create a new team with the current user as owner.
|
||||
|
||||
The creator is automatically the team owner with full management permissions.
|
||||
"""
|
||||
logger.info(f"Creating team '{team_data.name}' for user {current_user['sub']}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
team = await service.create_team(
|
||||
name=team_data.name,
|
||||
description=team_data.description or ""
|
||||
)
|
||||
|
||||
return TeamResponse(data=team)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating team: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# TEAM INVITATION ENDPOINTS (must come before /{team_id} routes)
|
||||
# ==============================================================================
|
||||
|
||||
@router.get("/invitations", response_model=InvitationListResponse)
|
||||
async def list_my_invitations(
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get current user's pending team invitations.
|
||||
|
||||
Returns list of invitations with team details and inviter information.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
invitations = await service.get_pending_invitations()
|
||||
|
||||
return InvitationListResponse(
|
||||
data=invitations,
|
||||
total=len(invitations)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing invitations: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/invitations/{invitation_id}/accept", response_model=MemberResponse)
|
||||
async def accept_team_invitation(
|
||||
invitation_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Accept a team invitation.
|
||||
|
||||
Updates the invitation status to 'accepted' and grants team membership.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
member = await service.accept_invitation(invitation_id)
|
||||
|
||||
return MemberResponse(data=member)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error accepting invitation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/invitations/{invitation_id}/decline", status_code=204)
|
||||
async def decline_team_invitation(
|
||||
invitation_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Decline a team invitation.
|
||||
|
||||
Removes the invitation from the system.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
await service.decline_invitation(invitation_id)
|
||||
|
||||
return None
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error declining invitation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/observable-requests", response_model=ObservableRequestListResponse)
|
||||
async def get_observable_requests(
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get pending Observable requests for the current user.
|
||||
|
||||
Returns list of teams requesting Observable access.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
requests = await service.get_observable_requests()
|
||||
|
||||
return ObservableRequestListResponse(
|
||||
data=[ObservableRequest(**req) for req in requests],
|
||||
total=len(requests)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Observable requests: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# TEAM CRUD ENDPOINTS (dynamic routes with {team_id})
|
||||
# ==============================================================================
|
||||
|
||||
@router.get("/{team_id}", response_model=TeamResponse)
|
||||
async def get_team(
|
||||
team_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get team details by ID.
|
||||
|
||||
Only accessible to team members or tenant admins.
|
||||
"""
|
||||
logger.info(f"Getting team {team_id} for user {current_user['sub']}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
team = await service.get_team_by_id(team_id)
|
||||
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail=f"Team {team_id} not found")
|
||||
|
||||
return TeamResponse(data=team)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting team: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{team_id}", response_model=TeamResponse)
|
||||
async def update_team(
|
||||
team_id: str,
|
||||
updates: TeamUpdate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update team name/description.
|
||||
|
||||
Requires: Team ownership or admin/developer role
|
||||
"""
|
||||
logger.info(f"Updating team {team_id} for user {current_user['sub']}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
|
||||
# Convert Pydantic model to dict, excluding None values
|
||||
update_dict = updates.model_dump(exclude_none=True)
|
||||
|
||||
team = await service.update_team(team_id, update_dict)
|
||||
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail=f"Team {team_id} not found")
|
||||
|
||||
return TeamResponse(data=team)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating team: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{team_id}", status_code=204)
|
||||
async def delete_team(
|
||||
team_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Delete a team and all its memberships (CASCADE).
|
||||
|
||||
Requires: Team ownership or admin/developer role
|
||||
"""
|
||||
logger.info(f"Deleting team {team_id} for user {current_user['sub']}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
success = await service.delete_team(team_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Team {team_id} not found")
|
||||
|
||||
return None # 204 No Content
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting team: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TEAM MEMBER ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/{team_id}/members", response_model=MemberListResponse)
|
||||
async def list_team_members(
|
||||
team_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List all members of a team with their permissions.
|
||||
|
||||
Only accessible to team members or tenant admins.
|
||||
|
||||
Returns:
|
||||
- user_id, user_email, user_name
|
||||
- team_permission: 'read' or 'share'
|
||||
- resource_permissions: JSONB dict of resource-level permissions
|
||||
"""
|
||||
logger.info(f"Listing members for team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
members = await service.get_team_members(team_id)
|
||||
|
||||
return MemberListResponse(
|
||||
data=members,
|
||||
total=len(members)
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing team members: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{team_id}/members", response_model=MemberResponse, status_code=201)
|
||||
async def add_team_member(
|
||||
team_id: str,
|
||||
member_data: AddMemberRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Add a user to the team with specified permission.
|
||||
|
||||
Requires: Team ownership or admin/developer role
|
||||
|
||||
Team Permissions:
|
||||
- 'read': Can access resources shared to this team
|
||||
- 'share': Can access resources AND share own resources to this team
|
||||
|
||||
Note: Observability access is automatically requested when inviting users.
|
||||
The invited user can approve or decline the observability request separately.
|
||||
"""
|
||||
logger.info(f"Adding member {member_data.user_email} to team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
member = await service.add_member(
|
||||
team_id=team_id,
|
||||
user_email=member_data.user_email,
|
||||
team_permission=member_data.team_permission
|
||||
)
|
||||
|
||||
return MemberResponse(data=member)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding team member: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{team_id}/members/{user_id}", response_model=MemberResponse)
|
||||
async def update_member_permission(
|
||||
team_id: str,
|
||||
user_id: str,
|
||||
permission_data: UpdateMemberPermissionRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update a team member's permission level.
|
||||
|
||||
Requires: Team ownership or admin/developer role
|
||||
|
||||
Note: DB trigger auto-clears resource_permissions when downgraded from 'share' to 'read'
|
||||
"""
|
||||
logger.info(f"PUT /teams/{team_id}/members/{user_id} - Permission update request")
|
||||
logger.info(f"Request body: {permission_data.model_dump()}")
|
||||
logger.info(f"Current user: {current_user.get('email')}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
member = await service.update_member_permission(
|
||||
team_id=team_id,
|
||||
user_id=user_id,
|
||||
new_permission=permission_data.team_permission
|
||||
)
|
||||
|
||||
return MemberResponse(data=member)
|
||||
|
||||
except PermissionError as e:
|
||||
logger.error(f"PermissionError updating member permission: {str(e)}")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(f"ValueError updating member permission: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"RuntimeError updating member permission: {str(e)}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating member permission: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{team_id}/members/{user_id}", status_code=204)
|
||||
async def remove_team_member(
|
||||
team_id: str,
|
||||
user_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Remove a user from the team.
|
||||
|
||||
Requires: Team ownership or admin/developer role
|
||||
"""
|
||||
logger.info(f"Removing member {user_id} from team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
success = await service.remove_member(team_id, user_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Member {user_id} not found in team {team_id}")
|
||||
|
||||
return None # 204 No Content
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing team member: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RESOURCE SHARING ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/{team_id}/share", status_code=201)
|
||||
async def share_resource_to_team(
|
||||
team_id: str,
|
||||
share_data: ShareResourceRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Share a resource (agent/dataset) to team with per-user permissions.
|
||||
|
||||
Requires: Team ownership or 'share' team permission
|
||||
|
||||
Request body:
|
||||
{
|
||||
"resource_type": "agent" | "dataset",
|
||||
"resource_id": "uuid",
|
||||
"user_permissions": {
|
||||
"user_uuid_1": "read",
|
||||
"user_uuid_2": "edit"
|
||||
}
|
||||
}
|
||||
|
||||
Resource Permissions:
|
||||
- 'read': View-only access to resource
|
||||
- 'edit': Full edit access to resource
|
||||
"""
|
||||
logger.info(f"Sharing {share_data.resource_type}:{share_data.resource_id} to team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
|
||||
# Use new junction table method (Phase 2)
|
||||
await service.share_resource_to_teams(
|
||||
resource_id=share_data.resource_id,
|
||||
resource_type=share_data.resource_type,
|
||||
shared_by=current_user["user_id"],
|
||||
team_shares=[{
|
||||
"team_id": team_id,
|
||||
"user_permissions": share_data.user_permissions
|
||||
}]
|
||||
)
|
||||
|
||||
return {"message": "Resource shared successfully", "success": True}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sharing resource: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{team_id}/share/{resource_type}/{resource_id}", status_code=204)
|
||||
async def unshare_resource_from_team(
|
||||
team_id: str,
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Remove resource sharing from team (removes from all members' resource_permissions).
|
||||
|
||||
Requires: Team ownership or 'share' team permission
|
||||
"""
|
||||
logger.info(f"Unsharing {resource_type}:{resource_id} from team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
|
||||
# Use new junction table method (Phase 2)
|
||||
await service.unshare_resource_from_team(
|
||||
resource_id=resource_id,
|
||||
resource_type=resource_type,
|
||||
team_id=team_id
|
||||
)
|
||||
|
||||
return None # 204 No Content
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error unsharing resource: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{team_id}/resources", response_model=SharedResourcesResponse)
|
||||
async def list_shared_resources(
|
||||
team_id: str,
|
||||
resource_type: str = Query(None, description="Filter by resource type: 'agent' or 'dataset'"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List all resources shared to a team.
|
||||
|
||||
Only accessible to team members or tenant admins.
|
||||
|
||||
Returns list of:
|
||||
{
|
||||
"resource_type": "agent" | "dataset",
|
||||
"resource_id": "uuid",
|
||||
"user_permissions": {"user_id": "read|edit", ...}
|
||||
}
|
||||
"""
|
||||
logger.info(f"Listing shared resources for team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
resources = await service.get_shared_resources(
|
||||
team_id=team_id,
|
||||
resource_type=resource_type
|
||||
)
|
||||
|
||||
return SharedResourcesResponse(
|
||||
data=resources,
|
||||
total=len(resources)
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing shared resources: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@router.get("/{team_id}/invitations", response_model=InvitationListResponse)
|
||||
async def list_team_invitations(
|
||||
team_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get pending invitations for a team (owner view).
|
||||
|
||||
Shows all users who have been invited but haven't accepted yet.
|
||||
Requires team ownership or admin role.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
invitations = await service.get_team_pending_invitations(team_id)
|
||||
|
||||
return InvitationListResponse(
|
||||
data=invitations,
|
||||
total=len(invitations)
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing team invitations: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{team_id}/invitations/{invitation_id}", status_code=204)
|
||||
async def cancel_team_invitation(
|
||||
team_id: str,
|
||||
invitation_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Cancel a pending invitation (owner only).
|
||||
|
||||
Removes the invitation before the user accepts it.
|
||||
Requires team ownership or admin role.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
await service.cancel_invitation(team_id, invitation_id)
|
||||
|
||||
return None
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error canceling invitation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Observable Member Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/{team_id}/members/{user_id}/request-observable", status_code=200)
|
||||
async def request_observable_access(
|
||||
team_id: str,
|
||||
user_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Request Observable access from a team member.
|
||||
|
||||
Sets Observable status to pending for the target user.
|
||||
Requires owner or manager permission.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
result = await service.request_observable_status(team_id, user_id)
|
||||
|
||||
return {"success": True, "data": result}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error requesting Observable access: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{team_id}/observable/approve", status_code=200)
|
||||
async def approve_observable_request(
|
||||
team_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Approve Observable status for current user in a team.
|
||||
|
||||
User explicitly consents to team managers viewing their activity.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
result = await service.approve_observable_consent(team_id)
|
||||
|
||||
return {"success": True, "data": result}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error approving Observable request: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{team_id}/observable", status_code=200)
|
||||
async def revoke_observable_status(
|
||||
team_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Revoke Observable status for current user in a team.
|
||||
|
||||
Immediately removes manager access to user's activity data.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
result = await service.revoke_observable_status(team_id)
|
||||
|
||||
return {"success": True, "data": result}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking Observable status: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{team_id}/activity", response_model=TeamActivityResponse)
|
||||
async def get_team_activity(
|
||||
team_id: str,
|
||||
days: int = Query(7, ge=1, le=365, description="Number of days to include in metrics"),
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get team activity metrics for Observable members.
|
||||
|
||||
Returns aggregated activity data for team members who have approved Observable status.
|
||||
Requires owner or manager permission.
|
||||
|
||||
Args:
|
||||
team_id: Team UUID
|
||||
days: Number of days to include (1-365, default 7)
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
activity = await service.get_team_activity(team_id, days)
|
||||
|
||||
return TeamActivityResponse(
|
||||
data=TeamActivityMetrics(**activity)
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting team activity: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
348
apps/tenant-backend/app/api/v1/users.py
Normal file
348
apps/tenant-backend/app/api/v1/users.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
User API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Handles user preferences and favorite agents management.
|
||||
Follows GT 2.0 principles: no mocks, real implementations, fail fast.
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, Header
|
||||
from typing import Optional
|
||||
|
||||
from app.services.user_service import UserService
|
||||
from app.schemas.user import (
|
||||
UserPreferencesResponse,
|
||||
UpdateUserPreferencesRequest,
|
||||
FavoriteAgentsResponse,
|
||||
UpdateFavoriteAgentsRequest,
|
||||
AddFavoriteAgentRequest,
|
||||
RemoveFavoriteAgentRequest,
|
||||
CustomCategoriesResponse,
|
||||
UpdateCustomCategoriesRequest
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
def get_user_context(
|
||||
x_tenant_domain: Optional[str] = Header(None),
|
||||
x_user_id: Optional[str] = Header(None),
|
||||
x_user_email: Optional[str] = Header(None)
|
||||
) -> tuple[str, str, str]:
|
||||
"""Extract user context from headers"""
|
||||
if not x_tenant_domain:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="X-Tenant-Domain header is required"
|
||||
)
|
||||
if not x_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="X-User-ID header is required"
|
||||
)
|
||||
|
||||
return x_tenant_domain, x_user_id, x_user_email or x_user_id
|
||||
|
||||
|
||||
# User Preferences Endpoints
|
||||
|
||||
@router.get("/me/preferences", response_model=UserPreferencesResponse)
|
||||
async def get_user_preferences(
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Get current user's preferences from PostgreSQL.
|
||||
|
||||
Returns all user preferences stored in the JSONB preferences column.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info("Getting user preferences", user_id=user_id, tenant_domain=tenant_domain)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
preferences = await service.get_user_preferences()
|
||||
|
||||
return UserPreferencesResponse(preferences=preferences)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user preferences", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve user preferences"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/preferences")
|
||||
async def update_user_preferences(
|
||||
request: UpdateUserPreferencesRequest,
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Update current user's preferences in PostgreSQL.
|
||||
|
||||
Merges provided preferences with existing preferences using JSONB || operator.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info("Updating user preferences", user_id=user_id, tenant_domain=tenant_domain)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
success = await service.update_user_preferences(request.preferences)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {"success": True, "message": "Preferences updated successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update user preferences", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user preferences"
|
||||
)
|
||||
|
||||
|
||||
# Favorite Agents Endpoints
|
||||
|
||||
@router.get("/me/favorite-agents", response_model=FavoriteAgentsResponse)
|
||||
async def get_favorite_agents(
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Get current user's favorited agent IDs from PostgreSQL.
|
||||
|
||||
Returns list of agent UUIDs that the user has marked as favorites.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info("Getting favorite agent IDs", user_id=user_id, tenant_domain=tenant_domain)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
favorite_ids = await service.get_favorite_agent_ids()
|
||||
|
||||
return FavoriteAgentsResponse(favorite_agent_ids=favorite_ids)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get favorite agents", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve favorite agents"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/favorite-agents")
|
||||
async def update_favorite_agents(
|
||||
request: UpdateFavoriteAgentsRequest,
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Update current user's favorite agent IDs in PostgreSQL.
|
||||
|
||||
Replaces the entire list of favorite agent IDs with the provided list.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Updating favorite agent IDs",
|
||||
user_id=user_id,
|
||||
tenant_domain=tenant_domain,
|
||||
agent_count=len(request.agent_ids)
|
||||
)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
success = await service.update_favorite_agent_ids(request.agent_ids)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Favorite agents updated successfully",
|
||||
"favorite_agent_ids": request.agent_ids
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update favorite agents", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update favorite agents"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/me/favorite-agents/add")
|
||||
async def add_favorite_agent(
|
||||
request: AddFavoriteAgentRequest,
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Add a single agent to user's favorites.
|
||||
|
||||
Idempotent - does nothing if agent is already in favorites.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Adding agent to favorites",
|
||||
user_id=user_id,
|
||||
tenant_domain=tenant_domain,
|
||||
agent_id=request.agent_id
|
||||
)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
success = await service.add_favorite_agent(request.agent_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Agent added to favorites",
|
||||
"agent_id": request.agent_id
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to add favorite agent", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to add favorite agent"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/me/favorite-agents/remove")
|
||||
async def remove_favorite_agent(
|
||||
request: RemoveFavoriteAgentRequest,
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Remove a single agent from user's favorites.
|
||||
|
||||
Idempotent - does nothing if agent is not in favorites.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Removing agent from favorites",
|
||||
user_id=user_id,
|
||||
tenant_domain=tenant_domain,
|
||||
agent_id=request.agent_id
|
||||
)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
success = await service.remove_favorite_agent(request.agent_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Agent removed from favorites",
|
||||
"agent_id": request.agent_id
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to remove favorite agent", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to remove favorite agent"
|
||||
)
|
||||
|
||||
|
||||
# Custom Categories Endpoints
|
||||
|
||||
@router.get("/me/custom-categories", response_model=CustomCategoriesResponse)
|
||||
async def get_custom_categories(
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Get current user's custom agent categories from PostgreSQL.
|
||||
|
||||
Returns list of custom categories with name and description.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info("Getting custom categories", user_id=user_id, tenant_domain=tenant_domain)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
categories = await service.get_custom_categories()
|
||||
|
||||
return CustomCategoriesResponse(categories=categories)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get custom categories", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve custom categories"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/custom-categories")
|
||||
async def update_custom_categories(
|
||||
request: UpdateCustomCategoriesRequest,
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Update current user's custom agent categories in PostgreSQL.
|
||||
|
||||
Replaces the entire list of custom categories with the provided list.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Updating custom categories",
|
||||
user_id=user_id,
|
||||
tenant_domain=tenant_domain,
|
||||
category_count=len(request.categories)
|
||||
)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
success = await service.update_custom_categories(request.categories)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Custom categories updated successfully",
|
||||
"categories": [cat.dict() for cat in request.categories]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update custom categories", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update custom categories"
|
||||
)
|
||||
376
apps/tenant-backend/app/api/v1/webhooks.py
Normal file
376
apps/tenant-backend/app/api/v1/webhooks.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
Webhook Receiver API
|
||||
|
||||
Handles incoming webhooks for automation triggers with security validation
|
||||
and rate limiting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import hmac
|
||||
import hashlib
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, BackgroundTasks, Depends, Header
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.event_bus import TenantEventBus, TriggerType
|
||||
from app.services.automation_executor import AutomationChainExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
|
||||
|
||||
|
||||
class WebhookRegistration(BaseModel):
|
||||
"""Webhook registration model"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
secret: Optional[str] = None
|
||||
rate_limit: int = 60 # Requests per minute
|
||||
allowed_ips: Optional[list[str]] = None
|
||||
events_to_trigger: list[str] = []
|
||||
|
||||
|
||||
class WebhookPayload(BaseModel):
|
||||
"""Generic webhook payload"""
|
||||
event: Optional[str] = None
|
||||
data: Dict[str, Any] = {}
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
|
||||
# In-memory webhook registry (in production, use database)
|
||||
webhook_registry: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Rate limiting tracker (in production, use Redis)
|
||||
rate_limiter: Dict[str, list[datetime]] = {}
|
||||
|
||||
|
||||
def validate_webhook_registration(tenant_domain: str, webhook_id: str) -> Dict[str, Any]:
|
||||
"""Validate webhook is registered"""
|
||||
key = f"{tenant_domain}:{webhook_id}"
|
||||
|
||||
if key not in webhook_registry:
|
||||
raise HTTPException(status_code=404, detail="Webhook not found")
|
||||
|
||||
return webhook_registry[key]
|
||||
|
||||
|
||||
def check_rate_limit(key: str, limit: int) -> bool:
|
||||
"""Check if request is within rate limit"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Get request history
|
||||
if key not in rate_limiter:
|
||||
rate_limiter[key] = []
|
||||
|
||||
# Remove old requests (older than 1 minute)
|
||||
rate_limiter[key] = [
|
||||
ts for ts in rate_limiter[key]
|
||||
if (now - ts).total_seconds() < 60
|
||||
]
|
||||
|
||||
# Check limit
|
||||
if len(rate_limiter[key]) >= limit:
|
||||
return False
|
||||
|
||||
# Add current request
|
||||
rate_limiter[key].append(now)
|
||||
return True
|
||||
|
||||
|
||||
def validate_hmac_signature(
|
||||
signature: str,
|
||||
body: bytes,
|
||||
secret: str
|
||||
) -> bool:
|
||||
"""Validate HMAC signature"""
|
||||
if not signature or not secret:
|
||||
return False
|
||||
|
||||
# Calculate expected signature
|
||||
expected = hmac.new(
|
||||
secret.encode(),
|
||||
body,
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Compare signatures
|
||||
return hmac.compare_digest(signature, expected)
|
||||
|
||||
|
||||
def sanitize_webhook_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize webhook payload to prevent injection attacks"""
|
||||
# Remove potentially dangerous keys
|
||||
dangerous_keys = ["__proto__", "constructor", "prototype"]
|
||||
|
||||
def clean_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
cleaned = {}
|
||||
for key, value in d.items():
|
||||
if key not in dangerous_keys:
|
||||
if isinstance(value, dict):
|
||||
cleaned[key] = clean_dict(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [
|
||||
clean_dict(item) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
# Limit string length
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
cleaned[key] = value[:10000]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
return cleaned
|
||||
|
||||
return clean_dict(payload)
|
||||
|
||||
|
||||
@router.post("/{tenant_domain}/{webhook_id}")
|
||||
async def receive_webhook(
|
||||
tenant_domain: str,
|
||||
webhook_id: str,
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
x_webhook_signature: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
Receive webhook and trigger automations.
|
||||
|
||||
Note: In production, webhooks terminate at NGINX in DMZ, not directly here.
|
||||
Traffic flow: Internet → NGINX (DMZ) → OAuth2 Proxy → This endpoint
|
||||
"""
|
||||
try:
|
||||
# Validate webhook registration
|
||||
webhook = validate_webhook_registration(tenant_domain, webhook_id)
|
||||
|
||||
# Check rate limiting
|
||||
rate_key = f"webhook:{tenant_domain}:{webhook_id}"
|
||||
if not check_rate_limit(rate_key, webhook.get("rate_limit", 60)):
|
||||
raise HTTPException(status_code=429, detail="Rate limit exceeded")
|
||||
|
||||
# Get request body
|
||||
body = await request.body()
|
||||
|
||||
# Validate signature if configured
|
||||
if webhook.get("secret"):
|
||||
if not validate_hmac_signature(x_webhook_signature, body, webhook["secret"]):
|
||||
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||
|
||||
# Check IP whitelist if configured
|
||||
client_ip = request.client.host
|
||||
allowed_ips = webhook.get("allowed_ips")
|
||||
if allowed_ips and client_ip not in allowed_ips:
|
||||
logger.warning(f"Webhook request from unauthorized IP: {client_ip}")
|
||||
raise HTTPException(status_code=403, detail="IP not authorized")
|
||||
|
||||
# Parse and sanitize payload
|
||||
try:
|
||||
payload = await request.json()
|
||||
payload = sanitize_webhook_payload(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid webhook payload: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid payload format")
|
||||
|
||||
# Queue for processing
|
||||
background_tasks.add_task(
|
||||
process_webhook_automation,
|
||||
tenant_domain=tenant_domain,
|
||||
webhook_id=webhook_id,
|
||||
payload=payload,
|
||||
webhook_config=webhook
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "accepted",
|
||||
"webhook_id": webhook_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing webhook: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
async def process_webhook_automation(
|
||||
tenant_domain: str,
|
||||
webhook_id: str,
|
||||
payload: Dict[str, Any],
|
||||
webhook_config: Dict[str, Any]
|
||||
):
|
||||
"""Process webhook and trigger associated automations"""
|
||||
try:
|
||||
# Initialize event bus
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Create webhook event
|
||||
event_type = payload.get("event", "webhook.received")
|
||||
|
||||
# Emit event to trigger automations
|
||||
await event_bus.emit_event(
|
||||
event_type=event_type,
|
||||
user_id=webhook_config.get("owner_id", "system"),
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"payload": payload,
|
||||
"source": webhook_config.get("name", "Unknown")
|
||||
},
|
||||
metadata={
|
||||
"trigger_type": TriggerType.WEBHOOK.value,
|
||||
"webhook_config": webhook_config
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Webhook processed: {webhook_id} → {event_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing webhook automation: {e}")
|
||||
|
||||
# Emit failure event
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
await event_bus.emit_event(
|
||||
event_type="webhook.failed",
|
||||
user_id="system",
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/{tenant_domain}/register")
|
||||
async def register_webhook(
|
||||
tenant_domain: str,
|
||||
registration: WebhookRegistration,
|
||||
user_email: str = "admin@example.com" # In production, get from auth
|
||||
):
|
||||
"""Register a new webhook endpoint"""
|
||||
import secrets
|
||||
|
||||
# Generate webhook ID
|
||||
webhook_id = secrets.token_urlsafe(16)
|
||||
|
||||
# Store registration
|
||||
key = f"{tenant_domain}:{webhook_id}"
|
||||
webhook_registry[key] = {
|
||||
"id": webhook_id,
|
||||
"name": registration.name,
|
||||
"description": registration.description,
|
||||
"secret": registration.secret or secrets.token_urlsafe(32),
|
||||
"rate_limit": registration.rate_limit,
|
||||
"allowed_ips": registration.allowed_ips,
|
||||
"events_to_trigger": registration.events_to_trigger,
|
||||
"owner_id": user_email,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"url": f"/webhooks/{tenant_domain}/{webhook_id}"
|
||||
}
|
||||
|
||||
return {
|
||||
"webhook_id": webhook_id,
|
||||
"url": f"/webhooks/{tenant_domain}/{webhook_id}",
|
||||
"secret": webhook_registry[key]["secret"],
|
||||
"created_at": webhook_registry[key]["created_at"]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{tenant_domain}/list")
|
||||
async def list_webhooks(
|
||||
tenant_domain: str,
|
||||
user_email: str = "admin@example.com" # In production, get from auth
|
||||
):
|
||||
"""List registered webhooks for tenant"""
|
||||
webhooks = []
|
||||
|
||||
for key, webhook in webhook_registry.items():
|
||||
if key.startswith(f"{tenant_domain}:"):
|
||||
# Only show webhooks owned by user
|
||||
if webhook.get("owner_id") == user_email:
|
||||
webhooks.append({
|
||||
"id": webhook["id"],
|
||||
"name": webhook["name"],
|
||||
"description": webhook.get("description"),
|
||||
"url": webhook["url"],
|
||||
"rate_limit": webhook["rate_limit"],
|
||||
"created_at": webhook["created_at"]
|
||||
})
|
||||
|
||||
return {
|
||||
"webhooks": webhooks,
|
||||
"total": len(webhooks)
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{tenant_domain}/{webhook_id}")
|
||||
async def delete_webhook(
|
||||
tenant_domain: str,
|
||||
webhook_id: str,
|
||||
user_email: str = "admin@example.com" # In production, get from auth
|
||||
):
|
||||
"""Delete a webhook registration"""
|
||||
key = f"{tenant_domain}:{webhook_id}"
|
||||
|
||||
if key not in webhook_registry:
|
||||
raise HTTPException(status_code=404, detail="Webhook not found")
|
||||
|
||||
webhook = webhook_registry[key]
|
||||
|
||||
# Check ownership
|
||||
if webhook.get("owner_id") != user_email:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Delete webhook
|
||||
del webhook_registry[key]
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"webhook_id": webhook_id
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{tenant_domain}/{webhook_id}/test")
|
||||
async def test_webhook(
|
||||
tenant_domain: str,
|
||||
webhook_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
user_email: str = "admin@example.com" # In production, get from auth
|
||||
):
|
||||
"""Send a test payload to webhook"""
|
||||
key = f"{tenant_domain}:{webhook_id}"
|
||||
|
||||
if key not in webhook_registry:
|
||||
raise HTTPException(status_code=404, detail="Webhook not found")
|
||||
|
||||
webhook = webhook_registry[key]
|
||||
|
||||
# Check ownership
|
||||
if webhook.get("owner_id") != user_email:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Create test payload
|
||||
test_payload = {
|
||||
"event": "webhook.test",
|
||||
"data": {
|
||||
"message": "This is a test webhook",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Process webhook
|
||||
background_tasks.add_task(
|
||||
process_webhook_automation,
|
||||
tenant_domain=tenant_domain,
|
||||
webhook_id=webhook_id,
|
||||
payload=test_payload,
|
||||
webhook_config=webhook
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "test_sent",
|
||||
"webhook_id": webhook_id,
|
||||
"payload": test_payload
|
||||
}
|
||||
534
apps/tenant-backend/app/api/v1/workflows.py
Normal file
534
apps/tenant-backend/app/api/v1/workflows.py
Normal file
@@ -0,0 +1,534 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
from app.services.workflow_service import WorkflowService, WorkflowValidationError
|
||||
from app.models.workflow import WorkflowStatus, TriggerType, InteractionMode
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/workflows", tags=["workflows"])
|
||||
|
||||
|
||||
# Request/Response models
|
||||
class WorkflowCreateRequest(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
definition: Dict[str, Any] = Field(..., description="Workflow nodes and edges")
|
||||
triggers: Optional[List[Dict[str, Any]]] = Field(default=[])
|
||||
interaction_modes: Optional[List[str]] = Field(default=["button"])
|
||||
api_key_ids: Optional[List[str]] = Field(default=[])
|
||||
webhook_ids: Optional[List[str]] = Field(default=[])
|
||||
dataset_ids: Optional[List[str]] = Field(default=[])
|
||||
integration_ids: Optional[List[str]] = Field(default=[])
|
||||
config: Optional[Dict[str, Any]] = Field(default={})
|
||||
timeout_seconds: Optional[int] = Field(default=300, ge=30, le=3600)
|
||||
max_retries: Optional[int] = Field(default=3, ge=0, le=10)
|
||||
|
||||
|
||||
class WorkflowUpdateRequest(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
definition: Optional[Dict[str, Any]] = None
|
||||
triggers: Optional[List[Dict[str, Any]]] = None
|
||||
interaction_modes: Optional[List[str]] = None
|
||||
status: Optional[WorkflowStatus] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
timeout_seconds: Optional[int] = Field(None, ge=30, le=3600)
|
||||
max_retries: Optional[int] = Field(None, ge=0, le=10)
|
||||
|
||||
|
||||
class WorkflowExecutionRequest(BaseModel):
|
||||
input_data: Dict[str, Any] = Field(default={})
|
||||
trigger_type: Optional[str] = Field(default="manual")
|
||||
interaction_mode: Optional[str] = Field(default="api")
|
||||
|
||||
|
||||
class WorkflowTriggerRequest(BaseModel):
|
||||
type: TriggerType
|
||||
config: Dict[str, Any] = Field(default={})
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel):
|
||||
message: str = Field(..., min_length=1, max_length=10000)
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
class WorkflowResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
status: str
|
||||
definition: Dict[str, Any]
|
||||
interaction_modes: List[str]
|
||||
execution_count: int
|
||||
last_executed: Optional[str]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class WorkflowExecutionResponse(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
progress_percentage: int
|
||||
current_node_id: Optional[str]
|
||||
input_data: Dict[str, Any]
|
||||
output_data: Dict[str, Any]
|
||||
error_details: Optional[str]
|
||||
started_at: str
|
||||
completed_at: Optional[str]
|
||||
duration_ms: Optional[int]
|
||||
tokens_used: int
|
||||
cost_cents: int
|
||||
interaction_mode: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# Workflow CRUD endpoints
|
||||
@router.post("/", response_model=WorkflowResponse)
|
||||
def create_workflow(
|
||||
workflow_data: WorkflowCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new workflow - temporary mock implementation"""
|
||||
try:
|
||||
# TODO: Implement proper PostgreSQL service integration
|
||||
# For now, return a mock workflow to avoid database integration issues
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
mock_workflow = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": workflow_data.name,
|
||||
"description": workflow_data.description or "",
|
||||
"status": "draft",
|
||||
"definition": workflow_data.definition,
|
||||
"interaction_modes": workflow_data.interaction_modes or ["button"],
|
||||
"execution_count": 0,
|
||||
"last_executed": None,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"updated_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
return WorkflowResponse(**mock_workflow)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create workflow: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/", response_model=List[WorkflowResponse])
|
||||
def list_workflows(
|
||||
status: Optional[WorkflowStatus] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""List user's workflows - temporary mock implementation"""
|
||||
try:
|
||||
# TODO: Implement proper PostgreSQL service integration
|
||||
# For now, return empty list to avoid database integration issues
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list workflows: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{workflow_id}", response_model=WorkflowResponse)
|
||||
def get_workflow(
|
||||
workflow_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get workflow by ID"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
workflow = service.get_workflow(workflow_id, current_user["sub"])
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Convert to dict with proper datetime formatting
|
||||
workflow_dict = {
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
"description": workflow.description,
|
||||
"status": workflow.status,
|
||||
"definition": workflow.definition,
|
||||
"interaction_modes": workflow.interaction_modes,
|
||||
"execution_count": workflow.execution_count,
|
||||
"last_executed": workflow.last_executed.isoformat() if workflow.last_executed else None,
|
||||
"created_at": workflow.created_at.isoformat(),
|
||||
"updated_at": workflow.updated_at.isoformat()
|
||||
}
|
||||
return WorkflowResponse(**workflow_dict)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get workflow: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/{workflow_id}", response_model=WorkflowResponse)
|
||||
def update_workflow(
|
||||
workflow_id: str,
|
||||
updates: WorkflowUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Update a workflow"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
|
||||
# Filter out None values
|
||||
update_data = {k: v for k, v in updates.dict().items() if v is not None}
|
||||
|
||||
workflow = service.update_workflow(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
updates=update_data
|
||||
)
|
||||
|
||||
# Convert to dict with proper datetime formatting
|
||||
workflow_dict = {
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
"description": workflow.description,
|
||||
"status": workflow.status,
|
||||
"definition": workflow.definition,
|
||||
"interaction_modes": workflow.interaction_modes,
|
||||
"execution_count": workflow.execution_count,
|
||||
"last_executed": workflow.last_executed.isoformat() if workflow.last_executed else None,
|
||||
"created_at": workflow.created_at.isoformat(),
|
||||
"updated_at": workflow.updated_at.isoformat()
|
||||
}
|
||||
return WorkflowResponse(**workflow_dict)
|
||||
|
||||
except WorkflowValidationError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update workflow: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{workflow_id}")
|
||||
def delete_workflow(
|
||||
workflow_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a workflow"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
success = service.delete_workflow(workflow_id, current_user["sub"])
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
return {"message": "Workflow deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {str(e)}")
|
||||
|
||||
|
||||
# Workflow execution endpoints
|
||||
@router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse)
|
||||
async def execute_workflow(
|
||||
workflow_id: str,
|
||||
execution_data: WorkflowExecutionRequest,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Execute a workflow"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
execution = await service.execute_workflow(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
input_data=execution_data.input_data,
|
||||
trigger_type=execution_data.trigger_type,
|
||||
interaction_mode=execution_data.interaction_mode
|
||||
)
|
||||
|
||||
return WorkflowExecutionResponse.from_orm(execution)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to execute workflow: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{workflow_id}/executions")
|
||||
async def list_workflow_executions(
|
||||
workflow_id: str,
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""List workflow executions"""
|
||||
try:
|
||||
# Verify workflow ownership first
|
||||
service = WorkflowService(db)
|
||||
workflow = await service.get_workflow(workflow_id, current_user["sub"])
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Get executions (implementation would query WorkflowExecution table)
|
||||
# For now, return empty list
|
||||
return []
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list executions: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}", response_model=WorkflowExecutionResponse)
|
||||
async def get_execution_status(
|
||||
execution_id: str,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get execution status"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
execution = await service.get_execution_status(execution_id, current_user["sub"])
|
||||
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
return WorkflowExecutionResponse.from_orm(execution)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get execution: {str(e)}")
|
||||
|
||||
|
||||
# Workflow trigger endpoints
|
||||
@router.post("/{workflow_id}/triggers")
|
||||
async def create_workflow_trigger(
|
||||
workflow_id: str,
|
||||
trigger_data: WorkflowTriggerRequest,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Create a trigger for a workflow"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
|
||||
# Verify workflow ownership
|
||||
workflow = await service.get_workflow(workflow_id, current_user["sub"])
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
trigger = await service.create_workflow_trigger(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
tenant_id=current_user["tenant_id"],
|
||||
trigger_config=trigger_data.dict()
|
||||
)
|
||||
|
||||
return {"id": trigger.id, "webhook_url": trigger.webhook_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create trigger: {str(e)}")
|
||||
|
||||
|
||||
# Chat interface endpoints
|
||||
@router.post("/{workflow_id}/chat/sessions")
|
||||
async def create_chat_session(
|
||||
workflow_id: str,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Create a chat session for workflow interaction"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
|
||||
# Verify workflow ownership
|
||||
workflow = await service.get_workflow(workflow_id, current_user["sub"])
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Check if workflow supports chat mode
|
||||
if "chat" not in workflow.interaction_modes:
|
||||
raise HTTPException(status_code=400, detail="Workflow does not support chat mode")
|
||||
|
||||
session = await service.create_chat_session(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
tenant_id=current_user["tenant_id"]
|
||||
)
|
||||
|
||||
return {"session_id": session.id, "expires_at": session.expires_at.isoformat()}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create chat session: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{workflow_id}/chat/message")
|
||||
async def send_chat_message(
|
||||
workflow_id: str,
|
||||
message_data: ChatMessageRequest,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Send a message to workflow chat"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
|
||||
# Create or get session
|
||||
session_id = message_data.session_id
|
||||
if not session_id:
|
||||
session = await service.create_chat_session(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
tenant_id=current_user["tenant_id"]
|
||||
)
|
||||
session_id = session.id
|
||||
|
||||
# Add user message
|
||||
user_message = await service.add_chat_message(
|
||||
session_id=session_id,
|
||||
user_id=current_user["sub"],
|
||||
role="user",
|
||||
content=message_data.message
|
||||
)
|
||||
|
||||
# Execute workflow with message as input
|
||||
execution = await service.execute_workflow(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
input_data={"message": message_data.message},
|
||||
trigger_type="chat",
|
||||
interaction_mode="chat"
|
||||
)
|
||||
|
||||
# Add agent response (in real implementation, this would come from workflow execution)
|
||||
assistant_response = execution.output_data.get('result', 'Workflow response')
|
||||
|
||||
assistant_message = await service.add_chat_message(
|
||||
session_id=session_id,
|
||||
user_id=current_user["sub"],
|
||||
role="agent",
|
||||
content=assistant_response,
|
||||
execution_id=execution.id
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"user_message": {
|
||||
"id": user_message.id,
|
||||
"content": user_message.content,
|
||||
"timestamp": user_message.created_at.isoformat()
|
||||
},
|
||||
"assistant_message": {
|
||||
"id": assistant_message.id,
|
||||
"content": assistant_message.content,
|
||||
"timestamp": assistant_message.created_at.isoformat()
|
||||
},
|
||||
"execution": {
|
||||
"id": execution.id,
|
||||
"status": execution.status
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to process chat message: {str(e)}")
|
||||
|
||||
|
||||
# Workflow interface generation endpoints
|
||||
@router.get("/{workflow_id}/interface/{mode}")
|
||||
async def get_workflow_interface(
|
||||
workflow_id: str,
|
||||
mode: InteractionMode,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get workflow interface configuration for specified mode"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
workflow = await service.get_workflow(workflow_id, current_user["sub"])
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
if mode not in workflow.interaction_modes:
|
||||
raise HTTPException(status_code=400, detail=f"Workflow does not support {mode} mode")
|
||||
|
||||
# Generate interface configuration based on mode
|
||||
from app.models.workflow import INTERACTION_MODE_CONFIGS
|
||||
|
||||
interface_config = INTERACTION_MODE_CONFIGS.get(mode, {})
|
||||
|
||||
# Customize based on workflow definition
|
||||
if mode == "form":
|
||||
# Generate form fields from workflow inputs
|
||||
trigger_nodes = [n for n in workflow.definition.get('nodes', []) if n.get('type') == 'trigger']
|
||||
form_fields = []
|
||||
|
||||
for node in trigger_nodes:
|
||||
if node.get('data', {}).get('input_schema'):
|
||||
form_fields.extend(node['data']['input_schema'])
|
||||
|
||||
interface_config['form_fields'] = form_fields
|
||||
|
||||
elif mode == "button":
|
||||
# Simple button configuration
|
||||
interface_config['button_text'] = f"Execute {workflow.name}"
|
||||
interface_config['description'] = workflow.description
|
||||
|
||||
elif mode == "dashboard":
|
||||
# Dashboard metrics configuration
|
||||
interface_config['metrics'] = {
|
||||
'execution_count': workflow.execution_count,
|
||||
'total_cost': workflow.total_cost_cents / 100,
|
||||
'avg_execution_time': workflow.average_execution_time_ms,
|
||||
'status': workflow.status
|
||||
}
|
||||
|
||||
return {
|
||||
'workflow_id': workflow_id,
|
||||
'mode': mode,
|
||||
'config': interface_config,
|
||||
'workflow': {
|
||||
'name': workflow.name,
|
||||
'description': workflow.description,
|
||||
'status': workflow.status
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get interface: {str(e)}")
|
||||
|
||||
|
||||
# Node type information endpoints
|
||||
@router.get("/node-types")
|
||||
async def get_workflow_node_types():
|
||||
"""Get available workflow node types and their configurations"""
|
||||
from app.models.workflow import WORKFLOW_NODE_TYPES
|
||||
return WORKFLOW_NODE_TYPES
|
||||
|
||||
|
||||
@router.get("/interaction-modes")
|
||||
async def get_interaction_modes():
|
||||
"""Get available interaction modes and their configurations"""
|
||||
from app.models.workflow import INTERACTION_MODE_CONFIGS
|
||||
return INTERACTION_MODE_CONFIGS
|
||||
395
apps/tenant-backend/app/api/websocket.py
Normal file
395
apps/tenant-backend/app/api/websocket.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
WebSocket API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Provides secure WebSocket connections for real-time chat with:
|
||||
- JWT authentication
|
||||
- Perfect tenant isolation
|
||||
- Conversation-based messaging
|
||||
- Automatic cleanup on disconnect
|
||||
|
||||
GT 2.0 Security Features:
|
||||
- Token-based authentication
|
||||
- Rate limiting per user
|
||||
- Connection limits per tenant
|
||||
- Automatic session cleanup
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db_session
|
||||
from app.core.security import get_current_user_email, get_tenant_info
|
||||
from app.websocket import (
|
||||
websocket_manager,
|
||||
get_websocket_manager,
|
||||
authenticate_websocket_connection,
|
||||
WebSocketManager
|
||||
)
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
|
||||
@router.websocket("/chat/{conversation_id}")
|
||||
async def websocket_chat_endpoint(
|
||||
websocket: WebSocket,
|
||||
conversation_id: str,
|
||||
token: str = Query(..., description="JWT authentication token")
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time chat in a specific conversation.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
conversation_id: Conversation ID to join
|
||||
token: JWT authentication token
|
||||
"""
|
||||
connection_id = None
|
||||
|
||||
try:
|
||||
# Authenticate connection
|
||||
try:
|
||||
user_id, tenant_id = await authenticate_websocket_connection(token)
|
||||
except ValueError as e:
|
||||
await websocket.close(code=1008, reason=f"Authentication failed: {e}")
|
||||
return
|
||||
|
||||
# Verify conversation access
|
||||
async with get_db_session() as db:
|
||||
conversation_service = ConversationService(db)
|
||||
conversation = await conversation_service.get_conversation(conversation_id, user_id)
|
||||
|
||||
if not conversation:
|
||||
await websocket.close(code=1008, reason="Conversation not found or access denied")
|
||||
return
|
||||
|
||||
# Establish WebSocket connection
|
||||
manager = get_websocket_manager()
|
||||
connection_id = await manager.connect(
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
logger.info(f"WebSocket chat connection established: {connection_id} for conversation {conversation_id}")
|
||||
|
||||
# Message handling loop
|
||||
while True:
|
||||
try:
|
||||
# Receive message
|
||||
data = await websocket.receive_text()
|
||||
message_data = json.loads(data)
|
||||
|
||||
# Handle message
|
||||
success = await manager.handle_message(connection_id, message_data)
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to handle message from {connection_id}")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected: {connection_id}")
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON received from {connection_id}")
|
||||
await manager.send_to_connection(connection_id, {
|
||||
"type": "error",
|
||||
"message": "Invalid JSON format",
|
||||
"code": "INVALID_JSON"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket message loop: {e}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket chat endpoint: {e}")
|
||||
try:
|
||||
await websocket.close(code=1011, reason=f"Server error: {e}")
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Cleanup connection
|
||||
if connection_id:
|
||||
await manager.disconnect(connection_id, reason="Connection closed")
|
||||
|
||||
|
||||
@router.websocket("/general")
|
||||
async def websocket_general_endpoint(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(..., description="JWT authentication token")
|
||||
):
|
||||
"""
|
||||
General WebSocket endpoint for notifications and system messages.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
token: JWT authentication token
|
||||
"""
|
||||
connection_id = None
|
||||
|
||||
try:
|
||||
# Authenticate connection
|
||||
try:
|
||||
user_id, tenant_id = await authenticate_websocket_connection(token)
|
||||
except ValueError as e:
|
||||
await websocket.close(code=1008, reason=f"Authentication failed: {e}")
|
||||
return
|
||||
|
||||
# Establish WebSocket connection (no specific conversation)
|
||||
manager = get_websocket_manager()
|
||||
connection_id = await manager.connect(
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None
|
||||
)
|
||||
|
||||
logger.info(f"General WebSocket connection established: {connection_id}")
|
||||
|
||||
# Message handling loop
|
||||
while True:
|
||||
try:
|
||||
# Receive message
|
||||
data = await websocket.receive_text()
|
||||
message_data = json.loads(data)
|
||||
|
||||
# Handle message
|
||||
success = await manager.handle_message(connection_id, message_data)
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to handle message from {connection_id}")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"General WebSocket disconnected: {connection_id}")
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON received from {connection_id}")
|
||||
await manager.send_to_connection(connection_id, {
|
||||
"type": "error",
|
||||
"message": "Invalid JSON format",
|
||||
"code": "INVALID_JSON"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in general WebSocket message loop: {e}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in general WebSocket endpoint: {e}")
|
||||
try:
|
||||
await websocket.close(code=1011, reason=f"Server error: {e}")
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Cleanup connection
|
||||
if connection_id:
|
||||
await manager.disconnect(connection_id, reason="Connection closed")
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_websocket_stats(
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: dict = Depends(get_tenant_info)
|
||||
):
|
||||
"""Get WebSocket connection statistics for tenant"""
|
||||
try:
|
||||
manager = get_websocket_manager()
|
||||
stats = manager.get_connection_stats()
|
||||
|
||||
# Filter stats for current tenant
|
||||
tenant_id = tenant_info["tenant_id"]
|
||||
tenant_stats = {
|
||||
"total_connections": stats["connections_by_tenant"].get(tenant_id, 0),
|
||||
"active_conversations": len([
|
||||
conv_id for conv_id, connections in manager.conversation_connections.items()
|
||||
if any(
|
||||
manager.connections.get(conn_id, {}).tenant_id == tenant_id
|
||||
for conn_id in connections
|
||||
)
|
||||
]),
|
||||
"user_connections": stats["connections_by_user"].get(current_user, 0)
|
||||
}
|
||||
|
||||
return JSONResponse(content=tenant_stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get WebSocket stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/broadcast/tenant")
|
||||
async def broadcast_to_tenant(
|
||||
message: dict,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: dict = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Broadcast message to all connections in tenant.
|
||||
(Admin/system use only)
|
||||
"""
|
||||
try:
|
||||
# Check if user has admin permissions
|
||||
# TODO: Implement proper admin role checking
|
||||
|
||||
manager = get_websocket_manager()
|
||||
tenant_id = tenant_info["tenant_id"]
|
||||
|
||||
broadcast_message = {
|
||||
"type": "system_broadcast",
|
||||
"message": message.get("content", ""),
|
||||
"timestamp": message.get("timestamp"),
|
||||
"sender": "system"
|
||||
}
|
||||
|
||||
await manager.broadcast_to_tenant(tenant_id, broadcast_message)
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "Broadcast sent successfully",
|
||||
"recipients": len(manager.tenant_connections.get(tenant_id, set()))
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to broadcast to tenant: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/conversations/{conversation_id}/broadcast")
|
||||
async def broadcast_to_conversation(
|
||||
conversation_id: str,
|
||||
message: dict,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: dict = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Broadcast message to all participants in a conversation.
|
||||
"""
|
||||
try:
|
||||
# Verify user has access to conversation
|
||||
conversation_service = ConversationService(db)
|
||||
conversation = await conversation_service.get_conversation(conversation_id, current_user)
|
||||
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
manager = get_websocket_manager()
|
||||
|
||||
broadcast_message = {
|
||||
"type": "conversation_broadcast",
|
||||
"conversation_id": conversation_id,
|
||||
"message": message.get("content", ""),
|
||||
"timestamp": message.get("timestamp"),
|
||||
"sender": current_user
|
||||
}
|
||||
|
||||
await manager.broadcast_to_conversation(conversation_id, broadcast_message)
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "Broadcast sent successfully",
|
||||
"conversation_id": conversation_id,
|
||||
"recipients": len(manager.conversation_connections.get(conversation_id, set()))
|
||||
})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to broadcast to conversation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/connections/{connection_id}/send")
|
||||
async def send_to_connection(
|
||||
connection_id: str,
|
||||
message: dict,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: dict = Depends(get_tenant_info)
|
||||
):
|
||||
"""
|
||||
Send message to specific connection.
|
||||
(Admin/system use only)
|
||||
"""
|
||||
try:
|
||||
# Check if user has admin permissions or owns the connection
|
||||
manager = get_websocket_manager()
|
||||
connection = manager.connections.get(connection_id)
|
||||
|
||||
if not connection:
|
||||
raise HTTPException(status_code=404, detail="Connection not found")
|
||||
|
||||
# Verify tenant access
|
||||
if connection.tenant_id != tenant_info["tenant_id"]:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# TODO: Add admin role check or connection ownership verification
|
||||
|
||||
target_message = {
|
||||
"type": "direct_message",
|
||||
"message": message.get("content", ""),
|
||||
"timestamp": message.get("timestamp"),
|
||||
"sender": current_user
|
||||
}
|
||||
|
||||
success = await manager.send_to_connection(connection_id, target_message)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=410, detail="Connection no longer active")
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "Message sent successfully",
|
||||
"connection_id": connection_id
|
||||
})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send to connection: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/connections/{connection_id}")
|
||||
async def disconnect_connection(
|
||||
connection_id: str,
|
||||
reason: str = Query("Admin disconnect", description="Disconnect reason"),
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: dict = Depends(get_tenant_info)
|
||||
):
|
||||
"""
|
||||
Forcefully disconnect a WebSocket connection.
|
||||
(Admin use only)
|
||||
"""
|
||||
try:
|
||||
# Check if user has admin permissions
|
||||
# TODO: Implement proper admin role checking
|
||||
|
||||
manager = get_websocket_manager()
|
||||
connection = manager.connections.get(connection_id)
|
||||
|
||||
if not connection:
|
||||
raise HTTPException(status_code=404, detail="Connection not found")
|
||||
|
||||
# Verify tenant access
|
||||
if connection.tenant_id != tenant_info["tenant_id"]:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
await manager.disconnect(connection_id, code=1008, reason=reason)
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "Connection disconnected successfully",
|
||||
"connection_id": connection_id,
|
||||
"reason": reason
|
||||
})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to disconnect connection: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user