GT AI OS Community v2.0.33 - Add NVIDIA NIM and Nemotron agents
- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2 - Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2 - Made more general-purpose (flexible targets, expanded tools) - Added nemotron-mini-agent.csv for fast local inference via Ollama - Added nemotron-agent.csv for advanced reasoning via Ollama - Added wiki page: Projects for NVIDIA NIMs and Nemotron
This commit is contained in:
5
apps/tenant-backend/app/api/__init__.py
Normal file
5
apps/tenant-backend/app/api/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend API Module
|
||||
|
||||
FastAPI routers and endpoints for tenant-specific functionality.
|
||||
"""
|
||||
757
apps/tenant-backend/app/api/auth.py
Normal file
757
apps/tenant-backend/app/api/auth.py
Normal file
@@ -0,0 +1,757 @@
|
||||
"""
|
||||
Authentication API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Handles JWT authentication via Control Panel Backend.
|
||||
No mocks - following GT 2.0 philosophy of building on real foundations.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from fastapi import APIRouter, HTTPException, Depends, Header, Request, Response, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from pydantic import BaseModel, EmailStr
|
||||
import jwt
|
||||
import logging
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.security import create_capability_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["authentication"])
|
||||
security = HTTPBearer(auto_error=False)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
# Development authentication function will be defined after class definitions
|
||||
|
||||
|
||||
# Pydantic models
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
user: dict
|
||||
tenant: Optional[dict] = None
|
||||
|
||||
|
||||
class TFASetupResponse(BaseModel):
|
||||
"""Response when TFA is enforced but not yet configured
|
||||
Session data (QR code, temp token) stored server-side in HTTP-only cookie"""
|
||||
requires_tfa: bool = True
|
||||
tfa_configured: bool = False
|
||||
|
||||
|
||||
class TFAVerificationResponse(BaseModel):
|
||||
"""Response when TFA is configured and verification is required
|
||||
Session data (temp token) stored server-side in HTTP-only cookie"""
|
||||
requires_tfa: bool = True
|
||||
tfa_configured: bool = True
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
id: int
|
||||
email: str
|
||||
full_name: str
|
||||
user_type: str
|
||||
tenant_id: Optional[int]
|
||||
capabilities: list
|
||||
is_active: bool
|
||||
|
||||
|
||||
# No development authentication function - violates No Mocks principle
|
||||
# All authentication MUST go through Control Panel Backend
|
||||
|
||||
|
||||
async def get_tenant_user_uuid_by_email(email: str) -> Optional[str]:
|
||||
"""
|
||||
Query tenant database to get user UUID by email.
|
||||
This maps Control Panel users to tenant-specific UUIDs for resource access.
|
||||
"""
|
||||
try:
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
|
||||
client = await get_postgresql_client()
|
||||
if not client or not client._initialized:
|
||||
logger.warning("PostgreSQL client not initialized, cannot query tenant user")
|
||||
return None
|
||||
|
||||
# Query tenant schema for user by email
|
||||
query = f"""
|
||||
SELECT id FROM {client.schema_name}.users
|
||||
WHERE email = $1
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
async with client._pool.acquire() as conn:
|
||||
user_uuid = await conn.fetchval(query, email)
|
||||
|
||||
if user_uuid:
|
||||
logger.info(f"Found tenant user UUID {user_uuid} for email {email}")
|
||||
return str(user_uuid)
|
||||
else:
|
||||
logger.warning(f"No tenant user found for email {email}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying tenant user by email {email}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def verify_token_with_control_panel(token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify JWT token with Control Panel Backend.
|
||||
This ensures consistency across the entire GT 2.0 platform.
|
||||
No fallbacks - if Control Panel is unavailable, authentication fails.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/verify-token",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get("success") and data.get("data", {}).get("valid"):
|
||||
return data["data"]
|
||||
|
||||
return {"valid": False, "error": "Invalid token"}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Control Panel unavailable for token verification: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service unavailable"
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
authorization: HTTPAuthorizationCredentials = Depends(security)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract current user from JWT token.
|
||||
Validates with Control Panel for consistency.
|
||||
No fallbacks - authentication is required.
|
||||
"""
|
||||
if not authorization or not authorization.credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = authorization.credentials
|
||||
validation = await verify_token_with_control_panel(token)
|
||||
|
||||
if not validation.get("valid"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=validation.get("error", "Invalid authentication token"),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user_data = validation.get("user", {})
|
||||
user_type = user_data.get("user_type", "")
|
||||
|
||||
# For super_admin users, allow access to any tenant backend
|
||||
# They will assume the tenant context of the backend they're accessing
|
||||
if user_type == "super_admin":
|
||||
logger.info(f"Super admin user {user_data.get('email')} accessing tenant backend {settings.tenant_id}")
|
||||
|
||||
# Override user data with tenant backend context and admin capabilities
|
||||
user_data.update({
|
||||
'tenant_id': settings.tenant_id,
|
||||
'tenant_domain': settings.tenant_domain,
|
||||
'tenant_name': f'Tenant {settings.tenant_domain}',
|
||||
'tenant_role': 'super_admin',
|
||||
'capabilities': [
|
||||
{'resource': '*', 'actions': ['*'], 'constraints': {}}
|
||||
]
|
||||
})
|
||||
# Tenant ID validation removed - any authenticated Control Panel user can access any tenant
|
||||
|
||||
return user_data
|
||||
|
||||
|
||||
@router.post("/auth/login", response_model=Union[LoginResponse, TFASetupResponse, TFAVerificationResponse])
|
||||
async def login(
|
||||
login_data: LoginRequest,
|
||||
request: Request,
|
||||
response: Response
|
||||
):
|
||||
"""
|
||||
Authenticate user via Control Panel Backend.
|
||||
This ensures a single source of truth for user authentication.
|
||||
No fallbacks - if Control Panel is unavailable, login fails.
|
||||
"""
|
||||
try:
|
||||
# Forward authentication request to Control Panel
|
||||
async with httpx.AsyncClient() as client:
|
||||
cp_response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/login",
|
||||
json={
|
||||
"email": login_data.email,
|
||||
"password": login_data.password
|
||||
},
|
||||
headers={
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown"),
|
||||
"X-App-Type": "tenant_app" # Distinguish from control_panel sessions
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if cp_response.status_code == 200:
|
||||
data = cp_response.json()
|
||||
|
||||
# Forward Set-Cookie headers from Control Panel to client
|
||||
if "set-cookie" in cp_response.headers:
|
||||
response.headers["set-cookie"] = cp_response.headers["set-cookie"]
|
||||
|
||||
# Check if this is a TFA response (setup or verification)
|
||||
if data.get("requires_tfa"):
|
||||
logger.info(
|
||||
f"TFA required for user {data.get('user_email')}: "
|
||||
f"configured={data.get('tfa_configured')}"
|
||||
)
|
||||
|
||||
# Return TFA response directly without modification
|
||||
if data.get("tfa_configured"):
|
||||
# TFA verification required
|
||||
return TFAVerificationResponse(**data)
|
||||
else:
|
||||
# TFA setup required
|
||||
return TFASetupResponse(**data)
|
||||
|
||||
# Handle normal login response (no TFA required)
|
||||
user = data.get("user", {})
|
||||
user_type = user.get("user_type", "")
|
||||
|
||||
# For super_admin users, allow access to any tenant backend
|
||||
# They will assume the tenant context of the backend they're accessing
|
||||
if user_type == "super_admin":
|
||||
logger.info(f"Super admin user {user.get('email')} accessing tenant backend {settings.tenant_id}")
|
||||
# Admin users can access any tenant backend - no tenant validation needed
|
||||
# Tenant ID validation removed - any authenticated Control Panel user can access any tenant
|
||||
|
||||
logger.info(
|
||||
f"User login successful: {user.get('email')} (ID: {user.get('id')})"
|
||||
)
|
||||
|
||||
# Use the original Control Panel JWT token - do not replace with tenant UUID token
|
||||
# UUID mapping will be handled at the service level when accessing tenant resources
|
||||
access_token = data["access_token"]
|
||||
logger.info(f"Using original Control Panel JWT for user {user.get('email')}")
|
||||
|
||||
# Get user's role from tenant database for frontend
|
||||
from app.core.permissions import get_user_role
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
tenant_role = await get_user_role(pg_client, user.get('email'), settings.tenant_domain)
|
||||
|
||||
# Add tenant role to user object for frontend
|
||||
user['role'] = tenant_role
|
||||
logger.info(f"Added tenant role '{tenant_role}' to user {user.get('email')}")
|
||||
|
||||
# Create tenant context for frontend
|
||||
tenant_info = {
|
||||
"id": settings.tenant_id,
|
||||
"domain": settings.tenant_domain,
|
||||
"name": f"Tenant {settings.tenant_domain}"
|
||||
}
|
||||
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
expires_in=data.get("expires_in", 86400),
|
||||
user=user,
|
||||
tenant=tenant_info
|
||||
)
|
||||
|
||||
elif response.status_code == 401:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Control Panel login failed: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Authentication service error"
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Control Panel connection failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Login error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Login failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/refresh")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Refresh authentication token.
|
||||
For now, returns the same token (Control Panel tokens have 24hr expiry).
|
||||
"""
|
||||
# Get current token
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# In a production system, we'd generate a new token here
|
||||
# For now, return the existing token since Control Panel tokens last 24 hours
|
||||
return {
|
||||
"access_token": token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 86400, # 24 hours
|
||||
"user": current_user
|
||||
}
|
||||
|
||||
|
||||
@router.post("/auth/logout")
|
||||
async def logout(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Logout user (forward to Control Panel for audit logging).
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
if token:
|
||||
# Forward logout to Control Panel for audit logging
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/auth/logout",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(
|
||||
f"User logout successful: {current_user.get('email')} (ID: {current_user.get('id')})"
|
||||
)
|
||||
|
||||
# Always return success for logout
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Logout error: {str(e)}")
|
||||
# Always return success for logout
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
|
||||
|
||||
@router.get("/auth/me")
|
||||
async def get_current_user_info(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""
|
||||
Get current user information.
|
||||
"""
|
||||
return {
|
||||
"success": True,
|
||||
"data": current_user
|
||||
}
|
||||
|
||||
|
||||
@router.get("/auth/verify")
|
||||
async def verify_token(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Verify if token is valid.
|
||||
"""
|
||||
return {
|
||||
"success": True,
|
||||
"valid": True,
|
||||
"user": current_user
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PASSWORD RESET PROXY ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/auth/request-password-reset")
|
||||
async def request_password_reset_proxy(
|
||||
data: dict,
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
Proxy password reset request to Control Panel Backend.
|
||||
Forwards client IP for rate limiting.
|
||||
"""
|
||||
try:
|
||||
# Get client IP for rate limiting
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/auth/request-password-reset",
|
||||
json=data,
|
||||
headers={
|
||||
"X-Forwarded-For": client_ip,
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to forward password reset request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Password reset service unavailable"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Password reset request error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Password reset request failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/reset-password")
|
||||
async def reset_password_proxy(data: dict):
|
||||
"""
|
||||
Proxy password reset to Control Panel Backend.
|
||||
"""
|
||||
try:
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/auth/reset-password",
|
||||
json=data,
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Return response with original status code
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response.json().get("detail", "Password reset failed")
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to forward password reset: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Password reset service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Password reset error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Password reset failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/auth/verify-reset-token")
|
||||
async def verify_reset_token_proxy(token: str):
|
||||
"""
|
||||
Proxy token verification to Control Panel Backend.
|
||||
"""
|
||||
try:
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/auth/verify-reset-token",
|
||||
params={"token": token},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to verify reset token: {str(e)}")
|
||||
return {"valid": False, "error": "Token verification service unavailable"}
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification error: {str(e)}")
|
||||
return {"valid": False, "error": "Token verification failed"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TFA PROXY ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/auth/tfa/session-data")
|
||||
async def get_tfa_session_data(request: Request, response: Response):
|
||||
"""
|
||||
Proxy TFA session data request to Control Panel Backend.
|
||||
Forwards cookies for session validation.
|
||||
"""
|
||||
try:
|
||||
# Get cookies from request
|
||||
cookie_header = request.headers.get("cookie", "")
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
cp_response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/session-data",
|
||||
headers={
|
||||
"Cookie": cookie_header,
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if cp_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=cp_response.status_code,
|
||||
detail=cp_response.json().get("detail", "Failed to get TFA session data")
|
||||
)
|
||||
|
||||
return cp_response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to get TFA session data: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/auth/tfa/session-qr-code")
|
||||
async def get_tfa_session_qr_code(request: Request):
|
||||
"""
|
||||
Proxy TFA QR code blob request to Control Panel Backend.
|
||||
Forwards cookies for session validation.
|
||||
Returns PNG image blob (never exposes TOTP secret to JavaScript).
|
||||
"""
|
||||
try:
|
||||
# Get cookies from request
|
||||
cookie_header = request.headers.get("cookie", "")
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
cp_response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/session-qr-code",
|
||||
headers={
|
||||
"Cookie": cookie_header,
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if cp_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=cp_response.status_code,
|
||||
detail="Failed to get TFA QR code"
|
||||
)
|
||||
|
||||
# Return raw PNG bytes with image/png content type
|
||||
from fastapi.responses import Response
|
||||
return Response(
|
||||
content=cp_response.content,
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Cache-Control": "no-store, no-cache, must-revalidate",
|
||||
"Pragma": "no-cache",
|
||||
"Expires": "0"
|
||||
}
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to get TFA QR code: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TFA session data error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get TFA session data"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/tfa/verify-login")
|
||||
async def verify_tfa_login_proxy(
|
||||
data: dict,
|
||||
request: Request,
|
||||
response: Response
|
||||
):
|
||||
"""
|
||||
Proxy TFA verification to Control Panel Backend.
|
||||
Forwards cookies for session validation.
|
||||
"""
|
||||
try:
|
||||
# Get cookies from request
|
||||
cookie_header = request.headers.get("cookie", "")
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
cp_response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/verify-login",
|
||||
json=data,
|
||||
headers={
|
||||
"Cookie": cookie_header,
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Forward Set-Cookie headers (cookie deletion after verification)
|
||||
if "set-cookie" in cp_response.headers:
|
||||
response.headers["set-cookie"] = cp_response.headers["set-cookie"]
|
||||
|
||||
if cp_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=cp_response.status_code,
|
||||
detail=cp_response.json().get("detail", "TFA verification failed")
|
||||
)
|
||||
|
||||
return cp_response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to verify TFA: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TFA verification error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="TFA verification failed"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SESSION STATUS ENDPOINT (Issue #264)
|
||||
# ============================================================================
|
||||
|
||||
class SessionStatusResponse(BaseModel):
|
||||
"""Response for session status check"""
|
||||
is_valid: bool
|
||||
seconds_remaining: int # Seconds until idle timeout
|
||||
show_warning: bool # True if < 5 minutes remaining
|
||||
absolute_seconds_remaining: Optional[int] = None # Seconds until absolute timeout
|
||||
|
||||
|
||||
@router.get("/auth/session/status", response_model=SessionStatusResponse)
|
||||
async def get_session_status(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get current session status for frontend session monitoring.
|
||||
|
||||
Proxies request to Control Panel Backend which is the authoritative
|
||||
source for session state. This endpoint replaces the complex react-idle-timer
|
||||
approach with a simple polling mechanism.
|
||||
|
||||
Frontend calls this every 60 seconds to check session health.
|
||||
|
||||
Returns:
|
||||
- is_valid: Whether session is currently valid
|
||||
- seconds_remaining: Seconds until idle timeout (30 min from last activity)
|
||||
- show_warning: True if warning should be shown (< 5 min remaining)
|
||||
- absolute_seconds_remaining: Seconds until absolute timeout (8 hours from login)
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No token provided"
|
||||
)
|
||||
|
||||
# Forward to Control Panel Backend (authoritative session source)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/session/status",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=5.0 # Short timeout for health check
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return SessionStatusResponse(
|
||||
is_valid=data.get("is_valid", True),
|
||||
seconds_remaining=data.get("seconds_remaining", 1800),
|
||||
show_warning=data.get("show_warning", False),
|
||||
absolute_seconds_remaining=data.get("absolute_seconds_remaining")
|
||||
)
|
||||
|
||||
elif response.status_code == 401:
|
||||
# Session expired - return invalid status
|
||||
return SessionStatusResponse(
|
||||
is_valid=False,
|
||||
seconds_remaining=0,
|
||||
show_warning=False,
|
||||
absolute_seconds_remaining=None
|
||||
)
|
||||
|
||||
else:
|
||||
# Unexpected response - return safe defaults
|
||||
logger.warning(
|
||||
f"Unexpected session status response: {response.status_code}"
|
||||
)
|
||||
return SessionStatusResponse(
|
||||
is_valid=True,
|
||||
seconds_remaining=1800, # 30 minutes default
|
||||
show_warning=False,
|
||||
absolute_seconds_remaining=None
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
# Control Panel unavailable - FAIL CLOSED for security
|
||||
logger.error(f"Session status check failed - Control Panel unavailable: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Session validation service unavailable"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Session status proxy error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Session status check failed"
|
||||
)
|
||||
144
apps/tenant-backend/app/api/conversations.py
Normal file
144
apps/tenant-backend/app/api/conversations.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Conversation API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Handles conversation management for AI chat sessions.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
router = APIRouter(tags=["conversations"])
|
||||
|
||||
|
||||
class CreateConversationRequest(BaseModel):
|
||||
"""Request model for creating a conversation"""
|
||||
agent_id: str = Field(..., description="Agent ID to use for conversation")
|
||||
title: Optional[str] = Field(None, description="Optional conversation title")
|
||||
|
||||
|
||||
class MessageRequest(BaseModel):
|
||||
"""Request model for sending a message"""
|
||||
content: str = Field(..., description="Message content")
|
||||
stream: bool = Field(default=False, description="Stream the response")
|
||||
|
||||
|
||||
async def get_conversation_service(db: AsyncSession = Depends(get_db)) -> ConversationService:
|
||||
"""Get conversation service instance"""
|
||||
return ConversationService(db)
|
||||
|
||||
|
||||
@router.get("/conversations")
|
||||
async def list_conversations(
|
||||
agent_id: Optional[str] = Query(None, description="Filter by agent ID"),
|
||||
limit: int = Query(20, ge=1, le=100, description="Number of results"),
|
||||
offset: int = Query(0, ge=0, description="Pagination offset"),
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""List user's conversations"""
|
||||
try:
|
||||
# Extract email from user object
|
||||
user_email = current_user.get("email") if isinstance(current_user, dict) else current_user
|
||||
result = await service.list_conversations(
|
||||
user_identifier=user_email,
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
return JSONResponse(status_code=200, content=result)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/conversations")
|
||||
async def create_conversation(
|
||||
request: CreateConversationRequest,
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""Create new conversation"""
|
||||
try:
|
||||
# Extract email from user object
|
||||
user_email = current_user.get("email") if isinstance(current_user, dict) else current_user
|
||||
result = await service.create_conversation(
|
||||
agent_id=request.agent_id,
|
||||
title=request.title,
|
||||
user_identifier=user_email
|
||||
)
|
||||
return JSONResponse(status_code=201, content=result)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}")
|
||||
async def get_conversation(
|
||||
conversation_id: int,
|
||||
include_messages: bool = Query(False, description="Include messages in response"),
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""Get conversation details"""
|
||||
try:
|
||||
result = await service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user,
|
||||
include_messages=include_messages
|
||||
)
|
||||
return JSONResponse(status_code=200, content=result)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}")
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
request: dict,
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""Update a conversation title"""
|
||||
try:
|
||||
title = request.get("title")
|
||||
if not title:
|
||||
raise ValueError("Title is required")
|
||||
|
||||
await service.update_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user,
|
||||
title=title
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"message": "Conversation updated successfully"})
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}")
|
||||
async def delete_conversation(
|
||||
conversation_id: int,
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""Delete a conversation"""
|
||||
try:
|
||||
await service.delete_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"message": "Conversation deleted successfully"})
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
404
apps/tenant-backend/app/api/documents.py
Normal file
404
apps/tenant-backend/app/api/documents.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""
|
||||
Document API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Handles document upload and management using PostgreSQL file service with perfect tenant isolation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Form, Query
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.core.path_security import sanitize_filename
|
||||
from app.services.postgresql_file_service import PostgreSQLFileService
|
||||
from app.services.document_processor import DocumentProcessor, get_document_processor
|
||||
from app.api.auth import get_tenant_user_uuid_by_email
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["documents"])
|
||||
|
||||
|
||||
@router.get("/documents")
|
||||
async def list_documents(
|
||||
status: Optional[str] = Query(None, description="Filter by processing status"),
|
||||
dataset_id: Optional[str] = Query(None, description="Filter by dataset ID"),
|
||||
offset: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""List user's documents with optional filtering using PostgreSQL file service"""
|
||||
try:
|
||||
# Get tenant user UUID from Control Panel user
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
tenant_user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
if not tenant_user_uuid:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_email} not found in tenant system")
|
||||
|
||||
# Get PostgreSQL file service
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=current_user.get('tenant_domain', 'test'),
|
||||
user_id=tenant_user_uuid
|
||||
)
|
||||
|
||||
# List files (documents) with optional dataset filter
|
||||
files = await file_service.list_files(
|
||||
category="documents",
|
||||
dataset_id=dataset_id,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
# Get chunk counts and document status for these documents
|
||||
document_ids = [file_info["id"] for file_info in files]
|
||||
chunk_counts = {}
|
||||
document_status = {}
|
||||
if document_ids:
|
||||
from app.core.database import get_postgresql_client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get chunk counts
|
||||
chunk_query = """
|
||||
SELECT document_id, COUNT(*) as chunk_count
|
||||
FROM document_chunks
|
||||
WHERE document_id = ANY($1)
|
||||
GROUP BY document_id
|
||||
"""
|
||||
chunk_results = await pg_client.execute_query(chunk_query, document_ids)
|
||||
chunk_counts = {str(row["document_id"]): row["chunk_count"] for row in chunk_results}
|
||||
|
||||
# Get document processing status and progress
|
||||
status_query = """
|
||||
SELECT id, processing_status, chunk_count, chunks_processed,
|
||||
total_chunks_expected, processing_progress, processing_stage,
|
||||
error_message, created_at, updated_at
|
||||
FROM documents
|
||||
WHERE id = ANY($1)
|
||||
"""
|
||||
status_results = await pg_client.execute_query(status_query, document_ids)
|
||||
document_status = {str(row["id"]): row for row in status_results}
|
||||
|
||||
# Convert to expected document format
|
||||
documents = []
|
||||
for file_info in files:
|
||||
doc_id = str(file_info["id"])
|
||||
chunk_count = chunk_counts.get(doc_id, 0)
|
||||
status_info = document_status.get(doc_id, {})
|
||||
|
||||
documents.append({
|
||||
"id": file_info["id"],
|
||||
"filename": file_info["filename"],
|
||||
"original_filename": file_info["original_filename"],
|
||||
"file_type": file_info["content_type"],
|
||||
"file_size_bytes": file_info["file_size"],
|
||||
"dataset_id": file_info.get("dataset_id"),
|
||||
"processing_status": status_info.get("processing_status", "completed"),
|
||||
"chunk_count": chunk_count,
|
||||
"chunks_processed": status_info.get("chunks_processed", chunk_count),
|
||||
"total_chunks_expected": status_info.get("total_chunks_expected", chunk_count),
|
||||
"processing_progress": status_info.get("processing_progress", 100 if chunk_count > 0 else 0),
|
||||
"processing_stage": status_info.get("processing_stage", "Completed" if chunk_count > 0 else "Pending"),
|
||||
"error_message": status_info.get("error_message"),
|
||||
"vector_count": chunk_count, # Each chunk gets one vector
|
||||
"created_at": file_info["created_at"],
|
||||
"processed_at": status_info.get("updated_at", file_info["created_at"])
|
||||
})
|
||||
|
||||
# Apply status filter if provided
|
||||
if status:
|
||||
documents = [doc for doc in documents if doc["processing_status"] == status]
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list documents: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/documents")
|
||||
async def upload_document(
|
||||
file: UploadFile = File(...),
|
||||
dataset_id: Optional[str] = Form(None),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Upload new document using PostgreSQL file service"""
|
||||
try:
|
||||
# Validate file
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
|
||||
# Get file extension
|
||||
import pathlib
|
||||
file_extension = pathlib.Path(file.filename).suffix.lower()
|
||||
|
||||
# Validate file type
|
||||
allowed_extensions = ['.pdf', '.docx', '.txt', '.md', '.html', '.csv', '.json']
|
||||
if file_extension not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}"
|
||||
)
|
||||
|
||||
# Get tenant user UUID from Control Panel user
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
tenant_user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
if not tenant_user_uuid:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_email} not found in tenant system")
|
||||
|
||||
# Get PostgreSQL file service
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=current_user.get('tenant_domain', 'test'),
|
||||
user_id=tenant_user_uuid
|
||||
)
|
||||
|
||||
# Store file
|
||||
result = await file_service.store_file(
|
||||
file=file,
|
||||
dataset_id=dataset_id,
|
||||
category="documents"
|
||||
)
|
||||
|
||||
# Start document processing if dataset_id is provided
|
||||
if dataset_id:
|
||||
try:
|
||||
# Get document processor with tenant domain
|
||||
tenant_domain = current_user.get('tenant_domain', 'test')
|
||||
processor = await get_document_processor(tenant_domain=tenant_domain)
|
||||
|
||||
# Process document for RAG (async)
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# Create temporary file for processing
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
|
||||
# Write file content to temp file
|
||||
file.file.seek(0) # Reset file pointer
|
||||
temp_file.write(await file.read())
|
||||
temp_file.flush()
|
||||
|
||||
# Process document using existing document ID
|
||||
try:
|
||||
processed_doc = await processor.process_file(
|
||||
file_path=Path(temp_file.name),
|
||||
dataset_id=dataset_id,
|
||||
user_id=tenant_user_uuid,
|
||||
original_filename=file.filename,
|
||||
document_id=result["id"] # Use existing document ID
|
||||
)
|
||||
|
||||
processing_status = "completed"
|
||||
chunk_count = getattr(processed_doc, 'chunk_count', 0)
|
||||
|
||||
except Exception as proc_error:
|
||||
logger.error(f"Document processing failed: {proc_error}")
|
||||
processing_status = "failed"
|
||||
chunk_count = 0
|
||||
finally:
|
||||
# Clean up temp file
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
except Exception as proc_error:
|
||||
logger.error(f"Failed to initiate document processing: {proc_error}")
|
||||
processing_status = "pending"
|
||||
chunk_count = 0
|
||||
else:
|
||||
processing_status = "completed"
|
||||
chunk_count = 0
|
||||
|
||||
# Return in expected format
|
||||
return {
|
||||
"id": result["id"],
|
||||
"filename": result["filename"],
|
||||
"original_filename": file.filename,
|
||||
"file_type": result["content_type"],
|
||||
"file_size_bytes": result["file_size"],
|
||||
"processing_status": processing_status,
|
||||
"chunk_count": chunk_count,
|
||||
"vector_count": chunk_count, # Each chunk gets one vector
|
||||
"created_at": result["upload_timestamp"],
|
||||
"processed_at": result["upload_timestamp"]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload document: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/documents/{document_id}/process")
|
||||
async def process_document(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Process a document for RAG pipeline (text extraction, chunking, embedding generation)"""
|
||||
try:
|
||||
# Get tenant user UUID from Control Panel user
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
tenant_user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
if not tenant_user_uuid:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_email} not found in tenant system")
|
||||
|
||||
tenant_domain = current_user.get('tenant_domain', 'test')
|
||||
|
||||
# Get PostgreSQL file service to verify document exists
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=tenant_user_uuid
|
||||
)
|
||||
|
||||
# Get file info to verify ownership and get file metadata
|
||||
file_info = await file_service.get_file_info(document_id)
|
||||
if not file_info:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Get document processor with tenant domain
|
||||
processor = await get_document_processor(tenant_domain=tenant_domain)
|
||||
|
||||
# Get file extension for temp file
|
||||
import pathlib
|
||||
original_filename = file_info.get("original_filename", file_info.get("filename", "unknown"))
|
||||
# Sanitize the filename to prevent path injection
|
||||
safe_filename = sanitize_filename(original_filename)
|
||||
file_extension = pathlib.Path(safe_filename).suffix.lower() if safe_filename else ".tmp"
|
||||
|
||||
# Create temporary file from database content for processing
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# codeql[py/path-injection] file_extension derived from sanitize_filename() at line 273
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
|
||||
# Stream file content from database to temp file
|
||||
async for chunk in file_service.get_file(document_id):
|
||||
temp_file.write(chunk)
|
||||
temp_file.flush()
|
||||
|
||||
# Process document
|
||||
try:
|
||||
processed_doc = await processor.process_file(
|
||||
file_path=Path(temp_file.name),
|
||||
dataset_id=file_info.get("dataset_id"),
|
||||
user_id=tenant_user_uuid,
|
||||
original_filename=original_filename
|
||||
)
|
||||
|
||||
processing_status = "completed"
|
||||
chunk_count = getattr(processed_doc, 'chunk_count', 0)
|
||||
|
||||
except Exception as proc_error:
|
||||
logger.error(f"Document processing failed for {document_id}: {proc_error}")
|
||||
processing_status = "failed"
|
||||
chunk_count = 0
|
||||
finally:
|
||||
# Clean up temp file
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"processing_status": processing_status,
|
||||
"message": "Document processed successfully" if processing_status == "completed" else f"Processing failed: {processing_status}",
|
||||
"chunk_count": chunk_count,
|
||||
"processed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process document {document_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}/status")
|
||||
async def get_document_processing_status(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get processing status of a document"""
|
||||
try:
|
||||
# Get document processor to check status
|
||||
tenant_domain = current_user.get('tenant_domain', 'test')
|
||||
processor = await get_document_processor(tenant_domain=tenant_domain)
|
||||
status = await processor.get_processing_status(document_id)
|
||||
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"processing_status": status.get("status", "unknown"),
|
||||
"error_message": status.get("error_message"),
|
||||
"chunk_count": status.get("chunk_count", 0),
|
||||
"chunks_processed": status.get("chunks_processed", 0),
|
||||
"total_chunks_expected": status.get("total_chunks_expected", 0),
|
||||
"processing_progress": status.get("processing_progress", 0),
|
||||
"processing_stage": status.get("processing_stage", "")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get processing status for {document_id}: {e}", exc_info=True)
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"processing_status": "unknown",
|
||||
"error_message": "Unable to retrieve processing status",
|
||||
"chunk_count": 0,
|
||||
"chunks_processed": 0,
|
||||
"total_chunks_expected": 0,
|
||||
"processing_progress": 0,
|
||||
"processing_stage": "Error"
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/documents/{document_id}")
|
||||
async def delete_document(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a document and its associated data"""
|
||||
try:
|
||||
# Get tenant user UUID from Control Panel user
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
tenant_user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
if not tenant_user_uuid:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_email} not found in tenant system")
|
||||
|
||||
# Get PostgreSQL file service
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=current_user.get('tenant_domain', 'test'),
|
||||
user_id=tenant_user_uuid
|
||||
)
|
||||
|
||||
# Verify document exists and user has permission to delete it
|
||||
file_info = await file_service.get_file_info(document_id)
|
||||
if not file_info:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Delete the document
|
||||
success = await file_service.delete_file(document_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete document")
|
||||
|
||||
return {
|
||||
"message": "Document deleted successfully",
|
||||
"document_id": document_id
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document {document_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Additional endpoints can be added here as needed for RAG processing
|
||||
99
apps/tenant-backend/app/api/embeddings.py
Normal file
99
apps/tenant-backend/app/api/embeddings.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
BGE-M3 Embedding Configuration API for Tenant Backend
|
||||
|
||||
Provides endpoint to update embedding configuration at runtime.
|
||||
This allows the tenant backend to switch between local and external embedding endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
import os
|
||||
|
||||
from app.services.embedding_client import get_embedding_client
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BGE_M3_ConfigRequest(BaseModel):
|
||||
"""BGE-M3 configuration update request"""
|
||||
is_local_mode: bool = True
|
||||
external_endpoint: Optional[str] = None
|
||||
|
||||
|
||||
class BGE_M3_ConfigResponse(BaseModel):
|
||||
"""BGE-M3 configuration response"""
|
||||
is_local_mode: bool
|
||||
current_endpoint: str
|
||||
message: str
|
||||
|
||||
|
||||
@router.post("/config/bge-m3", response_model=BGE_M3_ConfigResponse)
|
||||
async def update_bge_m3_config(
|
||||
config_request: BGE_M3_ConfigRequest
|
||||
) -> BGE_M3_ConfigResponse:
|
||||
"""
|
||||
Update BGE-M3 configuration for the tenant backend.
|
||||
|
||||
This allows switching between local and external endpoints at runtime.
|
||||
No authentication required for service-to-service calls.
|
||||
"""
|
||||
try:
|
||||
# Get the global embedding client
|
||||
embedding_client = get_embedding_client()
|
||||
|
||||
# Determine new endpoint
|
||||
if config_request.is_local_mode:
|
||||
new_endpoint = os.getenv('EMBEDDING_ENDPOINT', 'http://host.docker.internal:8005')
|
||||
else:
|
||||
if not config_request.external_endpoint:
|
||||
raise HTTPException(status_code=400, detail="External endpoint required when not in local mode")
|
||||
new_endpoint = config_request.external_endpoint
|
||||
|
||||
# Update the client endpoint
|
||||
embedding_client.update_endpoint(new_endpoint)
|
||||
|
||||
# Update environment variables for future client instances
|
||||
os.environ['BGE_M3_LOCAL_MODE'] = str(config_request.is_local_mode).lower()
|
||||
if config_request.external_endpoint:
|
||||
os.environ['BGE_M3_EXTERNAL_ENDPOINT'] = config_request.external_endpoint
|
||||
|
||||
logger.info(
|
||||
f"BGE-M3 configuration updated: "
|
||||
f"local_mode={config_request.is_local_mode}, "
|
||||
f"endpoint={new_endpoint}"
|
||||
)
|
||||
|
||||
return BGE_M3_ConfigResponse(
|
||||
is_local_mode=config_request.is_local_mode,
|
||||
current_endpoint=new_endpoint,
|
||||
message=f"BGE-M3 configuration updated to {'local' if config_request.is_local_mode else 'external'} mode"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating BGE-M3 config: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/config/bge-m3", response_model=BGE_M3_ConfigResponse)
|
||||
async def get_bge_m3_config() -> BGE_M3_ConfigResponse:
|
||||
"""
|
||||
Get current BGE-M3 configuration.
|
||||
"""
|
||||
try:
|
||||
embedding_client = get_embedding_client()
|
||||
|
||||
# Determine if currently in local mode
|
||||
is_local_mode = os.getenv('BGE_M3_LOCAL_MODE', 'true').lower() == 'true'
|
||||
|
||||
return BGE_M3_ConfigResponse(
|
||||
is_local_mode=is_local_mode,
|
||||
current_endpoint=embedding_client.base_url,
|
||||
message="Current BGE-M3 configuration"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting BGE-M3 config: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
358
apps/tenant-backend/app/api/events.py
Normal file
358
apps/tenant-backend/app/api/events.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
Event Automation API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Manages event subscriptions, triggers, and automation workflows
|
||||
with perfect tenant isolation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db_session
|
||||
from app.core.security import get_current_user_email, get_tenant_info
|
||||
from app.services.event_service import EventService, EventType, ActionType, EventActionConfig
|
||||
from app.schemas.event import (
|
||||
EventSubscriptionCreate, EventSubscriptionResponse, EventActionCreate,
|
||||
EventResponse, EventStatistics, ScheduledTaskResponse
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["events"])
|
||||
|
||||
|
||||
@router.post("/subscriptions", response_model=EventSubscriptionResponse)
|
||||
async def create_event_subscription(
|
||||
subscription: EventSubscriptionCreate,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Create a new event subscription"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
|
||||
# Convert actions
|
||||
actions = []
|
||||
for action_data in subscription.actions:
|
||||
action_config = EventActionConfig(
|
||||
action_type=ActionType(action_data.action_type),
|
||||
config=action_data.config,
|
||||
delay_seconds=action_data.delay_seconds,
|
||||
retry_count=action_data.retry_count,
|
||||
retry_delay=action_data.retry_delay,
|
||||
condition=action_data.condition
|
||||
)
|
||||
actions.append(action_config)
|
||||
|
||||
subscription_id = await event_service.create_subscription(
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
event_type=EventType(subscription.event_type),
|
||||
actions=actions,
|
||||
name=subscription.name,
|
||||
description=subscription.description
|
||||
)
|
||||
|
||||
# Get created subscription
|
||||
subscriptions = await event_service.get_user_subscriptions(
|
||||
current_user, tenant_info["tenant_id"]
|
||||
)
|
||||
created_subscription = next(
|
||||
(s for s in subscriptions if s.id == subscription_id), None
|
||||
)
|
||||
|
||||
if not created_subscription:
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve created subscription")
|
||||
|
||||
return EventSubscriptionResponse.from_orm(created_subscription)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create event subscription: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/subscriptions", response_model=List[EventSubscriptionResponse])
|
||||
async def list_event_subscriptions(
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""List user's event subscriptions"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
subscriptions = await event_service.get_user_subscriptions(
|
||||
current_user, tenant_info["tenant_id"]
|
||||
)
|
||||
|
||||
return [EventSubscriptionResponse.from_orm(sub) for sub in subscriptions]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list event subscriptions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/subscriptions/{subscription_id}/status")
|
||||
async def update_subscription_status(
|
||||
subscription_id: str,
|
||||
is_active: bool,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Update event subscription status"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
success = await event_service.update_subscription_status(
|
||||
subscription_id, current_user, is_active
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Subscription not found")
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": f"Subscription {'activated' if is_active else 'deactivated'} successfully"
|
||||
})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update subscription status: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/subscriptions/{subscription_id}")
|
||||
async def delete_event_subscription(
|
||||
subscription_id: str,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Delete event subscription"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
success = await event_service.delete_subscription(subscription_id, current_user)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Subscription not found")
|
||||
|
||||
return JSONResponse(content={"message": "Subscription deleted successfully"})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete subscription: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/emit")
|
||||
async def emit_event(
|
||||
event_type: str,
|
||||
data: Dict[str, Any],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Manually emit an event"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
|
||||
# Validate event type
|
||||
try:
|
||||
event_type_enum = EventType(event_type)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid event type: {event_type}")
|
||||
|
||||
event_id = await event_service.emit_event(
|
||||
event_type=event_type_enum,
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
data=data,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
return JSONResponse(content={
|
||||
"event_id": event_id,
|
||||
"message": "Event emitted successfully"
|
||||
})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to emit event: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/history", response_model=List[EventResponse])
|
||||
async def get_event_history(
|
||||
event_types: Optional[List[str]] = Query(None, description="Filter by event types"),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Get event history for user"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
|
||||
# Convert event types if provided
|
||||
event_type_enums = None
|
||||
if event_types:
|
||||
try:
|
||||
event_type_enums = [EventType(et) for et in event_types]
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid event type: {e}")
|
||||
|
||||
events = await event_service.get_event_history(
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
event_types=event_type_enums,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return [EventResponse.from_orm(event) for event in events]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get event history: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/statistics", response_model=EventStatistics)
|
||||
async def get_event_statistics(
|
||||
days: int = Query(30, ge=1, le=365),
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Get event statistics for user"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
stats = await event_service.get_event_statistics(
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
days=days
|
||||
)
|
||||
|
||||
return EventStatistics(**stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get event statistics: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/types")
|
||||
async def get_available_event_types():
|
||||
"""Get available event types and actions"""
|
||||
return JSONResponse(content={
|
||||
"event_types": [
|
||||
{"value": et.value, "description": et.value.replace("_", " ").title()}
|
||||
for et in EventType
|
||||
],
|
||||
"action_types": [
|
||||
{"value": at.value, "description": at.value.replace("_", " ").title()}
|
||||
for at in ActionType
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
# Document automation endpoints
|
||||
@router.post("/documents/{document_id}/auto-process")
|
||||
async def trigger_document_processing(
|
||||
document_id: int,
|
||||
chunking_strategy: Optional[str] = Query("hybrid", description="Chunking strategy"),
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Trigger automated document processing"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
|
||||
event_id = await event_service.emit_event(
|
||||
event_type=EventType.DOCUMENT_UPLOADED,
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
data={
|
||||
"document_id": document_id,
|
||||
"filename": f"document_{document_id}",
|
||||
"chunking_strategy": chunking_strategy,
|
||||
"manual_trigger": True
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(content={
|
||||
"event_id": event_id,
|
||||
"message": "Document processing automation triggered"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger document processing: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Conversation automation endpoints
|
||||
@router.post("/conversations/{conversation_id}/auto-analyze")
|
||||
async def trigger_conversation_analysis(
|
||||
conversation_id: int,
|
||||
analysis_type: str = Query("sentiment", description="Type of analysis"),
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Trigger automated conversation analysis"""
|
||||
try:
|
||||
event_service = EventService(db)
|
||||
|
||||
event_id = await event_service.emit_event(
|
||||
event_type=EventType.CONVERSATION_STARTED,
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
data={
|
||||
"conversation_id": conversation_id,
|
||||
"analysis_type": analysis_type,
|
||||
"manual_trigger": True
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(content={
|
||||
"event_id": event_id,
|
||||
"message": "Conversation analysis automation triggered"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger conversation analysis: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Default subscriptions endpoint
|
||||
@router.post("/setup-defaults")
|
||||
async def setup_default_subscriptions(
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: Dict[str, str] = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Setup default event subscriptions for user"""
|
||||
try:
|
||||
from app.services.event_service import setup_default_subscriptions
|
||||
|
||||
event_service = EventService(db)
|
||||
await setup_default_subscriptions(
|
||||
user_id=current_user,
|
||||
tenant_id=tenant_info["tenant_id"],
|
||||
event_service=event_service
|
||||
)
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "Default event subscriptions created successfully"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup default subscriptions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
3
apps/tenant-backend/app/api/internal/__init__.py
Normal file
3
apps/tenant-backend/app/api/internal/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Internal API endpoints for service-to-service communication
|
||||
"""
|
||||
137
apps/tenant-backend/app/api/messages.py
Normal file
137
apps/tenant-backend/app/api/messages.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Message API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Handles message management within conversations.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
import logging
|
||||
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.core.security import get_current_user
|
||||
from app.core.user_resolver import resolve_user_uuid
|
||||
|
||||
router = APIRouter(tags=["messages"])
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
"""Request model for sending a message"""
|
||||
content: str = Field(..., description="Message content")
|
||||
stream: bool = Field(default=False, description="Stream the response")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_conversation_service(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ConversationService:
|
||||
"""Get properly initialized conversation service"""
|
||||
tenant_domain, user_email, user_id = await resolve_user_uuid(current_user)
|
||||
return ConversationService(tenant_domain=tenant_domain, user_id=user_id)
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}/messages")
|
||||
async def list_messages(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""List messages in conversation"""
|
||||
try:
|
||||
# Get properly initialized service
|
||||
service = await get_conversation_service(current_user)
|
||||
|
||||
# Get conversation with messages
|
||||
result = await service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=service.user_id,
|
||||
include_messages=True
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"messages": result.get("messages", [])}
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/conversations/{conversation_id}/messages")
|
||||
async def send_message(
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
stream: bool = False,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> JSONResponse:
|
||||
"""Send a message and get AI response"""
|
||||
try:
|
||||
# Get properly initialized service
|
||||
service = await get_conversation_service(current_user)
|
||||
|
||||
# Send message
|
||||
result = await service.send_message(
|
||||
conversation_id=conversation_id,
|
||||
content=content,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
# Generate title after first message
|
||||
if result.get("is_first_message", False):
|
||||
try:
|
||||
await service.auto_generate_conversation_title(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=service.user_id
|
||||
)
|
||||
logger.info(f"✅ Title generated for conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate title: {e}")
|
||||
# Don't fail the request if title generation fails
|
||||
|
||||
return JSONResponse(status_code=201, content=result)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}/messages/stream")
|
||||
async def stream_message(
|
||||
conversation_id: int,
|
||||
content: str,
|
||||
service: ConversationService = Depends(get_conversation_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""Stream AI response for a message"""
|
||||
|
||||
async def generate():
|
||||
"""Generate SSE stream"""
|
||||
try:
|
||||
async for chunk in service.stream_message_response(
|
||||
conversation_id=conversation_id,
|
||||
message_content=content,
|
||||
user_identifier=current_user
|
||||
):
|
||||
# Format as Server-Sent Event
|
||||
yield f"data: {json.dumps({'content': chunk})}\n\n"
|
||||
|
||||
# Send completion signal
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Message streaming failed: {e}", exc_info=True)
|
||||
yield f"data: {json.dumps({'error': 'An internal error occurred. Please try again.'})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
411
apps/tenant-backend/app/api/v1/access_groups.py
Normal file
411
apps/tenant-backend/app/api/v1/access_groups.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Access Groups API for GT 2.0 Tenant Backend
|
||||
|
||||
RESTful API endpoints for managing resource access groups and permissions.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db_session
|
||||
from app.core.security import get_current_user, verify_capability_token
|
||||
from app.models.access_group import (
|
||||
AccessGroup, ResourceCreate, ResourceUpdate, ResourceResponse,
|
||||
AccessGroupModel
|
||||
)
|
||||
from app.services.access_controller import AccessController, AccessControlMiddleware
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/access-groups",
|
||||
tags=["access-groups"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
async def get_access_controller(
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain")
|
||||
) -> AccessController:
|
||||
"""Dependency to get access controller for tenant"""
|
||||
return AccessController(x_tenant_domain)
|
||||
|
||||
|
||||
async def get_middleware(
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain")
|
||||
) -> AccessControlMiddleware:
|
||||
"""Dependency to get access control middleware"""
|
||||
return AccessControlMiddleware(x_tenant_domain)
|
||||
|
||||
|
||||
@router.post("/resources", response_model=ResourceResponse)
|
||||
async def create_resource(
|
||||
resource: ResourceCreate,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Create a new resource with access control.
|
||||
|
||||
- **name**: Resource name
|
||||
- **resource_type**: Type of resource (agent, dataset, etc.)
|
||||
- **access_group**: Access level (individual, team, organization)
|
||||
- **team_members**: List of user IDs for team access
|
||||
- **metadata**: Resource-specific metadata
|
||||
"""
|
||||
try:
|
||||
# Extract bearer token
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
|
||||
# Create resource
|
||||
created_resource = await controller.create_resource(
|
||||
user_id=user.id,
|
||||
resource_data=resource,
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
return ResourceResponse(
|
||||
id=created_resource.id,
|
||||
name=created_resource.name,
|
||||
resource_type=created_resource.resource_type,
|
||||
owner_id=created_resource.owner_id,
|
||||
tenant_domain=created_resource.tenant_domain,
|
||||
access_group=created_resource.access_group,
|
||||
team_members=created_resource.team_members,
|
||||
created_at=created_resource.created_at,
|
||||
updated_at=created_resource.updated_at,
|
||||
metadata=created_resource.metadata,
|
||||
file_path=created_resource.file_path
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create resource: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/resources/{resource_id}/access", response_model=ResourceResponse)
|
||||
async def update_resource_access(
|
||||
resource_id: str,
|
||||
access_update: AccessGroupModel,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
middleware: AccessControlMiddleware = Depends(get_middleware),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Update resource access group.
|
||||
|
||||
Only the resource owner can change access settings.
|
||||
|
||||
- **access_group**: New access level
|
||||
- **team_members**: Updated team members list (for team access)
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Verify permission
|
||||
await middleware.verify_request(
|
||||
user_id=user.id,
|
||||
resource_id=resource_id,
|
||||
action="share",
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
# Update access
|
||||
updated_resource = await controller.update_resource_access(
|
||||
user_id=user.id,
|
||||
resource_id=resource_id,
|
||||
new_access_group=access_update.access_group,
|
||||
team_members=access_update.team_members
|
||||
)
|
||||
|
||||
return ResourceResponse(
|
||||
id=updated_resource.id,
|
||||
name=updated_resource.name,
|
||||
resource_type=updated_resource.resource_type,
|
||||
owner_id=updated_resource.owner_id,
|
||||
tenant_domain=updated_resource.tenant_domain,
|
||||
access_group=updated_resource.access_group,
|
||||
team_members=updated_resource.team_members,
|
||||
created_at=updated_resource.created_at,
|
||||
updated_at=updated_resource.updated_at,
|
||||
metadata=updated_resource.metadata,
|
||||
file_path=updated_resource.file_path
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update access: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/resources", response_model=List[ResourceResponse])
|
||||
async def list_accessible_resources(
|
||||
resource_type: Optional[str] = Query(None, description="Filter by resource type"),
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
List all resources accessible to the current user.
|
||||
|
||||
Returns resources based on access groups:
|
||||
- Individual: Only owned resources
|
||||
- Team: Resources shared with user's teams
|
||||
- Organization: All organization-wide resources
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
|
||||
# Get accessible resources
|
||||
resources = await controller.list_accessible_resources(
|
||||
user_id=user.id,
|
||||
resource_type=resource_type
|
||||
)
|
||||
|
||||
return [
|
||||
ResourceResponse(
|
||||
id=r.id,
|
||||
name=r.name,
|
||||
resource_type=r.resource_type,
|
||||
owner_id=r.owner_id,
|
||||
tenant_domain=r.tenant_domain,
|
||||
access_group=r.access_group,
|
||||
team_members=r.team_members,
|
||||
created_at=r.created_at,
|
||||
updated_at=r.updated_at,
|
||||
metadata=r.metadata,
|
||||
file_path=r.file_path
|
||||
)
|
||||
for r in resources
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list resources: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/resources/{resource_id}", response_model=ResourceResponse)
|
||||
async def get_resource(
|
||||
resource_id: str,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
middleware: AccessControlMiddleware = Depends(get_middleware),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Get a specific resource if user has access.
|
||||
|
||||
Checks read permission based on access group.
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Verify permission
|
||||
await middleware.verify_request(
|
||||
user_id=user.id,
|
||||
resource_id=resource_id,
|
||||
action="read",
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
# Load resource
|
||||
resource = await controller._load_resource(resource_id)
|
||||
|
||||
return ResourceResponse(
|
||||
id=resource.id,
|
||||
name=resource.name,
|
||||
resource_type=resource.resource_type,
|
||||
owner_id=resource.owner_id,
|
||||
tenant_domain=resource.tenant_domain,
|
||||
access_group=resource.access_group,
|
||||
team_members=resource.team_members,
|
||||
created_at=resource.created_at,
|
||||
updated_at=resource.updated_at,
|
||||
metadata=resource.metadata,
|
||||
file_path=resource.file_path
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get resource: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/resources/{resource_id}")
|
||||
async def delete_resource(
|
||||
resource_id: str,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
middleware: AccessControlMiddleware = Depends(get_middleware),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Delete a resource.
|
||||
|
||||
Only the resource owner can delete.
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Verify permission
|
||||
await middleware.verify_request(
|
||||
user_id=user.id,
|
||||
resource_id=resource_id,
|
||||
action="delete",
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
# Delete resource (implementation needed)
|
||||
# await controller.delete_resource(resource_id)
|
||||
|
||||
return {"status": "deleted", "resource_id": resource_id}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete resource: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_resource_stats(
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Get resource statistics for the current user.
|
||||
|
||||
Returns counts by type and access group.
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await get_current_user(authorization, db)
|
||||
|
||||
# Get stats
|
||||
stats = await controller.get_resource_stats(user_id=user.id)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/resources/{resource_id}/team-members/{user_id}")
|
||||
async def add_team_member(
|
||||
resource_id: str,
|
||||
user_id: str,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
middleware: AccessControlMiddleware = Depends(get_middleware),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Add a user to team access for a resource.
|
||||
|
||||
Only the resource owner can add team members.
|
||||
Resource must have team access group.
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
current_user = await get_current_user(authorization, db)
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Verify permission
|
||||
await middleware.verify_request(
|
||||
user_id=current_user.id,
|
||||
resource_id=resource_id,
|
||||
action="share",
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
# Load resource
|
||||
resource = await controller._load_resource(resource_id)
|
||||
|
||||
# Check if team access
|
||||
if resource.access_group != AccessGroup.TEAM:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Resource must have team access to add members"
|
||||
)
|
||||
|
||||
# Add team member
|
||||
resource.add_team_member(user_id)
|
||||
|
||||
# Save changes (implementation needed)
|
||||
# await controller.save_resource(resource)
|
||||
|
||||
return {"status": "added", "user_id": user_id, "resource_id": resource_id}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to add team member: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/resources/{resource_id}/team-members/{user_id}")
|
||||
async def remove_team_member(
|
||||
resource_id: str,
|
||||
user_id: str,
|
||||
authorization: str = Header(..., description="Bearer token"),
|
||||
x_tenant_domain: str = Header(..., description="Tenant domain"),
|
||||
controller: AccessController = Depends(get_access_controller),
|
||||
middleware: AccessControlMiddleware = Depends(get_middleware),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Remove a user from team access for a resource.
|
||||
|
||||
Only the resource owner can remove team members.
|
||||
"""
|
||||
try:
|
||||
# Get current user
|
||||
current_user = await get_current_user(authorization, db)
|
||||
capability_token = authorization.replace("Bearer ", "")
|
||||
|
||||
# Verify permission
|
||||
await middleware.verify_request(
|
||||
user_id=current_user.id,
|
||||
resource_id=resource_id,
|
||||
action="share",
|
||||
capability_token=capability_token
|
||||
)
|
||||
|
||||
# Load resource
|
||||
resource = await controller._load_resource(resource_id)
|
||||
|
||||
# Remove team member
|
||||
resource.remove_team_member(user_id)
|
||||
|
||||
# Save changes (implementation needed)
|
||||
# await controller.save_resource(resource)
|
||||
|
||||
return {"status": "removed", "user_id": user_id, "resource_id": resource_id}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to remove team member: {str(e)}")
|
||||
1465
apps/tenant-backend/app/api/v1/agents.py
Normal file
1465
apps/tenant-backend/app/api/v1/agents.py
Normal file
File diff suppressed because it is too large
Load Diff
582
apps/tenant-backend/app/api/v1/api_keys.py
Normal file
582
apps/tenant-backend/app/api/v1/api_keys.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
Enhanced API Keys Management API for GT 2.0
|
||||
|
||||
RESTful API for advanced API key management with capability-based permissions,
|
||||
configurable constraints, and comprehensive audit logging.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, Depends, Header, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.security import get_current_user, verify_capability_token
|
||||
from app.services.enhanced_api_keys import (
|
||||
EnhancedAPIKeyService, APIKeyConfig, APIKeyStatus, APIKeyScope, SharingPermission
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
"""Request to create a new API key"""
|
||||
name: str = Field(..., description="Human-readable name for the key")
|
||||
description: Optional[str] = Field(None, description="Description of the key's purpose")
|
||||
capabilities: List[str] = Field(..., description="List of capability strings")
|
||||
scope: str = Field("user", description="Key scope: user, tenant, admin")
|
||||
expires_in_days: int = Field(90, description="Expiration time in days")
|
||||
rate_limit_per_hour: Optional[int] = Field(None, description="Custom rate limit per hour")
|
||||
daily_quota: Optional[int] = Field(None, description="Custom daily quota")
|
||||
cost_limit_cents: Optional[int] = Field(None, description="Custom cost limit in cents")
|
||||
allowed_endpoints: Optional[List[str]] = Field(None, description="Allowed endpoints")
|
||||
blocked_endpoints: Optional[List[str]] = Field(None, description="Blocked endpoints")
|
||||
allowed_ips: Optional[List[str]] = Field(None, description="Allowed IP addresses")
|
||||
tenant_constraints: Optional[Dict[str, Any]] = Field(None, description="Custom tenant constraints")
|
||||
|
||||
|
||||
class APIKeyResponse(BaseModel):
|
||||
"""API key configuration response"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
owner_id: str
|
||||
scope: str
|
||||
capabilities: List[str]
|
||||
rate_limit_per_hour: int
|
||||
daily_quota: int
|
||||
cost_limit_cents: int
|
||||
max_tokens_per_request: int
|
||||
allowed_endpoints: List[str]
|
||||
blocked_endpoints: List[str]
|
||||
allowed_ips: List[str]
|
||||
status: str
|
||||
created_at: datetime
|
||||
expires_at: Optional[datetime]
|
||||
last_rotated: Optional[datetime]
|
||||
usage: Dict[str, Any]
|
||||
|
||||
|
||||
class CreateAPIKeyResponse(BaseModel):
|
||||
"""Response when creating a new API key"""
|
||||
api_key: APIKeyResponse
|
||||
raw_key: str = Field(..., description="The actual API key (only shown once)")
|
||||
warning: str = Field(..., description="Security warning about key storage")
|
||||
|
||||
|
||||
class RotateAPIKeyResponse(BaseModel):
|
||||
"""Response when rotating an API key"""
|
||||
api_key: APIKeyResponse
|
||||
new_raw_key: str = Field(..., description="The new API key (only shown once)")
|
||||
warning: str = Field(..., description="Security warning about updating systems")
|
||||
|
||||
|
||||
class APIKeyUsageResponse(BaseModel):
|
||||
"""API key usage analytics response"""
|
||||
total_requests: int
|
||||
total_errors: int
|
||||
avg_requests_per_day: float
|
||||
rate_limit_hits: int
|
||||
keys_analyzed: int
|
||||
date_range: Dict[str, str]
|
||||
most_used_endpoints: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class ValidateAPIKeyRequest(BaseModel):
|
||||
"""Request to validate an API key"""
|
||||
api_key: str = Field(..., description="Raw API key to validate")
|
||||
endpoint: Optional[str] = Field(None, description="Endpoint being accessed")
|
||||
client_ip: Optional[str] = Field(None, description="Client IP address")
|
||||
|
||||
|
||||
class ValidateAPIKeyResponse(BaseModel):
|
||||
"""API key validation response"""
|
||||
valid: bool
|
||||
error_message: Optional[str]
|
||||
capability_token: Optional[str]
|
||||
rate_limit_remaining: Optional[int]
|
||||
quota_remaining: Optional[int]
|
||||
|
||||
|
||||
# Dependency injection
|
||||
async def get_api_key_service(
|
||||
authorization: str = Header(...),
|
||||
current_user: str = Depends(get_current_user)
|
||||
) -> EnhancedAPIKeyService:
|
||||
"""Get enhanced API key service"""
|
||||
# Extract tenant from token (mock implementation)
|
||||
tenant_domain = "customer1.com" # Would extract from JWT
|
||||
|
||||
# Use tenant-specific signing key
|
||||
signing_key = f"signing_key_for_{tenant_domain}"
|
||||
|
||||
return EnhancedAPIKeyService(tenant_domain, signing_key)
|
||||
|
||||
|
||||
@router.post("", response_model=CreateAPIKeyResponse)
|
||||
async def create_api_key(
|
||||
request: CreateAPIKeyRequest,
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Create a new API key with specified capabilities and constraints.
|
||||
|
||||
- **name**: Human-readable name for the key
|
||||
- **capabilities**: List of capability strings (e.g., ["llm:gpt-4", "rag:search"])
|
||||
- **scope**: user, tenant, or admin level
|
||||
- **expires_in_days**: Expiration time (default 90 days)
|
||||
- **rate_limit_per_hour**: Custom rate limit (optional)
|
||||
- **allowed_endpoints**: Restrict to specific endpoints (optional)
|
||||
- **tenant_constraints**: Custom constraints for the key (optional)
|
||||
"""
|
||||
try:
|
||||
# Convert scope string to enum
|
||||
scope = APIKeyScope(request.scope.lower())
|
||||
|
||||
# Build constraints from request
|
||||
constraints = request.tenant_constraints or {}
|
||||
|
||||
# Apply custom limits if provided
|
||||
if request.rate_limit_per_hour:
|
||||
constraints["rate_limit_per_hour"] = request.rate_limit_per_hour
|
||||
if request.daily_quota:
|
||||
constraints["daily_quota"] = request.daily_quota
|
||||
if request.cost_limit_cents:
|
||||
constraints["cost_limit_cents"] = request.cost_limit_cents
|
||||
|
||||
# Create API key
|
||||
api_key, raw_key = await api_key_service.create_api_key(
|
||||
name=request.name,
|
||||
owner_id=current_user,
|
||||
capabilities=request.capabilities,
|
||||
scope=scope,
|
||||
expires_in_days=request.expires_in_days,
|
||||
constraints=constraints,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
# Apply custom settings if provided
|
||||
if request.allowed_endpoints:
|
||||
api_key.allowed_endpoints = request.allowed_endpoints
|
||||
if request.blocked_endpoints:
|
||||
api_key.blocked_endpoints = request.blocked_endpoints
|
||||
if request.allowed_ips:
|
||||
api_key.allowed_ips = request.allowed_ips
|
||||
if request.description:
|
||||
api_key.description = request.description
|
||||
|
||||
# Store updated key
|
||||
await api_key_service._store_api_key(api_key)
|
||||
|
||||
return CreateAPIKeyResponse(
|
||||
api_key=APIKeyResponse(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
description=api_key.description,
|
||||
owner_id=api_key.owner_id,
|
||||
scope=api_key.scope.value,
|
||||
capabilities=api_key.capabilities,
|
||||
rate_limit_per_hour=api_key.rate_limit_per_hour,
|
||||
daily_quota=api_key.daily_quota,
|
||||
cost_limit_cents=api_key.cost_limit_cents,
|
||||
max_tokens_per_request=api_key.max_tokens_per_request,
|
||||
allowed_endpoints=api_key.allowed_endpoints,
|
||||
blocked_endpoints=api_key.blocked_endpoints,
|
||||
allowed_ips=api_key.allowed_ips,
|
||||
status=api_key.status.value,
|
||||
created_at=api_key.created_at,
|
||||
expires_at=api_key.expires_at,
|
||||
last_rotated=api_key.last_rotated,
|
||||
usage=api_key.usage.to_dict()
|
||||
),
|
||||
raw_key=raw_key,
|
||||
warning="Store this API key securely. It will not be shown again."
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create API key: {str(e)}")
|
||||
|
||||
|
||||
@router.get("", response_model=List[APIKeyResponse])
|
||||
async def list_api_keys(
|
||||
include_usage: bool = Query(True, description="Include usage statistics"),
|
||||
status: Optional[str] = Query(None, description="Filter by status"),
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List API keys for the current user.
|
||||
|
||||
- **include_usage**: Include detailed usage statistics
|
||||
- **status**: Filter by key status (active, suspended, expired, revoked)
|
||||
"""
|
||||
try:
|
||||
api_keys = await api_key_service.list_user_api_keys(
|
||||
owner_id=current_user,
|
||||
capability_token=authorization,
|
||||
include_usage=include_usage
|
||||
)
|
||||
|
||||
# Filter by status if provided
|
||||
if status:
|
||||
try:
|
||||
status_filter = APIKeyStatus(status.lower())
|
||||
api_keys = [key for key in api_keys if key.status == status_filter]
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid status: {status}")
|
||||
|
||||
return [
|
||||
APIKeyResponse(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
description=key.description,
|
||||
owner_id=key.owner_id,
|
||||
scope=key.scope.value,
|
||||
capabilities=key.capabilities,
|
||||
rate_limit_per_hour=key.rate_limit_per_hour,
|
||||
daily_quota=key.daily_quota,
|
||||
cost_limit_cents=key.cost_limit_cents,
|
||||
max_tokens_per_request=key.max_tokens_per_request,
|
||||
allowed_endpoints=key.allowed_endpoints,
|
||||
blocked_endpoints=key.blocked_endpoints,
|
||||
allowed_ips=key.allowed_ips,
|
||||
status=key.status.value,
|
||||
created_at=key.created_at,
|
||||
expires_at=key.expires_at,
|
||||
last_rotated=key.last_rotated,
|
||||
usage=key.usage.to_dict()
|
||||
)
|
||||
for key in api_keys
|
||||
]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list API keys: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{key_id}", response_model=APIKeyResponse)
|
||||
async def get_api_key(
|
||||
key_id: str,
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get detailed information about a specific API key.
|
||||
"""
|
||||
try:
|
||||
# Get user's keys and find the requested one
|
||||
user_keys = await api_key_service.list_user_api_keys(
|
||||
owner_id=current_user,
|
||||
capability_token=authorization,
|
||||
include_usage=True
|
||||
)
|
||||
|
||||
api_key = next((key for key in user_keys if key.id == key_id), None)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
|
||||
return APIKeyResponse(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
description=api_key.description,
|
||||
owner_id=api_key.owner_id,
|
||||
scope=api_key.scope.value,
|
||||
capabilities=api_key.capabilities,
|
||||
rate_limit_per_hour=api_key.rate_limit_per_hour,
|
||||
daily_quota=api_key.daily_quota,
|
||||
cost_limit_cents=api_key.cost_limit_cents,
|
||||
max_tokens_per_request=api_key.max_tokens_per_request,
|
||||
allowed_endpoints=api_key.allowed_endpoints,
|
||||
blocked_endpoints=api_key.blocked_endpoints,
|
||||
allowed_ips=api_key.allowed_ips,
|
||||
status=api_key.status.value,
|
||||
created_at=api_key.created_at,
|
||||
expires_at=api_key.expires_at,
|
||||
last_rotated=api_key.last_rotated,
|
||||
usage=api_key.usage.to_dict()
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get API key: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{key_id}/rotate", response_model=RotateAPIKeyResponse)
|
||||
async def rotate_api_key(
|
||||
key_id: str,
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Rotate API key (generate new key value).
|
||||
|
||||
The old key will be invalidated and a new key will be generated.
|
||||
"""
|
||||
try:
|
||||
api_key, new_raw_key = await api_key_service.rotate_api_key(
|
||||
key_id=key_id,
|
||||
owner_id=current_user,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
return RotateAPIKeyResponse(
|
||||
api_key=APIKeyResponse(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
description=api_key.description,
|
||||
owner_id=api_key.owner_id,
|
||||
scope=api_key.scope.value,
|
||||
capabilities=api_key.capabilities,
|
||||
rate_limit_per_hour=api_key.rate_limit_per_hour,
|
||||
daily_quota=api_key.daily_quota,
|
||||
cost_limit_cents=api_key.cost_limit_cents,
|
||||
max_tokens_per_request=api_key.max_tokens_per_request,
|
||||
allowed_endpoints=api_key.allowed_endpoints,
|
||||
blocked_endpoints=api_key.blocked_endpoints,
|
||||
allowed_ips=api_key.allowed_ips,
|
||||
status=api_key.status.value,
|
||||
created_at=api_key.created_at,
|
||||
expires_at=api_key.expires_at,
|
||||
last_rotated=api_key.last_rotated,
|
||||
usage=api_key.usage.to_dict()
|
||||
),
|
||||
new_raw_key=new_raw_key,
|
||||
warning="Update all systems using this API key with the new value."
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to rotate API key: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{key_id}/revoke")
|
||||
async def revoke_api_key(
|
||||
key_id: str,
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Revoke API key (mark as revoked and disable access).
|
||||
|
||||
Revoked keys cannot be restored and will immediately stop working.
|
||||
"""
|
||||
try:
|
||||
success = await api_key_service.revoke_api_key(
|
||||
key_id=key_id,
|
||||
owner_id=current_user,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"API key {key_id} has been revoked",
|
||||
"key_id": key_id
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to revoke API key: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/validate", response_model=ValidateAPIKeyResponse)
|
||||
async def validate_api_key(
|
||||
request: ValidateAPIKeyRequest,
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service)
|
||||
):
|
||||
"""
|
||||
Validate an API key and get capability token.
|
||||
|
||||
This endpoint is used by other services to validate API keys
|
||||
and generate capability tokens for resource access.
|
||||
"""
|
||||
try:
|
||||
valid, api_key, error_message = await api_key_service.validate_api_key(
|
||||
raw_key=request.api_key,
|
||||
endpoint=request.endpoint or "",
|
||||
client_ip=request.client_ip or "",
|
||||
user_agent=""
|
||||
)
|
||||
|
||||
response = ValidateAPIKeyResponse(
|
||||
valid=valid,
|
||||
error_message=error_message,
|
||||
capability_token=None,
|
||||
rate_limit_remaining=None,
|
||||
quota_remaining=None
|
||||
)
|
||||
|
||||
if valid and api_key:
|
||||
# Generate capability token
|
||||
capability_token = await api_key_service.generate_capability_token(api_key)
|
||||
response.capability_token = capability_token
|
||||
|
||||
# Add rate limit and quota info
|
||||
response.rate_limit_remaining = max(0, api_key.rate_limit_per_hour - api_key.usage.requests_count)
|
||||
response.quota_remaining = max(0, api_key.daily_quota - api_key.usage.requests_count)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to validate API key: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{key_id}/usage", response_model=APIKeyUsageResponse)
|
||||
async def get_api_key_usage(
|
||||
key_id: str,
|
||||
days: int = Query(30, description="Number of days to analyze"),
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get usage analytics for a specific API key.
|
||||
|
||||
- **days**: Number of days to analyze (default 30)
|
||||
"""
|
||||
try:
|
||||
analytics = await api_key_service.get_usage_analytics(
|
||||
owner_id=current_user,
|
||||
key_id=key_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return APIKeyUsageResponse(
|
||||
total_requests=analytics["total_requests"],
|
||||
total_errors=analytics["total_errors"],
|
||||
avg_requests_per_day=analytics["avg_requests_per_day"],
|
||||
rate_limit_hits=analytics["rate_limit_hits"],
|
||||
keys_analyzed=analytics["keys_analyzed"],
|
||||
date_range=analytics["date_range"],
|
||||
most_used_endpoints=analytics.get("most_used_endpoints", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get usage analytics: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/analytics/summary", response_model=APIKeyUsageResponse)
|
||||
async def get_usage_summary(
|
||||
days: int = Query(30, description="Number of days to analyze"),
|
||||
authorization: str = Header(...),
|
||||
api_key_service: EnhancedAPIKeyService = Depends(get_api_key_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get usage analytics summary for all user's API keys.
|
||||
|
||||
- **days**: Number of days to analyze (default 30)
|
||||
"""
|
||||
try:
|
||||
analytics = await api_key_service.get_usage_analytics(
|
||||
owner_id=current_user,
|
||||
key_id=None, # All keys
|
||||
days=days
|
||||
)
|
||||
|
||||
return APIKeyUsageResponse(
|
||||
total_requests=analytics["total_requests"],
|
||||
total_errors=analytics["total_errors"],
|
||||
avg_requests_per_day=analytics["avg_requests_per_day"],
|
||||
rate_limit_hits=analytics["rate_limit_hits"],
|
||||
keys_analyzed=analytics["keys_analyzed"],
|
||||
date_range=analytics["date_range"],
|
||||
most_used_endpoints=analytics.get("most_used_endpoints", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get usage summary: {str(e)}")
|
||||
|
||||
|
||||
# Capability and scope catalogs for UI builders
|
||||
@router.get("/catalog/capabilities")
|
||||
async def get_capability_catalog():
|
||||
"""Get available capabilities for UI builders"""
|
||||
return {
|
||||
"capabilities": [
|
||||
# AI/ML Resources
|
||||
{"value": "llm:gpt-4", "label": "GPT-4 Language Model", "category": "AI/ML"},
|
||||
{"value": "llm:claude-sonnet", "label": "Claude Sonnet", "category": "AI/ML"},
|
||||
{"value": "llm:groq", "label": "Groq Models", "category": "AI/ML"},
|
||||
{"value": "embedding:openai", "label": "OpenAI Embeddings", "category": "AI/ML"},
|
||||
{"value": "image:dall-e", "label": "DALL-E Image Generation", "category": "AI/ML"},
|
||||
|
||||
# RAG & Knowledge
|
||||
{"value": "rag:search", "label": "RAG Search", "category": "Knowledge"},
|
||||
{"value": "rag:upload", "label": "Document Upload", "category": "Knowledge"},
|
||||
{"value": "rag:dataset_management", "label": "Dataset Management", "category": "Knowledge"},
|
||||
|
||||
# Automation
|
||||
{"value": "automation:create", "label": "Create Automations", "category": "Automation"},
|
||||
{"value": "automation:execute", "label": "Execute Automations", "category": "Automation"},
|
||||
{"value": "automation:api_calls", "label": "API Call Actions", "category": "Automation"},
|
||||
{"value": "automation:webhooks", "label": "Webhook Actions", "category": "Automation"},
|
||||
|
||||
# External Services
|
||||
{"value": "external:github", "label": "GitHub Integration", "category": "External"},
|
||||
{"value": "external:slack", "label": "Slack Integration", "category": "External"},
|
||||
|
||||
# Administrative
|
||||
{"value": "admin:user_management", "label": "User Management", "category": "Admin"},
|
||||
{"value": "admin:tenant_settings", "label": "Tenant Settings", "category": "Admin"}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/catalog/scopes")
|
||||
async def get_scope_catalog():
|
||||
"""Get available scopes for UI builders"""
|
||||
return {
|
||||
"scopes": [
|
||||
{
|
||||
"value": "user",
|
||||
"label": "User Scope",
|
||||
"description": "Access to user-specific operations and data",
|
||||
"default_limits": {
|
||||
"rate_limit_per_hour": 1000,
|
||||
"daily_quota": 10000,
|
||||
"cost_limit_cents": 1000
|
||||
}
|
||||
},
|
||||
{
|
||||
"value": "tenant",
|
||||
"label": "Tenant Scope",
|
||||
"description": "Access to tenant-wide operations and data",
|
||||
"default_limits": {
|
||||
"rate_limit_per_hour": 5000,
|
||||
"daily_quota": 50000,
|
||||
"cost_limit_cents": 5000
|
||||
}
|
||||
},
|
||||
{
|
||||
"value": "admin",
|
||||
"label": "Admin Scope",
|
||||
"description": "Administrative access with elevated privileges",
|
||||
"default_limits": {
|
||||
"rate_limit_per_hour": 10000,
|
||||
"daily_quota": 100000,
|
||||
"cost_limit_cents": 10000
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
860
apps/tenant-backend/app/api/v1/auth.py
Normal file
860
apps/tenant-backend/app/api/v1/auth.py
Normal file
@@ -0,0 +1,860 @@
|
||||
"""
|
||||
Authentication endpoints for Tenant Backend
|
||||
|
||||
Real authentication that connects to Control Panel Backend.
|
||||
No mocks - following GT 2.0 philosophy of building on real foundations.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from pydantic import BaseModel, EmailStr
|
||||
import jwt
|
||||
import structlog
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.security import create_capability_token
|
||||
from app.core.database import get_postgresql_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
security = HTTPBearer()
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
# Auth logging helper function (Issue #152)
|
||||
async def log_auth_event(
|
||||
event_type: str,
|
||||
email: str,
|
||||
user_id: str = None,
|
||||
success: bool = True,
|
||||
failure_reason: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
tenant_domain: str = None
|
||||
):
|
||||
"""
|
||||
Log authentication events to auth_logs table for security monitoring.
|
||||
|
||||
Args:
|
||||
event_type: 'login', 'logout', or 'failed_login'
|
||||
email: User's email address
|
||||
user_id: User ID (optional for failed logins)
|
||||
success: Whether the auth attempt succeeded
|
||||
failure_reason: Reason for failure (if applicable)
|
||||
ip_address: IP address of the request
|
||||
user_agent: User agent string from request
|
||||
tenant_domain: Tenant domain for the auth event
|
||||
"""
|
||||
try:
|
||||
if not tenant_domain:
|
||||
tenant_domain = settings.tenant_domain or "test-company"
|
||||
|
||||
schema_name = f"tenant_{tenant_domain.replace('-', '_')}"
|
||||
|
||||
client = await get_postgresql_client()
|
||||
query = f"""
|
||||
INSERT INTO {schema_name}.auth_logs (
|
||||
user_id,
|
||||
email,
|
||||
event_type,
|
||||
success,
|
||||
failure_reason,
|
||||
ip_address,
|
||||
user_agent,
|
||||
tenant_domain
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
"""
|
||||
|
||||
await client.execute_command(
|
||||
query,
|
||||
user_id or "unknown",
|
||||
email,
|
||||
event_type,
|
||||
success,
|
||||
failure_reason,
|
||||
ip_address,
|
||||
user_agent,
|
||||
tenant_domain
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Auth event logged",
|
||||
event_type=event_type,
|
||||
email=email,
|
||||
success=success,
|
||||
tenant=tenant_domain
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't fail the authentication if logging fails
|
||||
logger.error(f"Failed to log auth event: {e}", event_type=event_type, email=email)
|
||||
|
||||
|
||||
# Pydantic models
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
user: dict
|
||||
|
||||
|
||||
class TokenValidation(BaseModel):
|
||||
valid: bool
|
||||
user: Optional[dict] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
id: int
|
||||
email: str
|
||||
full_name: str
|
||||
user_type: str
|
||||
tenant_id: Optional[int]
|
||||
capabilities: list
|
||||
is_active: bool
|
||||
|
||||
|
||||
# Helper functions
|
||||
async def development_login(login_data: LoginRequest) -> LoginResponse:
|
||||
"""
|
||||
Development authentication that creates a valid JWT token
|
||||
for testing purposes when Control Panel is unavailable.
|
||||
"""
|
||||
# Simple development authentication - check for test credentials
|
||||
test_users = {
|
||||
"gtadmin@test.com": {"password": "password", "role": "admin"},
|
||||
"admin@test.com": {"password": "password", "role": "admin"},
|
||||
"test@example.com": {"password": "password", "role": "developer"}
|
||||
}
|
||||
|
||||
user_info = test_users.get(str(login_data.email).lower())
|
||||
if not user_info or login_data.password != user_info["password"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password"
|
||||
)
|
||||
|
||||
# Create capability token using GT 2.0 security
|
||||
# NIST compliant: 30 minutes idle timeout (Issue #242)
|
||||
token = create_capability_token(
|
||||
user_id=str(login_data.email),
|
||||
tenant_id="test-company",
|
||||
capabilities=[
|
||||
{"resource": "agents", "actions": ["read", "write", "delete"]},
|
||||
{"resource": "datasets", "actions": ["read", "write", "delete"]},
|
||||
{"resource": "conversations", "actions": ["read", "write", "delete"]}
|
||||
],
|
||||
expires_hours=0.5 # 30 minutes (NIST compliant)
|
||||
)
|
||||
|
||||
user_data = {
|
||||
"id": 1,
|
||||
"email": str(login_data.email),
|
||||
"full_name": "Test User",
|
||||
"role": user_info["role"],
|
||||
"tenant_id": "test-company",
|
||||
"tenant_domain": "test-company",
|
||||
"is_active": True
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Development login successful",
|
||||
email=login_data.email,
|
||||
tenant_id="test-company"
|
||||
)
|
||||
|
||||
return LoginResponse(
|
||||
access_token=token,
|
||||
expires_in=1800, # 30 minutes (NIST compliant)
|
||||
user=user_data
|
||||
)
|
||||
async def verify_token_with_control_panel(token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify JWT token with Control Panel Backend.
|
||||
This ensures consistency across the entire GT 2.0 platform.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/auth/verify-token",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get("success") and data.get("data", {}).get("valid"):
|
||||
return data["data"]
|
||||
|
||||
return {"valid": False, "error": "Invalid token"}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to verify token with control panel", error=str(e))
|
||||
# Fallback to local verification if control panel is unreachable
|
||||
try:
|
||||
# Use same RSA keys as token creation for consistency
|
||||
from app.core.security import verify_capability_token
|
||||
payload = verify_capability_token(token)
|
||||
if payload:
|
||||
return {"valid": True, "user": payload}
|
||||
else:
|
||||
return {"valid": False, "error": "Invalid token"}
|
||||
except Exception as e:
|
||||
logger.error(f"Local token verification failed: {e}")
|
||||
return {"valid": False, "error": "Token verification failed"}
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security)
|
||||
) -> UserInfo:
|
||||
"""
|
||||
Get current user from JWT token.
|
||||
Validates with Control Panel for consistency.
|
||||
"""
|
||||
token = credentials.credentials
|
||||
validation = await verify_token_with_control_panel(token)
|
||||
|
||||
if not validation.get("valid"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=validation.get("error", "Invalid authentication token"),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user_data = validation.get("user", {})
|
||||
|
||||
# Ensure user belongs to this tenant
|
||||
if settings.tenant_id and str(user_data.get("tenant_id")) != str(settings.tenant_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User does not belong to this tenant"
|
||||
)
|
||||
|
||||
return UserInfo(
|
||||
id=user_data.get("id"),
|
||||
email=user_data.get("email"),
|
||||
full_name=user_data.get("full_name"),
|
||||
user_type=user_data.get("user_type"),
|
||||
tenant_id=user_data.get("tenant_id"),
|
||||
capabilities=user_data.get("capabilities", []),
|
||||
is_active=user_data.get("is_active", True)
|
||||
)
|
||||
|
||||
|
||||
# API endpoints
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(
|
||||
login_data: LoginRequest,
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
Authenticate user via Control Panel Backend.
|
||||
For development, falls back to test authentication.
|
||||
"""
|
||||
logger.warning(f"Login attempt for {login_data.email}")
|
||||
logger.warning(f"Settings environment: {settings.environment}")
|
||||
try:
|
||||
# Try Control Panel first
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/login",
|
||||
json={
|
||||
"email": login_data.email,
|
||||
"password": login_data.password
|
||||
},
|
||||
headers={
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown"),
|
||||
"X-App-Type": "tenant_app" # Distinguish from control_panel sessions
|
||||
},
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
|
||||
# Verify user belongs to this tenant
|
||||
user = data.get("user", {})
|
||||
if settings.tenant_id and str(user.get("tenant_id")) != str(settings.tenant_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User does not belong to this tenant"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"User login successful via Control Panel",
|
||||
user_id=user.get("id"),
|
||||
email=user.get("email"),
|
||||
tenant_id=user.get("tenant_id")
|
||||
)
|
||||
|
||||
# Log successful login (Issue #152)
|
||||
await log_auth_event(
|
||||
event_type="login",
|
||||
email=user.get("email"),
|
||||
user_id=str(user.get("id")),
|
||||
success=True,
|
||||
ip_address=request.client.host if request.client else "unknown",
|
||||
user_agent=request.headers.get("user-agent", "unknown"),
|
||||
tenant_domain=settings.tenant_domain
|
||||
)
|
||||
|
||||
return LoginResponse(
|
||||
access_token=data["access_token"],
|
||||
expires_in=data.get("expires_in", 86400),
|
||||
user=user
|
||||
)
|
||||
else:
|
||||
# Control Panel returned non-200, fall back to development auth
|
||||
logger.warning(f"Control Panel returned {response.status_code}, using development auth")
|
||||
logger.warning(f"Environment is: {settings.environment}")
|
||||
if settings.environment == "development":
|
||||
logger.warning("Calling development_login fallback")
|
||||
return await development_login(login_data)
|
||||
else:
|
||||
logger.warning("Not in development mode, raising 401")
|
||||
|
||||
# Log failed login attempt (Issue #152)
|
||||
await log_auth_event(
|
||||
event_type="failed_login",
|
||||
email=login_data.email,
|
||||
success=False,
|
||||
failure_reason="Invalid credentials",
|
||||
ip_address=request.client.host if request.client else "unknown",
|
||||
user_agent=request.headers.get("user-agent", "unknown"),
|
||||
tenant_domain=settings.tenant_domain
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password"
|
||||
)
|
||||
|
||||
except (httpx.RequestError, httpx.TimeoutException) as e:
|
||||
logger.warning("Control Panel unavailable, using development auth", error=str(e))
|
||||
|
||||
# Development fallback authentication
|
||||
if settings.environment == "development":
|
||||
return await development_login(login_data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service unavailable"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Login error", error=str(e))
|
||||
|
||||
# Development fallback for any other errors
|
||||
if settings.environment == "development":
|
||||
logger.warning("Falling back to development auth due to error")
|
||||
return await development_login(login_data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Login failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Logout user (forward to Control Panel for audit logging).
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# Forward logout to Control Panel for audit logging
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/auth/logout",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(
|
||||
"User logout successful",
|
||||
user_id=current_user.id,
|
||||
email=current_user.email
|
||||
)
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
else:
|
||||
# Log locally even if Control Panel fails
|
||||
logger.warning(
|
||||
"Control Panel logout failed, but logging out locally",
|
||||
user_id=current_user.id,
|
||||
status_code=response.status_code
|
||||
)
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Logout error", error=str(e), user_id=current_user.id)
|
||||
# Always return success for logout
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_current_user_info(current_user: UserInfo = Depends(get_current_user)):
|
||||
"""
|
||||
Get current user information.
|
||||
"""
|
||||
return {
|
||||
"success": True,
|
||||
"data": current_user.dict()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/verify")
|
||||
async def verify_token(
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Verify if token is valid.
|
||||
"""
|
||||
return {
|
||||
"success": True,
|
||||
"valid": True,
|
||||
"user": current_user.dict()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Refresh authentication token via Control Panel Backend.
|
||||
Proxies the refresh request to maintain consistent token lifecycle.
|
||||
"""
|
||||
try:
|
||||
# Get current token
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No token provided"
|
||||
)
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
# Note: Control Panel auth endpoints are at /api/v1/* (not /api/v1/auth/*)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/refresh",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Handle successful refresh
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(
|
||||
"Token refresh successful",
|
||||
user_id=current_user.id,
|
||||
email=current_user.email
|
||||
)
|
||||
return {
|
||||
"access_token": data.get("access_token", token),
|
||||
"token_type": "bearer",
|
||||
"expires_in": data.get("expires_in", 86400),
|
||||
"user": current_user.dict()
|
||||
}
|
||||
|
||||
# Handle refresh failure (expired or invalid token)
|
||||
elif response.status_code == 401:
|
||||
logger.warning(
|
||||
"Token refresh failed - token expired or invalid",
|
||||
user_id=current_user.id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token refresh failed - please login again"
|
||||
)
|
||||
|
||||
# Handle other errors
|
||||
else:
|
||||
logger.error(
|
||||
"Token refresh unexpected response",
|
||||
status_code=response.status_code,
|
||||
user_id=current_user.id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail="Token refresh failed"
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to forward token refresh", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Token refresh service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Token refresh proxy error", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Token refresh failed"
|
||||
)
|
||||
|
||||
|
||||
# Two-Factor Authentication Proxy Endpoints
|
||||
@router.post("/tfa/enable")
|
||||
async def enable_tfa_proxy(
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Proxy TFA enable request to Control Panel Backend.
|
||||
User-initiated from settings page.
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/enable",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Return response with original status code
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response.json().get("detail", "TFA enable failed")
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to forward TFA enable request", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("TFA enable proxy error", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="TFA enable failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tfa/verify-setup")
|
||||
async def verify_setup_tfa_proxy(
|
||||
data: dict,
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Proxy TFA setup verification to Control Panel Backend.
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/verify-setup",
|
||||
json=data,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Return response with original status code
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response.json().get("detail", "TFA verification failed")
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to forward TFA verify setup", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("TFA verify setup proxy error", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="TFA verification failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tfa/disable")
|
||||
async def disable_tfa_proxy(
|
||||
data: dict,
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Proxy TFA disable request to Control Panel Backend.
|
||||
Requires password confirmation.
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/disable",
|
||||
json=data,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Return response with original status code
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response.json().get("detail", "TFA disable failed")
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to forward TFA disable request", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("TFA disable proxy error", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="TFA disable failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tfa/verify-login")
|
||||
async def verify_login_tfa_proxy(data: dict, request: Request):
|
||||
"""
|
||||
Proxy TFA login verification to Control Panel Backend.
|
||||
Called after password verification with temp token + 6-digit code.
|
||||
"""
|
||||
try:
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/verify-login",
|
||||
json=data,
|
||||
headers={
|
||||
"X-Forwarded-For": request.client.host if request.client else "unknown",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
# Return response with original status code
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response.json().get("detail", "TFA verification failed")
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to forward TFA verify login", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TFA service unavailable"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("TFA verify login proxy error", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="TFA verification failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tfa/status")
|
||||
async def get_tfa_status_proxy(
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Proxy TFA status request to Control Panel Backend.
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# Forward to Control Panel Backend
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/tfa/status",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to get TFA status", error=str(e))
|
||||
return {"tfa_enabled": False, "tfa_required": False, "tfa_status": "disabled"}
|
||||
except Exception as e:
|
||||
logger.error("TFA status proxy error", error=str(e))
|
||||
return {"tfa_enabled": False, "tfa_required": False, "tfa_status": "disabled"}
|
||||
|
||||
|
||||
class SessionStatusResponse(BaseModel):
|
||||
"""Response for session status check"""
|
||||
is_valid: bool
|
||||
seconds_remaining: int # Seconds until idle timeout
|
||||
show_warning: bool # True if < 5 minutes remaining
|
||||
absolute_seconds_remaining: Optional[int] = None # Seconds until absolute timeout
|
||||
|
||||
|
||||
@router.get("/session/status", response_model=SessionStatusResponse)
|
||||
async def get_session_status(
|
||||
request: Request,
|
||||
current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get current session status for frontend session monitoring.
|
||||
|
||||
Proxies request to Control Panel Backend which is the authoritative
|
||||
source for session state. This endpoint replaces the complex react-idle-timer
|
||||
approach with a simple polling mechanism.
|
||||
|
||||
Frontend calls this every 60 seconds to check session health.
|
||||
|
||||
Returns:
|
||||
- is_valid: Whether session is currently valid
|
||||
- seconds_remaining: Seconds until idle timeout (4 hours from last activity)
|
||||
- show_warning: True if warning should be shown (< 5 min remaining)
|
||||
- absolute_seconds_remaining: Seconds until absolute timeout (8 hours from login)
|
||||
"""
|
||||
try:
|
||||
# Get token from request
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No token provided"
|
||||
)
|
||||
|
||||
# Forward to Control Panel Backend (authoritative session source)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.control_panel_url}/api/v1/session/status",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": request.headers.get("user-agent", "unknown")
|
||||
},
|
||||
timeout=10.0 # Increased timeout for proxy/Cloudflare scenarios
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return SessionStatusResponse(
|
||||
is_valid=data.get("is_valid", True),
|
||||
seconds_remaining=data.get("seconds_remaining", 14400), # 4 hours default
|
||||
show_warning=data.get("show_warning", False),
|
||||
absolute_seconds_remaining=data.get("absolute_seconds_remaining")
|
||||
)
|
||||
|
||||
elif response.status_code == 401:
|
||||
# Session expired - return invalid status
|
||||
return SessionStatusResponse(
|
||||
is_valid=False,
|
||||
seconds_remaining=0,
|
||||
show_warning=False,
|
||||
absolute_seconds_remaining=None
|
||||
)
|
||||
|
||||
else:
|
||||
# Unexpected response - return safe defaults
|
||||
logger.warning(
|
||||
"Unexpected session status response",
|
||||
status_code=response.status_code,
|
||||
user_id=current_user.id
|
||||
)
|
||||
return SessionStatusResponse(
|
||||
is_valid=True,
|
||||
seconds_remaining=14400, # 4 hours default (matches IDLE_TIMEOUT_MINUTES)
|
||||
show_warning=False,
|
||||
absolute_seconds_remaining=None
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
# Control Panel unavailable - FAIL CLOSED for security
|
||||
# Return session invalid to force re-authentication
|
||||
logger.error(
|
||||
"Session status check failed - Control Panel unavailable",
|
||||
error=str(e),
|
||||
user_id=current_user.id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Session validation service unavailable"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Session status proxy error", error=str(e), user_id=current_user.id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Session status check failed"
|
||||
)
|
||||
219
apps/tenant-backend/app/api/v1/auth_logs.py
Normal file
219
apps/tenant-backend/app/api/v1/auth_logs.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Authentication Logs API Endpoints
|
||||
Issue: #152
|
||||
|
||||
Provides endpoints for querying authentication event logs including:
|
||||
- User logins
|
||||
- User logouts
|
||||
- Failed login attempts
|
||||
|
||||
Used by observability dashboard for security monitoring and audit trails.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.core.database import get_postgresql_client
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/auth-logs")
|
||||
async def get_auth_logs(
|
||||
event_type: Optional[str] = Query(None, description="Filter by event type: login, logout, failed_login"),
|
||||
start_date: Optional[datetime] = Query(None, description="Start date for filtering (ISO format)"),
|
||||
end_date: Optional[datetime] = Query(None, description="End date for filtering (ISO format)"),
|
||||
user_email: Optional[str] = Query(None, description="Filter by user email"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Maximum number of records to return"),
|
||||
offset: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get authentication logs with optional filtering.
|
||||
|
||||
Returns paginated list of authentication events including logins, logouts,
|
||||
and failed login attempts.
|
||||
"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test-company")
|
||||
schema_name = f"tenant_{tenant_domain.replace('-', '_')}"
|
||||
|
||||
# Build query conditions
|
||||
conditions = []
|
||||
params = []
|
||||
param_counter = 1
|
||||
|
||||
if event_type:
|
||||
if event_type not in ['login', 'logout', 'failed_login']:
|
||||
raise HTTPException(status_code=400, detail="Invalid event_type. Must be: login, logout, or failed_login")
|
||||
conditions.append(f"event_type = ${param_counter}")
|
||||
params.append(event_type)
|
||||
param_counter += 1
|
||||
|
||||
if start_date:
|
||||
conditions.append(f"created_at >= ${param_counter}")
|
||||
params.append(start_date)
|
||||
param_counter += 1
|
||||
|
||||
if end_date:
|
||||
conditions.append(f"created_at <= ${param_counter}")
|
||||
params.append(end_date)
|
||||
param_counter += 1
|
||||
|
||||
if user_email:
|
||||
conditions.append(f"email = ${param_counter}")
|
||||
params.append(user_email)
|
||||
param_counter += 1
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
|
||||
# Get total count
|
||||
client = await get_postgresql_client()
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as total
|
||||
FROM {schema_name}.auth_logs
|
||||
{where_clause}
|
||||
"""
|
||||
count_result = await client.fetch_one(count_query, *params)
|
||||
total_count = count_result['total'] if count_result else 0
|
||||
|
||||
# Get paginated results
|
||||
query = f"""
|
||||
SELECT
|
||||
id,
|
||||
user_id,
|
||||
email,
|
||||
event_type,
|
||||
success,
|
||||
failure_reason,
|
||||
ip_address,
|
||||
user_agent,
|
||||
tenant_domain,
|
||||
created_at,
|
||||
metadata
|
||||
FROM {schema_name}.auth_logs
|
||||
{where_clause}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ${param_counter} OFFSET ${param_counter + 1}
|
||||
"""
|
||||
params.extend([limit, offset])
|
||||
|
||||
logs = await client.fetch_all(query, *params)
|
||||
|
||||
# Format results
|
||||
formatted_logs = []
|
||||
for log in logs:
|
||||
formatted_logs.append({
|
||||
"id": str(log['id']),
|
||||
"user_id": log['user_id'],
|
||||
"email": log['email'],
|
||||
"event_type": log['event_type'],
|
||||
"success": log['success'],
|
||||
"failure_reason": log['failure_reason'],
|
||||
"ip_address": log['ip_address'],
|
||||
"user_agent": log['user_agent'],
|
||||
"tenant_domain": log['tenant_domain'],
|
||||
"created_at": log['created_at'].isoformat() if log['created_at'] else None,
|
||||
"metadata": log['metadata']
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Retrieved authentication logs",
|
||||
tenant=tenant_domain,
|
||||
count=len(formatted_logs),
|
||||
filters={"event_type": event_type, "user_email": user_email}
|
||||
)
|
||||
|
||||
return {
|
||||
"logs": formatted_logs,
|
||||
"pagination": {
|
||||
"total": total_count,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + limit) < total_count
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve auth logs: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/auth-logs/summary")
|
||||
async def get_auth_logs_summary(
|
||||
days: int = Query(7, ge=1, le=90, description="Number of days to summarize"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get authentication log summary statistics.
|
||||
|
||||
Returns aggregated counts of login events by type for the specified time period.
|
||||
"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test-company")
|
||||
schema_name = f"tenant_{tenant_domain.replace('-', '_')}"
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
client = await get_postgresql_client()
|
||||
query = f"""
|
||||
SELECT
|
||||
event_type,
|
||||
success,
|
||||
COUNT(*) as count
|
||||
FROM {schema_name}.auth_logs
|
||||
WHERE created_at >= $1 AND created_at <= $2
|
||||
GROUP BY event_type, success
|
||||
ORDER BY event_type, success
|
||||
"""
|
||||
|
||||
results = await client.fetch_all(query, start_date, end_date)
|
||||
|
||||
# Format summary
|
||||
summary = {
|
||||
"period_days": days,
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"events": {
|
||||
"successful_logins": 0,
|
||||
"failed_logins": 0,
|
||||
"logouts": 0,
|
||||
"total": 0
|
||||
}
|
||||
}
|
||||
|
||||
for row in results:
|
||||
event_type = row['event_type']
|
||||
success = row['success']
|
||||
count = row['count']
|
||||
|
||||
if event_type == 'login' and success:
|
||||
summary['events']['successful_logins'] = count
|
||||
elif event_type == 'failed_login' or (event_type == 'login' and not success):
|
||||
summary['events']['failed_logins'] += count
|
||||
elif event_type == 'logout':
|
||||
summary['events']['logouts'] = count
|
||||
|
||||
summary['events']['total'] += count
|
||||
|
||||
logger.info(
|
||||
"Retrieved auth logs summary",
|
||||
tenant=tenant_domain,
|
||||
days=days,
|
||||
total_events=summary['events']['total']
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve auth logs summary: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
539
apps/tenant-backend/app/api/v1/automation.py
Normal file
539
apps/tenant-backend/app/api/v1/automation.py
Normal file
@@ -0,0 +1,539 @@
|
||||
"""
|
||||
Automation Management API
|
||||
|
||||
REST endpoints for creating, managing, and monitoring automations
|
||||
with capability-based access control.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.services.event_bus import TenantEventBus, TriggerType, EVENT_CATALOG
|
||||
from app.services.automation_executor import AutomationChainExecutor
|
||||
from app.core.dependencies import get_current_user, get_tenant_domain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/automation", tags=["automation"])
|
||||
|
||||
|
||||
class AutomationCreate(BaseModel):
|
||||
"""Create automation request"""
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
trigger_type: str = Field(..., regex="^(cron|webhook|event|manual)$")
|
||||
trigger_config: Dict[str, Any] = Field(default_factory=dict)
|
||||
conditions: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
actions: List[Dict[str, Any]] = Field(..., min_items=1)
|
||||
is_active: bool = True
|
||||
max_retries: int = Field(default=3, ge=0, le=5)
|
||||
timeout_seconds: int = Field(default=300, ge=1, le=3600)
|
||||
|
||||
|
||||
class AutomationUpdate(BaseModel):
|
||||
"""Update automation request"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
trigger_config: Optional[Dict[str, Any]] = None
|
||||
conditions: Optional[List[Dict[str, Any]]] = None
|
||||
actions: Optional[List[Dict[str, Any]]] = None
|
||||
is_active: Optional[bool] = None
|
||||
max_retries: Optional[int] = Field(None, ge=0, le=5)
|
||||
timeout_seconds: Optional[int] = Field(None, ge=1, le=3600)
|
||||
|
||||
|
||||
class AutomationResponse(BaseModel):
|
||||
"""Automation response"""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
owner_id: str
|
||||
trigger_type: str
|
||||
trigger_config: Dict[str, Any]
|
||||
conditions: List[Dict[str, Any]]
|
||||
actions: List[Dict[str, Any]]
|
||||
is_active: bool
|
||||
max_retries: int
|
||||
timeout_seconds: int
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class TriggerAutomationRequest(BaseModel):
|
||||
"""Manual trigger request"""
|
||||
event_data: Dict[str, Any] = Field(default_factory=dict)
|
||||
variables: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@router.get("/catalog/events")
|
||||
async def get_event_catalog():
|
||||
"""Get available event types and their required fields"""
|
||||
return {
|
||||
"events": EVENT_CATALOG,
|
||||
"trigger_types": [trigger.value for trigger in TriggerType]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/catalog/actions")
|
||||
async def get_action_catalog():
|
||||
"""Get available action types and their configurations"""
|
||||
return {
|
||||
"actions": {
|
||||
"webhook": {
|
||||
"description": "Send HTTP request to external endpoint",
|
||||
"required_fields": ["url"],
|
||||
"optional_fields": ["method", "headers", "body"],
|
||||
"example": {
|
||||
"type": "webhook",
|
||||
"url": "https://api.example.com/notify",
|
||||
"method": "POST",
|
||||
"headers": {"Content-Type": "application/json"},
|
||||
"body": {"message": "Automation triggered"}
|
||||
}
|
||||
},
|
||||
"email": {
|
||||
"description": "Send email notification",
|
||||
"required_fields": ["to", "subject"],
|
||||
"optional_fields": ["body", "template"],
|
||||
"example": {
|
||||
"type": "email",
|
||||
"to": "user@example.com",
|
||||
"subject": "Automation Alert",
|
||||
"body": "Your automation has completed"
|
||||
}
|
||||
},
|
||||
"log": {
|
||||
"description": "Write to application logs",
|
||||
"required_fields": ["message"],
|
||||
"optional_fields": ["level"],
|
||||
"example": {
|
||||
"type": "log",
|
||||
"message": "Document processed successfully",
|
||||
"level": "info"
|
||||
}
|
||||
},
|
||||
"api_call": {
|
||||
"description": "Call internal or external API",
|
||||
"required_fields": ["endpoint"],
|
||||
"optional_fields": ["method", "headers", "body"],
|
||||
"example": {
|
||||
"type": "api_call",
|
||||
"endpoint": "/api/v1/documents/process",
|
||||
"method": "POST",
|
||||
"body": {"document_id": "${document_id}"}
|
||||
}
|
||||
},
|
||||
"data_transform": {
|
||||
"description": "Transform data between steps",
|
||||
"required_fields": ["transform_type", "source", "target"],
|
||||
"optional_fields": ["path", "mapping"],
|
||||
"example": {
|
||||
"type": "data_transform",
|
||||
"transform_type": "extract",
|
||||
"source": "api_response",
|
||||
"target": "document_id",
|
||||
"path": "data.document.id"
|
||||
}
|
||||
},
|
||||
"conditional": {
|
||||
"description": "Execute actions based on conditions",
|
||||
"required_fields": ["condition", "then"],
|
||||
"optional_fields": ["else"],
|
||||
"example": {
|
||||
"type": "conditional",
|
||||
"condition": {"left": "$status", "operator": "equals", "right": "success"},
|
||||
"then": [{"type": "log", "message": "Success"}],
|
||||
"else": [{"type": "log", "message": "Failed"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("", response_model=AutomationResponse)
|
||||
async def create_automation(
|
||||
automation: AutomationCreate,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Create a new automation"""
|
||||
try:
|
||||
# Initialize event bus
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Convert trigger type to enum
|
||||
trigger_type = TriggerType(automation.trigger_type)
|
||||
|
||||
# Create automation
|
||||
created_automation = await event_bus.create_automation(
|
||||
name=automation.name,
|
||||
owner_id=current_user,
|
||||
trigger_type=trigger_type,
|
||||
trigger_config=automation.trigger_config,
|
||||
actions=automation.actions,
|
||||
conditions=automation.conditions
|
||||
)
|
||||
|
||||
# Set additional properties
|
||||
created_automation.max_retries = automation.max_retries
|
||||
created_automation.timeout_seconds = automation.timeout_seconds
|
||||
created_automation.is_active = automation.is_active
|
||||
|
||||
# Log creation
|
||||
await event_bus.emit_event(
|
||||
event_type="automation.created",
|
||||
user_id=current_user,
|
||||
data={
|
||||
"automation_id": created_automation.id,
|
||||
"name": created_automation.name,
|
||||
"trigger_type": trigger_type.value
|
||||
}
|
||||
)
|
||||
|
||||
return AutomationResponse(
|
||||
id=created_automation.id,
|
||||
name=created_automation.name,
|
||||
description=automation.description,
|
||||
owner_id=created_automation.owner_id,
|
||||
trigger_type=trigger_type.value,
|
||||
trigger_config=created_automation.trigger_config,
|
||||
conditions=created_automation.conditions,
|
||||
actions=created_automation.actions,
|
||||
is_active=created_automation.is_active,
|
||||
max_retries=created_automation.max_retries,
|
||||
timeout_seconds=created_automation.timeout_seconds,
|
||||
created_at=created_automation.created_at.isoformat(),
|
||||
updated_at=created_automation.updated_at.isoformat()
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating automation: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create automation")
|
||||
|
||||
|
||||
@router.get("", response_model=List[AutomationResponse])
|
||||
async def list_automations(
|
||||
owner_only: bool = True,
|
||||
active_only: bool = False,
|
||||
trigger_type: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""List automations with optional filtering"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Get automations
|
||||
owner_filter = current_user if owner_only else None
|
||||
automations = await event_bus.list_automations(owner_id=owner_filter)
|
||||
|
||||
# Apply filters
|
||||
filtered = []
|
||||
for automation in automations:
|
||||
if active_only and not automation.is_active:
|
||||
continue
|
||||
|
||||
if trigger_type and automation.trigger_type.value != trigger_type:
|
||||
continue
|
||||
|
||||
filtered.append(AutomationResponse(
|
||||
id=automation.id,
|
||||
name=automation.name,
|
||||
description="", # Not stored in current model
|
||||
owner_id=automation.owner_id,
|
||||
trigger_type=automation.trigger_type.value,
|
||||
trigger_config=automation.trigger_config,
|
||||
conditions=automation.conditions,
|
||||
actions=automation.actions,
|
||||
is_active=automation.is_active,
|
||||
max_retries=automation.max_retries,
|
||||
timeout_seconds=automation.timeout_seconds,
|
||||
created_at=automation.created_at.isoformat(),
|
||||
updated_at=automation.updated_at.isoformat()
|
||||
))
|
||||
|
||||
# Apply limit
|
||||
return filtered[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing automations: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list automations")
|
||||
|
||||
|
||||
@router.get("/{automation_id}", response_model=AutomationResponse)
|
||||
async def get_automation(
|
||||
automation_id: str,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Get automation by ID"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
automation = await event_bus.get_automation(automation_id)
|
||||
|
||||
if not automation:
|
||||
raise HTTPException(status_code=404, detail="Automation not found")
|
||||
|
||||
# Check ownership
|
||||
if automation.owner_id != current_user:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
return AutomationResponse(
|
||||
id=automation.id,
|
||||
name=automation.name,
|
||||
description="",
|
||||
owner_id=automation.owner_id,
|
||||
trigger_type=automation.trigger_type.value,
|
||||
trigger_config=automation.trigger_config,
|
||||
conditions=automation.conditions,
|
||||
actions=automation.actions,
|
||||
is_active=automation.is_active,
|
||||
max_retries=automation.max_retries,
|
||||
timeout_seconds=automation.timeout_seconds,
|
||||
created_at=automation.created_at.isoformat(),
|
||||
updated_at=automation.updated_at.isoformat()
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting automation: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get automation")
|
||||
|
||||
|
||||
@router.delete("/{automation_id}")
|
||||
async def delete_automation(
|
||||
automation_id: str,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Delete automation"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Check if automation exists and user owns it
|
||||
automation = await event_bus.get_automation(automation_id)
|
||||
if not automation:
|
||||
raise HTTPException(status_code=404, detail="Automation not found")
|
||||
|
||||
if automation.owner_id != current_user:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Delete automation
|
||||
success = await event_bus.delete_automation(automation_id, current_user)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete automation")
|
||||
|
||||
# Log deletion
|
||||
await event_bus.emit_event(
|
||||
event_type="automation.deleted",
|
||||
user_id=current_user,
|
||||
data={
|
||||
"automation_id": automation_id,
|
||||
"name": automation.name
|
||||
}
|
||||
)
|
||||
|
||||
return {"status": "deleted", "automation_id": automation_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting automation: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete automation")
|
||||
|
||||
|
||||
@router.post("/{automation_id}/trigger")
|
||||
async def trigger_automation(
|
||||
automation_id: str,
|
||||
trigger_request: TriggerAutomationRequest,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Manually trigger an automation"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Get automation
|
||||
automation = await event_bus.get_automation(automation_id)
|
||||
if not automation:
|
||||
raise HTTPException(status_code=404, detail="Automation not found")
|
||||
|
||||
# Check ownership
|
||||
if automation.owner_id != current_user:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Check if automation supports manual triggering
|
||||
if automation.trigger_type != TriggerType.MANUAL:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Automation does not support manual triggering"
|
||||
)
|
||||
|
||||
# Create manual trigger event
|
||||
await event_bus.emit_event(
|
||||
event_type="automation.manual_trigger",
|
||||
user_id=current_user,
|
||||
data={
|
||||
"automation_id": automation_id,
|
||||
"trigger_data": trigger_request.event_data,
|
||||
"variables": trigger_request.variables
|
||||
},
|
||||
metadata={
|
||||
"trigger_type": TriggerType.MANUAL.value
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "triggered",
|
||||
"automation_id": automation_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering automation: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to trigger automation")
|
||||
|
||||
|
||||
@router.get("/{automation_id}/executions")
|
||||
async def get_execution_history(
|
||||
automation_id: str,
|
||||
limit: int = 10,
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Get execution history for automation"""
|
||||
try:
|
||||
# Check automation ownership first
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
automation = await event_bus.get_automation(automation_id)
|
||||
|
||||
if not automation:
|
||||
raise HTTPException(status_code=404, detail="Automation not found")
|
||||
|
||||
if automation.owner_id != current_user:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Get execution history
|
||||
executor = AutomationChainExecutor(tenant_domain, event_bus)
|
||||
executions = await executor.get_execution_history(automation_id, limit)
|
||||
|
||||
return {
|
||||
"automation_id": automation_id,
|
||||
"executions": executions,
|
||||
"total": len(executions)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting execution history: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get execution history")
|
||||
|
||||
|
||||
@router.post("/{automation_id}/test")
|
||||
async def test_automation(
|
||||
automation_id: str,
|
||||
test_data: Dict[str, Any] = {},
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Test automation with sample data"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Get automation
|
||||
automation = await event_bus.get_automation(automation_id)
|
||||
if not automation:
|
||||
raise HTTPException(status_code=404, detail="Automation not found")
|
||||
|
||||
# Check ownership
|
||||
if automation.owner_id != current_user:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Create test event
|
||||
test_event_type = automation.trigger_config.get("event_types", ["test.event"])[0]
|
||||
|
||||
await event_bus.emit_event(
|
||||
event_type=test_event_type,
|
||||
user_id=current_user,
|
||||
data={
|
||||
"test": True,
|
||||
"automation_id": automation_id,
|
||||
**test_data
|
||||
},
|
||||
metadata={
|
||||
"test_mode": True,
|
||||
"test_timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "test_triggered",
|
||||
"automation_id": automation_id,
|
||||
"test_event": test_event_type,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing automation: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to test automation")
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_automation_stats(
|
||||
current_user: str = Depends(get_current_user),
|
||||
tenant_domain: str = Depends(get_tenant_domain)
|
||||
):
|
||||
"""Get automation statistics for current user"""
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Get user's automations
|
||||
automations = await event_bus.list_automations(owner_id=current_user)
|
||||
|
||||
# Calculate stats
|
||||
total = len(automations)
|
||||
active = sum(1 for a in automations if a.is_active)
|
||||
by_trigger_type = {}
|
||||
|
||||
for automation in automations:
|
||||
trigger = automation.trigger_type.value
|
||||
by_trigger_type[trigger] = by_trigger_type.get(trigger, 0) + 1
|
||||
|
||||
# Get recent events
|
||||
recent_events = await event_bus.get_event_history(
|
||||
user_id=current_user,
|
||||
limit=10
|
||||
)
|
||||
|
||||
return {
|
||||
"total_automations": total,
|
||||
"active_automations": active,
|
||||
"inactive_automations": total - active,
|
||||
"by_trigger_type": by_trigger_type,
|
||||
"recent_events": [
|
||||
{
|
||||
"type": event.type,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"data": event.data
|
||||
}
|
||||
for event in recent_events
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting automation stats: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get automation stats")
|
||||
174
apps/tenant-backend/app/api/v1/categories.py
Normal file
174
apps/tenant-backend/app/api/v1/categories.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Category API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Provides tenant-scoped agent category management with CRUD operations.
|
||||
Supports Issue #215 requirements for editable/deletable categories.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
import logging
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.services.category_service import CategoryService
|
||||
from app.schemas.category import (
|
||||
CategoryCreate,
|
||||
CategoryUpdate,
|
||||
CategoryResponse,
|
||||
CategoryListResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/categories", tags=["categories"])
|
||||
|
||||
|
||||
async def get_category_service(current_user: Dict[str, Any]) -> CategoryService:
|
||||
"""Helper function to create CategoryService with proper context"""
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
# Get user ID from token or lookup
|
||||
user_id = current_user.get('sub', current_user.get('user_id', user_email))
|
||||
tenant_domain = current_user.get('tenant_domain', 'test-company')
|
||||
|
||||
return CategoryService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=str(user_id),
|
||||
user_email=user_email
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=CategoryListResponse)
|
||||
async def list_categories(
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all categories for the tenant.
|
||||
|
||||
Returns all active categories with permission flags indicating
|
||||
whether the current user can edit/delete each category.
|
||||
"""
|
||||
try:
|
||||
service = await get_category_service(current_user)
|
||||
categories = await service.get_all_categories()
|
||||
|
||||
return CategoryListResponse(
|
||||
categories=[CategoryResponse(**cat) for cat in categories],
|
||||
total=len(categories)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing categories: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{category_id}", response_model=CategoryResponse)
|
||||
async def get_category(
|
||||
category_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get a specific category by ID.
|
||||
"""
|
||||
try:
|
||||
service = await get_category_service(current_user)
|
||||
category = await service.get_category_by_id(category_id)
|
||||
|
||||
if not category:
|
||||
raise HTTPException(status_code=404, detail="Category not found")
|
||||
|
||||
return CategoryResponse(**category)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting category {category_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("", response_model=CategoryResponse, status_code=201)
|
||||
async def create_category(
|
||||
data: CategoryCreate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Create a new custom category.
|
||||
|
||||
The creating user becomes the owner and can edit/delete the category.
|
||||
All users in the tenant can use the category for their agents.
|
||||
"""
|
||||
try:
|
||||
service = await get_category_service(current_user)
|
||||
category = await service.create_category(
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
icon=data.icon
|
||||
)
|
||||
|
||||
return CategoryResponse(**category)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating category: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{category_id}", response_model=CategoryResponse)
|
||||
async def update_category(
|
||||
category_id: str,
|
||||
data: CategoryUpdate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update a category.
|
||||
|
||||
Requires permission: admin/developer role OR be the category creator.
|
||||
"""
|
||||
try:
|
||||
service = await get_category_service(current_user)
|
||||
category = await service.update_category(
|
||||
category_id=category_id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
icon=data.icon
|
||||
)
|
||||
|
||||
return CategoryResponse(**category)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating category {category_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{category_id}")
|
||||
async def delete_category(
|
||||
category_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Delete a category (soft delete).
|
||||
|
||||
Requires permission: admin/developer role OR be the category creator.
|
||||
Note: Agents using this category will retain their category value,
|
||||
but the category will no longer appear in selection lists.
|
||||
"""
|
||||
try:
|
||||
service = await get_category_service(current_user)
|
||||
await service.delete_category(category_id)
|
||||
|
||||
return {"message": "Category deleted successfully"}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting category {category_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
1563
apps/tenant-backend/app/api/v1/chat.py
Normal file
1563
apps/tenant-backend/app/api/v1/chat.py
Normal file
File diff suppressed because it is too large
Load Diff
631
apps/tenant-backend/app/api/v1/conversations.py
Normal file
631
apps/tenant-backend/app/api/v1/conversations.py
Normal file
@@ -0,0 +1,631 @@
|
||||
"""
|
||||
Conversation API endpoints for GT 2.0 Tenant Backend - PostgreSQL Migration
|
||||
|
||||
Basic conversation endpoints during PostgreSQL migration.
|
||||
Full functionality will be restored as migration completes.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import List, Optional, Dict, Any
|
||||
import logging
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.websocket.manager import websocket_manager
|
||||
|
||||
# TEMPORARY: Basic response schemas during migration
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from fastapi import File, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
class ConversationResponse(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
agent_id: Optional[str]
|
||||
agent_name: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
last_message_at: Optional[datetime] = None
|
||||
message_count: int = 0
|
||||
token_count: int = 0
|
||||
is_archived: bool = False
|
||||
unread_count: int = 0
|
||||
|
||||
class ConversationListResponse(BaseModel):
|
||||
conversations: List[ConversationResponse]
|
||||
total: int
|
||||
|
||||
# Message creation model
|
||||
class MessageCreate(BaseModel):
|
||||
"""Request body for creating a message in a conversation"""
|
||||
role: str = Field(..., description="Message role: user, assistant, agent, or system")
|
||||
content: str = Field(..., description="Message content (supports any length)")
|
||||
model_used: Optional[str] = Field(None, description="Model used to generate the message")
|
||||
token_count: int = Field(0, ge=0, description="Token count for the message")
|
||||
metadata: Optional[Dict] = Field(None, description="Additional message metadata")
|
||||
attachments: Optional[List] = Field(None, description="Message attachments")
|
||||
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/conversations", tags=["conversations"])
|
||||
|
||||
|
||||
@router.get("", response_model=ConversationListResponse)
|
||||
async def list_conversations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
agent_id: Optional[str] = Query(None, description="Filter by agent"),
|
||||
search: Optional[str] = Query(None, description="Search in conversation titles"),
|
||||
time_filter: str = Query("all", description="Filter by time: 'today', 'week', 'month', 'all'"),
|
||||
limit: int = Query(20, ge=1),
|
||||
offset: int = Query(0, ge=0)
|
||||
) -> ConversationListResponse:
|
||||
"""List user's conversations using PostgreSQL with server-side filtering"""
|
||||
try:
|
||||
# Extract tenant domain from user context
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
# Get conversations from PostgreSQL with filters
|
||||
result = await service.list_conversations(
|
||||
user_identifier=current_user["email"],
|
||||
agent_id=agent_id,
|
||||
search=search,
|
||||
time_filter=time_filter,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
conversations = [
|
||||
ConversationResponse(
|
||||
id=conv["id"],
|
||||
title=conv["title"],
|
||||
agent_id=conv["agent_id"],
|
||||
agent_name=conv.get("agent_name"),
|
||||
created_at=conv["created_at"],
|
||||
updated_at=conv.get("updated_at"),
|
||||
last_message_at=conv.get("last_message_at"),
|
||||
message_count=conv.get("message_count", 0),
|
||||
token_count=conv.get("token_count", 0),
|
||||
is_archived=conv.get("is_archived", False),
|
||||
unread_count=conv.get("unread_count", 0)
|
||||
)
|
||||
for conv in result["conversations"]
|
||||
]
|
||||
|
||||
return ConversationListResponse(
|
||||
conversations=conversations,
|
||||
total=result["total"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list conversations: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("", response_model=Dict[str, Any])
|
||||
async def create_conversation(
|
||||
agent_id: str,
|
||||
title: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new conversation with an agent"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
result = await service.create_conversation(
|
||||
agent_id=agent_id,
|
||||
title=title,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create conversation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{conversation_id}", response_model=Dict[str, Any])
|
||||
async def get_conversation(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get a specific conversation with details"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
result = await service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/messages", response_model=List[Dict[str, Any]])
|
||||
async def get_conversation_messages(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0)
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get messages for a conversation"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
messages = await service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"],
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages for conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{conversation_id}/messages", response_model=Dict[str, Any])
|
||||
async def add_message(
|
||||
conversation_id: str,
|
||||
message: MessageCreate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a message to a conversation (supports messages of any length)"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
result = await service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role=message.role,
|
||||
content=message.content,
|
||||
user_identifier=current_user["email"],
|
||||
model_used=message.model_used,
|
||||
token_count=message.token_count,
|
||||
metadata=message.metadata,
|
||||
attachments=message.attachments
|
||||
)
|
||||
|
||||
# Broadcast message creation via WebSocket for unread tracking and sidebar updates
|
||||
try:
|
||||
# Get updated conversation details (message count, timestamp)
|
||||
conv_details = await service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
# Import the Socket.IO broadcast function
|
||||
from app.websocket.manager import broadcast_conversation_update
|
||||
|
||||
await broadcast_conversation_update(
|
||||
conversation_id=conversation_id,
|
||||
event='conversation:message_added',
|
||||
data={
|
||||
'conversation_id': conversation_id,
|
||||
'message_id': result.get('id'),
|
||||
'sender_id': current_user.get('id'),
|
||||
'role': message.role,
|
||||
'content': message.content[:100], # First 100 chars for preview
|
||||
'message_count': conv_details.get('message_count', conv_details.get('total_messages', 0)),
|
||||
'last_message_at': result.get('created_at'),
|
||||
'title': conv_details.get('title', 'New Conversation')
|
||||
}
|
||||
)
|
||||
except Exception as ws_error:
|
||||
logger.warning(f"Failed to broadcast message via WebSocket: {ws_error}")
|
||||
# Don't fail the request if WebSocket broadcast fails
|
||||
|
||||
# Check if we should generate a title after this message
|
||||
if message.role == "agent" or message.role == "assistant":
|
||||
# This is an AI response - check if it's after the first user message
|
||||
try:
|
||||
# Get all messages in conversation
|
||||
messages = await service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
# Count user and agent messages
|
||||
user_messages = [m for m in messages if m["role"] == "user"]
|
||||
agent_messages = [m for m in messages if m["role"] in ["agent", "assistant"]]
|
||||
|
||||
# Generate title if this is the first exchange (1 user + 1 agent message)
|
||||
if len(user_messages) == 1 and len(agent_messages) == 1:
|
||||
logger.info(f"🎯 First exchange complete, generating title for conversation {conversation_id}")
|
||||
try:
|
||||
await service.auto_generate_conversation_title(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
logger.info(f"✅ Title generated for conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate title for conversation {conversation_id}: {e}")
|
||||
# Don't fail the request if title generation fails
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for title generation: {e}")
|
||||
# Don't fail the request if title check fails
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add message to conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{conversation_id}")
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
title: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Update a conversation (e.g., rename title)"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
success = await service.update_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"],
|
||||
title=title
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found or access denied")
|
||||
|
||||
return {"message": "Conversation updated successfully", "conversation_id": conversation_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{conversation_id}")
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, str]:
|
||||
"""Delete a conversation (soft delete)"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
success = await service.delete_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found or access denied")
|
||||
|
||||
return {"message": "Conversation deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/recent", response_model=ConversationListResponse)
|
||||
async def get_recent_conversations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
days_back: int = Query(7, ge=1, le=30, description="Number of days back"),
|
||||
max_conversations: int = Query(10, ge=1, le=25, description="Maximum conversations"),
|
||||
include_archived: bool = Query(False, description="Include deleted conversations")
|
||||
):
|
||||
"""
|
||||
Get recent conversation summaries.
|
||||
|
||||
Used by MCP conversation server for recent activity context.
|
||||
"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
# Get recent conversations
|
||||
result = await service.list_conversations(
|
||||
user_identifier=current_user["email"],
|
||||
limit=max_conversations,
|
||||
offset=0,
|
||||
order_by="updated_at",
|
||||
order_direction="desc"
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
conversations = [
|
||||
ConversationResponse(
|
||||
id=conv["id"],
|
||||
title=conv["title"],
|
||||
agent_id=conv["agent_id"],
|
||||
created_at=conv["created_at"]
|
||||
)
|
||||
for conv in result["conversations"]
|
||||
]
|
||||
|
||||
return ConversationListResponse(
|
||||
conversations=conversations,
|
||||
total=len(conversations)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get recent conversations: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Dataset management models
|
||||
class AddDatasetsRequest(BaseModel):
|
||||
"""Request to add datasets to a conversation"""
|
||||
dataset_ids: List[str] = Field(..., min_items=1, description="Dataset IDs to add to conversation")
|
||||
|
||||
class DatasetOperationResponse(BaseModel):
|
||||
"""Response for dataset operations"""
|
||||
success: bool
|
||||
message: str
|
||||
conversation_id: str
|
||||
dataset_count: int
|
||||
|
||||
# Conversation file models
|
||||
class ConversationFileResponse(BaseModel):
|
||||
"""Response for conversation file operations"""
|
||||
id: str
|
||||
filename: str
|
||||
original_filename: str
|
||||
content_type: str
|
||||
file_size_bytes: int
|
||||
processing_status: str
|
||||
uploaded_at: datetime
|
||||
processed_at: Optional[datetime] = None
|
||||
|
||||
class ConversationFileListResponse(BaseModel):
|
||||
"""Response for listing conversation files"""
|
||||
conversation_id: str
|
||||
files: List[ConversationFileResponse]
|
||||
total_files: int
|
||||
|
||||
|
||||
@router.post("/{conversation_id}/datasets", response_model=DatasetOperationResponse)
|
||||
async def add_datasets_to_conversation(
|
||||
conversation_id: str,
|
||||
request: AddDatasetsRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> DatasetOperationResponse:
|
||||
"""Add datasets to a conversation for context awareness"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
# Add datasets to conversation
|
||||
success = await service.add_datasets_to_conversation(
|
||||
conversation_id=conversation_id,
|
||||
dataset_ids=request.dataset_ids,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Failed to add datasets to conversation. Check conversation exists and you have access."
|
||||
)
|
||||
|
||||
logger.info(f"Added {len(request.dataset_ids)} datasets to conversation {conversation_id}")
|
||||
|
||||
return DatasetOperationResponse(
|
||||
success=True,
|
||||
message=f"Successfully added {len(request.dataset_ids)} dataset(s) to conversation",
|
||||
conversation_id=conversation_id,
|
||||
dataset_count=len(request.dataset_ids)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add datasets to conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/datasets")
|
||||
async def get_conversation_datasets(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get datasets associated with a conversation"""
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = ConversationService(tenant_domain, current_user["email"])
|
||||
|
||||
dataset_ids = await service.get_conversation_datasets(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=current_user["email"]
|
||||
)
|
||||
|
||||
return {
|
||||
"conversation_id": conversation_id,
|
||||
"dataset_ids": dataset_ids,
|
||||
"dataset_count": len(dataset_ids)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get datasets for conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Conversation File Management Endpoints
|
||||
|
||||
@router.post("/{conversation_id}/files", response_model=List[ConversationFileResponse])
|
||||
async def upload_conversation_files(
|
||||
conversation_id: str,
|
||||
files: List[UploadFile] = File(...),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> List[ConversationFileResponse]:
|
||||
"""Upload files directly to conversation (replaces dataset-based uploads)"""
|
||||
try:
|
||||
from app.services.conversation_file_service import get_conversation_file_service
|
||||
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = get_conversation_file_service(tenant_domain, current_user["email"])
|
||||
|
||||
# Upload files to conversation
|
||||
uploaded_files = await service.upload_files(
|
||||
conversation_id=conversation_id,
|
||||
files=files,
|
||||
user_id=current_user["email"]
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
file_responses = []
|
||||
for file_data in uploaded_files:
|
||||
file_responses.append(ConversationFileResponse(
|
||||
id=file_data["id"],
|
||||
filename=file_data["filename"],
|
||||
original_filename=file_data["original_filename"],
|
||||
content_type=file_data["content_type"],
|
||||
file_size_bytes=file_data["file_size_bytes"],
|
||||
processing_status=file_data["processing_status"],
|
||||
uploaded_at=file_data["uploaded_at"],
|
||||
processed_at=file_data.get("processed_at")
|
||||
))
|
||||
|
||||
logger.info(f"Uploaded {len(uploaded_files)} files to conversation {conversation_id}")
|
||||
return file_responses
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload files to conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/files", response_model=ConversationFileListResponse)
|
||||
async def list_conversation_files(
|
||||
conversation_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ConversationFileListResponse:
|
||||
"""List files attached to conversation"""
|
||||
try:
|
||||
from app.services.conversation_file_service import get_conversation_file_service
|
||||
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = get_conversation_file_service(tenant_domain, current_user["email"])
|
||||
|
||||
files = await service.list_files(conversation_id)
|
||||
|
||||
# Convert to response format
|
||||
file_responses = []
|
||||
for file_data in files:
|
||||
file_responses.append(ConversationFileResponse(
|
||||
id=str(file_data["id"]), # Convert UUID to string
|
||||
filename=file_data["filename"],
|
||||
original_filename=file_data["original_filename"],
|
||||
content_type=file_data["content_type"],
|
||||
file_size_bytes=file_data["file_size_bytes"],
|
||||
processing_status=file_data["processing_status"],
|
||||
uploaded_at=file_data["uploaded_at"],
|
||||
processed_at=file_data.get("processed_at")
|
||||
))
|
||||
|
||||
return ConversationFileListResponse(
|
||||
conversation_id=conversation_id,
|
||||
files=file_responses,
|
||||
total_files=len(file_responses)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list files for conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{conversation_id}/files/{file_id}")
|
||||
async def delete_conversation_file(
|
||||
conversation_id: str,
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, str]:
|
||||
"""Delete specific file from conversation"""
|
||||
try:
|
||||
from app.services.conversation_file_service import get_conversation_file_service
|
||||
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = get_conversation_file_service(tenant_domain, current_user["email"])
|
||||
|
||||
success = await service.delete_file(
|
||||
conversation_id=conversation_id,
|
||||
file_id=file_id,
|
||||
user_id=current_user["email"]
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="File not found or access denied")
|
||||
|
||||
return {"message": "File deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file {file_id} from conversation {conversation_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/files/{file_id}/download")
|
||||
async def download_conversation_file(
|
||||
conversation_id: str,
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Download specific conversation file"""
|
||||
try:
|
||||
from app.services.conversation_file_service import get_conversation_file_service
|
||||
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
service = get_conversation_file_service(tenant_domain, current_user["email"])
|
||||
|
||||
# Get file record for metadata
|
||||
file_record = await service._get_file_record(file_id)
|
||||
if not file_record or file_record['conversation_id'] != conversation_id:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Get file content
|
||||
content = await service.get_file_content(file_id, current_user["email"])
|
||||
if not content:
|
||||
raise HTTPException(status_code=404, detail="File content not found")
|
||||
|
||||
# Return file as streaming response
|
||||
def iter_content():
|
||||
yield content
|
||||
|
||||
return StreamingResponse(
|
||||
iter_content(),
|
||||
media_type=file_record['content_type'],
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=\"{file_record['original_filename']}\"",
|
||||
"Content-Length": str(file_record['file_size_bytes'])
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download file {file_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
425
apps/tenant-backend/app/api/v1/dataset_sharing.py
Normal file
425
apps/tenant-backend/app/api/v1/dataset_sharing.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Dataset Sharing API for GT 2.0
|
||||
|
||||
RESTful API for hierarchical dataset sharing with capability-based access control.
|
||||
Enables secure collaboration while maintaining perfect tenant isolation.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import APIRouter, HTTPException, Depends, Header, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.security import get_current_user, verify_capability_token
|
||||
from app.services.dataset_sharing import (
|
||||
DatasetSharingService, DatasetShare, DatasetInfo, SharingPermission
|
||||
)
|
||||
from app.services.access_controller import AccessController
|
||||
from app.models.access_group import AccessGroup
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class ShareDatasetRequest(BaseModel):
|
||||
"""Request to share a dataset"""
|
||||
dataset_id: str = Field(..., description="Dataset ID to share")
|
||||
access_group: str = Field(..., description="Access group: individual, team, organization")
|
||||
team_members: Optional[List[str]] = Field(None, description="Team members for team sharing")
|
||||
team_permissions: Optional[Dict[str, str]] = Field(None, description="Permissions for team members")
|
||||
expires_days: Optional[int] = Field(None, description="Expiration in days")
|
||||
|
||||
|
||||
class UpdatePermissionRequest(BaseModel):
|
||||
"""Request to update team member permissions"""
|
||||
user_id: str = Field(..., description="User ID to update")
|
||||
permission: str = Field(..., description="Permission: read, write, admin")
|
||||
|
||||
|
||||
class DatasetShareResponse(BaseModel):
|
||||
"""Dataset sharing configuration response"""
|
||||
id: str
|
||||
dataset_id: str
|
||||
owner_id: str
|
||||
access_group: str
|
||||
team_members: List[str]
|
||||
team_permissions: Dict[str, str]
|
||||
shared_at: datetime
|
||||
expires_at: Optional[datetime]
|
||||
is_active: bool
|
||||
|
||||
|
||||
class DatasetInfoResponse(BaseModel):
|
||||
"""Dataset information response"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
owner_id: str
|
||||
document_count: int
|
||||
size_bytes: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
tags: List[str]
|
||||
|
||||
|
||||
class SharingStatsResponse(BaseModel):
|
||||
"""Sharing statistics response"""
|
||||
owned_datasets: int
|
||||
shared_with_me: int
|
||||
sharing_breakdown: Dict[str, int]
|
||||
total_team_members: int
|
||||
expired_shares: int
|
||||
|
||||
|
||||
# Dependency injection
|
||||
async def get_dataset_sharing_service(
|
||||
authorization: str = Header(...),
|
||||
current_user: str = Depends(get_current_user)
|
||||
) -> DatasetSharingService:
|
||||
"""Get dataset sharing service with access controller"""
|
||||
# Extract tenant from token (mock implementation)
|
||||
tenant_domain = "customer1.com" # Would extract from JWT
|
||||
|
||||
access_controller = AccessController(tenant_domain)
|
||||
return DatasetSharingService(tenant_domain, access_controller)
|
||||
|
||||
|
||||
@router.post("/share", response_model=DatasetShareResponse)
|
||||
async def share_dataset(
|
||||
request: ShareDatasetRequest,
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Share a dataset with specified access group.
|
||||
|
||||
- **dataset_id**: Dataset to share
|
||||
- **access_group**: individual, team, or organization
|
||||
- **team_members**: Required for team sharing
|
||||
- **team_permissions**: Optional custom permissions
|
||||
- **expires_days**: Optional expiration period
|
||||
"""
|
||||
try:
|
||||
# Convert access group string to enum
|
||||
access_group = AccessGroup(request.access_group.lower())
|
||||
|
||||
# Convert permission strings to enums
|
||||
team_permissions = {}
|
||||
if request.team_permissions:
|
||||
for user, perm in request.team_permissions.items():
|
||||
team_permissions[user] = SharingPermission(perm.lower())
|
||||
|
||||
# Calculate expiration
|
||||
expires_at = None
|
||||
if request.expires_days:
|
||||
expires_at = datetime.utcnow() + timedelta(days=request.expires_days)
|
||||
|
||||
# Share dataset
|
||||
share = await sharing_service.share_dataset(
|
||||
dataset_id=request.dataset_id,
|
||||
owner_id=current_user,
|
||||
access_group=access_group,
|
||||
team_members=request.team_members,
|
||||
team_permissions=team_permissions,
|
||||
expires_at=expires_at,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
return DatasetShareResponse(
|
||||
id=share.id,
|
||||
dataset_id=share.dataset_id,
|
||||
owner_id=share.owner_id,
|
||||
access_group=share.access_group.value,
|
||||
team_members=share.team_members,
|
||||
team_permissions={k: v.value for k, v in share.team_permissions.items()},
|
||||
shared_at=share.shared_at,
|
||||
expires_at=share.expires_at,
|
||||
is_active=share.is_active
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to share dataset: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{dataset_id}", response_model=DatasetShareResponse)
|
||||
async def get_dataset_sharing(
|
||||
dataset_id: str,
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get sharing configuration for a dataset.
|
||||
|
||||
Returns sharing details if user has access to view them.
|
||||
"""
|
||||
try:
|
||||
share = await sharing_service.get_dataset_sharing(
|
||||
dataset_id=dataset_id,
|
||||
user_id=current_user,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
if not share:
|
||||
raise HTTPException(status_code=404, detail="Dataset sharing not found or access denied")
|
||||
|
||||
return DatasetShareResponse(
|
||||
id=share.id,
|
||||
dataset_id=share.dataset_id,
|
||||
owner_id=share.owner_id,
|
||||
access_group=share.access_group.value,
|
||||
team_members=share.team_members,
|
||||
team_permissions={k: v.value for k, v in share.team_permissions.items()},
|
||||
shared_at=share.shared_at,
|
||||
expires_at=share.expires_at,
|
||||
is_active=share.is_active
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get dataset sharing: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{dataset_id}/access-check")
|
||||
async def check_dataset_access(
|
||||
dataset_id: str,
|
||||
permission: str = Query("read", description="Permission to check: read, write, admin"),
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Check if user has specified permission on dataset.
|
||||
|
||||
- **permission**: read, write, or admin
|
||||
"""
|
||||
try:
|
||||
# Convert permission string to enum
|
||||
required_permission = SharingPermission(permission.lower())
|
||||
|
||||
allowed, reason = await sharing_service.check_dataset_access(
|
||||
dataset_id=dataset_id,
|
||||
user_id=current_user,
|
||||
permission=required_permission
|
||||
)
|
||||
|
||||
return {
|
||||
"allowed": allowed,
|
||||
"reason": reason,
|
||||
"permission": permission,
|
||||
"user_id": current_user,
|
||||
"dataset_id": dataset_id
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid permission: {str(e)}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to check access: {str(e)}")
|
||||
|
||||
|
||||
@router.get("", response_model=List[DatasetInfoResponse])
|
||||
async def list_accessible_datasets(
|
||||
include_owned: bool = Query(True, description="Include user's own datasets"),
|
||||
include_shared: bool = Query(True, description="Include datasets shared with user"),
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List datasets accessible to user.
|
||||
|
||||
- **include_owned**: Include user's own datasets
|
||||
- **include_shared**: Include datasets shared with user
|
||||
"""
|
||||
try:
|
||||
datasets = await sharing_service.list_accessible_datasets(
|
||||
user_id=current_user,
|
||||
capability_token=authorization,
|
||||
include_owned=include_owned,
|
||||
include_shared=include_shared
|
||||
)
|
||||
|
||||
return [
|
||||
DatasetInfoResponse(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
owner_id=dataset.owner_id,
|
||||
document_count=dataset.document_count,
|
||||
size_bytes=dataset.size_bytes,
|
||||
created_at=dataset.created_at,
|
||||
updated_at=dataset.updated_at,
|
||||
tags=dataset.tags
|
||||
)
|
||||
for dataset in datasets
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list datasets: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{dataset_id}/revoke")
|
||||
async def revoke_dataset_sharing(
|
||||
dataset_id: str,
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Revoke dataset sharing (make it private).
|
||||
|
||||
Only the dataset owner can revoke sharing.
|
||||
"""
|
||||
try:
|
||||
success = await sharing_service.revoke_dataset_sharing(
|
||||
dataset_id=dataset_id,
|
||||
owner_id=current_user,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to revoke sharing")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Dataset {dataset_id} sharing revoked",
|
||||
"dataset_id": dataset_id
|
||||
}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to revoke sharing: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/{dataset_id}/permissions")
|
||||
async def update_team_permissions(
|
||||
dataset_id: str,
|
||||
request: UpdatePermissionRequest,
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update team member permissions for a dataset.
|
||||
|
||||
Only the dataset owner can update permissions.
|
||||
"""
|
||||
try:
|
||||
# Convert permission string to enum
|
||||
permission = SharingPermission(request.permission.lower())
|
||||
|
||||
success = await sharing_service.update_team_permissions(
|
||||
dataset_id=dataset_id,
|
||||
owner_id=current_user,
|
||||
user_id=request.user_id,
|
||||
permission=permission,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to update permissions")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Updated {request.user_id} permission to {request.permission}",
|
||||
"dataset_id": dataset_id,
|
||||
"user_id": request.user_id,
|
||||
"permission": request.permission
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid permission: {str(e)}")
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update permissions: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/stats/sharing", response_model=SharingStatsResponse)
|
||||
async def get_sharing_statistics(
|
||||
authorization: str = Header(...),
|
||||
sharing_service: DatasetSharingService = Depends(get_dataset_sharing_service),
|
||||
current_user: str = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get sharing statistics for current user.
|
||||
|
||||
Returns counts of owned, shared datasets and sharing breakdown.
|
||||
"""
|
||||
try:
|
||||
stats = await sharing_service.get_sharing_statistics(
|
||||
user_id=current_user,
|
||||
capability_token=authorization
|
||||
)
|
||||
|
||||
# Convert AccessGroup enum keys to strings
|
||||
sharing_breakdown = {
|
||||
group.value: count for group, count in stats["sharing_breakdown"].items()
|
||||
}
|
||||
|
||||
return SharingStatsResponse(
|
||||
owned_datasets=stats["owned_datasets"],
|
||||
shared_with_me=stats["shared_with_me"],
|
||||
sharing_breakdown=sharing_breakdown,
|
||||
total_team_members=stats["total_team_members"],
|
||||
expired_shares=stats["expired_shares"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get statistics: {str(e)}")
|
||||
|
||||
|
||||
# Action and permission catalogs for UI builders
|
||||
@router.get("/catalog/permissions")
|
||||
async def get_permission_catalog():
|
||||
"""Get available sharing permissions for UI builders"""
|
||||
return {
|
||||
"permissions": [
|
||||
{
|
||||
"value": "read",
|
||||
"label": "Read Only",
|
||||
"description": "Can view and search dataset"
|
||||
},
|
||||
{
|
||||
"value": "write",
|
||||
"label": "Read & Write",
|
||||
"description": "Can view, search, and add documents"
|
||||
},
|
||||
{
|
||||
"value": "admin",
|
||||
"label": "Administrator",
|
||||
"description": "Can modify sharing settings"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/catalog/access-groups")
|
||||
async def get_access_group_catalog():
|
||||
"""Get available access groups for UI builders"""
|
||||
return {
|
||||
"access_groups": [
|
||||
{
|
||||
"value": "individual",
|
||||
"label": "Private",
|
||||
"description": "Only accessible to owner"
|
||||
},
|
||||
{
|
||||
"value": "team",
|
||||
"label": "Team",
|
||||
"description": "Shared with specific team members"
|
||||
},
|
||||
{
|
||||
"value": "organization",
|
||||
"label": "Organization",
|
||||
"description": "Read-only access for all tenant users"
|
||||
}
|
||||
]
|
||||
}
|
||||
1024
apps/tenant-backend/app/api/v1/datasets.py
Normal file
1024
apps/tenant-backend/app/api/v1/datasets.py
Normal file
File diff suppressed because it is too large
Load Diff
256
apps/tenant-backend/app/api/v1/documents.py
Normal file
256
apps/tenant-backend/app/api/v1/documents.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
GT 2.0 Documents API - Wrapper for Files API
|
||||
|
||||
Provides document-centric interface that wraps the underlying files API.
|
||||
This maintains the document abstraction for the frontend while leveraging
|
||||
the existing file storage infrastructure.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, UploadFile, File, Form
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.api.v1.files import (
|
||||
get_file_info,
|
||||
download_file,
|
||||
delete_file,
|
||||
list_files,
|
||||
get_document_summary as get_file_summary,
|
||||
upload_file
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
||||
|
||||
|
||||
@router.post("", status_code=201)
|
||||
@router.post("/", status_code=201) # Support both with and without trailing slash
|
||||
async def upload_document(
|
||||
file: UploadFile = File(...),
|
||||
dataset_id: Optional[str] = Form(None, description="Associate with dataset"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Upload document (proxy to files API) - accepts dataset_id from FormData"""
|
||||
try:
|
||||
logger.info(f"Document upload requested - file: {file.filename}, dataset_id: {dataset_id}")
|
||||
# Proxy to files upload endpoint with "documents" category
|
||||
return await upload_file(file, dataset_id, "documents", current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Document upload failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{document_id}")
|
||||
async def get_document(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get document details (proxy to files API)"""
|
||||
try:
|
||||
# Proxy to files API - documents are stored as files
|
||||
return await get_file_info(document_id, current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get document {document_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{document_id}/summary")
|
||||
async def get_document_summary(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get AI-generated summary for a document (proxy to files API)"""
|
||||
try:
|
||||
# Proxy to files summary endpoint
|
||||
# codeql[py/stack-trace-exposure] proxies to files API, returns summary dict
|
||||
return await get_file_summary(document_id, current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Document summary generation failed for {document_id}: {e}", exc_info=True)
|
||||
# Return a fallback response
|
||||
return {
|
||||
"summary": "Summary generation is currently unavailable. Please try again later.",
|
||||
"key_topics": [],
|
||||
"document_type": "unknown",
|
||||
"language": "en",
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_documents(
|
||||
dataset_id: Optional[str] = Query(None, description="Filter by dataset"),
|
||||
status: Optional[str] = Query(None, description="Filter by processing status"),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""List documents with optional filtering (proxy to files API)"""
|
||||
try:
|
||||
# Map documents request to files API
|
||||
# Documents are files in the "documents" category
|
||||
result = await list_files(
|
||||
dataset_id=dataset_id,
|
||||
category="documents",
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
# Extract just the files array from the response object
|
||||
# The list_files endpoint returns {files: [...], total: N, limit: N, offset: N}
|
||||
# But frontend expects just the array
|
||||
if isinstance(result, dict) and 'files' in result:
|
||||
return result['files']
|
||||
elif isinstance(result, list):
|
||||
return result
|
||||
else:
|
||||
logger.warning(f"Unexpected response format from list_files: {type(result)}")
|
||||
return []
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list documents: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/{document_id}")
|
||||
async def delete_document(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Delete document and its metadata (proxy to files API)"""
|
||||
try:
|
||||
# Proxy to files delete endpoint
|
||||
return await delete_file(document_id, current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document {document_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{document_id}/download")
|
||||
async def download_document(
|
||||
document_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Download document file (proxy to files API)"""
|
||||
try:
|
||||
# Proxy to files download endpoint
|
||||
return await download_file(document_id, current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download document {document_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/{document_id}/process")
|
||||
async def process_document(
|
||||
document_id: str,
|
||||
chunking_strategy: Optional[str] = Query("hybrid", description="Chunking strategy"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Trigger document processing (chunking and embedding generation)"""
|
||||
try:
|
||||
from app.services.document_processor import get_document_processor
|
||||
from app.core.user_resolver import resolve_user_uuid
|
||||
|
||||
logger.info(f"Manual processing requested for document {document_id}")
|
||||
|
||||
# Get user info
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get document processor
|
||||
processor = await get_document_processor(tenant_domain=tenant_domain)
|
||||
|
||||
# Get document info to verify it exists and get metadata
|
||||
from app.services.postgresql_file_service import PostgreSQLFileService
|
||||
file_service = PostgreSQLFileService(tenant_domain=tenant_domain, user_id=user_uuid)
|
||||
|
||||
try:
|
||||
doc_info = await file_service.get_file_info(document_id)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Trigger processing using the file service's processing method
|
||||
await file_service._process_document_from_database(
|
||||
processor=processor,
|
||||
document_id=document_id,
|
||||
dataset_id=doc_info.get("dataset_id"),
|
||||
user_uuid=user_uuid,
|
||||
filename=doc_info["original_filename"]
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Document processing started",
|
||||
"document_id": document_id,
|
||||
"chunking_strategy": chunking_strategy
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process document {document_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/processing-status")
|
||||
async def get_processing_status(
|
||||
request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get processing status for multiple documents"""
|
||||
try:
|
||||
from app.services.document_processor import get_document_processor
|
||||
from app.core.user_resolver import resolve_user_uuid
|
||||
|
||||
# Get user info
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get document IDs from request
|
||||
document_ids = request.get("document_ids", [])
|
||||
if not document_ids:
|
||||
raise HTTPException(status_code=400, detail="document_ids required")
|
||||
|
||||
# Get processor instance
|
||||
processor = await get_document_processor(tenant_domain=tenant_domain)
|
||||
|
||||
# Get status for each document
|
||||
status_results = {}
|
||||
for doc_id in document_ids:
|
||||
try:
|
||||
status_info = await processor.get_processing_status(doc_id)
|
||||
status_results[doc_id] = {
|
||||
"status": status_info["status"],
|
||||
"error_message": status_info["error_message"],
|
||||
"progress": status_info["processing_progress"],
|
||||
"stage": status_info["processing_stage"],
|
||||
"chunks_processed": status_info["chunks_processed"],
|
||||
"total_chunks_expected": status_info["total_chunks_expected"]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get status for doc {doc_id}: {e}", exc_info=True)
|
||||
status_results[doc_id] = {
|
||||
"status": "error",
|
||||
"error_message": "Failed to get processing status",
|
||||
"progress": 0,
|
||||
"stage": "unknown"
|
||||
}
|
||||
|
||||
return status_results
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get processing status: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
513
apps/tenant-backend/app/api/v1/external_services.py
Normal file
513
apps/tenant-backend/app/api/v1/external_services.py
Normal file
@@ -0,0 +1,513 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend - External Services API
|
||||
Manage external web service instances with Resource Cluster integration
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from app.api.auth import get_current_user
|
||||
from app.core.database import get_db_session
|
||||
from app.services.external_service import ExternalServiceManager
|
||||
from app.core.capability_client import CapabilityClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["external_services"])
|
||||
|
||||
class CreateServiceRequest(BaseModel):
|
||||
"""Request to create external service"""
|
||||
service_type: str = Field(..., description="Service type: ctfd, canvas, guacamole")
|
||||
service_name: str = Field(..., description="Human-readable service name")
|
||||
description: Optional[str] = Field(None, description="Service description")
|
||||
config_overrides: Optional[Dict[str, Any]] = Field(None, description="Custom configuration")
|
||||
template_id: Optional[str] = Field(None, description="Template to use as base")
|
||||
|
||||
class ShareServiceRequest(BaseModel):
|
||||
"""Request to share service with other users"""
|
||||
share_with_emails: List[str] = Field(..., description="Email addresses to share with")
|
||||
access_level: str = Field(default="read", description="Access level: read, write")
|
||||
|
||||
class ServiceResponse(BaseModel):
|
||||
"""Service instance response"""
|
||||
id: str
|
||||
service_type: str
|
||||
service_name: str
|
||||
description: Optional[str]
|
||||
endpoint_url: str
|
||||
status: str
|
||||
health_status: str
|
||||
created_by: str
|
||||
allowed_users: List[str]
|
||||
access_level: str
|
||||
created_at: str
|
||||
last_accessed: Optional[str]
|
||||
|
||||
class ServiceListResponse(BaseModel):
|
||||
"""List of services response"""
|
||||
services: List[ServiceResponse]
|
||||
total: int
|
||||
|
||||
class EmbedConfigResponse(BaseModel):
|
||||
"""Iframe embed configuration response"""
|
||||
iframe_url: str
|
||||
sandbox_attributes: List[str]
|
||||
security_policies: Dict[str, Any]
|
||||
sso_token: str
|
||||
expires_at: str
|
||||
|
||||
class ServiceAnalyticsResponse(BaseModel):
|
||||
"""Service analytics response"""
|
||||
instance_id: str
|
||||
service_type: str
|
||||
service_name: str
|
||||
analytics_period_days: int
|
||||
total_sessions: int
|
||||
total_time_hours: float
|
||||
unique_users: int
|
||||
average_session_duration_minutes: float
|
||||
daily_usage: Dict[str, Any]
|
||||
uptime_percentage: float
|
||||
|
||||
@router.post("/create", response_model=ServiceResponse)
|
||||
async def create_external_service(
|
||||
request: CreateServiceRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> ServiceResponse:
|
||||
"""Create a new external service instance"""
|
||||
try:
|
||||
# Initialize service manager
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
# Get capability token for Resource Cluster calls
|
||||
capability_client = CapabilityClient()
|
||||
capability_token = await capability_client.generate_capability_token(
|
||||
user_email=current_user['email'],
|
||||
tenant_id=current_user['tenant_id'],
|
||||
resources=['external_services'],
|
||||
expires_hours=24
|
||||
)
|
||||
service_manager.set_capability_token(capability_token)
|
||||
|
||||
# Create service instance
|
||||
instance = await service_manager.create_service_instance(
|
||||
service_type=request.service_type,
|
||||
service_name=request.service_name,
|
||||
user_email=current_user['email'],
|
||||
config_overrides=request.config_overrides,
|
||||
template_id=request.template_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created {request.service_type} service '{request.service_name}' "
|
||||
f"for user {current_user['email']}"
|
||||
)
|
||||
|
||||
return ServiceResponse(
|
||||
id=instance.id,
|
||||
service_type=instance.service_type,
|
||||
service_name=instance.service_name,
|
||||
description=instance.description,
|
||||
endpoint_url=instance.endpoint_url,
|
||||
status=instance.status,
|
||||
health_status=instance.health_status,
|
||||
created_by=instance.created_by,
|
||||
allowed_users=instance.allowed_users,
|
||||
access_level=instance.access_level,
|
||||
created_at=instance.created_at.isoformat(),
|
||||
last_accessed=instance.last_accessed.isoformat() if instance.last_accessed else None
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Invalid request: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid request parameters")
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Resource cluster error: {e}")
|
||||
raise HTTPException(status_code=502, detail="Resource cluster unavailable")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create external service: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/list", response_model=ServiceListResponse)
|
||||
async def list_external_services(
|
||||
service_type: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> ServiceListResponse:
|
||||
"""List external services accessible to the user"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
instances = await service_manager.list_user_services(
|
||||
user_email=current_user['email'],
|
||||
service_type=service_type,
|
||||
status=status
|
||||
)
|
||||
|
||||
services = [
|
||||
ServiceResponse(
|
||||
id=instance.id,
|
||||
service_type=instance.service_type,
|
||||
service_name=instance.service_name,
|
||||
description=instance.description,
|
||||
endpoint_url=instance.endpoint_url,
|
||||
status=instance.status,
|
||||
health_status=instance.health_status,
|
||||
created_by=instance.created_by,
|
||||
allowed_users=instance.allowed_users,
|
||||
access_level=instance.access_level,
|
||||
created_at=instance.created_at.isoformat(),
|
||||
last_accessed=instance.last_accessed.isoformat() if instance.last_accessed else None
|
||||
)
|
||||
for instance in instances
|
||||
]
|
||||
|
||||
return ServiceListResponse(
|
||||
services=services,
|
||||
total=len(services)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list external services: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/{instance_id}", response_model=ServiceResponse)
|
||||
async def get_external_service(
|
||||
instance_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> ServiceResponse:
|
||||
"""Get specific external service details"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
instance = await service_manager.get_service_instance(
|
||||
instance_id=instance_id,
|
||||
user_email=current_user['email']
|
||||
)
|
||||
|
||||
if not instance:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Service instance not found or access denied"
|
||||
)
|
||||
|
||||
return ServiceResponse(
|
||||
id=instance.id,
|
||||
service_type=instance.service_type,
|
||||
service_name=instance.service_name,
|
||||
description=instance.description,
|
||||
endpoint_url=instance.endpoint_url,
|
||||
status=instance.status,
|
||||
health_status=instance.health_status,
|
||||
created_by=instance.created_by,
|
||||
allowed_users=instance.allowed_users,
|
||||
access_level=instance.access_level,
|
||||
created_at=instance.created_at.isoformat(),
|
||||
last_accessed=instance.last_accessed.isoformat() if instance.last_accessed else None
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get external service {instance_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.delete("/{instance_id}")
|
||||
async def stop_external_service(
|
||||
instance_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> Dict[str, Any]:
|
||||
"""Stop external service instance"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
# Get capability token for Resource Cluster calls
|
||||
capability_client = CapabilityClient()
|
||||
capability_token = await capability_client.generate_capability_token(
|
||||
user_email=current_user['email'],
|
||||
tenant_id=current_user['tenant_id'],
|
||||
resources=['external_services'],
|
||||
expires_hours=1
|
||||
)
|
||||
service_manager.set_capability_token(capability_token)
|
||||
|
||||
success = await service_manager.stop_service_instance(
|
||||
instance_id=instance_id,
|
||||
user_email=current_user['email']
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to stop service instance"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Service instance {instance_id} stopped successfully",
|
||||
"stopped_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Service not found: {e}")
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop external service {instance_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/{instance_id}/health")
|
||||
async def get_service_health(
|
||||
instance_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get service health status"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
# Get capability token for Resource Cluster calls
|
||||
capability_client = CapabilityClient()
|
||||
capability_token = await capability_client.generate_capability_token(
|
||||
user_email=current_user['email'],
|
||||
tenant_id=current_user['tenant_id'],
|
||||
resources=['external_services'],
|
||||
expires_hours=1
|
||||
)
|
||||
service_manager.set_capability_token(capability_token)
|
||||
|
||||
health = await service_manager.get_service_health(
|
||||
instance_id=instance_id,
|
||||
user_email=current_user['email']
|
||||
)
|
||||
|
||||
# codeql[py/stack-trace-exposure] returns health status dict, not error details
|
||||
return health
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Service not found: {e}")
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get service health for {instance_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.post("/{instance_id}/embed-config", response_model=EmbedConfigResponse)
|
||||
async def get_embed_config(
|
||||
instance_id: str,
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> EmbedConfigResponse:
|
||||
"""Get iframe embed configuration with SSO token"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
# Get capability token for Resource Cluster calls
|
||||
capability_client = CapabilityClient()
|
||||
capability_token = await capability_client.generate_capability_token(
|
||||
user_email=current_user['email'],
|
||||
tenant_id=current_user['tenant_id'],
|
||||
resources=['external_services'],
|
||||
expires_hours=24
|
||||
)
|
||||
service_manager.set_capability_token(capability_token)
|
||||
|
||||
# Generate SSO token and get embed config
|
||||
sso_data = await service_manager.generate_sso_token(
|
||||
instance_id=instance_id,
|
||||
user_email=current_user['email']
|
||||
)
|
||||
|
||||
# Log access event
|
||||
await service_manager.log_service_access(
|
||||
service_instance_id=instance_id,
|
||||
service_type="unknown", # Will be filled by service lookup
|
||||
user_email=current_user['email'],
|
||||
access_type="embed_access",
|
||||
session_id=f"embed_{datetime.utcnow().timestamp()}",
|
||||
ip_address=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
referer=request.headers.get("referer")
|
||||
)
|
||||
|
||||
return EmbedConfigResponse(
|
||||
iframe_url=sso_data['iframe_config']['src'],
|
||||
sandbox_attributes=sso_data['iframe_config']['sandbox'],
|
||||
security_policies={
|
||||
'allow': sso_data['iframe_config']['allow'],
|
||||
'referrerpolicy': sso_data['iframe_config']['referrerpolicy'],
|
||||
'loading': sso_data['iframe_config']['loading']
|
||||
},
|
||||
sso_token=sso_data['token'],
|
||||
expires_at=sso_data['expires_at']
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Service not found: {e}")
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Resource cluster error: {e}")
|
||||
raise HTTPException(status_code=502, detail="Resource cluster unavailable")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embed config for {instance_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/{instance_id}/analytics", response_model=ServiceAnalyticsResponse)
|
||||
async def get_service_analytics(
|
||||
instance_id: str,
|
||||
days: int = 30,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> ServiceAnalyticsResponse:
|
||||
"""Get service usage analytics"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
analytics = await service_manager.get_service_analytics(
|
||||
instance_id=instance_id,
|
||||
user_email=current_user['email'],
|
||||
days=days
|
||||
)
|
||||
|
||||
return ServiceAnalyticsResponse(**analytics)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Service not found: {e}")
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get analytics for {instance_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.post("/{instance_id}/share")
|
||||
async def share_service(
|
||||
instance_id: str,
|
||||
request: ShareServiceRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db=Depends(get_db_session)
|
||||
) -> Dict[str, Any]:
|
||||
"""Share service instance with other users"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
success = await service_manager.share_service_instance(
|
||||
instance_id=instance_id,
|
||||
owner_email=current_user['email'],
|
||||
share_with_emails=request.share_with_emails,
|
||||
access_level=request.access_level
|
||||
)
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"shared_with": request.share_with_emails,
|
||||
"access_level": request.access_level,
|
||||
"shared_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Service not found: {e}")
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to share service {instance_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/templates/list")
|
||||
async def list_service_templates(
|
||||
service_type: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
db=Depends(get_db_session)
|
||||
) -> Dict[str, Any]:
|
||||
"""List available service templates"""
|
||||
try:
|
||||
service_manager = ExternalServiceManager(db)
|
||||
|
||||
templates = await service_manager.list_service_templates(
|
||||
service_type=service_type,
|
||||
category=category,
|
||||
public_only=True
|
||||
)
|
||||
|
||||
return {
|
||||
"templates": [template.to_dict() for template in templates],
|
||||
"total": len(templates)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list service templates: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/types/supported")
|
||||
async def get_supported_service_types() -> Dict[str, Any]:
|
||||
"""Get supported external service types and their capabilities"""
|
||||
return {
|
||||
"supported_types": [
|
||||
{
|
||||
"type": "ctfd",
|
||||
"name": "CTFd Platform",
|
||||
"description": "Cybersecurity capture-the-flag challenges and competitions",
|
||||
"category": "cybersecurity",
|
||||
"features": [
|
||||
"Challenge creation and management",
|
||||
"Team-based competitions",
|
||||
"Scoring and leaderboards",
|
||||
"User registration and management",
|
||||
"Real-time notifications"
|
||||
],
|
||||
"resource_requirements": {
|
||||
"cpu": "1000m",
|
||||
"memory": "2Gi",
|
||||
"storage": "7Gi"
|
||||
},
|
||||
"estimated_startup_time": "2-3 minutes",
|
||||
"sso_supported": True
|
||||
},
|
||||
{
|
||||
"type": "canvas",
|
||||
"name": "Canvas LMS",
|
||||
"description": "Learning management system for educational courses",
|
||||
"category": "education",
|
||||
"features": [
|
||||
"Course creation and management",
|
||||
"Assignment and grading system",
|
||||
"Discussion forums and messaging",
|
||||
"Grade book and analytics",
|
||||
"External tool integrations"
|
||||
],
|
||||
"resource_requirements": {
|
||||
"cpu": "2000m",
|
||||
"memory": "4Gi",
|
||||
"storage": "30Gi"
|
||||
},
|
||||
"estimated_startup_time": "3-5 minutes",
|
||||
"sso_supported": True
|
||||
},
|
||||
{
|
||||
"type": "guacamole",
|
||||
"name": "Apache Guacamole",
|
||||
"description": "Remote desktop access for cyber lab environments",
|
||||
"category": "remote_access",
|
||||
"features": [
|
||||
"RDP, VNC, and SSH connections",
|
||||
"Session recording and playback",
|
||||
"Multi-user concurrent access",
|
||||
"Connection sharing and collaboration",
|
||||
"File transfer capabilities"
|
||||
],
|
||||
"resource_requirements": {
|
||||
"cpu": "500m",
|
||||
"memory": "1Gi",
|
||||
"storage": "11Gi"
|
||||
},
|
||||
"estimated_startup_time": "2-4 minutes",
|
||||
"sso_supported": True
|
||||
}
|
||||
],
|
||||
"total_types": 3,
|
||||
"categories": ["cybersecurity", "education", "remote_access"],
|
||||
"extensible": True
|
||||
}
|
||||
342
apps/tenant-backend/app/api/v1/files.py
Normal file
342
apps/tenant-backend/app/api/v1/files.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
GT 2.0 Files API - PostgreSQL File Storage
|
||||
|
||||
Provides file upload, download, and management using PostgreSQL unified storage.
|
||||
Replaces MinIO integration with PostgreSQL 3-tier storage strategy.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Query, Form
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.core.user_resolver import resolve_user_uuid
|
||||
from app.core.response_filter import ResponseFilter
|
||||
from app.core.permissions import get_user_role, is_effective_owner
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.services.postgresql_file_service import PostgreSQLFileService
|
||||
from app.services.document_summarizer import DocumentSummarizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/files", tags=["files"])
|
||||
|
||||
|
||||
@router.post("/upload", status_code=201)
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
dataset_id: Optional[str] = Form(None, description="Associate with dataset"),
|
||||
category: str = Form("documents", description="File category"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Upload file using PostgreSQL storage"""
|
||||
try:
|
||||
logger.info(f"File upload started: {file.filename}, size: {file.size if hasattr(file, 'size') else 'unknown'}")
|
||||
logger.info(f"Current user: {current_user}")
|
||||
logger.info(f"Dataset ID: {dataset_id}, Category: {category}")
|
||||
|
||||
if not file.filename:
|
||||
logger.error("No filename provided in upload request")
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain = current_user.get('tenant_domain', 'test-company')
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
logger.info(f"Creating file service for tenant: {tenant_domain}, user: {user_email} (UUID: {user_uuid})")
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
# Store file
|
||||
logger.info(f"Storing file: {file.filename}")
|
||||
result = await file_service.store_file(
|
||||
file=file,
|
||||
dataset_id=dataset_id,
|
||||
category=category
|
||||
)
|
||||
|
||||
logger.info(f"File uploaded successfully: {file.filename} -> {result['id']}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"File upload failed for {file.filename if file and file.filename else 'unknown'}: {e}", exc_info=True)
|
||||
logger.error(f"Exception type: {type(e).__name__}")
|
||||
logger.error(f"Current user context: {current_user}")
|
||||
raise HTTPException(status_code=500, detail="Failed to upload file")
|
||||
|
||||
|
||||
@router.get("/{file_id}")
|
||||
async def download_file(
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Download file by ID with streaming support"""
|
||||
try:
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
# Get file info first
|
||||
file_info = await file_service.get_file_info(file_id)
|
||||
|
||||
# Stream file content
|
||||
file_stream = file_service.get_file(file_id)
|
||||
|
||||
return StreamingResponse(
|
||||
file_stream,
|
||||
media_type=file_info['content_type'],
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=\"{file_info['original_filename']}\"",
|
||||
"Content-Length": str(file_info['file_size'])
|
||||
}
|
||||
)
|
||||
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
except Exception as e:
|
||||
logger.error(f"File download failed for {file_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{file_id}/info")
|
||||
async def get_file_info(
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get file metadata"""
|
||||
try:
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
file_info = await file_service.get_file_info(file_id)
|
||||
|
||||
# Apply security filtering using effective ownership
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.core.permissions import get_user_role, is_effective_owner
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
is_owner = is_effective_owner(file_info.get("user_id"), user_uuid, user_role)
|
||||
|
||||
filtered_info = ResponseFilter.filter_file_response(
|
||||
file_info,
|
||||
is_owner=is_owner
|
||||
)
|
||||
|
||||
return filtered_info
|
||||
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Get file info failed for {file_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_files(
|
||||
dataset_id: Optional[str] = Query(None, description="Filter by dataset"),
|
||||
category: str = Query("documents", description="Filter by category"),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""List user files with filtering"""
|
||||
try:
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
files = await file_service.list_files(
|
||||
dataset_id=dataset_id,
|
||||
category=category,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
# Apply security filtering to file list using effective ownership
|
||||
filtered_files = []
|
||||
for file_info in files:
|
||||
is_owner = is_effective_owner(file_info.get("user_id"), user_uuid, user_role)
|
||||
filtered_file = ResponseFilter.filter_file_response(
|
||||
file_info,
|
||||
is_owner=is_owner
|
||||
)
|
||||
filtered_files.append(filtered_file)
|
||||
|
||||
return {
|
||||
"files": filtered_files,
|
||||
"total": len(filtered_files),
|
||||
"limit": limit,
|
||||
"offset": offset
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"List files failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/{file_id}")
|
||||
async def delete_file(
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Delete file and its metadata"""
|
||||
try:
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
success = await file_service.delete_file(file_id)
|
||||
|
||||
if success:
|
||||
return {"message": "File deleted successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="File not found or delete failed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Delete file failed for {file_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/cleanup")
|
||||
async def cleanup_orphaned_files(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Clean up orphaned files (admin operation)"""
|
||||
try:
|
||||
# Only allow admin users to run cleanup
|
||||
user_roles = current_user.get('roles', [])
|
||||
if 'admin' not in user_roles:
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get user role for permission checks
|
||||
pg_client = await get_postgresql_client()
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid,
|
||||
user_role=user_role
|
||||
)
|
||||
|
||||
cleanup_count = await file_service.cleanup_orphaned_files()
|
||||
|
||||
return {
|
||||
"message": f"Cleaned up {cleanup_count} orphaned files",
|
||||
"count": cleanup_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{file_id}/summary")
|
||||
async def get_document_summary(
|
||||
file_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get AI-generated summary for a document"""
|
||||
try:
|
||||
# Get file service with proper UUID resolution
|
||||
tenant_domain, user_email, user_uuid = await resolve_user_uuid(current_user)
|
||||
|
||||
# Get file service to retrieve document content
|
||||
file_service = PostgreSQLFileService(
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_uuid
|
||||
)
|
||||
|
||||
# Get file info
|
||||
file_info = await file_service.get_file_info(file_id)
|
||||
|
||||
# Initialize summarizer
|
||||
summarizer = DocumentSummarizer()
|
||||
|
||||
# Get file content (for text files)
|
||||
# Note: This assumes text content is available
|
||||
# In production, you'd need to extract text from PDFs, etc.
|
||||
file_stream = file_service.get_file(file_id)
|
||||
content = ""
|
||||
async for chunk in file_stream:
|
||||
content += chunk.decode('utf-8', errors='ignore')
|
||||
|
||||
# Generate summary
|
||||
summary_result = await summarizer.generate_document_summary(
|
||||
document_id=file_id,
|
||||
content=content[:summarizer.max_content_length], # Truncate if too long
|
||||
filename=file_info['original_filename'],
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# codeql[py/stack-trace-exposure] returns document summary dict, not error details
|
||||
return {
|
||||
"summary": summary_result.get("summary", "No summary available"),
|
||||
"key_topics": summary_result.get("key_topics", []),
|
||||
"document_type": summary_result.get("document_type"),
|
||||
"language": summary_result.get("language", "en"),
|
||||
"metadata": summary_result.get("metadata", {})
|
||||
}
|
||||
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Document summary generation failed for {file_id}: {e}", exc_info=True)
|
||||
# Return a fallback response instead of failing completely
|
||||
return {
|
||||
"summary": "Summary generation is currently unavailable. Please try again later.",
|
||||
"key_topics": [],
|
||||
"document_type": "unknown",
|
||||
"language": "en",
|
||||
"metadata": {}
|
||||
}
|
||||
520
apps/tenant-backend/app/api/v1/games.py
Normal file
520
apps/tenant-backend/app/api/v1/games.py
Normal file
@@ -0,0 +1,520 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from app.services.game_service import GameService, PuzzleService, PhilosophicalDialogueService
|
||||
from app.models.game import GameSession, PuzzleSession, PhilosophicalDialogue, LearningAnalytics
|
||||
|
||||
router = APIRouter(prefix="/games", tags=["AI Literacy & Games"])
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class GameConfigRequest(BaseModel):
|
||||
game_type: str = Field(..., description="Type of game: chess, go")
|
||||
difficulty: str = Field(default="intermediate", description="Difficulty level")
|
||||
name: Optional[str] = Field(None, description="Custom game name")
|
||||
ai_personality: Optional[str] = Field(default="teaching", description="AI opponent personality")
|
||||
time_control: Optional[str] = Field(None, description="Time control settings")
|
||||
|
||||
|
||||
class GameMoveRequest(BaseModel):
|
||||
move_data: Dict[str, Any] = Field(..., description="Move data specific to game type")
|
||||
request_analysis: Optional[bool] = Field(default=False, description="Request move analysis")
|
||||
|
||||
|
||||
class PuzzleConfigRequest(BaseModel):
|
||||
puzzle_type: str = Field(..., description="Type of puzzle: lateral_thinking, logical_deduction, etc.")
|
||||
difficulty: Optional[int] = Field(None, description="Difficulty level 1-10", ge=1, le=10)
|
||||
category: Optional[str] = Field(None, description="Puzzle category")
|
||||
|
||||
|
||||
class PuzzleSolutionRequest(BaseModel):
|
||||
solution: Dict[str, Any] = Field(..., description="Puzzle solution attempt")
|
||||
reasoning: Optional[str] = Field(None, description="User's reasoning explanation")
|
||||
|
||||
|
||||
class HintRequest(BaseModel):
|
||||
hint_level: int = Field(default=1, description="Hint level 1-3", ge=1, le=3)
|
||||
|
||||
|
||||
class DilemmaConfigRequest(BaseModel):
|
||||
dilemma_type: str = Field(..., description="Type of dilemma: ethical_frameworks, game_theory, ai_consciousness")
|
||||
topic: str = Field(..., description="Specific topic within the dilemma type")
|
||||
complexity: Optional[str] = Field(default="intermediate", description="Complexity level")
|
||||
|
||||
|
||||
class DilemmaResponseRequest(BaseModel):
|
||||
response: str = Field(..., description="User's response to the dilemma")
|
||||
framework: Optional[str] = Field(None, description="Ethical framework being applied")
|
||||
|
||||
|
||||
class DilemmaFinalPositionRequest(BaseModel):
|
||||
final_position: Dict[str, Any] = Field(..., description="User's final position on the dilemma")
|
||||
key_insights: Optional[List[str]] = Field(default=[], description="Key insights gained")
|
||||
|
||||
|
||||
# Strategic Games Endpoints
|
||||
@router.get("/available")
|
||||
async def get_available_games(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get available games and user progress"""
|
||||
service = GameService(db)
|
||||
return await service.get_available_games(current_user["user_id"])
|
||||
|
||||
|
||||
@router.post("/start")
|
||||
async def start_game_session(
|
||||
config: GameConfigRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Start a new game session"""
|
||||
service = GameService(db)
|
||||
|
||||
try:
|
||||
session = await service.start_game_session(
|
||||
user_id=current_user["user_id"],
|
||||
game_type=config.game_type,
|
||||
config=config.dict()
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session.id,
|
||||
"game_type": session.game_type,
|
||||
"difficulty": session.difficulty_level,
|
||||
"initial_state": session.current_state,
|
||||
"ai_config": session.ai_opponent_config,
|
||||
"started_at": session.started_at.isoformat()
|
||||
}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/session/{session_id}")
|
||||
async def get_game_session(
|
||||
session_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get current game session details"""
|
||||
service = GameService(db)
|
||||
|
||||
try:
|
||||
analysis = await service.get_game_analysis(session_id, current_user["user_id"])
|
||||
return analysis
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/session/{session_id}/move")
|
||||
async def make_move(
|
||||
session_id: str,
|
||||
move: GameMoveRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Make a move in the game"""
|
||||
service = GameService(db)
|
||||
|
||||
try:
|
||||
result = await service.make_move(
|
||||
session_id=session_id,
|
||||
user_id=current_user["user_id"],
|
||||
move_data=move.move_data
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/session/{session_id}/analysis")
|
||||
async def get_game_analysis(
|
||||
session_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get detailed game analysis"""
|
||||
service = GameService(db)
|
||||
|
||||
try:
|
||||
analysis = await service.get_game_analysis(session_id, current_user["user_id"])
|
||||
return analysis
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_game_history(
|
||||
game_type: Optional[str] = Query(None, description="Filter by game type"),
|
||||
limit: int = Query(20, description="Number of games to return", ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get user's game history"""
|
||||
service = GameService(db)
|
||||
return await service.get_user_game_history(
|
||||
user_id=current_user["user_id"],
|
||||
game_type=game_type,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
|
||||
# Logic Puzzles Endpoints
|
||||
@router.get("/puzzles/available")
|
||||
async def get_available_puzzles(
|
||||
category: Optional[str] = Query(None, description="Filter by puzzle category"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get available puzzle categories and difficulty recommendations"""
|
||||
service = PuzzleService(db)
|
||||
return await service.get_available_puzzles(current_user["user_id"], category)
|
||||
|
||||
|
||||
@router.post("/puzzles/start")
|
||||
async def start_puzzle_session(
|
||||
config: PuzzleConfigRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Start a new puzzle session"""
|
||||
service = PuzzleService(db)
|
||||
|
||||
try:
|
||||
session = await service.start_puzzle_session(
|
||||
user_id=current_user["user_id"],
|
||||
puzzle_type=config.puzzle_type,
|
||||
difficulty=config.difficulty
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session.id,
|
||||
"puzzle_type": session.puzzle_type,
|
||||
"difficulty": session.difficulty_rating,
|
||||
"puzzle": session.puzzle_definition,
|
||||
"estimated_time": session.estimated_time_minutes,
|
||||
"started_at": session.started_at.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/puzzles/{session_id}/solve")
|
||||
async def submit_puzzle_solution(
|
||||
session_id: str,
|
||||
solution: PuzzleSolutionRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Submit a solution for the puzzle"""
|
||||
service = PuzzleService(db)
|
||||
|
||||
try:
|
||||
result = await service.submit_solution(
|
||||
session_id=session_id,
|
||||
user_id=current_user["user_id"],
|
||||
solution=solution.solution,
|
||||
reasoning=solution.reasoning
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/puzzles/{session_id}/hint")
|
||||
async def get_puzzle_hint(
|
||||
session_id: str,
|
||||
hint_request: HintRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get a hint for the current puzzle"""
|
||||
service = PuzzleService(db)
|
||||
|
||||
try:
|
||||
hint = await service.get_hint(
|
||||
session_id=session_id,
|
||||
user_id=current_user["user_id"],
|
||||
hint_level=hint_request.hint_level
|
||||
)
|
||||
return hint
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
# Philosophical Dilemmas Endpoints
|
||||
@router.get("/dilemmas/available")
|
||||
async def get_available_dilemmas(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get available philosophical dilemmas"""
|
||||
service = PhilosophicalDialogueService(db)
|
||||
return await service.get_available_dilemmas(current_user["user_id"])
|
||||
|
||||
|
||||
@router.post("/dilemmas/start")
|
||||
async def start_philosophical_dialogue(
|
||||
config: DilemmaConfigRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Start a new philosophical dialogue session"""
|
||||
service = PhilosophicalDialogueService(db)
|
||||
|
||||
try:
|
||||
dialogue = await service.start_dialogue_session(
|
||||
user_id=current_user["user_id"],
|
||||
dilemma_type=config.dilemma_type,
|
||||
topic=config.topic
|
||||
)
|
||||
|
||||
return {
|
||||
"dialogue_id": dialogue.id,
|
||||
"dilemma_title": dialogue.dilemma_title,
|
||||
"scenario": dialogue.scenario_description,
|
||||
"framework_options": dialogue.framework_options,
|
||||
"complexity": dialogue.complexity_level,
|
||||
"estimated_time": dialogue.estimated_discussion_time,
|
||||
"started_at": dialogue.started_at.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/dilemmas/{dialogue_id}/respond")
|
||||
async def submit_dilemma_response(
|
||||
dialogue_id: str,
|
||||
response: DilemmaResponseRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Submit a response to the philosophical dilemma"""
|
||||
service = PhilosophicalDialogueService(db)
|
||||
|
||||
try:
|
||||
result = await service.submit_response(
|
||||
dialogue_id=dialogue_id,
|
||||
user_id=current_user["user_id"],
|
||||
response=response.response,
|
||||
framework=response.framework
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/dilemmas/{dialogue_id}/conclude")
|
||||
async def conclude_philosophical_dialogue(
|
||||
dialogue_id: str,
|
||||
final_position: DilemmaFinalPositionRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Conclude the philosophical dialogue with final assessment"""
|
||||
service = PhilosophicalDialogueService(db)
|
||||
|
||||
try:
|
||||
result = await service.conclude_dialogue(
|
||||
dialogue_id=dialogue_id,
|
||||
user_id=current_user["user_id"],
|
||||
final_position=final_position.final_position
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/dilemmas/{dialogue_id}")
|
||||
async def get_dialogue_session(
|
||||
dialogue_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get philosophical dialogue session details"""
|
||||
query = """
|
||||
SELECT * FROM philosophical_dialogues
|
||||
WHERE id = :dialogue_id AND user_id = :user_id
|
||||
"""
|
||||
|
||||
result = await db.execute(query, {
|
||||
"dialogue_id": dialogue_id,
|
||||
"user_id": current_user["user_id"]
|
||||
})
|
||||
dialogue = result.fetchone()
|
||||
|
||||
if not dialogue:
|
||||
raise HTTPException(status_code=404, detail="Dialogue session not found")
|
||||
|
||||
return {
|
||||
"dialogue_id": dialogue["id"],
|
||||
"dilemma_title": dialogue["dilemma_title"],
|
||||
"scenario": dialogue["scenario_description"],
|
||||
"dialogue_history": dialogue["dialogue_history"],
|
||||
"frameworks_explored": dialogue["frameworks_explored"],
|
||||
"status": dialogue["dialogue_status"],
|
||||
"exchange_count": dialogue["exchange_count"],
|
||||
"started_at": dialogue["started_at"],
|
||||
"last_exchange_at": dialogue["last_exchange_at"]
|
||||
}
|
||||
|
||||
|
||||
# Learning Analytics Endpoints
|
||||
@router.get("/analytics/progress")
|
||||
async def get_learning_progress(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get comprehensive learning progress and analytics"""
|
||||
service = GameService(db)
|
||||
analytics = await service.get_or_create_analytics(current_user["user_id"])
|
||||
|
||||
return {
|
||||
"overall_progress": {
|
||||
"total_sessions": analytics.total_sessions,
|
||||
"total_time_minutes": analytics.total_time_minutes,
|
||||
"current_streak": analytics.current_streak_days,
|
||||
"longest_streak": analytics.longest_streak_days
|
||||
},
|
||||
"game_ratings": {
|
||||
"chess": analytics.chess_rating,
|
||||
"go": analytics.go_rating,
|
||||
"puzzle_level": analytics.puzzle_solving_level,
|
||||
"philosophical_depth": analytics.philosophical_depth_level
|
||||
},
|
||||
"cognitive_skills": {
|
||||
"strategic_thinking": analytics.strategic_thinking_score,
|
||||
"logical_reasoning": analytics.logical_reasoning_score,
|
||||
"creative_problem_solving": analytics.creative_problem_solving_score,
|
||||
"ethical_reasoning": analytics.ethical_reasoning_score,
|
||||
"pattern_recognition": analytics.pattern_recognition_score,
|
||||
"metacognitive_awareness": analytics.metacognitive_awareness_score
|
||||
},
|
||||
"thinking_style": {
|
||||
"system1_reliance": analytics.system1_reliance_average,
|
||||
"system2_engagement": analytics.system2_engagement_average,
|
||||
"intuition_accuracy": analytics.intuition_accuracy_score,
|
||||
"reflection_frequency": analytics.reflection_frequency_score
|
||||
},
|
||||
"ai_collaboration": {
|
||||
"dependency_index": analytics.ai_dependency_index,
|
||||
"prompt_engineering": analytics.prompt_engineering_skill,
|
||||
"output_evaluation": analytics.ai_output_evaluation_skill,
|
||||
"collaborative_solving": analytics.collaborative_problem_solving
|
||||
},
|
||||
"achievements": analytics.achievement_badges,
|
||||
"recommendations": analytics.recommended_activities,
|
||||
"last_activity": analytics.last_activity_date.isoformat() if analytics.last_activity_date else None
|
||||
}
|
||||
|
||||
|
||||
@router.get("/analytics/trends")
|
||||
async def get_learning_trends(
|
||||
timeframe: str = Query("30d", description="Timeframe: 7d, 30d, 90d, 1y"),
|
||||
skill_area: Optional[str] = Query(None, description="Specific skill area to analyze"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get learning trends and performance over time"""
|
||||
service = GameService(db)
|
||||
analytics = await service.get_or_create_analytics(current_user["user_id"])
|
||||
|
||||
# This would typically involve more complex analytics queries
|
||||
# For now, return structured trend data
|
||||
return {
|
||||
"timeframe": timeframe,
|
||||
"skill_progression": analytics.skill_progression_data,
|
||||
"performance_trends": {
|
||||
"chess_rating_history": [{"date": "2024-01-01", "rating": 1200}], # Mock data
|
||||
"puzzle_completion_rate": [{"week": 1, "rate": 0.75}],
|
||||
"session_frequency": [{"week": 1, "sessions": 3}]
|
||||
},
|
||||
"comparative_metrics": {
|
||||
"peer_comparison": "Above average",
|
||||
"improvement_rate": "15% this month",
|
||||
"consistency_score": 0.85
|
||||
},
|
||||
"insights": [
|
||||
"Strong improvement in strategic thinking",
|
||||
"Puzzle-solving speed has increased 20%",
|
||||
"Consider more challenging philosophical dilemmas"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/analytics/recommendations")
|
||||
async def get_learning_recommendations(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get personalized learning recommendations"""
|
||||
service = GameService(db)
|
||||
analytics = await service.get_or_create_analytics(current_user["user_id"])
|
||||
|
||||
return {
|
||||
"next_activities": [
|
||||
{
|
||||
"type": "chess",
|
||||
"difficulty": "advanced",
|
||||
"reason": "Your tactical skills have improved significantly",
|
||||
"estimated_time": 30
|
||||
},
|
||||
{
|
||||
"type": "logical_deduction",
|
||||
"difficulty": 6,
|
||||
"reason": "Ready for more complex reasoning challenges",
|
||||
"estimated_time": 20
|
||||
},
|
||||
{
|
||||
"type": "ai_consciousness",
|
||||
"topic": "chinese_room",
|
||||
"reason": "Explore deeper philosophical concepts",
|
||||
"estimated_time": 25
|
||||
}
|
||||
],
|
||||
"skill_focus_areas": [
|
||||
"Synthesis of multiple ethical frameworks",
|
||||
"Advanced strategic planning in Go",
|
||||
"Metacognitive awareness development"
|
||||
],
|
||||
"adaptive_settings": {
|
||||
"chess_difficulty": "advanced",
|
||||
"puzzle_difficulty": min(analytics.puzzle_solving_level + 1, 10),
|
||||
"philosophical_complexity": "advanced" if analytics.philosophical_depth_level > 6 else "intermediate"
|
||||
},
|
||||
"learning_goals": analytics.learning_goals or [
|
||||
"Improve system 2 thinking engagement",
|
||||
"Develop better AI collaboration skills",
|
||||
"Master ethical framework application"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.post("/analytics/reflection")
|
||||
async def submit_learning_reflection(
|
||||
reflection_data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Submit learning reflection and self-assessment"""
|
||||
service = GameService(db)
|
||||
analytics = await service.get_or_create_analytics(current_user["user_id"])
|
||||
|
||||
# Process reflection data and update analytics
|
||||
# This would involve sophisticated analysis of user self-reflection
|
||||
|
||||
return {
|
||||
"reflection_recorded": True,
|
||||
"metacognitive_feedback": "Your self-awareness of thinking patterns is improving",
|
||||
"updated_recommendations": [
|
||||
"Continue exploring areas where intuition conflicts with analysis",
|
||||
"Practice explaining your reasoning process more explicitly"
|
||||
],
|
||||
"insights_gained": reflection_data.get("insights", [])
|
||||
}
|
||||
172
apps/tenant-backend/app/api/v1/models.py
Normal file
172
apps/tenant-backend/app/api/v1/models.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Tenant Models API - Interface to Resource Cluster Model Management
|
||||
|
||||
Provides tenant-scoped access to available AI models from the Resource Cluster.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import APIRouter, HTTPException, status, Depends
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.core.config import get_settings
|
||||
from app.core.cache import get_cache
|
||||
from app.services.resource_cluster_client import ResourceClusterClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
cache = get_cache()
|
||||
|
||||
router = APIRouter(prefix="/api/v1/models", tags=["Models"])
|
||||
|
||||
|
||||
@router.get("/", summary="List available models for tenant")
|
||||
async def list_available_models(
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get list of AI models available to the current tenant"""
|
||||
|
||||
try:
|
||||
# Get tenant domain from current user
|
||||
tenant_domain = current_user.get("tenant_domain", "default")
|
||||
|
||||
# Check cache first (5-minute TTL)
|
||||
cache_key = f"models_list_{tenant_domain}"
|
||||
cached_models = cache.get(cache_key, ttl=300)
|
||||
if cached_models:
|
||||
logger.debug(f"Returning cached model list for tenant {tenant_domain}")
|
||||
return {**cached_models, "cached": True}
|
||||
|
||||
# Call Resource Cluster models API - use Docker service name if in container
|
||||
import os
|
||||
if os.path.exists('/.dockerenv'):
|
||||
resource_cluster_url = "http://resource-cluster:8000"
|
||||
else:
|
||||
resource_cluster_url = settings.resource_cluster_url
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{resource_cluster_url}/api/v1/models/",
|
||||
headers={
|
||||
"X-Tenant-Domain": tenant_domain
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
models_data = response.json()
|
||||
models = models_data.get("models", [])
|
||||
|
||||
# Filter models by health and deployment status
|
||||
available_models = [
|
||||
{
|
||||
"value": model["id"], # model_id string for backwards compatibility
|
||||
"uuid": model.get("uuid"), # Database UUID for unique identification
|
||||
"label": model["name"],
|
||||
"description": model["description"],
|
||||
"provider": model["provider"],
|
||||
"model_type": model["model_type"],
|
||||
"max_tokens": model["performance"]["max_tokens"],
|
||||
"context_window": model["performance"]["context_window"],
|
||||
"cost_per_1k_tokens": model["performance"]["cost_per_1k_tokens"],
|
||||
"latency_p50_ms": model["performance"]["latency_p50_ms"],
|
||||
"health_status": model["status"]["health"],
|
||||
"deployment_status": model["status"]["deployment"]
|
||||
}
|
||||
for model in models
|
||||
if (model["status"]["deployment"] == "available" and
|
||||
model["status"]["health"] in ["healthy", "unknown"] and
|
||||
model["model_type"] != "embedding")
|
||||
]
|
||||
|
||||
# Sort by provider preference (NVIDIA first, then Groq) and then by performance
|
||||
provider_order = {"nvidia": 0, "groq": 1}
|
||||
available_models.sort(key=lambda x: (
|
||||
provider_order.get(x["provider"], 99), # NVIDIA first, then Groq
|
||||
x["latency_p50_ms"] or 999 # Lower latency first
|
||||
))
|
||||
|
||||
result = {
|
||||
"models": available_models,
|
||||
"total": len(available_models),
|
||||
"tenant_domain": tenant_domain,
|
||||
"last_updated": models_data.get("last_updated"),
|
||||
"cached": False
|
||||
}
|
||||
|
||||
# Cache the result for 5 minutes
|
||||
cache.set(cache_key, result)
|
||||
logger.debug(f"Cached model list for tenant {tenant_domain}")
|
||||
|
||||
return result
|
||||
|
||||
else:
|
||||
# Resource Cluster unavailable - return empty list
|
||||
logger.warning(f"Resource Cluster unavailable (HTTP {response.status_code})")
|
||||
return {
|
||||
"models": [],
|
||||
"total": 0,
|
||||
"tenant_domain": tenant_domain,
|
||||
"message": "No models available - resource cluster unavailable"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching models from Resource Cluster: {e}")
|
||||
# Return empty list in case of error
|
||||
return {
|
||||
"models": [],
|
||||
"total": 0,
|
||||
"tenant_domain": current_user.get("tenant_domain", "default"),
|
||||
"message": "No models available - service error"
|
||||
}
|
||||
|
||||
|
||||
|
||||
@router.get("/{model_id}", summary="Get model details")
|
||||
async def get_model_details(
|
||||
model_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get detailed information about a specific model"""
|
||||
|
||||
try:
|
||||
tenant_domain = current_user.get("tenant_domain", "default")
|
||||
|
||||
# Call Resource Cluster for model details - use Docker service name if in container
|
||||
import os
|
||||
if os.path.exists('/.dockerenv'):
|
||||
resource_cluster_url = "http://resource-cluster:8000"
|
||||
else:
|
||||
resource_cluster_url = settings.resource_cluster_url
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{resource_cluster_url}/api/v1/models/{model_id}",
|
||||
headers={
|
||||
"X-Tenant-Domain": tenant_domain
|
||||
},
|
||||
timeout=15.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code == 404:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model {model_id} not found"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Resource Cluster unavailable"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model {model_id} details: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Failed to get model details"
|
||||
)
|
||||
2908
apps/tenant-backend/app/api/v1/observability.py
Normal file
2908
apps/tenant-backend/app/api/v1/observability.py
Normal file
File diff suppressed because it is too large
Load Diff
238
apps/tenant-backend/app/api/v1/optics.py
Normal file
238
apps/tenant-backend/app/api/v1/optics.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Optics Cost Tracking API Endpoints
|
||||
|
||||
Provides cost visibility for inference and storage usage.
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.api.v1.observability import get_current_user, get_user_role
|
||||
|
||||
from app.services.optics_service import (
|
||||
fetch_optics_settings,
|
||||
get_optics_cost_summary,
|
||||
STORAGE_COST_PER_MB_CENTS
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/optics", tags=["Optics Cost Tracking"])
|
||||
|
||||
|
||||
# Response models
|
||||
class OpticsSettingsResponse(BaseModel):
|
||||
enabled: bool
|
||||
storage_cost_per_mb_cents: float
|
||||
show_to_admins_only: bool = True
|
||||
|
||||
|
||||
class ModelCostBreakdown(BaseModel):
|
||||
model_id: str
|
||||
model_name: str
|
||||
tokens: int
|
||||
conversations: int
|
||||
messages: int
|
||||
cost_cents: float
|
||||
cost_display: str
|
||||
percentage: float
|
||||
|
||||
|
||||
class UserCostBreakdown(BaseModel):
|
||||
user_id: str
|
||||
email: str
|
||||
tokens: int
|
||||
cost_cents: float
|
||||
cost_display: str
|
||||
percentage: float
|
||||
|
||||
|
||||
class OpticsCostResponse(BaseModel):
|
||||
enabled: bool
|
||||
inference_cost_cents: float
|
||||
storage_cost_cents: float
|
||||
total_cost_cents: float
|
||||
inference_cost_display: str
|
||||
storage_cost_display: str
|
||||
total_cost_display: str
|
||||
total_tokens: int
|
||||
total_storage_mb: float
|
||||
document_count: int
|
||||
dataset_count: int
|
||||
by_model: List[ModelCostBreakdown]
|
||||
by_user: Optional[List[UserCostBreakdown]] = None
|
||||
period_start: str
|
||||
period_end: str
|
||||
|
||||
|
||||
@router.get("/settings", response_model=OpticsSettingsResponse)
|
||||
async def get_optics_settings(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Check if Optics is enabled for the current tenant.
|
||||
|
||||
This endpoint is used by the frontend to determine whether
|
||||
to show the Optics tab in the observability dashboard.
|
||||
"""
|
||||
tenant_domain = current_user.get("tenant_domain", "test-company")
|
||||
|
||||
try:
|
||||
settings = await fetch_optics_settings(tenant_domain)
|
||||
|
||||
return OpticsSettingsResponse(
|
||||
enabled=settings.get("enabled", False),
|
||||
storage_cost_per_mb_cents=settings.get("storage_cost_per_mb_cents", STORAGE_COST_PER_MB_CENTS),
|
||||
show_to_admins_only=True # Only admins can see user breakdown
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching optics settings: {str(e)}")
|
||||
return OpticsSettingsResponse(
|
||||
enabled=False,
|
||||
storage_cost_per_mb_cents=STORAGE_COST_PER_MB_CENTS,
|
||||
show_to_admins_only=True
|
||||
)
|
||||
|
||||
|
||||
@router.get("/costs", response_model=OpticsCostResponse)
|
||||
async def get_optics_costs(
|
||||
days: Optional[int] = Query(30, ge=1, le=365, description="Number of days to look back"),
|
||||
start_date: Optional[str] = Query(None, description="Custom start date (ISO format)"),
|
||||
end_date: Optional[str] = Query(None, description="Custom end date (ISO format)"),
|
||||
user_id: Optional[str] = Query(None, description="Filter by user ID (admin only)"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get Optics cost breakdown for the current tenant.
|
||||
|
||||
Returns inference costs calculated from token usage and model pricing,
|
||||
plus storage costs at the configured rate (default 4 cents/MB).
|
||||
|
||||
- **days**: Number of days to look back (default 30)
|
||||
- **start_date**: Custom start date (overrides days)
|
||||
- **end_date**: Custom end date
|
||||
- **user_id**: Filter by specific user (admin only)
|
||||
"""
|
||||
tenant_domain = current_user.get("tenant_domain", "test-company")
|
||||
|
||||
# Check if Optics is enabled
|
||||
settings = await fetch_optics_settings(tenant_domain)
|
||||
if not settings.get("enabled", False):
|
||||
return OpticsCostResponse(
|
||||
enabled=False,
|
||||
inference_cost_cents=0,
|
||||
storage_cost_cents=0,
|
||||
total_cost_cents=0,
|
||||
inference_cost_display="$0.00",
|
||||
storage_cost_display="$0.00",
|
||||
total_cost_display="$0.00",
|
||||
total_tokens=0,
|
||||
total_storage_mb=0,
|
||||
document_count=0,
|
||||
dataset_count=0,
|
||||
by_model=[],
|
||||
by_user=None,
|
||||
period_start=datetime.utcnow().isoformat(),
|
||||
period_end=datetime.utcnow().isoformat()
|
||||
)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get user role for permission checks
|
||||
user_email = current_user.get("email", "")
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
|
||||
is_admin = user_role in ["admin", "developer"]
|
||||
|
||||
# Handle user filter - only admins can filter by user
|
||||
filter_user_id = None
|
||||
if user_id:
|
||||
if not is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only admins can filter by user"
|
||||
)
|
||||
filter_user_id = user_id
|
||||
elif not is_admin:
|
||||
# Non-admins can only see their own data
|
||||
# Get user UUID from email
|
||||
user_query = f"""
|
||||
SELECT id FROM tenant_{tenant_domain.replace('-', '_')}.users
|
||||
WHERE email = $1 LIMIT 1
|
||||
"""
|
||||
user_result = await pg_client.execute_query(user_query, user_email)
|
||||
if user_result:
|
||||
filter_user_id = str(user_result[0]["id"])
|
||||
|
||||
# Calculate date range
|
||||
date_end = datetime.utcnow()
|
||||
date_start = date_end - timedelta(days=days)
|
||||
|
||||
if start_date:
|
||||
try:
|
||||
date_start = datetime.fromisoformat(start_date.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid start_date format. Use ISO format."
|
||||
)
|
||||
|
||||
if end_date:
|
||||
try:
|
||||
date_end = datetime.fromisoformat(end_date.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid end_date format. Use ISO format."
|
||||
)
|
||||
|
||||
try:
|
||||
cost_summary = await get_optics_cost_summary(
|
||||
pg_client=pg_client,
|
||||
tenant_domain=tenant_domain,
|
||||
date_start=date_start,
|
||||
date_end=date_end,
|
||||
user_id=filter_user_id,
|
||||
include_user_breakdown=is_admin and not filter_user_id # Only include breakdown for platform view
|
||||
)
|
||||
|
||||
# Convert to response model
|
||||
by_model = [
|
||||
ModelCostBreakdown(**item)
|
||||
for item in cost_summary.get("by_model", [])
|
||||
]
|
||||
|
||||
by_user = None
|
||||
if cost_summary.get("by_user"):
|
||||
by_user = [
|
||||
UserCostBreakdown(**item)
|
||||
for item in cost_summary["by_user"]
|
||||
]
|
||||
|
||||
return OpticsCostResponse(
|
||||
enabled=True,
|
||||
inference_cost_cents=cost_summary["inference_cost_cents"],
|
||||
storage_cost_cents=cost_summary["storage_cost_cents"],
|
||||
total_cost_cents=cost_summary["total_cost_cents"],
|
||||
inference_cost_display=cost_summary["inference_cost_display"],
|
||||
storage_cost_display=cost_summary["storage_cost_display"],
|
||||
total_cost_display=cost_summary["total_cost_display"],
|
||||
total_tokens=cost_summary["total_tokens"],
|
||||
total_storage_mb=cost_summary["total_storage_mb"],
|
||||
document_count=cost_summary["document_count"],
|
||||
dataset_count=cost_summary["dataset_count"],
|
||||
by_model=by_model,
|
||||
by_user=by_user,
|
||||
period_start=cost_summary["period_start"],
|
||||
period_end=cost_summary["period_end"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating optics costs: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to calculate costs"
|
||||
)
|
||||
535
apps/tenant-backend/app/api/v1/rag_visualization.py
Normal file
535
apps/tenant-backend/app/api/v1/rag_visualization.py
Normal file
@@ -0,0 +1,535 @@
|
||||
"""
|
||||
RAG Network Visualization API for GT 2.0
|
||||
Provides force-directed graph data and semantic relationships
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.core.database import get_db_session
|
||||
from app.core.security import get_current_user
|
||||
from app.services.rag_service import RAGService
|
||||
import random
|
||||
import math
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/rag/visualization", tags=["rag-visualization"])
|
||||
|
||||
class NetworkNode:
|
||||
"""Represents a node in the knowledge network"""
|
||||
def __init__(self, id: str, label: str, type: str, metadata: Dict[str, Any]):
|
||||
self.id = id
|
||||
self.label = label
|
||||
self.type = type # document, chunk, concept, query
|
||||
self.metadata = metadata
|
||||
self.x = random.uniform(-100, 100)
|
||||
self.y = random.uniform(-100, 100)
|
||||
self.size = self._calculate_size(type, metadata)
|
||||
self.color = self._get_color_by_type(type)
|
||||
self.importance = metadata.get('importance', 0.5)
|
||||
|
||||
def _calculate_size(self, type: str, metadata: Dict[str, Any]) -> int:
|
||||
"""Calculate node size based on type and metadata"""
|
||||
base_sizes = {
|
||||
"document": 20,
|
||||
"chunk": 10,
|
||||
"concept": 15,
|
||||
"query": 25
|
||||
}
|
||||
|
||||
base_size = base_sizes.get(type, 10)
|
||||
|
||||
# Adjust based on importance/usage
|
||||
if 'usage_count' in metadata:
|
||||
usage_multiplier = min(2.0, 1.0 + metadata['usage_count'] / 10.0)
|
||||
base_size = int(base_size * usage_multiplier)
|
||||
|
||||
if 'relevance_score' in metadata:
|
||||
relevance_multiplier = 0.5 + metadata['relevance_score'] * 0.5
|
||||
base_size = int(base_size * relevance_multiplier)
|
||||
|
||||
return max(5, min(50, base_size))
|
||||
|
||||
def _get_color_by_type(self, type: str) -> str:
|
||||
"""Get color based on node type"""
|
||||
colors = {
|
||||
"document": "#00d084", # GT Green
|
||||
"chunk": "#677489", # GT Gray
|
||||
"concept": "#4a5568", # Darker gray
|
||||
"query": "#ff6b6b", # Red for queries
|
||||
"dataset": "#4ade80", # Light green
|
||||
"user": "#3b82f6" # Blue
|
||||
}
|
||||
return colors.get(type, "#9aa5b1")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"label": self.label,
|
||||
"type": self.type,
|
||||
"x": float(self.x),
|
||||
"y": float(self.y),
|
||||
"size": self.size,
|
||||
"color": self.color,
|
||||
"importance": self.importance,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
class NetworkEdge:
|
||||
"""Represents an edge in the knowledge network"""
|
||||
def __init__(self, source: str, target: str, weight: float, edge_type: str = "semantic"):
|
||||
self.source = source
|
||||
self.target = target
|
||||
self.weight = max(0.0, min(1.0, weight)) # Clamp to 0-1
|
||||
self.type = edge_type
|
||||
self.color = self._get_edge_color(weight)
|
||||
self.width = self._get_edge_width(weight)
|
||||
self.animated = weight > 0.7
|
||||
|
||||
def _get_edge_color(self, weight: float) -> str:
|
||||
"""Get edge color based on weight"""
|
||||
if weight > 0.8:
|
||||
return "#00d084" # Strong connection - GT green
|
||||
elif weight > 0.6:
|
||||
return "#4ade80" # Medium-strong - light green
|
||||
elif weight > 0.4:
|
||||
return "#9aa5b1" # Medium - gray
|
||||
else:
|
||||
return "#d1d9e0" # Weak - light gray
|
||||
|
||||
def _get_edge_width(self, weight: float) -> float:
|
||||
"""Get edge width based on weight"""
|
||||
return 1.0 + (weight * 4.0) # 1-5px width
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
return {
|
||||
"source": self.source,
|
||||
"target": self.target,
|
||||
"weight": self.weight,
|
||||
"type": self.type,
|
||||
"color": self.color,
|
||||
"width": self.width,
|
||||
"animated": self.animated
|
||||
}
|
||||
|
||||
@router.get("/network/{dataset_id}")
|
||||
async def get_knowledge_network(
|
||||
dataset_id: str,
|
||||
max_nodes: int = Query(default=100, le=500),
|
||||
min_similarity: float = Query(default=0.3, ge=0, le=1),
|
||||
include_concepts: bool = Query(default=True),
|
||||
layout_algorithm: str = Query(default="force", description="force, circular, hierarchical"),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get force-directed graph data for a RAG dataset
|
||||
Returns nodes (documents/chunks/concepts) and edges (semantic relationships)
|
||||
"""
|
||||
try:
|
||||
# rag_service = RAGService(db)
|
||||
user_id = current_user["sub"]
|
||||
tenant_id = current_user.get("tenant_id", "default")
|
||||
|
||||
# TODO: Verify dataset ownership and access permissions
|
||||
# dataset = await rag_service._get_user_dataset(user_id, dataset_id)
|
||||
# if not dataset:
|
||||
# raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
# For now, generate mock data that represents the expected structure
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
# Generate mock document nodes
|
||||
doc_count = min(max_nodes // 3, 20)
|
||||
for i in range(doc_count):
|
||||
doc_node = NetworkNode(
|
||||
id=f"doc_{i}",
|
||||
label=f"Document {i+1}",
|
||||
type="document",
|
||||
metadata={
|
||||
"filename": f"document_{i+1}.pdf",
|
||||
"size_bytes": random.randint(1000, 100000),
|
||||
"chunk_count": random.randint(5, 50),
|
||||
"upload_date": datetime.utcnow().isoformat(),
|
||||
"usage_count": random.randint(0, 100),
|
||||
"relevance_score": random.uniform(0.3, 1.0)
|
||||
}
|
||||
)
|
||||
nodes.append(doc_node.to_dict())
|
||||
|
||||
# Generate chunks for this document
|
||||
chunk_count = min(5, (max_nodes - len(nodes)) // (doc_count - i))
|
||||
for j in range(chunk_count):
|
||||
chunk_node = NetworkNode(
|
||||
id=f"chunk_{i}_{j}",
|
||||
label=f"Chunk {j+1}",
|
||||
type="chunk",
|
||||
metadata={
|
||||
"document_id": f"doc_{i}",
|
||||
"chunk_index": j,
|
||||
"token_count": random.randint(50, 500),
|
||||
"semantic_density": random.uniform(0.2, 0.9)
|
||||
}
|
||||
)
|
||||
nodes.append(chunk_node.to_dict())
|
||||
|
||||
# Connect chunk to document
|
||||
edge = NetworkEdge(
|
||||
source=f"doc_{i}",
|
||||
target=f"chunk_{i}_{j}",
|
||||
weight=1.0,
|
||||
edge_type="contains"
|
||||
)
|
||||
edges.append(edge.to_dict())
|
||||
|
||||
# Generate concept nodes if requested
|
||||
if include_concepts:
|
||||
concept_count = min(10, max_nodes - len(nodes))
|
||||
concepts = ["AI", "Machine Learning", "Neural Networks", "Data Science",
|
||||
"Python", "JavaScript", "API", "Database", "Security", "Cloud"]
|
||||
|
||||
for i in range(concept_count):
|
||||
if i < len(concepts):
|
||||
concept_node = NetworkNode(
|
||||
id=f"concept_{i}",
|
||||
label=concepts[i],
|
||||
type="concept",
|
||||
metadata={
|
||||
"frequency": random.randint(1, 50),
|
||||
"co_occurrence_score": random.uniform(0.1, 0.8),
|
||||
"domain": "technology"
|
||||
}
|
||||
)
|
||||
nodes.append(concept_node.to_dict())
|
||||
|
||||
# Generate semantic relationships between chunks
|
||||
chunk_nodes = [n for n in nodes if n["type"] == "chunk"]
|
||||
relationship_count = 0
|
||||
max_relationships = min(len(chunk_nodes) * 2, 100)
|
||||
|
||||
for i, node1 in enumerate(chunk_nodes):
|
||||
if relationship_count >= max_relationships:
|
||||
break
|
||||
|
||||
# Connect to a few other chunks based on semantic similarity
|
||||
connection_count = random.randint(1, 4)
|
||||
|
||||
for _ in range(connection_count):
|
||||
if relationship_count >= max_relationships:
|
||||
break
|
||||
|
||||
# Select random target (avoid self-connection)
|
||||
target_idx = random.randint(0, len(chunk_nodes) - 1)
|
||||
if target_idx == i:
|
||||
continue
|
||||
|
||||
node2 = chunk_nodes[target_idx]
|
||||
|
||||
# Generate similarity score (higher for nodes with related content)
|
||||
similarity = random.uniform(0.2, 0.9)
|
||||
|
||||
# Apply minimum similarity filter
|
||||
if similarity >= min_similarity:
|
||||
edge = NetworkEdge(
|
||||
source=node1["id"],
|
||||
target=node2["id"],
|
||||
weight=similarity,
|
||||
edge_type="semantic_similarity"
|
||||
)
|
||||
edges.append(edge.to_dict())
|
||||
relationship_count += 1
|
||||
|
||||
# Connect concepts to relevant chunks
|
||||
concept_nodes = [n for n in nodes if n["type"] == "concept"]
|
||||
for concept in concept_nodes:
|
||||
# Connect to 2-5 relevant chunks
|
||||
connection_count = random.randint(2, 6)
|
||||
connected_chunks = random.sample(
|
||||
range(len(chunk_nodes)),
|
||||
k=min(connection_count, len(chunk_nodes))
|
||||
)
|
||||
|
||||
for chunk_idx in connected_chunks:
|
||||
chunk = chunk_nodes[chunk_idx]
|
||||
relevance = random.uniform(0.4, 0.9)
|
||||
|
||||
edge = NetworkEdge(
|
||||
source=concept["id"],
|
||||
target=chunk["id"],
|
||||
weight=relevance,
|
||||
edge_type="concept_relevance"
|
||||
)
|
||||
edges.append(edge.to_dict())
|
||||
|
||||
# Apply layout algorithm positioning
|
||||
if layout_algorithm == "circular":
|
||||
_apply_circular_layout(nodes)
|
||||
elif layout_algorithm == "hierarchical":
|
||||
_apply_hierarchical_layout(nodes)
|
||||
# force layout uses random positioning (already applied)
|
||||
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"metadata": {
|
||||
"dataset_name": f"Dataset {dataset_id}",
|
||||
"total_nodes": len(nodes),
|
||||
"total_edges": len(edges),
|
||||
"node_types": {
|
||||
node_type: len([n for n in nodes if n["type"] == node_type])
|
||||
for node_type in set(n["type"] for n in nodes)
|
||||
},
|
||||
"edge_types": {
|
||||
edge_type: len([e for e in edges if e["type"] == edge_type])
|
||||
for edge_type in set(e["type"] for e in edges)
|
||||
},
|
||||
"min_similarity": min_similarity,
|
||||
"layout_algorithm": layout_algorithm,
|
||||
"generated_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating knowledge network: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/search/visual")
|
||||
async def visual_search(
|
||||
query: str,
|
||||
dataset_ids: Optional[List[str]] = Query(default=None),
|
||||
top_k: int = Query(default=10, le=50),
|
||||
include_network: bool = Query(default=True),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform semantic search with visualization metadata
|
||||
Returns search results with connection paths and confidence scores
|
||||
"""
|
||||
try:
|
||||
# rag_service = RAGService(db)
|
||||
user_id = current_user["sub"]
|
||||
tenant_id = current_user.get("tenant_id", "default")
|
||||
|
||||
# TODO: Perform actual search using RAG service
|
||||
# results = await rag_service.search_documents(
|
||||
# user_id=user_id,
|
||||
# tenant_id=tenant_id,
|
||||
# query=query,
|
||||
# dataset_ids=dataset_ids,
|
||||
# top_k=top_k
|
||||
# )
|
||||
|
||||
# Generate mock search results for now
|
||||
results = []
|
||||
for i in range(top_k):
|
||||
similarity = random.uniform(0.5, 0.95) # Mock similarity scores
|
||||
results.append({
|
||||
"id": f"result_{i}",
|
||||
"document_id": f"doc_{random.randint(0, 9)}",
|
||||
"chunk_id": f"chunk_{i}",
|
||||
"content": f"Mock search result {i+1} for query: {query}",
|
||||
"metadata": {
|
||||
"filename": f"document_{i}.pdf",
|
||||
"page_number": random.randint(1, 100),
|
||||
"chunk_index": i
|
||||
},
|
||||
"similarity": similarity,
|
||||
"dataset_id": dataset_ids[0] if dataset_ids else "default"
|
||||
})
|
||||
|
||||
# Enhance results with visualization data
|
||||
visual_results = []
|
||||
for i, result in enumerate(results):
|
||||
visual_result = {
|
||||
**result,
|
||||
"visual_metadata": {
|
||||
"position": i + 1,
|
||||
"relevance_score": result["similarity"],
|
||||
"confidence_level": _calculate_confidence(result["similarity"]),
|
||||
"connection_strength": result["similarity"],
|
||||
"highlight_color": _get_relevance_color(result["similarity"]),
|
||||
"path_to_query": _generate_path(query, result),
|
||||
"semantic_distance": 1.0 - result["similarity"],
|
||||
"cluster_id": f"cluster_{random.randint(0, 2)}"
|
||||
}
|
||||
}
|
||||
visual_results.append(visual_result)
|
||||
|
||||
response = {
|
||||
"query": query,
|
||||
"results": visual_results,
|
||||
"total_results": len(visual_results),
|
||||
"search_metadata": {
|
||||
"execution_time_ms": random.randint(50, 200),
|
||||
"datasets_searched": dataset_ids or ["all"],
|
||||
"semantic_method": "embedding_similarity",
|
||||
"reranking_applied": True
|
||||
}
|
||||
}
|
||||
|
||||
# Add network visualization if requested
|
||||
if include_network:
|
||||
network_data = _generate_search_network(query, visual_results)
|
||||
response["network_visualization"] = network_data
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in visual search: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/datasets/{dataset_id}/stats")
|
||||
async def get_dataset_stats(
|
||||
dataset_id: str,
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get statistical information about a dataset for visualization"""
|
||||
try:
|
||||
# TODO: Implement actual dataset statistics
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"document_count": random.randint(10, 100),
|
||||
"chunk_count": random.randint(100, 1000),
|
||||
"total_tokens": random.randint(10000, 100000),
|
||||
"concept_count": random.randint(20, 200),
|
||||
"average_document_similarity": random.uniform(0.3, 0.8),
|
||||
"semantic_clusters": random.randint(3, 10),
|
||||
"most_connected_documents": [
|
||||
{"id": f"doc_{i}", "connections": random.randint(5, 50)}
|
||||
for i in range(5)
|
||||
],
|
||||
"topic_distribution": {
|
||||
f"topic_{i}": random.uniform(0.05, 0.3)
|
||||
for i in range(5)
|
||||
},
|
||||
"last_updated": datetime.utcnow().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting dataset stats: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Helper functions
|
||||
def _calculate_confidence(similarity: float) -> str:
|
||||
"""Calculate confidence level from similarity score"""
|
||||
if similarity > 0.9:
|
||||
return "very_high"
|
||||
elif similarity > 0.75:
|
||||
return "high"
|
||||
elif similarity > 0.6:
|
||||
return "medium"
|
||||
elif similarity > 0.4:
|
||||
return "low"
|
||||
else:
|
||||
return "very_low"
|
||||
|
||||
def _get_relevance_color(similarity: float) -> str:
|
||||
"""Get color based on relevance score"""
|
||||
if similarity > 0.8:
|
||||
return "#00d084" # GT Green
|
||||
elif similarity > 0.6:
|
||||
return "#4ade80" # Light green
|
||||
elif similarity > 0.4:
|
||||
return "#fbbf24" # Yellow
|
||||
else:
|
||||
return "#ef4444" # Red
|
||||
|
||||
def _generate_path(query: str, result: Dict[str, Any]) -> List[str]:
|
||||
"""Generate conceptual path from query to result"""
|
||||
return [
|
||||
"query",
|
||||
f"dataset_{result.get('dataset_id', 'unknown')}",
|
||||
f"document_{result.get('document_id', 'unknown')}",
|
||||
f"chunk_{result.get('chunk_id', 'unknown')}"
|
||||
]
|
||||
|
||||
def _generate_search_network(query: str, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Generate network visualization for search results"""
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
# Add query node
|
||||
query_node = NetworkNode(
|
||||
id="query",
|
||||
label=query[:50] + "..." if len(query) > 50 else query,
|
||||
type="query",
|
||||
metadata={"original_query": query}
|
||||
)
|
||||
nodes.append(query_node.to_dict())
|
||||
|
||||
# Add result nodes and connections
|
||||
for i, result in enumerate(results):
|
||||
result_node = NetworkNode(
|
||||
id=f"result_{i}",
|
||||
label=result.get("content", "")[:30] + "...",
|
||||
type="chunk",
|
||||
metadata={
|
||||
"similarity": result["similarity"],
|
||||
"position": i + 1,
|
||||
"document_id": result.get("document_id")
|
||||
}
|
||||
)
|
||||
nodes.append(result_node.to_dict())
|
||||
|
||||
# Connect query to result
|
||||
edge = NetworkEdge(
|
||||
source="query",
|
||||
target=f"result_{i}",
|
||||
weight=result["similarity"],
|
||||
edge_type="search_result"
|
||||
)
|
||||
edges.append(edge.to_dict())
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"center_node_id": "query",
|
||||
"layout": "radial"
|
||||
}
|
||||
|
||||
def _apply_circular_layout(nodes: List[Dict[str, Any]]) -> None:
|
||||
"""Apply circular layout to nodes"""
|
||||
if not nodes:
|
||||
return
|
||||
|
||||
radius = 150
|
||||
angle_step = 2 * math.pi / len(nodes)
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
angle = i * angle_step
|
||||
node["x"] = radius * math.cos(angle)
|
||||
node["y"] = radius * math.sin(angle)
|
||||
|
||||
def _apply_hierarchical_layout(nodes: List[Dict[str, Any]]) -> None:
|
||||
"""Apply hierarchical layout to nodes"""
|
||||
if not nodes:
|
||||
return
|
||||
|
||||
# Group by type
|
||||
node_types = {}
|
||||
for node in nodes:
|
||||
node_type = node["type"]
|
||||
if node_type not in node_types:
|
||||
node_types[node_type] = []
|
||||
node_types[node_type].append(node)
|
||||
|
||||
# Position by type levels
|
||||
y_positions = {"document": -100, "concept": 0, "chunk": 100, "query": -150}
|
||||
|
||||
for node_type, type_nodes in node_types.items():
|
||||
y_pos = y_positions.get(node_type, 0)
|
||||
x_step = 300 / max(1, len(type_nodes) - 1) if len(type_nodes) > 1 else 0
|
||||
|
||||
for i, node in enumerate(type_nodes):
|
||||
node["y"] = y_pos
|
||||
node["x"] = -150 + (i * x_step) if len(type_nodes) > 1 else 0
|
||||
697
apps/tenant-backend/app/api/v1/search.py
Normal file
697
apps/tenant-backend/app/api/v1/search.py
Normal file
@@ -0,0 +1,697 @@
|
||||
"""
|
||||
Search API for GT 2.0 Tenant Backend
|
||||
|
||||
Provides hybrid vector + text search capabilities using PGVector.
|
||||
Supports semantic similarity search, full-text search, and hybrid ranking.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Header
|
||||
from pydantic import BaseModel, Field
|
||||
import logging
|
||||
import httpx
|
||||
import time
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.services.pgvector_search_service import (
|
||||
PGVectorSearchService,
|
||||
HybridSearchResult,
|
||||
SearchConfig,
|
||||
get_pgvector_search_service
|
||||
)
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/search", tags=["search"])
|
||||
|
||||
|
||||
async def get_user_context(
|
||||
x_tenant_domain: Optional[str] = Header(None),
|
||||
x_user_id: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get user context from headers (for internal services).
|
||||
Handles email, UUID, and numeric ID formats for user identification.
|
||||
"""
|
||||
logger.info(f"🔍 GET_USER_CONTEXT: x_tenant_domain='{x_tenant_domain}', x_user_id='{x_user_id}'")
|
||||
|
||||
# Validate required headers
|
||||
if not x_tenant_domain or not x_user_id:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing required headers: X-Tenant-Domain and X-User-ID"
|
||||
)
|
||||
|
||||
# Validate and clean inputs
|
||||
x_tenant_domain = x_tenant_domain.strip() if x_tenant_domain else None
|
||||
x_user_id = x_user_id.strip() if x_user_id else None
|
||||
|
||||
if not x_tenant_domain or not x_user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid empty headers: X-Tenant-Domain and X-User-ID cannot be empty"
|
||||
)
|
||||
|
||||
logger.info(f"🔍 GET_USER_CONTEXT: Processing user_id='{x_user_id}' for tenant='{x_tenant_domain}'")
|
||||
|
||||
# Use ensure_user_uuid to handle all user ID formats (email, UUID, numeric)
|
||||
from app.core.user_resolver import ensure_user_uuid
|
||||
try:
|
||||
resolved_uuid = await ensure_user_uuid(x_user_id, x_tenant_domain)
|
||||
logger.info(f"🔍 GET_USER_CONTEXT: Resolved user_id '{x_user_id}' to UUID '{resolved_uuid}'")
|
||||
|
||||
# Determine original email if input was UUID
|
||||
user_email = x_user_id if "@" in x_user_id else None
|
||||
|
||||
# If we don't have an email, try to get it from the database
|
||||
if not user_email:
|
||||
try:
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
tenant_schema = f"tenant_{x_tenant_domain.replace('.', '_').replace('-', '_')}"
|
||||
user_row = await conn.fetchrow(
|
||||
f"SELECT email FROM {tenant_schema}.users WHERE id = $1",
|
||||
resolved_uuid
|
||||
)
|
||||
if user_row:
|
||||
user_email = user_row['email']
|
||||
else:
|
||||
# Fallback to UUID as email for backward compatibility
|
||||
user_email = resolved_uuid
|
||||
logger.warning(f"Could not find email for UUID {resolved_uuid}, using UUID as email")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to lookup email for UUID {resolved_uuid}: {e}")
|
||||
user_email = resolved_uuid
|
||||
|
||||
context = {
|
||||
"tenant_domain": x_tenant_domain,
|
||||
"id": resolved_uuid,
|
||||
"sub": resolved_uuid,
|
||||
"email": user_email,
|
||||
"user_type": "internal_service"
|
||||
}
|
||||
logger.info(f"🔍 GET_USER_CONTEXT: Returning context: {context}")
|
||||
return context
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"🔍 GET_USER_CONTEXT ERROR: Failed to resolve user_id '{x_user_id}': {e}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid user identifier '{x_user_id}': {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"🔍 GET_USER_CONTEXT ERROR: Unexpected error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal error processing user context"
|
||||
)
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Request for hybrid search"""
|
||||
query: str = Field(..., min_length=1, max_length=1000, description="Search query")
|
||||
dataset_ids: Optional[List[str]] = Field(None, description="Optional dataset IDs to search within")
|
||||
search_type: str = Field("hybrid", description="Search type: hybrid, vector, text")
|
||||
max_results: int = Field(10, ge=1, le=200, description="Maximum results to return")
|
||||
|
||||
# Advanced search parameters
|
||||
vector_weight: Optional[float] = Field(0.7, ge=0.0, le=1.0, description="Weight for vector similarity")
|
||||
text_weight: Optional[float] = Field(0.3, ge=0.0, le=1.0, description="Weight for text relevance")
|
||||
min_similarity: Optional[float] = Field(0.3, ge=0.0, le=1.0, description="Minimum similarity threshold")
|
||||
rerank_results: Optional[bool] = Field(True, description="Apply result re-ranking")
|
||||
|
||||
|
||||
class SimilarChunksRequest(BaseModel):
|
||||
"""Request for finding similar chunks"""
|
||||
chunk_id: str = Field(..., description="Reference chunk ID")
|
||||
similarity_threshold: float = Field(0.5, ge=0.0, le=1.0, description="Minimum similarity")
|
||||
max_results: int = Field(5, ge=1, le=20, description="Maximum results")
|
||||
exclude_same_document: bool = Field(True, description="Exclude chunks from same document")
|
||||
|
||||
|
||||
class SearchResultResponse(BaseModel):
|
||||
"""Individual search result"""
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
dataset_id: Optional[str]
|
||||
text: str
|
||||
metadata: Dict[str, Any]
|
||||
vector_similarity: float
|
||||
text_relevance: float
|
||||
hybrid_score: float
|
||||
rank: int
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
"""Search response with results and metadata"""
|
||||
query: str
|
||||
search_type: str
|
||||
total_results: int
|
||||
results: List[SearchResultResponse]
|
||||
search_time_ms: float
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class ConversationSearchRequest(BaseModel):
|
||||
"""Request for searching conversation history"""
|
||||
query: str = Field(..., min_length=1, max_length=500, description="Search query")
|
||||
days_back: Optional[int] = Field(30, ge=1, le=365, description="Number of days back to search")
|
||||
max_results: Optional[int] = Field(5, ge=1, le=200, description="Maximum results to return")
|
||||
agent_filter: Optional[List[str]] = Field(None, description="Filter by agent names/IDs")
|
||||
include_user_messages: Optional[bool] = Field(True, description="Include user messages in results")
|
||||
|
||||
|
||||
class DocumentChunksResponse(BaseModel):
|
||||
"""Response for document chunks"""
|
||||
document_id: str
|
||||
total_chunks: int
|
||||
chunks: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# Search Endpoints
|
||||
|
||||
@router.post("/", response_model=SearchResponse)
|
||||
async def search_documents(
|
||||
request: SearchRequest,
|
||||
user_context: Dict[str, Any] = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Perform hybrid search across user's documents and datasets.
|
||||
|
||||
Combines vector similarity search with full-text search for optimal results.
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"🔍 SEARCH_API START: request={request.dict()}")
|
||||
logger.info(f"🔍 SEARCH_API: user_context={user_context}")
|
||||
|
||||
# Extract tenant info
|
||||
tenant_domain = user_context.get("tenant_domain", "test")
|
||||
user_id = user_context.get("id", user_context.get("sub", None))
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing user ID in authentication context"
|
||||
)
|
||||
logger.info(f"🔍 SEARCH_API: extracted tenant_domain='{tenant_domain}', user_id='{user_id}'")
|
||||
|
||||
# Validate weights sum to 1.0 for hybrid search
|
||||
if request.search_type == "hybrid":
|
||||
total_weight = (request.vector_weight or 0.7) + (request.text_weight or 0.3)
|
||||
if abs(total_weight - 1.0) > 0.01:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Vector weight and text weight must sum to 1.0"
|
||||
)
|
||||
|
||||
# Initialize search service
|
||||
logger.info(f"🔍 SEARCH_API: Initializing search service with tenant_id='{tenant_domain}', user_id='{user_id}'")
|
||||
search_service = get_pgvector_search_service(
|
||||
tenant_id=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Configure search parameters
|
||||
config = SearchConfig(
|
||||
vector_weight=request.vector_weight or 0.7,
|
||||
text_weight=request.text_weight or 0.3,
|
||||
min_vector_similarity=request.min_similarity or 0.3,
|
||||
min_text_relevance=0.01, # Fix: Use appropriate ts_rank_cd threshold
|
||||
max_results=request.max_results,
|
||||
rerank_results=request.rerank_results or True
|
||||
)
|
||||
logger.info(f"🔍 SEARCH_API: Search config created: {config.__dict__}")
|
||||
|
||||
# Execute search based on type
|
||||
results = []
|
||||
logger.info(f"🔍 SEARCH_API: Executing {request.search_type} search")
|
||||
if request.search_type == "hybrid":
|
||||
logger.info(f"🔍 SEARCH_API: Calling hybrid_search with query='{request.query}', user_id='{user_id}', dataset_ids={request.dataset_ids}")
|
||||
results = await search_service.hybrid_search(
|
||||
query=request.query,
|
||||
user_id=user_id,
|
||||
dataset_ids=request.dataset_ids,
|
||||
config=config,
|
||||
limit=request.max_results
|
||||
)
|
||||
logger.info(f"🔍 SEARCH_API: Hybrid search returned {len(results)} results")
|
||||
elif request.search_type == "vector":
|
||||
# Generate query embedding first
|
||||
query_embedding = await search_service._generate_query_embedding(
|
||||
request.query,
|
||||
user_id
|
||||
)
|
||||
results = await search_service.vector_similarity_search(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
dataset_ids=request.dataset_ids,
|
||||
similarity_threshold=config.min_vector_similarity,
|
||||
limit=request.max_results
|
||||
)
|
||||
elif request.search_type == "text":
|
||||
results = await search_service.full_text_search(
|
||||
query=request.query,
|
||||
user_id=user_id,
|
||||
dataset_ids=request.dataset_ids,
|
||||
limit=request.max_results
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid search_type. Must be 'hybrid', 'vector', or 'text'"
|
||||
)
|
||||
|
||||
# Calculate search time
|
||||
search_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Convert results to response format
|
||||
result_responses = [
|
||||
SearchResultResponse(
|
||||
chunk_id=str(result.chunk_id),
|
||||
document_id=str(result.document_id),
|
||||
dataset_id=str(result.dataset_id) if result.dataset_id else None,
|
||||
text=result.text,
|
||||
metadata=result.metadata if isinstance(result.metadata, dict) else {},
|
||||
vector_similarity=result.vector_similarity,
|
||||
text_relevance=result.text_relevance,
|
||||
hybrid_score=result.hybrid_score,
|
||||
rank=result.rank
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Search completed: query='{request.query}', "
|
||||
f"type={request.search_type}, results={len(results)}, "
|
||||
f"time={search_time_ms:.1f}ms"
|
||||
)
|
||||
|
||||
return SearchResponse(
|
||||
query=request.query,
|
||||
search_type=request.search_type,
|
||||
total_results=len(results),
|
||||
results=result_responses,
|
||||
search_time_ms=search_time_ms,
|
||||
config={
|
||||
"vector_weight": config.vector_weight,
|
||||
"text_weight": config.text_weight,
|
||||
"min_similarity": config.min_vector_similarity,
|
||||
"rerank_enabled": config.rerank_results
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}/chunks", response_model=DocumentChunksResponse)
|
||||
async def get_document_chunks(
|
||||
document_id: str,
|
||||
include_embeddings: bool = Query(False, description="Include embedding vectors"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all chunks for a specific document.
|
||||
|
||||
Useful for understanding document structure and chunk boundaries.
|
||||
"""
|
||||
try:
|
||||
# Extract tenant info
|
||||
tenant_domain = user_context.get("tenant_domain", "test")
|
||||
user_id = user_context.get("id", user_context.get("sub", None))
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing user ID in authentication context"
|
||||
)
|
||||
|
||||
search_service = get_pgvector_search_service(
|
||||
tenant_id=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
chunks = await search_service.get_document_chunks(
|
||||
document_id=document_id,
|
||||
user_id=user_id,
|
||||
include_embeddings=include_embeddings
|
||||
)
|
||||
|
||||
return DocumentChunksResponse(
|
||||
document_id=document_id,
|
||||
total_chunks=len(chunks),
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get document chunks: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/similar-chunks", response_model=SearchResponse)
|
||||
async def find_similar_chunks(
|
||||
request: SimilarChunksRequest,
|
||||
user_context: Dict[str, Any] = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Find chunks similar to a reference chunk.
|
||||
|
||||
Useful for exploring related content and building context.
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Extract tenant info
|
||||
tenant_domain = user_context.get("tenant_domain", "test")
|
||||
user_id = user_context.get("id", user_context.get("sub", None))
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing user ID in authentication context"
|
||||
)
|
||||
|
||||
search_service = get_pgvector_search_service(
|
||||
tenant_id=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
results = await search_service.search_similar_chunks(
|
||||
chunk_id=request.chunk_id,
|
||||
user_id=user_id,
|
||||
similarity_threshold=request.similarity_threshold,
|
||||
limit=request.max_results,
|
||||
exclude_same_document=request.exclude_same_document
|
||||
)
|
||||
|
||||
search_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Convert results to response format
|
||||
result_responses = [
|
||||
SearchResultResponse(
|
||||
chunk_id=str(result.chunk_id),
|
||||
document_id=str(result.document_id),
|
||||
dataset_id=str(result.dataset_id) if result.dataset_id else None,
|
||||
text=result.text,
|
||||
metadata=result.metadata if isinstance(result.metadata, dict) else {},
|
||||
vector_similarity=result.vector_similarity,
|
||||
text_relevance=result.text_relevance,
|
||||
hybrid_score=result.hybrid_score,
|
||||
rank=result.rank
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
logger.info(f"Similar chunks search: found {len(results)} results in {search_time_ms:.1f}ms")
|
||||
|
||||
return SearchResponse(
|
||||
query=f"Similar to chunk {request.chunk_id}",
|
||||
search_type="vector_similarity",
|
||||
total_results=len(results),
|
||||
results=result_responses,
|
||||
search_time_ms=search_time_ms,
|
||||
config={
|
||||
"similarity_threshold": request.similarity_threshold,
|
||||
"exclude_same_document": request.exclude_same_document
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Similar chunks search failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
class DocumentSearchRequest(BaseModel):
|
||||
"""Request for searching specific documents"""
|
||||
query: str = Field(..., min_length=1, max_length=1000, description="Search query")
|
||||
document_ids: List[str] = Field(..., description="Document IDs to search within")
|
||||
search_type: str = Field("hybrid", description="Search type: hybrid, vector, text")
|
||||
max_results: int = Field(5, ge=1, le=20, description="Maximum results per document")
|
||||
min_similarity: Optional[float] = Field(0.6, ge=0.0, le=1.0, description="Minimum similarity threshold")
|
||||
|
||||
|
||||
@router.post("/documents", response_model=SearchResponse)
|
||||
async def search_documents_specific(
|
||||
request: DocumentSearchRequest,
|
||||
user_context: Dict[str, Any] = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Search within specific documents for relevant chunks.
|
||||
|
||||
Used by MCP RAG server for document-specific queries.
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Extract tenant info
|
||||
tenant_domain = user_context.get("tenant_domain", "test")
|
||||
user_id = user_context.get("id", user_context.get("sub", None))
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing user ID in authentication context"
|
||||
)
|
||||
|
||||
# Initialize search service
|
||||
search_service = get_pgvector_search_service(
|
||||
tenant_id=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Configure search for documents
|
||||
config = SearchConfig(
|
||||
vector_weight=0.7,
|
||||
text_weight=0.3,
|
||||
min_vector_similarity=request.min_similarity or 0.6,
|
||||
min_text_relevance=0.1,
|
||||
max_results=request.max_results,
|
||||
rerank_results=True
|
||||
)
|
||||
|
||||
# Execute search with document filter
|
||||
# First resolve dataset IDs from document IDs to satisfy security constraints
|
||||
dataset_ids = await search_service.get_dataset_ids_from_documents(request.document_ids, user_id)
|
||||
if not dataset_ids:
|
||||
logger.warning(f"No dataset IDs found for documents: {request.document_ids}")
|
||||
return SearchResponse(
|
||||
query=request.query,
|
||||
search_type=request.search_type,
|
||||
total_results=0,
|
||||
results=[],
|
||||
search_time_ms=0.0,
|
||||
config={}
|
||||
)
|
||||
|
||||
results = []
|
||||
if request.search_type == "hybrid":
|
||||
results = await search_service.hybrid_search(
|
||||
query=request.query,
|
||||
user_id=user_id,
|
||||
dataset_ids=dataset_ids, # Use resolved dataset IDs
|
||||
config=config,
|
||||
limit=request.max_results * len(request.document_ids) # Allow more results for filtering
|
||||
)
|
||||
# Filter results to only include specified documents
|
||||
results = [r for r in results if r.document_id in request.document_ids][:request.max_results]
|
||||
|
||||
elif request.search_type == "vector":
|
||||
query_embedding = await search_service._generate_query_embedding(
|
||||
request.query,
|
||||
user_id
|
||||
)
|
||||
results = await search_service.vector_similarity_search(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
dataset_ids=dataset_ids, # Use resolved dataset IDs
|
||||
similarity_threshold=config.min_vector_similarity,
|
||||
limit=request.max_results * len(request.document_ids)
|
||||
)
|
||||
# Filter by document IDs
|
||||
results = [r for r in results if r.document_id in request.document_ids][:request.max_results]
|
||||
|
||||
elif request.search_type == "text":
|
||||
results = await search_service.full_text_search(
|
||||
query=request.query,
|
||||
user_id=user_id,
|
||||
dataset_ids=dataset_ids, # Use resolved dataset IDs
|
||||
limit=request.max_results * len(request.document_ids)
|
||||
)
|
||||
# Filter by document IDs
|
||||
results = [r for r in results if r.document_id in request.document_ids][:request.max_results]
|
||||
|
||||
search_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Convert results to response format
|
||||
result_responses = [
|
||||
SearchResultResponse(
|
||||
chunk_id=str(result.chunk_id),
|
||||
document_id=str(result.document_id),
|
||||
dataset_id=str(result.dataset_id) if result.dataset_id else None,
|
||||
text=result.text,
|
||||
metadata=result.metadata if isinstance(result.metadata, dict) else {},
|
||||
vector_similarity=result.vector_similarity,
|
||||
text_relevance=result.text_relevance,
|
||||
hybrid_score=result.hybrid_score,
|
||||
rank=result.rank
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Document search completed: query='{request.query}', "
|
||||
f"documents={len(request.document_ids)}, results={len(results)}, "
|
||||
f"time={search_time_ms:.1f}ms"
|
||||
)
|
||||
|
||||
return SearchResponse(
|
||||
query=request.query,
|
||||
search_type=request.search_type,
|
||||
total_results=len(results),
|
||||
results=result_responses,
|
||||
search_time_ms=search_time_ms,
|
||||
config={
|
||||
"document_ids": request.document_ids,
|
||||
"min_similarity": request.min_similarity,
|
||||
"max_results_per_document": request.max_results
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Document search failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def search_health_check(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Health check for search service functionality.
|
||||
|
||||
Verifies that PGVector extension and search capabilities are working.
|
||||
"""
|
||||
try:
|
||||
# Extract tenant info
|
||||
tenant_domain = user_context.get("tenant_domain", "test")
|
||||
user_id = user_context.get("id", user_context.get("sub", None))
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing user ID in authentication context"
|
||||
)
|
||||
|
||||
search_service = get_pgvector_search_service(
|
||||
tenant_id=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Test basic connectivity and PGVector functionality
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
# Test PGVector extension
|
||||
result = await conn.fetchval("SELECT 1 + 1")
|
||||
if result != 2:
|
||||
raise Exception("Basic database connectivity failed")
|
||||
|
||||
# Test PGVector extension (this will fail if extension is not installed)
|
||||
await conn.fetchval("SELECT '[1,2,3]'::vector <-> '[1,2,4]'::vector")
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"tenant_id": tenant_domain,
|
||||
"pgvector_available": True,
|
||||
"search_service": "operational"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search health check failed: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": "Search service health check failed",
|
||||
"tenant_id": tenant_domain,
|
||||
"pgvector_available": False,
|
||||
"search_service": "error"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/conversations")
|
||||
async def search_conversations(
|
||||
request: ConversationSearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_user_context),
|
||||
x_tenant_domain: Optional[str] = Header(None, alias="X-Tenant-Domain"),
|
||||
x_user_id: Optional[str] = Header(None, alias="X-User-ID")
|
||||
):
|
||||
"""
|
||||
Search through conversation history using MCP conversation server.
|
||||
|
||||
Used by both external clients and internal MCP tools for conversation search.
|
||||
"""
|
||||
try:
|
||||
# Use same user resolution pattern as document search
|
||||
tenant_domain = current_user.get("tenant_domain") or x_tenant_domain
|
||||
user_id = current_user.get("id") or current_user.get("sub") or x_user_id
|
||||
|
||||
if not tenant_domain or not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing tenant_domain or user_id in request context"
|
||||
)
|
||||
|
||||
logger.info(f"🔍 Conversation search: query='{request.query}', user={user_id}, tenant={tenant_domain}")
|
||||
|
||||
# Get resource cluster URL
|
||||
settings = get_settings()
|
||||
mcp_base_url = getattr(settings, 'resource_cluster_url', 'http://gentwo-resource-backend:8000')
|
||||
|
||||
# Build request payload for MCP execution
|
||||
request_payload = {
|
||||
"server_id": "conversation_server",
|
||||
"tool_name": "search_conversations",
|
||||
"parameters": {
|
||||
"query": request.query,
|
||||
"days_back": request.days_back or 30,
|
||||
"max_results": request.max_results or 5,
|
||||
"agent_filter": request.agent_filter,
|
||||
"include_user_messages": request.include_user_messages
|
||||
},
|
||||
"tenant_domain": tenant_domain,
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
logger.info(f"🌐 Making MCP request to: {mcp_base_url}/api/v1/mcp/execute")
|
||||
|
||||
response = await client.post(
|
||||
f"{mcp_base_url}/api/v1/mcp/execute",
|
||||
json=request_payload
|
||||
)
|
||||
|
||||
execution_time_ms = (time.time() - start_time) * 1000
|
||||
logger.info(f"📊 MCP response: {response.status_code} ({execution_time_ms:.1f}ms)")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
logger.info(f"✅ Conversation search successful ({execution_time_ms:.1f}ms)")
|
||||
return result
|
||||
else:
|
||||
error_text = response.text
|
||||
error_msg = f"MCP conversation search failed: {response.status_code} - {error_text}"
|
||||
logger.error(f"❌ {error_msg}")
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Conversation search endpoint error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
758
apps/tenant-backend/app/api/v1/teams.py
Normal file
758
apps/tenant-backend/app/api/v1/teams.py
Normal file
@@ -0,0 +1,758 @@
|
||||
"""
|
||||
Teams API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Provides team collaboration management with two-tier permissions:
|
||||
- Tier 1 (Team-level): 'read' or 'share' set by team owner
|
||||
- Tier 2 (Resource-level): 'read' or 'edit' set by resource sharer
|
||||
|
||||
Follows GT 2.0 principles:
|
||||
- Perfect tenant isolation
|
||||
- Admin bypass for tenant admins
|
||||
- Fail-fast error handling
|
||||
- PostgreSQL-first storage
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Dict, Any, List
|
||||
import logging
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.services.team_service import TeamService
|
||||
from app.api.auth import get_tenant_user_uuid_by_email
|
||||
from app.models.collaboration_team import (
|
||||
TeamCreate,
|
||||
TeamUpdate,
|
||||
Team,
|
||||
TeamListResponse,
|
||||
TeamResponse,
|
||||
TeamWithMembers,
|
||||
TeamWithMembersResponse,
|
||||
AddMemberRequest,
|
||||
UpdateMemberPermissionRequest,
|
||||
MemberListResponse,
|
||||
MemberResponse,
|
||||
ShareResourceRequest,
|
||||
SharedResourcesResponse,
|
||||
SharedResource,
|
||||
TeamInvitation,
|
||||
InvitationListResponse,
|
||||
ObservableRequest,
|
||||
ObservableRequestListResponse,
|
||||
TeamActivityMetrics,
|
||||
TeamActivityResponse,
|
||||
ErrorResponse
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/teams", tags=["teams"])
|
||||
|
||||
|
||||
async def get_team_service_for_user(current_user: Dict[str, Any]) -> TeamService:
|
||||
"""Helper function to create TeamService with proper tenant UUID mapping"""
|
||||
user_email = current_user.get('email')
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="User email not found in token")
|
||||
|
||||
tenant_user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
if not tenant_user_uuid:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_email} not found in tenant system")
|
||||
|
||||
return TeamService(
|
||||
tenant_domain=current_user.get('tenant_domain', 'test-company'),
|
||||
user_id=tenant_user_uuid,
|
||||
user_email=user_email
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TEAM CRUD ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("", response_model=TeamListResponse)
|
||||
async def list_teams(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List all teams where the current user is owner or member.
|
||||
Returns teams with member counts and permission flags.
|
||||
|
||||
Permission flags:
|
||||
- is_owner: User created this team
|
||||
- can_manage: User can manage team (owner or admin/developer)
|
||||
"""
|
||||
logger.info(f"Listing teams for user {current_user['sub']}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
teams = await service.get_user_teams()
|
||||
|
||||
return TeamListResponse(
|
||||
data=teams,
|
||||
total=len(teams)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing teams: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("", response_model=TeamResponse, status_code=201)
|
||||
async def create_team(
|
||||
team_data: TeamCreate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Create a new team with the current user as owner.
|
||||
|
||||
The creator is automatically the team owner with full management permissions.
|
||||
"""
|
||||
logger.info(f"Creating team '{team_data.name}' for user {current_user['sub']}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
team = await service.create_team(
|
||||
name=team_data.name,
|
||||
description=team_data.description or ""
|
||||
)
|
||||
|
||||
return TeamResponse(data=team)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating team: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# TEAM INVITATION ENDPOINTS (must come before /{team_id} routes)
|
||||
# ==============================================================================
|
||||
|
||||
@router.get("/invitations", response_model=InvitationListResponse)
|
||||
async def list_my_invitations(
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get current user's pending team invitations.
|
||||
|
||||
Returns list of invitations with team details and inviter information.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
invitations = await service.get_pending_invitations()
|
||||
|
||||
return InvitationListResponse(
|
||||
data=invitations,
|
||||
total=len(invitations)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing invitations: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/invitations/{invitation_id}/accept", response_model=MemberResponse)
|
||||
async def accept_team_invitation(
|
||||
invitation_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Accept a team invitation.
|
||||
|
||||
Updates the invitation status to 'accepted' and grants team membership.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
member = await service.accept_invitation(invitation_id)
|
||||
|
||||
return MemberResponse(data=member)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error accepting invitation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/invitations/{invitation_id}/decline", status_code=204)
|
||||
async def decline_team_invitation(
|
||||
invitation_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Decline a team invitation.
|
||||
|
||||
Removes the invitation from the system.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
await service.decline_invitation(invitation_id)
|
||||
|
||||
return None
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error declining invitation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/observable-requests", response_model=ObservableRequestListResponse)
|
||||
async def get_observable_requests(
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get pending Observable requests for the current user.
|
||||
|
||||
Returns list of teams requesting Observable access.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
requests = await service.get_observable_requests()
|
||||
|
||||
return ObservableRequestListResponse(
|
||||
data=[ObservableRequest(**req) for req in requests],
|
||||
total=len(requests)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Observable requests: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# TEAM CRUD ENDPOINTS (dynamic routes with {team_id})
|
||||
# ==============================================================================
|
||||
|
||||
@router.get("/{team_id}", response_model=TeamResponse)
|
||||
async def get_team(
|
||||
team_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get team details by ID.
|
||||
|
||||
Only accessible to team members or tenant admins.
|
||||
"""
|
||||
logger.info(f"Getting team {team_id} for user {current_user['sub']}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
team = await service.get_team_by_id(team_id)
|
||||
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail=f"Team {team_id} not found")
|
||||
|
||||
return TeamResponse(data=team)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting team: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{team_id}", response_model=TeamResponse)
|
||||
async def update_team(
|
||||
team_id: str,
|
||||
updates: TeamUpdate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update team name/description.
|
||||
|
||||
Requires: Team ownership or admin/developer role
|
||||
"""
|
||||
logger.info(f"Updating team {team_id} for user {current_user['sub']}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
|
||||
# Convert Pydantic model to dict, excluding None values
|
||||
update_dict = updates.model_dump(exclude_none=True)
|
||||
|
||||
team = await service.update_team(team_id, update_dict)
|
||||
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail=f"Team {team_id} not found")
|
||||
|
||||
return TeamResponse(data=team)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating team: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{team_id}", status_code=204)
|
||||
async def delete_team(
|
||||
team_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Delete a team and all its memberships (CASCADE).
|
||||
|
||||
Requires: Team ownership or admin/developer role
|
||||
"""
|
||||
logger.info(f"Deleting team {team_id} for user {current_user['sub']}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
success = await service.delete_team(team_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Team {team_id} not found")
|
||||
|
||||
return None # 204 No Content
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting team: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TEAM MEMBER ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/{team_id}/members", response_model=MemberListResponse)
|
||||
async def list_team_members(
|
||||
team_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List all members of a team with their permissions.
|
||||
|
||||
Only accessible to team members or tenant admins.
|
||||
|
||||
Returns:
|
||||
- user_id, user_email, user_name
|
||||
- team_permission: 'read' or 'share'
|
||||
- resource_permissions: JSONB dict of resource-level permissions
|
||||
"""
|
||||
logger.info(f"Listing members for team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
members = await service.get_team_members(team_id)
|
||||
|
||||
return MemberListResponse(
|
||||
data=members,
|
||||
total=len(members)
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing team members: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{team_id}/members", response_model=MemberResponse, status_code=201)
|
||||
async def add_team_member(
|
||||
team_id: str,
|
||||
member_data: AddMemberRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Add a user to the team with specified permission.
|
||||
|
||||
Requires: Team ownership or admin/developer role
|
||||
|
||||
Team Permissions:
|
||||
- 'read': Can access resources shared to this team
|
||||
- 'share': Can access resources AND share own resources to this team
|
||||
|
||||
Note: Observability access is automatically requested when inviting users.
|
||||
The invited user can approve or decline the observability request separately.
|
||||
"""
|
||||
logger.info(f"Adding member {member_data.user_email} to team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
member = await service.add_member(
|
||||
team_id=team_id,
|
||||
user_email=member_data.user_email,
|
||||
team_permission=member_data.team_permission
|
||||
)
|
||||
|
||||
return MemberResponse(data=member)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding team member: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{team_id}/members/{user_id}", response_model=MemberResponse)
|
||||
async def update_member_permission(
|
||||
team_id: str,
|
||||
user_id: str,
|
||||
permission_data: UpdateMemberPermissionRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update a team member's permission level.
|
||||
|
||||
Requires: Team ownership or admin/developer role
|
||||
|
||||
Note: DB trigger auto-clears resource_permissions when downgraded from 'share' to 'read'
|
||||
"""
|
||||
logger.info(f"PUT /teams/{team_id}/members/{user_id} - Permission update request")
|
||||
logger.info(f"Request body: {permission_data.model_dump()}")
|
||||
logger.info(f"Current user: {current_user.get('email')}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
member = await service.update_member_permission(
|
||||
team_id=team_id,
|
||||
user_id=user_id,
|
||||
new_permission=permission_data.team_permission
|
||||
)
|
||||
|
||||
return MemberResponse(data=member)
|
||||
|
||||
except PermissionError as e:
|
||||
logger.error(f"PermissionError updating member permission: {str(e)}")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(f"ValueError updating member permission: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"RuntimeError updating member permission: {str(e)}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating member permission: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{team_id}/members/{user_id}", status_code=204)
|
||||
async def remove_team_member(
|
||||
team_id: str,
|
||||
user_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Remove a user from the team.
|
||||
|
||||
Requires: Team ownership or admin/developer role
|
||||
"""
|
||||
logger.info(f"Removing member {user_id} from team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
success = await service.remove_member(team_id, user_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Member {user_id} not found in team {team_id}")
|
||||
|
||||
return None # 204 No Content
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing team member: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RESOURCE SHARING ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/{team_id}/share", status_code=201)
|
||||
async def share_resource_to_team(
|
||||
team_id: str,
|
||||
share_data: ShareResourceRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Share a resource (agent/dataset) to team with per-user permissions.
|
||||
|
||||
Requires: Team ownership or 'share' team permission
|
||||
|
||||
Request body:
|
||||
{
|
||||
"resource_type": "agent" | "dataset",
|
||||
"resource_id": "uuid",
|
||||
"user_permissions": {
|
||||
"user_uuid_1": "read",
|
||||
"user_uuid_2": "edit"
|
||||
}
|
||||
}
|
||||
|
||||
Resource Permissions:
|
||||
- 'read': View-only access to resource
|
||||
- 'edit': Full edit access to resource
|
||||
"""
|
||||
logger.info(f"Sharing {share_data.resource_type}:{share_data.resource_id} to team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
|
||||
# Use new junction table method (Phase 2)
|
||||
await service.share_resource_to_teams(
|
||||
resource_id=share_data.resource_id,
|
||||
resource_type=share_data.resource_type,
|
||||
shared_by=current_user["user_id"],
|
||||
team_shares=[{
|
||||
"team_id": team_id,
|
||||
"user_permissions": share_data.user_permissions
|
||||
}]
|
||||
)
|
||||
|
||||
return {"message": "Resource shared successfully", "success": True}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sharing resource: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{team_id}/share/{resource_type}/{resource_id}", status_code=204)
|
||||
async def unshare_resource_from_team(
|
||||
team_id: str,
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Remove resource sharing from team (removes from all members' resource_permissions).
|
||||
|
||||
Requires: Team ownership or 'share' team permission
|
||||
"""
|
||||
logger.info(f"Unsharing {resource_type}:{resource_id} from team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
|
||||
# Use new junction table method (Phase 2)
|
||||
await service.unshare_resource_from_team(
|
||||
resource_id=resource_id,
|
||||
resource_type=resource_type,
|
||||
team_id=team_id
|
||||
)
|
||||
|
||||
return None # 204 No Content
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error unsharing resource: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{team_id}/resources", response_model=SharedResourcesResponse)
|
||||
async def list_shared_resources(
|
||||
team_id: str,
|
||||
resource_type: str = Query(None, description="Filter by resource type: 'agent' or 'dataset'"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List all resources shared to a team.
|
||||
|
||||
Only accessible to team members or tenant admins.
|
||||
|
||||
Returns list of:
|
||||
{
|
||||
"resource_type": "agent" | "dataset",
|
||||
"resource_id": "uuid",
|
||||
"user_permissions": {"user_id": "read|edit", ...}
|
||||
}
|
||||
"""
|
||||
logger.info(f"Listing shared resources for team {team_id}")
|
||||
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
resources = await service.get_shared_resources(
|
||||
team_id=team_id,
|
||||
resource_type=resource_type
|
||||
)
|
||||
|
||||
return SharedResourcesResponse(
|
||||
data=resources,
|
||||
total=len(resources)
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing shared resources: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@router.get("/{team_id}/invitations", response_model=InvitationListResponse)
|
||||
async def list_team_invitations(
|
||||
team_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get pending invitations for a team (owner view).
|
||||
|
||||
Shows all users who have been invited but haven't accepted yet.
|
||||
Requires team ownership or admin role.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
invitations = await service.get_team_pending_invitations(team_id)
|
||||
|
||||
return InvitationListResponse(
|
||||
data=invitations,
|
||||
total=len(invitations)
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing team invitations: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{team_id}/invitations/{invitation_id}", status_code=204)
|
||||
async def cancel_team_invitation(
|
||||
team_id: str,
|
||||
invitation_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Cancel a pending invitation (owner only).
|
||||
|
||||
Removes the invitation before the user accepts it.
|
||||
Requires team ownership or admin role.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
await service.cancel_invitation(team_id, invitation_id)
|
||||
|
||||
return None
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error canceling invitation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Observable Member Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/{team_id}/members/{user_id}/request-observable", status_code=200)
|
||||
async def request_observable_access(
|
||||
team_id: str,
|
||||
user_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Request Observable access from a team member.
|
||||
|
||||
Sets Observable status to pending for the target user.
|
||||
Requires owner or manager permission.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
result = await service.request_observable_status(team_id, user_id)
|
||||
|
||||
return {"success": True, "data": result}
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error requesting Observable access: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{team_id}/observable/approve", status_code=200)
|
||||
async def approve_observable_request(
|
||||
team_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Approve Observable status for current user in a team.
|
||||
|
||||
User explicitly consents to team managers viewing their activity.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
result = await service.approve_observable_consent(team_id)
|
||||
|
||||
return {"success": True, "data": result}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error approving Observable request: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{team_id}/observable", status_code=200)
|
||||
async def revoke_observable_status(
|
||||
team_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Revoke Observable status for current user in a team.
|
||||
|
||||
Immediately removes manager access to user's activity data.
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
result = await service.revoke_observable_status(team_id)
|
||||
|
||||
return {"success": True, "data": result}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking Observable status: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{team_id}/activity", response_model=TeamActivityResponse)
|
||||
async def get_team_activity(
|
||||
team_id: str,
|
||||
days: int = Query(7, ge=1, le=365, description="Number of days to include in metrics"),
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get team activity metrics for Observable members.
|
||||
|
||||
Returns aggregated activity data for team members who have approved Observable status.
|
||||
Requires owner or manager permission.
|
||||
|
||||
Args:
|
||||
team_id: Team UUID
|
||||
days: Number of days to include (1-365, default 7)
|
||||
"""
|
||||
try:
|
||||
service = await get_team_service_for_user(current_user)
|
||||
activity = await service.get_team_activity(team_id, days)
|
||||
|
||||
return TeamActivityResponse(
|
||||
data=TeamActivityMetrics(**activity)
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting team activity: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
348
apps/tenant-backend/app/api/v1/users.py
Normal file
348
apps/tenant-backend/app/api/v1/users.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
User API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Handles user preferences and favorite agents management.
|
||||
Follows GT 2.0 principles: no mocks, real implementations, fail fast.
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, Header
|
||||
from typing import Optional
|
||||
|
||||
from app.services.user_service import UserService
|
||||
from app.schemas.user import (
|
||||
UserPreferencesResponse,
|
||||
UpdateUserPreferencesRequest,
|
||||
FavoriteAgentsResponse,
|
||||
UpdateFavoriteAgentsRequest,
|
||||
AddFavoriteAgentRequest,
|
||||
RemoveFavoriteAgentRequest,
|
||||
CustomCategoriesResponse,
|
||||
UpdateCustomCategoriesRequest
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
def get_user_context(
|
||||
x_tenant_domain: Optional[str] = Header(None),
|
||||
x_user_id: Optional[str] = Header(None),
|
||||
x_user_email: Optional[str] = Header(None)
|
||||
) -> tuple[str, str, str]:
|
||||
"""Extract user context from headers"""
|
||||
if not x_tenant_domain:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="X-Tenant-Domain header is required"
|
||||
)
|
||||
if not x_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="X-User-ID header is required"
|
||||
)
|
||||
|
||||
return x_tenant_domain, x_user_id, x_user_email or x_user_id
|
||||
|
||||
|
||||
# User Preferences Endpoints
|
||||
|
||||
@router.get("/me/preferences", response_model=UserPreferencesResponse)
|
||||
async def get_user_preferences(
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Get current user's preferences from PostgreSQL.
|
||||
|
||||
Returns all user preferences stored in the JSONB preferences column.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info("Getting user preferences", user_id=user_id, tenant_domain=tenant_domain)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
preferences = await service.get_user_preferences()
|
||||
|
||||
return UserPreferencesResponse(preferences=preferences)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user preferences", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve user preferences"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/preferences")
|
||||
async def update_user_preferences(
|
||||
request: UpdateUserPreferencesRequest,
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Update current user's preferences in PostgreSQL.
|
||||
|
||||
Merges provided preferences with existing preferences using JSONB || operator.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info("Updating user preferences", user_id=user_id, tenant_domain=tenant_domain)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
success = await service.update_user_preferences(request.preferences)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {"success": True, "message": "Preferences updated successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update user preferences", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user preferences"
|
||||
)
|
||||
|
||||
|
||||
# Favorite Agents Endpoints
|
||||
|
||||
@router.get("/me/favorite-agents", response_model=FavoriteAgentsResponse)
|
||||
async def get_favorite_agents(
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Get current user's favorited agent IDs from PostgreSQL.
|
||||
|
||||
Returns list of agent UUIDs that the user has marked as favorites.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info("Getting favorite agent IDs", user_id=user_id, tenant_domain=tenant_domain)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
favorite_ids = await service.get_favorite_agent_ids()
|
||||
|
||||
return FavoriteAgentsResponse(favorite_agent_ids=favorite_ids)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get favorite agents", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve favorite agents"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/favorite-agents")
|
||||
async def update_favorite_agents(
|
||||
request: UpdateFavoriteAgentsRequest,
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Update current user's favorite agent IDs in PostgreSQL.
|
||||
|
||||
Replaces the entire list of favorite agent IDs with the provided list.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Updating favorite agent IDs",
|
||||
user_id=user_id,
|
||||
tenant_domain=tenant_domain,
|
||||
agent_count=len(request.agent_ids)
|
||||
)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
success = await service.update_favorite_agent_ids(request.agent_ids)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Favorite agents updated successfully",
|
||||
"favorite_agent_ids": request.agent_ids
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update favorite agents", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update favorite agents"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/me/favorite-agents/add")
|
||||
async def add_favorite_agent(
|
||||
request: AddFavoriteAgentRequest,
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Add a single agent to user's favorites.
|
||||
|
||||
Idempotent - does nothing if agent is already in favorites.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Adding agent to favorites",
|
||||
user_id=user_id,
|
||||
tenant_domain=tenant_domain,
|
||||
agent_id=request.agent_id
|
||||
)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
success = await service.add_favorite_agent(request.agent_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Agent added to favorites",
|
||||
"agent_id": request.agent_id
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to add favorite agent", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to add favorite agent"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/me/favorite-agents/remove")
|
||||
async def remove_favorite_agent(
|
||||
request: RemoveFavoriteAgentRequest,
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Remove a single agent from user's favorites.
|
||||
|
||||
Idempotent - does nothing if agent is not in favorites.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Removing agent from favorites",
|
||||
user_id=user_id,
|
||||
tenant_domain=tenant_domain,
|
||||
agent_id=request.agent_id
|
||||
)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
success = await service.remove_favorite_agent(request.agent_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Agent removed from favorites",
|
||||
"agent_id": request.agent_id
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to remove favorite agent", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to remove favorite agent"
|
||||
)
|
||||
|
||||
|
||||
# Custom Categories Endpoints
|
||||
|
||||
@router.get("/me/custom-categories", response_model=CustomCategoriesResponse)
|
||||
async def get_custom_categories(
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Get current user's custom agent categories from PostgreSQL.
|
||||
|
||||
Returns list of custom categories with name and description.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info("Getting custom categories", user_id=user_id, tenant_domain=tenant_domain)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
categories = await service.get_custom_categories()
|
||||
|
||||
return CustomCategoriesResponse(categories=categories)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get custom categories", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve custom categories"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/custom-categories")
|
||||
async def update_custom_categories(
|
||||
request: UpdateCustomCategoriesRequest,
|
||||
user_context: tuple = Depends(get_user_context)
|
||||
):
|
||||
"""
|
||||
Update current user's custom agent categories in PostgreSQL.
|
||||
|
||||
Replaces the entire list of custom categories with the provided list.
|
||||
"""
|
||||
tenant_domain, user_id, user_email = user_context
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Updating custom categories",
|
||||
user_id=user_id,
|
||||
tenant_domain=tenant_domain,
|
||||
category_count=len(request.categories)
|
||||
)
|
||||
|
||||
service = UserService(tenant_domain, user_id, user_email)
|
||||
success = await service.update_custom_categories(request.categories)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Custom categories updated successfully",
|
||||
"categories": [cat.dict() for cat in request.categories]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update custom categories", error=str(e), user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update custom categories"
|
||||
)
|
||||
376
apps/tenant-backend/app/api/v1/webhooks.py
Normal file
376
apps/tenant-backend/app/api/v1/webhooks.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
Webhook Receiver API
|
||||
|
||||
Handles incoming webhooks for automation triggers with security validation
|
||||
and rate limiting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import hmac
|
||||
import hashlib
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, BackgroundTasks, Depends, Header
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.event_bus import TenantEventBus, TriggerType
|
||||
from app.services.automation_executor import AutomationChainExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
|
||||
|
||||
|
||||
class WebhookRegistration(BaseModel):
|
||||
"""Webhook registration model"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
secret: Optional[str] = None
|
||||
rate_limit: int = 60 # Requests per minute
|
||||
allowed_ips: Optional[list[str]] = None
|
||||
events_to_trigger: list[str] = []
|
||||
|
||||
|
||||
class WebhookPayload(BaseModel):
|
||||
"""Generic webhook payload"""
|
||||
event: Optional[str] = None
|
||||
data: Dict[str, Any] = {}
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
|
||||
# In-memory webhook registry (in production, use database)
|
||||
webhook_registry: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Rate limiting tracker (in production, use Redis)
|
||||
rate_limiter: Dict[str, list[datetime]] = {}
|
||||
|
||||
|
||||
def validate_webhook_registration(tenant_domain: str, webhook_id: str) -> Dict[str, Any]:
|
||||
"""Validate webhook is registered"""
|
||||
key = f"{tenant_domain}:{webhook_id}"
|
||||
|
||||
if key not in webhook_registry:
|
||||
raise HTTPException(status_code=404, detail="Webhook not found")
|
||||
|
||||
return webhook_registry[key]
|
||||
|
||||
|
||||
def check_rate_limit(key: str, limit: int) -> bool:
|
||||
"""Check if request is within rate limit"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Get request history
|
||||
if key not in rate_limiter:
|
||||
rate_limiter[key] = []
|
||||
|
||||
# Remove old requests (older than 1 minute)
|
||||
rate_limiter[key] = [
|
||||
ts for ts in rate_limiter[key]
|
||||
if (now - ts).total_seconds() < 60
|
||||
]
|
||||
|
||||
# Check limit
|
||||
if len(rate_limiter[key]) >= limit:
|
||||
return False
|
||||
|
||||
# Add current request
|
||||
rate_limiter[key].append(now)
|
||||
return True
|
||||
|
||||
|
||||
def validate_hmac_signature(
|
||||
signature: str,
|
||||
body: bytes,
|
||||
secret: str
|
||||
) -> bool:
|
||||
"""Validate HMAC signature"""
|
||||
if not signature or not secret:
|
||||
return False
|
||||
|
||||
# Calculate expected signature
|
||||
expected = hmac.new(
|
||||
secret.encode(),
|
||||
body,
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Compare signatures
|
||||
return hmac.compare_digest(signature, expected)
|
||||
|
||||
|
||||
def sanitize_webhook_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize webhook payload to prevent injection attacks"""
|
||||
# Remove potentially dangerous keys
|
||||
dangerous_keys = ["__proto__", "constructor", "prototype"]
|
||||
|
||||
def clean_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
cleaned = {}
|
||||
for key, value in d.items():
|
||||
if key not in dangerous_keys:
|
||||
if isinstance(value, dict):
|
||||
cleaned[key] = clean_dict(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [
|
||||
clean_dict(item) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
# Limit string length
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
cleaned[key] = value[:10000]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
return cleaned
|
||||
|
||||
return clean_dict(payload)
|
||||
|
||||
|
||||
@router.post("/{tenant_domain}/{webhook_id}")
|
||||
async def receive_webhook(
|
||||
tenant_domain: str,
|
||||
webhook_id: str,
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
x_webhook_signature: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
Receive webhook and trigger automations.
|
||||
|
||||
Note: In production, webhooks terminate at NGINX in DMZ, not directly here.
|
||||
Traffic flow: Internet → NGINX (DMZ) → OAuth2 Proxy → This endpoint
|
||||
"""
|
||||
try:
|
||||
# Validate webhook registration
|
||||
webhook = validate_webhook_registration(tenant_domain, webhook_id)
|
||||
|
||||
# Check rate limiting
|
||||
rate_key = f"webhook:{tenant_domain}:{webhook_id}"
|
||||
if not check_rate_limit(rate_key, webhook.get("rate_limit", 60)):
|
||||
raise HTTPException(status_code=429, detail="Rate limit exceeded")
|
||||
|
||||
# Get request body
|
||||
body = await request.body()
|
||||
|
||||
# Validate signature if configured
|
||||
if webhook.get("secret"):
|
||||
if not validate_hmac_signature(x_webhook_signature, body, webhook["secret"]):
|
||||
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||
|
||||
# Check IP whitelist if configured
|
||||
client_ip = request.client.host
|
||||
allowed_ips = webhook.get("allowed_ips")
|
||||
if allowed_ips and client_ip not in allowed_ips:
|
||||
logger.warning(f"Webhook request from unauthorized IP: {client_ip}")
|
||||
raise HTTPException(status_code=403, detail="IP not authorized")
|
||||
|
||||
# Parse and sanitize payload
|
||||
try:
|
||||
payload = await request.json()
|
||||
payload = sanitize_webhook_payload(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid webhook payload: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid payload format")
|
||||
|
||||
# Queue for processing
|
||||
background_tasks.add_task(
|
||||
process_webhook_automation,
|
||||
tenant_domain=tenant_domain,
|
||||
webhook_id=webhook_id,
|
||||
payload=payload,
|
||||
webhook_config=webhook
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "accepted",
|
||||
"webhook_id": webhook_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing webhook: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
async def process_webhook_automation(
|
||||
tenant_domain: str,
|
||||
webhook_id: str,
|
||||
payload: Dict[str, Any],
|
||||
webhook_config: Dict[str, Any]
|
||||
):
|
||||
"""Process webhook and trigger associated automations"""
|
||||
try:
|
||||
# Initialize event bus
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
|
||||
# Create webhook event
|
||||
event_type = payload.get("event", "webhook.received")
|
||||
|
||||
# Emit event to trigger automations
|
||||
await event_bus.emit_event(
|
||||
event_type=event_type,
|
||||
user_id=webhook_config.get("owner_id", "system"),
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"payload": payload,
|
||||
"source": webhook_config.get("name", "Unknown")
|
||||
},
|
||||
metadata={
|
||||
"trigger_type": TriggerType.WEBHOOK.value,
|
||||
"webhook_config": webhook_config
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Webhook processed: {webhook_id} → {event_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing webhook automation: {e}")
|
||||
|
||||
# Emit failure event
|
||||
try:
|
||||
event_bus = TenantEventBus(tenant_domain)
|
||||
await event_bus.emit_event(
|
||||
event_type="webhook.failed",
|
||||
user_id="system",
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/{tenant_domain}/register")
|
||||
async def register_webhook(
|
||||
tenant_domain: str,
|
||||
registration: WebhookRegistration,
|
||||
user_email: str = "admin@example.com" # In production, get from auth
|
||||
):
|
||||
"""Register a new webhook endpoint"""
|
||||
import secrets
|
||||
|
||||
# Generate webhook ID
|
||||
webhook_id = secrets.token_urlsafe(16)
|
||||
|
||||
# Store registration
|
||||
key = f"{tenant_domain}:{webhook_id}"
|
||||
webhook_registry[key] = {
|
||||
"id": webhook_id,
|
||||
"name": registration.name,
|
||||
"description": registration.description,
|
||||
"secret": registration.secret or secrets.token_urlsafe(32),
|
||||
"rate_limit": registration.rate_limit,
|
||||
"allowed_ips": registration.allowed_ips,
|
||||
"events_to_trigger": registration.events_to_trigger,
|
||||
"owner_id": user_email,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"url": f"/webhooks/{tenant_domain}/{webhook_id}"
|
||||
}
|
||||
|
||||
return {
|
||||
"webhook_id": webhook_id,
|
||||
"url": f"/webhooks/{tenant_domain}/{webhook_id}",
|
||||
"secret": webhook_registry[key]["secret"],
|
||||
"created_at": webhook_registry[key]["created_at"]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{tenant_domain}/list")
|
||||
async def list_webhooks(
|
||||
tenant_domain: str,
|
||||
user_email: str = "admin@example.com" # In production, get from auth
|
||||
):
|
||||
"""List registered webhooks for tenant"""
|
||||
webhooks = []
|
||||
|
||||
for key, webhook in webhook_registry.items():
|
||||
if key.startswith(f"{tenant_domain}:"):
|
||||
# Only show webhooks owned by user
|
||||
if webhook.get("owner_id") == user_email:
|
||||
webhooks.append({
|
||||
"id": webhook["id"],
|
||||
"name": webhook["name"],
|
||||
"description": webhook.get("description"),
|
||||
"url": webhook["url"],
|
||||
"rate_limit": webhook["rate_limit"],
|
||||
"created_at": webhook["created_at"]
|
||||
})
|
||||
|
||||
return {
|
||||
"webhooks": webhooks,
|
||||
"total": len(webhooks)
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{tenant_domain}/{webhook_id}")
|
||||
async def delete_webhook(
|
||||
tenant_domain: str,
|
||||
webhook_id: str,
|
||||
user_email: str = "admin@example.com" # In production, get from auth
|
||||
):
|
||||
"""Delete a webhook registration"""
|
||||
key = f"{tenant_domain}:{webhook_id}"
|
||||
|
||||
if key not in webhook_registry:
|
||||
raise HTTPException(status_code=404, detail="Webhook not found")
|
||||
|
||||
webhook = webhook_registry[key]
|
||||
|
||||
# Check ownership
|
||||
if webhook.get("owner_id") != user_email:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Delete webhook
|
||||
del webhook_registry[key]
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"webhook_id": webhook_id
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{tenant_domain}/{webhook_id}/test")
|
||||
async def test_webhook(
|
||||
tenant_domain: str,
|
||||
webhook_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
user_email: str = "admin@example.com" # In production, get from auth
|
||||
):
|
||||
"""Send a test payload to webhook"""
|
||||
key = f"{tenant_domain}:{webhook_id}"
|
||||
|
||||
if key not in webhook_registry:
|
||||
raise HTTPException(status_code=404, detail="Webhook not found")
|
||||
|
||||
webhook = webhook_registry[key]
|
||||
|
||||
# Check ownership
|
||||
if webhook.get("owner_id") != user_email:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Create test payload
|
||||
test_payload = {
|
||||
"event": "webhook.test",
|
||||
"data": {
|
||||
"message": "This is a test webhook",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Process webhook
|
||||
background_tasks.add_task(
|
||||
process_webhook_automation,
|
||||
tenant_domain=tenant_domain,
|
||||
webhook_id=webhook_id,
|
||||
payload=test_payload,
|
||||
webhook_config=webhook
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "test_sent",
|
||||
"webhook_id": webhook_id,
|
||||
"payload": test_payload
|
||||
}
|
||||
534
apps/tenant-backend/app/api/v1/workflows.py
Normal file
534
apps/tenant-backend/app/api/v1/workflows.py
Normal file
@@ -0,0 +1,534 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
from app.services.workflow_service import WorkflowService, WorkflowValidationError
|
||||
from app.models.workflow import WorkflowStatus, TriggerType, InteractionMode
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/workflows", tags=["workflows"])
|
||||
|
||||
|
||||
# Request/Response models
|
||||
class WorkflowCreateRequest(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
definition: Dict[str, Any] = Field(..., description="Workflow nodes and edges")
|
||||
triggers: Optional[List[Dict[str, Any]]] = Field(default=[])
|
||||
interaction_modes: Optional[List[str]] = Field(default=["button"])
|
||||
api_key_ids: Optional[List[str]] = Field(default=[])
|
||||
webhook_ids: Optional[List[str]] = Field(default=[])
|
||||
dataset_ids: Optional[List[str]] = Field(default=[])
|
||||
integration_ids: Optional[List[str]] = Field(default=[])
|
||||
config: Optional[Dict[str, Any]] = Field(default={})
|
||||
timeout_seconds: Optional[int] = Field(default=300, ge=30, le=3600)
|
||||
max_retries: Optional[int] = Field(default=3, ge=0, le=10)
|
||||
|
||||
|
||||
class WorkflowUpdateRequest(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
definition: Optional[Dict[str, Any]] = None
|
||||
triggers: Optional[List[Dict[str, Any]]] = None
|
||||
interaction_modes: Optional[List[str]] = None
|
||||
status: Optional[WorkflowStatus] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
timeout_seconds: Optional[int] = Field(None, ge=30, le=3600)
|
||||
max_retries: Optional[int] = Field(None, ge=0, le=10)
|
||||
|
||||
|
||||
class WorkflowExecutionRequest(BaseModel):
|
||||
input_data: Dict[str, Any] = Field(default={})
|
||||
trigger_type: Optional[str] = Field(default="manual")
|
||||
interaction_mode: Optional[str] = Field(default="api")
|
||||
|
||||
|
||||
class WorkflowTriggerRequest(BaseModel):
|
||||
type: TriggerType
|
||||
config: Dict[str, Any] = Field(default={})
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel):
|
||||
message: str = Field(..., min_length=1, max_length=10000)
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
class WorkflowResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
status: str
|
||||
definition: Dict[str, Any]
|
||||
interaction_modes: List[str]
|
||||
execution_count: int
|
||||
last_executed: Optional[str]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class WorkflowExecutionResponse(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
progress_percentage: int
|
||||
current_node_id: Optional[str]
|
||||
input_data: Dict[str, Any]
|
||||
output_data: Dict[str, Any]
|
||||
error_details: Optional[str]
|
||||
started_at: str
|
||||
completed_at: Optional[str]
|
||||
duration_ms: Optional[int]
|
||||
tokens_used: int
|
||||
cost_cents: int
|
||||
interaction_mode: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# Workflow CRUD endpoints
|
||||
@router.post("/", response_model=WorkflowResponse)
|
||||
def create_workflow(
|
||||
workflow_data: WorkflowCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new workflow - temporary mock implementation"""
|
||||
try:
|
||||
# TODO: Implement proper PostgreSQL service integration
|
||||
# For now, return a mock workflow to avoid database integration issues
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
mock_workflow = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": workflow_data.name,
|
||||
"description": workflow_data.description or "",
|
||||
"status": "draft",
|
||||
"definition": workflow_data.definition,
|
||||
"interaction_modes": workflow_data.interaction_modes or ["button"],
|
||||
"execution_count": 0,
|
||||
"last_executed": None,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"updated_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
return WorkflowResponse(**mock_workflow)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create workflow: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/", response_model=List[WorkflowResponse])
|
||||
def list_workflows(
|
||||
status: Optional[WorkflowStatus] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""List user's workflows - temporary mock implementation"""
|
||||
try:
|
||||
# TODO: Implement proper PostgreSQL service integration
|
||||
# For now, return empty list to avoid database integration issues
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list workflows: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{workflow_id}", response_model=WorkflowResponse)
|
||||
def get_workflow(
|
||||
workflow_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get workflow by ID"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
workflow = service.get_workflow(workflow_id, current_user["sub"])
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Convert to dict with proper datetime formatting
|
||||
workflow_dict = {
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
"description": workflow.description,
|
||||
"status": workflow.status,
|
||||
"definition": workflow.definition,
|
||||
"interaction_modes": workflow.interaction_modes,
|
||||
"execution_count": workflow.execution_count,
|
||||
"last_executed": workflow.last_executed.isoformat() if workflow.last_executed else None,
|
||||
"created_at": workflow.created_at.isoformat(),
|
||||
"updated_at": workflow.updated_at.isoformat()
|
||||
}
|
||||
return WorkflowResponse(**workflow_dict)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get workflow: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/{workflow_id}", response_model=WorkflowResponse)
|
||||
def update_workflow(
|
||||
workflow_id: str,
|
||||
updates: WorkflowUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Update a workflow"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
|
||||
# Filter out None values
|
||||
update_data = {k: v for k, v in updates.dict().items() if v is not None}
|
||||
|
||||
workflow = service.update_workflow(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
updates=update_data
|
||||
)
|
||||
|
||||
# Convert to dict with proper datetime formatting
|
||||
workflow_dict = {
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
"description": workflow.description,
|
||||
"status": workflow.status,
|
||||
"definition": workflow.definition,
|
||||
"interaction_modes": workflow.interaction_modes,
|
||||
"execution_count": workflow.execution_count,
|
||||
"last_executed": workflow.last_executed.isoformat() if workflow.last_executed else None,
|
||||
"created_at": workflow.created_at.isoformat(),
|
||||
"updated_at": workflow.updated_at.isoformat()
|
||||
}
|
||||
return WorkflowResponse(**workflow_dict)
|
||||
|
||||
except WorkflowValidationError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update workflow: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{workflow_id}")
|
||||
def delete_workflow(
|
||||
workflow_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a workflow"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
success = service.delete_workflow(workflow_id, current_user["sub"])
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
return {"message": "Workflow deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {str(e)}")
|
||||
|
||||
|
||||
# Workflow execution endpoints
|
||||
@router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse)
|
||||
async def execute_workflow(
|
||||
workflow_id: str,
|
||||
execution_data: WorkflowExecutionRequest,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Execute a workflow"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
execution = await service.execute_workflow(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
input_data=execution_data.input_data,
|
||||
trigger_type=execution_data.trigger_type,
|
||||
interaction_mode=execution_data.interaction_mode
|
||||
)
|
||||
|
||||
return WorkflowExecutionResponse.from_orm(execution)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to execute workflow: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{workflow_id}/executions")
|
||||
async def list_workflow_executions(
|
||||
workflow_id: str,
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""List workflow executions"""
|
||||
try:
|
||||
# Verify workflow ownership first
|
||||
service = WorkflowService(db)
|
||||
workflow = await service.get_workflow(workflow_id, current_user["sub"])
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Get executions (implementation would query WorkflowExecution table)
|
||||
# For now, return empty list
|
||||
return []
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list executions: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}", response_model=WorkflowExecutionResponse)
|
||||
async def get_execution_status(
|
||||
execution_id: str,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get execution status"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
execution = await service.get_execution_status(execution_id, current_user["sub"])
|
||||
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
return WorkflowExecutionResponse.from_orm(execution)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get execution: {str(e)}")
|
||||
|
||||
|
||||
# Workflow trigger endpoints
|
||||
@router.post("/{workflow_id}/triggers")
|
||||
async def create_workflow_trigger(
|
||||
workflow_id: str,
|
||||
trigger_data: WorkflowTriggerRequest,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Create a trigger for a workflow"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
|
||||
# Verify workflow ownership
|
||||
workflow = await service.get_workflow(workflow_id, current_user["sub"])
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
trigger = await service.create_workflow_trigger(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
tenant_id=current_user["tenant_id"],
|
||||
trigger_config=trigger_data.dict()
|
||||
)
|
||||
|
||||
return {"id": trigger.id, "webhook_url": trigger.webhook_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create trigger: {str(e)}")
|
||||
|
||||
|
||||
# Chat interface endpoints
|
||||
@router.post("/{workflow_id}/chat/sessions")
|
||||
async def create_chat_session(
|
||||
workflow_id: str,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Create a chat session for workflow interaction"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
|
||||
# Verify workflow ownership
|
||||
workflow = await service.get_workflow(workflow_id, current_user["sub"])
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Check if workflow supports chat mode
|
||||
if "chat" not in workflow.interaction_modes:
|
||||
raise HTTPException(status_code=400, detail="Workflow does not support chat mode")
|
||||
|
||||
session = await service.create_chat_session(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
tenant_id=current_user["tenant_id"]
|
||||
)
|
||||
|
||||
return {"session_id": session.id, "expires_at": session.expires_at.isoformat()}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create chat session: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{workflow_id}/chat/message")
|
||||
async def send_chat_message(
|
||||
workflow_id: str,
|
||||
message_data: ChatMessageRequest,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Send a message to workflow chat"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
|
||||
# Create or get session
|
||||
session_id = message_data.session_id
|
||||
if not session_id:
|
||||
session = await service.create_chat_session(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
tenant_id=current_user["tenant_id"]
|
||||
)
|
||||
session_id = session.id
|
||||
|
||||
# Add user message
|
||||
user_message = await service.add_chat_message(
|
||||
session_id=session_id,
|
||||
user_id=current_user["sub"],
|
||||
role="user",
|
||||
content=message_data.message
|
||||
)
|
||||
|
||||
# Execute workflow with message as input
|
||||
execution = await service.execute_workflow(
|
||||
workflow_id=workflow_id,
|
||||
user_id=current_user["sub"],
|
||||
input_data={"message": message_data.message},
|
||||
trigger_type="chat",
|
||||
interaction_mode="chat"
|
||||
)
|
||||
|
||||
# Add agent response (in real implementation, this would come from workflow execution)
|
||||
assistant_response = execution.output_data.get('result', 'Workflow response')
|
||||
|
||||
assistant_message = await service.add_chat_message(
|
||||
session_id=session_id,
|
||||
user_id=current_user["sub"],
|
||||
role="agent",
|
||||
content=assistant_response,
|
||||
execution_id=execution.id
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"user_message": {
|
||||
"id": user_message.id,
|
||||
"content": user_message.content,
|
||||
"timestamp": user_message.created_at.isoformat()
|
||||
},
|
||||
"assistant_message": {
|
||||
"id": assistant_message.id,
|
||||
"content": assistant_message.content,
|
||||
"timestamp": assistant_message.created_at.isoformat()
|
||||
},
|
||||
"execution": {
|
||||
"id": execution.id,
|
||||
"status": execution.status
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to process chat message: {str(e)}")
|
||||
|
||||
|
||||
# Workflow interface generation endpoints
|
||||
@router.get("/{workflow_id}/interface/{mode}")
|
||||
async def get_workflow_interface(
|
||||
workflow_id: str,
|
||||
mode: InteractionMode,
|
||||
db = Depends(get_db),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Get workflow interface configuration for specified mode"""
|
||||
try:
|
||||
service = WorkflowService(db)
|
||||
workflow = await service.get_workflow(workflow_id, current_user["sub"])
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
if mode not in workflow.interaction_modes:
|
||||
raise HTTPException(status_code=400, detail=f"Workflow does not support {mode} mode")
|
||||
|
||||
# Generate interface configuration based on mode
|
||||
from app.models.workflow import INTERACTION_MODE_CONFIGS
|
||||
|
||||
interface_config = INTERACTION_MODE_CONFIGS.get(mode, {})
|
||||
|
||||
# Customize based on workflow definition
|
||||
if mode == "form":
|
||||
# Generate form fields from workflow inputs
|
||||
trigger_nodes = [n for n in workflow.definition.get('nodes', []) if n.get('type') == 'trigger']
|
||||
form_fields = []
|
||||
|
||||
for node in trigger_nodes:
|
||||
if node.get('data', {}).get('input_schema'):
|
||||
form_fields.extend(node['data']['input_schema'])
|
||||
|
||||
interface_config['form_fields'] = form_fields
|
||||
|
||||
elif mode == "button":
|
||||
# Simple button configuration
|
||||
interface_config['button_text'] = f"Execute {workflow.name}"
|
||||
interface_config['description'] = workflow.description
|
||||
|
||||
elif mode == "dashboard":
|
||||
# Dashboard metrics configuration
|
||||
interface_config['metrics'] = {
|
||||
'execution_count': workflow.execution_count,
|
||||
'total_cost': workflow.total_cost_cents / 100,
|
||||
'avg_execution_time': workflow.average_execution_time_ms,
|
||||
'status': workflow.status
|
||||
}
|
||||
|
||||
return {
|
||||
'workflow_id': workflow_id,
|
||||
'mode': mode,
|
||||
'config': interface_config,
|
||||
'workflow': {
|
||||
'name': workflow.name,
|
||||
'description': workflow.description,
|
||||
'status': workflow.status
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get interface: {str(e)}")
|
||||
|
||||
|
||||
# Node type information endpoints
|
||||
@router.get("/node-types")
|
||||
async def get_workflow_node_types():
|
||||
"""Get available workflow node types and their configurations"""
|
||||
from app.models.workflow import WORKFLOW_NODE_TYPES
|
||||
return WORKFLOW_NODE_TYPES
|
||||
|
||||
|
||||
@router.get("/interaction-modes")
|
||||
async def get_interaction_modes():
|
||||
"""Get available interaction modes and their configurations"""
|
||||
from app.models.workflow import INTERACTION_MODE_CONFIGS
|
||||
return INTERACTION_MODE_CONFIGS
|
||||
395
apps/tenant-backend/app/api/websocket.py
Normal file
395
apps/tenant-backend/app/api/websocket.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
WebSocket API endpoints for GT 2.0 Tenant Backend
|
||||
|
||||
Provides secure WebSocket connections for real-time chat with:
|
||||
- JWT authentication
|
||||
- Perfect tenant isolation
|
||||
- Conversation-based messaging
|
||||
- Automatic cleanup on disconnect
|
||||
|
||||
GT 2.0 Security Features:
|
||||
- Token-based authentication
|
||||
- Rate limiting per user
|
||||
- Connection limits per tenant
|
||||
- Automatic session cleanup
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db_session
|
||||
from app.core.security import get_current_user_email, get_tenant_info
|
||||
from app.websocket import (
|
||||
websocket_manager,
|
||||
get_websocket_manager,
|
||||
authenticate_websocket_connection,
|
||||
WebSocketManager
|
||||
)
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
|
||||
@router.websocket("/chat/{conversation_id}")
|
||||
async def websocket_chat_endpoint(
|
||||
websocket: WebSocket,
|
||||
conversation_id: str,
|
||||
token: str = Query(..., description="JWT authentication token")
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time chat in a specific conversation.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
conversation_id: Conversation ID to join
|
||||
token: JWT authentication token
|
||||
"""
|
||||
connection_id = None
|
||||
|
||||
try:
|
||||
# Authenticate connection
|
||||
try:
|
||||
user_id, tenant_id = await authenticate_websocket_connection(token)
|
||||
except ValueError as e:
|
||||
await websocket.close(code=1008, reason=f"Authentication failed: {e}")
|
||||
return
|
||||
|
||||
# Verify conversation access
|
||||
async with get_db_session() as db:
|
||||
conversation_service = ConversationService(db)
|
||||
conversation = await conversation_service.get_conversation(conversation_id, user_id)
|
||||
|
||||
if not conversation:
|
||||
await websocket.close(code=1008, reason="Conversation not found or access denied")
|
||||
return
|
||||
|
||||
# Establish WebSocket connection
|
||||
manager = get_websocket_manager()
|
||||
connection_id = await manager.connect(
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
logger.info(f"WebSocket chat connection established: {connection_id} for conversation {conversation_id}")
|
||||
|
||||
# Message handling loop
|
||||
while True:
|
||||
try:
|
||||
# Receive message
|
||||
data = await websocket.receive_text()
|
||||
message_data = json.loads(data)
|
||||
|
||||
# Handle message
|
||||
success = await manager.handle_message(connection_id, message_data)
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to handle message from {connection_id}")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected: {connection_id}")
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON received from {connection_id}")
|
||||
await manager.send_to_connection(connection_id, {
|
||||
"type": "error",
|
||||
"message": "Invalid JSON format",
|
||||
"code": "INVALID_JSON"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket message loop: {e}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket chat endpoint: {e}")
|
||||
try:
|
||||
await websocket.close(code=1011, reason=f"Server error: {e}")
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Cleanup connection
|
||||
if connection_id:
|
||||
await manager.disconnect(connection_id, reason="Connection closed")
|
||||
|
||||
|
||||
@router.websocket("/general")
|
||||
async def websocket_general_endpoint(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(..., description="JWT authentication token")
|
||||
):
|
||||
"""
|
||||
General WebSocket endpoint for notifications and system messages.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
token: JWT authentication token
|
||||
"""
|
||||
connection_id = None
|
||||
|
||||
try:
|
||||
# Authenticate connection
|
||||
try:
|
||||
user_id, tenant_id = await authenticate_websocket_connection(token)
|
||||
except ValueError as e:
|
||||
await websocket.close(code=1008, reason=f"Authentication failed: {e}")
|
||||
return
|
||||
|
||||
# Establish WebSocket connection (no specific conversation)
|
||||
manager = get_websocket_manager()
|
||||
connection_id = await manager.connect(
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None
|
||||
)
|
||||
|
||||
logger.info(f"General WebSocket connection established: {connection_id}")
|
||||
|
||||
# Message handling loop
|
||||
while True:
|
||||
try:
|
||||
# Receive message
|
||||
data = await websocket.receive_text()
|
||||
message_data = json.loads(data)
|
||||
|
||||
# Handle message
|
||||
success = await manager.handle_message(connection_id, message_data)
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to handle message from {connection_id}")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"General WebSocket disconnected: {connection_id}")
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON received from {connection_id}")
|
||||
await manager.send_to_connection(connection_id, {
|
||||
"type": "error",
|
||||
"message": "Invalid JSON format",
|
||||
"code": "INVALID_JSON"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in general WebSocket message loop: {e}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in general WebSocket endpoint: {e}")
|
||||
try:
|
||||
await websocket.close(code=1011, reason=f"Server error: {e}")
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Cleanup connection
|
||||
if connection_id:
|
||||
await manager.disconnect(connection_id, reason="Connection closed")
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_websocket_stats(
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: dict = Depends(get_tenant_info)
|
||||
):
|
||||
"""Get WebSocket connection statistics for tenant"""
|
||||
try:
|
||||
manager = get_websocket_manager()
|
||||
stats = manager.get_connection_stats()
|
||||
|
||||
# Filter stats for current tenant
|
||||
tenant_id = tenant_info["tenant_id"]
|
||||
tenant_stats = {
|
||||
"total_connections": stats["connections_by_tenant"].get(tenant_id, 0),
|
||||
"active_conversations": len([
|
||||
conv_id for conv_id, connections in manager.conversation_connections.items()
|
||||
if any(
|
||||
manager.connections.get(conn_id, {}).tenant_id == tenant_id
|
||||
for conn_id in connections
|
||||
)
|
||||
]),
|
||||
"user_connections": stats["connections_by_user"].get(current_user, 0)
|
||||
}
|
||||
|
||||
return JSONResponse(content=tenant_stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get WebSocket stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/broadcast/tenant")
|
||||
async def broadcast_to_tenant(
|
||||
message: dict,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: dict = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Broadcast message to all connections in tenant.
|
||||
(Admin/system use only)
|
||||
"""
|
||||
try:
|
||||
# Check if user has admin permissions
|
||||
# TODO: Implement proper admin role checking
|
||||
|
||||
manager = get_websocket_manager()
|
||||
tenant_id = tenant_info["tenant_id"]
|
||||
|
||||
broadcast_message = {
|
||||
"type": "system_broadcast",
|
||||
"message": message.get("content", ""),
|
||||
"timestamp": message.get("timestamp"),
|
||||
"sender": "system"
|
||||
}
|
||||
|
||||
await manager.broadcast_to_tenant(tenant_id, broadcast_message)
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "Broadcast sent successfully",
|
||||
"recipients": len(manager.tenant_connections.get(tenant_id, set()))
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to broadcast to tenant: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/conversations/{conversation_id}/broadcast")
|
||||
async def broadcast_to_conversation(
|
||||
conversation_id: str,
|
||||
message: dict,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: dict = Depends(get_tenant_info),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""
|
||||
Broadcast message to all participants in a conversation.
|
||||
"""
|
||||
try:
|
||||
# Verify user has access to conversation
|
||||
conversation_service = ConversationService(db)
|
||||
conversation = await conversation_service.get_conversation(conversation_id, current_user)
|
||||
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
manager = get_websocket_manager()
|
||||
|
||||
broadcast_message = {
|
||||
"type": "conversation_broadcast",
|
||||
"conversation_id": conversation_id,
|
||||
"message": message.get("content", ""),
|
||||
"timestamp": message.get("timestamp"),
|
||||
"sender": current_user
|
||||
}
|
||||
|
||||
await manager.broadcast_to_conversation(conversation_id, broadcast_message)
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "Broadcast sent successfully",
|
||||
"conversation_id": conversation_id,
|
||||
"recipients": len(manager.conversation_connections.get(conversation_id, set()))
|
||||
})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to broadcast to conversation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/connections/{connection_id}/send")
|
||||
async def send_to_connection(
|
||||
connection_id: str,
|
||||
message: dict,
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: dict = Depends(get_tenant_info)
|
||||
):
|
||||
"""
|
||||
Send message to specific connection.
|
||||
(Admin/system use only)
|
||||
"""
|
||||
try:
|
||||
# Check if user has admin permissions or owns the connection
|
||||
manager = get_websocket_manager()
|
||||
connection = manager.connections.get(connection_id)
|
||||
|
||||
if not connection:
|
||||
raise HTTPException(status_code=404, detail="Connection not found")
|
||||
|
||||
# Verify tenant access
|
||||
if connection.tenant_id != tenant_info["tenant_id"]:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# TODO: Add admin role check or connection ownership verification
|
||||
|
||||
target_message = {
|
||||
"type": "direct_message",
|
||||
"message": message.get("content", ""),
|
||||
"timestamp": message.get("timestamp"),
|
||||
"sender": current_user
|
||||
}
|
||||
|
||||
success = await manager.send_to_connection(connection_id, target_message)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=410, detail="Connection no longer active")
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "Message sent successfully",
|
||||
"connection_id": connection_id
|
||||
})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send to connection: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/connections/{connection_id}")
|
||||
async def disconnect_connection(
|
||||
connection_id: str,
|
||||
reason: str = Query("Admin disconnect", description="Disconnect reason"),
|
||||
current_user: str = Depends(get_current_user_email),
|
||||
tenant_info: dict = Depends(get_tenant_info)
|
||||
):
|
||||
"""
|
||||
Forcefully disconnect a WebSocket connection.
|
||||
(Admin use only)
|
||||
"""
|
||||
try:
|
||||
# Check if user has admin permissions
|
||||
# TODO: Implement proper admin role checking
|
||||
|
||||
manager = get_websocket_manager()
|
||||
connection = manager.connections.get(connection_id)
|
||||
|
||||
if not connection:
|
||||
raise HTTPException(status_code=404, detail="Connection not found")
|
||||
|
||||
# Verify tenant access
|
||||
if connection.tenant_id != tenant_info["tenant_id"]:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
await manager.disconnect(connection_id, code=1008, reason=reason)
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "Connection disconnected successfully",
|
||||
"connection_id": connection_id,
|
||||
"reason": reason
|
||||
})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to disconnect connection: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
131
apps/tenant-backend/app/core/api_standards.py
Normal file
131
apps/tenant-backend/app/core/api_standards.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend - CB-REST API Standards Integration
|
||||
|
||||
This module integrates the CB-REST standards into the Tenant backend
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the api-standards package to the path
|
||||
api_standards_path = Path(__file__).parent.parent.parent.parent.parent / "packages" / "api-standards" / "src"
|
||||
if api_standards_path.exists():
|
||||
sys.path.insert(0, str(api_standards_path))
|
||||
|
||||
# Import CB-REST standards
|
||||
try:
|
||||
from response import StandardResponse, format_response, format_error
|
||||
from capability import (
|
||||
init_capability_verifier,
|
||||
verify_capability,
|
||||
require_capability,
|
||||
Capability,
|
||||
CapabilityToken
|
||||
)
|
||||
from errors import ErrorCode, APIError, raise_api_error
|
||||
from middleware import (
|
||||
RequestCorrelationMiddleware,
|
||||
CapabilityMiddleware,
|
||||
TenantIsolationMiddleware,
|
||||
RateLimitMiddleware
|
||||
)
|
||||
except ImportError as e:
|
||||
# Fallback for development - create minimal implementations
|
||||
print(f"Warning: Could not import api-standards package: {e}")
|
||||
|
||||
# Create minimal implementations for development
|
||||
class StandardResponse:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def format_response(data, capability_used, request_id=None):
|
||||
return {
|
||||
"data": data,
|
||||
"error": None,
|
||||
"capability_used": capability_used,
|
||||
"request_id": request_id or "dev-mode"
|
||||
}
|
||||
|
||||
def format_error(code, message, capability_used="none", **kwargs):
|
||||
return {
|
||||
"data": None,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message,
|
||||
**kwargs
|
||||
},
|
||||
"capability_used": capability_used,
|
||||
"request_id": kwargs.get("request_id", "dev-mode")
|
||||
}
|
||||
|
||||
class ErrorCode:
|
||||
CAPABILITY_INSUFFICIENT = "CAPABILITY_INSUFFICIENT"
|
||||
RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"
|
||||
INVALID_REQUEST = "INVALID_REQUEST"
|
||||
SYSTEM_ERROR = "SYSTEM_ERROR"
|
||||
TENANT_ISOLATION_VIOLATION = "TENANT_ISOLATION_VIOLATION"
|
||||
|
||||
class APIError(Exception):
|
||||
def __init__(self, code, message, **kwargs):
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.kwargs = kwargs
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# Export all CB-REST components
|
||||
__all__ = [
|
||||
'StandardResponse',
|
||||
'format_response',
|
||||
'format_error',
|
||||
'init_capability_verifier',
|
||||
'verify_capability',
|
||||
'require_capability',
|
||||
'Capability',
|
||||
'CapabilityToken',
|
||||
'ErrorCode',
|
||||
'APIError',
|
||||
'raise_api_error',
|
||||
'RequestCorrelationMiddleware',
|
||||
'CapabilityMiddleware',
|
||||
'TenantIsolationMiddleware',
|
||||
'RateLimitMiddleware'
|
||||
]
|
||||
|
||||
|
||||
def setup_api_standards(app, secret_key: str, tenant_id: str):
|
||||
"""
|
||||
Setup CB-REST API standards for the tenant application
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
secret_key: Secret key for JWT signing
|
||||
tenant_id: Tenant identifier for isolation
|
||||
"""
|
||||
# Initialize capability verifier
|
||||
if 'init_capability_verifier' in globals():
|
||||
init_capability_verifier(secret_key)
|
||||
|
||||
# Add middleware in correct order
|
||||
if 'RequestCorrelationMiddleware' in globals():
|
||||
app.add_middleware(RequestCorrelationMiddleware)
|
||||
|
||||
if 'RateLimitMiddleware' in globals():
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
requests_per_minute=100 # Per-tenant rate limiting
|
||||
)
|
||||
|
||||
if 'TenantIsolationMiddleware' in globals():
|
||||
app.add_middleware(
|
||||
TenantIsolationMiddleware,
|
||||
tenant_id=tenant_id,
|
||||
enforce_isolation=True
|
||||
)
|
||||
|
||||
if 'CapabilityMiddleware' in globals():
|
||||
app.add_middleware(
|
||||
CapabilityMiddleware,
|
||||
exclude_paths=["/health", "/ready", "/metrics", "/api/v1/auth/login"]
|
||||
)
|
||||
162
apps/tenant-backend/app/core/asgi_router.py
Normal file
162
apps/tenant-backend/app/core/asgi_router.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Composite ASGI Router for GT 2.0 Tenant Backend
|
||||
|
||||
Handles routing between FastAPI and Socket.IO applications to prevent
|
||||
ASGI protocol conflicts while maintaining both WebSocket systems.
|
||||
|
||||
Architecture:
|
||||
- `/socket.io/*` → Socket.IO ASGIApp (agentic real-time features)
|
||||
- All other paths → FastAPI app (REST API, native WebSocket)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Callable, Awaitable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompositeASGIRouter:
|
||||
"""
|
||||
ASGI router that handles both FastAPI and Socket.IO applications
|
||||
without protocol conflicts.
|
||||
"""
|
||||
|
||||
def __init__(self, fastapi_app, socketio_app):
|
||||
"""
|
||||
Initialize composite router with both applications.
|
||||
|
||||
Args:
|
||||
fastapi_app: FastAPI application instance
|
||||
socketio_app: Socket.IO ASGIApp instance
|
||||
"""
|
||||
self.fastapi_app = fastapi_app
|
||||
self.socketio_app = socketio_app
|
||||
logger.info("Composite ASGI router initialized for FastAPI + Socket.IO")
|
||||
|
||||
async def __call__(self, scope: Dict[str, Any], receive: Callable, send: Callable) -> None:
|
||||
"""
|
||||
ASGI application entry point that routes requests based on path.
|
||||
|
||||
Args:
|
||||
scope: ASGI scope containing request information
|
||||
receive: ASGI receive callable
|
||||
send: ASGI send callable
|
||||
"""
|
||||
try:
|
||||
# Extract path from scope
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Route based on path pattern
|
||||
if self._is_socketio_path(path):
|
||||
# Only log Socket.IO routing at DEBUG level for non-operational paths
|
||||
if self._should_log_route(path):
|
||||
logger.debug(f"Routing to Socket.IO: {path}")
|
||||
await self.socketio_app(scope, receive, send)
|
||||
else:
|
||||
# Only log FastAPI routing at DEBUG level for non-operational paths
|
||||
if self._should_log_route(path):
|
||||
logger.debug(f"Routing to FastAPI: {path}")
|
||||
await self.fastapi_app(scope, receive, send)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ASGI routing: {e}")
|
||||
# Fallback to FastAPI for error handling
|
||||
try:
|
||||
await self.fastapi_app(scope, receive, send)
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback routing also failed: {fallback_error}")
|
||||
# Last resort: send basic error response
|
||||
await self._send_error_response(scope, send)
|
||||
|
||||
def _is_socketio_path(self, path: str) -> bool:
|
||||
"""
|
||||
Determine if path should be routed to Socket.IO.
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
True if path should go to Socket.IO, False for FastAPI
|
||||
"""
|
||||
socketio_patterns = [
|
||||
"/socket.io/",
|
||||
"/socket.io"
|
||||
]
|
||||
|
||||
# Check if path starts with any Socket.IO pattern
|
||||
for pattern in socketio_patterns:
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _should_log_route(self, path: str) -> bool:
|
||||
"""
|
||||
Determine if this path should be logged during routing.
|
||||
|
||||
Operational endpoints like health checks and metrics are excluded
|
||||
to reduce log noise during normal operation.
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
True if path should be logged, False for operational endpoints
|
||||
"""
|
||||
operational_endpoints = [
|
||||
"/health",
|
||||
"/ready",
|
||||
"/metrics",
|
||||
"/api/v1/health"
|
||||
]
|
||||
|
||||
# Don't log operational endpoints
|
||||
if any(path.startswith(endpoint) for endpoint in operational_endpoints):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _send_error_response(self, scope: Dict[str, Any], send: Callable) -> None:
|
||||
"""
|
||||
Send basic error response when both applications fail.
|
||||
|
||||
Args:
|
||||
scope: ASGI scope
|
||||
send: ASGI send callable
|
||||
"""
|
||||
try:
|
||||
if scope["type"] == "http":
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [
|
||||
[b"content-type", b"application/json"],
|
||||
[b"content-length", b"27"]
|
||||
]
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": b'{"error": "ASGI routing failed"}'
|
||||
})
|
||||
elif scope["type"] == "websocket":
|
||||
await send({
|
||||
"type": "websocket.close",
|
||||
"code": 1011,
|
||||
"reason": "ASGI routing failed"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send error response: {e}")
|
||||
|
||||
|
||||
def create_composite_asgi_app(fastapi_app, socketio_app):
|
||||
"""
|
||||
Factory function to create composite ASGI application.
|
||||
|
||||
Args:
|
||||
fastapi_app: FastAPI application instance
|
||||
socketio_app: Socket.IO ASGIApp instance
|
||||
|
||||
Returns:
|
||||
CompositeASGIRouter instance
|
||||
"""
|
||||
return CompositeASGIRouter(fastapi_app, socketio_app)
|
||||
202
apps/tenant-backend/app/core/cache.py
Normal file
202
apps/tenant-backend/app/core/cache.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Simple in-memory cache with TTL support for Gen Two performance optimization.
|
||||
|
||||
This module provides a thread-safe caching layer for expensive database queries
|
||||
and API calls. Each Uvicorn worker maintains its own cache instance.
|
||||
|
||||
Key features:
|
||||
- TTL-based expiration (configurable per-key)
|
||||
- LRU eviction when cache reaches max size
|
||||
- Thread-safe for concurrent request handling
|
||||
- Pattern-based deletion for cache invalidation
|
||||
|
||||
Usage:
|
||||
from app.core.cache import get_cache
|
||||
|
||||
cache = get_cache()
|
||||
|
||||
# Get cached value with 60-second TTL
|
||||
cached_data = cache.get("agents_minimal_user123", ttl=60)
|
||||
if not cached_data:
|
||||
data = await fetch_from_db()
|
||||
cache.set("agents_minimal_user123", data)
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Dict, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleCache:
|
||||
"""
|
||||
Thread-safe TTL cache for API responses and database query results.
|
||||
|
||||
This cache is per-worker (each Uvicorn worker maintains separate cache).
|
||||
Cache keys should include tenant_domain or user_id for proper isolation.
|
||||
|
||||
Attributes:
|
||||
max_entries: Maximum number of cache entries before LRU eviction
|
||||
_cache: Internal cache storage (key -> (timestamp, data))
|
||||
_lock: Thread lock for safe concurrent access
|
||||
"""
|
||||
|
||||
def __init__(self, max_entries: int = 1000):
|
||||
"""
|
||||
Initialize cache with maximum entry limit.
|
||||
|
||||
Args:
|
||||
max_entries: Maximum cache entries (default 1000)
|
||||
Typical: 200KB per agent list × 1000 = 200MB per worker
|
||||
"""
|
||||
self._cache: Dict[str, Tuple[datetime, Any]] = {}
|
||||
self._lock = Lock()
|
||||
self._max_entries = max_entries
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
logger.info(f"SimpleCache initialized with max_entries={max_entries}")
|
||||
|
||||
def get(self, key: str, ttl: int = 60) -> Optional[Any]:
|
||||
"""
|
||||
Get cached value if not expired.
|
||||
|
||||
Args:
|
||||
key: Cache key (should include tenant/user for isolation)
|
||||
ttl: Time-to-live in seconds (default 60)
|
||||
|
||||
Returns:
|
||||
Cached data if found and not expired, None otherwise
|
||||
|
||||
Example:
|
||||
data = cache.get("agents_minimal_user123", ttl=60)
|
||||
if data is None:
|
||||
# Cache miss - fetch from database
|
||||
data = await fetch_from_db()
|
||||
cache.set("agents_minimal_user123", data)
|
||||
"""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
self._misses += 1
|
||||
logger.debug(f"Cache miss: {key}")
|
||||
return None
|
||||
|
||||
timestamp, data = self._cache[key]
|
||||
age = (datetime.utcnow() - timestamp).total_seconds()
|
||||
|
||||
if age > ttl:
|
||||
# Expired - remove and return None
|
||||
del self._cache[key]
|
||||
self._misses += 1
|
||||
logger.debug(f"Cache expired: {key} (age={age:.1f}s, ttl={ttl}s)")
|
||||
return None
|
||||
|
||||
self._hits += 1
|
||||
logger.debug(f"Cache hit: {key} (age={age:.1f}s, ttl={ttl}s)")
|
||||
return data
|
||||
|
||||
def set(self, key: str, data: Any) -> None:
|
||||
"""
|
||||
Set cache value with current timestamp.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
data: Data to cache (should be JSON-serializable)
|
||||
|
||||
Note:
|
||||
If cache is full, oldest entry is evicted (LRU)
|
||||
"""
|
||||
with self._lock:
|
||||
# LRU eviction if cache full
|
||||
if len(self._cache) >= self._max_entries:
|
||||
oldest_key = min(self._cache.items(), key=lambda x: x[1][0])[0]
|
||||
del self._cache[oldest_key]
|
||||
logger.warning(
|
||||
f"Cache full ({self._max_entries} entries), "
|
||||
f"evicted oldest key: {oldest_key}"
|
||||
)
|
||||
|
||||
self._cache[key] = (datetime.utcnow(), data)
|
||||
logger.debug(f"Cache set: {key} (total entries: {len(self._cache)})")
|
||||
|
||||
def delete(self, pattern: str) -> int:
|
||||
"""
|
||||
Delete all keys matching pattern (prefix match).
|
||||
|
||||
Args:
|
||||
pattern: Key prefix to match (e.g., "agents_minimal_")
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
|
||||
Example:
|
||||
# Delete all agent cache entries for a user
|
||||
count = cache.delete(f"agents_minimal_{user_id}")
|
||||
count += cache.delete(f"agents_summary_{user_id}")
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_delete = [k for k in self._cache.keys() if k.startswith(pattern)]
|
||||
for k in keys_to_delete:
|
||||
del self._cache[k]
|
||||
|
||||
if keys_to_delete:
|
||||
logger.info(f"Cache invalidated {len(keys_to_delete)} entries matching '{pattern}'")
|
||||
|
||||
return len(keys_to_delete)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear entire cache (use with caution)."""
|
||||
with self._lock:
|
||||
entry_count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
logger.warning(f"Cache cleared (removed {entry_count} entries)")
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get number of cached entries."""
|
||||
return len(self._cache)
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dict with size, hits, misses, hit_rate
|
||||
"""
|
||||
total_requests = self._hits + self._misses
|
||||
hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
"size": len(self._cache),
|
||||
"max_entries": self._max_entries,
|
||||
"hits": self._hits,
|
||||
"misses": self._misses,
|
||||
"hit_rate_percent": round(hit_rate, 2),
|
||||
}
|
||||
|
||||
|
||||
# Singleton cache instance per worker
|
||||
_cache: Optional[SimpleCache] = None
|
||||
|
||||
|
||||
def get_cache() -> SimpleCache:
|
||||
"""
|
||||
Get or create singleton cache instance.
|
||||
|
||||
Each Uvicorn worker creates its own cache instance (isolated per-process).
|
||||
|
||||
Returns:
|
||||
SimpleCache instance
|
||||
"""
|
||||
global _cache
|
||||
if _cache is None:
|
||||
_cache = SimpleCache(max_entries=1000)
|
||||
return _cache
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
"""Clear global cache (for testing or emergency use)."""
|
||||
cache = get_cache()
|
||||
cache.clear()
|
||||
380
apps/tenant-backend/app/core/capability_client.py
Normal file
380
apps/tenant-backend/app/core/capability_client.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend - Capability Client
|
||||
Generate JWT capability tokens for Resource Cluster API calls
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from jose import jwt
|
||||
from app.core.config import get_settings
|
||||
import logging
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class CapabilityClient:
|
||||
"""Generates capability-based JWT tokens for Resource Cluster access"""
|
||||
|
||||
def __init__(self):
|
||||
# Use tenant-specific secret key for token signing
|
||||
self.secret_key = settings.secret_key
|
||||
self.algorithm = "HS256"
|
||||
self.issuer = f"gt2-tenant-{settings.tenant_id}"
|
||||
self.http_client = httpx.AsyncClient(timeout=10.0)
|
||||
self.control_panel_url = settings.control_panel_url
|
||||
|
||||
async def generate_capability_token(
|
||||
self,
|
||||
user_email: str,
|
||||
tenant_id: str,
|
||||
resources: List[str],
|
||||
expires_hours: int = 24,
|
||||
additional_claims: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a JWT capability token for Resource Cluster API access.
|
||||
|
||||
Args:
|
||||
user_email: Email of the user making the request
|
||||
tenant_id: Tenant identifier
|
||||
resources: List of resource capabilities (e.g., ['external_services', 'rag_processing'])
|
||||
expires_hours: Token expiration time in hours
|
||||
additional_claims: Additional JWT claims to include
|
||||
|
||||
Returns:
|
||||
Signed JWT token string
|
||||
"""
|
||||
|
||||
now = datetime.utcnow()
|
||||
expiry = now + timedelta(hours=expires_hours)
|
||||
|
||||
# Build capability token payload
|
||||
payload = {
|
||||
# Standard JWT claims
|
||||
"iss": self.issuer, # Issuer
|
||||
"sub": user_email, # Subject (user)
|
||||
"aud": "gt2-resource-cluster", # Audience
|
||||
"iat": int(now.timestamp()), # Issued at
|
||||
"exp": int(expiry.timestamp()), # Expiration
|
||||
"nbf": int(now.timestamp()), # Not before
|
||||
"jti": f"{tenant_id}-{user_email}-{int(now.timestamp())}", # JWT ID
|
||||
|
||||
# GT 2.0 specific claims
|
||||
"tenant_id": tenant_id,
|
||||
"user_email": user_email,
|
||||
"user_type": "tenant_user",
|
||||
|
||||
# Capability grants
|
||||
"capabilities": await self._build_capabilities(resources, tenant_id, expiry),
|
||||
|
||||
# Security metadata
|
||||
"capability_hash": self._generate_capability_hash(resources, tenant_id),
|
||||
"token_version": "2.0",
|
||||
"security_level": "standard"
|
||||
}
|
||||
|
||||
# Add any additional claims
|
||||
if additional_claims:
|
||||
payload.update(additional_claims)
|
||||
|
||||
# Sign the token
|
||||
try:
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
self.secret_key,
|
||||
algorithm=self.algorithm
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated capability token for {user_email} with resources: {resources}"
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate capability token: {e}")
|
||||
raise RuntimeError(f"Token generation failed: {e}")
|
||||
|
||||
async def _build_capabilities(
|
||||
self,
|
||||
resources: List[str],
|
||||
tenant_id: str,
|
||||
expiry: datetime
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Build capability grants for resources with constraints from Control Panel.
|
||||
|
||||
For LLM resources, fetches real rate limits from Control Panel API.
|
||||
For other resources, uses default constraints.
|
||||
"""
|
||||
capabilities = []
|
||||
|
||||
for resource in resources:
|
||||
capability = {
|
||||
"resource": resource,
|
||||
"actions": self._get_default_actions(resource),
|
||||
"constraints": await self._get_constraints_for_resource(resource, tenant_id),
|
||||
"valid_until": expiry.isoformat()
|
||||
}
|
||||
capabilities.append(capability)
|
||||
|
||||
return capabilities
|
||||
|
||||
async def _get_constraints_for_resource(
|
||||
self,
|
||||
resource: str,
|
||||
tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get constraints for a resource, fetching from Control Panel for LLM resources.
|
||||
|
||||
GT 2.0 Principle: Single source of truth in database.
|
||||
Fails fast if Control Panel is unreachable for LLM resources.
|
||||
"""
|
||||
# For LLM resources, fetch real config from Control Panel
|
||||
if resource in ["llm", "llm_inference"]:
|
||||
# Note: We don't have model_id at this point in the flow
|
||||
# This is called during general capability token generation
|
||||
# For now, return default constraints that will be overridden
|
||||
# when model-specific tokens are generated
|
||||
return self._get_default_constraints(resource)
|
||||
|
||||
# For non-LLM resources, use defaults
|
||||
return self._get_default_constraints(resource)
|
||||
|
||||
async def _fetch_tenant_model_config(
|
||||
self,
|
||||
tenant_id: str,
|
||||
model_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch tenant model configuration from Control Panel API.
|
||||
|
||||
Returns rate limits from database (single source of truth).
|
||||
Fails fast if Control Panel is unreachable (no fallbacks).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Model config with rate_limits, or None if not found
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Control Panel API is unreachable (fail fast)
|
||||
"""
|
||||
try:
|
||||
url = f"{self.control_panel_url}/api/v1/tenant-models/tenants/{tenant_id}/models/{model_id}"
|
||||
|
||||
logger.debug(f"Fetching model config from Control Panel: {url}")
|
||||
|
||||
response = await self.http_client.get(url)
|
||||
|
||||
if response.status_code == 404:
|
||||
logger.warning(f"Model {model_id} not configured for tenant {tenant_id}")
|
||||
return None
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
config = response.json()
|
||||
logger.info(f"Fetched model config for {model_id}: rate_limits={config.get('rate_limits')}")
|
||||
|
||||
return config
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Control Panel API error: {e.response.status_code}")
|
||||
raise RuntimeError(
|
||||
f"Failed to fetch model config from Control Panel: HTTP {e.response.status_code}"
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Control Panel API unreachable: {e}")
|
||||
raise RuntimeError(
|
||||
f"Control Panel API unreachable - cannot generate capability token. "
|
||||
f"Ensure Control Panel is running at {self.control_panel_url}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error fetching model config: {e}")
|
||||
raise RuntimeError(f"Failed to fetch model config: {e}")
|
||||
|
||||
def _get_default_actions(self, resource: str) -> List[str]:
|
||||
"""Get default actions for a resource type"""
|
||||
|
||||
action_mappings = {
|
||||
"external_services": ["create", "read", "update", "delete", "health_check", "sso_token"],
|
||||
"rag_processing": ["process_document", "generate_embeddings", "vector_search"],
|
||||
"llm_inference": ["chat_completion", "streaming", "function_calling"],
|
||||
"llm": ["execute"], # Use valid ActionType from resource cluster
|
||||
"agent_orchestration": ["execute", "status", "interrupt"],
|
||||
"ai_literacy": ["play_games", "solve_puzzles", "dialogue", "analytics"],
|
||||
"app_integrations": ["read", "write", "webhook"],
|
||||
"admin": ["all"],
|
||||
# MCP Server Resources
|
||||
"mcp:rag": ["search_datasets", "query_documents", "list_user_datasets", "get_dataset_info", "get_relevant_chunks"]
|
||||
}
|
||||
|
||||
return action_mappings.get(resource, ["read"])
|
||||
|
||||
def _get_default_constraints(self, resource: str) -> Dict[str, Any]:
|
||||
"""Get default constraints for a resource type"""
|
||||
|
||||
constraint_mappings = {
|
||||
"external_services": {
|
||||
"max_instances_per_user": 10,
|
||||
"max_cpu_per_instance": "2000m",
|
||||
"max_memory_per_instance": "4Gi",
|
||||
"max_storage_per_instance": "50Gi",
|
||||
"allowed_service_types": ["ctfd", "canvas", "guacamole"]
|
||||
},
|
||||
"rag_processing": {
|
||||
"max_document_size_mb": 100,
|
||||
"max_batch_size": 50,
|
||||
"max_requests_per_hour": 1000
|
||||
},
|
||||
"llm_inference": {
|
||||
"max_tokens_per_request": 4000,
|
||||
"max_requests_per_hour": 100,
|
||||
"allowed_models": [] # Models dynamically determined by admin backend
|
||||
},
|
||||
"llm": {
|
||||
"max_tokens_per_request": 4000,
|
||||
"max_requests_per_hour": 100,
|
||||
"allowed_models": [] # Models dynamically determined by admin backend
|
||||
},
|
||||
"agent_orchestration": {
|
||||
"max_concurrent_agents": 5,
|
||||
"max_execution_time_minutes": 30
|
||||
},
|
||||
"ai_literacy": {
|
||||
"max_sessions_per_day": 20,
|
||||
"max_session_duration_hours": 4
|
||||
},
|
||||
"app_integrations": {
|
||||
"max_api_calls_per_hour": 500,
|
||||
"allowed_domains": ["api.example.com"]
|
||||
},
|
||||
# MCP Server Resources
|
||||
"mcp:rag": {
|
||||
"max_requests_per_hour": 500,
|
||||
"max_results_per_query": 50
|
||||
}
|
||||
}
|
||||
|
||||
return constraint_mappings.get(resource, {})
|
||||
|
||||
def _generate_capability_hash(self, resources: List[str], tenant_id: str) -> str:
|
||||
"""Generate a hash of the capabilities for verification"""
|
||||
import hashlib
|
||||
|
||||
# Create a deterministic string from capabilities
|
||||
capability_string = f"{tenant_id}:{':'.join(sorted(resources))}"
|
||||
|
||||
# Hash with SHA-256
|
||||
hash_object = hashlib.sha256(capability_string.encode())
|
||||
return hash_object.hexdigest()[:16] # First 16 characters
|
||||
|
||||
async def verify_capability_token(self, token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify and decode a capability token.
|
||||
|
||||
Args:
|
||||
token: JWT token to verify
|
||||
|
||||
Returns:
|
||||
Decoded token payload
|
||||
|
||||
Raises:
|
||||
ValueError: If token is invalid or expired
|
||||
"""
|
||||
|
||||
try:
|
||||
# Decode and verify the token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.secret_key,
|
||||
algorithms=[self.algorithm],
|
||||
audience="gt2-resource-cluster"
|
||||
)
|
||||
|
||||
# Additional validation
|
||||
if payload.get("iss") != self.issuer:
|
||||
raise ValueError("Invalid token issuer")
|
||||
|
||||
# Check if token is still valid
|
||||
now = datetime.utcnow()
|
||||
if payload.get("exp", 0) < now.timestamp():
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
if payload.get("nbf", 0) > now.timestamp():
|
||||
raise ValueError("Token not yet valid")
|
||||
|
||||
logger.debug(f"Verified capability token for user {payload.get('user_email')}")
|
||||
|
||||
return payload
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.JWTClaimsError as e:
|
||||
raise ValueError(f"Token claims validation failed: {e}")
|
||||
except jwt.JWTError as e:
|
||||
raise ValueError(f"Token validation failed: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Capability token verification failed: {e}")
|
||||
raise ValueError(f"Invalid token: {e}")
|
||||
|
||||
async def refresh_capability_token(
|
||||
self,
|
||||
current_token: str,
|
||||
extend_hours: int = 24
|
||||
) -> str:
|
||||
"""
|
||||
Refresh an existing capability token with extended expiration.
|
||||
|
||||
Args:
|
||||
current_token: Current JWT token
|
||||
extend_hours: Hours to extend from now
|
||||
|
||||
Returns:
|
||||
New JWT token with extended expiration
|
||||
"""
|
||||
|
||||
# Verify current token
|
||||
payload = await self.verify_capability_token(current_token)
|
||||
|
||||
# Extract current capabilities
|
||||
resources = [cap.get("resource") for cap in payload.get("capabilities", [])]
|
||||
|
||||
# Generate new token with extended expiration
|
||||
return await self.generate_capability_token(
|
||||
user_email=payload.get("user_email"),
|
||||
tenant_id=payload.get("tenant_id"),
|
||||
resources=resources,
|
||||
expires_hours=extend_hours
|
||||
)
|
||||
|
||||
def get_token_info(self, token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about a token without full verification.
|
||||
Useful for debugging and logging.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Decode without verification to get claims
|
||||
payload = jwt.get_unverified_claims(token)
|
||||
|
||||
return {
|
||||
"user_email": payload.get("user_email"),
|
||||
"tenant_id": payload.get("tenant_id"),
|
||||
"resources": [cap.get("resource") for cap in payload.get("capabilities", [])],
|
||||
"expires_at": datetime.fromtimestamp(payload.get("exp", 0)).isoformat(),
|
||||
"issued_at": datetime.fromtimestamp(payload.get("iat", 0)).isoformat(),
|
||||
"token_version": payload.get("token_version"),
|
||||
"security_level": payload.get("security_level")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get token info: {e}")
|
||||
return {"error": str(e)}
|
||||
289
apps/tenant-backend/app/core/config.py
Normal file
289
apps/tenant-backend/app/core/config.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend Configuration
|
||||
|
||||
Environment-based configuration for tenant applications with perfect isolation.
|
||||
Each tenant gets its own isolated backend instance with separate database files.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field, validator
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings with environment variable support"""
|
||||
|
||||
# Environment
|
||||
environment: str = Field(default="development", description="Runtime environment")
|
||||
debug: bool = Field(default=False, description="Debug mode")
|
||||
|
||||
# Tenant Identification (Critical for isolation)
|
||||
tenant_id: str = Field(..., description="Unique tenant identifier")
|
||||
tenant_domain: str = Field(..., description="Tenant domain (e.g., customer1)")
|
||||
|
||||
# Database Configuration (PostgreSQL + PGVector direct connection)
|
||||
database_url: str = Field(
|
||||
default="postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres-primary:5432/gt2_tenants",
|
||||
description="PostgreSQL connection URL (direct to primary)"
|
||||
)
|
||||
|
||||
|
||||
# PostgreSQL Configuration
|
||||
postgres_schema: str = Field(
|
||||
default="tenant_test",
|
||||
description="PostgreSQL schema for tenant data (tenant_{tenant_domain})"
|
||||
)
|
||||
postgres_pool_size: int = Field(
|
||||
default=10,
|
||||
description="Connection pool size for PostgreSQL"
|
||||
)
|
||||
postgres_max_overflow: int = Field(
|
||||
default=20,
|
||||
description="Max overflow connections for PostgreSQL pool"
|
||||
)
|
||||
|
||||
|
||||
# Authentication & Security
|
||||
secret_key: str = Field(..., description="JWT signing key")
|
||||
algorithm: str = Field(default="HS256", description="JWT algorithm")
|
||||
|
||||
# OAuth2 Configuration
|
||||
require_oauth2_auth: bool = Field(
|
||||
default=True,
|
||||
description="Require OAuth2 authentication for API endpoints"
|
||||
)
|
||||
oauth2_proxy_url: str = Field(
|
||||
default="http://oauth2-proxy:4180",
|
||||
description="Internal URL of OAuth2 Proxy service"
|
||||
)
|
||||
oauth2_issuer_url: str = Field(
|
||||
default="https://auth.gt2.com",
|
||||
description="OAuth2 provider issuer URL"
|
||||
)
|
||||
oauth2_audience: str = Field(
|
||||
default="gt2-tenant-client",
|
||||
description="OAuth2 token audience"
|
||||
)
|
||||
|
||||
# Resource Cluster Integration
|
||||
resource_cluster_url: str = Field(
|
||||
default="http://localhost:8004",
|
||||
description="URL of the Resource Cluster API"
|
||||
)
|
||||
resource_cluster_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Resource Cluster authentication"
|
||||
)
|
||||
|
||||
# MCP Service Configuration
|
||||
mcp_service_url: str = Field(
|
||||
default="http://resource-cluster:8000",
|
||||
description="URL of the MCP service for tool execution"
|
||||
)
|
||||
|
||||
# Control Panel Integration
|
||||
control_panel_url: str = Field(
|
||||
default="http://localhost:8001",
|
||||
description="URL of the Control Panel API"
|
||||
)
|
||||
service_auth_token: str = Field(
|
||||
default="internal-service-token",
|
||||
description="Service-to-service authentication token"
|
||||
)
|
||||
|
||||
# WebSocket Configuration
|
||||
websocket_ping_interval: int = Field(default=25, description="WebSocket ping interval")
|
||||
websocket_ping_timeout: int = Field(default=20, description="WebSocket ping timeout")
|
||||
|
||||
# File Upload Configuration
|
||||
max_file_size_mb: int = Field(default=10, description="Maximum file upload size in MB")
|
||||
allowed_file_types: List[str] = Field(
|
||||
default=[".pdf", ".docx", ".txt", ".md", ".csv", ".xlsx"],
|
||||
description="Allowed file extensions for upload"
|
||||
)
|
||||
upload_directory: str = Field(
|
||||
default_factory=lambda: f"/tmp/gt2-data/{os.getenv('TENANT_DOMAIN', 'default')}/uploads" if os.getenv('ENVIRONMENT') == 'test' else f"/data/{os.getenv('TENANT_DOMAIN', 'default')}/uploads",
|
||||
description="Directory for uploaded files"
|
||||
)
|
||||
temp_directory: str = Field(
|
||||
default_factory=lambda: f"/tmp/gt2-data/{os.getenv('TENANT_DOMAIN', 'default')}/temp" if os.getenv('ENVIRONMENT') == 'test' else f"/data/{os.getenv('TENANT_DOMAIN', 'default')}/temp",
|
||||
description="Temporary directory for file processing"
|
||||
)
|
||||
file_storage_path: str = Field(
|
||||
default_factory=lambda: f"/tmp/gt2-data/{os.getenv('TENANT_DOMAIN', 'default')}" if os.getenv('ENVIRONMENT') == 'test' else f"/data/{os.getenv('TENANT_DOMAIN', 'default')}",
|
||||
description="Root directory for file storage (conversation files, etc.)"
|
||||
)
|
||||
|
||||
# File Context Settings (for chat attachments)
|
||||
max_chunks_per_file: int = Field(
|
||||
default=50,
|
||||
description="Maximum chunks per file (enforces diversity across files)"
|
||||
)
|
||||
max_total_file_chunks: int = Field(
|
||||
default=100,
|
||||
description="Maximum total chunks across all attached files"
|
||||
)
|
||||
file_context_token_safety_margin: float = Field(
|
||||
default=0.05,
|
||||
description="Safety margin for token budget calculations (0.05 = 5%)"
|
||||
)
|
||||
|
||||
# Rate Limiting
|
||||
rate_limit_requests: int = Field(default=1000, description="Requests per minute per IP")
|
||||
rate_limit_window_seconds: int = Field(default=60, description="Rate limit window")
|
||||
|
||||
# CORS Configuration
|
||||
cors_origins: List[str] = Field(
|
||||
default=["http://localhost:3001", "http://localhost:3002", "https://*.gt2.com"],
|
||||
description="Allowed CORS origins"
|
||||
)
|
||||
|
||||
# Security
|
||||
allowed_hosts: List[str] = Field(
|
||||
default=["localhost", "*.gt2.com", "testserver", "gentwo-tenant-backend", "tenant-backend"],
|
||||
description="Allowed host headers"
|
||||
)
|
||||
|
||||
# Vector Storage Configuration (PGVector integrated with PostgreSQL)
|
||||
vector_dimensions: int = Field(
|
||||
default=384,
|
||||
description="Vector dimensions for embeddings (all-MiniLM-L6-v2 model)"
|
||||
)
|
||||
embedding_model: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Embedding model for document processing"
|
||||
)
|
||||
vector_similarity_threshold: float = Field(
|
||||
default=0.3,
|
||||
description="Minimum similarity threshold for vector search"
|
||||
)
|
||||
|
||||
# Legacy ChromaDB Configuration (DEPRECATED - replaced by PGVector)
|
||||
chromadb_mode: str = Field(
|
||||
default="disabled",
|
||||
description="ChromaDB mode - DEPRECATED, using PGVector instead"
|
||||
)
|
||||
chromadb_host: str = Field(
|
||||
default_factory=lambda: f"tenant-{os.getenv('TENANT_DOMAIN', 'test')}-chromadb",
|
||||
description="ChromaDB host - DEPRECATED"
|
||||
)
|
||||
chromadb_port: int = Field(
|
||||
default=8000,
|
||||
description="ChromaDB HTTP port - DEPRECATED"
|
||||
)
|
||||
chromadb_path: str = Field(
|
||||
default_factory=lambda: f"/data/{os.getenv('TENANT_DOMAIN', 'default')}/chromadb",
|
||||
description="ChromaDB file storage path - DEPRECATED"
|
||||
)
|
||||
|
||||
# Redis removed - PostgreSQL handles all caching and session storage needs
|
||||
|
||||
# Logging Configuration
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
log_format: str = Field(default="json", description="Log format: json or text")
|
||||
|
||||
# Performance
|
||||
worker_processes: int = Field(default=1, description="Number of worker processes")
|
||||
max_connections: int = Field(default=100, description="Maximum concurrent connections")
|
||||
|
||||
# Monitoring
|
||||
prometheus_enabled: bool = Field(default=True, description="Enable Prometheus metrics")
|
||||
prometheus_port: int = Field(default=9090, description="Prometheus metrics port")
|
||||
|
||||
# Feature Flags
|
||||
enable_file_upload: bool = Field(default=True, description="Enable file upload feature")
|
||||
enable_voice_input: bool = Field(default=False, description="Enable voice input (future)")
|
||||
enable_document_analysis: bool = Field(default=True, description="Enable document analysis")
|
||||
|
||||
@validator("tenant_id")
|
||||
def validate_tenant_id(cls, v):
|
||||
if not v or len(v) < 3:
|
||||
raise ValueError("Tenant ID must be at least 3 characters long")
|
||||
return v
|
||||
|
||||
@validator("tenant_domain")
|
||||
def validate_tenant_domain(cls, v):
|
||||
if not v or not v.replace("-", "").replace("_", "").isalnum():
|
||||
raise ValueError("Tenant domain must be alphanumeric with optional hyphens/underscores")
|
||||
return v
|
||||
|
||||
|
||||
@validator("upload_directory")
|
||||
def validate_upload_directory(cls, v):
|
||||
# Ensure the upload directory exists with secure permissions
|
||||
os.makedirs(v, exist_ok=True, mode=0o700)
|
||||
return v
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
def get_settings(tenant_id: Optional[str] = None) -> Settings:
|
||||
"""Get tenant-scoped application settings"""
|
||||
# For development and testing, use simple settings without caching
|
||||
if os.getenv("ENVIRONMENT") in ["development", "test"]:
|
||||
return Settings()
|
||||
|
||||
# In production, settings should be tenant-scoped
|
||||
# This prevents global state from affecting tenant isolation
|
||||
if tenant_id:
|
||||
# Create tenant-specific settings with proper isolation
|
||||
settings = Settings()
|
||||
# In production, this could load tenant-specific overrides
|
||||
return settings
|
||||
else:
|
||||
# Default settings for non-tenant operations
|
||||
return Settings()
|
||||
|
||||
|
||||
# Security and isolation utilities
|
||||
def get_tenant_data_path(tenant_domain: str) -> str:
|
||||
"""Get the secure data path for a tenant"""
|
||||
if os.getenv('ENVIRONMENT') == 'test':
|
||||
return f"/tmp/gt2-data/{tenant_domain}"
|
||||
return f"/data/{tenant_domain}"
|
||||
|
||||
|
||||
def get_tenant_database_url(tenant_domain: str) -> str:
|
||||
"""Get the database URL for a specific tenant (PostgreSQL)"""
|
||||
return f"postgresql://gt2_tenant_user:gt2_tenant_dev_password@tenant-postgres:5432/gt2_tenants"
|
||||
|
||||
|
||||
def get_tenant_schema_name(tenant_domain: str) -> str:
|
||||
"""Get the PostgreSQL schema name for a specific tenant"""
|
||||
# Clean domain name for schema usage
|
||||
clean_domain = tenant_domain.replace('-', '_').replace('.', '_').lower()
|
||||
return f"tenant_{clean_domain}"
|
||||
|
||||
|
||||
def ensure_tenant_isolation(tenant_id: str) -> None:
|
||||
"""Ensure proper tenant isolation is configured"""
|
||||
settings = get_settings()
|
||||
|
||||
if settings.tenant_id != tenant_id:
|
||||
raise ValueError(f"Tenant ID mismatch: expected {settings.tenant_id}, got {tenant_id}")
|
||||
|
||||
# Verify database path contains tenant identifier
|
||||
if settings.tenant_domain not in settings.database_path:
|
||||
raise ValueError("Database path does not contain tenant identifier - isolation breach risk")
|
||||
|
||||
# Verify upload directory contains tenant identifier
|
||||
if settings.tenant_domain not in settings.upload_directory:
|
||||
raise ValueError("Upload directory does not contain tenant identifier - isolation breach risk")
|
||||
|
||||
|
||||
# Development helpers
|
||||
def is_development() -> bool:
|
||||
"""Check if running in development mode"""
|
||||
return get_settings().environment == "development"
|
||||
|
||||
|
||||
def is_production() -> bool:
|
||||
"""Check if running in production mode"""
|
||||
return get_settings().environment == "production"
|
||||
131
apps/tenant-backend/app/core/database.py
Normal file
131
apps/tenant-backend/app/core/database.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend Database Configuration - PostgreSQL + PGVector Client
|
||||
|
||||
Migrated from DuckDB service to PostgreSQL + PGVector for enterprise readiness:
|
||||
- PostgreSQL + PGVector unified storage (replaces DuckDB + ChromaDB)
|
||||
- BionicGPT Row Level Security patterns for enterprise isolation
|
||||
- MVCC concurrency solving DuckDB file locking issues
|
||||
- Hybrid vector + full-text search in single queries
|
||||
- Connection pooling for 10,000+ concurrent connections
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Generator, Optional, Any, Dict, List
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.postgresql_client import (
|
||||
get_postgresql_client, init_postgresql, close_postgresql,
|
||||
get_db_session, execute_query, execute_command,
|
||||
fetch_one, fetch_scalar, health_check, get_database_info
|
||||
)
|
||||
|
||||
# Legacy DuckDB imports removed - PostgreSQL + PGVector only
|
||||
|
||||
# SQLAlchemy Base for ORM models
|
||||
Base = declarative_base()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
# PostgreSQL client is managed by postgresql_client module
|
||||
|
||||
|
||||
async def init_database() -> None:
|
||||
"""Initialize PostgreSQL + PGVector connection"""
|
||||
logger.info("Initializing PostgreSQL + PGVector database connection...")
|
||||
|
||||
try:
|
||||
await init_postgresql()
|
||||
logger.info("PostgreSQL + PGVector connection initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL database: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def close_database() -> None:
|
||||
"""Close PostgreSQL connections"""
|
||||
try:
|
||||
await close_postgresql()
|
||||
logger.info("PostgreSQL connections closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing PostgreSQL connections: {e}")
|
||||
|
||||
|
||||
async def get_db_client_instance():
|
||||
"""Get the PostgreSQL client instance"""
|
||||
return await get_postgresql_client()
|
||||
|
||||
|
||||
# get_db_session is imported from postgresql_client
|
||||
|
||||
|
||||
# execute_query is imported from postgresql_client
|
||||
|
||||
|
||||
# execute_command is imported from postgresql_client
|
||||
|
||||
|
||||
async def execute_transaction(commands: List[Dict[str, Any]]) -> List[int]:
|
||||
"""Execute multiple commands in a transaction (PostgreSQL format)"""
|
||||
client = await get_postgresql_client()
|
||||
pg_commands = [(cmd.get('query', cmd.get('command', '')), tuple(cmd.get('params', {}).values())) for cmd in commands]
|
||||
return await client.execute_transaction(pg_commands)
|
||||
|
||||
|
||||
# fetch_one is imported from postgresql_client
|
||||
|
||||
|
||||
async def fetch_all(query: str, *args) -> List[Dict[str, Any]]:
|
||||
"""Execute query and return all rows"""
|
||||
return await execute_query(query, *args)
|
||||
|
||||
|
||||
# fetch_scalar is imported from postgresql_client
|
||||
|
||||
|
||||
# get_database_info is imported from postgresql_client
|
||||
|
||||
|
||||
# health_check is imported from postgresql_client
|
||||
|
||||
|
||||
# Legacy compatibility functions (for gradual migration)
|
||||
def get_db() -> Generator[None, None, None]:
|
||||
"""Legacy sync database dependency - deprecated"""
|
||||
logger.warning("get_db() is deprecated. Use async get_db_session() instead")
|
||||
# Return a dummy generator for compatibility
|
||||
yield None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session_sync():
|
||||
"""Legacy sync session - deprecated"""
|
||||
logger.warning("get_db_session_sync() is deprecated. Use async get_db_session() instead")
|
||||
yield None
|
||||
|
||||
|
||||
def execute_raw_query(query: str, params: Optional[Dict] = None) -> List[Dict]:
|
||||
"""Legacy sync query execution - deprecated"""
|
||||
logger.error("execute_raw_query() is deprecated and not supported with PostgreSQL async client")
|
||||
raise NotImplementedError("Use async execute_query() instead")
|
||||
|
||||
|
||||
def verify_tenant_isolation() -> bool:
|
||||
"""Verify tenant isolation - PostgreSQL schema-based isolation with RLS is always enabled"""
|
||||
return True
|
||||
|
||||
|
||||
# Initialize database on module import (for FastAPI startup)
|
||||
async def startup_database():
|
||||
"""Initialize database during FastAPI startup"""
|
||||
await init_database()
|
||||
|
||||
|
||||
async def shutdown_database():
|
||||
"""Cleanup database during FastAPI shutdown"""
|
||||
await close_database()
|
||||
348
apps/tenant-backend/app/core/database_interface.py
Normal file
348
apps/tenant-backend/app/core/database_interface.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
GT 2.0 Database Interface - DuckDB Implementation
|
||||
|
||||
Provides a unified interface for DuckDB database operations
|
||||
following GT 2.0 principles of Zero Downtime, Perfect Tenant Isolation, and Elegant Simplicity.
|
||||
Post-migration: SQLite has been completely replaced with DuckDB for enhanced MVCC performance.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, AsyncGenerator, Union
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class DatabaseEngine(Enum):
|
||||
"""Supported database engines - DEPRECATED: Use PostgreSQL directly"""
|
||||
POSTGRESQL = "postgresql"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database configuration"""
|
||||
engine: DatabaseEngine
|
||||
database_path: str
|
||||
tenant_id: str
|
||||
shard_id: Optional[str] = None
|
||||
encryption_key: Optional[str] = None
|
||||
connection_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
"""Standardized query result"""
|
||||
rows: List[Dict[str, Any]]
|
||||
row_count: int
|
||||
columns: List[str]
|
||||
execution_time_ms: float
|
||||
|
||||
|
||||
class DatabaseInterface(ABC):
|
||||
"""
|
||||
Abstract database interface for GT 2.0 tenant isolation.
|
||||
|
||||
DuckDB implementation with MVCC concurrency for true zero-downtime operations,
|
||||
perfect tenant isolation, and 10x analytical performance improvements.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DatabaseConfig):
|
||||
self.config = config
|
||||
self.tenant_id = config.tenant_id
|
||||
self.database_path = config.database_path
|
||||
self.engine = config.engine
|
||||
|
||||
# Connection Management
|
||||
@abstractmethod
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize database connection and create tables"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close database connections"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def is_initialized(self) -> bool:
|
||||
"""Check if database is properly initialized"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@asynccontextmanager
|
||||
async def get_session(self) -> AsyncGenerator[Any, None]:
|
||||
"""Get database session context manager"""
|
||||
pass
|
||||
|
||||
# Schema Management
|
||||
@abstractmethod
|
||||
async def create_tables(self) -> None:
|
||||
"""Create all required tables"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_schema_version(self) -> Optional[str]:
|
||||
"""Get current database schema version"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def migrate_schema(self, target_version: str) -> bool:
|
||||
"""Migrate database schema to target version"""
|
||||
pass
|
||||
|
||||
# Query Operations
|
||||
@abstractmethod
|
||||
async def execute_query(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> QueryResult:
|
||||
"""Execute SELECT query and return results"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_command(
|
||||
self,
|
||||
command: str,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> int:
|
||||
"""Execute INSERT/UPDATE/DELETE command and return affected rows"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_batch(
|
||||
self,
|
||||
commands: List[str],
|
||||
params: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[int]:
|
||||
"""Execute batch commands in transaction"""
|
||||
pass
|
||||
|
||||
# Transaction Management
|
||||
@abstractmethod
|
||||
@asynccontextmanager
|
||||
async def transaction(self) -> AsyncGenerator[Any, None]:
|
||||
"""Transaction context manager"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def begin_transaction(self) -> Any:
|
||||
"""Begin transaction and return transaction handle"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def commit_transaction(self, tx: Any) -> None:
|
||||
"""Commit transaction"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def rollback_transaction(self, tx: Any) -> None:
|
||||
"""Rollback transaction"""
|
||||
pass
|
||||
|
||||
# Health and Monitoring
|
||||
@abstractmethod
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Check database health and return status"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get database statistics"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def optimize(self) -> bool:
|
||||
"""Optimize database performance"""
|
||||
pass
|
||||
|
||||
# Backup and Recovery
|
||||
@abstractmethod
|
||||
async def backup(self, backup_path: str) -> bool:
|
||||
"""Create database backup"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def restore(self, backup_path: str) -> bool:
|
||||
"""Restore from database backup"""
|
||||
pass
|
||||
|
||||
# Sharding Support (DuckDB specific)
|
||||
@abstractmethod
|
||||
async def create_shard(self, shard_id: str) -> bool:
|
||||
"""Create new database shard"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_shard_info(self) -> Dict[str, Any]:
|
||||
"""Get information about current shard"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def migrate_to_shard(self, source_db: 'DatabaseInterface') -> bool:
|
||||
"""Migrate data from another database instance"""
|
||||
pass
|
||||
|
||||
# Vector Operations (ChromaDB integration)
|
||||
@abstractmethod
|
||||
async def store_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadata: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Store embeddings with documents and metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def query_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
query_embedding: List[float],
|
||||
limit: int = 10,
|
||||
filter_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query embeddings by similarity"""
|
||||
pass
|
||||
|
||||
# Data Import/Export
|
||||
@abstractmethod
|
||||
async def export_data(
|
||||
self,
|
||||
format: str = "json",
|
||||
tables: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Export database data"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def import_data(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
format: str = "json",
|
||||
merge_strategy: str = "replace"
|
||||
) -> bool:
|
||||
"""Import database data"""
|
||||
pass
|
||||
|
||||
# Security and Encryption
|
||||
@abstractmethod
|
||||
async def encrypt_database(self, encryption_key: str) -> bool:
|
||||
"""Enable database encryption"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def verify_encryption(self) -> bool:
|
||||
"""Verify database encryption status"""
|
||||
pass
|
||||
|
||||
# Performance and Indexing
|
||||
@abstractmethod
|
||||
async def create_index(
|
||||
self,
|
||||
table: str,
|
||||
columns: List[str],
|
||||
index_name: Optional[str] = None,
|
||||
unique: bool = False
|
||||
) -> bool:
|
||||
"""Create database index"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def drop_index(self, index_name: str) -> bool:
|
||||
"""Drop database index"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_queries(self) -> Dict[str, Any]:
|
||||
"""Analyze query performance"""
|
||||
pass
|
||||
|
||||
# Utility Methods
|
||||
async def get_engine_info(self) -> Dict[str, Any]:
|
||||
"""Get database engine information"""
|
||||
return {
|
||||
"engine": self.engine.value,
|
||||
"tenant_id": self.tenant_id,
|
||||
"database_path": self.database_path,
|
||||
"shard_id": self.config.shard_id,
|
||||
"supports_mvcc": self.engine == DatabaseEngine.POSTGRESQL,
|
||||
"supports_sharding": self.engine == DatabaseEngine.POSTGRESQL,
|
||||
"file_based": True
|
||||
}
|
||||
|
||||
async def validate_tenant_isolation(self) -> bool:
|
||||
"""Validate that tenant isolation is maintained"""
|
||||
try:
|
||||
stats = await self.get_statistics()
|
||||
return (
|
||||
self.tenant_id in self.database_path and
|
||||
stats.get("isolated", False)
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class DatabaseFactory:
|
||||
"""Factory for creating database instances"""
|
||||
|
||||
@staticmethod
|
||||
async def create_database(config: DatabaseConfig) -> DatabaseInterface:
|
||||
"""Create database instance - PostgreSQL only"""
|
||||
raise NotImplementedError("Database interface deprecated. Use PostgreSQL directly via postgresql_client.py")
|
||||
|
||||
@staticmethod
|
||||
async def migrate_database(
|
||||
source_config: DatabaseConfig,
|
||||
target_config: DatabaseConfig,
|
||||
migration_options: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""Migrate data from source to target database"""
|
||||
source_db = await DatabaseFactory.create_database(source_config)
|
||||
target_db = await DatabaseFactory.create_database(target_config)
|
||||
|
||||
try:
|
||||
await source_db.initialize()
|
||||
await target_db.initialize()
|
||||
|
||||
# Export data from source
|
||||
data = await source_db.export_data()
|
||||
|
||||
# Import data to target
|
||||
success = await target_db.import_data(data)
|
||||
|
||||
if success and migration_options and migration_options.get("verify", True):
|
||||
# Verify migration
|
||||
source_stats = await source_db.get_statistics()
|
||||
target_stats = await target_db.get_statistics()
|
||||
|
||||
return source_stats.get("row_count", 0) == target_stats.get("row_count", 0)
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
await source_db.close()
|
||||
await target_db.close()
|
||||
|
||||
|
||||
# Error Classes
|
||||
class DatabaseError(Exception):
|
||||
"""Base database error"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseConnectionError(DatabaseError):
|
||||
"""Database connection error"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseMigrationError(DatabaseError):
|
||||
"""Database migration error"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseShardingError(DatabaseError):
|
||||
"""Database sharding error"""
|
||||
pass
|
||||
265
apps/tenant-backend/app/core/dependencies/resource_access.py
Normal file
265
apps/tenant-backend/app/core/dependencies/resource_access.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Resource Access Control Dependencies for FastAPI
|
||||
|
||||
Provides declarative access control for agents and datasets using team-based permissions.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
from uuid import UUID
|
||||
from fastapi import Depends, HTTPException
|
||||
from app.api.dependencies import get_current_user
|
||||
from app.services.team_service import TeamService
|
||||
from app.core.permissions import get_user_role
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def require_resource_access(
|
||||
resource_type: str,
|
||||
required_permission: str = "read"
|
||||
) -> Callable:
|
||||
"""
|
||||
FastAPI dependency factory for resource access control.
|
||||
|
||||
Creates a dependency that verifies user has required permission on a resource
|
||||
via ownership, organization visibility, or team membership.
|
||||
|
||||
Args:
|
||||
resource_type: 'agent' or 'dataset'
|
||||
required_permission: 'read' or 'edit' (default: 'read')
|
||||
|
||||
Returns:
|
||||
FastAPI dependency function
|
||||
|
||||
Usage:
|
||||
@router.get("/agents/{agent_id}")
|
||||
async def get_agent(
|
||||
agent_id: str,
|
||||
_: None = Depends(require_resource_access("agent", "read"))
|
||||
):
|
||||
# User has read access if we reach here
|
||||
...
|
||||
|
||||
@router.put("/agents/{agent_id}")
|
||||
async def update_agent(
|
||||
agent_id: str,
|
||||
_: None = Depends(require_resource_access("agent", "edit"))
|
||||
):
|
||||
# User has edit access if we reach here
|
||||
...
|
||||
"""
|
||||
|
||||
async def check_access(
|
||||
resource_id: str,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> None:
|
||||
"""
|
||||
Verify user has required permission on resource.
|
||||
|
||||
Raises HTTPException(403) if access denied.
|
||||
"""
|
||||
user_id = current_user["user_id"]
|
||||
tenant_domain = current_user["tenant_domain"]
|
||||
user_email = current_user.get("email", user_id)
|
||||
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Check if admin/developer (bypass all checks)
|
||||
user_role = await get_user_role(pg_client, user_email, tenant_domain)
|
||||
if user_role in ["admin", "developer"]:
|
||||
logger.debug(f"Admin/developer {user_id} has full access to {resource_type} {resource_id}")
|
||||
return
|
||||
|
||||
# Check if user owns the resource
|
||||
ownership_query = f"""
|
||||
SELECT created_by FROM {resource_type}s
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
owner_id = await pg_client.fetch_scalar(ownership_query, resource_id, tenant_domain)
|
||||
|
||||
if owner_id and str(owner_id) == str(user_id):
|
||||
logger.debug(f"User {user_id} owns {resource_type} {resource_id}")
|
||||
return
|
||||
|
||||
# Check if resource is organization-wide
|
||||
visibility_query = f"""
|
||||
SELECT visibility FROM {resource_type}s
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
visibility = await pg_client.fetch_scalar(visibility_query, resource_id, tenant_domain)
|
||||
|
||||
if visibility == "organization":
|
||||
logger.debug(f"{resource_type.capitalize()} {resource_id} is organization-wide")
|
||||
return
|
||||
|
||||
# Check team-based access using TeamService
|
||||
team_service = TeamService(tenant_domain, user_id, user_email)
|
||||
has_permission = await team_service.check_user_resource_permission(
|
||||
user_id=user_id,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
required_permission=required_permission
|
||||
)
|
||||
|
||||
if has_permission:
|
||||
logger.debug(f"User {user_id} has {required_permission} permission on {resource_type} {resource_id} via team")
|
||||
return
|
||||
|
||||
# Access denied
|
||||
logger.warning(f"Access denied: User {user_id} cannot access {resource_type} {resource_id} (required: {required_permission})")
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"You do not have {required_permission} permission for this {resource_type}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking resource access: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error verifying {resource_type} access"
|
||||
)
|
||||
|
||||
return check_access
|
||||
|
||||
|
||||
def require_agent_access(required_permission: str = "read") -> Callable:
|
||||
"""
|
||||
Convenience wrapper for agent access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/agents/{agent_id}")
|
||||
async def get_agent(
|
||||
agent_id: str,
|
||||
_: None = Depends(require_agent_access("read"))
|
||||
):
|
||||
...
|
||||
"""
|
||||
return require_resource_access("agent", required_permission)
|
||||
|
||||
|
||||
def require_dataset_access(required_permission: str = "read") -> Callable:
|
||||
"""
|
||||
Convenience wrapper for dataset access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/datasets/{dataset_id}")
|
||||
async def get_dataset(
|
||||
dataset_id: str,
|
||||
_: None = Depends(require_dataset_access("read"))
|
||||
):
|
||||
...
|
||||
"""
|
||||
return require_resource_access("dataset", required_permission)
|
||||
|
||||
|
||||
async def check_agent_edit_permission(
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
tenant_domain: str,
|
||||
user_email: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
Helper function to check if user can edit an agent.
|
||||
|
||||
Can be used in service layer without FastAPI dependency injection.
|
||||
|
||||
Args:
|
||||
agent_id: UUID of the agent
|
||||
user_id: UUID of the user
|
||||
tenant_domain: Tenant domain
|
||||
user_email: User email (optional)
|
||||
|
||||
Returns:
|
||||
True if user can edit agent
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Check if admin/developer
|
||||
user_role = await get_user_role(pg_client, user_email or user_id, tenant_domain)
|
||||
if user_role in ["admin", "developer"]:
|
||||
return True
|
||||
|
||||
# Check ownership
|
||||
query = """
|
||||
SELECT created_by FROM agents
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
owner_id = await pg_client.fetch_scalar(query, agent_id, tenant_domain)
|
||||
|
||||
if owner_id and str(owner_id) == str(user_id):
|
||||
return True
|
||||
|
||||
# Check team edit permission
|
||||
team_service = TeamService(tenant_domain, user_id, user_email or user_id)
|
||||
return await team_service.check_user_resource_permission(
|
||||
user_id=user_id,
|
||||
resource_type="agent",
|
||||
resource_id=agent_id,
|
||||
required_permission="edit"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking agent edit permission: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def check_dataset_edit_permission(
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
tenant_domain: str,
|
||||
user_email: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
Helper function to check if user can edit a dataset.
|
||||
|
||||
Can be used in service layer without FastAPI dependency injection.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset
|
||||
user_id: UUID of the user
|
||||
tenant_domain: Tenant domain
|
||||
user_email: User email (optional)
|
||||
|
||||
Returns:
|
||||
True if user can edit dataset
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Check if admin/developer
|
||||
user_role = await get_user_role(pg_client, user_email or user_id, tenant_domain)
|
||||
if user_role in ["admin", "developer"]:
|
||||
return True
|
||||
|
||||
# Check ownership
|
||||
query = """
|
||||
SELECT user_id FROM datasets
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
owner_id = await pg_client.fetch_scalar(query, dataset_id, tenant_domain)
|
||||
|
||||
if owner_id and str(owner_id) == str(user_id):
|
||||
return True
|
||||
|
||||
# Check team edit permission
|
||||
team_service = TeamService(tenant_domain, user_id, user_email or user_id)
|
||||
return await team_service.check_user_resource_permission(
|
||||
user_id=user_id,
|
||||
resource_type="dataset",
|
||||
resource_id=dataset_id,
|
||||
required_permission="edit"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking dataset edit permission: {e}")
|
||||
return False
|
||||
169
apps/tenant-backend/app/core/logging_config.py
Normal file
169
apps/tenant-backend/app/core/logging_config.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend Logging Configuration
|
||||
|
||||
Structured logging with tenant isolation and security considerations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Setup logging configuration for the tenant backend"""
|
||||
settings = get_settings()
|
||||
|
||||
# Determine log directory based on environment
|
||||
if settings.environment == "test":
|
||||
log_dir = f"/tmp/gt2-data/{settings.tenant_domain}/logs"
|
||||
else:
|
||||
log_dir = f"/data/{settings.tenant_domain}/logs"
|
||||
|
||||
# Create logging configuration
|
||||
log_config: Dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
},
|
||||
"json": {
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(pathname)s:%(lineno)d",
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
},
|
||||
"detailed": {
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(funcName)s() - %(message)s",
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": settings.log_level,
|
||||
"formatter": "json" if settings.log_format == "json" else "default",
|
||||
"stream": sys.stdout,
|
||||
},
|
||||
"file": {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"level": "INFO",
|
||||
"formatter": "json" if settings.log_format == "json" else "detailed",
|
||||
"filename": f"{log_dir}/tenant-backend.log",
|
||||
"maxBytes": 10485760, # 10MB
|
||||
"backupCount": 5,
|
||||
"encoding": "utf-8",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"": { # Root logger
|
||||
"level": settings.log_level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"app": {
|
||||
"level": settings.log_level,
|
||||
"handlers": ["console", "file"] if settings.environment == "production" else ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"sqlalchemy.engine": {
|
||||
"level": "INFO" if settings.debug else "WARNING",
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.access": {
|
||||
"level": "WARNING", # Suppress INFO level access logs (operational endpoints)
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.error": {
|
||||
"level": "INFO",
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Create log directory if it doesn't exist
|
||||
import os
|
||||
os.makedirs(log_dir, exist_ok=True, mode=0o700)
|
||||
|
||||
# Apply logging configuration
|
||||
logging.config.dictConfig(log_config)
|
||||
|
||||
# Add tenant context to all logs
|
||||
class TenantContextFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
record.tenant_id = settings.tenant_id
|
||||
record.tenant_domain = settings.tenant_domain
|
||||
return True
|
||||
|
||||
tenant_filter = TenantContextFilter()
|
||||
|
||||
# Add tenant filter to all handlers
|
||||
for handler in logging.getLogger().handlers:
|
||||
handler.addFilter(tenant_filter)
|
||||
|
||||
# Log startup information
|
||||
logger = logging.getLogger("app.startup")
|
||||
logger.info(
|
||||
"Tenant backend logging initialized",
|
||||
extra={
|
||||
"tenant_id": settings.tenant_id,
|
||||
"tenant_domain": settings.tenant_domain,
|
||||
"environment": settings.environment,
|
||||
"log_level": settings.log_level,
|
||||
"log_format": settings.log_format,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get logger with consistent naming and formatting"""
|
||||
return logging.getLogger(f"app.{name}")
|
||||
|
||||
|
||||
|
||||
class SecurityRedactionFilter(logging.Filter):
|
||||
"""Filter to redact sensitive information from logs"""
|
||||
|
||||
SENSITIVE_FIELDS = [
|
||||
"password", "token", "secret", "key", "authorization",
|
||||
"cookie", "session", "csrf", "api_key", "jwt"
|
||||
]
|
||||
|
||||
def filter(self, record):
|
||||
if hasattr(record, 'args') and record.args:
|
||||
# Redact sensitive information from log messages
|
||||
record.args = self._redact_sensitive_data(record.args)
|
||||
|
||||
if hasattr(record, 'msg') and isinstance(record.msg, str):
|
||||
for field in self.SENSITIVE_FIELDS:
|
||||
if field.lower() in record.msg.lower():
|
||||
record.msg = record.msg.replace(field, "[REDACTED]")
|
||||
|
||||
return True
|
||||
|
||||
def _redact_sensitive_data(self, data):
|
||||
"""Recursively redact sensitive data from log arguments"""
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
key: "[REDACTED]" if any(sensitive in key.lower() for sensitive in self.SENSITIVE_FIELDS)
|
||||
else self._redact_sensitive_data(value)
|
||||
for key, value in data.items()
|
||||
}
|
||||
elif isinstance(data, (list, tuple)):
|
||||
return type(data)(self._redact_sensitive_data(item) for item in data)
|
||||
return data
|
||||
|
||||
|
||||
def setup_security_logging():
|
||||
"""Setup security-focused logging with redaction"""
|
||||
security_filter = SecurityRedactionFilter()
|
||||
|
||||
# Add security filter to all loggers
|
||||
for name in ["app", "uvicorn", "sqlalchemy"]:
|
||||
logger = logging.getLogger(name)
|
||||
logger.addFilter(security_filter)
|
||||
175
apps/tenant-backend/app/core/path_security.py
Normal file
175
apps/tenant-backend/app/core/path_security.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Path Security Utilities for GT AI OS
|
||||
|
||||
Provides path sanitization and validation to prevent path traversal attacks.
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def sanitize_path_component(component: str) -> str:
|
||||
"""
|
||||
Sanitize a single path component to prevent path traversal.
|
||||
|
||||
Removes or replaces dangerous characters including:
|
||||
- Path separators (/ and \\)
|
||||
- Parent directory references (..)
|
||||
- Null bytes
|
||||
- Other special characters
|
||||
|
||||
Args:
|
||||
component: The path component to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized component safe for use in file paths
|
||||
"""
|
||||
if not component:
|
||||
return ""
|
||||
|
||||
# Remove null bytes
|
||||
sanitized = component.replace('\x00', '')
|
||||
|
||||
# Remove path separators
|
||||
sanitized = re.sub(r'[/\\]', '', sanitized)
|
||||
|
||||
# Remove parent directory references
|
||||
sanitized = sanitized.replace('..', '')
|
||||
|
||||
# For tenant domains and similar identifiers, allow alphanumeric, hyphen, underscore
|
||||
# For filenames, allow alphanumeric, hyphen, underscore, and single dots
|
||||
sanitized = re.sub(r'[^a-zA-Z0-9_\-.]', '_', sanitized)
|
||||
|
||||
# Prevent leading dots (hidden files) and multiple consecutive dots
|
||||
sanitized = re.sub(r'^\.+', '', sanitized)
|
||||
sanitized = re.sub(r'\.{2,}', '.', sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_tenant_domain(domain: str) -> str:
|
||||
"""
|
||||
Sanitize a tenant domain for safe use in file paths.
|
||||
|
||||
More restrictive than general path component sanitization.
|
||||
Only allows lowercase alphanumeric characters, hyphens, and underscores.
|
||||
|
||||
Args:
|
||||
domain: The tenant domain to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized domain safe for use in file paths
|
||||
"""
|
||||
if not domain:
|
||||
raise ValueError("Tenant domain cannot be empty")
|
||||
|
||||
# Convert to lowercase and sanitize
|
||||
sanitized = domain.lower()
|
||||
sanitized = re.sub(r'[^a-z0-9_\-]', '_', sanitized)
|
||||
sanitized = sanitized.strip('_-')
|
||||
|
||||
if not sanitized:
|
||||
raise ValueError("Tenant domain resulted in empty string after sanitization")
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""
|
||||
Sanitize a filename for safe storage.
|
||||
|
||||
Preserves the file extension but sanitizes the rest.
|
||||
|
||||
Args:
|
||||
filename: The filename to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized filename
|
||||
"""
|
||||
if not filename:
|
||||
return ""
|
||||
|
||||
# Get the extension
|
||||
path = Path(filename)
|
||||
stem = path.stem
|
||||
suffix = path.suffix
|
||||
|
||||
# Sanitize the stem (filename without extension)
|
||||
safe_stem = sanitize_path_component(stem)
|
||||
|
||||
# Sanitize the extension (should just be alphanumeric)
|
||||
safe_suffix = ""
|
||||
if suffix:
|
||||
safe_suffix = '.' + re.sub(r'[^a-zA-Z0-9]', '', suffix[1:])
|
||||
|
||||
result = safe_stem + safe_suffix
|
||||
|
||||
if not result:
|
||||
result = "unnamed"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def safe_join_path(base: Path, *components: str, require_within_base: bool = True) -> Path:
|
||||
"""
|
||||
Safely join path components, preventing traversal attacks.
|
||||
|
||||
Args:
|
||||
base: The base directory that all paths must stay within
|
||||
components: Path components to join to the base
|
||||
require_within_base: If True, verify the result is within base
|
||||
|
||||
Returns:
|
||||
The joined path
|
||||
|
||||
Raises:
|
||||
ValueError: If the resulting path would be outside the base directory
|
||||
"""
|
||||
if not base:
|
||||
raise ValueError("Base path cannot be empty")
|
||||
|
||||
# Sanitize all components
|
||||
sanitized = [sanitize_path_component(c) for c in components if c]
|
||||
|
||||
# Filter out empty components
|
||||
sanitized = [c for c in sanitized if c]
|
||||
|
||||
if not sanitized:
|
||||
return base
|
||||
|
||||
# Join the path
|
||||
result = base.joinpath(*sanitized)
|
||||
|
||||
# Verify the result is within the base directory
|
||||
if require_within_base:
|
||||
try:
|
||||
resolved_base = base.resolve()
|
||||
resolved_result = result.resolve()
|
||||
|
||||
# Check if result is within base
|
||||
resolved_result.relative_to(resolved_base)
|
||||
except (ValueError, RuntimeError):
|
||||
raise ValueError(f"Path traversal detected: result would be outside base directory")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_file_extension(filename: str, allowed_extensions: Optional[list] = None) -> bool:
|
||||
"""
|
||||
Validate that a file has an allowed extension.
|
||||
|
||||
Args:
|
||||
filename: The filename to check
|
||||
allowed_extensions: List of allowed extensions (e.g., ['.txt', '.pdf']).
|
||||
If None, all extensions are allowed.
|
||||
|
||||
Returns:
|
||||
True if the extension is allowed, False otherwise
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
return True
|
||||
|
||||
path = Path(filename)
|
||||
extension = path.suffix.lower()
|
||||
|
||||
return extension in [ext.lower() for ext in allowed_extensions]
|
||||
138
apps/tenant-backend/app/core/permissions.py
Normal file
138
apps/tenant-backend/app/core/permissions.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
GT 2.0 Role-Based Permissions
|
||||
Enforces organization-level resource sharing based on user roles.
|
||||
|
||||
Visibility Levels:
|
||||
- individual: Only the creator can see and edit
|
||||
- organization: All users can read, only admins/developers can create and edit
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Role hierarchy: admin/developer > analyst > student
|
||||
ADMIN_ROLES = ["admin", "developer"]
|
||||
|
||||
# Visibility levels
|
||||
VISIBILITY_INDIVIDUAL = "individual"
|
||||
VISIBILITY_ORGANIZATION = "organization"
|
||||
|
||||
|
||||
async def get_user_role(pg_client, user_email: str, tenant_domain: str) -> str:
|
||||
"""
|
||||
Get the role for a user in the tenant database.
|
||||
Returns: 'admin', 'developer', 'analyst', or 'student'
|
||||
"""
|
||||
query = """
|
||||
SELECT role FROM users
|
||||
WHERE email = $1
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
LIMIT 1
|
||||
"""
|
||||
role = await pg_client.fetch_scalar(query, user_email, tenant_domain)
|
||||
return role or "student"
|
||||
|
||||
|
||||
def can_share_to_organization(user_role: str) -> bool:
|
||||
"""
|
||||
Check if a user can share resources at the organization level.
|
||||
Only admin and developer roles can share to organization.
|
||||
"""
|
||||
return user_role in ADMIN_ROLES
|
||||
|
||||
|
||||
def validate_visibility_permission(visibility: str, user_role: str) -> None:
|
||||
"""
|
||||
Validate that the user has permission to set the given visibility level.
|
||||
Raises HTTPException if not authorized.
|
||||
|
||||
Rules:
|
||||
- admin/developer: Can set individual or organization visibility
|
||||
- analyst/student: Can only set individual visibility
|
||||
"""
|
||||
if visibility == VISIBILITY_ORGANIZATION and not can_share_to_organization(user_role):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Only admin and developer users can share resources to organization. Your role: {user_role}"
|
||||
)
|
||||
|
||||
|
||||
def can_edit_resource(resource_creator_id: str, current_user_id: str, user_role: str, resource_visibility: str) -> bool:
|
||||
"""
|
||||
Check if user can edit a resource.
|
||||
|
||||
Rules:
|
||||
- Owner can always edit their own resources
|
||||
- Admin/developer can edit any resource
|
||||
- Organization-shared resources: read-only for non-admins who didn't create it
|
||||
"""
|
||||
# Admin and developer can edit anything
|
||||
if user_role in ADMIN_ROLES:
|
||||
return True
|
||||
|
||||
# Owner can always edit
|
||||
if resource_creator_id == current_user_id:
|
||||
return True
|
||||
|
||||
# Organization resources are read-only for non-admins
|
||||
return False
|
||||
|
||||
|
||||
def can_delete_resource(resource_creator_id: str, current_user_id: str, user_role: str) -> bool:
|
||||
"""
|
||||
Check if user can delete a resource.
|
||||
|
||||
Rules:
|
||||
- Owner can delete their own resources
|
||||
- Admin/developer can delete any resource
|
||||
- Others cannot delete
|
||||
"""
|
||||
# Admin and developer can delete anything
|
||||
if user_role in ADMIN_ROLES:
|
||||
return True
|
||||
|
||||
# Owner can delete
|
||||
if resource_creator_id == current_user_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_effective_owner(resource_creator_id: str, current_user_id: str, user_role: str) -> bool:
|
||||
"""
|
||||
Check if user is effective owner of a resource.
|
||||
|
||||
Effective owners have identical access to actual owners:
|
||||
- Actual resource creator
|
||||
- Admin/developer users (tenant admins)
|
||||
|
||||
This determines whether user gets owner-level field visibility in ResponseFilter
|
||||
and whether they can perform owner-only actions like sharing.
|
||||
|
||||
Note: Tenant isolation is enforced at query level via tenant_id checks.
|
||||
This function only determines ownership semantics within the tenant.
|
||||
|
||||
Args:
|
||||
resource_creator_id: UUID of resource creator
|
||||
current_user_id: UUID of current user
|
||||
user_role: User's role in tenant (admin, developer, analyst, student)
|
||||
|
||||
Returns:
|
||||
True if user should have owner-level access
|
||||
|
||||
Examples:
|
||||
>>> is_effective_owner("user123", "admin456", "admin")
|
||||
True # Admin has owner-level access to all resources
|
||||
>>> is_effective_owner("user123", "user123", "student")
|
||||
True # Actual owner
|
||||
>>> is_effective_owner("user123", "user456", "analyst")
|
||||
False # Different user, not admin
|
||||
"""
|
||||
# Admins and developers have identical access to owners
|
||||
if user_role in ADMIN_ROLES:
|
||||
return True
|
||||
|
||||
# Actual owner
|
||||
return resource_creator_id == current_user_id
|
||||
498
apps/tenant-backend/app/core/postgresql_client.py
Normal file
498
apps/tenant-backend/app/core/postgresql_client.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
GT 2.0 PostgreSQL + PGVector Client for Tenant Backend
|
||||
|
||||
Replaces DuckDB service with direct PostgreSQL connections, providing:
|
||||
- PostgreSQL + PGVector unified storage (replaces DuckDB + ChromaDB)
|
||||
- BionicGPT Row Level Security patterns for enterprise isolation
|
||||
- MVCC concurrency solving DuckDB file locking issues
|
||||
- Hybrid vector + full-text search in single queries
|
||||
- Connection pooling for 10,000+ concurrent connections
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator, Tuple, Union
|
||||
from contextlib import asynccontextmanager
|
||||
import json
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
from asyncpg import Pool, Connection
|
||||
from asyncpg.exceptions import PostgresError
|
||||
|
||||
from app.core.config import get_settings, get_tenant_schema_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PostgreSQLClient:
|
||||
"""PostgreSQL + PGVector client for tenant backend operations"""
|
||||
|
||||
def __init__(self, database_url: str, tenant_domain: str):
|
||||
self.database_url = database_url
|
||||
self.tenant_domain = tenant_domain
|
||||
self.schema_name = get_tenant_schema_name(tenant_domain)
|
||||
self._pool: Optional[Pool] = None
|
||||
self._initialized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize connection pool and verify schema"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info(f"Initializing PostgreSQL connection pool for tenant: {self.tenant_domain}")
|
||||
logger.info(f"Schema: {self.schema_name}, URL: {self.database_url}")
|
||||
|
||||
try:
|
||||
# Create connection pool with resilient settings
|
||||
# Sized for 100+ concurrent users with RAG/vector search workloads
|
||||
self._pool = await asyncpg.create_pool(
|
||||
self.database_url,
|
||||
min_size=10,
|
||||
max_size=50, # Increased from 20 to handle 100+ concurrent users
|
||||
command_timeout=120, # Increased from 60s for queries under load
|
||||
timeout=10, # Connection acquire timeout increased for high load
|
||||
max_inactive_connection_lifetime=3600, # Recycle connections after 1 hour
|
||||
server_settings={
|
||||
'application_name': f'gt2_tenant_{self.tenant_domain}'
|
||||
},
|
||||
# Enable prepared statements for direct postgres connection (performance gain)
|
||||
statement_cache_size=100
|
||||
)
|
||||
|
||||
# Verify schema exists and has required tables
|
||||
await self._verify_schema()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"PostgreSQL client initialized successfully for tenant: {self.tenant_domain}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL client: {e}")
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close connection pool"""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
self._initialized = False
|
||||
logger.info(f"PostgreSQL connection pool closed for tenant: {self.tenant_domain}")
|
||||
|
||||
async def _verify_schema(self) -> None:
|
||||
"""Verify tenant schema exists and has required tables"""
|
||||
async with self._pool.acquire() as conn:
|
||||
# Check if schema exists
|
||||
schema_exists = await conn.fetchval("""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.schemata
|
||||
WHERE schema_name = $1
|
||||
)
|
||||
""", self.schema_name)
|
||||
|
||||
if not schema_exists:
|
||||
raise RuntimeError(f"Tenant schema '{self.schema_name}' does not exist. Run schema initialization first.")
|
||||
|
||||
# Check for required tables
|
||||
required_tables = ['tenants', 'users', 'agents', 'datasets', 'conversations', 'messages', 'documents', 'document_chunks']
|
||||
|
||||
for table in required_tables:
|
||||
table_exists = await conn.fetchval(f"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = $1 AND table_name = $2
|
||||
)
|
||||
""", self.schema_name, table)
|
||||
|
||||
if not table_exists:
|
||||
logger.warning(f"Table '{table}' not found in schema '{self.schema_name}'")
|
||||
|
||||
logger.info(f"Schema verification complete for tenant: {self.tenant_domain}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection(self) -> AsyncGenerator[Connection, None]:
|
||||
"""Get a connection from the pool"""
|
||||
if not self._pool:
|
||||
raise RuntimeError("PostgreSQL client not initialized. Call initialize() first.")
|
||||
|
||||
async with self._pool.acquire() as conn:
|
||||
try:
|
||||
# Set schema search path for this connection
|
||||
await conn.execute(f"SET search_path TO {self.schema_name}, public")
|
||||
|
||||
# Session variable logging removed - no longer using RLS
|
||||
|
||||
yield conn
|
||||
except Exception as e:
|
||||
logger.error(f"Database connection error: {e}")
|
||||
raise
|
||||
|
||||
async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
|
||||
"""Execute a SELECT query and return results"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
rows = await conn.fetch(query, *args)
|
||||
return [dict(row) for row in rows]
|
||||
except PostgresError as e:
|
||||
logger.error(f"Query execution failed: {e}, Query: {query}")
|
||||
raise
|
||||
|
||||
async def execute_command(self, command: str, *args) -> int:
|
||||
"""Execute an INSERT/UPDATE/DELETE command and return affected rows"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
result = await conn.execute(command, *args)
|
||||
# Parse result like "INSERT 0 5" to get affected rows
|
||||
return int(result.split()[-1]) if result else 0
|
||||
except PostgresError as e:
|
||||
logger.error(f"Command execution failed: {e}, Command: {command}")
|
||||
raise
|
||||
|
||||
async def fetch_one(self, query: str, *args) -> Optional[Dict[str, Any]]:
|
||||
"""Execute query and return first row"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
row = await conn.fetchrow(query, *args)
|
||||
return dict(row) if row else None
|
||||
except PostgresError as e:
|
||||
logger.error(f"Fetch one failed: {e}, Query: {query}")
|
||||
raise
|
||||
|
||||
async def fetch_scalar(self, query: str, *args) -> Any:
|
||||
"""Execute query and return single value"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
return await conn.fetchval(query, *args)
|
||||
except PostgresError as e:
|
||||
logger.error(f"Fetch scalar failed: {e}, Query: {query}")
|
||||
raise
|
||||
|
||||
async def execute_transaction(self, commands: List[Tuple[str, tuple]]) -> List[int]:
|
||||
"""Execute multiple commands in a transaction"""
|
||||
async with self.get_connection() as conn:
|
||||
async with conn.transaction():
|
||||
results = []
|
||||
for command, args in commands:
|
||||
try:
|
||||
result = await conn.execute(command, *args)
|
||||
results.append(int(result.split()[-1]) if result else 0)
|
||||
except PostgresError as e:
|
||||
logger.error(f"Transaction command failed: {e}, Command: {command}")
|
||||
raise
|
||||
return results
|
||||
|
||||
# Vector Search Operations (PGVector)
|
||||
|
||||
async def vector_similarity_search(
|
||||
self,
|
||||
query_vector: List[float],
|
||||
table: str = "document_chunks",
|
||||
limit: int = 10,
|
||||
similarity_threshold: float = 0.3,
|
||||
user_id: Optional[str] = None,
|
||||
dataset_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Perform vector similarity search using PGVector"""
|
||||
|
||||
# Convert Python list to PostgreSQL array format
|
||||
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
id,
|
||||
content,
|
||||
1 - (embedding <=> $1::vector) as similarity_score,
|
||||
metadata
|
||||
FROM {table}
|
||||
WHERE embedding IS NOT NULL
|
||||
AND 1 - (embedding <=> $1::vector) > $2
|
||||
"""
|
||||
|
||||
params = [vector_str, similarity_threshold]
|
||||
param_idx = 3
|
||||
|
||||
# Add user isolation if specified
|
||||
if user_id:
|
||||
query += f" AND user_id = ${param_idx}"
|
||||
params.append(user_id)
|
||||
param_idx += 1
|
||||
|
||||
# Add dataset filtering if specified
|
||||
if dataset_id:
|
||||
query += f" AND dataset_id = ${param_idx}"
|
||||
params.append(dataset_id)
|
||||
param_idx += 1
|
||||
|
||||
query += f" ORDER BY embedding <=> $1::vector LIMIT ${param_idx}"
|
||||
params.append(limit)
|
||||
|
||||
return await self.execute_query(query, *params)
|
||||
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query_text: str,
|
||||
query_vector: List[float],
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
similarity_threshold: float = 0.3,
|
||||
text_weight: float = 0.3,
|
||||
vector_weight: float = 0.7,
|
||||
dataset_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Perform hybrid search combining vector similarity and full-text search"""
|
||||
|
||||
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
|
||||
|
||||
# Use the enhanced_hybrid_search_chunks function from BionicGPT integration
|
||||
query = """
|
||||
SELECT
|
||||
id,
|
||||
document_id,
|
||||
content,
|
||||
similarity_score,
|
||||
text_rank,
|
||||
combined_score,
|
||||
metadata,
|
||||
access_verified
|
||||
FROM enhanced_hybrid_search_chunks($1, $2::vector, $3::uuid, $4, $5, $6, $7, $8)
|
||||
"""
|
||||
|
||||
return await self.execute_query(
|
||||
query,
|
||||
query_text,
|
||||
vector_str,
|
||||
user_id,
|
||||
dataset_id,
|
||||
limit,
|
||||
similarity_threshold,
|
||||
text_weight,
|
||||
vector_weight
|
||||
)
|
||||
|
||||
async def insert_document_chunk(
|
||||
self,
|
||||
document_id: str,
|
||||
tenant_id: int,
|
||||
user_id: str,
|
||||
chunk_index: int,
|
||||
content: str,
|
||||
content_hash: str,
|
||||
embedding: List[float],
|
||||
dataset_id: Optional[str] = None,
|
||||
token_count: int = 0,
|
||||
metadata: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""Insert a document chunk with vector embedding"""
|
||||
|
||||
vector_str = '[' + ','.join(map(str, embedding)) + ']'
|
||||
metadata_json = json.dumps(metadata or {})
|
||||
|
||||
query = """
|
||||
INSERT INTO document_chunks (
|
||||
document_id, tenant_id, user_id, dataset_id, chunk_index,
|
||||
content, content_hash, token_count, embedding, metadata
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector, $10::jsonb)
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
return await self.fetch_scalar(
|
||||
query,
|
||||
document_id, tenant_id, user_id, dataset_id, chunk_index,
|
||||
content, content_hash, token_count, vector_str, metadata_json
|
||||
)
|
||||
|
||||
# Health Check and Statistics
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check on PostgreSQL connection"""
|
||||
try:
|
||||
if not self._pool:
|
||||
return {"status": "unhealthy", "reason": "Connection pool not initialized"}
|
||||
|
||||
# Test basic connectivity
|
||||
test_result = await self.fetch_scalar("SELECT 1")
|
||||
|
||||
# Get pool statistics
|
||||
pool_stats = {
|
||||
"size": self._pool.get_size(),
|
||||
"min_size": self._pool.get_min_size(),
|
||||
"max_size": self._pool.get_max_size(),
|
||||
"idle_size": self._pool.get_idle_size()
|
||||
}
|
||||
|
||||
# Test schema access
|
||||
schema_test = await self.fetch_scalar("""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.schemata
|
||||
WHERE schema_name = $1
|
||||
)
|
||||
""", self.schema_name)
|
||||
|
||||
return {
|
||||
"status": "healthy" if test_result == 1 and schema_test else "degraded",
|
||||
"connectivity": "ok" if test_result == 1 else "failed",
|
||||
"schema_access": "ok" if schema_test else "failed",
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"schema_name": self.schema_name,
|
||||
"pool_stats": pool_stats,
|
||||
"database_type": "postgresql_pgvector"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL health check failed: {e}")
|
||||
return {"status": "unhealthy", "reason": str(e)}
|
||||
|
||||
async def get_database_stats(self) -> Dict[str, Any]:
|
||||
"""Get database statistics for monitoring"""
|
||||
try:
|
||||
# Get table counts and sizes
|
||||
stats_query = """
|
||||
SELECT
|
||||
schemaname,
|
||||
tablename,
|
||||
n_tup_ins as inserts,
|
||||
n_tup_upd as updates,
|
||||
n_tup_del as deletes,
|
||||
n_live_tup as live_tuples,
|
||||
n_dead_tup as dead_tuples
|
||||
FROM pg_stat_user_tables
|
||||
WHERE schemaname = $1
|
||||
"""
|
||||
|
||||
table_stats = await self.execute_query(stats_query, self.schema_name)
|
||||
|
||||
# Get total schema size
|
||||
size_query = """
|
||||
SELECT pg_size_pretty(
|
||||
SUM(pg_total_relation_size(quote_ident(schemaname)||'.'||quote_ident(tablename)))
|
||||
) as schema_size
|
||||
FROM pg_tables
|
||||
WHERE schemaname = $1
|
||||
"""
|
||||
|
||||
schema_size = await self.fetch_scalar(size_query, self.schema_name)
|
||||
|
||||
# Get vector index statistics if available
|
||||
vector_stats_query = """
|
||||
SELECT
|
||||
COUNT(*) as vector_count,
|
||||
AVG(vector_dims(embedding)) as avg_dimensions
|
||||
FROM document_chunks
|
||||
WHERE embedding IS NOT NULL
|
||||
"""
|
||||
|
||||
try:
|
||||
vector_stats = await self.fetch_one(vector_stats_query)
|
||||
except:
|
||||
vector_stats = {"vector_count": 0, "avg_dimensions": 0}
|
||||
|
||||
return {
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"schema_name": self.schema_name,
|
||||
"schema_size": schema_size,
|
||||
"table_stats": table_stats,
|
||||
"vector_stats": vector_stats,
|
||||
"engine_type": "PostgreSQL + PGVector",
|
||||
"mvcc_enabled": True,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database statistics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Global client instance (singleton pattern for tenant backend)
|
||||
_pg_client: Optional[PostgreSQLClient] = None
|
||||
|
||||
|
||||
async def get_postgresql_client() -> PostgreSQLClient:
|
||||
"""Get or create PostgreSQL client instance"""
|
||||
global _pg_client
|
||||
|
||||
if not _pg_client:
|
||||
settings = get_settings()
|
||||
_pg_client = PostgreSQLClient(
|
||||
database_url=settings.database_url,
|
||||
tenant_domain=settings.tenant_domain
|
||||
)
|
||||
await _pg_client.initialize()
|
||||
|
||||
return _pg_client
|
||||
|
||||
|
||||
async def init_postgresql() -> None:
|
||||
"""Initialize PostgreSQL client during startup"""
|
||||
logger.info("Initializing PostgreSQL client...")
|
||||
await get_postgresql_client()
|
||||
logger.info("PostgreSQL client initialized successfully")
|
||||
|
||||
|
||||
async def close_postgresql() -> None:
|
||||
"""Close PostgreSQL client during shutdown"""
|
||||
global _pg_client
|
||||
|
||||
if _pg_client:
|
||||
await _pg_client.close()
|
||||
_pg_client = None
|
||||
logger.info("PostgreSQL client closed")
|
||||
|
||||
|
||||
# Context manager for database operations
|
||||
@asynccontextmanager
|
||||
async def get_db_session():
|
||||
"""Async context manager for database operations"""
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
# Convenience functions for common operations
|
||||
async def execute_query(query: str, *args) -> List[Dict[str, Any]]:
|
||||
"""Execute a SELECT query"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.execute_query(query, *args)
|
||||
|
||||
|
||||
async def execute_command(command: str, *args) -> int:
|
||||
"""Execute an INSERT/UPDATE/DELETE command"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.execute_command(command, *args)
|
||||
|
||||
|
||||
async def fetch_one(query: str, *args) -> Optional[Dict[str, Any]]:
|
||||
"""Execute query and return first row"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.fetch_one(query, *args)
|
||||
|
||||
|
||||
async def fetch_scalar(query: str, *args) -> Any:
|
||||
"""Execute query and return single value"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.fetch_scalar(query, *args)
|
||||
|
||||
|
||||
async def health_check() -> Dict[str, Any]:
|
||||
"""Perform database health check"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
return await client.health_check()
|
||||
except Exception as e:
|
||||
return {"status": "unhealthy", "reason": str(e)}
|
||||
|
||||
|
||||
async def get_database_info() -> Dict[str, Any]:
|
||||
"""Get database information and statistics"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
return await client.get_database_stats()
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
531
apps/tenant-backend/app/core/resource_client.py
Normal file
531
apps/tenant-backend/app/core/resource_client.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""
|
||||
Resource Cluster Client for GT 2.0 Tenant Backend
|
||||
|
||||
Provides stateless access to Resource Cluster services including:
|
||||
- Document processing
|
||||
- Embedding generation
|
||||
- Vector storage (ChromaDB)
|
||||
- Model inference
|
||||
|
||||
Perfect tenant isolation with capability-based authentication.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import gc
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.capability_client import CapabilityClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResourceClusterClient:
|
||||
"""
|
||||
Client for accessing Resource Cluster services with capability-based auth.
|
||||
|
||||
GT 2.0 Security Principles:
|
||||
- Capability tokens for fine-grained access control
|
||||
- Stateless operations (no data persistence in Resource Cluster)
|
||||
- Perfect tenant isolation
|
||||
- Immediate memory cleanup
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.capability_client = CapabilityClient()
|
||||
|
||||
# Resource Cluster endpoints
|
||||
# IMPORTANT: Use Docker service name for stability across container restarts
|
||||
# Fixed 2025-09-12: Changed from hardcoded IP to service name for reliability
|
||||
self.base_url = getattr(
|
||||
self.settings,
|
||||
'resource_cluster_url', # Matches Pydantic field name (case insensitive)
|
||||
'http://gentwo-resource-backend:8000' # Fallback uses service name, not IP
|
||||
)
|
||||
|
||||
self.endpoints = {
|
||||
'document_processor': f"{self.base_url}/api/v1/process/document",
|
||||
'embedding_generator': f"{self.base_url}/api/v1/embeddings/generate",
|
||||
'chromadb_backend': f"{self.base_url}/api/v1/vectors",
|
||||
'inference': f"{self.base_url}/api/v1/ai/chat/completions" # Updated to match actual endpoint
|
||||
}
|
||||
|
||||
# Request timeouts
|
||||
self.request_timeout = 300 # seconds - 5 minutes for complex agent operations
|
||||
self.upload_timeout = 300 # seconds for large documents
|
||||
|
||||
logger.info("Resource Cluster client initialized")
|
||||
|
||||
async def _get_capability_token(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
resources: List[str]
|
||||
) -> str:
|
||||
"""Generate capability token for Resource Cluster access"""
|
||||
try:
|
||||
token = await self.capability_client.generate_capability_token(
|
||||
user_email=user_id, # Using user_id as email for now
|
||||
tenant_id=tenant_id,
|
||||
resources=resources,
|
||||
expires_hours=1
|
||||
)
|
||||
return token
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate capability token: {e}")
|
||||
raise
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
data: Dict[str, Any],
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
resources: List[str],
|
||||
timeout: int = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Make authenticated request to Resource Cluster"""
|
||||
try:
|
||||
# Get capability token
|
||||
token = await self._get_capability_token(tenant_id, user_id, resources)
|
||||
|
||||
# Prepare headers
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {token}',
|
||||
'X-Tenant-ID': tenant_id,
|
||||
'X-User-ID': user_id,
|
||||
'X-Request-ID': f"{tenant_id}_{user_id}_{datetime.utcnow().timestamp()}"
|
||||
}
|
||||
|
||||
# Make request
|
||||
timeout_config = aiohttp.ClientTimeout(total=timeout or self.request_timeout)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout_config) as session:
|
||||
async with session.request(
|
||||
method=method.upper(),
|
||||
url=endpoint,
|
||||
json=data,
|
||||
headers=headers
|
||||
) as response:
|
||||
|
||||
if response.status not in [200, 201]:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(
|
||||
f"Resource Cluster error: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Resource Cluster request failed: {e}")
|
||||
raise
|
||||
|
||||
# Document Processing
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
content: bytes,
|
||||
document_type: str,
|
||||
strategy_type: str = "hybrid",
|
||||
tenant_id: str = None,
|
||||
user_id: str = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Process document into chunks via Resource Cluster"""
|
||||
try:
|
||||
# Convert bytes to base64 for JSON transport
|
||||
import base64
|
||||
content_b64 = base64.b64encode(content).decode('utf-8')
|
||||
|
||||
request_data = {
|
||||
"content": content_b64,
|
||||
"document_type": document_type,
|
||||
"strategy": {
|
||||
"strategy_type": strategy_type,
|
||||
"chunk_size": 512,
|
||||
"chunk_overlap": 128
|
||||
}
|
||||
}
|
||||
|
||||
# Clear original content from memory
|
||||
del content
|
||||
gc.collect()
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=self.endpoints['document_processor'],
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['document_processing'],
|
||||
timeout=self.upload_timeout
|
||||
)
|
||||
|
||||
chunks = result.get('chunks', [])
|
||||
logger.info(f"Processed document into {len(chunks)} chunks")
|
||||
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Document processing failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
# Embedding Generation
|
||||
|
||||
async def generate_document_embeddings(
|
||||
self,
|
||||
documents: List[str],
|
||||
tenant_id: str,
|
||||
user_id: str
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings for documents"""
|
||||
try:
|
||||
request_data = {
|
||||
"texts": documents,
|
||||
"model": "BAAI/bge-m3",
|
||||
"instruction": None # Document embeddings don't need instruction
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=self.endpoints['embedding_generator'],
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['embedding_generation']
|
||||
)
|
||||
|
||||
embeddings = result.get('embeddings', [])
|
||||
|
||||
# Clear documents from memory
|
||||
del documents
|
||||
gc.collect()
|
||||
|
||||
logger.info(f"Generated {len(embeddings)} document embeddings")
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Document embedding generation failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def generate_query_embeddings(
|
||||
self,
|
||||
queries: List[str],
|
||||
tenant_id: str,
|
||||
user_id: str
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings for queries"""
|
||||
try:
|
||||
request_data = {
|
||||
"texts": queries,
|
||||
"model": "BAAI/bge-m3",
|
||||
"instruction": "Represent this sentence for searching relevant passages: "
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=self.endpoints['embedding_generator'],
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['embedding_generation']
|
||||
)
|
||||
|
||||
embeddings = result.get('embeddings', [])
|
||||
|
||||
# Clear queries from memory
|
||||
del queries
|
||||
gc.collect()
|
||||
|
||||
logger.info(f"Generated {len(embeddings)} query embeddings")
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Query embedding generation failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
# Vector Storage (ChromaDB)
|
||||
|
||||
async def create_vector_collection(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
dataset_name: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""Create vector collection in ChromaDB"""
|
||||
try:
|
||||
request_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"dataset_name": dataset_name,
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=f"{self.endpoints['chromadb_backend']}/collections",
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['vector_storage']
|
||||
)
|
||||
|
||||
success = result.get('success', False)
|
||||
logger.info(f"Created vector collection for {dataset_name}: {success}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector collection creation failed: {e}")
|
||||
raise
|
||||
|
||||
async def store_vectors(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
dataset_name: str,
|
||||
documents: List[str],
|
||||
embeddings: List[List[float]],
|
||||
metadata: List[Dict[str, Any]] = None,
|
||||
ids: List[str] = None
|
||||
) -> bool:
|
||||
"""Store vectors in ChromaDB"""
|
||||
try:
|
||||
request_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"dataset_name": dataset_name,
|
||||
"documents": documents,
|
||||
"embeddings": embeddings,
|
||||
"metadata": metadata or [],
|
||||
"ids": ids
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=f"{self.endpoints['chromadb_backend']}/store",
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['vector_storage']
|
||||
)
|
||||
|
||||
# Clear vectors from memory immediately
|
||||
del documents, embeddings
|
||||
gc.collect()
|
||||
|
||||
success = result.get('success', False)
|
||||
logger.info(f"Stored vectors in {dataset_name}: {success}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector storage failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def search_vectors(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
dataset_name: str,
|
||||
query_embedding: List[float],
|
||||
top_k: int = 5,
|
||||
filter_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search vectors in ChromaDB"""
|
||||
try:
|
||||
request_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"dataset_name": dataset_name,
|
||||
"query_embedding": query_embedding,
|
||||
"top_k": top_k,
|
||||
"filter_metadata": filter_metadata or {}
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=f"{self.endpoints['chromadb_backend']}/search",
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['vector_storage']
|
||||
)
|
||||
|
||||
# Clear query embedding from memory
|
||||
del query_embedding
|
||||
gc.collect()
|
||||
|
||||
results = result.get('results', [])
|
||||
logger.info(f"Found {len(results)} vector search results")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector search failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def delete_vector_collection(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
dataset_name: str
|
||||
) -> bool:
|
||||
"""Delete vector collection from ChromaDB"""
|
||||
try:
|
||||
request_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"dataset_name": dataset_name
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='DELETE',
|
||||
endpoint=f"{self.endpoints['chromadb_backend']}/collections",
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['vector_storage']
|
||||
)
|
||||
|
||||
success = result.get('success', False)
|
||||
logger.info(f"Deleted vector collection {dataset_name}: {success}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector collection deletion failed: {e}")
|
||||
raise
|
||||
|
||||
# Model Inference
|
||||
|
||||
async def inference_with_context(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
context: str,
|
||||
model: str = "llama-3.1-70b-versatile",
|
||||
tenant_id: str = None,
|
||||
user_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform inference with RAG context"""
|
||||
try:
|
||||
# Inject context into system message
|
||||
enhanced_messages = []
|
||||
system_context = f"Use the following context to answer the user's question:\n\n{context}\n\n"
|
||||
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
enhanced_msg = msg.copy()
|
||||
enhanced_msg["content"] = system_context + enhanced_msg["content"]
|
||||
enhanced_messages.append(enhanced_msg)
|
||||
else:
|
||||
enhanced_messages.append(msg)
|
||||
|
||||
# Add system message if none exists
|
||||
if not any(msg.get("role") == "system" for msg in enhanced_messages):
|
||||
enhanced_messages.insert(0, {
|
||||
"role": "system",
|
||||
"content": system_context + "You are a helpful AI agent."
|
||||
})
|
||||
|
||||
request_data = {
|
||||
"messages": enhanced_messages,
|
||||
"model": model,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4000,
|
||||
"user_id": user_id,
|
||||
"tenant_id": tenant_id
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=self.endpoints['inference'],
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['llm_inference']
|
||||
)
|
||||
|
||||
# Clear context from memory
|
||||
del context, enhanced_messages
|
||||
gc.collect()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Inference with context failed: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def check_health(self) -> Dict[str, Any]:
|
||||
"""Check Resource Cluster health"""
|
||||
try:
|
||||
# Test basic connectivity
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.base_url}/health") as response:
|
||||
if response.status == 200:
|
||||
health_data = await response.json()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"resource_cluster": health_data,
|
||||
"endpoints": list(self.endpoints.keys()),
|
||||
"base_url": self.base_url
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": f"Health check failed: {response.status}",
|
||||
"base_url": self.base_url
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"base_url": self.base_url
|
||||
}
|
||||
|
||||
async def call_inference_endpoint(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
endpoint: str = "chat/completions",
|
||||
data: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Call AI inference endpoint on Resource Cluster"""
|
||||
try:
|
||||
# Use the direct inference endpoint
|
||||
inference_url = self.endpoints['inference']
|
||||
|
||||
# Add tenant/user context to request
|
||||
request_data = data.copy() if data else {}
|
||||
|
||||
# Make request with capability token
|
||||
result = await self._make_request(
|
||||
method='POST',
|
||||
endpoint=inference_url,
|
||||
data=request_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
resources=['llm'] # Use valid ResourceType from resource cluster
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Inference endpoint call failed: {e}")
|
||||
raise
|
||||
|
||||
# Streaming removed for reliability - using non-streaming only
|
||||
320
apps/tenant-backend/app/core/response_filter.py
Normal file
320
apps/tenant-backend/app/core/response_filter.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Response Filtering Utilities for GT 2.0
|
||||
|
||||
Provides field-level authorization and data filtering for API responses.
|
||||
Implements principle of least privilege - users only see data they're authorized to access.
|
||||
|
||||
Security principles:
|
||||
1. Owner-only fields: resource_preferences, advanced RAG configs (max_chunks_per_query, history_context)
|
||||
2. Viewer fields: Public + usage stats + prompt_template + personality_config + dataset connections
|
||||
(Team members with read access need these fields to effectively use shared agents)
|
||||
3. Public fields: id, name, description, category, basic metadata
|
||||
4. No internal UUIDs, implementation details, or system configuration exposure
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Set
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseFilter:
|
||||
"""Filter API responses based on user permissions and access level"""
|
||||
|
||||
# Define field access levels for agents
|
||||
# REQUIRED fields that must always be present for AgentResponse schema
|
||||
AGENT_REQUIRED_FIELDS = {
|
||||
'id', 'name', 'description', 'created_at', 'updated_at'
|
||||
}
|
||||
|
||||
AGENT_PUBLIC_FIELDS = AGENT_REQUIRED_FIELDS | {
|
||||
'category', 'conversation_count', 'usage_count', 'is_favorite', 'tags',
|
||||
'created_by_name', 'can_edit', 'can_delete', 'is_owner',
|
||||
# Include these for display purposes
|
||||
'model', 'visibility', 'disclaimer', 'easy_prompts',
|
||||
# Dataset connections for showing dataset count on agent tiles
|
||||
'dataset_connection', 'selected_dataset_ids'
|
||||
}
|
||||
|
||||
AGENT_VIEWER_FIELDS = AGENT_PUBLIC_FIELDS | {
|
||||
'temperature', 'max_tokens', 'total_cost_cents', 'template_id',
|
||||
# Essential fields for using shared agents (team collaboration)
|
||||
'prompt_template', 'personality_config',
|
||||
'dataset_connection', 'selected_dataset_ids'
|
||||
}
|
||||
|
||||
AGENT_OWNER_FIELDS = AGENT_VIEWER_FIELDS | {
|
||||
# Advanced configuration fields (owner-only)
|
||||
'resource_preferences', 'max_chunks_per_query', 'history_context',
|
||||
# Team sharing configuration (owner-only for editing)
|
||||
'team_shares'
|
||||
}
|
||||
|
||||
# Define field access levels for datasets
|
||||
# Fields for all users (public/shared datasets) - stats are informational, not sensitive
|
||||
DATASET_PUBLIC_FIELDS = {
|
||||
'id', 'name', 'description', 'created_by_name', 'owner_name',
|
||||
'document_count', 'chunk_count', 'vector_count', 'storage_size_mb',
|
||||
'tags', 'created_at', 'updated_at', 'access_group',
|
||||
# Permission flags for UI controls
|
||||
'is_owner', 'can_edit', 'can_delete', 'can_share',
|
||||
# Team sharing flag for proper visibility indicators
|
||||
'shared_via_team'
|
||||
}
|
||||
|
||||
DATASET_VIEWER_FIELDS = DATASET_PUBLIC_FIELDS | {
|
||||
'summary' # Viewers can see dataset summary
|
||||
}
|
||||
|
||||
DATASET_OWNER_FIELDS = DATASET_VIEWER_FIELDS | {
|
||||
# Only owners see internal configuration
|
||||
'owner_id', 'team_members', 'chunking_strategy', 'chunk_size',
|
||||
'chunk_overlap', 'embedding_model', 'summary_generated_at',
|
||||
# Team sharing configuration (owner-only for editing)
|
||||
'team_shares'
|
||||
}
|
||||
|
||||
# Define field access levels for files
|
||||
# Public fields include processing info since it's informational metadata, not sensitive
|
||||
FILE_PUBLIC_FIELDS = {
|
||||
'id', 'original_filename', 'content_type', 'file_type', 'file_size', 'file_size_bytes',
|
||||
'created_at', 'updated_at', 'category',
|
||||
# Processing fields - informational, not sensitive
|
||||
'processing_status', 'chunk_count', 'processing_progress', 'processing_stage',
|
||||
# Permission flags for UI controls
|
||||
'can_delete'
|
||||
}
|
||||
|
||||
FILE_OWNER_FIELDS = FILE_PUBLIC_FIELDS | {
|
||||
'user_id', 'dataset_id', 'storage_path', 'metadata'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def filter_agent_response(
|
||||
agent_data: Dict[str, Any],
|
||||
is_owner: bool = False,
|
||||
can_view: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Filter agent response fields based on user permissions
|
||||
|
||||
Args:
|
||||
agent_data: Full agent data dictionary
|
||||
is_owner: Whether user owns this agent
|
||||
can_view: Whether user can view detailed information
|
||||
|
||||
Returns:
|
||||
Filtered dictionary with only authorized fields
|
||||
"""
|
||||
if is_owner:
|
||||
allowed_fields = ResponseFilter.AGENT_OWNER_FIELDS
|
||||
logger.info(f"🔓 Agent '{agent_data.get('name', 'Unknown')}': Using OWNER fields (is_owner=True, can_view={can_view})")
|
||||
elif can_view:
|
||||
allowed_fields = ResponseFilter.AGENT_VIEWER_FIELDS
|
||||
logger.info(f"👁️ Agent '{agent_data.get('name', 'Unknown')}': Using VIEWER fields (is_owner=False, can_view=True)")
|
||||
else:
|
||||
allowed_fields = ResponseFilter.AGENT_PUBLIC_FIELDS
|
||||
logger.info(f"🌍 Agent '{agent_data.get('name', 'Unknown')}': Using PUBLIC fields (is_owner=False, can_view=False)")
|
||||
|
||||
filtered = {
|
||||
key: value for key, value in agent_data.items()
|
||||
if key in allowed_fields
|
||||
}
|
||||
|
||||
# Ensure defaults for optional fields that were filtered out
|
||||
# This prevents AgentResponse schema validation errors
|
||||
default_values = {
|
||||
'personality_config': {},
|
||||
'resource_preferences': {},
|
||||
'tags': [],
|
||||
'easy_prompts': [],
|
||||
'conversation_count': 0,
|
||||
'usage_count': 0,
|
||||
'total_cost_cents': 0,
|
||||
'is_favorite': False,
|
||||
'can_edit': False,
|
||||
'can_delete': False,
|
||||
'is_owner': is_owner
|
||||
}
|
||||
|
||||
for key, default_value in default_values.items():
|
||||
if key not in filtered:
|
||||
filtered[key] = default_value
|
||||
|
||||
# Log field filtering for security audit
|
||||
removed_fields = set(agent_data.keys()) - set(filtered.keys())
|
||||
if removed_fields:
|
||||
logger.info(
|
||||
f"🔒 Filtered agent '{agent_data.get('name', 'Unknown')}' - removed fields: {removed_fields} "
|
||||
f"(is_owner={is_owner}, can_view={can_view})"
|
||||
)
|
||||
|
||||
# Special logging for prompt_template field
|
||||
if 'prompt_template' in agent_data:
|
||||
if 'prompt_template' in filtered:
|
||||
logger.info(f"✅ Agent '{agent_data.get('name', 'Unknown')}': prompt_template INCLUDED in response")
|
||||
else:
|
||||
logger.warning(f"❌ Agent '{agent_data.get('name', 'Unknown')}': prompt_template FILTERED OUT (is_owner={is_owner}, can_view={can_view})")
|
||||
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
def filter_dataset_response(
|
||||
dataset_data: Dict[str, Any],
|
||||
is_owner: bool = False,
|
||||
can_view: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Filter dataset response fields based on user permissions
|
||||
|
||||
Args:
|
||||
dataset_data: Full dataset data dictionary
|
||||
is_owner: Whether user owns this dataset
|
||||
can_view: Whether user can view the dataset
|
||||
|
||||
Returns:
|
||||
Filtered dictionary with only authorized fields
|
||||
"""
|
||||
if is_owner:
|
||||
allowed_fields = ResponseFilter.DATASET_OWNER_FIELDS
|
||||
elif can_view:
|
||||
allowed_fields = ResponseFilter.DATASET_VIEWER_FIELDS
|
||||
else:
|
||||
allowed_fields = ResponseFilter.DATASET_PUBLIC_FIELDS
|
||||
|
||||
filtered = {
|
||||
key: value for key, value in dataset_data.items()
|
||||
if key in allowed_fields
|
||||
}
|
||||
|
||||
# Security: Never expose owner_id UUID to non-owners
|
||||
if not is_owner and 'owner_id' in filtered:
|
||||
del filtered['owner_id']
|
||||
|
||||
# Ensure defaults for optional fields to prevent schema validation errors
|
||||
default_values = {
|
||||
'tags': [],
|
||||
'is_owner': is_owner,
|
||||
'can_edit': False,
|
||||
'can_delete': False,
|
||||
'can_share': False,
|
||||
# Always set these to None for non-owners (security)
|
||||
'team_members': None if not is_owner else filtered.get('team_members', []),
|
||||
'owner_id': None if not is_owner else filtered.get('owner_id'),
|
||||
# Internal fields - null for all except detail view
|
||||
'agent_has_access': None,
|
||||
'user_owns': None,
|
||||
# Stats fields - use actual values or safe defaults for frontend compatibility
|
||||
# These are informational only, not sensitive
|
||||
'chunk_count': filtered.get('chunk_count', 0),
|
||||
'vector_count': filtered.get('vector_count', 0),
|
||||
'storage_size_mb': filtered.get('storage_size_mb', 0.0),
|
||||
'updated_at': filtered.get('updated_at'),
|
||||
'summary': None
|
||||
}
|
||||
|
||||
for key, default_value in default_values.items():
|
||||
if key not in filtered:
|
||||
filtered[key] = default_value
|
||||
|
||||
# Log field filtering for security audit
|
||||
removed_fields = set(dataset_data.keys()) - set(filtered.keys())
|
||||
if removed_fields:
|
||||
logger.debug(
|
||||
f"Filtered dataset response - removed fields: {removed_fields} "
|
||||
f"(is_owner={is_owner}, can_view={can_view})"
|
||||
)
|
||||
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
def filter_file_response(
|
||||
file_data: Dict[str, Any],
|
||||
is_owner: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Filter file response fields based on user permissions
|
||||
|
||||
Args:
|
||||
file_data: Full file data dictionary
|
||||
is_owner: Whether user owns this file
|
||||
|
||||
Returns:
|
||||
Filtered dictionary with only authorized fields
|
||||
"""
|
||||
allowed_fields = (
|
||||
ResponseFilter.FILE_OWNER_FIELDS if is_owner
|
||||
else ResponseFilter.FILE_PUBLIC_FIELDS
|
||||
)
|
||||
|
||||
filtered = {
|
||||
key: value for key, value in file_data.items()
|
||||
if key in allowed_fields
|
||||
}
|
||||
|
||||
# Log field filtering for security audit
|
||||
removed_fields = set(file_data.keys()) - set(filtered.keys())
|
||||
if removed_fields:
|
||||
logger.debug(
|
||||
f"Filtered file response - removed fields: {removed_fields} "
|
||||
f"(is_owner={is_owner})"
|
||||
)
|
||||
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
def filter_batch_responses(
|
||||
items: List[Dict[str, Any]],
|
||||
filter_func: callable,
|
||||
ownership_map: Optional[Dict[str, bool]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Filter a batch of items using the provided filter function
|
||||
|
||||
Args:
|
||||
items: List of items to filter
|
||||
filter_func: Function to apply to each item (e.g., filter_agent_response)
|
||||
ownership_map: Optional map of item_id -> is_owner boolean
|
||||
|
||||
Returns:
|
||||
List of filtered items
|
||||
"""
|
||||
filtered_items = []
|
||||
|
||||
for item in items:
|
||||
item_id = item.get('id')
|
||||
is_owner = ownership_map.get(item_id, False) if ownership_map else False
|
||||
|
||||
filtered_item = filter_func(item, is_owner=is_owner)
|
||||
filtered_items.append(filtered_item)
|
||||
|
||||
return filtered_items
|
||||
|
||||
@staticmethod
|
||||
def sanitize_dataset_summary(
|
||||
summary_data: Dict[str, Any],
|
||||
user_can_access: bool = True
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Sanitize dataset summary for inclusion in chat context
|
||||
|
||||
Args:
|
||||
summary_data: Dataset summary with metadata
|
||||
user_can_access: Whether user should have access to this dataset
|
||||
|
||||
Returns:
|
||||
Sanitized summary or None if user shouldn't access
|
||||
"""
|
||||
if not user_can_access:
|
||||
return None
|
||||
|
||||
# Only include safe fields in summary
|
||||
safe_fields = {
|
||||
'id', 'name', 'description', 'summary',
|
||||
'document_count', 'chunk_count'
|
||||
}
|
||||
|
||||
return {
|
||||
key: value for key, value in summary_data.items()
|
||||
if key in safe_fields
|
||||
}
|
||||
314
apps/tenant-backend/app/core/security.py
Normal file
314
apps/tenant-backend/app/core/security.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
Security module for GT 2.0 Tenant Backend
|
||||
|
||||
Provides JWT capability token verification and user authentication.
|
||||
"""
|
||||
|
||||
import os
|
||||
import jwt
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import Header
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_jwt_secret() -> str:
|
||||
"""Get JWT secret from environment variable.
|
||||
|
||||
The JWT_SECRET is auto-generated by installers using:
|
||||
openssl rand -hex 32
|
||||
|
||||
This provides a 256-bit secret suitable for HS256 signing.
|
||||
"""
|
||||
secret = os.environ.get('JWT_SECRET')
|
||||
if not secret:
|
||||
raise ValueError("JWT_SECRET environment variable is required. Run the installer to generate one.")
|
||||
return secret
|
||||
|
||||
|
||||
def verify_capability_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify JWT capability token using HS256 symmetric key
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
Token payload if valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
secret = get_jwt_secret()
|
||||
|
||||
# Verify token with HS256 symmetric key
|
||||
payload = jwt.decode(token, secret, algorithms=["HS256"])
|
||||
|
||||
# Check expiration
|
||||
if "exp" in payload:
|
||||
if datetime.utcnow().timestamp() > payload["exp"]:
|
||||
logger.warning("Token expired")
|
||||
return None
|
||||
|
||||
return payload
|
||||
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.warning(f"Invalid token: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_capability_token(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
capabilities: list,
|
||||
expires_hours: int = 4
|
||||
) -> str:
|
||||
"""
|
||||
Create JWT capability token using HS256 symmetric key
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tenant_id: Tenant domain
|
||||
capabilities: List of capability objects
|
||||
expires_hours: Token expiration in hours
|
||||
|
||||
Returns:
|
||||
JWT token string
|
||||
"""
|
||||
try:
|
||||
secret = get_jwt_secret()
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": user_id,
|
||||
"user_type": "tenant_user",
|
||||
|
||||
# Current tenant context (primary structure)
|
||||
"current_tenant": {
|
||||
"id": tenant_id,
|
||||
"domain": tenant_id,
|
||||
"name": f"Tenant {tenant_id}",
|
||||
"role": "tenant_user",
|
||||
"display_name": user_id,
|
||||
"email": user_id,
|
||||
"is_primary": True,
|
||||
"capabilities": capabilities
|
||||
},
|
||||
|
||||
# Available tenants for tenant switching
|
||||
"available_tenants": [{
|
||||
"id": tenant_id,
|
||||
"domain": tenant_id,
|
||||
"name": f"Tenant {tenant_id}",
|
||||
"role": "tenant_user"
|
||||
}],
|
||||
|
||||
# Standard JWT fields
|
||||
"iat": datetime.utcnow().timestamp(),
|
||||
"exp": (datetime.utcnow() + timedelta(hours=expires_hours)).timestamp()
|
||||
}
|
||||
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create capability token: {e}")
|
||||
raise ValueError("Failed to create capability token")
|
||||
|
||||
|
||||
async def get_current_user(authorization: str = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current user from authorization header - REQUIRED for all endpoints
|
||||
Raises 401 if authentication fails - following GT 2.0 security principles
|
||||
"""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
if not authorization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
# Extract token
|
||||
token = authorization.replace("Bearer ", "")
|
||||
payload = verify_capability_token(token)
|
||||
|
||||
if not payload:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
# Extract tenant context from new JWT structure
|
||||
current_tenant = payload.get('current_tenant', {})
|
||||
available_tenants = payload.get('available_tenants', [])
|
||||
user_type = payload.get('user_type', 'tenant_user')
|
||||
|
||||
# For admin users, allow access to any tenant backend
|
||||
if user_type == 'super_admin' and current_tenant.get('domain') == 'admin':
|
||||
# Admin users accessing tenant backends - create tenant context for the current backend
|
||||
from app.core.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
# Override the admin context with the current tenant backend's context
|
||||
current_tenant = {
|
||||
'id': settings.tenant_id,
|
||||
'domain': settings.tenant_domain,
|
||||
'name': f'Tenant {settings.tenant_domain}',
|
||||
'role': 'super_admin',
|
||||
'display_name': payload.get('email', 'Admin User'),
|
||||
'email': payload.get('email'),
|
||||
'is_primary': True,
|
||||
'capabilities': [
|
||||
{'resource': '*', 'actions': ['*'], 'constraints': {}},
|
||||
]
|
||||
}
|
||||
logger.info(f"Admin user {payload.get('email')} accessing tenant backend {settings.tenant_domain}")
|
||||
|
||||
# Validate tenant context exists
|
||||
if not current_tenant or not current_tenant.get('id'):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No valid tenant context in token",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
# Return user dict with clean tenant context structure
|
||||
return {
|
||||
'sub': payload.get('sub'),
|
||||
'email': payload.get('email'),
|
||||
'user_id': payload.get('sub'),
|
||||
'user_type': payload.get('user_type', 'tenant_user'),
|
||||
|
||||
# Current tenant context (primary structure)
|
||||
'tenant_id': str(current_tenant.get('id')),
|
||||
'tenant_domain': current_tenant.get('domain'),
|
||||
'tenant_name': current_tenant.get('name'),
|
||||
'tenant_role': current_tenant.get('role'),
|
||||
'tenant_display_name': current_tenant.get('display_name'),
|
||||
'tenant_email': current_tenant.get('email'),
|
||||
'is_primary_tenant': current_tenant.get('is_primary', False),
|
||||
|
||||
# Tenant-specific capabilities
|
||||
'capabilities': current_tenant.get('capabilities', []),
|
||||
|
||||
# Available tenants for tenant switching
|
||||
'available_tenants': available_tenants
|
||||
}
|
||||
|
||||
|
||||
def get_current_user_email(authorization: str) -> str:
|
||||
"""
|
||||
Extract user email from authorization header
|
||||
"""
|
||||
if authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
payload = verify_capability_token(token)
|
||||
if payload:
|
||||
current_tenant = payload.get('current_tenant', {})
|
||||
# Prefer tenant-specific email, fallback to user email, then sub
|
||||
return (current_tenant.get('email') or
|
||||
payload.get('email') or
|
||||
payload.get('sub', 'test@example.com'))
|
||||
|
||||
return 'anonymous@example.com'
|
||||
|
||||
|
||||
def get_tenant_info(authorization: str) -> Dict[str, str]:
|
||||
"""
|
||||
Extract tenant information from authorization header
|
||||
"""
|
||||
if authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
payload = verify_capability_token(token)
|
||||
if payload:
|
||||
current_tenant = payload.get('current_tenant', {})
|
||||
if current_tenant:
|
||||
return {
|
||||
'tenant_id': str(current_tenant.get('id')),
|
||||
'tenant_domain': current_tenant.get('domain'),
|
||||
'tenant_name': current_tenant.get('name'),
|
||||
'tenant_role': current_tenant.get('role')
|
||||
}
|
||||
|
||||
return {
|
||||
'tenant_id': 'default',
|
||||
'tenant_domain': 'default',
|
||||
'tenant_name': 'Default Tenant',
|
||||
'tenant_role': 'tenant_user'
|
||||
}
|
||||
|
||||
|
||||
def verify_jwt_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify JWT token - alias for verify_capability_token
|
||||
"""
|
||||
return verify_capability_token(token)
|
||||
|
||||
|
||||
async def get_user_context_unified(
|
||||
authorization: Optional[str] = Header(None),
|
||||
x_tenant_domain: Optional[str] = Header(None),
|
||||
x_user_id: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Unified authentication for both JWT (user requests) and header-based (service requests).
|
||||
|
||||
Supports two auth modes:
|
||||
1. JWT Authentication: Authorization header with Bearer token (for direct user requests)
|
||||
2. Header Authentication: X-Tenant-Domain + X-User-ID headers (for internal service requests)
|
||||
|
||||
Returns user context with tenant information for both modes.
|
||||
"""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
# Mode 1: Header-based authentication (for internal services like MCP)
|
||||
if x_tenant_domain and x_user_id:
|
||||
logger.info(f"Using header auth: tenant={x_tenant_domain}, user={x_user_id}")
|
||||
return {
|
||||
"tenant_domain": x_tenant_domain,
|
||||
"tenant_id": x_tenant_domain,
|
||||
"id": x_user_id,
|
||||
"sub": x_user_id,
|
||||
"email": x_user_id,
|
||||
"user_id": x_user_id,
|
||||
"user_type": "internal_service",
|
||||
"tenant_role": "tenant_user"
|
||||
}
|
||||
|
||||
# Mode 2: JWT authentication (for direct user requests)
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
payload = verify_capability_token(token)
|
||||
|
||||
if payload:
|
||||
logger.info(f"Using JWT auth: user={payload.get('sub')}")
|
||||
# Extract tenant context from JWT structure
|
||||
current_tenant = payload.get('current_tenant', {})
|
||||
return {
|
||||
'sub': payload.get('sub'),
|
||||
'email': payload.get('email'),
|
||||
'user_id': payload.get('sub'),
|
||||
'id': payload.get('sub'),
|
||||
'user_type': payload.get('user_type', 'tenant_user'),
|
||||
'tenant_id': str(current_tenant.get('id', 'default')),
|
||||
'tenant_domain': current_tenant.get('domain', 'default'),
|
||||
'tenant_name': current_tenant.get('name', 'Default Tenant'),
|
||||
'tenant_role': current_tenant.get('role', 'tenant_user'),
|
||||
'capabilities': current_tenant.get('capabilities', [])
|
||||
}
|
||||
|
||||
# No valid authentication provided
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication: provide either Authorization header or X-Tenant-Domain + X-User-ID headers"
|
||||
)
|
||||
165
apps/tenant-backend/app/core/user_resolver.py
Normal file
165
apps/tenant-backend/app/core/user_resolver.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
User UUID Resolution Utilities for GT 2.0
|
||||
|
||||
Handles email-to-UUID resolution across all services to ensure
|
||||
consistent user identification in database operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def resolve_user_uuid(current_user: Dict[str, Any]) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Resolve user email to UUID for internal services.
|
||||
|
||||
Args:
|
||||
current_user: User data from JWT token
|
||||
|
||||
Returns:
|
||||
Tuple of (tenant_domain, user_email, user_uuid)
|
||||
|
||||
Raises:
|
||||
HTTPException: If UUID resolution fails
|
||||
"""
|
||||
tenant_domain = current_user.get("tenant_domain", "test")
|
||||
user_email = current_user["email"]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from app.api.auth import get_tenant_user_uuid_by_email
|
||||
|
||||
user_uuid = await get_tenant_user_uuid_by_email(user_email)
|
||||
|
||||
if not user_uuid:
|
||||
logger.error(f"Failed to resolve UUID for user {user_email} in tenant {tenant_domain}")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"User {user_email} not found in tenant system"
|
||||
)
|
||||
|
||||
logger.info(f"✅ Resolved user {user_email} to UUID: {user_uuid}")
|
||||
return tenant_domain, user_email, user_uuid
|
||||
|
||||
|
||||
async def ensure_user_uuid(email_or_uuid: str, tenant_domain: Optional[str] = None) -> str:
|
||||
"""
|
||||
Ensure we have a UUID, converting email if needed.
|
||||
|
||||
Args:
|
||||
email_or_uuid: Either an email address or UUID string
|
||||
tenant_domain: Tenant domain for lookup context
|
||||
|
||||
Returns:
|
||||
UUID string
|
||||
|
||||
Raises:
|
||||
ValueError: If email cannot be resolved to UUID or input is invalid
|
||||
"""
|
||||
import uuid
|
||||
import re
|
||||
|
||||
# Validate input is not empty or None
|
||||
if not email_or_uuid or not isinstance(email_or_uuid, str):
|
||||
raise ValueError(f"Invalid user identifier: {email_or_uuid}")
|
||||
|
||||
email_or_uuid = email_or_uuid.strip()
|
||||
|
||||
# Check if it's an email
|
||||
if "@" in email_or_uuid:
|
||||
# It's an email, resolve to UUID
|
||||
from app.api.auth import get_tenant_user_uuid_by_email
|
||||
|
||||
user_uuid = await get_tenant_user_uuid_by_email(email_or_uuid)
|
||||
|
||||
if not user_uuid:
|
||||
error_msg = f"Cannot resolve email {email_or_uuid} to UUID"
|
||||
if tenant_domain:
|
||||
error_msg += f" in tenant {tenant_domain}"
|
||||
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.debug(f"Resolved email {email_or_uuid} to UUID: {user_uuid}")
|
||||
return user_uuid
|
||||
|
||||
# Check if it's a valid UUID format
|
||||
try:
|
||||
uuid_obj = uuid.UUID(email_or_uuid)
|
||||
return str(uuid_obj) # Return normalized UUID string
|
||||
except (ValueError, TypeError):
|
||||
# Not a valid UUID, could be a numeric ID or other format
|
||||
pass
|
||||
|
||||
# Handle numeric user IDs or other legacy formats
|
||||
if email_or_uuid.isdigit():
|
||||
logger.warning(f"Received numeric user ID '{email_or_uuid}', attempting database lookup")
|
||||
# Try to resolve numeric ID to proper UUID via database
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
tenant_schema = f"tenant_{tenant_domain.replace('.', '_').replace('-', '_')}" if tenant_domain else "tenant_test"
|
||||
|
||||
# Try to find user by numeric ID (assuming it might be a legacy ID)
|
||||
user_row = await conn.fetchrow(
|
||||
f"SELECT id FROM {tenant_schema}.users WHERE id::text = $1 OR email = $1 LIMIT 1",
|
||||
email_or_uuid
|
||||
)
|
||||
|
||||
if user_row:
|
||||
return str(user_row['id'])
|
||||
|
||||
# If not found, try finding the first user (fallback for development)
|
||||
logger.warning(f"User '{email_or_uuid}' not found, using first available user as fallback")
|
||||
first_user = await conn.fetchrow(f"SELECT id FROM {tenant_schema}.users LIMIT 1")
|
||||
|
||||
if first_user:
|
||||
logger.info(f"Using fallback user UUID: {first_user['id']}")
|
||||
return str(first_user['id'])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database lookup failed for user '{email_or_uuid}': {e}")
|
||||
|
||||
# If all else fails, raise an error
|
||||
error_msg = f"Cannot resolve user identifier '{email_or_uuid}' to UUID. Expected email or valid UUID format."
|
||||
if tenant_domain:
|
||||
error_msg += f" Tenant: {tenant_domain}"
|
||||
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def get_user_sql_clause(param_num: int, user_identifier: str) -> str:
|
||||
"""
|
||||
Get the appropriate SQL clause for user identification.
|
||||
|
||||
Args:
|
||||
param_num: Parameter number in SQL query (e.g., 3 for $3)
|
||||
user_identifier: Either email or UUID
|
||||
|
||||
Returns:
|
||||
SQL clause string for user lookup
|
||||
"""
|
||||
if "@" in user_identifier:
|
||||
# Email - do lookup
|
||||
return f"(SELECT id FROM users WHERE email = ${param_num} LIMIT 1)"
|
||||
else:
|
||||
# UUID - use directly
|
||||
return f"${param_num}::uuid"
|
||||
|
||||
|
||||
def is_uuid_format(identifier: str) -> bool:
|
||||
"""
|
||||
Check if a string looks like a UUID.
|
||||
|
||||
Args:
|
||||
identifier: String to check
|
||||
|
||||
Returns:
|
||||
True if looks like UUID, False if looks like email
|
||||
"""
|
||||
return "@" not in identifier and len(identifier) == 36 and identifier.count("-") == 4
|
||||
403
apps/tenant-backend/app/main.py
Normal file
403
apps/tenant-backend/app/main.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend - Main Application Entry Point
|
||||
|
||||
This is the customer-facing API server that provides:
|
||||
- AI chat interface with WebSocket support
|
||||
- Document upload and processing
|
||||
- User authentication and session management
|
||||
- Perfect tenant isolation with file-based databases
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
import uvicorn
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.database import init_database as startup_database, close_database as shutdown_database
|
||||
from app.core.logging_config import setup_logging
|
||||
# Import models to ensure they're registered with the Base metadata
|
||||
# TEMPORARY: Commented out SQLAlchemy-based models during PostgreSQL migration
|
||||
# from app.models import workflow, agent, conversation, message, document
|
||||
from app.api.auth import router as auth_router
|
||||
# from app.api.agents import router as assistants_router # Legacy: replaced with agents_router
|
||||
# Import the migrated PostgreSQL-based conversations API
|
||||
from app.api.v1.conversations import router as conversations_router
|
||||
# from app.api.messages import router as messages_router
|
||||
from app.api.v1.documents import router as documents_router
|
||||
# from app.api.websocket import router as websocket_router
|
||||
# from app.api.events import router as events_router
|
||||
from app.api.v1.agents import router as agents_router
|
||||
# from app.api.v1.games import router as games_router
|
||||
# from app.api.v1.external_services import router as external_services_router
|
||||
# assistants_enhanced module removed - using agents terminology only
|
||||
from app.api.v1.rag_visualization import router as rag_visualization_router
|
||||
# from app.api.v1.dataset_sharing import router as dataset_sharing_router
|
||||
from app.api.v1.datasets import router as datasets_router
|
||||
from app.api.v1.chat import router as chat_router
|
||||
# from app.api.v1.workflows import router as workflows_router
|
||||
from app.api.v1.models import router as models_router
|
||||
from app.api.v1.files import router as files_router
|
||||
from app.api.v1.search import router as search_router
|
||||
from app.api.v1.users import router as users_router
|
||||
from app.api.v1.observability import router as observability_router
|
||||
from app.api.v1.teams import router as teams_router
|
||||
from app.api.v1.auth_logs import router as auth_logs_router
|
||||
from app.api.v1.categories import router as categories_router
|
||||
from app.middleware.tenant_isolation import TenantIsolationMiddleware
|
||||
from app.middleware.security import SecurityHeadersMiddleware
|
||||
from app.middleware.rate_limiting import RateLimitMiddleware
|
||||
from app.middleware.oauth2_auth import OAuth2AuthMiddleware
|
||||
from app.middleware.session_validation import SessionValidationMiddleware
|
||||
from app.services.message_bus_client import initialize_message_bus, message_bus_client
|
||||
|
||||
# Configure logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = get_settings()
|
||||
start_time = time.time() # Track service startup time for metrics
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan management"""
|
||||
logger.info("Starting GT 2.0 Tenant Backend...")
|
||||
|
||||
# Initialize database connections
|
||||
await startup_database()
|
||||
logger.info("PostgreSQL + PGVector database connection initialized")
|
||||
|
||||
# Initialize message bus for admin communication
|
||||
try:
|
||||
message_bus_connected = await initialize_message_bus()
|
||||
if message_bus_connected:
|
||||
logger.info("Message bus connected - admin communication enabled")
|
||||
else:
|
||||
logger.warning("Message bus connection failed - admin communication disabled")
|
||||
except Exception as e:
|
||||
logger.error(f"Message bus initialization error: {e}")
|
||||
|
||||
# Load BGE-M3 configuration from Control Panel database on startup
|
||||
try:
|
||||
import httpx
|
||||
control_panel_url = os.getenv('CONTROL_PANEL_BACKEND_URL', 'http://control-panel-backend:8000')
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# Fetch BGE-M3 configuration from Control Panel
|
||||
response = await client.get(f"{control_panel_url}/api/v1/models/BAAI%2Fbge-m3")
|
||||
|
||||
if response.status_code == 200:
|
||||
model_config = response.json()
|
||||
config = model_config.get('config', {})
|
||||
is_local_mode = config.get('is_local_mode', True)
|
||||
external_endpoint = config.get('external_endpoint')
|
||||
|
||||
# Update embedding client with database configuration
|
||||
from app.services.embedding_client import get_embedding_client
|
||||
embedding_client = get_embedding_client()
|
||||
|
||||
if is_local_mode:
|
||||
new_endpoint = os.getenv('EMBEDDING_ENDPOINT', 'http://host.docker.internal:8005')
|
||||
else:
|
||||
new_endpoint = external_endpoint if external_endpoint else 'http://host.docker.internal:8005'
|
||||
|
||||
embedding_client.update_endpoint(new_endpoint)
|
||||
|
||||
# Update environment variables for consistency
|
||||
os.environ['BGE_M3_LOCAL_MODE'] = str(is_local_mode).lower()
|
||||
if external_endpoint:
|
||||
os.environ['BGE_M3_EXTERNAL_ENDPOINT'] = external_endpoint
|
||||
|
||||
logger.info(f"BGE-M3 configuration loaded from database: is_local_mode={is_local_mode}, endpoint={new_endpoint}")
|
||||
else:
|
||||
logger.warning(f"Failed to load BGE-M3 configuration from Control Panel (status {response.status_code}), using defaults")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load BGE-M3 configuration from Control Panel: {e}, using defaults")
|
||||
|
||||
# Log configuration
|
||||
logger.info(f"Environment: {settings.environment}")
|
||||
logger.info(f"Tenant ID: {settings.tenant_id}")
|
||||
logger.info(f"Database URL: {settings.database_url}")
|
||||
logger.info(f"PostgreSQL Schema: {settings.postgres_schema}")
|
||||
logger.info(f"Resource cluster URL: {settings.resource_cluster_url}")
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup on shutdown
|
||||
logger.info("Shutting down GT 2.0 Tenant Backend...")
|
||||
|
||||
# Disconnect message bus
|
||||
try:
|
||||
await message_bus_client.disconnect()
|
||||
logger.info("Message bus disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting message bus: {e}")
|
||||
|
||||
await shutdown_database()
|
||||
logger.info("PostgreSQL database connections closed")
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title="GT 2.0 Tenant Backend",
|
||||
description="Customer-facing API for GT 2.0 Enterprise AI Platform",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs" if settings.environment == "development" else None,
|
||||
redoc_url="/redoc" if settings.environment == "development" else None,
|
||||
redirect_slashes=False, # Disable redirects - Next.js proxy can't follow internal Docker URLs
|
||||
)
|
||||
|
||||
# Security Middleware
|
||||
app.add_middleware(
|
||||
TrustedHostMiddleware,
|
||||
allowed_hosts=settings.allowed_hosts
|
||||
)
|
||||
|
||||
# OAuth2 Authentication Middleware (temporarily disabled for development)
|
||||
# app.add_middleware(OAuth2AuthMiddleware, require_auth=settings.require_oauth2_auth)
|
||||
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
app.add_middleware(TenantIsolationMiddleware)
|
||||
# Session validation middleware for OWASP/NIST compliance (Issue #264)
|
||||
app.add_middleware(SessionValidationMiddleware)
|
||||
|
||||
# CORS Middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["X-Session-Warning", "X-Session-Expired"], # Issue #264: Expose session headers to frontend
|
||||
)
|
||||
|
||||
# API Routes
|
||||
app.include_router(auth_router, prefix="/api/v1")
|
||||
# app.include_router(assistants_router, prefix="/api/v1") # Legacy: replaced with agents_router
|
||||
app.include_router(conversations_router) # Already has prefix
|
||||
# app.include_router(messages_router, prefix="/api/v1")
|
||||
app.include_router(documents_router, prefix="/api/v1")
|
||||
# app.include_router(events_router, prefix="/api/v1/events")
|
||||
app.include_router(agents_router, prefix="/api/v1")
|
||||
# app.include_router(games_router, prefix="/api/v1")
|
||||
# app.include_router(external_services_router, prefix="/api/v1/external-services")
|
||||
from app.api.websocket import router as websocket_router
|
||||
from app.api.embeddings import router as embeddings_router
|
||||
from app.websocket.manager import socket_app
|
||||
app.include_router(websocket_router, prefix="/ws")
|
||||
app.include_router(embeddings_router, prefix="/api/embeddings")
|
||||
|
||||
# Enhanced API Routes for GT 2.0 comprehensive agent platform
|
||||
# assistants_enhanced module removed - architecture now uses agents only
|
||||
# TEMPORARY: Commented out during PostgreSQL migration
|
||||
app.include_router(rag_visualization_router) # Already has /api/v1/rag/visualization prefix
|
||||
app.include_router(datasets_router) # Already has /api/v1/datasets prefix
|
||||
app.include_router(chat_router) # Already has /api/v1/chat prefix
|
||||
# app.include_router(dataset_sharing_router, prefix="/api/v1/datasets") # Dataset sharing endpoints
|
||||
# app.include_router(workflows_router) # Already has /api/v1/workflows prefix
|
||||
app.include_router(models_router) # Already has /api/v1/models prefix
|
||||
app.include_router(files_router, prefix="/api/v1") # Files upload/download API
|
||||
app.include_router(search_router) # Already has /api/v1/search prefix
|
||||
app.include_router(users_router, prefix="/api/v1") # User preferences and favorite agents
|
||||
app.include_router(observability_router, prefix="/api/v1") # Observability dashboard (admin-only)
|
||||
app.include_router(teams_router, prefix="/api/v1") # Team collaboration and resource sharing
|
||||
app.include_router(auth_logs_router, prefix="/api/v1") # Authentication logs for security monitoring (Issue #152)
|
||||
app.include_router(categories_router) # Agent categories CRUD (Issue #215) - already has /api/v1/categories prefix
|
||||
|
||||
# Note: Socket.IO integration moved to composite ASGI router to prevent protocol conflicts
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint for load balancer and Kubernetes"""
|
||||
# Import here to avoid circular imports
|
||||
from app.core.database import health_check as db_health_check
|
||||
|
||||
try:
|
||||
db_health = await db_health_check()
|
||||
is_healthy = db_health.get("status") == "healthy"
|
||||
|
||||
# codeql[py/stack-trace-exposure] returns health status dict, not error details
|
||||
return {
|
||||
"status": "healthy" if is_healthy else "degraded",
|
||||
"service": "gt2-tenant-backend",
|
||||
"version": "1.0.0",
|
||||
"tenant_id": settings.tenant_id,
|
||||
"environment": settings.environment,
|
||||
"database": db_health,
|
||||
"postgresql_pgvector": True
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"service": "gt2-tenant-backend",
|
||||
"version": "1.0.0",
|
||||
"error": "Health check failed",
|
||||
"database": {"status": "failed"}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/health")
|
||||
async def api_health_check():
|
||||
"""API v1 health check endpoint for frontend compatibility"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "gt2-tenant-backend",
|
||||
"version": "1.0.0",
|
||||
"tenant_id": settings.tenant_id,
|
||||
"environment": settings.environment,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/ready")
|
||||
async def ready_check():
|
||||
"""Kubernetes readiness probe endpoint"""
|
||||
return {
|
||||
"status": "ready",
|
||||
"service": "tenant-backend",
|
||||
"timestamp": datetime.utcnow(),
|
||||
"health": "ok"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/metrics")
|
||||
async def metrics(request: Request):
|
||||
"""Prometheus metrics endpoint"""
|
||||
try:
|
||||
# Basic metrics for now - in production would use prometheus_client
|
||||
import psutil
|
||||
import time
|
||||
|
||||
# Be permissive with Accept headers for monitoring tools
|
||||
# Most legitimate monitoring tools will accept text/plain or send */*
|
||||
accept_header = request.headers.get("accept", "text/plain")
|
||||
if (accept_header and
|
||||
accept_header != "text/plain" and
|
||||
not any(pattern in accept_header.lower() for pattern in [
|
||||
"text/plain", "text/*", "*/*", "application/openmetrics-text",
|
||||
"application/json", "text/html" # Common but non-metrics requests
|
||||
])):
|
||||
# Only return 400 for truly incompatible Accept headers
|
||||
logger.warning(f"Metrics endpoint received unsupported Accept header: {accept_header}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Unsupported media type. Metrics endpoint supports text/plain."
|
||||
)
|
||||
|
||||
# Get basic system metrics with error handling
|
||||
try:
|
||||
cpu_percent = psutil.cpu_percent(interval=0.1) # Reduced interval to avoid blocking
|
||||
except Exception:
|
||||
cpu_percent = 0.0
|
||||
|
||||
try:
|
||||
memory = psutil.virtual_memory()
|
||||
except Exception:
|
||||
# Fallback values if psutil fails
|
||||
memory = type('Memory', (), {'used': 0, 'available': 0})()
|
||||
|
||||
metrics_data = f"""# HELP tenant_backend_cpu_usage_percent CPU usage percentage
|
||||
# TYPE tenant_backend_cpu_usage_percent gauge
|
||||
tenant_backend_cpu_usage_percent {cpu_percent}
|
||||
|
||||
# HELP tenant_backend_memory_usage_bytes Memory usage in bytes
|
||||
# TYPE tenant_backend_memory_usage_bytes gauge
|
||||
tenant_backend_memory_usage_bytes {memory.used}
|
||||
|
||||
# HELP tenant_backend_memory_available_bytes Available memory in bytes
|
||||
# TYPE tenant_backend_memory_available_bytes gauge
|
||||
tenant_backend_memory_available_bytes {memory.available}
|
||||
|
||||
# HELP tenant_backend_uptime_seconds Service uptime in seconds
|
||||
# TYPE tenant_backend_uptime_seconds counter
|
||||
tenant_backend_uptime_seconds {time.time() - start_time}
|
||||
|
||||
# HELP tenant_backend_requests_total Total HTTP requests
|
||||
# TYPE tenant_backend_requests_total counter
|
||||
tenant_backend_requests_total 1
|
||||
"""
|
||||
|
||||
return Response(content=metrics_data, media_type="text/plain; version=0.0.4; charset=utf-8")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log the error but return basic metrics to avoid breaking monitoring
|
||||
logger.error(f"Error generating metrics: {e}")
|
||||
|
||||
# Return minimal metrics on error
|
||||
fallback_metrics = f"""# HELP tenant_backend_uptime_seconds Service uptime in seconds
|
||||
# TYPE tenant_backend_uptime_seconds counter
|
||||
tenant_backend_uptime_seconds {time.time() - start_time}
|
||||
|
||||
# HELP tenant_backend_errors_total Total errors
|
||||
# TYPE tenant_backend_errors_total counter
|
||||
tenant_backend_errors_total 1
|
||||
"""
|
||||
return Response(content=fallback_metrics, media_type="text/plain")
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
"""Custom HTTP exception handler"""
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error": {
|
||||
"message": exc.detail,
|
||||
"code": exc.status_code,
|
||||
"type": "http_error"
|
||||
},
|
||||
"request_id": getattr(request.state, "request_id", None),
|
||||
"timestamp": "2024-01-01T00:00:00Z" # TODO: Use actual timestamp
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""General exception handler for unhandled errors"""
|
||||
logger.error(f"Unhandled error: {str(exc)}", exc_info=True)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"message": "Internal server error",
|
||||
"code": 500,
|
||||
"type": "internal_error"
|
||||
},
|
||||
"request_id": getattr(request.state, "request_id", None),
|
||||
"timestamp": "2024-01-01T00:00:00Z" # TODO: Use actual timestamp
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Create composite ASGI application for Socket.IO + FastAPI coexistence
|
||||
from app.core.asgi_router import create_composite_asgi_app
|
||||
|
||||
# Create the composite application that routes between FastAPI and Socket.IO
|
||||
composite_app = create_composite_asgi_app(app, socket_app)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Development server
|
||||
uvicorn.run(
|
||||
"app.main:composite_app",
|
||||
host="0.0.0.0",
|
||||
port=8002,
|
||||
reload=True if settings.environment == "development" else False,
|
||||
log_level="info",
|
||||
access_log=True,
|
||||
)
|
||||
5
apps/tenant-backend/app/middleware/__init__.py
Normal file
5
apps/tenant-backend/app/middleware/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend Middleware
|
||||
|
||||
Security and isolation middleware for tenant applications.
|
||||
"""
|
||||
385
apps/tenant-backend/app/middleware/oauth2_auth.py
Normal file
385
apps/tenant-backend/app/middleware/oauth2_auth.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
OAuth2 Authentication Middleware for GT 2.0 Tenant Backend
|
||||
|
||||
Handles OAuth2 authentication headers from OAuth2 Proxy and extracts
|
||||
user information for tenant isolation and access control.
|
||||
"""
|
||||
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
from urllib.parse import unquote
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuth2AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to handle OAuth2 authentication from OAuth2 Proxy.
|
||||
|
||||
Extracts user information from OAuth2 Proxy headers and sets
|
||||
user context for downstream handlers.
|
||||
"""
|
||||
|
||||
# Routes that don't require authentication
|
||||
EXEMPT_PATHS = {
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/api/v1/health",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/refresh",
|
||||
"/api/v1/auth/logout"
|
||||
}
|
||||
|
||||
def __init__(self, app, require_auth: bool = True):
|
||||
super().__init__(app)
|
||||
self.require_auth = require_auth
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process OAuth2 authentication headers"""
|
||||
|
||||
# Skip authentication for exempt paths
|
||||
if request.url.path in self.EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# Try OAuth2 headers first, then fallback to JWT token authentication
|
||||
user_info = self._extract_oauth2_headers(request)
|
||||
|
||||
# If no OAuth2 headers found, try JWT token authentication
|
||||
if not user_info:
|
||||
user_info = await self._extract_jwt_user(request)
|
||||
|
||||
if self.require_auth and not user_info:
|
||||
logger.warning(f"Authentication required but no valid OAuth2 headers found for {request.url.path}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
# Set user context in request state
|
||||
if user_info:
|
||||
request.state.user = user_info
|
||||
request.state.authenticated = True
|
||||
logger.info(f"Authenticated user: {user_info.get('email', 'unknown')} for {request.url.path}")
|
||||
else:
|
||||
request.state.user = None
|
||||
request.state.authenticated = False
|
||||
|
||||
# Continue with request processing
|
||||
response = await call_next(request)
|
||||
|
||||
# Add authentication-related headers to response
|
||||
if user_info:
|
||||
response.headers["X-Authenticated-User"] = user_info.get("email", "unknown")
|
||||
response.headers["X-Auth-Source"] = user_info.get("auth_source", "oauth2-proxy")
|
||||
|
||||
return response
|
||||
|
||||
def _extract_oauth2_headers(self, request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract user information from OAuth2 Proxy headers.
|
||||
|
||||
OAuth2 Proxy sets the following headers:
|
||||
- X-Auth-Request-User: Username/email
|
||||
- X-Auth-Request-Email: User email
|
||||
- X-Auth-Request-Access-Token: Access token
|
||||
- Authorization: Bearer token (if configured)
|
||||
"""
|
||||
|
||||
# Extract user information from OAuth2 Proxy headers
|
||||
user_email = request.headers.get("X-Auth-Request-Email")
|
||||
user_name = request.headers.get("X-Auth-Request-User")
|
||||
access_token = request.headers.get("X-Auth-Request-Access-Token")
|
||||
|
||||
# Also check Authorization header for bearer token
|
||||
auth_header = request.headers.get("Authorization")
|
||||
bearer_token = None
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
bearer_token = auth_header[7:] # Remove "Bearer " prefix
|
||||
|
||||
if not user_email and not user_name:
|
||||
logger.debug("No OAuth2 authentication headers found")
|
||||
return None
|
||||
|
||||
user_info = {
|
||||
"email": user_email,
|
||||
"username": user_name or user_email,
|
||||
"access_token": access_token,
|
||||
"bearer_token": bearer_token,
|
||||
"auth_source": "oauth2-proxy",
|
||||
"authenticated_at": request.headers.get("X-Auth-Request-Timestamp"),
|
||||
}
|
||||
|
||||
# Extract additional user attributes if present
|
||||
if groups_header := request.headers.get("X-Auth-Request-Groups"):
|
||||
try:
|
||||
# Groups might be base64 encoded or comma-separated
|
||||
if self._is_base64(groups_header):
|
||||
groups_decoded = base64.b64decode(groups_header).decode('utf-8')
|
||||
user_info["groups"] = json.loads(groups_decoded)
|
||||
else:
|
||||
user_info["groups"] = groups_header.split(",")
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
logger.warning(f"Failed to decode groups header: {e}")
|
||||
user_info["groups"] = []
|
||||
|
||||
# Extract user roles if present
|
||||
if roles_header := request.headers.get("X-Auth-Request-Roles"):
|
||||
try:
|
||||
if self._is_base64(roles_header):
|
||||
roles_decoded = base64.b64decode(roles_header).decode('utf-8')
|
||||
user_info["roles"] = json.loads(roles_decoded)
|
||||
else:
|
||||
user_info["roles"] = roles_header.split(",")
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
logger.warning(f"Failed to decode roles header: {e}")
|
||||
user_info["roles"] = []
|
||||
|
||||
# Extract tenant information from headers or JWT token
|
||||
tenant_id = self._extract_tenant_info(request, user_info)
|
||||
if tenant_id:
|
||||
user_info["tenant_id"] = tenant_id
|
||||
|
||||
return user_info
|
||||
|
||||
def _extract_tenant_info(self, request: Request, user_info: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
Extract tenant information from request headers or JWT token.
|
||||
|
||||
Tenant information can come from:
|
||||
1. X-Tenant-ID header (set by load balancer based on domain)
|
||||
2. JWT token claims
|
||||
3. Domain name parsing
|
||||
"""
|
||||
|
||||
# Check for explicit tenant header
|
||||
if tenant_header := request.headers.get("X-Tenant-ID"):
|
||||
return tenant_header
|
||||
|
||||
# Extract tenant from domain name
|
||||
host = request.headers.get("Host", "")
|
||||
if host and "." in host:
|
||||
# Assume format: tenant.gt2.com
|
||||
potential_tenant = host.split(".")[0]
|
||||
if potential_tenant != "www" and potential_tenant != "api":
|
||||
return potential_tenant
|
||||
|
||||
# Try to extract from JWT token if present
|
||||
if bearer_token := user_info.get("bearer_token"):
|
||||
tenant_from_jwt = self._extract_tenant_from_jwt(bearer_token)
|
||||
if tenant_from_jwt:
|
||||
return tenant_from_jwt
|
||||
|
||||
logger.warning(f"Could not determine tenant for user {user_info.get('email', 'unknown')}")
|
||||
return None
|
||||
|
||||
def _extract_tenant_from_jwt(self, token: str) -> Optional[str]:
|
||||
"""
|
||||
Extract tenant information from JWT token without verifying signature.
|
||||
|
||||
Note: This is just for extracting claims, not for security validation.
|
||||
Security validation should be done by OAuth2 Proxy.
|
||||
"""
|
||||
try:
|
||||
# Split JWT token (header.payload.signature)
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
# Decode payload (add padding if needed)
|
||||
payload = parts[1]
|
||||
# Add padding if needed for base64 decoding
|
||||
payload += "=" * (4 - len(payload) % 4)
|
||||
|
||||
decoded_payload = base64.urlsafe_b64decode(payload)
|
||||
claims = json.loads(decoded_payload)
|
||||
|
||||
# Look for tenant in various claim fields
|
||||
tenant_claims = ["tenant_id", "tenant", "org_id", "organization"]
|
||||
for claim in tenant_claims:
|
||||
if claim in claims:
|
||||
return str(claims[claim])
|
||||
|
||||
except (json.JSONDecodeError, UnicodeDecodeError, ValueError) as e:
|
||||
logger.debug(f"Failed to decode JWT payload: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _is_base64(self, s: str) -> bool:
|
||||
"""Check if a string is base64 encoded"""
|
||||
try:
|
||||
if isinstance(s, str):
|
||||
s = s.encode('ascii')
|
||||
return base64.b64encode(base64.b64decode(s)) == s
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _extract_jwt_user(self, request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract user information from JWT token in Authorization header.
|
||||
|
||||
This provides fallback authentication when OAuth2 proxy headers are not present.
|
||||
"""
|
||||
from app.core.security import get_current_user
|
||||
|
||||
try:
|
||||
# Get Authorization header
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
|
||||
# Use the security module to validate and extract user info
|
||||
user_data = await get_current_user(auth_header)
|
||||
|
||||
# Convert security module format to middleware format
|
||||
if user_data:
|
||||
return {
|
||||
"email": user_data.get("email", user_data.get("user_id", "unknown")),
|
||||
"username": user_data.get("tenant_display_name", user_data.get("email", "unknown")),
|
||||
"tenant_id": user_data.get("tenant_id", "1"),
|
||||
"tenant_domain": user_data.get("tenant_domain", "default"),
|
||||
"tenant_name": user_data.get("tenant_name", "Default Tenant"),
|
||||
"tenant_role": user_data.get("tenant_role", "tenant_user"),
|
||||
"user_type": user_data.get("user_type", "tenant_user"),
|
||||
"capabilities": user_data.get("capabilities", []),
|
||||
"resource_limits": user_data.get("resource_limits", {}),
|
||||
"auth_source": "jwt-token",
|
||||
"bearer_token": auth_header[7:], # Remove "Bearer " prefix
|
||||
"authenticated_at": None,
|
||||
"is_primary_tenant": user_data.get("is_primary_tenant", False)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to authenticate via JWT token: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class OAuth2SecurityDependency:
|
||||
"""
|
||||
FastAPI dependency to get current authenticated user from OAuth2 context.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/v1/user/profile")
|
||||
async def get_profile(user: dict = Depends(get_current_user)):
|
||||
return {"user": user}
|
||||
"""
|
||||
|
||||
def __call__(self, request: Request) -> Dict[str, Any]:
|
||||
"""Get current authenticated user from request state"""
|
||||
|
||||
if not hasattr(request.state, "authenticated") or not request.state.authenticated:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
return request.state.user
|
||||
|
||||
|
||||
# Singleton instance for dependency injection
|
||||
get_current_user = OAuth2SecurityDependency()
|
||||
|
||||
|
||||
def get_current_user_optional(request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get current authenticated user (optional - doesn't raise exception if not authenticated).
|
||||
|
||||
Usage:
|
||||
@app.get("/api/v1/public/info")
|
||||
async def get_info(user: Optional[dict] = Depends(get_current_user_optional)):
|
||||
if user:
|
||||
return {"message": f"Hello {user['email']}"}
|
||||
return {"message": "Hello anonymous user"}
|
||||
"""
|
||||
|
||||
if hasattr(request.state, "authenticated") and request.state.authenticated:
|
||||
return request.state.user
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def require_tenant_access(required_tenant: Optional[str] = None):
|
||||
"""
|
||||
Dependency to ensure user has access to specified tenant.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/v1/tenant/{tenant_id}/data")
|
||||
async def get_tenant_data(
|
||||
tenant_id: str,
|
||||
user: dict = Depends(get_current_user),
|
||||
_: None = Depends(require_tenant_access)
|
||||
):
|
||||
# User is guaranteed to have access to tenant_id
|
||||
return {"data": "tenant specific data"}
|
||||
"""
|
||||
|
||||
def dependency(request: Request, user: Dict[str, Any] = Depends(get_current_user)) -> None:
|
||||
"""Check tenant access for current user"""
|
||||
|
||||
user_tenant = user.get("tenant_id")
|
||||
|
||||
# If no required tenant specified, use the one from user context
|
||||
target_tenant = required_tenant or user_tenant
|
||||
|
||||
if not target_tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Tenant information not available"
|
||||
)
|
||||
|
||||
# Check if user has access to the required tenant
|
||||
if user_tenant != target_tenant:
|
||||
logger.warning(
|
||||
f"User {user.get('email', 'unknown')} attempted to access tenant {target_tenant} "
|
||||
f"but belongs to tenant {user_tenant}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: insufficient tenant permissions"
|
||||
)
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
def require_roles(*required_roles: str):
|
||||
"""
|
||||
Dependency to ensure user has one of the required roles.
|
||||
|
||||
Usage:
|
||||
@app.delete("/api/v1/admin/users/{user_id}")
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
user: dict = Depends(get_current_user),
|
||||
_: None = Depends(require_roles("admin", "user_manager"))
|
||||
):
|
||||
# User has admin or user_manager role
|
||||
return {"deleted": user_id}
|
||||
"""
|
||||
|
||||
def dependency(user: Dict[str, Any] = Depends(get_current_user)) -> None:
|
||||
"""Check role requirements for current user"""
|
||||
|
||||
user_roles = set(user.get("roles", []))
|
||||
required_roles_set = set(required_roles)
|
||||
|
||||
if not user_roles.intersection(required_roles_set):
|
||||
logger.warning(
|
||||
f"User {user.get('email', 'unknown')} with roles {list(user_roles)} "
|
||||
f"attempted to access endpoint requiring roles {list(required_roles_set)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access denied: requires one of roles: {', '.join(required_roles)}"
|
||||
)
|
||||
|
||||
return dependency
|
||||
89
apps/tenant-backend/app/middleware/rate_limiting.py
Normal file
89
apps/tenant-backend/app/middleware/rate_limiting.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Rate Limiting Middleware for GT 2.0
|
||||
|
||||
Basic rate limiting implementation for tenant protection.
|
||||
"""
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
import time
|
||||
from typing import Dict, Tuple
|
||||
import logging
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Simple in-memory rate limiting middleware"""
|
||||
|
||||
# Operational endpoints that don't need rate limiting
|
||||
EXEMPT_PATHS = {
|
||||
"/health",
|
||||
"/ready",
|
||||
"/metrics",
|
||||
"/api/v1/health"
|
||||
}
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self._rate_limits: Dict[str, Tuple[int, float]] = {} # ip -> (count, window_start)
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip rate limiting for operational endpoints
|
||||
if request.url.path in self.EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
if self._is_rate_limited(client_ip):
|
||||
logger.warning(f"Rate limit exceeded for IP: {client_ip} - Path: {request.url.path}")
|
||||
# Return proper JSONResponse instead of raising HTTPException to prevent ASGI violations
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Too many requests. Please try again later."},
|
||||
headers={"Retry-After": str(settings.rate_limit_window_seconds)}
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP address"""
|
||||
# Check for forwarded IP first (behind proxy/load balancer)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Check for real IP header
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fall back to direct client IP
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
def _is_rate_limited(self, client_ip: str) -> bool:
|
||||
"""Check if client IP is rate limited"""
|
||||
current_time = time.time()
|
||||
|
||||
if client_ip not in self._rate_limits:
|
||||
self._rate_limits[client_ip] = (1, current_time)
|
||||
return False
|
||||
|
||||
count, window_start = self._rate_limits[client_ip]
|
||||
|
||||
# Check if we're still in the same window
|
||||
if current_time - window_start < settings.rate_limit_window_seconds:
|
||||
if count >= settings.rate_limit_requests:
|
||||
return True # Rate limited
|
||||
else:
|
||||
self._rate_limits[client_ip] = (count + 1, window_start)
|
||||
return False
|
||||
else:
|
||||
# New window, reset count
|
||||
self._rate_limits[client_ip] = (1, current_time)
|
||||
return False
|
||||
36
apps/tenant-backend/app/middleware/security.py
Normal file
36
apps/tenant-backend/app/middleware/security.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Security Headers Middleware for GT 2.0
|
||||
|
||||
Adds security headers to all responses.
|
||||
"""
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import uuid
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add security headers to all responses"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Generate request ID for tracing
|
||||
request_id = str(uuid.uuid4())
|
||||
request.state.request_id = request_id
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# Add security headers
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data: https:; "
|
||||
"connect-src 'self' ws: wss:;"
|
||||
)
|
||||
|
||||
return response
|
||||
156
apps/tenant-backend/app/middleware/session_validation.py
Normal file
156
apps/tenant-backend/app/middleware/session_validation.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
GT 2.0 Session Validation Middleware
|
||||
|
||||
OWASP/NIST Compliant Server-Side Session Validation (Issue #264)
|
||||
- Validates session_id from JWT against server-side session state
|
||||
- Updates session activity on every authenticated request
|
||||
- Adds X-Session-Warning header when < 5 minutes remaining
|
||||
- Returns 401 with X-Session-Expired header when session is invalid
|
||||
"""
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import httpx
|
||||
import logging
|
||||
import jwt
|
||||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionValidationMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to validate server-side sessions on every authenticated request.
|
||||
|
||||
The server-side session is the authoritative source of truth for session validity.
|
||||
JWT expiration is secondary - the session can expire before the JWT does.
|
||||
|
||||
Response Headers:
|
||||
- X-Session-Warning: <seconds> - Added when session is about to expire
|
||||
- X-Session-Expired: idle|absolute - Added on 401 when session expired
|
||||
"""
|
||||
|
||||
def __init__(self, app, control_panel_url: str = None, service_auth_token: str = None):
|
||||
super().__init__(app)
|
||||
self.control_panel_url = control_panel_url or settings.control_panel_url or "http://control-panel-backend:8001"
|
||||
self.service_auth_token = service_auth_token or settings.service_auth_token or "internal-service-token"
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process request and validate server-side session"""
|
||||
|
||||
# Skip session validation for public endpoints
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/refresh",
|
||||
"/api/v1/auth/password-reset",
|
||||
"/api/v1/public",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/redoc"
|
||||
]
|
||||
|
||||
if any(request.url.path.startswith(path) for path in skip_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract JWT from Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return await call_next(request)
|
||||
|
||||
token = auth_header.split(" ")[1]
|
||||
|
||||
# Decode JWT to get session_id (without verification - that's done elsewhere)
|
||||
try:
|
||||
# We just need to extract the session_id claim
|
||||
# Full JWT verification happens in the auth dependency
|
||||
payload = jwt.decode(token, options={"verify_signature": False})
|
||||
session_id = payload.get("session_id")
|
||||
except jwt.InvalidTokenError:
|
||||
# Let the normal auth flow handle invalid tokens
|
||||
return await call_next(request)
|
||||
|
||||
# If no session_id in JWT, skip session validation (backwards compatibility)
|
||||
# This allows old tokens without session_id to work until they expire
|
||||
if not session_id:
|
||||
logger.debug("No session_id in JWT, skipping server-side validation")
|
||||
return await call_next(request)
|
||||
|
||||
# Validate session with control panel
|
||||
validation_result = await self._validate_session(session_id)
|
||||
|
||||
if validation_result is None:
|
||||
# Control panel unavailable - FAIL CLOSED for security (OWASP best practice)
|
||||
# Reject the request rather than allowing potentially expired sessions through
|
||||
logger.error("Session validation failed - control panel unavailable, rejecting request")
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"detail": "Session validation service unavailable",
|
||||
"code": "SESSION_VALIDATION_UNAVAILABLE"
|
||||
},
|
||||
headers={"X-Session-Warning": "validation-unavailable"}
|
||||
)
|
||||
|
||||
if not validation_result.get("is_valid", False):
|
||||
# Session is invalid - return 401 with expiry reason
|
||||
# Ensure expiry_reason is never None (causes header encode error)
|
||||
expiry_reason = validation_result.get("expiry_reason") or "unknown"
|
||||
logger.info(f"Session expired: {expiry_reason}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"detail": f"Session expired ({expiry_reason})",
|
||||
"code": "SESSION_EXPIRED",
|
||||
"expiry_reason": expiry_reason
|
||||
},
|
||||
headers={"X-Session-Expired": expiry_reason}
|
||||
)
|
||||
|
||||
# Session is valid - process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add warning header if session is about to expire
|
||||
if validation_result.get("show_warning", False):
|
||||
seconds_remaining = validation_result.get("seconds_remaining", 0)
|
||||
response.headers["X-Session-Warning"] = str(seconds_remaining)
|
||||
logger.debug(f"Session warning: {seconds_remaining}s remaining")
|
||||
|
||||
return response
|
||||
|
||||
async def _validate_session(self, session_token: str) -> dict | None:
|
||||
"""
|
||||
Validate session with control panel internal API.
|
||||
|
||||
Returns:
|
||||
dict with is_valid, expiry_reason, seconds_remaining, show_warning
|
||||
or None if control panel is unavailable
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.control_panel_url}/internal/sessions/validate",
|
||||
json={"session_token": session_token},
|
||||
headers={
|
||||
"X-Service-Auth": self.service_auth_token,
|
||||
"X-Service-Name": "tenant-backend"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.error(f"Session validation failed: {response.status_code} - {response.text}")
|
||||
return None
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Session validation request failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during session validation: {e}")
|
||||
return None
|
||||
48
apps/tenant-backend/app/middleware/tenant_isolation.py
Normal file
48
apps/tenant-backend/app/middleware/tenant_isolation.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Tenant Isolation Middleware for GT 2.0
|
||||
|
||||
Ensures perfect tenant isolation for all requests.
|
||||
"""
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import logging
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class TenantIsolationMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to enforce tenant isolation boundaries"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Add tenant context to request
|
||||
request.state.tenant_id = settings.tenant_id
|
||||
request.state.tenant_domain = settings.tenant_domain
|
||||
|
||||
# Validate tenant isolation
|
||||
await self._validate_tenant_isolation(request)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# Add tenant headers to response
|
||||
response.headers["X-Tenant-Domain"] = settings.tenant_domain
|
||||
response.headers["X-Tenant-Isolated"] = "true"
|
||||
|
||||
return response
|
||||
|
||||
async def _validate_tenant_isolation(self, request: Request):
|
||||
"""Validate that all operations are tenant-isolated"""
|
||||
# This is where we would add tenant boundary validation
|
||||
# For now, we just log the tenant context
|
||||
logger.debug(
|
||||
"Tenant isolation validated",
|
||||
extra={
|
||||
"tenant_id": settings.tenant_id,
|
||||
"tenant_domain": settings.tenant_domain,
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
}
|
||||
)
|
||||
41
apps/tenant-backend/app/models/__init__.py
Normal file
41
apps/tenant-backend/app/models/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend Models
|
||||
|
||||
Database models for tenant-specific data with perfect isolation.
|
||||
Each tenant has their own SQLite database with these models.
|
||||
"""
|
||||
|
||||
from .agent import Agent # Complete migration - only Agent class
|
||||
from .conversation import Conversation
|
||||
from .message import Message
|
||||
from .document import Document, RAGDataset, DatasetDocument, DocumentChunk
|
||||
from .user_session import UserSession
|
||||
from .workflow import (
|
||||
Workflow,
|
||||
WorkflowExecution,
|
||||
WorkflowTrigger,
|
||||
WorkflowSession,
|
||||
WorkflowMessage,
|
||||
WorkflowStatus,
|
||||
TriggerType,
|
||||
InteractionMode
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"Conversation",
|
||||
"Message",
|
||||
"Document",
|
||||
"RAGDataset",
|
||||
"DatasetDocument",
|
||||
"DocumentChunk",
|
||||
"UserSession",
|
||||
"Workflow",
|
||||
"WorkflowExecution",
|
||||
"WorkflowTrigger",
|
||||
"WorkflowSession",
|
||||
"WorkflowMessage",
|
||||
"WorkflowStatus",
|
||||
"TriggerType",
|
||||
"InteractionMode",
|
||||
]
|
||||
299
apps/tenant-backend/app/models/access_group.py
Normal file
299
apps/tenant-backend/app/models/access_group.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Access Group Models for GT 2.0 Tenant Backend - Service-Based Architecture
|
||||
|
||||
Pydantic models for access group entities using the PostgreSQL + PGVector backend.
|
||||
Implements simplified Tenant → User hierarchy with access groups for resource sharing.
|
||||
NO TEAM ENTITIES - using access groups instead for collaboration.
|
||||
Perfect tenant isolation - each tenant has separate access data.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
from pydantic import Field, ConfigDict
|
||||
from app.models.base import BaseServiceModel, BaseCreateModel, BaseUpdateModel, BaseResponseModel
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
"""Generate a unique identifier"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class AccessGroup(str, Enum):
|
||||
"""Resource access levels within a tenant"""
|
||||
INDIVIDUAL = "individual" # Private to owner
|
||||
TEAM = "team" # Shared with specific users
|
||||
ORGANIZATION = "organization" # Read-only for all tenant users
|
||||
|
||||
|
||||
class TenantStructure(BaseServiceModel):
|
||||
"""
|
||||
Simplified hierarchy model for GT 2.0 service-based architecture.
|
||||
|
||||
Direct tenant-to-user relationship with access groups for sharing.
|
||||
NO TEAM ENTITIES - using access groups instead for collaboration.
|
||||
"""
|
||||
|
||||
# Core tenant properties
|
||||
tenant_domain: str = Field(..., description="Tenant domain (e.g., customer1.com)")
|
||||
tenant_id: str = Field(..., description="Unique tenant identifier")
|
||||
|
||||
# Tenant settings
|
||||
settings: Dict[str, Any] = Field(default_factory=dict, description="Tenant-wide settings")
|
||||
|
||||
# Statistics
|
||||
user_count: int = Field(default=0, description="Number of users")
|
||||
resource_count: int = Field(default=0, description="Number of resources")
|
||||
|
||||
# Status
|
||||
is_active: bool = Field(default=True, description="Whether tenant is active")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "tenant_structures"
|
||||
|
||||
def activate(self) -> None:
|
||||
"""Activate the tenant"""
|
||||
self.is_active = True
|
||||
self.update_timestamp()
|
||||
|
||||
def deactivate(self) -> None:
|
||||
"""Deactivate the tenant"""
|
||||
self.is_active = False
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class User(BaseServiceModel):
|
||||
"""
|
||||
User model for GT 2.0 service-based architecture.
|
||||
|
||||
User within a tenant with role-based permissions.
|
||||
"""
|
||||
|
||||
# Core user properties
|
||||
user_id: str = Field(default_factory=generate_uuid, description="Unique user identifier")
|
||||
email: str = Field(..., description="User email address")
|
||||
full_name: str = Field(..., description="User full name")
|
||||
role: str = Field(..., description="User role (admin, developer, analyst, student)")
|
||||
tenant_domain: str = Field(..., description="Parent tenant domain")
|
||||
|
||||
# User status
|
||||
is_active: bool = Field(default=True, description="Whether user is active")
|
||||
last_active: Optional[datetime] = Field(None, description="Last activity timestamp")
|
||||
|
||||
# User settings
|
||||
preferences: Dict[str, Any] = Field(default_factory=dict, description="User preferences")
|
||||
|
||||
# Statistics
|
||||
owned_resources_count: int = Field(default=0, description="Number of owned resources")
|
||||
team_resources_count: int = Field(default=0, description="Number of team resources accessible")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "users"
|
||||
|
||||
def update_activity(self) -> None:
|
||||
"""Update last activity timestamp"""
|
||||
self.last_active = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
def can_access_resource(self, resource_access_group: AccessGroup, resource_owner_id: str,
|
||||
resource_team_members: List[str]) -> bool:
|
||||
"""Check if user can access a resource"""
|
||||
# Owner always has access
|
||||
if resource_owner_id == self.user_id:
|
||||
return True
|
||||
|
||||
# Organization-wide resources
|
||||
if resource_access_group == AccessGroup.ORGANIZATION:
|
||||
return True
|
||||
|
||||
# Team resources
|
||||
if resource_access_group == AccessGroup.TEAM:
|
||||
return self.user_id in resource_team_members
|
||||
|
||||
return False
|
||||
|
||||
def can_modify_resource(self, resource_owner_id: str) -> bool:
|
||||
"""Check if user can modify a resource"""
|
||||
# Only owner can modify
|
||||
return resource_owner_id == self.user_id
|
||||
|
||||
|
||||
class Resource(BaseServiceModel):
|
||||
"""
|
||||
Base resource model for GT 2.0 service-based architecture.
|
||||
|
||||
Base class for any resource (agent, dataset, automation, etc.)
|
||||
with file-based storage and access control.
|
||||
"""
|
||||
|
||||
# Core resource properties
|
||||
resource_uuid: str = Field(default_factory=generate_uuid, description="Unique resource identifier")
|
||||
name: str = Field(..., min_length=1, max_length=200, description="Resource name")
|
||||
resource_type: str = Field(..., max_length=50, description="Type of resource")
|
||||
owner_id: str = Field(..., description="Owner user ID")
|
||||
tenant_domain: str = Field(..., description="Parent tenant domain")
|
||||
|
||||
# Access control
|
||||
access_group: AccessGroup = Field(default=AccessGroup.INDIVIDUAL, description="Access level")
|
||||
team_members: List[str] = Field(default_factory=list, description="Team member IDs for team access")
|
||||
|
||||
# File storage
|
||||
file_path: Optional[str] = Field(None, description="File-based storage path")
|
||||
file_permissions: str = Field(default="700", description="Unix file permissions")
|
||||
|
||||
# Resource metadata
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Resource-specific metadata")
|
||||
description: Optional[str] = Field(None, max_length=1000, description="Resource description")
|
||||
|
||||
# Statistics
|
||||
access_count: int = Field(default=0, description="Number of times accessed")
|
||||
last_accessed: Optional[datetime] = Field(None, description="Last access timestamp")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "resources"
|
||||
|
||||
def update_access_group(self, new_group: AccessGroup, team_members: Optional[List[str]] = None) -> None:
|
||||
"""Update resource access group"""
|
||||
self.access_group = new_group
|
||||
self.team_members = team_members if new_group == AccessGroup.TEAM else []
|
||||
self.update_timestamp()
|
||||
|
||||
def add_team_member(self, user_id: str) -> None:
|
||||
"""Add user to team access"""
|
||||
if self.access_group == AccessGroup.TEAM and user_id not in self.team_members:
|
||||
self.team_members.append(user_id)
|
||||
self.update_timestamp()
|
||||
|
||||
def remove_team_member(self, user_id: str) -> None:
|
||||
"""Remove user from team access"""
|
||||
if user_id in self.team_members:
|
||||
self.team_members.remove(user_id)
|
||||
self.update_timestamp()
|
||||
|
||||
def record_access(self, user_id: str) -> None:
|
||||
"""Record resource access"""
|
||||
self.access_count += 1
|
||||
self.last_accessed = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
def get_file_permissions(self) -> str:
|
||||
"""
|
||||
Get Unix file permissions based on access group.
|
||||
All files created with 700 permissions (owner only).
|
||||
OS User: gt-{tenant_domain}-{pod_id}
|
||||
"""
|
||||
return "700" # Owner read/write/execute only
|
||||
|
||||
|
||||
# Create/Update/Response models
|
||||
|
||||
class AccessGroupModel(BaseCreateModel):
|
||||
"""API model for access group configuration"""
|
||||
access_group: AccessGroup = Field(..., description="Access level")
|
||||
team_members: List[str] = Field(default_factory=list, description="Team member IDs if team access")
|
||||
|
||||
|
||||
class ResourceCreate(BaseCreateModel):
|
||||
"""Model for creating resources"""
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
resource_type: str = Field(..., max_length=50)
|
||||
owner_id: str
|
||||
tenant_domain: str
|
||||
access_group: AccessGroup = Field(default=AccessGroup.INDIVIDUAL)
|
||||
team_members: List[str] = Field(default_factory=list)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
|
||||
|
||||
class ResourceUpdate(BaseUpdateModel):
|
||||
"""Model for updating resources"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
access_group: Optional[AccessGroup] = None
|
||||
team_members: Optional[List[str]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
|
||||
|
||||
class ResourceResponse(BaseResponseModel):
|
||||
"""Model for resource API responses"""
|
||||
id: str
|
||||
resource_uuid: str
|
||||
name: str
|
||||
resource_type: str
|
||||
owner_id: str
|
||||
tenant_domain: str
|
||||
access_group: AccessGroup
|
||||
team_members: List[str]
|
||||
file_path: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
description: Optional[str]
|
||||
access_count: int
|
||||
last_accessed: Optional[datetime]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class UserCreate(BaseCreateModel):
|
||||
"""Model for creating users"""
|
||||
email: str
|
||||
full_name: str
|
||||
role: str
|
||||
tenant_domain: str
|
||||
preferences: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class UserUpdate(BaseUpdateModel):
|
||||
"""Model for updating users"""
|
||||
full_name: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class UserResponse(BaseResponseModel):
|
||||
"""Model for user API responses"""
|
||||
id: str
|
||||
user_id: str
|
||||
email: str
|
||||
full_name: str
|
||||
role: str
|
||||
tenant_domain: str
|
||||
is_active: bool
|
||||
last_active: Optional[datetime]
|
||||
preferences: Dict[str, Any]
|
||||
owned_resources_count: int
|
||||
team_resources_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
184
apps/tenant-backend/app/models/agent.py
Normal file
184
apps/tenant-backend/app/models/agent.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
GT 2.0 Agent Model - Service-Based Architecture
|
||||
|
||||
Pydantic models for agent entities using the PostgreSQL + PGVector backend.
|
||||
Complete migration - all assistant terminology has been replaced with agent.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import Field, ConfigDict, field_validator
|
||||
from app.models.base import BaseServiceModel, BaseCreateModel, BaseUpdateModel, BaseResponseModel
|
||||
|
||||
|
||||
class AgentStatus(str, Enum):
|
||||
"""Agent status enumeration"""
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class AgentVisibility(str, Enum):
|
||||
"""Agent visibility levels"""
|
||||
INDIVIDUAL = "individual"
|
||||
TEAM = "team"
|
||||
ORGANIZATION = "organization"
|
||||
|
||||
|
||||
class Agent(BaseServiceModel):
|
||||
"""
|
||||
Agent model for GT 2.0 service-based architecture.
|
||||
|
||||
Represents an AI agent configuration with capabilities, model settings,
|
||||
and access control for perfect tenant isolation.
|
||||
"""
|
||||
|
||||
# Core agent properties
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Agent display name")
|
||||
description: Optional[str] = Field(None, max_length=1000, description="Agent description")
|
||||
instructions: Optional[str] = Field(None, description="System instructions for the agent")
|
||||
|
||||
# Model configuration
|
||||
model_provider: str = Field(default="groq", description="AI model provider")
|
||||
model_name: str = Field(default="llama3-groq-8b-8192-tool-use-preview", description="Model identifier")
|
||||
model_settings: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Model-specific configuration")
|
||||
|
||||
# Capabilities and tools
|
||||
capabilities: Optional[List[str]] = Field(default_factory=list, description="Agent capabilities")
|
||||
tools: Optional[List[str]] = Field(default_factory=list, description="Available tools")
|
||||
|
||||
# MCP (Model Context Protocol) tool configuration
|
||||
mcp_servers: Optional[List[str]] = Field(default_factory=list, description="MCP servers this agent can access")
|
||||
rag_enabled: bool = Field(default=False, description="Whether agent can access RAG tools")
|
||||
|
||||
# Access control
|
||||
owner_id: str = Field(..., description="User ID of the agent owner")
|
||||
access_group: str = Field(default="individual", description="Access group for sharing")
|
||||
visibility: AgentVisibility = Field(default=AgentVisibility.INDIVIDUAL, description="Agent visibility level")
|
||||
|
||||
# Status and metadata
|
||||
status: AgentStatus = Field(default=AgentStatus.ACTIVE, description="Agent status")
|
||||
featured: bool = Field(default=False, description="Whether agent is featured")
|
||||
tags: Optional[List[str]] = Field(default_factory=list, description="Agent tags for categorization")
|
||||
category: Optional[str] = Field(None, max_length=100, description="Agent category")
|
||||
|
||||
# Usage statistics
|
||||
conversation_count: int = Field(default=0, description="Number of conversations")
|
||||
last_used_at: Optional[datetime] = Field(None, description="Last usage timestamp")
|
||||
|
||||
# UI/UX Enhancement Fields
|
||||
disclaimer: Optional[str] = Field(None, max_length=500, description="Disclaimer text shown in chat")
|
||||
easy_prompts: Optional[List[str]] = Field(default_factory=list, max_length=10, description="Quick-access preset prompts (max 10)")
|
||||
|
||||
@field_validator('disclaimer')
|
||||
@classmethod
|
||||
def validate_disclaimer(cls, v):
|
||||
"""Validate disclaimer length"""
|
||||
if v and len(v) > 500:
|
||||
raise ValueError('Disclaimer must be 500 characters or less')
|
||||
return v
|
||||
|
||||
@field_validator('easy_prompts')
|
||||
@classmethod
|
||||
def validate_easy_prompts(cls, v):
|
||||
"""Validate easy prompts count"""
|
||||
if v and len(v) > 10:
|
||||
raise ValueError('Maximum 10 easy prompts allowed')
|
||||
return v
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(), # Allow model_ fields
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "agents"
|
||||
|
||||
def increment_usage(self):
|
||||
"""Increment usage statistics"""
|
||||
self.conversation_count += 1
|
||||
self.last_used_at = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class AgentCreate(BaseCreateModel):
|
||||
"""Model for creating new agents"""
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
instructions: Optional[str] = None
|
||||
model_provider: str = Field(default="groq")
|
||||
model_name: str = Field(default="llama3-groq-8b-8192-tool-use-preview")
|
||||
model_settings: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
capabilities: Optional[List[str]] = Field(default_factory=list)
|
||||
tools: Optional[List[str]] = Field(default_factory=list)
|
||||
mcp_servers: Optional[List[str]] = Field(default_factory=list)
|
||||
rag_enabled: bool = Field(default=False)
|
||||
owner_id: str
|
||||
access_group: str = Field(default="individual")
|
||||
visibility: AgentVisibility = Field(default=AgentVisibility.INDIVIDUAL)
|
||||
tags: Optional[List[str]] = Field(default_factory=list)
|
||||
category: Optional[str] = None
|
||||
disclaimer: Optional[str] = Field(None, max_length=500)
|
||||
easy_prompts: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class AgentUpdate(BaseUpdateModel):
|
||||
"""Model for updating agents"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
instructions: Optional[str] = None
|
||||
model_provider: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
model_settings: Optional[Dict[str, Any]] = None
|
||||
capabilities: Optional[List[str]] = None
|
||||
tools: Optional[List[str]] = None
|
||||
access_group: Optional[str] = None
|
||||
visibility: Optional[AgentVisibility] = None
|
||||
status: Optional[AgentStatus] = None
|
||||
featured: Optional[bool] = None
|
||||
tags: Optional[List[str]] = None
|
||||
category: Optional[str] = None
|
||||
disclaimer: Optional[str] = None
|
||||
easy_prompts: Optional[List[str]] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class AgentResponse(BaseResponseModel):
|
||||
"""Model for agent API responses"""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
instructions: Optional[str]
|
||||
model_provider: str
|
||||
model_name: str
|
||||
model_settings: Dict[str, Any]
|
||||
capabilities: List[str]
|
||||
tools: List[str]
|
||||
owner_id: str
|
||||
access_group: str
|
||||
visibility: AgentVisibility
|
||||
status: AgentStatus
|
||||
featured: bool
|
||||
tags: List[str]
|
||||
category: Optional[str]
|
||||
conversation_count: int
|
||||
usage_count: int = 0 # Alias for conversation_count for frontend compatibility
|
||||
last_used_at: Optional[datetime]
|
||||
disclaimer: Optional[str]
|
||||
easy_prompts: List[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
345
apps/tenant-backend/app/models/agent_original.py
Normal file
345
apps/tenant-backend/app/models/agent_original.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Agent Model for GT 2.0 Tenant Backend
|
||||
|
||||
File-based agent configuration with DuckDB reference tracking.
|
||||
Perfect tenant isolation - each tenant has separate agent data.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
import os
|
||||
import json
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.core.database import Base
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
class Agent(Base):
|
||||
"""Agent model for AI agent configurations"""
|
||||
|
||||
__tablename__ = "agents"
|
||||
|
||||
# Primary Key - using UUID for PostgreSQL compatibility
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
|
||||
# Agent Details
|
||||
name = Column(String(200), nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
template_id = Column(String(100), nullable=True, index=True) # Template used to create this agent
|
||||
category_id = Column(String(36), nullable=True, index=True) # Foreign key to categories table for discovery
|
||||
agent_type = Column(String(50), nullable=False, default="custom", index=True) # Agent type/category
|
||||
prompt_template = Column(Text, nullable=True) # System prompt template
|
||||
|
||||
# Visibility and Sharing (GT 2.0 Team Enhancement)
|
||||
visibility = Column(String(20), nullable=False, default="private", index=True) # private, team, organization
|
||||
tenant_id = Column(String(36), nullable=True, index=True) # Foreign key to teams table (null for private)
|
||||
shared_with = Column(JSON, nullable=False, default=list) # List of user emails for explicit sharing
|
||||
|
||||
# File-based Configuration References
|
||||
config_file_path = Column(String(500), nullable=False) # Path to config.json
|
||||
prompt_file_path = Column(String(500), nullable=False) # Path to prompt.md
|
||||
capabilities_file_path = Column(String(500), nullable=False) # Path to capabilities.json
|
||||
|
||||
# User Information (from JWT token)
|
||||
created_by = Column(String(255), nullable=False, index=True) # User email or ID
|
||||
user_id = Column(String(255), nullable=False, index=True) # User ID (alias for created_by for API compatibility)
|
||||
user_name = Column(String(100), nullable=True) # User display name
|
||||
|
||||
# Agent Configuration (cached from files for quick access)
|
||||
personality_config = Column(JSON, nullable=False, default=dict) # Tone, style, etc.
|
||||
resource_preferences = Column(JSON, nullable=False, default=dict) # LLM preferences, etc.
|
||||
memory_settings = Column(JSON, nullable=False, default=dict) # Conversation retention settings
|
||||
|
||||
# Status and Metadata
|
||||
is_active = Column(Boolean, nullable=False, default=True)
|
||||
is_favorite = Column(Boolean, nullable=False, default=False)
|
||||
tags = Column(JSON, nullable=False, default=list) # User-defined tags
|
||||
example_prompts = Column(JSON, nullable=False, default=list) # Up to 4 example prompts for discovery
|
||||
|
||||
# Statistics (updated by triggers or background processes)
|
||||
conversation_count = Column(Integer, nullable=False, default=0)
|
||||
total_messages = Column(Integer, nullable=False, default=0)
|
||||
total_tokens_used = Column(Integer, nullable=False, default=0)
|
||||
total_cost_cents = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
conversations = relationship("Conversation", back_populates="agent", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Agent(id={self.id}, name='{self.name}', created_by='{self.created_by}')>"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API responses"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"uuid": str(self.uuid),
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"template_id": self.template_id,
|
||||
"created_by": self.created_by,
|
||||
"user_name": self.user_name,
|
||||
"personality_config": self.personality_config,
|
||||
"resource_preferences": self.resource_preferences,
|
||||
"memory_settings": self.memory_settings,
|
||||
"is_active": self.is_active,
|
||||
"is_favorite": self.is_favorite,
|
||||
"tags": self.tags,
|
||||
"conversation_count": self.conversation_count,
|
||||
"total_messages": self.total_messages,
|
||||
"total_tokens_used": self.total_tokens_used,
|
||||
"total_cost_cents": self.total_cost_cents,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Agent":
|
||||
"""Create from dictionary"""
|
||||
created_by = data.get("created_by", data.get("user_id", ""))
|
||||
return cls(
|
||||
name=data.get("name", ""),
|
||||
description=data.get("description"),
|
||||
template_id=data.get("template_id"),
|
||||
agent_type=data.get("agent_type", "custom"),
|
||||
prompt_template=data.get("prompt_template", ""),
|
||||
created_by=created_by,
|
||||
user_id=created_by, # Keep in sync
|
||||
user_name=data.get("user_name"),
|
||||
personality_config=data.get("personality_config", {}),
|
||||
resource_preferences=data.get("resource_preferences", {}),
|
||||
memory_settings=data.get("memory_settings", {}),
|
||||
tags=data.get("tags", []),
|
||||
)
|
||||
|
||||
def get_agent_directory(self) -> str:
|
||||
"""Get the file system directory for this agent"""
|
||||
settings = get_settings()
|
||||
tenant_data_path = os.path.dirname(settings.database_path)
|
||||
return os.path.join(tenant_data_path, "agents", str(self.uuid))
|
||||
|
||||
def ensure_directory_exists(self) -> None:
|
||||
"""Create agent directory with secure permissions"""
|
||||
agent_dir = self.get_agent_directory()
|
||||
os.makedirs(agent_dir, exist_ok=True, mode=0o700)
|
||||
|
||||
# Create subdirectories
|
||||
subdirs = ["memory", "memory/conversations", "memory/context", "memory/preferences", "resources"]
|
||||
for subdir in subdirs:
|
||||
subdir_path = os.path.join(agent_dir, subdir)
|
||||
os.makedirs(subdir_path, exist_ok=True, mode=0o700)
|
||||
|
||||
def initialize_file_paths(self) -> None:
|
||||
"""Initialize file paths for this agent"""
|
||||
agent_dir = self.get_agent_directory()
|
||||
self.config_file_path = os.path.join(agent_dir, "config.json")
|
||||
self.prompt_file_path = os.path.join(agent_dir, "prompt.md")
|
||||
self.capabilities_file_path = os.path.join(agent_dir, "capabilities.json")
|
||||
|
||||
def load_config_from_file(self) -> Dict[str, Any]:
|
||||
"""Load agent configuration from file"""
|
||||
try:
|
||||
with open(self.config_file_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
def save_config_to_file(self, config: Dict[str, Any]) -> None:
|
||||
"""Save agent configuration to file"""
|
||||
self.ensure_directory_exists()
|
||||
with open(self.config_file_path, 'w') as f:
|
||||
json.dump(config, f, indent=2, default=str)
|
||||
|
||||
def load_prompt_from_file(self) -> str:
|
||||
"""Load system prompt from file"""
|
||||
try:
|
||||
with open(self.prompt_file_path, 'r') as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return ""
|
||||
|
||||
def save_prompt_to_file(self, prompt: str) -> None:
|
||||
"""Save system prompt to file"""
|
||||
self.ensure_directory_exists()
|
||||
with open(self.prompt_file_path, 'w') as f:
|
||||
f.write(prompt)
|
||||
|
||||
def load_capabilities_from_file(self) -> List[Dict[str, Any]]:
|
||||
"""Load capabilities from file"""
|
||||
try:
|
||||
with open(self.capabilities_file_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return []
|
||||
|
||||
def save_capabilities_to_file(self, capabilities: List[Dict[str, Any]]) -> None:
|
||||
"""Save capabilities to file"""
|
||||
self.ensure_directory_exists()
|
||||
with open(self.capabilities_file_path, 'w') as f:
|
||||
json.dump(capabilities, f, indent=2, default=str)
|
||||
|
||||
def update_statistics(self, conversation_count: int = None, messages: int = None,
|
||||
tokens: int = None, cost_cents: int = None) -> None:
|
||||
"""Update agent statistics"""
|
||||
if conversation_count is not None:
|
||||
self.conversation_count = conversation_count
|
||||
if messages is not None:
|
||||
self.total_messages += messages
|
||||
if tokens is not None:
|
||||
self.total_tokens_used += tokens
|
||||
if cost_cents is not None:
|
||||
self.total_cost_cents += cost_cents
|
||||
|
||||
self.last_used_at = datetime.utcnow()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def add_tag(self, tag: str) -> None:
|
||||
"""Add a tag to the agent"""
|
||||
if tag not in self.tags:
|
||||
current_tags = self.tags or []
|
||||
current_tags.append(tag)
|
||||
self.tags = current_tags
|
||||
|
||||
def remove_tag(self, tag: str) -> None:
|
||||
"""Remove a tag from the agent"""
|
||||
if self.tags and tag in self.tags:
|
||||
current_tags = self.tags.copy()
|
||||
current_tags.remove(tag)
|
||||
self.tags = current_tags
|
||||
|
||||
def get_full_configuration(self) -> Dict[str, Any]:
|
||||
"""Get complete agent configuration including file-based data"""
|
||||
config = self.load_config_from_file()
|
||||
prompt = self.load_prompt_from_file()
|
||||
capabilities = self.load_capabilities_from_file()
|
||||
|
||||
return {
|
||||
**self.to_dict(),
|
||||
"config": config,
|
||||
"prompt": prompt,
|
||||
"capabilities": capabilities,
|
||||
}
|
||||
|
||||
def clone(self, new_name: str, user_identifier: str, modifications: Dict[str, Any] = None) -> "Agent":
|
||||
"""Create a clone of this agent with modifications"""
|
||||
# Load current configuration
|
||||
config = self.load_config_from_file()
|
||||
prompt = self.load_prompt_from_file()
|
||||
capabilities = self.load_capabilities_from_file()
|
||||
|
||||
# Apply modifications if provided
|
||||
if modifications:
|
||||
config.update(modifications.get("config", {}))
|
||||
if "prompt" in modifications:
|
||||
prompt = modifications["prompt"]
|
||||
if "capabilities" in modifications:
|
||||
capabilities = modifications["capabilities"]
|
||||
|
||||
# Create new agent
|
||||
new_agent = Agent(
|
||||
name=new_name,
|
||||
description=f"Clone of {self.name}",
|
||||
template_id=self.template_id,
|
||||
created_by=user_identifier,
|
||||
personality_config=self.personality_config.copy(),
|
||||
resource_preferences=self.resource_preferences.copy(),
|
||||
memory_settings=self.memory_settings.copy(),
|
||||
tags=self.tags.copy() if self.tags else [],
|
||||
)
|
||||
|
||||
return new_agent
|
||||
|
||||
def archive(self) -> None:
|
||||
"""Archive the agent (soft delete)"""
|
||||
self.is_active = False
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def unarchive(self) -> None:
|
||||
"""Unarchive the agent"""
|
||||
self.is_active = True
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def favorite(self) -> None:
|
||||
"""Mark agent as favorite"""
|
||||
self.is_favorite = True
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def unfavorite(self) -> None:
|
||||
"""Remove favorite status"""
|
||||
self.is_favorite = False
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def is_owned_by(self, user_identifier: str) -> bool:
|
||||
"""Check if agent is owned by the given user"""
|
||||
return self.created_by == user_identifier
|
||||
|
||||
def can_be_accessed_by(self, user_identifier: str, user_teams: List[int] = None) -> bool:
|
||||
"""Check if agent can be accessed by the given user
|
||||
|
||||
GT 2.0 Access Rules:
|
||||
1. Owner always has access
|
||||
2. Team members have access if visibility is 'team' and they're in the team
|
||||
3. All organization members have access if visibility is 'organization'
|
||||
4. Explicitly shared users have access
|
||||
"""
|
||||
# Owner always has access
|
||||
if self.is_owned_by(user_identifier):
|
||||
return True
|
||||
|
||||
# Check explicit sharing
|
||||
if self.shared_with and user_identifier in self.shared_with:
|
||||
return True
|
||||
|
||||
# Check team visibility
|
||||
if self.visibility == "team" and self.tenant_id and user_teams:
|
||||
if self.tenant_id in user_teams:
|
||||
return True
|
||||
|
||||
# Check organization visibility
|
||||
if self.visibility == "organization":
|
||||
return True # All authenticated users in the tenant
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def average_tokens_per_message(self) -> float:
|
||||
"""Calculate average tokens per message"""
|
||||
if self.total_messages == 0:
|
||||
return 0.0
|
||||
return self.total_tokens_used / self.total_messages
|
||||
|
||||
@property
|
||||
def total_cost_dollars(self) -> float:
|
||||
"""Get total cost in dollars"""
|
||||
return self.total_cost_cents / 100.0
|
||||
|
||||
@property
|
||||
def average_cost_per_conversation(self) -> float:
|
||||
"""Calculate average cost per conversation in dollars"""
|
||||
if self.conversation_count == 0:
|
||||
return 0.0
|
||||
return self.total_cost_dollars / self.conversation_count
|
||||
|
||||
@property
|
||||
def usage_count(self) -> int:
|
||||
"""Alias for conversation_count for API compatibility"""
|
||||
return self.conversation_count
|
||||
|
||||
@usage_count.setter
|
||||
def usage_count(self, value: int) -> None:
|
||||
"""Set conversation_count via usage_count alias"""
|
||||
self.conversation_count = value
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
Agent = Agent
|
||||
166
apps/tenant-backend/app/models/assistant_dataset.py
Normal file
166
apps/tenant-backend/app/models/assistant_dataset.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Agent-Dataset Binding Model for GT 2.0 Tenant Backend
|
||||
|
||||
Links agents to RAG datasets for context-aware conversations.
|
||||
Follows GT 2.0's principle of "Elegant Simplicity"
|
||||
- Simple many-to-many relationships
|
||||
- Configurable relevance thresholds
|
||||
- Priority ordering for multiple datasets
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Float, ForeignKey, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
"""Generate a unique identifier"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class AssistantDataset(Base):
|
||||
"""Links agents to RAG datasets for context retrieval
|
||||
|
||||
GT 2.0 Design: Simple binding table with configuration
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_datasets"
|
||||
|
||||
# Primary Key
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
|
||||
# Foreign Keys
|
||||
agent_id = Column(String(36), ForeignKey("agents.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
dataset_id = Column(String(36), ForeignKey("rag_datasets.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
|
||||
# Configuration
|
||||
relevance_threshold = Column(Float, nullable=False, default=0.7) # Minimum similarity score
|
||||
max_chunks = Column(Integer, nullable=False, default=5) # Max chunks to retrieve
|
||||
priority_order = Column(Integer, nullable=False, default=0) # Order when multiple datasets (lower = higher priority)
|
||||
|
||||
# Settings
|
||||
is_active = Column(Boolean, nullable=False, default=True)
|
||||
auto_include = Column(Boolean, nullable=False, default=True) # Automatically include in searches
|
||||
|
||||
# Usage Statistics
|
||||
search_count = Column(Integer, nullable=False, default=0)
|
||||
chunks_retrieved_total = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
agent = relationship("Agent", backref="dataset_bindings")
|
||||
dataset = relationship("RAGDataset", backref="assistant_bindings")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssistantDataset(agent_id={self.agent_id}, dataset_id='{self.dataset_id}')>"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API responses"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"agent_id": self.agent_id,
|
||||
"dataset_id": self.dataset_id,
|
||||
"relevance_threshold": self.relevance_threshold,
|
||||
"max_chunks": self.max_chunks,
|
||||
"priority_order": self.priority_order,
|
||||
"is_active": self.is_active,
|
||||
"auto_include": self.auto_include,
|
||||
"search_count": self.search_count,
|
||||
"chunks_retrieved_total": self.chunks_retrieved_total,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
}
|
||||
|
||||
def increment_usage(self, chunks_retrieved: int = 0) -> None:
|
||||
"""Update usage statistics"""
|
||||
self.search_count += 1
|
||||
self.chunks_retrieved_total += chunks_retrieved
|
||||
self.last_used_at = datetime.utcnow()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
class AssistantIntegration(Base):
|
||||
"""Links agents to external integrations and tools
|
||||
|
||||
GT 2.0 Design: Simple binding to resource cluster integrations
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_integrations"
|
||||
|
||||
# Primary Key
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
|
||||
# Foreign Keys
|
||||
agent_id = Column(String(36), ForeignKey("agents.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
integration_resource_id = Column(String(36), nullable=False, index=True) # Resource cluster integration ID
|
||||
|
||||
# Configuration
|
||||
integration_type = Column(String(50), nullable=False) # github, slack, jira, etc.
|
||||
enabled = Column(Boolean, nullable=False, default=True)
|
||||
config = Column(String, nullable=False, default="{}") # JSON configuration
|
||||
|
||||
# Permissions
|
||||
allowed_actions = Column(String, nullable=False, default="[]") # JSON array of allowed actions
|
||||
|
||||
# Usage Statistics
|
||||
usage_count = Column(Integer, nullable=False, default=0)
|
||||
last_error = Column(String, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
agent = relationship("Agent", backref="integration_bindings")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssistantIntegration(agent_id={self.agent_id}, type='{self.integration_type}')>"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API responses"""
|
||||
import json
|
||||
|
||||
try:
|
||||
config_obj = json.loads(self.config) if isinstance(self.config, str) else self.config
|
||||
allowed_actions_list = json.loads(self.allowed_actions) if isinstance(self.allowed_actions, str) else self.allowed_actions
|
||||
except json.JSONDecodeError:
|
||||
config_obj = {}
|
||||
allowed_actions_list = []
|
||||
|
||||
return {
|
||||
"id": self.id,
|
||||
"agent_id": self.agent_id,
|
||||
"integration_resource_id": self.integration_resource_id,
|
||||
"integration_type": self.integration_type,
|
||||
"enabled": self.enabled,
|
||||
"config": config_obj,
|
||||
"allowed_actions": allowed_actions_list,
|
||||
"usage_count": self.usage_count,
|
||||
"last_error": self.last_error,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
}
|
||||
|
||||
def increment_usage(self) -> None:
|
||||
"""Update usage statistics"""
|
||||
self.usage_count += 1
|
||||
self.last_used_at = datetime.utcnow()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def record_error(self, error_message: str) -> None:
|
||||
"""Record an error from the integration"""
|
||||
self.last_error = error_message[:500] # Truncate to 500 chars
|
||||
self.updated_at = datetime.utcnow()
|
||||
439
apps/tenant-backend/app/models/assistant_template.py
Normal file
439
apps/tenant-backend/app/models/assistant_template.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
Agent Template Models for GT 2.0
|
||||
|
||||
Defines agent templates, custom builders, and MCP integration models.
|
||||
Follows the simplified hierarchy with file-based storage.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from app.models.access_group import AccessGroup, Resource
|
||||
|
||||
|
||||
class AssistantType(str, Enum):
|
||||
"""Pre-defined agent types from architecture"""
|
||||
RESEARCH = "research_assistant"
|
||||
CODING = "coding_assistant"
|
||||
CYBER_ANALYST = "cyber_analyst"
|
||||
EDUCATIONAL = "educational_tutor"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class PersonalityConfig(BaseModel):
|
||||
"""Agent personality configuration"""
|
||||
tone: str = Field(default="balanced", description="formal | balanced | casual")
|
||||
explanation_depth: str = Field(default="intermediate", description="beginner | intermediate | expert")
|
||||
interaction_style: str = Field(default="collaborative", description="teaching | collaborative | direct")
|
||||
|
||||
|
||||
class ResourcePreferences(BaseModel):
|
||||
"""Agent resource preferences"""
|
||||
primary_llm: str = Field(default="gpt-4", description="Primary LLM model")
|
||||
fallback_models: List[str] = Field(default_factory=list, description="Fallback model list")
|
||||
context_length: int = Field(default=4000, description="Maximum context length")
|
||||
temperature: float = Field(default=0.7, description="Response temperature")
|
||||
streaming_enabled: bool = Field(default=True, description="Enable streaming responses")
|
||||
|
||||
|
||||
class MemorySettings(BaseModel):
|
||||
"""Agent memory configuration"""
|
||||
conversation_retention: str = Field(default="session", description="session | temporary | permanent")
|
||||
context_window_size: int = Field(default=10, description="Number of messages to retain")
|
||||
learning_from_interactions: bool = Field(default=False, description="Learn from user interactions")
|
||||
max_memory_size_mb: int = Field(default=50, description="Maximum memory size in MB")
|
||||
|
||||
|
||||
class AssistantTemplate(BaseModel):
|
||||
"""
|
||||
Pre-configured agent template
|
||||
Stored in Resource Cluster library
|
||||
"""
|
||||
template_id: str
|
||||
name: str
|
||||
description: str
|
||||
category: AssistantType
|
||||
|
||||
# Core configuration
|
||||
system_prompt: str = Field(description="System prompt with variable substitution")
|
||||
default_capabilities: List[str] = Field(default_factory=list, description="Default capability requirements")
|
||||
|
||||
# Configurations
|
||||
personality_config: PersonalityConfig = Field(default_factory=PersonalityConfig)
|
||||
resource_preferences: ResourcePreferences = Field(default_factory=ResourcePreferences)
|
||||
memory_settings: MemorySettings = Field(default_factory=MemorySettings)
|
||||
|
||||
# Metadata
|
||||
icon_path: Optional[str] = None
|
||||
version: str = Field(default="1.0.0")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
# Access control
|
||||
required_access_groups: List[str] = Field(default_factory=list)
|
||||
minimum_role: Optional[str] = None
|
||||
|
||||
def to_instance(self, user_id: str, instance_name: str, tenant_domain: str) -> "AssistantInstance":
|
||||
"""Create an instance from this template"""
|
||||
return AssistantInstance(
|
||||
id=f"{user_id}-{instance_name}-{datetime.utcnow().timestamp()}",
|
||||
template_id=self.template_id,
|
||||
name=instance_name,
|
||||
description=f"Instance of {self.name}",
|
||||
owner_id=user_id,
|
||||
tenant_domain=tenant_domain,
|
||||
|
||||
# Copy configurations
|
||||
system_prompt=self.system_prompt,
|
||||
capabilities=self.default_capabilities.copy(),
|
||||
personality_config=self.personality_config.model_copy(),
|
||||
resource_preferences=self.resource_preferences.model_copy(),
|
||||
memory_settings=self.memory_settings.model_copy(),
|
||||
|
||||
# Instance specific
|
||||
access_group=AccessGroup.INDIVIDUAL,
|
||||
team_members=[],
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
|
||||
class AssistantInstance(Resource):
|
||||
"""
|
||||
User's instance of an agent
|
||||
Inherits from Resource for access control
|
||||
"""
|
||||
template_id: Optional[str] = Field(default=None, description="Source template if from template")
|
||||
|
||||
# Agent configuration
|
||||
system_prompt: str
|
||||
capabilities: List[str] = Field(default_factory=list)
|
||||
personality_config: PersonalityConfig = Field(default_factory=PersonalityConfig)
|
||||
resource_preferences: ResourcePreferences = Field(default_factory=ResourcePreferences)
|
||||
memory_settings: MemorySettings = Field(default_factory=MemorySettings)
|
||||
|
||||
# Resource bindings
|
||||
linked_datasets: List[str] = Field(default_factory=list, description="Linked RAG dataset IDs")
|
||||
linked_tools: List[str] = Field(default_factory=list, description="Linked tool/integration IDs")
|
||||
linked_models: List[str] = Field(default_factory=list, description="Specific model overrides")
|
||||
|
||||
# Usage tracking
|
||||
conversation_count: int = Field(default=0)
|
||||
total_messages: int = Field(default=0)
|
||||
total_tokens_used: int = Field(default=0)
|
||||
last_used: Optional[datetime] = None
|
||||
|
||||
# File storage paths (created by controller)
|
||||
config_file_path: Optional[str] = None
|
||||
memory_file_path: Optional[str] = None
|
||||
|
||||
def get_file_structure(self) -> Dict[str, str]:
|
||||
"""Get expected file structure for agent storage"""
|
||||
base_path = f"/data/{self.tenant_domain}/users/{self.owner_id}/agents/{self.id}"
|
||||
return {
|
||||
"config": f"{base_path}/config.json",
|
||||
"prompt": f"{base_path}/prompt.md",
|
||||
"capabilities": f"{base_path}/capabilities.json",
|
||||
"memory": f"{base_path}/memory/",
|
||||
"resources": f"{base_path}/resources/"
|
||||
}
|
||||
|
||||
def update_from_template(self, template: AssistantTemplate):
|
||||
"""Update instance from template (for version updates)"""
|
||||
self.system_prompt = template.system_prompt
|
||||
self.personality_config = template.personality_config.model_copy()
|
||||
self.resource_preferences = template.resource_preferences.model_copy()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def add_linked_dataset(self, dataset_id: str):
|
||||
"""Link a RAG dataset to this agent"""
|
||||
if dataset_id not in self.linked_datasets:
|
||||
self.linked_datasets.append(dataset_id)
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def remove_linked_dataset(self, dataset_id: str):
|
||||
"""Unlink a RAG dataset"""
|
||||
if dataset_id in self.linked_datasets:
|
||||
self.linked_datasets.remove(dataset_id)
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
class AssistantBuilder(BaseModel):
|
||||
"""Configuration for building custom agents"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
base_template: Optional[AssistantType] = None
|
||||
|
||||
# Custom configuration
|
||||
system_prompt: str
|
||||
personality_config: PersonalityConfig = Field(default_factory=PersonalityConfig)
|
||||
resource_preferences: ResourcePreferences = Field(default_factory=ResourcePreferences)
|
||||
memory_settings: MemorySettings = Field(default_factory=MemorySettings)
|
||||
|
||||
# Capabilities
|
||||
requested_capabilities: List[str] = Field(default_factory=list)
|
||||
required_models: List[str] = Field(default_factory=list)
|
||||
required_tools: List[str] = Field(default_factory=list)
|
||||
|
||||
def build_instance(self, user_id: str, tenant_domain: str) -> AssistantInstance:
|
||||
"""Build agent instance from configuration"""
|
||||
return AssistantInstance(
|
||||
id=f"custom-{user_id}-{datetime.utcnow().timestamp()}",
|
||||
template_id=None, # Custom build
|
||||
name=self.name,
|
||||
description=self.description or f"Custom agent by {user_id}",
|
||||
owner_id=user_id,
|
||||
tenant_domain=tenant_domain,
|
||||
resource_type="agent",
|
||||
|
||||
# Apply configurations
|
||||
system_prompt=self.system_prompt,
|
||||
capabilities=self.requested_capabilities,
|
||||
personality_config=self.personality_config,
|
||||
resource_preferences=self.resource_preferences,
|
||||
memory_settings=self.memory_settings,
|
||||
|
||||
# Default access
|
||||
access_group=AccessGroup.INDIVIDUAL,
|
||||
team_members=[],
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
|
||||
# Pre-defined templates from architecture
|
||||
BUILTIN_TEMPLATES = {
|
||||
AssistantType.RESEARCH: AssistantTemplate(
|
||||
template_id="research_assistant_v1",
|
||||
name="Research & Analysis Agent",
|
||||
description="Specialized in information synthesis and analysis with citations",
|
||||
category=AssistantType.RESEARCH,
|
||||
system_prompt="""You are a research agent specialized in information synthesis and analysis.
|
||||
Focus on providing well-sourced, analytical responses with clear reasoning.
|
||||
Always cite your sources and provide evidence for your claims.
|
||||
When uncertain, clearly state the limitations of your knowledge.""",
|
||||
default_capabilities=[
|
||||
"llm:gpt-4",
|
||||
"rag:semantic_search",
|
||||
"tools:web_search",
|
||||
"export:citations"
|
||||
],
|
||||
personality_config=PersonalityConfig(
|
||||
tone="formal",
|
||||
explanation_depth="expert",
|
||||
interaction_style="collaborative"
|
||||
),
|
||||
resource_preferences=ResourcePreferences(
|
||||
primary_llm="gpt-4",
|
||||
fallback_models=["claude-sonnet", "gpt-3.5-turbo"],
|
||||
context_length=8000,
|
||||
temperature=0.7
|
||||
),
|
||||
required_access_groups=["research_tools"]
|
||||
),
|
||||
|
||||
AssistantType.CODING: AssistantTemplate(
|
||||
template_id="coding_assistant_v1",
|
||||
name="Software Development Agent",
|
||||
description="Code quality, debugging, and development best practices",
|
||||
category=AssistantType.CODING,
|
||||
system_prompt="""You are a software development agent focused on code quality and best practices.
|
||||
Provide clear explanations, suggest improvements, and help debug issues.
|
||||
Follow the principle of clean, maintainable code.
|
||||
Always consider security implications in your suggestions.""",
|
||||
default_capabilities=[
|
||||
"llm:claude-sonnet",
|
||||
"tools:github_integration",
|
||||
"resources:documentation",
|
||||
"export:code_snippets"
|
||||
],
|
||||
personality_config=PersonalityConfig(
|
||||
tone="balanced",
|
||||
explanation_depth="intermediate",
|
||||
interaction_style="direct"
|
||||
),
|
||||
resource_preferences=ResourcePreferences(
|
||||
primary_llm="claude-sonnet",
|
||||
fallback_models=["gpt-4", "codellama"],
|
||||
context_length=16000,
|
||||
temperature=0.5
|
||||
),
|
||||
required_access_groups=["development_tools"]
|
||||
),
|
||||
|
||||
AssistantType.CYBER_ANALYST: AssistantTemplate(
|
||||
template_id="cyber_analyst_v1",
|
||||
name="Cybersecurity Analysis Agent",
|
||||
description="Threat detection, incident response, and security best practices",
|
||||
category=AssistantType.CYBER_ANALYST,
|
||||
system_prompt="""You are a cybersecurity analyst agent for threat detection and response.
|
||||
Prioritize security best practices and provide actionable recommendations.
|
||||
Consider defense-in-depth strategies and zero-trust principles.
|
||||
Always emphasize the importance of continuous monitoring and improvement.""",
|
||||
default_capabilities=[
|
||||
"llm:gpt-4",
|
||||
"tools:security_scanning",
|
||||
"resources:threat_intelligence",
|
||||
"export:security_reports"
|
||||
],
|
||||
personality_config=PersonalityConfig(
|
||||
tone="formal",
|
||||
explanation_depth="expert",
|
||||
interaction_style="direct"
|
||||
),
|
||||
resource_preferences=ResourcePreferences(
|
||||
primary_llm="gpt-4",
|
||||
fallback_models=["claude-sonnet"],
|
||||
context_length=8000,
|
||||
temperature=0.3
|
||||
),
|
||||
required_access_groups=["cybersecurity_advanced"]
|
||||
),
|
||||
|
||||
AssistantType.EDUCATIONAL: AssistantTemplate(
|
||||
template_id="educational_tutor_v1",
|
||||
name="AI Literacy Educational Agent",
|
||||
description="Critical thinking development and AI collaboration skills",
|
||||
category=AssistantType.EDUCATIONAL,
|
||||
system_prompt="""You are an educational agent focused on developing critical thinking and AI literacy.
|
||||
Use socratic questioning and encourage deep analysis of problems.
|
||||
Help students understand both the capabilities and limitations of AI.
|
||||
Foster independent thinking while teaching effective AI collaboration.""",
|
||||
default_capabilities=[
|
||||
"llm:claude-sonnet",
|
||||
"games:strategic_thinking",
|
||||
"puzzles:logic_reasoning",
|
||||
"analytics:learning_progress"
|
||||
],
|
||||
personality_config=PersonalityConfig(
|
||||
tone="casual",
|
||||
explanation_depth="beginner",
|
||||
interaction_style="teaching"
|
||||
),
|
||||
resource_preferences=ResourcePreferences(
|
||||
primary_llm="claude-sonnet",
|
||||
fallback_models=["gpt-4"],
|
||||
context_length=4000,
|
||||
temperature=0.8
|
||||
),
|
||||
required_access_groups=["ai_literacy"]
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class AssistantTemplateLibrary:
|
||||
"""
|
||||
Manages the agent template library
|
||||
Templates stored in Resource Cluster, cached locally
|
||||
"""
|
||||
|
||||
def __init__(self, resource_cluster_url: str):
|
||||
self.resource_cluster_url = resource_cluster_url
|
||||
self.cache_path = Path("/tmp/agent_templates_cache")
|
||||
self.cache_path.mkdir(exist_ok=True)
|
||||
self._templates_cache: Dict[str, AssistantTemplate] = {}
|
||||
|
||||
async def get_template(self, template_id: str) -> Optional[AssistantTemplate]:
|
||||
"""Get template by ID, using cache if available"""
|
||||
if template_id in self._templates_cache:
|
||||
return self._templates_cache[template_id]
|
||||
|
||||
# Check built-in templates
|
||||
for template_type, template in BUILTIN_TEMPLATES.items():
|
||||
if template.template_id == template_id:
|
||||
self._templates_cache[template_id] = template
|
||||
return template
|
||||
|
||||
# Would fetch from Resource Cluster in production
|
||||
return None
|
||||
|
||||
async def list_templates(
|
||||
self,
|
||||
category: Optional[AssistantType] = None,
|
||||
access_groups: Optional[List[str]] = None
|
||||
) -> List[AssistantTemplate]:
|
||||
"""List available templates with filtering"""
|
||||
templates = list(BUILTIN_TEMPLATES.values())
|
||||
|
||||
if category:
|
||||
templates = [t for t in templates if t.category == category]
|
||||
|
||||
if access_groups:
|
||||
templates = [
|
||||
t for t in templates
|
||||
if any(g in access_groups for g in t.required_access_groups)
|
||||
]
|
||||
|
||||
return templates
|
||||
|
||||
async def deploy_template(
|
||||
self,
|
||||
template_id: str,
|
||||
user_id: str,
|
||||
instance_name: str,
|
||||
tenant_domain: str,
|
||||
customizations: Optional[Dict[str, Any]] = None
|
||||
) -> AssistantInstance:
|
||||
"""Deploy template as user instance"""
|
||||
template = await self.get_template(template_id)
|
||||
if not template:
|
||||
raise ValueError(f"Template not found: {template_id}")
|
||||
|
||||
# Create instance
|
||||
instance = template.to_instance(user_id, instance_name, tenant_domain)
|
||||
|
||||
# Apply customizations
|
||||
if customizations:
|
||||
if "personality" in customizations:
|
||||
instance.personality_config = PersonalityConfig(**customizations["personality"])
|
||||
if "resources" in customizations:
|
||||
instance.resource_preferences = ResourcePreferences(**customizations["resources"])
|
||||
if "memory" in customizations:
|
||||
instance.memory_settings = MemorySettings(**customizations["memory"])
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
# API Models
|
||||
class AssistantTemplateResponse(BaseModel):
|
||||
"""API response for agent template"""
|
||||
template_id: str
|
||||
name: str
|
||||
description: str
|
||||
category: str
|
||||
required_access_groups: List[str]
|
||||
version: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class AssistantInstanceResponse(BaseModel):
|
||||
"""API response for agent instance"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
template_id: Optional[str]
|
||||
owner_id: str
|
||||
access_group: AccessGroup
|
||||
team_members: List[str]
|
||||
conversation_count: int
|
||||
last_used: Optional[datetime]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class CreateAssistantRequest(BaseModel):
|
||||
"""Request to create agent from template or custom"""
|
||||
template_id: Optional[str] = None
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
customizations: Optional[Dict[str, Any]] = None
|
||||
|
||||
# For custom agents
|
||||
system_prompt: Optional[str] = None
|
||||
personality_config: Optional[PersonalityConfig] = None
|
||||
resource_preferences: Optional[ResourcePreferences] = None
|
||||
memory_settings: Optional[MemorySettings] = None
|
||||
126
apps/tenant-backend/app/models/base.py
Normal file
126
apps/tenant-backend/app/models/base.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
GT 2.0 Base Model Classes - Service-Based Architecture
|
||||
|
||||
Provides Pydantic models for data serialization with the DuckDB service.
|
||||
No SQLAlchemy ORM dependency - pure Python/Pydantic models.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, List, Type, TypeVar
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
# Generic type for model classes
|
||||
T = TypeVar('T', bound='BaseServiceModel')
|
||||
|
||||
|
||||
class BaseServiceModel(BaseModel):
|
||||
"""
|
||||
Base model for all GT 2.0 entities using service-based architecture.
|
||||
|
||||
Replaces SQLAlchemy models with Pydantic models + DuckDB service.
|
||||
"""
|
||||
|
||||
# Pydantic v2 configuration
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
validate_assignment=True,
|
||||
arbitrary_types_allowed=True,
|
||||
use_enum_values=True
|
||||
)
|
||||
|
||||
# Standard fields for all models
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique identifier")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation timestamp")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update timestamp")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert model to dictionary"""
|
||||
return self.model_dump()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
|
||||
"""Create model instance from dictionary"""
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_row(cls: Type[T], row: Dict[str, Any]) -> T:
|
||||
"""Create model instance from database row"""
|
||||
# Convert database row to model, handling type conversions
|
||||
model_data = {}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
if field_name in row:
|
||||
value = row[field_name]
|
||||
|
||||
# Handle datetime conversion
|
||||
if field_info.annotation == datetime and isinstance(value, str):
|
||||
try:
|
||||
value = datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
value = datetime.utcnow()
|
||||
|
||||
model_data[field_name] = value
|
||||
|
||||
return cls(**model_data)
|
||||
|
||||
def update_timestamp(self):
|
||||
"""Update the updated_at timestamp"""
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
class BaseCreateModel(BaseModel):
|
||||
"""Base model for creation requests"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class BaseUpdateModel(BaseModel):
|
||||
"""Base model for update requests"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class BaseResponseModel(BaseServiceModel):
|
||||
"""Base model for API responses"""
|
||||
pass
|
||||
|
||||
|
||||
# Legacy compatibility - some files might still import Base
|
||||
Base = BaseServiceModel # For backwards compatibility during migration
|
||||
|
||||
|
||||
# Database service integration helpers
|
||||
class DatabaseMixin:
|
||||
"""Mixin providing database service integration methods"""
|
||||
|
||||
@classmethod
|
||||
async def get_table_name(cls) -> str:
|
||||
"""Get the database table name for this model"""
|
||||
# Convert CamelCase to snake_case and pluralize
|
||||
name = cls.__name__.lower()
|
||||
if name.endswith('y'):
|
||||
name = name[:-1] + 'ies'
|
||||
elif name.endswith('s'):
|
||||
name = name + 'es'
|
||||
else:
|
||||
name = name + 's'
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
async def create_sql(cls) -> str:
|
||||
"""Generate CREATE TABLE SQL for this model"""
|
||||
# This would generate SQL based on Pydantic field types
|
||||
# For now, return placeholder - actual schemas are in DuckDB service
|
||||
table_name = await cls.get_table_name()
|
||||
return f"-- CREATE TABLE {table_name} generated by DuckDB service"
|
||||
|
||||
async def to_sql_values(self) -> Dict[str, Any]:
|
||||
"""Convert model to SQL-safe values"""
|
||||
data = self.to_dict()
|
||||
|
||||
# Convert datetime objects to ISO strings
|
||||
for key, value in data.items():
|
||||
if isinstance(value, datetime):
|
||||
data[key] = value.isoformat()
|
||||
|
||||
return data
|
||||
340
apps/tenant-backend/app/models/category.py
Normal file
340
apps/tenant-backend/app/models/category.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Category Model for GT 2.0 Agent Discovery
|
||||
|
||||
Implements a simple hierarchical category system for organizing agents.
|
||||
Follows GT 2.0's principle of "Clarity Over Complexity"
|
||||
- Simple parent-child relationships
|
||||
- System categories that cannot be deleted
|
||||
- Tenant-specific and global categories
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class Category(Base):
|
||||
"""Category model for organizing agents and resources
|
||||
|
||||
GT 2.0 Design: Simple hierarchical categories without complex taxonomies
|
||||
"""
|
||||
|
||||
__tablename__ = "categories"
|
||||
|
||||
# Primary Key
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
slug = Column(String(100), unique=True, nullable=False, index=True) # URL-safe identifier
|
||||
|
||||
# Category Details
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
display_name = Column(String(100), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
icon = Column(String(10), nullable=True) # Emoji or icon code
|
||||
color = Column(String(20), nullable=True) # Hex color code for UI
|
||||
|
||||
# Hierarchy (simple parent-child)
|
||||
parent_id = Column(String(36), ForeignKey("categories.id"), nullable=True, index=True)
|
||||
|
||||
# Scope
|
||||
is_system = Column(Boolean, nullable=False, default=False) # Protected from deletion
|
||||
is_global = Column(Boolean, nullable=False, default=True) # Available to all tenants
|
||||
|
||||
# Display Order
|
||||
sort_order = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Usage Statistics (cached)
|
||||
assistant_count = Column(Integer, nullable=False, default=0)
|
||||
dataset_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
|
||||
# Relationships
|
||||
parent = relationship("Category", remote_side=[id], backref="children")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Category(id={self.id}, name='{self.name}', slug='{self.slug}')>"
|
||||
|
||||
def to_dict(self, include_children: bool = False) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API responses"""
|
||||
data = {
|
||||
"id": self.id,
|
||||
"slug": self.slug,
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"icon": self.icon,
|
||||
"color": self.color,
|
||||
"parent_id": self.parent_id,
|
||||
"is_system": self.is_system,
|
||||
"is_global": self.is_global,
|
||||
"sort_order": self.sort_order,
|
||||
"assistant_count": self.assistant_count,
|
||||
"dataset_count": self.dataset_count,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
if include_children and self.children:
|
||||
data["children"] = [child.to_dict() for child in self.children]
|
||||
|
||||
return data
|
||||
|
||||
def get_full_path(self) -> str:
|
||||
"""Get full category path (e.g., 'AI Tools > Research > Academic')"""
|
||||
if not self.parent_id:
|
||||
return self.display_name
|
||||
|
||||
# Simple recursion to build path
|
||||
parent_path = self.parent.get_full_path() if self.parent else ""
|
||||
return f"{parent_path} > {self.display_name}" if parent_path else self.display_name
|
||||
|
||||
def is_descendant_of(self, ancestor_id: int) -> bool:
|
||||
"""Check if this category is a descendant of another"""
|
||||
if not self.parent_id:
|
||||
return False
|
||||
|
||||
if self.parent_id == ancestor_id:
|
||||
return True
|
||||
|
||||
return self.parent.is_descendant_of(ancestor_id) if self.parent else False
|
||||
|
||||
def get_all_descendants(self) -> List["Category"]:
|
||||
"""Get all descendant categories"""
|
||||
descendants = []
|
||||
|
||||
if self.children:
|
||||
for child in self.children:
|
||||
descendants.append(child)
|
||||
descendants.extend(child.get_all_descendants())
|
||||
|
||||
return descendants
|
||||
|
||||
def update_counts(self, assistant_delta: int = 0, dataset_delta: int = 0) -> None:
|
||||
"""Update resource counts for this category"""
|
||||
self.assistant_count = max(0, self.assistant_count + assistant_delta)
|
||||
self.dataset_count = max(0, self.dataset_count + dataset_delta)
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
# GT 2.0 Default System Categories
|
||||
DEFAULT_CATEGORIES = [
|
||||
# Top-level categories
|
||||
{
|
||||
"slug": "research",
|
||||
"name": "Research & Analysis",
|
||||
"display_name": "Research & Analysis",
|
||||
"description": "Agents for research, analysis, and information synthesis",
|
||||
"icon": "🔍",
|
||||
"color": "#3B82F6", # Blue
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 10,
|
||||
},
|
||||
{
|
||||
"slug": "development",
|
||||
"name": "Software Development",
|
||||
"display_name": "Software Development",
|
||||
"description": "Coding, debugging, and development tools",
|
||||
"icon": "💻",
|
||||
"color": "#10B981", # Green
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 20,
|
||||
},
|
||||
{
|
||||
"slug": "cybersecurity",
|
||||
"name": "Cybersecurity",
|
||||
"display_name": "Cybersecurity",
|
||||
"description": "Security analysis, threat detection, and incident response",
|
||||
"icon": "🛡️",
|
||||
"color": "#EF4444", # Red
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 30,
|
||||
},
|
||||
{
|
||||
"slug": "education",
|
||||
"name": "Education & Training",
|
||||
"display_name": "Education & Training",
|
||||
"description": "Educational agents and AI literacy tools",
|
||||
"icon": "🎓",
|
||||
"color": "#8B5CF6", # Purple
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 40,
|
||||
},
|
||||
{
|
||||
"slug": "creative",
|
||||
"name": "Creative & Content",
|
||||
"display_name": "Creative & Content",
|
||||
"description": "Writing, design, and creative content generation",
|
||||
"icon": "✨",
|
||||
"color": "#F59E0B", # Amber
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 50,
|
||||
},
|
||||
{
|
||||
"slug": "analytics",
|
||||
"name": "Data & Analytics",
|
||||
"display_name": "Data & Analytics",
|
||||
"description": "Data analysis, visualization, and insights",
|
||||
"icon": "📊",
|
||||
"color": "#06B6D4", # Cyan
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 60,
|
||||
},
|
||||
{
|
||||
"slug": "business",
|
||||
"name": "Business & Operations",
|
||||
"display_name": "Business & Operations",
|
||||
"description": "Business analysis, planning, and operations",
|
||||
"icon": "💼",
|
||||
"color": "#64748B", # Slate
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 70,
|
||||
},
|
||||
{
|
||||
"slug": "personal",
|
||||
"name": "Personal Productivity",
|
||||
"display_name": "Personal Productivity",
|
||||
"description": "Personal agents and productivity tools",
|
||||
"icon": "🚀",
|
||||
"color": "#14B8A6", # Teal
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 80,
|
||||
},
|
||||
{
|
||||
"slug": "custom",
|
||||
"name": "Custom & Specialized",
|
||||
"display_name": "Custom & Specialized",
|
||||
"description": "Custom-built and specialized agents",
|
||||
"icon": "⚙️",
|
||||
"color": "#71717A", # Zinc
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 90,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Sub-categories (examples)
|
||||
DEFAULT_SUBCATEGORIES = [
|
||||
# Research subcategories
|
||||
{
|
||||
"slug": "research-academic",
|
||||
"name": "Academic Research",
|
||||
"display_name": "Academic Research",
|
||||
"description": "Academic papers, citations, and literature review",
|
||||
"icon": "📚",
|
||||
"parent_slug": "research", # Will be resolved to parent_id
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 11,
|
||||
},
|
||||
{
|
||||
"slug": "research-market",
|
||||
"name": "Market Research",
|
||||
"display_name": "Market Research",
|
||||
"description": "Market analysis, competitor research, and trends",
|
||||
"icon": "📈",
|
||||
"parent_slug": "research",
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 12,
|
||||
},
|
||||
|
||||
# Development subcategories
|
||||
{
|
||||
"slug": "dev-web",
|
||||
"name": "Web Development",
|
||||
"display_name": "Web Development",
|
||||
"description": "Frontend, backend, and full-stack development",
|
||||
"icon": "🌐",
|
||||
"parent_slug": "development",
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 21,
|
||||
},
|
||||
{
|
||||
"slug": "dev-mobile",
|
||||
"name": "Mobile Development",
|
||||
"display_name": "Mobile Development",
|
||||
"description": "iOS, Android, and cross-platform development",
|
||||
"icon": "📱",
|
||||
"parent_slug": "development",
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 22,
|
||||
},
|
||||
{
|
||||
"slug": "dev-devops",
|
||||
"name": "DevOps & Infrastructure",
|
||||
"display_name": "DevOps & Infrastructure",
|
||||
"description": "CI/CD, containerization, and infrastructure",
|
||||
"icon": "🔧",
|
||||
"parent_slug": "development",
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 23,
|
||||
},
|
||||
|
||||
# Cybersecurity subcategories
|
||||
{
|
||||
"slug": "cyber-analysis",
|
||||
"name": "Threat Analysis",
|
||||
"display_name": "Threat Analysis",
|
||||
"description": "Threat detection, analysis, and intelligence",
|
||||
"icon": "🔍",
|
||||
"parent_slug": "cybersecurity",
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 31,
|
||||
},
|
||||
{
|
||||
"slug": "cyber-incident",
|
||||
"name": "Incident Response",
|
||||
"display_name": "Incident Response",
|
||||
"description": "Incident handling and forensics",
|
||||
"icon": "🚨",
|
||||
"parent_slug": "cybersecurity",
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 32,
|
||||
},
|
||||
|
||||
# Education subcategories
|
||||
{
|
||||
"slug": "edu-ai-literacy",
|
||||
"name": "AI Literacy",
|
||||
"display_name": "AI Literacy",
|
||||
"description": "Understanding and working with AI systems",
|
||||
"icon": "🤖",
|
||||
"parent_slug": "education",
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 41,
|
||||
},
|
||||
{
|
||||
"slug": "edu-critical-thinking",
|
||||
"name": "Critical Thinking",
|
||||
"display_name": "Critical Thinking",
|
||||
"description": "Logic, reasoning, and problem-solving",
|
||||
"icon": "🧠",
|
||||
"parent_slug": "education",
|
||||
"is_system": True,
|
||||
"is_global": True,
|
||||
"sort_order": 42,
|
||||
},
|
||||
]
|
||||
263
apps/tenant-backend/app/models/collaboration_team.py
Normal file
263
apps/tenant-backend/app/models/collaboration_team.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Collaboration Team Models for GT 2.0 Tenant Backend
|
||||
|
||||
Pydantic models for user collaboration teams (team sharing system).
|
||||
This is separate from the tenant isolation 'tenants' table (formerly 'teams').
|
||||
|
||||
Database Schema:
|
||||
- teams: User collaboration groups within a tenant
|
||||
- team_memberships: Team members with two-tier permissions
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator
|
||||
|
||||
|
||||
class TeamBase(BaseModel):
|
||||
"""Base team model with common fields"""
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Team name")
|
||||
description: Optional[str] = Field(None, description="Team description")
|
||||
|
||||
|
||||
class TeamCreate(TeamBase):
|
||||
"""Model for creating a new team"""
|
||||
pass
|
||||
|
||||
|
||||
class TeamUpdate(BaseModel):
|
||||
"""Model for updating a team"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class TeamMember(BaseModel):
|
||||
"""Team member with permissions"""
|
||||
id: str = Field(..., description="Membership UUID")
|
||||
team_id: str = Field(..., description="Team UUID")
|
||||
user_id: str = Field(..., description="User UUID")
|
||||
user_email: str = Field(..., description="User email")
|
||||
user_name: str = Field(..., description="User display name")
|
||||
team_permission: str = Field(..., description="Team-level permission: 'read', 'share', or 'manager'")
|
||||
resource_permissions: Dict[str, str] = Field(default_factory=dict, description="Resource-level permissions JSONB")
|
||||
is_owner: bool = Field(default=False, description="Whether this member is the team owner")
|
||||
is_observable: bool = Field(default=False, description="Member consents to activity observation")
|
||||
observable_consent_status: str = Field(default="none", description="Consent status: 'none', 'pending', 'approved', 'revoked'")
|
||||
observable_consent_at: Optional[str] = Field(None, description="When Observable status was approved")
|
||||
status: str = Field(default="accepted", description="Membership status: 'pending', 'accepted', or 'declined'")
|
||||
invited_at: Optional[str] = None
|
||||
responded_at: Optional[str] = None
|
||||
joined_at: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class Team(TeamBase):
|
||||
"""Complete team model with metadata"""
|
||||
id: str = Field(..., description="Team UUID")
|
||||
tenant_id: str = Field(..., description="Tenant UUID")
|
||||
owner_id: str = Field(..., description="Owner user UUID")
|
||||
owner_name: Optional[str] = Field(None, description="Owner display name")
|
||||
owner_email: Optional[str] = Field(None, description="Owner email")
|
||||
is_owner: bool = Field(..., description="Whether current user is the owner")
|
||||
can_manage: bool = Field(..., description="Whether current user can manage the team")
|
||||
user_permission: Optional[str] = Field(None, description="Current user's team permission: 'read' or 'share' (None if owner)")
|
||||
member_count: int = Field(0, description="Number of team members")
|
||||
shared_resource_count: int = Field(0, description="Number of shared resources (agents and datasets)")
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TeamWithMembers(Team):
|
||||
"""Team with full member list"""
|
||||
members: List[TeamMember] = Field(default_factory=list, description="List of team members")
|
||||
|
||||
|
||||
class TeamListResponse(BaseModel):
|
||||
"""Response model for listing teams"""
|
||||
data: List[Team]
|
||||
total: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TeamResponse(BaseModel):
|
||||
"""Response model for single team operation"""
|
||||
data: Team
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TeamWithMembersResponse(BaseModel):
|
||||
"""Response model for team with members"""
|
||||
data: TeamWithMembers
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Team Membership Models
|
||||
|
||||
class AddMemberRequest(BaseModel):
|
||||
"""Request model for adding a member to a team"""
|
||||
user_email: str = Field(..., description="Email of user to add")
|
||||
team_permission: str = Field("read", description="Team permission: 'read', 'share', or 'manager'")
|
||||
|
||||
|
||||
class UpdateMemberPermissionRequest(BaseModel):
|
||||
"""Request model for updating member permission"""
|
||||
team_permission: str = Field(..., description="New permission: 'read', 'share', or 'manager'")
|
||||
|
||||
@field_validator('team_permission')
|
||||
@classmethod
|
||||
def validate_permission(cls, v: str) -> str:
|
||||
if v not in ["read", "share", "manager"]:
|
||||
raise ValueError(f"Invalid permission: {v}. Must be 'read', 'share', or 'manager'")
|
||||
return v
|
||||
|
||||
|
||||
class MemberListResponse(BaseModel):
|
||||
"""Response model for listing team members"""
|
||||
data: List[TeamMember]
|
||||
total: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class MemberResponse(BaseModel):
|
||||
"""Response model for single member operation"""
|
||||
data: TeamMember
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Team Invitation Models
|
||||
|
||||
class TeamInvitation(BaseModel):
|
||||
"""Pending team invitation"""
|
||||
id: str = Field(..., description="Invitation (membership) UUID")
|
||||
team_id: str = Field(..., description="Team UUID")
|
||||
team_name: str = Field(..., description="Team name")
|
||||
team_description: Optional[str] = Field(None, description="Team description")
|
||||
owner_name: str = Field(..., description="Team owner display name")
|
||||
owner_email: str = Field(..., description="Team owner email")
|
||||
team_permission: str = Field(..., description="Invited permission: 'read', 'share', or 'manager'")
|
||||
observable_requested: bool = Field(default=False, description="Whether Observable access was requested on invite")
|
||||
invited_at: str = Field(..., description="Invitation timestamp")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class InvitationActionRequest(BaseModel):
|
||||
"""Request to accept or decline invitation"""
|
||||
action: str = Field(..., description="Action: 'accept' or 'decline'")
|
||||
|
||||
|
||||
class InvitationListResponse(BaseModel):
|
||||
"""Response model for listing invitations"""
|
||||
data: List[TeamInvitation]
|
||||
total: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Resource Sharing Models
|
||||
|
||||
class ShareResourceRequest(BaseModel):
|
||||
"""Request model for sharing a resource to team"""
|
||||
resource_type: str = Field(..., description="Resource type: 'agent' or 'dataset'")
|
||||
resource_id: str = Field(..., description="Resource UUID")
|
||||
user_permissions: Dict[str, str] = Field(
|
||||
...,
|
||||
description="User permissions: {user_id: 'read'|'edit'}"
|
||||
)
|
||||
|
||||
|
||||
class SharedResource(BaseModel):
|
||||
"""Model for a shared resource"""
|
||||
resource_type: str = Field(..., description="Resource type: 'agent' or 'dataset'")
|
||||
resource_id: str = Field(..., description="Resource UUID")
|
||||
resource_name: str = Field(..., description="Resource name")
|
||||
resource_owner: str = Field(..., description="Resource owner name or email")
|
||||
user_permissions: Dict[str, str] = Field(..., description="User permissions map")
|
||||
|
||||
|
||||
class SharedResourcesResponse(BaseModel):
|
||||
"""Response model for listing shared resources"""
|
||||
data: List[SharedResource]
|
||||
total: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Observable Request Models
|
||||
|
||||
class ObservableRequest(BaseModel):
|
||||
"""Observable access request for a team member"""
|
||||
team_id: str = Field(..., description="Team UUID")
|
||||
team_name: str = Field(..., description="Team name")
|
||||
requested_by_name: str = Field(..., description="Name of manager/owner who requested")
|
||||
requested_by_email: str = Field(..., description="Email of manager/owner who requested")
|
||||
requested_at: str = Field(..., description="When request was made")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ObservableRequestListResponse(BaseModel):
|
||||
"""Response model for listing Observable requests"""
|
||||
data: List[ObservableRequest]
|
||||
total: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Team Activity Models
|
||||
|
||||
class TeamActivityMetrics(BaseModel):
|
||||
"""Team activity metrics for Observable members"""
|
||||
team_id: str
|
||||
team_name: str
|
||||
date_range_days: int
|
||||
observable_member_count: int
|
||||
total_member_count: int
|
||||
team_totals: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Aggregated metrics: conversations, messages, tokens"
|
||||
)
|
||||
member_breakdown: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Per-member activity stats"
|
||||
)
|
||||
time_series: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Activity over time"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TeamActivityResponse(BaseModel):
|
||||
"""Response model for team activity"""
|
||||
data: TeamActivityMetrics
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Error Response Models
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""Error detail model"""
|
||||
message: str
|
||||
field: Optional[str] = None
|
||||
code: Optional[str] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response model"""
|
||||
error: str
|
||||
details: Optional[List[ErrorDetail]] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
148
apps/tenant-backend/app/models/conversation.py
Normal file
148
apps/tenant-backend/app/models/conversation.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Conversation Model for GT 2.0 Tenant Backend - Service-Based Architecture
|
||||
|
||||
Pydantic models for conversation entities using the PostgreSQL + PGVector backend.
|
||||
Stores conversation metadata and settings for AI chat sessions.
|
||||
Perfect tenant isolation - each tenant has separate conversation data.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import Field, ConfigDict
|
||||
from app.models.base import BaseServiceModel, BaseCreateModel, BaseUpdateModel, BaseResponseModel
|
||||
|
||||
|
||||
class ConversationStatus(str, Enum):
|
||||
"""Conversation status enumeration"""
|
||||
ACTIVE = "active"
|
||||
ARCHIVED = "archived"
|
||||
DELETED = "deleted"
|
||||
|
||||
|
||||
class Conversation(BaseServiceModel):
|
||||
"""
|
||||
Conversation model for GT 2.0 service-based architecture.
|
||||
|
||||
Represents a chat session with an AI agent including metadata,
|
||||
configuration, and usage statistics.
|
||||
"""
|
||||
|
||||
# Core conversation properties
|
||||
title: str = Field(..., min_length=1, max_length=200, description="Conversation title")
|
||||
agent_id: Optional[str] = Field(None, description="Associated agent ID")
|
||||
|
||||
# User information
|
||||
created_by: str = Field(..., description="User email or ID who created this")
|
||||
user_name: Optional[str] = Field(None, max_length=100, description="User display name")
|
||||
|
||||
# Configuration
|
||||
system_prompt: Optional[str] = Field(None, description="Custom system prompt override")
|
||||
model_id: str = Field(default="groq:llama3-70b-8192", description="AI model identifier")
|
||||
configuration: Dict[str, Any] = Field(default_factory=dict, description="Model parameters and settings")
|
||||
|
||||
# Status and metadata
|
||||
status: ConversationStatus = Field(default=ConversationStatus.ACTIVE, description="Conversation status")
|
||||
tags: List[str] = Field(default_factory=list, description="Conversation tags")
|
||||
|
||||
# Statistics
|
||||
message_count: int = Field(default=0, description="Number of messages in conversation")
|
||||
total_tokens_used: int = Field(default=0, description="Total tokens used")
|
||||
total_cost_cents: int = Field(default=0, description="Total cost in cents")
|
||||
|
||||
# Timestamps
|
||||
last_activity_at: Optional[datetime] = Field(None, description="Last activity timestamp")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "conversations"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Conversation":
|
||||
"""Create from dictionary"""
|
||||
return cls(
|
||||
agent_id=data.get("agent_id"),
|
||||
title=data.get("title", ""),
|
||||
system_prompt=data.get("system_prompt"),
|
||||
model_id=data.get("model_id", "groq:llama3-70b-8192"),
|
||||
created_by=data.get("created_by", ""),
|
||||
user_name=data.get("user_name"),
|
||||
configuration=data.get("configuration", {}),
|
||||
tags=data.get("tags", []),
|
||||
)
|
||||
|
||||
def update_statistics(self, message_count: int, tokens_used: int, cost_cents: int) -> None:
|
||||
"""Update conversation statistics"""
|
||||
self.message_count = message_count
|
||||
self.total_tokens_used = tokens_used
|
||||
self.total_cost_cents = cost_cents
|
||||
self.last_activity_at = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
def archive(self) -> None:
|
||||
"""Archive this conversation"""
|
||||
self.status = ConversationStatus.ARCHIVED
|
||||
self.update_timestamp()
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Mark conversation as deleted"""
|
||||
self.status = ConversationStatus.DELETED
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class ConversationCreate(BaseCreateModel):
|
||||
"""Model for creating new conversations"""
|
||||
title: str = Field(..., min_length=1, max_length=200)
|
||||
agent_id: Optional[str] = None
|
||||
created_by: str
|
||||
user_name: Optional[str] = Field(None, max_length=100)
|
||||
system_prompt: Optional[str] = None
|
||||
model_id: str = Field(default="groq:llama3-70b-8192")
|
||||
configuration: Dict[str, Any] = Field(default_factory=dict)
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ConversationUpdate(BaseUpdateModel):
|
||||
"""Model for updating conversations"""
|
||||
title: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
system_prompt: Optional[str] = None
|
||||
model_id: Optional[str] = None
|
||||
configuration: Optional[Dict[str, Any]] = None
|
||||
status: Optional[ConversationStatus] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ConversationResponse(BaseResponseModel):
|
||||
"""Model for conversation API responses"""
|
||||
id: str
|
||||
title: str
|
||||
agent_id: Optional[str]
|
||||
created_by: str
|
||||
user_name: Optional[str]
|
||||
system_prompt: Optional[str]
|
||||
model_id: str
|
||||
configuration: Dict[str, Any]
|
||||
status: ConversationStatus
|
||||
tags: List[str]
|
||||
message_count: int
|
||||
total_tokens_used: int
|
||||
total_cost_cents: int
|
||||
last_activity_at: Optional[datetime]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
435
apps/tenant-backend/app/models/document.py
Normal file
435
apps/tenant-backend/app/models/document.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Document and RAG Models for GT 2.0 Tenant Backend - Service-Based Architecture
|
||||
|
||||
Pydantic models for document entities using the PostgreSQL + PGVector backend.
|
||||
Stores document metadata, RAG datasets, and processing status.
|
||||
Perfect tenant isolation - each tenant has separate document data.
|
||||
All vectors stored encrypted in tenant-specific ChromaDB.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
from pydantic import Field, ConfigDict
|
||||
from app.models.base import BaseServiceModel, BaseCreateModel, BaseUpdateModel, BaseResponseModel
|
||||
|
||||
# SQLAlchemy imports for database models
|
||||
from sqlalchemy import Column, String, Integer, BigInteger, Text, DateTime, Boolean, JSON, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.core.database import Base
|
||||
|
||||
# PGVector import for embeddings
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
except ImportError:
|
||||
# Fallback if pgvector not available
|
||||
from sqlalchemy import Text as Vector
|
||||
|
||||
|
||||
class DocumentStatus(str, Enum):
|
||||
"""Document processing status enumeration"""
|
||||
UPLOADING = "uploading"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class DocumentType(str, Enum):
|
||||
"""Document type enumeration"""
|
||||
PDF = "pdf"
|
||||
DOCX = "docx"
|
||||
TXT = "txt"
|
||||
MD = "md"
|
||||
HTML = "html"
|
||||
JSON = "json"
|
||||
CSV = "csv"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class Document(BaseServiceModel):
|
||||
"""
|
||||
Document model for GT 2.0 service-based architecture.
|
||||
|
||||
Represents a document with metadata, processing status,
|
||||
and RAG integration for knowledge retrieval.
|
||||
"""
|
||||
|
||||
# Core document properties
|
||||
filename: str = Field(..., min_length=1, max_length=255, description="Original filename")
|
||||
original_name: str = Field(..., min_length=1, max_length=255, description="User-provided name")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
mime_type: str = Field(..., max_length=100, description="MIME type of the file")
|
||||
doc_type: DocumentType = Field(..., description="Document type classification")
|
||||
|
||||
# Storage and processing
|
||||
file_path: str = Field(..., description="Storage path for the file")
|
||||
content_hash: Optional[str] = Field(None, max_length=64, description="SHA-256 hash of content")
|
||||
status: DocumentStatus = Field(default=DocumentStatus.UPLOADING, description="Processing status")
|
||||
|
||||
# Owner and access
|
||||
owner_id: str = Field(..., description="User ID of the document owner")
|
||||
dataset_id: Optional[str] = Field(None, description="Associated dataset ID")
|
||||
|
||||
# RAG and processing metadata
|
||||
content_preview: Optional[str] = Field(None, max_length=500, description="Content preview")
|
||||
extracted_text: Optional[str] = Field(None, description="Extracted text content")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
|
||||
|
||||
# Processing statistics
|
||||
chunk_count: int = Field(default=0, description="Number of chunks created")
|
||||
vector_count: int = Field(default=0, description="Number of vectors stored")
|
||||
processing_time_ms: Optional[float] = Field(None, description="Processing time in milliseconds")
|
||||
|
||||
# Errors and logs
|
||||
error_message: Optional[str] = Field(None, description="Error message if processing failed")
|
||||
processing_log: List[str] = Field(default_factory=list, description="Processing log entries")
|
||||
|
||||
# Timestamps
|
||||
processed_at: Optional[datetime] = Field(None, description="When processing completed")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "documents"
|
||||
|
||||
def mark_processing(self) -> None:
|
||||
"""Mark document as processing"""
|
||||
self.status = DocumentStatus.PROCESSING
|
||||
self.update_timestamp()
|
||||
|
||||
def mark_completed(self, chunk_count: int, vector_count: int, processing_time_ms: float) -> None:
|
||||
"""Mark document processing as completed"""
|
||||
self.status = DocumentStatus.COMPLETED
|
||||
self.chunk_count = chunk_count
|
||||
self.vector_count = vector_count
|
||||
self.processing_time_ms = processing_time_ms
|
||||
self.processed_at = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
def mark_failed(self, error_message: str) -> None:
|
||||
"""Mark document processing as failed"""
|
||||
self.status = DocumentStatus.FAILED
|
||||
self.error_message = error_message
|
||||
self.update_timestamp()
|
||||
|
||||
def add_log_entry(self, message: str) -> None:
|
||||
"""Add a processing log entry"""
|
||||
timestamp = datetime.utcnow().isoformat()
|
||||
self.processing_log.append(f"[{timestamp}] {message}")
|
||||
|
||||
|
||||
class RAGDataset(BaseServiceModel):
|
||||
"""
|
||||
RAG Dataset model for organizing documents into collections.
|
||||
|
||||
Groups related documents together for focused retrieval and
|
||||
provides dataset-level configuration and statistics.
|
||||
"""
|
||||
|
||||
# Core dataset properties
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
|
||||
description: Optional[str] = Field(None, max_length=1000, description="Dataset description")
|
||||
|
||||
# Owner and access
|
||||
owner_id: str = Field(..., description="User ID of the dataset owner")
|
||||
|
||||
# Configuration
|
||||
chunk_size: int = Field(default=1000, ge=100, le=5000, description="Default chunk size")
|
||||
chunk_overlap: int = Field(default=200, ge=0, le=1000, description="Default chunk overlap")
|
||||
embedding_model: str = Field(default="all-MiniLM-L6-v2", description="Embedding model to use")
|
||||
|
||||
# Statistics
|
||||
document_count: int = Field(default=0, description="Number of documents")
|
||||
total_chunks: int = Field(default=0, description="Total chunks across all documents")
|
||||
total_vectors: int = Field(default=0, description="Total vectors stored")
|
||||
total_size_bytes: int = Field(default=0, description="Total size of all documents")
|
||||
|
||||
# Status
|
||||
is_public: bool = Field(default=False, description="Whether dataset is publicly accessible")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "rag_datasets"
|
||||
|
||||
def update_statistics(self, doc_count: int, chunk_count: int, vector_count: int, size_bytes: int) -> None:
|
||||
"""Update dataset statistics"""
|
||||
self.document_count = doc_count
|
||||
self.total_chunks = chunk_count
|
||||
self.total_vectors = vector_count
|
||||
self.total_size_bytes = size_bytes
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class DatasetDocument(BaseServiceModel):
|
||||
"""
|
||||
Dataset-Document relationship model for GT 2.0 service-based architecture.
|
||||
|
||||
Junction table model that links documents to RAG datasets,
|
||||
tracking the relationship and statistics.
|
||||
"""
|
||||
|
||||
# Core relationship properties
|
||||
dataset_id: str = Field(..., description="RAG dataset ID")
|
||||
document_id: str = Field(..., description="Document ID")
|
||||
user_id: str = Field(..., description="User who added document to dataset")
|
||||
|
||||
# Statistics
|
||||
chunk_count: int = Field(default=0, description="Number of chunks for this document")
|
||||
vector_count: int = Field(default=0, description="Number of vectors stored for this document")
|
||||
|
||||
# Status
|
||||
processing_status: str = Field(default="pending", max_length=50, description="Processing status")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "dataset_documents"
|
||||
|
||||
|
||||
class DocumentChunk(BaseServiceModel):
|
||||
"""
|
||||
Document chunk model for processed document pieces.
|
||||
|
||||
Represents individual chunks of processed documents with
|
||||
embeddings and metadata for RAG retrieval.
|
||||
"""
|
||||
|
||||
# Core chunk properties
|
||||
document_id: str = Field(..., description="Parent document ID")
|
||||
chunk_index: int = Field(..., ge=0, description="Chunk index within document")
|
||||
chunk_text: str = Field(..., min_length=1, description="Chunk text content")
|
||||
|
||||
# Chunk metadata
|
||||
chunk_size: int = Field(..., ge=1, description="Character count of chunk")
|
||||
token_count: Optional[int] = Field(None, description="Token count for chunk")
|
||||
chunk_metadata: Dict[str, Any] = Field(default_factory=dict, description="Chunk-specific metadata")
|
||||
|
||||
# Embedding information
|
||||
embedding_id: Optional[str] = Field(None, description="Vector store embedding ID")
|
||||
embedding_model: Optional[str] = Field(None, max_length=100, description="Model used for embedding")
|
||||
|
||||
# Position and context
|
||||
start_char: Optional[int] = Field(None, description="Starting character position in document")
|
||||
end_char: Optional[int] = Field(None, description="Ending character position in document")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "document_chunks"
|
||||
|
||||
|
||||
class DocumentCreate(BaseCreateModel):
|
||||
"""Model for creating new documents"""
|
||||
filename: str = Field(..., min_length=1, max_length=255)
|
||||
original_name: str = Field(..., min_length=1, max_length=255)
|
||||
file_size: int = Field(..., ge=0)
|
||||
mime_type: str = Field(..., max_length=100)
|
||||
doc_type: DocumentType
|
||||
file_path: str
|
||||
content_hash: Optional[str] = Field(None, max_length=64)
|
||||
owner_id: str
|
||||
dataset_id: Optional[str] = None
|
||||
content_preview: Optional[str] = Field(None, max_length=500)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DocumentUpdate(BaseUpdateModel):
|
||||
"""Model for updating documents"""
|
||||
original_name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
status: Optional[DocumentStatus] = None
|
||||
dataset_id: Optional[str] = None
|
||||
content_preview: Optional[str] = Field(None, max_length=500)
|
||||
extracted_text: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
chunk_count: Optional[int] = Field(None, ge=0)
|
||||
vector_count: Optional[int] = Field(None, ge=0)
|
||||
processing_time_ms: Optional[float] = None
|
||||
error_message: Optional[str] = None
|
||||
processed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class DocumentResponse(BaseResponseModel):
|
||||
"""Model for document API responses"""
|
||||
id: str
|
||||
filename: str
|
||||
original_name: str
|
||||
file_size: int
|
||||
mime_type: str
|
||||
doc_type: DocumentType
|
||||
file_path: str
|
||||
content_hash: Optional[str]
|
||||
status: DocumentStatus
|
||||
owner_id: str
|
||||
dataset_id: Optional[str]
|
||||
content_preview: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
chunk_count: int
|
||||
vector_count: int
|
||||
processing_time_ms: Optional[float]
|
||||
error_message: Optional[str]
|
||||
processing_log: List[str]
|
||||
processed_at: Optional[datetime]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class RAGDatasetCreate(BaseCreateModel):
|
||||
"""Model for creating new RAG datasets"""
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
owner_id: str
|
||||
chunk_size: int = Field(default=1000, ge=100, le=5000)
|
||||
chunk_overlap: int = Field(default=200, ge=0, le=1000)
|
||||
embedding_model: str = Field(default="all-MiniLM-L6-v2")
|
||||
is_public: bool = Field(default=False)
|
||||
|
||||
|
||||
class RAGDatasetUpdate(BaseUpdateModel):
|
||||
"""Model for updating RAG datasets"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
chunk_size: Optional[int] = Field(None, ge=100, le=5000)
|
||||
chunk_overlap: Optional[int] = Field(None, ge=0, le=1000)
|
||||
embedding_model: Optional[str] = None
|
||||
is_public: Optional[bool] = None
|
||||
|
||||
|
||||
class RAGDatasetResponse(BaseResponseModel):
|
||||
"""Model for RAG dataset API responses"""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
owner_id: str
|
||||
chunk_size: int
|
||||
chunk_overlap: int
|
||||
embedding_model: str
|
||||
document_count: int
|
||||
total_chunks: int
|
||||
total_vectors: int
|
||||
total_size_bytes: int
|
||||
is_public: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# SQLAlchemy Database Models for PostgreSQL + PGVector
|
||||
|
||||
class Document(Base):
|
||||
"""SQLAlchemy model for documents table"""
|
||||
__tablename__ = "documents"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
dataset_id = Column(UUID(as_uuid=True), nullable=True, index=True)
|
||||
|
||||
filename = Column(String(255), nullable=False)
|
||||
original_filename = Column(String(255), nullable=False)
|
||||
file_type = Column(String(100), nullable=False)
|
||||
file_size_bytes = Column(BigInteger, nullable=False)
|
||||
file_hash = Column(String(64), nullable=True)
|
||||
|
||||
content_text = Column(Text, nullable=True)
|
||||
chunk_count = Column(Integer, default=0)
|
||||
processing_status = Column(String(50), default="pending")
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
doc_metadata = Column(JSONB, nullable=True)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
# Relationships
|
||||
chunks = relationship("DocumentChunk", back_populates="document", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
"""SQLAlchemy model for document_chunks table"""
|
||||
__tablename__ = "document_chunks"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
document_id = Column(UUID(as_uuid=True), ForeignKey("documents.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
dataset_id = Column(UUID(as_uuid=True), nullable=True, index=True)
|
||||
|
||||
chunk_index = Column(Integer, nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
content_hash = Column(String(32), nullable=True)
|
||||
token_count = Column(Integer, nullable=True)
|
||||
|
||||
# PGVector embedding column (1024 dimensions for BGE-M3)
|
||||
embedding = Column(Vector(1024), nullable=True)
|
||||
|
||||
chunk_metadata = Column(JSONB, nullable=True)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
# Relationships
|
||||
document = relationship("Document", back_populates="chunks")
|
||||
|
||||
|
||||
class Dataset(Base):
|
||||
"""SQLAlchemy model for datasets table"""
|
||||
__tablename__ = "datasets"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True) # created_by in schema
|
||||
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
chunk_size = Column(Integer, default=512)
|
||||
chunk_overlap = Column(Integer, default=128)
|
||||
embedding_model = Column(String(100), default='BAAI/bge-m3')
|
||||
search_method = Column(String(20), default='hybrid')
|
||||
specialized_language = Column(Boolean, default=False)
|
||||
|
||||
is_active = Column(Boolean, default=True)
|
||||
visibility = Column(String(20), default='individual')
|
||||
access_group = Column(String(50), default='individual')
|
||||
|
||||
dataset_metadata = Column(JSONB, nullable=True)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
283
apps/tenant-backend/app/models/event.py
Normal file
283
apps/tenant-backend/app/models/event.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Event Models for GT 2.0 Tenant Backend - Service-Based Architecture
|
||||
|
||||
Pydantic models for event entities using the PostgreSQL + PGVector backend.
|
||||
Handles event automation, triggers, and action definitions.
|
||||
Perfect tenant isolation with encrypted storage.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
from pydantic import Field, ConfigDict
|
||||
from app.models.base import BaseServiceModel, BaseCreateModel, BaseUpdateModel, BaseResponseModel
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
"""Generate a unique identifier"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class EventStatus(str, Enum):
|
||||
"""Event status enumeration"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
RETRYING = "retrying"
|
||||
|
||||
|
||||
class Event(BaseServiceModel):
|
||||
"""
|
||||
Event model for GT 2.0 service-based architecture.
|
||||
|
||||
Represents an automation event with processing status,
|
||||
payload data, and retry logic.
|
||||
"""
|
||||
|
||||
# Core event properties
|
||||
event_id: str = Field(default_factory=generate_uuid, description="Unique event identifier")
|
||||
event_type: str = Field(..., min_length=1, max_length=100, description="Event type identifier")
|
||||
user_id: str = Field(..., description="User who triggered the event")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Event data
|
||||
payload: Dict[str, Any] = Field(default_factory=dict, description="Encrypted event data")
|
||||
event_metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
|
||||
# Processing status
|
||||
status: EventStatus = Field(default=EventStatus.PENDING, description="Processing status")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
retry_count: int = Field(default=0, ge=0, description="Number of retry attempts")
|
||||
|
||||
# Timestamps
|
||||
started_at: Optional[datetime] = Field(None, description="Processing start time")
|
||||
completed_at: Optional[datetime] = Field(None, description="Processing completion time")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "events"
|
||||
|
||||
def is_completed(self) -> bool:
|
||||
"""Check if event processing is completed"""
|
||||
return self.status == EventStatus.COMPLETED
|
||||
|
||||
def is_failed(self) -> bool:
|
||||
"""Check if event processing failed"""
|
||||
return self.status == EventStatus.FAILED
|
||||
|
||||
def mark_processing(self) -> None:
|
||||
"""Mark event as processing"""
|
||||
self.status = EventStatus.PROCESSING
|
||||
self.started_at = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
def mark_completed(self) -> None:
|
||||
"""Mark event as completed"""
|
||||
self.status = EventStatus.COMPLETED
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
def mark_failed(self, error_message: str) -> None:
|
||||
"""Mark event as failed"""
|
||||
self.status = EventStatus.FAILED
|
||||
self.error_message = error_message
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
def increment_retry(self) -> None:
|
||||
"""Increment retry count"""
|
||||
self.retry_count += 1
|
||||
self.status = EventStatus.RETRYING
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class EventTrigger(BaseServiceModel):
|
||||
"""
|
||||
Event trigger model for automation conditions.
|
||||
|
||||
Defines conditions that will trigger event processing.
|
||||
"""
|
||||
|
||||
# Core trigger properties
|
||||
trigger_name: str = Field(..., min_length=1, max_length=100, description="Trigger name")
|
||||
event_type: str = Field(..., min_length=1, max_length=100, description="Event type to trigger")
|
||||
user_id: str = Field(..., description="User who owns this trigger")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Trigger configuration
|
||||
conditions: Dict[str, Any] = Field(default_factory=dict, description="Trigger conditions")
|
||||
trigger_config: Dict[str, Any] = Field(default_factory=dict, description="Trigger configuration")
|
||||
|
||||
# Status
|
||||
is_active: bool = Field(default=True, description="Whether trigger is active")
|
||||
trigger_count: int = Field(default=0, description="Number of times triggered")
|
||||
last_triggered: Optional[datetime] = Field(None, description="Last trigger timestamp")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "event_triggers"
|
||||
|
||||
|
||||
class EventAction(BaseServiceModel):
|
||||
"""
|
||||
Event action model for automation responses.
|
||||
|
||||
Defines actions to take when events are processed.
|
||||
"""
|
||||
|
||||
# Core action properties
|
||||
action_name: str = Field(..., min_length=1, max_length=100, description="Action name")
|
||||
event_type: str = Field(..., min_length=1, max_length=100, description="Event type this action handles")
|
||||
user_id: str = Field(..., description="User who owns this action")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Action configuration
|
||||
action_type: str = Field(..., min_length=1, max_length=50, description="Type of action")
|
||||
action_config: Dict[str, Any] = Field(default_factory=dict, description="Action configuration")
|
||||
|
||||
# Execution settings
|
||||
priority: int = Field(default=10, ge=1, le=100, description="Execution priority")
|
||||
timeout_seconds: int = Field(default=300, ge=1, le=3600, description="Action timeout")
|
||||
max_retries: int = Field(default=3, ge=0, le=10, description="Maximum retry attempts")
|
||||
|
||||
# Status
|
||||
is_active: bool = Field(default=True, description="Whether action is active")
|
||||
execution_count: int = Field(default=0, description="Number of times executed")
|
||||
last_executed: Optional[datetime] = Field(None, description="Last execution timestamp")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "event_actions"
|
||||
|
||||
|
||||
class EventSubscription(BaseServiceModel):
|
||||
"""
|
||||
Event subscription model for user notifications.
|
||||
|
||||
Manages user subscriptions to specific event types.
|
||||
"""
|
||||
|
||||
# Core subscription properties
|
||||
user_id: str = Field(..., description="Subscribing user ID")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
event_type: str = Field(..., min_length=1, max_length=100, description="Subscribed event type")
|
||||
|
||||
# Subscription configuration
|
||||
notification_method: str = Field(default="websocket", max_length=50, description="Notification delivery method")
|
||||
subscription_config: Dict[str, Any] = Field(default_factory=dict, description="Subscription settings")
|
||||
|
||||
# Filtering
|
||||
event_filters: Dict[str, Any] = Field(default_factory=dict, description="Event filtering criteria")
|
||||
|
||||
# Status
|
||||
is_active: bool = Field(default=True, description="Whether subscription is active")
|
||||
notification_count: int = Field(default=0, description="Number of notifications sent")
|
||||
last_notified: Optional[datetime] = Field(None, description="Last notification timestamp")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "event_subscriptions"
|
||||
|
||||
|
||||
# Create/Update/Response models
|
||||
|
||||
class EventCreate(BaseCreateModel):
|
||||
"""Model for creating new events"""
|
||||
event_type: str = Field(..., min_length=1, max_length=100)
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
payload: Dict[str, Any] = Field(default_factory=dict)
|
||||
event_metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EventUpdate(BaseUpdateModel):
|
||||
"""Model for updating events"""
|
||||
status: Optional[EventStatus] = None
|
||||
error_message: Optional[str] = None
|
||||
retry_count: Optional[int] = Field(None, ge=0)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class EventResponse(BaseResponseModel):
|
||||
"""Model for event API responses"""
|
||||
id: str
|
||||
event_id: str
|
||||
event_type: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
payload: Dict[str, Any]
|
||||
event_metadata: Dict[str, Any]
|
||||
status: EventStatus
|
||||
error_message: Optional[str]
|
||||
retry_count: int
|
||||
started_at: Optional[datetime]
|
||||
completed_at: Optional[datetime]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# Legacy compatibility - simplified versions of missing models
|
||||
class EventLog(BaseServiceModel):
|
||||
"""Minimal EventLog model for compatibility"""
|
||||
event_id: str = Field(..., description="Related event ID")
|
||||
log_message: str = Field(..., description="Log message")
|
||||
log_level: str = Field(default="info", description="Log level")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
return "event_logs"
|
||||
|
||||
|
||||
class ScheduledTask(BaseServiceModel):
|
||||
"""Minimal ScheduledTask model for compatibility"""
|
||||
task_name: str = Field(..., description="Task name")
|
||||
schedule: str = Field(..., description="Cron schedule")
|
||||
is_active: bool = Field(default=True, description="Whether task is active")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
return "scheduled_tasks"
|
||||
254
apps/tenant-backend/app/models/external_service.py
Normal file
254
apps/tenant-backend/app/models/external_service.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
External Service Models for GT 2.0 Tenant Backend - Service-Based Architecture
|
||||
|
||||
Pydantic models for external service entities using the PostgreSQL + PGVector backend.
|
||||
Manages external web services integration with SSO and iframe embedding.
|
||||
Perfect tenant isolation - each tenant has separate external service data.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
from pydantic import Field, ConfigDict
|
||||
from app.models.base import BaseServiceModel, BaseCreateModel, BaseUpdateModel, BaseResponseModel
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
"""Generate a unique identifier"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class ServiceStatus(str, Enum):
|
||||
"""Service status enumeration"""
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
MAINTENANCE = "maintenance"
|
||||
DEPRECATED = "deprecated"
|
||||
|
||||
|
||||
class AccessLevel(str, Enum):
|
||||
"""Access level enumeration"""
|
||||
PUBLIC = "public"
|
||||
AUTHENTICATED = "authenticated"
|
||||
ADMIN_ONLY = "admin_only"
|
||||
RESTRICTED = "restricted"
|
||||
|
||||
|
||||
class ExternalServiceInstance(BaseServiceModel):
|
||||
"""
|
||||
External service instance model for GT 2.0 service-based architecture.
|
||||
|
||||
Represents external web services like Canvas LMS, Jupyter Hub, CTFd
|
||||
with SSO integration and iframe embedding.
|
||||
"""
|
||||
|
||||
# Core service properties
|
||||
service_name: str = Field(..., min_length=1, max_length=100, description="Service name")
|
||||
service_type: str = Field(..., min_length=1, max_length=50, description="Service type")
|
||||
service_url: str = Field(..., description="Service URL")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Service configuration
|
||||
config: Dict[str, Any] = Field(default_factory=dict, description="Service configuration")
|
||||
auth_config: Dict[str, Any] = Field(default_factory=dict, description="Authentication configuration")
|
||||
iframe_config: Dict[str, Any] = Field(default_factory=dict, description="Iframe embedding configuration")
|
||||
|
||||
# Service details
|
||||
description: Optional[str] = Field(None, max_length=500, description="Service description")
|
||||
version: str = Field(default="1.0.0", max_length=50, description="Service version")
|
||||
provider: str = Field(..., max_length=100, description="Service provider")
|
||||
|
||||
# Access control
|
||||
access_level: AccessLevel = Field(default=AccessLevel.AUTHENTICATED, description="Access level required")
|
||||
allowed_users: List[str] = Field(default_factory=list, description="Allowed user IDs")
|
||||
allowed_roles: List[str] = Field(default_factory=list, description="Allowed user roles")
|
||||
|
||||
# Status and monitoring
|
||||
status: ServiceStatus = Field(default=ServiceStatus.ACTIVE, description="Service status")
|
||||
health_check_url: Optional[str] = Field(None, description="Health check endpoint")
|
||||
last_health_check: Optional[datetime] = Field(None, description="Last health check timestamp")
|
||||
is_healthy: bool = Field(default=True, description="Health status")
|
||||
|
||||
# Usage statistics
|
||||
total_access_count: int = Field(default=0, description="Total access count")
|
||||
active_user_count: int = Field(default=0, description="Current active users")
|
||||
last_accessed: Optional[datetime] = Field(None, description="Last access timestamp")
|
||||
|
||||
# Metadata
|
||||
tags: List[str] = Field(default_factory=list, description="Service tags")
|
||||
category: str = Field(default="general", max_length=50, description="Service category")
|
||||
priority: int = Field(default=10, ge=1, le=100, description="Display priority")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "external_service_instances"
|
||||
|
||||
def activate(self) -> None:
|
||||
"""Activate the service"""
|
||||
self.status = ServiceStatus.ACTIVE
|
||||
self.update_timestamp()
|
||||
|
||||
def deactivate(self) -> None:
|
||||
"""Deactivate the service"""
|
||||
self.status = ServiceStatus.INACTIVE
|
||||
self.update_timestamp()
|
||||
|
||||
def record_access(self, user_id: str) -> None:
|
||||
"""Record service access"""
|
||||
self.total_access_count += 1
|
||||
self.last_accessed = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
def update_health_status(self, is_healthy: bool) -> None:
|
||||
"""Update health status"""
|
||||
self.is_healthy = is_healthy
|
||||
self.last_health_check = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class ServiceAccessLog(BaseServiceModel):
|
||||
"""
|
||||
Service access log model for tracking usage and security.
|
||||
|
||||
Logs all access attempts to external services for auditing.
|
||||
"""
|
||||
|
||||
# Core access properties
|
||||
service_id: str = Field(..., description="External service instance ID")
|
||||
user_id: str = Field(..., description="User who accessed the service")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Access details
|
||||
access_type: str = Field(..., max_length=50, description="Type of access")
|
||||
ip_address: Optional[str] = Field(None, max_length=45, description="User IP address")
|
||||
user_agent: Optional[str] = Field(None, max_length=500, description="User agent string")
|
||||
|
||||
# Session information
|
||||
session_id: Optional[str] = Field(None, description="User session ID")
|
||||
session_duration_seconds: Optional[int] = Field(None, description="Session duration")
|
||||
|
||||
# Access result
|
||||
access_granted: bool = Field(default=True, description="Whether access was granted")
|
||||
denial_reason: Optional[str] = Field(None, description="Reason for access denial")
|
||||
|
||||
# Additional metadata
|
||||
referrer_url: Optional[str] = Field(None, description="Referrer URL")
|
||||
access_metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional access data")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "service_access_logs"
|
||||
|
||||
|
||||
class ServiceTemplate(BaseServiceModel):
|
||||
"""
|
||||
Service template model for reusable service configurations.
|
||||
|
||||
Defines templates for common external service integrations.
|
||||
"""
|
||||
|
||||
# Core template properties
|
||||
template_name: str = Field(..., min_length=1, max_length=100, description="Template name")
|
||||
service_type: str = Field(..., min_length=1, max_length=50, description="Service type")
|
||||
template_description: str = Field(..., max_length=500, description="Template description")
|
||||
|
||||
# Template configuration
|
||||
default_config: Dict[str, Any] = Field(default_factory=dict, description="Default service configuration")
|
||||
default_auth_config: Dict[str, Any] = Field(default_factory=dict, description="Default auth configuration")
|
||||
default_iframe_config: Dict[str, Any] = Field(default_factory=dict, description="Default iframe configuration")
|
||||
|
||||
# Template metadata
|
||||
version: str = Field(default="1.0.0", max_length=50, description="Template version")
|
||||
provider: str = Field(..., max_length=100, description="Service provider")
|
||||
supported_versions: List[str] = Field(default_factory=list, description="Supported service versions")
|
||||
|
||||
# Documentation
|
||||
setup_instructions: Optional[str] = Field(None, description="Setup instructions")
|
||||
configuration_schema: Dict[str, Any] = Field(default_factory=dict, description="Configuration schema")
|
||||
example_config: Dict[str, Any] = Field(default_factory=dict, description="Example configuration")
|
||||
|
||||
# Template status
|
||||
is_active: bool = Field(default=True, description="Whether template is active")
|
||||
is_verified: bool = Field(default=False, description="Whether template is verified")
|
||||
usage_count: int = Field(default=0, description="Number of times used")
|
||||
|
||||
# Access control
|
||||
is_public: bool = Field(default=True, description="Whether template is publicly available")
|
||||
created_by: str = Field(..., description="Creator of the template")
|
||||
tenant_id: Optional[str] = Field(None, description="Tenant ID if tenant-specific")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "service_templates"
|
||||
|
||||
def increment_usage(self) -> None:
|
||||
"""Increment usage count"""
|
||||
self.usage_count += 1
|
||||
self.update_timestamp()
|
||||
|
||||
def verify_template(self) -> None:
|
||||
"""Mark template as verified"""
|
||||
self.is_verified = True
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
# Create/Update/Response models - minimal for now
|
||||
|
||||
class ExternalServiceInstanceCreate(BaseCreateModel):
|
||||
"""Model for creating external service instances"""
|
||||
service_name: str = Field(..., min_length=1, max_length=100)
|
||||
service_type: str = Field(..., min_length=1, max_length=50)
|
||||
service_url: str
|
||||
tenant_id: str
|
||||
provider: str = Field(..., max_length=100)
|
||||
|
||||
|
||||
class ExternalServiceInstanceUpdate(BaseUpdateModel):
|
||||
"""Model for updating external service instances"""
|
||||
service_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
service_url: Optional[str] = None
|
||||
status: Optional[ServiceStatus] = None
|
||||
is_healthy: Optional[bool] = None
|
||||
|
||||
|
||||
class ExternalServiceInstanceResponse(BaseResponseModel):
|
||||
"""Model for external service instance API responses"""
|
||||
id: str
|
||||
service_name: str
|
||||
service_type: str
|
||||
service_url: str
|
||||
tenant_id: str
|
||||
provider: str
|
||||
status: ServiceStatus
|
||||
is_healthy: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
383
apps/tenant-backend/app/models/game.py
Normal file
383
apps/tenant-backend/app/models/game.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
Game Models for GT 2.0 Tenant Backend - Service-Based Architecture
|
||||
|
||||
Pydantic models for game entities using the PostgreSQL + PGVector backend.
|
||||
Game sessions for AI literacy and strategic thinking development.
|
||||
Perfect tenant isolation - each tenant has separate game data.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
from pydantic import Field, ConfigDict
|
||||
from app.models.base import BaseServiceModel, BaseCreateModel, BaseUpdateModel, BaseResponseModel
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
"""Generate a unique identifier"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class GameType(str, Enum):
|
||||
"""Game type enumeration"""
|
||||
CHESS = "chess"
|
||||
GO = "go"
|
||||
LOGIC_PUZZLE = "logic_puzzle"
|
||||
PHILOSOPHICAL_DILEMMA = "philosophical_dilemma"
|
||||
TRIVIA = "trivia"
|
||||
DEBATE = "debate"
|
||||
|
||||
|
||||
class DifficultyLevel(str, Enum):
|
||||
"""Difficulty level enumeration"""
|
||||
BEGINNER = "beginner"
|
||||
INTERMEDIATE = "intermediate"
|
||||
ADVANCED = "advanced"
|
||||
EXPERT = "expert"
|
||||
|
||||
|
||||
class GameStatus(str, Enum):
|
||||
"""Game status enumeration"""
|
||||
ACTIVE = "active"
|
||||
COMPLETED = "completed"
|
||||
PAUSED = "paused"
|
||||
ABANDONED = "abandoned"
|
||||
|
||||
|
||||
class GameSession(BaseServiceModel):
|
||||
"""
|
||||
Game session model for GT 2.0 service-based architecture.
|
||||
|
||||
Represents AI literacy and strategic thinking game sessions
|
||||
with progress tracking and skill development.
|
||||
"""
|
||||
|
||||
# Core game properties
|
||||
user_id: str = Field(..., description="User playing the game")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
game_type: GameType = Field(..., description="Type of game")
|
||||
game_name: str = Field(..., min_length=1, max_length=100, description="Game name")
|
||||
|
||||
# Game configuration
|
||||
difficulty_level: DifficultyLevel = Field(default=DifficultyLevel.INTERMEDIATE, description="Difficulty level")
|
||||
ai_opponent_config: Dict[str, Any] = Field(default_factory=dict, description="AI opponent settings")
|
||||
game_rules: Dict[str, Any] = Field(default_factory=dict, description="Game-specific rules")
|
||||
|
||||
# Game state
|
||||
current_state: Dict[str, Any] = Field(default_factory=dict, description="Current game state")
|
||||
move_history: List[Dict[str, Any]] = Field(default_factory=list, description="History of moves")
|
||||
game_status: GameStatus = Field(default=GameStatus.ACTIVE, description="Game status")
|
||||
|
||||
# Progress tracking
|
||||
moves_count: int = Field(default=0, description="Number of moves made")
|
||||
hints_used: int = Field(default=0, description="Number of hints used")
|
||||
time_spent_seconds: int = Field(default=0, description="Time spent in seconds")
|
||||
current_rating: int = Field(default=1200, description="ELO-style rating")
|
||||
|
||||
# Results
|
||||
winner: Optional[str] = Field(None, description="Winner of the game")
|
||||
final_score: Optional[Dict[str, Any]] = Field(None, description="Final score details")
|
||||
learning_insights: List[str] = Field(default_factory=list, description="Learning insights")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "game_sessions"
|
||||
|
||||
def add_move(self, move_data: Dict[str, Any]) -> None:
|
||||
"""Add a move to the game history"""
|
||||
self.move_history.append(move_data)
|
||||
self.moves_count += 1
|
||||
self.update_timestamp()
|
||||
|
||||
def use_hint(self) -> None:
|
||||
"""Record hint usage"""
|
||||
self.hints_used += 1
|
||||
self.update_timestamp()
|
||||
|
||||
def complete_game(self, winner: str, final_score: Dict[str, Any]) -> None:
|
||||
"""Mark game as completed"""
|
||||
self.game_status = GameStatus.COMPLETED
|
||||
self.winner = winner
|
||||
self.final_score = final_score
|
||||
self.update_timestamp()
|
||||
|
||||
def pause_game(self) -> None:
|
||||
"""Pause the game"""
|
||||
self.game_status = GameStatus.PAUSED
|
||||
self.update_timestamp()
|
||||
|
||||
def resume_game(self) -> None:
|
||||
"""Resume the game"""
|
||||
self.game_status = GameStatus.ACTIVE
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class PuzzleSession(BaseServiceModel):
|
||||
"""
|
||||
Puzzle session model for logic and problem-solving games.
|
||||
|
||||
Tracks puzzle-specific metrics and progress.
|
||||
"""
|
||||
|
||||
# Core puzzle properties
|
||||
user_id: str = Field(..., description="User solving the puzzle")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
puzzle_type: str = Field(..., max_length=50, description="Type of puzzle")
|
||||
puzzle_name: str = Field(..., min_length=1, max_length=100, description="Puzzle name")
|
||||
|
||||
# Puzzle configuration
|
||||
difficulty_level: DifficultyLevel = Field(default=DifficultyLevel.INTERMEDIATE, description="Difficulty level")
|
||||
puzzle_data: Dict[str, Any] = Field(default_factory=dict, description="Puzzle configuration")
|
||||
solution_data: Dict[str, Any] = Field(default_factory=dict, description="Solution information")
|
||||
|
||||
# Progress tracking
|
||||
attempts_made: int = Field(default=0, description="Number of attempts")
|
||||
hints_requested: int = Field(default=0, description="Hints requested")
|
||||
is_solved: bool = Field(default=False, description="Whether puzzle is solved")
|
||||
solve_time_seconds: Optional[int] = Field(None, description="Time to solve")
|
||||
|
||||
# Learning metrics
|
||||
skill_points_earned: int = Field(default=0, description="Skill points earned")
|
||||
concepts_learned: List[str] = Field(default_factory=list, description="Concepts learned")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "puzzle_sessions"
|
||||
|
||||
def add_attempt(self) -> None:
|
||||
"""Record a puzzle attempt"""
|
||||
self.attempts_made += 1
|
||||
self.update_timestamp()
|
||||
|
||||
def solve_puzzle(self, solve_time: int, skill_points: int) -> None:
|
||||
"""Mark puzzle as solved"""
|
||||
self.is_solved = True
|
||||
self.solve_time_seconds = solve_time
|
||||
self.skill_points_earned = skill_points
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class PhilosophicalDialogue(BaseServiceModel):
|
||||
"""
|
||||
Philosophical dialogue model for ethical and critical thinking development.
|
||||
|
||||
Tracks philosophical discussions and thinking development.
|
||||
"""
|
||||
|
||||
# Core dialogue properties
|
||||
user_id: str = Field(..., description="User participating in dialogue")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
dialogue_topic: str = Field(..., min_length=1, max_length=200, description="Dialogue topic")
|
||||
dialogue_type: str = Field(..., max_length=50, description="Type of philosophical dialogue")
|
||||
|
||||
# Dialogue configuration
|
||||
ai_persona: str = Field(default="socratic", max_length=50, description="AI dialogue persona")
|
||||
dialogue_style: str = Field(default="questioning", max_length=50, description="Dialogue style")
|
||||
target_concepts: List[str] = Field(default_factory=list, description="Target concepts to explore")
|
||||
|
||||
# Dialogue content
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list, description="Dialogue messages")
|
||||
key_insights: List[str] = Field(default_factory=list, description="Key insights generated")
|
||||
|
||||
# Progress metrics
|
||||
turns_count: int = Field(default=0, description="Number of dialogue turns")
|
||||
depth_score: float = Field(default=0.0, description="Depth of philosophical exploration")
|
||||
critical_thinking_score: float = Field(default=0.0, description="Critical thinking score")
|
||||
|
||||
# Status
|
||||
is_completed: bool = Field(default=False, description="Whether dialogue is completed")
|
||||
completion_reason: Optional[str] = Field(None, description="Reason for completion")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "philosophical_dialogues"
|
||||
|
||||
def add_message(self, message_data: Dict[str, Any]) -> None:
|
||||
"""Add a message to the dialogue"""
|
||||
self.messages.append(message_data)
|
||||
self.turns_count += 1
|
||||
self.update_timestamp()
|
||||
|
||||
def complete_dialogue(self, reason: str) -> None:
|
||||
"""Mark dialogue as completed"""
|
||||
self.is_completed = True
|
||||
self.completion_reason = reason
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class LearningAnalytics(BaseServiceModel):
|
||||
"""
|
||||
Learning analytics model for tracking educational progress.
|
||||
|
||||
Aggregates learning data across all game types.
|
||||
"""
|
||||
|
||||
# Core analytics properties
|
||||
user_id: str = Field(..., description="User being analyzed")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Skill tracking
|
||||
chess_rating: int = Field(default=1200, description="Chess skill rating")
|
||||
go_rating: int = Field(default=1200, description="Go skill rating")
|
||||
puzzle_solving_level: int = Field(default=1, description="Puzzle solving level")
|
||||
critical_thinking_level: int = Field(default=1, description="Critical thinking level")
|
||||
|
||||
# Activity metrics
|
||||
total_games_played: int = Field(default=0, description="Total games played")
|
||||
total_puzzles_solved: int = Field(default=0, description="Total puzzles solved")
|
||||
total_dialogues_completed: int = Field(default=0, description="Total dialogues completed")
|
||||
total_time_spent_hours: float = Field(default=0.0, description="Total time spent in hours")
|
||||
|
||||
# Learning metrics
|
||||
concepts_mastered: List[str] = Field(default_factory=list, description="Mastered concepts")
|
||||
learning_streaks: Dict[str, int] = Field(default_factory=dict, description="Learning streaks")
|
||||
achievement_badges: List[str] = Field(default_factory=list, description="Achievement badges")
|
||||
|
||||
# Progress tracking
|
||||
last_activity_date: Optional[datetime] = Field(None, description="Last activity date")
|
||||
learning_goals: List[Dict[str, Any]] = Field(default_factory=list, description="Learning goals")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "learning_analytics"
|
||||
|
||||
def update_activity(self) -> None:
|
||||
"""Update last activity timestamp"""
|
||||
self.last_activity_date = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
def earn_badge(self, badge_name: str) -> None:
|
||||
"""Earn an achievement badge"""
|
||||
if badge_name not in self.achievement_badges:
|
||||
self.achievement_badges.append(badge_name)
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class GameTemplate(BaseServiceModel):
|
||||
"""
|
||||
Game template model for configuring game types and rules.
|
||||
|
||||
Defines reusable game configurations and templates.
|
||||
"""
|
||||
|
||||
# Core template properties
|
||||
template_name: str = Field(..., min_length=1, max_length=100, description="Template name")
|
||||
game_type: GameType = Field(..., description="Game type")
|
||||
template_description: str = Field(..., max_length=500, description="Template description")
|
||||
|
||||
# Template configuration
|
||||
default_rules: Dict[str, Any] = Field(default_factory=dict, description="Default game rules")
|
||||
ai_configurations: List[Dict[str, Any]] = Field(default_factory=list, description="AI opponent configs")
|
||||
difficulty_settings: Dict[str, Any] = Field(default_factory=dict, description="Difficulty settings")
|
||||
|
||||
# Educational content
|
||||
learning_objectives: List[str] = Field(default_factory=list, description="Learning objectives")
|
||||
skill_categories: List[str] = Field(default_factory=list, description="Skill categories")
|
||||
educational_notes: Optional[str] = Field(None, description="Educational notes")
|
||||
|
||||
# Template metadata
|
||||
created_by: str = Field(..., description="Creator of the template")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
is_public: bool = Field(default=False, description="Whether template is publicly available")
|
||||
usage_count: int = Field(default=0, description="Number of times used")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "game_templates"
|
||||
|
||||
def increment_usage(self) -> None:
|
||||
"""Increment usage count"""
|
||||
self.usage_count += 1
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
# Create/Update/Response models
|
||||
|
||||
class GameSessionCreate(BaseCreateModel):
|
||||
"""Model for creating new game sessions"""
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
game_type: GameType
|
||||
game_name: str = Field(..., min_length=1, max_length=100)
|
||||
difficulty_level: DifficultyLevel = Field(default=DifficultyLevel.INTERMEDIATE)
|
||||
ai_opponent_config: Dict[str, Any] = Field(default_factory=dict)
|
||||
game_rules: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class GameSessionUpdate(BaseUpdateModel):
|
||||
"""Model for updating game sessions"""
|
||||
current_state: Optional[Dict[str, Any]] = None
|
||||
game_status: Optional[GameStatus] = None
|
||||
time_spent_seconds: Optional[int] = Field(None, ge=0)
|
||||
current_rating: Optional[int] = Field(None, ge=0, le=3000)
|
||||
winner: Optional[str] = None
|
||||
final_score: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class GameSessionResponse(BaseResponseModel):
|
||||
"""Model for game session API responses"""
|
||||
id: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
game_type: GameType
|
||||
game_name: str
|
||||
difficulty_level: DifficultyLevel
|
||||
current_state: Dict[str, Any]
|
||||
move_history: List[Dict[str, Any]]
|
||||
game_status: GameStatus
|
||||
moves_count: int
|
||||
hints_used: int
|
||||
time_spent_seconds: int
|
||||
current_rating: int
|
||||
winner: Optional[str]
|
||||
final_score: Optional[Dict[str, Any]]
|
||||
learning_insights: List[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
123
apps/tenant-backend/app/models/message.py
Normal file
123
apps/tenant-backend/app/models/message.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Message Model for GT 2.0 Tenant Backend - Service-Based Architecture
|
||||
|
||||
Pydantic models for message entities using the PostgreSQL + PGVector backend.
|
||||
Stores individual messages within conversations with full context tracking.
|
||||
Perfect tenant isolation - each tenant has separate message data.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import Field, ConfigDict
|
||||
from app.models.base import BaseServiceModel, BaseCreateModel, BaseUpdateModel, BaseResponseModel
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Message role enumeration"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
AGENT = "agent"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
class Message(BaseServiceModel):
|
||||
"""
|
||||
Message model for GT 2.0 service-based architecture.
|
||||
|
||||
Represents a single message within a conversation including content,
|
||||
role, metadata, and usage statistics.
|
||||
"""
|
||||
|
||||
# Core message properties
|
||||
conversation_id: str = Field(..., description="ID of the parent conversation")
|
||||
role: MessageRole = Field(..., description="Message role (system, user, agent, tool)")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
# Optional metadata
|
||||
model_used: Optional[str] = Field(None, description="AI model used for generation")
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = Field(default_factory=list, description="Tool calls made")
|
||||
tool_call_id: Optional[str] = Field(None, description="Tool call ID if this is a tool response")
|
||||
|
||||
# Usage statistics
|
||||
tokens_used: int = Field(default=0, description="Tokens consumed by this message")
|
||||
cost_cents: int = Field(default=0, description="Cost in cents for this message")
|
||||
|
||||
# Processing metadata
|
||||
processing_time_ms: Optional[float] = Field(None, description="Time taken to process this message")
|
||||
temperature: Optional[float] = Field(None, description="Temperature used for generation")
|
||||
max_tokens: Optional[int] = Field(None, description="Max tokens setting used")
|
||||
|
||||
# Status
|
||||
is_edited: bool = Field(default=False, description="Whether message was edited")
|
||||
is_deleted: bool = Field(default=False, description="Whether message was deleted")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "messages"
|
||||
|
||||
def mark_edited(self) -> None:
|
||||
"""Mark message as edited"""
|
||||
self.is_edited = True
|
||||
self.update_timestamp()
|
||||
|
||||
def mark_deleted(self) -> None:
|
||||
"""Mark message as deleted"""
|
||||
self.is_deleted = True
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class MessageCreate(BaseCreateModel):
|
||||
"""Model for creating new messages"""
|
||||
conversation_id: str
|
||||
role: MessageRole
|
||||
content: str
|
||||
model_used: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = Field(default_factory=list)
|
||||
tool_call_id: Optional[str] = None
|
||||
tokens_used: int = Field(default=0)
|
||||
cost_cents: int = Field(default=0)
|
||||
processing_time_ms: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class MessageUpdate(BaseUpdateModel):
|
||||
"""Model for updating messages"""
|
||||
content: Optional[str] = None
|
||||
is_edited: Optional[bool] = None
|
||||
is_deleted: Optional[bool] = None
|
||||
|
||||
|
||||
class MessageResponse(BaseResponseModel):
|
||||
"""Model for message API responses"""
|
||||
id: str
|
||||
conversation_id: str
|
||||
role: MessageRole
|
||||
content: str
|
||||
model_used: Optional[str]
|
||||
tool_calls: List[Dict[str, Any]]
|
||||
tool_call_id: Optional[str]
|
||||
tokens_used: int
|
||||
cost_cents: int
|
||||
processing_time_ms: Optional[float]
|
||||
temperature: Optional[float]
|
||||
max_tokens: Optional[int]
|
||||
is_edited: bool
|
||||
is_deleted: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
309
apps/tenant-backend/app/models/team.py
Normal file
309
apps/tenant-backend/app/models/team.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Team and Organization Models for GT 2.0 Tenant Backend - Service-Based Architecture
|
||||
|
||||
Pydantic models for team entities using the PostgreSQL + PGVector backend.
|
||||
Implements team-based collaboration with file-based isolation.
|
||||
Follows GT 2.0's principle of "Elegant Simplicity Through Intelligent Architecture"
|
||||
- File-based team configurations with PostgreSQL reference tracking
|
||||
- Perfect tenant isolation - each tenant has separate team data
|
||||
- Zero complexity addition through simple file structures
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
import uuid
|
||||
import os
|
||||
import json
|
||||
|
||||
from pydantic import Field, ConfigDict
|
||||
from app.models.base import BaseServiceModel, BaseCreateModel, BaseUpdateModel, BaseResponseModel
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
"""Generate a unique identifier"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TeamType(str, Enum):
|
||||
"""Team type enumeration"""
|
||||
DEPARTMENT = "department"
|
||||
PROJECT = "project"
|
||||
CROSS_FUNCTIONAL = "cross_functional"
|
||||
|
||||
|
||||
class RoleType(str, Enum):
|
||||
"""Role type enumeration"""
|
||||
OWNER = "owner"
|
||||
ADMIN = "admin"
|
||||
MEMBER = "member"
|
||||
VIEWER = "viewer"
|
||||
|
||||
|
||||
class Team(BaseServiceModel):
|
||||
"""
|
||||
Team model for GT 2.0 service-based architecture.
|
||||
|
||||
GT 2.0 Design: Teams are lightweight DuckDB references to file-based configurations.
|
||||
Team data is stored in encrypted files, not complex database relationships.
|
||||
"""
|
||||
|
||||
# Team identifier
|
||||
team_uuid: str = Field(default_factory=generate_uuid, description="Unique team identifier")
|
||||
|
||||
# Team details
|
||||
name: str = Field(..., min_length=1, max_length=200, description="Team name")
|
||||
description: Optional[str] = Field(None, max_length=1000, description="Team description")
|
||||
team_type: TeamType = Field(default=TeamType.PROJECT, description="Team type")
|
||||
|
||||
# File-based configuration reference
|
||||
config_file_path: str = Field(..., description="Path to team config.json")
|
||||
members_file_path: str = Field(..., description="Path to members.json")
|
||||
|
||||
# Owner and access
|
||||
created_by: str = Field(..., description="User who created this team")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Team settings
|
||||
is_active: bool = Field(default=True, description="Whether team is active")
|
||||
is_public: bool = Field(default=False, description="Whether team is publicly visible")
|
||||
max_members: int = Field(default=50, ge=1, le=1000, description="Maximum team members")
|
||||
|
||||
# Statistics
|
||||
member_count: int = Field(default=0, description="Current member count")
|
||||
resource_count: int = Field(default=0, description="Number of shared resources")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "teams"
|
||||
|
||||
def get_config_path(self) -> str:
|
||||
"""Get the full path to team configuration file"""
|
||||
return self.config_file_path
|
||||
|
||||
def get_members_path(self) -> str:
|
||||
"""Get the full path to team members file"""
|
||||
return self.members_file_path
|
||||
|
||||
def activate(self) -> None:
|
||||
"""Activate the team"""
|
||||
self.is_active = True
|
||||
self.update_timestamp()
|
||||
|
||||
def deactivate(self) -> None:
|
||||
"""Deactivate the team"""
|
||||
self.is_active = False
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class TeamRole(BaseServiceModel):
|
||||
"""
|
||||
Team role model for user permissions within teams.
|
||||
|
||||
Manages role-based access control for team resources.
|
||||
"""
|
||||
|
||||
# Core role properties
|
||||
team_id: str = Field(..., description="Team ID")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
role_type: RoleType = Field(..., description="Role type")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Role configuration
|
||||
permissions: Dict[str, bool] = Field(default_factory=dict, description="Role permissions")
|
||||
custom_permissions: Dict[str, Any] = Field(default_factory=dict, description="Custom permissions")
|
||||
|
||||
# Role details
|
||||
assigned_by: str = Field(..., description="User who assigned this role")
|
||||
role_description: Optional[str] = Field(None, max_length=500, description="Role description")
|
||||
|
||||
# Status
|
||||
is_active: bool = Field(default=True, description="Whether role is active")
|
||||
expires_at: Optional[datetime] = Field(None, description="Role expiration time")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "team_roles"
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if role is expired"""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
def has_permission(self, permission: str) -> bool:
|
||||
"""Check if role has specific permission"""
|
||||
return self.permissions.get(permission, False)
|
||||
|
||||
def grant_permission(self, permission: str) -> None:
|
||||
"""Grant a permission to this role"""
|
||||
self.permissions[permission] = True
|
||||
self.update_timestamp()
|
||||
|
||||
def revoke_permission(self, permission: str) -> None:
|
||||
"""Revoke a permission from this role"""
|
||||
self.permissions[permission] = False
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class OrganizationSettings(BaseServiceModel):
|
||||
"""
|
||||
Organization settings model for tenant-wide configuration.
|
||||
|
||||
Manages organization-level settings and policies.
|
||||
"""
|
||||
|
||||
# Organization details
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
organization_name: str = Field(..., min_length=1, max_length=200, description="Organization name")
|
||||
organization_domain: str = Field(..., description="Organization domain")
|
||||
|
||||
# Organization settings
|
||||
settings: Dict[str, Any] = Field(default_factory=dict, description="Organization settings")
|
||||
branding: Dict[str, Any] = Field(default_factory=dict, description="Branding configuration")
|
||||
|
||||
# Team policies
|
||||
default_team_settings: Dict[str, Any] = Field(default_factory=dict, description="Default team settings")
|
||||
team_creation_policy: str = Field(default="admin_only", description="Who can create teams")
|
||||
max_teams_per_user: int = Field(default=10, ge=1, le=100, description="Max teams per user")
|
||||
|
||||
# Security policies
|
||||
security_settings: Dict[str, Any] = Field(default_factory=dict, description="Security settings")
|
||||
data_retention_days: int = Field(default=365, ge=30, le=2555, description="Data retention period")
|
||||
|
||||
# Feature flags
|
||||
features_enabled: Dict[str, bool] = Field(default_factory=dict, description="Enabled features")
|
||||
|
||||
# Contact and billing
|
||||
admin_email: Optional[str] = Field(None, description="Primary admin email")
|
||||
billing_contact: Optional[str] = Field(None, description="Billing contact email")
|
||||
|
||||
# Status
|
||||
is_active: bool = Field(default=True, description="Whether organization is active")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "organization_settings"
|
||||
|
||||
def is_feature_enabled(self, feature: str) -> bool:
|
||||
"""Check if a feature is enabled"""
|
||||
return self.features_enabled.get(feature, False)
|
||||
|
||||
def enable_feature(self, feature: str) -> None:
|
||||
"""Enable a feature"""
|
||||
self.features_enabled[feature] = True
|
||||
self.update_timestamp()
|
||||
|
||||
def disable_feature(self, feature: str) -> None:
|
||||
"""Disable a feature"""
|
||||
self.features_enabled[feature] = False
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
# Create/Update/Response models
|
||||
|
||||
class TeamCreate(BaseCreateModel):
|
||||
"""Model for creating new teams"""
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
team_type: TeamType = Field(default=TeamType.PROJECT)
|
||||
created_by: str
|
||||
tenant_id: str
|
||||
is_public: bool = Field(default=False)
|
||||
max_members: int = Field(default=50, ge=1, le=1000)
|
||||
|
||||
|
||||
class TeamUpdate(BaseUpdateModel):
|
||||
"""Model for updating teams"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
team_type: Optional[TeamType] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_public: Optional[bool] = None
|
||||
max_members: Optional[int] = Field(None, ge=1, le=1000)
|
||||
|
||||
|
||||
class TeamResponse(BaseResponseModel):
|
||||
"""Model for team API responses"""
|
||||
id: str
|
||||
team_uuid: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
team_type: TeamType
|
||||
config_file_path: str
|
||||
members_file_path: str
|
||||
created_by: str
|
||||
tenant_id: str
|
||||
is_active: bool
|
||||
is_public: bool
|
||||
max_members: int
|
||||
member_count: int
|
||||
resource_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class TeamRoleCreate(BaseCreateModel):
|
||||
"""Model for creating team roles"""
|
||||
team_id: str
|
||||
user_id: str
|
||||
role_type: RoleType
|
||||
tenant_id: str
|
||||
assigned_by: str
|
||||
permissions: Dict[str, bool] = Field(default_factory=dict)
|
||||
role_description: Optional[str] = Field(None, max_length=500)
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class TeamRoleUpdate(BaseUpdateModel):
|
||||
"""Model for updating team roles"""
|
||||
role_type: Optional[RoleType] = None
|
||||
permissions: Optional[Dict[str, bool]] = None
|
||||
custom_permissions: Optional[Dict[str, Any]] = None
|
||||
role_description: Optional[str] = Field(None, max_length=500)
|
||||
is_active: Optional[bool] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class TeamRoleResponse(BaseResponseModel):
|
||||
"""Model for team role API responses"""
|
||||
id: str
|
||||
team_id: str
|
||||
user_id: str
|
||||
role_type: RoleType
|
||||
tenant_id: str
|
||||
permissions: Dict[str, bool]
|
||||
custom_permissions: Dict[str, Any]
|
||||
assigned_by: str
|
||||
role_description: Optional[str]
|
||||
is_active: bool
|
||||
expires_at: Optional[datetime]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
146
apps/tenant-backend/app/models/user_session.py
Normal file
146
apps/tenant-backend/app/models/user_session.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
User Session Model for GT 2.0 Tenant Backend - Service-Based Architecture
|
||||
|
||||
Pydantic models for user session entities using the PostgreSQL + PGVector backend.
|
||||
Stores user session data and authentication state.
|
||||
Perfect tenant isolation - each tenant has separate session data.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import Field, ConfigDict
|
||||
from app.models.base import BaseServiceModel, BaseCreateModel, BaseUpdateModel, BaseResponseModel
|
||||
|
||||
|
||||
class SessionStatus(str, Enum):
|
||||
"""Session status enumeration"""
|
||||
ACTIVE = "active"
|
||||
EXPIRED = "expired"
|
||||
REVOKED = "revoked"
|
||||
|
||||
|
||||
class UserSession(BaseServiceModel):
|
||||
"""
|
||||
User session model for GT 2.0 service-based architecture.
|
||||
|
||||
Represents a user authentication session with state management,
|
||||
preferences, and activity tracking.
|
||||
"""
|
||||
|
||||
# Core session properties
|
||||
session_id: str = Field(..., description="Unique session identifier")
|
||||
user_id: str = Field(..., description="User ID (email or unique identifier)")
|
||||
user_email: Optional[str] = Field(None, max_length=255, description="User email address")
|
||||
user_name: Optional[str] = Field(None, max_length=100, description="User display name")
|
||||
|
||||
# Authentication details
|
||||
auth_provider: str = Field(default="jwt", max_length=50, description="Authentication provider")
|
||||
auth_method: str = Field(default="bearer", max_length=50, description="Authentication method")
|
||||
|
||||
# Session lifecycle
|
||||
status: SessionStatus = Field(default=SessionStatus.ACTIVE, description="Session status")
|
||||
expires_at: datetime = Field(..., description="Session expiration time")
|
||||
last_activity_at: datetime = Field(default_factory=datetime.utcnow, description="Last activity timestamp")
|
||||
|
||||
# User preferences and state
|
||||
preferences: Dict[str, Any] = Field(default_factory=dict, description="User preferences")
|
||||
session_data: Dict[str, Any] = Field(default_factory=dict, description="Session-specific data")
|
||||
|
||||
# Activity tracking
|
||||
login_ip: Optional[str] = Field(None, max_length=45, description="Login IP address")
|
||||
user_agent: Optional[str] = Field(None, max_length=500, description="User agent string")
|
||||
activity_count: int = Field(default=1, description="Number of activities in this session")
|
||||
|
||||
# Security
|
||||
csrf_token: Optional[str] = Field(None, max_length=64, description="CSRF protection token")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "user_sessions"
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if session is expired"""
|
||||
return datetime.utcnow() > self.expires_at or self.status != SessionStatus.ACTIVE
|
||||
|
||||
def extend_session(self, minutes: int = 30) -> None:
|
||||
"""Extend session expiration time"""
|
||||
if self.status == SessionStatus.ACTIVE:
|
||||
self.expires_at = datetime.utcnow() + timedelta(minutes=minutes)
|
||||
self.update_timestamp()
|
||||
|
||||
def update_activity(self) -> None:
|
||||
"""Update last activity timestamp"""
|
||||
self.last_activity_at = datetime.utcnow()
|
||||
self.activity_count += 1
|
||||
self.update_timestamp()
|
||||
|
||||
def revoke(self) -> None:
|
||||
"""Revoke the session"""
|
||||
self.status = SessionStatus.REVOKED
|
||||
self.update_timestamp()
|
||||
|
||||
def expire(self) -> None:
|
||||
"""Mark session as expired"""
|
||||
self.status = SessionStatus.EXPIRED
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class UserSessionCreate(BaseCreateModel):
|
||||
"""Model for creating new user sessions"""
|
||||
session_id: str
|
||||
user_id: str
|
||||
user_email: Optional[str] = Field(None, max_length=255)
|
||||
user_name: Optional[str] = Field(None, max_length=100)
|
||||
auth_provider: str = Field(default="jwt", max_length=50)
|
||||
auth_method: str = Field(default="bearer", max_length=50)
|
||||
expires_at: datetime
|
||||
preferences: Dict[str, Any] = Field(default_factory=dict)
|
||||
session_data: Dict[str, Any] = Field(default_factory=dict)
|
||||
login_ip: Optional[str] = Field(None, max_length=45)
|
||||
user_agent: Optional[str] = Field(None, max_length=500)
|
||||
csrf_token: Optional[str] = Field(None, max_length=64)
|
||||
|
||||
|
||||
class UserSessionUpdate(BaseUpdateModel):
|
||||
"""Model for updating user sessions"""
|
||||
user_email: Optional[str] = Field(None, max_length=255)
|
||||
user_name: Optional[str] = Field(None, max_length=100)
|
||||
status: Optional[SessionStatus] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
session_data: Optional[Dict[str, Any]] = None
|
||||
activity_count: Optional[int] = Field(None, ge=0)
|
||||
csrf_token: Optional[str] = Field(None, max_length=64)
|
||||
|
||||
|
||||
class UserSessionResponse(BaseResponseModel):
|
||||
"""Model for user session API responses"""
|
||||
id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
user_email: Optional[str]
|
||||
user_name: Optional[str]
|
||||
auth_provider: str
|
||||
auth_method: str
|
||||
status: SessionStatus
|
||||
expires_at: datetime
|
||||
last_activity_at: datetime
|
||||
preferences: Dict[str, Any]
|
||||
session_data: Dict[str, Any]
|
||||
login_ip: Optional[str]
|
||||
user_agent: Optional[str]
|
||||
activity_count: int
|
||||
csrf_token: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
603
apps/tenant-backend/app/models/workflow.py
Normal file
603
apps/tenant-backend/app/models/workflow.py
Normal file
@@ -0,0 +1,603 @@
|
||||
"""
|
||||
Workflow Models for GT 2.0 Tenant Backend - Service-Based Architecture
|
||||
|
||||
Pydantic models for workflow entities using the PostgreSQL + PGVector backend.
|
||||
Stores workflow definitions, executions, triggers, and chat sessions.
|
||||
Perfect tenant isolation - each tenant has separate workflow data.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import Field, ConfigDict
|
||||
from app.models.base import BaseServiceModel, BaseCreateModel, BaseUpdateModel, BaseResponseModel
|
||||
|
||||
|
||||
class WorkflowStatus(str, Enum):
|
||||
"""Workflow status enumeration"""
|
||||
DRAFT = "draft"
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class TriggerType(str, Enum):
|
||||
"""Trigger type enumeration"""
|
||||
MANUAL = "manual"
|
||||
WEBHOOK = "webhook"
|
||||
CRON = "cron"
|
||||
EVENT = "event"
|
||||
API = "api"
|
||||
|
||||
|
||||
class InteractionMode(str, Enum):
|
||||
"""Interaction mode enumeration"""
|
||||
CHAT = "chat"
|
||||
BUTTON = "button"
|
||||
FORM = "form"
|
||||
DASHBOARD = "dashboard"
|
||||
API = "api"
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
"""Execution status enumeration"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class Workflow(BaseServiceModel):
|
||||
"""
|
||||
Workflow model for GT 2.0 service-based architecture.
|
||||
|
||||
Represents an agentic workflow with nodes, triggers, and execution logic.
|
||||
Supports chat interfaces, form inputs, API endpoints, and dashboard views.
|
||||
"""
|
||||
|
||||
# Basic workflow properties
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
user_id: str = Field(..., description="User who owns this workflow")
|
||||
name: str = Field(..., min_length=1, max_length=200, description="Workflow name")
|
||||
description: Optional[str] = Field(None, max_length=1000, description="Workflow description")
|
||||
|
||||
# Workflow definition as JSON structure
|
||||
definition: Dict[str, Any] = Field(..., description="Nodes, edges, and configuration")
|
||||
|
||||
# Triggers and interaction modes
|
||||
triggers: List[Dict[str, Any]] = Field(default_factory=list, description="Webhook, cron, event triggers")
|
||||
interaction_modes: List[InteractionMode] = Field(default_factory=list, description="UI interaction modes")
|
||||
|
||||
# Resource references - ensuring user owns all resources
|
||||
agent_ids: List[str] = Field(default_factory=list, description="Referenced agents")
|
||||
api_key_ids: List[str] = Field(default_factory=list, description="Referenced API keys")
|
||||
webhook_ids: List[str] = Field(default_factory=list, description="Referenced webhooks")
|
||||
dataset_ids: List[str] = Field(default_factory=list, description="Referenced datasets")
|
||||
integration_ids: List[str] = Field(default_factory=list, description="Referenced integrations")
|
||||
|
||||
# Workflow configuration
|
||||
config: Dict[str, Any] = Field(default_factory=dict, description="Runtime configuration")
|
||||
timeout_seconds: int = Field(default=300, ge=1, le=3600, description="Execution timeout (5 min default)")
|
||||
max_retries: int = Field(default=3, ge=0, le=10, description="Maximum retry attempts")
|
||||
|
||||
# Status and metadata
|
||||
status: WorkflowStatus = Field(default=WorkflowStatus.DRAFT, description="Workflow status")
|
||||
execution_count: int = Field(default=0, description="Total execution count")
|
||||
last_executed: Optional[datetime] = Field(None, description="Last execution timestamp")
|
||||
|
||||
# Analytics
|
||||
total_tokens_used: int = Field(default=0, description="Total tokens consumed")
|
||||
total_cost_cents: int = Field(default=0, description="Total cost in cents")
|
||||
average_execution_time_ms: Optional[int] = Field(None, description="Average execution time")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "workflows"
|
||||
|
||||
def activate(self) -> None:
|
||||
"""Activate the workflow"""
|
||||
self.status = WorkflowStatus.ACTIVE
|
||||
self.update_timestamp()
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the workflow"""
|
||||
self.status = WorkflowStatus.PAUSED
|
||||
self.update_timestamp()
|
||||
|
||||
def archive(self) -> None:
|
||||
"""Archive the workflow"""
|
||||
self.status = WorkflowStatus.ARCHIVED
|
||||
self.update_timestamp()
|
||||
|
||||
def update_execution_stats(self, tokens_used: int, cost_cents: int, execution_time_ms: int) -> None:
|
||||
"""Update execution statistics"""
|
||||
self.execution_count += 1
|
||||
self.total_tokens_used += tokens_used
|
||||
self.total_cost_cents += cost_cents
|
||||
self.last_executed = datetime.utcnow()
|
||||
|
||||
# Update rolling average execution time
|
||||
if self.average_execution_time_ms is None:
|
||||
self.average_execution_time_ms = execution_time_ms
|
||||
else:
|
||||
# Simple moving average
|
||||
self.average_execution_time_ms = int(
|
||||
(self.average_execution_time_ms * (self.execution_count - 1) + execution_time_ms) / self.execution_count
|
||||
)
|
||||
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class WorkflowExecution(BaseServiceModel):
|
||||
"""
|
||||
Workflow execution model for tracking individual workflow runs.
|
||||
|
||||
Stores execution state, progress, timing, and resource usage.
|
||||
"""
|
||||
|
||||
# Core execution properties
|
||||
workflow_id: str = Field(..., description="Parent workflow ID")
|
||||
user_id: str = Field(..., description="User who triggered execution")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Execution state
|
||||
status: ExecutionStatus = Field(default=ExecutionStatus.PENDING, description="Execution status")
|
||||
current_node_id: Optional[str] = Field(None, description="Currently executing node")
|
||||
progress_percentage: int = Field(default=0, ge=0, le=100, description="Execution progress")
|
||||
|
||||
# Data and context
|
||||
input_data: Dict[str, Any] = Field(default_factory=dict, description="Execution input data")
|
||||
output_data: Dict[str, Any] = Field(default_factory=dict, description="Execution output data")
|
||||
execution_trace: List[Dict[str, Any]] = Field(default_factory=list, description="Step-by-step log")
|
||||
error_details: Optional[str] = Field(None, description="Error details if failed")
|
||||
|
||||
# Timing and performance
|
||||
started_at: datetime = Field(default_factory=datetime.utcnow, description="Execution start time")
|
||||
completed_at: Optional[datetime] = Field(None, description="Execution completion time")
|
||||
duration_ms: Optional[int] = Field(None, description="Execution duration in milliseconds")
|
||||
|
||||
# Resource usage
|
||||
tokens_used: int = Field(default=0, description="Tokens consumed")
|
||||
cost_cents: int = Field(default=0, description="Cost in cents")
|
||||
tool_calls_count: int = Field(default=0, description="Number of tool calls made")
|
||||
|
||||
# Trigger information
|
||||
trigger_type: Optional[TriggerType] = Field(None, description="How execution was triggered")
|
||||
trigger_data: Dict[str, Any] = Field(default_factory=dict, description="Trigger-specific data")
|
||||
trigger_source: Optional[str] = Field(None, description="Source identifier for trigger")
|
||||
|
||||
# Session information for chat mode
|
||||
session_id: Optional[str] = Field(None, description="Chat session ID if applicable")
|
||||
interaction_mode: Optional[InteractionMode] = Field(None, description="User interaction mode")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "workflow_executions"
|
||||
|
||||
def mark_running(self, current_node_id: str) -> None:
|
||||
"""Mark execution as running"""
|
||||
self.status = ExecutionStatus.RUNNING
|
||||
self.current_node_id = current_node_id
|
||||
self.update_timestamp()
|
||||
|
||||
def mark_completed(self, output_data: Dict[str, Any]) -> None:
|
||||
"""Mark execution as completed"""
|
||||
self.status = ExecutionStatus.COMPLETED
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.output_data = output_data
|
||||
self.progress_percentage = 100
|
||||
|
||||
if self.started_at:
|
||||
self.duration_ms = int((self.completed_at - self.started_at).total_seconds() * 1000)
|
||||
|
||||
self.update_timestamp()
|
||||
|
||||
def mark_failed(self, error_details: str) -> None:
|
||||
"""Mark execution as failed"""
|
||||
self.status = ExecutionStatus.FAILED
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.error_details = error_details
|
||||
|
||||
if self.started_at:
|
||||
self.duration_ms = int((self.completed_at - self.started_at).total_seconds() * 1000)
|
||||
|
||||
self.update_timestamp()
|
||||
|
||||
def add_trace_entry(self, node_id: str, action: str, data: Dict[str, Any]) -> None:
|
||||
"""Add entry to execution trace"""
|
||||
trace_entry = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"node_id": node_id,
|
||||
"action": action,
|
||||
"data": data
|
||||
}
|
||||
self.execution_trace.append(trace_entry)
|
||||
|
||||
|
||||
class WorkflowTrigger(BaseServiceModel):
|
||||
"""
|
||||
Workflow trigger model for automated workflow execution.
|
||||
|
||||
Supports webhook, cron, event, and API triggers.
|
||||
"""
|
||||
|
||||
# Core trigger properties
|
||||
workflow_id: str = Field(..., description="Parent workflow ID")
|
||||
user_id: str = Field(..., description="User who owns this trigger")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Trigger configuration
|
||||
trigger_type: TriggerType = Field(..., description="Type of trigger")
|
||||
trigger_config: Dict[str, Any] = Field(..., description="Trigger-specific configuration")
|
||||
|
||||
# Webhook-specific fields
|
||||
webhook_url: Optional[str] = Field(None, description="Generated webhook URL")
|
||||
webhook_secret: Optional[str] = Field(None, max_length=128, description="Webhook signature secret")
|
||||
|
||||
# Cron-specific fields
|
||||
cron_schedule: Optional[str] = Field(None, max_length=100, description="Cron expression")
|
||||
timezone: str = Field(default="UTC", max_length=50, description="Timezone for cron schedule")
|
||||
|
||||
# Event-specific fields
|
||||
event_source: Optional[str] = Field(None, max_length=100, description="Event source system")
|
||||
event_filters: Dict[str, Any] = Field(default_factory=dict, description="Event filtering criteria")
|
||||
|
||||
# Status and metadata
|
||||
is_active: bool = Field(default=True, description="Whether trigger is active")
|
||||
trigger_count: int = Field(default=0, description="Number of times triggered")
|
||||
last_triggered: Optional[datetime] = Field(None, description="Last trigger timestamp")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "workflow_triggers"
|
||||
|
||||
def activate(self) -> None:
|
||||
"""Activate the trigger"""
|
||||
self.is_active = True
|
||||
self.update_timestamp()
|
||||
|
||||
def deactivate(self) -> None:
|
||||
"""Deactivate the trigger"""
|
||||
self.is_active = False
|
||||
self.update_timestamp()
|
||||
|
||||
def record_trigger(self) -> None:
|
||||
"""Record a trigger event"""
|
||||
self.trigger_count += 1
|
||||
self.last_triggered = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class WorkflowSession(BaseServiceModel):
|
||||
"""
|
||||
Workflow session model for chat-based workflow interactions.
|
||||
|
||||
Manages conversational state for workflow chat interfaces.
|
||||
"""
|
||||
|
||||
# Core session properties
|
||||
workflow_id: str = Field(..., description="Parent workflow ID")
|
||||
user_id: str = Field(..., description="User participating in session")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Session configuration
|
||||
session_type: str = Field(default="chat", max_length=50, description="Session type")
|
||||
session_state: Dict[str, Any] = Field(default_factory=dict, description="Current conversation state")
|
||||
|
||||
# Chat history
|
||||
message_count: int = Field(default=0, description="Number of messages in session")
|
||||
last_message_at: Optional[datetime] = Field(None, description="Last message timestamp")
|
||||
|
||||
# Status
|
||||
is_active: bool = Field(default=True, description="Whether session is active")
|
||||
expires_at: Optional[datetime] = Field(None, description="Session expiration time")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "workflow_sessions"
|
||||
|
||||
def add_message(self) -> None:
|
||||
"""Record a new message in the session"""
|
||||
self.message_count += 1
|
||||
self.last_message_at = datetime.utcnow()
|
||||
self.update_timestamp()
|
||||
|
||||
def close_session(self) -> None:
|
||||
"""Close the session"""
|
||||
self.is_active = False
|
||||
self.update_timestamp()
|
||||
|
||||
|
||||
class WorkflowMessage(BaseServiceModel):
|
||||
"""
|
||||
Workflow message model for chat session messages.
|
||||
|
||||
Stores individual messages within workflow chat sessions.
|
||||
"""
|
||||
|
||||
# Core message properties
|
||||
session_id: str = Field(..., description="Parent session ID")
|
||||
workflow_id: str = Field(..., description="Parent workflow ID")
|
||||
execution_id: Optional[str] = Field(None, description="Associated execution ID")
|
||||
user_id: str = Field(..., description="User who sent/received message")
|
||||
tenant_id: str = Field(..., description="Tenant domain identifier")
|
||||
|
||||
# Message content
|
||||
role: str = Field(..., max_length=20, description="Message role (user, agent, system)")
|
||||
content: str = Field(..., description="Message content")
|
||||
message_type: str = Field(default="text", max_length=50, description="Message type")
|
||||
|
||||
# Agent information for agent messages
|
||||
agent_id: Optional[str] = Field(None, description="Agent that generated this message")
|
||||
confidence_score: Optional[int] = Field(None, ge=0, le=100, description="Agent confidence (0-100)")
|
||||
|
||||
# Additional data
|
||||
message_metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional message data")
|
||||
tokens_used: int = Field(default=0, description="Tokens consumed for this message")
|
||||
|
||||
# Model configuration
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_table_name(cls) -> str:
|
||||
"""Get the database table name"""
|
||||
return "workflow_messages"
|
||||
|
||||
|
||||
# Create/Update/Response models for each entity
|
||||
|
||||
class WorkflowCreate(BaseCreateModel):
|
||||
"""Model for creating new workflows"""
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
definition: Dict[str, Any]
|
||||
triggers: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
interaction_modes: List[InteractionMode] = Field(default_factory=list)
|
||||
agent_ids: List[str] = Field(default_factory=list)
|
||||
api_key_ids: List[str] = Field(default_factory=list)
|
||||
webhook_ids: List[str] = Field(default_factory=list)
|
||||
dataset_ids: List[str] = Field(default_factory=list)
|
||||
integration_ids: List[str] = Field(default_factory=list)
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
timeout_seconds: int = Field(default=300, ge=1, le=3600)
|
||||
max_retries: int = Field(default=3, ge=0, le=10)
|
||||
|
||||
|
||||
class WorkflowUpdate(BaseUpdateModel):
|
||||
"""Model for updating workflows"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
definition: Optional[Dict[str, Any]] = None
|
||||
triggers: Optional[List[Dict[str, Any]]] = None
|
||||
interaction_modes: Optional[List[InteractionMode]] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
timeout_seconds: Optional[int] = Field(None, ge=1, le=3600)
|
||||
max_retries: Optional[int] = Field(None, ge=0, le=10)
|
||||
status: Optional[WorkflowStatus] = None
|
||||
|
||||
|
||||
class WorkflowResponse(BaseResponseModel):
|
||||
"""Model for workflow API responses"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
definition: Dict[str, Any]
|
||||
triggers: List[Dict[str, Any]]
|
||||
interaction_modes: List[InteractionMode]
|
||||
agent_ids: List[str]
|
||||
api_key_ids: List[str]
|
||||
webhook_ids: List[str]
|
||||
dataset_ids: List[str]
|
||||
integration_ids: List[str]
|
||||
config: Dict[str, Any]
|
||||
timeout_seconds: int
|
||||
max_retries: int
|
||||
status: WorkflowStatus
|
||||
execution_count: int
|
||||
last_executed: Optional[datetime]
|
||||
total_tokens_used: int
|
||||
total_cost_cents: int
|
||||
average_execution_time_ms: Optional[int]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class WorkflowExecutionCreate(BaseCreateModel):
|
||||
"""Model for creating new workflow executions"""
|
||||
workflow_id: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
input_data: Dict[str, Any] = Field(default_factory=dict)
|
||||
trigger_type: Optional[TriggerType] = None
|
||||
trigger_data: Dict[str, Any] = Field(default_factory=dict)
|
||||
trigger_source: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
interaction_mode: Optional[InteractionMode] = None
|
||||
|
||||
|
||||
class WorkflowExecutionUpdate(BaseUpdateModel):
|
||||
"""Model for updating workflow executions"""
|
||||
status: Optional[ExecutionStatus] = None
|
||||
current_node_id: Optional[str] = None
|
||||
progress_percentage: Optional[int] = Field(None, ge=0, le=100)
|
||||
output_data: Optional[Dict[str, Any]] = None
|
||||
error_details: Optional[str] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
tokens_used: Optional[int] = Field(None, ge=0)
|
||||
cost_cents: Optional[int] = Field(None, ge=0)
|
||||
tool_calls_count: Optional[int] = Field(None, ge=0)
|
||||
|
||||
|
||||
class WorkflowExecutionResponse(BaseResponseModel):
|
||||
"""Model for workflow execution API responses"""
|
||||
id: str
|
||||
workflow_id: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
status: ExecutionStatus
|
||||
current_node_id: Optional[str]
|
||||
progress_percentage: int
|
||||
input_data: Dict[str, Any]
|
||||
output_data: Dict[str, Any]
|
||||
execution_trace: List[Dict[str, Any]]
|
||||
error_details: Optional[str]
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime]
|
||||
duration_ms: Optional[int]
|
||||
tokens_used: int
|
||||
cost_cents: int
|
||||
tool_calls_count: int
|
||||
trigger_type: Optional[TriggerType]
|
||||
trigger_data: Dict[str, Any]
|
||||
trigger_source: Optional[str]
|
||||
session_id: Optional[str]
|
||||
interaction_mode: Optional[InteractionMode]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# Node type definitions for workflow canvas
|
||||
WORKFLOW_NODE_TYPES = {
|
||||
"agent": {
|
||||
"name": "Agent",
|
||||
"description": "Execute an AI Agent with personality",
|
||||
"inputs": ["text", "context"],
|
||||
"outputs": ["response", "confidence"],
|
||||
"config_schema": {
|
||||
"agent_id": {"type": "string", "required": True},
|
||||
"confidence_threshold": {"type": "integer", "default": 70},
|
||||
"max_tokens": {"type": "integer", "default": 2000},
|
||||
"temperature": {"type": "number", "default": 0.7}
|
||||
}
|
||||
},
|
||||
"trigger": {
|
||||
"name": "Trigger",
|
||||
"description": "Start workflow execution",
|
||||
"inputs": [],
|
||||
"outputs": ["trigger_data"],
|
||||
"subtypes": ["webhook", "cron", "event", "manual", "api"],
|
||||
"config_schema": {
|
||||
"trigger_type": {"type": "string", "required": True}
|
||||
}
|
||||
},
|
||||
"integration": {
|
||||
"name": "Integration",
|
||||
"description": "Connect to external services",
|
||||
"inputs": ["data"],
|
||||
"outputs": ["response"],
|
||||
"subtypes": ["api", "database", "storage", "webhook"],
|
||||
"config_schema": {
|
||||
"integration_type": {"type": "string", "required": True},
|
||||
"api_key_id": {"type": "string"},
|
||||
"endpoint_url": {"type": "string"},
|
||||
"method": {"type": "string", "default": "GET"}
|
||||
}
|
||||
},
|
||||
"logic": {
|
||||
"name": "Logic",
|
||||
"description": "Control flow and data transformation",
|
||||
"inputs": ["data"],
|
||||
"outputs": ["result"],
|
||||
"subtypes": ["decision", "loop", "transform", "aggregate", "filter"],
|
||||
"config_schema": {
|
||||
"logic_type": {"type": "string", "required": True}
|
||||
}
|
||||
},
|
||||
"output": {
|
||||
"name": "Output",
|
||||
"description": "Send results to external systems",
|
||||
"inputs": ["data"],
|
||||
"outputs": [],
|
||||
"subtypes": ["webhook", "api", "email", "storage", "notification"],
|
||||
"config_schema": {
|
||||
"output_type": {"type": "string", "required": True}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Interaction mode configurations
|
||||
INTERACTION_MODE_CONFIGS = {
|
||||
"chat": {
|
||||
"name": "Chat Interface",
|
||||
"description": "Conversational interaction with workflow",
|
||||
"supports_streaming": True,
|
||||
"supports_history": True,
|
||||
"ui_components": ["chat_input", "message_history", "agent_avatars"]
|
||||
},
|
||||
"button": {
|
||||
"name": "Button Trigger",
|
||||
"description": "Simple one-click workflow execution",
|
||||
"supports_streaming": False,
|
||||
"supports_history": False,
|
||||
"ui_components": ["trigger_button", "progress_indicator", "result_display"]
|
||||
},
|
||||
"form": {
|
||||
"name": "Form Input",
|
||||
"description": "Structured input with validation",
|
||||
"supports_streaming": False,
|
||||
"supports_history": True,
|
||||
"ui_components": ["dynamic_form", "validation", "submit_button"]
|
||||
},
|
||||
"dashboard": {
|
||||
"name": "Dashboard View",
|
||||
"description": "Overview of workflow status and metrics",
|
||||
"supports_streaming": True,
|
||||
"supports_history": True,
|
||||
"ui_components": ["metrics_cards", "execution_history", "status_indicators"]
|
||||
},
|
||||
"api": {
|
||||
"name": "API Endpoint",
|
||||
"description": "Programmatic access to workflow",
|
||||
"supports_streaming": True,
|
||||
"supports_history": False,
|
||||
"ui_components": []
|
||||
}
|
||||
}
|
||||
154
apps/tenant-backend/app/schemas/agent.py
Normal file
154
apps/tenant-backend/app/schemas/agent.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Agent schemas for GT 2.0 Tenant Backend
|
||||
|
||||
Pydantic models for agent-related API request/response validation.
|
||||
Implements comprehensive agent management per CLAUDE.md specifications.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class AgentTemplate(BaseModel):
|
||||
"""Agent template information"""
|
||||
id: str = Field(..., description="Template identifier")
|
||||
name: str = Field(..., description="Template display name")
|
||||
description: str = Field(..., description="Template description")
|
||||
icon: str = Field(..., description="Template icon emoji or URL")
|
||||
category: str = Field(..., description="Template category")
|
||||
prompt: str = Field(..., description="System prompt template")
|
||||
default_capabilities: List[str] = Field(default_factory=list, description="Default capability grants")
|
||||
personality_config: Dict[str, Any] = Field(default_factory=dict, description="Personality configuration")
|
||||
resource_preferences: Dict[str, Any] = Field(default_factory=dict, description="Resource preferences")
|
||||
|
||||
|
||||
class AgentTemplateListResponse(BaseModel):
|
||||
"""Response for listing agent templates"""
|
||||
templates: List[AgentTemplate]
|
||||
categories: List[str] = Field(default_factory=list, description="Available categories")
|
||||
total: int
|
||||
|
||||
|
||||
class AgentCreate(BaseModel):
|
||||
"""Request to create a new agent"""
|
||||
name: str = Field(..., description="Agent name")
|
||||
description: Optional[str] = Field(None, description="Agent description")
|
||||
template_id: Optional[str] = Field(None, description="Template ID to use")
|
||||
category: Optional[str] = Field(None, description="Agent category")
|
||||
prompt_template: Optional[str] = Field(None, description="System prompt template")
|
||||
model: Optional[str] = Field(None, description="AI model identifier")
|
||||
model_id: Optional[str] = Field(None, description="AI model identifier (alias for model)")
|
||||
temperature: Optional[float] = Field(None, description="Model temperature parameter")
|
||||
# max_tokens removed - now determined by model configuration
|
||||
visibility: Optional[str] = Field(None, description="Agent visibility setting")
|
||||
dataset_connection: Optional[str] = Field(None, description="RAG dataset connection type")
|
||||
selected_dataset_ids: Optional[List[str]] = Field(None, description="Selected dataset IDs for RAG")
|
||||
personality_config: Optional[Dict[str, Any]] = Field(None, description="Personality configuration")
|
||||
resource_preferences: Optional[Dict[str, Any]] = Field(None, description="Resource preferences")
|
||||
tags: Optional[List[str]] = Field(None, description="Agent tags")
|
||||
disclaimer: Optional[str] = Field(None, max_length=500, description="Disclaimer text shown in chat")
|
||||
easy_prompts: Optional[List[str]] = Field(None, description="Quick-access preset prompts (max 10)")
|
||||
team_shares: Optional[List[Dict[str, Any]]] = Field(None, description="Team sharing configuration with per-user permissions")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class AgentUpdate(BaseModel):
|
||||
"""Request to update an agent"""
|
||||
name: Optional[str] = Field(None, description="New agent name")
|
||||
description: Optional[str] = Field(None, description="New agent description")
|
||||
category: Optional[str] = Field(None, description="Agent category")
|
||||
prompt_template: Optional[str] = Field(None, description="System prompt template")
|
||||
model: Optional[str] = Field(None, description="AI model identifier")
|
||||
temperature: Optional[float] = Field(None, description="Model temperature parameter")
|
||||
# max_tokens removed - now determined by model configuration
|
||||
visibility: Optional[str] = Field(None, description="Agent visibility setting")
|
||||
dataset_connection: Optional[str] = Field(None, description="RAG dataset connection type")
|
||||
selected_dataset_ids: Optional[List[str]] = Field(None, description="Selected dataset IDs for RAG")
|
||||
personality_config: Optional[Dict[str, Any]] = Field(None, description="Updated personality config")
|
||||
resource_preferences: Optional[Dict[str, Any]] = Field(None, description="Updated resource preferences")
|
||||
tags: Optional[List[str]] = Field(None, description="Updated tags")
|
||||
is_favorite: Optional[bool] = Field(None, description="Favorite status")
|
||||
disclaimer: Optional[str] = Field(None, max_length=500, description="Disclaimer text shown in chat")
|
||||
easy_prompts: Optional[List[str]] = Field(None, description="Quick-access preset prompts (max 10)")
|
||||
team_shares: Optional[List[Dict[str, Any]]] = Field(None, description="Update team sharing configuration")
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Response for agent operations"""
|
||||
id: str = Field(..., description="Agent UUID")
|
||||
name: str = Field(..., description="Agent name")
|
||||
description: Optional[str] = Field(None, description="Agent description")
|
||||
template_id: Optional[str] = Field(None, description="Template ID if created from template")
|
||||
category: Optional[str] = Field(None, description="Agent category")
|
||||
prompt_template: Optional[str] = Field(None, description="System prompt template")
|
||||
model: Optional[str] = Field(None, description="AI model identifier")
|
||||
temperature: Optional[float] = Field(None, description="Model temperature parameter")
|
||||
max_tokens: Optional[int] = Field(None, description="Maximum tokens for generation")
|
||||
visibility: Optional[str] = Field(None, description="Agent visibility setting")
|
||||
dataset_connection: Optional[str] = Field(None, description="RAG dataset connection type")
|
||||
selected_dataset_ids: Optional[List[str]] = Field(None, description="Selected dataset IDs for RAG")
|
||||
personality_config: Dict[str, Any] = Field(default_factory=dict, description="Personality configuration")
|
||||
resource_preferences: Dict[str, Any] = Field(default_factory=dict, description="Resource preferences")
|
||||
tags: List[str] = Field(default_factory=list, description="Agent tags")
|
||||
is_favorite: bool = Field(False, description="Favorite status")
|
||||
disclaimer: Optional[str] = Field(None, description="Disclaimer text shown in chat")
|
||||
easy_prompts: List[str] = Field(default_factory=list, description="Quick-access preset prompts")
|
||||
conversation_count: int = Field(0, description="Number of conversations")
|
||||
usage_count: int = Field(0, description="Number of conversations (alias for frontend compatibility)")
|
||||
total_cost_cents: int = Field(0, description="Total cost in cents")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
# Creator information
|
||||
created_by_name: Optional[str] = Field(None, description="Full name of the user who created this agent")
|
||||
# Permission flags for frontend
|
||||
can_edit: bool = Field(False, description="Whether current user can edit this agent")
|
||||
can_delete: bool = Field(False, description="Whether current user can delete this agent")
|
||||
is_owner: bool = Field(False, description="Whether current user owns this agent")
|
||||
# Team sharing configuration
|
||||
team_shares: Optional[List[Dict[str, Any]]] = Field(None, description="Team sharing configuration with per-user permissions")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentListResponse(BaseModel):
|
||||
"""Response for listing agents"""
|
||||
data: List[AgentResponse] = Field(..., description="List of agents")
|
||||
total: int = Field(..., description="Total number of agents")
|
||||
limit: int = Field(..., description="Query limit")
|
||||
offset: int = Field(..., description="Query offset")
|
||||
|
||||
|
||||
class AgentCapabilities(BaseModel):
|
||||
"""Agent capabilities and resource access"""
|
||||
agent_id: str = Field(..., description="Agent UUID")
|
||||
capabilities: List[Dict[str, Any]] = Field(default_factory=list, description="Granted capabilities")
|
||||
resource_preferences: Dict[str, Any] = Field(default_factory=dict, description="Resource preferences")
|
||||
allowed_tools: List[str] = Field(default_factory=list, description="Allowed tool integrations")
|
||||
total: int = Field(..., description="Total capability count")
|
||||
|
||||
|
||||
class AgentStatistics(BaseModel):
|
||||
"""Agent usage statistics"""
|
||||
agent_id: str = Field(..., description="Agent UUID")
|
||||
name: str = Field(..., description="Agent name")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
last_used_at: Optional[datetime] = Field(None, description="Last usage timestamp")
|
||||
conversation_count: int = Field(0, description="Total conversations")
|
||||
total_messages: int = Field(0, description="Total messages processed")
|
||||
total_tokens_used: int = Field(0, description="Total tokens consumed")
|
||||
total_cost_cents: int = Field(0, description="Total cost in cents")
|
||||
total_cost_dollars: float = Field(0.0, description="Total cost in dollars")
|
||||
average_tokens_per_message: float = Field(0.0, description="Average tokens per message")
|
||||
is_favorite: bool = Field(False, description="Favorite status")
|
||||
tags: List[str] = Field(default_factory=list, description="Agent tags")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentCloneRequest(BaseModel):
|
||||
"""Request to clone an agent"""
|
||||
new_name: str = Field(..., description="Name for the cloned agent")
|
||||
modifications: Optional[Dict[str, Any]] = Field(None, description="Modifications to apply")
|
||||
71
apps/tenant-backend/app/schemas/category.py
Normal file
71
apps/tenant-backend/app/schemas/category.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Category schemas for GT 2.0 Tenant Backend
|
||||
|
||||
Pydantic models for agent category API request/response validation.
|
||||
Supports tenant-scoped editable/deletable categories per Issue #215.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
|
||||
class CategoryCreate(BaseModel):
|
||||
"""Request to create a new category"""
|
||||
name: str = Field(..., min_length=1, max_length=100, description="Category display name")
|
||||
description: Optional[str] = Field(None, max_length=500, description="Category description")
|
||||
icon: Optional[str] = Field(None, max_length=10, description="Category icon (emoji)")
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError('Category name cannot be empty')
|
||||
# Check for invalid characters (allow alphanumeric, spaces, hyphens, underscores)
|
||||
if not re.match(r'^[\w\s\-]+$', v):
|
||||
raise ValueError('Category name can only contain letters, numbers, spaces, hyphens, and underscores')
|
||||
return v
|
||||
|
||||
|
||||
class CategoryUpdate(BaseModel):
|
||||
"""Request to update a category"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100, description="New category name")
|
||||
description: Optional[str] = Field(None, max_length=500, description="New category description")
|
||||
icon: Optional[str] = Field(None, max_length=10, description="New category icon")
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError('Category name cannot be empty')
|
||||
if not re.match(r'^[\w\s\-]+$', v):
|
||||
raise ValueError('Category name can only contain letters, numbers, spaces, hyphens, and underscores')
|
||||
return v
|
||||
|
||||
|
||||
class CategoryResponse(BaseModel):
|
||||
"""Response for category operations"""
|
||||
id: str = Field(..., description="Category UUID")
|
||||
name: str = Field(..., description="Category display name")
|
||||
slug: str = Field(..., description="URL-safe category identifier")
|
||||
description: Optional[str] = Field(None, description="Category description")
|
||||
icon: Optional[str] = Field(None, description="Category icon (emoji)")
|
||||
is_default: bool = Field(..., description="Whether this is a system default category")
|
||||
created_by: Optional[str] = Field(None, description="UUID of user who created the category")
|
||||
created_by_name: Optional[str] = Field(None, description="Name of user who created the category")
|
||||
can_edit: bool = Field(..., description="Whether current user can edit this category")
|
||||
can_delete: bool = Field(..., description="Whether current user can delete this category")
|
||||
sort_order: int = Field(..., description="Display sort order")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class CategoryListResponse(BaseModel):
|
||||
"""Response for listing categories"""
|
||||
categories: List[CategoryResponse] = Field(default_factory=list, description="List of categories")
|
||||
total: int = Field(..., description="Total number of categories")
|
||||
81
apps/tenant-backend/app/schemas/conversation.py
Normal file
81
apps/tenant-backend/app/schemas/conversation.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Conversation schemas for GT 2.0 Tenant Backend
|
||||
|
||||
Pydantic models for conversation-related API request/response validation.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ConversationCreate(BaseModel):
|
||||
"""Request to create a new conversation"""
|
||||
agent_id: str = Field(..., description="Agent UUID to chat with")
|
||||
title: Optional[str] = Field(None, description="Conversation title")
|
||||
initial_message: Optional[str] = Field(None, description="First message to send")
|
||||
|
||||
|
||||
class ConversationUpdate(BaseModel):
|
||||
"""Request to update a conversation"""
|
||||
title: Optional[str] = Field(None, description="New conversation title")
|
||||
system_prompt: Optional[str] = Field(None, description="Updated system prompt")
|
||||
|
||||
|
||||
class MessageCreate(BaseModel):
|
||||
"""Request to send a message"""
|
||||
content: str = Field(..., description="Message content")
|
||||
context_sources: Optional[List[str]] = Field(None, description="Context source IDs")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional message metadata")
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Message response"""
|
||||
id: Optional[str] = Field(None, description="Message ID")
|
||||
message_id: Optional[str] = Field(None, description="Message ID (alternative)")
|
||||
content: Optional[str] = Field(None, description="Message content")
|
||||
role: Optional[str] = Field(None, description="Message role")
|
||||
tokens_used: Optional[int] = Field(None, description="Tokens consumed")
|
||||
model_used: Optional[str] = Field(None, description="Model used for generation")
|
||||
context_sources: Optional[List[str]] = Field(None, description="RAG context source documents")
|
||||
created_at: Optional[datetime] = Field(None, description="Creation timestamp")
|
||||
stream: Optional[bool] = Field(None, description="Whether response is streamed")
|
||||
stream_endpoint: Optional[str] = Field(None, description="Stream endpoint URL")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, protected_namespaces=())
|
||||
|
||||
|
||||
class MessageListResponse(BaseModel):
|
||||
"""Response for listing messages"""
|
||||
messages: List[MessageResponse]
|
||||
conversation_id: int
|
||||
total: int
|
||||
|
||||
|
||||
class ConversationResponse(BaseModel):
|
||||
"""Conversation response"""
|
||||
id: int = Field(..., description="Conversation ID")
|
||||
title: str = Field(..., description="Conversation title")
|
||||
agent_id: str = Field(..., description="Agent ID")
|
||||
model_id: str = Field(..., description="Model identifier")
|
||||
system_prompt: Optional[str] = Field(None, description="System prompt")
|
||||
message_count: int = Field(0, description="Total message count")
|
||||
total_tokens: int = Field(0, description="Total tokens used")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
messages: Optional[List[MessageResponse]] = Field(None, description="Conversation messages")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, protected_namespaces=())
|
||||
|
||||
|
||||
class ConversationWithUnread(ConversationResponse):
|
||||
"""Conversation response with unread message count"""
|
||||
unread_count: int = Field(0, description="Number of unread messages")
|
||||
|
||||
|
||||
class ConversationListResponse(BaseModel):
|
||||
"""Response for listing conversations"""
|
||||
conversations: List[ConversationResponse]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
160
apps/tenant-backend/app/schemas/document.py
Normal file
160
apps/tenant-backend/app/schemas/document.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Document Pydantic schemas for GT 2.0 Tenant Backend
|
||||
|
||||
Defines request/response schemas for document and RAG operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class DocumentResponse(BaseModel):
|
||||
"""Document response schema"""
|
||||
id: int
|
||||
uuid: str
|
||||
filename: str
|
||||
original_filename: str
|
||||
file_type: str
|
||||
file_extension: str
|
||||
file_size_bytes: int
|
||||
processing_status: str
|
||||
chunk_count: int
|
||||
content_summary: Optional[str] = None
|
||||
detected_language: Optional[str] = None
|
||||
content_type: Optional[str] = None
|
||||
keywords: List[str] = Field(default_factory=list)
|
||||
uploaded_by: str
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
category: Optional[str] = None
|
||||
access_count: int = 0
|
||||
is_active: bool = True
|
||||
is_searchable: bool = True
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
processed_at: Optional[datetime] = None
|
||||
last_accessed_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class RAGDatasetCreate(BaseModel):
|
||||
"""Schema for creating a RAG dataset"""
|
||||
dataset_name: str = Field(..., min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
chunking_strategy: str = Field(default="hybrid", pattern="^(fixed|semantic|hierarchical|hybrid)$")
|
||||
chunk_size: int = Field(default=512, ge=128, le=2048)
|
||||
chunk_overlap: int = Field(default=128, ge=0, le=512)
|
||||
embedding_model: str = Field(default="BAAI/bge-m3")
|
||||
|
||||
@validator('chunk_overlap')
|
||||
def validate_chunk_overlap(cls, v, values):
|
||||
if 'chunk_size' in values and v >= values['chunk_size']:
|
||||
raise ValueError('chunk_overlap must be less than chunk_size')
|
||||
return v
|
||||
|
||||
|
||||
class RAGDatasetResponse(BaseModel):
|
||||
"""RAG dataset response schema"""
|
||||
id: str
|
||||
user_id: str
|
||||
dataset_name: str
|
||||
description: Optional[str] = None
|
||||
chunking_strategy: str
|
||||
embedding_model: str
|
||||
chunk_size: int
|
||||
chunk_overlap: int
|
||||
document_count: int = 0
|
||||
chunk_count: int = 0
|
||||
vector_count: int = 0
|
||||
total_size_bytes: int = 0
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class DocumentChunkResponse(BaseModel):
|
||||
"""Document chunk response schema"""
|
||||
id: str
|
||||
chunk_index: int
|
||||
chunk_metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
embedding_id: str
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Document search request schema"""
|
||||
query: str = Field(..., min_length=1, max_length=1000)
|
||||
dataset_ids: Optional[List[str]] = None
|
||||
top_k: int = Field(default=5, ge=1, le=20)
|
||||
similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""Document search result schema"""
|
||||
document_id: Optional[int] = None
|
||||
dataset_id: Optional[str] = None
|
||||
dataset_name: Optional[str] = None
|
||||
text: str
|
||||
similarity: float
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
filename: Optional[str] = None
|
||||
chunk_index: Optional[int] = None
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
"""Document search response schema"""
|
||||
query: str
|
||||
results: List[SearchResult]
|
||||
total_results: int
|
||||
search_time_ms: Optional[float] = None
|
||||
|
||||
|
||||
class DocumentContextResponse(BaseModel):
|
||||
"""Document context response schema"""
|
||||
document_id: int
|
||||
document_name: str
|
||||
query: str
|
||||
relevant_chunks: List[SearchResult]
|
||||
context_text: str
|
||||
|
||||
|
||||
class RAGStatistics(BaseModel):
|
||||
"""RAG usage statistics schema"""
|
||||
user_id: str
|
||||
document_count: int
|
||||
dataset_count: int
|
||||
total_size_bytes: int
|
||||
total_size_mb: float
|
||||
total_chunks: int
|
||||
processed_documents: int
|
||||
pending_documents: int
|
||||
failed_documents: int
|
||||
|
||||
|
||||
class ProcessDocumentRequest(BaseModel):
|
||||
"""Document processing request schema"""
|
||||
chunking_strategy: Optional[str] = Field(default="hybrid", pattern="^(fixed|semantic|hierarchical|hybrid)$")
|
||||
|
||||
|
||||
class ProcessDocumentResponse(BaseModel):
|
||||
"""Document processing response schema"""
|
||||
status: str
|
||||
document_id: int
|
||||
chunk_count: int
|
||||
vector_store_ids: List[str]
|
||||
processing_time_ms: Optional[float] = None
|
||||
|
||||
|
||||
class UploadDocumentResponse(BaseModel):
|
||||
"""Document upload response schema"""
|
||||
document: DocumentResponse
|
||||
processing_initiated: bool = False
|
||||
message: str = "Document uploaded successfully"
|
||||
269
apps/tenant-backend/app/schemas/event.py
Normal file
269
apps/tenant-backend/app/schemas/event.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Event Pydantic schemas for GT 2.0 Tenant Backend
|
||||
|
||||
Defines request/response schemas for event automation operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class EventActionCreate(BaseModel):
|
||||
"""Schema for creating an event action"""
|
||||
action_type: str = Field(..., description="Type of action to execute")
|
||||
config: Dict[str, Any] = Field(default_factory=dict, description="Action configuration")
|
||||
delay_seconds: int = Field(default=0, ge=0, le=3600, description="Delay before execution")
|
||||
retry_count: int = Field(default=3, ge=0, le=10, description="Number of retries on failure")
|
||||
retry_delay: int = Field(default=60, ge=1, le=3600, description="Delay between retries")
|
||||
condition: Optional[str] = Field(None, max_length=1000, description="Python expression for conditional execution")
|
||||
execution_order: int = Field(default=0, ge=0, description="Order of execution within subscription")
|
||||
|
||||
@validator('action_type')
|
||||
def validate_action_type(cls, v):
|
||||
valid_types = [
|
||||
'process_document', 'send_notification', 'update_statistics',
|
||||
'trigger_rag_indexing', 'log_analytics', 'execute_webhook',
|
||||
'create_assistant', 'schedule_task'
|
||||
]
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'action_type must be one of: {", ".join(valid_types)}')
|
||||
return v
|
||||
|
||||
|
||||
class EventActionResponse(BaseModel):
|
||||
"""Event action response schema"""
|
||||
id: str
|
||||
action_type: str
|
||||
subscription_id: str
|
||||
config: Dict[str, Any]
|
||||
condition: Optional[str] = None
|
||||
delay_seconds: int
|
||||
retry_count: int
|
||||
retry_delay: int
|
||||
execution_order: int
|
||||
is_active: bool
|
||||
execution_count: int
|
||||
success_count: int
|
||||
failure_count: int
|
||||
last_executed_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class EventSubscriptionCreate(BaseModel):
|
||||
"""Schema for creating an event subscription"""
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
event_type: str = Field(..., description="Type of event to subscribe to")
|
||||
actions: List[EventActionCreate] = Field(..., min_items=1, description="Actions to execute")
|
||||
filter_conditions: Dict[str, Any] = Field(default_factory=dict, description="Conditions for subscription activation")
|
||||
|
||||
@validator('event_type')
|
||||
def validate_event_type(cls, v):
|
||||
valid_types = [
|
||||
'document.uploaded', 'document.processed', 'document.failed',
|
||||
'conversation.started', 'message.sent', 'agent.created',
|
||||
'rag.search_performed', 'user.login', 'user.activity',
|
||||
'system.health_check'
|
||||
]
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'event_type must be one of: {", ".join(valid_types)}')
|
||||
return v
|
||||
|
||||
|
||||
class EventSubscriptionResponse(BaseModel):
|
||||
"""Event subscription response schema"""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
event_type: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
trigger_id: Optional[str] = None
|
||||
filter_conditions: Dict[str, Any]
|
||||
is_active: bool
|
||||
trigger_count: int
|
||||
last_triggered_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
actions: List[EventActionResponse] = Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class EventResponse(BaseModel):
|
||||
"""Event response schema"""
|
||||
id: int
|
||||
event_id: str
|
||||
event_type: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
payload: Dict[str, Any]
|
||||
metadata: Dict[str, Any]
|
||||
status: str
|
||||
error_message: Optional[str] = None
|
||||
retry_count: int
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class EventStatistics(BaseModel):
|
||||
"""Event statistics response schema"""
|
||||
total_events: int
|
||||
events_by_type: Dict[str, int]
|
||||
events_by_status: Dict[str, int]
|
||||
average_events_per_day: float
|
||||
|
||||
|
||||
class EventTriggerCreate(BaseModel):
|
||||
"""Schema for creating an event trigger"""
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
trigger_type: str = Field(..., description="Type of trigger")
|
||||
config: Dict[str, Any] = Field(default_factory=dict, description="Trigger configuration")
|
||||
conditions: Dict[str, Any] = Field(default_factory=dict, description="Trigger conditions")
|
||||
|
||||
@validator('trigger_type')
|
||||
def validate_trigger_type(cls, v):
|
||||
valid_types = [
|
||||
'schedule', 'webhook', 'file_watch', 'database_change',
|
||||
'api_call', 'user_action', 'system_event'
|
||||
]
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'trigger_type must be one of: {", ".join(valid_types)}')
|
||||
return v
|
||||
|
||||
|
||||
class EventTriggerResponse(BaseModel):
|
||||
"""Event trigger response schema"""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
trigger_type: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
config: Dict[str, Any]
|
||||
conditions: Dict[str, Any]
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_triggered_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ScheduledTaskResponse(BaseModel):
|
||||
"""Scheduled task response schema"""
|
||||
id: str
|
||||
task_type: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
scheduled_at: datetime
|
||||
executed_at: Optional[datetime] = None
|
||||
config: Dict[str, Any]
|
||||
context: Dict[str, Any]
|
||||
status: str
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error_message: Optional[str] = None
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
retry_count: int
|
||||
max_retries: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class EventLogResponse(BaseModel):
|
||||
"""Event log response schema"""
|
||||
id: int
|
||||
event_id: str
|
||||
log_level: str
|
||||
message: str
|
||||
details: Dict[str, Any]
|
||||
action_id: Optional[str] = None
|
||||
subscription_id: Optional[str] = None
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class EmitEventRequest(BaseModel):
|
||||
"""Request schema for manually emitting events"""
|
||||
event_type: str = Field(..., description="Type of event to emit")
|
||||
data: Dict[str, Any] = Field(..., description="Event data payload")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional metadata")
|
||||
|
||||
@validator('event_type')
|
||||
def validate_event_type(cls, v):
|
||||
valid_types = [
|
||||
'document.uploaded', 'document.processed', 'document.failed',
|
||||
'conversation.started', 'message.sent', 'agent.created',
|
||||
'rag.search_performed', 'user.login', 'user.activity',
|
||||
'system.health_check'
|
||||
]
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'event_type must be one of: {", ".join(valid_types)}')
|
||||
return v
|
||||
|
||||
|
||||
class WebhookConfig(BaseModel):
|
||||
"""Configuration for webhook actions"""
|
||||
url: str = Field(..., description="Webhook URL")
|
||||
method: str = Field(default="POST", pattern="^(GET|POST|PUT|PATCH|DELETE)$")
|
||||
headers: Dict[str, str] = Field(default_factory=dict)
|
||||
timeout: int = Field(default=30, ge=1, le=300)
|
||||
retry_on_failure: bool = Field(default=True)
|
||||
|
||||
|
||||
class NotificationConfig(BaseModel):
|
||||
"""Configuration for notification actions"""
|
||||
type: str = Field(default="system", description="Notification type")
|
||||
message: str = Field(..., min_length=1, max_length=1000, description="Notification message")
|
||||
priority: str = Field(default="normal", pattern="^(low|normal|high|urgent)$")
|
||||
channels: List[str] = Field(default_factory=list, description="Notification channels")
|
||||
|
||||
|
||||
class DocumentProcessingConfig(BaseModel):
|
||||
"""Configuration for document processing actions"""
|
||||
chunking_strategy: str = Field(default="hybrid", pattern="^(fixed|semantic|hierarchical|hybrid)$")
|
||||
chunk_size: int = Field(default=512, ge=128, le=2048)
|
||||
chunk_overlap: int = Field(default=128, ge=0, le=512)
|
||||
auto_index: bool = Field(default=True, description="Automatically index in RAG system")
|
||||
|
||||
|
||||
class StatisticsUpdateConfig(BaseModel):
|
||||
"""Configuration for statistics update actions"""
|
||||
type: str = Field(..., description="Type of statistic to update")
|
||||
increment: int = Field(default=1, description="Amount to increment")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
|
||||
|
||||
class AssistantCreationConfig(BaseModel):
|
||||
"""Configuration for agent creation actions"""
|
||||
template_id: str = Field(default="general_assistant", description="Agent template ID")
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Agent name")
|
||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, description="Configuration overrides")
|
||||
|
||||
|
||||
class TaskSchedulingConfig(BaseModel):
|
||||
"""Configuration for task scheduling actions"""
|
||||
task_type: str = Field(..., description="Type of task to schedule")
|
||||
delay_minutes: int = Field(default=0, ge=0, description="Delay before execution")
|
||||
task_config: Dict[str, Any] = Field(default_factory=dict, description="Task configuration")
|
||||
max_retries: int = Field(default=3, ge=0, le=10, description="Maximum retry attempts")
|
||||
64
apps/tenant-backend/app/schemas/user.py
Normal file
64
apps/tenant-backend/app/schemas/user.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
User schemas for GT 2.0 Tenant Backend
|
||||
|
||||
Pydantic models for user-related API request/response validation.
|
||||
Implements user preferences management per GT 2.0 specifications.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class CustomCategory(BaseModel):
|
||||
"""User-defined custom category with metadata"""
|
||||
name: str = Field(..., description="Category name (lowercase, unique per user)")
|
||||
description: str = Field(..., description="Category description")
|
||||
created_at: Optional[str] = Field(None, description="ISO timestamp when category was created")
|
||||
|
||||
|
||||
class UserPreferences(BaseModel):
|
||||
"""User preferences stored in JSONB"""
|
||||
favorite_agent_ids: Optional[List[str]] = Field(default_factory=list, description="List of favorited agent UUIDs")
|
||||
custom_categories: Optional[List[CustomCategory]] = Field(default_factory=list, description="User's custom agent categories")
|
||||
# Future preferences can be added here
|
||||
|
||||
|
||||
class UserPreferencesResponse(BaseModel):
|
||||
"""Response for getting user preferences"""
|
||||
preferences: Dict[str, Any] = Field(..., description="User preferences dictionary")
|
||||
|
||||
|
||||
class UpdateUserPreferencesRequest(BaseModel):
|
||||
"""Request to update user preferences (merges with existing)"""
|
||||
preferences: Dict[str, Any] = Field(..., description="Preferences to merge with existing")
|
||||
|
||||
|
||||
class FavoriteAgentsResponse(BaseModel):
|
||||
"""Response for getting favorite agent IDs"""
|
||||
favorite_agent_ids: List[str] = Field(..., description="List of favorited agent UUIDs")
|
||||
|
||||
|
||||
class UpdateFavoriteAgentsRequest(BaseModel):
|
||||
"""Request to update favorite agent IDs (replaces existing list)"""
|
||||
agent_ids: List[str] = Field(..., description="List of agent UUIDs to set as favorites")
|
||||
|
||||
|
||||
class AddFavoriteAgentRequest(BaseModel):
|
||||
"""Request to add a single agent to favorites"""
|
||||
agent_id: str = Field(..., description="Agent UUID to add to favorites")
|
||||
|
||||
|
||||
class RemoveFavoriteAgentRequest(BaseModel):
|
||||
"""Request to remove a single agent from favorites"""
|
||||
agent_id: str = Field(..., description="Agent UUID to remove from favorites")
|
||||
|
||||
|
||||
class CustomCategoriesResponse(BaseModel):
|
||||
"""Response for getting custom categories"""
|
||||
categories: List[CustomCategory] = Field(..., description="List of user's custom categories")
|
||||
|
||||
|
||||
class UpdateCustomCategoriesRequest(BaseModel):
|
||||
"""Request to update custom categories (replaces entire list)"""
|
||||
categories: List[CustomCategory] = Field(..., description="Complete list of custom categories")
|
||||
5
apps/tenant-backend/app/services/__init__.py
Normal file
5
apps/tenant-backend/app/services/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend Services
|
||||
|
||||
Business logic and orchestration services for tenant applications.
|
||||
"""
|
||||
451
apps/tenant-backend/app/services/access_controller.py
Normal file
451
apps/tenant-backend/app/services/access_controller.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""
|
||||
Access Controller Service for GT 2.0
|
||||
|
||||
Manages resource access control with capability-based security.
|
||||
Ensures perfect tenant isolation and proper permission cascading.
|
||||
"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from app.models.access_group import (
|
||||
AccessGroup, TenantStructure, User, Resource,
|
||||
ResourceCreate, ResourceUpdate, ResourceResponse
|
||||
)
|
||||
from app.core.security import verify_capability_token
|
||||
from app.core.database import get_db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccessController:
|
||||
"""
|
||||
Centralized access control service
|
||||
Manages permissions for all resources with tenant isolation
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.base_path = Path(f"/data/{tenant_domain}")
|
||||
self._ensure_tenant_directory()
|
||||
|
||||
def _ensure_tenant_directory(self):
|
||||
"""
|
||||
Ensure tenant directory exists with proper permissions
|
||||
OS User: gt-{tenant_domain}-{pod_id}
|
||||
Permissions: 700 (owner only)
|
||||
"""
|
||||
if not self.base_path.exists():
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
# Set strict permissions - owner only
|
||||
os.chmod(self.base_path, stat.S_IRWXU) # 700
|
||||
logger.info(f"Created tenant directory: {self.base_path} with 700 permissions")
|
||||
|
||||
async def check_permission(
|
||||
self,
|
||||
user_id: str,
|
||||
resource: Resource,
|
||||
action: str = "read"
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if user has permission for action on resource
|
||||
|
||||
Args:
|
||||
user_id: User requesting access
|
||||
resource: Resource being accessed
|
||||
action: read, write, delete, share
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, reason)
|
||||
"""
|
||||
# Verify tenant isolation
|
||||
if resource.tenant_domain != self.tenant_domain:
|
||||
logger.warning(f"Cross-tenant access attempt: {user_id} -> {resource.id}")
|
||||
return False, "Cross-tenant access denied"
|
||||
|
||||
# Owner has all permissions
|
||||
if resource.owner_id == user_id:
|
||||
return True, "Owner access granted"
|
||||
|
||||
# Check action-specific permissions
|
||||
if action == "read":
|
||||
return self._check_read_permission(user_id, resource)
|
||||
elif action == "write":
|
||||
return self._check_write_permission(user_id, resource)
|
||||
elif action == "delete":
|
||||
return False, "Only owner can delete"
|
||||
elif action == "share":
|
||||
return False, "Only owner can share"
|
||||
else:
|
||||
return False, f"Unknown action: {action}"
|
||||
|
||||
def _check_read_permission(self, user_id: str, resource: Resource) -> Tuple[bool, str]:
|
||||
"""Check read permission based on access group"""
|
||||
if resource.access_group == AccessGroup.ORGANIZATION:
|
||||
return True, "Organization-wide read access"
|
||||
elif resource.access_group == AccessGroup.TEAM:
|
||||
if user_id in resource.team_members:
|
||||
return True, "Team member read access"
|
||||
return False, "Not a team member"
|
||||
else: # INDIVIDUAL
|
||||
return False, "Private resource"
|
||||
|
||||
def _check_write_permission(self, user_id: str, resource: Resource) -> Tuple[bool, str]:
|
||||
"""Check write permission - only owner can write"""
|
||||
return False, "Only owner can modify"
|
||||
|
||||
async def create_resource(
|
||||
self,
|
||||
user_id: str,
|
||||
resource_data: ResourceCreate,
|
||||
capability_token: str
|
||||
) -> Resource:
|
||||
"""
|
||||
Create a new resource with proper access control
|
||||
|
||||
Args:
|
||||
user_id: User creating the resource
|
||||
resource_data: Resource creation data
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
Created resource
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Create resource
|
||||
resource = Resource(
|
||||
id=self._generate_resource_id(),
|
||||
name=resource_data.name,
|
||||
resource_type=resource_data.resource_type,
|
||||
owner_id=user_id,
|
||||
tenant_domain=self.tenant_domain,
|
||||
access_group=resource_data.access_group,
|
||||
team_members=resource_data.team_members or [],
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
metadata=resource_data.metadata or {},
|
||||
file_path=None
|
||||
)
|
||||
|
||||
# Create file-based storage if needed
|
||||
if self._requires_file_storage(resource.resource_type):
|
||||
resource.file_path = await self._create_resource_file(resource)
|
||||
|
||||
# Audit log
|
||||
logger.info(f"Resource created: {resource.id} by {user_id} in {self.tenant_domain}")
|
||||
|
||||
return resource
|
||||
|
||||
async def update_resource_access(
|
||||
self,
|
||||
user_id: str,
|
||||
resource_id: str,
|
||||
new_access_group: AccessGroup,
|
||||
team_members: Optional[List[str]] = None
|
||||
) -> Resource:
|
||||
"""
|
||||
Update resource access group
|
||||
|
||||
Args:
|
||||
user_id: User requesting update
|
||||
resource_id: Resource to update
|
||||
new_access_group: New access level
|
||||
team_members: Team members if team access
|
||||
|
||||
Returns:
|
||||
Updated resource
|
||||
"""
|
||||
# Load resource
|
||||
resource = await self._load_resource(resource_id)
|
||||
|
||||
# Check permission
|
||||
allowed, reason = await self.check_permission(user_id, resource, "share")
|
||||
if not allowed:
|
||||
raise PermissionError(f"Access denied: {reason}")
|
||||
|
||||
# Update access
|
||||
old_group = resource.access_group
|
||||
resource.update_access_group(new_access_group, team_members)
|
||||
|
||||
# Update file permissions if needed
|
||||
if resource.file_path:
|
||||
await self._update_file_permissions(resource)
|
||||
|
||||
# Audit log
|
||||
logger.info(
|
||||
f"Access updated: {resource_id} from {old_group} to {new_access_group} "
|
||||
f"by {user_id}"
|
||||
)
|
||||
|
||||
return resource
|
||||
|
||||
async def list_accessible_resources(
|
||||
self,
|
||||
user_id: str,
|
||||
resource_type: Optional[str] = None
|
||||
) -> List[Resource]:
|
||||
"""
|
||||
List all resources accessible to user
|
||||
|
||||
Args:
|
||||
user_id: User requesting list
|
||||
resource_type: Filter by type
|
||||
|
||||
Returns:
|
||||
List of accessible resources
|
||||
"""
|
||||
accessible = []
|
||||
|
||||
# Get all resources in tenant
|
||||
all_resources = await self._list_tenant_resources(resource_type)
|
||||
|
||||
for resource in all_resources:
|
||||
allowed, _ = await self.check_permission(user_id, resource, "read")
|
||||
if allowed:
|
||||
accessible.append(resource)
|
||||
|
||||
return accessible
|
||||
|
||||
async def get_resource_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get resource statistics for user
|
||||
|
||||
Args:
|
||||
user_id: User to get stats for
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
all_resources = await self._list_tenant_resources()
|
||||
|
||||
owned = [r for r in all_resources if r.owner_id == user_id]
|
||||
accessible = await self.list_accessible_resources(user_id)
|
||||
|
||||
stats = {
|
||||
"owned_count": len(owned),
|
||||
"accessible_count": len(accessible),
|
||||
"by_type": {},
|
||||
"by_access_group": {
|
||||
AccessGroup.INDIVIDUAL: 0,
|
||||
AccessGroup.TEAM: 0,
|
||||
AccessGroup.ORGANIZATION: 0
|
||||
}
|
||||
}
|
||||
|
||||
for resource in owned:
|
||||
# Count by type
|
||||
if resource.resource_type not in stats["by_type"]:
|
||||
stats["by_type"][resource.resource_type] = 0
|
||||
stats["by_type"][resource.resource_type] += 1
|
||||
|
||||
# Count by access group
|
||||
stats["by_access_group"][resource.access_group] += 1
|
||||
|
||||
return stats
|
||||
|
||||
def _generate_resource_id(self) -> str:
|
||||
"""Generate unique resource ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _requires_file_storage(self, resource_type: str) -> bool:
|
||||
"""Check if resource type requires file storage"""
|
||||
file_based_types = [
|
||||
"agent", "dataset", "document", "workflow",
|
||||
"notebook", "model", "configuration"
|
||||
]
|
||||
return resource_type in file_based_types
|
||||
|
||||
async def _create_resource_file(self, resource: Resource) -> str:
|
||||
"""
|
||||
Create file for resource with proper permissions
|
||||
|
||||
Args:
|
||||
resource: Resource to create file for
|
||||
|
||||
Returns:
|
||||
File path
|
||||
"""
|
||||
# Determine path based on resource type
|
||||
type_dir = self.base_path / resource.resource_type / resource.id
|
||||
type_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create main file
|
||||
file_path = type_dir / "data.json"
|
||||
file_path.touch()
|
||||
|
||||
# Set strict permissions - 700 for directory, 600 for file
|
||||
os.chmod(type_dir, stat.S_IRWXU) # 700
|
||||
os.chmod(file_path, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
logger.info(f"Created resource file: {file_path} with secure permissions")
|
||||
|
||||
return str(file_path)
|
||||
|
||||
async def _update_file_permissions(self, resource: Resource):
|
||||
"""Update file permissions (always 700/600 for security)"""
|
||||
if not resource.file_path or not Path(resource.file_path).exists():
|
||||
return
|
||||
|
||||
# Permissions don't change based on access group
|
||||
# All files remain 700/600 for OS-level security
|
||||
# Access control is handled at application level
|
||||
pass
|
||||
|
||||
async def _load_resource(self, resource_id: str) -> Resource:
|
||||
"""Load resource from storage"""
|
||||
try:
|
||||
# Search for resource in all resource type directories
|
||||
for resource_type_dir in self.base_path.iterdir():
|
||||
if not resource_type_dir.is_dir():
|
||||
continue
|
||||
|
||||
resource_file = resource_type_dir / "data.json"
|
||||
if resource_file.exists():
|
||||
try:
|
||||
import json
|
||||
with open(resource_file, 'r') as f:
|
||||
resources_data = json.load(f)
|
||||
|
||||
if not isinstance(resources_data, list):
|
||||
resources_data = [resources_data]
|
||||
|
||||
for resource_data in resources_data:
|
||||
if resource_data.get('id') == resource_id:
|
||||
return Resource(
|
||||
id=resource_data['id'],
|
||||
name=resource_data['name'],
|
||||
resource_type=resource_data['resource_type'],
|
||||
owner_id=resource_data['owner_id'],
|
||||
tenant_domain=resource_data['tenant_domain'],
|
||||
access_group=AccessGroup(resource_data['access_group']),
|
||||
team_members=resource_data.get('team_members', []),
|
||||
created_at=datetime.fromisoformat(resource_data['created_at']),
|
||||
updated_at=datetime.fromisoformat(resource_data['updated_at']),
|
||||
metadata=resource_data.get('metadata', {}),
|
||||
file_path=resource_data.get('file_path')
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse resource file {resource_file}: {e}")
|
||||
continue
|
||||
|
||||
raise ValueError(f"Resource {resource_id} not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load resource {resource_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _list_tenant_resources(
|
||||
self,
|
||||
resource_type: Optional[str] = None
|
||||
) -> List[Resource]:
|
||||
"""List all resources in tenant"""
|
||||
try:
|
||||
import json
|
||||
resources = []
|
||||
|
||||
# If specific resource type requested, search only that directory
|
||||
search_dirs = [self.base_path / resource_type] if resource_type else list(self.base_path.iterdir())
|
||||
|
||||
for resource_type_dir in search_dirs:
|
||||
if not resource_type_dir.exists() or not resource_type_dir.is_dir():
|
||||
continue
|
||||
|
||||
resource_file = resource_type_dir / "data.json"
|
||||
if resource_file.exists():
|
||||
try:
|
||||
with open(resource_file, 'r') as f:
|
||||
resources_data = json.load(f)
|
||||
|
||||
if not isinstance(resources_data, list):
|
||||
resources_data = [resources_data]
|
||||
|
||||
for resource_data in resources_data:
|
||||
try:
|
||||
resource = Resource(
|
||||
id=resource_data['id'],
|
||||
name=resource_data['name'],
|
||||
resource_type=resource_data['resource_type'],
|
||||
owner_id=resource_data['owner_id'],
|
||||
tenant_domain=resource_data['tenant_domain'],
|
||||
access_group=AccessGroup(resource_data['access_group']),
|
||||
team_members=resource_data.get('team_members', []),
|
||||
created_at=datetime.fromisoformat(resource_data['created_at']),
|
||||
updated_at=datetime.fromisoformat(resource_data['updated_at']),
|
||||
metadata=resource_data.get('metadata', {}),
|
||||
file_path=resource_data.get('file_path')
|
||||
)
|
||||
resources.append(resource)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse resource data: {e}")
|
||||
continue
|
||||
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning(f"Failed to read resource file {resource_file}: {e}")
|
||||
continue
|
||||
|
||||
return resources
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list tenant resources: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class AccessControlMiddleware:
|
||||
"""
|
||||
Middleware for enforcing access control on API requests
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str):
|
||||
self.controller = AccessController(tenant_domain)
|
||||
|
||||
async def verify_request(
|
||||
self,
|
||||
user_id: str,
|
||||
resource_id: str,
|
||||
action: str,
|
||||
capability_token: str
|
||||
) -> bool:
|
||||
"""
|
||||
Verify request has proper permissions
|
||||
|
||||
Args:
|
||||
user_id: User making request
|
||||
resource_id: Resource being accessed
|
||||
action: Action being performed
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
True if allowed, raises PermissionError if not
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Verify tenant match
|
||||
if token_data.get("tenant_id") != self.controller.tenant_domain:
|
||||
raise PermissionError("Tenant mismatch in capability token")
|
||||
|
||||
# Load resource and check permission
|
||||
resource = await self.controller._load_resource(resource_id)
|
||||
allowed, reason = await self.controller.check_permission(
|
||||
user_id, resource, action
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
f"Access denied: {user_id} -> {resource_id} ({action}): {reason}"
|
||||
)
|
||||
raise PermissionError(f"Access denied: {reason}")
|
||||
|
||||
return True
|
||||
920
apps/tenant-backend/app/services/agent_orchestrator_client.py
Normal file
920
apps/tenant-backend/app/services/agent_orchestrator_client.py
Normal file
@@ -0,0 +1,920 @@
|
||||
"""
|
||||
GT 2.0 Agent Orchestrator Client
|
||||
|
||||
Client for interacting with the Resource Cluster's Agent Orchestration system.
|
||||
Enables spawning and managing subagents for complex task execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import httpx
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from app.services.task_classifier import SubagentType, TaskClassification
|
||||
from app.models.agent import Agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExecutionStrategy(str, Enum):
|
||||
"""Execution strategies for subagents"""
|
||||
SEQUENTIAL = "sequential"
|
||||
PARALLEL = "parallel"
|
||||
CONDITIONAL = "conditional"
|
||||
PIPELINE = "pipeline"
|
||||
MAP_REDUCE = "map_reduce"
|
||||
|
||||
|
||||
class SubagentOrchestrator:
|
||||
"""
|
||||
Orchestrates subagent execution for complex tasks.
|
||||
|
||||
Manages lifecycle of subagents spawned from main agent templates,
|
||||
coordinates their execution, and aggregates results.
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.resource_cluster_url = "http://resource-cluster:8000"
|
||||
self.active_subagents: Dict[str, Dict[str, Any]] = {}
|
||||
self.execution_history: List[Dict[str, Any]] = []
|
||||
|
||||
async def execute_task_plan(
|
||||
self,
|
||||
task_classification: TaskClassification,
|
||||
parent_agent: Agent,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
available_tools: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a task plan using subagents.
|
||||
|
||||
Args:
|
||||
task_classification: Task classification with execution plan
|
||||
parent_agent: Parent agent spawning subagents
|
||||
conversation_id: Current conversation ID
|
||||
user_message: Original user message
|
||||
available_tools: Available MCP tools
|
||||
|
||||
Returns:
|
||||
Aggregated results from subagent execution
|
||||
"""
|
||||
try:
|
||||
execution_id = str(uuid.uuid4())
|
||||
logger.info(f"Starting subagent execution {execution_id} for {task_classification.complexity} task")
|
||||
|
||||
# Track execution
|
||||
execution_record = {
|
||||
"execution_id": execution_id,
|
||||
"conversation_id": conversation_id,
|
||||
"parent_agent_id": parent_agent.id,
|
||||
"task_complexity": task_classification.complexity,
|
||||
"started_at": datetime.now().isoformat(),
|
||||
"subagent_plan": task_classification.subagent_plan
|
||||
}
|
||||
self.execution_history.append(execution_record)
|
||||
|
||||
# Determine execution strategy
|
||||
strategy = self._determine_strategy(task_classification)
|
||||
|
||||
# Execute based on strategy
|
||||
if strategy == ExecutionStrategy.PARALLEL:
|
||||
results = await self._execute_parallel(
|
||||
task_classification.subagent_plan,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools
|
||||
)
|
||||
elif strategy == ExecutionStrategy.SEQUENTIAL:
|
||||
results = await self._execute_sequential(
|
||||
task_classification.subagent_plan,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools
|
||||
)
|
||||
elif strategy == ExecutionStrategy.PIPELINE:
|
||||
results = await self._execute_pipeline(
|
||||
task_classification.subagent_plan,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools
|
||||
)
|
||||
else:
|
||||
# Default to sequential
|
||||
results = await self._execute_sequential(
|
||||
task_classification.subagent_plan,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools
|
||||
)
|
||||
|
||||
# Update execution record
|
||||
execution_record["completed_at"] = datetime.now().isoformat()
|
||||
execution_record["results"] = results
|
||||
|
||||
# Synthesize final response
|
||||
final_response = await self._synthesize_results(
|
||||
results,
|
||||
task_classification,
|
||||
user_message
|
||||
)
|
||||
|
||||
logger.info(f"Completed subagent execution {execution_id}")
|
||||
|
||||
return {
|
||||
"execution_id": execution_id,
|
||||
"strategy": strategy,
|
||||
"subagent_results": results,
|
||||
"final_response": final_response,
|
||||
"execution_time_ms": self._calculate_execution_time(execution_record)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Subagent execution failed: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"partial_results": self.active_subagents
|
||||
}
|
||||
|
||||
async def _execute_parallel(
|
||||
self,
|
||||
subagent_plan: List[Dict[str, Any]],
|
||||
parent_agent: Agent,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
available_tools: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute subagents in parallel"""
|
||||
# Group subagents by priority
|
||||
priority_groups = {}
|
||||
for plan_item in subagent_plan:
|
||||
priority = plan_item.get("priority", 1)
|
||||
if priority not in priority_groups:
|
||||
priority_groups[priority] = []
|
||||
priority_groups[priority].append(plan_item)
|
||||
|
||||
results = {}
|
||||
|
||||
# Execute each priority group
|
||||
for priority in sorted(priority_groups.keys()):
|
||||
group_tasks = []
|
||||
|
||||
for plan_item in priority_groups[priority]:
|
||||
# Check dependencies
|
||||
if self._dependencies_met(plan_item, results):
|
||||
task = asyncio.create_task(
|
||||
self._execute_subagent(
|
||||
plan_item,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools,
|
||||
results
|
||||
)
|
||||
)
|
||||
group_tasks.append((plan_item["id"], task))
|
||||
|
||||
# Wait for group to complete
|
||||
for agent_id, task in group_tasks:
|
||||
try:
|
||||
results[agent_id] = await task
|
||||
except Exception as e:
|
||||
logger.error(f"Subagent {agent_id} failed: {e}")
|
||||
results[agent_id] = {"error": str(e)}
|
||||
|
||||
return results
|
||||
|
||||
async def _execute_sequential(
|
||||
self,
|
||||
subagent_plan: List[Dict[str, Any]],
|
||||
parent_agent: Agent,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
available_tools: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute subagents sequentially"""
|
||||
results = {}
|
||||
|
||||
for plan_item in subagent_plan:
|
||||
if self._dependencies_met(plan_item, results):
|
||||
try:
|
||||
results[plan_item["id"]] = await self._execute_subagent(
|
||||
plan_item,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools,
|
||||
results
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Subagent {plan_item['id']} failed: {e}")
|
||||
results[plan_item["id"]] = {"error": str(e)}
|
||||
|
||||
return results
|
||||
|
||||
async def _execute_pipeline(
|
||||
self,
|
||||
subagent_plan: List[Dict[str, Any]],
|
||||
parent_agent: Agent,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
available_tools: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute subagents in pipeline mode"""
|
||||
results = {}
|
||||
pipeline_data = {"original_message": user_message}
|
||||
|
||||
for plan_item in subagent_plan:
|
||||
try:
|
||||
# Pass output from previous stage as input
|
||||
result = await self._execute_subagent(
|
||||
plan_item,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools,
|
||||
results,
|
||||
pipeline_data
|
||||
)
|
||||
|
||||
results[plan_item["id"]] = result
|
||||
|
||||
# Update pipeline data with output
|
||||
if "output" in result:
|
||||
pipeline_data = result["output"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline stage {plan_item['id']} failed: {e}")
|
||||
results[plan_item["id"]] = {"error": str(e)}
|
||||
break # Pipeline broken
|
||||
|
||||
return results
|
||||
|
||||
async def _execute_subagent(
|
||||
self,
|
||||
plan_item: Dict[str, Any],
|
||||
parent_agent: Agent,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
available_tools: List[Dict[str, Any]],
|
||||
previous_results: Dict[str, Any],
|
||||
pipeline_data: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a single subagent"""
|
||||
subagent_id = plan_item["id"]
|
||||
subagent_type = plan_item["type"]
|
||||
task_description = plan_item["task"]
|
||||
|
||||
logger.info(f"Executing subagent {subagent_id} ({subagent_type}): {task_description[:50]}...")
|
||||
|
||||
# Track subagent
|
||||
self.active_subagents[subagent_id] = {
|
||||
"type": subagent_type,
|
||||
"task": task_description,
|
||||
"started_at": datetime.now().isoformat(),
|
||||
"status": "running"
|
||||
}
|
||||
|
||||
try:
|
||||
# Create subagent configuration based on type
|
||||
subagent_config = self._create_subagent_config(
|
||||
subagent_type,
|
||||
parent_agent,
|
||||
task_description,
|
||||
pipeline_data
|
||||
)
|
||||
|
||||
# Select tools for this subagent
|
||||
subagent_tools = self._select_tools_for_subagent(
|
||||
subagent_type,
|
||||
available_tools
|
||||
)
|
||||
|
||||
# Execute subagent based on type
|
||||
if subagent_type == SubagentType.RESEARCH:
|
||||
result = await self._execute_research_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
subagent_tools,
|
||||
conversation_id
|
||||
)
|
||||
elif subagent_type == SubagentType.PLANNING:
|
||||
result = await self._execute_planning_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
user_message,
|
||||
previous_results
|
||||
)
|
||||
elif subagent_type == SubagentType.IMPLEMENTATION:
|
||||
result = await self._execute_implementation_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
subagent_tools,
|
||||
previous_results
|
||||
)
|
||||
elif subagent_type == SubagentType.VALIDATION:
|
||||
result = await self._execute_validation_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
previous_results
|
||||
)
|
||||
elif subagent_type == SubagentType.SYNTHESIS:
|
||||
result = await self._execute_synthesis_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
previous_results
|
||||
)
|
||||
elif subagent_type == SubagentType.ANALYST:
|
||||
result = await self._execute_analyst_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
previous_results
|
||||
)
|
||||
else:
|
||||
# Default execution
|
||||
result = await self._execute_generic_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
subagent_tools
|
||||
)
|
||||
|
||||
# Update tracking
|
||||
self.active_subagents[subagent_id]["status"] = "completed"
|
||||
self.active_subagents[subagent_id]["completed_at"] = datetime.now().isoformat()
|
||||
self.active_subagents[subagent_id]["result"] = result
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Subagent {subagent_id} execution failed: {e}")
|
||||
self.active_subagents[subagent_id]["status"] = "failed"
|
||||
self.active_subagents[subagent_id]["error"] = str(e)
|
||||
raise
|
||||
|
||||
async def _execute_research_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
conversation_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute research subagent"""
|
||||
# Research agents focus on information gathering
|
||||
prompt = f"""You are a research specialist. Your task is to:
|
||||
{task}
|
||||
|
||||
Available tools: {[t['name'] for t in tools]}
|
||||
|
||||
Gather comprehensive information and return structured findings."""
|
||||
|
||||
result = await self._call_llm_with_tools(
|
||||
prompt,
|
||||
config,
|
||||
tools,
|
||||
max_iterations=3
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "research",
|
||||
"findings": result.get("content", ""),
|
||||
"sources": result.get("tool_results", []),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_planning_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
original_query: str,
|
||||
previous_results: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute planning subagent"""
|
||||
context = self._format_previous_results(previous_results)
|
||||
|
||||
prompt = f"""You are a planning specialist. Break down this task into actionable steps:
|
||||
|
||||
Original request: {original_query}
|
||||
Specific task: {task}
|
||||
|
||||
Context from previous agents:
|
||||
{context}
|
||||
|
||||
Create a detailed execution plan with clear steps."""
|
||||
|
||||
result = await self._call_llm(prompt, config)
|
||||
|
||||
return {
|
||||
"type": "planning",
|
||||
"plan": result.get("content", ""),
|
||||
"steps": self._extract_steps(result.get("content", "")),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_implementation_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
previous_results: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute implementation subagent"""
|
||||
context = self._format_previous_results(previous_results)
|
||||
|
||||
prompt = f"""You are an implementation specialist. Execute this task:
|
||||
{task}
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Available tools: {[t['name'] for t in tools]}
|
||||
|
||||
Complete the implementation and return results."""
|
||||
|
||||
result = await self._call_llm_with_tools(
|
||||
prompt,
|
||||
config,
|
||||
tools,
|
||||
max_iterations=5
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "implementation",
|
||||
"implementation": result.get("content", ""),
|
||||
"tool_calls": result.get("tool_calls", []),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_validation_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
previous_results: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute validation subagent"""
|
||||
context = self._format_previous_results(previous_results)
|
||||
|
||||
prompt = f"""You are a validation specialist. Verify the following:
|
||||
{task}
|
||||
|
||||
Results to validate:
|
||||
{context}
|
||||
|
||||
Check for correctness, completeness, and quality."""
|
||||
|
||||
result = await self._call_llm(prompt, config)
|
||||
|
||||
return {
|
||||
"type": "validation",
|
||||
"validation_result": result.get("content", ""),
|
||||
"issues_found": self._extract_issues(result.get("content", "")),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_synthesis_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
previous_results: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute synthesis subagent"""
|
||||
all_results = self._format_all_results(previous_results)
|
||||
|
||||
prompt = f"""You are a synthesis specialist. Combine and summarize these results:
|
||||
|
||||
Task: {task}
|
||||
|
||||
Results from all agents:
|
||||
{all_results}
|
||||
|
||||
Create a comprehensive, coherent response that addresses the original request."""
|
||||
|
||||
result = await self._call_llm(prompt, config)
|
||||
|
||||
return {
|
||||
"type": "synthesis",
|
||||
"final_response": result.get("content", ""),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_analyst_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
previous_results: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute analyst subagent"""
|
||||
data = self._format_previous_results(previous_results)
|
||||
|
||||
prompt = f"""You are an analysis specialist. Analyze the following:
|
||||
{task}
|
||||
|
||||
Data to analyze:
|
||||
{data}
|
||||
|
||||
Identify patterns, insights, and recommendations."""
|
||||
|
||||
result = await self._call_llm(prompt, config)
|
||||
|
||||
return {
|
||||
"type": "analysis",
|
||||
"analysis": result.get("content", ""),
|
||||
"insights": self._extract_insights(result.get("content", "")),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_generic_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
tools: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute generic subagent"""
|
||||
prompt = f"""Complete the following task:
|
||||
{task}
|
||||
|
||||
Available tools: {[t['name'] for t in tools] if tools else 'None'}"""
|
||||
|
||||
if tools:
|
||||
result = await self._call_llm_with_tools(prompt, config, tools)
|
||||
else:
|
||||
result = await self._call_llm(prompt, config)
|
||||
|
||||
return {
|
||||
"type": "generic",
|
||||
"result": result.get("content", ""),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _call_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Call LLM without tools"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Require model to be specified in config - no hardcoded fallbacks
|
||||
model = config.get("model")
|
||||
if not model:
|
||||
raise ValueError(f"No model specified in subagent config: {config}")
|
||||
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": config.get("instructions", "")},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": config.get("temperature", 0.7),
|
||||
"max_tokens": config.get("max_tokens", 2000)
|
||||
},
|
||||
headers={
|
||||
"X-Tenant-ID": self.tenant_domain,
|
||||
"X-User-ID": self.user_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return {
|
||||
"content": result["choices"][0]["message"]["content"],
|
||||
"model": result["model"]
|
||||
}
|
||||
else:
|
||||
raise Exception(f"LLM call failed: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM call failed: {e}")
|
||||
return {"content": f"Error: {str(e)}"}
|
||||
|
||||
async def _call_llm_with_tools(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any],
|
||||
tools: List[Dict[str, Any]],
|
||||
max_iterations: int = 3
|
||||
) -> Dict[str, Any]:
|
||||
"""Call LLM with tool execution capability"""
|
||||
messages = [
|
||||
{"role": "system", "content": config.get("instructions", "")},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
tool_results = []
|
||||
iterations = 0
|
||||
|
||||
while iterations < max_iterations:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Require model to be specified in config - no hardcoded fallbacks
|
||||
model = config.get("model")
|
||||
if not model:
|
||||
raise ValueError(f"No model specified in subagent config: {config}")
|
||||
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": config.get("temperature", 0.7),
|
||||
"max_tokens": config.get("max_tokens", 2000),
|
||||
"tools": tools,
|
||||
"tool_choice": "auto"
|
||||
},
|
||||
headers={
|
||||
"X-Tenant-ID": self.tenant_domain,
|
||||
"X-User-ID": self.user_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"LLM call failed: {response.status_code}")
|
||||
|
||||
result = response.json()
|
||||
choice = result["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
# Add agent's response to messages
|
||||
messages.append(message)
|
||||
|
||||
# Check for tool calls
|
||||
if message.get("tool_calls"):
|
||||
# Execute tools
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_result = await self._execute_tool(
|
||||
tool_call["function"]["name"],
|
||||
tool_call["function"].get("arguments", {})
|
||||
)
|
||||
|
||||
tool_results.append({
|
||||
"tool": tool_call["function"]["name"],
|
||||
"result": tool_result
|
||||
})
|
||||
|
||||
# Add tool result to messages
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": str(tool_result)
|
||||
})
|
||||
|
||||
iterations += 1
|
||||
continue # Get next response
|
||||
|
||||
# No more tool calls, return final result
|
||||
return {
|
||||
"content": message.get("content", ""),
|
||||
"tool_calls": message.get("tool_calls", []),
|
||||
"tool_results": tool_results,
|
||||
"model": result["model"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM with tools call failed: {e}")
|
||||
return {"content": f"Error: {str(e)}", "tool_results": tool_results}
|
||||
|
||||
iterations += 1
|
||||
|
||||
# Max iterations reached
|
||||
return {
|
||||
"content": "Max iterations reached",
|
||||
"tool_results": tool_results
|
||||
}
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute an MCP tool"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/mcp/execute",
|
||||
json={
|
||||
"tool_name": tool_name,
|
||||
"parameters": arguments,
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"user_id": self.user_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return {"error": f"Tool execution failed: {response.status_code}"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _determine_strategy(self, task_classification: TaskClassification) -> ExecutionStrategy:
|
||||
"""Determine execution strategy based on task classification"""
|
||||
if task_classification.parallel_execution:
|
||||
return ExecutionStrategy.PARALLEL
|
||||
elif len(task_classification.subagent_plan) > 3:
|
||||
return ExecutionStrategy.PIPELINE
|
||||
else:
|
||||
return ExecutionStrategy.SEQUENTIAL
|
||||
|
||||
def _dependencies_met(
|
||||
self,
|
||||
plan_item: Dict[str, Any],
|
||||
completed_results: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check if dependencies are met for a subagent"""
|
||||
depends_on = plan_item.get("depends_on", [])
|
||||
return all(dep in completed_results for dep in depends_on)
|
||||
|
||||
def _create_subagent_config(
|
||||
self,
|
||||
subagent_type: SubagentType,
|
||||
parent_agent: Agent,
|
||||
task: str,
|
||||
pipeline_data: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create configuration for subagent"""
|
||||
# Base config from parent
|
||||
config = {
|
||||
"model": parent_agent.model_name,
|
||||
"temperature": parent_agent.model_settings.get("temperature", 0.7),
|
||||
"max_tokens": parent_agent.model_settings.get("max_tokens", 2000)
|
||||
}
|
||||
|
||||
# Customize based on subagent type
|
||||
if subagent_type == SubagentType.RESEARCH:
|
||||
config["instructions"] = "You are a research specialist. Be thorough and accurate."
|
||||
config["temperature"] = 0.3 # Lower for factual research
|
||||
elif subagent_type == SubagentType.PLANNING:
|
||||
config["instructions"] = "You are a planning specialist. Create clear, actionable plans."
|
||||
config["temperature"] = 0.5
|
||||
elif subagent_type == SubagentType.IMPLEMENTATION:
|
||||
config["instructions"] = "You are an implementation specialist. Execute tasks precisely."
|
||||
config["temperature"] = 0.3
|
||||
elif subagent_type == SubagentType.SYNTHESIS:
|
||||
config["instructions"] = "You are a synthesis specialist. Create coherent summaries."
|
||||
config["temperature"] = 0.7
|
||||
else:
|
||||
config["instructions"] = parent_agent.instructions or ""
|
||||
|
||||
return config
|
||||
|
||||
def _select_tools_for_subagent(
|
||||
self,
|
||||
subagent_type: SubagentType,
|
||||
available_tools: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Select appropriate tools for subagent type"""
|
||||
if not available_tools:
|
||||
return []
|
||||
|
||||
# Tool selection based on subagent type
|
||||
if subagent_type == SubagentType.RESEARCH:
|
||||
# Research agents get search tools
|
||||
return [t for t in available_tools if any(
|
||||
keyword in t["name"].lower()
|
||||
for keyword in ["search", "find", "list", "get", "fetch"]
|
||||
)]
|
||||
elif subagent_type == SubagentType.IMPLEMENTATION:
|
||||
# Implementation agents get action tools
|
||||
return [t for t in available_tools if any(
|
||||
keyword in t["name"].lower()
|
||||
for keyword in ["create", "update", "write", "execute", "run"]
|
||||
)]
|
||||
elif subagent_type == SubagentType.VALIDATION:
|
||||
# Validation agents get read/check tools
|
||||
return [t for t in available_tools if any(
|
||||
keyword in t["name"].lower()
|
||||
for keyword in ["read", "check", "verify", "test"]
|
||||
)]
|
||||
else:
|
||||
# Give all tools to other types
|
||||
return available_tools
|
||||
|
||||
async def _synthesize_results(
|
||||
self,
|
||||
results: Dict[str, Any],
|
||||
task_classification: TaskClassification,
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""Synthesize final response from all subagent results"""
|
||||
# Look for synthesis agent result first
|
||||
for agent_id, result in results.items():
|
||||
if result.get("type") == "synthesis":
|
||||
return result.get("final_response", "")
|
||||
|
||||
# Otherwise, compile results
|
||||
response_parts = []
|
||||
|
||||
# Add results in order of priority
|
||||
for plan_item in sorted(
|
||||
task_classification.subagent_plan,
|
||||
key=lambda x: x.get("priority", 999)
|
||||
):
|
||||
agent_id = plan_item["id"]
|
||||
if agent_id in results:
|
||||
result = results[agent_id]
|
||||
if "error" not in result:
|
||||
content = result.get("output", {}).get("content", "")
|
||||
if content:
|
||||
response_parts.append(content)
|
||||
|
||||
return "\n\n".join(response_parts) if response_parts else "Task completed"
|
||||
|
||||
def _format_previous_results(self, results: Dict[str, Any]) -> str:
|
||||
"""Format previous results for context"""
|
||||
if not results:
|
||||
return "No previous results"
|
||||
|
||||
formatted = []
|
||||
for agent_id, result in results.items():
|
||||
if "error" not in result:
|
||||
formatted.append(f"{agent_id}: {result.get('output', {}).get('content', '')[:200]}")
|
||||
|
||||
return "\n".join(formatted) if formatted else "No valid previous results"
|
||||
|
||||
def _format_all_results(self, results: Dict[str, Any]) -> str:
|
||||
"""Format all results for synthesis"""
|
||||
if not results:
|
||||
return "No results to synthesize"
|
||||
|
||||
formatted = []
|
||||
for agent_id, result in results.items():
|
||||
if "error" not in result:
|
||||
agent_type = result.get("type", "unknown")
|
||||
content = result.get("output", {}).get("content", "")
|
||||
formatted.append(f"[{agent_type}] {agent_id}:\n{content}\n")
|
||||
|
||||
return "\n".join(formatted) if formatted else "No valid results to synthesize"
|
||||
|
||||
def _extract_steps(self, content: str) -> List[str]:
|
||||
"""Extract steps from planning content"""
|
||||
import re
|
||||
steps = []
|
||||
|
||||
# Look for numbered lists
|
||||
pattern = r"(?:^|\n)\s*(?:\d+[\.\)]|\-|\*)\s+(.+)"
|
||||
matches = re.findall(pattern, content)
|
||||
|
||||
for match in matches:
|
||||
steps.append(match.strip())
|
||||
|
||||
return steps
|
||||
|
||||
def _extract_issues(self, content: str) -> List[str]:
|
||||
"""Extract issues from validation content"""
|
||||
import re
|
||||
issues = []
|
||||
|
||||
# Look for issue indicators
|
||||
issue_patterns = [
|
||||
r"(?:issue|problem|error|warning|concern):\s*(.+)",
|
||||
r"(?:^|\n)\s*[\-\*]\s*(?:Issue|Problem|Error):\s*(.+)"
|
||||
]
|
||||
|
||||
for pattern in issue_patterns:
|
||||
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||
issues.extend([m.strip() for m in matches])
|
||||
|
||||
return issues
|
||||
|
||||
def _extract_insights(self, content: str) -> List[str]:
|
||||
"""Extract insights from analysis content"""
|
||||
import re
|
||||
insights = []
|
||||
|
||||
# Look for insight indicators
|
||||
insight_patterns = [
|
||||
r"(?:insight|finding|observation|pattern):\s*(.+)",
|
||||
r"(?:^|\n)\s*\d+[\.\)]\s*(.+(?:shows?|indicates?|suggests?|reveals?).+)"
|
||||
]
|
||||
|
||||
for pattern in insight_patterns:
|
||||
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||
insights.extend([m.strip() for m in matches])
|
||||
|
||||
return insights
|
||||
|
||||
def _calculate_execution_time(self, execution_record: Dict[str, Any]) -> float:
|
||||
"""Calculate execution time in milliseconds"""
|
||||
if "completed_at" in execution_record and "started_at" in execution_record:
|
||||
start = datetime.fromisoformat(execution_record["started_at"])
|
||||
end = datetime.fromisoformat(execution_record["completed_at"])
|
||||
return (end - start).total_seconds() * 1000
|
||||
return 0.0
|
||||
|
||||
|
||||
# Factory function
|
||||
def get_subagent_orchestrator(tenant_domain: str, user_id: str) -> SubagentOrchestrator:
|
||||
"""Get subagent orchestrator instance"""
|
||||
return SubagentOrchestrator(tenant_domain, user_id)
|
||||
854
apps/tenant-backend/app/services/agent_service.py
Normal file
854
apps/tenant-backend/app/services/agent_service.py
Normal file
@@ -0,0 +1,854 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pathlib import Path
|
||||
from app.core.config import get_settings
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.core.permissions import get_user_role, validate_visibility_permission, can_edit_resource, can_delete_resource, is_effective_owner
|
||||
from app.services.category_service import CategoryService
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AgentService:
|
||||
"""GT 2.0 PostgreSQL+PGVector Agent Service with Perfect Tenant Isolation"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str, user_email: str = None):
|
||||
"""Initialize with tenant and user isolation using PostgreSQL+PGVector storage"""
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.user_email = user_email or user_id # Fallback to user_id if no email provided
|
||||
self.settings = get_settings()
|
||||
self._resolved_user_uuid = None # Cache for resolved user UUID (performance optimization)
|
||||
|
||||
logger.info(f"Agent service initialized with PostgreSQL+PGVector for {tenant_domain}/{user_id} (email: {self.user_email})")
|
||||
|
||||
async def _get_resolved_user_uuid(self, user_identifier: Optional[str] = None) -> str:
|
||||
"""
|
||||
Resolve user identifier to UUID with caching for performance.
|
||||
|
||||
This optimization reduces repeated database lookups by caching the resolved UUID.
|
||||
Performance impact: ~50% reduction in query time for operations with multiple queries.
|
||||
Pattern matches conversation_service.py for consistency.
|
||||
"""
|
||||
identifier = user_identifier or self.user_email or self.user_id
|
||||
|
||||
# Return cached UUID if already resolved for this instance
|
||||
if self._resolved_user_uuid and str(identifier) in [str(self.user_email), str(self.user_id)]:
|
||||
return self._resolved_user_uuid
|
||||
|
||||
# Check if already a UUID
|
||||
if "@" not in str(identifier):
|
||||
try:
|
||||
# Validate it's a proper UUID format
|
||||
uuid.UUID(str(identifier))
|
||||
if str(identifier) == str(self.user_id):
|
||||
self._resolved_user_uuid = str(identifier)
|
||||
return str(identifier)
|
||||
except (ValueError, AttributeError):
|
||||
pass # Not a valid UUID, treat as email/username
|
||||
|
||||
# Resolve email to UUID
|
||||
pg_client = await get_postgresql_client()
|
||||
query = """
|
||||
SELECT id FROM users
|
||||
WHERE (email = $1 OR username = $1)
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
LIMIT 1
|
||||
"""
|
||||
result = await pg_client.fetch_one(query, str(identifier), self.tenant_domain)
|
||||
|
||||
if not result:
|
||||
raise ValueError(f"User not found: {identifier}")
|
||||
|
||||
user_uuid = str(result["id"])
|
||||
|
||||
# Cache if this is the service's primary user
|
||||
if str(identifier) in [str(self.user_email), str(self.user_id)]:
|
||||
self._resolved_user_uuid = user_uuid
|
||||
|
||||
return user_uuid
|
||||
|
||||
async def create_agent(
|
||||
self,
|
||||
name: str,
|
||||
agent_type: str = "conversational",
|
||||
prompt_template: str = "",
|
||||
description: str = "",
|
||||
capabilities: Optional[List[str]] = None,
|
||||
access_group: str = "INDIVIDUAL",
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new agent using PostgreSQL+PGVector storage following GT 2.0 principles"""
|
||||
|
||||
try:
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Generate agent ID
|
||||
agent_id = str(uuid.uuid4())
|
||||
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_id = await self._get_resolved_user_uuid()
|
||||
|
||||
logger.info(f"Found user ID: {user_id} for email/id: {self.user_email}/{self.user_id}")
|
||||
|
||||
# Create agent in PostgreSQL
|
||||
query = """
|
||||
INSERT INTO agents (
|
||||
id, name, description, system_prompt,
|
||||
tenant_id, created_by, model, temperature, max_tokens,
|
||||
visibility, configuration, is_active, access_group, agent_type
|
||||
) VALUES (
|
||||
$1, $2, $3, $4,
|
||||
(SELECT id FROM tenants WHERE domain = $5 LIMIT 1),
|
||||
$6,
|
||||
$7, $8, $9, $10, $11, true, $12, $13
|
||||
)
|
||||
RETURNING id, name, description, system_prompt, model, temperature, max_tokens,
|
||||
visibility, configuration, access_group, agent_type, created_at, updated_at
|
||||
"""
|
||||
|
||||
# Prepare configuration with additional kwargs
|
||||
# Ensure list fields are always lists, never None
|
||||
configuration = {
|
||||
"agent_type": agent_type,
|
||||
"capabilities": capabilities or [],
|
||||
"personality_config": kwargs.get("personality_config", {}),
|
||||
"resource_preferences": kwargs.get("resource_preferences", {}),
|
||||
"model_config": kwargs.get("model_config", {}),
|
||||
"tags": kwargs.get("tags") or [],
|
||||
"easy_prompts": kwargs.get("easy_prompts") or [],
|
||||
"selected_dataset_ids": kwargs.get("selected_dataset_ids") or [],
|
||||
**{k: v for k, v in kwargs.items() if k not in ["tags", "easy_prompts", "selected_dataset_ids"]}
|
||||
}
|
||||
|
||||
# Extract model configuration
|
||||
model = kwargs.get("model")
|
||||
if not model:
|
||||
raise ValueError("Model is required for agent creation")
|
||||
temperature = kwargs.get("temperature", 0.7)
|
||||
max_tokens = kwargs.get("max_tokens", 8000) # Increased to match Groq Llama 3.1 capabilities
|
||||
|
||||
# Use access_group as visibility directly (individual, organization only)
|
||||
visibility = access_group.lower()
|
||||
|
||||
# Validate visibility permission based on user role
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
validate_visibility_permission(visibility, user_role)
|
||||
logger.info(f"User {self.user_email} (role: {user_role}) creating agent with visibility: {visibility}")
|
||||
|
||||
# Auto-create category if specified (Issue #215)
|
||||
# This ensures imported agents with unknown categories create those categories
|
||||
# Category is stored in agent_type column
|
||||
category = kwargs.get("category")
|
||||
if category and isinstance(category, str) and category.strip():
|
||||
category_slug = category.strip().lower()
|
||||
try:
|
||||
category_service = CategoryService(self.tenant_domain, user_id, self.user_email)
|
||||
# Pass category_description from CSV import if provided
|
||||
category_description = kwargs.get("category_description")
|
||||
await category_service.get_or_create_category(category_slug, description=category_description)
|
||||
logger.info(f"Ensured category exists: {category}")
|
||||
except Exception as cat_err:
|
||||
logger.warning(f"Failed to ensure category '{category}' exists: {cat_err}")
|
||||
# Continue with agent creation even if category creation fails
|
||||
# Use category as agent_type (they map to the same column)
|
||||
agent_type = category_slug
|
||||
|
||||
agent_data = await pg_client.fetch_one(
|
||||
query,
|
||||
agent_id, name, description, prompt_template,
|
||||
self.tenant_domain, user_id,
|
||||
model, temperature, max_tokens, visibility,
|
||||
json.dumps(configuration), access_group, agent_type
|
||||
)
|
||||
|
||||
if not agent_data:
|
||||
raise RuntimeError("Failed to create agent - no data returned")
|
||||
|
||||
# Convert to dict with proper types
|
||||
# Parse configuration JSON if it's a string
|
||||
config = agent_data["configuration"]
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
elif config is None:
|
||||
config = {}
|
||||
|
||||
result = {
|
||||
"id": str(agent_data["id"]),
|
||||
"name": agent_data["name"],
|
||||
"agent_type": config.get("agent_type", "conversational"),
|
||||
"prompt_template": agent_data["system_prompt"],
|
||||
"description": agent_data["description"],
|
||||
"capabilities": config.get("capabilities", []),
|
||||
"access_group": agent_data["access_group"],
|
||||
"config": config,
|
||||
"model": agent_data["model"],
|
||||
"temperature": float(agent_data["temperature"]) if agent_data["temperature"] is not None else None,
|
||||
"max_tokens": agent_data["max_tokens"],
|
||||
"top_p": config.get("top_p"),
|
||||
"frequency_penalty": config.get("frequency_penalty"),
|
||||
"presence_penalty": config.get("presence_penalty"),
|
||||
"visibility": agent_data["visibility"],
|
||||
"dataset_connection": config.get("dataset_connection"),
|
||||
"selected_dataset_ids": config.get("selected_dataset_ids", []),
|
||||
"max_chunks_per_query": config.get("max_chunks_per_query"),
|
||||
"history_context": config.get("history_context"),
|
||||
"personality_config": config.get("personality_config", {}),
|
||||
"resource_preferences": config.get("resource_preferences", {}),
|
||||
"tags": config.get("tags", []),
|
||||
"is_favorite": config.get("is_favorite", False),
|
||||
"conversation_count": 0,
|
||||
"total_cost_cents": 0,
|
||||
"created_at": agent_data["created_at"].isoformat(),
|
||||
"updated_at": agent_data["updated_at"].isoformat(),
|
||||
"user_id": self.user_id,
|
||||
"tenant_domain": self.tenant_domain
|
||||
}
|
||||
|
||||
logger.info(f"Created agent {agent_id} in PostgreSQL for user {self.user_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent: {e}")
|
||||
raise
|
||||
|
||||
async def get_user_agents(
|
||||
self,
|
||||
active_only: bool = True,
|
||||
sort_by: Optional[str] = None,
|
||||
filter_usage: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all agents for the current user using PostgreSQL storage"""
|
||||
try:
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
try:
|
||||
user_id = await self._get_resolved_user_uuid()
|
||||
except ValueError as e:
|
||||
logger.warning(f"User not found for agents list: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}: {e}")
|
||||
return []
|
||||
|
||||
# Get user role to determine access level
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
is_admin = user_role in ["admin", "developer"]
|
||||
|
||||
# Query agents from PostgreSQL with conversation counts
|
||||
# Admins see ALL agents, others see only their own or organization-level agents
|
||||
if is_admin:
|
||||
where_clause = "WHERE a.tenant_id = (SELECT id FROM tenants WHERE domain = $1)"
|
||||
params = [self.tenant_domain]
|
||||
else:
|
||||
where_clause = "WHERE (a.created_by = $1 OR a.visibility = 'organization') AND a.tenant_id = (SELECT id FROM tenants WHERE domain = $2)"
|
||||
params = [user_id, self.tenant_domain]
|
||||
|
||||
# Prepare user_id parameter for per-user usage tracking
|
||||
# Need to add user_id as an additional parameter for usage calculations
|
||||
user_id_param_index = len(params) + 1
|
||||
params.append(user_id)
|
||||
|
||||
# Per-user usage tracking: Count only conversations for this user
|
||||
query = f"""
|
||||
SELECT
|
||||
a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
|
||||
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
|
||||
a.is_active, a.created_by, a.agent_type,
|
||||
u.full_name as created_by_name,
|
||||
COUNT(CASE WHEN c.user_id = ${user_id_param_index}::uuid THEN c.id END) as user_conversation_count,
|
||||
MAX(CASE WHEN c.user_id = ${user_id_param_index}::uuid THEN c.created_at END) as user_last_used_at
|
||||
FROM agents a
|
||||
LEFT JOIN conversations c ON a.id = c.agent_id
|
||||
LEFT JOIN users u ON a.created_by = u.id
|
||||
{where_clause}
|
||||
"""
|
||||
|
||||
if active_only:
|
||||
query += " AND a.is_active = true"
|
||||
|
||||
# Time-based usage filters (per-user)
|
||||
if filter_usage == "used_last_7_days":
|
||||
query += f" AND EXISTS (SELECT 1 FROM conversations c2 WHERE c2.agent_id = a.id AND c2.user_id = ${user_id_param_index}::uuid AND c2.created_at >= NOW() - INTERVAL '7 days')"
|
||||
elif filter_usage == "used_last_30_days":
|
||||
query += f" AND EXISTS (SELECT 1 FROM conversations c2 WHERE c2.agent_id = a.id AND c2.user_id = ${user_id_param_index}::uuid AND c2.created_at >= NOW() - INTERVAL '30 days')"
|
||||
|
||||
query += " GROUP BY a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens, a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at, a.is_active, a.created_by, a.agent_type, u.full_name"
|
||||
|
||||
# User-specific sorting
|
||||
if sort_by == "recent_usage":
|
||||
query += " ORDER BY user_last_used_at DESC NULLS LAST, a.updated_at DESC"
|
||||
elif sort_by == "my_most_used":
|
||||
query += " ORDER BY user_conversation_count DESC, a.updated_at DESC"
|
||||
else:
|
||||
query += " ORDER BY a.updated_at DESC"
|
||||
|
||||
agents_data = await pg_client.execute_query(query, *params)
|
||||
|
||||
# Convert to proper format
|
||||
agents = []
|
||||
for agent in agents_data:
|
||||
# Debug logging for creator name
|
||||
logger.info(f"🔍 Agent '{agent['name']}': created_by={agent.get('created_by')}, created_by_name={agent.get('created_by_name')}")
|
||||
|
||||
# Parse configuration JSON if it's a string
|
||||
config = agent["configuration"]
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
elif config is None:
|
||||
config = {}
|
||||
|
||||
disclaimer_val = config.get("disclaimer")
|
||||
easy_prompts_val = config.get("easy_prompts", [])
|
||||
logger.info(f"get_user_agents - Agent {agent['name']}: disclaimer={disclaimer_val}, easy_prompts={easy_prompts_val}")
|
||||
|
||||
# Determine if user can edit this agent
|
||||
# User can edit if they created it OR if they're admin/developer
|
||||
# Use cached user_role from line 190 (no need to re-query for each agent)
|
||||
is_owner = is_effective_owner(str(agent["created_by"]), str(user_id), user_role)
|
||||
can_edit = can_edit_resource(str(agent["created_by"]), str(user_id), user_role, agent["visibility"])
|
||||
can_delete = can_delete_resource(str(agent["created_by"]), str(user_id), user_role)
|
||||
|
||||
logger.info(f"Agent {agent['name']}: created_by={agent['created_by']}, user_id={user_id}, user_role={user_role}, is_owner={is_owner}, can_edit={can_edit}, can_delete={can_delete}")
|
||||
|
||||
agents.append({
|
||||
"id": str(agent["id"]),
|
||||
"name": agent["name"],
|
||||
"agent_type": agent["agent_type"] or "conversational",
|
||||
"prompt_template": agent["system_prompt"],
|
||||
"description": agent["description"],
|
||||
"capabilities": config.get("capabilities", []),
|
||||
"access_group": agent["access_group"],
|
||||
"config": config,
|
||||
"model": agent["model"],
|
||||
"temperature": float(agent["temperature"]) if agent["temperature"] is not None else None,
|
||||
"max_tokens": agent["max_tokens"],
|
||||
"visibility": agent["visibility"],
|
||||
"dataset_connection": config.get("dataset_connection"),
|
||||
"selected_dataset_ids": config.get("selected_dataset_ids", []),
|
||||
"personality_config": config.get("personality_config", {}),
|
||||
"resource_preferences": config.get("resource_preferences", {}),
|
||||
"tags": config.get("tags", []),
|
||||
"is_favorite": config.get("is_favorite", False),
|
||||
"disclaimer": disclaimer_val,
|
||||
"easy_prompts": easy_prompts_val,
|
||||
"conversation_count": int(agent["user_conversation_count"]) if agent.get("user_conversation_count") is not None else 0,
|
||||
"last_used_at": agent["user_last_used_at"].isoformat() if agent.get("user_last_used_at") else None,
|
||||
"total_cost_cents": 0,
|
||||
"created_at": agent["created_at"].isoformat() if agent["created_at"] else None,
|
||||
"updated_at": agent["updated_at"].isoformat() if agent["updated_at"] else None,
|
||||
"is_active": agent["is_active"],
|
||||
"user_id": agent["created_by"],
|
||||
"created_by_name": agent.get("created_by_name", "Unknown"),
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"can_edit": can_edit,
|
||||
"can_delete": can_delete,
|
||||
"is_owner": is_owner
|
||||
})
|
||||
|
||||
# Fetch team-shared agents and merge with owned agents
|
||||
team_shared = await self.get_team_shared_agents(user_id)
|
||||
|
||||
# Merge and deduplicate (owned agents take precedence)
|
||||
agent_ids_seen = {agent["id"] for agent in agents}
|
||||
for team_agent in team_shared:
|
||||
if team_agent["id"] not in agent_ids_seen:
|
||||
agents.append(team_agent)
|
||||
agent_ids_seen.add(team_agent["id"])
|
||||
|
||||
logger.info(f"Retrieved {len(agents)} total agents ({len(agents) - len(team_shared)} owned + {len(team_shared)} team-shared) from PostgreSQL for user {self.user_id}")
|
||||
return agents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading agents for user {self.user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_team_shared_agents(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get agents shared to teams where user is a member (via junction table).
|
||||
|
||||
Uses the user_accessible_resources view for efficient lookups.
|
||||
|
||||
Returns agents with permission flags:
|
||||
- can_edit: True if user has 'edit' permission for this agent
|
||||
- can_delete: False (only owner can delete)
|
||||
- is_owner: False (team-shared agents)
|
||||
- shared_via_team: True (indicates team sharing)
|
||||
- shared_in_teams: Number of teams this agent is shared with
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Query agents using the efficient user_accessible_resources view
|
||||
# This view joins team_memberships -> team_resource_shares -> agents
|
||||
# Include per-user usage statistics
|
||||
query = """
|
||||
SELECT DISTINCT
|
||||
a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
|
||||
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
|
||||
a.is_active, a.created_by, a.agent_type,
|
||||
u.full_name as created_by_name,
|
||||
COUNT(DISTINCT CASE WHEN c.user_id = $1::uuid THEN c.id END) as user_conversation_count,
|
||||
MAX(CASE WHEN c.user_id = $1::uuid THEN c.created_at END) as user_last_used_at,
|
||||
uar.best_permission as user_permission,
|
||||
uar.shared_in_teams,
|
||||
uar.team_ids
|
||||
FROM user_accessible_resources uar
|
||||
INNER JOIN agents a ON a.id = uar.resource_id
|
||||
LEFT JOIN users u ON a.created_by = u.id
|
||||
LEFT JOIN conversations c ON a.id = c.agent_id
|
||||
WHERE uar.user_id = $1::uuid
|
||||
AND uar.resource_type = 'agent'
|
||||
AND a.tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
AND a.is_active = true
|
||||
GROUP BY a.id, a.name, a.description, a.system_prompt, a.model, a.temperature,
|
||||
a.max_tokens, a.visibility, a.configuration, a.access_group, a.created_at,
|
||||
a.updated_at, a.is_active, a.created_by, a.agent_type, u.full_name,
|
||||
uar.best_permission, uar.shared_in_teams, uar.team_ids
|
||||
ORDER BY a.updated_at DESC
|
||||
"""
|
||||
|
||||
agents_data = await pg_client.execute_query(query, user_id, self.tenant_domain)
|
||||
|
||||
# Format agents with team sharing metadata
|
||||
agents = []
|
||||
for agent in agents_data:
|
||||
# Parse configuration JSON
|
||||
config = agent["configuration"]
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
elif config is None:
|
||||
config = {}
|
||||
|
||||
# Get permission from view (will be "read" or "edit")
|
||||
user_permission = agent.get("user_permission")
|
||||
can_edit = user_permission == "edit"
|
||||
|
||||
# Get team sharing metadata
|
||||
shared_in_teams = agent.get("shared_in_teams", 0)
|
||||
team_ids = agent.get("team_ids", [])
|
||||
|
||||
agents.append({
|
||||
"id": str(agent["id"]),
|
||||
"name": agent["name"],
|
||||
"agent_type": agent["agent_type"] or "conversational",
|
||||
"prompt_template": agent["system_prompt"],
|
||||
"description": agent["description"],
|
||||
"capabilities": config.get("capabilities", []),
|
||||
"access_group": agent["access_group"],
|
||||
"config": config,
|
||||
"model": agent["model"],
|
||||
"temperature": float(agent["temperature"]) if agent["temperature"] is not None else None,
|
||||
"max_tokens": agent["max_tokens"],
|
||||
"visibility": agent["visibility"],
|
||||
"dataset_connection": config.get("dataset_connection"),
|
||||
"selected_dataset_ids": config.get("selected_dataset_ids", []),
|
||||
"personality_config": config.get("personality_config", {}),
|
||||
"resource_preferences": config.get("resource_preferences", {}),
|
||||
"tags": config.get("tags", []),
|
||||
"is_favorite": config.get("is_favorite", False),
|
||||
"disclaimer": config.get("disclaimer"),
|
||||
"easy_prompts": config.get("easy_prompts", []),
|
||||
"conversation_count": int(agent["user_conversation_count"]) if agent.get("user_conversation_count") else 0,
|
||||
"last_used_at": agent["user_last_used_at"].isoformat() if agent.get("user_last_used_at") else None,
|
||||
"total_cost_cents": 0,
|
||||
"created_at": agent["created_at"].isoformat() if agent["created_at"] else None,
|
||||
"updated_at": agent["updated_at"].isoformat() if agent["updated_at"] else None,
|
||||
"is_active": agent["is_active"],
|
||||
"user_id": agent["created_by"],
|
||||
"created_by_name": agent.get("created_by_name", "Unknown"),
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"can_edit": can_edit,
|
||||
"can_delete": False, # Only owner can delete
|
||||
"is_owner": False, # Team-shared agents
|
||||
"shared_via_team": True,
|
||||
"shared_in_teams": shared_in_teams,
|
||||
"team_ids": [str(tid) for tid in team_ids] if team_ids else [],
|
||||
"team_permission": user_permission
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(agents)} team-shared agents for user {user_id}")
|
||||
return agents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching team-shared agents for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a specific agent by ID using PostgreSQL"""
|
||||
try:
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
try:
|
||||
user_id = await self._get_resolved_user_uuid()
|
||||
except ValueError as e:
|
||||
logger.warning(f"User not found: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}: {e}")
|
||||
return None
|
||||
|
||||
# Check if user is admin - admins can see all agents
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
is_admin = user_role in ["admin", "developer"]
|
||||
|
||||
# Query the agent first
|
||||
query = """
|
||||
SELECT
|
||||
a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
|
||||
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
|
||||
a.is_active, a.created_by, a.agent_type,
|
||||
COUNT(c.id) as conversation_count
|
||||
FROM agents a
|
||||
LEFT JOIN conversations c ON a.id = c.agent_id
|
||||
WHERE a.id = $1 AND a.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
GROUP BY a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
|
||||
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
|
||||
a.is_active, a.created_by, a.agent_type
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
agent_data = await pg_client.fetch_one(query, agent_id, self.tenant_domain)
|
||||
logger.info(f"Agent query result: {agent_data is not None}")
|
||||
|
||||
# If agent doesn't exist, return None
|
||||
if not agent_data:
|
||||
return None
|
||||
|
||||
# Check access: admin, owner, organization, or team-based
|
||||
if not is_admin:
|
||||
is_owner = str(agent_data["created_by"]) == str(user_id)
|
||||
is_org_wide = agent_data["visibility"] == "organization"
|
||||
|
||||
# Check team-based access if not owner or org-wide
|
||||
if not is_owner and not is_org_wide:
|
||||
# Import TeamService here to avoid circular dependency
|
||||
from app.services.team_service import TeamService
|
||||
team_service = TeamService(self.tenant_domain, str(user_id), self.user_email)
|
||||
|
||||
has_team_access = await team_service.check_user_resource_permission(
|
||||
user_id=str(user_id),
|
||||
resource_type="agent",
|
||||
resource_id=agent_id,
|
||||
required_permission="read"
|
||||
)
|
||||
|
||||
if not has_team_access:
|
||||
logger.warning(f"User {user_id} denied access to agent {agent_id}")
|
||||
return None
|
||||
|
||||
logger.info(f"User {user_id} has team-based access to agent {agent_id}")
|
||||
|
||||
if agent_data:
|
||||
# Parse configuration JSON if it's a string
|
||||
config = agent_data["configuration"]
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
elif config is None:
|
||||
config = {}
|
||||
|
||||
# Convert to proper format
|
||||
logger.info(f"Config disclaimer: {config.get('disclaimer')}, easy_prompts: {config.get('easy_prompts')}")
|
||||
|
||||
# Compute is_owner for export permission checks
|
||||
is_owner = str(agent_data["created_by"]) == str(user_id)
|
||||
|
||||
result = {
|
||||
"id": str(agent_data["id"]),
|
||||
"name": agent_data["name"],
|
||||
"agent_type": agent_data["agent_type"] or "conversational",
|
||||
"prompt_template": agent_data["system_prompt"],
|
||||
"description": agent_data["description"],
|
||||
"capabilities": config.get("capabilities", []),
|
||||
"access_group": agent_data["access_group"],
|
||||
"config": config,
|
||||
"model": agent_data["model"],
|
||||
"temperature": float(agent_data["temperature"]) if agent_data["temperature"] is not None else None,
|
||||
"max_tokens": agent_data["max_tokens"],
|
||||
"visibility": agent_data["visibility"],
|
||||
"dataset_connection": config.get("dataset_connection"),
|
||||
"selected_dataset_ids": config.get("selected_dataset_ids", []),
|
||||
"personality_config": config.get("personality_config", {}),
|
||||
"resource_preferences": config.get("resource_preferences", {}),
|
||||
"tags": config.get("tags", []),
|
||||
"is_favorite": config.get("is_favorite", False),
|
||||
"disclaimer": config.get("disclaimer"),
|
||||
"easy_prompts": config.get("easy_prompts", []),
|
||||
"conversation_count": int(agent_data["conversation_count"]) if agent_data.get("conversation_count") is not None else 0,
|
||||
"total_cost_cents": 0,
|
||||
"created_at": agent_data["created_at"].isoformat() if agent_data["created_at"] else None,
|
||||
"updated_at": agent_data["updated_at"].isoformat() if agent_data["updated_at"] else None,
|
||||
"is_active": agent_data["is_active"],
|
||||
"created_by": agent_data["created_by"], # Keep DB field
|
||||
"user_id": agent_data["created_by"], # Alias for compatibility
|
||||
"is_owner": is_owner, # Computed ownership for export/edit permissions
|
||||
"tenant_domain": self.tenant_domain
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading agent {agent_id}: {e}")
|
||||
return None
|
||||
|
||||
async def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
updates: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Update an agent's configuration using PostgreSQL with permission checks"""
|
||||
try:
|
||||
logger.info(f"Processing updates for agent {agent_id}: {updates}")
|
||||
|
||||
# Log which fields will be processed
|
||||
logger.info(f"Update fields being processed: {list(updates.keys())}")
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get user role for permission checks
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
|
||||
# If updating visibility, validate permission
|
||||
if "visibility" in updates:
|
||||
validate_visibility_permission(updates["visibility"], user_role)
|
||||
logger.info(f"User {self.user_email} (role: {user_role}) updating agent visibility to: {updates['visibility']}")
|
||||
|
||||
# Build dynamic UPDATE query based on provided updates
|
||||
set_clauses = []
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
# Collect all configuration updates in a single object
|
||||
config_updates = {}
|
||||
|
||||
# Handle each update field mapping to correct column names
|
||||
for field, value in updates.items():
|
||||
if field in ["name", "description", "access_group"]:
|
||||
set_clauses.append(f"{field} = ${param_idx}")
|
||||
params.append(value)
|
||||
param_idx += 1
|
||||
elif field == "prompt_template":
|
||||
set_clauses.append(f"system_prompt = ${param_idx}")
|
||||
params.append(value)
|
||||
param_idx += 1
|
||||
elif field in ["model", "temperature", "max_tokens", "visibility", "agent_type"]:
|
||||
set_clauses.append(f"{field} = ${param_idx}")
|
||||
params.append(value)
|
||||
param_idx += 1
|
||||
elif field == "is_active":
|
||||
set_clauses.append(f"is_active = ${param_idx}")
|
||||
params.append(value)
|
||||
param_idx += 1
|
||||
elif field in ["config", "configuration", "personality_config", "resource_preferences", "tags", "is_favorite",
|
||||
"dataset_connection", "selected_dataset_ids", "disclaimer", "easy_prompts"]:
|
||||
# Collect configuration updates
|
||||
if field in ["config", "configuration"]:
|
||||
config_updates.update(value if isinstance(value, dict) else {})
|
||||
else:
|
||||
config_updates[field] = value
|
||||
|
||||
# Apply configuration updates as a single operation
|
||||
if config_updates:
|
||||
set_clauses.append(f"configuration = configuration || ${param_idx}::jsonb")
|
||||
params.append(json.dumps(config_updates))
|
||||
param_idx += 1
|
||||
|
||||
if not set_clauses:
|
||||
logger.warning(f"No valid update fields provided for agent {agent_id}")
|
||||
return await self.get_agent(agent_id)
|
||||
|
||||
# Add updated_at timestamp
|
||||
set_clauses.append(f"updated_at = NOW()")
|
||||
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
try:
|
||||
user_id = await self._get_resolved_user_uuid()
|
||||
except ValueError as e:
|
||||
logger.warning(f"User not found for update: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}: {e}")
|
||||
return None
|
||||
|
||||
# Check if user is admin - admins can update any agent
|
||||
is_admin = user_role in ["admin", "developer"]
|
||||
|
||||
# Build final query - admins can update any agent in tenant, others only their own
|
||||
if is_admin:
|
||||
query = f"""
|
||||
UPDATE agents
|
||||
SET {', '.join(set_clauses)}
|
||||
WHERE id = ${param_idx}
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = ${param_idx + 1})
|
||||
RETURNING id
|
||||
"""
|
||||
params.extend([agent_id, self.tenant_domain])
|
||||
else:
|
||||
query = f"""
|
||||
UPDATE agents
|
||||
SET {', '.join(set_clauses)}
|
||||
WHERE id = ${param_idx}
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = ${param_idx + 1})
|
||||
AND created_by = ${param_idx + 2}
|
||||
RETURNING id
|
||||
"""
|
||||
params.extend([agent_id, self.tenant_domain, user_id])
|
||||
|
||||
# Execute update
|
||||
logger.info(f"Executing update query: {query}")
|
||||
logger.info(f"Query parameters: {params}")
|
||||
updated_id = await pg_client.fetch_scalar(query, *params)
|
||||
logger.info(f"Update result: {updated_id}")
|
||||
|
||||
if updated_id:
|
||||
# Get updated agent data
|
||||
updated_agent = await self.get_agent(agent_id)
|
||||
|
||||
logger.info(f"Updated agent {agent_id} in PostgreSQL")
|
||||
return updated_agent
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating agent {agent_id}: {e}")
|
||||
return None
|
||||
|
||||
async def delete_agent(self, agent_id: str) -> bool:
|
||||
"""Soft delete an agent using PostgreSQL"""
|
||||
try:
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get user role to check if admin
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
is_admin = user_role in ["admin", "developer"]
|
||||
|
||||
# Soft delete in PostgreSQL - admins can delete any agent, others only their own
|
||||
if is_admin:
|
||||
query = """
|
||||
UPDATE agents
|
||||
SET is_active = false, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
RETURNING id
|
||||
"""
|
||||
deleted_id = await pg_client.fetch_scalar(query, agent_id, self.tenant_domain)
|
||||
else:
|
||||
query = """
|
||||
UPDATE agents
|
||||
SET is_active = false, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
AND created_by = (SELECT id FROM users WHERE email = $3)
|
||||
RETURNING id
|
||||
"""
|
||||
deleted_id = await pg_client.fetch_scalar(query, agent_id, self.tenant_domain, self.user_email or self.user_id)
|
||||
|
||||
if deleted_id:
|
||||
logger.info(f"Deleted agent {agent_id} from PostgreSQL")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting agent {agent_id}: {e}")
|
||||
return False
|
||||
|
||||
async def check_access_permission(self, agent_id: str, requesting_user_id: str, access_type: str = "read") -> bool:
|
||||
"""
|
||||
Check if user has access to agent (via ownership, organization, or team).
|
||||
|
||||
Args:
|
||||
agent_id: UUID of the agent
|
||||
requesting_user_id: UUID of the user requesting access
|
||||
access_type: 'read' or 'edit' (default: 'read')
|
||||
|
||||
Returns:
|
||||
True if user has required access
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Check if admin/developer
|
||||
user_role = await get_user_role(pg_client, requesting_user_id, self.tenant_domain)
|
||||
if user_role in ["admin", "developer"]:
|
||||
return True
|
||||
|
||||
# Get agent to check ownership and visibility
|
||||
query = """
|
||||
SELECT created_by, visibility
|
||||
FROM agents
|
||||
WHERE id = $1 AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
"""
|
||||
agent_data = await pg_client.fetch_one(query, agent_id, self.tenant_domain)
|
||||
|
||||
if not agent_data:
|
||||
return False
|
||||
|
||||
owner_id = str(agent_data["created_by"])
|
||||
visibility = agent_data["visibility"]
|
||||
|
||||
# Owner has full access
|
||||
if requesting_user_id == owner_id:
|
||||
return True
|
||||
|
||||
# Organization-wide resources are accessible to all in tenant
|
||||
if visibility == "organization":
|
||||
return True
|
||||
|
||||
# Check team-based access
|
||||
from app.services.team_service import TeamService
|
||||
team_service = TeamService(self.tenant_domain, requesting_user_id, requesting_user_id)
|
||||
|
||||
return await team_service.check_user_resource_permission(
|
||||
user_id=requesting_user_id,
|
||||
resource_type="agent",
|
||||
resource_id=agent_id,
|
||||
required_permission=access_type
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking access permission for agent {agent_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _check_team_membership(self, user_id: str, team_members: List[str]) -> bool:
|
||||
"""Check if user is in the team members list"""
|
||||
return user_id in team_members
|
||||
|
||||
async def _check_same_tenant(self, user_id: str) -> bool:
|
||||
"""Check if requesting user is in the same tenant through PostgreSQL"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Check if user exists in same tenant
|
||||
query = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM users
|
||||
WHERE id = $1 AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
"""
|
||||
|
||||
result = await pg_client.fetch_one(query, user_id, self.tenant_domain)
|
||||
return result and result["count"] > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check tenant membership for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_agent_conversation_history(self, agent_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get conversation history for an agent (file-based)"""
|
||||
conversations_path = Path(f"/data/{self.tenant_domain}/users/{self.user_id}/conversations")
|
||||
conversations_path.mkdir(parents=True, exist_ok=True, mode=0o700)
|
||||
|
||||
conversations = []
|
||||
try:
|
||||
for conv_file in conversations_path.glob("*.json"):
|
||||
with open(conv_file, 'r') as f:
|
||||
conv_data = json.load(f)
|
||||
if conv_data.get("agent_id") == agent_id:
|
||||
conversations.append(conv_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading conversations for agent {agent_id}: {e}")
|
||||
|
||||
conversations.sort(key=lambda x: x.get("updated_at", ""), reverse=True)
|
||||
return conversations
|
||||
493
apps/tenant-backend/app/services/assistant_builder.py
Normal file
493
apps/tenant-backend/app/services/assistant_builder.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""
|
||||
Assistant Builder Service for GT 2.0
|
||||
|
||||
Manages assistant creation, deployment, and lifecycle.
|
||||
Integrates with template library and file-based storage.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import stat
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from app.models.assistant_template import (
|
||||
AssistantTemplate, AssistantInstance, AssistantBuilder,
|
||||
AssistantType, PersonalityConfig, ResourcePreferences, MemorySettings,
|
||||
AssistantTemplateLibrary, BUILTIN_TEMPLATES
|
||||
)
|
||||
from app.models.access_group import AccessGroup
|
||||
from app.core.security import verify_capability_token
|
||||
from app.services.access_controller import AccessController
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssistantBuilderService:
|
||||
"""
|
||||
Service for building and managing assistants
|
||||
Handles both template-based and custom assistant creation
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, resource_cluster_url: str = "http://resource-cluster:8004"):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.base_path = Path(f"/data/{tenant_domain}/assistants")
|
||||
self.template_library = AssistantTemplateLibrary(resource_cluster_url)
|
||||
self.access_controller = AccessController(tenant_domain)
|
||||
self._ensure_directories()
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure assistant directories exist with proper permissions"""
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(self.base_path, stat.S_IRWXU) # 700
|
||||
|
||||
# Create subdirectories
|
||||
for subdir in ["templates", "instances", "shared"]:
|
||||
path = self.base_path / subdir
|
||||
path.mkdir(exist_ok=True)
|
||||
os.chmod(path, stat.S_IRWXU) # 700
|
||||
|
||||
async def create_from_template(
|
||||
self,
|
||||
template_id: str,
|
||||
user_id: str,
|
||||
instance_name: str,
|
||||
customizations: Optional[Dict[str, Any]] = None,
|
||||
capability_token: str = None
|
||||
) -> AssistantInstance:
|
||||
"""
|
||||
Create assistant instance from template
|
||||
|
||||
Args:
|
||||
template_id: Template to use
|
||||
user_id: User creating the assistant
|
||||
instance_name: Name for the instance
|
||||
customizations: Optional customizations
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
Created assistant instance
|
||||
"""
|
||||
# Verify capability token
|
||||
if capability_token:
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Deploy from template
|
||||
instance = await self.template_library.deploy_template(
|
||||
template_id=template_id,
|
||||
user_id=user_id,
|
||||
instance_name=instance_name,
|
||||
tenant_domain=self.tenant_domain,
|
||||
customizations=customizations
|
||||
)
|
||||
|
||||
# Create file storage
|
||||
await self._create_assistant_files(instance)
|
||||
|
||||
# Save to database (would be SQLite in production)
|
||||
await self._save_assistant(instance)
|
||||
|
||||
logger.info(f"Created assistant {instance.id} from template {template_id} for {user_id}")
|
||||
|
||||
return instance
|
||||
|
||||
async def create_custom(
|
||||
self,
|
||||
builder_config: AssistantBuilder,
|
||||
user_id: str,
|
||||
capability_token: str = None
|
||||
) -> AssistantInstance:
|
||||
"""
|
||||
Create custom assistant from builder configuration
|
||||
|
||||
Args:
|
||||
builder_config: Custom assistant configuration
|
||||
user_id: User creating the assistant
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
Created assistant instance
|
||||
"""
|
||||
# Verify capability token
|
||||
if capability_token:
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Check if user has required capabilities
|
||||
user_capabilities = token_data.get("capabilities", [])
|
||||
for required_cap in builder_config.requested_capabilities:
|
||||
if not any(required_cap in cap.get("resource", "") for cap in user_capabilities):
|
||||
raise PermissionError(f"Missing capability: {required_cap}")
|
||||
|
||||
# Build instance
|
||||
instance = builder_config.build_instance(user_id, self.tenant_domain)
|
||||
|
||||
# Create file storage
|
||||
await self._create_assistant_files(instance)
|
||||
|
||||
# Save to database
|
||||
await self._save_assistant(instance)
|
||||
|
||||
logger.info(f"Created custom assistant {instance.id} for {user_id}")
|
||||
|
||||
return instance
|
||||
|
||||
async def get_assistant(
|
||||
self,
|
||||
assistant_id: str,
|
||||
user_id: str
|
||||
) -> Optional[AssistantInstance]:
|
||||
"""
|
||||
Get assistant instance by ID
|
||||
|
||||
Args:
|
||||
assistant_id: Assistant ID
|
||||
user_id: User requesting the assistant
|
||||
|
||||
Returns:
|
||||
Assistant instance if found and accessible
|
||||
"""
|
||||
# Load assistant
|
||||
instance = await self._load_assistant(assistant_id)
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
# Check access permission
|
||||
allowed, _ = await self.access_controller.check_permission(
|
||||
user_id, instance, "read"
|
||||
)
|
||||
if not allowed:
|
||||
return None
|
||||
|
||||
return instance
|
||||
|
||||
async def list_user_assistants(
|
||||
self,
|
||||
user_id: str,
|
||||
include_shared: bool = True
|
||||
) -> List[AssistantInstance]:
|
||||
"""
|
||||
List all assistants accessible to user
|
||||
|
||||
Args:
|
||||
user_id: User to list assistants for
|
||||
include_shared: Include team/org shared assistants
|
||||
|
||||
Returns:
|
||||
List of accessible assistants
|
||||
"""
|
||||
assistants = []
|
||||
|
||||
# Get owned assistants
|
||||
owned = await self._get_owned_assistants(user_id)
|
||||
assistants.extend(owned)
|
||||
|
||||
# Get shared assistants if requested
|
||||
if include_shared:
|
||||
shared = await self._get_shared_assistants(user_id)
|
||||
assistants.extend(shared)
|
||||
|
||||
return assistants
|
||||
|
||||
async def update_assistant(
|
||||
self,
|
||||
assistant_id: str,
|
||||
user_id: str,
|
||||
updates: Dict[str, Any]
|
||||
) -> AssistantInstance:
|
||||
"""
|
||||
Update assistant configuration
|
||||
|
||||
Args:
|
||||
assistant_id: Assistant to update
|
||||
user_id: User requesting update
|
||||
updates: Configuration updates
|
||||
|
||||
Returns:
|
||||
Updated assistant instance
|
||||
"""
|
||||
# Load assistant
|
||||
instance = await self._load_assistant(assistant_id)
|
||||
if not instance:
|
||||
raise ValueError(f"Assistant not found: {assistant_id}")
|
||||
|
||||
# Check permission
|
||||
if instance.owner_id != user_id:
|
||||
raise PermissionError("Only owner can update assistant")
|
||||
|
||||
# Apply updates
|
||||
if "personality" in updates:
|
||||
instance.personality_config = PersonalityConfig(**updates["personality"])
|
||||
if "resources" in updates:
|
||||
instance.resource_preferences = ResourcePreferences(**updates["resources"])
|
||||
if "memory" in updates:
|
||||
instance.memory_settings = MemorySettings(**updates["memory"])
|
||||
if "system_prompt" in updates:
|
||||
instance.system_prompt = updates["system_prompt"]
|
||||
|
||||
instance.updated_at = datetime.utcnow()
|
||||
|
||||
# Save changes
|
||||
await self._save_assistant(instance)
|
||||
await self._update_assistant_files(instance)
|
||||
|
||||
logger.info(f"Updated assistant {assistant_id} by {user_id}")
|
||||
|
||||
return instance
|
||||
|
||||
async def share_assistant(
|
||||
self,
|
||||
assistant_id: str,
|
||||
user_id: str,
|
||||
access_group: AccessGroup,
|
||||
team_members: Optional[List[str]] = None
|
||||
) -> AssistantInstance:
|
||||
"""
|
||||
Share assistant with team or organization
|
||||
|
||||
Args:
|
||||
assistant_id: Assistant to share
|
||||
user_id: User sharing (must be owner)
|
||||
access_group: New access level
|
||||
team_members: Team members if team access
|
||||
|
||||
Returns:
|
||||
Updated assistant instance
|
||||
"""
|
||||
# Load assistant
|
||||
instance = await self._load_assistant(assistant_id)
|
||||
if not instance:
|
||||
raise ValueError(f"Assistant not found: {assistant_id}")
|
||||
|
||||
# Check ownership
|
||||
if instance.owner_id != user_id:
|
||||
raise PermissionError("Only owner can share assistant")
|
||||
|
||||
# Update access
|
||||
instance.access_group = access_group
|
||||
if access_group == AccessGroup.TEAM:
|
||||
instance.team_members = team_members or []
|
||||
else:
|
||||
instance.team_members = []
|
||||
|
||||
instance.updated_at = datetime.utcnow()
|
||||
|
||||
# Save changes
|
||||
await self._save_assistant(instance)
|
||||
|
||||
logger.info(f"Shared assistant {assistant_id} with {access_group.value} by {user_id}")
|
||||
|
||||
return instance
|
||||
|
||||
async def delete_assistant(
|
||||
self,
|
||||
assistant_id: str,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete assistant and its files
|
||||
|
||||
Args:
|
||||
assistant_id: Assistant to delete
|
||||
user_id: User requesting deletion
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
# Load assistant
|
||||
instance = await self._load_assistant(assistant_id)
|
||||
if not instance:
|
||||
return False
|
||||
|
||||
# Check ownership
|
||||
if instance.owner_id != user_id:
|
||||
raise PermissionError("Only owner can delete assistant")
|
||||
|
||||
# Delete files
|
||||
await self._delete_assistant_files(instance)
|
||||
|
||||
# Delete from database
|
||||
await self._delete_assistant_record(assistant_id)
|
||||
|
||||
logger.info(f"Deleted assistant {assistant_id} by {user_id}")
|
||||
|
||||
return True
|
||||
|
||||
async def get_assistant_statistics(
|
||||
self,
|
||||
assistant_id: str,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get usage statistics for assistant
|
||||
|
||||
Args:
|
||||
assistant_id: Assistant ID
|
||||
user_id: User requesting stats
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
# Load assistant
|
||||
instance = await self.get_assistant(assistant_id, user_id)
|
||||
if not instance:
|
||||
raise ValueError(f"Assistant not found or not accessible: {assistant_id}")
|
||||
|
||||
return {
|
||||
"assistant_id": assistant_id,
|
||||
"name": instance.name,
|
||||
"created_at": instance.created_at.isoformat(),
|
||||
"last_used": instance.last_used.isoformat() if instance.last_used else None,
|
||||
"conversation_count": instance.conversation_count,
|
||||
"total_messages": instance.total_messages,
|
||||
"total_tokens_used": instance.total_tokens_used,
|
||||
"access_group": instance.access_group.value,
|
||||
"team_members_count": len(instance.team_members),
|
||||
"linked_datasets_count": len(instance.linked_datasets),
|
||||
"linked_tools_count": len(instance.linked_tools)
|
||||
}
|
||||
|
||||
async def _create_assistant_files(self, instance: AssistantInstance):
|
||||
"""Create file structure for assistant"""
|
||||
# Get file paths
|
||||
file_structure = instance.get_file_structure()
|
||||
|
||||
# Create directories
|
||||
for key, path in file_structure.items():
|
||||
if key in ["memory", "resources"]:
|
||||
# These are directories
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(Path(path), stat.S_IRWXU) # 700
|
||||
else:
|
||||
# These are files
|
||||
parent = Path(path).parent
|
||||
parent.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(parent, stat.S_IRWXU) # 700
|
||||
|
||||
# Save configuration
|
||||
config_path = Path(file_structure["config"])
|
||||
config_data = {
|
||||
"id": instance.id,
|
||||
"name": instance.name,
|
||||
"template_id": instance.template_id,
|
||||
"personality": instance.personality_config.model_dump(),
|
||||
"resources": instance.resource_preferences.model_dump(),
|
||||
"memory": instance.memory_settings.model_dump(),
|
||||
"created_at": instance.created_at.isoformat(),
|
||||
"updated_at": instance.updated_at.isoformat()
|
||||
}
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
os.chmod(config_path, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
# Save prompt
|
||||
prompt_path = Path(file_structure["prompt"])
|
||||
with open(prompt_path, 'w') as f:
|
||||
f.write(instance.system_prompt)
|
||||
os.chmod(prompt_path, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
# Save capabilities
|
||||
capabilities_path = Path(file_structure["capabilities"])
|
||||
with open(capabilities_path, 'w') as f:
|
||||
json.dump(instance.capabilities, f, indent=2)
|
||||
os.chmod(capabilities_path, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
# Update instance with file paths
|
||||
instance.config_file_path = str(config_path)
|
||||
instance.memory_file_path = str(Path(file_structure["memory"]))
|
||||
|
||||
async def _update_assistant_files(self, instance: AssistantInstance):
|
||||
"""Update assistant files with current configuration"""
|
||||
if instance.config_file_path:
|
||||
config_data = {
|
||||
"id": instance.id,
|
||||
"name": instance.name,
|
||||
"template_id": instance.template_id,
|
||||
"personality": instance.personality_config.model_dump(),
|
||||
"resources": instance.resource_preferences.model_dump(),
|
||||
"memory": instance.memory_settings.model_dump(),
|
||||
"created_at": instance.created_at.isoformat(),
|
||||
"updated_at": instance.updated_at.isoformat()
|
||||
}
|
||||
|
||||
with open(instance.config_file_path, 'w') as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
async def _delete_assistant_files(self, instance: AssistantInstance):
|
||||
"""Delete assistant file structure"""
|
||||
file_structure = instance.get_file_structure()
|
||||
base_dir = Path(file_structure["config"]).parent
|
||||
|
||||
if base_dir.exists():
|
||||
import shutil
|
||||
shutil.rmtree(base_dir)
|
||||
logger.info(f"Deleted assistant files at {base_dir}")
|
||||
|
||||
async def _save_assistant(self, instance: AssistantInstance):
|
||||
"""Save assistant to database (SQLite in production)"""
|
||||
# This would save to SQLite database
|
||||
# For now, we'll save to a JSON file as placeholder
|
||||
db_file = self.base_path / "instances" / f"{instance.id}.json"
|
||||
with open(db_file, 'w') as f:
|
||||
json.dump(instance.model_dump(mode='json'), f, indent=2, default=str)
|
||||
os.chmod(db_file, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
async def _load_assistant(self, assistant_id: str) -> Optional[AssistantInstance]:
|
||||
"""Load assistant from database"""
|
||||
db_file = self.base_path / "instances" / f"{assistant_id}.json"
|
||||
if not db_file.exists():
|
||||
return None
|
||||
|
||||
with open(db_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Convert datetime strings back to datetime objects
|
||||
for field in ['created_at', 'updated_at', 'last_used']:
|
||||
if field in data and data[field]:
|
||||
data[field] = datetime.fromisoformat(data[field])
|
||||
|
||||
return AssistantInstance(**data)
|
||||
|
||||
async def _delete_assistant_record(self, assistant_id: str):
|
||||
"""Delete assistant from database"""
|
||||
db_file = self.base_path / "instances" / f"{assistant_id}.json"
|
||||
if db_file.exists():
|
||||
db_file.unlink()
|
||||
|
||||
async def _get_owned_assistants(self, user_id: str) -> List[AssistantInstance]:
|
||||
"""Get assistants owned by user"""
|
||||
assistants = []
|
||||
instances_dir = self.base_path / "instances"
|
||||
|
||||
if instances_dir.exists():
|
||||
for file in instances_dir.glob("*.json"):
|
||||
instance = await self._load_assistant(file.stem)
|
||||
if instance and instance.owner_id == user_id:
|
||||
assistants.append(instance)
|
||||
|
||||
return assistants
|
||||
|
||||
async def _get_shared_assistants(self, user_id: str) -> List[AssistantInstance]:
|
||||
"""Get assistants shared with user"""
|
||||
assistants = []
|
||||
instances_dir = self.base_path / "instances"
|
||||
|
||||
if instances_dir.exists():
|
||||
for file in instances_dir.glob("*.json"):
|
||||
instance = await self._load_assistant(file.stem)
|
||||
if instance and instance.owner_id != user_id:
|
||||
# Check if user has access
|
||||
allowed, _ = await self.access_controller.check_permission(
|
||||
user_id, instance, "read"
|
||||
)
|
||||
if allowed:
|
||||
assistants.append(instance)
|
||||
|
||||
return assistants
|
||||
599
apps/tenant-backend/app/services/assistant_manager.py
Normal file
599
apps/tenant-backend/app/services/assistant_manager.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""
|
||||
AssistantManager Service for GT 2.0 Tenant Backend
|
||||
|
||||
File-based agent lifecycle management with perfect tenant isolation.
|
||||
Implements the core Agent System specification from CLAUDE.md.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func, desc
|
||||
from sqlalchemy.orm import selectinload
|
||||
import logging
|
||||
|
||||
from app.models.agent import Agent
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.message import Message
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssistantManager:
|
||||
"""File-based agent lifecycle management"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.settings = get_settings()
|
||||
|
||||
async def create_from_template(self, template_id: str, config: Dict[str, Any], user_identifier: str) -> str:
|
||||
"""Create agent from template or custom config"""
|
||||
try:
|
||||
# Get template configuration
|
||||
template_config = await self._load_template_config(template_id)
|
||||
|
||||
# Merge template config with user overrides
|
||||
merged_config = {**template_config, **config}
|
||||
|
||||
# Create agent record
|
||||
agent = Agent(
|
||||
name=merged_config.get("name", f"Agent from {template_id}"),
|
||||
description=merged_config.get("description", f"Created from template: {template_id}"),
|
||||
template_id=template_id,
|
||||
created_by=user_identifier,
|
||||
user_name=merged_config.get("user_name"),
|
||||
personality_config=merged_config.get("personality_config", {}),
|
||||
resource_preferences=merged_config.get("resource_preferences", {}),
|
||||
memory_settings=merged_config.get("memory_settings", {}),
|
||||
tags=merged_config.get("tags", []),
|
||||
)
|
||||
|
||||
# Initialize with placeholder paths first
|
||||
agent.config_file_path = "placeholder"
|
||||
agent.prompt_file_path = "placeholder"
|
||||
agent.capabilities_file_path = "placeholder"
|
||||
|
||||
# Save to database first to get ID and UUID
|
||||
self.db.add(agent)
|
||||
await self.db.flush() # Flush to get the generated UUID without committing
|
||||
|
||||
# Now we can initialize proper file paths with the UUID
|
||||
agent.initialize_file_paths()
|
||||
|
||||
# Create file system structure
|
||||
await self._setup_assistant_files(agent, merged_config)
|
||||
|
||||
# Commit all changes
|
||||
await self.db.commit()
|
||||
await self.db.refresh(agent)
|
||||
|
||||
logger.info(
|
||||
f"Created agent from template",
|
||||
extra={
|
||||
"agent_id": agent.id,
|
||||
"assistant_uuid": agent.uuid,
|
||||
"template_id": template_id,
|
||||
"created_by": user_identifier,
|
||||
}
|
||||
)
|
||||
|
||||
return str(agent.uuid)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent from template: {e}", exc_info=True)
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def create_custom_assistant(self, config: Dict[str, Any], user_identifier: str) -> str:
|
||||
"""Create custom agent without template"""
|
||||
try:
|
||||
# Validate required fields
|
||||
if not config.get("name"):
|
||||
raise ValueError("Agent name is required")
|
||||
|
||||
# Create agent record
|
||||
agent = Agent(
|
||||
name=config["name"],
|
||||
description=config.get("description", "Custom AI agent"),
|
||||
template_id=None, # No template used
|
||||
created_by=user_identifier,
|
||||
user_name=config.get("user_name"),
|
||||
personality_config=config.get("personality_config", {}),
|
||||
resource_preferences=config.get("resource_preferences", {}),
|
||||
memory_settings=config.get("memory_settings", {}),
|
||||
tags=config.get("tags", []),
|
||||
)
|
||||
|
||||
# Initialize with placeholder paths first
|
||||
agent.config_file_path = "placeholder"
|
||||
agent.prompt_file_path = "placeholder"
|
||||
agent.capabilities_file_path = "placeholder"
|
||||
|
||||
# Save to database first to get ID and UUID
|
||||
self.db.add(agent)
|
||||
await self.db.flush() # Flush to get the generated UUID without committing
|
||||
|
||||
# Now we can initialize proper file paths with the UUID
|
||||
agent.initialize_file_paths()
|
||||
|
||||
# Create file system structure
|
||||
await self._setup_assistant_files(agent, config)
|
||||
|
||||
# Commit all changes
|
||||
await self.db.commit()
|
||||
await self.db.refresh(agent)
|
||||
|
||||
logger.info(
|
||||
f"Created custom agent",
|
||||
extra={
|
||||
"agent_id": agent.id,
|
||||
"assistant_uuid": agent.uuid,
|
||||
"created_by": user_identifier,
|
||||
}
|
||||
)
|
||||
|
||||
return str(agent.uuid)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create custom agent: {e}", exc_info=True)
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def get_assistant_config(self, assistant_uuid: str, user_identifier: str) -> Dict[str, Any]:
|
||||
"""Get complete agent configuration including file-based data"""
|
||||
try:
|
||||
# Get agent from database
|
||||
result = await self.db.execute(
|
||||
select(Agent).where(
|
||||
and_(
|
||||
Agent.uuid == assistant_uuid,
|
||||
Agent.created_by == user_identifier,
|
||||
Agent.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
raise ValueError(f"Agent not found: {assistant_uuid}")
|
||||
|
||||
# Load complete configuration
|
||||
return agent.get_full_configuration()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get agent config: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def list_user_assistants(
|
||||
self,
|
||||
user_identifier: str,
|
||||
include_archived: bool = False,
|
||||
template_id: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List user's agents with filtering options"""
|
||||
try:
|
||||
# Build base query
|
||||
query = select(Agent).where(Agent.created_by == user_identifier)
|
||||
|
||||
# Apply filters
|
||||
if not include_archived:
|
||||
query = query.where(Agent.is_active == True)
|
||||
|
||||
if template_id:
|
||||
query = query.where(Agent.template_id == template_id)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search}%"
|
||||
query = query.where(
|
||||
or_(
|
||||
Agent.name.ilike(search_term),
|
||||
Agent.description.ilike(search_term)
|
||||
)
|
||||
)
|
||||
|
||||
# Apply ordering and pagination
|
||||
query = query.order_by(desc(Agent.last_used_at), desc(Agent.created_at))
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
agents = result.scalars().all()
|
||||
|
||||
return [agent.to_dict() for agent in agents]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list user agents: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def count_user_assistants(
|
||||
self,
|
||||
user_identifier: str,
|
||||
include_archived: bool = False,
|
||||
template_id: Optional[str] = None,
|
||||
search: Optional[str] = None
|
||||
) -> int:
|
||||
"""Count user's agents matching criteria"""
|
||||
try:
|
||||
# Build base query
|
||||
query = select(func.count(Agent.id)).where(Agent.created_by == user_identifier)
|
||||
|
||||
# Apply filters
|
||||
if not include_archived:
|
||||
query = query.where(Agent.is_active == True)
|
||||
|
||||
if template_id:
|
||||
query = query.where(Agent.template_id == template_id)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search}%"
|
||||
query = query.where(
|
||||
or_(
|
||||
Agent.name.ilike(search_term),
|
||||
Agent.description.ilike(search_term)
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to count user agents: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_assistant(self, agent_id: str, updates: Dict[str, Any], user_identifier: str) -> bool:
|
||||
"""Update agent configuration (renamed from update_configuration)"""
|
||||
return await self.update_configuration(agent_id, updates, user_identifier)
|
||||
|
||||
async def update_configuration(self, assistant_uuid: str, updates: Dict[str, Any], user_identifier: str) -> bool:
|
||||
"""Update agent configuration"""
|
||||
try:
|
||||
# Get agent
|
||||
result = await self.db.execute(
|
||||
select(Agent).where(
|
||||
and_(
|
||||
Agent.uuid == assistant_uuid,
|
||||
Agent.created_by == user_identifier,
|
||||
Agent.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
raise ValueError(f"Agent not found: {assistant_uuid}")
|
||||
|
||||
# Update database fields
|
||||
if "name" in updates:
|
||||
agent.name = updates["name"]
|
||||
if "description" in updates:
|
||||
agent.description = updates["description"]
|
||||
if "personality_config" in updates:
|
||||
agent.personality_config = updates["personality_config"]
|
||||
if "resource_preferences" in updates:
|
||||
agent.resource_preferences = updates["resource_preferences"]
|
||||
if "memory_settings" in updates:
|
||||
agent.memory_settings = updates["memory_settings"]
|
||||
if "tags" in updates:
|
||||
agent.tags = updates["tags"]
|
||||
|
||||
# Update file-based configurations
|
||||
if "config" in updates:
|
||||
agent.save_config_to_file(updates["config"])
|
||||
if "prompt" in updates:
|
||||
agent.save_prompt_to_file(updates["prompt"])
|
||||
if "capabilities" in updates:
|
||||
agent.save_capabilities_to_file(updates["capabilities"])
|
||||
|
||||
agent.updated_at = datetime.utcnow()
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Updated agent configuration",
|
||||
extra={
|
||||
"assistant_uuid": assistant_uuid,
|
||||
"updated_fields": list(updates.keys()),
|
||||
}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update agent configuration: {e}", exc_info=True)
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def clone_assistant(self, source_uuid: str, new_name: str, user_identifier: str, modifications: Dict[str, Any] = None) -> str:
|
||||
"""Clone existing agent with modifications"""
|
||||
try:
|
||||
# Get source agent
|
||||
result = await self.db.execute(
|
||||
select(Agent).where(
|
||||
and_(
|
||||
Agent.uuid == source_uuid,
|
||||
Agent.created_by == user_identifier,
|
||||
Agent.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
source_assistant = result.scalar_one_or_none()
|
||||
|
||||
if not source_assistant:
|
||||
raise ValueError(f"Source agent not found: {source_uuid}")
|
||||
|
||||
# Clone agent
|
||||
cloned_assistant = source_assistant.clone(new_name, user_identifier, modifications or {})
|
||||
|
||||
# Initialize with placeholder paths first
|
||||
cloned_assistant.config_file_path = "placeholder"
|
||||
cloned_assistant.prompt_file_path = "placeholder"
|
||||
cloned_assistant.capabilities_file_path = "placeholder"
|
||||
|
||||
# Save to database first to get UUID
|
||||
self.db.add(cloned_assistant)
|
||||
await self.db.flush() # Flush to get the generated UUID
|
||||
|
||||
# Initialize proper file paths with UUID
|
||||
cloned_assistant.initialize_file_paths()
|
||||
|
||||
# Copy and modify files
|
||||
await self._clone_assistant_files(source_assistant, cloned_assistant, modifications or {})
|
||||
|
||||
# Commit all changes
|
||||
await self.db.commit()
|
||||
await self.db.refresh(cloned_assistant)
|
||||
|
||||
logger.info(
|
||||
f"Cloned agent",
|
||||
extra={
|
||||
"source_uuid": source_uuid,
|
||||
"new_uuid": cloned_assistant.uuid,
|
||||
"new_name": new_name,
|
||||
}
|
||||
)
|
||||
|
||||
return str(cloned_assistant.uuid)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clone agent: {e}", exc_info=True)
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def archive_assistant(self, assistant_uuid: str, user_identifier: str) -> bool:
|
||||
"""Archive agent (soft delete)"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(Agent).where(
|
||||
and_(
|
||||
Agent.uuid == assistant_uuid,
|
||||
Agent.created_by == user_identifier
|
||||
)
|
||||
)
|
||||
)
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
raise ValueError(f"Agent not found: {assistant_uuid}")
|
||||
|
||||
agent.archive()
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Archived agent",
|
||||
extra={"assistant_uuid": assistant_uuid}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to archive agent: {e}", exc_info=True)
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def get_assistant_statistics(self, assistant_uuid: str, user_identifier: str) -> Dict[str, Any]:
|
||||
"""Get usage statistics for agent"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(Agent).where(
|
||||
and_(
|
||||
Agent.uuid == assistant_uuid,
|
||||
Agent.created_by == user_identifier,
|
||||
Agent.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
raise ValueError(f"Agent not found: {assistant_uuid}")
|
||||
|
||||
# Get conversation statistics
|
||||
conv_result = await self.db.execute(
|
||||
select(func.count(Conversation.id))
|
||||
.where(Conversation.agent_id == agent.id)
|
||||
)
|
||||
conversation_count = conv_result.scalar() or 0
|
||||
|
||||
# Get message statistics
|
||||
msg_result = await self.db.execute(
|
||||
select(
|
||||
func.count(Message.id),
|
||||
func.sum(Message.tokens_used),
|
||||
func.sum(Message.cost_cents)
|
||||
)
|
||||
.join(Conversation, Message.conversation_id == Conversation.id)
|
||||
.where(Conversation.agent_id == agent.id)
|
||||
)
|
||||
message_stats = msg_result.first()
|
||||
|
||||
return {
|
||||
"agent_id": assistant_uuid, # Use agent_id to match schema
|
||||
"name": agent.name,
|
||||
"created_at": agent.created_at, # Return datetime object, not ISO string
|
||||
"last_used_at": agent.last_used_at, # Return datetime object, not ISO string
|
||||
"conversation_count": conversation_count,
|
||||
"total_messages": message_stats[0] or 0,
|
||||
"total_tokens_used": message_stats[1] or 0,
|
||||
"total_cost_cents": message_stats[2] or 0,
|
||||
"total_cost_dollars": (message_stats[2] or 0) / 100.0,
|
||||
"average_tokens_per_message": (
|
||||
(message_stats[1] or 0) / max(1, message_stats[0] or 1)
|
||||
),
|
||||
"is_favorite": agent.is_favorite,
|
||||
"tags": agent.tags,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get agent statistics: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# Private helper methods
|
||||
|
||||
async def _load_template_config(self, template_id: str) -> Dict[str, Any]:
|
||||
"""Load template configuration from Resource Cluster or built-in templates"""
|
||||
# Built-in templates (as specified in CLAUDE.md)
|
||||
builtin_templates = {
|
||||
"research_assistant": {
|
||||
"name": "Research & Analysis Agent",
|
||||
"description": "Specialized in information synthesis and analysis",
|
||||
"prompt": """You are a research agent specialized in information synthesis and analysis.
|
||||
Focus on providing well-sourced, analytical responses with clear reasoning.""",
|
||||
"personality_config": {
|
||||
"tone": "balanced",
|
||||
"explanation_depth": "expert",
|
||||
"interaction_style": "collaborative"
|
||||
},
|
||||
"resource_preferences": {
|
||||
"primary_llm": "groq:llama3-70b-8192",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4000
|
||||
},
|
||||
"capabilities": [
|
||||
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 4000}},
|
||||
{"resource": "rag:semantic_search", "actions": ["search"], "limits": {}},
|
||||
{"resource": "tools:web_search", "actions": ["search"], "limits": {"requests_per_hour": 50}},
|
||||
{"resource": "export:citations", "actions": ["create"], "limits": {}}
|
||||
]
|
||||
},
|
||||
"coding_assistant": {
|
||||
"name": "Software Development Agent",
|
||||
"description": "Focused on code quality and best practices",
|
||||
"prompt": """You are a software development agent focused on code quality and best practices.
|
||||
Provide clear explanations, suggest improvements, and help debug issues.""",
|
||||
"personality_config": {
|
||||
"tone": "direct",
|
||||
"explanation_depth": "intermediate",
|
||||
"interaction_style": "teaching"
|
||||
},
|
||||
"resource_preferences": {
|
||||
"primary_llm": "groq:llama3-70b-8192",
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 4000
|
||||
},
|
||||
"capabilities": [
|
||||
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 4000}},
|
||||
{"resource": "tools:github_integration", "actions": ["read"], "limits": {}},
|
||||
{"resource": "resources:documentation", "actions": ["search"], "limits": {}},
|
||||
{"resource": "export:code_snippets", "actions": ["create"], "limits": {}}
|
||||
]
|
||||
},
|
||||
"cyber_analyst": {
|
||||
"name": "Cybersecurity Analysis Agent",
|
||||
"description": "For threat detection and response analysis",
|
||||
"prompt": """You are a cybersecurity analyst agent for threat detection and response.
|
||||
Prioritize security best practices and provide actionable recommendations.""",
|
||||
"personality_config": {
|
||||
"tone": "formal",
|
||||
"explanation_depth": "expert",
|
||||
"interaction_style": "direct"
|
||||
},
|
||||
"resource_preferences": {
|
||||
"primary_llm": "groq:llama3-70b-8192",
|
||||
"temperature": 0.2,
|
||||
"max_tokens": 4000
|
||||
},
|
||||
"capabilities": [
|
||||
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 4000}},
|
||||
{"resource": "tools:security_scanning", "actions": ["analyze"], "limits": {}},
|
||||
{"resource": "resources:threat_intelligence", "actions": ["search"], "limits": {}},
|
||||
{"resource": "export:security_reports", "actions": ["create"], "limits": {}}
|
||||
]
|
||||
},
|
||||
"educational_tutor": {
|
||||
"name": "AI Literacy Educational Agent",
|
||||
"description": "Develops critical thinking and AI literacy",
|
||||
"prompt": """You are an educational agent focused on developing critical thinking and AI literacy.
|
||||
Use socratic questioning and encourage deep analysis of problems.""",
|
||||
"personality_config": {
|
||||
"tone": "casual",
|
||||
"explanation_depth": "beginner",
|
||||
"interaction_style": "teaching"
|
||||
},
|
||||
"resource_preferences": {
|
||||
"primary_llm": "groq:llama3-70b-8192",
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 3000
|
||||
},
|
||||
"capabilities": [
|
||||
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 3000}},
|
||||
{"resource": "games:strategic_thinking", "actions": ["play"], "limits": {}},
|
||||
{"resource": "puzzles:logic_reasoning", "actions": ["present"], "limits": {}},
|
||||
{"resource": "analytics:learning_progress", "actions": ["track"], "limits": {}}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
if template_id in builtin_templates:
|
||||
return builtin_templates[template_id]
|
||||
|
||||
# TODO: In the future, load from Resource Cluster Agent Library
|
||||
# For now, return empty config for unknown templates
|
||||
logger.warning(f"Unknown template ID: {template_id}")
|
||||
return {
|
||||
"name": f"Agent ({template_id})",
|
||||
"description": "Custom agent",
|
||||
"prompt": "You are a helpful AI agent.",
|
||||
"capabilities": []
|
||||
}
|
||||
|
||||
async def _setup_assistant_files(self, agent: Agent, config: Dict[str, Any]) -> None:
|
||||
"""Create file system structure for agent"""
|
||||
# Ensure directory exists
|
||||
agent.ensure_directory_exists()
|
||||
|
||||
# Save configuration files
|
||||
agent.save_config_to_file(config)
|
||||
agent.save_prompt_to_file(config.get("prompt", "You are a helpful AI agent."))
|
||||
agent.save_capabilities_to_file(config.get("capabilities", []))
|
||||
|
||||
logger.info(f"Created agent files for {agent.uuid}")
|
||||
|
||||
async def _clone_assistant_files(self, source: Agent, target: Agent, modifications: Dict[str, Any]) -> None:
|
||||
"""Clone agent files with modifications"""
|
||||
# Load source configurations
|
||||
source_config = source.load_config_from_file()
|
||||
source_prompt = source.load_prompt_from_file()
|
||||
source_capabilities = source.load_capabilities_from_file()
|
||||
|
||||
# Apply modifications
|
||||
target_config = {**source_config, **modifications.get("config", {})}
|
||||
target_prompt = modifications.get("prompt", source_prompt)
|
||||
target_capabilities = modifications.get("capabilities", source_capabilities)
|
||||
|
||||
# Create target files
|
||||
target.ensure_directory_exists()
|
||||
target.save_config_to_file(target_config)
|
||||
target.save_prompt_to_file(target_prompt)
|
||||
target.save_capabilities_to_file(target_capabilities)
|
||||
|
||||
logger.info(f"Cloned agent files from {source.uuid} to {target.uuid}")
|
||||
|
||||
|
||||
async def get_assistant_manager(db: AsyncSession) -> AssistantManager:
|
||||
"""Get AssistantManager instance"""
|
||||
return AssistantManager(db)
|
||||
632
apps/tenant-backend/app/services/automation_executor.py
Normal file
632
apps/tenant-backend/app/services/automation_executor.py
Normal file
@@ -0,0 +1,632 @@
|
||||
"""
|
||||
Automation Chain Executor
|
||||
|
||||
Executes automation chains with configurable depth, capability-based limits,
|
||||
and comprehensive error handling.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from app.services.event_bus import Event, Automation, TriggerType, TenantEventBus
|
||||
from app.core.security import verify_capability_token
|
||||
from app.core.path_security import sanitize_tenant_domain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChainDepthExceeded(Exception):
|
||||
"""Raised when automation chain depth exceeds limit"""
|
||||
pass
|
||||
|
||||
|
||||
class AutomationTimeout(Exception):
|
||||
"""Raised when automation execution times out"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""Context for automation execution"""
|
||||
automation_id: str
|
||||
chain_depth: int = 0
|
||||
parent_automation_id: Optional[str] = None
|
||||
start_time: datetime = None
|
||||
execution_history: List[Dict[str, Any]] = None
|
||||
variables: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.start_time is None:
|
||||
self.start_time = datetime.utcnow()
|
||||
if self.execution_history is None:
|
||||
self.execution_history = []
|
||||
if self.variables is None:
|
||||
self.variables = {}
|
||||
|
||||
def add_execution(self, action: str, result: Any, duration_ms: float):
|
||||
"""Add execution record to history"""
|
||||
self.execution_history.append({
|
||||
"action": action,
|
||||
"result": result,
|
||||
"duration_ms": duration_ms,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
def get_total_duration(self) -> float:
|
||||
"""Get total execution duration in milliseconds"""
|
||||
return (datetime.utcnow() - self.start_time).total_seconds() * 1000
|
||||
|
||||
|
||||
class AutomationChainExecutor:
|
||||
"""
|
||||
Execute automation chains with configurable depth and capability-based limits.
|
||||
|
||||
Features:
|
||||
- Configurable max chain depth per tenant
|
||||
- Retry logic with exponential backoff
|
||||
- Comprehensive error handling
|
||||
- Execution history tracking
|
||||
- Variable passing between chain steps
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_domain: str,
|
||||
event_bus: TenantEventBus,
|
||||
base_path: Optional[Path] = None
|
||||
):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.event_bus = event_bus
|
||||
# Sanitize tenant_domain to prevent path traversal
|
||||
safe_tenant = sanitize_tenant_domain(tenant_domain)
|
||||
self.base_path = base_path or (Path("/data") / safe_tenant / "automations")
|
||||
self.execution_path = self.base_path / "executions"
|
||||
self.running_chains: Dict[str, ExecutionContext] = {}
|
||||
|
||||
# Ensure directories exist
|
||||
self._ensure_directories()
|
||||
|
||||
logger.info(f"AutomationChainExecutor initialized for {tenant_domain}")
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure execution directories exist with proper permissions"""
|
||||
import os
|
||||
import stat
|
||||
|
||||
# codeql[py/path-injection] execution_path derived from sanitize_tenant_domain() at line 86
|
||||
self.execution_path.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(self.execution_path, stat.S_IRWXU) # 700 permissions
|
||||
|
||||
async def execute_chain(
|
||||
self,
|
||||
automation: Automation,
|
||||
event: Event,
|
||||
capability_token: str,
|
||||
current_depth: int = 0
|
||||
) -> Any:
|
||||
"""
|
||||
Execute automation chain with depth control.
|
||||
|
||||
Args:
|
||||
automation: Automation to execute
|
||||
event: Triggering event
|
||||
capability_token: JWT capability token
|
||||
current_depth: Current chain depth
|
||||
|
||||
Returns:
|
||||
Execution result
|
||||
|
||||
Raises:
|
||||
ChainDepthExceeded: If chain depth exceeds limit
|
||||
AutomationTimeout: If execution times out
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data:
|
||||
raise ValueError("Invalid capability token")
|
||||
|
||||
# Get max chain depth from capability token (tenant-specific)
|
||||
max_depth = self._get_constraint(token_data, "max_automation_chain_depth", 5)
|
||||
|
||||
# Check depth limit
|
||||
if current_depth >= max_depth:
|
||||
raise ChainDepthExceeded(
|
||||
f"Chain depth {current_depth} exceeds limit {max_depth}"
|
||||
)
|
||||
|
||||
# Create execution context
|
||||
context = ExecutionContext(
|
||||
automation_id=automation.id,
|
||||
chain_depth=current_depth,
|
||||
parent_automation_id=event.metadata.get("parent_automation_id")
|
||||
)
|
||||
|
||||
# Track running chain
|
||||
self.running_chains[automation.id] = context
|
||||
|
||||
try:
|
||||
# Execute automation with timeout
|
||||
timeout = self._get_constraint(token_data, "automation_timeout_seconds", 300)
|
||||
result = await asyncio.wait_for(
|
||||
self._execute_automation(automation, event, context, token_data),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# If this automation triggers chain
|
||||
if automation.triggers_chain:
|
||||
await self._trigger_chain_automations(
|
||||
automation,
|
||||
result,
|
||||
capability_token,
|
||||
current_depth
|
||||
)
|
||||
|
||||
# Store execution history
|
||||
await self._store_execution(context, result)
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise AutomationTimeout(
|
||||
f"Automation {automation.id} timed out after {timeout} seconds"
|
||||
)
|
||||
finally:
|
||||
# Remove from running chains
|
||||
self.running_chains.pop(automation.id, None)
|
||||
|
||||
async def _execute_automation(
|
||||
self,
|
||||
automation: Automation,
|
||||
event: Event,
|
||||
context: ExecutionContext,
|
||||
token_data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""Execute automation with retry logic"""
|
||||
results = []
|
||||
retry_count = 0
|
||||
max_retries = min(automation.max_retries, 5) # Cap at 5 retries
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# Execute each action
|
||||
for action in automation.actions:
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Check if action is allowed by capabilities
|
||||
if not self._is_action_allowed(action, token_data):
|
||||
logger.warning(f"Action {action.get('type')} not allowed by capabilities")
|
||||
continue
|
||||
|
||||
# Execute action with context
|
||||
result = await self._execute_action(action, event, context, token_data)
|
||||
|
||||
# Track execution
|
||||
duration_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
context.add_execution(action.get("type"), result, duration_ms)
|
||||
|
||||
results.append(result)
|
||||
|
||||
# Update variables for next action
|
||||
if isinstance(result, dict) and "variables" in result:
|
||||
context.variables.update(result["variables"])
|
||||
|
||||
# Success - break retry loop
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count > max_retries:
|
||||
logger.error(f"Automation {automation.id} failed after {max_retries} retries: {e}")
|
||||
raise
|
||||
|
||||
# Exponential backoff
|
||||
wait_time = min(2 ** retry_count, 30) # Max 30 seconds
|
||||
logger.info(f"Retrying automation {automation.id} in {wait_time} seconds...")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
return {
|
||||
"automation_id": automation.id,
|
||||
"results": results,
|
||||
"context": {
|
||||
"chain_depth": context.chain_depth,
|
||||
"total_duration_ms": context.get_total_duration(),
|
||||
"variables": context.variables
|
||||
}
|
||||
}
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event,
|
||||
context: ExecutionContext,
|
||||
token_data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""Execute a single action with capability constraints"""
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "api_call":
|
||||
return await self._execute_api_call(action, context, token_data)
|
||||
elif action_type == "data_transform":
|
||||
return await self._execute_data_transform(action, context)
|
||||
elif action_type == "conditional":
|
||||
return await self._execute_conditional(action, context)
|
||||
elif action_type == "loop":
|
||||
return await self._execute_loop(action, event, context, token_data)
|
||||
elif action_type == "wait":
|
||||
return await self._execute_wait(action)
|
||||
elif action_type == "variable_set":
|
||||
return await self._execute_variable_set(action, context)
|
||||
else:
|
||||
# Delegate to event bus for standard actions
|
||||
return await self.event_bus._execute_action(action, event, None)
|
||||
|
||||
async def _execute_api_call(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
context: ExecutionContext,
|
||||
token_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute API call action with rate limiting"""
|
||||
endpoint = action.get("endpoint")
|
||||
method = action.get("method", "GET")
|
||||
headers = action.get("headers", {})
|
||||
body = action.get("body")
|
||||
|
||||
# Apply variable substitution
|
||||
if body and context.variables:
|
||||
body = self._substitute_variables(body, context.variables)
|
||||
|
||||
# Check rate limits
|
||||
rate_limit = self._get_constraint(token_data, "api_calls_per_minute", 60)
|
||||
# In production, implement actual rate limiting
|
||||
|
||||
logger.info(f"Mock API call: {method} {endpoint}")
|
||||
|
||||
# Mock response
|
||||
return {
|
||||
"status": 200,
|
||||
"data": {"message": "Mock API response"},
|
||||
"headers": {"content-type": "application/json"}
|
||||
}
|
||||
|
||||
async def _execute_data_transform(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute data transformation action"""
|
||||
transform_type = action.get("transform_type")
|
||||
source = action.get("source")
|
||||
target = action.get("target")
|
||||
|
||||
# Get source data from context
|
||||
source_data = context.variables.get(source)
|
||||
|
||||
if transform_type == "json_parse":
|
||||
result = json.loads(source_data) if isinstance(source_data, str) else source_data
|
||||
elif transform_type == "json_stringify":
|
||||
result = json.dumps(source_data)
|
||||
elif transform_type == "extract":
|
||||
path = action.get("path", "")
|
||||
result = self._extract_path(source_data, path)
|
||||
elif transform_type == "map":
|
||||
mapping = action.get("mapping", {})
|
||||
result = {k: self._extract_path(source_data, v) for k, v in mapping.items()}
|
||||
else:
|
||||
result = source_data
|
||||
|
||||
# Store result in context
|
||||
context.variables[target] = result
|
||||
|
||||
return {
|
||||
"transform_type": transform_type,
|
||||
"target": target,
|
||||
"variables": {target: result}
|
||||
}
|
||||
|
||||
async def _execute_conditional(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute conditional action"""
|
||||
condition = action.get("condition")
|
||||
then_actions = action.get("then", [])
|
||||
else_actions = action.get("else", [])
|
||||
|
||||
# Evaluate condition
|
||||
if self._evaluate_condition(condition, context.variables):
|
||||
actions_to_execute = then_actions
|
||||
branch = "then"
|
||||
else:
|
||||
actions_to_execute = else_actions
|
||||
branch = "else"
|
||||
|
||||
# Execute branch actions
|
||||
results = []
|
||||
for sub_action in actions_to_execute:
|
||||
result = await self._execute_action(sub_action, None, context, {})
|
||||
results.append(result)
|
||||
|
||||
return {
|
||||
"condition": condition,
|
||||
"branch": branch,
|
||||
"results": results
|
||||
}
|
||||
|
||||
async def _execute_loop(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event,
|
||||
context: ExecutionContext,
|
||||
token_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute loop action with iteration limit"""
|
||||
items = action.get("items", [])
|
||||
variable = action.get("variable", "item")
|
||||
loop_actions = action.get("actions", [])
|
||||
|
||||
# Get max iterations from capabilities
|
||||
max_iterations = self._get_constraint(token_data, "max_loop_iterations", 100)
|
||||
|
||||
# Resolve items from context if it's a variable reference
|
||||
if isinstance(items, str) and items.startswith("$"):
|
||||
items = context.variables.get(items[1:], [])
|
||||
|
||||
# Limit iterations
|
||||
items = items[:max_iterations]
|
||||
|
||||
results = []
|
||||
for item in items:
|
||||
# Set loop variable
|
||||
context.variables[variable] = item
|
||||
|
||||
# Execute loop actions
|
||||
for loop_action in loop_actions:
|
||||
result = await self._execute_action(loop_action, event, context, token_data)
|
||||
results.append(result)
|
||||
|
||||
return {
|
||||
"loop_count": len(items),
|
||||
"results": results
|
||||
}
|
||||
|
||||
async def _execute_wait(self, action: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute wait action"""
|
||||
duration = action.get("duration", 1)
|
||||
max_wait = 60 # Maximum 60 seconds wait
|
||||
|
||||
duration = min(duration, max_wait)
|
||||
await asyncio.sleep(duration)
|
||||
|
||||
return {
|
||||
"waited": duration,
|
||||
"unit": "seconds"
|
||||
}
|
||||
|
||||
async def _execute_variable_set(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""Set variables in context"""
|
||||
variables = action.get("variables", {})
|
||||
|
||||
for key, value in variables.items():
|
||||
# Substitute existing variables in value
|
||||
if isinstance(value, str):
|
||||
value = self._substitute_variables(value, context.variables)
|
||||
context.variables[key] = value
|
||||
|
||||
return {
|
||||
"variables": variables
|
||||
}
|
||||
|
||||
async def _trigger_chain_automations(
|
||||
self,
|
||||
automation: Automation,
|
||||
result: Any,
|
||||
capability_token: str,
|
||||
current_depth: int
|
||||
):
|
||||
"""Trigger chained automations"""
|
||||
for target_id in automation.chain_targets:
|
||||
# Load target automation
|
||||
target_automation = await self.event_bus.get_automation(target_id)
|
||||
|
||||
if not target_automation:
|
||||
logger.warning(f"Chain target automation {target_id} not found")
|
||||
continue
|
||||
|
||||
# Create chain event
|
||||
chain_event = Event(
|
||||
type=TriggerType.CHAIN.value,
|
||||
tenant=self.tenant_domain,
|
||||
user=automation.owner_id,
|
||||
data=result,
|
||||
metadata={
|
||||
"parent_automation_id": automation.id,
|
||||
"chain_depth": current_depth + 1
|
||||
}
|
||||
)
|
||||
|
||||
# Execute chained automation
|
||||
try:
|
||||
await self.execute_chain(
|
||||
target_automation,
|
||||
chain_event,
|
||||
capability_token,
|
||||
current_depth + 1
|
||||
)
|
||||
except ChainDepthExceeded:
|
||||
logger.warning(f"Chain depth exceeded for automation {target_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing chained automation {target_id}: {e}")
|
||||
|
||||
def _get_constraint(
|
||||
self,
|
||||
token_data: Dict[str, Any],
|
||||
constraint_name: str,
|
||||
default: Any
|
||||
) -> Any:
|
||||
"""Get constraint value from capability token"""
|
||||
constraints = token_data.get("constraints", {})
|
||||
return constraints.get(constraint_name, default)
|
||||
|
||||
def _is_action_allowed(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
token_data: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check if action is allowed by capabilities"""
|
||||
action_type = action.get("type")
|
||||
|
||||
# Check specific action capabilities
|
||||
capabilities = token_data.get("capabilities", [])
|
||||
|
||||
# Map action types to required capabilities
|
||||
required_capabilities = {
|
||||
"api_call": "automation:api_calls",
|
||||
"webhook": "automation:webhooks",
|
||||
"email": "automation:email",
|
||||
"data_transform": "automation:data_processing",
|
||||
"conditional": "automation:logic",
|
||||
"loop": "automation:logic"
|
||||
}
|
||||
|
||||
required = required_capabilities.get(action_type)
|
||||
if not required:
|
||||
return True # Allow unknown actions by default
|
||||
|
||||
# Check if capability exists
|
||||
return any(
|
||||
cap.get("resource") == required
|
||||
for cap in capabilities
|
||||
)
|
||||
|
||||
def _substitute_variables(
|
||||
self,
|
||||
template: Any,
|
||||
variables: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""Substitute variables in template"""
|
||||
if not isinstance(template, str):
|
||||
return template
|
||||
|
||||
# Simple variable substitution
|
||||
for key, value in variables.items():
|
||||
template = template.replace(f"${{{key}}}", str(value))
|
||||
template = template.replace(f"${key}", str(value))
|
||||
|
||||
return template
|
||||
|
||||
def _extract_path(self, data: Any, path: str) -> Any:
|
||||
"""Extract value from nested data using path"""
|
||||
if not path:
|
||||
return data
|
||||
|
||||
parts = path.split(".")
|
||||
current = data
|
||||
|
||||
for part in parts:
|
||||
if isinstance(current, dict):
|
||||
current = current.get(part)
|
||||
elif isinstance(current, list) and part.isdigit():
|
||||
index = int(part)
|
||||
if 0 <= index < len(current):
|
||||
current = current[index]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
return current
|
||||
|
||||
def _evaluate_condition(
|
||||
self,
|
||||
condition: Dict[str, Any],
|
||||
variables: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Evaluate condition against variables"""
|
||||
left = condition.get("left")
|
||||
operator = condition.get("operator")
|
||||
right = condition.get("right")
|
||||
|
||||
# Resolve variables
|
||||
if isinstance(left, str) and left.startswith("$"):
|
||||
left = variables.get(left[1:])
|
||||
if isinstance(right, str) and right.startswith("$"):
|
||||
right = variables.get(right[1:])
|
||||
|
||||
# Evaluate
|
||||
try:
|
||||
if operator == "equals":
|
||||
return left == right
|
||||
elif operator == "not_equals":
|
||||
return left != right
|
||||
elif operator == "greater_than":
|
||||
return float(left) > float(right)
|
||||
elif operator == "less_than":
|
||||
return float(left) < float(right)
|
||||
elif operator == "contains":
|
||||
return right in left
|
||||
elif operator == "exists":
|
||||
return left is not None
|
||||
elif operator == "not_exists":
|
||||
return left is None
|
||||
else:
|
||||
return False
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
async def _store_execution(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
result: Any
|
||||
):
|
||||
"""Store execution history to file system"""
|
||||
execution_record = {
|
||||
"automation_id": context.automation_id,
|
||||
"chain_depth": context.chain_depth,
|
||||
"parent_automation_id": context.parent_automation_id,
|
||||
"start_time": context.start_time.isoformat(),
|
||||
"total_duration_ms": context.get_total_duration(),
|
||||
"execution_history": context.execution_history,
|
||||
"variables": context.variables,
|
||||
"result": result if isinstance(result, (dict, list, str, int, float, bool)) else str(result)
|
||||
}
|
||||
|
||||
# Create execution file
|
||||
execution_file = self.execution_path / f"{context.automation_id}_{context.start_time.strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
with open(execution_file, "w") as f:
|
||||
json.dump(execution_record, f, indent=2)
|
||||
|
||||
async def get_execution_history(
|
||||
self,
|
||||
automation_id: Optional[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get execution history for automations"""
|
||||
executions = []
|
||||
|
||||
# Get all execution files
|
||||
pattern = f"{automation_id}_*.json" if automation_id else "*.json"
|
||||
|
||||
for execution_file in sorted(
|
||||
self.execution_path.glob(pattern),
|
||||
key=lambda x: x.stat().st_mtime,
|
||||
reverse=True
|
||||
)[:limit]:
|
||||
try:
|
||||
with open(execution_file, "r") as f:
|
||||
executions.append(json.load(f))
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading execution {execution_file}: {e}")
|
||||
|
||||
return executions
|
||||
514
apps/tenant-backend/app/services/category_service.py
Normal file
514
apps/tenant-backend/app/services/category_service.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""
|
||||
Category Service for GT 2.0 Tenant Backend
|
||||
|
||||
Provides tenant-scoped agent category management with permission-based
|
||||
editing and deletion. Supports Issue #215 requirements.
|
||||
|
||||
Permission Model:
|
||||
- Admins/developers can edit/delete ANY category
|
||||
- Regular users can only edit/delete categories they created
|
||||
- All users can view and use all tenant categories
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from app.core.config import get_settings
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.core.permissions import get_user_role
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Admin roles that can manage all categories
|
||||
ADMIN_ROLES = ["admin", "developer"]
|
||||
|
||||
|
||||
class CategoryService:
|
||||
"""GT 2.0 Category Management Service with Tenant Isolation"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str, user_email: str = None):
|
||||
"""Initialize with tenant and user isolation using PostgreSQL storage"""
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.user_email = user_email or user_id
|
||||
self.settings = get_settings()
|
||||
|
||||
logger.info(f"Category service initialized for {tenant_domain}/{user_id}")
|
||||
|
||||
def _generate_slug(self, name: str) -> str:
|
||||
"""Generate URL-safe slug from category name"""
|
||||
# Convert to lowercase, replace non-alphanumeric with hyphens
|
||||
slug = re.sub(r'[^a-zA-Z0-9]+', '-', name.lower())
|
||||
# Remove leading/trailing hyphens
|
||||
slug = slug.strip('-')
|
||||
return slug or 'category'
|
||||
|
||||
async def _get_user_id(self, pg_client) -> str:
|
||||
"""Get user UUID from email/username/uuid with tenant isolation"""
|
||||
identifier = self.user_email
|
||||
|
||||
user_lookup_query = """
|
||||
SELECT id FROM users
|
||||
WHERE (email = $1 OR id::text = $1 OR username = $1)
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
user_id = await pg_client.fetch_scalar(user_lookup_query, identifier, self.tenant_domain)
|
||||
if not user_id:
|
||||
user_id = await pg_client.fetch_scalar(user_lookup_query, self.user_id, self.tenant_domain)
|
||||
|
||||
if not user_id:
|
||||
raise RuntimeError(f"User not found: {identifier} in tenant {self.tenant_domain}")
|
||||
|
||||
return str(user_id)
|
||||
|
||||
async def _get_tenant_id(self, pg_client) -> str:
|
||||
"""Get tenant UUID from domain"""
|
||||
query = "SELECT id FROM tenants WHERE domain = $1 LIMIT 1"
|
||||
tenant_id = await pg_client.fetch_scalar(query, self.tenant_domain)
|
||||
if not tenant_id:
|
||||
raise RuntimeError(f"Tenant not found: {self.tenant_domain}")
|
||||
return str(tenant_id)
|
||||
|
||||
async def _can_manage_category(self, pg_client, category: Dict) -> tuple:
|
||||
"""
|
||||
Check if current user can manage (edit/delete) a category.
|
||||
Returns (can_edit, can_delete) tuple.
|
||||
"""
|
||||
# Get user role
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
is_admin = user_role in ADMIN_ROLES
|
||||
|
||||
# Get current user ID
|
||||
current_user_id = await self._get_user_id(pg_client)
|
||||
|
||||
# Admins can manage all categories
|
||||
if is_admin:
|
||||
return (True, True)
|
||||
|
||||
# Check if user created this category
|
||||
created_by = category.get('created_by')
|
||||
if created_by and str(created_by) == current_user_id:
|
||||
return (True, True)
|
||||
|
||||
# Regular users cannot manage other users' categories or defaults
|
||||
return (False, False)
|
||||
|
||||
async def get_all_categories(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all active categories for the tenant.
|
||||
Returns categories with permission flags for current user.
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
user_id = await self._get_user_id(pg_client)
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
is_admin = user_role in ADMIN_ROLES
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
c.id, c.name, c.slug, c.description, c.icon,
|
||||
c.is_default, c.created_by, c.sort_order,
|
||||
c.created_at, c.updated_at,
|
||||
u.full_name as created_by_name
|
||||
FROM categories c
|
||||
LEFT JOIN users u ON c.created_by = u.id
|
||||
WHERE c.tenant_id = (SELECT id FROM tenants WHERE domain = $1 LIMIT 1)
|
||||
AND c.is_deleted = FALSE
|
||||
ORDER BY c.sort_order ASC, c.name ASC
|
||||
"""
|
||||
|
||||
rows = await pg_client.execute_query(query, self.tenant_domain)
|
||||
|
||||
categories = []
|
||||
for row in rows:
|
||||
# Determine permissions
|
||||
can_edit = False
|
||||
can_delete = False
|
||||
|
||||
if is_admin:
|
||||
can_edit = True
|
||||
can_delete = True
|
||||
elif row.get('created_by') and str(row['created_by']) == user_id:
|
||||
can_edit = True
|
||||
can_delete = True
|
||||
|
||||
categories.append({
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"slug": row["slug"],
|
||||
"description": row.get("description"),
|
||||
"icon": row.get("icon"),
|
||||
"is_default": row.get("is_default", False),
|
||||
"created_by": str(row["created_by"]) if row.get("created_by") else None,
|
||||
"created_by_name": row.get("created_by_name"),
|
||||
"can_edit": can_edit,
|
||||
"can_delete": can_delete,
|
||||
"sort_order": row.get("sort_order", 0),
|
||||
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
|
||||
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(categories)} categories for tenant {self.tenant_domain}")
|
||||
return categories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting categories: {e}")
|
||||
raise
|
||||
|
||||
async def get_category_by_id(self, category_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single category by ID"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
c.id, c.name, c.slug, c.description, c.icon,
|
||||
c.is_default, c.created_by, c.sort_order,
|
||||
c.created_at, c.updated_at,
|
||||
u.full_name as created_by_name
|
||||
FROM categories c
|
||||
LEFT JOIN users u ON c.created_by = u.id
|
||||
WHERE c.id = $1::uuid
|
||||
AND c.tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
AND c.is_deleted = FALSE
|
||||
"""
|
||||
|
||||
row = await pg_client.fetch_one(query, category_id, self.tenant_domain)
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
can_edit, can_delete = await self._can_manage_category(pg_client, dict(row))
|
||||
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"slug": row["slug"],
|
||||
"description": row.get("description"),
|
||||
"icon": row.get("icon"),
|
||||
"is_default": row.get("is_default", False),
|
||||
"created_by": str(row["created_by"]) if row.get("created_by") else None,
|
||||
"created_by_name": row.get("created_by_name"),
|
||||
"can_edit": can_edit,
|
||||
"can_delete": can_delete,
|
||||
"sort_order": row.get("sort_order", 0),
|
||||
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
|
||||
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting category {category_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_category_by_slug(self, slug: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single category by slug"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
c.id, c.name, c.slug, c.description, c.icon,
|
||||
c.is_default, c.created_by, c.sort_order,
|
||||
c.created_at, c.updated_at,
|
||||
u.full_name as created_by_name
|
||||
FROM categories c
|
||||
LEFT JOIN users u ON c.created_by = u.id
|
||||
WHERE c.slug = $1
|
||||
AND c.tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
AND c.is_deleted = FALSE
|
||||
"""
|
||||
|
||||
row = await pg_client.fetch_one(query, slug.lower(), self.tenant_domain)
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
can_edit, can_delete = await self._can_manage_category(pg_client, dict(row))
|
||||
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"slug": row["slug"],
|
||||
"description": row.get("description"),
|
||||
"icon": row.get("icon"),
|
||||
"is_default": row.get("is_default", False),
|
||||
"created_by": str(row["created_by"]) if row.get("created_by") else None,
|
||||
"created_by_name": row.get("created_by_name"),
|
||||
"can_edit": can_edit,
|
||||
"can_delete": can_delete,
|
||||
"sort_order": row.get("sort_order", 0),
|
||||
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
|
||||
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting category by slug {slug}: {e}")
|
||||
raise
|
||||
|
||||
async def create_category(
|
||||
self,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
icon: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new custom category.
|
||||
The creating user becomes the owner and can edit/delete it.
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
user_id = await self._get_user_id(pg_client)
|
||||
tenant_id = await self._get_tenant_id(pg_client)
|
||||
|
||||
# Generate slug
|
||||
slug = self._generate_slug(name)
|
||||
|
||||
# Check if slug already exists
|
||||
existing = await self.get_category_by_slug(slug)
|
||||
if existing:
|
||||
raise ValueError(f"A category with name '{name}' already exists")
|
||||
|
||||
# Generate category ID
|
||||
category_id = str(uuid.uuid4())
|
||||
|
||||
# Get next sort_order (after all existing categories)
|
||||
sort_query = """
|
||||
SELECT COALESCE(MAX(sort_order), 0) + 10 as next_order
|
||||
FROM categories
|
||||
WHERE tenant_id = $1::uuid
|
||||
"""
|
||||
next_order = await pg_client.fetch_scalar(sort_query, tenant_id)
|
||||
|
||||
# Create category
|
||||
query = """
|
||||
INSERT INTO categories (
|
||||
id, tenant_id, name, slug, description, icon,
|
||||
is_default, created_by, sort_order, is_deleted,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, $3, $4, $5, $6,
|
||||
FALSE, $7::uuid, $8, FALSE,
|
||||
NOW(), NOW()
|
||||
)
|
||||
RETURNING id, name, slug, description, icon, is_default,
|
||||
created_by, sort_order, created_at, updated_at
|
||||
"""
|
||||
|
||||
row = await pg_client.fetch_one(
|
||||
query,
|
||||
category_id, tenant_id, name, slug, description, icon,
|
||||
user_id, next_order
|
||||
)
|
||||
|
||||
if not row:
|
||||
raise RuntimeError("Failed to create category")
|
||||
|
||||
logger.info(f"Created category {category_id}: {name} for user {user_id}")
|
||||
|
||||
# Get creator name
|
||||
user_query = "SELECT full_name FROM users WHERE id = $1::uuid"
|
||||
created_by_name = await pg_client.fetch_scalar(user_query, user_id)
|
||||
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"slug": row["slug"],
|
||||
"description": row.get("description"),
|
||||
"icon": row.get("icon"),
|
||||
"is_default": False,
|
||||
"created_by": user_id,
|
||||
"created_by_name": created_by_name,
|
||||
"can_edit": True,
|
||||
"can_delete": True,
|
||||
"sort_order": row.get("sort_order", 0),
|
||||
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
|
||||
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating category: {e}")
|
||||
raise
|
||||
|
||||
async def update_category(
|
||||
self,
|
||||
category_id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
icon: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Update a category.
|
||||
Requires permission (admin or category creator).
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get existing category
|
||||
existing = await self.get_category_by_id(category_id)
|
||||
if not existing:
|
||||
raise ValueError("Category not found")
|
||||
|
||||
# Check permissions
|
||||
can_edit, _ = await self._can_manage_category(pg_client, existing)
|
||||
if not can_edit:
|
||||
raise PermissionError("You do not have permission to edit this category")
|
||||
|
||||
# Build update fields
|
||||
updates = []
|
||||
params = [category_id, self.tenant_domain]
|
||||
param_idx = 3
|
||||
|
||||
if name is not None:
|
||||
new_slug = self._generate_slug(name)
|
||||
# Check if new slug conflicts with another category
|
||||
slug_check = await self.get_category_by_slug(new_slug)
|
||||
if slug_check and slug_check["id"] != category_id:
|
||||
raise ValueError(f"A category with name '{name}' already exists")
|
||||
updates.append(f"name = ${param_idx}")
|
||||
params.append(name)
|
||||
param_idx += 1
|
||||
updates.append(f"slug = ${param_idx}")
|
||||
params.append(new_slug)
|
||||
param_idx += 1
|
||||
|
||||
if description is not None:
|
||||
updates.append(f"description = ${param_idx}")
|
||||
params.append(description)
|
||||
param_idx += 1
|
||||
|
||||
if icon is not None:
|
||||
updates.append(f"icon = ${param_idx}")
|
||||
params.append(icon)
|
||||
param_idx += 1
|
||||
|
||||
if not updates:
|
||||
return existing
|
||||
|
||||
updates.append("updated_at = NOW()")
|
||||
|
||||
query = f"""
|
||||
UPDATE categories
|
||||
SET {', '.join(updates)}
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
AND is_deleted = FALSE
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
result = await pg_client.fetch_scalar(query, *params)
|
||||
if not result:
|
||||
raise RuntimeError("Failed to update category")
|
||||
|
||||
logger.info(f"Updated category {category_id}")
|
||||
|
||||
# Return updated category
|
||||
return await self.get_category_by_id(category_id)
|
||||
|
||||
except (ValueError, PermissionError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating category {category_id}: {e}")
|
||||
raise
|
||||
|
||||
async def delete_category(self, category_id: str) -> bool:
|
||||
"""
|
||||
Soft delete a category.
|
||||
Requires permission (admin or category creator).
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get existing category
|
||||
existing = await self.get_category_by_id(category_id)
|
||||
if not existing:
|
||||
raise ValueError("Category not found")
|
||||
|
||||
# Check permissions
|
||||
_, can_delete = await self._can_manage_category(pg_client, existing)
|
||||
if not can_delete:
|
||||
raise PermissionError("You do not have permission to delete this category")
|
||||
|
||||
# Soft delete
|
||||
query = """
|
||||
UPDATE categories
|
||||
SET is_deleted = TRUE, updated_at = NOW()
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
|
||||
await pg_client.execute_command(query, category_id, self.tenant_domain)
|
||||
|
||||
logger.info(f"Deleted category {category_id}")
|
||||
return True
|
||||
|
||||
except (ValueError, PermissionError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting category {category_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_or_create_category(
|
||||
self,
|
||||
slug: str,
|
||||
description: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get existing category by slug or create it if not exists.
|
||||
Used for agent import to auto-create missing categories.
|
||||
|
||||
If the category was soft-deleted, it will be restored.
|
||||
|
||||
Args:
|
||||
slug: Category slug (lowercase, hyphenated)
|
||||
description: Optional description for new/restored categories
|
||||
"""
|
||||
try:
|
||||
# Try to get existing active category
|
||||
existing = await self.get_category_by_slug(slug)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# Check if there's a soft-deleted category with this slug
|
||||
pg_client = await get_postgresql_client()
|
||||
deleted_query = """
|
||||
SELECT id FROM categories
|
||||
WHERE slug = $1
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
AND is_deleted = TRUE
|
||||
"""
|
||||
deleted_id = await pg_client.fetch_scalar(deleted_query, slug.lower(), self.tenant_domain)
|
||||
|
||||
if deleted_id:
|
||||
# Restore the soft-deleted category
|
||||
user_id = await self._get_user_id(pg_client)
|
||||
restore_query = """
|
||||
UPDATE categories
|
||||
SET is_deleted = FALSE,
|
||||
updated_at = NOW(),
|
||||
created_by = $3::uuid
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
await pg_client.execute_command(restore_query, str(deleted_id), self.tenant_domain, user_id)
|
||||
logger.info(f"Restored soft-deleted category: {slug}")
|
||||
|
||||
# Return the restored category
|
||||
return await self.get_category_by_slug(slug)
|
||||
|
||||
# Auto-create with importing user as creator
|
||||
name = slug.replace('-', ' ').title()
|
||||
return await self.create_category(
|
||||
name=name,
|
||||
description=description, # Use provided description or None
|
||||
icon=None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_or_create_category for slug {slug}: {e}")
|
||||
raise
|
||||
563
apps/tenant-backend/app/services/conversation_file_service.py
Normal file
563
apps/tenant-backend/app/services/conversation_file_service.py
Normal file
@@ -0,0 +1,563 @@
|
||||
"""
|
||||
Conversation File Service for GT 2.0
|
||||
|
||||
Handles conversation-scoped file attachments as a simpler alternative to dataset-based uploads.
|
||||
Preserves all existing dataset infrastructure while providing direct conversation file storage.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import UploadFile, HTTPException
|
||||
from app.core.config import get_settings
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.core.path_security import sanitize_tenant_domain
|
||||
from app.services.embedding_client import BGE_M3_EmbeddingClient
|
||||
from app.services.document_processor import DocumentProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationFileService:
|
||||
"""Service for managing conversation-scoped file attachments"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.settings = get_settings()
|
||||
self.schema_name = f"tenant_{tenant_domain.replace('.', '_').replace('-', '_')}"
|
||||
|
||||
# File storage configuration
|
||||
# Sanitize tenant_domain to prevent path traversal
|
||||
safe_tenant = sanitize_tenant_domain(tenant_domain)
|
||||
# codeql[py/path-injection] safe_tenant validated by sanitize_tenant_domain()
|
||||
self.storage_root = Path(self.settings.file_storage_path) / safe_tenant / "conversations"
|
||||
self.storage_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"ConversationFileService initialized for {tenant_domain}/{user_id}")
|
||||
|
||||
def _get_conversation_storage_path(self, conversation_id: str) -> Path:
|
||||
"""Get storage directory for conversation files"""
|
||||
conv_path = self.storage_root / conversation_id
|
||||
conv_path.mkdir(parents=True, exist_ok=True)
|
||||
return conv_path
|
||||
|
||||
def _generate_safe_filename(self, original_filename: str, file_id: str) -> str:
|
||||
"""Generate safe filename for storage"""
|
||||
# Sanitize filename and prepend file ID
|
||||
safe_name = "".join(c for c in original_filename if c.isalnum() or c in ".-_")
|
||||
return f"{file_id}-{safe_name}"
|
||||
|
||||
async def upload_files(
|
||||
self,
|
||||
conversation_id: str,
|
||||
files: List[UploadFile],
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Upload files directly to conversation"""
|
||||
try:
|
||||
# Validate conversation access
|
||||
await self._validate_conversation_access(conversation_id, user_id)
|
||||
|
||||
uploaded_files = []
|
||||
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File must have a filename")
|
||||
|
||||
# Generate file metadata
|
||||
file_id = str(uuid.uuid4())
|
||||
safe_filename = self._generate_safe_filename(file.filename, file_id)
|
||||
conversation_path = self._get_conversation_storage_path(conversation_id)
|
||||
file_path = conversation_path / safe_filename
|
||||
|
||||
# Store file to disk
|
||||
content = await file.read()
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Create database record
|
||||
file_record = await self._create_file_record(
|
||||
file_id=file_id,
|
||||
conversation_id=conversation_id,
|
||||
original_filename=file.filename,
|
||||
safe_filename=safe_filename,
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
file_size=len(content),
|
||||
file_path=str(file_path.relative_to(Path(self.settings.file_storage_path))),
|
||||
uploaded_by=user_id
|
||||
)
|
||||
|
||||
uploaded_files.append(file_record)
|
||||
|
||||
# Queue for background processing
|
||||
asyncio.create_task(self._process_file_embeddings(file_id))
|
||||
|
||||
logger.info(f"Uploaded conversation file: {file.filename} -> {file_id}")
|
||||
|
||||
return uploaded_files
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload conversation files: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
|
||||
|
||||
async def _get_user_uuid(self, user_email: str) -> str:
|
||||
"""Resolve user email to UUID"""
|
||||
client = await get_postgresql_client()
|
||||
query = f"SELECT id FROM {self.schema_name}.users WHERE email = $1 LIMIT 1"
|
||||
result = await client.fetch_one(query, user_email)
|
||||
if not result:
|
||||
raise ValueError(f"User not found: {user_email}")
|
||||
return str(result['id'])
|
||||
|
||||
async def _create_file_record(
|
||||
self,
|
||||
file_id: str,
|
||||
conversation_id: str,
|
||||
original_filename: str,
|
||||
safe_filename: str,
|
||||
content_type: str,
|
||||
file_size: int,
|
||||
file_path: str,
|
||||
uploaded_by: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create conversation_files database record"""
|
||||
|
||||
client = await get_postgresql_client()
|
||||
|
||||
# Resolve user email to UUID if needed
|
||||
user_uuid = uploaded_by
|
||||
if '@' in uploaded_by: # Check if it's an email
|
||||
user_uuid = await self._get_user_uuid(uploaded_by)
|
||||
|
||||
query = f"""
|
||||
INSERT INTO {self.schema_name}.conversation_files (
|
||||
id, conversation_id, filename, original_filename, content_type,
|
||||
file_size_bytes, file_path, uploaded_by, processing_status
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'pending')
|
||||
RETURNING id, filename, original_filename, content_type, file_size_bytes,
|
||||
processing_status, uploaded_at
|
||||
"""
|
||||
|
||||
result = await client.fetch_one(
|
||||
query,
|
||||
file_id, conversation_id, safe_filename, original_filename,
|
||||
content_type, file_size, file_path, user_uuid
|
||||
)
|
||||
|
||||
# Convert UUID fields to strings for JSON serialization
|
||||
result_dict = dict(result)
|
||||
if 'id' in result_dict and result_dict['id']:
|
||||
result_dict['id'] = str(result_dict['id'])
|
||||
|
||||
return result_dict
|
||||
|
||||
async def _process_file_embeddings(self, file_id: str):
|
||||
"""Background task to process file content and generate embeddings"""
|
||||
try:
|
||||
# Update status to processing
|
||||
await self._update_processing_status(file_id, "processing")
|
||||
|
||||
# Get file record
|
||||
file_record = await self._get_file_record(file_id)
|
||||
if not file_record:
|
||||
logger.error(f"File record not found: {file_id}")
|
||||
return
|
||||
|
||||
# Read file content
|
||||
file_path = Path(self.settings.file_storage_path) / file_record['file_path']
|
||||
if not file_path.exists():
|
||||
logger.error(f"File not found on disk: {file_path}")
|
||||
await self._update_processing_status(file_id, "failed")
|
||||
return
|
||||
|
||||
# Extract text content using DocumentProcessor public methods
|
||||
processor = DocumentProcessor()
|
||||
|
||||
text_content = await processor.extract_text_from_path(
|
||||
file_path,
|
||||
file_record['content_type']
|
||||
)
|
||||
|
||||
if not text_content:
|
||||
logger.warning(f"No text content extracted from {file_record['original_filename']}")
|
||||
await self._update_processing_status(file_id, "completed")
|
||||
return
|
||||
|
||||
# Chunk content for RAG
|
||||
chunks = await processor.chunk_text_simple(text_content)
|
||||
|
||||
# Generate embeddings for full document (single embedding for semantic search)
|
||||
embedding_client = BGE_M3_EmbeddingClient()
|
||||
embeddings = await embedding_client.generate_embeddings([text_content])
|
||||
|
||||
if not embeddings:
|
||||
logger.error(f"Failed to generate embeddings for {file_id}")
|
||||
await self._update_processing_status(file_id, "failed")
|
||||
return
|
||||
|
||||
# Update record with processed content (chunks as JSONB, embedding as vector)
|
||||
await self._update_file_processing_results(
|
||||
file_id, chunks, embeddings[0], "completed"
|
||||
)
|
||||
|
||||
logger.info(f"Successfully processed file: {file_record['original_filename']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process file {file_id}: {e}")
|
||||
await self._update_processing_status(file_id, "failed")
|
||||
|
||||
async def _update_processing_status(self, file_id: str, status: str):
|
||||
"""Update file processing status"""
|
||||
client = await get_postgresql_client()
|
||||
|
||||
query = f"""
|
||||
UPDATE {self.schema_name}.conversation_files
|
||||
SET processing_status = $1,
|
||||
processed_at = CASE WHEN $1 IN ('completed', 'failed') THEN NOW() ELSE processed_at END
|
||||
WHERE id = $2
|
||||
"""
|
||||
|
||||
await client.execute_query(query, status, file_id)
|
||||
|
||||
async def _update_file_processing_results(
|
||||
self,
|
||||
file_id: str,
|
||||
chunks: List[str],
|
||||
embedding: List[float],
|
||||
status: str
|
||||
):
|
||||
"""Update file with processing results"""
|
||||
import json
|
||||
client = await get_postgresql_client()
|
||||
|
||||
# Sanitize chunks: remove null bytes and other control characters
|
||||
# that PostgreSQL can't handle in JSONB
|
||||
sanitized_chunks = [
|
||||
chunk.replace('\u0000', '').replace('\x00', '')
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
# Convert chunks list to JSONB-compatible format
|
||||
chunks_json = json.dumps(sanitized_chunks)
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = f"[{','.join(map(str, embedding))}]"
|
||||
|
||||
query = f"""
|
||||
UPDATE {self.schema_name}.conversation_files
|
||||
SET processed_chunks = $1::jsonb,
|
||||
embeddings = $2::vector,
|
||||
processing_status = $3,
|
||||
processed_at = NOW()
|
||||
WHERE id = $4
|
||||
"""
|
||||
|
||||
await client.execute_query(query, chunks_json, embedding_str, status, file_id)
|
||||
|
||||
async def _get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get file record by ID"""
|
||||
client = await get_postgresql_client()
|
||||
|
||||
query = f"""
|
||||
SELECT id, conversation_id, filename, original_filename, content_type,
|
||||
file_size_bytes, file_path, processing_status, uploaded_at
|
||||
FROM {self.schema_name}.conversation_files
|
||||
WHERE id = $1
|
||||
"""
|
||||
|
||||
result = await client.fetch_one(query, file_id)
|
||||
return dict(result) if result else None
|
||||
|
||||
async def list_files(self, conversation_id: str) -> List[Dict[str, Any]]:
|
||||
"""List files attached to conversation"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
|
||||
query = f"""
|
||||
SELECT id, filename, original_filename, content_type, file_size_bytes,
|
||||
processing_status, uploaded_at, processed_at
|
||||
FROM {self.schema_name}.conversation_files
|
||||
WHERE conversation_id = $1
|
||||
ORDER BY uploaded_at DESC
|
||||
"""
|
||||
|
||||
rows = await client.execute_query(query, conversation_id)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list conversation files: {e}")
|
||||
return []
|
||||
|
||||
async def delete_file(self, conversation_id: str, file_id: str, user_id: str, allow_post_message_deletion: bool = False) -> bool:
|
||||
"""Delete specific file from conversation
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID
|
||||
file_id: The file ID to delete
|
||||
user_id: The user requesting deletion
|
||||
allow_post_message_deletion: If False, prevents deletion after messages exist (default: False)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"DELETE FILE CALLED: file_id={file_id}, conversation_id={conversation_id}, user_id={user_id}")
|
||||
|
||||
# Validate access
|
||||
await self._validate_conversation_access(conversation_id, user_id)
|
||||
logger.info(f"DELETE FILE: Access validated")
|
||||
|
||||
# Check if conversation has messages (unless explicitly allowed to delete post-message)
|
||||
if not allow_post_message_deletion:
|
||||
client = await get_postgresql_client()
|
||||
message_check_query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {self.schema_name}.messages
|
||||
WHERE conversation_id = $1
|
||||
"""
|
||||
message_count_result = await client.fetch_one(message_check_query, conversation_id)
|
||||
message_count = message_count_result['count'] if message_count_result else 0
|
||||
|
||||
if message_count > 0:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete files after conversation has started. Files are part of the conversation context."
|
||||
)
|
||||
|
||||
# Get file record for cleanup
|
||||
file_record = await self._get_file_record(file_id)
|
||||
logger.info(f"DELETE FILE: file_record={file_record}")
|
||||
if not file_record or str(file_record['conversation_id']) != conversation_id:
|
||||
logger.warning(f"DELETE FILE FAILED: file not found or conversation mismatch. file_record={file_record}, expected_conv_id={conversation_id}")
|
||||
return False
|
||||
|
||||
# Delete from database
|
||||
client = await get_postgresql_client()
|
||||
query = f"""
|
||||
DELETE FROM {self.schema_name}.conversation_files
|
||||
WHERE id = $1 AND conversation_id = $2
|
||||
"""
|
||||
|
||||
rows_deleted = await client.execute_command(query, file_id, conversation_id)
|
||||
|
||||
if rows_deleted > 0:
|
||||
# Delete file from disk
|
||||
file_path = Path(self.settings.file_storage_path) / file_record['file_path']
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
logger.info(f"Deleted conversation file: {file_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except HTTPException:
|
||||
raise # Re-raise HTTPException to preserve status code and message
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete conversation file: {e}")
|
||||
return False
|
||||
|
||||
async def search_conversation_files(
|
||||
self,
|
||||
conversation_id: str,
|
||||
query: str,
|
||||
max_results: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search files within a conversation using vector similarity"""
|
||||
try:
|
||||
# Generate query embedding
|
||||
embedding_client = BGE_M3_EmbeddingClient()
|
||||
embeddings = await embedding_client.generate_embeddings([query])
|
||||
|
||||
if not embeddings:
|
||||
return []
|
||||
|
||||
query_embedding = embeddings[0]
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = '[' + ','.join(map(str, query_embedding)) + ']'
|
||||
|
||||
# Vector search against conversation files
|
||||
client = await get_postgresql_client()
|
||||
|
||||
search_query = f"""
|
||||
SELECT id, filename, original_filename, processed_chunks,
|
||||
1 - (embeddings <=> $1::vector) as similarity_score
|
||||
FROM {self.schema_name}.conversation_files
|
||||
WHERE conversation_id = $2
|
||||
AND processing_status = 'completed'
|
||||
AND embeddings IS NOT NULL
|
||||
AND 1 - (embeddings <=> $1::vector) > 0.1
|
||||
ORDER BY embeddings <=> $1::vector
|
||||
LIMIT $3
|
||||
"""
|
||||
|
||||
rows = await client.execute_query(
|
||||
search_query, embedding_str, conversation_id, max_results
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
for row in rows:
|
||||
processed_chunks = row.get('processed_chunks', [])
|
||||
|
||||
if not processed_chunks:
|
||||
continue
|
||||
|
||||
# Handle case where processed_chunks might be returned as JSON string
|
||||
if isinstance(processed_chunks, str):
|
||||
import json
|
||||
processed_chunks = json.loads(processed_chunks)
|
||||
|
||||
for idx, chunk_text in enumerate(processed_chunks):
|
||||
results.append({
|
||||
'id': f"{row['id']}_chunk_{idx}",
|
||||
'document_id': row['id'],
|
||||
'document_name': row['original_filename'],
|
||||
'original_filename': row['original_filename'],
|
||||
'chunk_index': idx,
|
||||
'content': chunk_text,
|
||||
'similarity_score': row['similarity_score'],
|
||||
'source': 'conversation_file',
|
||||
'source_type': 'conversation_file'
|
||||
})
|
||||
|
||||
if len(results) >= max_results:
|
||||
results = results[:max_results]
|
||||
break
|
||||
|
||||
logger.info(f"Found {len(results)} chunks from {len(rows)} matching conversation files")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search conversation files: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_chunks_for_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
max_chunks_per_file: int = 50,
|
||||
max_total_chunks: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve ALL chunks from files attached to conversation.
|
||||
Non-query-dependent - returns everything up to limits.
|
||||
|
||||
Args:
|
||||
conversation_id: UUID of conversation
|
||||
max_chunks_per_file: Limit per file (enforces diversity)
|
||||
max_total_chunks: Total chunk limit across all files
|
||||
|
||||
Returns:
|
||||
List of chunks with metadata, grouped by file
|
||||
"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
|
||||
query = f"""
|
||||
SELECT id, filename, original_filename, processed_chunks,
|
||||
file_size_bytes, uploaded_at
|
||||
FROM {self.schema_name}.conversation_files
|
||||
WHERE conversation_id = $1
|
||||
AND processing_status = 'completed'
|
||||
AND processed_chunks IS NOT NULL
|
||||
ORDER BY uploaded_at ASC
|
||||
"""
|
||||
|
||||
rows = await client.execute_query(query, conversation_id)
|
||||
|
||||
results = []
|
||||
total_chunks = 0
|
||||
|
||||
for row in rows:
|
||||
if total_chunks >= max_total_chunks:
|
||||
break
|
||||
|
||||
processed_chunks = row.get('processed_chunks', [])
|
||||
|
||||
# Handle JSON string if needed
|
||||
if isinstance(processed_chunks, str):
|
||||
import json
|
||||
processed_chunks = json.loads(processed_chunks)
|
||||
|
||||
# Limit chunks per file (diversity enforcement)
|
||||
chunks_from_this_file = 0
|
||||
|
||||
for idx, chunk_text in enumerate(processed_chunks):
|
||||
if chunks_from_this_file >= max_chunks_per_file:
|
||||
break
|
||||
if total_chunks >= max_total_chunks:
|
||||
break
|
||||
|
||||
results.append({
|
||||
'id': f"{row['id']}_chunk_{idx}",
|
||||
'document_id': row['id'],
|
||||
'document_name': row['original_filename'],
|
||||
'original_filename': row['original_filename'],
|
||||
'chunk_index': idx,
|
||||
'total_chunks': len(processed_chunks),
|
||||
'content': chunk_text,
|
||||
'file_size_bytes': row['file_size_bytes'],
|
||||
'source': 'conversation_file',
|
||||
'source_type': 'conversation_file'
|
||||
})
|
||||
|
||||
chunks_from_this_file += 1
|
||||
total_chunks += 1
|
||||
|
||||
logger.info(f"Retrieved {len(results)} total chunks from {len(rows)} conversation files")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get all chunks for conversation: {e}")
|
||||
return []
|
||||
|
||||
async def _validate_conversation_access(self, conversation_id: str, user_id: str):
|
||||
"""Validate user has access to conversation"""
|
||||
client = await get_postgresql_client()
|
||||
|
||||
query = f"""
|
||||
SELECT id FROM {self.schema_name}.conversations
|
||||
WHERE id = $1 AND user_id = (
|
||||
SELECT id FROM {self.schema_name}.users WHERE email = $2 LIMIT 1
|
||||
)
|
||||
"""
|
||||
|
||||
result = await client.fetch_one(query, conversation_id, user_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied: conversation not found or access denied"
|
||||
)
|
||||
|
||||
async def get_file_content(self, file_id: str, user_id: str) -> Optional[bytes]:
|
||||
"""Get file content for download"""
|
||||
try:
|
||||
file_record = await self._get_file_record(file_id)
|
||||
if not file_record:
|
||||
return None
|
||||
|
||||
# Validate access to conversation
|
||||
await self._validate_conversation_access(file_record['conversation_id'], user_id)
|
||||
|
||||
# Read file content
|
||||
file_path = Path(self.settings.file_storage_path) / file_record['file_path']
|
||||
if file_path.exists():
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file content: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Factory function for service instances
|
||||
def get_conversation_file_service(tenant_domain: str, user_id: str) -> ConversationFileService:
|
||||
"""Get conversation file service instance"""
|
||||
return ConversationFileService(tenant_domain, user_id)
|
||||
959
apps/tenant-backend/app/services/conversation_service.py
Normal file
959
apps/tenant-backend/app/services/conversation_service.py
Normal file
@@ -0,0 +1,959 @@
|
||||
"""
|
||||
Conversation Service for GT 2.0 Tenant Backend - PostgreSQL + PGVector
|
||||
|
||||
Manages AI-powered conversations with Agent integration using PostgreSQL directly.
|
||||
Handles message persistence, context management, and LLM inference.
|
||||
Replaces SQLAlchemy with direct PostgreSQL operations for GT 2.0 principles.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional, AsyncIterator, AsyncGenerator
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.services.agent_service import AgentService
|
||||
from app.core.resource_client import ResourceClusterClient
|
||||
from app.services.conversation_summarizer import ConversationSummarizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationService:
|
||||
"""PostgreSQL-based service for managing AI conversations"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str):
|
||||
"""Initialize with tenant and user isolation using PostgreSQL"""
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.settings = get_settings()
|
||||
self.agent_service = AgentService(tenant_domain, user_id)
|
||||
self.resource_client = ResourceClusterClient()
|
||||
self._resolved_user_uuid = None # Cache for resolved user UUID
|
||||
|
||||
logger.info(f"Conversation service initialized with PostgreSQL for {tenant_domain}/{user_id}")
|
||||
|
||||
async def _get_resolved_user_uuid(self, user_identifier: Optional[str] = None) -> str:
|
||||
"""
|
||||
Resolve user identifier to UUID with caching for performance.
|
||||
|
||||
This optimization reduces repeated database lookups by caching the resolved UUID.
|
||||
Performance impact: ~50% reduction in query time for operations with multiple queries.
|
||||
"""
|
||||
identifier = user_identifier or self.user_id
|
||||
|
||||
# Return cached UUID if already resolved for this instance
|
||||
if self._resolved_user_uuid and identifier == self.user_id:
|
||||
return self._resolved_user_uuid
|
||||
|
||||
# Check if already a UUID
|
||||
if not "@" in identifier:
|
||||
try:
|
||||
# Validate it's a proper UUID format
|
||||
uuid.UUID(identifier)
|
||||
if identifier == self.user_id:
|
||||
self._resolved_user_uuid = identifier
|
||||
return identifier
|
||||
except ValueError:
|
||||
pass # Not a valid UUID, treat as email/username
|
||||
|
||||
# Resolve email to UUID
|
||||
pg_client = await get_postgresql_client()
|
||||
query = "SELECT id FROM users WHERE email = $1 LIMIT 1"
|
||||
result = await pg_client.fetch_one(query, identifier)
|
||||
|
||||
if not result:
|
||||
raise ValueError(f"User not found: {identifier}")
|
||||
|
||||
user_uuid = str(result["id"])
|
||||
|
||||
# Cache if this is the service's primary user
|
||||
if identifier == self.user_id:
|
||||
self._resolved_user_uuid = user_uuid
|
||||
|
||||
return user_uuid
|
||||
|
||||
def _get_user_clause(self, param_num: int, user_identifier: str) -> str:
|
||||
"""
|
||||
DEPRECATED: Get the appropriate SQL clause for user identification.
|
||||
Use _get_resolved_user_uuid() instead for better performance.
|
||||
"""
|
||||
if "@" in user_identifier:
|
||||
# Email - do lookup
|
||||
return f"(SELECT id FROM users WHERE email = ${param_num} LIMIT 1)"
|
||||
else:
|
||||
# UUID - use directly
|
||||
return f"${param_num}::uuid"
|
||||
|
||||
async def create_conversation(
|
||||
self,
|
||||
agent_id: str,
|
||||
title: Optional[str],
|
||||
user_identifier: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new conversation with an agent using PostgreSQL"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
# Get agent configuration
|
||||
agent_data = await self.agent_service.get_agent(agent_id)
|
||||
if not agent_data:
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
# Validate tenant has access to the agent's model
|
||||
agent_model = agent_data.get("model")
|
||||
if agent_model:
|
||||
available_models = await self.get_available_models(self.tenant_domain)
|
||||
available_model_ids = [m["model_id"] for m in available_models]
|
||||
|
||||
if agent_model not in available_model_ids:
|
||||
raise ValueError(f"Agent model '{agent_model}' is not accessible to tenant '{self.tenant_domain}'. Available models: {', '.join(available_model_ids)}")
|
||||
|
||||
logger.info(f"Validated tenant access to model '{agent_model}' for agent '{agent_data.get('name')}'")
|
||||
else:
|
||||
logger.warning(f"Agent {agent_id} has no model configured, will use default")
|
||||
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Generate conversation ID
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
# Create conversation in PostgreSQL (optimized: use resolved UUID directly)
|
||||
query = """
|
||||
INSERT INTO conversations (
|
||||
id, title, tenant_id, user_id, agent_id, summary,
|
||||
total_messages, total_tokens, metadata, is_archived,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
$1, $2,
|
||||
(SELECT id FROM tenants WHERE domain = $3 LIMIT 1),
|
||||
$4::uuid,
|
||||
$5, '', 0, 0, '{}', false, NOW(), NOW()
|
||||
)
|
||||
RETURNING id, title, tenant_id, user_id, agent_id, created_at, updated_at
|
||||
"""
|
||||
|
||||
conv_title = title or f"Conversation with {agent_data.get('name', 'Agent')}"
|
||||
|
||||
conversation_data = await pg_client.fetch_one(
|
||||
query,
|
||||
conversation_id, conv_title, self.tenant_domain,
|
||||
user_uuid, agent_id
|
||||
)
|
||||
|
||||
if not conversation_data:
|
||||
raise RuntimeError("Failed to create conversation - no data returned")
|
||||
|
||||
# Note: conversation_settings and conversation_participants are now created automatically
|
||||
# by the auto_create_conversation_settings trigger, so we don't need to create them manually
|
||||
|
||||
# Get the model_id from the auto-created settings or use agent's model
|
||||
settings_query = """
|
||||
SELECT model_id FROM conversation_settings WHERE conversation_id = $1
|
||||
"""
|
||||
settings_data = await pg_client.fetch_one(settings_query, conversation_id)
|
||||
model_id = settings_data["model_id"] if settings_data else agent_model
|
||||
|
||||
result = {
|
||||
"id": str(conversation_data["id"]),
|
||||
"title": conversation_data["title"],
|
||||
"agent_id": str(conversation_data["agent_id"]),
|
||||
"model_id": model_id,
|
||||
"created_at": conversation_data["created_at"].isoformat(),
|
||||
"user_id": user_uuid,
|
||||
"tenant_domain": self.tenant_domain
|
||||
}
|
||||
|
||||
logger.info(f"Created conversation {conversation_id} in PostgreSQL for user {user_uuid}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create conversation: {e}")
|
||||
raise
|
||||
|
||||
async def list_conversations(
|
||||
self,
|
||||
user_identifier: str,
|
||||
agent_id: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
time_filter: str = "all",
|
||||
limit: int = 20,
|
||||
offset: int = 0
|
||||
) -> Dict[str, Any]:
|
||||
"""List conversations for a user using PostgreSQL with server-side filtering"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Build query with optional filters - exclude archived conversations (optimized: use cached UUID)
|
||||
where_clause = "WHERE c.user_id = $1::uuid AND c.is_archived = false"
|
||||
params = [user_uuid]
|
||||
param_count = 1
|
||||
|
||||
# Time filter
|
||||
if time_filter != "all":
|
||||
if time_filter == "today":
|
||||
where_clause += " AND c.updated_at >= NOW() - INTERVAL '1 day'"
|
||||
elif time_filter == "week":
|
||||
where_clause += " AND c.updated_at >= NOW() - INTERVAL '7 days'"
|
||||
elif time_filter == "month":
|
||||
where_clause += " AND c.updated_at >= NOW() - INTERVAL '30 days'"
|
||||
|
||||
# Agent filter
|
||||
if agent_id:
|
||||
param_count += 1
|
||||
where_clause += f" AND c.agent_id = ${param_count}"
|
||||
params.append(agent_id)
|
||||
|
||||
# Search filter (case-insensitive title search)
|
||||
if search:
|
||||
param_count += 1
|
||||
where_clause += f" AND c.title ILIKE ${param_count}"
|
||||
params.append(f"%{search}%")
|
||||
|
||||
# Get conversations with agent info and unread counts (optimized: use cached UUID)
|
||||
query = f"""
|
||||
SELECT
|
||||
c.id, c.title, c.agent_id, c.created_at, c.updated_at,
|
||||
c.total_messages, c.total_tokens, c.is_archived,
|
||||
a.name as agent_name,
|
||||
COUNT(m.id) FILTER (
|
||||
WHERE m.created_at > COALESCE((c.metadata->>'last_read_at')::timestamptz, c.created_at)
|
||||
AND m.user_id != $1::uuid
|
||||
) as unread_count
|
||||
FROM conversations c
|
||||
LEFT JOIN agents a ON c.agent_id = a.id
|
||||
LEFT JOIN messages m ON m.conversation_id = c.id
|
||||
{where_clause}
|
||||
GROUP BY c.id, c.title, c.agent_id, c.created_at, c.updated_at,
|
||||
c.total_messages, c.total_tokens, c.is_archived, a.name
|
||||
ORDER BY
|
||||
CASE WHEN COUNT(m.id) FILTER (
|
||||
WHERE m.created_at > COALESCE((c.metadata->>'last_read_at')::timestamptz, c.created_at)
|
||||
AND m.user_id != $1::uuid
|
||||
) > 0 THEN 0 ELSE 1 END,
|
||||
c.updated_at DESC
|
||||
LIMIT ${param_count + 1} OFFSET ${param_count + 2}
|
||||
"""
|
||||
params.extend([limit, offset])
|
||||
|
||||
conversations = await pg_client.execute_query(query, *params)
|
||||
|
||||
# Get total count
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as total
|
||||
FROM conversations c
|
||||
{where_clause}
|
||||
"""
|
||||
count_result = await pg_client.fetch_one(count_query, *params[:-2]) # Exclude limit/offset
|
||||
total = count_result["total"] if count_result else 0
|
||||
|
||||
# Format results with lightweight fields including unread count
|
||||
conversation_list = [
|
||||
{
|
||||
"id": str(conv["id"]),
|
||||
"title": conv["title"],
|
||||
"agent_id": str(conv["agent_id"]) if conv["agent_id"] else None,
|
||||
"agent_name": conv["agent_name"] or "AI Assistant",
|
||||
"created_at": conv["created_at"].isoformat(),
|
||||
"updated_at": conv["updated_at"].isoformat(),
|
||||
"last_message_at": conv["updated_at"].isoformat(), # Use updated_at as last activity
|
||||
"message_count": conv["total_messages"] or 0,
|
||||
"token_count": conv["total_tokens"] or 0,
|
||||
"is_archived": conv["is_archived"],
|
||||
"unread_count": conv.get("unread_count", 0) or 0 # Include unread count
|
||||
# Removed preview field for performance
|
||||
}
|
||||
for conv in conversations
|
||||
]
|
||||
|
||||
return {
|
||||
"conversations": conversation_list,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list conversations: {e}")
|
||||
raise
|
||||
|
||||
async def get_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get a specific conversation with details"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
c.id, c.title, c.agent_id, c.created_at, c.updated_at,
|
||||
c.total_messages, c.total_tokens, c.is_archived, c.summary,
|
||||
a.name as agent_name,
|
||||
cs.model_id, cs.temperature, cs.max_tokens, cs.system_prompt
|
||||
FROM conversations c
|
||||
LEFT JOIN agents a ON c.agent_id = a.id
|
||||
LEFT JOIN conversation_settings cs ON c.id = cs.conversation_id
|
||||
WHERE c.id = $1
|
||||
AND c.user_id = $2::uuid
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
conversation = await pg_client.fetch_one(query, conversation_id, user_uuid)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": conversation["id"],
|
||||
"title": conversation["title"],
|
||||
"agent_id": conversation["agent_id"],
|
||||
"agent_name": conversation["agent_name"],
|
||||
"model_id": conversation["model_id"],
|
||||
"temperature": float(conversation["temperature"]) if conversation["temperature"] else 0.7,
|
||||
"max_tokens": conversation["max_tokens"],
|
||||
"system_prompt": conversation["system_prompt"],
|
||||
"summary": conversation["summary"],
|
||||
"message_count": conversation["total_messages"],
|
||||
"token_count": conversation["total_tokens"],
|
||||
"is_archived": conversation["is_archived"],
|
||||
"created_at": conversation["created_at"].isoformat(),
|
||||
"updated_at": conversation["updated_at"].isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation {conversation_id}: {e}")
|
||||
return None
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
user_identifier: str,
|
||||
model_used: Optional[str] = None,
|
||||
token_count: int = 0,
|
||||
metadata: Optional[Dict] = None,
|
||||
attachments: Optional[List] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a message to a conversation"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# Insert message (optimized: use cached UUID)
|
||||
query = """
|
||||
INSERT INTO messages (
|
||||
id, conversation_id, user_id, role, content,
|
||||
content_type, token_count, model_used, metadata, attachments, created_at
|
||||
) VALUES (
|
||||
$1, $2, $3::uuid,
|
||||
$4, $5, 'text', $6, $7, $8, $9, NOW()
|
||||
)
|
||||
RETURNING id, created_at
|
||||
"""
|
||||
|
||||
message_data = await pg_client.fetch_one(
|
||||
query,
|
||||
message_id, conversation_id, user_uuid,
|
||||
role, content, token_count, model_used,
|
||||
json.dumps(metadata or {}), json.dumps(attachments or [])
|
||||
)
|
||||
|
||||
if not message_data:
|
||||
raise RuntimeError("Failed to add message - no data returned")
|
||||
|
||||
# Update conversation totals (optimized: use cached UUID)
|
||||
update_query = """
|
||||
UPDATE conversations
|
||||
SET total_messages = total_messages + 1,
|
||||
total_tokens = total_tokens + $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND user_id = $2::uuid
|
||||
"""
|
||||
|
||||
await pg_client.execute_command(update_query, conversation_id, user_uuid, token_count)
|
||||
|
||||
result = {
|
||||
"id": message_data["id"],
|
||||
"conversation_id": conversation_id,
|
||||
"role": role,
|
||||
"content": content,
|
||||
"token_count": token_count,
|
||||
"model_used": model_used,
|
||||
"metadata": metadata or {},
|
||||
"attachments": attachments or [],
|
||||
"created_at": message_data["created_at"].isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Added message {message_id} to conversation {conversation_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add message to conversation {conversation_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_messages(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get messages for a conversation"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
m.id, m.role, m.content, m.content_type, m.token_count,
|
||||
m.model_used, m.finish_reason, m.metadata, m.attachments, m.created_at
|
||||
FROM messages m
|
||||
JOIN conversations c ON m.conversation_id = c.id
|
||||
WHERE c.id = $1
|
||||
AND c.user_id = $2::uuid
|
||||
ORDER BY m.created_at ASC
|
||||
LIMIT $3 OFFSET $4
|
||||
"""
|
||||
|
||||
messages = await pg_client.execute_query(query, conversation_id, user_uuid, limit, offset)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": msg["id"],
|
||||
"role": msg["role"],
|
||||
"content": msg["content"],
|
||||
"content_type": msg["content_type"],
|
||||
"token_count": msg["token_count"],
|
||||
"model_used": msg["model_used"],
|
||||
"finish_reason": msg["finish_reason"],
|
||||
"metadata": (
|
||||
json.loads(msg["metadata"]) if isinstance(msg["metadata"], str)
|
||||
else (msg["metadata"] if isinstance(msg["metadata"], dict) else {})
|
||||
),
|
||||
"attachments": (
|
||||
json.loads(msg["attachments"]) if isinstance(msg["attachments"], str)
|
||||
else (msg["attachments"] if isinstance(msg["attachments"], list) else [])
|
||||
),
|
||||
"context_sources": (
|
||||
(json.loads(msg["metadata"]) if isinstance(msg["metadata"], str) else msg["metadata"]).get("context_sources", [])
|
||||
if (isinstance(msg["metadata"], str) or isinstance(msg["metadata"], dict))
|
||||
else []
|
||||
),
|
||||
"created_at": msg["created_at"].isoformat()
|
||||
}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages for conversation {conversation_id}: {e}")
|
||||
return []
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
content: str,
|
||||
user_identifier: Optional[str] = None,
|
||||
stream: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a message to conversation and get AI response"""
|
||||
user_id = user_identifier or self.user_id
|
||||
|
||||
# Check if this is the first message
|
||||
existing_messages = await self.get_messages(conversation_id, user_id)
|
||||
is_first_message = len(existing_messages) == 0
|
||||
|
||||
# Add user message
|
||||
user_message = await self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=content,
|
||||
user_identifier=user_identifier
|
||||
)
|
||||
|
||||
# Get conversation details for agent
|
||||
conversation = await self.get_conversation(conversation_id, user_identifier)
|
||||
agent_id = conversation.get("agent_id")
|
||||
|
||||
ai_message = None
|
||||
if agent_id:
|
||||
agent_data = await self.agent_service.get_agent(agent_id)
|
||||
|
||||
# Prepare messages for AI
|
||||
messages = [
|
||||
{"role": "system", "content": agent_data.get("prompt_template", "You are a helpful assistant.")},
|
||||
{"role": "user", "content": content}
|
||||
]
|
||||
|
||||
# Get AI response
|
||||
ai_response = await self.get_ai_response(
|
||||
model=agent_data.get("model", "llama-3.1-8b-instant"),
|
||||
messages=messages,
|
||||
tenant_id=self.tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Extract content from response
|
||||
ai_content = ai_response["choices"][0]["message"]["content"]
|
||||
|
||||
# Add AI message
|
||||
ai_message = await self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="agent",
|
||||
content=ai_content,
|
||||
user_identifier=user_id,
|
||||
model_used=agent_data.get("model"),
|
||||
token_count=ai_response["usage"]["total_tokens"]
|
||||
)
|
||||
|
||||
return {
|
||||
"user_message": user_message,
|
||||
"ai_message": ai_message,
|
||||
"is_first_message": is_first_message,
|
||||
"conversation_id": conversation_id
|
||||
}
|
||||
|
||||
async def update_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str,
|
||||
title: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Update conversation properties like title"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Build dynamic update query based on provided fields
|
||||
update_fields = []
|
||||
params = []
|
||||
param_count = 1
|
||||
|
||||
if title is not None:
|
||||
update_fields.append(f"title = ${param_count}")
|
||||
params.append(title)
|
||||
param_count += 1
|
||||
|
||||
if not update_fields:
|
||||
return True # Nothing to update
|
||||
|
||||
# Add updated_at timestamp
|
||||
update_fields.append(f"updated_at = NOW()")
|
||||
|
||||
query = f"""
|
||||
UPDATE conversations
|
||||
SET {', '.join(update_fields)}
|
||||
WHERE id = ${param_count}
|
||||
AND user_id = ${param_count + 1}::uuid
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
params.extend([conversation_id, user_uuid])
|
||||
|
||||
result = await pg_client.fetch_scalar(query, *params)
|
||||
|
||||
if result:
|
||||
logger.info(f"Updated conversation {conversation_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update conversation {conversation_id}: {e}")
|
||||
return False
|
||||
|
||||
async def auto_generate_conversation_title(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str
|
||||
) -> Optional[str]:
|
||||
"""Generate conversation title based on first user prompt and agent response pair"""
|
||||
try:
|
||||
# Get only the first few messages (first exchange)
|
||||
messages = await self.get_messages(conversation_id, user_identifier, limit=2)
|
||||
|
||||
if not messages or len(messages) < 2:
|
||||
return None # Need at least one user-agent exchange
|
||||
|
||||
# Only use first user message and first agent response for title
|
||||
first_exchange = messages[:2]
|
||||
|
||||
# Generate title using the summarization service
|
||||
from app.services.conversation_summarizer import generate_conversation_title
|
||||
new_title = await generate_conversation_title(first_exchange, self.tenant_domain, user_identifier)
|
||||
|
||||
# Update the conversation with the generated title
|
||||
success = await self.update_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=user_identifier,
|
||||
title=new_title
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Auto-generated title '{new_title}' for conversation {conversation_id} based on first exchange")
|
||||
return new_title
|
||||
else:
|
||||
logger.warning(f"Failed to update conversation {conversation_id} with generated title")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to auto-generate title for conversation {conversation_id}: {e}")
|
||||
return None
|
||||
|
||||
async def delete_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str
|
||||
) -> bool:
|
||||
"""Soft delete a conversation (archive it)"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
query = """
|
||||
UPDATE conversations
|
||||
SET is_archived = true, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND user_id = $2::uuid
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
result = await pg_client.fetch_scalar(query, conversation_id, user_uuid)
|
||||
|
||||
if result:
|
||||
logger.info(f"Archived conversation {conversation_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to archive conversation {conversation_id}: {e}")
|
||||
return False
|
||||
|
||||
async def get_ai_response(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get AI response from Resource Cluster"""
|
||||
try:
|
||||
# Prepare request for Resource Cluster
|
||||
request_data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": top_p
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
request_data["tools"] = tools
|
||||
if tool_choice:
|
||||
request_data["tool_choice"] = tool_choice
|
||||
|
||||
# Call Resource Cluster AI inference endpoint
|
||||
response = await self.resource_client.call_inference_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint="chat/completions",
|
||||
data=request_data
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get AI response: {e}")
|
||||
raise
|
||||
|
||||
# Streaming removed for reliability - using non-streaming only
|
||||
|
||||
async def get_available_models(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get available models for tenant from Resource Cluster"""
|
||||
try:
|
||||
# Get models dynamically from Resource Cluster
|
||||
import aiohttp
|
||||
|
||||
resource_cluster_url = self.resource_client.base_url
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Get capability token for model access
|
||||
token = await self.resource_client._get_capability_token(
|
||||
tenant_id=tenant_id,
|
||||
user_id=self.user_id,
|
||||
resources=['model_registry']
|
||||
)
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Content-Type': 'application/json',
|
||||
'X-Tenant-ID': tenant_id,
|
||||
'X-User-ID': self.user_id
|
||||
}
|
||||
|
||||
async with session.get(
|
||||
f"{resource_cluster_url}/api/v1/models/",
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
|
||||
if response.status == 200:
|
||||
response_data = await response.json()
|
||||
models_data = response_data.get("models", [])
|
||||
|
||||
# Transform Resource Cluster model format to frontend format
|
||||
available_models = []
|
||||
for model in models_data:
|
||||
# Only include available models
|
||||
if model.get("status", {}).get("deployment") == "available":
|
||||
available_models.append({
|
||||
"id": model.get("uuid"), # Database UUID for unique identification
|
||||
"model_id": model["id"], # model_id string for API calls
|
||||
"name": model["name"],
|
||||
"provider": model["provider"],
|
||||
"model_type": model["model_type"],
|
||||
"context_window": model.get("performance", {}).get("context_window", 4000),
|
||||
"max_tokens": model.get("performance", {}).get("max_tokens", 4000),
|
||||
"performance": model.get("performance", {}), # Include full performance for chat.py
|
||||
"capabilities": {"chat": True} # All LLM models support chat
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(available_models)} models from Resource Cluster")
|
||||
return available_models
|
||||
else:
|
||||
logger.error(f"Resource Cluster returned {response.status}: {await response.text()}")
|
||||
raise RuntimeError(f"Resource Cluster API error: {response.status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get models from Resource Cluster: {e}")
|
||||
raise
|
||||
|
||||
async def get_conversation_datasets(self, conversation_id: str, user_identifier: str) -> List[str]:
|
||||
"""Get dataset IDs attached to a conversation"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Ensure proper schema qualification
|
||||
schema_name = f"tenant_{self.tenant_domain.replace('.', '_').replace('-', '_')}"
|
||||
|
||||
query = f"""
|
||||
SELECT cd.dataset_id
|
||||
FROM {schema_name}.conversations c
|
||||
JOIN {schema_name}.conversation_datasets cd ON cd.conversation_id = c.id
|
||||
WHERE c.id = $1
|
||||
AND c.user_id = (SELECT id FROM {schema_name}.users WHERE email = $2 LIMIT 1)
|
||||
AND cd.is_active = true
|
||||
ORDER BY cd.attached_at ASC
|
||||
"""
|
||||
|
||||
rows = await pg_client.execute_query(query, conversation_id, user_identifier)
|
||||
dataset_ids = [str(row['dataset_id']) for row in rows]
|
||||
|
||||
logger.info(f"Found {len(dataset_ids)} datasets for conversation {conversation_id}")
|
||||
return dataset_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation datasets: {e}")
|
||||
return []
|
||||
|
||||
async def add_datasets_to_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str,
|
||||
dataset_ids: List[str],
|
||||
source: str = "user_selected"
|
||||
) -> bool:
|
||||
"""Add datasets to a conversation"""
|
||||
try:
|
||||
if not dataset_ids:
|
||||
return True
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Ensure proper schema qualification
|
||||
schema_name = f"tenant_{self.tenant_domain.replace('.', '_').replace('-', '_')}"
|
||||
|
||||
# Get user ID first
|
||||
user_query = f"SELECT id FROM {schema_name}.users WHERE email = $1 LIMIT 1"
|
||||
user_result = await pg_client.fetch_scalar(user_query, user_identifier)
|
||||
|
||||
if not user_result:
|
||||
logger.error(f"User not found: {user_identifier}")
|
||||
return False
|
||||
|
||||
user_id = user_result
|
||||
|
||||
# Insert dataset attachments (ON CONFLICT DO NOTHING to avoid duplicates)
|
||||
values_list = []
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
for dataset_id in dataset_ids:
|
||||
values_list.append(f"(${param_idx}, ${param_idx + 1}, ${param_idx + 2})")
|
||||
params.extend([conversation_id, dataset_id, user_id])
|
||||
param_idx += 3
|
||||
|
||||
query = f"""
|
||||
INSERT INTO {schema_name}.conversation_datasets (conversation_id, dataset_id, attached_by)
|
||||
VALUES {', '.join(values_list)}
|
||||
ON CONFLICT (conversation_id, dataset_id) DO UPDATE SET
|
||||
is_active = true,
|
||||
attached_at = NOW()
|
||||
"""
|
||||
|
||||
await pg_client.execute_query(query, *params)
|
||||
|
||||
logger.info(f"Added {len(dataset_ids)} datasets to conversation {conversation_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add datasets to conversation: {e}")
|
||||
return False
|
||||
|
||||
async def copy_agent_datasets_to_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str,
|
||||
agent_id: str
|
||||
) -> bool:
|
||||
"""Copy an agent's default datasets to a new conversation"""
|
||||
try:
|
||||
# Get agent's selected dataset IDs from config
|
||||
from app.services.agent_service import AgentService
|
||||
agent_service = AgentService(self.tenant_domain, user_identifier)
|
||||
agent_data = await agent_service.get_agent(agent_id)
|
||||
|
||||
if not agent_data:
|
||||
logger.warning(f"Agent {agent_id} not found")
|
||||
return False
|
||||
|
||||
# Get selected_dataset_ids from agent config
|
||||
selected_dataset_ids = agent_data.get('selected_dataset_ids', [])
|
||||
|
||||
if not selected_dataset_ids:
|
||||
logger.info(f"Agent {agent_id} has no default datasets")
|
||||
return True
|
||||
|
||||
# Add agent's datasets to conversation
|
||||
success = await self.add_datasets_to_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=user_identifier,
|
||||
dataset_ids=selected_dataset_ids,
|
||||
source="agent_default"
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Copied {len(selected_dataset_ids)} datasets from agent {agent_id} to conversation {conversation_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to copy agent datasets: {e}")
|
||||
return False
|
||||
|
||||
async def get_recent_conversations(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""Get recent conversations ordered by last activity"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Handle both email and UUID formats using existing pattern
|
||||
user_clause = self._get_user_clause(1, user_id)
|
||||
|
||||
query = f"""
|
||||
SELECT c.id, c.title, c.created_at, c.updated_at,
|
||||
COUNT(m.id) as message_count,
|
||||
MAX(m.created_at) as last_message_at,
|
||||
a.name as agent_name
|
||||
FROM conversations c
|
||||
LEFT JOIN messages m ON m.conversation_id = c.id
|
||||
LEFT JOIN agents a ON a.id = c.agent_id
|
||||
WHERE c.user_id = {user_clause}
|
||||
AND c.is_archived = false
|
||||
GROUP BY c.id, c.title, c.created_at, c.updated_at, a.name
|
||||
ORDER BY COALESCE(MAX(m.created_at), c.created_at) DESC
|
||||
LIMIT $2
|
||||
"""
|
||||
|
||||
rows = await pg_client.execute_query(query, user_id, limit)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get recent conversations: {e}")
|
||||
return []
|
||||
|
||||
async def mark_conversation_read(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str
|
||||
) -> bool:
|
||||
"""
|
||||
Mark a conversation as read by updating last_read_at in metadata.
|
||||
|
||||
Args:
|
||||
conversation_id: UUID of the conversation
|
||||
user_identifier: User email or UUID
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Update last_read_at in conversation metadata
|
||||
query = """
|
||||
UPDATE conversations
|
||||
SET metadata = jsonb_set(
|
||||
COALESCE(metadata, '{}'::jsonb),
|
||||
'{last_read_at}',
|
||||
to_jsonb(NOW()::text)
|
||||
)
|
||||
WHERE id = $1
|
||||
AND user_id = $2::uuid
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
result = await pg_client.fetch_one(query, conversation_id, user_uuid)
|
||||
|
||||
if result:
|
||||
logger.info(f"Marked conversation {conversation_id} as read for user {user_identifier}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Conversation {conversation_id} not found or access denied for user {user_identifier}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark conversation as read: {e}")
|
||||
return False
|
||||
200
apps/tenant-backend/app/services/conversation_summarizer.py
Normal file
200
apps/tenant-backend/app/services/conversation_summarizer.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Conversation Summarization Service for GT 2.0
|
||||
|
||||
Automatically generates meaningful conversation titles using a specialized
|
||||
summarization agent with llama-3.1-8b-instant.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.resource_client import ResourceClusterClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class ConversationSummarizer:
|
||||
"""Service for generating conversation summaries and titles"""
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.resource_client = ResourceClusterClient()
|
||||
self.summarization_model = "llama-3.1-8b-instant"
|
||||
|
||||
async def generate_conversation_title(self, messages: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Generate a concise conversation title based on the conversation content.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries from the conversation
|
||||
|
||||
Returns:
|
||||
Generated conversation title (3-6 words)
|
||||
"""
|
||||
try:
|
||||
# Extract conversation context for summarization
|
||||
conversation_text = self._prepare_conversation_context(messages)
|
||||
|
||||
if not conversation_text.strip():
|
||||
return "New Conversation"
|
||||
|
||||
# Generate title using specialized summarization prompt
|
||||
title = await self._call_summarization_agent(conversation_text)
|
||||
|
||||
# Validate and clean the generated title
|
||||
clean_title = self._clean_title(title)
|
||||
|
||||
logger.info(f"Generated conversation title: '{clean_title}' from {len(messages)} messages")
|
||||
return clean_title
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating conversation title: {e}")
|
||||
return self._fallback_title(messages)
|
||||
|
||||
def _prepare_conversation_context(self, messages: List[Dict[str, Any]]) -> str:
|
||||
"""Prepare conversation context for summarization"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Limit to first few exchanges for title generation
|
||||
context_messages = messages[:6] # First 3 user-agent exchanges
|
||||
|
||||
context_parts = []
|
||||
for msg in context_messages:
|
||||
role = "User" if msg.get("role") == "user" else "Agent"
|
||||
# Truncate very long messages for context
|
||||
content = msg.get("content", "")
|
||||
content = content[:200] + "..." if len(content) > 200 else content
|
||||
context_parts.append(f"{role}: {content}")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
async def _call_summarization_agent(self, conversation_text: str) -> str:
|
||||
"""Call the resource cluster AI inference for summarization"""
|
||||
|
||||
summarization_prompt = f"""You are a conversation title generator. Your job is to create concise, descriptive titles for conversations.
|
||||
|
||||
Given this conversation:
|
||||
---
|
||||
{conversation_text}
|
||||
---
|
||||
|
||||
Generate a title that:
|
||||
- Is 3-6 words maximum
|
||||
- Captures the main topic or purpose
|
||||
- Is clear and descriptive
|
||||
- Uses title case
|
||||
- Does NOT include quotes or special characters
|
||||
|
||||
Examples of good titles:
|
||||
- "Python Code Review"
|
||||
- "Database Migration Help"
|
||||
- "React Component Design"
|
||||
- "System Architecture Discussion"
|
||||
|
||||
Title:"""
|
||||
|
||||
request_data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": summarization_prompt
|
||||
}
|
||||
],
|
||||
"model": self.summarization_model,
|
||||
"temperature": 0.3, # Lower temperature for consistent, focused titles
|
||||
"max_tokens": 20, # Short response for title generation
|
||||
"stream": False
|
||||
}
|
||||
|
||||
try:
|
||||
# Use the resource client instead of direct HTTP calls
|
||||
result = await self.resource_client.call_inference_endpoint(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
endpoint="chat/completions",
|
||||
data=request_data
|
||||
)
|
||||
|
||||
if result and "choices" in result and len(result["choices"]) > 0:
|
||||
title = result["choices"][0]["message"]["content"].strip()
|
||||
return title
|
||||
else:
|
||||
logger.error("Invalid response format from summarization API")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling summarization agent: {e}")
|
||||
return ""
|
||||
|
||||
def _clean_title(self, raw_title: str) -> str:
|
||||
"""Clean and validate the generated title"""
|
||||
if not raw_title:
|
||||
return "New Conversation"
|
||||
|
||||
# Remove quotes, extra whitespace, and special characters
|
||||
cleaned = raw_title.strip().strip('"\'').strip()
|
||||
|
||||
# Remove common prefixes that AI might add
|
||||
prefixes_to_remove = [
|
||||
"Title:", "title:", "TITLE:",
|
||||
"Conversation:", "conversation:",
|
||||
"Topic:", "topic:",
|
||||
"Subject:", "subject:"
|
||||
]
|
||||
|
||||
for prefix in prefixes_to_remove:
|
||||
if cleaned.startswith(prefix):
|
||||
cleaned = cleaned[len(prefix):].strip()
|
||||
|
||||
# Limit length and ensure it's reasonable
|
||||
if len(cleaned) > 50:
|
||||
cleaned = cleaned[:47] + "..."
|
||||
|
||||
# Ensure it's not empty after cleaning
|
||||
if not cleaned or len(cleaned.split()) > 8:
|
||||
return "New Conversation"
|
||||
|
||||
return cleaned
|
||||
|
||||
def _fallback_title(self, messages: List[Dict[str, Any]]) -> str:
|
||||
"""Generate fallback title when AI summarization fails"""
|
||||
if not messages:
|
||||
return "New Conversation"
|
||||
|
||||
# Try to use the first user message for context
|
||||
first_user_msg = next((msg for msg in messages if msg.get("role") == "user"), None)
|
||||
|
||||
if first_user_msg and first_user_msg.get("content"):
|
||||
# Extract first few words from the user's message
|
||||
words = first_user_msg["content"].strip().split()[:4]
|
||||
if len(words) >= 2:
|
||||
fallback = " ".join(words).capitalize()
|
||||
# Remove common question words for cleaner titles
|
||||
for word in ["How", "What", "Can", "Could", "Please", "Help"]:
|
||||
if fallback.startswith(word + " "):
|
||||
fallback = fallback[len(word):].strip()
|
||||
break
|
||||
return fallback if fallback else "New Conversation"
|
||||
|
||||
return "New Conversation"
|
||||
|
||||
|
||||
async def generate_conversation_title(messages: List[Dict[str, Any]], tenant_id: str, user_id: str) -> str:
|
||||
"""
|
||||
Convenience function to generate a conversation title.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries from the conversation
|
||||
tenant_id: Tenant identifier
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
Generated conversation title
|
||||
"""
|
||||
summarizer = ConversationSummarizer(tenant_id, user_id)
|
||||
return await summarizer.generate_conversation_title(messages)
|
||||
1064
apps/tenant-backend/app/services/dataset_service.py
Normal file
1064
apps/tenant-backend/app/services/dataset_service.py
Normal file
File diff suppressed because it is too large
Load Diff
585
apps/tenant-backend/app/services/dataset_sharing.py
Normal file
585
apps/tenant-backend/app/services/dataset_sharing.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""
|
||||
Dataset Sharing Service for GT 2.0
|
||||
|
||||
Implements hierarchical dataset sharing with perfect tenant isolation.
|
||||
Enables secure data collaboration while maintaining ownership and access control.
|
||||
"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from uuid import uuid4
|
||||
|
||||
from app.models.access_group import AccessGroup, Resource
|
||||
from app.services.access_controller import AccessController
|
||||
from app.core.security import verify_capability_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharingPermission(Enum):
|
||||
"""Sharing permission levels"""
|
||||
READ = "read" # Can view and search dataset
|
||||
WRITE = "write" # Can add documents
|
||||
ADMIN = "admin" # Can modify sharing settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetShare:
|
||||
"""Dataset sharing configuration"""
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
dataset_id: str = ""
|
||||
owner_id: str = ""
|
||||
access_group: AccessGroup = AccessGroup.INDIVIDUAL
|
||||
team_members: List[str] = field(default_factory=list)
|
||||
team_permissions: Dict[str, SharingPermission] = field(default_factory=dict)
|
||||
shared_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: Optional[datetime] = None
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"dataset_id": self.dataset_id,
|
||||
"owner_id": self.owner_id,
|
||||
"access_group": self.access_group.value,
|
||||
"team_members": self.team_members,
|
||||
"team_permissions": {k: v.value for k, v in self.team_permissions.items()},
|
||||
"shared_at": self.shared_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"is_active": self.is_active
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DatasetShare":
|
||||
"""Create from dictionary"""
|
||||
return cls(
|
||||
id=data.get("id", str(uuid4())),
|
||||
dataset_id=data["dataset_id"],
|
||||
owner_id=data["owner_id"],
|
||||
access_group=AccessGroup(data["access_group"]),
|
||||
team_members=data.get("team_members", []),
|
||||
team_permissions={
|
||||
k: SharingPermission(v) for k, v in data.get("team_permissions", {}).items()
|
||||
},
|
||||
shared_at=datetime.fromisoformat(data["shared_at"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
|
||||
is_active=data.get("is_active", True)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
"""Dataset information for sharing"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
owner_id: str
|
||||
document_count: int
|
||||
size_bytes: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
tags: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class DatasetSharingService:
|
||||
"""
|
||||
Service for hierarchical dataset sharing with capability-based access control.
|
||||
|
||||
Features:
|
||||
- Individual, Team, and Organization level sharing
|
||||
- Granular permission management (read, write, admin)
|
||||
- Time-based expiration of shares
|
||||
- Perfect tenant isolation through file-based storage
|
||||
- Event emission for sharing activities
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, access_controller: AccessController):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.access_controller = access_controller
|
||||
self.base_path = Path(f"/data/{tenant_domain}/dataset_sharing")
|
||||
self.shares_path = self.base_path / "shares"
|
||||
self.datasets_path = self.base_path / "datasets"
|
||||
|
||||
# Ensure directories exist with proper permissions
|
||||
self._ensure_directories()
|
||||
|
||||
logger.info(f"DatasetSharingService initialized for {tenant_domain}")
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure sharing directories exist with proper permissions"""
|
||||
for path in [self.shares_path, self.datasets_path]:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
# Set permissions to 700 (owner only)
|
||||
os.chmod(path, stat.S_IRWXU)
|
||||
|
||||
async def share_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
owner_id: str,
|
||||
access_group: AccessGroup,
|
||||
team_members: Optional[List[str]] = None,
|
||||
team_permissions: Optional[Dict[str, SharingPermission]] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
capability_token: str = ""
|
||||
) -> DatasetShare:
|
||||
"""
|
||||
Share a dataset with specified access group.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset to share
|
||||
owner_id: Owner of the dataset
|
||||
access_group: Level of sharing (Individual, Team, Organization)
|
||||
team_members: List of team members (if Team access)
|
||||
team_permissions: Permissions for each team member
|
||||
expires_at: Optional expiration time
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
DatasetShare configuration
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Verify ownership
|
||||
dataset_resource = await self._load_dataset_resource(dataset_id)
|
||||
if not dataset_resource or dataset_resource.owner_id != owner_id:
|
||||
raise PermissionError("Only dataset owner can modify sharing")
|
||||
|
||||
# Validate team members for team sharing
|
||||
if access_group == AccessGroup.TEAM:
|
||||
if not team_members:
|
||||
raise ValueError("Team members required for team sharing")
|
||||
|
||||
# Ensure all team members are valid users in tenant
|
||||
for member in team_members:
|
||||
if not await self._is_valid_tenant_user(member):
|
||||
logger.warning(f"Invalid team member: {member}")
|
||||
|
||||
# Create sharing configuration
|
||||
share = DatasetShare(
|
||||
dataset_id=dataset_id,
|
||||
owner_id=owner_id,
|
||||
access_group=access_group,
|
||||
team_members=team_members or [],
|
||||
team_permissions=team_permissions or {},
|
||||
expires_at=expires_at
|
||||
)
|
||||
|
||||
# Set default permissions for team members
|
||||
if access_group == AccessGroup.TEAM:
|
||||
for member in share.team_members:
|
||||
if member not in share.team_permissions:
|
||||
share.team_permissions[member] = SharingPermission.READ
|
||||
|
||||
# Store sharing configuration
|
||||
await self._store_share(share)
|
||||
|
||||
# Update dataset resource access group
|
||||
await self.access_controller.update_resource_access(
|
||||
owner_id, dataset_id, access_group, team_members
|
||||
)
|
||||
|
||||
# Emit sharing event
|
||||
if hasattr(self.access_controller, 'event_bus'):
|
||||
await self.access_controller.event_bus.emit_event(
|
||||
"dataset.shared",
|
||||
owner_id,
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"access_group": access_group.value,
|
||||
"team_members": team_members or [],
|
||||
"expires_at": expires_at.isoformat() if expires_at else None
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Dataset {dataset_id} shared as {access_group.value} by {owner_id}")
|
||||
return share
|
||||
|
||||
async def get_dataset_sharing(
|
||||
self,
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
capability_token: str
|
||||
) -> Optional[DatasetShare]:
|
||||
"""
|
||||
Get sharing configuration for a dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
user_id: Requesting user
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
DatasetShare if user has access, None otherwise
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Load sharing configuration
|
||||
share = await self._load_share(dataset_id)
|
||||
if not share:
|
||||
return None
|
||||
|
||||
# Check if user has access to view sharing info
|
||||
if share.owner_id == user_id:
|
||||
return share # Owner can always see
|
||||
|
||||
if share.access_group == AccessGroup.TEAM and user_id in share.team_members:
|
||||
return share # Team member can see
|
||||
|
||||
if share.access_group == AccessGroup.ORGANIZATION:
|
||||
# All tenant users can see organization shares
|
||||
if await self._is_valid_tenant_user(user_id):
|
||||
return share
|
||||
|
||||
return None
|
||||
|
||||
async def check_dataset_access(
|
||||
self,
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
permission: SharingPermission = SharingPermission.READ
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if user has specified permission on dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset to check
|
||||
user_id: User requesting access
|
||||
permission: Required permission level
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, reason)
|
||||
"""
|
||||
# Load sharing configuration
|
||||
share = await self._load_share(dataset_id)
|
||||
if not share or not share.is_active:
|
||||
return False, "Dataset not shared or sharing inactive"
|
||||
|
||||
# Check expiration
|
||||
if share.expires_at and datetime.utcnow() > share.expires_at:
|
||||
return False, "Dataset sharing has expired"
|
||||
|
||||
# Owner has all permissions
|
||||
if share.owner_id == user_id:
|
||||
return True, "Owner access"
|
||||
|
||||
# Check access group permissions
|
||||
if share.access_group == AccessGroup.INDIVIDUAL:
|
||||
return False, "Private dataset"
|
||||
|
||||
elif share.access_group == AccessGroup.TEAM:
|
||||
if user_id not in share.team_members:
|
||||
return False, "Not a team member"
|
||||
|
||||
# Check specific permission
|
||||
user_permission = share.team_permissions.get(user_id, SharingPermission.READ)
|
||||
if self._has_permission(user_permission, permission):
|
||||
return True, f"Team member with {user_permission.value} permission"
|
||||
else:
|
||||
return False, f"Insufficient permission: has {user_permission.value}, needs {permission.value}"
|
||||
|
||||
elif share.access_group == AccessGroup.ORGANIZATION:
|
||||
# Organization sharing is typically read-only
|
||||
if permission == SharingPermission.READ:
|
||||
if await self._is_valid_tenant_user(user_id):
|
||||
return True, "Organization-wide read access"
|
||||
return False, "Organization access is read-only"
|
||||
|
||||
return False, "Unknown access configuration"
|
||||
|
||||
async def list_accessible_datasets(
|
||||
self,
|
||||
user_id: str,
|
||||
capability_token: str,
|
||||
include_owned: bool = True,
|
||||
include_shared: bool = True
|
||||
) -> List[DatasetInfo]:
|
||||
"""
|
||||
List datasets accessible to user.
|
||||
|
||||
Args:
|
||||
user_id: User requesting list
|
||||
capability_token: JWT capability token
|
||||
include_owned: Include user's own datasets
|
||||
include_shared: Include datasets shared with user
|
||||
|
||||
Returns:
|
||||
List of accessible datasets
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
accessible_datasets = []
|
||||
|
||||
# Get all dataset shares
|
||||
all_shares = await self._list_all_shares()
|
||||
|
||||
for share in all_shares:
|
||||
# Skip inactive or expired shares
|
||||
if not share.is_active:
|
||||
continue
|
||||
if share.expires_at and datetime.utcnow() > share.expires_at:
|
||||
continue
|
||||
|
||||
# Check if user has access
|
||||
has_access = False
|
||||
|
||||
if include_owned and share.owner_id == user_id:
|
||||
has_access = True
|
||||
elif include_shared:
|
||||
allowed, _ = await self.check_dataset_access(share.dataset_id, user_id)
|
||||
has_access = allowed
|
||||
|
||||
if has_access:
|
||||
dataset_info = await self._load_dataset_info(share.dataset_id)
|
||||
if dataset_info:
|
||||
accessible_datasets.append(dataset_info)
|
||||
|
||||
return accessible_datasets
|
||||
|
||||
async def revoke_dataset_sharing(
|
||||
self,
|
||||
dataset_id: str,
|
||||
owner_id: str,
|
||||
capability_token: str
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke dataset sharing (make it private).
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset to make private
|
||||
owner_id: Owner of the dataset
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
True if revoked successfully
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Verify ownership
|
||||
share = await self._load_share(dataset_id)
|
||||
if not share or share.owner_id != owner_id:
|
||||
raise PermissionError("Only dataset owner can revoke sharing")
|
||||
|
||||
# Update sharing to individual (private)
|
||||
share.access_group = AccessGroup.INDIVIDUAL
|
||||
share.team_members = []
|
||||
share.team_permissions = {}
|
||||
share.is_active = False
|
||||
|
||||
# Store updated share
|
||||
await self._store_share(share)
|
||||
|
||||
# Update resource access
|
||||
await self.access_controller.update_resource_access(
|
||||
owner_id, dataset_id, AccessGroup.INDIVIDUAL, []
|
||||
)
|
||||
|
||||
# Emit revocation event
|
||||
if hasattr(self.access_controller, 'event_bus'):
|
||||
await self.access_controller.event_bus.emit_event(
|
||||
"dataset.sharing_revoked",
|
||||
owner_id,
|
||||
{"dataset_id": dataset_id}
|
||||
)
|
||||
|
||||
logger.info(f"Dataset {dataset_id} sharing revoked by {owner_id}")
|
||||
return True
|
||||
|
||||
async def update_team_permissions(
|
||||
self,
|
||||
dataset_id: str,
|
||||
owner_id: str,
|
||||
user_id: str,
|
||||
permission: SharingPermission,
|
||||
capability_token: str
|
||||
) -> bool:
|
||||
"""
|
||||
Update team member permissions for a dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
owner_id: Owner of the dataset
|
||||
user_id: Team member to update
|
||||
permission: New permission level
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
True if updated successfully
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Load and verify sharing
|
||||
share = await self._load_share(dataset_id)
|
||||
if not share or share.owner_id != owner_id:
|
||||
raise PermissionError("Only dataset owner can update permissions")
|
||||
|
||||
if share.access_group != AccessGroup.TEAM:
|
||||
raise ValueError("Can only update permissions for team-shared datasets")
|
||||
|
||||
if user_id not in share.team_members:
|
||||
raise ValueError("User is not a team member")
|
||||
|
||||
# Update permission
|
||||
share.team_permissions[user_id] = permission
|
||||
|
||||
# Store updated share
|
||||
await self._store_share(share)
|
||||
|
||||
logger.info(f"Updated {user_id} permission to {permission.value} for dataset {dataset_id}")
|
||||
return True
|
||||
|
||||
async def get_sharing_statistics(
|
||||
self,
|
||||
user_id: str,
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get sharing statistics for user.
|
||||
|
||||
Args:
|
||||
user_id: User to get stats for
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
stats = {
|
||||
"owned_datasets": 0,
|
||||
"shared_with_me": 0,
|
||||
"sharing_breakdown": {
|
||||
AccessGroup.INDIVIDUAL: 0,
|
||||
AccessGroup.TEAM: 0,
|
||||
AccessGroup.ORGANIZATION: 0
|
||||
},
|
||||
"total_team_members": 0,
|
||||
"expired_shares": 0
|
||||
}
|
||||
|
||||
all_shares = await self._list_all_shares()
|
||||
|
||||
for share in all_shares:
|
||||
# Count owned datasets
|
||||
if share.owner_id == user_id:
|
||||
stats["owned_datasets"] += 1
|
||||
stats["sharing_breakdown"][share.access_group] += 1
|
||||
stats["total_team_members"] += len(share.team_members)
|
||||
|
||||
# Count expired shares
|
||||
if share.expires_at and datetime.utcnow() > share.expires_at:
|
||||
stats["expired_shares"] += 1
|
||||
|
||||
# Count datasets shared with user
|
||||
elif user_id in share.team_members or share.access_group == AccessGroup.ORGANIZATION:
|
||||
if share.is_active and (not share.expires_at or datetime.utcnow() <= share.expires_at):
|
||||
stats["shared_with_me"] += 1
|
||||
|
||||
return stats
|
||||
|
||||
def _has_permission(self, user_permission: SharingPermission, required: SharingPermission) -> bool:
|
||||
"""Check if user permission satisfies required permission"""
|
||||
permission_hierarchy = {
|
||||
SharingPermission.READ: 1,
|
||||
SharingPermission.WRITE: 2,
|
||||
SharingPermission.ADMIN: 3
|
||||
}
|
||||
|
||||
return permission_hierarchy[user_permission] >= permission_hierarchy[required]
|
||||
|
||||
async def _store_share(self, share: DatasetShare):
|
||||
"""Store sharing configuration to file system"""
|
||||
share_file = self.shares_path / f"{share.dataset_id}.json"
|
||||
|
||||
with open(share_file, "w") as f:
|
||||
json.dump(share.to_dict(), f, indent=2)
|
||||
|
||||
# Set secure permissions
|
||||
os.chmod(share_file, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
async def _load_share(self, dataset_id: str) -> Optional[DatasetShare]:
|
||||
"""Load sharing configuration from file system"""
|
||||
share_file = self.shares_path / f"{dataset_id}.json"
|
||||
|
||||
if not share_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(share_file, "r") as f:
|
||||
data = json.load(f)
|
||||
return DatasetShare.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading share for dataset {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _list_all_shares(self) -> List[DatasetShare]:
|
||||
"""List all sharing configurations"""
|
||||
shares = []
|
||||
|
||||
if self.shares_path.exists():
|
||||
for share_file in self.shares_path.glob("*.json"):
|
||||
try:
|
||||
with open(share_file, "r") as f:
|
||||
data = json.load(f)
|
||||
shares.append(DatasetShare.from_dict(data))
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading share file {share_file}: {e}")
|
||||
|
||||
return shares
|
||||
|
||||
async def _load_dataset_resource(self, dataset_id: str) -> Optional[Resource]:
|
||||
"""Load dataset resource (implementation would query storage)"""
|
||||
# Placeholder - would integrate with actual resource storage
|
||||
return Resource(
|
||||
id=dataset_id,
|
||||
name=f"Dataset {dataset_id}",
|
||||
resource_type="dataset",
|
||||
owner_id="mock_owner",
|
||||
tenant_domain=self.tenant_domain,
|
||||
access_group=AccessGroup.INDIVIDUAL
|
||||
)
|
||||
|
||||
async def _load_dataset_info(self, dataset_id: str) -> Optional[DatasetInfo]:
|
||||
"""Load dataset information (implementation would query storage)"""
|
||||
# Placeholder - would integrate with actual dataset storage
|
||||
return DatasetInfo(
|
||||
id=dataset_id,
|
||||
name=f"Dataset {dataset_id}",
|
||||
description="Mock dataset for testing",
|
||||
owner_id="mock_owner",
|
||||
document_count=10,
|
||||
size_bytes=1024000,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
tags=["test", "mock"]
|
||||
)
|
||||
|
||||
async def _is_valid_tenant_user(self, user_id: str) -> bool:
|
||||
"""Check if user is valid in tenant (implementation would query user store)"""
|
||||
# Placeholder - would integrate with actual user management
|
||||
return "@" in user_id and user_id.endswith((".com", ".org", ".edu"))
|
||||
445
apps/tenant-backend/app/services/dataset_summarizer.py
Normal file
445
apps/tenant-backend/app/services/dataset_summarizer.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Dataset Summarization Service for GT 2.0
|
||||
|
||||
Generates comprehensive summaries for datasets based on their constituent documents.
|
||||
Provides analytics, topic clustering, and overview generation for RAG optimization.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from collections import Counter
|
||||
|
||||
from app.core.database import get_db_session, execute_command, fetch_one, fetch_all
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetSummarizer:
|
||||
"""
|
||||
Service for generating dataset-level summaries and analytics.
|
||||
|
||||
Features:
|
||||
- Aggregate document summaries into dataset overview
|
||||
- Topic clustering and theme analysis
|
||||
- Dataset statistics and metrics
|
||||
- Search optimization recommendations
|
||||
- RAG performance insights
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.resource_cluster_url = "http://gentwo-resource-backend:8000"
|
||||
|
||||
async def generate_dataset_summary(
|
||||
self,
|
||||
dataset_id: str,
|
||||
tenant_domain: str,
|
||||
user_id: str,
|
||||
force_regenerate: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a comprehensive summary for a dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID to summarize
|
||||
tenant_domain: Tenant domain for database context
|
||||
user_id: User requesting the summary
|
||||
force_regenerate: Force regeneration even if summary exists
|
||||
|
||||
Returns:
|
||||
Dictionary with dataset summary including overview, topics,
|
||||
statistics, and search optimization insights
|
||||
"""
|
||||
try:
|
||||
# Check if summary already exists and is recent
|
||||
if not force_regenerate:
|
||||
existing_summary = await self._get_existing_summary(dataset_id, tenant_domain)
|
||||
if existing_summary and self._is_summary_fresh(existing_summary):
|
||||
logger.info(f"Using cached dataset summary for {dataset_id}")
|
||||
return existing_summary
|
||||
|
||||
# Get dataset information and documents
|
||||
dataset_info = await self._get_dataset_info(dataset_id, tenant_domain)
|
||||
if not dataset_info:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
|
||||
documents = await self._get_dataset_documents(dataset_id, tenant_domain)
|
||||
document_summaries = await self._get_document_summaries(dataset_id, tenant_domain)
|
||||
|
||||
# Generate statistics
|
||||
stats = await self._calculate_dataset_statistics(dataset_id, tenant_domain)
|
||||
|
||||
# Analyze topics across all documents
|
||||
topics_analysis = await self._analyze_dataset_topics(document_summaries)
|
||||
|
||||
# Generate overall summary using LLM
|
||||
overview = await self._generate_dataset_overview(
|
||||
dataset_info, document_summaries, topics_analysis, stats
|
||||
)
|
||||
|
||||
# Create comprehensive summary
|
||||
summary_data = {
|
||||
"dataset_id": dataset_id,
|
||||
"overview": overview,
|
||||
"statistics": stats,
|
||||
"topics": topics_analysis,
|
||||
"recommendations": await self._generate_search_recommendations(stats, topics_analysis),
|
||||
"metadata": {
|
||||
"document_count": len(documents),
|
||||
"has_summaries": len(document_summaries),
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"generated_by": user_id
|
||||
}
|
||||
}
|
||||
|
||||
# Store summary in database
|
||||
await self._store_dataset_summary(dataset_id, summary_data, tenant_domain, user_id)
|
||||
|
||||
logger.info(f"Generated dataset summary for {dataset_id} with {len(documents)} documents")
|
||||
return summary_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate dataset summary for {dataset_id}: {e}")
|
||||
# Return basic fallback summary
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"overview": "Dataset summary generation failed",
|
||||
"statistics": {"error": str(e)},
|
||||
"topics": [],
|
||||
"recommendations": [],
|
||||
"metadata": {
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
async def _get_dataset_info(self, dataset_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get basic dataset information"""
|
||||
async with get_db_session() as session:
|
||||
query = """
|
||||
SELECT id, dataset_name, description, chunking_strategy,
|
||||
chunk_size, chunk_overlap, created_at
|
||||
FROM datasets
|
||||
WHERE id = $1
|
||||
"""
|
||||
result = await fetch_one(session, query, dataset_id)
|
||||
return dict(result) if result else None
|
||||
|
||||
async def _get_dataset_documents(self, dataset_id: str, tenant_domain: str) -> List[Dict[str, Any]]:
|
||||
"""Get all documents in the dataset"""
|
||||
async with get_db_session() as session:
|
||||
query = """
|
||||
SELECT id, filename, original_filename, file_type,
|
||||
file_size_bytes, chunk_count, created_at
|
||||
FROM documents
|
||||
WHERE dataset_id = $1 AND processing_status = 'completed'
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
results = await fetch_all(session, query, dataset_id)
|
||||
return [dict(row) for row in results]
|
||||
|
||||
async def _get_document_summaries(self, dataset_id: str, tenant_domain: str) -> List[Dict[str, Any]]:
|
||||
"""Get summaries for all documents in the dataset"""
|
||||
async with get_db_session() as session:
|
||||
query = """
|
||||
SELECT ds.document_id, ds.quick_summary, ds.detailed_analysis,
|
||||
ds.topics, ds.metadata, ds.confidence,
|
||||
d.filename, d.original_filename
|
||||
FROM document_summaries ds
|
||||
JOIN documents d ON ds.document_id = d.id
|
||||
WHERE d.dataset_id = $1
|
||||
ORDER BY ds.created_at DESC
|
||||
"""
|
||||
results = await fetch_all(session, query, dataset_id)
|
||||
|
||||
summaries = []
|
||||
for row in results:
|
||||
summary = dict(row)
|
||||
# Parse JSON fields
|
||||
if summary["topics"]:
|
||||
summary["topics"] = json.loads(summary["topics"])
|
||||
if summary["metadata"]:
|
||||
summary["metadata"] = json.loads(summary["metadata"])
|
||||
summaries.append(summary)
|
||||
|
||||
return summaries
|
||||
|
||||
async def _calculate_dataset_statistics(self, dataset_id: str, tenant_domain: str) -> Dict[str, Any]:
|
||||
"""Calculate comprehensive dataset statistics"""
|
||||
async with get_db_session() as session:
|
||||
# Basic document statistics
|
||||
doc_stats_query = """
|
||||
SELECT
|
||||
COUNT(*) as total_documents,
|
||||
SUM(file_size_bytes) as total_size_bytes,
|
||||
SUM(chunk_count) as total_chunks,
|
||||
AVG(chunk_count) as avg_chunks_per_doc,
|
||||
COUNT(DISTINCT file_type) as unique_file_types
|
||||
FROM documents
|
||||
WHERE dataset_id = $1 AND processing_status = 'completed'
|
||||
"""
|
||||
doc_stats = await fetch_one(session, doc_stats_query, dataset_id)
|
||||
|
||||
# Chunk statistics
|
||||
chunk_stats_query = """
|
||||
SELECT
|
||||
COUNT(*) as total_vector_embeddings,
|
||||
AVG(token_count) as avg_tokens_per_chunk,
|
||||
MIN(token_count) as min_tokens,
|
||||
MAX(token_count) as max_tokens
|
||||
FROM document_chunks
|
||||
WHERE dataset_id = $1
|
||||
"""
|
||||
chunk_stats = await fetch_one(session, chunk_stats_query, dataset_id)
|
||||
|
||||
# File type distribution
|
||||
file_types_query = """
|
||||
SELECT file_type, COUNT(*) as count
|
||||
FROM documents
|
||||
WHERE dataset_id = $1 AND processing_status = 'completed'
|
||||
GROUP BY file_type
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
file_types_results = await fetch_all(session, file_types_query, dataset_id)
|
||||
file_types = {row["file_type"]: row["count"] for row in file_types_results}
|
||||
|
||||
return {
|
||||
"documents": {
|
||||
"total": doc_stats["total_documents"] or 0,
|
||||
"total_size_mb": round((doc_stats["total_size_bytes"] or 0) / 1024 / 1024, 2),
|
||||
"avg_chunks_per_document": round(doc_stats["avg_chunks_per_doc"] or 0, 1),
|
||||
"unique_file_types": doc_stats["unique_file_types"] or 0,
|
||||
"file_type_distribution": file_types
|
||||
},
|
||||
"chunks": {
|
||||
"total": chunk_stats["total_vector_embeddings"] or 0,
|
||||
"avg_tokens": round(chunk_stats["avg_tokens_per_chunk"] or 0, 1),
|
||||
"token_range": {
|
||||
"min": chunk_stats["min_tokens"] or 0,
|
||||
"max": chunk_stats["max_tokens"] or 0
|
||||
}
|
||||
},
|
||||
"search_readiness": {
|
||||
"has_vectors": (chunk_stats["total_vector_embeddings"] or 0) > 0,
|
||||
"vector_coverage": 1.0 if (doc_stats["total_chunks"] or 0) == (chunk_stats["total_vector_embeddings"] or 0) else 0.0
|
||||
}
|
||||
}
|
||||
|
||||
async def _analyze_dataset_topics(self, document_summaries: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Analyze topics across all document summaries"""
|
||||
if not document_summaries:
|
||||
return {"main_topics": [], "topic_distribution": {}, "confidence": 0.0}
|
||||
|
||||
# Collect all topics from document summaries
|
||||
all_topics = []
|
||||
for summary in document_summaries:
|
||||
topics = summary.get("topics", [])
|
||||
if isinstance(topics, list):
|
||||
all_topics.extend(topics)
|
||||
|
||||
# Count topic frequencies
|
||||
topic_counts = Counter(all_topics)
|
||||
|
||||
# Get top topics
|
||||
main_topics = [topic for topic, count in topic_counts.most_common(10)]
|
||||
|
||||
# Calculate topic distribution
|
||||
total_topics = len(all_topics)
|
||||
topic_distribution = {}
|
||||
if total_topics > 0:
|
||||
for topic, count in topic_counts.items():
|
||||
topic_distribution[topic] = round(count / total_topics, 3)
|
||||
|
||||
# Calculate confidence based on number of summaries available
|
||||
confidence = min(1.0, len(document_summaries) / 5.0) # Full confidence with 5+ documents
|
||||
|
||||
return {
|
||||
"main_topics": main_topics,
|
||||
"topic_distribution": topic_distribution,
|
||||
"confidence": confidence,
|
||||
"total_unique_topics": len(topic_counts)
|
||||
}
|
||||
|
||||
async def _generate_dataset_overview(
|
||||
self,
|
||||
dataset_info: Dict[str, Any],
|
||||
document_summaries: List[Dict[str, Any]],
|
||||
topics_analysis: Dict[str, Any],
|
||||
stats: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Generate LLM-powered overview of the dataset"""
|
||||
|
||||
# Create context for LLM
|
||||
context = f"""Dataset: {dataset_info['dataset_name']}
|
||||
Description: {dataset_info.get('description', 'No description provided')}
|
||||
|
||||
Statistics:
|
||||
- {stats['documents']['total']} documents ({stats['documents']['total_size_mb']} MB)
|
||||
- {stats['chunks']['total']} text chunks for search
|
||||
- Average {stats['documents']['avg_chunks_per_document']} chunks per document
|
||||
|
||||
Main Topics: {', '.join(topics_analysis['main_topics'][:5])}
|
||||
|
||||
Document Summaries:
|
||||
"""
|
||||
|
||||
# Add sample document summaries
|
||||
for i, summary in enumerate(document_summaries[:3]): # First 3 documents
|
||||
context += f"\n{i+1}. {summary['filename']}: {summary['quick_summary']}"
|
||||
|
||||
prompt = f"""Analyze this dataset and provide a comprehensive 2-3 paragraph overview.
|
||||
|
||||
{context}
|
||||
|
||||
Focus on:
|
||||
1. What type of content this dataset contains
|
||||
2. The main themes and topics covered
|
||||
3. How useful this would be for AI-powered search and retrieval
|
||||
4. Any notable patterns or characteristics
|
||||
|
||||
Provide a professional, informative summary suitable for users exploring their datasets."""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
|
||||
json={
|
||||
"model": "llama-3.1-8b-instant",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a data analysis expert. Provide clear, insightful dataset summaries."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 500
|
||||
},
|
||||
headers={
|
||||
"X-Tenant-ID": "default",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
llm_response = response.json()
|
||||
return llm_response["choices"][0]["message"]["content"]
|
||||
else:
|
||||
raise Exception(f"LLM API error: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM overview generation failed: {e}")
|
||||
# Fallback to template-based overview
|
||||
return f"This dataset contains {stats['documents']['total']} documents covering topics such as {', '.join(topics_analysis['main_topics'][:3])}. The dataset includes {stats['chunks']['total']} searchable text chunks optimized for AI-powered retrieval and question answering."
|
||||
|
||||
async def _generate_search_recommendations(
|
||||
self,
|
||||
stats: Dict[str, Any],
|
||||
topics_analysis: Dict[str, Any]
|
||||
) -> List[str]:
|
||||
"""Generate recommendations for optimizing search performance"""
|
||||
recommendations = []
|
||||
|
||||
# Vector coverage recommendations
|
||||
if not stats["search_readiness"]["has_vectors"]:
|
||||
recommendations.append("Generate vector embeddings for all documents to enable semantic search")
|
||||
elif stats["search_readiness"]["vector_coverage"] < 1.0:
|
||||
recommendations.append("Complete vector embedding generation for optimal search performance")
|
||||
|
||||
# Chunk size recommendations
|
||||
avg_tokens = stats["chunks"]["avg_tokens"]
|
||||
if avg_tokens < 100:
|
||||
recommendations.append("Consider increasing chunk size for better context in search results")
|
||||
elif avg_tokens > 600:
|
||||
recommendations.append("Consider reducing chunk size for more precise search matches")
|
||||
|
||||
# Topic diversity recommendations
|
||||
if topics_analysis["total_unique_topics"] < 3:
|
||||
recommendations.append("Dataset may benefit from more diverse content for comprehensive coverage")
|
||||
elif topics_analysis["total_unique_topics"] > 50:
|
||||
recommendations.append("Consider organizing content into focused sub-datasets for better search precision")
|
||||
|
||||
# Document count recommendations
|
||||
doc_count = stats["documents"]["total"]
|
||||
if doc_count < 5:
|
||||
recommendations.append("Add more documents to improve search quality and coverage")
|
||||
elif doc_count > 100:
|
||||
recommendations.append("Consider implementing advanced filtering and categorization for better navigation")
|
||||
|
||||
return recommendations[:5] # Limit to top 5 recommendations
|
||||
|
||||
async def _store_dataset_summary(
|
||||
self,
|
||||
dataset_id: str,
|
||||
summary_data: Dict[str, Any],
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
):
|
||||
"""Store or update dataset summary in database"""
|
||||
async with get_db_session() as session:
|
||||
query = """
|
||||
UPDATE datasets
|
||||
SET
|
||||
summary = $1,
|
||||
summary_generated_at = $2,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
"""
|
||||
|
||||
await execute_command(
|
||||
session,
|
||||
query,
|
||||
json.dumps(summary_data),
|
||||
datetime.utcnow(),
|
||||
dataset_id
|
||||
)
|
||||
|
||||
async def _get_existing_summary(self, dataset_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get existing dataset summary if available"""
|
||||
async with get_db_session() as session:
|
||||
query = """
|
||||
SELECT summary, summary_generated_at
|
||||
FROM datasets
|
||||
WHERE id = $1 AND summary IS NOT NULL
|
||||
"""
|
||||
result = await fetch_one(session, query, dataset_id)
|
||||
|
||||
if result and result["summary"]:
|
||||
return json.loads(result["summary"])
|
||||
return None
|
||||
|
||||
def _is_summary_fresh(self, summary: Dict[str, Any], max_age_hours: int = 24) -> bool:
|
||||
"""Check if summary is recent enough to avoid regeneration"""
|
||||
try:
|
||||
generated_at = datetime.fromisoformat(summary["metadata"]["generated_at"])
|
||||
age_hours = (datetime.utcnow() - generated_at).total_seconds() / 3600
|
||||
return age_hours < max_age_hours
|
||||
except (KeyError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
# Global instance
|
||||
dataset_summarizer = DatasetSummarizer()
|
||||
|
||||
|
||||
async def generate_dataset_summary(
|
||||
dataset_id: str,
|
||||
tenant_domain: str,
|
||||
user_id: str,
|
||||
force_regenerate: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Convenience function for dataset summary generation"""
|
||||
return await dataset_summarizer.generate_dataset_summary(
|
||||
dataset_id, tenant_domain, user_id, force_regenerate
|
||||
)
|
||||
|
||||
|
||||
async def get_dataset_summary(dataset_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
|
||||
"""Convenience function for retrieving dataset summary"""
|
||||
return await dataset_summarizer._get_existing_summary(dataset_id, tenant_domain)
|
||||
834
apps/tenant-backend/app/services/document_processor.py
Normal file
834
apps/tenant-backend/app/services/document_processor.py
Normal file
@@ -0,0 +1,834 @@
|
||||
"""
|
||||
Document Processing Service for GT 2.0
|
||||
|
||||
Handles file upload, text extraction, chunking, and embedding generation
|
||||
for RAG pipeline. Supports multiple file formats with intelligent chunking.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
# Document processing libraries
|
||||
import pypdf as PyPDF2 # pypdf is the maintained successor to PyPDF2
|
||||
import docx
|
||||
import pandas as pd
|
||||
import json
|
||||
import csv
|
||||
from io import StringIO
|
||||
|
||||
# Database and core services
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
|
||||
# Resource cluster client for embeddings
|
||||
import httpx
|
||||
from app.services.embedding_client import get_embedding_client
|
||||
|
||||
# Document summarization
|
||||
from app.services.summarization_service import SummarizationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""
|
||||
Comprehensive document processing service for RAG pipeline.
|
||||
|
||||
Features:
|
||||
- Multi-format support (PDF, DOCX, TXT, MD, CSV, JSON)
|
||||
- Intelligent chunking with overlap
|
||||
- Async embedding generation with batch processing
|
||||
- Progress tracking
|
||||
- Error handling and recovery
|
||||
"""
|
||||
|
||||
def __init__(self, db=None, tenant_domain=None):
|
||||
self.db = db
|
||||
self.tenant_domain = tenant_domain or "test" # Default fallback
|
||||
# Use configurable embedding client instead of hardcoded URL
|
||||
self.embedding_client = get_embedding_client()
|
||||
self.chunk_size = 512 # Default chunk size in tokens
|
||||
self.chunk_overlap = 128 # Default overlap
|
||||
self.max_file_size = 100 * 1024 * 1024 # 100MB limit
|
||||
|
||||
# Embedding batch processing configuration
|
||||
self.EMBEDDING_BATCH_SIZE = 15 # Process embeddings in batches of 15 (ARM64 optimized)
|
||||
self.MAX_CONCURRENT_BATCHES = 3 # Process up to 3 batches concurrently
|
||||
self.MAX_RETRIES = 3 # Maximum retries per batch
|
||||
self.INITIAL_RETRY_DELAY = 1.0 # Initial delay in seconds
|
||||
|
||||
# Supported file types
|
||||
self.supported_types = {
|
||||
'.pdf': 'application/pdf',
|
||||
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'.txt': 'text/plain',
|
||||
'.md': 'text/markdown',
|
||||
'.csv': 'text/csv',
|
||||
'.json': 'application/json'
|
||||
}
|
||||
|
||||
async def process_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
original_filename: str,
|
||||
document_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a uploaded file through the complete RAG pipeline.
|
||||
|
||||
Args:
|
||||
file_path: Path to uploaded file
|
||||
dataset_id: Dataset UUID to attach to
|
||||
user_id: User who uploaded the file
|
||||
original_filename: Original filename
|
||||
document_id: Optional existing document ID to update instead of creating new
|
||||
|
||||
Returns:
|
||||
Dict: Document record with processing status
|
||||
"""
|
||||
logger.info(f"Processing file {original_filename} for dataset {dataset_id}")
|
||||
|
||||
# Process file directly (no session management needed with PostgreSQL client)
|
||||
return await self._process_file_internal(file_path, dataset_id, user_id, original_filename, document_id)
|
||||
|
||||
async def _process_file_internal(
|
||||
self,
|
||||
file_path: Path,
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
original_filename: str,
|
||||
document_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal file processing method"""
|
||||
try:
|
||||
# 1. Validate file
|
||||
await self._validate_file(file_path)
|
||||
|
||||
# 2. Create or use existing document record
|
||||
if document_id:
|
||||
# Use existing document
|
||||
document = {"id": document_id}
|
||||
logger.info(f"Using existing document {document_id} for processing")
|
||||
else:
|
||||
# Create new document record
|
||||
document = await self._create_document_record(
|
||||
file_path, dataset_id, user_id, original_filename
|
||||
)
|
||||
|
||||
# 3. Get or extract text content
|
||||
await self._update_processing_status(document["id"], "processing", processing_stage="Getting text content")
|
||||
|
||||
# Check if content already exists (e.g., from upload-time extraction)
|
||||
existing_content, storage_type = await self._get_existing_document_content(document["id"])
|
||||
|
||||
if existing_content and storage_type in ["pdf_extracted", "text"]:
|
||||
# Use existing extracted content
|
||||
text_content = existing_content
|
||||
logger.info(f"Using existing extracted content ({len(text_content)} chars, type: {storage_type})")
|
||||
else:
|
||||
# Extract text from file
|
||||
await self._update_processing_status(document["id"], "processing", processing_stage="Extracting text")
|
||||
|
||||
# Determine file type for extraction
|
||||
if document_id:
|
||||
# For existing documents, determine file type from file extension
|
||||
file_ext = file_path.suffix.lower()
|
||||
file_type = self.supported_types.get(file_ext, 'text/plain')
|
||||
else:
|
||||
file_type = document["file_type"]
|
||||
|
||||
text_content = await self._extract_text(file_path, file_type)
|
||||
|
||||
# 4. Update document with extracted text
|
||||
await self._update_document_content(document["id"], text_content)
|
||||
|
||||
# 5. Generate document summary
|
||||
await self._update_processing_status(document["id"], "processing", processing_stage="Generating summary")
|
||||
await self._generate_document_summary(document["id"], text_content, original_filename, user_id)
|
||||
|
||||
# 6. Chunk the document
|
||||
await self._update_processing_status(document["id"], "processing", processing_stage="Creating chunks")
|
||||
chunks = await self._chunk_text(text_content, document["id"])
|
||||
|
||||
# Set expected chunk count for progress tracking
|
||||
await self._update_processing_status(
|
||||
document["id"], "processing",
|
||||
processing_stage="Preparing for embedding generation",
|
||||
total_chunks_expected=len(chunks)
|
||||
)
|
||||
|
||||
# 7. Generate embeddings
|
||||
await self._update_processing_status(document["id"], "processing", processing_stage="Starting embedding generation")
|
||||
await self._generate_embeddings_for_chunks(chunks, dataset_id, user_id)
|
||||
|
||||
# 8. Update final status
|
||||
await self._update_processing_status(
|
||||
document["id"], "completed",
|
||||
processing_stage="Completed",
|
||||
chunks_processed=len(chunks),
|
||||
total_chunks_expected=len(chunks)
|
||||
)
|
||||
await self._update_chunk_count(document["id"], len(chunks))
|
||||
|
||||
# 9. Update dataset summary (after document is fully processed)
|
||||
await self._update_dataset_summary_after_document_change(dataset_id, user_id)
|
||||
|
||||
logger.info(f"Successfully processed {original_filename} with {len(chunks)} chunks")
|
||||
return document
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file {original_filename}: {e}")
|
||||
if 'document' in locals():
|
||||
await self._update_processing_status(
|
||||
document["id"], "failed",
|
||||
error_message=str(e),
|
||||
processing_stage="Failed"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _validate_file(self, file_path: Path):
|
||||
"""Validate file size and type"""
|
||||
if not file_path.exists():
|
||||
raise ValueError("File does not exist")
|
||||
|
||||
file_size = file_path.stat().st_size
|
||||
if file_size > self.max_file_size:
|
||||
raise ValueError(f"File too large: {file_size} bytes (max: {self.max_file_size})")
|
||||
|
||||
file_ext = file_path.suffix.lower()
|
||||
if file_ext not in self.supported_types:
|
||||
raise ValueError(f"Unsupported file type: {file_ext}")
|
||||
|
||||
async def _create_document_record(
|
||||
self,
|
||||
file_path: Path,
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
original_filename: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create document record in database"""
|
||||
|
||||
# Calculate file hash
|
||||
with open(file_path, 'rb') as f:
|
||||
file_hash = hashlib.sha256(f.read()).hexdigest()
|
||||
|
||||
file_ext = file_path.suffix.lower()
|
||||
file_size = file_path.stat().st_size
|
||||
document_id = str(uuid.uuid4())
|
||||
|
||||
# Insert document record using raw SQL
|
||||
# Note: tenant_id is nullable UUID, so we set it to NULL for individual documents
|
||||
pg_client = await get_postgresql_client()
|
||||
await pg_client.execute_command(
|
||||
"""INSERT INTO documents (
|
||||
id, user_id, dataset_id, filename, original_filename,
|
||||
file_type, file_size_bytes, file_hash, processing_status
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)""",
|
||||
document_id, str(user_id), dataset_id, str(file_path.name),
|
||||
original_filename, self.supported_types[file_ext], file_size, file_hash, "pending"
|
||||
)
|
||||
|
||||
return {
|
||||
"id": document_id,
|
||||
"user_id": user_id,
|
||||
"dataset_id": dataset_id,
|
||||
"filename": str(file_path.name),
|
||||
"original_filename": original_filename,
|
||||
"file_type": self.supported_types[file_ext],
|
||||
"file_size_bytes": file_size,
|
||||
"file_hash": file_hash,
|
||||
"processing_status": "pending",
|
||||
"chunk_count": 0
|
||||
}
|
||||
|
||||
async def _extract_text(self, file_path: Path, file_type: str) -> str:
|
||||
"""Extract text content from various file formats"""
|
||||
|
||||
try:
|
||||
if file_type == 'application/pdf':
|
||||
return await self._extract_pdf_text(file_path)
|
||||
elif 'wordprocessingml' in file_type:
|
||||
return await self._extract_docx_text(file_path)
|
||||
elif file_type == 'text/csv':
|
||||
return await self._extract_csv_text(file_path)
|
||||
elif file_type == 'application/json':
|
||||
return await self._extract_json_text(file_path)
|
||||
else: # text/plain, text/markdown
|
||||
return await self._extract_plain_text(file_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Text extraction failed for {file_path}: {e}")
|
||||
raise ValueError(f"Could not extract text from file: {e}")
|
||||
|
||||
async def _extract_pdf_text(self, file_path: Path) -> str:
|
||||
"""Extract text from PDF file"""
|
||||
text_parts = []
|
||||
|
||||
with open(file_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
|
||||
for page_num, page in enumerate(pdf_reader.pages):
|
||||
try:
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num + 1} ---\n{page_text}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not extract text from page {page_num + 1}: {e}")
|
||||
|
||||
if not text_parts:
|
||||
raise ValueError("No text could be extracted from PDF")
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
async def _extract_docx_text(self, file_path: Path) -> str:
|
||||
"""Extract text from DOCX file"""
|
||||
doc = docx.Document(file_path)
|
||||
text_parts = []
|
||||
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text_parts.append(paragraph.text)
|
||||
|
||||
if not text_parts:
|
||||
raise ValueError("No text could be extracted from DOCX")
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
async def _extract_csv_text(self, file_path: Path) -> str:
|
||||
"""Extract and format text from CSV file"""
|
||||
try:
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# Create readable format
|
||||
text_parts = [f"CSV Data with {len(df)} rows and {len(df.columns)} columns"]
|
||||
text_parts.append(f"Columns: {', '.join(df.columns.tolist())}")
|
||||
text_parts.append("")
|
||||
|
||||
# Sample first few rows in readable format
|
||||
for idx, row in df.head(20).iterrows():
|
||||
row_text = []
|
||||
for col in df.columns:
|
||||
if pd.notna(row[col]):
|
||||
row_text.append(f"{col}: {row[col]}")
|
||||
text_parts.append(f"Row {idx + 1}: " + " | ".join(row_text))
|
||||
|
||||
return "\n".join(text_parts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CSV parsing error: {e}")
|
||||
# Fallback to reading as plain text
|
||||
return await self._extract_plain_text(file_path)
|
||||
|
||||
async def _extract_json_text(self, file_path: Path) -> str:
|
||||
"""Extract and format text from JSON file"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Convert JSON to readable text format
|
||||
def json_to_text(obj, prefix=""):
|
||||
text_parts = []
|
||||
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
text_parts.append(f"{prefix}{key}:")
|
||||
text_parts.extend(json_to_text(value, prefix + " "))
|
||||
else:
|
||||
text_parts.append(f"{prefix}{key}: {value}")
|
||||
elif isinstance(obj, list):
|
||||
for i, item in enumerate(obj):
|
||||
if isinstance(item, (dict, list)):
|
||||
text_parts.append(f"{prefix}Item {i + 1}:")
|
||||
text_parts.extend(json_to_text(item, prefix + " "))
|
||||
else:
|
||||
text_parts.append(f"{prefix}Item {i + 1}: {item}")
|
||||
else:
|
||||
text_parts.append(f"{prefix}{obj}")
|
||||
|
||||
return text_parts
|
||||
|
||||
return "\n".join(json_to_text(data))
|
||||
|
||||
async def _extract_plain_text(self, file_path: Path) -> str:
|
||||
"""Extract text from plain text files"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except UnicodeDecodeError:
|
||||
# Try with latin-1 encoding
|
||||
with open(file_path, 'r', encoding='latin-1') as f:
|
||||
return f.read()
|
||||
|
||||
async def extract_text_from_path(self, file_path: Path, content_type: str) -> str:
|
||||
"""Public wrapper for text extraction from file path"""
|
||||
return await self._extract_text(file_path, content_type)
|
||||
|
||||
async def chunk_text_simple(self, text: str) -> List[str]:
|
||||
"""Public wrapper for simple text chunking without document_id"""
|
||||
chunks = []
|
||||
chunk_size = self.chunk_size * 4 # ~2048 chars
|
||||
overlap = self.chunk_overlap * 4 # ~512 chars
|
||||
|
||||
for i in range(0, len(text), chunk_size - overlap):
|
||||
chunk = text[i:i + chunk_size]
|
||||
if chunk.strip():
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
async def _chunk_text(self, text: str, document_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Split text into overlapping chunks optimized for embeddings.
|
||||
|
||||
Returns:
|
||||
List of chunk dictionaries with content and metadata
|
||||
"""
|
||||
# Simple sentence-aware chunking
|
||||
sentences = re.split(r'[.!?]+', text)
|
||||
sentences = [s.strip() for s in sentences if s.strip()]
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
current_tokens = 0
|
||||
chunk_index = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_tokens = len(sentence.split())
|
||||
|
||||
# If adding this sentence would exceed chunk size, save current chunk
|
||||
if current_tokens + sentence_tokens > self.chunk_size and current_chunk:
|
||||
# Create chunk with overlap from previous chunk
|
||||
chunk_content = current_chunk.strip()
|
||||
if chunk_content:
|
||||
chunks.append({
|
||||
"document_id": document_id,
|
||||
"chunk_index": chunk_index,
|
||||
"content": chunk_content,
|
||||
"token_count": current_tokens,
|
||||
"content_hash": hashlib.md5(chunk_content.encode()).hexdigest()
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
# Start new chunk with overlap
|
||||
if self.chunk_overlap > 0 and chunks:
|
||||
# Take last few sentences for overlap
|
||||
overlap_sentences = current_chunk.split('.')[-2:] # Rough overlap
|
||||
current_chunk = '. '.join(s.strip() for s in overlap_sentences if s.strip())
|
||||
current_tokens = len(current_chunk.split())
|
||||
else:
|
||||
current_chunk = ""
|
||||
current_tokens = 0
|
||||
|
||||
# Add sentence to current chunk
|
||||
if current_chunk:
|
||||
current_chunk += ". " + sentence
|
||||
else:
|
||||
current_chunk = sentence
|
||||
current_tokens += sentence_tokens
|
||||
|
||||
# Add final chunk
|
||||
if current_chunk.strip():
|
||||
chunk_content = current_chunk.strip()
|
||||
chunks.append({
|
||||
"document_id": document_id,
|
||||
"chunk_index": chunk_index,
|
||||
"content": chunk_content,
|
||||
"token_count": current_tokens,
|
||||
"content_hash": hashlib.md5(chunk_content.encode()).hexdigest()
|
||||
})
|
||||
|
||||
logger.info(f"Created {len(chunks)} chunks from document {document_id}")
|
||||
return chunks
|
||||
|
||||
async def _generate_embeddings_for_chunks(
|
||||
self,
|
||||
chunks: List[Dict[str, Any]],
|
||||
dataset_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""
|
||||
Generate embeddings for all chunks using concurrent batch processing.
|
||||
|
||||
Processes chunks in batches with controlled concurrency to optimize performance
|
||||
while preventing system overload. Includes retry logic and progressive storage.
|
||||
"""
|
||||
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
total_chunks = len(chunks)
|
||||
document_id = chunks[0]["document_id"]
|
||||
total_batches = (total_chunks + self.EMBEDDING_BATCH_SIZE - 1) // self.EMBEDDING_BATCH_SIZE
|
||||
|
||||
logger.info(f"Starting concurrent embedding generation for {total_chunks} chunks")
|
||||
logger.info(f"Batch size: {self.EMBEDDING_BATCH_SIZE}, Total batches: {total_batches}, Max concurrent: {self.MAX_CONCURRENT_BATCHES}")
|
||||
|
||||
# Create semaphore to limit concurrent batches
|
||||
semaphore = asyncio.Semaphore(self.MAX_CONCURRENT_BATCHES)
|
||||
|
||||
# Create batch data with metadata
|
||||
batch_tasks = []
|
||||
for batch_start in range(0, total_chunks, self.EMBEDDING_BATCH_SIZE):
|
||||
batch_end = min(batch_start + self.EMBEDDING_BATCH_SIZE, total_chunks)
|
||||
batch_chunks = chunks[batch_start:batch_end]
|
||||
batch_num = (batch_start // self.EMBEDDING_BATCH_SIZE) + 1
|
||||
|
||||
batch_metadata = {
|
||||
"chunks": batch_chunks,
|
||||
"batch_num": batch_num,
|
||||
"start_index": batch_start,
|
||||
"end_index": batch_end,
|
||||
"dataset_id": dataset_id,
|
||||
"user_id": user_id,
|
||||
"document_id": document_id
|
||||
}
|
||||
|
||||
# Create task for this batch
|
||||
task = self._process_batch_with_semaphore(semaphore, batch_metadata, total_batches, total_chunks)
|
||||
batch_tasks.append(task)
|
||||
|
||||
# Process all batches concurrently
|
||||
logger.info(f"Starting concurrent processing of {len(batch_tasks)} batches")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
# Analyze results
|
||||
successful_batches = []
|
||||
failed_batches = []
|
||||
|
||||
for i, result in enumerate(results):
|
||||
batch_num = i + 1
|
||||
if isinstance(result, Exception):
|
||||
failed_batches.append({
|
||||
"batch_num": batch_num,
|
||||
"error": str(result)
|
||||
})
|
||||
logger.error(f"Batch {batch_num} failed: {result}")
|
||||
else:
|
||||
successful_batches.append(result)
|
||||
|
||||
successful_chunks = sum(len(batch["chunks"]) for batch in successful_batches)
|
||||
|
||||
logger.info(f"Concurrent processing completed in {processing_time:.2f} seconds")
|
||||
logger.info(f"Successfully processed {successful_chunks}/{total_chunks} chunks in {len(successful_batches)}/{total_batches} batches")
|
||||
|
||||
# Report final results
|
||||
if failed_batches:
|
||||
failed_chunk_count = total_chunks - successful_chunks
|
||||
error_details = "; ".join([f"Batch {b['batch_num']}: {b['error']}" for b in failed_batches[:3]])
|
||||
if len(failed_batches) > 3:
|
||||
error_details += f" (and {len(failed_batches) - 3} more failures)"
|
||||
|
||||
raise ValueError(f"Failed to generate embeddings for {failed_chunk_count}/{total_chunks} chunks. Errors: {error_details}")
|
||||
|
||||
logger.info(f"Successfully stored all {total_chunks} chunks with embeddings")
|
||||
|
||||
async def _process_batch_with_semaphore(
|
||||
self,
|
||||
semaphore: asyncio.Semaphore,
|
||||
batch_metadata: Dict[str, Any],
|
||||
total_batches: int,
|
||||
total_chunks: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a single batch with semaphore-controlled concurrency.
|
||||
|
||||
Args:
|
||||
semaphore: Concurrency control semaphore
|
||||
batch_metadata: Batch information including chunks and metadata
|
||||
total_batches: Total number of batches
|
||||
total_chunks: Total number of chunks
|
||||
|
||||
Returns:
|
||||
Dict with batch processing results
|
||||
"""
|
||||
async with semaphore:
|
||||
batch_chunks = batch_metadata["chunks"]
|
||||
batch_num = batch_metadata["batch_num"]
|
||||
dataset_id = batch_metadata["dataset_id"]
|
||||
user_id = batch_metadata["user_id"]
|
||||
document_id = batch_metadata["document_id"]
|
||||
|
||||
logger.info(f"Starting batch {batch_num}/{total_batches} ({len(batch_chunks)} chunks)")
|
||||
|
||||
try:
|
||||
# Generate embeddings for this batch (pass user_id for billing)
|
||||
embeddings = await self._generate_embedding_batch(batch_chunks, user_id=user_id)
|
||||
|
||||
# Store embeddings for this batch immediately
|
||||
await self._store_chunk_embeddings(batch_chunks, embeddings, dataset_id, user_id)
|
||||
|
||||
# Update progress in database
|
||||
progress_stage = f"Completed batch {batch_num}/{total_batches}"
|
||||
|
||||
# Calculate current progress (approximate since batches complete out of order)
|
||||
await self._update_processing_status(
|
||||
document_id, "processing",
|
||||
processing_stage=progress_stage,
|
||||
chunks_processed=batch_num * self.EMBEDDING_BATCH_SIZE, # Approximate
|
||||
total_chunks_expected=total_chunks
|
||||
)
|
||||
|
||||
logger.info(f"Successfully completed batch {batch_num}/{total_batches}")
|
||||
|
||||
return {
|
||||
"batch_num": batch_num,
|
||||
"chunks": batch_chunks,
|
||||
"success": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process batch {batch_num}/{total_batches}: {e}")
|
||||
raise ValueError(f"Batch {batch_num} failed: {str(e)}")
|
||||
|
||||
async def _generate_embedding_batch(
|
||||
self,
|
||||
batch_chunks: List[Dict[str, Any]],
|
||||
user_id: str = None
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for a single batch of chunks with retry logic.
|
||||
|
||||
Args:
|
||||
batch_chunks: List of chunk dictionaries
|
||||
user_id: User ID for usage tracking
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
|
||||
Raises:
|
||||
ValueError: If embedding generation fails after all retries
|
||||
"""
|
||||
texts = [chunk["content"] for chunk in batch_chunks]
|
||||
|
||||
for attempt in range(self.MAX_RETRIES + 1):
|
||||
try:
|
||||
# Use the configurable embedding client with tenant/user context for billing
|
||||
embeddings = await self.embedding_client.generate_embeddings(
|
||||
texts,
|
||||
tenant_id=self.tenant_domain,
|
||||
user_id=str(user_id) if user_id else None
|
||||
)
|
||||
|
||||
if len(embeddings) != len(texts):
|
||||
raise ValueError(f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}")
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
if attempt < self.MAX_RETRIES:
|
||||
delay = self.INITIAL_RETRY_DELAY * (2 ** attempt) # Exponential backoff
|
||||
logger.warning(f"Embedding generation attempt {attempt + 1}/{self.MAX_RETRIES + 1} failed: {e}. Retrying in {delay}s...")
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error(f"All {self.MAX_RETRIES + 1} embedding generation attempts failed. Final error: {e}")
|
||||
logger.error(f"Failed request details: URL=http://gentwo-vllm-embeddings:8000/v1/embeddings, texts_count={len(texts)}")
|
||||
raise ValueError(f"Embedding generation failed after {self.MAX_RETRIES + 1} attempts: {str(e)}")
|
||||
|
||||
async def _store_chunk_embeddings(
|
||||
self,
|
||||
batch_chunks: List[Dict[str, Any]],
|
||||
embeddings: List[List[float]],
|
||||
dataset_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""Store chunk embeddings in database with proper error handling."""
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
for chunk_data, embedding in zip(batch_chunks, embeddings):
|
||||
chunk_id = str(uuid.uuid4())
|
||||
|
||||
# Convert embedding list to PostgreSQL array format
|
||||
embedding_array = f"[{','.join(map(str, embedding))}]" if embedding else None
|
||||
|
||||
await pg_client.execute_command(
|
||||
"""INSERT INTO document_chunks (
|
||||
id, document_id, user_id, dataset_id, chunk_index,
|
||||
content, content_hash, token_count, embedding
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector)""",
|
||||
chunk_id, chunk_data["document_id"], str(user_id),
|
||||
dataset_id, chunk_data["chunk_index"], chunk_data["content"],
|
||||
chunk_data["content_hash"], chunk_data["token_count"], embedding_array
|
||||
)
|
||||
|
||||
async def _update_processing_status(
|
||||
self,
|
||||
document_id: str,
|
||||
status: str,
|
||||
error_message: Optional[str] = None,
|
||||
processing_stage: Optional[str] = None,
|
||||
chunks_processed: Optional[int] = None,
|
||||
total_chunks_expected: Optional[int] = None
|
||||
):
|
||||
"""Update document processing status with progress tracking via metadata JSONB"""
|
||||
|
||||
# Calculate progress percentage if we have the data
|
||||
processing_progress = None
|
||||
if chunks_processed is not None and total_chunks_expected is not None and total_chunks_expected > 0:
|
||||
processing_progress = min(100, int((chunks_processed / total_chunks_expected) * 100))
|
||||
|
||||
# Build progress metadata object
|
||||
import json
|
||||
progress_data = {}
|
||||
if processing_stage is not None:
|
||||
progress_data['processing_stage'] = processing_stage
|
||||
if chunks_processed is not None:
|
||||
progress_data['chunks_processed'] = chunks_processed
|
||||
if total_chunks_expected is not None:
|
||||
progress_data['total_chunks_expected'] = total_chunks_expected
|
||||
if processing_progress is not None:
|
||||
progress_data['processing_progress'] = processing_progress
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
if error_message:
|
||||
await pg_client.execute_command(
|
||||
"""UPDATE documents SET
|
||||
processing_status = $1,
|
||||
error_message = $2,
|
||||
metadata = COALESCE(metadata, '{}'::jsonb) || $3::jsonb,
|
||||
updated_at = NOW()
|
||||
WHERE id = $4""",
|
||||
status, error_message, json.dumps(progress_data), document_id
|
||||
)
|
||||
else:
|
||||
await pg_client.execute_command(
|
||||
"""UPDATE documents SET
|
||||
processing_status = $1,
|
||||
metadata = COALESCE(metadata, '{}'::jsonb) || $2::jsonb,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3""",
|
||||
status, json.dumps(progress_data), document_id
|
||||
)
|
||||
|
||||
async def _get_existing_document_content(self, document_id: str) -> tuple[str, str]:
|
||||
"""Get existing document content and storage type"""
|
||||
pg_client = await get_postgresql_client()
|
||||
result = await pg_client.fetch_one(
|
||||
"SELECT content_text, metadata FROM documents WHERE id = $1",
|
||||
document_id
|
||||
)
|
||||
if result and result["content_text"]:
|
||||
# Handle metadata - might be JSON string or dict
|
||||
metadata_raw = result["metadata"] or "{}"
|
||||
if isinstance(metadata_raw, str):
|
||||
import json
|
||||
try:
|
||||
metadata = json.loads(metadata_raw)
|
||||
except json.JSONDecodeError:
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = metadata_raw or {}
|
||||
storage_type = metadata.get("storage_type", "unknown")
|
||||
return result["content_text"], storage_type
|
||||
return None, None
|
||||
|
||||
async def _update_document_content(self, document_id: str, content: str):
|
||||
"""Update document with extracted text content"""
|
||||
pg_client = await get_postgresql_client()
|
||||
await pg_client.execute_command(
|
||||
"UPDATE documents SET content_text = $1, updated_at = NOW() WHERE id = $2",
|
||||
content, document_id
|
||||
)
|
||||
|
||||
async def _update_chunk_count(self, document_id: str, chunk_count: int):
|
||||
"""Update document with final chunk count"""
|
||||
pg_client = await get_postgresql_client()
|
||||
await pg_client.execute_command(
|
||||
"UPDATE documents SET chunk_count = $1, updated_at = NOW() WHERE id = $2",
|
||||
chunk_count, document_id
|
||||
)
|
||||
|
||||
async def _generate_document_summary(
|
||||
self,
|
||||
document_id: str,
|
||||
content: str,
|
||||
filename: str,
|
||||
user_id: str
|
||||
):
|
||||
"""Generate and store AI summary for the document"""
|
||||
try:
|
||||
# Use tenant_domain from instance context
|
||||
tenant_domain = self.tenant_domain
|
||||
|
||||
# Create summarization service instance
|
||||
summarization_service = SummarizationService(tenant_domain, user_id)
|
||||
|
||||
# Generate summary using our new service
|
||||
summary = await summarization_service.generate_document_summary(
|
||||
document_id=document_id,
|
||||
document_content=content,
|
||||
document_name=filename
|
||||
)
|
||||
|
||||
if summary:
|
||||
logger.info(f"Generated summary for document {document_id}: {summary[:100]}...")
|
||||
else:
|
||||
logger.warning(f"Failed to generate summary for document {document_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating document summary for {document_id}: {e}")
|
||||
# Don't fail the entire document processing if summarization fails
|
||||
|
||||
async def _update_dataset_summary_after_document_change(
|
||||
self,
|
||||
dataset_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""Update dataset summary after a document is added or removed"""
|
||||
try:
|
||||
# Create summarization service instance
|
||||
summarization_service = SummarizationService(self.tenant_domain, user_id)
|
||||
|
||||
# Update dataset summary asynchronously (don't block document processing)
|
||||
asyncio.create_task(
|
||||
summarization_service.update_dataset_summary_on_change(dataset_id)
|
||||
)
|
||||
|
||||
logger.info(f"Triggered dataset summary update for dataset {dataset_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering dataset summary update for {dataset_id}: {e}")
|
||||
# Don't fail document processing if dataset summary update fails
|
||||
|
||||
async def get_processing_status(self, document_id: str) -> Dict[str, Any]:
|
||||
"""Get current processing status of a document with progress information from metadata"""
|
||||
pg_client = await get_postgresql_client()
|
||||
result = await pg_client.fetch_one(
|
||||
"""SELECT processing_status, error_message, chunk_count, metadata
|
||||
FROM documents WHERE id = $1""",
|
||||
document_id
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise ValueError("Document not found")
|
||||
|
||||
# Extract progress data from metadata JSONB
|
||||
metadata = result["metadata"] or {}
|
||||
|
||||
return {
|
||||
"status": result["processing_status"],
|
||||
"error_message": result["error_message"],
|
||||
"chunk_count": result["chunk_count"],
|
||||
"chunks_processed": metadata.get("chunks_processed"),
|
||||
"total_chunks_expected": metadata.get("total_chunks_expected"),
|
||||
"processing_progress": metadata.get("processing_progress"),
|
||||
"processing_stage": metadata.get("processing_stage")
|
||||
}
|
||||
|
||||
|
||||
# Factory function for document processor
|
||||
async def get_document_processor(tenant_domain=None):
|
||||
"""Get document processor instance (will create its own DB session when needed)"""
|
||||
return DocumentProcessor(tenant_domain=tenant_domain)
|
||||
317
apps/tenant-backend/app/services/document_summarizer.py
Normal file
317
apps/tenant-backend/app/services/document_summarizer.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Document Summarization Service for GT 2.0
|
||||
|
||||
Generates AI-powered summaries for uploaded documents using the Resource Cluster.
|
||||
Provides both quick summaries and detailed analysis for RAG visualization.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.database import get_db_session, execute_command, fetch_one
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentSummarizer:
|
||||
"""
|
||||
Service for generating document summaries using Resource Cluster LLM.
|
||||
|
||||
Features:
|
||||
- Quick document summaries (2-3 sentences)
|
||||
- Detailed analysis with key topics and themes
|
||||
- Metadata extraction (document type, language, etc.)
|
||||
- Integration with document processor workflow
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.resource_cluster_url = "http://gentwo-resource-backend:8000"
|
||||
self.max_content_length = 4000 # Max chars to send for summarization
|
||||
|
||||
async def generate_document_summary(
|
||||
self,
|
||||
document_id: str,
|
||||
content: str,
|
||||
filename: str,
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a comprehensive summary for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document ID in the database
|
||||
content: Document text content
|
||||
filename: Original filename
|
||||
tenant_domain: Tenant domain for context
|
||||
user_id: User who uploaded the document
|
||||
|
||||
Returns:
|
||||
Dictionary with summary data including quick_summary, detailed_analysis,
|
||||
topics, metadata, and confidence scores
|
||||
"""
|
||||
try:
|
||||
# Truncate content if too long
|
||||
truncated_content = content[:self.max_content_length]
|
||||
if len(content) > self.max_content_length:
|
||||
truncated_content += "... [content truncated]"
|
||||
|
||||
# Generate summary using Resource Cluster LLM
|
||||
summary_data = await self._call_llm_for_summary(
|
||||
content=truncated_content,
|
||||
filename=filename,
|
||||
document_type=self._detect_document_type(filename)
|
||||
)
|
||||
|
||||
# Store summary in database
|
||||
await self._store_document_summary(
|
||||
document_id=document_id,
|
||||
summary_data=summary_data,
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
logger.info(f"Generated summary for document {document_id}: {filename}")
|
||||
return summary_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary for document {document_id}: {e}")
|
||||
# Return basic fallback summary
|
||||
return {
|
||||
"quick_summary": f"Document: {filename}",
|
||||
"detailed_analysis": "Summary generation failed",
|
||||
"topics": [],
|
||||
"metadata": {
|
||||
"document_type": self._detect_document_type(filename),
|
||||
"estimated_read_time": len(content) // 200, # ~200 words per minute
|
||||
"character_count": len(content),
|
||||
"language": "unknown"
|
||||
},
|
||||
"confidence": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def _call_llm_for_summary(
|
||||
self,
|
||||
content: str,
|
||||
filename: str,
|
||||
document_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Call Resource Cluster LLM to generate document summary"""
|
||||
|
||||
prompt = f"""Analyze this {document_type} document and provide a comprehensive summary.
|
||||
|
||||
Document: {filename}
|
||||
Content:
|
||||
{content}
|
||||
|
||||
Please provide:
|
||||
1. A concise 2-3 sentence summary
|
||||
2. Key topics and themes (list)
|
||||
3. Document analysis including tone, purpose, and target audience
|
||||
4. Estimated language and reading level
|
||||
|
||||
Format your response as JSON with these keys:
|
||||
- quick_summary: Brief 2-3 sentence overview
|
||||
- detailed_analysis: Paragraph with deeper insights
|
||||
- topics: Array of key topics/themes
|
||||
- metadata: Object with language, tone, purpose, target_audience
|
||||
- confidence: Float 0-1 indicating analysis confidence"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
|
||||
json={
|
||||
"model": "llama-3.1-8b-instant",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a document analysis expert. Provide accurate, concise summaries in valid JSON format."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 1000
|
||||
},
|
||||
headers={
|
||||
"X-Tenant-ID": "default",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
llm_response = response.json()
|
||||
content_text = llm_response["choices"][0]["message"]["content"]
|
||||
|
||||
# Try to parse JSON response
|
||||
try:
|
||||
import json
|
||||
summary_data = json.loads(content_text)
|
||||
|
||||
# Validate required fields and add defaults if missing
|
||||
return {
|
||||
"quick_summary": summary_data.get("quick_summary", f"Analysis of {filename}"),
|
||||
"detailed_analysis": summary_data.get("detailed_analysis", "Detailed analysis not available"),
|
||||
"topics": summary_data.get("topics", []),
|
||||
"metadata": {
|
||||
**summary_data.get("metadata", {}),
|
||||
"document_type": document_type,
|
||||
"generated_at": datetime.utcnow().isoformat()
|
||||
},
|
||||
"confidence": min(1.0, max(0.0, summary_data.get("confidence", 0.7)))
|
||||
}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if LLM doesn't return valid JSON
|
||||
return {
|
||||
"quick_summary": content_text[:200] + "..." if len(content_text) > 200 else content_text,
|
||||
"detailed_analysis": content_text,
|
||||
"topics": [],
|
||||
"metadata": {
|
||||
"document_type": document_type,
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"note": "Summary extracted from free-form LLM response"
|
||||
},
|
||||
"confidence": 0.5
|
||||
}
|
||||
else:
|
||||
raise Exception(f"Resource Cluster API error: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM summarization failed: {e}")
|
||||
raise
|
||||
|
||||
async def _store_document_summary(
|
||||
self,
|
||||
document_id: str,
|
||||
summary_data: Dict[str, Any],
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
):
|
||||
"""Store generated summary in database"""
|
||||
|
||||
# Use the same database session pattern as document processor
|
||||
async with get_db_session(tenant_domain) as session:
|
||||
try:
|
||||
# Insert or update document summary
|
||||
query = """
|
||||
INSERT INTO document_summaries (
|
||||
document_id, user_id, quick_summary, detailed_analysis,
|
||||
topics, metadata, confidence, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (document_id)
|
||||
DO UPDATE SET
|
||||
quick_summary = EXCLUDED.quick_summary,
|
||||
detailed_analysis = EXCLUDED.detailed_analysis,
|
||||
topics = EXCLUDED.topics,
|
||||
metadata = EXCLUDED.metadata,
|
||||
confidence = EXCLUDED.confidence,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
"""
|
||||
|
||||
import json
|
||||
await execute_command(
|
||||
session,
|
||||
query,
|
||||
document_id,
|
||||
user_id,
|
||||
summary_data["quick_summary"],
|
||||
summary_data["detailed_analysis"],
|
||||
json.dumps(summary_data["topics"]),
|
||||
json.dumps(summary_data["metadata"]),
|
||||
summary_data["confidence"],
|
||||
datetime.utcnow(),
|
||||
datetime.utcnow()
|
||||
)
|
||||
|
||||
logger.info(f"Stored summary for document {document_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store document summary: {e}")
|
||||
raise
|
||||
|
||||
def _detect_document_type(self, filename: str) -> str:
|
||||
"""Detect document type from filename extension"""
|
||||
import pathlib
|
||||
|
||||
extension = pathlib.Path(filename).suffix.lower()
|
||||
|
||||
type_mapping = {
|
||||
'.pdf': 'PDF document',
|
||||
'.docx': 'Word document',
|
||||
'.doc': 'Word document',
|
||||
'.txt': 'Text file',
|
||||
'.md': 'Markdown document',
|
||||
'.csv': 'CSV data file',
|
||||
'.json': 'JSON data file',
|
||||
'.html': 'HTML document',
|
||||
'.htm': 'HTML document',
|
||||
'.rtf': 'Rich text document'
|
||||
}
|
||||
|
||||
return type_mapping.get(extension, 'Unknown document type')
|
||||
|
||||
async def get_document_summary(
|
||||
self,
|
||||
document_id: str,
|
||||
tenant_domain: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve stored document summary"""
|
||||
|
||||
async with get_db_session(tenant_domain) as session:
|
||||
try:
|
||||
query = """
|
||||
SELECT quick_summary, detailed_analysis, topics, metadata,
|
||||
confidence, created_at, updated_at
|
||||
FROM document_summaries
|
||||
WHERE document_id = $1
|
||||
"""
|
||||
|
||||
result = await fetch_one(session, query, document_id)
|
||||
|
||||
if result:
|
||||
import json
|
||||
return {
|
||||
"quick_summary": result["quick_summary"],
|
||||
"detailed_analysis": result["detailed_analysis"],
|
||||
"topics": json.loads(result["topics"]) if result["topics"] else [],
|
||||
"metadata": json.loads(result["metadata"]) if result["metadata"] else {},
|
||||
"confidence": result["confidence"],
|
||||
"created_at": result["created_at"].isoformat(),
|
||||
"updated_at": result["updated_at"].isoformat()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve document summary: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Global instance
|
||||
document_summarizer = DocumentSummarizer()
|
||||
|
||||
|
||||
async def generate_document_summary(
|
||||
document_id: str,
|
||||
content: str,
|
||||
filename: str,
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Convenience function for document summary generation"""
|
||||
return await document_summarizer.generate_document_summary(
|
||||
document_id, content, filename, tenant_domain, user_id
|
||||
)
|
||||
|
||||
|
||||
async def get_document_summary(document_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
|
||||
"""Convenience function for retrieving document summary"""
|
||||
return await document_summarizer.get_document_summary(document_id, tenant_domain)
|
||||
286
apps/tenant-backend/app/services/embedding_client.py
Normal file
286
apps/tenant-backend/app/services/embedding_client.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
BGE-M3 Embedding Client for GT 2.0
|
||||
|
||||
Simple client for the vLLM BGE-M3 embedding service running on port 8005.
|
||||
Provides text embedding generation for RAG pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BGE_M3_EmbeddingClient:
|
||||
"""
|
||||
Simple client for BGE-M3 embedding service via vLLM.
|
||||
|
||||
Features:
|
||||
- Async HTTP client for embeddings
|
||||
- Batch processing support
|
||||
- Error handling and retries
|
||||
- OpenAI-compatible API format
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str = None):
|
||||
# Determine base URL from environment or configuration
|
||||
if base_url is None:
|
||||
base_url = self._get_embedding_endpoint()
|
||||
|
||||
self.base_url = base_url
|
||||
self.model = "BAAI/bge-m3"
|
||||
self.embedding_dimensions = 1024
|
||||
self.max_batch_size = 32
|
||||
|
||||
# Initialize BGE-M3 tokenizer for accurate token counting
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
|
||||
logger.info("Initialized BGE-M3 tokenizer for accurate token counting")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load BGE-M3 tokenizer, using word estimation: {e}")
|
||||
self.tokenizer = None
|
||||
|
||||
def _get_embedding_endpoint(self) -> str:
|
||||
"""
|
||||
Get the BGE-M3 endpoint based on configuration.
|
||||
This should sync with the control panel configuration.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Check environment variables for BGE-M3 configuration
|
||||
is_local_mode = os.getenv('BGE_M3_LOCAL_MODE', 'true').lower() == 'true'
|
||||
external_endpoint = os.getenv('BGE_M3_EXTERNAL_ENDPOINT')
|
||||
|
||||
if not is_local_mode and external_endpoint:
|
||||
return external_endpoint
|
||||
|
||||
# Default to local endpoint
|
||||
return os.getenv('EMBEDDING_ENDPOINT', 'http://host.docker.internal:8005')
|
||||
|
||||
def update_endpoint(self, new_endpoint: str):
|
||||
"""Update the embedding endpoint dynamically"""
|
||||
self.base_url = new_endpoint
|
||||
logger.info(f"BGE-M3 client endpoint updated to: {new_endpoint}")
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Check if BGE-M3 service is responding"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(f"{self.base_url}/v1/models")
|
||||
if response.status_code == 200:
|
||||
models = response.json()
|
||||
model_ids = [model['id'] for model in models.get('data', [])]
|
||||
return self.model in model_ids
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
async def generate_embeddings(
|
||||
self,
|
||||
texts: List[str],
|
||||
tenant_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
request_id: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for a list of texts using BGE-M3.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
tenant_id: Tenant ID for usage tracking (optional)
|
||||
user_id: User ID for usage tracking (optional)
|
||||
request_id: Request ID for tracking (optional)
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each is a list of 1024 floats)
|
||||
|
||||
Raises:
|
||||
ValueError: If embedding generation fails
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
if len(texts) > self.max_batch_size:
|
||||
# Process in batches
|
||||
all_embeddings = []
|
||||
for i in range(0, len(texts), self.max_batch_size):
|
||||
batch = texts[i:i + self.max_batch_size]
|
||||
batch_embeddings = await self._generate_batch(batch)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
embeddings = all_embeddings
|
||||
else:
|
||||
embeddings = await self._generate_batch(texts)
|
||||
|
||||
# Log usage if tenant context provided (fire and forget)
|
||||
if tenant_id and user_id:
|
||||
import asyncio
|
||||
tokens_used = self._count_tokens(texts)
|
||||
asyncio.create_task(
|
||||
self._log_embedding_usage(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
tokens_used=tokens_used,
|
||||
embedding_count=len(embeddings),
|
||||
request_id=request_id
|
||||
)
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def _generate_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for a single batch"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/v1/embeddings",
|
||||
json={
|
||||
"input": texts,
|
||||
"model": self.model
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
# Extract embeddings from OpenAI-compatible response
|
||||
embeddings = []
|
||||
for item in data.get("data", []):
|
||||
embedding = item.get("embedding", [])
|
||||
if len(embedding) != self.embedding_dimensions:
|
||||
raise ValueError(f"Invalid embedding dimensions: {len(embedding)} (expected {self.embedding_dimensions})")
|
||||
embeddings.append(embedding)
|
||||
|
||||
logger.info(f"Generated {len(embeddings)} embeddings")
|
||||
return embeddings
|
||||
else:
|
||||
error_text = response.text
|
||||
logger.error(f"Embedding generation failed: {response.status_code} - {error_text}")
|
||||
raise ValueError(f"Embedding generation failed: {response.status_code}")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Embedding generation timed out")
|
||||
raise ValueError("Embedding generation timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling embedding service: {e}")
|
||||
raise ValueError(f"Embedding service error: {str(e)}")
|
||||
|
||||
def _count_tokens(self, texts: List[str]) -> int:
|
||||
"""Count tokens using actual BGE-M3 tokenizer."""
|
||||
if self.tokenizer is not None:
|
||||
try:
|
||||
total_tokens = 0
|
||||
for text in texts:
|
||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
total_tokens += len(tokens)
|
||||
return total_tokens
|
||||
except Exception as e:
|
||||
logger.warning(f"Tokenizer error, falling back to estimation: {e}")
|
||||
|
||||
# Fallback: word count * 1.3
|
||||
total_words = sum(len(text.split()) for text in texts)
|
||||
return int(total_words * 1.3)
|
||||
|
||||
async def _log_embedding_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
tokens_used: int,
|
||||
embedding_count: int,
|
||||
request_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""Log embedding usage to control panel database for billing."""
|
||||
try:
|
||||
import asyncpg
|
||||
import os
|
||||
|
||||
# Calculate cost: BGE-M3 pricing ~$0.10 per million tokens
|
||||
cost_cents = (tokens_used / 1_000_000) * 0.10 * 100
|
||||
|
||||
db_password = os.getenv("CONTROL_PANEL_DB_PASSWORD")
|
||||
if not db_password:
|
||||
logger.warning("CONTROL_PANEL_DB_PASSWORD not set, skipping embedding usage logging")
|
||||
return
|
||||
|
||||
conn = await asyncpg.connect(
|
||||
host=os.getenv("CONTROL_PANEL_DB_HOST", "gentwo-controlpanel-postgres"),
|
||||
database=os.getenv("CONTROL_PANEL_DB_NAME", "gt2_admin"),
|
||||
user=os.getenv("CONTROL_PANEL_DB_USER", "postgres"),
|
||||
password=db_password,
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
try:
|
||||
await conn.execute("""
|
||||
INSERT INTO public.embedding_usage_logs
|
||||
(tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""", tenant_id, user_id, tokens_used, embedding_count, self.model, cost_cents, request_id)
|
||||
|
||||
logger.info(
|
||||
f"Logged embedding usage: tenant={tenant_id}, user={user_id}, "
|
||||
f"tokens={tokens_used}, embeddings={embedding_count}, cost_cents={cost_cents:.4f}"
|
||||
)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log embedding usage: {e}")
|
||||
|
||||
async def generate_single_embedding(self, text: str) -> List[float]:
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text string to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector (list of 1024 floats)
|
||||
"""
|
||||
embeddings = await self.generate_embeddings([text])
|
||||
return embeddings[0] if embeddings else []
|
||||
|
||||
|
||||
# Global client instance
|
||||
_embedding_client: Optional[BGE_M3_EmbeddingClient] = None
|
||||
|
||||
|
||||
def get_embedding_client() -> BGE_M3_EmbeddingClient:
|
||||
"""Get or create global embedding client instance"""
|
||||
global _embedding_client
|
||||
if _embedding_client is None:
|
||||
_embedding_client = BGE_M3_EmbeddingClient()
|
||||
else:
|
||||
# Always refresh the endpoint from current configuration
|
||||
current_endpoint = _embedding_client._get_embedding_endpoint()
|
||||
if _embedding_client.base_url != current_endpoint:
|
||||
_embedding_client.base_url = current_endpoint
|
||||
logger.info(f"BGE-M3 client endpoint refreshed to: {current_endpoint}")
|
||||
return _embedding_client
|
||||
|
||||
|
||||
async def test_embedding_client():
|
||||
"""Test function for the embedding client"""
|
||||
client = get_embedding_client()
|
||||
|
||||
# Test health check
|
||||
is_healthy = await client.health_check()
|
||||
print(f"BGE-M3 service healthy: {is_healthy}")
|
||||
|
||||
if is_healthy:
|
||||
# Test embedding generation
|
||||
test_texts = [
|
||||
"This is a test document about machine learning.",
|
||||
"GT 2.0 is an enterprise AI platform.",
|
||||
"Vector embeddings enable semantic search."
|
||||
]
|
||||
|
||||
embeddings = await client.generate_embeddings(test_texts)
|
||||
print(f"Generated {len(embeddings)} embeddings")
|
||||
print(f"Embedding dimensions: {len(embeddings[0]) if embeddings else 0}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_embedding_client())
|
||||
722
apps/tenant-backend/app/services/enhanced_api_keys.py
Normal file
722
apps/tenant-backend/app/services/enhanced_api_keys.py
Normal file
@@ -0,0 +1,722 @@
|
||||
"""
|
||||
Enhanced API Key Management Service for GT 2.0
|
||||
|
||||
Implements advanced API key management with capability-based permissions,
|
||||
configurable constraints, and comprehensive audit logging.
|
||||
"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
import json
|
||||
import secrets
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from uuid import uuid4
|
||||
import jwt
|
||||
|
||||
from app.core.security import verify_capability_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIKeyStatus(Enum):
|
||||
"""API key status states"""
|
||||
ACTIVE = "active"
|
||||
SUSPENDED = "suspended"
|
||||
EXPIRED = "expired"
|
||||
REVOKED = "revoked"
|
||||
|
||||
|
||||
class APIKeyScope(Enum):
|
||||
"""API key scope levels"""
|
||||
USER = "user" # User-specific operations
|
||||
TENANT = "tenant" # Tenant-wide operations
|
||||
ADMIN = "admin" # Administrative operations
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIKeyUsage:
|
||||
"""API key usage tracking"""
|
||||
requests_count: int = 0
|
||||
last_used: Optional[datetime] = None
|
||||
bytes_transferred: int = 0
|
||||
errors_count: int = 0
|
||||
rate_limit_hits: int = 0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage"""
|
||||
return {
|
||||
"requests_count": self.requests_count,
|
||||
"last_used": self.last_used.isoformat() if self.last_used else None,
|
||||
"bytes_transferred": self.bytes_transferred,
|
||||
"errors_count": self.errors_count,
|
||||
"rate_limit_hits": self.rate_limit_hits
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "APIKeyUsage":
|
||||
"""Create from dictionary"""
|
||||
return cls(
|
||||
requests_count=data.get("requests_count", 0),
|
||||
last_used=datetime.fromisoformat(data["last_used"]) if data.get("last_used") else None,
|
||||
bytes_transferred=data.get("bytes_transferred", 0),
|
||||
errors_count=data.get("errors_count", 0),
|
||||
rate_limit_hits=data.get("rate_limit_hits", 0)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIKeyConfig:
|
||||
"""Enhanced API key configuration"""
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
owner_id: str = ""
|
||||
key_hash: str = ""
|
||||
|
||||
# Capability and permissions
|
||||
capabilities: List[str] = field(default_factory=list)
|
||||
scope: APIKeyScope = APIKeyScope.USER
|
||||
tenant_constraints: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Rate limiting and quotas
|
||||
rate_limit_per_hour: int = 1000
|
||||
daily_quota: int = 10000
|
||||
monthly_quota: int = 300000
|
||||
cost_limit_cents: int = 1000
|
||||
|
||||
# Resource constraints
|
||||
max_tokens_per_request: int = 4000
|
||||
max_concurrent_requests: int = 10
|
||||
allowed_endpoints: List[str] = field(default_factory=list)
|
||||
blocked_endpoints: List[str] = field(default_factory=list)
|
||||
|
||||
# Network and security
|
||||
allowed_ips: List[str] = field(default_factory=list)
|
||||
allowed_domains: List[str] = field(default_factory=list)
|
||||
require_tls: bool = True
|
||||
|
||||
# Lifecycle management
|
||||
status: APIKeyStatus = APIKeyStatus.ACTIVE
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: Optional[datetime] = None
|
||||
last_rotated: Optional[datetime] = None
|
||||
|
||||
# Usage tracking
|
||||
usage: APIKeyUsage = field(default_factory=APIKeyUsage)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"owner_id": self.owner_id,
|
||||
"key_hash": self.key_hash,
|
||||
"capabilities": self.capabilities,
|
||||
"scope": self.scope.value,
|
||||
"tenant_constraints": self.tenant_constraints,
|
||||
"rate_limit_per_hour": self.rate_limit_per_hour,
|
||||
"daily_quota": self.daily_quota,
|
||||
"monthly_quota": self.monthly_quota,
|
||||
"cost_limit_cents": self.cost_limit_cents,
|
||||
"max_tokens_per_request": self.max_tokens_per_request,
|
||||
"max_concurrent_requests": self.max_concurrent_requests,
|
||||
"allowed_endpoints": self.allowed_endpoints,
|
||||
"blocked_endpoints": self.blocked_endpoints,
|
||||
"allowed_ips": self.allowed_ips,
|
||||
"allowed_domains": self.allowed_domains,
|
||||
"require_tls": self.require_tls,
|
||||
"status": self.status.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"last_rotated": self.last_rotated.isoformat() if self.last_rotated else None,
|
||||
"usage": self.usage.to_dict()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "APIKeyConfig":
|
||||
"""Create from dictionary"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
description=data.get("description", ""),
|
||||
owner_id=data["owner_id"],
|
||||
key_hash=data["key_hash"],
|
||||
capabilities=data.get("capabilities", []),
|
||||
scope=APIKeyScope(data.get("scope", "user")),
|
||||
tenant_constraints=data.get("tenant_constraints", {}),
|
||||
rate_limit_per_hour=data.get("rate_limit_per_hour", 1000),
|
||||
daily_quota=data.get("daily_quota", 10000),
|
||||
monthly_quota=data.get("monthly_quota", 300000),
|
||||
cost_limit_cents=data.get("cost_limit_cents", 1000),
|
||||
max_tokens_per_request=data.get("max_tokens_per_request", 4000),
|
||||
max_concurrent_requests=data.get("max_concurrent_requests", 10),
|
||||
allowed_endpoints=data.get("allowed_endpoints", []),
|
||||
blocked_endpoints=data.get("blocked_endpoints", []),
|
||||
allowed_ips=data.get("allowed_ips", []),
|
||||
allowed_domains=data.get("allowed_domains", []),
|
||||
require_tls=data.get("require_tls", True),
|
||||
status=APIKeyStatus(data.get("status", "active")),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
|
||||
last_rotated=datetime.fromisoformat(data["last_rotated"]) if data.get("last_rotated") else None,
|
||||
usage=APIKeyUsage.from_dict(data.get("usage", {}))
|
||||
)
|
||||
|
||||
|
||||
class EnhancedAPIKeyService:
|
||||
"""
|
||||
Enhanced API Key management service with advanced capabilities.
|
||||
|
||||
Features:
|
||||
- Capability-based permissions with tenant constraints
|
||||
- Granular rate limiting and quota management
|
||||
- Network-based access controls (IP, domain restrictions)
|
||||
- Comprehensive usage tracking and analytics
|
||||
- Automated key rotation and lifecycle management
|
||||
- Perfect tenant isolation through file-based storage
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, signing_key: str = ""):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.signing_key = signing_key or self._generate_signing_key()
|
||||
self.base_path = Path(f"/data/{tenant_domain}/api_keys")
|
||||
self.keys_path = self.base_path / "keys"
|
||||
self.usage_path = self.base_path / "usage"
|
||||
self.audit_path = self.base_path / "audit"
|
||||
|
||||
# Ensure directories exist with proper permissions
|
||||
self._ensure_directories()
|
||||
|
||||
logger.info(f"EnhancedAPIKeyService initialized for {tenant_domain}")
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure API key directories exist with proper permissions"""
|
||||
for path in [self.keys_path, self.usage_path, self.audit_path]:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
# Set permissions to 700 (owner only)
|
||||
os.chmod(path, stat.S_IRWXU)
|
||||
|
||||
def _generate_signing_key(self) -> str:
|
||||
"""Generate cryptographic signing key for JWT tokens"""
|
||||
return secrets.token_urlsafe(64)
|
||||
|
||||
async def create_api_key(
|
||||
self,
|
||||
name: str,
|
||||
owner_id: str,
|
||||
capabilities: List[str],
|
||||
scope: APIKeyScope = APIKeyScope.USER,
|
||||
expires_in_days: int = 90,
|
||||
constraints: Optional[Dict[str, Any]] = None,
|
||||
capability_token: str = ""
|
||||
) -> Tuple[APIKeyConfig, str]:
|
||||
"""
|
||||
Create a new API key with specified capabilities and constraints.
|
||||
|
||||
Args:
|
||||
name: Human-readable name for the key
|
||||
owner_id: User who owns the key
|
||||
capabilities: List of capability strings
|
||||
scope: Key scope level
|
||||
expires_in_days: Expiration time in days
|
||||
constraints: Custom constraints for the key
|
||||
capability_token: Admin capability token
|
||||
|
||||
Returns:
|
||||
Tuple of (APIKeyConfig, raw_key)
|
||||
"""
|
||||
# Verify admin capability for key creation
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Generate secure API key
|
||||
raw_key = f"gt2_{self.tenant_domain}_{secrets.token_urlsafe(32)}"
|
||||
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
|
||||
# Apply constraints with tenant-specific defaults
|
||||
final_constraints = self._apply_tenant_defaults(constraints or {})
|
||||
|
||||
# Create API key configuration
|
||||
api_key = APIKeyConfig(
|
||||
name=name,
|
||||
owner_id=owner_id,
|
||||
key_hash=key_hash,
|
||||
capabilities=capabilities,
|
||||
scope=scope,
|
||||
tenant_constraints=final_constraints,
|
||||
expires_at=datetime.utcnow() + timedelta(days=expires_in_days)
|
||||
)
|
||||
|
||||
# Apply scope-based defaults
|
||||
self._apply_scope_defaults(api_key, scope)
|
||||
|
||||
# Store API key
|
||||
await self._store_api_key(api_key)
|
||||
|
||||
# Log creation
|
||||
await self._audit_log("api_key_created", owner_id, {
|
||||
"key_id": api_key.id,
|
||||
"name": name,
|
||||
"scope": scope.value,
|
||||
"capabilities": capabilities
|
||||
})
|
||||
|
||||
logger.info(f"Created API key: {name} ({api_key.id}) for {owner_id}")
|
||||
return api_key, raw_key
|
||||
|
||||
async def validate_api_key(
|
||||
self,
|
||||
raw_key: str,
|
||||
endpoint: str = "",
|
||||
client_ip: str = "",
|
||||
user_agent: str = ""
|
||||
) -> Tuple[bool, Optional[APIKeyConfig], Optional[str]]:
|
||||
"""
|
||||
Validate API key and check constraints.
|
||||
|
||||
Args:
|
||||
raw_key: Raw API key from request
|
||||
endpoint: Requested endpoint
|
||||
client_ip: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
Tuple of (valid, api_key_config, error_message)
|
||||
"""
|
||||
# Hash the key for lookup
|
||||
# Security Note: SHA256 is used here for API key lookup/indexing, not password storage.
|
||||
# API keys are high-entropy random strings, making them resistant to dictionary/rainbow attacks.
|
||||
# This is an acceptable security pattern similar to how GitHub and Stripe handle API keys.
|
||||
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
|
||||
# Load API key configuration
|
||||
api_key = await self._load_api_key_by_hash(key_hash)
|
||||
if not api_key:
|
||||
return False, None, "Invalid API key"
|
||||
|
||||
# Check key status
|
||||
if api_key.status != APIKeyStatus.ACTIVE:
|
||||
return False, api_key, f"API key is {api_key.status.value}"
|
||||
|
||||
# Check expiration
|
||||
if api_key.expires_at and datetime.utcnow() > api_key.expires_at:
|
||||
# Auto-expire the key
|
||||
api_key.status = APIKeyStatus.EXPIRED
|
||||
await self._store_api_key(api_key)
|
||||
return False, api_key, "API key has expired"
|
||||
|
||||
# Check endpoint restrictions
|
||||
if api_key.allowed_endpoints:
|
||||
if endpoint not in api_key.allowed_endpoints:
|
||||
return False, api_key, f"Endpoint {endpoint} not allowed"
|
||||
|
||||
if endpoint in api_key.blocked_endpoints:
|
||||
return False, api_key, f"Endpoint {endpoint} is blocked"
|
||||
|
||||
# Check IP restrictions
|
||||
if api_key.allowed_ips and client_ip not in api_key.allowed_ips:
|
||||
return False, api_key, f"IP {client_ip} not allowed"
|
||||
|
||||
# Check rate limits
|
||||
rate_limit_ok, rate_error = await self._check_rate_limits(api_key)
|
||||
if not rate_limit_ok:
|
||||
return False, api_key, rate_error
|
||||
|
||||
# Update usage
|
||||
await self._update_usage(api_key, endpoint, client_ip)
|
||||
|
||||
return True, api_key, None
|
||||
|
||||
async def generate_capability_token(
|
||||
self,
|
||||
api_key: APIKeyConfig,
|
||||
additional_context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate JWT capability token from API key.
|
||||
|
||||
Args:
|
||||
api_key: API key configuration
|
||||
additional_context: Additional context for the token
|
||||
|
||||
Returns:
|
||||
JWT capability token
|
||||
"""
|
||||
# Build capability payload
|
||||
capabilities = []
|
||||
for cap_string in api_key.capabilities:
|
||||
capability = {
|
||||
"resource": cap_string,
|
||||
"actions": ["*"], # API keys get full action access for their capabilities
|
||||
"constraints": api_key.tenant_constraints.get(cap_string, {})
|
||||
}
|
||||
capabilities.append(capability)
|
||||
|
||||
# Create JWT payload
|
||||
payload = {
|
||||
"sub": api_key.owner_id,
|
||||
"tenant_id": self.tenant_domain,
|
||||
"api_key_id": api_key.id,
|
||||
"scope": api_key.scope.value,
|
||||
"capabilities": capabilities,
|
||||
"constraints": api_key.tenant_constraints,
|
||||
"rate_limits": {
|
||||
"requests_per_hour": api_key.rate_limit_per_hour,
|
||||
"max_tokens_per_request": api_key.max_tokens_per_request,
|
||||
"cost_limit_cents": api_key.cost_limit_cents
|
||||
},
|
||||
"iat": int(datetime.utcnow().timestamp()),
|
||||
"exp": int((datetime.utcnow() + timedelta(hours=1)).timestamp())
|
||||
}
|
||||
|
||||
# Add additional context
|
||||
if additional_context:
|
||||
payload.update(additional_context)
|
||||
|
||||
# Sign and return token
|
||||
token = jwt.encode(payload, self.signing_key, algorithm="HS256")
|
||||
return token
|
||||
|
||||
async def rotate_api_key(
|
||||
self,
|
||||
key_id: str,
|
||||
owner_id: str,
|
||||
capability_token: str
|
||||
) -> Tuple[APIKeyConfig, str]:
|
||||
"""
|
||||
Rotate API key (generate new key value).
|
||||
|
||||
Args:
|
||||
key_id: API key ID to rotate
|
||||
owner_id: Owner of the key
|
||||
capability_token: Admin capability token
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_config, new_raw_key)
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Load existing key
|
||||
api_key = await self._load_api_key(key_id)
|
||||
if not api_key:
|
||||
raise ValueError("API key not found")
|
||||
|
||||
# Verify ownership
|
||||
if api_key.owner_id != owner_id:
|
||||
raise PermissionError("Only key owner can rotate")
|
||||
|
||||
# Generate new key
|
||||
new_raw_key = f"gt2_{self.tenant_domain}_{secrets.token_urlsafe(32)}"
|
||||
new_key_hash = hashlib.sha256(new_raw_key.encode()).hexdigest()
|
||||
|
||||
# Update configuration
|
||||
api_key.key_hash = new_key_hash
|
||||
api_key.last_rotated = datetime.utcnow()
|
||||
api_key.updated_at = datetime.utcnow()
|
||||
|
||||
# Store updated key
|
||||
await self._store_api_key(api_key)
|
||||
|
||||
# Log rotation
|
||||
await self._audit_log("api_key_rotated", owner_id, {
|
||||
"key_id": key_id,
|
||||
"name": api_key.name
|
||||
})
|
||||
|
||||
logger.info(f"Rotated API key: {api_key.name} ({key_id})")
|
||||
return api_key, new_raw_key
|
||||
|
||||
async def revoke_api_key(
|
||||
self,
|
||||
key_id: str,
|
||||
owner_id: str,
|
||||
capability_token: str
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke API key (mark as revoked).
|
||||
|
||||
Args:
|
||||
key_id: API key ID to revoke
|
||||
owner_id: Owner of the key
|
||||
capability_token: Admin capability token
|
||||
|
||||
Returns:
|
||||
True if revoked successfully
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Load and verify key
|
||||
api_key = await self._load_api_key(key_id)
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
if api_key.owner_id != owner_id:
|
||||
raise PermissionError("Only key owner can revoke")
|
||||
|
||||
# Revoke key
|
||||
api_key.status = APIKeyStatus.REVOKED
|
||||
api_key.updated_at = datetime.utcnow()
|
||||
|
||||
# Store updated key
|
||||
await self._store_api_key(api_key)
|
||||
|
||||
# Log revocation
|
||||
await self._audit_log("api_key_revoked", owner_id, {
|
||||
"key_id": key_id,
|
||||
"name": api_key.name
|
||||
})
|
||||
|
||||
logger.info(f"Revoked API key: {api_key.name} ({key_id})")
|
||||
return True
|
||||
|
||||
async def list_user_api_keys(
|
||||
self,
|
||||
owner_id: str,
|
||||
capability_token: str,
|
||||
include_usage: bool = True
|
||||
) -> List[APIKeyConfig]:
|
||||
"""
|
||||
List API keys for a user.
|
||||
|
||||
Args:
|
||||
owner_id: User to get keys for
|
||||
capability_token: User capability token
|
||||
include_usage: Include usage statistics
|
||||
|
||||
Returns:
|
||||
List of API key configurations
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
user_keys = []
|
||||
|
||||
# Load all keys and filter by owner
|
||||
if self.keys_path.exists():
|
||||
for key_file in self.keys_path.glob("*.json"):
|
||||
try:
|
||||
with open(key_file, "r") as f:
|
||||
data = json.load(f)
|
||||
if data.get("owner_id") == owner_id:
|
||||
api_key = APIKeyConfig.from_dict(data)
|
||||
|
||||
# Update usage if requested
|
||||
if include_usage:
|
||||
await self._update_key_usage_stats(api_key)
|
||||
|
||||
user_keys.append(api_key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading key file {key_file}: {e}")
|
||||
|
||||
return sorted(user_keys, key=lambda k: k.created_at, reverse=True)
|
||||
|
||||
async def get_usage_analytics(
|
||||
self,
|
||||
owner_id: str,
|
||||
key_id: Optional[str] = None,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get usage analytics for API keys.
|
||||
|
||||
Args:
|
||||
owner_id: Owner of the keys
|
||||
key_id: Specific key ID (optional)
|
||||
days: Number of days to analyze
|
||||
|
||||
Returns:
|
||||
Usage analytics data
|
||||
"""
|
||||
analytics = {
|
||||
"total_requests": 0,
|
||||
"total_errors": 0,
|
||||
"avg_requests_per_day": 0,
|
||||
"most_used_endpoints": [],
|
||||
"rate_limit_hits": 0,
|
||||
"keys_analyzed": 0,
|
||||
"date_range": {
|
||||
"start": (datetime.utcnow() - timedelta(days=days)).isoformat(),
|
||||
"end": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Get user's keys
|
||||
user_keys = await self.list_user_api_keys(owner_id, "", include_usage=True)
|
||||
|
||||
# Filter by specific key if requested
|
||||
if key_id:
|
||||
user_keys = [key for key in user_keys if key.id == key_id]
|
||||
|
||||
# Aggregate usage data
|
||||
for api_key in user_keys:
|
||||
analytics["total_requests"] += api_key.usage.requests_count
|
||||
analytics["total_errors"] += api_key.usage.errors_count
|
||||
analytics["rate_limit_hits"] += api_key.usage.rate_limit_hits
|
||||
analytics["keys_analyzed"] += 1
|
||||
|
||||
# Calculate averages
|
||||
if days > 0:
|
||||
analytics["avg_requests_per_day"] = analytics["total_requests"] / days
|
||||
|
||||
return analytics
|
||||
|
||||
def _apply_tenant_defaults(self, constraints: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Apply tenant-specific default constraints"""
|
||||
defaults = {
|
||||
"max_automation_chain_depth": 5,
|
||||
"mcp_memory_limit_mb": 512,
|
||||
"mcp_timeout_seconds": 30,
|
||||
"max_file_size_bytes": 10 * 1024 * 1024, # 10MB
|
||||
"allowed_file_types": [".pdf", ".txt", ".md", ".json", ".csv"],
|
||||
"enable_premium_features": False
|
||||
}
|
||||
|
||||
# Merge with provided constraints (provided values take precedence)
|
||||
final_constraints = defaults.copy()
|
||||
final_constraints.update(constraints)
|
||||
|
||||
return final_constraints
|
||||
|
||||
def _apply_scope_defaults(self, api_key: APIKeyConfig, scope: APIKeyScope):
|
||||
"""Apply scope-based default limits"""
|
||||
if scope == APIKeyScope.USER:
|
||||
api_key.rate_limit_per_hour = 1000
|
||||
api_key.daily_quota = 10000
|
||||
api_key.cost_limit_cents = 1000
|
||||
elif scope == APIKeyScope.TENANT:
|
||||
api_key.rate_limit_per_hour = 5000
|
||||
api_key.daily_quota = 50000
|
||||
api_key.cost_limit_cents = 5000
|
||||
elif scope == APIKeyScope.ADMIN:
|
||||
api_key.rate_limit_per_hour = 10000
|
||||
api_key.daily_quota = 100000
|
||||
api_key.cost_limit_cents = 10000
|
||||
|
||||
async def _check_rate_limits(self, api_key: APIKeyConfig) -> Tuple[bool, Optional[str]]:
|
||||
"""Check if API key is within rate limits"""
|
||||
# For now, implement basic hourly check
|
||||
# In production, would check against usage tracking database
|
||||
|
||||
current_hour = datetime.utcnow().replace(minute=0, second=0, microsecond=0)
|
||||
|
||||
# Load hourly usage (mock implementation)
|
||||
hourly_usage = 0 # Would query actual usage data
|
||||
|
||||
if hourly_usage >= api_key.rate_limit_per_hour:
|
||||
api_key.usage.rate_limit_hits += 1
|
||||
await self._store_api_key(api_key)
|
||||
return False, f"Rate limit exceeded: {hourly_usage}/{api_key.rate_limit_per_hour} requests per hour"
|
||||
|
||||
return True, None
|
||||
|
||||
async def _update_usage(self, api_key: APIKeyConfig, endpoint: str, client_ip: str):
|
||||
"""Update API key usage statistics"""
|
||||
api_key.usage.requests_count += 1
|
||||
api_key.usage.last_used = datetime.utcnow()
|
||||
|
||||
# Store updated usage
|
||||
await self._store_api_key(api_key)
|
||||
|
||||
# Log detailed usage (for analytics)
|
||||
await self._log_usage(api_key.id, endpoint, client_ip)
|
||||
|
||||
async def _store_api_key(self, api_key: APIKeyConfig):
|
||||
"""Store API key configuration to file system"""
|
||||
key_file = self.keys_path / f"{api_key.id}.json"
|
||||
|
||||
with open(key_file, "w") as f:
|
||||
json.dump(api_key.to_dict(), f, indent=2)
|
||||
|
||||
# Set secure permissions
|
||||
os.chmod(key_file, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
async def _load_api_key(self, key_id: str) -> Optional[APIKeyConfig]:
|
||||
"""Load API key configuration by ID"""
|
||||
key_file = self.keys_path / f"{key_id}.json"
|
||||
|
||||
if not key_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(key_file, "r") as f:
|
||||
data = json.load(f)
|
||||
return APIKeyConfig.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading API key {key_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _load_api_key_by_hash(self, key_hash: str) -> Optional[APIKeyConfig]:
|
||||
"""Load API key configuration by hash"""
|
||||
if not self.keys_path.exists():
|
||||
return None
|
||||
|
||||
for key_file in self.keys_path.glob("*.json"):
|
||||
try:
|
||||
with open(key_file, "r") as f:
|
||||
data = json.load(f)
|
||||
if data.get("key_hash") == key_hash:
|
||||
return APIKeyConfig.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading key file {key_file}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _update_key_usage_stats(self, api_key: APIKeyConfig):
|
||||
"""Update comprehensive usage statistics for a key"""
|
||||
# In production, would aggregate from detailed usage logs
|
||||
# For now, use existing basic stats
|
||||
pass
|
||||
|
||||
async def _log_usage(self, key_id: str, endpoint: str, client_ip: str):
|
||||
"""Log detailed API key usage for analytics"""
|
||||
usage_record = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"key_id": key_id,
|
||||
"endpoint": endpoint,
|
||||
"client_ip": client_ip,
|
||||
"tenant": self.tenant_domain
|
||||
}
|
||||
|
||||
# Store in daily usage file
|
||||
date_str = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
usage_file = self.usage_path / f"usage_{date_str}.jsonl"
|
||||
|
||||
with open(usage_file, "a") as f:
|
||||
f.write(json.dumps(usage_record) + "\n")
|
||||
|
||||
async def _audit_log(self, action: str, user_id: str, details: Dict[str, Any]):
|
||||
"""Log API key management actions for audit"""
|
||||
audit_record = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"action": action,
|
||||
"user_id": user_id,
|
||||
"tenant": self.tenant_domain,
|
||||
"details": details
|
||||
}
|
||||
|
||||
# Store in daily audit file
|
||||
date_str = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
audit_file = self.audit_path / f"audit_{date_str}.jsonl"
|
||||
|
||||
with open(audit_file, "a") as f:
|
||||
f.write(json.dumps(audit_record) + "\n")
|
||||
635
apps/tenant-backend/app/services/event_bus.py
Normal file
635
apps/tenant-backend/app/services/event_bus.py
Normal file
@@ -0,0 +1,635 @@
|
||||
"""
|
||||
Tenant Event Bus System
|
||||
|
||||
Implements event-driven architecture for automation triggers with perfect
|
||||
tenant isolation and capability-based execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
from enum import Enum
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.path_security import sanitize_tenant_domain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerType(Enum):
|
||||
"""Types of automation triggers"""
|
||||
CRON = "cron" # Time-based
|
||||
WEBHOOK = "webhook" # External HTTP
|
||||
EVENT = "event" # Internal events
|
||||
CHAIN = "chain" # Triggered by other automations
|
||||
MANUAL = "manual" # User-initiated
|
||||
|
||||
|
||||
# Event catalog with required fields
|
||||
EVENT_CATALOG = {
|
||||
"document.uploaded": ["document_id", "dataset_id", "filename"],
|
||||
"document.processed": ["document_id", "chunks_created"],
|
||||
"agent.created": ["agent_id", "name", "owner_id"],
|
||||
"chat.started": ["conversation_id", "agent_id"],
|
||||
"resource.shared": ["resource_id", "access_group", "shared_with"],
|
||||
"quota.warning": ["resource_type", "current_usage", "limit"],
|
||||
"automation.completed": ["automation_id", "result", "duration_ms"],
|
||||
"automation.failed": ["automation_id", "error", "retry_count"]
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
"""Event data structure"""
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
type: str = ""
|
||||
tenant: str = ""
|
||||
user: str = ""
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert event to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.type,
|
||||
"tenant": self.tenant,
|
||||
"user": self.user,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"data": self.data,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Event":
|
||||
"""Create event from dictionary"""
|
||||
return cls(
|
||||
id=data.get("id", str(uuid4())),
|
||||
type=data.get("type", ""),
|
||||
tenant=data.get("tenant", ""),
|
||||
user=data.get("user", ""),
|
||||
timestamp=datetime.fromisoformat(data.get("timestamp", datetime.utcnow().isoformat())),
|
||||
data=data.get("data", {}),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Automation:
|
||||
"""Automation configuration"""
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
name: str = ""
|
||||
owner_id: str = ""
|
||||
trigger_type: TriggerType = TriggerType.MANUAL
|
||||
trigger_config: Dict[str, Any] = field(default_factory=dict)
|
||||
conditions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
actions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
triggers_chain: bool = False
|
||||
chain_targets: List[str] = field(default_factory=list)
|
||||
max_retries: int = 3
|
||||
timeout_seconds: int = 300
|
||||
is_active: bool = True
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def matches_event(self, event: Event) -> bool:
|
||||
"""Check if automation should trigger for event"""
|
||||
if not self.is_active:
|
||||
return False
|
||||
|
||||
if self.trigger_type != TriggerType.EVENT:
|
||||
return False
|
||||
|
||||
# Check event type matches
|
||||
event_types = self.trigger_config.get("event_types", [])
|
||||
if event.type not in event_types:
|
||||
return False
|
||||
|
||||
# Check conditions
|
||||
for condition in self.conditions:
|
||||
if not self._evaluate_condition(condition, event):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _evaluate_condition(self, condition: Dict[str, Any], event: Event) -> bool:
|
||||
"""Evaluate a single condition"""
|
||||
field = condition.get("field")
|
||||
operator = condition.get("operator")
|
||||
value = condition.get("value")
|
||||
|
||||
# Get field value from event
|
||||
if "." in field:
|
||||
parts = field.split(".")
|
||||
# Handle data.field paths by starting from the event object
|
||||
if parts[0] == "data":
|
||||
event_value = event.data
|
||||
parts = parts[1:] # Skip the "data" part
|
||||
else:
|
||||
event_value = event
|
||||
|
||||
for part in parts:
|
||||
if isinstance(event_value, dict):
|
||||
event_value = event_value.get(part)
|
||||
elif hasattr(event_value, part):
|
||||
event_value = getattr(event_value, part)
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
event_value = getattr(event, field, None)
|
||||
|
||||
# Evaluate condition
|
||||
if operator == "equals":
|
||||
return event_value == value
|
||||
elif operator == "not_equals":
|
||||
return event_value != value
|
||||
elif operator == "contains":
|
||||
return value in str(event_value)
|
||||
elif operator == "greater_than":
|
||||
return float(event_value) > float(value)
|
||||
elif operator == "less_than":
|
||||
return float(event_value) < float(value)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class TenantEventBus:
|
||||
"""
|
||||
Event system for automation triggers with tenant isolation.
|
||||
|
||||
Features:
|
||||
- Perfect tenant isolation through file-based storage
|
||||
- Event persistence and replay capability
|
||||
- Automation matching and triggering
|
||||
- Access control for automation execution
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, base_path: Optional[Path] = None):
|
||||
self.tenant_domain = tenant_domain
|
||||
# Sanitize tenant_domain to prevent path traversal
|
||||
safe_tenant = sanitize_tenant_domain(tenant_domain)
|
||||
self.base_path = base_path or (Path("/data") / safe_tenant / "events")
|
||||
self.event_store_path = self.base_path / "store"
|
||||
self.automations_path = self.base_path / "automations"
|
||||
self.event_handlers: Dict[str, List[Callable]] = {}
|
||||
self.running_automations: Dict[str, asyncio.Task] = {}
|
||||
|
||||
# Ensure directories exist with proper permissions
|
||||
self._ensure_directories()
|
||||
|
||||
logger.info(f"TenantEventBus initialized for {tenant_domain}")
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure event directories exist with proper permissions"""
|
||||
import os
|
||||
import stat
|
||||
|
||||
for path in [self.event_store_path, self.automations_path]:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
# Set permissions to 700 (owner only)
|
||||
# codeql[py/path-injection] paths derived from sanitize_tenant_domain() at line 175
|
||||
os.chmod(path, stat.S_IRWXU)
|
||||
|
||||
async def emit_event(
|
||||
self,
|
||||
event_type: str,
|
||||
user_id: str,
|
||||
data: Dict[str, Any],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Event:
|
||||
"""
|
||||
Emit an event and trigger matching automations.
|
||||
|
||||
Args:
|
||||
event_type: Type of event from EVENT_CATALOG
|
||||
user_id: User who triggered the event
|
||||
data: Event data
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
Created event
|
||||
"""
|
||||
# Validate event type
|
||||
if event_type not in EVENT_CATALOG:
|
||||
logger.warning(f"Unknown event type: {event_type}")
|
||||
|
||||
# Create event
|
||||
event = Event(
|
||||
type=event_type,
|
||||
tenant=self.tenant_domain,
|
||||
user=user_id,
|
||||
data=data,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# Store event
|
||||
await self._store_event(event)
|
||||
|
||||
# Find matching automations
|
||||
automations = await self._find_matching_automations(event)
|
||||
|
||||
# Trigger automations with access control
|
||||
for automation in automations:
|
||||
if await self._can_trigger(user_id, automation):
|
||||
asyncio.create_task(self._trigger_automation(automation, event))
|
||||
|
||||
# Call registered handlers
|
||||
if event_type in self.event_handlers:
|
||||
for handler in self.event_handlers[event_type]:
|
||||
asyncio.create_task(handler(event))
|
||||
|
||||
logger.info(f"Event emitted: {event_type} by {user_id}")
|
||||
return event
|
||||
|
||||
async def _store_event(self, event: Event):
|
||||
"""Store event to file system"""
|
||||
# Create daily event file
|
||||
date_str = event.timestamp.strftime("%Y-%m-%d")
|
||||
event_file = self.event_store_path / f"events_{date_str}.jsonl"
|
||||
|
||||
# Append event as JSON line
|
||||
with open(event_file, "a") as f:
|
||||
f.write(json.dumps(event.to_dict()) + "\n")
|
||||
|
||||
async def _find_matching_automations(self, event: Event) -> List[Automation]:
|
||||
"""Find automations that match the event"""
|
||||
matching = []
|
||||
|
||||
# Load all automations from file system
|
||||
if self.automations_path.exists():
|
||||
for automation_file in self.automations_path.glob("*.json"):
|
||||
try:
|
||||
with open(automation_file, "r") as f:
|
||||
automation_data = json.load(f)
|
||||
automation = Automation(**automation_data)
|
||||
|
||||
if automation.matches_event(event):
|
||||
matching.append(automation)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading automation {automation_file}: {e}")
|
||||
|
||||
return matching
|
||||
|
||||
async def _can_trigger(self, user_id: str, automation: Automation) -> bool:
|
||||
"""Check if user can trigger automation"""
|
||||
# Owner can always trigger their automations
|
||||
if automation.owner_id == user_id:
|
||||
return True
|
||||
|
||||
# Check if automation is public or shared
|
||||
# This would integrate with AccessController
|
||||
# For now, only owner can trigger
|
||||
return False
|
||||
|
||||
async def _trigger_automation(self, automation: Automation, event: Event):
|
||||
"""Trigger automation execution"""
|
||||
try:
|
||||
# Check if automation is already running
|
||||
if automation.id in self.running_automations:
|
||||
logger.info(f"Automation {automation.id} already running, skipping")
|
||||
return
|
||||
|
||||
# Create task for automation execution
|
||||
task = asyncio.create_task(
|
||||
self._execute_automation(automation, event)
|
||||
)
|
||||
self.running_automations[automation.id] = task
|
||||
|
||||
# Wait for completion with timeout
|
||||
await asyncio.wait_for(task, timeout=automation.timeout_seconds)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Automation {automation.id} timed out")
|
||||
await self.emit_event(
|
||||
"automation.failed",
|
||||
automation.owner_id,
|
||||
{
|
||||
"automation_id": automation.id,
|
||||
"error": "Timeout",
|
||||
"retry_count": 0
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering automation {automation.id}: {e}")
|
||||
await self.emit_event(
|
||||
"automation.failed",
|
||||
automation.owner_id,
|
||||
{
|
||||
"automation_id": automation.id,
|
||||
"error": str(e),
|
||||
"retry_count": 0
|
||||
}
|
||||
)
|
||||
finally:
|
||||
# Remove from running automations
|
||||
self.running_automations.pop(automation.id, None)
|
||||
|
||||
async def _execute_automation(self, automation: Automation, event: Event) -> Any:
|
||||
"""Execute automation actions"""
|
||||
start_time = datetime.utcnow()
|
||||
results = []
|
||||
|
||||
try:
|
||||
# Execute each action in sequence
|
||||
for action in automation.actions:
|
||||
result = await self._execute_action(action, event, automation)
|
||||
results.append(result)
|
||||
|
||||
# Calculate duration
|
||||
duration_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
# Emit completion event
|
||||
await self.emit_event(
|
||||
"automation.completed",
|
||||
automation.owner_id,
|
||||
{
|
||||
"automation_id": automation.id,
|
||||
"result": results,
|
||||
"duration_ms": duration_ms
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing automation {automation.id}: {e}")
|
||||
raise
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event,
|
||||
automation: Automation
|
||||
) -> Any:
|
||||
"""Execute a single action"""
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "webhook":
|
||||
return await self._execute_webhook_action(action, event)
|
||||
elif action_type == "email":
|
||||
return await self._execute_email_action(action, event)
|
||||
elif action_type == "log":
|
||||
return await self._execute_log_action(action, event)
|
||||
elif action_type == "chain":
|
||||
return await self._execute_chain_action(action, event, automation)
|
||||
else:
|
||||
logger.warning(f"Unknown action type: {action_type}")
|
||||
return None
|
||||
|
||||
async def _execute_webhook_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute webhook action (mock implementation)"""
|
||||
url = action.get("url")
|
||||
method = action.get("method", "POST")
|
||||
headers = action.get("headers", {})
|
||||
body = action.get("body", event.to_dict())
|
||||
|
||||
logger.info(f"Mock webhook call to {url}")
|
||||
|
||||
# In production, use httpx or aiohttp to make actual HTTP request
|
||||
return {
|
||||
"status": "success",
|
||||
"url": url,
|
||||
"method": method,
|
||||
"mock": True
|
||||
}
|
||||
|
||||
async def _execute_email_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute email action (mock implementation)"""
|
||||
to = action.get("to")
|
||||
subject = action.get("subject")
|
||||
body = action.get("body")
|
||||
|
||||
logger.info(f"Mock email to {to}: {subject}")
|
||||
|
||||
# In production, integrate with email service
|
||||
return {
|
||||
"status": "success",
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"mock": True
|
||||
}
|
||||
|
||||
async def _execute_log_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute log action"""
|
||||
message = action.get("message", f"Event: {event.type}")
|
||||
level = action.get("level", "info")
|
||||
|
||||
if level == "debug":
|
||||
logger.debug(message)
|
||||
elif level == "info":
|
||||
logger.info(message)
|
||||
elif level == "warning":
|
||||
logger.warning(message)
|
||||
elif level == "error":
|
||||
logger.error(message)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": message,
|
||||
"level": level
|
||||
}
|
||||
|
||||
async def _execute_chain_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event,
|
||||
automation: Automation
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute chain action to trigger other automations"""
|
||||
target_automation_id = action.get("target_automation_id")
|
||||
|
||||
if not target_automation_id:
|
||||
return {"status": "error", "message": "No target automation specified"}
|
||||
|
||||
# Emit chain event
|
||||
chain_event = await self.emit_event(
|
||||
"automation.chain",
|
||||
automation.owner_id,
|
||||
{
|
||||
"source_automation": automation.id,
|
||||
"target_automation": target_automation_id,
|
||||
"original_event": event.to_dict()
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"chain_event_id": chain_event.id,
|
||||
"target_automation": target_automation_id
|
||||
}
|
||||
|
||||
def register_handler(self, event_type: str, handler: Callable):
|
||||
"""Register an event handler"""
|
||||
if event_type not in self.event_handlers:
|
||||
self.event_handlers[event_type] = []
|
||||
self.event_handlers[event_type].append(handler)
|
||||
|
||||
async def create_automation(
|
||||
self,
|
||||
name: str,
|
||||
owner_id: str,
|
||||
trigger_type: TriggerType,
|
||||
trigger_config: Dict[str, Any],
|
||||
actions: List[Dict[str, Any]],
|
||||
conditions: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Automation:
|
||||
"""Create and save a new automation"""
|
||||
automation = Automation(
|
||||
name=name,
|
||||
owner_id=owner_id,
|
||||
trigger_type=trigger_type,
|
||||
trigger_config=trigger_config,
|
||||
actions=actions,
|
||||
conditions=conditions or []
|
||||
)
|
||||
|
||||
# Save to file system
|
||||
automation_file = self.automations_path / f"{automation.id}.json"
|
||||
with open(automation_file, "w") as f:
|
||||
json.dump({
|
||||
"id": automation.id,
|
||||
"name": automation.name,
|
||||
"owner_id": automation.owner_id,
|
||||
"trigger_type": automation.trigger_type.value,
|
||||
"trigger_config": automation.trigger_config,
|
||||
"conditions": automation.conditions,
|
||||
"actions": automation.actions,
|
||||
"triggers_chain": automation.triggers_chain,
|
||||
"chain_targets": automation.chain_targets,
|
||||
"max_retries": automation.max_retries,
|
||||
"timeout_seconds": automation.timeout_seconds,
|
||||
"is_active": automation.is_active,
|
||||
"created_at": automation.created_at.isoformat(),
|
||||
"updated_at": automation.updated_at.isoformat()
|
||||
}, f, indent=2)
|
||||
|
||||
logger.info(f"Created automation: {automation.name} ({automation.id})")
|
||||
return automation
|
||||
|
||||
async def get_automation(self, automation_id: str) -> Optional[Automation]:
|
||||
"""Get automation by ID"""
|
||||
automation_file = self.automations_path / f"{automation_id}.json"
|
||||
|
||||
if not automation_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(automation_file, "r") as f:
|
||||
data = json.load(f)
|
||||
data["trigger_type"] = TriggerType(data["trigger_type"])
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
|
||||
return Automation(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading automation {automation_id}: {e}")
|
||||
return None
|
||||
|
||||
async def list_automations(self, owner_id: Optional[str] = None) -> List[Automation]:
|
||||
"""List all automations, optionally filtered by owner"""
|
||||
automations = []
|
||||
|
||||
if self.automations_path.exists():
|
||||
for automation_file in self.automations_path.glob("*.json"):
|
||||
try:
|
||||
with open(automation_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Filter by owner if specified
|
||||
if owner_id and data.get("owner_id") != owner_id:
|
||||
continue
|
||||
|
||||
data["trigger_type"] = TriggerType(data["trigger_type"])
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
|
||||
automations.append(Automation(**data))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading automation {automation_file}: {e}")
|
||||
|
||||
return automations
|
||||
|
||||
async def delete_automation(self, automation_id: str, owner_id: str) -> bool:
|
||||
"""Delete an automation"""
|
||||
automation = await self.get_automation(automation_id)
|
||||
|
||||
if not automation:
|
||||
return False
|
||||
|
||||
# Check ownership
|
||||
if automation.owner_id != owner_id:
|
||||
logger.warning(f"User {owner_id} attempted to delete automation owned by {automation.owner_id}")
|
||||
return False
|
||||
|
||||
# Delete file
|
||||
automation_file = self.automations_path / f"{automation_id}.json"
|
||||
automation_file.unlink()
|
||||
|
||||
logger.info(f"Deleted automation: {automation_id}")
|
||||
return True
|
||||
|
||||
async def get_event_history(
|
||||
self,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
event_type: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[Event]:
|
||||
"""Get event history with optional filters"""
|
||||
events = []
|
||||
|
||||
# Determine date range
|
||||
if not end_date:
|
||||
end_date = datetime.utcnow()
|
||||
if not start_date:
|
||||
start_date = end_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# Iterate through daily event files
|
||||
current_date = start_date
|
||||
while current_date <= end_date:
|
||||
date_str = current_date.strftime("%Y-%m-%d")
|
||||
event_file = self.event_store_path / f"events_{date_str}.jsonl"
|
||||
|
||||
if event_file.exists():
|
||||
with open(event_file, "r") as f:
|
||||
for line in f:
|
||||
try:
|
||||
event_data = json.loads(line)
|
||||
event = Event.from_dict(event_data)
|
||||
|
||||
# Apply filters
|
||||
if event_type and event.type != event_type:
|
||||
continue
|
||||
if user_id and event.user != user_id:
|
||||
continue
|
||||
|
||||
events.append(event)
|
||||
|
||||
if len(events) >= limit:
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing event: {e}")
|
||||
|
||||
# Move to next day
|
||||
current_date = current_date.replace(day=current_date.day + 1)
|
||||
|
||||
return events
|
||||
869
apps/tenant-backend/app/services/event_service.py
Normal file
869
apps/tenant-backend/app/services/event_service.py
Normal file
@@ -0,0 +1,869 @@
|
||||
"""
|
||||
Event Automation Service for GT 2.0 Tenant Backend
|
||||
|
||||
Handles event-driven automation workflows including:
|
||||
- Document processing triggers
|
||||
- Conversation events
|
||||
- RAG pipeline automation
|
||||
- Agent lifecycle events
|
||||
- User activity tracking
|
||||
|
||||
Perfect tenant isolation with zero downtime compliance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, asdict
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, desc
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.database import get_db_session
|
||||
from app.core.config import get_settings
|
||||
from app.models.event import Event, EventTrigger, EventAction, EventSubscription
|
||||
from app.services.rag_service import RAGService
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.assistant_manager import AssistantManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
"""Event types for automation triggers"""
|
||||
DOCUMENT_UPLOADED = "document.uploaded"
|
||||
DOCUMENT_PROCESSED = "document.processed"
|
||||
DOCUMENT_FAILED = "document.failed"
|
||||
CONVERSATION_STARTED = "conversation.started"
|
||||
MESSAGE_SENT = "message.sent"
|
||||
ASSISTANT_CREATED = "agent.created"
|
||||
RAG_SEARCH_PERFORMED = "rag.search_performed"
|
||||
USER_LOGIN = "user.login"
|
||||
USER_ACTIVITY = "user.activity"
|
||||
SYSTEM_HEALTH_CHECK = "system.health_check"
|
||||
TEAM_INVITATION_CREATED = "team.invitation.created"
|
||||
TEAM_OBSERVABLE_REQUESTED = "team.observable_requested"
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
"""Action types for event responses"""
|
||||
PROCESS_DOCUMENT = "process_document"
|
||||
SEND_NOTIFICATION = "send_notification"
|
||||
UPDATE_STATISTICS = "update_statistics"
|
||||
TRIGGER_RAG_INDEXING = "trigger_rag_indexing"
|
||||
LOG_ANALYTICS = "log_analytics"
|
||||
EXECUTE_WEBHOOK = "execute_webhook"
|
||||
CREATE_ASSISTANT = "create_assistant"
|
||||
SCHEDULE_TASK = "schedule_task"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventPayload:
|
||||
"""Event payload structure"""
|
||||
event_id: str
|
||||
event_type: EventType
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
timestamp: datetime
|
||||
data: Dict[str, Any]
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventActionConfig:
|
||||
"""Configuration for event actions"""
|
||||
action_type: ActionType
|
||||
config: Dict[str, Any]
|
||||
delay_seconds: int = 0
|
||||
retry_count: int = 3
|
||||
retry_delay: int = 60
|
||||
condition: Optional[str] = None # Python expression for conditional execution
|
||||
|
||||
|
||||
class EventService:
|
||||
"""
|
||||
Event automation service with perfect tenant isolation.
|
||||
|
||||
GT 2.0 Security Principles:
|
||||
- Perfect tenant isolation (all events user-scoped)
|
||||
- Zero downtime compliance (async processing)
|
||||
- Self-contained automation (no external dependencies)
|
||||
- Stateless event processing
|
||||
"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.settings = get_settings()
|
||||
self.rag_service = RAGService(db)
|
||||
self.conversation_service = ConversationService(db)
|
||||
self.assistant_manager = AssistantManager(db)
|
||||
|
||||
# Event handlers registry
|
||||
self.action_handlers: Dict[ActionType, Callable] = {
|
||||
ActionType.PROCESS_DOCUMENT: self._handle_process_document,
|
||||
ActionType.SEND_NOTIFICATION: self._handle_send_notification,
|
||||
ActionType.UPDATE_STATISTICS: self._handle_update_statistics,
|
||||
ActionType.TRIGGER_RAG_INDEXING: self._handle_trigger_rag_indexing,
|
||||
ActionType.LOG_ANALYTICS: self._handle_log_analytics,
|
||||
ActionType.EXECUTE_WEBHOOK: self._handle_execute_webhook,
|
||||
ActionType.CREATE_ASSISTANT: self._handle_create_assistant,
|
||||
ActionType.SCHEDULE_TASK: self._handle_schedule_task,
|
||||
}
|
||||
|
||||
# Active event subscriptions cache
|
||||
self.subscriptions_cache: Dict[str, List[EventSubscription]] = {}
|
||||
self.cache_expiry: Optional[datetime] = None
|
||||
|
||||
logger.info("Event automation service initialized with tenant isolation")
|
||||
|
||||
async def emit_event(
|
||||
self,
|
||||
event_type: EventType,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
data: Dict[str, Any],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""Emit an event and trigger automated actions"""
|
||||
try:
|
||||
# Create event payload
|
||||
event_payload = EventPayload(
|
||||
event_id=str(uuid.uuid4()),
|
||||
event_type=event_type,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
timestamp=datetime.utcnow(),
|
||||
data=data,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# Store event in database
|
||||
event_record = Event(
|
||||
event_id=event_payload.event_id,
|
||||
event_type=event_type.value,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
payload=event_payload.to_dict(),
|
||||
status="processing"
|
||||
)
|
||||
|
||||
self.db.add(event_record)
|
||||
await self.db.commit()
|
||||
|
||||
# Process event asynchronously
|
||||
asyncio.create_task(self._process_event(event_payload))
|
||||
|
||||
logger.info(f"Event emitted: {event_type.value} for user {user_id}")
|
||||
return event_payload.event_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to emit event {event_type.value}: {e}")
|
||||
raise
|
||||
|
||||
async def _process_event(self, event_payload: EventPayload) -> None:
|
||||
"""Process event and execute matching actions"""
|
||||
try:
|
||||
# Get subscriptions for this event type
|
||||
subscriptions = await self._get_event_subscriptions(
|
||||
event_payload.event_type,
|
||||
event_payload.user_id,
|
||||
event_payload.tenant_id
|
||||
)
|
||||
|
||||
if not subscriptions:
|
||||
logger.debug(f"No subscriptions found for event {event_payload.event_type}")
|
||||
return
|
||||
|
||||
# Execute actions for each subscription
|
||||
for subscription in subscriptions:
|
||||
try:
|
||||
await self._execute_subscription_actions(subscription, event_payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute subscription {subscription.id}: {e}")
|
||||
continue
|
||||
|
||||
# Update event status
|
||||
await self._update_event_status(event_payload.event_id, "completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process event {event_payload.event_id}: {e}")
|
||||
await self._update_event_status(event_payload.event_id, "failed", str(e))
|
||||
|
||||
async def _get_event_subscriptions(
|
||||
self,
|
||||
event_type: EventType,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> List[EventSubscription]:
|
||||
"""Get active subscriptions for event type with tenant isolation"""
|
||||
try:
|
||||
# Check cache first
|
||||
cache_key = f"{tenant_id}:{user_id}:{event_type.value}"
|
||||
if (self.cache_expiry and datetime.utcnow() < self.cache_expiry and
|
||||
cache_key in self.subscriptions_cache):
|
||||
return self.subscriptions_cache[cache_key]
|
||||
|
||||
# Query database
|
||||
query = select(EventSubscription).where(
|
||||
and_(
|
||||
EventSubscription.event_type == event_type.value,
|
||||
EventSubscription.user_id == user_id,
|
||||
EventSubscription.tenant_id == tenant_id,
|
||||
EventSubscription.is_active == True
|
||||
)
|
||||
).options(selectinload(EventSubscription.actions))
|
||||
|
||||
result = await self.db.execute(query)
|
||||
subscriptions = result.scalars().all()
|
||||
|
||||
# Cache results
|
||||
self.subscriptions_cache[cache_key] = list(subscriptions)
|
||||
self.cache_expiry = datetime.utcnow() + timedelta(minutes=5)
|
||||
|
||||
return list(subscriptions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get event subscriptions: {e}")
|
||||
return []
|
||||
|
||||
async def _execute_subscription_actions(
|
||||
self,
|
||||
subscription: EventSubscription,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Execute all actions for a subscription"""
|
||||
try:
|
||||
for action in subscription.actions:
|
||||
# Check if action should be executed
|
||||
if not await self._should_execute_action(action, event_payload):
|
||||
continue
|
||||
|
||||
# Add delay if specified
|
||||
if action.delay_seconds > 0:
|
||||
await asyncio.sleep(action.delay_seconds)
|
||||
|
||||
# Execute action with retry logic
|
||||
await self._execute_action_with_retry(action, event_payload)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute subscription actions: {e}")
|
||||
raise
|
||||
|
||||
async def _should_execute_action(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> bool:
|
||||
"""Check if action should be executed based on conditions"""
|
||||
try:
|
||||
if not action.condition:
|
||||
return True
|
||||
|
||||
# Create evaluation context
|
||||
context = {
|
||||
'event': event_payload.to_dict(),
|
||||
'data': event_payload.data,
|
||||
'user_id': event_payload.user_id,
|
||||
'tenant_id': event_payload.tenant_id,
|
||||
'event_type': event_payload.event_type.value
|
||||
}
|
||||
|
||||
# Safely evaluate condition
|
||||
try:
|
||||
result = eval(action.condition, {"__builtins__": {}}, context)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to evaluate action condition: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking action condition: {e}")
|
||||
return False
|
||||
|
||||
async def _execute_action_with_retry(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Execute action with retry logic"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(action.retry_count + 1):
|
||||
try:
|
||||
await self._execute_action(action, event_payload)
|
||||
return # Success
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"Action execution attempt {attempt + 1} failed: {e}")
|
||||
|
||||
if attempt < action.retry_count:
|
||||
await asyncio.sleep(action.retry_delay)
|
||||
else:
|
||||
logger.error(f"Action execution failed after {action.retry_count + 1} attempts")
|
||||
raise last_error
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Execute a specific action"""
|
||||
try:
|
||||
action_type = ActionType(action.action_type)
|
||||
handler = self.action_handlers.get(action_type)
|
||||
|
||||
if not handler:
|
||||
raise ValueError(f"No handler for action type: {action_type}")
|
||||
|
||||
await handler(action, event_payload)
|
||||
|
||||
logger.debug(f"Action executed: {action_type.value} for event {event_payload.event_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute action {action.action_type}: {e}")
|
||||
raise
|
||||
|
||||
# Action Handlers
|
||||
|
||||
async def _handle_process_document(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle document processing automation"""
|
||||
try:
|
||||
document_id = event_payload.data.get("document_id")
|
||||
if not document_id:
|
||||
raise ValueError("document_id required for process_document action")
|
||||
|
||||
chunking_strategy = action.config.get("chunking_strategy", "hybrid")
|
||||
|
||||
result = await self.rag_service.process_document(
|
||||
user_id=event_payload.user_id,
|
||||
document_id=document_id,
|
||||
tenant_id=event_payload.tenant_id,
|
||||
chunking_strategy=chunking_strategy
|
||||
)
|
||||
|
||||
# Emit processing completed event
|
||||
await self.emit_event(
|
||||
EventType.DOCUMENT_PROCESSED,
|
||||
event_payload.user_id,
|
||||
event_payload.tenant_id,
|
||||
{
|
||||
"document_id": document_id,
|
||||
"chunk_count": result.get("chunk_count", 0),
|
||||
"processing_result": result
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Emit processing failed event
|
||||
await self.emit_event(
|
||||
EventType.DOCUMENT_FAILED,
|
||||
event_payload.user_id,
|
||||
event_payload.tenant_id,
|
||||
{
|
||||
"document_id": event_payload.data.get("document_id"),
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
|
||||
async def _handle_send_notification(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle notification sending"""
|
||||
try:
|
||||
notification_type = action.config.get("type", "system")
|
||||
message = action.config.get("message", "Event notification")
|
||||
|
||||
# Format message with event data
|
||||
formatted_message = message.format(**event_payload.data)
|
||||
|
||||
# Store notification (implement notification system later)
|
||||
notification_data = {
|
||||
"type": notification_type,
|
||||
"message": formatted_message,
|
||||
"user_id": event_payload.user_id,
|
||||
"event_id": event_payload.event_id,
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Notification sent: {formatted_message} to user {event_payload.user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send notification: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_update_statistics(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle statistics updates"""
|
||||
try:
|
||||
stat_type = action.config.get("type")
|
||||
increment = action.config.get("increment", 1)
|
||||
|
||||
# Update user statistics (implement statistics system later)
|
||||
logger.info(f"Statistics updated: {stat_type} += {increment} for user {event_payload.user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update statistics: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_trigger_rag_indexing(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle RAG reindexing automation"""
|
||||
try:
|
||||
dataset_ids = action.config.get("dataset_ids", [])
|
||||
|
||||
if not dataset_ids:
|
||||
# Get all user datasets
|
||||
datasets = await self.rag_service.list_user_datasets(event_payload.user_id)
|
||||
dataset_ids = [d.id for d in datasets]
|
||||
|
||||
for dataset_id in dataset_ids:
|
||||
# Trigger reindexing for dataset
|
||||
logger.info(f"RAG reindexing triggered for dataset {dataset_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger RAG indexing: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_log_analytics(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle analytics logging"""
|
||||
try:
|
||||
analytics_data = {
|
||||
"event_type": event_payload.event_type.value,
|
||||
"user_id": event_payload.user_id,
|
||||
"tenant_id": event_payload.tenant_id,
|
||||
"timestamp": event_payload.timestamp.isoformat(),
|
||||
"data": event_payload.data,
|
||||
"custom_properties": action.config.get("properties", {})
|
||||
}
|
||||
|
||||
# Log analytics (implement analytics system later)
|
||||
logger.info(f"Analytics logged: {analytics_data}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log analytics: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_execute_webhook(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle webhook execution"""
|
||||
try:
|
||||
webhook_url = action.config.get("url")
|
||||
method = action.config.get("method", "POST")
|
||||
headers = action.config.get("headers", {})
|
||||
|
||||
if not webhook_url:
|
||||
raise ValueError("webhook url required")
|
||||
|
||||
# Prepare webhook payload
|
||||
webhook_payload = {
|
||||
"event": event_payload.to_dict(),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Execute webhook (implement HTTP client later)
|
||||
logger.info(f"Webhook executed: {method} {webhook_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute webhook: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_create_assistant(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle automatic agent creation"""
|
||||
try:
|
||||
template_id = action.config.get("template_id", "general_assistant")
|
||||
assistant_name = action.config.get("name", "Auto-created Agent")
|
||||
|
||||
# Create agent
|
||||
agent_id = await self.assistant_manager.create_from_template(
|
||||
template_id=template_id,
|
||||
config={"name": assistant_name},
|
||||
user_identifier=event_payload.user_id
|
||||
)
|
||||
|
||||
# Emit agent created event
|
||||
await self.emit_event(
|
||||
EventType.ASSISTANT_CREATED,
|
||||
event_payload.user_id,
|
||||
event_payload.tenant_id,
|
||||
{
|
||||
"agent_id": agent_id,
|
||||
"template_id": template_id,
|
||||
"trigger_event": event_payload.event_id
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_schedule_task(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle task scheduling"""
|
||||
try:
|
||||
task_type = action.config.get("task_type")
|
||||
delay_minutes = action.config.get("delay_minutes", 0)
|
||||
|
||||
# Schedule task for future execution
|
||||
scheduled_time = datetime.utcnow() + timedelta(minutes=delay_minutes)
|
||||
|
||||
logger.info(f"Task scheduled: {task_type} for {scheduled_time}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to schedule task: {e}")
|
||||
raise
|
||||
|
||||
# Subscription Management
|
||||
|
||||
async def create_subscription(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
event_type: EventType,
|
||||
actions: List[EventActionConfig],
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None
|
||||
) -> str:
|
||||
"""Create an event subscription"""
|
||||
try:
|
||||
subscription = EventSubscription(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
event_type=event_type.value,
|
||||
name=name or f"Auto-subscription for {event_type.value}",
|
||||
description=description,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
self.db.add(subscription)
|
||||
await self.db.flush()
|
||||
|
||||
# Create actions
|
||||
for action_config in actions:
|
||||
action = EventAction(
|
||||
subscription_id=subscription.id,
|
||||
action_type=action_config.action_type.value,
|
||||
config=action_config.config,
|
||||
delay_seconds=action_config.delay_seconds,
|
||||
retry_count=action_config.retry_count,
|
||||
retry_delay=action_config.retry_delay,
|
||||
condition=action_config.condition
|
||||
)
|
||||
self.db.add(action)
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
# Clear subscriptions cache
|
||||
self._clear_subscriptions_cache()
|
||||
|
||||
logger.info(f"Event subscription created: {subscription.id} for {event_type.value}")
|
||||
return subscription.id
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
logger.error(f"Failed to create subscription: {e}")
|
||||
raise
|
||||
|
||||
async def get_user_subscriptions(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> List[EventSubscription]:
|
||||
"""Get all subscriptions for a user"""
|
||||
try:
|
||||
query = select(EventSubscription).where(
|
||||
and_(
|
||||
EventSubscription.user_id == user_id,
|
||||
EventSubscription.tenant_id == tenant_id
|
||||
)
|
||||
).options(selectinload(EventSubscription.actions))
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user subscriptions: {e}")
|
||||
raise
|
||||
|
||||
async def update_subscription_status(
|
||||
self,
|
||||
subscription_id: str,
|
||||
user_id: str,
|
||||
is_active: bool
|
||||
) -> bool:
|
||||
"""Update subscription status with ownership verification"""
|
||||
try:
|
||||
query = select(EventSubscription).where(
|
||||
and_(
|
||||
EventSubscription.id == subscription_id,
|
||||
EventSubscription.user_id == user_id
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
subscription = result.scalar_one_or_none()
|
||||
|
||||
if not subscription:
|
||||
return False
|
||||
|
||||
subscription.is_active = is_active
|
||||
subscription.updated_at = datetime.utcnow()
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
# Clear subscriptions cache
|
||||
self._clear_subscriptions_cache()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
logger.error(f"Failed to update subscription status: {e}")
|
||||
raise
|
||||
|
||||
async def delete_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
"""Delete subscription with ownership verification"""
|
||||
try:
|
||||
query = select(EventSubscription).where(
|
||||
and_(
|
||||
EventSubscription.id == subscription_id,
|
||||
EventSubscription.user_id == user_id
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
subscription = result.scalar_one_or_none()
|
||||
|
||||
if not subscription:
|
||||
return False
|
||||
|
||||
await self.db.delete(subscription)
|
||||
await self.db.commit()
|
||||
|
||||
# Clear subscriptions cache
|
||||
self._clear_subscriptions_cache()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
logger.error(f"Failed to delete subscription: {e}")
|
||||
raise
|
||||
|
||||
# Utility Methods
|
||||
|
||||
async def _update_event_status(
|
||||
self,
|
||||
event_id: str,
|
||||
status: str,
|
||||
error_message: Optional[str] = None
|
||||
) -> None:
|
||||
"""Update event processing status"""
|
||||
try:
|
||||
query = select(Event).where(Event.event_id == event_id)
|
||||
result = await self.db.execute(query)
|
||||
event = result.scalar_one_or_none()
|
||||
|
||||
if event:
|
||||
event.status = status
|
||||
event.error_message = error_message
|
||||
event.completed_at = datetime.utcnow()
|
||||
await self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update event status: {e}")
|
||||
|
||||
def _clear_subscriptions_cache(self) -> None:
|
||||
"""Clear subscriptions cache"""
|
||||
self.subscriptions_cache.clear()
|
||||
self.cache_expiry = None
|
||||
|
||||
async def get_event_history(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
event_types: Optional[List[EventType]] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[Event]:
|
||||
"""Get event history for user with filtering"""
|
||||
try:
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.user_id == user_id,
|
||||
Event.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
|
||||
if event_types:
|
||||
event_type_values = [et.value for et in event_types]
|
||||
query = query.where(Event.event_type.in_(event_type_values))
|
||||
|
||||
query = query.order_by(desc(Event.created_at)).offset(offset).limit(limit)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get event history: {e}")
|
||||
raise
|
||||
|
||||
async def get_event_statistics(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get event statistics for user"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.user_id == user_id,
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.created_at >= cutoff_date
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
# Calculate statistics
|
||||
stats = {
|
||||
"total_events": len(events),
|
||||
"events_by_type": {},
|
||||
"events_by_status": {},
|
||||
"average_events_per_day": 0
|
||||
}
|
||||
|
||||
for event in events:
|
||||
# Count by type
|
||||
event_type = event.event_type
|
||||
stats["events_by_type"][event_type] = stats["events_by_type"].get(event_type, 0) + 1
|
||||
|
||||
# Count by status
|
||||
status = event.status
|
||||
stats["events_by_status"][status] = stats["events_by_status"].get(status, 0) + 1
|
||||
|
||||
# Calculate average per day
|
||||
if days > 0:
|
||||
stats["average_events_per_day"] = round(len(events) / days, 2)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get event statistics: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Factory function for dependency injection
|
||||
async def get_event_service(db: AsyncSession = None) -> EventService:
|
||||
"""Get event service instance"""
|
||||
if db is None:
|
||||
async with get_db_session() as session:
|
||||
return EventService(session)
|
||||
return EventService(db)
|
||||
|
||||
|
||||
# Default event subscriptions for new users
|
||||
DEFAULT_SUBSCRIPTIONS = [
|
||||
{
|
||||
"event_type": EventType.DOCUMENT_UPLOADED,
|
||||
"actions": [
|
||||
EventActionConfig(
|
||||
action_type=ActionType.PROCESS_DOCUMENT,
|
||||
config={"chunking_strategy": "hybrid"},
|
||||
delay_seconds=5 # Small delay to ensure file is fully uploaded
|
||||
)
|
||||
]
|
||||
},
|
||||
{
|
||||
"event_type": EventType.DOCUMENT_PROCESSED,
|
||||
"actions": [
|
||||
EventActionConfig(
|
||||
action_type=ActionType.SEND_NOTIFICATION,
|
||||
config={
|
||||
"type": "success",
|
||||
"message": "Document '{filename}' has been processed successfully with {chunk_count} chunks."
|
||||
}
|
||||
),
|
||||
EventActionConfig(
|
||||
action_type=ActionType.UPDATE_STATISTICS,
|
||||
config={"type": "documents_processed", "increment": 1}
|
||||
)
|
||||
]
|
||||
},
|
||||
{
|
||||
"event_type": EventType.CONVERSATION_STARTED,
|
||||
"actions": [
|
||||
EventActionConfig(
|
||||
action_type=ActionType.LOG_ANALYTICS,
|
||||
config={"properties": {"conversation_start": True}}
|
||||
)
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def setup_default_subscriptions(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
event_service: EventService
|
||||
) -> None:
|
||||
"""Setup default event subscriptions for new user"""
|
||||
try:
|
||||
for subscription_config in DEFAULT_SUBSCRIPTIONS:
|
||||
await event_service.create_subscription(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
event_type=subscription_config["event_type"],
|
||||
actions=subscription_config["actions"],
|
||||
name=f"Default: {subscription_config['event_type'].value}",
|
||||
description="Automatically created default subscription"
|
||||
)
|
||||
|
||||
logger.info(f"Default event subscriptions created for user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup default subscriptions: {e}")
|
||||
raise
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user