GT AI OS Community Edition v2.0.33
Security hardening release addressing CodeQL and Dependabot alerts: - Fix stack trace exposure in error responses - Add SSRF protection with DNS resolution checking - Implement proper URL hostname validation (replaces substring matching) - Add centralized path sanitization to prevent path traversal - Fix ReDoS vulnerability in email validation regex - Improve HTML sanitization in validation utilities - Fix capability wildcard matching in auth utilities - Update glob dependency to address CVE - Add CodeQL suppression comments for verified false positives 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
5
apps/tenant-backend/app/services/__init__.py
Normal file
5
apps/tenant-backend/app/services/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend Services
|
||||
|
||||
Business logic and orchestration services for tenant applications.
|
||||
"""
|
||||
451
apps/tenant-backend/app/services/access_controller.py
Normal file
451
apps/tenant-backend/app/services/access_controller.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""
|
||||
Access Controller Service for GT 2.0
|
||||
|
||||
Manages resource access control with capability-based security.
|
||||
Ensures perfect tenant isolation and proper permission cascading.
|
||||
"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from app.models.access_group import (
|
||||
AccessGroup, TenantStructure, User, Resource,
|
||||
ResourceCreate, ResourceUpdate, ResourceResponse
|
||||
)
|
||||
from app.core.security import verify_capability_token
|
||||
from app.core.database import get_db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccessController:
|
||||
"""
|
||||
Centralized access control service
|
||||
Manages permissions for all resources with tenant isolation
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.base_path = Path(f"/data/{tenant_domain}")
|
||||
self._ensure_tenant_directory()
|
||||
|
||||
def _ensure_tenant_directory(self):
|
||||
"""
|
||||
Ensure tenant directory exists with proper permissions
|
||||
OS User: gt-{tenant_domain}-{pod_id}
|
||||
Permissions: 700 (owner only)
|
||||
"""
|
||||
if not self.base_path.exists():
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
# Set strict permissions - owner only
|
||||
os.chmod(self.base_path, stat.S_IRWXU) # 700
|
||||
logger.info(f"Created tenant directory: {self.base_path} with 700 permissions")
|
||||
|
||||
async def check_permission(
|
||||
self,
|
||||
user_id: str,
|
||||
resource: Resource,
|
||||
action: str = "read"
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if user has permission for action on resource
|
||||
|
||||
Args:
|
||||
user_id: User requesting access
|
||||
resource: Resource being accessed
|
||||
action: read, write, delete, share
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, reason)
|
||||
"""
|
||||
# Verify tenant isolation
|
||||
if resource.tenant_domain != self.tenant_domain:
|
||||
logger.warning(f"Cross-tenant access attempt: {user_id} -> {resource.id}")
|
||||
return False, "Cross-tenant access denied"
|
||||
|
||||
# Owner has all permissions
|
||||
if resource.owner_id == user_id:
|
||||
return True, "Owner access granted"
|
||||
|
||||
# Check action-specific permissions
|
||||
if action == "read":
|
||||
return self._check_read_permission(user_id, resource)
|
||||
elif action == "write":
|
||||
return self._check_write_permission(user_id, resource)
|
||||
elif action == "delete":
|
||||
return False, "Only owner can delete"
|
||||
elif action == "share":
|
||||
return False, "Only owner can share"
|
||||
else:
|
||||
return False, f"Unknown action: {action}"
|
||||
|
||||
def _check_read_permission(self, user_id: str, resource: Resource) -> Tuple[bool, str]:
|
||||
"""Check read permission based on access group"""
|
||||
if resource.access_group == AccessGroup.ORGANIZATION:
|
||||
return True, "Organization-wide read access"
|
||||
elif resource.access_group == AccessGroup.TEAM:
|
||||
if user_id in resource.team_members:
|
||||
return True, "Team member read access"
|
||||
return False, "Not a team member"
|
||||
else: # INDIVIDUAL
|
||||
return False, "Private resource"
|
||||
|
||||
def _check_write_permission(self, user_id: str, resource: Resource) -> Tuple[bool, str]:
|
||||
"""Check write permission - only owner can write"""
|
||||
return False, "Only owner can modify"
|
||||
|
||||
async def create_resource(
|
||||
self,
|
||||
user_id: str,
|
||||
resource_data: ResourceCreate,
|
||||
capability_token: str
|
||||
) -> Resource:
|
||||
"""
|
||||
Create a new resource with proper access control
|
||||
|
||||
Args:
|
||||
user_id: User creating the resource
|
||||
resource_data: Resource creation data
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
Created resource
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Create resource
|
||||
resource = Resource(
|
||||
id=self._generate_resource_id(),
|
||||
name=resource_data.name,
|
||||
resource_type=resource_data.resource_type,
|
||||
owner_id=user_id,
|
||||
tenant_domain=self.tenant_domain,
|
||||
access_group=resource_data.access_group,
|
||||
team_members=resource_data.team_members or [],
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
metadata=resource_data.metadata or {},
|
||||
file_path=None
|
||||
)
|
||||
|
||||
# Create file-based storage if needed
|
||||
if self._requires_file_storage(resource.resource_type):
|
||||
resource.file_path = await self._create_resource_file(resource)
|
||||
|
||||
# Audit log
|
||||
logger.info(f"Resource created: {resource.id} by {user_id} in {self.tenant_domain}")
|
||||
|
||||
return resource
|
||||
|
||||
async def update_resource_access(
|
||||
self,
|
||||
user_id: str,
|
||||
resource_id: str,
|
||||
new_access_group: AccessGroup,
|
||||
team_members: Optional[List[str]] = None
|
||||
) -> Resource:
|
||||
"""
|
||||
Update resource access group
|
||||
|
||||
Args:
|
||||
user_id: User requesting update
|
||||
resource_id: Resource to update
|
||||
new_access_group: New access level
|
||||
team_members: Team members if team access
|
||||
|
||||
Returns:
|
||||
Updated resource
|
||||
"""
|
||||
# Load resource
|
||||
resource = await self._load_resource(resource_id)
|
||||
|
||||
# Check permission
|
||||
allowed, reason = await self.check_permission(user_id, resource, "share")
|
||||
if not allowed:
|
||||
raise PermissionError(f"Access denied: {reason}")
|
||||
|
||||
# Update access
|
||||
old_group = resource.access_group
|
||||
resource.update_access_group(new_access_group, team_members)
|
||||
|
||||
# Update file permissions if needed
|
||||
if resource.file_path:
|
||||
await self._update_file_permissions(resource)
|
||||
|
||||
# Audit log
|
||||
logger.info(
|
||||
f"Access updated: {resource_id} from {old_group} to {new_access_group} "
|
||||
f"by {user_id}"
|
||||
)
|
||||
|
||||
return resource
|
||||
|
||||
async def list_accessible_resources(
|
||||
self,
|
||||
user_id: str,
|
||||
resource_type: Optional[str] = None
|
||||
) -> List[Resource]:
|
||||
"""
|
||||
List all resources accessible to user
|
||||
|
||||
Args:
|
||||
user_id: User requesting list
|
||||
resource_type: Filter by type
|
||||
|
||||
Returns:
|
||||
List of accessible resources
|
||||
"""
|
||||
accessible = []
|
||||
|
||||
# Get all resources in tenant
|
||||
all_resources = await self._list_tenant_resources(resource_type)
|
||||
|
||||
for resource in all_resources:
|
||||
allowed, _ = await self.check_permission(user_id, resource, "read")
|
||||
if allowed:
|
||||
accessible.append(resource)
|
||||
|
||||
return accessible
|
||||
|
||||
async def get_resource_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get resource statistics for user
|
||||
|
||||
Args:
|
||||
user_id: User to get stats for
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
all_resources = await self._list_tenant_resources()
|
||||
|
||||
owned = [r for r in all_resources if r.owner_id == user_id]
|
||||
accessible = await self.list_accessible_resources(user_id)
|
||||
|
||||
stats = {
|
||||
"owned_count": len(owned),
|
||||
"accessible_count": len(accessible),
|
||||
"by_type": {},
|
||||
"by_access_group": {
|
||||
AccessGroup.INDIVIDUAL: 0,
|
||||
AccessGroup.TEAM: 0,
|
||||
AccessGroup.ORGANIZATION: 0
|
||||
}
|
||||
}
|
||||
|
||||
for resource in owned:
|
||||
# Count by type
|
||||
if resource.resource_type not in stats["by_type"]:
|
||||
stats["by_type"][resource.resource_type] = 0
|
||||
stats["by_type"][resource.resource_type] += 1
|
||||
|
||||
# Count by access group
|
||||
stats["by_access_group"][resource.access_group] += 1
|
||||
|
||||
return stats
|
||||
|
||||
def _generate_resource_id(self) -> str:
|
||||
"""Generate unique resource ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _requires_file_storage(self, resource_type: str) -> bool:
|
||||
"""Check if resource type requires file storage"""
|
||||
file_based_types = [
|
||||
"agent", "dataset", "document", "workflow",
|
||||
"notebook", "model", "configuration"
|
||||
]
|
||||
return resource_type in file_based_types
|
||||
|
||||
async def _create_resource_file(self, resource: Resource) -> str:
|
||||
"""
|
||||
Create file for resource with proper permissions
|
||||
|
||||
Args:
|
||||
resource: Resource to create file for
|
||||
|
||||
Returns:
|
||||
File path
|
||||
"""
|
||||
# Determine path based on resource type
|
||||
type_dir = self.base_path / resource.resource_type / resource.id
|
||||
type_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create main file
|
||||
file_path = type_dir / "data.json"
|
||||
file_path.touch()
|
||||
|
||||
# Set strict permissions - 700 for directory, 600 for file
|
||||
os.chmod(type_dir, stat.S_IRWXU) # 700
|
||||
os.chmod(file_path, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
logger.info(f"Created resource file: {file_path} with secure permissions")
|
||||
|
||||
return str(file_path)
|
||||
|
||||
async def _update_file_permissions(self, resource: Resource):
|
||||
"""Update file permissions (always 700/600 for security)"""
|
||||
if not resource.file_path or not Path(resource.file_path).exists():
|
||||
return
|
||||
|
||||
# Permissions don't change based on access group
|
||||
# All files remain 700/600 for OS-level security
|
||||
# Access control is handled at application level
|
||||
pass
|
||||
|
||||
async def _load_resource(self, resource_id: str) -> Resource:
|
||||
"""Load resource from storage"""
|
||||
try:
|
||||
# Search for resource in all resource type directories
|
||||
for resource_type_dir in self.base_path.iterdir():
|
||||
if not resource_type_dir.is_dir():
|
||||
continue
|
||||
|
||||
resource_file = resource_type_dir / "data.json"
|
||||
if resource_file.exists():
|
||||
try:
|
||||
import json
|
||||
with open(resource_file, 'r') as f:
|
||||
resources_data = json.load(f)
|
||||
|
||||
if not isinstance(resources_data, list):
|
||||
resources_data = [resources_data]
|
||||
|
||||
for resource_data in resources_data:
|
||||
if resource_data.get('id') == resource_id:
|
||||
return Resource(
|
||||
id=resource_data['id'],
|
||||
name=resource_data['name'],
|
||||
resource_type=resource_data['resource_type'],
|
||||
owner_id=resource_data['owner_id'],
|
||||
tenant_domain=resource_data['tenant_domain'],
|
||||
access_group=AccessGroup(resource_data['access_group']),
|
||||
team_members=resource_data.get('team_members', []),
|
||||
created_at=datetime.fromisoformat(resource_data['created_at']),
|
||||
updated_at=datetime.fromisoformat(resource_data['updated_at']),
|
||||
metadata=resource_data.get('metadata', {}),
|
||||
file_path=resource_data.get('file_path')
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse resource file {resource_file}: {e}")
|
||||
continue
|
||||
|
||||
raise ValueError(f"Resource {resource_id} not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load resource {resource_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _list_tenant_resources(
|
||||
self,
|
||||
resource_type: Optional[str] = None
|
||||
) -> List[Resource]:
|
||||
"""List all resources in tenant"""
|
||||
try:
|
||||
import json
|
||||
resources = []
|
||||
|
||||
# If specific resource type requested, search only that directory
|
||||
search_dirs = [self.base_path / resource_type] if resource_type else list(self.base_path.iterdir())
|
||||
|
||||
for resource_type_dir in search_dirs:
|
||||
if not resource_type_dir.exists() or not resource_type_dir.is_dir():
|
||||
continue
|
||||
|
||||
resource_file = resource_type_dir / "data.json"
|
||||
if resource_file.exists():
|
||||
try:
|
||||
with open(resource_file, 'r') as f:
|
||||
resources_data = json.load(f)
|
||||
|
||||
if not isinstance(resources_data, list):
|
||||
resources_data = [resources_data]
|
||||
|
||||
for resource_data in resources_data:
|
||||
try:
|
||||
resource = Resource(
|
||||
id=resource_data['id'],
|
||||
name=resource_data['name'],
|
||||
resource_type=resource_data['resource_type'],
|
||||
owner_id=resource_data['owner_id'],
|
||||
tenant_domain=resource_data['tenant_domain'],
|
||||
access_group=AccessGroup(resource_data['access_group']),
|
||||
team_members=resource_data.get('team_members', []),
|
||||
created_at=datetime.fromisoformat(resource_data['created_at']),
|
||||
updated_at=datetime.fromisoformat(resource_data['updated_at']),
|
||||
metadata=resource_data.get('metadata', {}),
|
||||
file_path=resource_data.get('file_path')
|
||||
)
|
||||
resources.append(resource)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse resource data: {e}")
|
||||
continue
|
||||
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning(f"Failed to read resource file {resource_file}: {e}")
|
||||
continue
|
||||
|
||||
return resources
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list tenant resources: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class AccessControlMiddleware:
|
||||
"""
|
||||
Middleware for enforcing access control on API requests
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str):
|
||||
self.controller = AccessController(tenant_domain)
|
||||
|
||||
async def verify_request(
|
||||
self,
|
||||
user_id: str,
|
||||
resource_id: str,
|
||||
action: str,
|
||||
capability_token: str
|
||||
) -> bool:
|
||||
"""
|
||||
Verify request has proper permissions
|
||||
|
||||
Args:
|
||||
user_id: User making request
|
||||
resource_id: Resource being accessed
|
||||
action: Action being performed
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
True if allowed, raises PermissionError if not
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Verify tenant match
|
||||
if token_data.get("tenant_id") != self.controller.tenant_domain:
|
||||
raise PermissionError("Tenant mismatch in capability token")
|
||||
|
||||
# Load resource and check permission
|
||||
resource = await self.controller._load_resource(resource_id)
|
||||
allowed, reason = await self.controller.check_permission(
|
||||
user_id, resource, action
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
f"Access denied: {user_id} -> {resource_id} ({action}): {reason}"
|
||||
)
|
||||
raise PermissionError(f"Access denied: {reason}")
|
||||
|
||||
return True
|
||||
920
apps/tenant-backend/app/services/agent_orchestrator_client.py
Normal file
920
apps/tenant-backend/app/services/agent_orchestrator_client.py
Normal file
@@ -0,0 +1,920 @@
|
||||
"""
|
||||
GT 2.0 Agent Orchestrator Client
|
||||
|
||||
Client for interacting with the Resource Cluster's Agent Orchestration system.
|
||||
Enables spawning and managing subagents for complex task execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import httpx
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from app.services.task_classifier import SubagentType, TaskClassification
|
||||
from app.models.agent import Agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExecutionStrategy(str, Enum):
|
||||
"""Execution strategies for subagents"""
|
||||
SEQUENTIAL = "sequential"
|
||||
PARALLEL = "parallel"
|
||||
CONDITIONAL = "conditional"
|
||||
PIPELINE = "pipeline"
|
||||
MAP_REDUCE = "map_reduce"
|
||||
|
||||
|
||||
class SubagentOrchestrator:
|
||||
"""
|
||||
Orchestrates subagent execution for complex tasks.
|
||||
|
||||
Manages lifecycle of subagents spawned from main agent templates,
|
||||
coordinates their execution, and aggregates results.
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.resource_cluster_url = "http://resource-cluster:8000"
|
||||
self.active_subagents: Dict[str, Dict[str, Any]] = {}
|
||||
self.execution_history: List[Dict[str, Any]] = []
|
||||
|
||||
async def execute_task_plan(
|
||||
self,
|
||||
task_classification: TaskClassification,
|
||||
parent_agent: Agent,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
available_tools: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a task plan using subagents.
|
||||
|
||||
Args:
|
||||
task_classification: Task classification with execution plan
|
||||
parent_agent: Parent agent spawning subagents
|
||||
conversation_id: Current conversation ID
|
||||
user_message: Original user message
|
||||
available_tools: Available MCP tools
|
||||
|
||||
Returns:
|
||||
Aggregated results from subagent execution
|
||||
"""
|
||||
try:
|
||||
execution_id = str(uuid.uuid4())
|
||||
logger.info(f"Starting subagent execution {execution_id} for {task_classification.complexity} task")
|
||||
|
||||
# Track execution
|
||||
execution_record = {
|
||||
"execution_id": execution_id,
|
||||
"conversation_id": conversation_id,
|
||||
"parent_agent_id": parent_agent.id,
|
||||
"task_complexity": task_classification.complexity,
|
||||
"started_at": datetime.now().isoformat(),
|
||||
"subagent_plan": task_classification.subagent_plan
|
||||
}
|
||||
self.execution_history.append(execution_record)
|
||||
|
||||
# Determine execution strategy
|
||||
strategy = self._determine_strategy(task_classification)
|
||||
|
||||
# Execute based on strategy
|
||||
if strategy == ExecutionStrategy.PARALLEL:
|
||||
results = await self._execute_parallel(
|
||||
task_classification.subagent_plan,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools
|
||||
)
|
||||
elif strategy == ExecutionStrategy.SEQUENTIAL:
|
||||
results = await self._execute_sequential(
|
||||
task_classification.subagent_plan,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools
|
||||
)
|
||||
elif strategy == ExecutionStrategy.PIPELINE:
|
||||
results = await self._execute_pipeline(
|
||||
task_classification.subagent_plan,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools
|
||||
)
|
||||
else:
|
||||
# Default to sequential
|
||||
results = await self._execute_sequential(
|
||||
task_classification.subagent_plan,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools
|
||||
)
|
||||
|
||||
# Update execution record
|
||||
execution_record["completed_at"] = datetime.now().isoformat()
|
||||
execution_record["results"] = results
|
||||
|
||||
# Synthesize final response
|
||||
final_response = await self._synthesize_results(
|
||||
results,
|
||||
task_classification,
|
||||
user_message
|
||||
)
|
||||
|
||||
logger.info(f"Completed subagent execution {execution_id}")
|
||||
|
||||
return {
|
||||
"execution_id": execution_id,
|
||||
"strategy": strategy,
|
||||
"subagent_results": results,
|
||||
"final_response": final_response,
|
||||
"execution_time_ms": self._calculate_execution_time(execution_record)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Subagent execution failed: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"partial_results": self.active_subagents
|
||||
}
|
||||
|
||||
async def _execute_parallel(
|
||||
self,
|
||||
subagent_plan: List[Dict[str, Any]],
|
||||
parent_agent: Agent,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
available_tools: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute subagents in parallel"""
|
||||
# Group subagents by priority
|
||||
priority_groups = {}
|
||||
for plan_item in subagent_plan:
|
||||
priority = plan_item.get("priority", 1)
|
||||
if priority not in priority_groups:
|
||||
priority_groups[priority] = []
|
||||
priority_groups[priority].append(plan_item)
|
||||
|
||||
results = {}
|
||||
|
||||
# Execute each priority group
|
||||
for priority in sorted(priority_groups.keys()):
|
||||
group_tasks = []
|
||||
|
||||
for plan_item in priority_groups[priority]:
|
||||
# Check dependencies
|
||||
if self._dependencies_met(plan_item, results):
|
||||
task = asyncio.create_task(
|
||||
self._execute_subagent(
|
||||
plan_item,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools,
|
||||
results
|
||||
)
|
||||
)
|
||||
group_tasks.append((plan_item["id"], task))
|
||||
|
||||
# Wait for group to complete
|
||||
for agent_id, task in group_tasks:
|
||||
try:
|
||||
results[agent_id] = await task
|
||||
except Exception as e:
|
||||
logger.error(f"Subagent {agent_id} failed: {e}")
|
||||
results[agent_id] = {"error": str(e)}
|
||||
|
||||
return results
|
||||
|
||||
async def _execute_sequential(
|
||||
self,
|
||||
subagent_plan: List[Dict[str, Any]],
|
||||
parent_agent: Agent,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
available_tools: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute subagents sequentially"""
|
||||
results = {}
|
||||
|
||||
for plan_item in subagent_plan:
|
||||
if self._dependencies_met(plan_item, results):
|
||||
try:
|
||||
results[plan_item["id"]] = await self._execute_subagent(
|
||||
plan_item,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools,
|
||||
results
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Subagent {plan_item['id']} failed: {e}")
|
||||
results[plan_item["id"]] = {"error": str(e)}
|
||||
|
||||
return results
|
||||
|
||||
async def _execute_pipeline(
|
||||
self,
|
||||
subagent_plan: List[Dict[str, Any]],
|
||||
parent_agent: Agent,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
available_tools: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute subagents in pipeline mode"""
|
||||
results = {}
|
||||
pipeline_data = {"original_message": user_message}
|
||||
|
||||
for plan_item in subagent_plan:
|
||||
try:
|
||||
# Pass output from previous stage as input
|
||||
result = await self._execute_subagent(
|
||||
plan_item,
|
||||
parent_agent,
|
||||
conversation_id,
|
||||
user_message,
|
||||
available_tools,
|
||||
results,
|
||||
pipeline_data
|
||||
)
|
||||
|
||||
results[plan_item["id"]] = result
|
||||
|
||||
# Update pipeline data with output
|
||||
if "output" in result:
|
||||
pipeline_data = result["output"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline stage {plan_item['id']} failed: {e}")
|
||||
results[plan_item["id"]] = {"error": str(e)}
|
||||
break # Pipeline broken
|
||||
|
||||
return results
|
||||
|
||||
async def _execute_subagent(
|
||||
self,
|
||||
plan_item: Dict[str, Any],
|
||||
parent_agent: Agent,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
available_tools: List[Dict[str, Any]],
|
||||
previous_results: Dict[str, Any],
|
||||
pipeline_data: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a single subagent"""
|
||||
subagent_id = plan_item["id"]
|
||||
subagent_type = plan_item["type"]
|
||||
task_description = plan_item["task"]
|
||||
|
||||
logger.info(f"Executing subagent {subagent_id} ({subagent_type}): {task_description[:50]}...")
|
||||
|
||||
# Track subagent
|
||||
self.active_subagents[subagent_id] = {
|
||||
"type": subagent_type,
|
||||
"task": task_description,
|
||||
"started_at": datetime.now().isoformat(),
|
||||
"status": "running"
|
||||
}
|
||||
|
||||
try:
|
||||
# Create subagent configuration based on type
|
||||
subagent_config = self._create_subagent_config(
|
||||
subagent_type,
|
||||
parent_agent,
|
||||
task_description,
|
||||
pipeline_data
|
||||
)
|
||||
|
||||
# Select tools for this subagent
|
||||
subagent_tools = self._select_tools_for_subagent(
|
||||
subagent_type,
|
||||
available_tools
|
||||
)
|
||||
|
||||
# Execute subagent based on type
|
||||
if subagent_type == SubagentType.RESEARCH:
|
||||
result = await self._execute_research_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
subagent_tools,
|
||||
conversation_id
|
||||
)
|
||||
elif subagent_type == SubagentType.PLANNING:
|
||||
result = await self._execute_planning_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
user_message,
|
||||
previous_results
|
||||
)
|
||||
elif subagent_type == SubagentType.IMPLEMENTATION:
|
||||
result = await self._execute_implementation_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
subagent_tools,
|
||||
previous_results
|
||||
)
|
||||
elif subagent_type == SubagentType.VALIDATION:
|
||||
result = await self._execute_validation_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
previous_results
|
||||
)
|
||||
elif subagent_type == SubagentType.SYNTHESIS:
|
||||
result = await self._execute_synthesis_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
previous_results
|
||||
)
|
||||
elif subagent_type == SubagentType.ANALYST:
|
||||
result = await self._execute_analyst_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
previous_results
|
||||
)
|
||||
else:
|
||||
# Default execution
|
||||
result = await self._execute_generic_agent(
|
||||
subagent_config,
|
||||
task_description,
|
||||
subagent_tools
|
||||
)
|
||||
|
||||
# Update tracking
|
||||
self.active_subagents[subagent_id]["status"] = "completed"
|
||||
self.active_subagents[subagent_id]["completed_at"] = datetime.now().isoformat()
|
||||
self.active_subagents[subagent_id]["result"] = result
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Subagent {subagent_id} execution failed: {e}")
|
||||
self.active_subagents[subagent_id]["status"] = "failed"
|
||||
self.active_subagents[subagent_id]["error"] = str(e)
|
||||
raise
|
||||
|
||||
async def _execute_research_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
conversation_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute research subagent"""
|
||||
# Research agents focus on information gathering
|
||||
prompt = f"""You are a research specialist. Your task is to:
|
||||
{task}
|
||||
|
||||
Available tools: {[t['name'] for t in tools]}
|
||||
|
||||
Gather comprehensive information and return structured findings."""
|
||||
|
||||
result = await self._call_llm_with_tools(
|
||||
prompt,
|
||||
config,
|
||||
tools,
|
||||
max_iterations=3
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "research",
|
||||
"findings": result.get("content", ""),
|
||||
"sources": result.get("tool_results", []),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_planning_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
original_query: str,
|
||||
previous_results: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute planning subagent"""
|
||||
context = self._format_previous_results(previous_results)
|
||||
|
||||
prompt = f"""You are a planning specialist. Break down this task into actionable steps:
|
||||
|
||||
Original request: {original_query}
|
||||
Specific task: {task}
|
||||
|
||||
Context from previous agents:
|
||||
{context}
|
||||
|
||||
Create a detailed execution plan with clear steps."""
|
||||
|
||||
result = await self._call_llm(prompt, config)
|
||||
|
||||
return {
|
||||
"type": "planning",
|
||||
"plan": result.get("content", ""),
|
||||
"steps": self._extract_steps(result.get("content", "")),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_implementation_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
previous_results: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute implementation subagent"""
|
||||
context = self._format_previous_results(previous_results)
|
||||
|
||||
prompt = f"""You are an implementation specialist. Execute this task:
|
||||
{task}
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Available tools: {[t['name'] for t in tools]}
|
||||
|
||||
Complete the implementation and return results."""
|
||||
|
||||
result = await self._call_llm_with_tools(
|
||||
prompt,
|
||||
config,
|
||||
tools,
|
||||
max_iterations=5
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "implementation",
|
||||
"implementation": result.get("content", ""),
|
||||
"tool_calls": result.get("tool_calls", []),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_validation_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
previous_results: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute validation subagent"""
|
||||
context = self._format_previous_results(previous_results)
|
||||
|
||||
prompt = f"""You are a validation specialist. Verify the following:
|
||||
{task}
|
||||
|
||||
Results to validate:
|
||||
{context}
|
||||
|
||||
Check for correctness, completeness, and quality."""
|
||||
|
||||
result = await self._call_llm(prompt, config)
|
||||
|
||||
return {
|
||||
"type": "validation",
|
||||
"validation_result": result.get("content", ""),
|
||||
"issues_found": self._extract_issues(result.get("content", "")),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_synthesis_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
previous_results: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute synthesis subagent"""
|
||||
all_results = self._format_all_results(previous_results)
|
||||
|
||||
prompt = f"""You are a synthesis specialist. Combine and summarize these results:
|
||||
|
||||
Task: {task}
|
||||
|
||||
Results from all agents:
|
||||
{all_results}
|
||||
|
||||
Create a comprehensive, coherent response that addresses the original request."""
|
||||
|
||||
result = await self._call_llm(prompt, config)
|
||||
|
||||
return {
|
||||
"type": "synthesis",
|
||||
"final_response": result.get("content", ""),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_analyst_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
previous_results: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute analyst subagent"""
|
||||
data = self._format_previous_results(previous_results)
|
||||
|
||||
prompt = f"""You are an analysis specialist. Analyze the following:
|
||||
{task}
|
||||
|
||||
Data to analyze:
|
||||
{data}
|
||||
|
||||
Identify patterns, insights, and recommendations."""
|
||||
|
||||
result = await self._call_llm(prompt, config)
|
||||
|
||||
return {
|
||||
"type": "analysis",
|
||||
"analysis": result.get("content", ""),
|
||||
"insights": self._extract_insights(result.get("content", "")),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _execute_generic_agent(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
task: str,
|
||||
tools: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute generic subagent"""
|
||||
prompt = f"""Complete the following task:
|
||||
{task}
|
||||
|
||||
Available tools: {[t['name'] for t in tools] if tools else 'None'}"""
|
||||
|
||||
if tools:
|
||||
result = await self._call_llm_with_tools(prompt, config, tools)
|
||||
else:
|
||||
result = await self._call_llm(prompt, config)
|
||||
|
||||
return {
|
||||
"type": "generic",
|
||||
"result": result.get("content", ""),
|
||||
"output": result
|
||||
}
|
||||
|
||||
async def _call_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Call LLM without tools"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Require model to be specified in config - no hardcoded fallbacks
|
||||
model = config.get("model")
|
||||
if not model:
|
||||
raise ValueError(f"No model specified in subagent config: {config}")
|
||||
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": config.get("instructions", "")},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": config.get("temperature", 0.7),
|
||||
"max_tokens": config.get("max_tokens", 2000)
|
||||
},
|
||||
headers={
|
||||
"X-Tenant-ID": self.tenant_domain,
|
||||
"X-User-ID": self.user_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return {
|
||||
"content": result["choices"][0]["message"]["content"],
|
||||
"model": result["model"]
|
||||
}
|
||||
else:
|
||||
raise Exception(f"LLM call failed: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM call failed: {e}")
|
||||
return {"content": f"Error: {str(e)}"}
|
||||
|
||||
async def _call_llm_with_tools(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any],
|
||||
tools: List[Dict[str, Any]],
|
||||
max_iterations: int = 3
|
||||
) -> Dict[str, Any]:
|
||||
"""Call LLM with tool execution capability"""
|
||||
messages = [
|
||||
{"role": "system", "content": config.get("instructions", "")},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
tool_results = []
|
||||
iterations = 0
|
||||
|
||||
while iterations < max_iterations:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Require model to be specified in config - no hardcoded fallbacks
|
||||
model = config.get("model")
|
||||
if not model:
|
||||
raise ValueError(f"No model specified in subagent config: {config}")
|
||||
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": config.get("temperature", 0.7),
|
||||
"max_tokens": config.get("max_tokens", 2000),
|
||||
"tools": tools,
|
||||
"tool_choice": "auto"
|
||||
},
|
||||
headers={
|
||||
"X-Tenant-ID": self.tenant_domain,
|
||||
"X-User-ID": self.user_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"LLM call failed: {response.status_code}")
|
||||
|
||||
result = response.json()
|
||||
choice = result["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
# Add agent's response to messages
|
||||
messages.append(message)
|
||||
|
||||
# Check for tool calls
|
||||
if message.get("tool_calls"):
|
||||
# Execute tools
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_result = await self._execute_tool(
|
||||
tool_call["function"]["name"],
|
||||
tool_call["function"].get("arguments", {})
|
||||
)
|
||||
|
||||
tool_results.append({
|
||||
"tool": tool_call["function"]["name"],
|
||||
"result": tool_result
|
||||
})
|
||||
|
||||
# Add tool result to messages
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": str(tool_result)
|
||||
})
|
||||
|
||||
iterations += 1
|
||||
continue # Get next response
|
||||
|
||||
# No more tool calls, return final result
|
||||
return {
|
||||
"content": message.get("content", ""),
|
||||
"tool_calls": message.get("tool_calls", []),
|
||||
"tool_results": tool_results,
|
||||
"model": result["model"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM with tools call failed: {e}")
|
||||
return {"content": f"Error: {str(e)}", "tool_results": tool_results}
|
||||
|
||||
iterations += 1
|
||||
|
||||
# Max iterations reached
|
||||
return {
|
||||
"content": "Max iterations reached",
|
||||
"tool_results": tool_results
|
||||
}
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute an MCP tool"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/mcp/execute",
|
||||
json={
|
||||
"tool_name": tool_name,
|
||||
"parameters": arguments,
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"user_id": self.user_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return {"error": f"Tool execution failed: {response.status_code}"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _determine_strategy(self, task_classification: TaskClassification) -> ExecutionStrategy:
|
||||
"""Determine execution strategy based on task classification"""
|
||||
if task_classification.parallel_execution:
|
||||
return ExecutionStrategy.PARALLEL
|
||||
elif len(task_classification.subagent_plan) > 3:
|
||||
return ExecutionStrategy.PIPELINE
|
||||
else:
|
||||
return ExecutionStrategy.SEQUENTIAL
|
||||
|
||||
def _dependencies_met(
|
||||
self,
|
||||
plan_item: Dict[str, Any],
|
||||
completed_results: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check if dependencies are met for a subagent"""
|
||||
depends_on = plan_item.get("depends_on", [])
|
||||
return all(dep in completed_results for dep in depends_on)
|
||||
|
||||
def _create_subagent_config(
|
||||
self,
|
||||
subagent_type: SubagentType,
|
||||
parent_agent: Agent,
|
||||
task: str,
|
||||
pipeline_data: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create configuration for subagent"""
|
||||
# Base config from parent
|
||||
config = {
|
||||
"model": parent_agent.model_name,
|
||||
"temperature": parent_agent.model_settings.get("temperature", 0.7),
|
||||
"max_tokens": parent_agent.model_settings.get("max_tokens", 2000)
|
||||
}
|
||||
|
||||
# Customize based on subagent type
|
||||
if subagent_type == SubagentType.RESEARCH:
|
||||
config["instructions"] = "You are a research specialist. Be thorough and accurate."
|
||||
config["temperature"] = 0.3 # Lower for factual research
|
||||
elif subagent_type == SubagentType.PLANNING:
|
||||
config["instructions"] = "You are a planning specialist. Create clear, actionable plans."
|
||||
config["temperature"] = 0.5
|
||||
elif subagent_type == SubagentType.IMPLEMENTATION:
|
||||
config["instructions"] = "You are an implementation specialist. Execute tasks precisely."
|
||||
config["temperature"] = 0.3
|
||||
elif subagent_type == SubagentType.SYNTHESIS:
|
||||
config["instructions"] = "You are a synthesis specialist. Create coherent summaries."
|
||||
config["temperature"] = 0.7
|
||||
else:
|
||||
config["instructions"] = parent_agent.instructions or ""
|
||||
|
||||
return config
|
||||
|
||||
def _select_tools_for_subagent(
|
||||
self,
|
||||
subagent_type: SubagentType,
|
||||
available_tools: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Select appropriate tools for subagent type"""
|
||||
if not available_tools:
|
||||
return []
|
||||
|
||||
# Tool selection based on subagent type
|
||||
if subagent_type == SubagentType.RESEARCH:
|
||||
# Research agents get search tools
|
||||
return [t for t in available_tools if any(
|
||||
keyword in t["name"].lower()
|
||||
for keyword in ["search", "find", "list", "get", "fetch"]
|
||||
)]
|
||||
elif subagent_type == SubagentType.IMPLEMENTATION:
|
||||
# Implementation agents get action tools
|
||||
return [t for t in available_tools if any(
|
||||
keyword in t["name"].lower()
|
||||
for keyword in ["create", "update", "write", "execute", "run"]
|
||||
)]
|
||||
elif subagent_type == SubagentType.VALIDATION:
|
||||
# Validation agents get read/check tools
|
||||
return [t for t in available_tools if any(
|
||||
keyword in t["name"].lower()
|
||||
for keyword in ["read", "check", "verify", "test"]
|
||||
)]
|
||||
else:
|
||||
# Give all tools to other types
|
||||
return available_tools
|
||||
|
||||
async def _synthesize_results(
|
||||
self,
|
||||
results: Dict[str, Any],
|
||||
task_classification: TaskClassification,
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""Synthesize final response from all subagent results"""
|
||||
# Look for synthesis agent result first
|
||||
for agent_id, result in results.items():
|
||||
if result.get("type") == "synthesis":
|
||||
return result.get("final_response", "")
|
||||
|
||||
# Otherwise, compile results
|
||||
response_parts = []
|
||||
|
||||
# Add results in order of priority
|
||||
for plan_item in sorted(
|
||||
task_classification.subagent_plan,
|
||||
key=lambda x: x.get("priority", 999)
|
||||
):
|
||||
agent_id = plan_item["id"]
|
||||
if agent_id in results:
|
||||
result = results[agent_id]
|
||||
if "error" not in result:
|
||||
content = result.get("output", {}).get("content", "")
|
||||
if content:
|
||||
response_parts.append(content)
|
||||
|
||||
return "\n\n".join(response_parts) if response_parts else "Task completed"
|
||||
|
||||
def _format_previous_results(self, results: Dict[str, Any]) -> str:
|
||||
"""Format previous results for context"""
|
||||
if not results:
|
||||
return "No previous results"
|
||||
|
||||
formatted = []
|
||||
for agent_id, result in results.items():
|
||||
if "error" not in result:
|
||||
formatted.append(f"{agent_id}: {result.get('output', {}).get('content', '')[:200]}")
|
||||
|
||||
return "\n".join(formatted) if formatted else "No valid previous results"
|
||||
|
||||
def _format_all_results(self, results: Dict[str, Any]) -> str:
|
||||
"""Format all results for synthesis"""
|
||||
if not results:
|
||||
return "No results to synthesize"
|
||||
|
||||
formatted = []
|
||||
for agent_id, result in results.items():
|
||||
if "error" not in result:
|
||||
agent_type = result.get("type", "unknown")
|
||||
content = result.get("output", {}).get("content", "")
|
||||
formatted.append(f"[{agent_type}] {agent_id}:\n{content}\n")
|
||||
|
||||
return "\n".join(formatted) if formatted else "No valid results to synthesize"
|
||||
|
||||
def _extract_steps(self, content: str) -> List[str]:
|
||||
"""Extract steps from planning content"""
|
||||
import re
|
||||
steps = []
|
||||
|
||||
# Look for numbered lists
|
||||
pattern = r"(?:^|\n)\s*(?:\d+[\.\)]|\-|\*)\s+(.+)"
|
||||
matches = re.findall(pattern, content)
|
||||
|
||||
for match in matches:
|
||||
steps.append(match.strip())
|
||||
|
||||
return steps
|
||||
|
||||
def _extract_issues(self, content: str) -> List[str]:
|
||||
"""Extract issues from validation content"""
|
||||
import re
|
||||
issues = []
|
||||
|
||||
# Look for issue indicators
|
||||
issue_patterns = [
|
||||
r"(?:issue|problem|error|warning|concern):\s*(.+)",
|
||||
r"(?:^|\n)\s*[\-\*]\s*(?:Issue|Problem|Error):\s*(.+)"
|
||||
]
|
||||
|
||||
for pattern in issue_patterns:
|
||||
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||
issues.extend([m.strip() for m in matches])
|
||||
|
||||
return issues
|
||||
|
||||
def _extract_insights(self, content: str) -> List[str]:
|
||||
"""Extract insights from analysis content"""
|
||||
import re
|
||||
insights = []
|
||||
|
||||
# Look for insight indicators
|
||||
insight_patterns = [
|
||||
r"(?:insight|finding|observation|pattern):\s*(.+)",
|
||||
r"(?:^|\n)\s*\d+[\.\)]\s*(.+(?:shows?|indicates?|suggests?|reveals?).+)"
|
||||
]
|
||||
|
||||
for pattern in insight_patterns:
|
||||
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||
insights.extend([m.strip() for m in matches])
|
||||
|
||||
return insights
|
||||
|
||||
def _calculate_execution_time(self, execution_record: Dict[str, Any]) -> float:
|
||||
"""Calculate execution time in milliseconds"""
|
||||
if "completed_at" in execution_record and "started_at" in execution_record:
|
||||
start = datetime.fromisoformat(execution_record["started_at"])
|
||||
end = datetime.fromisoformat(execution_record["completed_at"])
|
||||
return (end - start).total_seconds() * 1000
|
||||
return 0.0
|
||||
|
||||
|
||||
# Factory function
|
||||
def get_subagent_orchestrator(tenant_domain: str, user_id: str) -> SubagentOrchestrator:
|
||||
"""Get subagent orchestrator instance"""
|
||||
return SubagentOrchestrator(tenant_domain, user_id)
|
||||
854
apps/tenant-backend/app/services/agent_service.py
Normal file
854
apps/tenant-backend/app/services/agent_service.py
Normal file
@@ -0,0 +1,854 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pathlib import Path
|
||||
from app.core.config import get_settings
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.core.permissions import get_user_role, validate_visibility_permission, can_edit_resource, can_delete_resource, is_effective_owner
|
||||
from app.services.category_service import CategoryService
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AgentService:
|
||||
"""GT 2.0 PostgreSQL+PGVector Agent Service with Perfect Tenant Isolation"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str, user_email: str = None):
|
||||
"""Initialize with tenant and user isolation using PostgreSQL+PGVector storage"""
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.user_email = user_email or user_id # Fallback to user_id if no email provided
|
||||
self.settings = get_settings()
|
||||
self._resolved_user_uuid = None # Cache for resolved user UUID (performance optimization)
|
||||
|
||||
logger.info(f"Agent service initialized with PostgreSQL+PGVector for {tenant_domain}/{user_id} (email: {self.user_email})")
|
||||
|
||||
async def _get_resolved_user_uuid(self, user_identifier: Optional[str] = None) -> str:
|
||||
"""
|
||||
Resolve user identifier to UUID with caching for performance.
|
||||
|
||||
This optimization reduces repeated database lookups by caching the resolved UUID.
|
||||
Performance impact: ~50% reduction in query time for operations with multiple queries.
|
||||
Pattern matches conversation_service.py for consistency.
|
||||
"""
|
||||
identifier = user_identifier or self.user_email or self.user_id
|
||||
|
||||
# Return cached UUID if already resolved for this instance
|
||||
if self._resolved_user_uuid and str(identifier) in [str(self.user_email), str(self.user_id)]:
|
||||
return self._resolved_user_uuid
|
||||
|
||||
# Check if already a UUID
|
||||
if "@" not in str(identifier):
|
||||
try:
|
||||
# Validate it's a proper UUID format
|
||||
uuid.UUID(str(identifier))
|
||||
if str(identifier) == str(self.user_id):
|
||||
self._resolved_user_uuid = str(identifier)
|
||||
return str(identifier)
|
||||
except (ValueError, AttributeError):
|
||||
pass # Not a valid UUID, treat as email/username
|
||||
|
||||
# Resolve email to UUID
|
||||
pg_client = await get_postgresql_client()
|
||||
query = """
|
||||
SELECT id FROM users
|
||||
WHERE (email = $1 OR username = $1)
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
LIMIT 1
|
||||
"""
|
||||
result = await pg_client.fetch_one(query, str(identifier), self.tenant_domain)
|
||||
|
||||
if not result:
|
||||
raise ValueError(f"User not found: {identifier}")
|
||||
|
||||
user_uuid = str(result["id"])
|
||||
|
||||
# Cache if this is the service's primary user
|
||||
if str(identifier) in [str(self.user_email), str(self.user_id)]:
|
||||
self._resolved_user_uuid = user_uuid
|
||||
|
||||
return user_uuid
|
||||
|
||||
async def create_agent(
|
||||
self,
|
||||
name: str,
|
||||
agent_type: str = "conversational",
|
||||
prompt_template: str = "",
|
||||
description: str = "",
|
||||
capabilities: Optional[List[str]] = None,
|
||||
access_group: str = "INDIVIDUAL",
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new agent using PostgreSQL+PGVector storage following GT 2.0 principles"""
|
||||
|
||||
try:
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Generate agent ID
|
||||
agent_id = str(uuid.uuid4())
|
||||
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_id = await self._get_resolved_user_uuid()
|
||||
|
||||
logger.info(f"Found user ID: {user_id} for email/id: {self.user_email}/{self.user_id}")
|
||||
|
||||
# Create agent in PostgreSQL
|
||||
query = """
|
||||
INSERT INTO agents (
|
||||
id, name, description, system_prompt,
|
||||
tenant_id, created_by, model, temperature, max_tokens,
|
||||
visibility, configuration, is_active, access_group, agent_type
|
||||
) VALUES (
|
||||
$1, $2, $3, $4,
|
||||
(SELECT id FROM tenants WHERE domain = $5 LIMIT 1),
|
||||
$6,
|
||||
$7, $8, $9, $10, $11, true, $12, $13
|
||||
)
|
||||
RETURNING id, name, description, system_prompt, model, temperature, max_tokens,
|
||||
visibility, configuration, access_group, agent_type, created_at, updated_at
|
||||
"""
|
||||
|
||||
# Prepare configuration with additional kwargs
|
||||
# Ensure list fields are always lists, never None
|
||||
configuration = {
|
||||
"agent_type": agent_type,
|
||||
"capabilities": capabilities or [],
|
||||
"personality_config": kwargs.get("personality_config", {}),
|
||||
"resource_preferences": kwargs.get("resource_preferences", {}),
|
||||
"model_config": kwargs.get("model_config", {}),
|
||||
"tags": kwargs.get("tags") or [],
|
||||
"easy_prompts": kwargs.get("easy_prompts") or [],
|
||||
"selected_dataset_ids": kwargs.get("selected_dataset_ids") or [],
|
||||
**{k: v for k, v in kwargs.items() if k not in ["tags", "easy_prompts", "selected_dataset_ids"]}
|
||||
}
|
||||
|
||||
# Extract model configuration
|
||||
model = kwargs.get("model")
|
||||
if not model:
|
||||
raise ValueError("Model is required for agent creation")
|
||||
temperature = kwargs.get("temperature", 0.7)
|
||||
max_tokens = kwargs.get("max_tokens", 8000) # Increased to match Groq Llama 3.1 capabilities
|
||||
|
||||
# Use access_group as visibility directly (individual, organization only)
|
||||
visibility = access_group.lower()
|
||||
|
||||
# Validate visibility permission based on user role
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
validate_visibility_permission(visibility, user_role)
|
||||
logger.info(f"User {self.user_email} (role: {user_role}) creating agent with visibility: {visibility}")
|
||||
|
||||
# Auto-create category if specified (Issue #215)
|
||||
# This ensures imported agents with unknown categories create those categories
|
||||
# Category is stored in agent_type column
|
||||
category = kwargs.get("category")
|
||||
if category and isinstance(category, str) and category.strip():
|
||||
category_slug = category.strip().lower()
|
||||
try:
|
||||
category_service = CategoryService(self.tenant_domain, user_id, self.user_email)
|
||||
# Pass category_description from CSV import if provided
|
||||
category_description = kwargs.get("category_description")
|
||||
await category_service.get_or_create_category(category_slug, description=category_description)
|
||||
logger.info(f"Ensured category exists: {category}")
|
||||
except Exception as cat_err:
|
||||
logger.warning(f"Failed to ensure category '{category}' exists: {cat_err}")
|
||||
# Continue with agent creation even if category creation fails
|
||||
# Use category as agent_type (they map to the same column)
|
||||
agent_type = category_slug
|
||||
|
||||
agent_data = await pg_client.fetch_one(
|
||||
query,
|
||||
agent_id, name, description, prompt_template,
|
||||
self.tenant_domain, user_id,
|
||||
model, temperature, max_tokens, visibility,
|
||||
json.dumps(configuration), access_group, agent_type
|
||||
)
|
||||
|
||||
if not agent_data:
|
||||
raise RuntimeError("Failed to create agent - no data returned")
|
||||
|
||||
# Convert to dict with proper types
|
||||
# Parse configuration JSON if it's a string
|
||||
config = agent_data["configuration"]
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
elif config is None:
|
||||
config = {}
|
||||
|
||||
result = {
|
||||
"id": str(agent_data["id"]),
|
||||
"name": agent_data["name"],
|
||||
"agent_type": config.get("agent_type", "conversational"),
|
||||
"prompt_template": agent_data["system_prompt"],
|
||||
"description": agent_data["description"],
|
||||
"capabilities": config.get("capabilities", []),
|
||||
"access_group": agent_data["access_group"],
|
||||
"config": config,
|
||||
"model": agent_data["model"],
|
||||
"temperature": float(agent_data["temperature"]) if agent_data["temperature"] is not None else None,
|
||||
"max_tokens": agent_data["max_tokens"],
|
||||
"top_p": config.get("top_p"),
|
||||
"frequency_penalty": config.get("frequency_penalty"),
|
||||
"presence_penalty": config.get("presence_penalty"),
|
||||
"visibility": agent_data["visibility"],
|
||||
"dataset_connection": config.get("dataset_connection"),
|
||||
"selected_dataset_ids": config.get("selected_dataset_ids", []),
|
||||
"max_chunks_per_query": config.get("max_chunks_per_query"),
|
||||
"history_context": config.get("history_context"),
|
||||
"personality_config": config.get("personality_config", {}),
|
||||
"resource_preferences": config.get("resource_preferences", {}),
|
||||
"tags": config.get("tags", []),
|
||||
"is_favorite": config.get("is_favorite", False),
|
||||
"conversation_count": 0,
|
||||
"total_cost_cents": 0,
|
||||
"created_at": agent_data["created_at"].isoformat(),
|
||||
"updated_at": agent_data["updated_at"].isoformat(),
|
||||
"user_id": self.user_id,
|
||||
"tenant_domain": self.tenant_domain
|
||||
}
|
||||
|
||||
logger.info(f"Created agent {agent_id} in PostgreSQL for user {self.user_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent: {e}")
|
||||
raise
|
||||
|
||||
async def get_user_agents(
|
||||
self,
|
||||
active_only: bool = True,
|
||||
sort_by: Optional[str] = None,
|
||||
filter_usage: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all agents for the current user using PostgreSQL storage"""
|
||||
try:
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
try:
|
||||
user_id = await self._get_resolved_user_uuid()
|
||||
except ValueError as e:
|
||||
logger.warning(f"User not found for agents list: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}: {e}")
|
||||
return []
|
||||
|
||||
# Get user role to determine access level
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
is_admin = user_role in ["admin", "developer"]
|
||||
|
||||
# Query agents from PostgreSQL with conversation counts
|
||||
# Admins see ALL agents, others see only their own or organization-level agents
|
||||
if is_admin:
|
||||
where_clause = "WHERE a.tenant_id = (SELECT id FROM tenants WHERE domain = $1)"
|
||||
params = [self.tenant_domain]
|
||||
else:
|
||||
where_clause = "WHERE (a.created_by = $1 OR a.visibility = 'organization') AND a.tenant_id = (SELECT id FROM tenants WHERE domain = $2)"
|
||||
params = [user_id, self.tenant_domain]
|
||||
|
||||
# Prepare user_id parameter for per-user usage tracking
|
||||
# Need to add user_id as an additional parameter for usage calculations
|
||||
user_id_param_index = len(params) + 1
|
||||
params.append(user_id)
|
||||
|
||||
# Per-user usage tracking: Count only conversations for this user
|
||||
query = f"""
|
||||
SELECT
|
||||
a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
|
||||
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
|
||||
a.is_active, a.created_by, a.agent_type,
|
||||
u.full_name as created_by_name,
|
||||
COUNT(CASE WHEN c.user_id = ${user_id_param_index}::uuid THEN c.id END) as user_conversation_count,
|
||||
MAX(CASE WHEN c.user_id = ${user_id_param_index}::uuid THEN c.created_at END) as user_last_used_at
|
||||
FROM agents a
|
||||
LEFT JOIN conversations c ON a.id = c.agent_id
|
||||
LEFT JOIN users u ON a.created_by = u.id
|
||||
{where_clause}
|
||||
"""
|
||||
|
||||
if active_only:
|
||||
query += " AND a.is_active = true"
|
||||
|
||||
# Time-based usage filters (per-user)
|
||||
if filter_usage == "used_last_7_days":
|
||||
query += f" AND EXISTS (SELECT 1 FROM conversations c2 WHERE c2.agent_id = a.id AND c2.user_id = ${user_id_param_index}::uuid AND c2.created_at >= NOW() - INTERVAL '7 days')"
|
||||
elif filter_usage == "used_last_30_days":
|
||||
query += f" AND EXISTS (SELECT 1 FROM conversations c2 WHERE c2.agent_id = a.id AND c2.user_id = ${user_id_param_index}::uuid AND c2.created_at >= NOW() - INTERVAL '30 days')"
|
||||
|
||||
query += " GROUP BY a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens, a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at, a.is_active, a.created_by, a.agent_type, u.full_name"
|
||||
|
||||
# User-specific sorting
|
||||
if sort_by == "recent_usage":
|
||||
query += " ORDER BY user_last_used_at DESC NULLS LAST, a.updated_at DESC"
|
||||
elif sort_by == "my_most_used":
|
||||
query += " ORDER BY user_conversation_count DESC, a.updated_at DESC"
|
||||
else:
|
||||
query += " ORDER BY a.updated_at DESC"
|
||||
|
||||
agents_data = await pg_client.execute_query(query, *params)
|
||||
|
||||
# Convert to proper format
|
||||
agents = []
|
||||
for agent in agents_data:
|
||||
# Debug logging for creator name
|
||||
logger.info(f"🔍 Agent '{agent['name']}': created_by={agent.get('created_by')}, created_by_name={agent.get('created_by_name')}")
|
||||
|
||||
# Parse configuration JSON if it's a string
|
||||
config = agent["configuration"]
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
elif config is None:
|
||||
config = {}
|
||||
|
||||
disclaimer_val = config.get("disclaimer")
|
||||
easy_prompts_val = config.get("easy_prompts", [])
|
||||
logger.info(f"get_user_agents - Agent {agent['name']}: disclaimer={disclaimer_val}, easy_prompts={easy_prompts_val}")
|
||||
|
||||
# Determine if user can edit this agent
|
||||
# User can edit if they created it OR if they're admin/developer
|
||||
# Use cached user_role from line 190 (no need to re-query for each agent)
|
||||
is_owner = is_effective_owner(str(agent["created_by"]), str(user_id), user_role)
|
||||
can_edit = can_edit_resource(str(agent["created_by"]), str(user_id), user_role, agent["visibility"])
|
||||
can_delete = can_delete_resource(str(agent["created_by"]), str(user_id), user_role)
|
||||
|
||||
logger.info(f"Agent {agent['name']}: created_by={agent['created_by']}, user_id={user_id}, user_role={user_role}, is_owner={is_owner}, can_edit={can_edit}, can_delete={can_delete}")
|
||||
|
||||
agents.append({
|
||||
"id": str(agent["id"]),
|
||||
"name": agent["name"],
|
||||
"agent_type": agent["agent_type"] or "conversational",
|
||||
"prompt_template": agent["system_prompt"],
|
||||
"description": agent["description"],
|
||||
"capabilities": config.get("capabilities", []),
|
||||
"access_group": agent["access_group"],
|
||||
"config": config,
|
||||
"model": agent["model"],
|
||||
"temperature": float(agent["temperature"]) if agent["temperature"] is not None else None,
|
||||
"max_tokens": agent["max_tokens"],
|
||||
"visibility": agent["visibility"],
|
||||
"dataset_connection": config.get("dataset_connection"),
|
||||
"selected_dataset_ids": config.get("selected_dataset_ids", []),
|
||||
"personality_config": config.get("personality_config", {}),
|
||||
"resource_preferences": config.get("resource_preferences", {}),
|
||||
"tags": config.get("tags", []),
|
||||
"is_favorite": config.get("is_favorite", False),
|
||||
"disclaimer": disclaimer_val,
|
||||
"easy_prompts": easy_prompts_val,
|
||||
"conversation_count": int(agent["user_conversation_count"]) if agent.get("user_conversation_count") is not None else 0,
|
||||
"last_used_at": agent["user_last_used_at"].isoformat() if agent.get("user_last_used_at") else None,
|
||||
"total_cost_cents": 0,
|
||||
"created_at": agent["created_at"].isoformat() if agent["created_at"] else None,
|
||||
"updated_at": agent["updated_at"].isoformat() if agent["updated_at"] else None,
|
||||
"is_active": agent["is_active"],
|
||||
"user_id": agent["created_by"],
|
||||
"created_by_name": agent.get("created_by_name", "Unknown"),
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"can_edit": can_edit,
|
||||
"can_delete": can_delete,
|
||||
"is_owner": is_owner
|
||||
})
|
||||
|
||||
# Fetch team-shared agents and merge with owned agents
|
||||
team_shared = await self.get_team_shared_agents(user_id)
|
||||
|
||||
# Merge and deduplicate (owned agents take precedence)
|
||||
agent_ids_seen = {agent["id"] for agent in agents}
|
||||
for team_agent in team_shared:
|
||||
if team_agent["id"] not in agent_ids_seen:
|
||||
agents.append(team_agent)
|
||||
agent_ids_seen.add(team_agent["id"])
|
||||
|
||||
logger.info(f"Retrieved {len(agents)} total agents ({len(agents) - len(team_shared)} owned + {len(team_shared)} team-shared) from PostgreSQL for user {self.user_id}")
|
||||
return agents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading agents for user {self.user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_team_shared_agents(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get agents shared to teams where user is a member (via junction table).
|
||||
|
||||
Uses the user_accessible_resources view for efficient lookups.
|
||||
|
||||
Returns agents with permission flags:
|
||||
- can_edit: True if user has 'edit' permission for this agent
|
||||
- can_delete: False (only owner can delete)
|
||||
- is_owner: False (team-shared agents)
|
||||
- shared_via_team: True (indicates team sharing)
|
||||
- shared_in_teams: Number of teams this agent is shared with
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Query agents using the efficient user_accessible_resources view
|
||||
# This view joins team_memberships -> team_resource_shares -> agents
|
||||
# Include per-user usage statistics
|
||||
query = """
|
||||
SELECT DISTINCT
|
||||
a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
|
||||
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
|
||||
a.is_active, a.created_by, a.agent_type,
|
||||
u.full_name as created_by_name,
|
||||
COUNT(DISTINCT CASE WHEN c.user_id = $1::uuid THEN c.id END) as user_conversation_count,
|
||||
MAX(CASE WHEN c.user_id = $1::uuid THEN c.created_at END) as user_last_used_at,
|
||||
uar.best_permission as user_permission,
|
||||
uar.shared_in_teams,
|
||||
uar.team_ids
|
||||
FROM user_accessible_resources uar
|
||||
INNER JOIN agents a ON a.id = uar.resource_id
|
||||
LEFT JOIN users u ON a.created_by = u.id
|
||||
LEFT JOIN conversations c ON a.id = c.agent_id
|
||||
WHERE uar.user_id = $1::uuid
|
||||
AND uar.resource_type = 'agent'
|
||||
AND a.tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
AND a.is_active = true
|
||||
GROUP BY a.id, a.name, a.description, a.system_prompt, a.model, a.temperature,
|
||||
a.max_tokens, a.visibility, a.configuration, a.access_group, a.created_at,
|
||||
a.updated_at, a.is_active, a.created_by, a.agent_type, u.full_name,
|
||||
uar.best_permission, uar.shared_in_teams, uar.team_ids
|
||||
ORDER BY a.updated_at DESC
|
||||
"""
|
||||
|
||||
agents_data = await pg_client.execute_query(query, user_id, self.tenant_domain)
|
||||
|
||||
# Format agents with team sharing metadata
|
||||
agents = []
|
||||
for agent in agents_data:
|
||||
# Parse configuration JSON
|
||||
config = agent["configuration"]
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
elif config is None:
|
||||
config = {}
|
||||
|
||||
# Get permission from view (will be "read" or "edit")
|
||||
user_permission = agent.get("user_permission")
|
||||
can_edit = user_permission == "edit"
|
||||
|
||||
# Get team sharing metadata
|
||||
shared_in_teams = agent.get("shared_in_teams", 0)
|
||||
team_ids = agent.get("team_ids", [])
|
||||
|
||||
agents.append({
|
||||
"id": str(agent["id"]),
|
||||
"name": agent["name"],
|
||||
"agent_type": agent["agent_type"] or "conversational",
|
||||
"prompt_template": agent["system_prompt"],
|
||||
"description": agent["description"],
|
||||
"capabilities": config.get("capabilities", []),
|
||||
"access_group": agent["access_group"],
|
||||
"config": config,
|
||||
"model": agent["model"],
|
||||
"temperature": float(agent["temperature"]) if agent["temperature"] is not None else None,
|
||||
"max_tokens": agent["max_tokens"],
|
||||
"visibility": agent["visibility"],
|
||||
"dataset_connection": config.get("dataset_connection"),
|
||||
"selected_dataset_ids": config.get("selected_dataset_ids", []),
|
||||
"personality_config": config.get("personality_config", {}),
|
||||
"resource_preferences": config.get("resource_preferences", {}),
|
||||
"tags": config.get("tags", []),
|
||||
"is_favorite": config.get("is_favorite", False),
|
||||
"disclaimer": config.get("disclaimer"),
|
||||
"easy_prompts": config.get("easy_prompts", []),
|
||||
"conversation_count": int(agent["user_conversation_count"]) if agent.get("user_conversation_count") else 0,
|
||||
"last_used_at": agent["user_last_used_at"].isoformat() if agent.get("user_last_used_at") else None,
|
||||
"total_cost_cents": 0,
|
||||
"created_at": agent["created_at"].isoformat() if agent["created_at"] else None,
|
||||
"updated_at": agent["updated_at"].isoformat() if agent["updated_at"] else None,
|
||||
"is_active": agent["is_active"],
|
||||
"user_id": agent["created_by"],
|
||||
"created_by_name": agent.get("created_by_name", "Unknown"),
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"can_edit": can_edit,
|
||||
"can_delete": False, # Only owner can delete
|
||||
"is_owner": False, # Team-shared agents
|
||||
"shared_via_team": True,
|
||||
"shared_in_teams": shared_in_teams,
|
||||
"team_ids": [str(tid) for tid in team_ids] if team_ids else [],
|
||||
"team_permission": user_permission
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(agents)} team-shared agents for user {user_id}")
|
||||
return agents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching team-shared agents for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a specific agent by ID using PostgreSQL"""
|
||||
try:
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
try:
|
||||
user_id = await self._get_resolved_user_uuid()
|
||||
except ValueError as e:
|
||||
logger.warning(f"User not found: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}: {e}")
|
||||
return None
|
||||
|
||||
# Check if user is admin - admins can see all agents
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
is_admin = user_role in ["admin", "developer"]
|
||||
|
||||
# Query the agent first
|
||||
query = """
|
||||
SELECT
|
||||
a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
|
||||
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
|
||||
a.is_active, a.created_by, a.agent_type,
|
||||
COUNT(c.id) as conversation_count
|
||||
FROM agents a
|
||||
LEFT JOIN conversations c ON a.id = c.agent_id
|
||||
WHERE a.id = $1 AND a.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
GROUP BY a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
|
||||
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
|
||||
a.is_active, a.created_by, a.agent_type
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
agent_data = await pg_client.fetch_one(query, agent_id, self.tenant_domain)
|
||||
logger.info(f"Agent query result: {agent_data is not None}")
|
||||
|
||||
# If agent doesn't exist, return None
|
||||
if not agent_data:
|
||||
return None
|
||||
|
||||
# Check access: admin, owner, organization, or team-based
|
||||
if not is_admin:
|
||||
is_owner = str(agent_data["created_by"]) == str(user_id)
|
||||
is_org_wide = agent_data["visibility"] == "organization"
|
||||
|
||||
# Check team-based access if not owner or org-wide
|
||||
if not is_owner and not is_org_wide:
|
||||
# Import TeamService here to avoid circular dependency
|
||||
from app.services.team_service import TeamService
|
||||
team_service = TeamService(self.tenant_domain, str(user_id), self.user_email)
|
||||
|
||||
has_team_access = await team_service.check_user_resource_permission(
|
||||
user_id=str(user_id),
|
||||
resource_type="agent",
|
||||
resource_id=agent_id,
|
||||
required_permission="read"
|
||||
)
|
||||
|
||||
if not has_team_access:
|
||||
logger.warning(f"User {user_id} denied access to agent {agent_id}")
|
||||
return None
|
||||
|
||||
logger.info(f"User {user_id} has team-based access to agent {agent_id}")
|
||||
|
||||
if agent_data:
|
||||
# Parse configuration JSON if it's a string
|
||||
config = agent_data["configuration"]
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
elif config is None:
|
||||
config = {}
|
||||
|
||||
# Convert to proper format
|
||||
logger.info(f"Config disclaimer: {config.get('disclaimer')}, easy_prompts: {config.get('easy_prompts')}")
|
||||
|
||||
# Compute is_owner for export permission checks
|
||||
is_owner = str(agent_data["created_by"]) == str(user_id)
|
||||
|
||||
result = {
|
||||
"id": str(agent_data["id"]),
|
||||
"name": agent_data["name"],
|
||||
"agent_type": agent_data["agent_type"] or "conversational",
|
||||
"prompt_template": agent_data["system_prompt"],
|
||||
"description": agent_data["description"],
|
||||
"capabilities": config.get("capabilities", []),
|
||||
"access_group": agent_data["access_group"],
|
||||
"config": config,
|
||||
"model": agent_data["model"],
|
||||
"temperature": float(agent_data["temperature"]) if agent_data["temperature"] is not None else None,
|
||||
"max_tokens": agent_data["max_tokens"],
|
||||
"visibility": agent_data["visibility"],
|
||||
"dataset_connection": config.get("dataset_connection"),
|
||||
"selected_dataset_ids": config.get("selected_dataset_ids", []),
|
||||
"personality_config": config.get("personality_config", {}),
|
||||
"resource_preferences": config.get("resource_preferences", {}),
|
||||
"tags": config.get("tags", []),
|
||||
"is_favorite": config.get("is_favorite", False),
|
||||
"disclaimer": config.get("disclaimer"),
|
||||
"easy_prompts": config.get("easy_prompts", []),
|
||||
"conversation_count": int(agent_data["conversation_count"]) if agent_data.get("conversation_count") is not None else 0,
|
||||
"total_cost_cents": 0,
|
||||
"created_at": agent_data["created_at"].isoformat() if agent_data["created_at"] else None,
|
||||
"updated_at": agent_data["updated_at"].isoformat() if agent_data["updated_at"] else None,
|
||||
"is_active": agent_data["is_active"],
|
||||
"created_by": agent_data["created_by"], # Keep DB field
|
||||
"user_id": agent_data["created_by"], # Alias for compatibility
|
||||
"is_owner": is_owner, # Computed ownership for export/edit permissions
|
||||
"tenant_domain": self.tenant_domain
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading agent {agent_id}: {e}")
|
||||
return None
|
||||
|
||||
async def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
updates: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Update an agent's configuration using PostgreSQL with permission checks"""
|
||||
try:
|
||||
logger.info(f"Processing updates for agent {agent_id}: {updates}")
|
||||
|
||||
# Log which fields will be processed
|
||||
logger.info(f"Update fields being processed: {list(updates.keys())}")
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get user role for permission checks
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
|
||||
# If updating visibility, validate permission
|
||||
if "visibility" in updates:
|
||||
validate_visibility_permission(updates["visibility"], user_role)
|
||||
logger.info(f"User {self.user_email} (role: {user_role}) updating agent visibility to: {updates['visibility']}")
|
||||
|
||||
# Build dynamic UPDATE query based on provided updates
|
||||
set_clauses = []
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
# Collect all configuration updates in a single object
|
||||
config_updates = {}
|
||||
|
||||
# Handle each update field mapping to correct column names
|
||||
for field, value in updates.items():
|
||||
if field in ["name", "description", "access_group"]:
|
||||
set_clauses.append(f"{field} = ${param_idx}")
|
||||
params.append(value)
|
||||
param_idx += 1
|
||||
elif field == "prompt_template":
|
||||
set_clauses.append(f"system_prompt = ${param_idx}")
|
||||
params.append(value)
|
||||
param_idx += 1
|
||||
elif field in ["model", "temperature", "max_tokens", "visibility", "agent_type"]:
|
||||
set_clauses.append(f"{field} = ${param_idx}")
|
||||
params.append(value)
|
||||
param_idx += 1
|
||||
elif field == "is_active":
|
||||
set_clauses.append(f"is_active = ${param_idx}")
|
||||
params.append(value)
|
||||
param_idx += 1
|
||||
elif field in ["config", "configuration", "personality_config", "resource_preferences", "tags", "is_favorite",
|
||||
"dataset_connection", "selected_dataset_ids", "disclaimer", "easy_prompts"]:
|
||||
# Collect configuration updates
|
||||
if field in ["config", "configuration"]:
|
||||
config_updates.update(value if isinstance(value, dict) else {})
|
||||
else:
|
||||
config_updates[field] = value
|
||||
|
||||
# Apply configuration updates as a single operation
|
||||
if config_updates:
|
||||
set_clauses.append(f"configuration = configuration || ${param_idx}::jsonb")
|
||||
params.append(json.dumps(config_updates))
|
||||
param_idx += 1
|
||||
|
||||
if not set_clauses:
|
||||
logger.warning(f"No valid update fields provided for agent {agent_id}")
|
||||
return await self.get_agent(agent_id)
|
||||
|
||||
# Add updated_at timestamp
|
||||
set_clauses.append(f"updated_at = NOW()")
|
||||
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
try:
|
||||
user_id = await self._get_resolved_user_uuid()
|
||||
except ValueError as e:
|
||||
logger.warning(f"User not found for update: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}: {e}")
|
||||
return None
|
||||
|
||||
# Check if user is admin - admins can update any agent
|
||||
is_admin = user_role in ["admin", "developer"]
|
||||
|
||||
# Build final query - admins can update any agent in tenant, others only their own
|
||||
if is_admin:
|
||||
query = f"""
|
||||
UPDATE agents
|
||||
SET {', '.join(set_clauses)}
|
||||
WHERE id = ${param_idx}
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = ${param_idx + 1})
|
||||
RETURNING id
|
||||
"""
|
||||
params.extend([agent_id, self.tenant_domain])
|
||||
else:
|
||||
query = f"""
|
||||
UPDATE agents
|
||||
SET {', '.join(set_clauses)}
|
||||
WHERE id = ${param_idx}
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = ${param_idx + 1})
|
||||
AND created_by = ${param_idx + 2}
|
||||
RETURNING id
|
||||
"""
|
||||
params.extend([agent_id, self.tenant_domain, user_id])
|
||||
|
||||
# Execute update
|
||||
logger.info(f"Executing update query: {query}")
|
||||
logger.info(f"Query parameters: {params}")
|
||||
updated_id = await pg_client.fetch_scalar(query, *params)
|
||||
logger.info(f"Update result: {updated_id}")
|
||||
|
||||
if updated_id:
|
||||
# Get updated agent data
|
||||
updated_agent = await self.get_agent(agent_id)
|
||||
|
||||
logger.info(f"Updated agent {agent_id} in PostgreSQL")
|
||||
return updated_agent
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating agent {agent_id}: {e}")
|
||||
return None
|
||||
|
||||
async def delete_agent(self, agent_id: str) -> bool:
|
||||
"""Soft delete an agent using PostgreSQL"""
|
||||
try:
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get user role to check if admin
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
is_admin = user_role in ["admin", "developer"]
|
||||
|
||||
# Soft delete in PostgreSQL - admins can delete any agent, others only their own
|
||||
if is_admin:
|
||||
query = """
|
||||
UPDATE agents
|
||||
SET is_active = false, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
RETURNING id
|
||||
"""
|
||||
deleted_id = await pg_client.fetch_scalar(query, agent_id, self.tenant_domain)
|
||||
else:
|
||||
query = """
|
||||
UPDATE agents
|
||||
SET is_active = false, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
AND created_by = (SELECT id FROM users WHERE email = $3)
|
||||
RETURNING id
|
||||
"""
|
||||
deleted_id = await pg_client.fetch_scalar(query, agent_id, self.tenant_domain, self.user_email or self.user_id)
|
||||
|
||||
if deleted_id:
|
||||
logger.info(f"Deleted agent {agent_id} from PostgreSQL")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting agent {agent_id}: {e}")
|
||||
return False
|
||||
|
||||
async def check_access_permission(self, agent_id: str, requesting_user_id: str, access_type: str = "read") -> bool:
|
||||
"""
|
||||
Check if user has access to agent (via ownership, organization, or team).
|
||||
|
||||
Args:
|
||||
agent_id: UUID of the agent
|
||||
requesting_user_id: UUID of the user requesting access
|
||||
access_type: 'read' or 'edit' (default: 'read')
|
||||
|
||||
Returns:
|
||||
True if user has required access
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Check if admin/developer
|
||||
user_role = await get_user_role(pg_client, requesting_user_id, self.tenant_domain)
|
||||
if user_role in ["admin", "developer"]:
|
||||
return True
|
||||
|
||||
# Get agent to check ownership and visibility
|
||||
query = """
|
||||
SELECT created_by, visibility
|
||||
FROM agents
|
||||
WHERE id = $1 AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
"""
|
||||
agent_data = await pg_client.fetch_one(query, agent_id, self.tenant_domain)
|
||||
|
||||
if not agent_data:
|
||||
return False
|
||||
|
||||
owner_id = str(agent_data["created_by"])
|
||||
visibility = agent_data["visibility"]
|
||||
|
||||
# Owner has full access
|
||||
if requesting_user_id == owner_id:
|
||||
return True
|
||||
|
||||
# Organization-wide resources are accessible to all in tenant
|
||||
if visibility == "organization":
|
||||
return True
|
||||
|
||||
# Check team-based access
|
||||
from app.services.team_service import TeamService
|
||||
team_service = TeamService(self.tenant_domain, requesting_user_id, requesting_user_id)
|
||||
|
||||
return await team_service.check_user_resource_permission(
|
||||
user_id=requesting_user_id,
|
||||
resource_type="agent",
|
||||
resource_id=agent_id,
|
||||
required_permission=access_type
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking access permission for agent {agent_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _check_team_membership(self, user_id: str, team_members: List[str]) -> bool:
|
||||
"""Check if user is in the team members list"""
|
||||
return user_id in team_members
|
||||
|
||||
async def _check_same_tenant(self, user_id: str) -> bool:
|
||||
"""Check if requesting user is in the same tenant through PostgreSQL"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Check if user exists in same tenant
|
||||
query = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM users
|
||||
WHERE id = $1 AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
"""
|
||||
|
||||
result = await pg_client.fetch_one(query, user_id, self.tenant_domain)
|
||||
return result and result["count"] > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check tenant membership for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_agent_conversation_history(self, agent_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get conversation history for an agent (file-based)"""
|
||||
conversations_path = Path(f"/data/{self.tenant_domain}/users/{self.user_id}/conversations")
|
||||
conversations_path.mkdir(parents=True, exist_ok=True, mode=0o700)
|
||||
|
||||
conversations = []
|
||||
try:
|
||||
for conv_file in conversations_path.glob("*.json"):
|
||||
with open(conv_file, 'r') as f:
|
||||
conv_data = json.load(f)
|
||||
if conv_data.get("agent_id") == agent_id:
|
||||
conversations.append(conv_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading conversations for agent {agent_id}: {e}")
|
||||
|
||||
conversations.sort(key=lambda x: x.get("updated_at", ""), reverse=True)
|
||||
return conversations
|
||||
493
apps/tenant-backend/app/services/assistant_builder.py
Normal file
493
apps/tenant-backend/app/services/assistant_builder.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""
|
||||
Assistant Builder Service for GT 2.0
|
||||
|
||||
Manages assistant creation, deployment, and lifecycle.
|
||||
Integrates with template library and file-based storage.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import stat
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from app.models.assistant_template import (
|
||||
AssistantTemplate, AssistantInstance, AssistantBuilder,
|
||||
AssistantType, PersonalityConfig, ResourcePreferences, MemorySettings,
|
||||
AssistantTemplateLibrary, BUILTIN_TEMPLATES
|
||||
)
|
||||
from app.models.access_group import AccessGroup
|
||||
from app.core.security import verify_capability_token
|
||||
from app.services.access_controller import AccessController
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssistantBuilderService:
|
||||
"""
|
||||
Service for building and managing assistants
|
||||
Handles both template-based and custom assistant creation
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, resource_cluster_url: str = "http://resource-cluster:8004"):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.base_path = Path(f"/data/{tenant_domain}/assistants")
|
||||
self.template_library = AssistantTemplateLibrary(resource_cluster_url)
|
||||
self.access_controller = AccessController(tenant_domain)
|
||||
self._ensure_directories()
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure assistant directories exist with proper permissions"""
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(self.base_path, stat.S_IRWXU) # 700
|
||||
|
||||
# Create subdirectories
|
||||
for subdir in ["templates", "instances", "shared"]:
|
||||
path = self.base_path / subdir
|
||||
path.mkdir(exist_ok=True)
|
||||
os.chmod(path, stat.S_IRWXU) # 700
|
||||
|
||||
async def create_from_template(
|
||||
self,
|
||||
template_id: str,
|
||||
user_id: str,
|
||||
instance_name: str,
|
||||
customizations: Optional[Dict[str, Any]] = None,
|
||||
capability_token: str = None
|
||||
) -> AssistantInstance:
|
||||
"""
|
||||
Create assistant instance from template
|
||||
|
||||
Args:
|
||||
template_id: Template to use
|
||||
user_id: User creating the assistant
|
||||
instance_name: Name for the instance
|
||||
customizations: Optional customizations
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
Created assistant instance
|
||||
"""
|
||||
# Verify capability token
|
||||
if capability_token:
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Deploy from template
|
||||
instance = await self.template_library.deploy_template(
|
||||
template_id=template_id,
|
||||
user_id=user_id,
|
||||
instance_name=instance_name,
|
||||
tenant_domain=self.tenant_domain,
|
||||
customizations=customizations
|
||||
)
|
||||
|
||||
# Create file storage
|
||||
await self._create_assistant_files(instance)
|
||||
|
||||
# Save to database (would be SQLite in production)
|
||||
await self._save_assistant(instance)
|
||||
|
||||
logger.info(f"Created assistant {instance.id} from template {template_id} for {user_id}")
|
||||
|
||||
return instance
|
||||
|
||||
async def create_custom(
|
||||
self,
|
||||
builder_config: AssistantBuilder,
|
||||
user_id: str,
|
||||
capability_token: str = None
|
||||
) -> AssistantInstance:
|
||||
"""
|
||||
Create custom assistant from builder configuration
|
||||
|
||||
Args:
|
||||
builder_config: Custom assistant configuration
|
||||
user_id: User creating the assistant
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
Created assistant instance
|
||||
"""
|
||||
# Verify capability token
|
||||
if capability_token:
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Check if user has required capabilities
|
||||
user_capabilities = token_data.get("capabilities", [])
|
||||
for required_cap in builder_config.requested_capabilities:
|
||||
if not any(required_cap in cap.get("resource", "") for cap in user_capabilities):
|
||||
raise PermissionError(f"Missing capability: {required_cap}")
|
||||
|
||||
# Build instance
|
||||
instance = builder_config.build_instance(user_id, self.tenant_domain)
|
||||
|
||||
# Create file storage
|
||||
await self._create_assistant_files(instance)
|
||||
|
||||
# Save to database
|
||||
await self._save_assistant(instance)
|
||||
|
||||
logger.info(f"Created custom assistant {instance.id} for {user_id}")
|
||||
|
||||
return instance
|
||||
|
||||
async def get_assistant(
|
||||
self,
|
||||
assistant_id: str,
|
||||
user_id: str
|
||||
) -> Optional[AssistantInstance]:
|
||||
"""
|
||||
Get assistant instance by ID
|
||||
|
||||
Args:
|
||||
assistant_id: Assistant ID
|
||||
user_id: User requesting the assistant
|
||||
|
||||
Returns:
|
||||
Assistant instance if found and accessible
|
||||
"""
|
||||
# Load assistant
|
||||
instance = await self._load_assistant(assistant_id)
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
# Check access permission
|
||||
allowed, _ = await self.access_controller.check_permission(
|
||||
user_id, instance, "read"
|
||||
)
|
||||
if not allowed:
|
||||
return None
|
||||
|
||||
return instance
|
||||
|
||||
async def list_user_assistants(
|
||||
self,
|
||||
user_id: str,
|
||||
include_shared: bool = True
|
||||
) -> List[AssistantInstance]:
|
||||
"""
|
||||
List all assistants accessible to user
|
||||
|
||||
Args:
|
||||
user_id: User to list assistants for
|
||||
include_shared: Include team/org shared assistants
|
||||
|
||||
Returns:
|
||||
List of accessible assistants
|
||||
"""
|
||||
assistants = []
|
||||
|
||||
# Get owned assistants
|
||||
owned = await self._get_owned_assistants(user_id)
|
||||
assistants.extend(owned)
|
||||
|
||||
# Get shared assistants if requested
|
||||
if include_shared:
|
||||
shared = await self._get_shared_assistants(user_id)
|
||||
assistants.extend(shared)
|
||||
|
||||
return assistants
|
||||
|
||||
async def update_assistant(
|
||||
self,
|
||||
assistant_id: str,
|
||||
user_id: str,
|
||||
updates: Dict[str, Any]
|
||||
) -> AssistantInstance:
|
||||
"""
|
||||
Update assistant configuration
|
||||
|
||||
Args:
|
||||
assistant_id: Assistant to update
|
||||
user_id: User requesting update
|
||||
updates: Configuration updates
|
||||
|
||||
Returns:
|
||||
Updated assistant instance
|
||||
"""
|
||||
# Load assistant
|
||||
instance = await self._load_assistant(assistant_id)
|
||||
if not instance:
|
||||
raise ValueError(f"Assistant not found: {assistant_id}")
|
||||
|
||||
# Check permission
|
||||
if instance.owner_id != user_id:
|
||||
raise PermissionError("Only owner can update assistant")
|
||||
|
||||
# Apply updates
|
||||
if "personality" in updates:
|
||||
instance.personality_config = PersonalityConfig(**updates["personality"])
|
||||
if "resources" in updates:
|
||||
instance.resource_preferences = ResourcePreferences(**updates["resources"])
|
||||
if "memory" in updates:
|
||||
instance.memory_settings = MemorySettings(**updates["memory"])
|
||||
if "system_prompt" in updates:
|
||||
instance.system_prompt = updates["system_prompt"]
|
||||
|
||||
instance.updated_at = datetime.utcnow()
|
||||
|
||||
# Save changes
|
||||
await self._save_assistant(instance)
|
||||
await self._update_assistant_files(instance)
|
||||
|
||||
logger.info(f"Updated assistant {assistant_id} by {user_id}")
|
||||
|
||||
return instance
|
||||
|
||||
async def share_assistant(
|
||||
self,
|
||||
assistant_id: str,
|
||||
user_id: str,
|
||||
access_group: AccessGroup,
|
||||
team_members: Optional[List[str]] = None
|
||||
) -> AssistantInstance:
|
||||
"""
|
||||
Share assistant with team or organization
|
||||
|
||||
Args:
|
||||
assistant_id: Assistant to share
|
||||
user_id: User sharing (must be owner)
|
||||
access_group: New access level
|
||||
team_members: Team members if team access
|
||||
|
||||
Returns:
|
||||
Updated assistant instance
|
||||
"""
|
||||
# Load assistant
|
||||
instance = await self._load_assistant(assistant_id)
|
||||
if not instance:
|
||||
raise ValueError(f"Assistant not found: {assistant_id}")
|
||||
|
||||
# Check ownership
|
||||
if instance.owner_id != user_id:
|
||||
raise PermissionError("Only owner can share assistant")
|
||||
|
||||
# Update access
|
||||
instance.access_group = access_group
|
||||
if access_group == AccessGroup.TEAM:
|
||||
instance.team_members = team_members or []
|
||||
else:
|
||||
instance.team_members = []
|
||||
|
||||
instance.updated_at = datetime.utcnow()
|
||||
|
||||
# Save changes
|
||||
await self._save_assistant(instance)
|
||||
|
||||
logger.info(f"Shared assistant {assistant_id} with {access_group.value} by {user_id}")
|
||||
|
||||
return instance
|
||||
|
||||
async def delete_assistant(
|
||||
self,
|
||||
assistant_id: str,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete assistant and its files
|
||||
|
||||
Args:
|
||||
assistant_id: Assistant to delete
|
||||
user_id: User requesting deletion
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
# Load assistant
|
||||
instance = await self._load_assistant(assistant_id)
|
||||
if not instance:
|
||||
return False
|
||||
|
||||
# Check ownership
|
||||
if instance.owner_id != user_id:
|
||||
raise PermissionError("Only owner can delete assistant")
|
||||
|
||||
# Delete files
|
||||
await self._delete_assistant_files(instance)
|
||||
|
||||
# Delete from database
|
||||
await self._delete_assistant_record(assistant_id)
|
||||
|
||||
logger.info(f"Deleted assistant {assistant_id} by {user_id}")
|
||||
|
||||
return True
|
||||
|
||||
async def get_assistant_statistics(
|
||||
self,
|
||||
assistant_id: str,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get usage statistics for assistant
|
||||
|
||||
Args:
|
||||
assistant_id: Assistant ID
|
||||
user_id: User requesting stats
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
# Load assistant
|
||||
instance = await self.get_assistant(assistant_id, user_id)
|
||||
if not instance:
|
||||
raise ValueError(f"Assistant not found or not accessible: {assistant_id}")
|
||||
|
||||
return {
|
||||
"assistant_id": assistant_id,
|
||||
"name": instance.name,
|
||||
"created_at": instance.created_at.isoformat(),
|
||||
"last_used": instance.last_used.isoformat() if instance.last_used else None,
|
||||
"conversation_count": instance.conversation_count,
|
||||
"total_messages": instance.total_messages,
|
||||
"total_tokens_used": instance.total_tokens_used,
|
||||
"access_group": instance.access_group.value,
|
||||
"team_members_count": len(instance.team_members),
|
||||
"linked_datasets_count": len(instance.linked_datasets),
|
||||
"linked_tools_count": len(instance.linked_tools)
|
||||
}
|
||||
|
||||
async def _create_assistant_files(self, instance: AssistantInstance):
|
||||
"""Create file structure for assistant"""
|
||||
# Get file paths
|
||||
file_structure = instance.get_file_structure()
|
||||
|
||||
# Create directories
|
||||
for key, path in file_structure.items():
|
||||
if key in ["memory", "resources"]:
|
||||
# These are directories
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(Path(path), stat.S_IRWXU) # 700
|
||||
else:
|
||||
# These are files
|
||||
parent = Path(path).parent
|
||||
parent.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(parent, stat.S_IRWXU) # 700
|
||||
|
||||
# Save configuration
|
||||
config_path = Path(file_structure["config"])
|
||||
config_data = {
|
||||
"id": instance.id,
|
||||
"name": instance.name,
|
||||
"template_id": instance.template_id,
|
||||
"personality": instance.personality_config.model_dump(),
|
||||
"resources": instance.resource_preferences.model_dump(),
|
||||
"memory": instance.memory_settings.model_dump(),
|
||||
"created_at": instance.created_at.isoformat(),
|
||||
"updated_at": instance.updated_at.isoformat()
|
||||
}
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
os.chmod(config_path, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
# Save prompt
|
||||
prompt_path = Path(file_structure["prompt"])
|
||||
with open(prompt_path, 'w') as f:
|
||||
f.write(instance.system_prompt)
|
||||
os.chmod(prompt_path, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
# Save capabilities
|
||||
capabilities_path = Path(file_structure["capabilities"])
|
||||
with open(capabilities_path, 'w') as f:
|
||||
json.dump(instance.capabilities, f, indent=2)
|
||||
os.chmod(capabilities_path, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
# Update instance with file paths
|
||||
instance.config_file_path = str(config_path)
|
||||
instance.memory_file_path = str(Path(file_structure["memory"]))
|
||||
|
||||
async def _update_assistant_files(self, instance: AssistantInstance):
|
||||
"""Update assistant files with current configuration"""
|
||||
if instance.config_file_path:
|
||||
config_data = {
|
||||
"id": instance.id,
|
||||
"name": instance.name,
|
||||
"template_id": instance.template_id,
|
||||
"personality": instance.personality_config.model_dump(),
|
||||
"resources": instance.resource_preferences.model_dump(),
|
||||
"memory": instance.memory_settings.model_dump(),
|
||||
"created_at": instance.created_at.isoformat(),
|
||||
"updated_at": instance.updated_at.isoformat()
|
||||
}
|
||||
|
||||
with open(instance.config_file_path, 'w') as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
async def _delete_assistant_files(self, instance: AssistantInstance):
|
||||
"""Delete assistant file structure"""
|
||||
file_structure = instance.get_file_structure()
|
||||
base_dir = Path(file_structure["config"]).parent
|
||||
|
||||
if base_dir.exists():
|
||||
import shutil
|
||||
shutil.rmtree(base_dir)
|
||||
logger.info(f"Deleted assistant files at {base_dir}")
|
||||
|
||||
async def _save_assistant(self, instance: AssistantInstance):
|
||||
"""Save assistant to database (SQLite in production)"""
|
||||
# This would save to SQLite database
|
||||
# For now, we'll save to a JSON file as placeholder
|
||||
db_file = self.base_path / "instances" / f"{instance.id}.json"
|
||||
with open(db_file, 'w') as f:
|
||||
json.dump(instance.model_dump(mode='json'), f, indent=2, default=str)
|
||||
os.chmod(db_file, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
async def _load_assistant(self, assistant_id: str) -> Optional[AssistantInstance]:
|
||||
"""Load assistant from database"""
|
||||
db_file = self.base_path / "instances" / f"{assistant_id}.json"
|
||||
if not db_file.exists():
|
||||
return None
|
||||
|
||||
with open(db_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Convert datetime strings back to datetime objects
|
||||
for field in ['created_at', 'updated_at', 'last_used']:
|
||||
if field in data and data[field]:
|
||||
data[field] = datetime.fromisoformat(data[field])
|
||||
|
||||
return AssistantInstance(**data)
|
||||
|
||||
async def _delete_assistant_record(self, assistant_id: str):
|
||||
"""Delete assistant from database"""
|
||||
db_file = self.base_path / "instances" / f"{assistant_id}.json"
|
||||
if db_file.exists():
|
||||
db_file.unlink()
|
||||
|
||||
async def _get_owned_assistants(self, user_id: str) -> List[AssistantInstance]:
|
||||
"""Get assistants owned by user"""
|
||||
assistants = []
|
||||
instances_dir = self.base_path / "instances"
|
||||
|
||||
if instances_dir.exists():
|
||||
for file in instances_dir.glob("*.json"):
|
||||
instance = await self._load_assistant(file.stem)
|
||||
if instance and instance.owner_id == user_id:
|
||||
assistants.append(instance)
|
||||
|
||||
return assistants
|
||||
|
||||
async def _get_shared_assistants(self, user_id: str) -> List[AssistantInstance]:
|
||||
"""Get assistants shared with user"""
|
||||
assistants = []
|
||||
instances_dir = self.base_path / "instances"
|
||||
|
||||
if instances_dir.exists():
|
||||
for file in instances_dir.glob("*.json"):
|
||||
instance = await self._load_assistant(file.stem)
|
||||
if instance and instance.owner_id != user_id:
|
||||
# Check if user has access
|
||||
allowed, _ = await self.access_controller.check_permission(
|
||||
user_id, instance, "read"
|
||||
)
|
||||
if allowed:
|
||||
assistants.append(instance)
|
||||
|
||||
return assistants
|
||||
599
apps/tenant-backend/app/services/assistant_manager.py
Normal file
599
apps/tenant-backend/app/services/assistant_manager.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""
|
||||
AssistantManager Service for GT 2.0 Tenant Backend
|
||||
|
||||
File-based agent lifecycle management with perfect tenant isolation.
|
||||
Implements the core Agent System specification from CLAUDE.md.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func, desc
|
||||
from sqlalchemy.orm import selectinload
|
||||
import logging
|
||||
|
||||
from app.models.agent import Agent
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.message import Message
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssistantManager:
|
||||
"""File-based agent lifecycle management"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.settings = get_settings()
|
||||
|
||||
async def create_from_template(self, template_id: str, config: Dict[str, Any], user_identifier: str) -> str:
|
||||
"""Create agent from template or custom config"""
|
||||
try:
|
||||
# Get template configuration
|
||||
template_config = await self._load_template_config(template_id)
|
||||
|
||||
# Merge template config with user overrides
|
||||
merged_config = {**template_config, **config}
|
||||
|
||||
# Create agent record
|
||||
agent = Agent(
|
||||
name=merged_config.get("name", f"Agent from {template_id}"),
|
||||
description=merged_config.get("description", f"Created from template: {template_id}"),
|
||||
template_id=template_id,
|
||||
created_by=user_identifier,
|
||||
user_name=merged_config.get("user_name"),
|
||||
personality_config=merged_config.get("personality_config", {}),
|
||||
resource_preferences=merged_config.get("resource_preferences", {}),
|
||||
memory_settings=merged_config.get("memory_settings", {}),
|
||||
tags=merged_config.get("tags", []),
|
||||
)
|
||||
|
||||
# Initialize with placeholder paths first
|
||||
agent.config_file_path = "placeholder"
|
||||
agent.prompt_file_path = "placeholder"
|
||||
agent.capabilities_file_path = "placeholder"
|
||||
|
||||
# Save to database first to get ID and UUID
|
||||
self.db.add(agent)
|
||||
await self.db.flush() # Flush to get the generated UUID without committing
|
||||
|
||||
# Now we can initialize proper file paths with the UUID
|
||||
agent.initialize_file_paths()
|
||||
|
||||
# Create file system structure
|
||||
await self._setup_assistant_files(agent, merged_config)
|
||||
|
||||
# Commit all changes
|
||||
await self.db.commit()
|
||||
await self.db.refresh(agent)
|
||||
|
||||
logger.info(
|
||||
f"Created agent from template",
|
||||
extra={
|
||||
"agent_id": agent.id,
|
||||
"assistant_uuid": agent.uuid,
|
||||
"template_id": template_id,
|
||||
"created_by": user_identifier,
|
||||
}
|
||||
)
|
||||
|
||||
return str(agent.uuid)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent from template: {e}", exc_info=True)
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def create_custom_assistant(self, config: Dict[str, Any], user_identifier: str) -> str:
|
||||
"""Create custom agent without template"""
|
||||
try:
|
||||
# Validate required fields
|
||||
if not config.get("name"):
|
||||
raise ValueError("Agent name is required")
|
||||
|
||||
# Create agent record
|
||||
agent = Agent(
|
||||
name=config["name"],
|
||||
description=config.get("description", "Custom AI agent"),
|
||||
template_id=None, # No template used
|
||||
created_by=user_identifier,
|
||||
user_name=config.get("user_name"),
|
||||
personality_config=config.get("personality_config", {}),
|
||||
resource_preferences=config.get("resource_preferences", {}),
|
||||
memory_settings=config.get("memory_settings", {}),
|
||||
tags=config.get("tags", []),
|
||||
)
|
||||
|
||||
# Initialize with placeholder paths first
|
||||
agent.config_file_path = "placeholder"
|
||||
agent.prompt_file_path = "placeholder"
|
||||
agent.capabilities_file_path = "placeholder"
|
||||
|
||||
# Save to database first to get ID and UUID
|
||||
self.db.add(agent)
|
||||
await self.db.flush() # Flush to get the generated UUID without committing
|
||||
|
||||
# Now we can initialize proper file paths with the UUID
|
||||
agent.initialize_file_paths()
|
||||
|
||||
# Create file system structure
|
||||
await self._setup_assistant_files(agent, config)
|
||||
|
||||
# Commit all changes
|
||||
await self.db.commit()
|
||||
await self.db.refresh(agent)
|
||||
|
||||
logger.info(
|
||||
f"Created custom agent",
|
||||
extra={
|
||||
"agent_id": agent.id,
|
||||
"assistant_uuid": agent.uuid,
|
||||
"created_by": user_identifier,
|
||||
}
|
||||
)
|
||||
|
||||
return str(agent.uuid)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create custom agent: {e}", exc_info=True)
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def get_assistant_config(self, assistant_uuid: str, user_identifier: str) -> Dict[str, Any]:
|
||||
"""Get complete agent configuration including file-based data"""
|
||||
try:
|
||||
# Get agent from database
|
||||
result = await self.db.execute(
|
||||
select(Agent).where(
|
||||
and_(
|
||||
Agent.uuid == assistant_uuid,
|
||||
Agent.created_by == user_identifier,
|
||||
Agent.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
raise ValueError(f"Agent not found: {assistant_uuid}")
|
||||
|
||||
# Load complete configuration
|
||||
return agent.get_full_configuration()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get agent config: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def list_user_assistants(
|
||||
self,
|
||||
user_identifier: str,
|
||||
include_archived: bool = False,
|
||||
template_id: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List user's agents with filtering options"""
|
||||
try:
|
||||
# Build base query
|
||||
query = select(Agent).where(Agent.created_by == user_identifier)
|
||||
|
||||
# Apply filters
|
||||
if not include_archived:
|
||||
query = query.where(Agent.is_active == True)
|
||||
|
||||
if template_id:
|
||||
query = query.where(Agent.template_id == template_id)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search}%"
|
||||
query = query.where(
|
||||
or_(
|
||||
Agent.name.ilike(search_term),
|
||||
Agent.description.ilike(search_term)
|
||||
)
|
||||
)
|
||||
|
||||
# Apply ordering and pagination
|
||||
query = query.order_by(desc(Agent.last_used_at), desc(Agent.created_at))
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
agents = result.scalars().all()
|
||||
|
||||
return [agent.to_dict() for agent in agents]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list user agents: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def count_user_assistants(
|
||||
self,
|
||||
user_identifier: str,
|
||||
include_archived: bool = False,
|
||||
template_id: Optional[str] = None,
|
||||
search: Optional[str] = None
|
||||
) -> int:
|
||||
"""Count user's agents matching criteria"""
|
||||
try:
|
||||
# Build base query
|
||||
query = select(func.count(Agent.id)).where(Agent.created_by == user_identifier)
|
||||
|
||||
# Apply filters
|
||||
if not include_archived:
|
||||
query = query.where(Agent.is_active == True)
|
||||
|
||||
if template_id:
|
||||
query = query.where(Agent.template_id == template_id)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search}%"
|
||||
query = query.where(
|
||||
or_(
|
||||
Agent.name.ilike(search_term),
|
||||
Agent.description.ilike(search_term)
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to count user agents: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_assistant(self, agent_id: str, updates: Dict[str, Any], user_identifier: str) -> bool:
|
||||
"""Update agent configuration (renamed from update_configuration)"""
|
||||
return await self.update_configuration(agent_id, updates, user_identifier)
|
||||
|
||||
async def update_configuration(self, assistant_uuid: str, updates: Dict[str, Any], user_identifier: str) -> bool:
|
||||
"""Update agent configuration"""
|
||||
try:
|
||||
# Get agent
|
||||
result = await self.db.execute(
|
||||
select(Agent).where(
|
||||
and_(
|
||||
Agent.uuid == assistant_uuid,
|
||||
Agent.created_by == user_identifier,
|
||||
Agent.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
raise ValueError(f"Agent not found: {assistant_uuid}")
|
||||
|
||||
# Update database fields
|
||||
if "name" in updates:
|
||||
agent.name = updates["name"]
|
||||
if "description" in updates:
|
||||
agent.description = updates["description"]
|
||||
if "personality_config" in updates:
|
||||
agent.personality_config = updates["personality_config"]
|
||||
if "resource_preferences" in updates:
|
||||
agent.resource_preferences = updates["resource_preferences"]
|
||||
if "memory_settings" in updates:
|
||||
agent.memory_settings = updates["memory_settings"]
|
||||
if "tags" in updates:
|
||||
agent.tags = updates["tags"]
|
||||
|
||||
# Update file-based configurations
|
||||
if "config" in updates:
|
||||
agent.save_config_to_file(updates["config"])
|
||||
if "prompt" in updates:
|
||||
agent.save_prompt_to_file(updates["prompt"])
|
||||
if "capabilities" in updates:
|
||||
agent.save_capabilities_to_file(updates["capabilities"])
|
||||
|
||||
agent.updated_at = datetime.utcnow()
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Updated agent configuration",
|
||||
extra={
|
||||
"assistant_uuid": assistant_uuid,
|
||||
"updated_fields": list(updates.keys()),
|
||||
}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update agent configuration: {e}", exc_info=True)
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def clone_assistant(self, source_uuid: str, new_name: str, user_identifier: str, modifications: Dict[str, Any] = None) -> str:
|
||||
"""Clone existing agent with modifications"""
|
||||
try:
|
||||
# Get source agent
|
||||
result = await self.db.execute(
|
||||
select(Agent).where(
|
||||
and_(
|
||||
Agent.uuid == source_uuid,
|
||||
Agent.created_by == user_identifier,
|
||||
Agent.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
source_assistant = result.scalar_one_or_none()
|
||||
|
||||
if not source_assistant:
|
||||
raise ValueError(f"Source agent not found: {source_uuid}")
|
||||
|
||||
# Clone agent
|
||||
cloned_assistant = source_assistant.clone(new_name, user_identifier, modifications or {})
|
||||
|
||||
# Initialize with placeholder paths first
|
||||
cloned_assistant.config_file_path = "placeholder"
|
||||
cloned_assistant.prompt_file_path = "placeholder"
|
||||
cloned_assistant.capabilities_file_path = "placeholder"
|
||||
|
||||
# Save to database first to get UUID
|
||||
self.db.add(cloned_assistant)
|
||||
await self.db.flush() # Flush to get the generated UUID
|
||||
|
||||
# Initialize proper file paths with UUID
|
||||
cloned_assistant.initialize_file_paths()
|
||||
|
||||
# Copy and modify files
|
||||
await self._clone_assistant_files(source_assistant, cloned_assistant, modifications or {})
|
||||
|
||||
# Commit all changes
|
||||
await self.db.commit()
|
||||
await self.db.refresh(cloned_assistant)
|
||||
|
||||
logger.info(
|
||||
f"Cloned agent",
|
||||
extra={
|
||||
"source_uuid": source_uuid,
|
||||
"new_uuid": cloned_assistant.uuid,
|
||||
"new_name": new_name,
|
||||
}
|
||||
)
|
||||
|
||||
return str(cloned_assistant.uuid)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clone agent: {e}", exc_info=True)
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def archive_assistant(self, assistant_uuid: str, user_identifier: str) -> bool:
|
||||
"""Archive agent (soft delete)"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(Agent).where(
|
||||
and_(
|
||||
Agent.uuid == assistant_uuid,
|
||||
Agent.created_by == user_identifier
|
||||
)
|
||||
)
|
||||
)
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
raise ValueError(f"Agent not found: {assistant_uuid}")
|
||||
|
||||
agent.archive()
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Archived agent",
|
||||
extra={"assistant_uuid": assistant_uuid}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to archive agent: {e}", exc_info=True)
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def get_assistant_statistics(self, assistant_uuid: str, user_identifier: str) -> Dict[str, Any]:
|
||||
"""Get usage statistics for agent"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(Agent).where(
|
||||
and_(
|
||||
Agent.uuid == assistant_uuid,
|
||||
Agent.created_by == user_identifier,
|
||||
Agent.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
raise ValueError(f"Agent not found: {assistant_uuid}")
|
||||
|
||||
# Get conversation statistics
|
||||
conv_result = await self.db.execute(
|
||||
select(func.count(Conversation.id))
|
||||
.where(Conversation.agent_id == agent.id)
|
||||
)
|
||||
conversation_count = conv_result.scalar() or 0
|
||||
|
||||
# Get message statistics
|
||||
msg_result = await self.db.execute(
|
||||
select(
|
||||
func.count(Message.id),
|
||||
func.sum(Message.tokens_used),
|
||||
func.sum(Message.cost_cents)
|
||||
)
|
||||
.join(Conversation, Message.conversation_id == Conversation.id)
|
||||
.where(Conversation.agent_id == agent.id)
|
||||
)
|
||||
message_stats = msg_result.first()
|
||||
|
||||
return {
|
||||
"agent_id": assistant_uuid, # Use agent_id to match schema
|
||||
"name": agent.name,
|
||||
"created_at": agent.created_at, # Return datetime object, not ISO string
|
||||
"last_used_at": agent.last_used_at, # Return datetime object, not ISO string
|
||||
"conversation_count": conversation_count,
|
||||
"total_messages": message_stats[0] or 0,
|
||||
"total_tokens_used": message_stats[1] or 0,
|
||||
"total_cost_cents": message_stats[2] or 0,
|
||||
"total_cost_dollars": (message_stats[2] or 0) / 100.0,
|
||||
"average_tokens_per_message": (
|
||||
(message_stats[1] or 0) / max(1, message_stats[0] or 1)
|
||||
),
|
||||
"is_favorite": agent.is_favorite,
|
||||
"tags": agent.tags,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get agent statistics: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# Private helper methods
|
||||
|
||||
async def _load_template_config(self, template_id: str) -> Dict[str, Any]:
|
||||
"""Load template configuration from Resource Cluster or built-in templates"""
|
||||
# Built-in templates (as specified in CLAUDE.md)
|
||||
builtin_templates = {
|
||||
"research_assistant": {
|
||||
"name": "Research & Analysis Agent",
|
||||
"description": "Specialized in information synthesis and analysis",
|
||||
"prompt": """You are a research agent specialized in information synthesis and analysis.
|
||||
Focus on providing well-sourced, analytical responses with clear reasoning.""",
|
||||
"personality_config": {
|
||||
"tone": "balanced",
|
||||
"explanation_depth": "expert",
|
||||
"interaction_style": "collaborative"
|
||||
},
|
||||
"resource_preferences": {
|
||||
"primary_llm": "groq:llama3-70b-8192",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4000
|
||||
},
|
||||
"capabilities": [
|
||||
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 4000}},
|
||||
{"resource": "rag:semantic_search", "actions": ["search"], "limits": {}},
|
||||
{"resource": "tools:web_search", "actions": ["search"], "limits": {"requests_per_hour": 50}},
|
||||
{"resource": "export:citations", "actions": ["create"], "limits": {}}
|
||||
]
|
||||
},
|
||||
"coding_assistant": {
|
||||
"name": "Software Development Agent",
|
||||
"description": "Focused on code quality and best practices",
|
||||
"prompt": """You are a software development agent focused on code quality and best practices.
|
||||
Provide clear explanations, suggest improvements, and help debug issues.""",
|
||||
"personality_config": {
|
||||
"tone": "direct",
|
||||
"explanation_depth": "intermediate",
|
||||
"interaction_style": "teaching"
|
||||
},
|
||||
"resource_preferences": {
|
||||
"primary_llm": "groq:llama3-70b-8192",
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 4000
|
||||
},
|
||||
"capabilities": [
|
||||
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 4000}},
|
||||
{"resource": "tools:github_integration", "actions": ["read"], "limits": {}},
|
||||
{"resource": "resources:documentation", "actions": ["search"], "limits": {}},
|
||||
{"resource": "export:code_snippets", "actions": ["create"], "limits": {}}
|
||||
]
|
||||
},
|
||||
"cyber_analyst": {
|
||||
"name": "Cybersecurity Analysis Agent",
|
||||
"description": "For threat detection and response analysis",
|
||||
"prompt": """You are a cybersecurity analyst agent for threat detection and response.
|
||||
Prioritize security best practices and provide actionable recommendations.""",
|
||||
"personality_config": {
|
||||
"tone": "formal",
|
||||
"explanation_depth": "expert",
|
||||
"interaction_style": "direct"
|
||||
},
|
||||
"resource_preferences": {
|
||||
"primary_llm": "groq:llama3-70b-8192",
|
||||
"temperature": 0.2,
|
||||
"max_tokens": 4000
|
||||
},
|
||||
"capabilities": [
|
||||
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 4000}},
|
||||
{"resource": "tools:security_scanning", "actions": ["analyze"], "limits": {}},
|
||||
{"resource": "resources:threat_intelligence", "actions": ["search"], "limits": {}},
|
||||
{"resource": "export:security_reports", "actions": ["create"], "limits": {}}
|
||||
]
|
||||
},
|
||||
"educational_tutor": {
|
||||
"name": "AI Literacy Educational Agent",
|
||||
"description": "Develops critical thinking and AI literacy",
|
||||
"prompt": """You are an educational agent focused on developing critical thinking and AI literacy.
|
||||
Use socratic questioning and encourage deep analysis of problems.""",
|
||||
"personality_config": {
|
||||
"tone": "casual",
|
||||
"explanation_depth": "beginner",
|
||||
"interaction_style": "teaching"
|
||||
},
|
||||
"resource_preferences": {
|
||||
"primary_llm": "groq:llama3-70b-8192",
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 3000
|
||||
},
|
||||
"capabilities": [
|
||||
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 3000}},
|
||||
{"resource": "games:strategic_thinking", "actions": ["play"], "limits": {}},
|
||||
{"resource": "puzzles:logic_reasoning", "actions": ["present"], "limits": {}},
|
||||
{"resource": "analytics:learning_progress", "actions": ["track"], "limits": {}}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
if template_id in builtin_templates:
|
||||
return builtin_templates[template_id]
|
||||
|
||||
# TODO: In the future, load from Resource Cluster Agent Library
|
||||
# For now, return empty config for unknown templates
|
||||
logger.warning(f"Unknown template ID: {template_id}")
|
||||
return {
|
||||
"name": f"Agent ({template_id})",
|
||||
"description": "Custom agent",
|
||||
"prompt": "You are a helpful AI agent.",
|
||||
"capabilities": []
|
||||
}
|
||||
|
||||
async def _setup_assistant_files(self, agent: Agent, config: Dict[str, Any]) -> None:
|
||||
"""Create file system structure for agent"""
|
||||
# Ensure directory exists
|
||||
agent.ensure_directory_exists()
|
||||
|
||||
# Save configuration files
|
||||
agent.save_config_to_file(config)
|
||||
agent.save_prompt_to_file(config.get("prompt", "You are a helpful AI agent."))
|
||||
agent.save_capabilities_to_file(config.get("capabilities", []))
|
||||
|
||||
logger.info(f"Created agent files for {agent.uuid}")
|
||||
|
||||
async def _clone_assistant_files(self, source: Agent, target: Agent, modifications: Dict[str, Any]) -> None:
|
||||
"""Clone agent files with modifications"""
|
||||
# Load source configurations
|
||||
source_config = source.load_config_from_file()
|
||||
source_prompt = source.load_prompt_from_file()
|
||||
source_capabilities = source.load_capabilities_from_file()
|
||||
|
||||
# Apply modifications
|
||||
target_config = {**source_config, **modifications.get("config", {})}
|
||||
target_prompt = modifications.get("prompt", source_prompt)
|
||||
target_capabilities = modifications.get("capabilities", source_capabilities)
|
||||
|
||||
# Create target files
|
||||
target.ensure_directory_exists()
|
||||
target.save_config_to_file(target_config)
|
||||
target.save_prompt_to_file(target_prompt)
|
||||
target.save_capabilities_to_file(target_capabilities)
|
||||
|
||||
logger.info(f"Cloned agent files from {source.uuid} to {target.uuid}")
|
||||
|
||||
|
||||
async def get_assistant_manager(db: AsyncSession) -> AssistantManager:
|
||||
"""Get AssistantManager instance"""
|
||||
return AssistantManager(db)
|
||||
632
apps/tenant-backend/app/services/automation_executor.py
Normal file
632
apps/tenant-backend/app/services/automation_executor.py
Normal file
@@ -0,0 +1,632 @@
|
||||
"""
|
||||
Automation Chain Executor
|
||||
|
||||
Executes automation chains with configurable depth, capability-based limits,
|
||||
and comprehensive error handling.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from app.services.event_bus import Event, Automation, TriggerType, TenantEventBus
|
||||
from app.core.security import verify_capability_token
|
||||
from app.core.path_security import sanitize_tenant_domain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChainDepthExceeded(Exception):
|
||||
"""Raised when automation chain depth exceeds limit"""
|
||||
pass
|
||||
|
||||
|
||||
class AutomationTimeout(Exception):
|
||||
"""Raised when automation execution times out"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""Context for automation execution"""
|
||||
automation_id: str
|
||||
chain_depth: int = 0
|
||||
parent_automation_id: Optional[str] = None
|
||||
start_time: datetime = None
|
||||
execution_history: List[Dict[str, Any]] = None
|
||||
variables: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.start_time is None:
|
||||
self.start_time = datetime.utcnow()
|
||||
if self.execution_history is None:
|
||||
self.execution_history = []
|
||||
if self.variables is None:
|
||||
self.variables = {}
|
||||
|
||||
def add_execution(self, action: str, result: Any, duration_ms: float):
|
||||
"""Add execution record to history"""
|
||||
self.execution_history.append({
|
||||
"action": action,
|
||||
"result": result,
|
||||
"duration_ms": duration_ms,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
def get_total_duration(self) -> float:
|
||||
"""Get total execution duration in milliseconds"""
|
||||
return (datetime.utcnow() - self.start_time).total_seconds() * 1000
|
||||
|
||||
|
||||
class AutomationChainExecutor:
|
||||
"""
|
||||
Execute automation chains with configurable depth and capability-based limits.
|
||||
|
||||
Features:
|
||||
- Configurable max chain depth per tenant
|
||||
- Retry logic with exponential backoff
|
||||
- Comprehensive error handling
|
||||
- Execution history tracking
|
||||
- Variable passing between chain steps
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_domain: str,
|
||||
event_bus: TenantEventBus,
|
||||
base_path: Optional[Path] = None
|
||||
):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.event_bus = event_bus
|
||||
# Sanitize tenant_domain to prevent path traversal
|
||||
safe_tenant = sanitize_tenant_domain(tenant_domain)
|
||||
self.base_path = base_path or (Path("/data") / safe_tenant / "automations")
|
||||
self.execution_path = self.base_path / "executions"
|
||||
self.running_chains: Dict[str, ExecutionContext] = {}
|
||||
|
||||
# Ensure directories exist
|
||||
self._ensure_directories()
|
||||
|
||||
logger.info(f"AutomationChainExecutor initialized for {tenant_domain}")
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure execution directories exist with proper permissions"""
|
||||
import os
|
||||
import stat
|
||||
|
||||
# codeql[py/path-injection] execution_path derived from sanitize_tenant_domain() at line 86
|
||||
self.execution_path.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(self.execution_path, stat.S_IRWXU) # 700 permissions
|
||||
|
||||
async def execute_chain(
|
||||
self,
|
||||
automation: Automation,
|
||||
event: Event,
|
||||
capability_token: str,
|
||||
current_depth: int = 0
|
||||
) -> Any:
|
||||
"""
|
||||
Execute automation chain with depth control.
|
||||
|
||||
Args:
|
||||
automation: Automation to execute
|
||||
event: Triggering event
|
||||
capability_token: JWT capability token
|
||||
current_depth: Current chain depth
|
||||
|
||||
Returns:
|
||||
Execution result
|
||||
|
||||
Raises:
|
||||
ChainDepthExceeded: If chain depth exceeds limit
|
||||
AutomationTimeout: If execution times out
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data:
|
||||
raise ValueError("Invalid capability token")
|
||||
|
||||
# Get max chain depth from capability token (tenant-specific)
|
||||
max_depth = self._get_constraint(token_data, "max_automation_chain_depth", 5)
|
||||
|
||||
# Check depth limit
|
||||
if current_depth >= max_depth:
|
||||
raise ChainDepthExceeded(
|
||||
f"Chain depth {current_depth} exceeds limit {max_depth}"
|
||||
)
|
||||
|
||||
# Create execution context
|
||||
context = ExecutionContext(
|
||||
automation_id=automation.id,
|
||||
chain_depth=current_depth,
|
||||
parent_automation_id=event.metadata.get("parent_automation_id")
|
||||
)
|
||||
|
||||
# Track running chain
|
||||
self.running_chains[automation.id] = context
|
||||
|
||||
try:
|
||||
# Execute automation with timeout
|
||||
timeout = self._get_constraint(token_data, "automation_timeout_seconds", 300)
|
||||
result = await asyncio.wait_for(
|
||||
self._execute_automation(automation, event, context, token_data),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# If this automation triggers chain
|
||||
if automation.triggers_chain:
|
||||
await self._trigger_chain_automations(
|
||||
automation,
|
||||
result,
|
||||
capability_token,
|
||||
current_depth
|
||||
)
|
||||
|
||||
# Store execution history
|
||||
await self._store_execution(context, result)
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise AutomationTimeout(
|
||||
f"Automation {automation.id} timed out after {timeout} seconds"
|
||||
)
|
||||
finally:
|
||||
# Remove from running chains
|
||||
self.running_chains.pop(automation.id, None)
|
||||
|
||||
async def _execute_automation(
|
||||
self,
|
||||
automation: Automation,
|
||||
event: Event,
|
||||
context: ExecutionContext,
|
||||
token_data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""Execute automation with retry logic"""
|
||||
results = []
|
||||
retry_count = 0
|
||||
max_retries = min(automation.max_retries, 5) # Cap at 5 retries
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# Execute each action
|
||||
for action in automation.actions:
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Check if action is allowed by capabilities
|
||||
if not self._is_action_allowed(action, token_data):
|
||||
logger.warning(f"Action {action.get('type')} not allowed by capabilities")
|
||||
continue
|
||||
|
||||
# Execute action with context
|
||||
result = await self._execute_action(action, event, context, token_data)
|
||||
|
||||
# Track execution
|
||||
duration_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
context.add_execution(action.get("type"), result, duration_ms)
|
||||
|
||||
results.append(result)
|
||||
|
||||
# Update variables for next action
|
||||
if isinstance(result, dict) and "variables" in result:
|
||||
context.variables.update(result["variables"])
|
||||
|
||||
# Success - break retry loop
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count > max_retries:
|
||||
logger.error(f"Automation {automation.id} failed after {max_retries} retries: {e}")
|
||||
raise
|
||||
|
||||
# Exponential backoff
|
||||
wait_time = min(2 ** retry_count, 30) # Max 30 seconds
|
||||
logger.info(f"Retrying automation {automation.id} in {wait_time} seconds...")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
return {
|
||||
"automation_id": automation.id,
|
||||
"results": results,
|
||||
"context": {
|
||||
"chain_depth": context.chain_depth,
|
||||
"total_duration_ms": context.get_total_duration(),
|
||||
"variables": context.variables
|
||||
}
|
||||
}
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event,
|
||||
context: ExecutionContext,
|
||||
token_data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""Execute a single action with capability constraints"""
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "api_call":
|
||||
return await self._execute_api_call(action, context, token_data)
|
||||
elif action_type == "data_transform":
|
||||
return await self._execute_data_transform(action, context)
|
||||
elif action_type == "conditional":
|
||||
return await self._execute_conditional(action, context)
|
||||
elif action_type == "loop":
|
||||
return await self._execute_loop(action, event, context, token_data)
|
||||
elif action_type == "wait":
|
||||
return await self._execute_wait(action)
|
||||
elif action_type == "variable_set":
|
||||
return await self._execute_variable_set(action, context)
|
||||
else:
|
||||
# Delegate to event bus for standard actions
|
||||
return await self.event_bus._execute_action(action, event, None)
|
||||
|
||||
async def _execute_api_call(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
context: ExecutionContext,
|
||||
token_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute API call action with rate limiting"""
|
||||
endpoint = action.get("endpoint")
|
||||
method = action.get("method", "GET")
|
||||
headers = action.get("headers", {})
|
||||
body = action.get("body")
|
||||
|
||||
# Apply variable substitution
|
||||
if body and context.variables:
|
||||
body = self._substitute_variables(body, context.variables)
|
||||
|
||||
# Check rate limits
|
||||
rate_limit = self._get_constraint(token_data, "api_calls_per_minute", 60)
|
||||
# In production, implement actual rate limiting
|
||||
|
||||
logger.info(f"Mock API call: {method} {endpoint}")
|
||||
|
||||
# Mock response
|
||||
return {
|
||||
"status": 200,
|
||||
"data": {"message": "Mock API response"},
|
||||
"headers": {"content-type": "application/json"}
|
||||
}
|
||||
|
||||
async def _execute_data_transform(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute data transformation action"""
|
||||
transform_type = action.get("transform_type")
|
||||
source = action.get("source")
|
||||
target = action.get("target")
|
||||
|
||||
# Get source data from context
|
||||
source_data = context.variables.get(source)
|
||||
|
||||
if transform_type == "json_parse":
|
||||
result = json.loads(source_data) if isinstance(source_data, str) else source_data
|
||||
elif transform_type == "json_stringify":
|
||||
result = json.dumps(source_data)
|
||||
elif transform_type == "extract":
|
||||
path = action.get("path", "")
|
||||
result = self._extract_path(source_data, path)
|
||||
elif transform_type == "map":
|
||||
mapping = action.get("mapping", {})
|
||||
result = {k: self._extract_path(source_data, v) for k, v in mapping.items()}
|
||||
else:
|
||||
result = source_data
|
||||
|
||||
# Store result in context
|
||||
context.variables[target] = result
|
||||
|
||||
return {
|
||||
"transform_type": transform_type,
|
||||
"target": target,
|
||||
"variables": {target: result}
|
||||
}
|
||||
|
||||
async def _execute_conditional(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute conditional action"""
|
||||
condition = action.get("condition")
|
||||
then_actions = action.get("then", [])
|
||||
else_actions = action.get("else", [])
|
||||
|
||||
# Evaluate condition
|
||||
if self._evaluate_condition(condition, context.variables):
|
||||
actions_to_execute = then_actions
|
||||
branch = "then"
|
||||
else:
|
||||
actions_to_execute = else_actions
|
||||
branch = "else"
|
||||
|
||||
# Execute branch actions
|
||||
results = []
|
||||
for sub_action in actions_to_execute:
|
||||
result = await self._execute_action(sub_action, None, context, {})
|
||||
results.append(result)
|
||||
|
||||
return {
|
||||
"condition": condition,
|
||||
"branch": branch,
|
||||
"results": results
|
||||
}
|
||||
|
||||
async def _execute_loop(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event,
|
||||
context: ExecutionContext,
|
||||
token_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute loop action with iteration limit"""
|
||||
items = action.get("items", [])
|
||||
variable = action.get("variable", "item")
|
||||
loop_actions = action.get("actions", [])
|
||||
|
||||
# Get max iterations from capabilities
|
||||
max_iterations = self._get_constraint(token_data, "max_loop_iterations", 100)
|
||||
|
||||
# Resolve items from context if it's a variable reference
|
||||
if isinstance(items, str) and items.startswith("$"):
|
||||
items = context.variables.get(items[1:], [])
|
||||
|
||||
# Limit iterations
|
||||
items = items[:max_iterations]
|
||||
|
||||
results = []
|
||||
for item in items:
|
||||
# Set loop variable
|
||||
context.variables[variable] = item
|
||||
|
||||
# Execute loop actions
|
||||
for loop_action in loop_actions:
|
||||
result = await self._execute_action(loop_action, event, context, token_data)
|
||||
results.append(result)
|
||||
|
||||
return {
|
||||
"loop_count": len(items),
|
||||
"results": results
|
||||
}
|
||||
|
||||
async def _execute_wait(self, action: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute wait action"""
|
||||
duration = action.get("duration", 1)
|
||||
max_wait = 60 # Maximum 60 seconds wait
|
||||
|
||||
duration = min(duration, max_wait)
|
||||
await asyncio.sleep(duration)
|
||||
|
||||
return {
|
||||
"waited": duration,
|
||||
"unit": "seconds"
|
||||
}
|
||||
|
||||
async def _execute_variable_set(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""Set variables in context"""
|
||||
variables = action.get("variables", {})
|
||||
|
||||
for key, value in variables.items():
|
||||
# Substitute existing variables in value
|
||||
if isinstance(value, str):
|
||||
value = self._substitute_variables(value, context.variables)
|
||||
context.variables[key] = value
|
||||
|
||||
return {
|
||||
"variables": variables
|
||||
}
|
||||
|
||||
async def _trigger_chain_automations(
|
||||
self,
|
||||
automation: Automation,
|
||||
result: Any,
|
||||
capability_token: str,
|
||||
current_depth: int
|
||||
):
|
||||
"""Trigger chained automations"""
|
||||
for target_id in automation.chain_targets:
|
||||
# Load target automation
|
||||
target_automation = await self.event_bus.get_automation(target_id)
|
||||
|
||||
if not target_automation:
|
||||
logger.warning(f"Chain target automation {target_id} not found")
|
||||
continue
|
||||
|
||||
# Create chain event
|
||||
chain_event = Event(
|
||||
type=TriggerType.CHAIN.value,
|
||||
tenant=self.tenant_domain,
|
||||
user=automation.owner_id,
|
||||
data=result,
|
||||
metadata={
|
||||
"parent_automation_id": automation.id,
|
||||
"chain_depth": current_depth + 1
|
||||
}
|
||||
)
|
||||
|
||||
# Execute chained automation
|
||||
try:
|
||||
await self.execute_chain(
|
||||
target_automation,
|
||||
chain_event,
|
||||
capability_token,
|
||||
current_depth + 1
|
||||
)
|
||||
except ChainDepthExceeded:
|
||||
logger.warning(f"Chain depth exceeded for automation {target_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing chained automation {target_id}: {e}")
|
||||
|
||||
def _get_constraint(
|
||||
self,
|
||||
token_data: Dict[str, Any],
|
||||
constraint_name: str,
|
||||
default: Any
|
||||
) -> Any:
|
||||
"""Get constraint value from capability token"""
|
||||
constraints = token_data.get("constraints", {})
|
||||
return constraints.get(constraint_name, default)
|
||||
|
||||
def _is_action_allowed(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
token_data: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check if action is allowed by capabilities"""
|
||||
action_type = action.get("type")
|
||||
|
||||
# Check specific action capabilities
|
||||
capabilities = token_data.get("capabilities", [])
|
||||
|
||||
# Map action types to required capabilities
|
||||
required_capabilities = {
|
||||
"api_call": "automation:api_calls",
|
||||
"webhook": "automation:webhooks",
|
||||
"email": "automation:email",
|
||||
"data_transform": "automation:data_processing",
|
||||
"conditional": "automation:logic",
|
||||
"loop": "automation:logic"
|
||||
}
|
||||
|
||||
required = required_capabilities.get(action_type)
|
||||
if not required:
|
||||
return True # Allow unknown actions by default
|
||||
|
||||
# Check if capability exists
|
||||
return any(
|
||||
cap.get("resource") == required
|
||||
for cap in capabilities
|
||||
)
|
||||
|
||||
def _substitute_variables(
|
||||
self,
|
||||
template: Any,
|
||||
variables: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""Substitute variables in template"""
|
||||
if not isinstance(template, str):
|
||||
return template
|
||||
|
||||
# Simple variable substitution
|
||||
for key, value in variables.items():
|
||||
template = template.replace(f"${{{key}}}", str(value))
|
||||
template = template.replace(f"${key}", str(value))
|
||||
|
||||
return template
|
||||
|
||||
def _extract_path(self, data: Any, path: str) -> Any:
|
||||
"""Extract value from nested data using path"""
|
||||
if not path:
|
||||
return data
|
||||
|
||||
parts = path.split(".")
|
||||
current = data
|
||||
|
||||
for part in parts:
|
||||
if isinstance(current, dict):
|
||||
current = current.get(part)
|
||||
elif isinstance(current, list) and part.isdigit():
|
||||
index = int(part)
|
||||
if 0 <= index < len(current):
|
||||
current = current[index]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
return current
|
||||
|
||||
def _evaluate_condition(
|
||||
self,
|
||||
condition: Dict[str, Any],
|
||||
variables: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Evaluate condition against variables"""
|
||||
left = condition.get("left")
|
||||
operator = condition.get("operator")
|
||||
right = condition.get("right")
|
||||
|
||||
# Resolve variables
|
||||
if isinstance(left, str) and left.startswith("$"):
|
||||
left = variables.get(left[1:])
|
||||
if isinstance(right, str) and right.startswith("$"):
|
||||
right = variables.get(right[1:])
|
||||
|
||||
# Evaluate
|
||||
try:
|
||||
if operator == "equals":
|
||||
return left == right
|
||||
elif operator == "not_equals":
|
||||
return left != right
|
||||
elif operator == "greater_than":
|
||||
return float(left) > float(right)
|
||||
elif operator == "less_than":
|
||||
return float(left) < float(right)
|
||||
elif operator == "contains":
|
||||
return right in left
|
||||
elif operator == "exists":
|
||||
return left is not None
|
||||
elif operator == "not_exists":
|
||||
return left is None
|
||||
else:
|
||||
return False
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
async def _store_execution(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
result: Any
|
||||
):
|
||||
"""Store execution history to file system"""
|
||||
execution_record = {
|
||||
"automation_id": context.automation_id,
|
||||
"chain_depth": context.chain_depth,
|
||||
"parent_automation_id": context.parent_automation_id,
|
||||
"start_time": context.start_time.isoformat(),
|
||||
"total_duration_ms": context.get_total_duration(),
|
||||
"execution_history": context.execution_history,
|
||||
"variables": context.variables,
|
||||
"result": result if isinstance(result, (dict, list, str, int, float, bool)) else str(result)
|
||||
}
|
||||
|
||||
# Create execution file
|
||||
execution_file = self.execution_path / f"{context.automation_id}_{context.start_time.strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
with open(execution_file, "w") as f:
|
||||
json.dump(execution_record, f, indent=2)
|
||||
|
||||
async def get_execution_history(
|
||||
self,
|
||||
automation_id: Optional[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get execution history for automations"""
|
||||
executions = []
|
||||
|
||||
# Get all execution files
|
||||
pattern = f"{automation_id}_*.json" if automation_id else "*.json"
|
||||
|
||||
for execution_file in sorted(
|
||||
self.execution_path.glob(pattern),
|
||||
key=lambda x: x.stat().st_mtime,
|
||||
reverse=True
|
||||
)[:limit]:
|
||||
try:
|
||||
with open(execution_file, "r") as f:
|
||||
executions.append(json.load(f))
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading execution {execution_file}: {e}")
|
||||
|
||||
return executions
|
||||
514
apps/tenant-backend/app/services/category_service.py
Normal file
514
apps/tenant-backend/app/services/category_service.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""
|
||||
Category Service for GT 2.0 Tenant Backend
|
||||
|
||||
Provides tenant-scoped agent category management with permission-based
|
||||
editing and deletion. Supports Issue #215 requirements.
|
||||
|
||||
Permission Model:
|
||||
- Admins/developers can edit/delete ANY category
|
||||
- Regular users can only edit/delete categories they created
|
||||
- All users can view and use all tenant categories
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from app.core.config import get_settings
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.core.permissions import get_user_role
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Admin roles that can manage all categories
|
||||
ADMIN_ROLES = ["admin", "developer"]
|
||||
|
||||
|
||||
class CategoryService:
|
||||
"""GT 2.0 Category Management Service with Tenant Isolation"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str, user_email: str = None):
|
||||
"""Initialize with tenant and user isolation using PostgreSQL storage"""
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.user_email = user_email or user_id
|
||||
self.settings = get_settings()
|
||||
|
||||
logger.info(f"Category service initialized for {tenant_domain}/{user_id}")
|
||||
|
||||
def _generate_slug(self, name: str) -> str:
|
||||
"""Generate URL-safe slug from category name"""
|
||||
# Convert to lowercase, replace non-alphanumeric with hyphens
|
||||
slug = re.sub(r'[^a-zA-Z0-9]+', '-', name.lower())
|
||||
# Remove leading/trailing hyphens
|
||||
slug = slug.strip('-')
|
||||
return slug or 'category'
|
||||
|
||||
async def _get_user_id(self, pg_client) -> str:
|
||||
"""Get user UUID from email/username/uuid with tenant isolation"""
|
||||
identifier = self.user_email
|
||||
|
||||
user_lookup_query = """
|
||||
SELECT id FROM users
|
||||
WHERE (email = $1 OR id::text = $1 OR username = $1)
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
user_id = await pg_client.fetch_scalar(user_lookup_query, identifier, self.tenant_domain)
|
||||
if not user_id:
|
||||
user_id = await pg_client.fetch_scalar(user_lookup_query, self.user_id, self.tenant_domain)
|
||||
|
||||
if not user_id:
|
||||
raise RuntimeError(f"User not found: {identifier} in tenant {self.tenant_domain}")
|
||||
|
||||
return str(user_id)
|
||||
|
||||
async def _get_tenant_id(self, pg_client) -> str:
|
||||
"""Get tenant UUID from domain"""
|
||||
query = "SELECT id FROM tenants WHERE domain = $1 LIMIT 1"
|
||||
tenant_id = await pg_client.fetch_scalar(query, self.tenant_domain)
|
||||
if not tenant_id:
|
||||
raise RuntimeError(f"Tenant not found: {self.tenant_domain}")
|
||||
return str(tenant_id)
|
||||
|
||||
async def _can_manage_category(self, pg_client, category: Dict) -> tuple:
|
||||
"""
|
||||
Check if current user can manage (edit/delete) a category.
|
||||
Returns (can_edit, can_delete) tuple.
|
||||
"""
|
||||
# Get user role
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
is_admin = user_role in ADMIN_ROLES
|
||||
|
||||
# Get current user ID
|
||||
current_user_id = await self._get_user_id(pg_client)
|
||||
|
||||
# Admins can manage all categories
|
||||
if is_admin:
|
||||
return (True, True)
|
||||
|
||||
# Check if user created this category
|
||||
created_by = category.get('created_by')
|
||||
if created_by and str(created_by) == current_user_id:
|
||||
return (True, True)
|
||||
|
||||
# Regular users cannot manage other users' categories or defaults
|
||||
return (False, False)
|
||||
|
||||
async def get_all_categories(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all active categories for the tenant.
|
||||
Returns categories with permission flags for current user.
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
user_id = await self._get_user_id(pg_client)
|
||||
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
|
||||
is_admin = user_role in ADMIN_ROLES
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
c.id, c.name, c.slug, c.description, c.icon,
|
||||
c.is_default, c.created_by, c.sort_order,
|
||||
c.created_at, c.updated_at,
|
||||
u.full_name as created_by_name
|
||||
FROM categories c
|
||||
LEFT JOIN users u ON c.created_by = u.id
|
||||
WHERE c.tenant_id = (SELECT id FROM tenants WHERE domain = $1 LIMIT 1)
|
||||
AND c.is_deleted = FALSE
|
||||
ORDER BY c.sort_order ASC, c.name ASC
|
||||
"""
|
||||
|
||||
rows = await pg_client.execute_query(query, self.tenant_domain)
|
||||
|
||||
categories = []
|
||||
for row in rows:
|
||||
# Determine permissions
|
||||
can_edit = False
|
||||
can_delete = False
|
||||
|
||||
if is_admin:
|
||||
can_edit = True
|
||||
can_delete = True
|
||||
elif row.get('created_by') and str(row['created_by']) == user_id:
|
||||
can_edit = True
|
||||
can_delete = True
|
||||
|
||||
categories.append({
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"slug": row["slug"],
|
||||
"description": row.get("description"),
|
||||
"icon": row.get("icon"),
|
||||
"is_default": row.get("is_default", False),
|
||||
"created_by": str(row["created_by"]) if row.get("created_by") else None,
|
||||
"created_by_name": row.get("created_by_name"),
|
||||
"can_edit": can_edit,
|
||||
"can_delete": can_delete,
|
||||
"sort_order": row.get("sort_order", 0),
|
||||
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
|
||||
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(categories)} categories for tenant {self.tenant_domain}")
|
||||
return categories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting categories: {e}")
|
||||
raise
|
||||
|
||||
async def get_category_by_id(self, category_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single category by ID"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
c.id, c.name, c.slug, c.description, c.icon,
|
||||
c.is_default, c.created_by, c.sort_order,
|
||||
c.created_at, c.updated_at,
|
||||
u.full_name as created_by_name
|
||||
FROM categories c
|
||||
LEFT JOIN users u ON c.created_by = u.id
|
||||
WHERE c.id = $1::uuid
|
||||
AND c.tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
AND c.is_deleted = FALSE
|
||||
"""
|
||||
|
||||
row = await pg_client.fetch_one(query, category_id, self.tenant_domain)
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
can_edit, can_delete = await self._can_manage_category(pg_client, dict(row))
|
||||
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"slug": row["slug"],
|
||||
"description": row.get("description"),
|
||||
"icon": row.get("icon"),
|
||||
"is_default": row.get("is_default", False),
|
||||
"created_by": str(row["created_by"]) if row.get("created_by") else None,
|
||||
"created_by_name": row.get("created_by_name"),
|
||||
"can_edit": can_edit,
|
||||
"can_delete": can_delete,
|
||||
"sort_order": row.get("sort_order", 0),
|
||||
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
|
||||
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting category {category_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_category_by_slug(self, slug: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single category by slug"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
c.id, c.name, c.slug, c.description, c.icon,
|
||||
c.is_default, c.created_by, c.sort_order,
|
||||
c.created_at, c.updated_at,
|
||||
u.full_name as created_by_name
|
||||
FROM categories c
|
||||
LEFT JOIN users u ON c.created_by = u.id
|
||||
WHERE c.slug = $1
|
||||
AND c.tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
AND c.is_deleted = FALSE
|
||||
"""
|
||||
|
||||
row = await pg_client.fetch_one(query, slug.lower(), self.tenant_domain)
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
can_edit, can_delete = await self._can_manage_category(pg_client, dict(row))
|
||||
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"slug": row["slug"],
|
||||
"description": row.get("description"),
|
||||
"icon": row.get("icon"),
|
||||
"is_default": row.get("is_default", False),
|
||||
"created_by": str(row["created_by"]) if row.get("created_by") else None,
|
||||
"created_by_name": row.get("created_by_name"),
|
||||
"can_edit": can_edit,
|
||||
"can_delete": can_delete,
|
||||
"sort_order": row.get("sort_order", 0),
|
||||
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
|
||||
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting category by slug {slug}: {e}")
|
||||
raise
|
||||
|
||||
async def create_category(
|
||||
self,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
icon: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new custom category.
|
||||
The creating user becomes the owner and can edit/delete it.
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
user_id = await self._get_user_id(pg_client)
|
||||
tenant_id = await self._get_tenant_id(pg_client)
|
||||
|
||||
# Generate slug
|
||||
slug = self._generate_slug(name)
|
||||
|
||||
# Check if slug already exists
|
||||
existing = await self.get_category_by_slug(slug)
|
||||
if existing:
|
||||
raise ValueError(f"A category with name '{name}' already exists")
|
||||
|
||||
# Generate category ID
|
||||
category_id = str(uuid.uuid4())
|
||||
|
||||
# Get next sort_order (after all existing categories)
|
||||
sort_query = """
|
||||
SELECT COALESCE(MAX(sort_order), 0) + 10 as next_order
|
||||
FROM categories
|
||||
WHERE tenant_id = $1::uuid
|
||||
"""
|
||||
next_order = await pg_client.fetch_scalar(sort_query, tenant_id)
|
||||
|
||||
# Create category
|
||||
query = """
|
||||
INSERT INTO categories (
|
||||
id, tenant_id, name, slug, description, icon,
|
||||
is_default, created_by, sort_order, is_deleted,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, $3, $4, $5, $6,
|
||||
FALSE, $7::uuid, $8, FALSE,
|
||||
NOW(), NOW()
|
||||
)
|
||||
RETURNING id, name, slug, description, icon, is_default,
|
||||
created_by, sort_order, created_at, updated_at
|
||||
"""
|
||||
|
||||
row = await pg_client.fetch_one(
|
||||
query,
|
||||
category_id, tenant_id, name, slug, description, icon,
|
||||
user_id, next_order
|
||||
)
|
||||
|
||||
if not row:
|
||||
raise RuntimeError("Failed to create category")
|
||||
|
||||
logger.info(f"Created category {category_id}: {name} for user {user_id}")
|
||||
|
||||
# Get creator name
|
||||
user_query = "SELECT full_name FROM users WHERE id = $1::uuid"
|
||||
created_by_name = await pg_client.fetch_scalar(user_query, user_id)
|
||||
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"slug": row["slug"],
|
||||
"description": row.get("description"),
|
||||
"icon": row.get("icon"),
|
||||
"is_default": False,
|
||||
"created_by": user_id,
|
||||
"created_by_name": created_by_name,
|
||||
"can_edit": True,
|
||||
"can_delete": True,
|
||||
"sort_order": row.get("sort_order", 0),
|
||||
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
|
||||
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating category: {e}")
|
||||
raise
|
||||
|
||||
async def update_category(
|
||||
self,
|
||||
category_id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
icon: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Update a category.
|
||||
Requires permission (admin or category creator).
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get existing category
|
||||
existing = await self.get_category_by_id(category_id)
|
||||
if not existing:
|
||||
raise ValueError("Category not found")
|
||||
|
||||
# Check permissions
|
||||
can_edit, _ = await self._can_manage_category(pg_client, existing)
|
||||
if not can_edit:
|
||||
raise PermissionError("You do not have permission to edit this category")
|
||||
|
||||
# Build update fields
|
||||
updates = []
|
||||
params = [category_id, self.tenant_domain]
|
||||
param_idx = 3
|
||||
|
||||
if name is not None:
|
||||
new_slug = self._generate_slug(name)
|
||||
# Check if new slug conflicts with another category
|
||||
slug_check = await self.get_category_by_slug(new_slug)
|
||||
if slug_check and slug_check["id"] != category_id:
|
||||
raise ValueError(f"A category with name '{name}' already exists")
|
||||
updates.append(f"name = ${param_idx}")
|
||||
params.append(name)
|
||||
param_idx += 1
|
||||
updates.append(f"slug = ${param_idx}")
|
||||
params.append(new_slug)
|
||||
param_idx += 1
|
||||
|
||||
if description is not None:
|
||||
updates.append(f"description = ${param_idx}")
|
||||
params.append(description)
|
||||
param_idx += 1
|
||||
|
||||
if icon is not None:
|
||||
updates.append(f"icon = ${param_idx}")
|
||||
params.append(icon)
|
||||
param_idx += 1
|
||||
|
||||
if not updates:
|
||||
return existing
|
||||
|
||||
updates.append("updated_at = NOW()")
|
||||
|
||||
query = f"""
|
||||
UPDATE categories
|
||||
SET {', '.join(updates)}
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
AND is_deleted = FALSE
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
result = await pg_client.fetch_scalar(query, *params)
|
||||
if not result:
|
||||
raise RuntimeError("Failed to update category")
|
||||
|
||||
logger.info(f"Updated category {category_id}")
|
||||
|
||||
# Return updated category
|
||||
return await self.get_category_by_id(category_id)
|
||||
|
||||
except (ValueError, PermissionError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating category {category_id}: {e}")
|
||||
raise
|
||||
|
||||
async def delete_category(self, category_id: str) -> bool:
|
||||
"""
|
||||
Soft delete a category.
|
||||
Requires permission (admin or category creator).
|
||||
"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Get existing category
|
||||
existing = await self.get_category_by_id(category_id)
|
||||
if not existing:
|
||||
raise ValueError("Category not found")
|
||||
|
||||
# Check permissions
|
||||
_, can_delete = await self._can_manage_category(pg_client, existing)
|
||||
if not can_delete:
|
||||
raise PermissionError("You do not have permission to delete this category")
|
||||
|
||||
# Soft delete
|
||||
query = """
|
||||
UPDATE categories
|
||||
SET is_deleted = TRUE, updated_at = NOW()
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
|
||||
await pg_client.execute_command(query, category_id, self.tenant_domain)
|
||||
|
||||
logger.info(f"Deleted category {category_id}")
|
||||
return True
|
||||
|
||||
except (ValueError, PermissionError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting category {category_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_or_create_category(
|
||||
self,
|
||||
slug: str,
|
||||
description: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get existing category by slug or create it if not exists.
|
||||
Used for agent import to auto-create missing categories.
|
||||
|
||||
If the category was soft-deleted, it will be restored.
|
||||
|
||||
Args:
|
||||
slug: Category slug (lowercase, hyphenated)
|
||||
description: Optional description for new/restored categories
|
||||
"""
|
||||
try:
|
||||
# Try to get existing active category
|
||||
existing = await self.get_category_by_slug(slug)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# Check if there's a soft-deleted category with this slug
|
||||
pg_client = await get_postgresql_client()
|
||||
deleted_query = """
|
||||
SELECT id FROM categories
|
||||
WHERE slug = $1
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
AND is_deleted = TRUE
|
||||
"""
|
||||
deleted_id = await pg_client.fetch_scalar(deleted_query, slug.lower(), self.tenant_domain)
|
||||
|
||||
if deleted_id:
|
||||
# Restore the soft-deleted category
|
||||
user_id = await self._get_user_id(pg_client)
|
||||
restore_query = """
|
||||
UPDATE categories
|
||||
SET is_deleted = FALSE,
|
||||
updated_at = NOW(),
|
||||
created_by = $3::uuid
|
||||
WHERE id = $1::uuid
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
await pg_client.execute_command(restore_query, str(deleted_id), self.tenant_domain, user_id)
|
||||
logger.info(f"Restored soft-deleted category: {slug}")
|
||||
|
||||
# Return the restored category
|
||||
return await self.get_category_by_slug(slug)
|
||||
|
||||
# Auto-create with importing user as creator
|
||||
name = slug.replace('-', ' ').title()
|
||||
return await self.create_category(
|
||||
name=name,
|
||||
description=description, # Use provided description or None
|
||||
icon=None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_or_create_category for slug {slug}: {e}")
|
||||
raise
|
||||
563
apps/tenant-backend/app/services/conversation_file_service.py
Normal file
563
apps/tenant-backend/app/services/conversation_file_service.py
Normal file
@@ -0,0 +1,563 @@
|
||||
"""
|
||||
Conversation File Service for GT 2.0
|
||||
|
||||
Handles conversation-scoped file attachments as a simpler alternative to dataset-based uploads.
|
||||
Preserves all existing dataset infrastructure while providing direct conversation file storage.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import UploadFile, HTTPException
|
||||
from app.core.config import get_settings
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.core.path_security import sanitize_tenant_domain
|
||||
from app.services.embedding_client import BGE_M3_EmbeddingClient
|
||||
from app.services.document_processor import DocumentProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationFileService:
|
||||
"""Service for managing conversation-scoped file attachments"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.settings = get_settings()
|
||||
self.schema_name = f"tenant_{tenant_domain.replace('.', '_').replace('-', '_')}"
|
||||
|
||||
# File storage configuration
|
||||
# Sanitize tenant_domain to prevent path traversal
|
||||
safe_tenant = sanitize_tenant_domain(tenant_domain)
|
||||
# codeql[py/path-injection] safe_tenant validated by sanitize_tenant_domain()
|
||||
self.storage_root = Path(self.settings.file_storage_path) / safe_tenant / "conversations"
|
||||
self.storage_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"ConversationFileService initialized for {tenant_domain}/{user_id}")
|
||||
|
||||
def _get_conversation_storage_path(self, conversation_id: str) -> Path:
|
||||
"""Get storage directory for conversation files"""
|
||||
conv_path = self.storage_root / conversation_id
|
||||
conv_path.mkdir(parents=True, exist_ok=True)
|
||||
return conv_path
|
||||
|
||||
def _generate_safe_filename(self, original_filename: str, file_id: str) -> str:
|
||||
"""Generate safe filename for storage"""
|
||||
# Sanitize filename and prepend file ID
|
||||
safe_name = "".join(c for c in original_filename if c.isalnum() or c in ".-_")
|
||||
return f"{file_id}-{safe_name}"
|
||||
|
||||
async def upload_files(
|
||||
self,
|
||||
conversation_id: str,
|
||||
files: List[UploadFile],
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Upload files directly to conversation"""
|
||||
try:
|
||||
# Validate conversation access
|
||||
await self._validate_conversation_access(conversation_id, user_id)
|
||||
|
||||
uploaded_files = []
|
||||
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File must have a filename")
|
||||
|
||||
# Generate file metadata
|
||||
file_id = str(uuid.uuid4())
|
||||
safe_filename = self._generate_safe_filename(file.filename, file_id)
|
||||
conversation_path = self._get_conversation_storage_path(conversation_id)
|
||||
file_path = conversation_path / safe_filename
|
||||
|
||||
# Store file to disk
|
||||
content = await file.read()
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Create database record
|
||||
file_record = await self._create_file_record(
|
||||
file_id=file_id,
|
||||
conversation_id=conversation_id,
|
||||
original_filename=file.filename,
|
||||
safe_filename=safe_filename,
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
file_size=len(content),
|
||||
file_path=str(file_path.relative_to(Path(self.settings.file_storage_path))),
|
||||
uploaded_by=user_id
|
||||
)
|
||||
|
||||
uploaded_files.append(file_record)
|
||||
|
||||
# Queue for background processing
|
||||
asyncio.create_task(self._process_file_embeddings(file_id))
|
||||
|
||||
logger.info(f"Uploaded conversation file: {file.filename} -> {file_id}")
|
||||
|
||||
return uploaded_files
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload conversation files: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
|
||||
|
||||
async def _get_user_uuid(self, user_email: str) -> str:
|
||||
"""Resolve user email to UUID"""
|
||||
client = await get_postgresql_client()
|
||||
query = f"SELECT id FROM {self.schema_name}.users WHERE email = $1 LIMIT 1"
|
||||
result = await client.fetch_one(query, user_email)
|
||||
if not result:
|
||||
raise ValueError(f"User not found: {user_email}")
|
||||
return str(result['id'])
|
||||
|
||||
async def _create_file_record(
|
||||
self,
|
||||
file_id: str,
|
||||
conversation_id: str,
|
||||
original_filename: str,
|
||||
safe_filename: str,
|
||||
content_type: str,
|
||||
file_size: int,
|
||||
file_path: str,
|
||||
uploaded_by: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create conversation_files database record"""
|
||||
|
||||
client = await get_postgresql_client()
|
||||
|
||||
# Resolve user email to UUID if needed
|
||||
user_uuid = uploaded_by
|
||||
if '@' in uploaded_by: # Check if it's an email
|
||||
user_uuid = await self._get_user_uuid(uploaded_by)
|
||||
|
||||
query = f"""
|
||||
INSERT INTO {self.schema_name}.conversation_files (
|
||||
id, conversation_id, filename, original_filename, content_type,
|
||||
file_size_bytes, file_path, uploaded_by, processing_status
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'pending')
|
||||
RETURNING id, filename, original_filename, content_type, file_size_bytes,
|
||||
processing_status, uploaded_at
|
||||
"""
|
||||
|
||||
result = await client.fetch_one(
|
||||
query,
|
||||
file_id, conversation_id, safe_filename, original_filename,
|
||||
content_type, file_size, file_path, user_uuid
|
||||
)
|
||||
|
||||
# Convert UUID fields to strings for JSON serialization
|
||||
result_dict = dict(result)
|
||||
if 'id' in result_dict and result_dict['id']:
|
||||
result_dict['id'] = str(result_dict['id'])
|
||||
|
||||
return result_dict
|
||||
|
||||
async def _process_file_embeddings(self, file_id: str):
|
||||
"""Background task to process file content and generate embeddings"""
|
||||
try:
|
||||
# Update status to processing
|
||||
await self._update_processing_status(file_id, "processing")
|
||||
|
||||
# Get file record
|
||||
file_record = await self._get_file_record(file_id)
|
||||
if not file_record:
|
||||
logger.error(f"File record not found: {file_id}")
|
||||
return
|
||||
|
||||
# Read file content
|
||||
file_path = Path(self.settings.file_storage_path) / file_record['file_path']
|
||||
if not file_path.exists():
|
||||
logger.error(f"File not found on disk: {file_path}")
|
||||
await self._update_processing_status(file_id, "failed")
|
||||
return
|
||||
|
||||
# Extract text content using DocumentProcessor public methods
|
||||
processor = DocumentProcessor()
|
||||
|
||||
text_content = await processor.extract_text_from_path(
|
||||
file_path,
|
||||
file_record['content_type']
|
||||
)
|
||||
|
||||
if not text_content:
|
||||
logger.warning(f"No text content extracted from {file_record['original_filename']}")
|
||||
await self._update_processing_status(file_id, "completed")
|
||||
return
|
||||
|
||||
# Chunk content for RAG
|
||||
chunks = await processor.chunk_text_simple(text_content)
|
||||
|
||||
# Generate embeddings for full document (single embedding for semantic search)
|
||||
embedding_client = BGE_M3_EmbeddingClient()
|
||||
embeddings = await embedding_client.generate_embeddings([text_content])
|
||||
|
||||
if not embeddings:
|
||||
logger.error(f"Failed to generate embeddings for {file_id}")
|
||||
await self._update_processing_status(file_id, "failed")
|
||||
return
|
||||
|
||||
# Update record with processed content (chunks as JSONB, embedding as vector)
|
||||
await self._update_file_processing_results(
|
||||
file_id, chunks, embeddings[0], "completed"
|
||||
)
|
||||
|
||||
logger.info(f"Successfully processed file: {file_record['original_filename']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process file {file_id}: {e}")
|
||||
await self._update_processing_status(file_id, "failed")
|
||||
|
||||
async def _update_processing_status(self, file_id: str, status: str):
|
||||
"""Update file processing status"""
|
||||
client = await get_postgresql_client()
|
||||
|
||||
query = f"""
|
||||
UPDATE {self.schema_name}.conversation_files
|
||||
SET processing_status = $1,
|
||||
processed_at = CASE WHEN $1 IN ('completed', 'failed') THEN NOW() ELSE processed_at END
|
||||
WHERE id = $2
|
||||
"""
|
||||
|
||||
await client.execute_query(query, status, file_id)
|
||||
|
||||
async def _update_file_processing_results(
|
||||
self,
|
||||
file_id: str,
|
||||
chunks: List[str],
|
||||
embedding: List[float],
|
||||
status: str
|
||||
):
|
||||
"""Update file with processing results"""
|
||||
import json
|
||||
client = await get_postgresql_client()
|
||||
|
||||
# Sanitize chunks: remove null bytes and other control characters
|
||||
# that PostgreSQL can't handle in JSONB
|
||||
sanitized_chunks = [
|
||||
chunk.replace('\u0000', '').replace('\x00', '')
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
# Convert chunks list to JSONB-compatible format
|
||||
chunks_json = json.dumps(sanitized_chunks)
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = f"[{','.join(map(str, embedding))}]"
|
||||
|
||||
query = f"""
|
||||
UPDATE {self.schema_name}.conversation_files
|
||||
SET processed_chunks = $1::jsonb,
|
||||
embeddings = $2::vector,
|
||||
processing_status = $3,
|
||||
processed_at = NOW()
|
||||
WHERE id = $4
|
||||
"""
|
||||
|
||||
await client.execute_query(query, chunks_json, embedding_str, status, file_id)
|
||||
|
||||
async def _get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get file record by ID"""
|
||||
client = await get_postgresql_client()
|
||||
|
||||
query = f"""
|
||||
SELECT id, conversation_id, filename, original_filename, content_type,
|
||||
file_size_bytes, file_path, processing_status, uploaded_at
|
||||
FROM {self.schema_name}.conversation_files
|
||||
WHERE id = $1
|
||||
"""
|
||||
|
||||
result = await client.fetch_one(query, file_id)
|
||||
return dict(result) if result else None
|
||||
|
||||
async def list_files(self, conversation_id: str) -> List[Dict[str, Any]]:
|
||||
"""List files attached to conversation"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
|
||||
query = f"""
|
||||
SELECT id, filename, original_filename, content_type, file_size_bytes,
|
||||
processing_status, uploaded_at, processed_at
|
||||
FROM {self.schema_name}.conversation_files
|
||||
WHERE conversation_id = $1
|
||||
ORDER BY uploaded_at DESC
|
||||
"""
|
||||
|
||||
rows = await client.execute_query(query, conversation_id)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list conversation files: {e}")
|
||||
return []
|
||||
|
||||
async def delete_file(self, conversation_id: str, file_id: str, user_id: str, allow_post_message_deletion: bool = False) -> bool:
|
||||
"""Delete specific file from conversation
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID
|
||||
file_id: The file ID to delete
|
||||
user_id: The user requesting deletion
|
||||
allow_post_message_deletion: If False, prevents deletion after messages exist (default: False)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"DELETE FILE CALLED: file_id={file_id}, conversation_id={conversation_id}, user_id={user_id}")
|
||||
|
||||
# Validate access
|
||||
await self._validate_conversation_access(conversation_id, user_id)
|
||||
logger.info(f"DELETE FILE: Access validated")
|
||||
|
||||
# Check if conversation has messages (unless explicitly allowed to delete post-message)
|
||||
if not allow_post_message_deletion:
|
||||
client = await get_postgresql_client()
|
||||
message_check_query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {self.schema_name}.messages
|
||||
WHERE conversation_id = $1
|
||||
"""
|
||||
message_count_result = await client.fetch_one(message_check_query, conversation_id)
|
||||
message_count = message_count_result['count'] if message_count_result else 0
|
||||
|
||||
if message_count > 0:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete files after conversation has started. Files are part of the conversation context."
|
||||
)
|
||||
|
||||
# Get file record for cleanup
|
||||
file_record = await self._get_file_record(file_id)
|
||||
logger.info(f"DELETE FILE: file_record={file_record}")
|
||||
if not file_record or str(file_record['conversation_id']) != conversation_id:
|
||||
logger.warning(f"DELETE FILE FAILED: file not found or conversation mismatch. file_record={file_record}, expected_conv_id={conversation_id}")
|
||||
return False
|
||||
|
||||
# Delete from database
|
||||
client = await get_postgresql_client()
|
||||
query = f"""
|
||||
DELETE FROM {self.schema_name}.conversation_files
|
||||
WHERE id = $1 AND conversation_id = $2
|
||||
"""
|
||||
|
||||
rows_deleted = await client.execute_command(query, file_id, conversation_id)
|
||||
|
||||
if rows_deleted > 0:
|
||||
# Delete file from disk
|
||||
file_path = Path(self.settings.file_storage_path) / file_record['file_path']
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
logger.info(f"Deleted conversation file: {file_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except HTTPException:
|
||||
raise # Re-raise HTTPException to preserve status code and message
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete conversation file: {e}")
|
||||
return False
|
||||
|
||||
async def search_conversation_files(
|
||||
self,
|
||||
conversation_id: str,
|
||||
query: str,
|
||||
max_results: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search files within a conversation using vector similarity"""
|
||||
try:
|
||||
# Generate query embedding
|
||||
embedding_client = BGE_M3_EmbeddingClient()
|
||||
embeddings = await embedding_client.generate_embeddings([query])
|
||||
|
||||
if not embeddings:
|
||||
return []
|
||||
|
||||
query_embedding = embeddings[0]
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = '[' + ','.join(map(str, query_embedding)) + ']'
|
||||
|
||||
# Vector search against conversation files
|
||||
client = await get_postgresql_client()
|
||||
|
||||
search_query = f"""
|
||||
SELECT id, filename, original_filename, processed_chunks,
|
||||
1 - (embeddings <=> $1::vector) as similarity_score
|
||||
FROM {self.schema_name}.conversation_files
|
||||
WHERE conversation_id = $2
|
||||
AND processing_status = 'completed'
|
||||
AND embeddings IS NOT NULL
|
||||
AND 1 - (embeddings <=> $1::vector) > 0.1
|
||||
ORDER BY embeddings <=> $1::vector
|
||||
LIMIT $3
|
||||
"""
|
||||
|
||||
rows = await client.execute_query(
|
||||
search_query, embedding_str, conversation_id, max_results
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
for row in rows:
|
||||
processed_chunks = row.get('processed_chunks', [])
|
||||
|
||||
if not processed_chunks:
|
||||
continue
|
||||
|
||||
# Handle case where processed_chunks might be returned as JSON string
|
||||
if isinstance(processed_chunks, str):
|
||||
import json
|
||||
processed_chunks = json.loads(processed_chunks)
|
||||
|
||||
for idx, chunk_text in enumerate(processed_chunks):
|
||||
results.append({
|
||||
'id': f"{row['id']}_chunk_{idx}",
|
||||
'document_id': row['id'],
|
||||
'document_name': row['original_filename'],
|
||||
'original_filename': row['original_filename'],
|
||||
'chunk_index': idx,
|
||||
'content': chunk_text,
|
||||
'similarity_score': row['similarity_score'],
|
||||
'source': 'conversation_file',
|
||||
'source_type': 'conversation_file'
|
||||
})
|
||||
|
||||
if len(results) >= max_results:
|
||||
results = results[:max_results]
|
||||
break
|
||||
|
||||
logger.info(f"Found {len(results)} chunks from {len(rows)} matching conversation files")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search conversation files: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_chunks_for_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
max_chunks_per_file: int = 50,
|
||||
max_total_chunks: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve ALL chunks from files attached to conversation.
|
||||
Non-query-dependent - returns everything up to limits.
|
||||
|
||||
Args:
|
||||
conversation_id: UUID of conversation
|
||||
max_chunks_per_file: Limit per file (enforces diversity)
|
||||
max_total_chunks: Total chunk limit across all files
|
||||
|
||||
Returns:
|
||||
List of chunks with metadata, grouped by file
|
||||
"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
|
||||
query = f"""
|
||||
SELECT id, filename, original_filename, processed_chunks,
|
||||
file_size_bytes, uploaded_at
|
||||
FROM {self.schema_name}.conversation_files
|
||||
WHERE conversation_id = $1
|
||||
AND processing_status = 'completed'
|
||||
AND processed_chunks IS NOT NULL
|
||||
ORDER BY uploaded_at ASC
|
||||
"""
|
||||
|
||||
rows = await client.execute_query(query, conversation_id)
|
||||
|
||||
results = []
|
||||
total_chunks = 0
|
||||
|
||||
for row in rows:
|
||||
if total_chunks >= max_total_chunks:
|
||||
break
|
||||
|
||||
processed_chunks = row.get('processed_chunks', [])
|
||||
|
||||
# Handle JSON string if needed
|
||||
if isinstance(processed_chunks, str):
|
||||
import json
|
||||
processed_chunks = json.loads(processed_chunks)
|
||||
|
||||
# Limit chunks per file (diversity enforcement)
|
||||
chunks_from_this_file = 0
|
||||
|
||||
for idx, chunk_text in enumerate(processed_chunks):
|
||||
if chunks_from_this_file >= max_chunks_per_file:
|
||||
break
|
||||
if total_chunks >= max_total_chunks:
|
||||
break
|
||||
|
||||
results.append({
|
||||
'id': f"{row['id']}_chunk_{idx}",
|
||||
'document_id': row['id'],
|
||||
'document_name': row['original_filename'],
|
||||
'original_filename': row['original_filename'],
|
||||
'chunk_index': idx,
|
||||
'total_chunks': len(processed_chunks),
|
||||
'content': chunk_text,
|
||||
'file_size_bytes': row['file_size_bytes'],
|
||||
'source': 'conversation_file',
|
||||
'source_type': 'conversation_file'
|
||||
})
|
||||
|
||||
chunks_from_this_file += 1
|
||||
total_chunks += 1
|
||||
|
||||
logger.info(f"Retrieved {len(results)} total chunks from {len(rows)} conversation files")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get all chunks for conversation: {e}")
|
||||
return []
|
||||
|
||||
async def _validate_conversation_access(self, conversation_id: str, user_id: str):
|
||||
"""Validate user has access to conversation"""
|
||||
client = await get_postgresql_client()
|
||||
|
||||
query = f"""
|
||||
SELECT id FROM {self.schema_name}.conversations
|
||||
WHERE id = $1 AND user_id = (
|
||||
SELECT id FROM {self.schema_name}.users WHERE email = $2 LIMIT 1
|
||||
)
|
||||
"""
|
||||
|
||||
result = await client.fetch_one(query, conversation_id, user_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied: conversation not found or access denied"
|
||||
)
|
||||
|
||||
async def get_file_content(self, file_id: str, user_id: str) -> Optional[bytes]:
|
||||
"""Get file content for download"""
|
||||
try:
|
||||
file_record = await self._get_file_record(file_id)
|
||||
if not file_record:
|
||||
return None
|
||||
|
||||
# Validate access to conversation
|
||||
await self._validate_conversation_access(file_record['conversation_id'], user_id)
|
||||
|
||||
# Read file content
|
||||
file_path = Path(self.settings.file_storage_path) / file_record['file_path']
|
||||
if file_path.exists():
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file content: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Factory function for service instances
|
||||
def get_conversation_file_service(tenant_domain: str, user_id: str) -> ConversationFileService:
|
||||
"""Get conversation file service instance"""
|
||||
return ConversationFileService(tenant_domain, user_id)
|
||||
959
apps/tenant-backend/app/services/conversation_service.py
Normal file
959
apps/tenant-backend/app/services/conversation_service.py
Normal file
@@ -0,0 +1,959 @@
|
||||
"""
|
||||
Conversation Service for GT 2.0 Tenant Backend - PostgreSQL + PGVector
|
||||
|
||||
Manages AI-powered conversations with Agent integration using PostgreSQL directly.
|
||||
Handles message persistence, context management, and LLM inference.
|
||||
Replaces SQLAlchemy with direct PostgreSQL operations for GT 2.0 principles.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional, AsyncIterator, AsyncGenerator
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.services.agent_service import AgentService
|
||||
from app.core.resource_client import ResourceClusterClient
|
||||
from app.services.conversation_summarizer import ConversationSummarizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationService:
|
||||
"""PostgreSQL-based service for managing AI conversations"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str):
|
||||
"""Initialize with tenant and user isolation using PostgreSQL"""
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.settings = get_settings()
|
||||
self.agent_service = AgentService(tenant_domain, user_id)
|
||||
self.resource_client = ResourceClusterClient()
|
||||
self._resolved_user_uuid = None # Cache for resolved user UUID
|
||||
|
||||
logger.info(f"Conversation service initialized with PostgreSQL for {tenant_domain}/{user_id}")
|
||||
|
||||
async def _get_resolved_user_uuid(self, user_identifier: Optional[str] = None) -> str:
|
||||
"""
|
||||
Resolve user identifier to UUID with caching for performance.
|
||||
|
||||
This optimization reduces repeated database lookups by caching the resolved UUID.
|
||||
Performance impact: ~50% reduction in query time for operations with multiple queries.
|
||||
"""
|
||||
identifier = user_identifier or self.user_id
|
||||
|
||||
# Return cached UUID if already resolved for this instance
|
||||
if self._resolved_user_uuid and identifier == self.user_id:
|
||||
return self._resolved_user_uuid
|
||||
|
||||
# Check if already a UUID
|
||||
if not "@" in identifier:
|
||||
try:
|
||||
# Validate it's a proper UUID format
|
||||
uuid.UUID(identifier)
|
||||
if identifier == self.user_id:
|
||||
self._resolved_user_uuid = identifier
|
||||
return identifier
|
||||
except ValueError:
|
||||
pass # Not a valid UUID, treat as email/username
|
||||
|
||||
# Resolve email to UUID
|
||||
pg_client = await get_postgresql_client()
|
||||
query = "SELECT id FROM users WHERE email = $1 LIMIT 1"
|
||||
result = await pg_client.fetch_one(query, identifier)
|
||||
|
||||
if not result:
|
||||
raise ValueError(f"User not found: {identifier}")
|
||||
|
||||
user_uuid = str(result["id"])
|
||||
|
||||
# Cache if this is the service's primary user
|
||||
if identifier == self.user_id:
|
||||
self._resolved_user_uuid = user_uuid
|
||||
|
||||
return user_uuid
|
||||
|
||||
def _get_user_clause(self, param_num: int, user_identifier: str) -> str:
|
||||
"""
|
||||
DEPRECATED: Get the appropriate SQL clause for user identification.
|
||||
Use _get_resolved_user_uuid() instead for better performance.
|
||||
"""
|
||||
if "@" in user_identifier:
|
||||
# Email - do lookup
|
||||
return f"(SELECT id FROM users WHERE email = ${param_num} LIMIT 1)"
|
||||
else:
|
||||
# UUID - use directly
|
||||
return f"${param_num}::uuid"
|
||||
|
||||
async def create_conversation(
|
||||
self,
|
||||
agent_id: str,
|
||||
title: Optional[str],
|
||||
user_identifier: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new conversation with an agent using PostgreSQL"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
# Get agent configuration
|
||||
agent_data = await self.agent_service.get_agent(agent_id)
|
||||
if not agent_data:
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
# Validate tenant has access to the agent's model
|
||||
agent_model = agent_data.get("model")
|
||||
if agent_model:
|
||||
available_models = await self.get_available_models(self.tenant_domain)
|
||||
available_model_ids = [m["model_id"] for m in available_models]
|
||||
|
||||
if agent_model not in available_model_ids:
|
||||
raise ValueError(f"Agent model '{agent_model}' is not accessible to tenant '{self.tenant_domain}'. Available models: {', '.join(available_model_ids)}")
|
||||
|
||||
logger.info(f"Validated tenant access to model '{agent_model}' for agent '{agent_data.get('name')}'")
|
||||
else:
|
||||
logger.warning(f"Agent {agent_id} has no model configured, will use default")
|
||||
|
||||
# Get PostgreSQL client
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Generate conversation ID
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
# Create conversation in PostgreSQL (optimized: use resolved UUID directly)
|
||||
query = """
|
||||
INSERT INTO conversations (
|
||||
id, title, tenant_id, user_id, agent_id, summary,
|
||||
total_messages, total_tokens, metadata, is_archived,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
$1, $2,
|
||||
(SELECT id FROM tenants WHERE domain = $3 LIMIT 1),
|
||||
$4::uuid,
|
||||
$5, '', 0, 0, '{}', false, NOW(), NOW()
|
||||
)
|
||||
RETURNING id, title, tenant_id, user_id, agent_id, created_at, updated_at
|
||||
"""
|
||||
|
||||
conv_title = title or f"Conversation with {agent_data.get('name', 'Agent')}"
|
||||
|
||||
conversation_data = await pg_client.fetch_one(
|
||||
query,
|
||||
conversation_id, conv_title, self.tenant_domain,
|
||||
user_uuid, agent_id
|
||||
)
|
||||
|
||||
if not conversation_data:
|
||||
raise RuntimeError("Failed to create conversation - no data returned")
|
||||
|
||||
# Note: conversation_settings and conversation_participants are now created automatically
|
||||
# by the auto_create_conversation_settings trigger, so we don't need to create them manually
|
||||
|
||||
# Get the model_id from the auto-created settings or use agent's model
|
||||
settings_query = """
|
||||
SELECT model_id FROM conversation_settings WHERE conversation_id = $1
|
||||
"""
|
||||
settings_data = await pg_client.fetch_one(settings_query, conversation_id)
|
||||
model_id = settings_data["model_id"] if settings_data else agent_model
|
||||
|
||||
result = {
|
||||
"id": str(conversation_data["id"]),
|
||||
"title": conversation_data["title"],
|
||||
"agent_id": str(conversation_data["agent_id"]),
|
||||
"model_id": model_id,
|
||||
"created_at": conversation_data["created_at"].isoformat(),
|
||||
"user_id": user_uuid,
|
||||
"tenant_domain": self.tenant_domain
|
||||
}
|
||||
|
||||
logger.info(f"Created conversation {conversation_id} in PostgreSQL for user {user_uuid}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create conversation: {e}")
|
||||
raise
|
||||
|
||||
async def list_conversations(
|
||||
self,
|
||||
user_identifier: str,
|
||||
agent_id: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
time_filter: str = "all",
|
||||
limit: int = 20,
|
||||
offset: int = 0
|
||||
) -> Dict[str, Any]:
|
||||
"""List conversations for a user using PostgreSQL with server-side filtering"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Build query with optional filters - exclude archived conversations (optimized: use cached UUID)
|
||||
where_clause = "WHERE c.user_id = $1::uuid AND c.is_archived = false"
|
||||
params = [user_uuid]
|
||||
param_count = 1
|
||||
|
||||
# Time filter
|
||||
if time_filter != "all":
|
||||
if time_filter == "today":
|
||||
where_clause += " AND c.updated_at >= NOW() - INTERVAL '1 day'"
|
||||
elif time_filter == "week":
|
||||
where_clause += " AND c.updated_at >= NOW() - INTERVAL '7 days'"
|
||||
elif time_filter == "month":
|
||||
where_clause += " AND c.updated_at >= NOW() - INTERVAL '30 days'"
|
||||
|
||||
# Agent filter
|
||||
if agent_id:
|
||||
param_count += 1
|
||||
where_clause += f" AND c.agent_id = ${param_count}"
|
||||
params.append(agent_id)
|
||||
|
||||
# Search filter (case-insensitive title search)
|
||||
if search:
|
||||
param_count += 1
|
||||
where_clause += f" AND c.title ILIKE ${param_count}"
|
||||
params.append(f"%{search}%")
|
||||
|
||||
# Get conversations with agent info and unread counts (optimized: use cached UUID)
|
||||
query = f"""
|
||||
SELECT
|
||||
c.id, c.title, c.agent_id, c.created_at, c.updated_at,
|
||||
c.total_messages, c.total_tokens, c.is_archived,
|
||||
a.name as agent_name,
|
||||
COUNT(m.id) FILTER (
|
||||
WHERE m.created_at > COALESCE((c.metadata->>'last_read_at')::timestamptz, c.created_at)
|
||||
AND m.user_id != $1::uuid
|
||||
) as unread_count
|
||||
FROM conversations c
|
||||
LEFT JOIN agents a ON c.agent_id = a.id
|
||||
LEFT JOIN messages m ON m.conversation_id = c.id
|
||||
{where_clause}
|
||||
GROUP BY c.id, c.title, c.agent_id, c.created_at, c.updated_at,
|
||||
c.total_messages, c.total_tokens, c.is_archived, a.name
|
||||
ORDER BY
|
||||
CASE WHEN COUNT(m.id) FILTER (
|
||||
WHERE m.created_at > COALESCE((c.metadata->>'last_read_at')::timestamptz, c.created_at)
|
||||
AND m.user_id != $1::uuid
|
||||
) > 0 THEN 0 ELSE 1 END,
|
||||
c.updated_at DESC
|
||||
LIMIT ${param_count + 1} OFFSET ${param_count + 2}
|
||||
"""
|
||||
params.extend([limit, offset])
|
||||
|
||||
conversations = await pg_client.execute_query(query, *params)
|
||||
|
||||
# Get total count
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as total
|
||||
FROM conversations c
|
||||
{where_clause}
|
||||
"""
|
||||
count_result = await pg_client.fetch_one(count_query, *params[:-2]) # Exclude limit/offset
|
||||
total = count_result["total"] if count_result else 0
|
||||
|
||||
# Format results with lightweight fields including unread count
|
||||
conversation_list = [
|
||||
{
|
||||
"id": str(conv["id"]),
|
||||
"title": conv["title"],
|
||||
"agent_id": str(conv["agent_id"]) if conv["agent_id"] else None,
|
||||
"agent_name": conv["agent_name"] or "AI Assistant",
|
||||
"created_at": conv["created_at"].isoformat(),
|
||||
"updated_at": conv["updated_at"].isoformat(),
|
||||
"last_message_at": conv["updated_at"].isoformat(), # Use updated_at as last activity
|
||||
"message_count": conv["total_messages"] or 0,
|
||||
"token_count": conv["total_tokens"] or 0,
|
||||
"is_archived": conv["is_archived"],
|
||||
"unread_count": conv.get("unread_count", 0) or 0 # Include unread count
|
||||
# Removed preview field for performance
|
||||
}
|
||||
for conv in conversations
|
||||
]
|
||||
|
||||
return {
|
||||
"conversations": conversation_list,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list conversations: {e}")
|
||||
raise
|
||||
|
||||
async def get_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get a specific conversation with details"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
c.id, c.title, c.agent_id, c.created_at, c.updated_at,
|
||||
c.total_messages, c.total_tokens, c.is_archived, c.summary,
|
||||
a.name as agent_name,
|
||||
cs.model_id, cs.temperature, cs.max_tokens, cs.system_prompt
|
||||
FROM conversations c
|
||||
LEFT JOIN agents a ON c.agent_id = a.id
|
||||
LEFT JOIN conversation_settings cs ON c.id = cs.conversation_id
|
||||
WHERE c.id = $1
|
||||
AND c.user_id = $2::uuid
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
conversation = await pg_client.fetch_one(query, conversation_id, user_uuid)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": conversation["id"],
|
||||
"title": conversation["title"],
|
||||
"agent_id": conversation["agent_id"],
|
||||
"agent_name": conversation["agent_name"],
|
||||
"model_id": conversation["model_id"],
|
||||
"temperature": float(conversation["temperature"]) if conversation["temperature"] else 0.7,
|
||||
"max_tokens": conversation["max_tokens"],
|
||||
"system_prompt": conversation["system_prompt"],
|
||||
"summary": conversation["summary"],
|
||||
"message_count": conversation["total_messages"],
|
||||
"token_count": conversation["total_tokens"],
|
||||
"is_archived": conversation["is_archived"],
|
||||
"created_at": conversation["created_at"].isoformat(),
|
||||
"updated_at": conversation["updated_at"].isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation {conversation_id}: {e}")
|
||||
return None
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
user_identifier: str,
|
||||
model_used: Optional[str] = None,
|
||||
token_count: int = 0,
|
||||
metadata: Optional[Dict] = None,
|
||||
attachments: Optional[List] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a message to a conversation"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# Insert message (optimized: use cached UUID)
|
||||
query = """
|
||||
INSERT INTO messages (
|
||||
id, conversation_id, user_id, role, content,
|
||||
content_type, token_count, model_used, metadata, attachments, created_at
|
||||
) VALUES (
|
||||
$1, $2, $3::uuid,
|
||||
$4, $5, 'text', $6, $7, $8, $9, NOW()
|
||||
)
|
||||
RETURNING id, created_at
|
||||
"""
|
||||
|
||||
message_data = await pg_client.fetch_one(
|
||||
query,
|
||||
message_id, conversation_id, user_uuid,
|
||||
role, content, token_count, model_used,
|
||||
json.dumps(metadata or {}), json.dumps(attachments or [])
|
||||
)
|
||||
|
||||
if not message_data:
|
||||
raise RuntimeError("Failed to add message - no data returned")
|
||||
|
||||
# Update conversation totals (optimized: use cached UUID)
|
||||
update_query = """
|
||||
UPDATE conversations
|
||||
SET total_messages = total_messages + 1,
|
||||
total_tokens = total_tokens + $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND user_id = $2::uuid
|
||||
"""
|
||||
|
||||
await pg_client.execute_command(update_query, conversation_id, user_uuid, token_count)
|
||||
|
||||
result = {
|
||||
"id": message_data["id"],
|
||||
"conversation_id": conversation_id,
|
||||
"role": role,
|
||||
"content": content,
|
||||
"token_count": token_count,
|
||||
"model_used": model_used,
|
||||
"metadata": metadata or {},
|
||||
"attachments": attachments or [],
|
||||
"created_at": message_data["created_at"].isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Added message {message_id} to conversation {conversation_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add message to conversation {conversation_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_messages(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get messages for a conversation"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
m.id, m.role, m.content, m.content_type, m.token_count,
|
||||
m.model_used, m.finish_reason, m.metadata, m.attachments, m.created_at
|
||||
FROM messages m
|
||||
JOIN conversations c ON m.conversation_id = c.id
|
||||
WHERE c.id = $1
|
||||
AND c.user_id = $2::uuid
|
||||
ORDER BY m.created_at ASC
|
||||
LIMIT $3 OFFSET $4
|
||||
"""
|
||||
|
||||
messages = await pg_client.execute_query(query, conversation_id, user_uuid, limit, offset)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": msg["id"],
|
||||
"role": msg["role"],
|
||||
"content": msg["content"],
|
||||
"content_type": msg["content_type"],
|
||||
"token_count": msg["token_count"],
|
||||
"model_used": msg["model_used"],
|
||||
"finish_reason": msg["finish_reason"],
|
||||
"metadata": (
|
||||
json.loads(msg["metadata"]) if isinstance(msg["metadata"], str)
|
||||
else (msg["metadata"] if isinstance(msg["metadata"], dict) else {})
|
||||
),
|
||||
"attachments": (
|
||||
json.loads(msg["attachments"]) if isinstance(msg["attachments"], str)
|
||||
else (msg["attachments"] if isinstance(msg["attachments"], list) else [])
|
||||
),
|
||||
"context_sources": (
|
||||
(json.loads(msg["metadata"]) if isinstance(msg["metadata"], str) else msg["metadata"]).get("context_sources", [])
|
||||
if (isinstance(msg["metadata"], str) or isinstance(msg["metadata"], dict))
|
||||
else []
|
||||
),
|
||||
"created_at": msg["created_at"].isoformat()
|
||||
}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages for conversation {conversation_id}: {e}")
|
||||
return []
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
content: str,
|
||||
user_identifier: Optional[str] = None,
|
||||
stream: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a message to conversation and get AI response"""
|
||||
user_id = user_identifier or self.user_id
|
||||
|
||||
# Check if this is the first message
|
||||
existing_messages = await self.get_messages(conversation_id, user_id)
|
||||
is_first_message = len(existing_messages) == 0
|
||||
|
||||
# Add user message
|
||||
user_message = await self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=content,
|
||||
user_identifier=user_identifier
|
||||
)
|
||||
|
||||
# Get conversation details for agent
|
||||
conversation = await self.get_conversation(conversation_id, user_identifier)
|
||||
agent_id = conversation.get("agent_id")
|
||||
|
||||
ai_message = None
|
||||
if agent_id:
|
||||
agent_data = await self.agent_service.get_agent(agent_id)
|
||||
|
||||
# Prepare messages for AI
|
||||
messages = [
|
||||
{"role": "system", "content": agent_data.get("prompt_template", "You are a helpful assistant.")},
|
||||
{"role": "user", "content": content}
|
||||
]
|
||||
|
||||
# Get AI response
|
||||
ai_response = await self.get_ai_response(
|
||||
model=agent_data.get("model", "llama-3.1-8b-instant"),
|
||||
messages=messages,
|
||||
tenant_id=self.tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Extract content from response
|
||||
ai_content = ai_response["choices"][0]["message"]["content"]
|
||||
|
||||
# Add AI message
|
||||
ai_message = await self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="agent",
|
||||
content=ai_content,
|
||||
user_identifier=user_id,
|
||||
model_used=agent_data.get("model"),
|
||||
token_count=ai_response["usage"]["total_tokens"]
|
||||
)
|
||||
|
||||
return {
|
||||
"user_message": user_message,
|
||||
"ai_message": ai_message,
|
||||
"is_first_message": is_first_message,
|
||||
"conversation_id": conversation_id
|
||||
}
|
||||
|
||||
async def update_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str,
|
||||
title: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Update conversation properties like title"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Build dynamic update query based on provided fields
|
||||
update_fields = []
|
||||
params = []
|
||||
param_count = 1
|
||||
|
||||
if title is not None:
|
||||
update_fields.append(f"title = ${param_count}")
|
||||
params.append(title)
|
||||
param_count += 1
|
||||
|
||||
if not update_fields:
|
||||
return True # Nothing to update
|
||||
|
||||
# Add updated_at timestamp
|
||||
update_fields.append(f"updated_at = NOW()")
|
||||
|
||||
query = f"""
|
||||
UPDATE conversations
|
||||
SET {', '.join(update_fields)}
|
||||
WHERE id = ${param_count}
|
||||
AND user_id = ${param_count + 1}::uuid
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
params.extend([conversation_id, user_uuid])
|
||||
|
||||
result = await pg_client.fetch_scalar(query, *params)
|
||||
|
||||
if result:
|
||||
logger.info(f"Updated conversation {conversation_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update conversation {conversation_id}: {e}")
|
||||
return False
|
||||
|
||||
async def auto_generate_conversation_title(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str
|
||||
) -> Optional[str]:
|
||||
"""Generate conversation title based on first user prompt and agent response pair"""
|
||||
try:
|
||||
# Get only the first few messages (first exchange)
|
||||
messages = await self.get_messages(conversation_id, user_identifier, limit=2)
|
||||
|
||||
if not messages or len(messages) < 2:
|
||||
return None # Need at least one user-agent exchange
|
||||
|
||||
# Only use first user message and first agent response for title
|
||||
first_exchange = messages[:2]
|
||||
|
||||
# Generate title using the summarization service
|
||||
from app.services.conversation_summarizer import generate_conversation_title
|
||||
new_title = await generate_conversation_title(first_exchange, self.tenant_domain, user_identifier)
|
||||
|
||||
# Update the conversation with the generated title
|
||||
success = await self.update_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=user_identifier,
|
||||
title=new_title
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Auto-generated title '{new_title}' for conversation {conversation_id} based on first exchange")
|
||||
return new_title
|
||||
else:
|
||||
logger.warning(f"Failed to update conversation {conversation_id} with generated title")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to auto-generate title for conversation {conversation_id}: {e}")
|
||||
return None
|
||||
|
||||
async def delete_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str
|
||||
) -> bool:
|
||||
"""Soft delete a conversation (archive it)"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
query = """
|
||||
UPDATE conversations
|
||||
SET is_archived = true, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND user_id = $2::uuid
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
result = await pg_client.fetch_scalar(query, conversation_id, user_uuid)
|
||||
|
||||
if result:
|
||||
logger.info(f"Archived conversation {conversation_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to archive conversation {conversation_id}: {e}")
|
||||
return False
|
||||
|
||||
async def get_ai_response(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get AI response from Resource Cluster"""
|
||||
try:
|
||||
# Prepare request for Resource Cluster
|
||||
request_data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": top_p
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
request_data["tools"] = tools
|
||||
if tool_choice:
|
||||
request_data["tool_choice"] = tool_choice
|
||||
|
||||
# Call Resource Cluster AI inference endpoint
|
||||
response = await self.resource_client.call_inference_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint="chat/completions",
|
||||
data=request_data
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get AI response: {e}")
|
||||
raise
|
||||
|
||||
# Streaming removed for reliability - using non-streaming only
|
||||
|
||||
async def get_available_models(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get available models for tenant from Resource Cluster"""
|
||||
try:
|
||||
# Get models dynamically from Resource Cluster
|
||||
import aiohttp
|
||||
|
||||
resource_cluster_url = self.resource_client.base_url
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Get capability token for model access
|
||||
token = await self.resource_client._get_capability_token(
|
||||
tenant_id=tenant_id,
|
||||
user_id=self.user_id,
|
||||
resources=['model_registry']
|
||||
)
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Content-Type': 'application/json',
|
||||
'X-Tenant-ID': tenant_id,
|
||||
'X-User-ID': self.user_id
|
||||
}
|
||||
|
||||
async with session.get(
|
||||
f"{resource_cluster_url}/api/v1/models/",
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
|
||||
if response.status == 200:
|
||||
response_data = await response.json()
|
||||
models_data = response_data.get("models", [])
|
||||
|
||||
# Transform Resource Cluster model format to frontend format
|
||||
available_models = []
|
||||
for model in models_data:
|
||||
# Only include available models
|
||||
if model.get("status", {}).get("deployment") == "available":
|
||||
available_models.append({
|
||||
"id": model.get("uuid"), # Database UUID for unique identification
|
||||
"model_id": model["id"], # model_id string for API calls
|
||||
"name": model["name"],
|
||||
"provider": model["provider"],
|
||||
"model_type": model["model_type"],
|
||||
"context_window": model.get("performance", {}).get("context_window", 4000),
|
||||
"max_tokens": model.get("performance", {}).get("max_tokens", 4000),
|
||||
"performance": model.get("performance", {}), # Include full performance for chat.py
|
||||
"capabilities": {"chat": True} # All LLM models support chat
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(available_models)} models from Resource Cluster")
|
||||
return available_models
|
||||
else:
|
||||
logger.error(f"Resource Cluster returned {response.status}: {await response.text()}")
|
||||
raise RuntimeError(f"Resource Cluster API error: {response.status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get models from Resource Cluster: {e}")
|
||||
raise
|
||||
|
||||
async def get_conversation_datasets(self, conversation_id: str, user_identifier: str) -> List[str]:
|
||||
"""Get dataset IDs attached to a conversation"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Ensure proper schema qualification
|
||||
schema_name = f"tenant_{self.tenant_domain.replace('.', '_').replace('-', '_')}"
|
||||
|
||||
query = f"""
|
||||
SELECT cd.dataset_id
|
||||
FROM {schema_name}.conversations c
|
||||
JOIN {schema_name}.conversation_datasets cd ON cd.conversation_id = c.id
|
||||
WHERE c.id = $1
|
||||
AND c.user_id = (SELECT id FROM {schema_name}.users WHERE email = $2 LIMIT 1)
|
||||
AND cd.is_active = true
|
||||
ORDER BY cd.attached_at ASC
|
||||
"""
|
||||
|
||||
rows = await pg_client.execute_query(query, conversation_id, user_identifier)
|
||||
dataset_ids = [str(row['dataset_id']) for row in rows]
|
||||
|
||||
logger.info(f"Found {len(dataset_ids)} datasets for conversation {conversation_id}")
|
||||
return dataset_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation datasets: {e}")
|
||||
return []
|
||||
|
||||
async def add_datasets_to_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str,
|
||||
dataset_ids: List[str],
|
||||
source: str = "user_selected"
|
||||
) -> bool:
|
||||
"""Add datasets to a conversation"""
|
||||
try:
|
||||
if not dataset_ids:
|
||||
return True
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Ensure proper schema qualification
|
||||
schema_name = f"tenant_{self.tenant_domain.replace('.', '_').replace('-', '_')}"
|
||||
|
||||
# Get user ID first
|
||||
user_query = f"SELECT id FROM {schema_name}.users WHERE email = $1 LIMIT 1"
|
||||
user_result = await pg_client.fetch_scalar(user_query, user_identifier)
|
||||
|
||||
if not user_result:
|
||||
logger.error(f"User not found: {user_identifier}")
|
||||
return False
|
||||
|
||||
user_id = user_result
|
||||
|
||||
# Insert dataset attachments (ON CONFLICT DO NOTHING to avoid duplicates)
|
||||
values_list = []
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
for dataset_id in dataset_ids:
|
||||
values_list.append(f"(${param_idx}, ${param_idx + 1}, ${param_idx + 2})")
|
||||
params.extend([conversation_id, dataset_id, user_id])
|
||||
param_idx += 3
|
||||
|
||||
query = f"""
|
||||
INSERT INTO {schema_name}.conversation_datasets (conversation_id, dataset_id, attached_by)
|
||||
VALUES {', '.join(values_list)}
|
||||
ON CONFLICT (conversation_id, dataset_id) DO UPDATE SET
|
||||
is_active = true,
|
||||
attached_at = NOW()
|
||||
"""
|
||||
|
||||
await pg_client.execute_query(query, *params)
|
||||
|
||||
logger.info(f"Added {len(dataset_ids)} datasets to conversation {conversation_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add datasets to conversation: {e}")
|
||||
return False
|
||||
|
||||
async def copy_agent_datasets_to_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str,
|
||||
agent_id: str
|
||||
) -> bool:
|
||||
"""Copy an agent's default datasets to a new conversation"""
|
||||
try:
|
||||
# Get agent's selected dataset IDs from config
|
||||
from app.services.agent_service import AgentService
|
||||
agent_service = AgentService(self.tenant_domain, user_identifier)
|
||||
agent_data = await agent_service.get_agent(agent_id)
|
||||
|
||||
if not agent_data:
|
||||
logger.warning(f"Agent {agent_id} not found")
|
||||
return False
|
||||
|
||||
# Get selected_dataset_ids from agent config
|
||||
selected_dataset_ids = agent_data.get('selected_dataset_ids', [])
|
||||
|
||||
if not selected_dataset_ids:
|
||||
logger.info(f"Agent {agent_id} has no default datasets")
|
||||
return True
|
||||
|
||||
# Add agent's datasets to conversation
|
||||
success = await self.add_datasets_to_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=user_identifier,
|
||||
dataset_ids=selected_dataset_ids,
|
||||
source="agent_default"
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Copied {len(selected_dataset_ids)} datasets from agent {agent_id} to conversation {conversation_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to copy agent datasets: {e}")
|
||||
return False
|
||||
|
||||
async def get_recent_conversations(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""Get recent conversations ordered by last activity"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Handle both email and UUID formats using existing pattern
|
||||
user_clause = self._get_user_clause(1, user_id)
|
||||
|
||||
query = f"""
|
||||
SELECT c.id, c.title, c.created_at, c.updated_at,
|
||||
COUNT(m.id) as message_count,
|
||||
MAX(m.created_at) as last_message_at,
|
||||
a.name as agent_name
|
||||
FROM conversations c
|
||||
LEFT JOIN messages m ON m.conversation_id = c.id
|
||||
LEFT JOIN agents a ON a.id = c.agent_id
|
||||
WHERE c.user_id = {user_clause}
|
||||
AND c.is_archived = false
|
||||
GROUP BY c.id, c.title, c.created_at, c.updated_at, a.name
|
||||
ORDER BY COALESCE(MAX(m.created_at), c.created_at) DESC
|
||||
LIMIT $2
|
||||
"""
|
||||
|
||||
rows = await pg_client.execute_query(query, user_id, limit)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get recent conversations: {e}")
|
||||
return []
|
||||
|
||||
async def mark_conversation_read(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_identifier: str
|
||||
) -> bool:
|
||||
"""
|
||||
Mark a conversation as read by updating last_read_at in metadata.
|
||||
|
||||
Args:
|
||||
conversation_id: UUID of the conversation
|
||||
user_identifier: User email or UUID
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Resolve user UUID with caching (performance optimization)
|
||||
user_uuid = await self._get_resolved_user_uuid(user_identifier)
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Update last_read_at in conversation metadata
|
||||
query = """
|
||||
UPDATE conversations
|
||||
SET metadata = jsonb_set(
|
||||
COALESCE(metadata, '{}'::jsonb),
|
||||
'{last_read_at}',
|
||||
to_jsonb(NOW()::text)
|
||||
)
|
||||
WHERE id = $1
|
||||
AND user_id = $2::uuid
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
result = await pg_client.fetch_one(query, conversation_id, user_uuid)
|
||||
|
||||
if result:
|
||||
logger.info(f"Marked conversation {conversation_id} as read for user {user_identifier}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Conversation {conversation_id} not found or access denied for user {user_identifier}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark conversation as read: {e}")
|
||||
return False
|
||||
200
apps/tenant-backend/app/services/conversation_summarizer.py
Normal file
200
apps/tenant-backend/app/services/conversation_summarizer.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Conversation Summarization Service for GT 2.0
|
||||
|
||||
Automatically generates meaningful conversation titles using a specialized
|
||||
summarization agent with llama-3.1-8b-instant.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.resource_client import ResourceClusterClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class ConversationSummarizer:
|
||||
"""Service for generating conversation summaries and titles"""
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.resource_client = ResourceClusterClient()
|
||||
self.summarization_model = "llama-3.1-8b-instant"
|
||||
|
||||
async def generate_conversation_title(self, messages: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Generate a concise conversation title based on the conversation content.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries from the conversation
|
||||
|
||||
Returns:
|
||||
Generated conversation title (3-6 words)
|
||||
"""
|
||||
try:
|
||||
# Extract conversation context for summarization
|
||||
conversation_text = self._prepare_conversation_context(messages)
|
||||
|
||||
if not conversation_text.strip():
|
||||
return "New Conversation"
|
||||
|
||||
# Generate title using specialized summarization prompt
|
||||
title = await self._call_summarization_agent(conversation_text)
|
||||
|
||||
# Validate and clean the generated title
|
||||
clean_title = self._clean_title(title)
|
||||
|
||||
logger.info(f"Generated conversation title: '{clean_title}' from {len(messages)} messages")
|
||||
return clean_title
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating conversation title: {e}")
|
||||
return self._fallback_title(messages)
|
||||
|
||||
def _prepare_conversation_context(self, messages: List[Dict[str, Any]]) -> str:
|
||||
"""Prepare conversation context for summarization"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Limit to first few exchanges for title generation
|
||||
context_messages = messages[:6] # First 3 user-agent exchanges
|
||||
|
||||
context_parts = []
|
||||
for msg in context_messages:
|
||||
role = "User" if msg.get("role") == "user" else "Agent"
|
||||
# Truncate very long messages for context
|
||||
content = msg.get("content", "")
|
||||
content = content[:200] + "..." if len(content) > 200 else content
|
||||
context_parts.append(f"{role}: {content}")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
async def _call_summarization_agent(self, conversation_text: str) -> str:
|
||||
"""Call the resource cluster AI inference for summarization"""
|
||||
|
||||
summarization_prompt = f"""You are a conversation title generator. Your job is to create concise, descriptive titles for conversations.
|
||||
|
||||
Given this conversation:
|
||||
---
|
||||
{conversation_text}
|
||||
---
|
||||
|
||||
Generate a title that:
|
||||
- Is 3-6 words maximum
|
||||
- Captures the main topic or purpose
|
||||
- Is clear and descriptive
|
||||
- Uses title case
|
||||
- Does NOT include quotes or special characters
|
||||
|
||||
Examples of good titles:
|
||||
- "Python Code Review"
|
||||
- "Database Migration Help"
|
||||
- "React Component Design"
|
||||
- "System Architecture Discussion"
|
||||
|
||||
Title:"""
|
||||
|
||||
request_data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": summarization_prompt
|
||||
}
|
||||
],
|
||||
"model": self.summarization_model,
|
||||
"temperature": 0.3, # Lower temperature for consistent, focused titles
|
||||
"max_tokens": 20, # Short response for title generation
|
||||
"stream": False
|
||||
}
|
||||
|
||||
try:
|
||||
# Use the resource client instead of direct HTTP calls
|
||||
result = await self.resource_client.call_inference_endpoint(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
endpoint="chat/completions",
|
||||
data=request_data
|
||||
)
|
||||
|
||||
if result and "choices" in result and len(result["choices"]) > 0:
|
||||
title = result["choices"][0]["message"]["content"].strip()
|
||||
return title
|
||||
else:
|
||||
logger.error("Invalid response format from summarization API")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling summarization agent: {e}")
|
||||
return ""
|
||||
|
||||
def _clean_title(self, raw_title: str) -> str:
|
||||
"""Clean and validate the generated title"""
|
||||
if not raw_title:
|
||||
return "New Conversation"
|
||||
|
||||
# Remove quotes, extra whitespace, and special characters
|
||||
cleaned = raw_title.strip().strip('"\'').strip()
|
||||
|
||||
# Remove common prefixes that AI might add
|
||||
prefixes_to_remove = [
|
||||
"Title:", "title:", "TITLE:",
|
||||
"Conversation:", "conversation:",
|
||||
"Topic:", "topic:",
|
||||
"Subject:", "subject:"
|
||||
]
|
||||
|
||||
for prefix in prefixes_to_remove:
|
||||
if cleaned.startswith(prefix):
|
||||
cleaned = cleaned[len(prefix):].strip()
|
||||
|
||||
# Limit length and ensure it's reasonable
|
||||
if len(cleaned) > 50:
|
||||
cleaned = cleaned[:47] + "..."
|
||||
|
||||
# Ensure it's not empty after cleaning
|
||||
if not cleaned or len(cleaned.split()) > 8:
|
||||
return "New Conversation"
|
||||
|
||||
return cleaned
|
||||
|
||||
def _fallback_title(self, messages: List[Dict[str, Any]]) -> str:
|
||||
"""Generate fallback title when AI summarization fails"""
|
||||
if not messages:
|
||||
return "New Conversation"
|
||||
|
||||
# Try to use the first user message for context
|
||||
first_user_msg = next((msg for msg in messages if msg.get("role") == "user"), None)
|
||||
|
||||
if first_user_msg and first_user_msg.get("content"):
|
||||
# Extract first few words from the user's message
|
||||
words = first_user_msg["content"].strip().split()[:4]
|
||||
if len(words) >= 2:
|
||||
fallback = " ".join(words).capitalize()
|
||||
# Remove common question words for cleaner titles
|
||||
for word in ["How", "What", "Can", "Could", "Please", "Help"]:
|
||||
if fallback.startswith(word + " "):
|
||||
fallback = fallback[len(word):].strip()
|
||||
break
|
||||
return fallback if fallback else "New Conversation"
|
||||
|
||||
return "New Conversation"
|
||||
|
||||
|
||||
async def generate_conversation_title(messages: List[Dict[str, Any]], tenant_id: str, user_id: str) -> str:
|
||||
"""
|
||||
Convenience function to generate a conversation title.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries from the conversation
|
||||
tenant_id: Tenant identifier
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
Generated conversation title
|
||||
"""
|
||||
summarizer = ConversationSummarizer(tenant_id, user_id)
|
||||
return await summarizer.generate_conversation_title(messages)
|
||||
1064
apps/tenant-backend/app/services/dataset_service.py
Normal file
1064
apps/tenant-backend/app/services/dataset_service.py
Normal file
File diff suppressed because it is too large
Load Diff
585
apps/tenant-backend/app/services/dataset_sharing.py
Normal file
585
apps/tenant-backend/app/services/dataset_sharing.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""
|
||||
Dataset Sharing Service for GT 2.0
|
||||
|
||||
Implements hierarchical dataset sharing with perfect tenant isolation.
|
||||
Enables secure data collaboration while maintaining ownership and access control.
|
||||
"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from uuid import uuid4
|
||||
|
||||
from app.models.access_group import AccessGroup, Resource
|
||||
from app.services.access_controller import AccessController
|
||||
from app.core.security import verify_capability_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharingPermission(Enum):
|
||||
"""Sharing permission levels"""
|
||||
READ = "read" # Can view and search dataset
|
||||
WRITE = "write" # Can add documents
|
||||
ADMIN = "admin" # Can modify sharing settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetShare:
|
||||
"""Dataset sharing configuration"""
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
dataset_id: str = ""
|
||||
owner_id: str = ""
|
||||
access_group: AccessGroup = AccessGroup.INDIVIDUAL
|
||||
team_members: List[str] = field(default_factory=list)
|
||||
team_permissions: Dict[str, SharingPermission] = field(default_factory=dict)
|
||||
shared_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: Optional[datetime] = None
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"dataset_id": self.dataset_id,
|
||||
"owner_id": self.owner_id,
|
||||
"access_group": self.access_group.value,
|
||||
"team_members": self.team_members,
|
||||
"team_permissions": {k: v.value for k, v in self.team_permissions.items()},
|
||||
"shared_at": self.shared_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"is_active": self.is_active
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DatasetShare":
|
||||
"""Create from dictionary"""
|
||||
return cls(
|
||||
id=data.get("id", str(uuid4())),
|
||||
dataset_id=data["dataset_id"],
|
||||
owner_id=data["owner_id"],
|
||||
access_group=AccessGroup(data["access_group"]),
|
||||
team_members=data.get("team_members", []),
|
||||
team_permissions={
|
||||
k: SharingPermission(v) for k, v in data.get("team_permissions", {}).items()
|
||||
},
|
||||
shared_at=datetime.fromisoformat(data["shared_at"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
|
||||
is_active=data.get("is_active", True)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
"""Dataset information for sharing"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
owner_id: str
|
||||
document_count: int
|
||||
size_bytes: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
tags: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class DatasetSharingService:
|
||||
"""
|
||||
Service for hierarchical dataset sharing with capability-based access control.
|
||||
|
||||
Features:
|
||||
- Individual, Team, and Organization level sharing
|
||||
- Granular permission management (read, write, admin)
|
||||
- Time-based expiration of shares
|
||||
- Perfect tenant isolation through file-based storage
|
||||
- Event emission for sharing activities
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, access_controller: AccessController):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.access_controller = access_controller
|
||||
self.base_path = Path(f"/data/{tenant_domain}/dataset_sharing")
|
||||
self.shares_path = self.base_path / "shares"
|
||||
self.datasets_path = self.base_path / "datasets"
|
||||
|
||||
# Ensure directories exist with proper permissions
|
||||
self._ensure_directories()
|
||||
|
||||
logger.info(f"DatasetSharingService initialized for {tenant_domain}")
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure sharing directories exist with proper permissions"""
|
||||
for path in [self.shares_path, self.datasets_path]:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
# Set permissions to 700 (owner only)
|
||||
os.chmod(path, stat.S_IRWXU)
|
||||
|
||||
async def share_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
owner_id: str,
|
||||
access_group: AccessGroup,
|
||||
team_members: Optional[List[str]] = None,
|
||||
team_permissions: Optional[Dict[str, SharingPermission]] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
capability_token: str = ""
|
||||
) -> DatasetShare:
|
||||
"""
|
||||
Share a dataset with specified access group.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset to share
|
||||
owner_id: Owner of the dataset
|
||||
access_group: Level of sharing (Individual, Team, Organization)
|
||||
team_members: List of team members (if Team access)
|
||||
team_permissions: Permissions for each team member
|
||||
expires_at: Optional expiration time
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
DatasetShare configuration
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Verify ownership
|
||||
dataset_resource = await self._load_dataset_resource(dataset_id)
|
||||
if not dataset_resource or dataset_resource.owner_id != owner_id:
|
||||
raise PermissionError("Only dataset owner can modify sharing")
|
||||
|
||||
# Validate team members for team sharing
|
||||
if access_group == AccessGroup.TEAM:
|
||||
if not team_members:
|
||||
raise ValueError("Team members required for team sharing")
|
||||
|
||||
# Ensure all team members are valid users in tenant
|
||||
for member in team_members:
|
||||
if not await self._is_valid_tenant_user(member):
|
||||
logger.warning(f"Invalid team member: {member}")
|
||||
|
||||
# Create sharing configuration
|
||||
share = DatasetShare(
|
||||
dataset_id=dataset_id,
|
||||
owner_id=owner_id,
|
||||
access_group=access_group,
|
||||
team_members=team_members or [],
|
||||
team_permissions=team_permissions or {},
|
||||
expires_at=expires_at
|
||||
)
|
||||
|
||||
# Set default permissions for team members
|
||||
if access_group == AccessGroup.TEAM:
|
||||
for member in share.team_members:
|
||||
if member not in share.team_permissions:
|
||||
share.team_permissions[member] = SharingPermission.READ
|
||||
|
||||
# Store sharing configuration
|
||||
await self._store_share(share)
|
||||
|
||||
# Update dataset resource access group
|
||||
await self.access_controller.update_resource_access(
|
||||
owner_id, dataset_id, access_group, team_members
|
||||
)
|
||||
|
||||
# Emit sharing event
|
||||
if hasattr(self.access_controller, 'event_bus'):
|
||||
await self.access_controller.event_bus.emit_event(
|
||||
"dataset.shared",
|
||||
owner_id,
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"access_group": access_group.value,
|
||||
"team_members": team_members or [],
|
||||
"expires_at": expires_at.isoformat() if expires_at else None
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Dataset {dataset_id} shared as {access_group.value} by {owner_id}")
|
||||
return share
|
||||
|
||||
async def get_dataset_sharing(
|
||||
self,
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
capability_token: str
|
||||
) -> Optional[DatasetShare]:
|
||||
"""
|
||||
Get sharing configuration for a dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
user_id: Requesting user
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
DatasetShare if user has access, None otherwise
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Load sharing configuration
|
||||
share = await self._load_share(dataset_id)
|
||||
if not share:
|
||||
return None
|
||||
|
||||
# Check if user has access to view sharing info
|
||||
if share.owner_id == user_id:
|
||||
return share # Owner can always see
|
||||
|
||||
if share.access_group == AccessGroup.TEAM and user_id in share.team_members:
|
||||
return share # Team member can see
|
||||
|
||||
if share.access_group == AccessGroup.ORGANIZATION:
|
||||
# All tenant users can see organization shares
|
||||
if await self._is_valid_tenant_user(user_id):
|
||||
return share
|
||||
|
||||
return None
|
||||
|
||||
async def check_dataset_access(
|
||||
self,
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
permission: SharingPermission = SharingPermission.READ
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if user has specified permission on dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset to check
|
||||
user_id: User requesting access
|
||||
permission: Required permission level
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, reason)
|
||||
"""
|
||||
# Load sharing configuration
|
||||
share = await self._load_share(dataset_id)
|
||||
if not share or not share.is_active:
|
||||
return False, "Dataset not shared or sharing inactive"
|
||||
|
||||
# Check expiration
|
||||
if share.expires_at and datetime.utcnow() > share.expires_at:
|
||||
return False, "Dataset sharing has expired"
|
||||
|
||||
# Owner has all permissions
|
||||
if share.owner_id == user_id:
|
||||
return True, "Owner access"
|
||||
|
||||
# Check access group permissions
|
||||
if share.access_group == AccessGroup.INDIVIDUAL:
|
||||
return False, "Private dataset"
|
||||
|
||||
elif share.access_group == AccessGroup.TEAM:
|
||||
if user_id not in share.team_members:
|
||||
return False, "Not a team member"
|
||||
|
||||
# Check specific permission
|
||||
user_permission = share.team_permissions.get(user_id, SharingPermission.READ)
|
||||
if self._has_permission(user_permission, permission):
|
||||
return True, f"Team member with {user_permission.value} permission"
|
||||
else:
|
||||
return False, f"Insufficient permission: has {user_permission.value}, needs {permission.value}"
|
||||
|
||||
elif share.access_group == AccessGroup.ORGANIZATION:
|
||||
# Organization sharing is typically read-only
|
||||
if permission == SharingPermission.READ:
|
||||
if await self._is_valid_tenant_user(user_id):
|
||||
return True, "Organization-wide read access"
|
||||
return False, "Organization access is read-only"
|
||||
|
||||
return False, "Unknown access configuration"
|
||||
|
||||
async def list_accessible_datasets(
|
||||
self,
|
||||
user_id: str,
|
||||
capability_token: str,
|
||||
include_owned: bool = True,
|
||||
include_shared: bool = True
|
||||
) -> List[DatasetInfo]:
|
||||
"""
|
||||
List datasets accessible to user.
|
||||
|
||||
Args:
|
||||
user_id: User requesting list
|
||||
capability_token: JWT capability token
|
||||
include_owned: Include user's own datasets
|
||||
include_shared: Include datasets shared with user
|
||||
|
||||
Returns:
|
||||
List of accessible datasets
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
accessible_datasets = []
|
||||
|
||||
# Get all dataset shares
|
||||
all_shares = await self._list_all_shares()
|
||||
|
||||
for share in all_shares:
|
||||
# Skip inactive or expired shares
|
||||
if not share.is_active:
|
||||
continue
|
||||
if share.expires_at and datetime.utcnow() > share.expires_at:
|
||||
continue
|
||||
|
||||
# Check if user has access
|
||||
has_access = False
|
||||
|
||||
if include_owned and share.owner_id == user_id:
|
||||
has_access = True
|
||||
elif include_shared:
|
||||
allowed, _ = await self.check_dataset_access(share.dataset_id, user_id)
|
||||
has_access = allowed
|
||||
|
||||
if has_access:
|
||||
dataset_info = await self._load_dataset_info(share.dataset_id)
|
||||
if dataset_info:
|
||||
accessible_datasets.append(dataset_info)
|
||||
|
||||
return accessible_datasets
|
||||
|
||||
async def revoke_dataset_sharing(
|
||||
self,
|
||||
dataset_id: str,
|
||||
owner_id: str,
|
||||
capability_token: str
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke dataset sharing (make it private).
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset to make private
|
||||
owner_id: Owner of the dataset
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
True if revoked successfully
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Verify ownership
|
||||
share = await self._load_share(dataset_id)
|
||||
if not share or share.owner_id != owner_id:
|
||||
raise PermissionError("Only dataset owner can revoke sharing")
|
||||
|
||||
# Update sharing to individual (private)
|
||||
share.access_group = AccessGroup.INDIVIDUAL
|
||||
share.team_members = []
|
||||
share.team_permissions = {}
|
||||
share.is_active = False
|
||||
|
||||
# Store updated share
|
||||
await self._store_share(share)
|
||||
|
||||
# Update resource access
|
||||
await self.access_controller.update_resource_access(
|
||||
owner_id, dataset_id, AccessGroup.INDIVIDUAL, []
|
||||
)
|
||||
|
||||
# Emit revocation event
|
||||
if hasattr(self.access_controller, 'event_bus'):
|
||||
await self.access_controller.event_bus.emit_event(
|
||||
"dataset.sharing_revoked",
|
||||
owner_id,
|
||||
{"dataset_id": dataset_id}
|
||||
)
|
||||
|
||||
logger.info(f"Dataset {dataset_id} sharing revoked by {owner_id}")
|
||||
return True
|
||||
|
||||
async def update_team_permissions(
|
||||
self,
|
||||
dataset_id: str,
|
||||
owner_id: str,
|
||||
user_id: str,
|
||||
permission: SharingPermission,
|
||||
capability_token: str
|
||||
) -> bool:
|
||||
"""
|
||||
Update team member permissions for a dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
owner_id: Owner of the dataset
|
||||
user_id: Team member to update
|
||||
permission: New permission level
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
True if updated successfully
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Load and verify sharing
|
||||
share = await self._load_share(dataset_id)
|
||||
if not share or share.owner_id != owner_id:
|
||||
raise PermissionError("Only dataset owner can update permissions")
|
||||
|
||||
if share.access_group != AccessGroup.TEAM:
|
||||
raise ValueError("Can only update permissions for team-shared datasets")
|
||||
|
||||
if user_id not in share.team_members:
|
||||
raise ValueError("User is not a team member")
|
||||
|
||||
# Update permission
|
||||
share.team_permissions[user_id] = permission
|
||||
|
||||
# Store updated share
|
||||
await self._store_share(share)
|
||||
|
||||
logger.info(f"Updated {user_id} permission to {permission.value} for dataset {dataset_id}")
|
||||
return True
|
||||
|
||||
async def get_sharing_statistics(
|
||||
self,
|
||||
user_id: str,
|
||||
capability_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get sharing statistics for user.
|
||||
|
||||
Args:
|
||||
user_id: User to get stats for
|
||||
capability_token: JWT capability token
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
stats = {
|
||||
"owned_datasets": 0,
|
||||
"shared_with_me": 0,
|
||||
"sharing_breakdown": {
|
||||
AccessGroup.INDIVIDUAL: 0,
|
||||
AccessGroup.TEAM: 0,
|
||||
AccessGroup.ORGANIZATION: 0
|
||||
},
|
||||
"total_team_members": 0,
|
||||
"expired_shares": 0
|
||||
}
|
||||
|
||||
all_shares = await self._list_all_shares()
|
||||
|
||||
for share in all_shares:
|
||||
# Count owned datasets
|
||||
if share.owner_id == user_id:
|
||||
stats["owned_datasets"] += 1
|
||||
stats["sharing_breakdown"][share.access_group] += 1
|
||||
stats["total_team_members"] += len(share.team_members)
|
||||
|
||||
# Count expired shares
|
||||
if share.expires_at and datetime.utcnow() > share.expires_at:
|
||||
stats["expired_shares"] += 1
|
||||
|
||||
# Count datasets shared with user
|
||||
elif user_id in share.team_members or share.access_group == AccessGroup.ORGANIZATION:
|
||||
if share.is_active and (not share.expires_at or datetime.utcnow() <= share.expires_at):
|
||||
stats["shared_with_me"] += 1
|
||||
|
||||
return stats
|
||||
|
||||
def _has_permission(self, user_permission: SharingPermission, required: SharingPermission) -> bool:
|
||||
"""Check if user permission satisfies required permission"""
|
||||
permission_hierarchy = {
|
||||
SharingPermission.READ: 1,
|
||||
SharingPermission.WRITE: 2,
|
||||
SharingPermission.ADMIN: 3
|
||||
}
|
||||
|
||||
return permission_hierarchy[user_permission] >= permission_hierarchy[required]
|
||||
|
||||
async def _store_share(self, share: DatasetShare):
|
||||
"""Store sharing configuration to file system"""
|
||||
share_file = self.shares_path / f"{share.dataset_id}.json"
|
||||
|
||||
with open(share_file, "w") as f:
|
||||
json.dump(share.to_dict(), f, indent=2)
|
||||
|
||||
# Set secure permissions
|
||||
os.chmod(share_file, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
async def _load_share(self, dataset_id: str) -> Optional[DatasetShare]:
|
||||
"""Load sharing configuration from file system"""
|
||||
share_file = self.shares_path / f"{dataset_id}.json"
|
||||
|
||||
if not share_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(share_file, "r") as f:
|
||||
data = json.load(f)
|
||||
return DatasetShare.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading share for dataset {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _list_all_shares(self) -> List[DatasetShare]:
|
||||
"""List all sharing configurations"""
|
||||
shares = []
|
||||
|
||||
if self.shares_path.exists():
|
||||
for share_file in self.shares_path.glob("*.json"):
|
||||
try:
|
||||
with open(share_file, "r") as f:
|
||||
data = json.load(f)
|
||||
shares.append(DatasetShare.from_dict(data))
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading share file {share_file}: {e}")
|
||||
|
||||
return shares
|
||||
|
||||
async def _load_dataset_resource(self, dataset_id: str) -> Optional[Resource]:
|
||||
"""Load dataset resource (implementation would query storage)"""
|
||||
# Placeholder - would integrate with actual resource storage
|
||||
return Resource(
|
||||
id=dataset_id,
|
||||
name=f"Dataset {dataset_id}",
|
||||
resource_type="dataset",
|
||||
owner_id="mock_owner",
|
||||
tenant_domain=self.tenant_domain,
|
||||
access_group=AccessGroup.INDIVIDUAL
|
||||
)
|
||||
|
||||
async def _load_dataset_info(self, dataset_id: str) -> Optional[DatasetInfo]:
|
||||
"""Load dataset information (implementation would query storage)"""
|
||||
# Placeholder - would integrate with actual dataset storage
|
||||
return DatasetInfo(
|
||||
id=dataset_id,
|
||||
name=f"Dataset {dataset_id}",
|
||||
description="Mock dataset for testing",
|
||||
owner_id="mock_owner",
|
||||
document_count=10,
|
||||
size_bytes=1024000,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
tags=["test", "mock"]
|
||||
)
|
||||
|
||||
async def _is_valid_tenant_user(self, user_id: str) -> bool:
|
||||
"""Check if user is valid in tenant (implementation would query user store)"""
|
||||
# Placeholder - would integrate with actual user management
|
||||
return "@" in user_id and user_id.endswith((".com", ".org", ".edu"))
|
||||
445
apps/tenant-backend/app/services/dataset_summarizer.py
Normal file
445
apps/tenant-backend/app/services/dataset_summarizer.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Dataset Summarization Service for GT 2.0
|
||||
|
||||
Generates comprehensive summaries for datasets based on their constituent documents.
|
||||
Provides analytics, topic clustering, and overview generation for RAG optimization.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from collections import Counter
|
||||
|
||||
from app.core.database import get_db_session, execute_command, fetch_one, fetch_all
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetSummarizer:
|
||||
"""
|
||||
Service for generating dataset-level summaries and analytics.
|
||||
|
||||
Features:
|
||||
- Aggregate document summaries into dataset overview
|
||||
- Topic clustering and theme analysis
|
||||
- Dataset statistics and metrics
|
||||
- Search optimization recommendations
|
||||
- RAG performance insights
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.resource_cluster_url = "http://gentwo-resource-backend:8000"
|
||||
|
||||
async def generate_dataset_summary(
|
||||
self,
|
||||
dataset_id: str,
|
||||
tenant_domain: str,
|
||||
user_id: str,
|
||||
force_regenerate: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a comprehensive summary for a dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID to summarize
|
||||
tenant_domain: Tenant domain for database context
|
||||
user_id: User requesting the summary
|
||||
force_regenerate: Force regeneration even if summary exists
|
||||
|
||||
Returns:
|
||||
Dictionary with dataset summary including overview, topics,
|
||||
statistics, and search optimization insights
|
||||
"""
|
||||
try:
|
||||
# Check if summary already exists and is recent
|
||||
if not force_regenerate:
|
||||
existing_summary = await self._get_existing_summary(dataset_id, tenant_domain)
|
||||
if existing_summary and self._is_summary_fresh(existing_summary):
|
||||
logger.info(f"Using cached dataset summary for {dataset_id}")
|
||||
return existing_summary
|
||||
|
||||
# Get dataset information and documents
|
||||
dataset_info = await self._get_dataset_info(dataset_id, tenant_domain)
|
||||
if not dataset_info:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
|
||||
documents = await self._get_dataset_documents(dataset_id, tenant_domain)
|
||||
document_summaries = await self._get_document_summaries(dataset_id, tenant_domain)
|
||||
|
||||
# Generate statistics
|
||||
stats = await self._calculate_dataset_statistics(dataset_id, tenant_domain)
|
||||
|
||||
# Analyze topics across all documents
|
||||
topics_analysis = await self._analyze_dataset_topics(document_summaries)
|
||||
|
||||
# Generate overall summary using LLM
|
||||
overview = await self._generate_dataset_overview(
|
||||
dataset_info, document_summaries, topics_analysis, stats
|
||||
)
|
||||
|
||||
# Create comprehensive summary
|
||||
summary_data = {
|
||||
"dataset_id": dataset_id,
|
||||
"overview": overview,
|
||||
"statistics": stats,
|
||||
"topics": topics_analysis,
|
||||
"recommendations": await self._generate_search_recommendations(stats, topics_analysis),
|
||||
"metadata": {
|
||||
"document_count": len(documents),
|
||||
"has_summaries": len(document_summaries),
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"generated_by": user_id
|
||||
}
|
||||
}
|
||||
|
||||
# Store summary in database
|
||||
await self._store_dataset_summary(dataset_id, summary_data, tenant_domain, user_id)
|
||||
|
||||
logger.info(f"Generated dataset summary for {dataset_id} with {len(documents)} documents")
|
||||
return summary_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate dataset summary for {dataset_id}: {e}")
|
||||
# Return basic fallback summary
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"overview": "Dataset summary generation failed",
|
||||
"statistics": {"error": str(e)},
|
||||
"topics": [],
|
||||
"recommendations": [],
|
||||
"metadata": {
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
async def _get_dataset_info(self, dataset_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get basic dataset information"""
|
||||
async with get_db_session() as session:
|
||||
query = """
|
||||
SELECT id, dataset_name, description, chunking_strategy,
|
||||
chunk_size, chunk_overlap, created_at
|
||||
FROM datasets
|
||||
WHERE id = $1
|
||||
"""
|
||||
result = await fetch_one(session, query, dataset_id)
|
||||
return dict(result) if result else None
|
||||
|
||||
async def _get_dataset_documents(self, dataset_id: str, tenant_domain: str) -> List[Dict[str, Any]]:
|
||||
"""Get all documents in the dataset"""
|
||||
async with get_db_session() as session:
|
||||
query = """
|
||||
SELECT id, filename, original_filename, file_type,
|
||||
file_size_bytes, chunk_count, created_at
|
||||
FROM documents
|
||||
WHERE dataset_id = $1 AND processing_status = 'completed'
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
results = await fetch_all(session, query, dataset_id)
|
||||
return [dict(row) for row in results]
|
||||
|
||||
async def _get_document_summaries(self, dataset_id: str, tenant_domain: str) -> List[Dict[str, Any]]:
|
||||
"""Get summaries for all documents in the dataset"""
|
||||
async with get_db_session() as session:
|
||||
query = """
|
||||
SELECT ds.document_id, ds.quick_summary, ds.detailed_analysis,
|
||||
ds.topics, ds.metadata, ds.confidence,
|
||||
d.filename, d.original_filename
|
||||
FROM document_summaries ds
|
||||
JOIN documents d ON ds.document_id = d.id
|
||||
WHERE d.dataset_id = $1
|
||||
ORDER BY ds.created_at DESC
|
||||
"""
|
||||
results = await fetch_all(session, query, dataset_id)
|
||||
|
||||
summaries = []
|
||||
for row in results:
|
||||
summary = dict(row)
|
||||
# Parse JSON fields
|
||||
if summary["topics"]:
|
||||
summary["topics"] = json.loads(summary["topics"])
|
||||
if summary["metadata"]:
|
||||
summary["metadata"] = json.loads(summary["metadata"])
|
||||
summaries.append(summary)
|
||||
|
||||
return summaries
|
||||
|
||||
async def _calculate_dataset_statistics(self, dataset_id: str, tenant_domain: str) -> Dict[str, Any]:
|
||||
"""Calculate comprehensive dataset statistics"""
|
||||
async with get_db_session() as session:
|
||||
# Basic document statistics
|
||||
doc_stats_query = """
|
||||
SELECT
|
||||
COUNT(*) as total_documents,
|
||||
SUM(file_size_bytes) as total_size_bytes,
|
||||
SUM(chunk_count) as total_chunks,
|
||||
AVG(chunk_count) as avg_chunks_per_doc,
|
||||
COUNT(DISTINCT file_type) as unique_file_types
|
||||
FROM documents
|
||||
WHERE dataset_id = $1 AND processing_status = 'completed'
|
||||
"""
|
||||
doc_stats = await fetch_one(session, doc_stats_query, dataset_id)
|
||||
|
||||
# Chunk statistics
|
||||
chunk_stats_query = """
|
||||
SELECT
|
||||
COUNT(*) as total_vector_embeddings,
|
||||
AVG(token_count) as avg_tokens_per_chunk,
|
||||
MIN(token_count) as min_tokens,
|
||||
MAX(token_count) as max_tokens
|
||||
FROM document_chunks
|
||||
WHERE dataset_id = $1
|
||||
"""
|
||||
chunk_stats = await fetch_one(session, chunk_stats_query, dataset_id)
|
||||
|
||||
# File type distribution
|
||||
file_types_query = """
|
||||
SELECT file_type, COUNT(*) as count
|
||||
FROM documents
|
||||
WHERE dataset_id = $1 AND processing_status = 'completed'
|
||||
GROUP BY file_type
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
file_types_results = await fetch_all(session, file_types_query, dataset_id)
|
||||
file_types = {row["file_type"]: row["count"] for row in file_types_results}
|
||||
|
||||
return {
|
||||
"documents": {
|
||||
"total": doc_stats["total_documents"] or 0,
|
||||
"total_size_mb": round((doc_stats["total_size_bytes"] or 0) / 1024 / 1024, 2),
|
||||
"avg_chunks_per_document": round(doc_stats["avg_chunks_per_doc"] or 0, 1),
|
||||
"unique_file_types": doc_stats["unique_file_types"] or 0,
|
||||
"file_type_distribution": file_types
|
||||
},
|
||||
"chunks": {
|
||||
"total": chunk_stats["total_vector_embeddings"] or 0,
|
||||
"avg_tokens": round(chunk_stats["avg_tokens_per_chunk"] or 0, 1),
|
||||
"token_range": {
|
||||
"min": chunk_stats["min_tokens"] or 0,
|
||||
"max": chunk_stats["max_tokens"] or 0
|
||||
}
|
||||
},
|
||||
"search_readiness": {
|
||||
"has_vectors": (chunk_stats["total_vector_embeddings"] or 0) > 0,
|
||||
"vector_coverage": 1.0 if (doc_stats["total_chunks"] or 0) == (chunk_stats["total_vector_embeddings"] or 0) else 0.0
|
||||
}
|
||||
}
|
||||
|
||||
async def _analyze_dataset_topics(self, document_summaries: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Analyze topics across all document summaries"""
|
||||
if not document_summaries:
|
||||
return {"main_topics": [], "topic_distribution": {}, "confidence": 0.0}
|
||||
|
||||
# Collect all topics from document summaries
|
||||
all_topics = []
|
||||
for summary in document_summaries:
|
||||
topics = summary.get("topics", [])
|
||||
if isinstance(topics, list):
|
||||
all_topics.extend(topics)
|
||||
|
||||
# Count topic frequencies
|
||||
topic_counts = Counter(all_topics)
|
||||
|
||||
# Get top topics
|
||||
main_topics = [topic for topic, count in topic_counts.most_common(10)]
|
||||
|
||||
# Calculate topic distribution
|
||||
total_topics = len(all_topics)
|
||||
topic_distribution = {}
|
||||
if total_topics > 0:
|
||||
for topic, count in topic_counts.items():
|
||||
topic_distribution[topic] = round(count / total_topics, 3)
|
||||
|
||||
# Calculate confidence based on number of summaries available
|
||||
confidence = min(1.0, len(document_summaries) / 5.0) # Full confidence with 5+ documents
|
||||
|
||||
return {
|
||||
"main_topics": main_topics,
|
||||
"topic_distribution": topic_distribution,
|
||||
"confidence": confidence,
|
||||
"total_unique_topics": len(topic_counts)
|
||||
}
|
||||
|
||||
async def _generate_dataset_overview(
|
||||
self,
|
||||
dataset_info: Dict[str, Any],
|
||||
document_summaries: List[Dict[str, Any]],
|
||||
topics_analysis: Dict[str, Any],
|
||||
stats: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Generate LLM-powered overview of the dataset"""
|
||||
|
||||
# Create context for LLM
|
||||
context = f"""Dataset: {dataset_info['dataset_name']}
|
||||
Description: {dataset_info.get('description', 'No description provided')}
|
||||
|
||||
Statistics:
|
||||
- {stats['documents']['total']} documents ({stats['documents']['total_size_mb']} MB)
|
||||
- {stats['chunks']['total']} text chunks for search
|
||||
- Average {stats['documents']['avg_chunks_per_document']} chunks per document
|
||||
|
||||
Main Topics: {', '.join(topics_analysis['main_topics'][:5])}
|
||||
|
||||
Document Summaries:
|
||||
"""
|
||||
|
||||
# Add sample document summaries
|
||||
for i, summary in enumerate(document_summaries[:3]): # First 3 documents
|
||||
context += f"\n{i+1}. {summary['filename']}: {summary['quick_summary']}"
|
||||
|
||||
prompt = f"""Analyze this dataset and provide a comprehensive 2-3 paragraph overview.
|
||||
|
||||
{context}
|
||||
|
||||
Focus on:
|
||||
1. What type of content this dataset contains
|
||||
2. The main themes and topics covered
|
||||
3. How useful this would be for AI-powered search and retrieval
|
||||
4. Any notable patterns or characteristics
|
||||
|
||||
Provide a professional, informative summary suitable for users exploring their datasets."""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
|
||||
json={
|
||||
"model": "llama-3.1-8b-instant",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a data analysis expert. Provide clear, insightful dataset summaries."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 500
|
||||
},
|
||||
headers={
|
||||
"X-Tenant-ID": "default",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
llm_response = response.json()
|
||||
return llm_response["choices"][0]["message"]["content"]
|
||||
else:
|
||||
raise Exception(f"LLM API error: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM overview generation failed: {e}")
|
||||
# Fallback to template-based overview
|
||||
return f"This dataset contains {stats['documents']['total']} documents covering topics such as {', '.join(topics_analysis['main_topics'][:3])}. The dataset includes {stats['chunks']['total']} searchable text chunks optimized for AI-powered retrieval and question answering."
|
||||
|
||||
async def _generate_search_recommendations(
|
||||
self,
|
||||
stats: Dict[str, Any],
|
||||
topics_analysis: Dict[str, Any]
|
||||
) -> List[str]:
|
||||
"""Generate recommendations for optimizing search performance"""
|
||||
recommendations = []
|
||||
|
||||
# Vector coverage recommendations
|
||||
if not stats["search_readiness"]["has_vectors"]:
|
||||
recommendations.append("Generate vector embeddings for all documents to enable semantic search")
|
||||
elif stats["search_readiness"]["vector_coverage"] < 1.0:
|
||||
recommendations.append("Complete vector embedding generation for optimal search performance")
|
||||
|
||||
# Chunk size recommendations
|
||||
avg_tokens = stats["chunks"]["avg_tokens"]
|
||||
if avg_tokens < 100:
|
||||
recommendations.append("Consider increasing chunk size for better context in search results")
|
||||
elif avg_tokens > 600:
|
||||
recommendations.append("Consider reducing chunk size for more precise search matches")
|
||||
|
||||
# Topic diversity recommendations
|
||||
if topics_analysis["total_unique_topics"] < 3:
|
||||
recommendations.append("Dataset may benefit from more diverse content for comprehensive coverage")
|
||||
elif topics_analysis["total_unique_topics"] > 50:
|
||||
recommendations.append("Consider organizing content into focused sub-datasets for better search precision")
|
||||
|
||||
# Document count recommendations
|
||||
doc_count = stats["documents"]["total"]
|
||||
if doc_count < 5:
|
||||
recommendations.append("Add more documents to improve search quality and coverage")
|
||||
elif doc_count > 100:
|
||||
recommendations.append("Consider implementing advanced filtering and categorization for better navigation")
|
||||
|
||||
return recommendations[:5] # Limit to top 5 recommendations
|
||||
|
||||
async def _store_dataset_summary(
|
||||
self,
|
||||
dataset_id: str,
|
||||
summary_data: Dict[str, Any],
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
):
|
||||
"""Store or update dataset summary in database"""
|
||||
async with get_db_session() as session:
|
||||
query = """
|
||||
UPDATE datasets
|
||||
SET
|
||||
summary = $1,
|
||||
summary_generated_at = $2,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
"""
|
||||
|
||||
await execute_command(
|
||||
session,
|
||||
query,
|
||||
json.dumps(summary_data),
|
||||
datetime.utcnow(),
|
||||
dataset_id
|
||||
)
|
||||
|
||||
async def _get_existing_summary(self, dataset_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get existing dataset summary if available"""
|
||||
async with get_db_session() as session:
|
||||
query = """
|
||||
SELECT summary, summary_generated_at
|
||||
FROM datasets
|
||||
WHERE id = $1 AND summary IS NOT NULL
|
||||
"""
|
||||
result = await fetch_one(session, query, dataset_id)
|
||||
|
||||
if result and result["summary"]:
|
||||
return json.loads(result["summary"])
|
||||
return None
|
||||
|
||||
def _is_summary_fresh(self, summary: Dict[str, Any], max_age_hours: int = 24) -> bool:
|
||||
"""Check if summary is recent enough to avoid regeneration"""
|
||||
try:
|
||||
generated_at = datetime.fromisoformat(summary["metadata"]["generated_at"])
|
||||
age_hours = (datetime.utcnow() - generated_at).total_seconds() / 3600
|
||||
return age_hours < max_age_hours
|
||||
except (KeyError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
# Global instance
|
||||
dataset_summarizer = DatasetSummarizer()
|
||||
|
||||
|
||||
async def generate_dataset_summary(
|
||||
dataset_id: str,
|
||||
tenant_domain: str,
|
||||
user_id: str,
|
||||
force_regenerate: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Convenience function for dataset summary generation"""
|
||||
return await dataset_summarizer.generate_dataset_summary(
|
||||
dataset_id, tenant_domain, user_id, force_regenerate
|
||||
)
|
||||
|
||||
|
||||
async def get_dataset_summary(dataset_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
|
||||
"""Convenience function for retrieving dataset summary"""
|
||||
return await dataset_summarizer._get_existing_summary(dataset_id, tenant_domain)
|
||||
834
apps/tenant-backend/app/services/document_processor.py
Normal file
834
apps/tenant-backend/app/services/document_processor.py
Normal file
@@ -0,0 +1,834 @@
|
||||
"""
|
||||
Document Processing Service for GT 2.0
|
||||
|
||||
Handles file upload, text extraction, chunking, and embedding generation
|
||||
for RAG pipeline. Supports multiple file formats with intelligent chunking.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
# Document processing libraries
|
||||
import pypdf as PyPDF2 # pypdf is the maintained successor to PyPDF2
|
||||
import docx
|
||||
import pandas as pd
|
||||
import json
|
||||
import csv
|
||||
from io import StringIO
|
||||
|
||||
# Database and core services
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
|
||||
# Resource cluster client for embeddings
|
||||
import httpx
|
||||
from app.services.embedding_client import get_embedding_client
|
||||
|
||||
# Document summarization
|
||||
from app.services.summarization_service import SummarizationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""
|
||||
Comprehensive document processing service for RAG pipeline.
|
||||
|
||||
Features:
|
||||
- Multi-format support (PDF, DOCX, TXT, MD, CSV, JSON)
|
||||
- Intelligent chunking with overlap
|
||||
- Async embedding generation with batch processing
|
||||
- Progress tracking
|
||||
- Error handling and recovery
|
||||
"""
|
||||
|
||||
def __init__(self, db=None, tenant_domain=None):
|
||||
self.db = db
|
||||
self.tenant_domain = tenant_domain or "test" # Default fallback
|
||||
# Use configurable embedding client instead of hardcoded URL
|
||||
self.embedding_client = get_embedding_client()
|
||||
self.chunk_size = 512 # Default chunk size in tokens
|
||||
self.chunk_overlap = 128 # Default overlap
|
||||
self.max_file_size = 100 * 1024 * 1024 # 100MB limit
|
||||
|
||||
# Embedding batch processing configuration
|
||||
self.EMBEDDING_BATCH_SIZE = 15 # Process embeddings in batches of 15 (ARM64 optimized)
|
||||
self.MAX_CONCURRENT_BATCHES = 3 # Process up to 3 batches concurrently
|
||||
self.MAX_RETRIES = 3 # Maximum retries per batch
|
||||
self.INITIAL_RETRY_DELAY = 1.0 # Initial delay in seconds
|
||||
|
||||
# Supported file types
|
||||
self.supported_types = {
|
||||
'.pdf': 'application/pdf',
|
||||
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'.txt': 'text/plain',
|
||||
'.md': 'text/markdown',
|
||||
'.csv': 'text/csv',
|
||||
'.json': 'application/json'
|
||||
}
|
||||
|
||||
async def process_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
original_filename: str,
|
||||
document_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a uploaded file through the complete RAG pipeline.
|
||||
|
||||
Args:
|
||||
file_path: Path to uploaded file
|
||||
dataset_id: Dataset UUID to attach to
|
||||
user_id: User who uploaded the file
|
||||
original_filename: Original filename
|
||||
document_id: Optional existing document ID to update instead of creating new
|
||||
|
||||
Returns:
|
||||
Dict: Document record with processing status
|
||||
"""
|
||||
logger.info(f"Processing file {original_filename} for dataset {dataset_id}")
|
||||
|
||||
# Process file directly (no session management needed with PostgreSQL client)
|
||||
return await self._process_file_internal(file_path, dataset_id, user_id, original_filename, document_id)
|
||||
|
||||
async def _process_file_internal(
|
||||
self,
|
||||
file_path: Path,
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
original_filename: str,
|
||||
document_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal file processing method"""
|
||||
try:
|
||||
# 1. Validate file
|
||||
await self._validate_file(file_path)
|
||||
|
||||
# 2. Create or use existing document record
|
||||
if document_id:
|
||||
# Use existing document
|
||||
document = {"id": document_id}
|
||||
logger.info(f"Using existing document {document_id} for processing")
|
||||
else:
|
||||
# Create new document record
|
||||
document = await self._create_document_record(
|
||||
file_path, dataset_id, user_id, original_filename
|
||||
)
|
||||
|
||||
# 3. Get or extract text content
|
||||
await self._update_processing_status(document["id"], "processing", processing_stage="Getting text content")
|
||||
|
||||
# Check if content already exists (e.g., from upload-time extraction)
|
||||
existing_content, storage_type = await self._get_existing_document_content(document["id"])
|
||||
|
||||
if existing_content and storage_type in ["pdf_extracted", "text"]:
|
||||
# Use existing extracted content
|
||||
text_content = existing_content
|
||||
logger.info(f"Using existing extracted content ({len(text_content)} chars, type: {storage_type})")
|
||||
else:
|
||||
# Extract text from file
|
||||
await self._update_processing_status(document["id"], "processing", processing_stage="Extracting text")
|
||||
|
||||
# Determine file type for extraction
|
||||
if document_id:
|
||||
# For existing documents, determine file type from file extension
|
||||
file_ext = file_path.suffix.lower()
|
||||
file_type = self.supported_types.get(file_ext, 'text/plain')
|
||||
else:
|
||||
file_type = document["file_type"]
|
||||
|
||||
text_content = await self._extract_text(file_path, file_type)
|
||||
|
||||
# 4. Update document with extracted text
|
||||
await self._update_document_content(document["id"], text_content)
|
||||
|
||||
# 5. Generate document summary
|
||||
await self._update_processing_status(document["id"], "processing", processing_stage="Generating summary")
|
||||
await self._generate_document_summary(document["id"], text_content, original_filename, user_id)
|
||||
|
||||
# 6. Chunk the document
|
||||
await self._update_processing_status(document["id"], "processing", processing_stage="Creating chunks")
|
||||
chunks = await self._chunk_text(text_content, document["id"])
|
||||
|
||||
# Set expected chunk count for progress tracking
|
||||
await self._update_processing_status(
|
||||
document["id"], "processing",
|
||||
processing_stage="Preparing for embedding generation",
|
||||
total_chunks_expected=len(chunks)
|
||||
)
|
||||
|
||||
# 7. Generate embeddings
|
||||
await self._update_processing_status(document["id"], "processing", processing_stage="Starting embedding generation")
|
||||
await self._generate_embeddings_for_chunks(chunks, dataset_id, user_id)
|
||||
|
||||
# 8. Update final status
|
||||
await self._update_processing_status(
|
||||
document["id"], "completed",
|
||||
processing_stage="Completed",
|
||||
chunks_processed=len(chunks),
|
||||
total_chunks_expected=len(chunks)
|
||||
)
|
||||
await self._update_chunk_count(document["id"], len(chunks))
|
||||
|
||||
# 9. Update dataset summary (after document is fully processed)
|
||||
await self._update_dataset_summary_after_document_change(dataset_id, user_id)
|
||||
|
||||
logger.info(f"Successfully processed {original_filename} with {len(chunks)} chunks")
|
||||
return document
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file {original_filename}: {e}")
|
||||
if 'document' in locals():
|
||||
await self._update_processing_status(
|
||||
document["id"], "failed",
|
||||
error_message=str(e),
|
||||
processing_stage="Failed"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _validate_file(self, file_path: Path):
|
||||
"""Validate file size and type"""
|
||||
if not file_path.exists():
|
||||
raise ValueError("File does not exist")
|
||||
|
||||
file_size = file_path.stat().st_size
|
||||
if file_size > self.max_file_size:
|
||||
raise ValueError(f"File too large: {file_size} bytes (max: {self.max_file_size})")
|
||||
|
||||
file_ext = file_path.suffix.lower()
|
||||
if file_ext not in self.supported_types:
|
||||
raise ValueError(f"Unsupported file type: {file_ext}")
|
||||
|
||||
async def _create_document_record(
|
||||
self,
|
||||
file_path: Path,
|
||||
dataset_id: str,
|
||||
user_id: str,
|
||||
original_filename: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create document record in database"""
|
||||
|
||||
# Calculate file hash
|
||||
with open(file_path, 'rb') as f:
|
||||
file_hash = hashlib.sha256(f.read()).hexdigest()
|
||||
|
||||
file_ext = file_path.suffix.lower()
|
||||
file_size = file_path.stat().st_size
|
||||
document_id = str(uuid.uuid4())
|
||||
|
||||
# Insert document record using raw SQL
|
||||
# Note: tenant_id is nullable UUID, so we set it to NULL for individual documents
|
||||
pg_client = await get_postgresql_client()
|
||||
await pg_client.execute_command(
|
||||
"""INSERT INTO documents (
|
||||
id, user_id, dataset_id, filename, original_filename,
|
||||
file_type, file_size_bytes, file_hash, processing_status
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)""",
|
||||
document_id, str(user_id), dataset_id, str(file_path.name),
|
||||
original_filename, self.supported_types[file_ext], file_size, file_hash, "pending"
|
||||
)
|
||||
|
||||
return {
|
||||
"id": document_id,
|
||||
"user_id": user_id,
|
||||
"dataset_id": dataset_id,
|
||||
"filename": str(file_path.name),
|
||||
"original_filename": original_filename,
|
||||
"file_type": self.supported_types[file_ext],
|
||||
"file_size_bytes": file_size,
|
||||
"file_hash": file_hash,
|
||||
"processing_status": "pending",
|
||||
"chunk_count": 0
|
||||
}
|
||||
|
||||
async def _extract_text(self, file_path: Path, file_type: str) -> str:
|
||||
"""Extract text content from various file formats"""
|
||||
|
||||
try:
|
||||
if file_type == 'application/pdf':
|
||||
return await self._extract_pdf_text(file_path)
|
||||
elif 'wordprocessingml' in file_type:
|
||||
return await self._extract_docx_text(file_path)
|
||||
elif file_type == 'text/csv':
|
||||
return await self._extract_csv_text(file_path)
|
||||
elif file_type == 'application/json':
|
||||
return await self._extract_json_text(file_path)
|
||||
else: # text/plain, text/markdown
|
||||
return await self._extract_plain_text(file_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Text extraction failed for {file_path}: {e}")
|
||||
raise ValueError(f"Could not extract text from file: {e}")
|
||||
|
||||
async def _extract_pdf_text(self, file_path: Path) -> str:
|
||||
"""Extract text from PDF file"""
|
||||
text_parts = []
|
||||
|
||||
with open(file_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
|
||||
for page_num, page in enumerate(pdf_reader.pages):
|
||||
try:
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num + 1} ---\n{page_text}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not extract text from page {page_num + 1}: {e}")
|
||||
|
||||
if not text_parts:
|
||||
raise ValueError("No text could be extracted from PDF")
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
async def _extract_docx_text(self, file_path: Path) -> str:
|
||||
"""Extract text from DOCX file"""
|
||||
doc = docx.Document(file_path)
|
||||
text_parts = []
|
||||
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text_parts.append(paragraph.text)
|
||||
|
||||
if not text_parts:
|
||||
raise ValueError("No text could be extracted from DOCX")
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
async def _extract_csv_text(self, file_path: Path) -> str:
|
||||
"""Extract and format text from CSV file"""
|
||||
try:
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# Create readable format
|
||||
text_parts = [f"CSV Data with {len(df)} rows and {len(df.columns)} columns"]
|
||||
text_parts.append(f"Columns: {', '.join(df.columns.tolist())}")
|
||||
text_parts.append("")
|
||||
|
||||
# Sample first few rows in readable format
|
||||
for idx, row in df.head(20).iterrows():
|
||||
row_text = []
|
||||
for col in df.columns:
|
||||
if pd.notna(row[col]):
|
||||
row_text.append(f"{col}: {row[col]}")
|
||||
text_parts.append(f"Row {idx + 1}: " + " | ".join(row_text))
|
||||
|
||||
return "\n".join(text_parts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CSV parsing error: {e}")
|
||||
# Fallback to reading as plain text
|
||||
return await self._extract_plain_text(file_path)
|
||||
|
||||
async def _extract_json_text(self, file_path: Path) -> str:
|
||||
"""Extract and format text from JSON file"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Convert JSON to readable text format
|
||||
def json_to_text(obj, prefix=""):
|
||||
text_parts = []
|
||||
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
text_parts.append(f"{prefix}{key}:")
|
||||
text_parts.extend(json_to_text(value, prefix + " "))
|
||||
else:
|
||||
text_parts.append(f"{prefix}{key}: {value}")
|
||||
elif isinstance(obj, list):
|
||||
for i, item in enumerate(obj):
|
||||
if isinstance(item, (dict, list)):
|
||||
text_parts.append(f"{prefix}Item {i + 1}:")
|
||||
text_parts.extend(json_to_text(item, prefix + " "))
|
||||
else:
|
||||
text_parts.append(f"{prefix}Item {i + 1}: {item}")
|
||||
else:
|
||||
text_parts.append(f"{prefix}{obj}")
|
||||
|
||||
return text_parts
|
||||
|
||||
return "\n".join(json_to_text(data))
|
||||
|
||||
async def _extract_plain_text(self, file_path: Path) -> str:
|
||||
"""Extract text from plain text files"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except UnicodeDecodeError:
|
||||
# Try with latin-1 encoding
|
||||
with open(file_path, 'r', encoding='latin-1') as f:
|
||||
return f.read()
|
||||
|
||||
async def extract_text_from_path(self, file_path: Path, content_type: str) -> str:
|
||||
"""Public wrapper for text extraction from file path"""
|
||||
return await self._extract_text(file_path, content_type)
|
||||
|
||||
async def chunk_text_simple(self, text: str) -> List[str]:
|
||||
"""Public wrapper for simple text chunking without document_id"""
|
||||
chunks = []
|
||||
chunk_size = self.chunk_size * 4 # ~2048 chars
|
||||
overlap = self.chunk_overlap * 4 # ~512 chars
|
||||
|
||||
for i in range(0, len(text), chunk_size - overlap):
|
||||
chunk = text[i:i + chunk_size]
|
||||
if chunk.strip():
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
async def _chunk_text(self, text: str, document_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Split text into overlapping chunks optimized for embeddings.
|
||||
|
||||
Returns:
|
||||
List of chunk dictionaries with content and metadata
|
||||
"""
|
||||
# Simple sentence-aware chunking
|
||||
sentences = re.split(r'[.!?]+', text)
|
||||
sentences = [s.strip() for s in sentences if s.strip()]
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
current_tokens = 0
|
||||
chunk_index = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_tokens = len(sentence.split())
|
||||
|
||||
# If adding this sentence would exceed chunk size, save current chunk
|
||||
if current_tokens + sentence_tokens > self.chunk_size and current_chunk:
|
||||
# Create chunk with overlap from previous chunk
|
||||
chunk_content = current_chunk.strip()
|
||||
if chunk_content:
|
||||
chunks.append({
|
||||
"document_id": document_id,
|
||||
"chunk_index": chunk_index,
|
||||
"content": chunk_content,
|
||||
"token_count": current_tokens,
|
||||
"content_hash": hashlib.md5(chunk_content.encode()).hexdigest()
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
# Start new chunk with overlap
|
||||
if self.chunk_overlap > 0 and chunks:
|
||||
# Take last few sentences for overlap
|
||||
overlap_sentences = current_chunk.split('.')[-2:] # Rough overlap
|
||||
current_chunk = '. '.join(s.strip() for s in overlap_sentences if s.strip())
|
||||
current_tokens = len(current_chunk.split())
|
||||
else:
|
||||
current_chunk = ""
|
||||
current_tokens = 0
|
||||
|
||||
# Add sentence to current chunk
|
||||
if current_chunk:
|
||||
current_chunk += ". " + sentence
|
||||
else:
|
||||
current_chunk = sentence
|
||||
current_tokens += sentence_tokens
|
||||
|
||||
# Add final chunk
|
||||
if current_chunk.strip():
|
||||
chunk_content = current_chunk.strip()
|
||||
chunks.append({
|
||||
"document_id": document_id,
|
||||
"chunk_index": chunk_index,
|
||||
"content": chunk_content,
|
||||
"token_count": current_tokens,
|
||||
"content_hash": hashlib.md5(chunk_content.encode()).hexdigest()
|
||||
})
|
||||
|
||||
logger.info(f"Created {len(chunks)} chunks from document {document_id}")
|
||||
return chunks
|
||||
|
||||
async def _generate_embeddings_for_chunks(
|
||||
self,
|
||||
chunks: List[Dict[str, Any]],
|
||||
dataset_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""
|
||||
Generate embeddings for all chunks using concurrent batch processing.
|
||||
|
||||
Processes chunks in batches with controlled concurrency to optimize performance
|
||||
while preventing system overload. Includes retry logic and progressive storage.
|
||||
"""
|
||||
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
total_chunks = len(chunks)
|
||||
document_id = chunks[0]["document_id"]
|
||||
total_batches = (total_chunks + self.EMBEDDING_BATCH_SIZE - 1) // self.EMBEDDING_BATCH_SIZE
|
||||
|
||||
logger.info(f"Starting concurrent embedding generation for {total_chunks} chunks")
|
||||
logger.info(f"Batch size: {self.EMBEDDING_BATCH_SIZE}, Total batches: {total_batches}, Max concurrent: {self.MAX_CONCURRENT_BATCHES}")
|
||||
|
||||
# Create semaphore to limit concurrent batches
|
||||
semaphore = asyncio.Semaphore(self.MAX_CONCURRENT_BATCHES)
|
||||
|
||||
# Create batch data with metadata
|
||||
batch_tasks = []
|
||||
for batch_start in range(0, total_chunks, self.EMBEDDING_BATCH_SIZE):
|
||||
batch_end = min(batch_start + self.EMBEDDING_BATCH_SIZE, total_chunks)
|
||||
batch_chunks = chunks[batch_start:batch_end]
|
||||
batch_num = (batch_start // self.EMBEDDING_BATCH_SIZE) + 1
|
||||
|
||||
batch_metadata = {
|
||||
"chunks": batch_chunks,
|
||||
"batch_num": batch_num,
|
||||
"start_index": batch_start,
|
||||
"end_index": batch_end,
|
||||
"dataset_id": dataset_id,
|
||||
"user_id": user_id,
|
||||
"document_id": document_id
|
||||
}
|
||||
|
||||
# Create task for this batch
|
||||
task = self._process_batch_with_semaphore(semaphore, batch_metadata, total_batches, total_chunks)
|
||||
batch_tasks.append(task)
|
||||
|
||||
# Process all batches concurrently
|
||||
logger.info(f"Starting concurrent processing of {len(batch_tasks)} batches")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
# Analyze results
|
||||
successful_batches = []
|
||||
failed_batches = []
|
||||
|
||||
for i, result in enumerate(results):
|
||||
batch_num = i + 1
|
||||
if isinstance(result, Exception):
|
||||
failed_batches.append({
|
||||
"batch_num": batch_num,
|
||||
"error": str(result)
|
||||
})
|
||||
logger.error(f"Batch {batch_num} failed: {result}")
|
||||
else:
|
||||
successful_batches.append(result)
|
||||
|
||||
successful_chunks = sum(len(batch["chunks"]) for batch in successful_batches)
|
||||
|
||||
logger.info(f"Concurrent processing completed in {processing_time:.2f} seconds")
|
||||
logger.info(f"Successfully processed {successful_chunks}/{total_chunks} chunks in {len(successful_batches)}/{total_batches} batches")
|
||||
|
||||
# Report final results
|
||||
if failed_batches:
|
||||
failed_chunk_count = total_chunks - successful_chunks
|
||||
error_details = "; ".join([f"Batch {b['batch_num']}: {b['error']}" for b in failed_batches[:3]])
|
||||
if len(failed_batches) > 3:
|
||||
error_details += f" (and {len(failed_batches) - 3} more failures)"
|
||||
|
||||
raise ValueError(f"Failed to generate embeddings for {failed_chunk_count}/{total_chunks} chunks. Errors: {error_details}")
|
||||
|
||||
logger.info(f"Successfully stored all {total_chunks} chunks with embeddings")
|
||||
|
||||
async def _process_batch_with_semaphore(
|
||||
self,
|
||||
semaphore: asyncio.Semaphore,
|
||||
batch_metadata: Dict[str, Any],
|
||||
total_batches: int,
|
||||
total_chunks: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a single batch with semaphore-controlled concurrency.
|
||||
|
||||
Args:
|
||||
semaphore: Concurrency control semaphore
|
||||
batch_metadata: Batch information including chunks and metadata
|
||||
total_batches: Total number of batches
|
||||
total_chunks: Total number of chunks
|
||||
|
||||
Returns:
|
||||
Dict with batch processing results
|
||||
"""
|
||||
async with semaphore:
|
||||
batch_chunks = batch_metadata["chunks"]
|
||||
batch_num = batch_metadata["batch_num"]
|
||||
dataset_id = batch_metadata["dataset_id"]
|
||||
user_id = batch_metadata["user_id"]
|
||||
document_id = batch_metadata["document_id"]
|
||||
|
||||
logger.info(f"Starting batch {batch_num}/{total_batches} ({len(batch_chunks)} chunks)")
|
||||
|
||||
try:
|
||||
# Generate embeddings for this batch (pass user_id for billing)
|
||||
embeddings = await self._generate_embedding_batch(batch_chunks, user_id=user_id)
|
||||
|
||||
# Store embeddings for this batch immediately
|
||||
await self._store_chunk_embeddings(batch_chunks, embeddings, dataset_id, user_id)
|
||||
|
||||
# Update progress in database
|
||||
progress_stage = f"Completed batch {batch_num}/{total_batches}"
|
||||
|
||||
# Calculate current progress (approximate since batches complete out of order)
|
||||
await self._update_processing_status(
|
||||
document_id, "processing",
|
||||
processing_stage=progress_stage,
|
||||
chunks_processed=batch_num * self.EMBEDDING_BATCH_SIZE, # Approximate
|
||||
total_chunks_expected=total_chunks
|
||||
)
|
||||
|
||||
logger.info(f"Successfully completed batch {batch_num}/{total_batches}")
|
||||
|
||||
return {
|
||||
"batch_num": batch_num,
|
||||
"chunks": batch_chunks,
|
||||
"success": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process batch {batch_num}/{total_batches}: {e}")
|
||||
raise ValueError(f"Batch {batch_num} failed: {str(e)}")
|
||||
|
||||
async def _generate_embedding_batch(
|
||||
self,
|
||||
batch_chunks: List[Dict[str, Any]],
|
||||
user_id: str = None
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for a single batch of chunks with retry logic.
|
||||
|
||||
Args:
|
||||
batch_chunks: List of chunk dictionaries
|
||||
user_id: User ID for usage tracking
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
|
||||
Raises:
|
||||
ValueError: If embedding generation fails after all retries
|
||||
"""
|
||||
texts = [chunk["content"] for chunk in batch_chunks]
|
||||
|
||||
for attempt in range(self.MAX_RETRIES + 1):
|
||||
try:
|
||||
# Use the configurable embedding client with tenant/user context for billing
|
||||
embeddings = await self.embedding_client.generate_embeddings(
|
||||
texts,
|
||||
tenant_id=self.tenant_domain,
|
||||
user_id=str(user_id) if user_id else None
|
||||
)
|
||||
|
||||
if len(embeddings) != len(texts):
|
||||
raise ValueError(f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}")
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
if attempt < self.MAX_RETRIES:
|
||||
delay = self.INITIAL_RETRY_DELAY * (2 ** attempt) # Exponential backoff
|
||||
logger.warning(f"Embedding generation attempt {attempt + 1}/{self.MAX_RETRIES + 1} failed: {e}. Retrying in {delay}s...")
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error(f"All {self.MAX_RETRIES + 1} embedding generation attempts failed. Final error: {e}")
|
||||
logger.error(f"Failed request details: URL=http://gentwo-vllm-embeddings:8000/v1/embeddings, texts_count={len(texts)}")
|
||||
raise ValueError(f"Embedding generation failed after {self.MAX_RETRIES + 1} attempts: {str(e)}")
|
||||
|
||||
async def _store_chunk_embeddings(
|
||||
self,
|
||||
batch_chunks: List[Dict[str, Any]],
|
||||
embeddings: List[List[float]],
|
||||
dataset_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""Store chunk embeddings in database with proper error handling."""
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
for chunk_data, embedding in zip(batch_chunks, embeddings):
|
||||
chunk_id = str(uuid.uuid4())
|
||||
|
||||
# Convert embedding list to PostgreSQL array format
|
||||
embedding_array = f"[{','.join(map(str, embedding))}]" if embedding else None
|
||||
|
||||
await pg_client.execute_command(
|
||||
"""INSERT INTO document_chunks (
|
||||
id, document_id, user_id, dataset_id, chunk_index,
|
||||
content, content_hash, token_count, embedding
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector)""",
|
||||
chunk_id, chunk_data["document_id"], str(user_id),
|
||||
dataset_id, chunk_data["chunk_index"], chunk_data["content"],
|
||||
chunk_data["content_hash"], chunk_data["token_count"], embedding_array
|
||||
)
|
||||
|
||||
async def _update_processing_status(
|
||||
self,
|
||||
document_id: str,
|
||||
status: str,
|
||||
error_message: Optional[str] = None,
|
||||
processing_stage: Optional[str] = None,
|
||||
chunks_processed: Optional[int] = None,
|
||||
total_chunks_expected: Optional[int] = None
|
||||
):
|
||||
"""Update document processing status with progress tracking via metadata JSONB"""
|
||||
|
||||
# Calculate progress percentage if we have the data
|
||||
processing_progress = None
|
||||
if chunks_processed is not None and total_chunks_expected is not None and total_chunks_expected > 0:
|
||||
processing_progress = min(100, int((chunks_processed / total_chunks_expected) * 100))
|
||||
|
||||
# Build progress metadata object
|
||||
import json
|
||||
progress_data = {}
|
||||
if processing_stage is not None:
|
||||
progress_data['processing_stage'] = processing_stage
|
||||
if chunks_processed is not None:
|
||||
progress_data['chunks_processed'] = chunks_processed
|
||||
if total_chunks_expected is not None:
|
||||
progress_data['total_chunks_expected'] = total_chunks_expected
|
||||
if processing_progress is not None:
|
||||
progress_data['processing_progress'] = processing_progress
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
if error_message:
|
||||
await pg_client.execute_command(
|
||||
"""UPDATE documents SET
|
||||
processing_status = $1,
|
||||
error_message = $2,
|
||||
metadata = COALESCE(metadata, '{}'::jsonb) || $3::jsonb,
|
||||
updated_at = NOW()
|
||||
WHERE id = $4""",
|
||||
status, error_message, json.dumps(progress_data), document_id
|
||||
)
|
||||
else:
|
||||
await pg_client.execute_command(
|
||||
"""UPDATE documents SET
|
||||
processing_status = $1,
|
||||
metadata = COALESCE(metadata, '{}'::jsonb) || $2::jsonb,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3""",
|
||||
status, json.dumps(progress_data), document_id
|
||||
)
|
||||
|
||||
async def _get_existing_document_content(self, document_id: str) -> tuple[str, str]:
|
||||
"""Get existing document content and storage type"""
|
||||
pg_client = await get_postgresql_client()
|
||||
result = await pg_client.fetch_one(
|
||||
"SELECT content_text, metadata FROM documents WHERE id = $1",
|
||||
document_id
|
||||
)
|
||||
if result and result["content_text"]:
|
||||
# Handle metadata - might be JSON string or dict
|
||||
metadata_raw = result["metadata"] or "{}"
|
||||
if isinstance(metadata_raw, str):
|
||||
import json
|
||||
try:
|
||||
metadata = json.loads(metadata_raw)
|
||||
except json.JSONDecodeError:
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = metadata_raw or {}
|
||||
storage_type = metadata.get("storage_type", "unknown")
|
||||
return result["content_text"], storage_type
|
||||
return None, None
|
||||
|
||||
async def _update_document_content(self, document_id: str, content: str):
|
||||
"""Update document with extracted text content"""
|
||||
pg_client = await get_postgresql_client()
|
||||
await pg_client.execute_command(
|
||||
"UPDATE documents SET content_text = $1, updated_at = NOW() WHERE id = $2",
|
||||
content, document_id
|
||||
)
|
||||
|
||||
async def _update_chunk_count(self, document_id: str, chunk_count: int):
|
||||
"""Update document with final chunk count"""
|
||||
pg_client = await get_postgresql_client()
|
||||
await pg_client.execute_command(
|
||||
"UPDATE documents SET chunk_count = $1, updated_at = NOW() WHERE id = $2",
|
||||
chunk_count, document_id
|
||||
)
|
||||
|
||||
async def _generate_document_summary(
|
||||
self,
|
||||
document_id: str,
|
||||
content: str,
|
||||
filename: str,
|
||||
user_id: str
|
||||
):
|
||||
"""Generate and store AI summary for the document"""
|
||||
try:
|
||||
# Use tenant_domain from instance context
|
||||
tenant_domain = self.tenant_domain
|
||||
|
||||
# Create summarization service instance
|
||||
summarization_service = SummarizationService(tenant_domain, user_id)
|
||||
|
||||
# Generate summary using our new service
|
||||
summary = await summarization_service.generate_document_summary(
|
||||
document_id=document_id,
|
||||
document_content=content,
|
||||
document_name=filename
|
||||
)
|
||||
|
||||
if summary:
|
||||
logger.info(f"Generated summary for document {document_id}: {summary[:100]}...")
|
||||
else:
|
||||
logger.warning(f"Failed to generate summary for document {document_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating document summary for {document_id}: {e}")
|
||||
# Don't fail the entire document processing if summarization fails
|
||||
|
||||
async def _update_dataset_summary_after_document_change(
|
||||
self,
|
||||
dataset_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""Update dataset summary after a document is added or removed"""
|
||||
try:
|
||||
# Create summarization service instance
|
||||
summarization_service = SummarizationService(self.tenant_domain, user_id)
|
||||
|
||||
# Update dataset summary asynchronously (don't block document processing)
|
||||
asyncio.create_task(
|
||||
summarization_service.update_dataset_summary_on_change(dataset_id)
|
||||
)
|
||||
|
||||
logger.info(f"Triggered dataset summary update for dataset {dataset_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering dataset summary update for {dataset_id}: {e}")
|
||||
# Don't fail document processing if dataset summary update fails
|
||||
|
||||
async def get_processing_status(self, document_id: str) -> Dict[str, Any]:
|
||||
"""Get current processing status of a document with progress information from metadata"""
|
||||
pg_client = await get_postgresql_client()
|
||||
result = await pg_client.fetch_one(
|
||||
"""SELECT processing_status, error_message, chunk_count, metadata
|
||||
FROM documents WHERE id = $1""",
|
||||
document_id
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise ValueError("Document not found")
|
||||
|
||||
# Extract progress data from metadata JSONB
|
||||
metadata = result["metadata"] or {}
|
||||
|
||||
return {
|
||||
"status": result["processing_status"],
|
||||
"error_message": result["error_message"],
|
||||
"chunk_count": result["chunk_count"],
|
||||
"chunks_processed": metadata.get("chunks_processed"),
|
||||
"total_chunks_expected": metadata.get("total_chunks_expected"),
|
||||
"processing_progress": metadata.get("processing_progress"),
|
||||
"processing_stage": metadata.get("processing_stage")
|
||||
}
|
||||
|
||||
|
||||
# Factory function for document processor
|
||||
async def get_document_processor(tenant_domain=None):
|
||||
"""Get document processor instance (will create its own DB session when needed)"""
|
||||
return DocumentProcessor(tenant_domain=tenant_domain)
|
||||
317
apps/tenant-backend/app/services/document_summarizer.py
Normal file
317
apps/tenant-backend/app/services/document_summarizer.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Document Summarization Service for GT 2.0
|
||||
|
||||
Generates AI-powered summaries for uploaded documents using the Resource Cluster.
|
||||
Provides both quick summaries and detailed analysis for RAG visualization.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.database import get_db_session, execute_command, fetch_one
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentSummarizer:
|
||||
"""
|
||||
Service for generating document summaries using Resource Cluster LLM.
|
||||
|
||||
Features:
|
||||
- Quick document summaries (2-3 sentences)
|
||||
- Detailed analysis with key topics and themes
|
||||
- Metadata extraction (document type, language, etc.)
|
||||
- Integration with document processor workflow
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.resource_cluster_url = "http://gentwo-resource-backend:8000"
|
||||
self.max_content_length = 4000 # Max chars to send for summarization
|
||||
|
||||
async def generate_document_summary(
|
||||
self,
|
||||
document_id: str,
|
||||
content: str,
|
||||
filename: str,
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a comprehensive summary for a document.
|
||||
|
||||
Args:
|
||||
document_id: Document ID in the database
|
||||
content: Document text content
|
||||
filename: Original filename
|
||||
tenant_domain: Tenant domain for context
|
||||
user_id: User who uploaded the document
|
||||
|
||||
Returns:
|
||||
Dictionary with summary data including quick_summary, detailed_analysis,
|
||||
topics, metadata, and confidence scores
|
||||
"""
|
||||
try:
|
||||
# Truncate content if too long
|
||||
truncated_content = content[:self.max_content_length]
|
||||
if len(content) > self.max_content_length:
|
||||
truncated_content += "... [content truncated]"
|
||||
|
||||
# Generate summary using Resource Cluster LLM
|
||||
summary_data = await self._call_llm_for_summary(
|
||||
content=truncated_content,
|
||||
filename=filename,
|
||||
document_type=self._detect_document_type(filename)
|
||||
)
|
||||
|
||||
# Store summary in database
|
||||
await self._store_document_summary(
|
||||
document_id=document_id,
|
||||
summary_data=summary_data,
|
||||
tenant_domain=tenant_domain,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
logger.info(f"Generated summary for document {document_id}: {filename}")
|
||||
return summary_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary for document {document_id}: {e}")
|
||||
# Return basic fallback summary
|
||||
return {
|
||||
"quick_summary": f"Document: {filename}",
|
||||
"detailed_analysis": "Summary generation failed",
|
||||
"topics": [],
|
||||
"metadata": {
|
||||
"document_type": self._detect_document_type(filename),
|
||||
"estimated_read_time": len(content) // 200, # ~200 words per minute
|
||||
"character_count": len(content),
|
||||
"language": "unknown"
|
||||
},
|
||||
"confidence": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def _call_llm_for_summary(
|
||||
self,
|
||||
content: str,
|
||||
filename: str,
|
||||
document_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Call Resource Cluster LLM to generate document summary"""
|
||||
|
||||
prompt = f"""Analyze this {document_type} document and provide a comprehensive summary.
|
||||
|
||||
Document: {filename}
|
||||
Content:
|
||||
{content}
|
||||
|
||||
Please provide:
|
||||
1. A concise 2-3 sentence summary
|
||||
2. Key topics and themes (list)
|
||||
3. Document analysis including tone, purpose, and target audience
|
||||
4. Estimated language and reading level
|
||||
|
||||
Format your response as JSON with these keys:
|
||||
- quick_summary: Brief 2-3 sentence overview
|
||||
- detailed_analysis: Paragraph with deeper insights
|
||||
- topics: Array of key topics/themes
|
||||
- metadata: Object with language, tone, purpose, target_audience
|
||||
- confidence: Float 0-1 indicating analysis confidence"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
|
||||
json={
|
||||
"model": "llama-3.1-8b-instant",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a document analysis expert. Provide accurate, concise summaries in valid JSON format."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 1000
|
||||
},
|
||||
headers={
|
||||
"X-Tenant-ID": "default",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
llm_response = response.json()
|
||||
content_text = llm_response["choices"][0]["message"]["content"]
|
||||
|
||||
# Try to parse JSON response
|
||||
try:
|
||||
import json
|
||||
summary_data = json.loads(content_text)
|
||||
|
||||
# Validate required fields and add defaults if missing
|
||||
return {
|
||||
"quick_summary": summary_data.get("quick_summary", f"Analysis of {filename}"),
|
||||
"detailed_analysis": summary_data.get("detailed_analysis", "Detailed analysis not available"),
|
||||
"topics": summary_data.get("topics", []),
|
||||
"metadata": {
|
||||
**summary_data.get("metadata", {}),
|
||||
"document_type": document_type,
|
||||
"generated_at": datetime.utcnow().isoformat()
|
||||
},
|
||||
"confidence": min(1.0, max(0.0, summary_data.get("confidence", 0.7)))
|
||||
}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if LLM doesn't return valid JSON
|
||||
return {
|
||||
"quick_summary": content_text[:200] + "..." if len(content_text) > 200 else content_text,
|
||||
"detailed_analysis": content_text,
|
||||
"topics": [],
|
||||
"metadata": {
|
||||
"document_type": document_type,
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"note": "Summary extracted from free-form LLM response"
|
||||
},
|
||||
"confidence": 0.5
|
||||
}
|
||||
else:
|
||||
raise Exception(f"Resource Cluster API error: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM summarization failed: {e}")
|
||||
raise
|
||||
|
||||
async def _store_document_summary(
|
||||
self,
|
||||
document_id: str,
|
||||
summary_data: Dict[str, Any],
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
):
|
||||
"""Store generated summary in database"""
|
||||
|
||||
# Use the same database session pattern as document processor
|
||||
async with get_db_session(tenant_domain) as session:
|
||||
try:
|
||||
# Insert or update document summary
|
||||
query = """
|
||||
INSERT INTO document_summaries (
|
||||
document_id, user_id, quick_summary, detailed_analysis,
|
||||
topics, metadata, confidence, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (document_id)
|
||||
DO UPDATE SET
|
||||
quick_summary = EXCLUDED.quick_summary,
|
||||
detailed_analysis = EXCLUDED.detailed_analysis,
|
||||
topics = EXCLUDED.topics,
|
||||
metadata = EXCLUDED.metadata,
|
||||
confidence = EXCLUDED.confidence,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
"""
|
||||
|
||||
import json
|
||||
await execute_command(
|
||||
session,
|
||||
query,
|
||||
document_id,
|
||||
user_id,
|
||||
summary_data["quick_summary"],
|
||||
summary_data["detailed_analysis"],
|
||||
json.dumps(summary_data["topics"]),
|
||||
json.dumps(summary_data["metadata"]),
|
||||
summary_data["confidence"],
|
||||
datetime.utcnow(),
|
||||
datetime.utcnow()
|
||||
)
|
||||
|
||||
logger.info(f"Stored summary for document {document_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store document summary: {e}")
|
||||
raise
|
||||
|
||||
def _detect_document_type(self, filename: str) -> str:
|
||||
"""Detect document type from filename extension"""
|
||||
import pathlib
|
||||
|
||||
extension = pathlib.Path(filename).suffix.lower()
|
||||
|
||||
type_mapping = {
|
||||
'.pdf': 'PDF document',
|
||||
'.docx': 'Word document',
|
||||
'.doc': 'Word document',
|
||||
'.txt': 'Text file',
|
||||
'.md': 'Markdown document',
|
||||
'.csv': 'CSV data file',
|
||||
'.json': 'JSON data file',
|
||||
'.html': 'HTML document',
|
||||
'.htm': 'HTML document',
|
||||
'.rtf': 'Rich text document'
|
||||
}
|
||||
|
||||
return type_mapping.get(extension, 'Unknown document type')
|
||||
|
||||
async def get_document_summary(
|
||||
self,
|
||||
document_id: str,
|
||||
tenant_domain: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve stored document summary"""
|
||||
|
||||
async with get_db_session(tenant_domain) as session:
|
||||
try:
|
||||
query = """
|
||||
SELECT quick_summary, detailed_analysis, topics, metadata,
|
||||
confidence, created_at, updated_at
|
||||
FROM document_summaries
|
||||
WHERE document_id = $1
|
||||
"""
|
||||
|
||||
result = await fetch_one(session, query, document_id)
|
||||
|
||||
if result:
|
||||
import json
|
||||
return {
|
||||
"quick_summary": result["quick_summary"],
|
||||
"detailed_analysis": result["detailed_analysis"],
|
||||
"topics": json.loads(result["topics"]) if result["topics"] else [],
|
||||
"metadata": json.loads(result["metadata"]) if result["metadata"] else {},
|
||||
"confidence": result["confidence"],
|
||||
"created_at": result["created_at"].isoformat(),
|
||||
"updated_at": result["updated_at"].isoformat()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve document summary: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Global instance
|
||||
document_summarizer = DocumentSummarizer()
|
||||
|
||||
|
||||
async def generate_document_summary(
|
||||
document_id: str,
|
||||
content: str,
|
||||
filename: str,
|
||||
tenant_domain: str,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Convenience function for document summary generation"""
|
||||
return await document_summarizer.generate_document_summary(
|
||||
document_id, content, filename, tenant_domain, user_id
|
||||
)
|
||||
|
||||
|
||||
async def get_document_summary(document_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
|
||||
"""Convenience function for retrieving document summary"""
|
||||
return await document_summarizer.get_document_summary(document_id, tenant_domain)
|
||||
286
apps/tenant-backend/app/services/embedding_client.py
Normal file
286
apps/tenant-backend/app/services/embedding_client.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
BGE-M3 Embedding Client for GT 2.0
|
||||
|
||||
Simple client for the vLLM BGE-M3 embedding service running on port 8005.
|
||||
Provides text embedding generation for RAG pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BGE_M3_EmbeddingClient:
|
||||
"""
|
||||
Simple client for BGE-M3 embedding service via vLLM.
|
||||
|
||||
Features:
|
||||
- Async HTTP client for embeddings
|
||||
- Batch processing support
|
||||
- Error handling and retries
|
||||
- OpenAI-compatible API format
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str = None):
|
||||
# Determine base URL from environment or configuration
|
||||
if base_url is None:
|
||||
base_url = self._get_embedding_endpoint()
|
||||
|
||||
self.base_url = base_url
|
||||
self.model = "BAAI/bge-m3"
|
||||
self.embedding_dimensions = 1024
|
||||
self.max_batch_size = 32
|
||||
|
||||
# Initialize BGE-M3 tokenizer for accurate token counting
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
|
||||
logger.info("Initialized BGE-M3 tokenizer for accurate token counting")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load BGE-M3 tokenizer, using word estimation: {e}")
|
||||
self.tokenizer = None
|
||||
|
||||
def _get_embedding_endpoint(self) -> str:
|
||||
"""
|
||||
Get the BGE-M3 endpoint based on configuration.
|
||||
This should sync with the control panel configuration.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Check environment variables for BGE-M3 configuration
|
||||
is_local_mode = os.getenv('BGE_M3_LOCAL_MODE', 'true').lower() == 'true'
|
||||
external_endpoint = os.getenv('BGE_M3_EXTERNAL_ENDPOINT')
|
||||
|
||||
if not is_local_mode and external_endpoint:
|
||||
return external_endpoint
|
||||
|
||||
# Default to local endpoint
|
||||
return os.getenv('EMBEDDING_ENDPOINT', 'http://host.docker.internal:8005')
|
||||
|
||||
def update_endpoint(self, new_endpoint: str):
|
||||
"""Update the embedding endpoint dynamically"""
|
||||
self.base_url = new_endpoint
|
||||
logger.info(f"BGE-M3 client endpoint updated to: {new_endpoint}")
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Check if BGE-M3 service is responding"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(f"{self.base_url}/v1/models")
|
||||
if response.status_code == 200:
|
||||
models = response.json()
|
||||
model_ids = [model['id'] for model in models.get('data', [])]
|
||||
return self.model in model_ids
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
async def generate_embeddings(
|
||||
self,
|
||||
texts: List[str],
|
||||
tenant_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
request_id: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for a list of texts using BGE-M3.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
tenant_id: Tenant ID for usage tracking (optional)
|
||||
user_id: User ID for usage tracking (optional)
|
||||
request_id: Request ID for tracking (optional)
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each is a list of 1024 floats)
|
||||
|
||||
Raises:
|
||||
ValueError: If embedding generation fails
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
if len(texts) > self.max_batch_size:
|
||||
# Process in batches
|
||||
all_embeddings = []
|
||||
for i in range(0, len(texts), self.max_batch_size):
|
||||
batch = texts[i:i + self.max_batch_size]
|
||||
batch_embeddings = await self._generate_batch(batch)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
embeddings = all_embeddings
|
||||
else:
|
||||
embeddings = await self._generate_batch(texts)
|
||||
|
||||
# Log usage if tenant context provided (fire and forget)
|
||||
if tenant_id and user_id:
|
||||
import asyncio
|
||||
tokens_used = self._count_tokens(texts)
|
||||
asyncio.create_task(
|
||||
self._log_embedding_usage(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
tokens_used=tokens_used,
|
||||
embedding_count=len(embeddings),
|
||||
request_id=request_id
|
||||
)
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def _generate_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for a single batch"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/v1/embeddings",
|
||||
json={
|
||||
"input": texts,
|
||||
"model": self.model
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
# Extract embeddings from OpenAI-compatible response
|
||||
embeddings = []
|
||||
for item in data.get("data", []):
|
||||
embedding = item.get("embedding", [])
|
||||
if len(embedding) != self.embedding_dimensions:
|
||||
raise ValueError(f"Invalid embedding dimensions: {len(embedding)} (expected {self.embedding_dimensions})")
|
||||
embeddings.append(embedding)
|
||||
|
||||
logger.info(f"Generated {len(embeddings)} embeddings")
|
||||
return embeddings
|
||||
else:
|
||||
error_text = response.text
|
||||
logger.error(f"Embedding generation failed: {response.status_code} - {error_text}")
|
||||
raise ValueError(f"Embedding generation failed: {response.status_code}")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Embedding generation timed out")
|
||||
raise ValueError("Embedding generation timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling embedding service: {e}")
|
||||
raise ValueError(f"Embedding service error: {str(e)}")
|
||||
|
||||
def _count_tokens(self, texts: List[str]) -> int:
|
||||
"""Count tokens using actual BGE-M3 tokenizer."""
|
||||
if self.tokenizer is not None:
|
||||
try:
|
||||
total_tokens = 0
|
||||
for text in texts:
|
||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
total_tokens += len(tokens)
|
||||
return total_tokens
|
||||
except Exception as e:
|
||||
logger.warning(f"Tokenizer error, falling back to estimation: {e}")
|
||||
|
||||
# Fallback: word count * 1.3
|
||||
total_words = sum(len(text.split()) for text in texts)
|
||||
return int(total_words * 1.3)
|
||||
|
||||
async def _log_embedding_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
tokens_used: int,
|
||||
embedding_count: int,
|
||||
request_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""Log embedding usage to control panel database for billing."""
|
||||
try:
|
||||
import asyncpg
|
||||
import os
|
||||
|
||||
# Calculate cost: BGE-M3 pricing ~$0.10 per million tokens
|
||||
cost_cents = (tokens_used / 1_000_000) * 0.10 * 100
|
||||
|
||||
db_password = os.getenv("CONTROL_PANEL_DB_PASSWORD")
|
||||
if not db_password:
|
||||
logger.warning("CONTROL_PANEL_DB_PASSWORD not set, skipping embedding usage logging")
|
||||
return
|
||||
|
||||
conn = await asyncpg.connect(
|
||||
host=os.getenv("CONTROL_PANEL_DB_HOST", "gentwo-controlpanel-postgres"),
|
||||
database=os.getenv("CONTROL_PANEL_DB_NAME", "gt2_admin"),
|
||||
user=os.getenv("CONTROL_PANEL_DB_USER", "postgres"),
|
||||
password=db_password,
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
try:
|
||||
await conn.execute("""
|
||||
INSERT INTO public.embedding_usage_logs
|
||||
(tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""", tenant_id, user_id, tokens_used, embedding_count, self.model, cost_cents, request_id)
|
||||
|
||||
logger.info(
|
||||
f"Logged embedding usage: tenant={tenant_id}, user={user_id}, "
|
||||
f"tokens={tokens_used}, embeddings={embedding_count}, cost_cents={cost_cents:.4f}"
|
||||
)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log embedding usage: {e}")
|
||||
|
||||
async def generate_single_embedding(self, text: str) -> List[float]:
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text string to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector (list of 1024 floats)
|
||||
"""
|
||||
embeddings = await self.generate_embeddings([text])
|
||||
return embeddings[0] if embeddings else []
|
||||
|
||||
|
||||
# Global client instance
|
||||
_embedding_client: Optional[BGE_M3_EmbeddingClient] = None
|
||||
|
||||
|
||||
def get_embedding_client() -> BGE_M3_EmbeddingClient:
|
||||
"""Get or create global embedding client instance"""
|
||||
global _embedding_client
|
||||
if _embedding_client is None:
|
||||
_embedding_client = BGE_M3_EmbeddingClient()
|
||||
else:
|
||||
# Always refresh the endpoint from current configuration
|
||||
current_endpoint = _embedding_client._get_embedding_endpoint()
|
||||
if _embedding_client.base_url != current_endpoint:
|
||||
_embedding_client.base_url = current_endpoint
|
||||
logger.info(f"BGE-M3 client endpoint refreshed to: {current_endpoint}")
|
||||
return _embedding_client
|
||||
|
||||
|
||||
async def test_embedding_client():
|
||||
"""Test function for the embedding client"""
|
||||
client = get_embedding_client()
|
||||
|
||||
# Test health check
|
||||
is_healthy = await client.health_check()
|
||||
print(f"BGE-M3 service healthy: {is_healthy}")
|
||||
|
||||
if is_healthy:
|
||||
# Test embedding generation
|
||||
test_texts = [
|
||||
"This is a test document about machine learning.",
|
||||
"GT 2.0 is an enterprise AI platform.",
|
||||
"Vector embeddings enable semantic search."
|
||||
]
|
||||
|
||||
embeddings = await client.generate_embeddings(test_texts)
|
||||
print(f"Generated {len(embeddings)} embeddings")
|
||||
print(f"Embedding dimensions: {len(embeddings[0]) if embeddings else 0}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_embedding_client())
|
||||
722
apps/tenant-backend/app/services/enhanced_api_keys.py
Normal file
722
apps/tenant-backend/app/services/enhanced_api_keys.py
Normal file
@@ -0,0 +1,722 @@
|
||||
"""
|
||||
Enhanced API Key Management Service for GT 2.0
|
||||
|
||||
Implements advanced API key management with capability-based permissions,
|
||||
configurable constraints, and comprehensive audit logging.
|
||||
"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
import json
|
||||
import secrets
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from uuid import uuid4
|
||||
import jwt
|
||||
|
||||
from app.core.security import verify_capability_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIKeyStatus(Enum):
|
||||
"""API key status states"""
|
||||
ACTIVE = "active"
|
||||
SUSPENDED = "suspended"
|
||||
EXPIRED = "expired"
|
||||
REVOKED = "revoked"
|
||||
|
||||
|
||||
class APIKeyScope(Enum):
|
||||
"""API key scope levels"""
|
||||
USER = "user" # User-specific operations
|
||||
TENANT = "tenant" # Tenant-wide operations
|
||||
ADMIN = "admin" # Administrative operations
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIKeyUsage:
|
||||
"""API key usage tracking"""
|
||||
requests_count: int = 0
|
||||
last_used: Optional[datetime] = None
|
||||
bytes_transferred: int = 0
|
||||
errors_count: int = 0
|
||||
rate_limit_hits: int = 0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage"""
|
||||
return {
|
||||
"requests_count": self.requests_count,
|
||||
"last_used": self.last_used.isoformat() if self.last_used else None,
|
||||
"bytes_transferred": self.bytes_transferred,
|
||||
"errors_count": self.errors_count,
|
||||
"rate_limit_hits": self.rate_limit_hits
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "APIKeyUsage":
|
||||
"""Create from dictionary"""
|
||||
return cls(
|
||||
requests_count=data.get("requests_count", 0),
|
||||
last_used=datetime.fromisoformat(data["last_used"]) if data.get("last_used") else None,
|
||||
bytes_transferred=data.get("bytes_transferred", 0),
|
||||
errors_count=data.get("errors_count", 0),
|
||||
rate_limit_hits=data.get("rate_limit_hits", 0)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIKeyConfig:
|
||||
"""Enhanced API key configuration"""
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
owner_id: str = ""
|
||||
key_hash: str = ""
|
||||
|
||||
# Capability and permissions
|
||||
capabilities: List[str] = field(default_factory=list)
|
||||
scope: APIKeyScope = APIKeyScope.USER
|
||||
tenant_constraints: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Rate limiting and quotas
|
||||
rate_limit_per_hour: int = 1000
|
||||
daily_quota: int = 10000
|
||||
monthly_quota: int = 300000
|
||||
cost_limit_cents: int = 1000
|
||||
|
||||
# Resource constraints
|
||||
max_tokens_per_request: int = 4000
|
||||
max_concurrent_requests: int = 10
|
||||
allowed_endpoints: List[str] = field(default_factory=list)
|
||||
blocked_endpoints: List[str] = field(default_factory=list)
|
||||
|
||||
# Network and security
|
||||
allowed_ips: List[str] = field(default_factory=list)
|
||||
allowed_domains: List[str] = field(default_factory=list)
|
||||
require_tls: bool = True
|
||||
|
||||
# Lifecycle management
|
||||
status: APIKeyStatus = APIKeyStatus.ACTIVE
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: Optional[datetime] = None
|
||||
last_rotated: Optional[datetime] = None
|
||||
|
||||
# Usage tracking
|
||||
usage: APIKeyUsage = field(default_factory=APIKeyUsage)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"owner_id": self.owner_id,
|
||||
"key_hash": self.key_hash,
|
||||
"capabilities": self.capabilities,
|
||||
"scope": self.scope.value,
|
||||
"tenant_constraints": self.tenant_constraints,
|
||||
"rate_limit_per_hour": self.rate_limit_per_hour,
|
||||
"daily_quota": self.daily_quota,
|
||||
"monthly_quota": self.monthly_quota,
|
||||
"cost_limit_cents": self.cost_limit_cents,
|
||||
"max_tokens_per_request": self.max_tokens_per_request,
|
||||
"max_concurrent_requests": self.max_concurrent_requests,
|
||||
"allowed_endpoints": self.allowed_endpoints,
|
||||
"blocked_endpoints": self.blocked_endpoints,
|
||||
"allowed_ips": self.allowed_ips,
|
||||
"allowed_domains": self.allowed_domains,
|
||||
"require_tls": self.require_tls,
|
||||
"status": self.status.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"last_rotated": self.last_rotated.isoformat() if self.last_rotated else None,
|
||||
"usage": self.usage.to_dict()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "APIKeyConfig":
|
||||
"""Create from dictionary"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
description=data.get("description", ""),
|
||||
owner_id=data["owner_id"],
|
||||
key_hash=data["key_hash"],
|
||||
capabilities=data.get("capabilities", []),
|
||||
scope=APIKeyScope(data.get("scope", "user")),
|
||||
tenant_constraints=data.get("tenant_constraints", {}),
|
||||
rate_limit_per_hour=data.get("rate_limit_per_hour", 1000),
|
||||
daily_quota=data.get("daily_quota", 10000),
|
||||
monthly_quota=data.get("monthly_quota", 300000),
|
||||
cost_limit_cents=data.get("cost_limit_cents", 1000),
|
||||
max_tokens_per_request=data.get("max_tokens_per_request", 4000),
|
||||
max_concurrent_requests=data.get("max_concurrent_requests", 10),
|
||||
allowed_endpoints=data.get("allowed_endpoints", []),
|
||||
blocked_endpoints=data.get("blocked_endpoints", []),
|
||||
allowed_ips=data.get("allowed_ips", []),
|
||||
allowed_domains=data.get("allowed_domains", []),
|
||||
require_tls=data.get("require_tls", True),
|
||||
status=APIKeyStatus(data.get("status", "active")),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
|
||||
last_rotated=datetime.fromisoformat(data["last_rotated"]) if data.get("last_rotated") else None,
|
||||
usage=APIKeyUsage.from_dict(data.get("usage", {}))
|
||||
)
|
||||
|
||||
|
||||
class EnhancedAPIKeyService:
|
||||
"""
|
||||
Enhanced API Key management service with advanced capabilities.
|
||||
|
||||
Features:
|
||||
- Capability-based permissions with tenant constraints
|
||||
- Granular rate limiting and quota management
|
||||
- Network-based access controls (IP, domain restrictions)
|
||||
- Comprehensive usage tracking and analytics
|
||||
- Automated key rotation and lifecycle management
|
||||
- Perfect tenant isolation through file-based storage
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, signing_key: str = ""):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.signing_key = signing_key or self._generate_signing_key()
|
||||
self.base_path = Path(f"/data/{tenant_domain}/api_keys")
|
||||
self.keys_path = self.base_path / "keys"
|
||||
self.usage_path = self.base_path / "usage"
|
||||
self.audit_path = self.base_path / "audit"
|
||||
|
||||
# Ensure directories exist with proper permissions
|
||||
self._ensure_directories()
|
||||
|
||||
logger.info(f"EnhancedAPIKeyService initialized for {tenant_domain}")
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure API key directories exist with proper permissions"""
|
||||
for path in [self.keys_path, self.usage_path, self.audit_path]:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
# Set permissions to 700 (owner only)
|
||||
os.chmod(path, stat.S_IRWXU)
|
||||
|
||||
def _generate_signing_key(self) -> str:
|
||||
"""Generate cryptographic signing key for JWT tokens"""
|
||||
return secrets.token_urlsafe(64)
|
||||
|
||||
async def create_api_key(
|
||||
self,
|
||||
name: str,
|
||||
owner_id: str,
|
||||
capabilities: List[str],
|
||||
scope: APIKeyScope = APIKeyScope.USER,
|
||||
expires_in_days: int = 90,
|
||||
constraints: Optional[Dict[str, Any]] = None,
|
||||
capability_token: str = ""
|
||||
) -> Tuple[APIKeyConfig, str]:
|
||||
"""
|
||||
Create a new API key with specified capabilities and constraints.
|
||||
|
||||
Args:
|
||||
name: Human-readable name for the key
|
||||
owner_id: User who owns the key
|
||||
capabilities: List of capability strings
|
||||
scope: Key scope level
|
||||
expires_in_days: Expiration time in days
|
||||
constraints: Custom constraints for the key
|
||||
capability_token: Admin capability token
|
||||
|
||||
Returns:
|
||||
Tuple of (APIKeyConfig, raw_key)
|
||||
"""
|
||||
# Verify admin capability for key creation
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Generate secure API key
|
||||
raw_key = f"gt2_{self.tenant_domain}_{secrets.token_urlsafe(32)}"
|
||||
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
|
||||
# Apply constraints with tenant-specific defaults
|
||||
final_constraints = self._apply_tenant_defaults(constraints or {})
|
||||
|
||||
# Create API key configuration
|
||||
api_key = APIKeyConfig(
|
||||
name=name,
|
||||
owner_id=owner_id,
|
||||
key_hash=key_hash,
|
||||
capabilities=capabilities,
|
||||
scope=scope,
|
||||
tenant_constraints=final_constraints,
|
||||
expires_at=datetime.utcnow() + timedelta(days=expires_in_days)
|
||||
)
|
||||
|
||||
# Apply scope-based defaults
|
||||
self._apply_scope_defaults(api_key, scope)
|
||||
|
||||
# Store API key
|
||||
await self._store_api_key(api_key)
|
||||
|
||||
# Log creation
|
||||
await self._audit_log("api_key_created", owner_id, {
|
||||
"key_id": api_key.id,
|
||||
"name": name,
|
||||
"scope": scope.value,
|
||||
"capabilities": capabilities
|
||||
})
|
||||
|
||||
logger.info(f"Created API key: {name} ({api_key.id}) for {owner_id}")
|
||||
return api_key, raw_key
|
||||
|
||||
async def validate_api_key(
|
||||
self,
|
||||
raw_key: str,
|
||||
endpoint: str = "",
|
||||
client_ip: str = "",
|
||||
user_agent: str = ""
|
||||
) -> Tuple[bool, Optional[APIKeyConfig], Optional[str]]:
|
||||
"""
|
||||
Validate API key and check constraints.
|
||||
|
||||
Args:
|
||||
raw_key: Raw API key from request
|
||||
endpoint: Requested endpoint
|
||||
client_ip: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
Tuple of (valid, api_key_config, error_message)
|
||||
"""
|
||||
# Hash the key for lookup
|
||||
# Security Note: SHA256 is used here for API key lookup/indexing, not password storage.
|
||||
# API keys are high-entropy random strings, making them resistant to dictionary/rainbow attacks.
|
||||
# This is an acceptable security pattern similar to how GitHub and Stripe handle API keys.
|
||||
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
|
||||
# Load API key configuration
|
||||
api_key = await self._load_api_key_by_hash(key_hash)
|
||||
if not api_key:
|
||||
return False, None, "Invalid API key"
|
||||
|
||||
# Check key status
|
||||
if api_key.status != APIKeyStatus.ACTIVE:
|
||||
return False, api_key, f"API key is {api_key.status.value}"
|
||||
|
||||
# Check expiration
|
||||
if api_key.expires_at and datetime.utcnow() > api_key.expires_at:
|
||||
# Auto-expire the key
|
||||
api_key.status = APIKeyStatus.EXPIRED
|
||||
await self._store_api_key(api_key)
|
||||
return False, api_key, "API key has expired"
|
||||
|
||||
# Check endpoint restrictions
|
||||
if api_key.allowed_endpoints:
|
||||
if endpoint not in api_key.allowed_endpoints:
|
||||
return False, api_key, f"Endpoint {endpoint} not allowed"
|
||||
|
||||
if endpoint in api_key.blocked_endpoints:
|
||||
return False, api_key, f"Endpoint {endpoint} is blocked"
|
||||
|
||||
# Check IP restrictions
|
||||
if api_key.allowed_ips and client_ip not in api_key.allowed_ips:
|
||||
return False, api_key, f"IP {client_ip} not allowed"
|
||||
|
||||
# Check rate limits
|
||||
rate_limit_ok, rate_error = await self._check_rate_limits(api_key)
|
||||
if not rate_limit_ok:
|
||||
return False, api_key, rate_error
|
||||
|
||||
# Update usage
|
||||
await self._update_usage(api_key, endpoint, client_ip)
|
||||
|
||||
return True, api_key, None
|
||||
|
||||
async def generate_capability_token(
|
||||
self,
|
||||
api_key: APIKeyConfig,
|
||||
additional_context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate JWT capability token from API key.
|
||||
|
||||
Args:
|
||||
api_key: API key configuration
|
||||
additional_context: Additional context for the token
|
||||
|
||||
Returns:
|
||||
JWT capability token
|
||||
"""
|
||||
# Build capability payload
|
||||
capabilities = []
|
||||
for cap_string in api_key.capabilities:
|
||||
capability = {
|
||||
"resource": cap_string,
|
||||
"actions": ["*"], # API keys get full action access for their capabilities
|
||||
"constraints": api_key.tenant_constraints.get(cap_string, {})
|
||||
}
|
||||
capabilities.append(capability)
|
||||
|
||||
# Create JWT payload
|
||||
payload = {
|
||||
"sub": api_key.owner_id,
|
||||
"tenant_id": self.tenant_domain,
|
||||
"api_key_id": api_key.id,
|
||||
"scope": api_key.scope.value,
|
||||
"capabilities": capabilities,
|
||||
"constraints": api_key.tenant_constraints,
|
||||
"rate_limits": {
|
||||
"requests_per_hour": api_key.rate_limit_per_hour,
|
||||
"max_tokens_per_request": api_key.max_tokens_per_request,
|
||||
"cost_limit_cents": api_key.cost_limit_cents
|
||||
},
|
||||
"iat": int(datetime.utcnow().timestamp()),
|
||||
"exp": int((datetime.utcnow() + timedelta(hours=1)).timestamp())
|
||||
}
|
||||
|
||||
# Add additional context
|
||||
if additional_context:
|
||||
payload.update(additional_context)
|
||||
|
||||
# Sign and return token
|
||||
token = jwt.encode(payload, self.signing_key, algorithm="HS256")
|
||||
return token
|
||||
|
||||
async def rotate_api_key(
|
||||
self,
|
||||
key_id: str,
|
||||
owner_id: str,
|
||||
capability_token: str
|
||||
) -> Tuple[APIKeyConfig, str]:
|
||||
"""
|
||||
Rotate API key (generate new key value).
|
||||
|
||||
Args:
|
||||
key_id: API key ID to rotate
|
||||
owner_id: Owner of the key
|
||||
capability_token: Admin capability token
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_config, new_raw_key)
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Load existing key
|
||||
api_key = await self._load_api_key(key_id)
|
||||
if not api_key:
|
||||
raise ValueError("API key not found")
|
||||
|
||||
# Verify ownership
|
||||
if api_key.owner_id != owner_id:
|
||||
raise PermissionError("Only key owner can rotate")
|
||||
|
||||
# Generate new key
|
||||
new_raw_key = f"gt2_{self.tenant_domain}_{secrets.token_urlsafe(32)}"
|
||||
new_key_hash = hashlib.sha256(new_raw_key.encode()).hexdigest()
|
||||
|
||||
# Update configuration
|
||||
api_key.key_hash = new_key_hash
|
||||
api_key.last_rotated = datetime.utcnow()
|
||||
api_key.updated_at = datetime.utcnow()
|
||||
|
||||
# Store updated key
|
||||
await self._store_api_key(api_key)
|
||||
|
||||
# Log rotation
|
||||
await self._audit_log("api_key_rotated", owner_id, {
|
||||
"key_id": key_id,
|
||||
"name": api_key.name
|
||||
})
|
||||
|
||||
logger.info(f"Rotated API key: {api_key.name} ({key_id})")
|
||||
return api_key, new_raw_key
|
||||
|
||||
async def revoke_api_key(
|
||||
self,
|
||||
key_id: str,
|
||||
owner_id: str,
|
||||
capability_token: str
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke API key (mark as revoked).
|
||||
|
||||
Args:
|
||||
key_id: API key ID to revoke
|
||||
owner_id: Owner of the key
|
||||
capability_token: Admin capability token
|
||||
|
||||
Returns:
|
||||
True if revoked successfully
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
# Load and verify key
|
||||
api_key = await self._load_api_key(key_id)
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
if api_key.owner_id != owner_id:
|
||||
raise PermissionError("Only key owner can revoke")
|
||||
|
||||
# Revoke key
|
||||
api_key.status = APIKeyStatus.REVOKED
|
||||
api_key.updated_at = datetime.utcnow()
|
||||
|
||||
# Store updated key
|
||||
await self._store_api_key(api_key)
|
||||
|
||||
# Log revocation
|
||||
await self._audit_log("api_key_revoked", owner_id, {
|
||||
"key_id": key_id,
|
||||
"name": api_key.name
|
||||
})
|
||||
|
||||
logger.info(f"Revoked API key: {api_key.name} ({key_id})")
|
||||
return True
|
||||
|
||||
async def list_user_api_keys(
|
||||
self,
|
||||
owner_id: str,
|
||||
capability_token: str,
|
||||
include_usage: bool = True
|
||||
) -> List[APIKeyConfig]:
|
||||
"""
|
||||
List API keys for a user.
|
||||
|
||||
Args:
|
||||
owner_id: User to get keys for
|
||||
capability_token: User capability token
|
||||
include_usage: Include usage statistics
|
||||
|
||||
Returns:
|
||||
List of API key configurations
|
||||
"""
|
||||
# Verify capability token
|
||||
token_data = verify_capability_token(capability_token)
|
||||
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
|
||||
raise PermissionError("Invalid capability token")
|
||||
|
||||
user_keys = []
|
||||
|
||||
# Load all keys and filter by owner
|
||||
if self.keys_path.exists():
|
||||
for key_file in self.keys_path.glob("*.json"):
|
||||
try:
|
||||
with open(key_file, "r") as f:
|
||||
data = json.load(f)
|
||||
if data.get("owner_id") == owner_id:
|
||||
api_key = APIKeyConfig.from_dict(data)
|
||||
|
||||
# Update usage if requested
|
||||
if include_usage:
|
||||
await self._update_key_usage_stats(api_key)
|
||||
|
||||
user_keys.append(api_key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading key file {key_file}: {e}")
|
||||
|
||||
return sorted(user_keys, key=lambda k: k.created_at, reverse=True)
|
||||
|
||||
async def get_usage_analytics(
|
||||
self,
|
||||
owner_id: str,
|
||||
key_id: Optional[str] = None,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get usage analytics for API keys.
|
||||
|
||||
Args:
|
||||
owner_id: Owner of the keys
|
||||
key_id: Specific key ID (optional)
|
||||
days: Number of days to analyze
|
||||
|
||||
Returns:
|
||||
Usage analytics data
|
||||
"""
|
||||
analytics = {
|
||||
"total_requests": 0,
|
||||
"total_errors": 0,
|
||||
"avg_requests_per_day": 0,
|
||||
"most_used_endpoints": [],
|
||||
"rate_limit_hits": 0,
|
||||
"keys_analyzed": 0,
|
||||
"date_range": {
|
||||
"start": (datetime.utcnow() - timedelta(days=days)).isoformat(),
|
||||
"end": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Get user's keys
|
||||
user_keys = await self.list_user_api_keys(owner_id, "", include_usage=True)
|
||||
|
||||
# Filter by specific key if requested
|
||||
if key_id:
|
||||
user_keys = [key for key in user_keys if key.id == key_id]
|
||||
|
||||
# Aggregate usage data
|
||||
for api_key in user_keys:
|
||||
analytics["total_requests"] += api_key.usage.requests_count
|
||||
analytics["total_errors"] += api_key.usage.errors_count
|
||||
analytics["rate_limit_hits"] += api_key.usage.rate_limit_hits
|
||||
analytics["keys_analyzed"] += 1
|
||||
|
||||
# Calculate averages
|
||||
if days > 0:
|
||||
analytics["avg_requests_per_day"] = analytics["total_requests"] / days
|
||||
|
||||
return analytics
|
||||
|
||||
def _apply_tenant_defaults(self, constraints: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Apply tenant-specific default constraints"""
|
||||
defaults = {
|
||||
"max_automation_chain_depth": 5,
|
||||
"mcp_memory_limit_mb": 512,
|
||||
"mcp_timeout_seconds": 30,
|
||||
"max_file_size_bytes": 10 * 1024 * 1024, # 10MB
|
||||
"allowed_file_types": [".pdf", ".txt", ".md", ".json", ".csv"],
|
||||
"enable_premium_features": False
|
||||
}
|
||||
|
||||
# Merge with provided constraints (provided values take precedence)
|
||||
final_constraints = defaults.copy()
|
||||
final_constraints.update(constraints)
|
||||
|
||||
return final_constraints
|
||||
|
||||
def _apply_scope_defaults(self, api_key: APIKeyConfig, scope: APIKeyScope):
|
||||
"""Apply scope-based default limits"""
|
||||
if scope == APIKeyScope.USER:
|
||||
api_key.rate_limit_per_hour = 1000
|
||||
api_key.daily_quota = 10000
|
||||
api_key.cost_limit_cents = 1000
|
||||
elif scope == APIKeyScope.TENANT:
|
||||
api_key.rate_limit_per_hour = 5000
|
||||
api_key.daily_quota = 50000
|
||||
api_key.cost_limit_cents = 5000
|
||||
elif scope == APIKeyScope.ADMIN:
|
||||
api_key.rate_limit_per_hour = 10000
|
||||
api_key.daily_quota = 100000
|
||||
api_key.cost_limit_cents = 10000
|
||||
|
||||
async def _check_rate_limits(self, api_key: APIKeyConfig) -> Tuple[bool, Optional[str]]:
|
||||
"""Check if API key is within rate limits"""
|
||||
# For now, implement basic hourly check
|
||||
# In production, would check against usage tracking database
|
||||
|
||||
current_hour = datetime.utcnow().replace(minute=0, second=0, microsecond=0)
|
||||
|
||||
# Load hourly usage (mock implementation)
|
||||
hourly_usage = 0 # Would query actual usage data
|
||||
|
||||
if hourly_usage >= api_key.rate_limit_per_hour:
|
||||
api_key.usage.rate_limit_hits += 1
|
||||
await self._store_api_key(api_key)
|
||||
return False, f"Rate limit exceeded: {hourly_usage}/{api_key.rate_limit_per_hour} requests per hour"
|
||||
|
||||
return True, None
|
||||
|
||||
async def _update_usage(self, api_key: APIKeyConfig, endpoint: str, client_ip: str):
|
||||
"""Update API key usage statistics"""
|
||||
api_key.usage.requests_count += 1
|
||||
api_key.usage.last_used = datetime.utcnow()
|
||||
|
||||
# Store updated usage
|
||||
await self._store_api_key(api_key)
|
||||
|
||||
# Log detailed usage (for analytics)
|
||||
await self._log_usage(api_key.id, endpoint, client_ip)
|
||||
|
||||
async def _store_api_key(self, api_key: APIKeyConfig):
|
||||
"""Store API key configuration to file system"""
|
||||
key_file = self.keys_path / f"{api_key.id}.json"
|
||||
|
||||
with open(key_file, "w") as f:
|
||||
json.dump(api_key.to_dict(), f, indent=2)
|
||||
|
||||
# Set secure permissions
|
||||
os.chmod(key_file, stat.S_IRUSR | stat.S_IWUSR) # 600
|
||||
|
||||
async def _load_api_key(self, key_id: str) -> Optional[APIKeyConfig]:
|
||||
"""Load API key configuration by ID"""
|
||||
key_file = self.keys_path / f"{key_id}.json"
|
||||
|
||||
if not key_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(key_file, "r") as f:
|
||||
data = json.load(f)
|
||||
return APIKeyConfig.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading API key {key_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _load_api_key_by_hash(self, key_hash: str) -> Optional[APIKeyConfig]:
|
||||
"""Load API key configuration by hash"""
|
||||
if not self.keys_path.exists():
|
||||
return None
|
||||
|
||||
for key_file in self.keys_path.glob("*.json"):
|
||||
try:
|
||||
with open(key_file, "r") as f:
|
||||
data = json.load(f)
|
||||
if data.get("key_hash") == key_hash:
|
||||
return APIKeyConfig.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading key file {key_file}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _update_key_usage_stats(self, api_key: APIKeyConfig):
|
||||
"""Update comprehensive usage statistics for a key"""
|
||||
# In production, would aggregate from detailed usage logs
|
||||
# For now, use existing basic stats
|
||||
pass
|
||||
|
||||
async def _log_usage(self, key_id: str, endpoint: str, client_ip: str):
|
||||
"""Log detailed API key usage for analytics"""
|
||||
usage_record = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"key_id": key_id,
|
||||
"endpoint": endpoint,
|
||||
"client_ip": client_ip,
|
||||
"tenant": self.tenant_domain
|
||||
}
|
||||
|
||||
# Store in daily usage file
|
||||
date_str = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
usage_file = self.usage_path / f"usage_{date_str}.jsonl"
|
||||
|
||||
with open(usage_file, "a") as f:
|
||||
f.write(json.dumps(usage_record) + "\n")
|
||||
|
||||
async def _audit_log(self, action: str, user_id: str, details: Dict[str, Any]):
|
||||
"""Log API key management actions for audit"""
|
||||
audit_record = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"action": action,
|
||||
"user_id": user_id,
|
||||
"tenant": self.tenant_domain,
|
||||
"details": details
|
||||
}
|
||||
|
||||
# Store in daily audit file
|
||||
date_str = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
audit_file = self.audit_path / f"audit_{date_str}.jsonl"
|
||||
|
||||
with open(audit_file, "a") as f:
|
||||
f.write(json.dumps(audit_record) + "\n")
|
||||
635
apps/tenant-backend/app/services/event_bus.py
Normal file
635
apps/tenant-backend/app/services/event_bus.py
Normal file
@@ -0,0 +1,635 @@
|
||||
"""
|
||||
Tenant Event Bus System
|
||||
|
||||
Implements event-driven architecture for automation triggers with perfect
|
||||
tenant isolation and capability-based execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
from enum import Enum
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.path_security import sanitize_tenant_domain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerType(Enum):
|
||||
"""Types of automation triggers"""
|
||||
CRON = "cron" # Time-based
|
||||
WEBHOOK = "webhook" # External HTTP
|
||||
EVENT = "event" # Internal events
|
||||
CHAIN = "chain" # Triggered by other automations
|
||||
MANUAL = "manual" # User-initiated
|
||||
|
||||
|
||||
# Event catalog with required fields
|
||||
EVENT_CATALOG = {
|
||||
"document.uploaded": ["document_id", "dataset_id", "filename"],
|
||||
"document.processed": ["document_id", "chunks_created"],
|
||||
"agent.created": ["agent_id", "name", "owner_id"],
|
||||
"chat.started": ["conversation_id", "agent_id"],
|
||||
"resource.shared": ["resource_id", "access_group", "shared_with"],
|
||||
"quota.warning": ["resource_type", "current_usage", "limit"],
|
||||
"automation.completed": ["automation_id", "result", "duration_ms"],
|
||||
"automation.failed": ["automation_id", "error", "retry_count"]
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
"""Event data structure"""
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
type: str = ""
|
||||
tenant: str = ""
|
||||
user: str = ""
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert event to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.type,
|
||||
"tenant": self.tenant,
|
||||
"user": self.user,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"data": self.data,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Event":
|
||||
"""Create event from dictionary"""
|
||||
return cls(
|
||||
id=data.get("id", str(uuid4())),
|
||||
type=data.get("type", ""),
|
||||
tenant=data.get("tenant", ""),
|
||||
user=data.get("user", ""),
|
||||
timestamp=datetime.fromisoformat(data.get("timestamp", datetime.utcnow().isoformat())),
|
||||
data=data.get("data", {}),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Automation:
|
||||
"""Automation configuration"""
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
name: str = ""
|
||||
owner_id: str = ""
|
||||
trigger_type: TriggerType = TriggerType.MANUAL
|
||||
trigger_config: Dict[str, Any] = field(default_factory=dict)
|
||||
conditions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
actions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
triggers_chain: bool = False
|
||||
chain_targets: List[str] = field(default_factory=list)
|
||||
max_retries: int = 3
|
||||
timeout_seconds: int = 300
|
||||
is_active: bool = True
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def matches_event(self, event: Event) -> bool:
|
||||
"""Check if automation should trigger for event"""
|
||||
if not self.is_active:
|
||||
return False
|
||||
|
||||
if self.trigger_type != TriggerType.EVENT:
|
||||
return False
|
||||
|
||||
# Check event type matches
|
||||
event_types = self.trigger_config.get("event_types", [])
|
||||
if event.type not in event_types:
|
||||
return False
|
||||
|
||||
# Check conditions
|
||||
for condition in self.conditions:
|
||||
if not self._evaluate_condition(condition, event):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _evaluate_condition(self, condition: Dict[str, Any], event: Event) -> bool:
|
||||
"""Evaluate a single condition"""
|
||||
field = condition.get("field")
|
||||
operator = condition.get("operator")
|
||||
value = condition.get("value")
|
||||
|
||||
# Get field value from event
|
||||
if "." in field:
|
||||
parts = field.split(".")
|
||||
# Handle data.field paths by starting from the event object
|
||||
if parts[0] == "data":
|
||||
event_value = event.data
|
||||
parts = parts[1:] # Skip the "data" part
|
||||
else:
|
||||
event_value = event
|
||||
|
||||
for part in parts:
|
||||
if isinstance(event_value, dict):
|
||||
event_value = event_value.get(part)
|
||||
elif hasattr(event_value, part):
|
||||
event_value = getattr(event_value, part)
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
event_value = getattr(event, field, None)
|
||||
|
||||
# Evaluate condition
|
||||
if operator == "equals":
|
||||
return event_value == value
|
||||
elif operator == "not_equals":
|
||||
return event_value != value
|
||||
elif operator == "contains":
|
||||
return value in str(event_value)
|
||||
elif operator == "greater_than":
|
||||
return float(event_value) > float(value)
|
||||
elif operator == "less_than":
|
||||
return float(event_value) < float(value)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class TenantEventBus:
|
||||
"""
|
||||
Event system for automation triggers with tenant isolation.
|
||||
|
||||
Features:
|
||||
- Perfect tenant isolation through file-based storage
|
||||
- Event persistence and replay capability
|
||||
- Automation matching and triggering
|
||||
- Access control for automation execution
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, base_path: Optional[Path] = None):
|
||||
self.tenant_domain = tenant_domain
|
||||
# Sanitize tenant_domain to prevent path traversal
|
||||
safe_tenant = sanitize_tenant_domain(tenant_domain)
|
||||
self.base_path = base_path or (Path("/data") / safe_tenant / "events")
|
||||
self.event_store_path = self.base_path / "store"
|
||||
self.automations_path = self.base_path / "automations"
|
||||
self.event_handlers: Dict[str, List[Callable]] = {}
|
||||
self.running_automations: Dict[str, asyncio.Task] = {}
|
||||
|
||||
# Ensure directories exist with proper permissions
|
||||
self._ensure_directories()
|
||||
|
||||
logger.info(f"TenantEventBus initialized for {tenant_domain}")
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure event directories exist with proper permissions"""
|
||||
import os
|
||||
import stat
|
||||
|
||||
for path in [self.event_store_path, self.automations_path]:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
# Set permissions to 700 (owner only)
|
||||
# codeql[py/path-injection] paths derived from sanitize_tenant_domain() at line 175
|
||||
os.chmod(path, stat.S_IRWXU)
|
||||
|
||||
async def emit_event(
|
||||
self,
|
||||
event_type: str,
|
||||
user_id: str,
|
||||
data: Dict[str, Any],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Event:
|
||||
"""
|
||||
Emit an event and trigger matching automations.
|
||||
|
||||
Args:
|
||||
event_type: Type of event from EVENT_CATALOG
|
||||
user_id: User who triggered the event
|
||||
data: Event data
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
Created event
|
||||
"""
|
||||
# Validate event type
|
||||
if event_type not in EVENT_CATALOG:
|
||||
logger.warning(f"Unknown event type: {event_type}")
|
||||
|
||||
# Create event
|
||||
event = Event(
|
||||
type=event_type,
|
||||
tenant=self.tenant_domain,
|
||||
user=user_id,
|
||||
data=data,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# Store event
|
||||
await self._store_event(event)
|
||||
|
||||
# Find matching automations
|
||||
automations = await self._find_matching_automations(event)
|
||||
|
||||
# Trigger automations with access control
|
||||
for automation in automations:
|
||||
if await self._can_trigger(user_id, automation):
|
||||
asyncio.create_task(self._trigger_automation(automation, event))
|
||||
|
||||
# Call registered handlers
|
||||
if event_type in self.event_handlers:
|
||||
for handler in self.event_handlers[event_type]:
|
||||
asyncio.create_task(handler(event))
|
||||
|
||||
logger.info(f"Event emitted: {event_type} by {user_id}")
|
||||
return event
|
||||
|
||||
async def _store_event(self, event: Event):
|
||||
"""Store event to file system"""
|
||||
# Create daily event file
|
||||
date_str = event.timestamp.strftime("%Y-%m-%d")
|
||||
event_file = self.event_store_path / f"events_{date_str}.jsonl"
|
||||
|
||||
# Append event as JSON line
|
||||
with open(event_file, "a") as f:
|
||||
f.write(json.dumps(event.to_dict()) + "\n")
|
||||
|
||||
async def _find_matching_automations(self, event: Event) -> List[Automation]:
|
||||
"""Find automations that match the event"""
|
||||
matching = []
|
||||
|
||||
# Load all automations from file system
|
||||
if self.automations_path.exists():
|
||||
for automation_file in self.automations_path.glob("*.json"):
|
||||
try:
|
||||
with open(automation_file, "r") as f:
|
||||
automation_data = json.load(f)
|
||||
automation = Automation(**automation_data)
|
||||
|
||||
if automation.matches_event(event):
|
||||
matching.append(automation)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading automation {automation_file}: {e}")
|
||||
|
||||
return matching
|
||||
|
||||
async def _can_trigger(self, user_id: str, automation: Automation) -> bool:
|
||||
"""Check if user can trigger automation"""
|
||||
# Owner can always trigger their automations
|
||||
if automation.owner_id == user_id:
|
||||
return True
|
||||
|
||||
# Check if automation is public or shared
|
||||
# This would integrate with AccessController
|
||||
# For now, only owner can trigger
|
||||
return False
|
||||
|
||||
async def _trigger_automation(self, automation: Automation, event: Event):
|
||||
"""Trigger automation execution"""
|
||||
try:
|
||||
# Check if automation is already running
|
||||
if automation.id in self.running_automations:
|
||||
logger.info(f"Automation {automation.id} already running, skipping")
|
||||
return
|
||||
|
||||
# Create task for automation execution
|
||||
task = asyncio.create_task(
|
||||
self._execute_automation(automation, event)
|
||||
)
|
||||
self.running_automations[automation.id] = task
|
||||
|
||||
# Wait for completion with timeout
|
||||
await asyncio.wait_for(task, timeout=automation.timeout_seconds)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Automation {automation.id} timed out")
|
||||
await self.emit_event(
|
||||
"automation.failed",
|
||||
automation.owner_id,
|
||||
{
|
||||
"automation_id": automation.id,
|
||||
"error": "Timeout",
|
||||
"retry_count": 0
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering automation {automation.id}: {e}")
|
||||
await self.emit_event(
|
||||
"automation.failed",
|
||||
automation.owner_id,
|
||||
{
|
||||
"automation_id": automation.id,
|
||||
"error": str(e),
|
||||
"retry_count": 0
|
||||
}
|
||||
)
|
||||
finally:
|
||||
# Remove from running automations
|
||||
self.running_automations.pop(automation.id, None)
|
||||
|
||||
async def _execute_automation(self, automation: Automation, event: Event) -> Any:
|
||||
"""Execute automation actions"""
|
||||
start_time = datetime.utcnow()
|
||||
results = []
|
||||
|
||||
try:
|
||||
# Execute each action in sequence
|
||||
for action in automation.actions:
|
||||
result = await self._execute_action(action, event, automation)
|
||||
results.append(result)
|
||||
|
||||
# Calculate duration
|
||||
duration_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
# Emit completion event
|
||||
await self.emit_event(
|
||||
"automation.completed",
|
||||
automation.owner_id,
|
||||
{
|
||||
"automation_id": automation.id,
|
||||
"result": results,
|
||||
"duration_ms": duration_ms
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing automation {automation.id}: {e}")
|
||||
raise
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event,
|
||||
automation: Automation
|
||||
) -> Any:
|
||||
"""Execute a single action"""
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "webhook":
|
||||
return await self._execute_webhook_action(action, event)
|
||||
elif action_type == "email":
|
||||
return await self._execute_email_action(action, event)
|
||||
elif action_type == "log":
|
||||
return await self._execute_log_action(action, event)
|
||||
elif action_type == "chain":
|
||||
return await self._execute_chain_action(action, event, automation)
|
||||
else:
|
||||
logger.warning(f"Unknown action type: {action_type}")
|
||||
return None
|
||||
|
||||
async def _execute_webhook_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute webhook action (mock implementation)"""
|
||||
url = action.get("url")
|
||||
method = action.get("method", "POST")
|
||||
headers = action.get("headers", {})
|
||||
body = action.get("body", event.to_dict())
|
||||
|
||||
logger.info(f"Mock webhook call to {url}")
|
||||
|
||||
# In production, use httpx or aiohttp to make actual HTTP request
|
||||
return {
|
||||
"status": "success",
|
||||
"url": url,
|
||||
"method": method,
|
||||
"mock": True
|
||||
}
|
||||
|
||||
async def _execute_email_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute email action (mock implementation)"""
|
||||
to = action.get("to")
|
||||
subject = action.get("subject")
|
||||
body = action.get("body")
|
||||
|
||||
logger.info(f"Mock email to {to}: {subject}")
|
||||
|
||||
# In production, integrate with email service
|
||||
return {
|
||||
"status": "success",
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"mock": True
|
||||
}
|
||||
|
||||
async def _execute_log_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute log action"""
|
||||
message = action.get("message", f"Event: {event.type}")
|
||||
level = action.get("level", "info")
|
||||
|
||||
if level == "debug":
|
||||
logger.debug(message)
|
||||
elif level == "info":
|
||||
logger.info(message)
|
||||
elif level == "warning":
|
||||
logger.warning(message)
|
||||
elif level == "error":
|
||||
logger.error(message)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": message,
|
||||
"level": level
|
||||
}
|
||||
|
||||
async def _execute_chain_action(
|
||||
self,
|
||||
action: Dict[str, Any],
|
||||
event: Event,
|
||||
automation: Automation
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute chain action to trigger other automations"""
|
||||
target_automation_id = action.get("target_automation_id")
|
||||
|
||||
if not target_automation_id:
|
||||
return {"status": "error", "message": "No target automation specified"}
|
||||
|
||||
# Emit chain event
|
||||
chain_event = await self.emit_event(
|
||||
"automation.chain",
|
||||
automation.owner_id,
|
||||
{
|
||||
"source_automation": automation.id,
|
||||
"target_automation": target_automation_id,
|
||||
"original_event": event.to_dict()
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"chain_event_id": chain_event.id,
|
||||
"target_automation": target_automation_id
|
||||
}
|
||||
|
||||
def register_handler(self, event_type: str, handler: Callable):
|
||||
"""Register an event handler"""
|
||||
if event_type not in self.event_handlers:
|
||||
self.event_handlers[event_type] = []
|
||||
self.event_handlers[event_type].append(handler)
|
||||
|
||||
async def create_automation(
|
||||
self,
|
||||
name: str,
|
||||
owner_id: str,
|
||||
trigger_type: TriggerType,
|
||||
trigger_config: Dict[str, Any],
|
||||
actions: List[Dict[str, Any]],
|
||||
conditions: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Automation:
|
||||
"""Create and save a new automation"""
|
||||
automation = Automation(
|
||||
name=name,
|
||||
owner_id=owner_id,
|
||||
trigger_type=trigger_type,
|
||||
trigger_config=trigger_config,
|
||||
actions=actions,
|
||||
conditions=conditions or []
|
||||
)
|
||||
|
||||
# Save to file system
|
||||
automation_file = self.automations_path / f"{automation.id}.json"
|
||||
with open(automation_file, "w") as f:
|
||||
json.dump({
|
||||
"id": automation.id,
|
||||
"name": automation.name,
|
||||
"owner_id": automation.owner_id,
|
||||
"trigger_type": automation.trigger_type.value,
|
||||
"trigger_config": automation.trigger_config,
|
||||
"conditions": automation.conditions,
|
||||
"actions": automation.actions,
|
||||
"triggers_chain": automation.triggers_chain,
|
||||
"chain_targets": automation.chain_targets,
|
||||
"max_retries": automation.max_retries,
|
||||
"timeout_seconds": automation.timeout_seconds,
|
||||
"is_active": automation.is_active,
|
||||
"created_at": automation.created_at.isoformat(),
|
||||
"updated_at": automation.updated_at.isoformat()
|
||||
}, f, indent=2)
|
||||
|
||||
logger.info(f"Created automation: {automation.name} ({automation.id})")
|
||||
return automation
|
||||
|
||||
async def get_automation(self, automation_id: str) -> Optional[Automation]:
|
||||
"""Get automation by ID"""
|
||||
automation_file = self.automations_path / f"{automation_id}.json"
|
||||
|
||||
if not automation_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(automation_file, "r") as f:
|
||||
data = json.load(f)
|
||||
data["trigger_type"] = TriggerType(data["trigger_type"])
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
|
||||
return Automation(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading automation {automation_id}: {e}")
|
||||
return None
|
||||
|
||||
async def list_automations(self, owner_id: Optional[str] = None) -> List[Automation]:
|
||||
"""List all automations, optionally filtered by owner"""
|
||||
automations = []
|
||||
|
||||
if self.automations_path.exists():
|
||||
for automation_file in self.automations_path.glob("*.json"):
|
||||
try:
|
||||
with open(automation_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Filter by owner if specified
|
||||
if owner_id and data.get("owner_id") != owner_id:
|
||||
continue
|
||||
|
||||
data["trigger_type"] = TriggerType(data["trigger_type"])
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
|
||||
automations.append(Automation(**data))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading automation {automation_file}: {e}")
|
||||
|
||||
return automations
|
||||
|
||||
async def delete_automation(self, automation_id: str, owner_id: str) -> bool:
|
||||
"""Delete an automation"""
|
||||
automation = await self.get_automation(automation_id)
|
||||
|
||||
if not automation:
|
||||
return False
|
||||
|
||||
# Check ownership
|
||||
if automation.owner_id != owner_id:
|
||||
logger.warning(f"User {owner_id} attempted to delete automation owned by {automation.owner_id}")
|
||||
return False
|
||||
|
||||
# Delete file
|
||||
automation_file = self.automations_path / f"{automation_id}.json"
|
||||
automation_file.unlink()
|
||||
|
||||
logger.info(f"Deleted automation: {automation_id}")
|
||||
return True
|
||||
|
||||
async def get_event_history(
|
||||
self,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
event_type: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[Event]:
|
||||
"""Get event history with optional filters"""
|
||||
events = []
|
||||
|
||||
# Determine date range
|
||||
if not end_date:
|
||||
end_date = datetime.utcnow()
|
||||
if not start_date:
|
||||
start_date = end_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# Iterate through daily event files
|
||||
current_date = start_date
|
||||
while current_date <= end_date:
|
||||
date_str = current_date.strftime("%Y-%m-%d")
|
||||
event_file = self.event_store_path / f"events_{date_str}.jsonl"
|
||||
|
||||
if event_file.exists():
|
||||
with open(event_file, "r") as f:
|
||||
for line in f:
|
||||
try:
|
||||
event_data = json.loads(line)
|
||||
event = Event.from_dict(event_data)
|
||||
|
||||
# Apply filters
|
||||
if event_type and event.type != event_type:
|
||||
continue
|
||||
if user_id and event.user != user_id:
|
||||
continue
|
||||
|
||||
events.append(event)
|
||||
|
||||
if len(events) >= limit:
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing event: {e}")
|
||||
|
||||
# Move to next day
|
||||
current_date = current_date.replace(day=current_date.day + 1)
|
||||
|
||||
return events
|
||||
869
apps/tenant-backend/app/services/event_service.py
Normal file
869
apps/tenant-backend/app/services/event_service.py
Normal file
@@ -0,0 +1,869 @@
|
||||
"""
|
||||
Event Automation Service for GT 2.0 Tenant Backend
|
||||
|
||||
Handles event-driven automation workflows including:
|
||||
- Document processing triggers
|
||||
- Conversation events
|
||||
- RAG pipeline automation
|
||||
- Agent lifecycle events
|
||||
- User activity tracking
|
||||
|
||||
Perfect tenant isolation with zero downtime compliance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, asdict
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, desc
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.database import get_db_session
|
||||
from app.core.config import get_settings
|
||||
from app.models.event import Event, EventTrigger, EventAction, EventSubscription
|
||||
from app.services.rag_service import RAGService
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.assistant_manager import AssistantManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
"""Event types for automation triggers"""
|
||||
DOCUMENT_UPLOADED = "document.uploaded"
|
||||
DOCUMENT_PROCESSED = "document.processed"
|
||||
DOCUMENT_FAILED = "document.failed"
|
||||
CONVERSATION_STARTED = "conversation.started"
|
||||
MESSAGE_SENT = "message.sent"
|
||||
ASSISTANT_CREATED = "agent.created"
|
||||
RAG_SEARCH_PERFORMED = "rag.search_performed"
|
||||
USER_LOGIN = "user.login"
|
||||
USER_ACTIVITY = "user.activity"
|
||||
SYSTEM_HEALTH_CHECK = "system.health_check"
|
||||
TEAM_INVITATION_CREATED = "team.invitation.created"
|
||||
TEAM_OBSERVABLE_REQUESTED = "team.observable_requested"
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
"""Action types for event responses"""
|
||||
PROCESS_DOCUMENT = "process_document"
|
||||
SEND_NOTIFICATION = "send_notification"
|
||||
UPDATE_STATISTICS = "update_statistics"
|
||||
TRIGGER_RAG_INDEXING = "trigger_rag_indexing"
|
||||
LOG_ANALYTICS = "log_analytics"
|
||||
EXECUTE_WEBHOOK = "execute_webhook"
|
||||
CREATE_ASSISTANT = "create_assistant"
|
||||
SCHEDULE_TASK = "schedule_task"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventPayload:
|
||||
"""Event payload structure"""
|
||||
event_id: str
|
||||
event_type: EventType
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
timestamp: datetime
|
||||
data: Dict[str, Any]
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventActionConfig:
|
||||
"""Configuration for event actions"""
|
||||
action_type: ActionType
|
||||
config: Dict[str, Any]
|
||||
delay_seconds: int = 0
|
||||
retry_count: int = 3
|
||||
retry_delay: int = 60
|
||||
condition: Optional[str] = None # Python expression for conditional execution
|
||||
|
||||
|
||||
class EventService:
|
||||
"""
|
||||
Event automation service with perfect tenant isolation.
|
||||
|
||||
GT 2.0 Security Principles:
|
||||
- Perfect tenant isolation (all events user-scoped)
|
||||
- Zero downtime compliance (async processing)
|
||||
- Self-contained automation (no external dependencies)
|
||||
- Stateless event processing
|
||||
"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.settings = get_settings()
|
||||
self.rag_service = RAGService(db)
|
||||
self.conversation_service = ConversationService(db)
|
||||
self.assistant_manager = AssistantManager(db)
|
||||
|
||||
# Event handlers registry
|
||||
self.action_handlers: Dict[ActionType, Callable] = {
|
||||
ActionType.PROCESS_DOCUMENT: self._handle_process_document,
|
||||
ActionType.SEND_NOTIFICATION: self._handle_send_notification,
|
||||
ActionType.UPDATE_STATISTICS: self._handle_update_statistics,
|
||||
ActionType.TRIGGER_RAG_INDEXING: self._handle_trigger_rag_indexing,
|
||||
ActionType.LOG_ANALYTICS: self._handle_log_analytics,
|
||||
ActionType.EXECUTE_WEBHOOK: self._handle_execute_webhook,
|
||||
ActionType.CREATE_ASSISTANT: self._handle_create_assistant,
|
||||
ActionType.SCHEDULE_TASK: self._handle_schedule_task,
|
||||
}
|
||||
|
||||
# Active event subscriptions cache
|
||||
self.subscriptions_cache: Dict[str, List[EventSubscription]] = {}
|
||||
self.cache_expiry: Optional[datetime] = None
|
||||
|
||||
logger.info("Event automation service initialized with tenant isolation")
|
||||
|
||||
async def emit_event(
|
||||
self,
|
||||
event_type: EventType,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
data: Dict[str, Any],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""Emit an event and trigger automated actions"""
|
||||
try:
|
||||
# Create event payload
|
||||
event_payload = EventPayload(
|
||||
event_id=str(uuid.uuid4()),
|
||||
event_type=event_type,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
timestamp=datetime.utcnow(),
|
||||
data=data,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# Store event in database
|
||||
event_record = Event(
|
||||
event_id=event_payload.event_id,
|
||||
event_type=event_type.value,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
payload=event_payload.to_dict(),
|
||||
status="processing"
|
||||
)
|
||||
|
||||
self.db.add(event_record)
|
||||
await self.db.commit()
|
||||
|
||||
# Process event asynchronously
|
||||
asyncio.create_task(self._process_event(event_payload))
|
||||
|
||||
logger.info(f"Event emitted: {event_type.value} for user {user_id}")
|
||||
return event_payload.event_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to emit event {event_type.value}: {e}")
|
||||
raise
|
||||
|
||||
async def _process_event(self, event_payload: EventPayload) -> None:
|
||||
"""Process event and execute matching actions"""
|
||||
try:
|
||||
# Get subscriptions for this event type
|
||||
subscriptions = await self._get_event_subscriptions(
|
||||
event_payload.event_type,
|
||||
event_payload.user_id,
|
||||
event_payload.tenant_id
|
||||
)
|
||||
|
||||
if not subscriptions:
|
||||
logger.debug(f"No subscriptions found for event {event_payload.event_type}")
|
||||
return
|
||||
|
||||
# Execute actions for each subscription
|
||||
for subscription in subscriptions:
|
||||
try:
|
||||
await self._execute_subscription_actions(subscription, event_payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute subscription {subscription.id}: {e}")
|
||||
continue
|
||||
|
||||
# Update event status
|
||||
await self._update_event_status(event_payload.event_id, "completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process event {event_payload.event_id}: {e}")
|
||||
await self._update_event_status(event_payload.event_id, "failed", str(e))
|
||||
|
||||
async def _get_event_subscriptions(
|
||||
self,
|
||||
event_type: EventType,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> List[EventSubscription]:
|
||||
"""Get active subscriptions for event type with tenant isolation"""
|
||||
try:
|
||||
# Check cache first
|
||||
cache_key = f"{tenant_id}:{user_id}:{event_type.value}"
|
||||
if (self.cache_expiry and datetime.utcnow() < self.cache_expiry and
|
||||
cache_key in self.subscriptions_cache):
|
||||
return self.subscriptions_cache[cache_key]
|
||||
|
||||
# Query database
|
||||
query = select(EventSubscription).where(
|
||||
and_(
|
||||
EventSubscription.event_type == event_type.value,
|
||||
EventSubscription.user_id == user_id,
|
||||
EventSubscription.tenant_id == tenant_id,
|
||||
EventSubscription.is_active == True
|
||||
)
|
||||
).options(selectinload(EventSubscription.actions))
|
||||
|
||||
result = await self.db.execute(query)
|
||||
subscriptions = result.scalars().all()
|
||||
|
||||
# Cache results
|
||||
self.subscriptions_cache[cache_key] = list(subscriptions)
|
||||
self.cache_expiry = datetime.utcnow() + timedelta(minutes=5)
|
||||
|
||||
return list(subscriptions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get event subscriptions: {e}")
|
||||
return []
|
||||
|
||||
async def _execute_subscription_actions(
|
||||
self,
|
||||
subscription: EventSubscription,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Execute all actions for a subscription"""
|
||||
try:
|
||||
for action in subscription.actions:
|
||||
# Check if action should be executed
|
||||
if not await self._should_execute_action(action, event_payload):
|
||||
continue
|
||||
|
||||
# Add delay if specified
|
||||
if action.delay_seconds > 0:
|
||||
await asyncio.sleep(action.delay_seconds)
|
||||
|
||||
# Execute action with retry logic
|
||||
await self._execute_action_with_retry(action, event_payload)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute subscription actions: {e}")
|
||||
raise
|
||||
|
||||
async def _should_execute_action(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> bool:
|
||||
"""Check if action should be executed based on conditions"""
|
||||
try:
|
||||
if not action.condition:
|
||||
return True
|
||||
|
||||
# Create evaluation context
|
||||
context = {
|
||||
'event': event_payload.to_dict(),
|
||||
'data': event_payload.data,
|
||||
'user_id': event_payload.user_id,
|
||||
'tenant_id': event_payload.tenant_id,
|
||||
'event_type': event_payload.event_type.value
|
||||
}
|
||||
|
||||
# Safely evaluate condition
|
||||
try:
|
||||
result = eval(action.condition, {"__builtins__": {}}, context)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to evaluate action condition: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking action condition: {e}")
|
||||
return False
|
||||
|
||||
async def _execute_action_with_retry(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Execute action with retry logic"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(action.retry_count + 1):
|
||||
try:
|
||||
await self._execute_action(action, event_payload)
|
||||
return # Success
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"Action execution attempt {attempt + 1} failed: {e}")
|
||||
|
||||
if attempt < action.retry_count:
|
||||
await asyncio.sleep(action.retry_delay)
|
||||
else:
|
||||
logger.error(f"Action execution failed after {action.retry_count + 1} attempts")
|
||||
raise last_error
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Execute a specific action"""
|
||||
try:
|
||||
action_type = ActionType(action.action_type)
|
||||
handler = self.action_handlers.get(action_type)
|
||||
|
||||
if not handler:
|
||||
raise ValueError(f"No handler for action type: {action_type}")
|
||||
|
||||
await handler(action, event_payload)
|
||||
|
||||
logger.debug(f"Action executed: {action_type.value} for event {event_payload.event_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute action {action.action_type}: {e}")
|
||||
raise
|
||||
|
||||
# Action Handlers
|
||||
|
||||
async def _handle_process_document(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle document processing automation"""
|
||||
try:
|
||||
document_id = event_payload.data.get("document_id")
|
||||
if not document_id:
|
||||
raise ValueError("document_id required for process_document action")
|
||||
|
||||
chunking_strategy = action.config.get("chunking_strategy", "hybrid")
|
||||
|
||||
result = await self.rag_service.process_document(
|
||||
user_id=event_payload.user_id,
|
||||
document_id=document_id,
|
||||
tenant_id=event_payload.tenant_id,
|
||||
chunking_strategy=chunking_strategy
|
||||
)
|
||||
|
||||
# Emit processing completed event
|
||||
await self.emit_event(
|
||||
EventType.DOCUMENT_PROCESSED,
|
||||
event_payload.user_id,
|
||||
event_payload.tenant_id,
|
||||
{
|
||||
"document_id": document_id,
|
||||
"chunk_count": result.get("chunk_count", 0),
|
||||
"processing_result": result
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Emit processing failed event
|
||||
await self.emit_event(
|
||||
EventType.DOCUMENT_FAILED,
|
||||
event_payload.user_id,
|
||||
event_payload.tenant_id,
|
||||
{
|
||||
"document_id": event_payload.data.get("document_id"),
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
|
||||
async def _handle_send_notification(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle notification sending"""
|
||||
try:
|
||||
notification_type = action.config.get("type", "system")
|
||||
message = action.config.get("message", "Event notification")
|
||||
|
||||
# Format message with event data
|
||||
formatted_message = message.format(**event_payload.data)
|
||||
|
||||
# Store notification (implement notification system later)
|
||||
notification_data = {
|
||||
"type": notification_type,
|
||||
"message": formatted_message,
|
||||
"user_id": event_payload.user_id,
|
||||
"event_id": event_payload.event_id,
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Notification sent: {formatted_message} to user {event_payload.user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send notification: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_update_statistics(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle statistics updates"""
|
||||
try:
|
||||
stat_type = action.config.get("type")
|
||||
increment = action.config.get("increment", 1)
|
||||
|
||||
# Update user statistics (implement statistics system later)
|
||||
logger.info(f"Statistics updated: {stat_type} += {increment} for user {event_payload.user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update statistics: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_trigger_rag_indexing(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle RAG reindexing automation"""
|
||||
try:
|
||||
dataset_ids = action.config.get("dataset_ids", [])
|
||||
|
||||
if not dataset_ids:
|
||||
# Get all user datasets
|
||||
datasets = await self.rag_service.list_user_datasets(event_payload.user_id)
|
||||
dataset_ids = [d.id for d in datasets]
|
||||
|
||||
for dataset_id in dataset_ids:
|
||||
# Trigger reindexing for dataset
|
||||
logger.info(f"RAG reindexing triggered for dataset {dataset_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger RAG indexing: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_log_analytics(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle analytics logging"""
|
||||
try:
|
||||
analytics_data = {
|
||||
"event_type": event_payload.event_type.value,
|
||||
"user_id": event_payload.user_id,
|
||||
"tenant_id": event_payload.tenant_id,
|
||||
"timestamp": event_payload.timestamp.isoformat(),
|
||||
"data": event_payload.data,
|
||||
"custom_properties": action.config.get("properties", {})
|
||||
}
|
||||
|
||||
# Log analytics (implement analytics system later)
|
||||
logger.info(f"Analytics logged: {analytics_data}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log analytics: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_execute_webhook(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle webhook execution"""
|
||||
try:
|
||||
webhook_url = action.config.get("url")
|
||||
method = action.config.get("method", "POST")
|
||||
headers = action.config.get("headers", {})
|
||||
|
||||
if not webhook_url:
|
||||
raise ValueError("webhook url required")
|
||||
|
||||
# Prepare webhook payload
|
||||
webhook_payload = {
|
||||
"event": event_payload.to_dict(),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Execute webhook (implement HTTP client later)
|
||||
logger.info(f"Webhook executed: {method} {webhook_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute webhook: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_create_assistant(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle automatic agent creation"""
|
||||
try:
|
||||
template_id = action.config.get("template_id", "general_assistant")
|
||||
assistant_name = action.config.get("name", "Auto-created Agent")
|
||||
|
||||
# Create agent
|
||||
agent_id = await self.assistant_manager.create_from_template(
|
||||
template_id=template_id,
|
||||
config={"name": assistant_name},
|
||||
user_identifier=event_payload.user_id
|
||||
)
|
||||
|
||||
# Emit agent created event
|
||||
await self.emit_event(
|
||||
EventType.ASSISTANT_CREATED,
|
||||
event_payload.user_id,
|
||||
event_payload.tenant_id,
|
||||
{
|
||||
"agent_id": agent_id,
|
||||
"template_id": template_id,
|
||||
"trigger_event": event_payload.event_id
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_schedule_task(
|
||||
self,
|
||||
action: EventAction,
|
||||
event_payload: EventPayload
|
||||
) -> None:
|
||||
"""Handle task scheduling"""
|
||||
try:
|
||||
task_type = action.config.get("task_type")
|
||||
delay_minutes = action.config.get("delay_minutes", 0)
|
||||
|
||||
# Schedule task for future execution
|
||||
scheduled_time = datetime.utcnow() + timedelta(minutes=delay_minutes)
|
||||
|
||||
logger.info(f"Task scheduled: {task_type} for {scheduled_time}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to schedule task: {e}")
|
||||
raise
|
||||
|
||||
# Subscription Management
|
||||
|
||||
async def create_subscription(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
event_type: EventType,
|
||||
actions: List[EventActionConfig],
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None
|
||||
) -> str:
|
||||
"""Create an event subscription"""
|
||||
try:
|
||||
subscription = EventSubscription(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
event_type=event_type.value,
|
||||
name=name or f"Auto-subscription for {event_type.value}",
|
||||
description=description,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
self.db.add(subscription)
|
||||
await self.db.flush()
|
||||
|
||||
# Create actions
|
||||
for action_config in actions:
|
||||
action = EventAction(
|
||||
subscription_id=subscription.id,
|
||||
action_type=action_config.action_type.value,
|
||||
config=action_config.config,
|
||||
delay_seconds=action_config.delay_seconds,
|
||||
retry_count=action_config.retry_count,
|
||||
retry_delay=action_config.retry_delay,
|
||||
condition=action_config.condition
|
||||
)
|
||||
self.db.add(action)
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
# Clear subscriptions cache
|
||||
self._clear_subscriptions_cache()
|
||||
|
||||
logger.info(f"Event subscription created: {subscription.id} for {event_type.value}")
|
||||
return subscription.id
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
logger.error(f"Failed to create subscription: {e}")
|
||||
raise
|
||||
|
||||
async def get_user_subscriptions(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> List[EventSubscription]:
|
||||
"""Get all subscriptions for a user"""
|
||||
try:
|
||||
query = select(EventSubscription).where(
|
||||
and_(
|
||||
EventSubscription.user_id == user_id,
|
||||
EventSubscription.tenant_id == tenant_id
|
||||
)
|
||||
).options(selectinload(EventSubscription.actions))
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user subscriptions: {e}")
|
||||
raise
|
||||
|
||||
async def update_subscription_status(
|
||||
self,
|
||||
subscription_id: str,
|
||||
user_id: str,
|
||||
is_active: bool
|
||||
) -> bool:
|
||||
"""Update subscription status with ownership verification"""
|
||||
try:
|
||||
query = select(EventSubscription).where(
|
||||
and_(
|
||||
EventSubscription.id == subscription_id,
|
||||
EventSubscription.user_id == user_id
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
subscription = result.scalar_one_or_none()
|
||||
|
||||
if not subscription:
|
||||
return False
|
||||
|
||||
subscription.is_active = is_active
|
||||
subscription.updated_at = datetime.utcnow()
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
# Clear subscriptions cache
|
||||
self._clear_subscriptions_cache()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
logger.error(f"Failed to update subscription status: {e}")
|
||||
raise
|
||||
|
||||
async def delete_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
"""Delete subscription with ownership verification"""
|
||||
try:
|
||||
query = select(EventSubscription).where(
|
||||
and_(
|
||||
EventSubscription.id == subscription_id,
|
||||
EventSubscription.user_id == user_id
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
subscription = result.scalar_one_or_none()
|
||||
|
||||
if not subscription:
|
||||
return False
|
||||
|
||||
await self.db.delete(subscription)
|
||||
await self.db.commit()
|
||||
|
||||
# Clear subscriptions cache
|
||||
self._clear_subscriptions_cache()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
logger.error(f"Failed to delete subscription: {e}")
|
||||
raise
|
||||
|
||||
# Utility Methods
|
||||
|
||||
async def _update_event_status(
|
||||
self,
|
||||
event_id: str,
|
||||
status: str,
|
||||
error_message: Optional[str] = None
|
||||
) -> None:
|
||||
"""Update event processing status"""
|
||||
try:
|
||||
query = select(Event).where(Event.event_id == event_id)
|
||||
result = await self.db.execute(query)
|
||||
event = result.scalar_one_or_none()
|
||||
|
||||
if event:
|
||||
event.status = status
|
||||
event.error_message = error_message
|
||||
event.completed_at = datetime.utcnow()
|
||||
await self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update event status: {e}")
|
||||
|
||||
def _clear_subscriptions_cache(self) -> None:
|
||||
"""Clear subscriptions cache"""
|
||||
self.subscriptions_cache.clear()
|
||||
self.cache_expiry = None
|
||||
|
||||
async def get_event_history(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
event_types: Optional[List[EventType]] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[Event]:
|
||||
"""Get event history for user with filtering"""
|
||||
try:
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.user_id == user_id,
|
||||
Event.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
|
||||
if event_types:
|
||||
event_type_values = [et.value for et in event_types]
|
||||
query = query.where(Event.event_type.in_(event_type_values))
|
||||
|
||||
query = query.order_by(desc(Event.created_at)).offset(offset).limit(limit)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get event history: {e}")
|
||||
raise
|
||||
|
||||
async def get_event_statistics(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get event statistics for user"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.user_id == user_id,
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.created_at >= cutoff_date
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
# Calculate statistics
|
||||
stats = {
|
||||
"total_events": len(events),
|
||||
"events_by_type": {},
|
||||
"events_by_status": {},
|
||||
"average_events_per_day": 0
|
||||
}
|
||||
|
||||
for event in events:
|
||||
# Count by type
|
||||
event_type = event.event_type
|
||||
stats["events_by_type"][event_type] = stats["events_by_type"].get(event_type, 0) + 1
|
||||
|
||||
# Count by status
|
||||
status = event.status
|
||||
stats["events_by_status"][status] = stats["events_by_status"].get(status, 0) + 1
|
||||
|
||||
# Calculate average per day
|
||||
if days > 0:
|
||||
stats["average_events_per_day"] = round(len(events) / days, 2)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get event statistics: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Factory function for dependency injection
|
||||
async def get_event_service(db: AsyncSession = None) -> EventService:
|
||||
"""Get event service instance"""
|
||||
if db is None:
|
||||
async with get_db_session() as session:
|
||||
return EventService(session)
|
||||
return EventService(db)
|
||||
|
||||
|
||||
# Default event subscriptions for new users
|
||||
DEFAULT_SUBSCRIPTIONS = [
|
||||
{
|
||||
"event_type": EventType.DOCUMENT_UPLOADED,
|
||||
"actions": [
|
||||
EventActionConfig(
|
||||
action_type=ActionType.PROCESS_DOCUMENT,
|
||||
config={"chunking_strategy": "hybrid"},
|
||||
delay_seconds=5 # Small delay to ensure file is fully uploaded
|
||||
)
|
||||
]
|
||||
},
|
||||
{
|
||||
"event_type": EventType.DOCUMENT_PROCESSED,
|
||||
"actions": [
|
||||
EventActionConfig(
|
||||
action_type=ActionType.SEND_NOTIFICATION,
|
||||
config={
|
||||
"type": "success",
|
||||
"message": "Document '{filename}' has been processed successfully with {chunk_count} chunks."
|
||||
}
|
||||
),
|
||||
EventActionConfig(
|
||||
action_type=ActionType.UPDATE_STATISTICS,
|
||||
config={"type": "documents_processed", "increment": 1}
|
||||
)
|
||||
]
|
||||
},
|
||||
{
|
||||
"event_type": EventType.CONVERSATION_STARTED,
|
||||
"actions": [
|
||||
EventActionConfig(
|
||||
action_type=ActionType.LOG_ANALYTICS,
|
||||
config={"properties": {"conversation_start": True}}
|
||||
)
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def setup_default_subscriptions(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
event_service: EventService
|
||||
) -> None:
|
||||
"""Setup default event subscriptions for new user"""
|
||||
try:
|
||||
for subscription_config in DEFAULT_SUBSCRIPTIONS:
|
||||
await event_service.create_subscription(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
event_type=subscription_config["event_type"],
|
||||
actions=subscription_config["actions"],
|
||||
name=f"Default: {subscription_config['event_type'].value}",
|
||||
description="Automatically created default subscription"
|
||||
)
|
||||
|
||||
logger.info(f"Default event subscriptions created for user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup default subscriptions: {e}")
|
||||
raise
|
||||
636
apps/tenant-backend/app/services/external_service.py
Normal file
636
apps/tenant-backend/app/services/external_service.py
Normal file
@@ -0,0 +1,636 @@
|
||||
"""
|
||||
GT 2.0 Tenant Backend - External Service Management
|
||||
Business logic for managing external web services with Resource Cluster integration
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy import select, update, delete, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.external_service import ExternalServiceInstance, ServiceAccessLog, ServiceTemplate
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class ExternalServiceManager:
|
||||
"""Manages external service instances and Resource Cluster integration"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.resource_cluster_base_url = settings.resource_cluster_url or "http://resource-cluster:8003"
|
||||
self.capability_token = None
|
||||
|
||||
def set_capability_token(self, token: str):
|
||||
"""Set capability token for Resource Cluster API calls"""
|
||||
self.capability_token = token
|
||||
|
||||
async def create_service_instance(
|
||||
self,
|
||||
service_type: str,
|
||||
service_name: str,
|
||||
user_email: str,
|
||||
config_overrides: Optional[Dict[str, Any]] = None,
|
||||
template_id: Optional[str] = None
|
||||
) -> ExternalServiceInstance:
|
||||
"""Create a new external service instance"""
|
||||
|
||||
# Validate service type
|
||||
supported_services = ['ctfd', 'canvas', 'guacamole']
|
||||
if service_type not in supported_services:
|
||||
raise ValueError(f"Unsupported service type: {service_type}")
|
||||
|
||||
# Load template if provided
|
||||
template = None
|
||||
if template_id:
|
||||
template = await self.get_service_template(template_id)
|
||||
if not template:
|
||||
raise ValueError(f"Template {template_id} not found")
|
||||
|
||||
# Prepare configuration
|
||||
service_config = {}
|
||||
if template:
|
||||
service_config.update(template.default_config)
|
||||
if config_overrides:
|
||||
service_config.update(config_overrides)
|
||||
|
||||
# Call Resource Cluster to create instance
|
||||
resource_instance = await self._create_resource_cluster_instance(
|
||||
service_type=service_type,
|
||||
config_overrides=service_config
|
||||
)
|
||||
|
||||
# Create database record
|
||||
instance = ExternalServiceInstance(
|
||||
service_type=service_type,
|
||||
service_name=service_name,
|
||||
description=f"{service_type.title()} instance for {user_email}",
|
||||
resource_instance_id=resource_instance['instance_id'],
|
||||
endpoint_url=resource_instance['endpoint_url'],
|
||||
status=resource_instance['status'],
|
||||
service_config=service_config,
|
||||
created_by=user_email,
|
||||
allowed_users=[user_email],
|
||||
resource_limits=template.resource_requirements if template else {},
|
||||
auto_start=template.default_config.get('auto_start', True) if template else True
|
||||
)
|
||||
|
||||
self.db.add(instance)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(instance)
|
||||
|
||||
logger.info(
|
||||
f"Created {service_type} service instance {instance.id} "
|
||||
f"for user {user_email}"
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
async def _create_resource_cluster_instance(
|
||||
self,
|
||||
service_type: str,
|
||||
config_overrides: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create instance via Resource Cluster API with zero downtime error handling"""
|
||||
|
||||
if not self.capability_token:
|
||||
raise ValueError("Capability token not set")
|
||||
|
||||
max_retries = 3
|
||||
base_delay = 1.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
timeout = httpx.Timeout(60.0, connect=10.0)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_base_url}/api/v1/services/instances",
|
||||
json={
|
||||
"service_type": service_type,
|
||||
"config_overrides": config_overrides
|
||||
},
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.capability_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code in [500, 502, 503, 504] and attempt < max_retries - 1:
|
||||
# Retry for server errors
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"Service creation failed (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
error_detail = response.json().get('detail', f'HTTP {response.status_code}')
|
||||
except:
|
||||
error_detail = f'HTTP {response.status_code}'
|
||||
raise RuntimeError(f"Failed to create service instance: {error_detail}")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
if attempt < max_retries - 1:
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"Service creation timeout (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Failed to create service instance: timeout after retries")
|
||||
except httpx.RequestError as e:
|
||||
if attempt < max_retries - 1:
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"Service creation request error (attempt {attempt + 1}/{max_retries}): {e}, retrying in {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Failed to create service instance: {e}")
|
||||
|
||||
raise RuntimeError("Failed to create service instance: maximum retries exceeded")
|
||||
|
||||
async def get_service_instance(
|
||||
self,
|
||||
instance_id: str,
|
||||
user_email: str
|
||||
) -> Optional[ExternalServiceInstance]:
|
||||
"""Get service instance with access control"""
|
||||
|
||||
query = select(ExternalServiceInstance).where(
|
||||
and_(
|
||||
ExternalServiceInstance.id == instance_id,
|
||||
ExternalServiceInstance.allowed_users.op('json_extract_path_text')('*').op('@>')([user_email])
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_user_services(
|
||||
self,
|
||||
user_email: str,
|
||||
service_type: Optional[str] = None,
|
||||
status: Optional[str] = None
|
||||
) -> List[ExternalServiceInstance]:
|
||||
"""List all services accessible to a user"""
|
||||
|
||||
query = select(ExternalServiceInstance).where(
|
||||
ExternalServiceInstance.allowed_users.op('json_extract_path_text')('*').op('@>')([user_email])
|
||||
)
|
||||
|
||||
if service_type:
|
||||
query = query.where(ExternalServiceInstance.service_type == service_type)
|
||||
|
||||
if status:
|
||||
query = query.where(ExternalServiceInstance.status == status)
|
||||
|
||||
query = query.order_by(ExternalServiceInstance.created_at.desc())
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def stop_service_instance(
|
||||
self,
|
||||
instance_id: str,
|
||||
user_email: str
|
||||
) -> bool:
|
||||
"""Stop a service instance"""
|
||||
|
||||
# Check access
|
||||
instance = await self.get_service_instance(instance_id, user_email)
|
||||
if not instance:
|
||||
raise ValueError(f"Service instance {instance_id} not found or access denied")
|
||||
|
||||
# Call Resource Cluster to stop instance
|
||||
success = await self._stop_resource_cluster_instance(instance.resource_instance_id)
|
||||
|
||||
if success:
|
||||
# Update database status
|
||||
instance.status = 'stopped'
|
||||
instance.updated_at = datetime.utcnow()
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Stopped {instance.service_type} instance {instance_id} "
|
||||
f"by user {user_email}"
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
async def _stop_resource_cluster_instance(self, resource_instance_id: str) -> bool:
|
||||
"""Stop instance via Resource Cluster API with zero downtime error handling"""
|
||||
|
||||
if not self.capability_token:
|
||||
raise ValueError("Capability token not set")
|
||||
|
||||
max_retries = 3
|
||||
base_delay = 1.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
timeout = httpx.Timeout(30.0, connect=10.0)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.delete(
|
||||
f"{self.resource_cluster_base_url}/api/v1/services/instances/{resource_instance_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.capability_token}"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
elif response.status_code == 404:
|
||||
# Instance already gone, consider it successfully stopped
|
||||
logger.info(f"Instance {resource_instance_id} not found, assuming already stopped")
|
||||
return True
|
||||
elif response.status_code in [500, 502, 503, 504] and attempt < max_retries - 1:
|
||||
# Retry for server errors
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"Instance stop failed (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Failed to stop instance {resource_instance_id}: HTTP {response.status_code}")
|
||||
return False
|
||||
|
||||
except httpx.TimeoutException:
|
||||
if attempt < max_retries - 1:
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"Instance stop timeout (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Failed to stop instance {resource_instance_id}: timeout after retries")
|
||||
return False
|
||||
except httpx.RequestError as e:
|
||||
if attempt < max_retries - 1:
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"Instance stop request error (attempt {attempt + 1}/{max_retries}): {e}, retrying in {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Failed to stop instance {resource_instance_id}: {e}")
|
||||
return False
|
||||
|
||||
logger.error(f"Failed to stop instance {resource_instance_id}: maximum retries exceeded")
|
||||
return False
|
||||
|
||||
async def get_service_health(
|
||||
self,
|
||||
instance_id: str,
|
||||
user_email: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get service health status"""
|
||||
|
||||
# Check access
|
||||
instance = await self.get_service_instance(instance_id, user_email)
|
||||
if not instance:
|
||||
raise ValueError(f"Service instance {instance_id} not found or access denied")
|
||||
|
||||
# Get health from Resource Cluster
|
||||
health = await self._get_resource_cluster_health(instance.resource_instance_id)
|
||||
|
||||
# Update instance health status
|
||||
instance.health_status = health.get('status', 'unknown')
|
||||
instance.last_health_check = datetime.utcnow()
|
||||
if health.get('restart_count', 0) != instance.restart_count:
|
||||
instance.restart_count = health.get('restart_count', 0)
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
return health
|
||||
|
||||
async def _get_resource_cluster_health(self, resource_instance_id: str) -> Dict[str, Any]:
|
||||
"""Get health status via Resource Cluster API with zero downtime error handling"""
|
||||
|
||||
if not self.capability_token:
|
||||
raise ValueError("Capability token not set")
|
||||
|
||||
try:
|
||||
timeout = httpx.Timeout(10.0, connect=5.0)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.resource_cluster_base_url}/api/v1/services/health/{resource_instance_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.capability_token}"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code == 404:
|
||||
return {
|
||||
'status': 'not_found',
|
||||
'error': 'Instance not found'
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'status': 'error',
|
||||
'error': f'Health check failed: HTTP {response.status_code}'
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning(f"Health check timeout for instance {resource_instance_id}")
|
||||
return {
|
||||
'status': 'timeout',
|
||||
'error': 'Health check timeout'
|
||||
}
|
||||
except httpx.RequestError as e:
|
||||
logger.warning(f"Health check request error for instance {resource_instance_id}: {e}")
|
||||
return {
|
||||
'status': 'connection_error',
|
||||
'error': f'Connection error: {e}'
|
||||
}
|
||||
|
||||
async def generate_sso_token(
|
||||
self,
|
||||
instance_id: str,
|
||||
user_email: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate SSO token for iframe embedding"""
|
||||
|
||||
# Check access
|
||||
instance = await self.get_service_instance(instance_id, user_email)
|
||||
if not instance:
|
||||
raise ValueError(f"Service instance {instance_id} not found or access denied")
|
||||
|
||||
# Generate SSO token via Resource Cluster
|
||||
sso_data = await self._generate_resource_cluster_sso_token(instance.resource_instance_id)
|
||||
|
||||
# Update last accessed time
|
||||
instance.last_accessed = datetime.utcnow()
|
||||
await self.db.commit()
|
||||
|
||||
return sso_data
|
||||
|
||||
async def _generate_resource_cluster_sso_token(self, resource_instance_id: str) -> Dict[str, Any]:
|
||||
"""Generate SSO token via Resource Cluster API with zero downtime error handling"""
|
||||
|
||||
if not self.capability_token:
|
||||
raise ValueError("Capability token not set")
|
||||
|
||||
max_retries = 3
|
||||
base_delay = 1.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
timeout = httpx.Timeout(10.0, connect=5.0)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_base_url}/api/v1/services/sso-token/{resource_instance_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.capability_token}"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code in [500, 502, 503, 504] and attempt < max_retries - 1:
|
||||
# Retry for server errors
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"SSO token generation failed (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
error_detail = response.json().get('detail', f'HTTP {response.status_code}')
|
||||
except:
|
||||
error_detail = f'HTTP {response.status_code}'
|
||||
raise RuntimeError(f"Failed to generate SSO token: {error_detail}")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
if attempt < max_retries - 1:
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"SSO token generation timeout (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Failed to generate SSO token: timeout after retries")
|
||||
except httpx.RequestError as e:
|
||||
if attempt < max_retries - 1:
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"SSO token generation request error (attempt {attempt + 1}/{max_retries}): {e}, retrying in {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Failed to generate SSO token: {e}")
|
||||
|
||||
raise RuntimeError("Failed to generate SSO token: maximum retries exceeded")
|
||||
|
||||
async def log_service_access(
|
||||
self,
|
||||
service_instance_id: str,
|
||||
service_type: str,
|
||||
user_email: str,
|
||||
access_type: str,
|
||||
session_id: str,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
referer: Optional[str] = None,
|
||||
session_duration_seconds: Optional[int] = None,
|
||||
actions_performed: Optional[List[str]] = None
|
||||
) -> ServiceAccessLog:
|
||||
"""Log service access event"""
|
||||
|
||||
access_log = ServiceAccessLog(
|
||||
service_instance_id=service_instance_id,
|
||||
service_type=service_type,
|
||||
user_email=user_email,
|
||||
session_id=session_id,
|
||||
access_type=access_type,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
session_duration_seconds=session_duration_seconds,
|
||||
actions_performed=actions_performed or []
|
||||
)
|
||||
|
||||
self.db.add(access_log)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(access_log)
|
||||
|
||||
return access_log
|
||||
|
||||
async def get_service_analytics(
|
||||
self,
|
||||
instance_id: str,
|
||||
user_email: str,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get service usage analytics"""
|
||||
|
||||
# Check access
|
||||
instance = await self.get_service_instance(instance_id, user_email)
|
||||
if not instance:
|
||||
raise ValueError(f"Service instance {instance_id} not found or access denied")
|
||||
|
||||
# Query access logs
|
||||
since_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
query = select(ServiceAccessLog).where(
|
||||
and_(
|
||||
ServiceAccessLog.service_instance_id == instance_id,
|
||||
ServiceAccessLog.timestamp >= since_date
|
||||
)
|
||||
).order_by(ServiceAccessLog.timestamp.desc())
|
||||
|
||||
result = await self.db.execute(query)
|
||||
access_logs = result.scalars().all()
|
||||
|
||||
# Compute analytics
|
||||
total_sessions = len(set(log.session_id for log in access_logs))
|
||||
total_time_seconds = sum(
|
||||
log.session_duration_seconds or 0
|
||||
for log in access_logs
|
||||
if log.session_duration_seconds
|
||||
)
|
||||
unique_users = len(set(log.user_email for log in access_logs))
|
||||
|
||||
# Group by day for trend analysis
|
||||
daily_usage = {}
|
||||
for log in access_logs:
|
||||
day = log.timestamp.date().isoformat()
|
||||
if day not in daily_usage:
|
||||
daily_usage[day] = {'sessions': 0, 'users': set()}
|
||||
if log.access_type == 'login':
|
||||
daily_usage[day]['sessions'] += 1
|
||||
daily_usage[day]['users'].add(log.user_email)
|
||||
|
||||
# Convert sets to counts
|
||||
for day_data in daily_usage.values():
|
||||
day_data['unique_users'] = len(day_data['users'])
|
||||
del day_data['users']
|
||||
|
||||
return {
|
||||
'instance_id': instance_id,
|
||||
'service_type': instance.service_type,
|
||||
'service_name': instance.service_name,
|
||||
'analytics_period_days': days,
|
||||
'total_sessions': total_sessions,
|
||||
'total_time_hours': round(total_time_seconds / 3600, 1),
|
||||
'unique_users': unique_users,
|
||||
'average_session_duration_minutes': round(
|
||||
total_time_seconds / max(total_sessions, 1) / 60, 1
|
||||
),
|
||||
'daily_usage': daily_usage,
|
||||
'health_status': instance.health_status,
|
||||
'uptime_percentage': self._calculate_uptime_percentage(access_logs, days),
|
||||
'last_accessed': instance.last_accessed.isoformat() if instance.last_accessed else None,
|
||||
'created_at': instance.created_at.isoformat()
|
||||
}
|
||||
|
||||
def _calculate_uptime_percentage(self, access_logs: List[ServiceAccessLog], days: int) -> float:
|
||||
"""Calculate approximate uptime percentage based on access patterns"""
|
||||
if not access_logs:
|
||||
return 0.0
|
||||
|
||||
# Simple heuristic: if we have recent login events, assume service is up
|
||||
recent_logins = [
|
||||
log for log in access_logs
|
||||
if log.access_type == 'login' and
|
||||
log.timestamp > datetime.utcnow() - timedelta(days=1)
|
||||
]
|
||||
|
||||
if recent_logins:
|
||||
return 95.0 # Assume good uptime if recently accessed
|
||||
elif len(access_logs) > 0:
|
||||
return 85.0 # Some historical usage
|
||||
else:
|
||||
return 50.0 # No usage data
|
||||
|
||||
async def create_service_template(
|
||||
self,
|
||||
template_name: str,
|
||||
service_type: str,
|
||||
description: str,
|
||||
default_config: Dict[str, Any],
|
||||
created_by: str,
|
||||
**kwargs
|
||||
) -> ServiceTemplate:
|
||||
"""Create a new service template"""
|
||||
|
||||
template = ServiceTemplate(
|
||||
template_name=template_name,
|
||||
service_type=service_type,
|
||||
description=description,
|
||||
default_config=default_config,
|
||||
created_by=created_by,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.db.add(template)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(template)
|
||||
|
||||
return template
|
||||
|
||||
async def get_service_template(self, template_id: str) -> Optional[ServiceTemplate]:
|
||||
"""Get service template by ID"""
|
||||
|
||||
query = select(ServiceTemplate).where(ServiceTemplate.id == template_id)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_service_templates(
|
||||
self,
|
||||
service_type: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
public_only: bool = True
|
||||
) -> List[ServiceTemplate]:
|
||||
"""List available service templates"""
|
||||
|
||||
query = select(ServiceTemplate).where(ServiceTemplate.is_active == True)
|
||||
|
||||
if public_only:
|
||||
query = query.where(ServiceTemplate.is_public == True)
|
||||
|
||||
if service_type:
|
||||
query = query.where(ServiceTemplate.service_type == service_type)
|
||||
|
||||
if category:
|
||||
query = query.where(ServiceTemplate.category == category)
|
||||
|
||||
query = query.order_by(ServiceTemplate.usage_count.desc())
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def share_service_instance(
|
||||
self,
|
||||
instance_id: str,
|
||||
owner_email: str,
|
||||
share_with_emails: List[str],
|
||||
access_level: str = 'read'
|
||||
) -> bool:
|
||||
"""Share service instance with other users"""
|
||||
|
||||
# Check owner access
|
||||
instance = await self.get_service_instance(instance_id, owner_email)
|
||||
if not instance:
|
||||
raise ValueError(f"Service instance {instance_id} not found or access denied")
|
||||
|
||||
if instance.created_by != owner_email:
|
||||
raise ValueError("Only the instance creator can share access")
|
||||
|
||||
# Update allowed users
|
||||
current_users = set(instance.allowed_users)
|
||||
new_users = current_users.union(set(share_with_emails))
|
||||
|
||||
instance.allowed_users = list(new_users)
|
||||
instance.access_level = 'team' if len(new_users) > 1 else 'private'
|
||||
instance.updated_at = datetime.utcnow()
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Shared {instance.service_type} instance {instance_id} "
|
||||
f"with {len(share_with_emails)} users"
|
||||
)
|
||||
|
||||
return True
|
||||
950
apps/tenant-backend/app/services/game_service.py
Normal file
950
apps/tenant-backend/app/services/game_service.py
Normal file
@@ -0,0 +1,950 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_, or_, desc
|
||||
from sqlalchemy.orm import selectinload
|
||||
from app.models.game import (
|
||||
GameSession, PuzzleSession, PhilosophicalDialogue,
|
||||
LearningAnalytics, GameTemplate
|
||||
)
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
|
||||
|
||||
class GameService:
|
||||
"""Service for managing strategic games (Chess, Go, etc.)"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_available_games(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get available games and user's current progress"""
|
||||
# Get user's analytics to determine appropriate difficulty
|
||||
analytics = await self.get_or_create_analytics(user_id)
|
||||
|
||||
# Available game types with current user ratings
|
||||
available_games = {
|
||||
"chess": {
|
||||
"name": "Strategic Chess",
|
||||
"description": "Classical chess with AI analysis and move commentary",
|
||||
"current_rating": analytics.chess_rating,
|
||||
"difficulty_levels": ["beginner", "intermediate", "advanced", "expert"],
|
||||
"features": ["move_analysis", "position_evaluation", "opening_guidance", "endgame_tutorials"],
|
||||
"estimated_time": "15-45 minutes",
|
||||
"skills_developed": ["strategic_planning", "pattern_recognition", "calculation_depth"]
|
||||
},
|
||||
"go": {
|
||||
"name": "Strategic Go",
|
||||
"description": "The ancient game of Go with territory and influence analysis",
|
||||
"current_rating": analytics.go_rating,
|
||||
"difficulty_levels": ["beginner", "intermediate", "advanced", "expert"],
|
||||
"features": ["territory_visualization", "influence_mapping", "joseki_suggestions", "life_death_training"],
|
||||
"estimated_time": "20-60 minutes",
|
||||
"skills_developed": ["strategic_concepts", "reading_ability", "intuitive_judgment"]
|
||||
}
|
||||
}
|
||||
|
||||
# Get recent game sessions
|
||||
recent_sessions_query = select(GameSession).where(
|
||||
and_(
|
||||
GameSession.user_id == user_id,
|
||||
GameSession.started_at >= datetime.utcnow() - timedelta(days=7)
|
||||
)
|
||||
).order_by(desc(GameSession.started_at)).limit(5)
|
||||
|
||||
result = await self.db.execute(recent_sessions_query)
|
||||
recent_sessions = result.scalars().all()
|
||||
|
||||
return {
|
||||
"available_games": available_games,
|
||||
"recent_sessions": [self._serialize_game_session(session) for session in recent_sessions],
|
||||
"user_analytics": self._serialize_analytics(analytics)
|
||||
}
|
||||
|
||||
async def start_game_session(self, user_id: str, game_type: str, config: Dict[str, Any]) -> GameSession:
|
||||
"""Start a new game session"""
|
||||
analytics = await self.get_or_create_analytics(user_id)
|
||||
|
||||
# Determine AI opponent configuration based on user rating
|
||||
ai_config = self._configure_ai_opponent(game_type, analytics, config.get('difficulty', 'intermediate'))
|
||||
|
||||
session = GameSession(
|
||||
user_id=user_id,
|
||||
game_type=game_type,
|
||||
game_name=config.get('name', f"{game_type.title()} Game"),
|
||||
difficulty_level=config.get('difficulty', 'intermediate'),
|
||||
ai_opponent_config=ai_config,
|
||||
game_rules=self._get_game_rules(game_type),
|
||||
current_state=self._initialize_game_state(game_type),
|
||||
current_rating=getattr(analytics, f"{game_type}_rating", 1200)
|
||||
)
|
||||
|
||||
self.db.add(session)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(session)
|
||||
|
||||
return session
|
||||
|
||||
async def make_move(self, session_id: str, user_id: str, move_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process a user move and generate AI response"""
|
||||
session_query = select(GameSession).where(
|
||||
and_(GameSession.id == session_id, GameSession.user_id == user_id)
|
||||
)
|
||||
result = await self.db.execute(session_query)
|
||||
session = result.scalar_one_or_none()
|
||||
|
||||
if not session or session.game_status != 'active':
|
||||
raise ValueError("Game session not found or not active")
|
||||
|
||||
# Process user move
|
||||
move_result = self._process_move(session, move_data, is_ai=False)
|
||||
|
||||
# Generate AI response
|
||||
ai_move = self._generate_ai_move(session)
|
||||
ai_result = self._process_move(session, ai_move, is_ai=True)
|
||||
|
||||
# Update session state
|
||||
session.moves_count += 2 # User move + AI move
|
||||
session.last_move_at = datetime.utcnow()
|
||||
session.move_history = session.move_history + [move_result, ai_result]
|
||||
|
||||
# Check for game end conditions
|
||||
game_status = self._check_game_status(session)
|
||||
if game_status['ended']:
|
||||
session.game_status = 'completed'
|
||||
session.completed_at = datetime.utcnow()
|
||||
session.outcome = game_status['outcome']
|
||||
session.ai_analysis = self._generate_game_analysis(session)
|
||||
session.learning_insights = self._extract_learning_insights(session)
|
||||
|
||||
# Update user analytics
|
||||
await self._update_analytics_after_game(session)
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(session)
|
||||
|
||||
return {
|
||||
"user_move": move_result,
|
||||
"ai_move": ai_result,
|
||||
"current_state": session.current_state,
|
||||
"game_status": game_status,
|
||||
"analysis": self._generate_move_analysis(session, move_data) if game_status.get('ended') else None
|
||||
}
|
||||
|
||||
async def get_game_analysis(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""Get detailed analysis of the current game position"""
|
||||
session_query = select(GameSession).where(
|
||||
and_(GameSession.id == session_id, GameSession.user_id == user_id)
|
||||
)
|
||||
result = await self.db.execute(session_query)
|
||||
session = result.scalar_one_or_none()
|
||||
|
||||
if not session:
|
||||
raise ValueError("Game session not found")
|
||||
|
||||
analysis = {
|
||||
"position_evaluation": self._evaluate_position(session),
|
||||
"best_moves": self._get_best_moves(session),
|
||||
"strategic_insights": self._get_strategic_insights(session),
|
||||
"learning_points": self._get_learning_points(session),
|
||||
"skill_assessment": self._assess_current_skill(session)
|
||||
}
|
||||
|
||||
return analysis
|
||||
|
||||
async def get_user_game_history(self, user_id: str, game_type: Optional[str] = None, limit: int = 20) -> List[Dict[str, Any]]:
|
||||
"""Get user's game history with performance trends"""
|
||||
query = select(GameSession).where(GameSession.user_id == user_id)
|
||||
|
||||
if game_type:
|
||||
query = query.where(GameSession.game_type == game_type)
|
||||
|
||||
query = query.order_by(desc(GameSession.started_at)).limit(limit)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
sessions = result.scalars().all()
|
||||
|
||||
return [self._serialize_game_session(session) for session in sessions]
|
||||
|
||||
def _configure_ai_opponent(self, game_type: str, analytics: LearningAnalytics, difficulty: str) -> Dict[str, Any]:
|
||||
"""Configure AI opponent based on user skill level"""
|
||||
base_config = {
|
||||
"personality": "teaching", # teaching, aggressive, defensive, balanced
|
||||
"explanation_mode": True,
|
||||
"move_commentary": True,
|
||||
"mistake_correction": True,
|
||||
"hint_availability": True
|
||||
}
|
||||
|
||||
if game_type == "chess":
|
||||
rating_map = {
|
||||
"beginner": analytics.chess_rating - 200,
|
||||
"intermediate": analytics.chess_rating,
|
||||
"advanced": analytics.chess_rating + 200,
|
||||
"expert": analytics.chess_rating + 400
|
||||
}
|
||||
base_config.update({
|
||||
"engine_strength": rating_map.get(difficulty, analytics.chess_rating),
|
||||
"opening_book": True,
|
||||
"endgame_tablebase": difficulty in ["advanced", "expert"],
|
||||
"thinking_time": {"beginner": 1, "intermediate": 3, "advanced": 5, "expert": 10}[difficulty]
|
||||
})
|
||||
|
||||
elif game_type == "go":
|
||||
base_config.update({
|
||||
"handicap_stones": {"beginner": 4, "intermediate": 2, "advanced": 0, "expert": 0}[difficulty],
|
||||
"commentary_level": {"beginner": "detailed", "intermediate": "moderate", "advanced": "minimal", "expert": "minimal"}[difficulty],
|
||||
"joseki_teaching": difficulty in ["beginner", "intermediate"]
|
||||
})
|
||||
|
||||
return base_config
|
||||
|
||||
def _get_game_rules(self, game_type: str) -> Dict[str, Any]:
|
||||
"""Get standard rules for the game type"""
|
||||
rules = {
|
||||
"chess": {
|
||||
"board_size": "8x8",
|
||||
"time_control": "unlimited",
|
||||
"special_rules": ["castling", "en_passant", "promotion"],
|
||||
"victory_conditions": ["checkmate", "resignation", "time_forfeit"]
|
||||
},
|
||||
"go": {
|
||||
"board_size": "19x19",
|
||||
"komi": 6.5,
|
||||
"special_rules": ["ko_rule", "suicide_rule"],
|
||||
"victory_conditions": ["territory_count", "resignation"]
|
||||
}
|
||||
}
|
||||
return rules.get(game_type, {})
|
||||
|
||||
def _initialize_game_state(self, game_type: str) -> Dict[str, Any]:
|
||||
"""Initialize the starting game state"""
|
||||
if game_type == "chess":
|
||||
return {
|
||||
"board": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", # FEN notation
|
||||
"to_move": "white",
|
||||
"castling_rights": "KQkq",
|
||||
"en_passant": None,
|
||||
"halfmove_clock": 0,
|
||||
"fullmove_number": 1
|
||||
}
|
||||
elif game_type == "go":
|
||||
return {
|
||||
"board": [[0 for _ in range(19)] for _ in range(19)], # 0=empty, 1=black, 2=white
|
||||
"to_move": "black",
|
||||
"captured_stones": {"black": 0, "white": 0},
|
||||
"ko_position": None,
|
||||
"move_number": 0
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
def _process_move(self, session: GameSession, move_data: Dict[str, Any], is_ai: bool) -> Dict[str, Any]:
|
||||
"""Process a move and update game state"""
|
||||
# This would contain game-specific logic for processing moves
|
||||
# For now, return a mock processed move
|
||||
return {
|
||||
"move": move_data,
|
||||
"is_ai": is_ai,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"evaluation": random.uniform(-1.0, 1.0), # Mock evaluation
|
||||
"commentary": "Good move!" if not is_ai else "AI plays strategically"
|
||||
}
|
||||
|
||||
def _generate_ai_move(self, session: GameSession) -> Dict[str, Any]:
|
||||
"""Generate AI move based on current position"""
|
||||
# Mock AI move generation
|
||||
if session.game_type == "chess":
|
||||
return {"from": "e2", "to": "e4", "piece": "pawn"}
|
||||
elif session.game_type == "go":
|
||||
return {"x": 10, "y": 10, "color": "white"}
|
||||
|
||||
return {}
|
||||
|
||||
def _check_game_status(self, session: GameSession) -> Dict[str, Any]:
|
||||
"""Check if game has ended and determine outcome"""
|
||||
# Mock game status check
|
||||
return {
|
||||
"ended": session.moves_count > 20, # Mock end condition
|
||||
"outcome": random.choice(["win", "loss", "draw"]) if session.moves_count > 20 else None
|
||||
}
|
||||
|
||||
def _generate_game_analysis(self, session: GameSession) -> Dict[str, Any]:
|
||||
"""Generate comprehensive game analysis"""
|
||||
return {
|
||||
"game_quality": random.uniform(0.6, 0.95),
|
||||
"key_moments": [
|
||||
{"move": 5, "evaluation": "Excellent opening choice"},
|
||||
{"move": 12, "evaluation": "Missed tactical opportunity"},
|
||||
{"move": 18, "evaluation": "Strong endgame technique"}
|
||||
],
|
||||
"skill_demonstration": {
|
||||
"tactical_awareness": random.uniform(0.5, 1.0),
|
||||
"strategic_understanding": random.uniform(0.4, 0.9),
|
||||
"time_management": random.uniform(0.6, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
def _extract_learning_insights(self, session: GameSession) -> List[str]:
|
||||
"""Extract key learning insights from the game"""
|
||||
insights = [
|
||||
"Focus on controlling the center in the opening",
|
||||
"Look for tactical combinations before moving",
|
||||
"Consider your opponent's threats before making your move",
|
||||
"Practice endgame fundamentals"
|
||||
]
|
||||
return random.sample(insights, 2)
|
||||
|
||||
async def _update_analytics_after_game(self, session: GameSession):
|
||||
"""Update user analytics based on game performance"""
|
||||
analytics = await self.get_or_create_analytics(session.user_id)
|
||||
|
||||
# Update game-specific rating
|
||||
if session.game_type == "chess":
|
||||
rating_change = self._calculate_rating_change(session)
|
||||
analytics.chess_rating += rating_change
|
||||
elif session.game_type == "go":
|
||||
rating_change = self._calculate_rating_change(session)
|
||||
analytics.go_rating += rating_change
|
||||
|
||||
# Update cognitive skills based on performance
|
||||
if session.ai_analysis:
|
||||
skill_updates = session.ai_analysis.get("skill_demonstration", {})
|
||||
analytics.strategic_thinking_score = self._update_skill_score(
|
||||
analytics.strategic_thinking_score,
|
||||
skill_updates.get("strategic_understanding", 0.5)
|
||||
)
|
||||
|
||||
analytics.total_sessions += 1
|
||||
analytics.total_time_minutes += session.time_spent_seconds // 60
|
||||
analytics.last_activity_date = datetime.utcnow()
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
def _calculate_rating_change(self, session: GameSession) -> int:
|
||||
"""Calculate ELO rating change based on game outcome"""
|
||||
base_change = 30
|
||||
if session.outcome == "win":
|
||||
return base_change
|
||||
elif session.outcome == "loss":
|
||||
return -base_change
|
||||
else: # draw
|
||||
return 0
|
||||
|
||||
def _update_skill_score(self, current_score: float, performance: float) -> float:
|
||||
"""Update skill score using exponential moving average"""
|
||||
alpha = 0.1 # Learning rate
|
||||
return current_score * (1 - alpha) + performance * 100 * alpha
|
||||
|
||||
async def get_or_create_analytics(self, user_id: str) -> LearningAnalytics:
|
||||
"""Get or create learning analytics for user"""
|
||||
query = select(LearningAnalytics).where(LearningAnalytics.user_id == user_id)
|
||||
result = await self.db.execute(query)
|
||||
analytics = result.scalar_one_or_none()
|
||||
|
||||
if not analytics:
|
||||
analytics = LearningAnalytics(user_id=user_id)
|
||||
self.db.add(analytics)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(analytics)
|
||||
|
||||
return analytics
|
||||
|
||||
def _serialize_game_session(self, session: GameSession) -> Dict[str, Any]:
|
||||
"""Serialize game session for API response"""
|
||||
return {
|
||||
"id": session.id,
|
||||
"game_type": session.game_type,
|
||||
"game_name": session.game_name,
|
||||
"difficulty_level": session.difficulty_level,
|
||||
"game_status": session.game_status,
|
||||
"moves_count": session.moves_count,
|
||||
"time_spent_seconds": session.time_spent_seconds,
|
||||
"current_rating": session.current_rating,
|
||||
"outcome": session.outcome,
|
||||
"started_at": session.started_at.isoformat() if session.started_at else None,
|
||||
"completed_at": session.completed_at.isoformat() if session.completed_at else None,
|
||||
"learning_insights": session.learning_insights
|
||||
}
|
||||
|
||||
def _serialize_analytics(self, analytics: LearningAnalytics) -> Dict[str, Any]:
|
||||
"""Serialize learning analytics for API response"""
|
||||
return {
|
||||
"chess_rating": analytics.chess_rating,
|
||||
"go_rating": analytics.go_rating,
|
||||
"total_sessions": analytics.total_sessions,
|
||||
"total_time_minutes": analytics.total_time_minutes,
|
||||
"current_streak_days": analytics.current_streak_days,
|
||||
"strategic_thinking_score": analytics.strategic_thinking_score,
|
||||
"logical_reasoning_score": analytics.logical_reasoning_score,
|
||||
"pattern_recognition_score": analytics.pattern_recognition_score,
|
||||
"ai_collaboration_skills": {
|
||||
"dependency_index": analytics.ai_dependency_index,
|
||||
"prompt_engineering": analytics.prompt_engineering_skill,
|
||||
"output_evaluation": analytics.ai_output_evaluation_skill,
|
||||
"collaborative_solving": analytics.collaborative_problem_solving
|
||||
}
|
||||
}
|
||||
|
||||
# Mock helper methods for analysis (would be replaced with actual game engines)
|
||||
def _evaluate_position(self, session: GameSession) -> Dict[str, Any]:
|
||||
return {"advantage": random.uniform(-2.0, 2.0), "complexity": random.uniform(0.3, 1.0)}
|
||||
|
||||
def _get_best_moves(self, session: GameSession) -> List[Dict[str, Any]]:
|
||||
return [{"move": "e4", "evaluation": 0.3}, {"move": "d4", "evaluation": 0.2}]
|
||||
|
||||
def _get_strategic_insights(self, session: GameSession) -> List[str]:
|
||||
return ["Control the center", "Develop pieces actively", "Ensure king safety"]
|
||||
|
||||
def _get_learning_points(self, session: GameSession) -> List[str]:
|
||||
return ["Practice tactical patterns", "Study endgame principles"]
|
||||
|
||||
def _assess_current_skill(self, session: GameSession) -> Dict[str, float]:
|
||||
return {
|
||||
"tactical_strength": random.uniform(0.4, 0.9),
|
||||
"positional_understanding": random.uniform(0.3, 0.8),
|
||||
"calculation_accuracy": random.uniform(0.5, 0.95)
|
||||
}
|
||||
|
||||
def _generate_move_analysis(self, session: GameSession, move_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"move_quality": random.uniform(0.6, 1.0),
|
||||
"alternatives": ["Better move: Nf3", "Consider: Bc4"],
|
||||
"consequences": "Leads to tactical complications"
|
||||
}
|
||||
|
||||
|
||||
class PuzzleService:
|
||||
"""Service for managing logic puzzles and reasoning challenges"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_available_puzzles(self, user_id: str, category: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get available puzzles based on user skill level"""
|
||||
analytics = await self._get_analytics(user_id)
|
||||
|
||||
puzzle_categories = {
|
||||
"lateral_thinking": {
|
||||
"name": "Lateral Thinking",
|
||||
"description": "Creative problem-solving that requires thinking outside conventional patterns",
|
||||
"difficulty_range": [1, 10],
|
||||
"skills_developed": ["creative_thinking", "assumption_challenging", "perspective_shifting"]
|
||||
},
|
||||
"logical_deduction": {
|
||||
"name": "Logical Deduction",
|
||||
"description": "Step-by-step reasoning to reach logical conclusions",
|
||||
"difficulty_range": [1, 8],
|
||||
"skills_developed": ["systematic_thinking", "evidence_evaluation", "logical_consistency"]
|
||||
},
|
||||
"mathematical_reasoning": {
|
||||
"name": "Mathematical Reasoning",
|
||||
"description": "Number patterns, sequences, and mathematical logic",
|
||||
"difficulty_range": [1, 9],
|
||||
"skills_developed": ["pattern_recognition", "analytical_thinking", "quantitative_reasoning"]
|
||||
},
|
||||
"spatial_reasoning": {
|
||||
"name": "Spatial Reasoning",
|
||||
"description": "3D visualization and spatial relationship puzzles",
|
||||
"difficulty_range": [1, 7],
|
||||
"skills_developed": ["spatial_visualization", "mental_rotation", "pattern_matching"]
|
||||
}
|
||||
}
|
||||
|
||||
if category and category in puzzle_categories:
|
||||
return {"category": puzzle_categories[category]}
|
||||
|
||||
return {
|
||||
"categories": puzzle_categories,
|
||||
"recommended_difficulty": min(analytics.puzzle_solving_level + 1, 10),
|
||||
"user_progress": {
|
||||
"current_level": analytics.puzzle_solving_level,
|
||||
"puzzles_solved_total": analytics.total_sessions,
|
||||
"favorite_categories": [] # Would be determined from session history
|
||||
}
|
||||
}
|
||||
|
||||
async def start_puzzle_session(self, user_id: str, puzzle_type: str, difficulty: int = None) -> PuzzleSession:
|
||||
"""Start a new puzzle session"""
|
||||
analytics = await self._get_analytics(user_id)
|
||||
|
||||
if difficulty is None:
|
||||
difficulty = min(analytics.puzzle_solving_level + 1, 10)
|
||||
|
||||
puzzle_def = self._generate_puzzle(puzzle_type, difficulty)
|
||||
|
||||
session = PuzzleSession(
|
||||
user_id=user_id,
|
||||
puzzle_type=puzzle_type,
|
||||
puzzle_category=puzzle_def["category"],
|
||||
puzzle_definition=puzzle_def["definition"],
|
||||
solution_criteria=puzzle_def["solution_criteria"],
|
||||
difficulty_rating=difficulty,
|
||||
estimated_time_minutes=puzzle_def.get("estimated_time", 10)
|
||||
)
|
||||
|
||||
self.db.add(session)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(session)
|
||||
|
||||
return session
|
||||
|
||||
async def submit_solution(self, session_id: str, user_id: str, solution: Dict[str, Any], reasoning: str = None) -> Dict[str, Any]:
|
||||
"""Submit a solution attempt for evaluation"""
|
||||
session_query = select(PuzzleSession).where(
|
||||
and_(PuzzleSession.id == session_id, PuzzleSession.user_id == user_id)
|
||||
)
|
||||
result = await self.db.execute(session_query)
|
||||
session = result.scalar_one_or_none()
|
||||
|
||||
if not session or session.session_status not in ['active', 'hint_requested']:
|
||||
raise ValueError("Puzzle session not found or not active")
|
||||
|
||||
# Evaluate solution
|
||||
evaluation = self._evaluate_solution(session, solution, reasoning)
|
||||
|
||||
# Update session
|
||||
session.attempts_count += 1
|
||||
session.current_attempt = solution
|
||||
session.attempt_history = session.attempt_history + [solution]
|
||||
session.reasoning_explanation = reasoning
|
||||
session.time_spent_seconds += 60 # Mock time tracking
|
||||
|
||||
if evaluation["correct"]:
|
||||
session.is_solved = True
|
||||
session.session_status = 'solved'
|
||||
session.solved_at = datetime.utcnow()
|
||||
session.solution_quality_score = evaluation["quality_score"]
|
||||
session.ai_feedback = evaluation["feedback"]
|
||||
|
||||
# Update user analytics
|
||||
await self._update_analytics_after_puzzle(session, evaluation)
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(session)
|
||||
|
||||
return {
|
||||
"correct": evaluation["correct"],
|
||||
"feedback": evaluation["feedback"],
|
||||
"quality_score": evaluation.get("quality_score", 0),
|
||||
"hints_available": not evaluation["correct"],
|
||||
"next_difficulty_recommendation": evaluation.get("next_difficulty", session.difficulty_rating)
|
||||
}
|
||||
|
||||
async def get_hint(self, session_id: str, user_id: str, hint_level: int = 1) -> Dict[str, Any]:
|
||||
"""Provide a hint for the current puzzle"""
|
||||
session_query = select(PuzzleSession).where(
|
||||
and_(PuzzleSession.id == session_id, PuzzleSession.user_id == user_id)
|
||||
)
|
||||
result = await self.db.execute(session_query)
|
||||
session = result.scalar_one_or_none()
|
||||
|
||||
if not session or session.session_status not in ['active', 'hint_requested']:
|
||||
raise ValueError("Puzzle session not found or not active")
|
||||
|
||||
hint = self._generate_hint(session, hint_level)
|
||||
|
||||
# Update session
|
||||
session.hints_used_count += 1
|
||||
session.hints_given = session.hints_given + [hint]
|
||||
session.session_status = 'hint_requested'
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
return {
|
||||
"hint": hint["text"],
|
||||
"hint_type": hint["type"],
|
||||
"points_deducted": hint.get("point_penalty", 5),
|
||||
"hints_remaining": max(0, 3 - session.hints_used_count)
|
||||
}
|
||||
|
||||
def _generate_puzzle(self, puzzle_type: str, difficulty: int) -> Dict[str, Any]:
|
||||
"""Generate a puzzle based on type and difficulty"""
|
||||
puzzles = {
|
||||
"lateral_thinking": {
|
||||
"definition": {
|
||||
"question": "A man lives on the 20th floor of an apartment building. Every morning he takes the elevator down to the ground floor. When he comes home, he takes the elevator to the 10th floor and walks the rest of the way... except on rainy days, when he takes the elevator all the way to the 20th floor. Why?",
|
||||
"context": "This is a classic lateral thinking puzzle that requires you to challenge your assumptions."
|
||||
},
|
||||
"solution_criteria": {
|
||||
"key_insights": ["height limitation", "elevator button accessibility"],
|
||||
"required_elements": ["umbrella usage", "physical constraint explanation"]
|
||||
},
|
||||
"category": "lateral_thinking",
|
||||
"estimated_time": 15
|
||||
},
|
||||
"logical_deduction": {
|
||||
"definition": {
|
||||
"question": "Five people of different nationalities live in five houses of different colors, drink different beverages, smoke different brands, and keep different pets. Using the given clues, determine who owns the fish.",
|
||||
"clues": [
|
||||
"The Brit lives in the red house",
|
||||
"The Swede keeps dogs as pets",
|
||||
"The Dane drinks tea",
|
||||
"The green house is on the left of the white house",
|
||||
"The green house's owner drinks coffee"
|
||||
]
|
||||
},
|
||||
"solution_criteria": {
|
||||
"format": "grid_solution",
|
||||
"required_mapping": ["nationality", "house_color", "beverage", "smoke", "pet"]
|
||||
},
|
||||
"category": "logical_deduction",
|
||||
"estimated_time": 25
|
||||
}
|
||||
}
|
||||
|
||||
return puzzles.get(puzzle_type, puzzles["lateral_thinking"])
|
||||
|
||||
def _evaluate_solution(self, session: PuzzleSession, solution: Dict[str, Any], reasoning: str = None) -> Dict[str, Any]:
|
||||
"""Evaluate a puzzle solution"""
|
||||
# Mock evaluation logic
|
||||
is_correct = random.choice([True, False]) # Would be actual evaluation
|
||||
|
||||
quality_score = random.uniform(60, 95) if is_correct else random.uniform(20, 60)
|
||||
|
||||
feedback = {
|
||||
"correctness": "Excellent reasoning!" if is_correct else "Not quite right, but good thinking.",
|
||||
"reasoning_quality": "Clear logical steps" if reasoning else "Consider explaining your reasoning",
|
||||
"suggestions": ["Try thinking about the constraints differently", "What assumptions might you be making?"]
|
||||
}
|
||||
|
||||
return {
|
||||
"correct": is_correct,
|
||||
"quality_score": quality_score,
|
||||
"feedback": feedback,
|
||||
"next_difficulty": session.difficulty_rating + (1 if is_correct else 0)
|
||||
}
|
||||
|
||||
def _generate_hint(self, session: PuzzleSession, hint_level: int) -> Dict[str, Any]:
|
||||
"""Generate appropriate hint based on puzzle and user progress"""
|
||||
hints = [
|
||||
{"text": "Consider what might be different about rainy days", "type": "direction"},
|
||||
{"text": "Think about the man's physical characteristics", "type": "clue"},
|
||||
{"text": "What would an umbrella help him reach?", "type": "leading"}
|
||||
]
|
||||
|
||||
return hints[min(hint_level - 1, len(hints) - 1)]
|
||||
|
||||
async def _update_analytics_after_puzzle(self, session: PuzzleSession, evaluation: Dict[str, Any]):
|
||||
"""Update analytics after puzzle completion"""
|
||||
analytics = await self._get_analytics(session.user_id)
|
||||
|
||||
if evaluation["correct"]:
|
||||
analytics.puzzle_solving_level = min(analytics.puzzle_solving_level + 0.1, 10)
|
||||
analytics.logical_reasoning_score = self._update_skill_score(
|
||||
analytics.logical_reasoning_score,
|
||||
evaluation["quality_score"] / 100
|
||||
)
|
||||
|
||||
analytics.total_sessions += 1
|
||||
analytics.last_activity_date = datetime.utcnow()
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
async def _get_analytics(self, user_id: str) -> LearningAnalytics:
|
||||
"""Get user analytics (shared with GameService)"""
|
||||
query = select(LearningAnalytics).where(LearningAnalytics.user_id == user_id)
|
||||
result = await self.db.execute(query)
|
||||
analytics = result.scalar_one_or_none()
|
||||
|
||||
if not analytics:
|
||||
analytics = LearningAnalytics(user_id=user_id)
|
||||
self.db.add(analytics)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(analytics)
|
||||
|
||||
return analytics
|
||||
|
||||
def _update_skill_score(self, current_score: float, performance: float) -> float:
|
||||
"""Update skill score using exponential moving average"""
|
||||
alpha = 0.1
|
||||
return current_score * (1 - alpha) + performance * 100 * alpha
|
||||
|
||||
|
||||
class PhilosophicalDialogueService:
|
||||
"""Service for managing philosophical dilemmas and ethical reasoning"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_available_dilemmas(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get available philosophical dilemmas"""
|
||||
analytics = await self._get_analytics(user_id)
|
||||
|
||||
dilemma_types = {
|
||||
"ethical_frameworks": {
|
||||
"name": "Ethical Framework Analysis",
|
||||
"description": "Explore different ethical theories through practical dilemmas",
|
||||
"topics": ["trolley_problem", "utilitarian_vs_deontological", "virtue_ethics_scenarios"],
|
||||
"skills_developed": ["ethical_reasoning", "framework_application", "moral_consistency"]
|
||||
},
|
||||
"game_theory": {
|
||||
"name": "Game Theory Dilemmas",
|
||||
"description": "Strategic decision-making in competitive and cooperative scenarios",
|
||||
"topics": ["prisoners_dilemma", "tragedy_of_commons", "coordination_games"],
|
||||
"skills_developed": ["strategic_thinking", "cooperation_analysis", "incentive_understanding"]
|
||||
},
|
||||
"ai_consciousness": {
|
||||
"name": "AI Consciousness & Rights",
|
||||
"description": "Explore questions about AI sentience, rights, and moral status",
|
||||
"topics": ["chinese_room", "turing_test_ethics", "ai_rights", "consciousness_criteria"],
|
||||
"skills_developed": ["conceptual_analysis", "consciousness_theory", "future_ethics"]
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"dilemma_types": dilemma_types,
|
||||
"user_level": analytics.philosophical_depth_level,
|
||||
"recommended_topics": self._get_recommended_topics(analytics),
|
||||
"recent_insights": [] # Would come from recent sessions
|
||||
}
|
||||
|
||||
async def start_dialogue_session(self, user_id: str, dilemma_type: str, topic: str) -> PhilosophicalDialogue:
|
||||
"""Start a new philosophical dialogue session"""
|
||||
scenario = self._get_dilemma_scenario(dilemma_type, topic)
|
||||
|
||||
dialogue = PhilosophicalDialogue(
|
||||
user_id=user_id,
|
||||
dilemma_type=dilemma_type,
|
||||
dilemma_title=scenario["title"],
|
||||
scenario_description=scenario["description"],
|
||||
framework_options=scenario["frameworks"],
|
||||
complexity_level=scenario.get("complexity", "intermediate"),
|
||||
estimated_discussion_time=scenario.get("estimated_time", 20)
|
||||
)
|
||||
|
||||
self.db.add(dialogue)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(dialogue)
|
||||
|
||||
return dialogue
|
||||
|
||||
async def submit_response(self, dialogue_id: str, user_id: str, response: str, framework: str = None) -> Dict[str, Any]:
|
||||
"""Submit a response to the philosophical dilemma"""
|
||||
dialogue_query = select(PhilosophicalDialogue).where(
|
||||
and_(PhilosophicalDialogue.id == dialogue_id, PhilosophicalDialogue.user_id == user_id)
|
||||
)
|
||||
result = await self.db.execute(dialogue_query)
|
||||
dialogue = result.scalar_one_or_none()
|
||||
|
||||
if not dialogue or dialogue.dialogue_status != 'active':
|
||||
raise ValueError("Dialogue session not found or not active")
|
||||
|
||||
# Generate AI response based on user input
|
||||
ai_response = self._generate_ai_response(dialogue, response, framework)
|
||||
|
||||
# Update dialogue state
|
||||
dialogue.exchange_count += 1
|
||||
dialogue.dialogue_history = dialogue.dialogue_history + [
|
||||
{"speaker": "user", "content": response, "framework": framework, "timestamp": datetime.utcnow().isoformat()},
|
||||
{"speaker": "ai", "content": ai_response["content"], "type": ai_response["type"], "timestamp": datetime.utcnow().isoformat()}
|
||||
]
|
||||
|
||||
if framework and framework not in dialogue.frameworks_explored:
|
||||
dialogue.frameworks_explored = dialogue.frameworks_explored + [framework]
|
||||
dialogue.framework_applications += 1
|
||||
|
||||
dialogue.last_exchange_at = datetime.utcnow()
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(dialogue)
|
||||
|
||||
return {
|
||||
"ai_response": ai_response["content"],
|
||||
"follow_up_questions": ai_response.get("questions", []),
|
||||
"suggested_frameworks": ai_response.get("suggested_frameworks", []),
|
||||
"dialogue_progress": {
|
||||
"exchanges": dialogue.exchange_count,
|
||||
"frameworks_explored": len(dialogue.frameworks_explored),
|
||||
"depth_assessment": self._assess_dialogue_depth(dialogue)
|
||||
}
|
||||
}
|
||||
|
||||
async def conclude_dialogue(self, dialogue_id: str, user_id: str, final_position: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Conclude the philosophical dialogue with final assessment"""
|
||||
dialogue_query = select(PhilosophicalDialogue).where(
|
||||
and_(PhilosophicalDialogue.id == dialogue_id, PhilosophicalDialogue.user_id == user_id)
|
||||
)
|
||||
result = await self.db.execute(dialogue_query)
|
||||
dialogue = result.scalar_one_or_none()
|
||||
|
||||
if not dialogue:
|
||||
raise ValueError("Dialogue session not found")
|
||||
|
||||
# Generate final assessment
|
||||
assessment = self._generate_final_assessment(dialogue, final_position)
|
||||
|
||||
# Update dialogue
|
||||
dialogue.dialogue_status = 'concluded'
|
||||
dialogue.concluded_at = datetime.utcnow()
|
||||
dialogue.final_position = final_position
|
||||
dialogue.key_insights = assessment["insights"]
|
||||
dialogue.ai_assessment = assessment["ai_evaluation"]
|
||||
|
||||
# Update scores based on dialogue quality
|
||||
dialogue.ethical_consistency_score = assessment["scores"]["consistency"]
|
||||
dialogue.perspective_flexibility_score = assessment["scores"]["flexibility"]
|
||||
dialogue.framework_mastery_score = assessment["scores"]["framework_mastery"]
|
||||
dialogue.synthesis_quality_score = assessment["scores"]["synthesis"]
|
||||
|
||||
# Update user analytics
|
||||
await self._update_analytics_after_dialogue(dialogue, assessment)
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(dialogue)
|
||||
|
||||
return {
|
||||
"final_assessment": assessment,
|
||||
"skill_development": assessment["skill_changes"],
|
||||
"recommended_next_topics": assessment["recommendations"]
|
||||
}
|
||||
|
||||
def _get_dilemma_scenario(self, dilemma_type: str, topic: str) -> Dict[str, Any]:
|
||||
"""Get specific dilemma scenario"""
|
||||
scenarios = {
|
||||
"ethical_frameworks": {
|
||||
"trolley_problem": {
|
||||
"title": "The Trolley Problem",
|
||||
"description": "A runaway trolley is heading towards five people tied to the tracks. You can pull a lever to divert it to another track, where it will kill one person instead. Do you pull the lever?",
|
||||
"frameworks": ["utilitarianism", "deontological", "virtue_ethics", "care_ethics"],
|
||||
"complexity": "intermediate"
|
||||
}
|
||||
},
|
||||
"game_theory": {
|
||||
"prisoners_dilemma": {
|
||||
"title": "The Prisoner's Dilemma",
|
||||
"description": "You and your partner are arrested and held separately. You can either confess (defect) or remain silent (cooperate). The outcomes depend on both choices.",
|
||||
"frameworks": ["rational_choice", "social_contract", "evolutionary_ethics"],
|
||||
"complexity": "intermediate"
|
||||
}
|
||||
},
|
||||
"ai_consciousness": {
|
||||
"chinese_room": {
|
||||
"title": "The Chinese Room Argument",
|
||||
"description": "Is a computer program that can perfectly simulate understanding Chinese actually understanding Chinese? What does this mean for AI consciousness?",
|
||||
"frameworks": ["functionalism", "behaviorism", "phenomenology", "computational_theory"],
|
||||
"complexity": "advanced"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return scenarios.get(dilemma_type, {}).get(topic, scenarios["ethical_frameworks"]["trolley_problem"])
|
||||
|
||||
def _generate_ai_response(self, dialogue: PhilosophicalDialogue, user_response: str, framework: str = None) -> Dict[str, Any]:
|
||||
"""Generate AI response using Socratic method"""
|
||||
# Mock AI response generation
|
||||
response_types = ["socratic_question", "framework_challenge", "perspective_shift", "synthesis_prompt"]
|
||||
response_type = random.choice(response_types)
|
||||
|
||||
responses = {
|
||||
"socratic_question": {
|
||||
"content": "That's an interesting perspective. What underlying assumptions are you making about the value of individual lives versus collective outcomes?",
|
||||
"questions": ["How do you weigh individual rights against collective welfare?", "What if the numbers were different?"],
|
||||
"type": "questioning"
|
||||
},
|
||||
"framework_challenge": {
|
||||
"content": "Your utilitarian approach focuses on outcomes. How might a deontologist view this same situation?",
|
||||
"suggested_frameworks": ["deontological", "virtue_ethics"],
|
||||
"type": "framework_exploration"
|
||||
},
|
||||
"perspective_shift": {
|
||||
"content": "Consider this from the perspective of each person involved. How might their consent or agency factor into your decision?",
|
||||
"questions": ["Does intent matter as much as outcome?", "What about the rights of those who cannot consent?"],
|
||||
"type": "perspective_expansion"
|
||||
},
|
||||
"synthesis_prompt": {
|
||||
"content": "You've explored multiple frameworks. Can you synthesize these different approaches into a coherent position?",
|
||||
"type": "synthesis"
|
||||
}
|
||||
}
|
||||
|
||||
return responses[response_type]
|
||||
|
||||
def _assess_dialogue_depth(self, dialogue: PhilosophicalDialogue) -> Dict[str, Any]:
|
||||
"""Assess the depth and quality of the dialogue"""
|
||||
return {
|
||||
"conceptual_depth": min(dialogue.exchange_count / 5, 1.0),
|
||||
"framework_breadth": len(dialogue.frameworks_explored) / 4,
|
||||
"synthesis_attempts": dialogue.synthesis_attempts,
|
||||
"perspective_shifts": dialogue.perspective_shifts
|
||||
}
|
||||
|
||||
def _generate_final_assessment(self, dialogue: PhilosophicalDialogue, final_position: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate comprehensive final assessment"""
|
||||
return {
|
||||
"insights": [
|
||||
"Demonstrated understanding of utilitarian reasoning",
|
||||
"Showed ability to consider multiple perspectives",
|
||||
"Struggled with synthesizing competing frameworks"
|
||||
],
|
||||
"scores": {
|
||||
"consistency": random.uniform(0.6, 0.9),
|
||||
"flexibility": random.uniform(0.5, 0.8),
|
||||
"framework_mastery": random.uniform(0.4, 0.85),
|
||||
"synthesis": random.uniform(0.3, 0.7)
|
||||
},
|
||||
"skill_changes": {
|
||||
"ethical_reasoning": "+5%",
|
||||
"perspective_taking": "+3%",
|
||||
"logical_consistency": "+2%"
|
||||
},
|
||||
"recommendations": [
|
||||
"Explore virtue ethics in more depth",
|
||||
"Practice synthesizing competing moral intuitions",
|
||||
"Consider real-world applications of ethical frameworks"
|
||||
],
|
||||
"ai_evaluation": {
|
||||
"dialogue_quality": "Good engagement with multiple perspectives",
|
||||
"growth_areas": "Framework synthesis and practical application",
|
||||
"strengths": "Clear reasoning and willingness to explore"
|
||||
}
|
||||
}
|
||||
|
||||
async def _update_analytics_after_dialogue(self, dialogue: PhilosophicalDialogue, assessment: Dict[str, Any]):
|
||||
"""Update user analytics after dialogue completion"""
|
||||
analytics = await self._get_analytics(dialogue.user_id)
|
||||
|
||||
# Update philosophical reasoning skills
|
||||
analytics.ethical_reasoning_score = self._update_skill_score(
|
||||
analytics.ethical_reasoning_score,
|
||||
assessment["scores"]["consistency"]
|
||||
)
|
||||
|
||||
analytics.philosophical_depth_level = min(
|
||||
analytics.philosophical_depth_level + 0.1,
|
||||
10
|
||||
)
|
||||
|
||||
analytics.total_sessions += 1
|
||||
analytics.last_activity_date = datetime.utcnow()
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
def _get_recommended_topics(self, analytics: LearningAnalytics) -> List[str]:
|
||||
"""Get recommended topics based on user progress"""
|
||||
if analytics.philosophical_depth_level < 3:
|
||||
return ["trolley_problem", "prisoners_dilemma"]
|
||||
elif analytics.philosophical_depth_level < 6:
|
||||
return ["virtue_ethics_scenarios", "tragedy_of_commons", "ai_rights"]
|
||||
else:
|
||||
return ["chinese_room", "consciousness_criteria", "coordination_games"]
|
||||
|
||||
async def _get_analytics(self, user_id: str) -> LearningAnalytics:
|
||||
"""Get user analytics"""
|
||||
query = select(LearningAnalytics).where(LearningAnalytics.user_id == user_id)
|
||||
result = await self.db.execute(query)
|
||||
analytics = result.scalar_one_or_none()
|
||||
|
||||
if not analytics:
|
||||
analytics = LearningAnalytics(user_id=user_id)
|
||||
self.db.add(analytics)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(analytics)
|
||||
|
||||
return analytics
|
||||
|
||||
def _update_skill_score(self, current_score: float, performance: float) -> float:
|
||||
"""Update skill score using exponential moving average"""
|
||||
alpha = 0.1
|
||||
return current_score * (1 - alpha) + performance * 100 * alpha
|
||||
427
apps/tenant-backend/app/services/mcp_integration.py
Normal file
427
apps/tenant-backend/app/services/mcp_integration.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""
|
||||
Model Context Protocol (MCP) Integration Service
|
||||
Enables extensible tool integration for GT 2.0 agents
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
import httpx
|
||||
import asyncio
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from app.models.agent import Agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MCPTool(BaseModel):
|
||||
"""MCP Tool definition"""
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
returns: Dict[str, Any] = Field(default_factory=dict)
|
||||
endpoint: str
|
||||
requires_auth: bool = False
|
||||
rate_limit: Optional[int] = None # requests per minute
|
||||
timeout: int = 30 # seconds
|
||||
enabled: bool = True
|
||||
|
||||
class MCPServer(BaseModel):
|
||||
"""MCP Server configuration"""
|
||||
id: str
|
||||
name: str
|
||||
base_url: str
|
||||
api_key: Optional[str] = None
|
||||
tools: List[MCPTool] = Field(default_factory=list)
|
||||
health_check_endpoint: str = "/health"
|
||||
tools_endpoint: str = "/tools"
|
||||
timeout: int = 30
|
||||
max_retries: int = 3
|
||||
enabled: bool = True
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_health_check: Optional[datetime] = None
|
||||
health_status: str = "unknown" # healthy, unhealthy, unknown
|
||||
|
||||
class MCPExecutionResult(BaseModel):
|
||||
"""Result of tool execution"""
|
||||
success: bool
|
||||
data: Dict[str, Any] = Field(default_factory=dict)
|
||||
error: Optional[str] = None
|
||||
execution_time_ms: int
|
||||
tokens_used: Optional[int] = None
|
||||
cost_cents: Optional[int] = None
|
||||
|
||||
class MCPIntegrationService:
|
||||
"""Service for managing MCP integrations"""
|
||||
|
||||
def __init__(self):
|
||||
self.servers: Dict[str, MCPServer] = {}
|
||||
self.rate_limits: Dict[str, Dict[str, List[datetime]]] = {} # server_id -> tool_name -> timestamps
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(30.0),
|
||||
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100)
|
||||
)
|
||||
|
||||
async def register_server(self, server: MCPServer) -> bool:
|
||||
"""Register a new MCP server"""
|
||||
try:
|
||||
# Validate server is reachable
|
||||
if await self.health_check(server):
|
||||
# Discover available tools
|
||||
tools = await self.discover_tools(server)
|
||||
server.tools = tools
|
||||
server.health_status = "healthy"
|
||||
server.last_health_check = datetime.utcnow()
|
||||
|
||||
self.servers[server.id] = server
|
||||
self.rate_limits[server.id] = {}
|
||||
|
||||
logger.info(f"Registered MCP server: {server.name} with {len(tools)} tools")
|
||||
return True
|
||||
else:
|
||||
server.health_status = "unhealthy"
|
||||
logger.error(f"Failed to register MCP server: {server.name} - health check failed")
|
||||
return False
|
||||
except Exception as e:
|
||||
server.health_status = "unhealthy"
|
||||
logger.error(f"Failed to register MCP server: {server.name} - {str(e)}")
|
||||
return False
|
||||
|
||||
async def health_check(self, server: MCPServer) -> bool:
|
||||
"""Check if MCP server is healthy"""
|
||||
try:
|
||||
headers = {}
|
||||
if server.api_key:
|
||||
headers["Authorization"] = f"Bearer {server.api_key}"
|
||||
|
||||
response = await self.client.get(
|
||||
f"{server.base_url.rstrip('/')}{server.health_check_endpoint}",
|
||||
headers=headers,
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
is_healthy = response.status_code == 200
|
||||
|
||||
if server.id in self.servers:
|
||||
self.servers[server.id].health_status = "healthy" if is_healthy else "unhealthy"
|
||||
self.servers[server.id].last_health_check = datetime.utcnow()
|
||||
|
||||
return is_healthy
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for {server.name}: {e}")
|
||||
if server.id in self.servers:
|
||||
self.servers[server.id].health_status = "unhealthy"
|
||||
return False
|
||||
|
||||
async def discover_tools(self, server: MCPServer) -> List[MCPTool]:
|
||||
"""Discover available tools from MCP server"""
|
||||
try:
|
||||
headers = {}
|
||||
if server.api_key:
|
||||
headers["Authorization"] = f"Bearer {server.api_key}"
|
||||
|
||||
response = await self.client.get(
|
||||
f"{server.base_url.rstrip('/')}{server.tools_endpoint}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
tools_data = response.json()
|
||||
tools = []
|
||||
|
||||
for tool_data in tools_data.get("tools", []):
|
||||
try:
|
||||
tool = MCPTool(**tool_data)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid tool definition from {server.name}: {e}")
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool discovery failed for {server.name}: {e}")
|
||||
return []
|
||||
|
||||
def _check_rate_limit(self, server_id: str, tool_name: str) -> bool:
|
||||
"""Check if rate limit allows execution"""
|
||||
server = self.servers.get(server_id)
|
||||
if not server:
|
||||
return False
|
||||
|
||||
tool = next((t for t in server.tools if t.name == tool_name), None)
|
||||
if not tool or not tool.rate_limit:
|
||||
return True
|
||||
|
||||
now = datetime.utcnow()
|
||||
minute_ago = now.timestamp() - 60
|
||||
|
||||
# Initialize tracking if needed
|
||||
if server_id not in self.rate_limits:
|
||||
self.rate_limits[server_id] = {}
|
||||
if tool_name not in self.rate_limits[server_id]:
|
||||
self.rate_limits[server_id][tool_name] = []
|
||||
|
||||
# Clean old timestamps
|
||||
timestamps = self.rate_limits[server_id][tool_name]
|
||||
self.rate_limits[server_id][tool_name] = [
|
||||
ts for ts in timestamps if ts.timestamp() > minute_ago
|
||||
]
|
||||
|
||||
# Check limit
|
||||
current_count = len(self.rate_limits[server_id][tool_name])
|
||||
if current_count >= tool.rate_limit:
|
||||
return False
|
||||
|
||||
# Record this request
|
||||
self.rate_limits[server_id][tool_name].append(now)
|
||||
return True
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
server_id: str,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
assistant_context: Optional[Dict[str, Any]] = None,
|
||||
user_context: Optional[Dict[str, Any]] = None
|
||||
) -> MCPExecutionResult:
|
||||
"""Execute a tool on an MCP server"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# Validate server exists and is enabled
|
||||
if server_id not in self.servers:
|
||||
return MCPExecutionResult(
|
||||
success=False,
|
||||
error=f"Server {server_id} not registered",
|
||||
execution_time_ms=0
|
||||
)
|
||||
|
||||
server = self.servers[server_id]
|
||||
if not server.enabled:
|
||||
return MCPExecutionResult(
|
||||
success=False,
|
||||
error=f"Server {server_id} is disabled",
|
||||
execution_time_ms=0
|
||||
)
|
||||
|
||||
# Find tool
|
||||
tool = next((t for t in server.tools if t.name == tool_name), None)
|
||||
if not tool:
|
||||
return MCPExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool {tool_name} not found on server {server_id}",
|
||||
execution_time_ms=0
|
||||
)
|
||||
|
||||
if not tool.enabled:
|
||||
return MCPExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool {tool_name} is disabled",
|
||||
execution_time_ms=0
|
||||
)
|
||||
|
||||
# Check rate limit
|
||||
if not self._check_rate_limit(server_id, tool_name):
|
||||
return MCPExecutionResult(
|
||||
success=False,
|
||||
error=f"Rate limit exceeded for tool {tool_name}",
|
||||
execution_time_ms=0
|
||||
)
|
||||
|
||||
# Prepare request
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if server.api_key:
|
||||
headers["Authorization"] = f"Bearer {server.api_key}"
|
||||
|
||||
payload = {
|
||||
"tool": tool_name,
|
||||
"parameters": parameters
|
||||
}
|
||||
|
||||
# Add context if provided
|
||||
if assistant_context:
|
||||
payload["assistant_context"] = assistant_context
|
||||
if user_context:
|
||||
payload["user_context"] = user_context
|
||||
|
||||
# Execute with retries
|
||||
last_exception = None
|
||||
for attempt in range(server.max_retries):
|
||||
try:
|
||||
response = await self.client.post(
|
||||
f"{server.base_url.rstrip('/')}{tool.endpoint}",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=tool.timeout
|
||||
)
|
||||
|
||||
execution_time = int((datetime.utcnow() - start_time).total_seconds() * 1000)
|
||||
|
||||
if response.status_code == 200:
|
||||
result_data = response.json()
|
||||
return MCPExecutionResult(
|
||||
success=True,
|
||||
data=result_data,
|
||||
execution_time_ms=execution_time,
|
||||
tokens_used=result_data.get("tokens_used"),
|
||||
cost_cents=result_data.get("cost_cents")
|
||||
)
|
||||
else:
|
||||
return MCPExecutionResult(
|
||||
success=False,
|
||||
error=f"HTTP {response.status_code}: {response.text}",
|
||||
execution_time_ms=execution_time
|
||||
)
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
last_exception = e
|
||||
if attempt == server.max_retries - 1:
|
||||
break
|
||||
await asyncio.sleep(2 ** attempt) # Exponential backoff
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt == server.max_retries - 1:
|
||||
break
|
||||
await asyncio.sleep(2 ** attempt)
|
||||
|
||||
# All retries failed
|
||||
execution_time = int((datetime.utcnow() - start_time).total_seconds() * 1000)
|
||||
return MCPExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool execution failed after {server.max_retries} attempts: {str(last_exception)}",
|
||||
execution_time_ms=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = int((datetime.utcnow() - start_time).total_seconds() * 1000)
|
||||
return MCPExecutionResult(
|
||||
success=False,
|
||||
error=f"Unexpected error: {str(e)}",
|
||||
execution_time_ms=execution_time
|
||||
)
|
||||
|
||||
async def get_tools_for_assistant(
|
||||
self,
|
||||
agent: Agent
|
||||
) -> List[MCPTool]:
|
||||
"""Get all available tools for an agent based on integrations"""
|
||||
tools = []
|
||||
for integration_id in agent.mcp_integration_ids:
|
||||
if integration_id in self.servers:
|
||||
server = self.servers[integration_id]
|
||||
if server.enabled:
|
||||
# Filter by enabled tools only
|
||||
tools.extend([tool for tool in server.tools if tool.enabled])
|
||||
return tools
|
||||
|
||||
def format_tools_for_llm(self, tools: List[MCPTool]) -> List[Dict[str, Any]]:
|
||||
"""Format tools for LLM function calling (OpenAI format)"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters
|
||||
}
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
async def get_server_status(self, server_id: str) -> Dict[str, Any]:
|
||||
"""Get detailed status of an MCP server"""
|
||||
if server_id not in self.servers:
|
||||
return {"error": "Server not found"}
|
||||
|
||||
server = self.servers[server_id]
|
||||
|
||||
# Perform health check
|
||||
await self.health_check(server)
|
||||
|
||||
return {
|
||||
"id": server.id,
|
||||
"name": server.name,
|
||||
"base_url": server.base_url,
|
||||
"enabled": server.enabled,
|
||||
"health_status": server.health_status,
|
||||
"last_health_check": server.last_health_check.isoformat() if server.last_health_check else None,
|
||||
"tools_count": len(server.tools),
|
||||
"enabled_tools_count": sum(1 for tool in server.tools if tool.enabled),
|
||||
"tools": [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"enabled": tool.enabled,
|
||||
"rate_limit": tool.rate_limit
|
||||
}
|
||||
for tool in server.tools
|
||||
]
|
||||
}
|
||||
|
||||
async def list_servers(self) -> List[Dict[str, Any]]:
|
||||
"""List all registered MCP servers"""
|
||||
servers = []
|
||||
for server_id, server in self.servers.items():
|
||||
servers.append(await self.get_server_status(server_id))
|
||||
return servers
|
||||
|
||||
async def remove_server(self, server_id: str) -> bool:
|
||||
"""Remove an MCP server"""
|
||||
if server_id in self.servers:
|
||||
del self.servers[server_id]
|
||||
if server_id in self.rate_limits:
|
||||
del self.rate_limits[server_id]
|
||||
logger.info(f"Removed MCP server: {server_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def update_server_config(self, server_id: str, config: Dict[str, Any]) -> bool:
|
||||
"""Update MCP server configuration"""
|
||||
if server_id not in self.servers:
|
||||
return False
|
||||
|
||||
server = self.servers[server_id]
|
||||
|
||||
# Update allowed fields
|
||||
if "enabled" in config:
|
||||
server.enabled = config["enabled"]
|
||||
if "api_key" in config:
|
||||
server.api_key = config["api_key"]
|
||||
if "timeout" in config:
|
||||
server.timeout = config["timeout"]
|
||||
if "max_retries" in config:
|
||||
server.max_retries = config["max_retries"]
|
||||
|
||||
# Rediscover tools if server was re-enabled or config changed
|
||||
if server.enabled:
|
||||
tools = await self.discover_tools(server)
|
||||
server.tools = tools
|
||||
|
||||
logger.info(f"Updated MCP server config: {server_id}")
|
||||
return True
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client"""
|
||||
await self.client.aclose()
|
||||
|
||||
# Singleton instance
|
||||
mcp_service = MCPIntegrationService()
|
||||
|
||||
# Default MCP servers for GT 2.0 (can be configured via admin)
|
||||
DEFAULT_MCP_SERVERS = [
|
||||
{
|
||||
"id": "gt2_core_tools",
|
||||
"name": "GT 2.0 Core Tools",
|
||||
"base_url": "http://localhost:8003", # Internal tools server
|
||||
"tools_endpoint": "/mcp/tools",
|
||||
"health_check_endpoint": "/mcp/health"
|
||||
},
|
||||
{
|
||||
"id": "web_search",
|
||||
"name": "Web Search Tools",
|
||||
"base_url": "http://localhost:8004",
|
||||
"api_key": None # Configured via admin
|
||||
}
|
||||
]
|
||||
371
apps/tenant-backend/app/services/message_bus_client.py
Normal file
371
apps/tenant-backend/app/services/message_bus_client.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""
|
||||
Message Bus Client for GT 2.0 Tenant Backend
|
||||
|
||||
Handles communication with Admin Cluster via RabbitMQ message queues.
|
||||
Processes commands from admin and sends responses back.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import uuid
|
||||
import hmac
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Callable, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import aio_pika
|
||||
from aio_pika import connect, Message, DeliveryMode, ExchangeType
|
||||
from aio_pika.abc import AbstractIncomingMessage
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AdminCommand(BaseModel):
|
||||
"""Command received from Admin Cluster"""
|
||||
command_id: str
|
||||
command_type: str # TENANT_PROVISION, TENANT_SUSPEND, etc.
|
||||
target_cluster: str
|
||||
target_namespace: str
|
||||
payload: Dict[str, Any]
|
||||
timestamp: str
|
||||
signature: str
|
||||
|
||||
class TenantResponse(BaseModel):
|
||||
"""Response sent to Admin Cluster"""
|
||||
command_id: str
|
||||
response_type: str # SUCCESS, ERROR, PROCESSING
|
||||
target_cluster: str = "admin"
|
||||
source_cluster: str = "tenant"
|
||||
namespace: str
|
||||
payload: Dict[str, Any] = Field(default_factory=dict)
|
||||
timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
|
||||
signature: Optional[str] = None
|
||||
|
||||
class MessageBusClient:
|
||||
"""
|
||||
Client for RabbitMQ message bus communication with Admin Cluster.
|
||||
|
||||
GT 2.0 Security Principles:
|
||||
- HMAC signature verification for all messages
|
||||
- Tenant-scoped command processing
|
||||
- No cross-tenant message leakage
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.connection = None
|
||||
self.channel = None
|
||||
self.admin_to_tenant_queue = None
|
||||
self.tenant_to_admin_queue = None
|
||||
|
||||
# Message handlers
|
||||
self.command_handlers: Dict[str, Callable[[AdminCommand], Any]] = {}
|
||||
|
||||
# RabbitMQ configuration from admin specification
|
||||
self.rabbitmq_url = getattr(
|
||||
self.settings,
|
||||
'RABBITMQ_URL',
|
||||
'amqp://gt2_admin:dev_password_change_in_prod@rabbitmq:5672/'
|
||||
)
|
||||
|
||||
# Security
|
||||
self.secret_key = getattr(self.settings, 'SECRET_KEY', 'production-secret-key')
|
||||
|
||||
logger.info("Message bus client initialized")
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to RabbitMQ message bus"""
|
||||
try:
|
||||
self.connection = await connect(self.rabbitmq_url)
|
||||
self.channel = await self.connection.channel()
|
||||
|
||||
# Declare exchanges and queues matching admin specification
|
||||
await self.channel.declare_exchange(
|
||||
'gt2_commands', ExchangeType.DIRECT, durable=True
|
||||
)
|
||||
|
||||
# Admin → Tenant command queue
|
||||
self.admin_to_tenant_queue = await self.channel.declare_queue(
|
||||
'admin_to_tenant', durable=True
|
||||
)
|
||||
|
||||
# Tenant → Admin response queue
|
||||
self.tenant_to_admin_queue = await self.channel.declare_queue(
|
||||
'tenant_to_admin', durable=True
|
||||
)
|
||||
|
||||
# Start consuming commands
|
||||
await self.admin_to_tenant_queue.consume(self._handle_admin_command)
|
||||
|
||||
logger.info("Connected to RabbitMQ message bus")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to message bus: {e}")
|
||||
return False
|
||||
|
||||
def _verify_signature(self, message_data: Dict[str, Any], signature: str) -> bool:
|
||||
"""Verify HMAC signature of incoming message"""
|
||||
try:
|
||||
# Create message content for signature (excluding signature field)
|
||||
content = {k: v for k, v in message_data.items() if k != 'signature'}
|
||||
content_str = json.dumps(content, sort_keys=True)
|
||||
|
||||
# Calculate expected signature
|
||||
expected_signature = hmac.new(
|
||||
self.secret_key.encode(),
|
||||
content_str.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
return hmac.compare_digest(expected_signature, signature)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Signature verification failed: {e}")
|
||||
return False
|
||||
|
||||
def _sign_message(self, message_data: Dict[str, Any]) -> str:
|
||||
"""Create HMAC signature for outgoing message"""
|
||||
try:
|
||||
content_str = json.dumps(message_data, sort_keys=True)
|
||||
return hmac.new(
|
||||
self.secret_key.encode(),
|
||||
content_str.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
except Exception as e:
|
||||
logger.error(f"Message signing failed: {e}")
|
||||
return ""
|
||||
|
||||
async def _handle_admin_command(self, message: AbstractIncomingMessage) -> None:
|
||||
"""Handle incoming command from Admin Cluster"""
|
||||
try:
|
||||
# Parse message
|
||||
message_data = json.loads(message.body.decode())
|
||||
command = AdminCommand(**message_data)
|
||||
|
||||
logger.info(f"Received admin command: {command.command_type} ({command.command_id})")
|
||||
|
||||
# Verify signature
|
||||
if not self._verify_signature(message_data, command.signature):
|
||||
logger.error(f"Invalid signature for command {command.command_id}")
|
||||
await self._send_response(command.command_id, "ERROR", {
|
||||
"error": "Invalid signature",
|
||||
"namespace": command.target_namespace
|
||||
})
|
||||
return
|
||||
|
||||
# Check if we have a handler for this command type
|
||||
if command.command_type not in self.command_handlers:
|
||||
logger.warning(f"No handler for command type: {command.command_type}")
|
||||
await self._send_response(command.command_id, "ERROR", {
|
||||
"error": f"Unknown command type: {command.command_type}",
|
||||
"namespace": command.target_namespace
|
||||
})
|
||||
return
|
||||
|
||||
# Execute command handler
|
||||
handler = self.command_handlers[command.command_type]
|
||||
result = await handler(command)
|
||||
|
||||
# Send success response
|
||||
await self._send_response(command.command_id, "SUCCESS", {
|
||||
"result": result,
|
||||
"namespace": command.target_namespace
|
||||
})
|
||||
|
||||
# Acknowledge message
|
||||
message.ack()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling admin command: {e}")
|
||||
try:
|
||||
if 'command' in locals():
|
||||
await self._send_response(command.command_id, "ERROR", {
|
||||
"error": str(e),
|
||||
"namespace": getattr(command, 'target_namespace', 'unknown')
|
||||
})
|
||||
message.nack(requeue=False)
|
||||
except:
|
||||
pass
|
||||
|
||||
async def _send_response(self, command_id: str, response_type: str, payload: Dict[str, Any]) -> None:
|
||||
"""Send response back to Admin Cluster"""
|
||||
try:
|
||||
response_data = {
|
||||
"command_id": command_id,
|
||||
"response_type": response_type,
|
||||
"target_cluster": "admin",
|
||||
"source_cluster": "tenant",
|
||||
"namespace": payload.get("namespace", "unknown"),
|
||||
"payload": payload,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Sign the response
|
||||
response_data["signature"] = self._sign_message(response_data)
|
||||
|
||||
# Send message
|
||||
message = Message(
|
||||
json.dumps(response_data).encode(),
|
||||
delivery_mode=DeliveryMode.PERSISTENT
|
||||
)
|
||||
|
||||
await self.channel.default_exchange.publish(
|
||||
message, routing_key=self.tenant_to_admin_queue.name
|
||||
)
|
||||
|
||||
logger.info(f"Sent response to admin: {response_type} for {command_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send response to admin: {e}")
|
||||
|
||||
def register_handler(self, command_type: str, handler: Callable[[AdminCommand], Any]) -> None:
|
||||
"""Register handler for specific command type"""
|
||||
self.command_handlers[command_type] = handler
|
||||
logger.info(f"Registered handler for command type: {command_type}")
|
||||
|
||||
async def send_notification(self, notification_type: str, payload: Dict[str, Any]) -> None:
|
||||
"""Send notification to Admin Cluster (not in response to a command)"""
|
||||
try:
|
||||
notification_data = {
|
||||
"command_id": str(uuid.uuid4()),
|
||||
"response_type": f"NOTIFICATION_{notification_type}",
|
||||
"target_cluster": "admin",
|
||||
"source_cluster": "tenant",
|
||||
"namespace": getattr(self.settings, 'TENANT_ID', 'unknown'),
|
||||
"payload": payload,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Sign the notification
|
||||
notification_data["signature"] = self._sign_message(notification_data)
|
||||
|
||||
# Send message
|
||||
message = Message(
|
||||
json.dumps(notification_data).encode(),
|
||||
delivery_mode=DeliveryMode.PERSISTENT
|
||||
)
|
||||
|
||||
await self.channel.default_exchange.publish(
|
||||
message, routing_key=self.tenant_to_admin_queue.name
|
||||
)
|
||||
|
||||
logger.info(f"Sent notification to admin: {notification_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send notification to admin: {e}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from message bus"""
|
||||
try:
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
logger.info("Disconnected from message bus")
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting from message bus: {e}")
|
||||
|
||||
# Global instance
|
||||
message_bus_client = MessageBusClient()
|
||||
|
||||
# Command handlers for different admin commands
|
||||
async def handle_tenant_provision(command: AdminCommand) -> Dict[str, Any]:
|
||||
"""Handle tenant provisioning command"""
|
||||
logger.info(f"Processing tenant provision for: {command.target_namespace}")
|
||||
|
||||
# TODO: Implement tenant provisioning logic
|
||||
# - Create tenant directory structure
|
||||
# - Initialize SQLite database
|
||||
# - Set up access controls
|
||||
# - Configure resource quotas
|
||||
|
||||
return {
|
||||
"status": "provisioned",
|
||||
"namespace": command.target_namespace,
|
||||
"resources_allocated": command.payload.get("resources", {})
|
||||
}
|
||||
|
||||
async def handle_tenant_suspend(command: AdminCommand) -> Dict[str, Any]:
|
||||
"""Handle tenant suspension command"""
|
||||
logger.info(f"Processing tenant suspension for: {command.target_namespace}")
|
||||
|
||||
# TODO: Implement tenant suspension logic
|
||||
# - Disable API access
|
||||
# - Pause running processes
|
||||
# - Preserve data integrity
|
||||
|
||||
return {
|
||||
"status": "suspended",
|
||||
"namespace": command.target_namespace
|
||||
}
|
||||
|
||||
async def handle_tenant_activate(command: AdminCommand) -> Dict[str, Any]:
|
||||
"""Handle tenant activation command"""
|
||||
logger.info(f"Processing tenant activation for: {command.target_namespace}")
|
||||
|
||||
# TODO: Implement tenant activation logic
|
||||
# - Restore API access
|
||||
# - Resume processes
|
||||
# - Validate system state
|
||||
|
||||
return {
|
||||
"status": "activated",
|
||||
"namespace": command.target_namespace
|
||||
}
|
||||
|
||||
async def handle_resource_allocate(command: AdminCommand) -> Dict[str, Any]:
|
||||
"""Handle resource allocation command"""
|
||||
logger.info(f"Processing resource allocation for: {command.target_namespace}")
|
||||
|
||||
# TODO: Implement resource allocation logic
|
||||
# - Update resource quotas
|
||||
# - Configure rate limits
|
||||
# - Enable new capabilities
|
||||
|
||||
return {
|
||||
"status": "allocated",
|
||||
"namespace": command.target_namespace,
|
||||
"resources": command.payload.get("resources", {})
|
||||
}
|
||||
|
||||
async def handle_resource_revoke(command: AdminCommand) -> Dict[str, Any]:
|
||||
"""Handle resource revocation command"""
|
||||
logger.info(f"Processing resource revocation for: {command.target_namespace}")
|
||||
|
||||
# TODO: Implement resource revocation logic
|
||||
# - Remove resource access
|
||||
# - Update quotas
|
||||
# - Gracefully handle running operations
|
||||
|
||||
return {
|
||||
"status": "revoked",
|
||||
"namespace": command.target_namespace
|
||||
}
|
||||
|
||||
# Register all command handlers
|
||||
async def initialize_message_bus() -> bool:
|
||||
"""Initialize message bus with all command handlers"""
|
||||
try:
|
||||
# Register command handlers
|
||||
message_bus_client.register_handler("TENANT_PROVISION", handle_tenant_provision)
|
||||
message_bus_client.register_handler("TENANT_SUSPEND", handle_tenant_suspend)
|
||||
message_bus_client.register_handler("TENANT_ACTIVATE", handle_tenant_activate)
|
||||
message_bus_client.register_handler("RESOURCE_ALLOCATE", handle_resource_allocate)
|
||||
message_bus_client.register_handler("RESOURCE_REVOKE", handle_resource_revoke)
|
||||
|
||||
# Connect to message bus
|
||||
connected = await message_bus_client.connect()
|
||||
|
||||
if connected:
|
||||
logger.info("Message bus client initialized successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to initialize message bus client")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing message bus: {e}")
|
||||
return False
|
||||
347
apps/tenant-backend/app/services/optics_service.py
Normal file
347
apps/tenant-backend/app/services/optics_service.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Optics Cost Calculation Service
|
||||
|
||||
Calculates inference and storage costs for the Optics feature.
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Storage cost rate
|
||||
STORAGE_COST_PER_MB_CENTS = 4.0 # $0.04 per MB
|
||||
|
||||
# Fallback pricing for unknown models
|
||||
DEFAULT_MODEL_PRICING = {
|
||||
"cost_per_1k_input": 0.10,
|
||||
"cost_per_1k_output": 0.10
|
||||
}
|
||||
|
||||
|
||||
class OpticsPricingCache:
|
||||
"""Simple in-memory cache for model pricing"""
|
||||
_pricing: Optional[Dict[str, Any]] = None
|
||||
_expires_at: Optional[datetime] = None
|
||||
_ttl_seconds: int = 300 # 5 minutes
|
||||
|
||||
@classmethod
|
||||
def get(cls) -> Optional[Dict[str, Any]]:
|
||||
if cls._pricing and cls._expires_at and datetime.utcnow() < cls._expires_at:
|
||||
return cls._pricing
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def set(cls, pricing: Dict[str, Any]):
|
||||
cls._pricing = pricing
|
||||
cls._expires_at = datetime.utcnow() + timedelta(seconds=cls._ttl_seconds)
|
||||
|
||||
@classmethod
|
||||
def clear(cls):
|
||||
cls._pricing = None
|
||||
cls._expires_at = None
|
||||
|
||||
|
||||
async def fetch_optics_settings(tenant_domain: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch Optics settings from Control Panel for a tenant.
|
||||
|
||||
Returns:
|
||||
dict with 'enabled', 'storage_cost_per_mb_cents'
|
||||
"""
|
||||
settings = get_settings()
|
||||
control_panel_url = settings.control_panel_url or "http://gentwo-control-panel-backend:8001"
|
||||
service_token = settings.service_auth_token or "internal-service-token"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{control_panel_url}/internal/optics/tenant/{tenant_domain}/settings",
|
||||
headers={
|
||||
"X-Service-Auth": service_token,
|
||||
"X-Service-Name": "tenant-backend"
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code == 404:
|
||||
logger.warning(f"Tenant {tenant_domain} not found in Control Panel")
|
||||
return {"enabled": False, "storage_cost_per_mb_cents": STORAGE_COST_PER_MB_CENTS}
|
||||
else:
|
||||
logger.error(f"Failed to fetch optics settings: {response.status_code}")
|
||||
return {"enabled": False, "storage_cost_per_mb_cents": STORAGE_COST_PER_MB_CENTS}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching optics settings: {str(e)}")
|
||||
# Default to disabled on error
|
||||
return {"enabled": False, "storage_cost_per_mb_cents": STORAGE_COST_PER_MB_CENTS}
|
||||
|
||||
|
||||
async def fetch_model_pricing() -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Fetch model pricing from Control Panel.
|
||||
Uses caching to avoid repeated requests.
|
||||
|
||||
Returns:
|
||||
dict mapping model_id -> {cost_per_1k_input, cost_per_1k_output}
|
||||
"""
|
||||
# Check cache first
|
||||
cached = OpticsPricingCache.get()
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
settings = get_settings()
|
||||
control_panel_url = settings.control_panel_url or "http://gentwo-control-panel-backend:8001"
|
||||
service_token = settings.service_auth_token or "internal-service-token"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{control_panel_url}/internal/optics/model-pricing",
|
||||
headers={
|
||||
"X-Service-Auth": service_token,
|
||||
"X-Service-Name": "tenant-backend"
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
pricing = data.get("models", {})
|
||||
OpticsPricingCache.set(pricing)
|
||||
return pricing
|
||||
else:
|
||||
logger.error(f"Failed to fetch model pricing: {response.status_code}")
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model pricing: {str(e)}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_model_cost_per_1k(model_id: str, pricing_map: Dict[str, Dict[str, float]]) -> float:
|
||||
"""
|
||||
Get combined cost per 1k tokens for a model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier (e.g., 'llama-3.3-70b-versatile')
|
||||
pricing_map: Map of model_id -> pricing info
|
||||
|
||||
Returns:
|
||||
Combined input + output cost per 1k tokens in dollars
|
||||
"""
|
||||
pricing = pricing_map.get(model_id)
|
||||
if pricing:
|
||||
return (pricing.get("cost_per_1k_input", 0.0) + pricing.get("cost_per_1k_output", 0.0))
|
||||
|
||||
# Try variations of the model ID
|
||||
# Sometimes model_used might have provider prefix like "groq:model-name"
|
||||
if ":" in model_id:
|
||||
model_name = model_id.split(":", 1)[1]
|
||||
pricing = pricing_map.get(model_name)
|
||||
if pricing:
|
||||
return (pricing.get("cost_per_1k_input", 0.0) + pricing.get("cost_per_1k_output", 0.0))
|
||||
|
||||
# Return default pricing
|
||||
return DEFAULT_MODEL_PRICING["cost_per_1k_input"] + DEFAULT_MODEL_PRICING["cost_per_1k_output"]
|
||||
|
||||
|
||||
def calculate_inference_cost_cents(tokens: int, cost_per_1k: float) -> float:
|
||||
"""
|
||||
Calculate inference cost in cents from token count.
|
||||
|
||||
Args:
|
||||
tokens: Total token count
|
||||
cost_per_1k: Cost per 1000 tokens in dollars
|
||||
|
||||
Returns:
|
||||
Cost in cents
|
||||
"""
|
||||
return (tokens / 1000) * cost_per_1k * 100
|
||||
|
||||
|
||||
def calculate_storage_cost_cents(total_mb: float, cost_per_mb_cents: float = STORAGE_COST_PER_MB_CENTS) -> float:
|
||||
"""
|
||||
Calculate storage cost in cents.
|
||||
|
||||
Args:
|
||||
total_mb: Total storage in megabytes
|
||||
cost_per_mb_cents: Cost per MB in cents (default 4 cents = $0.04)
|
||||
|
||||
Returns:
|
||||
Cost in cents
|
||||
"""
|
||||
return total_mb * cost_per_mb_cents
|
||||
|
||||
|
||||
def format_cost_display(cents: float) -> str:
|
||||
"""Format cost in cents to a display string like '$12.34'"""
|
||||
dollars = cents / 100
|
||||
return f"${dollars:,.2f}"
|
||||
|
||||
|
||||
async def get_optics_cost_summary(
|
||||
pg_client,
|
||||
tenant_domain: str,
|
||||
date_start: datetime,
|
||||
date_end: datetime,
|
||||
user_id: Optional[str] = None,
|
||||
include_user_breakdown: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate full Optics cost summary for a tenant.
|
||||
|
||||
Args:
|
||||
pg_client: PostgreSQL client
|
||||
tenant_domain: Tenant domain
|
||||
date_start: Start date for cost calculation
|
||||
date_end: End date for cost calculation
|
||||
user_id: Optional user ID filter
|
||||
include_user_breakdown: Whether to include per-user breakdown
|
||||
|
||||
Returns:
|
||||
Complete cost summary with breakdowns
|
||||
"""
|
||||
schema = f"tenant_{tenant_domain.replace('-', '_')}"
|
||||
|
||||
# Fetch model pricing
|
||||
pricing_map = await fetch_model_pricing()
|
||||
|
||||
# Build user filter
|
||||
user_filter = ""
|
||||
params = [date_start, date_end]
|
||||
param_idx = 3
|
||||
|
||||
if user_id:
|
||||
user_filter = f"AND c.user_id = ${param_idx}::uuid"
|
||||
params.append(user_id)
|
||||
param_idx += 1
|
||||
|
||||
# Query token usage by model
|
||||
token_query = f"""
|
||||
SELECT
|
||||
COALESCE(m.model_used, 'unknown') as model_id,
|
||||
COALESCE(SUM(m.token_count), 0) as total_tokens,
|
||||
COUNT(DISTINCT c.id) as conversations,
|
||||
COUNT(m.id) as messages
|
||||
FROM {schema}.messages m
|
||||
JOIN {schema}.conversations c ON m.conversation_id = c.id
|
||||
WHERE c.created_at >= $1 AND c.created_at <= $2
|
||||
AND m.model_used IS NOT NULL AND m.model_used != ''
|
||||
{user_filter}
|
||||
GROUP BY m.model_used
|
||||
ORDER BY total_tokens DESC
|
||||
"""
|
||||
|
||||
token_results = await pg_client.execute_query(token_query, *params)
|
||||
|
||||
# Calculate inference costs by model
|
||||
by_model = []
|
||||
total_inference_cents = 0.0
|
||||
total_tokens = 0
|
||||
|
||||
for row in token_results or []:
|
||||
model_id = row["model_id"]
|
||||
tokens = int(row["total_tokens"])
|
||||
total_tokens += tokens
|
||||
|
||||
cost_per_1k = get_model_cost_per_1k(model_id, pricing_map)
|
||||
cost_cents = calculate_inference_cost_cents(tokens, cost_per_1k)
|
||||
total_inference_cents += cost_cents
|
||||
|
||||
# Clean up model name for display
|
||||
model_name = model_id.split(":")[-1] if ":" in model_id else model_id
|
||||
|
||||
by_model.append({
|
||||
"model_id": model_id,
|
||||
"model_name": model_name,
|
||||
"tokens": tokens,
|
||||
"conversations": row["conversations"],
|
||||
"messages": row["messages"],
|
||||
"cost_cents": round(cost_cents, 2),
|
||||
"cost_display": format_cost_display(cost_cents)
|
||||
})
|
||||
|
||||
# Calculate percentages
|
||||
for item in by_model:
|
||||
item["percentage"] = round((item["cost_cents"] / total_inference_cents * 100) if total_inference_cents > 0 else 0, 1)
|
||||
|
||||
# Query storage
|
||||
storage_params = []
|
||||
storage_user_filter = ""
|
||||
if user_id:
|
||||
storage_user_filter = f"WHERE d.user_id = $1::uuid"
|
||||
storage_params.append(user_id)
|
||||
|
||||
storage_query = f"""
|
||||
SELECT
|
||||
COALESCE(SUM(d.file_size_bytes), 0) / 1048576.0 as total_mb,
|
||||
COUNT(d.id) as document_count,
|
||||
COUNT(DISTINCT d.dataset_id) as dataset_count
|
||||
FROM {schema}.documents d
|
||||
{storage_user_filter}
|
||||
"""
|
||||
|
||||
storage_result = await pg_client.execute_query(storage_query, *storage_params) if storage_params else await pg_client.execute_query(storage_query)
|
||||
storage_data = storage_result[0] if storage_result else {"total_mb": 0, "document_count": 0, "dataset_count": 0}
|
||||
|
||||
total_storage_mb = float(storage_data.get("total_mb", 0))
|
||||
storage_cost_cents = calculate_storage_cost_cents(total_storage_mb)
|
||||
|
||||
# Total cost
|
||||
total_cost_cents = total_inference_cents + storage_cost_cents
|
||||
|
||||
# User breakdown (admin only)
|
||||
by_user = []
|
||||
if include_user_breakdown:
|
||||
user_query = f"""
|
||||
SELECT
|
||||
c.user_id,
|
||||
u.email,
|
||||
COALESCE(SUM(m.token_count), 0) as tokens
|
||||
FROM {schema}.messages m
|
||||
JOIN {schema}.conversations c ON m.conversation_id = c.id
|
||||
JOIN {schema}.users u ON c.user_id = u.id
|
||||
WHERE c.created_at >= $1 AND c.created_at <= $2
|
||||
GROUP BY c.user_id, u.email
|
||||
ORDER BY tokens DESC
|
||||
"""
|
||||
|
||||
user_results = await pg_client.execute_query(user_query, date_start, date_end)
|
||||
|
||||
for row in user_results or []:
|
||||
user_tokens = int(row["tokens"])
|
||||
# Use average model cost for user breakdown
|
||||
avg_cost_per_1k = (total_inference_cents / total_tokens * 10) if total_tokens > 0 else 0.2
|
||||
user_cost_cents = (user_tokens / 1000) * avg_cost_per_1k
|
||||
|
||||
by_user.append({
|
||||
"user_id": str(row["user_id"]),
|
||||
"email": row["email"],
|
||||
"tokens": user_tokens,
|
||||
"cost_cents": round(user_cost_cents, 2),
|
||||
"cost_display": format_cost_display(user_cost_cents),
|
||||
"percentage": round((user_tokens / total_tokens * 100) if total_tokens > 0 else 0, 1)
|
||||
})
|
||||
|
||||
return {
|
||||
"inference_cost_cents": round(total_inference_cents, 2),
|
||||
"storage_cost_cents": round(storage_cost_cents, 2),
|
||||
"total_cost_cents": round(total_cost_cents, 2),
|
||||
"inference_cost_display": format_cost_display(total_inference_cents),
|
||||
"storage_cost_display": format_cost_display(storage_cost_cents),
|
||||
"total_cost_display": format_cost_display(total_cost_cents),
|
||||
"total_tokens": total_tokens,
|
||||
"total_storage_mb": round(total_storage_mb, 2),
|
||||
"document_count": storage_data.get("document_count", 0),
|
||||
"dataset_count": storage_data.get("dataset_count", 0),
|
||||
"by_model": by_model,
|
||||
"by_user": by_user if include_user_breakdown else None,
|
||||
"period_start": date_start.isoformat(),
|
||||
"period_end": date_end.isoformat()
|
||||
}
|
||||
806
apps/tenant-backend/app/services/pgvector_search_service.py
Normal file
806
apps/tenant-backend/app/services/pgvector_search_service.py
Normal file
@@ -0,0 +1,806 @@
|
||||
"""
|
||||
PGVector Hybrid Search Service for GT 2.0 Tenant Backend
|
||||
|
||||
Provides unified vector similarity search + full-text search using PostgreSQL
|
||||
with PGVector extension. Replaces ChromaDB for better performance and consistency.
|
||||
|
||||
Features:
|
||||
- Vector similarity search using PGVector
|
||||
- Full-text search using PostgreSQL built-in features
|
||||
- Hybrid scoring combining both approaches
|
||||
- Perfect tenant isolation using RLS
|
||||
- Zero-downtime MVCC operations
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
import uuid as uuid_lib
|
||||
from typing import Dict, Any, List, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
import asyncpg
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text, select, and_, or_
|
||||
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.core.config import get_settings
|
||||
from app.services.embedding_client import BGE_M3_EmbeddingClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchResult:
|
||||
"""Result from hybrid vector + text search"""
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
dataset_id: Optional[str]
|
||||
text: str
|
||||
metadata: Dict[str, Any]
|
||||
vector_similarity: float
|
||||
text_relevance: float
|
||||
hybrid_score: float
|
||||
rank: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchConfig:
|
||||
"""Configuration for hybrid search behavior"""
|
||||
vector_weight: float = 0.7
|
||||
text_weight: float = 0.3
|
||||
min_vector_similarity: float = 0.3
|
||||
min_text_relevance: float = 0.01
|
||||
max_results: int = 100
|
||||
rerank_results: bool = True
|
||||
|
||||
|
||||
class PGVectorSearchService:
|
||||
"""
|
||||
Hybrid search service using PostgreSQL + PGVector.
|
||||
|
||||
GT 2.0 Principles:
|
||||
- Perfect tenant isolation via RLS policies
|
||||
- Zero downtime MVCC operations
|
||||
- Real implementation (no mocks)
|
||||
- Operational elegance through unified storage
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: Optional[str] = None):
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.settings = get_settings()
|
||||
self.embedding_client = BGE_M3_EmbeddingClient()
|
||||
|
||||
# Schema naming for tenant isolation
|
||||
self.schema_name = self.settings.postgres_schema
|
||||
|
||||
logger.info(f"PGVector search service initialized for tenant {tenant_id}")
|
||||
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
dataset_ids: Optional[List[str]] = None,
|
||||
config: Optional[SearchConfig] = None,
|
||||
limit: int = 10
|
||||
) -> List[HybridSearchResult]:
|
||||
"""
|
||||
Perform hybrid vector + text search across user's documents.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
user_id: User performing search (for RLS)
|
||||
dataset_ids: Optional list of dataset IDs to search
|
||||
config: Search configuration parameters
|
||||
limit: Maximum results to return
|
||||
|
||||
Returns:
|
||||
List of ranked search results
|
||||
"""
|
||||
if config is None:
|
||||
config = SearchConfig()
|
||||
|
||||
try:
|
||||
logger.info(f"🔍 HYBRID_SEARCH START: query='{query}', user_id='{user_id}', dataset_ids={dataset_ids}")
|
||||
logger.info(f"🔍 HYBRID_SEARCH CONFIG: vector_weight={config.vector_weight}, text_weight={config.text_weight}, min_similarity={config.min_vector_similarity}")
|
||||
|
||||
# Generate query embedding via resource cluster
|
||||
logger.info(f"🔍 HYBRID_SEARCH: Generating embedding for query '{query}' with user_id '{user_id}'")
|
||||
query_embedding = await self._generate_query_embedding(query, user_id)
|
||||
logger.info(f"🔍 HYBRID_SEARCH: Generated embedding with {len(query_embedding)} dimensions")
|
||||
|
||||
# Execute hybrid search query
|
||||
logger.info(f"🔍 HYBRID_SEARCH: Executing hybrid query with user_id='{user_id}', dataset_ids={dataset_ids}")
|
||||
results = await self._execute_hybrid_query(
|
||||
query=query,
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
dataset_ids=dataset_ids,
|
||||
config=config,
|
||||
limit=limit
|
||||
)
|
||||
logger.info(f"🔍 HYBRID_SEARCH: Query returned {len(results)} raw results")
|
||||
|
||||
# Apply re-ranking if enabled
|
||||
if config.rerank_results and len(results) > 1:
|
||||
logger.info(f"🔍 HYBRID_SEARCH: Applying re-ranking to {len(results)} results")
|
||||
results = await self._rerank_results(results, query, config)
|
||||
logger.info(f"🔍 HYBRID_SEARCH: Re-ranking complete, final result count: {len(results)}")
|
||||
|
||||
logger.info(f"🔍 HYBRID_SEARCH COMPLETE: Returned {len(results)} results for user {user_id}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"🔍 HYBRID_SEARCH ERROR: {e}")
|
||||
logger.exception("Full hybrid search error traceback:")
|
||||
raise
|
||||
|
||||
async def vector_similarity_search(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
user_id: str,
|
||||
dataset_ids: Optional[List[str]] = None,
|
||||
similarity_threshold: float = 0.3,
|
||||
limit: int = 10
|
||||
) -> List[HybridSearchResult]:
|
||||
"""
|
||||
Pure vector similarity search using PGVector.
|
||||
|
||||
Args:
|
||||
query_embedding: Pre-computed query embedding
|
||||
user_id: User performing search
|
||||
dataset_ids: Optional dataset filter
|
||||
similarity_threshold: Minimum cosine similarity
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
Vector similarity results
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🔍 VECTOR_SEARCH START: user_id='{user_id}', dataset_ids={dataset_ids}, threshold={similarity_threshold}")
|
||||
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
logger.info(f"🔍 VECTOR_SEARCH: Got DB connection, resolving user UUID from '{user_id}'")
|
||||
|
||||
# Resolve user UUID first
|
||||
resolved_user_id = await self._resolve_user_uuid(conn, user_id)
|
||||
logger.info(f"🔍 VECTOR_SEARCH: Resolved user_id '{user_id}' to UUID '{resolved_user_id}'")
|
||||
|
||||
# RLS context removed - using schema-level isolation instead
|
||||
logger.info(f"🔍 VECTOR_SEARCH: Using resolved UUID '{resolved_user_id}' for query parameters")
|
||||
|
||||
# Build query with dataset filtering
|
||||
dataset_filter = ""
|
||||
params = [query_embedding, similarity_threshold, limit]
|
||||
|
||||
if dataset_ids:
|
||||
logger.info(f"🔍 VECTOR_SEARCH: Adding dataset filter for datasets: {dataset_ids}")
|
||||
dataset_start_idx = 4 # Start after query_embedding, similarity_threshold, limit
|
||||
placeholders = ",".join(f"${dataset_start_idx + i}" for i in range(len(dataset_ids)))
|
||||
dataset_filter = f"AND dataset_id = ANY(ARRAY[{placeholders}]::uuid[])"
|
||||
params.extend(dataset_ids)
|
||||
logger.info(f"🔍 VECTOR_SEARCH: Dataset filter SQL: {dataset_filter}")
|
||||
else:
|
||||
logger.error(f"🔍 VECTOR_SEARCH: SECURITY ERROR - Dataset IDs are required for search operations")
|
||||
raise ValueError("Dataset IDs are required for vector search operations. This could mean the agent has no datasets configured or dataset access control failed.")
|
||||
|
||||
query_sql = f"""
|
||||
SELECT
|
||||
id as chunk_id,
|
||||
document_id,
|
||||
dataset_id,
|
||||
content as text,
|
||||
metadata as chunk_metadata,
|
||||
1 - (embedding <=> $1::vector) as similarity,
|
||||
0.0 as text_relevance,
|
||||
1 - (embedding <=> $1::vector) as hybrid_score,
|
||||
ROW_NUMBER() OVER (ORDER BY embedding <=> $1::vector) as rank
|
||||
FROM {self.schema_name}.document_chunks
|
||||
WHERE 1 - (embedding <=> $1::vector) >= $2
|
||||
{dataset_filter}
|
||||
ORDER BY embedding <=> $1::vector
|
||||
LIMIT $3
|
||||
"""
|
||||
|
||||
logger.info(f"🔍 VECTOR_SEARCH: Executing SQL query with {len(params)} parameters")
|
||||
logger.info(f"🔍 VECTOR_SEARCH: SQL: {query_sql}")
|
||||
logger.info(f"🔍 VECTOR_SEARCH: Params types: embedding={type(query_embedding)} (len={len(query_embedding)}), threshold={type(similarity_threshold)}, limit={type(limit)}")
|
||||
if dataset_ids:
|
||||
logger.info(f"🔍 VECTOR_SEARCH: Dataset params: {[type(d) for d in dataset_ids]}")
|
||||
|
||||
rows = await conn.fetch(query_sql, *params)
|
||||
logger.info(f"🔍 VECTOR_SEARCH: Query executed successfully, got {len(rows)} rows")
|
||||
|
||||
results = [
|
||||
HybridSearchResult(
|
||||
chunk_id=row['chunk_id'],
|
||||
document_id=row['document_id'],
|
||||
dataset_id=row['dataset_id'],
|
||||
text=row['text'],
|
||||
metadata=row['metadata'] if row['metadata'] else {},
|
||||
vector_similarity=float(row['similarity']),
|
||||
text_relevance=0.0,
|
||||
hybrid_score=float(row['hybrid_score']),
|
||||
rank=row['rank']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
logger.info(f"Vector search returned {len(results)} results")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector similarity search failed: {e}")
|
||||
raise
|
||||
|
||||
async def full_text_search(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
dataset_ids: Optional[List[str]] = None,
|
||||
language: str = 'english',
|
||||
limit: int = 10
|
||||
) -> List[HybridSearchResult]:
|
||||
"""
|
||||
Full-text search using PostgreSQL's built-in features.
|
||||
|
||||
Args:
|
||||
query: Text query
|
||||
user_id: User performing search
|
||||
dataset_ids: Optional dataset filter
|
||||
language: Text search language configuration
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
Text relevance results
|
||||
"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
# Resolve user UUID first
|
||||
resolved_user_id = await self._resolve_user_uuid(conn, user_id)
|
||||
# RLS context removed - using schema-level isolation instead
|
||||
|
||||
# Build dataset filter - REQUIRE dataset_ids for security
|
||||
dataset_filter = ""
|
||||
params = [query, limit, resolved_user_id]
|
||||
if dataset_ids:
|
||||
placeholders = ",".join(f"${i+4}" for i in range(len(dataset_ids)))
|
||||
dataset_filter = f"AND dataset_id = ANY(ARRAY[{placeholders}]::uuid[])"
|
||||
params.extend(dataset_ids)
|
||||
else:
|
||||
logger.error(f"🔍 FULL_TEXT_SEARCH: SECURITY ERROR - Dataset IDs are required for search operations")
|
||||
raise ValueError("Dataset IDs are required for full-text search operations. This could mean the agent has no datasets configured or dataset access control failed.")
|
||||
|
||||
query_sql = f"""
|
||||
SELECT
|
||||
chunk_id,
|
||||
document_id,
|
||||
dataset_id,
|
||||
content as text,
|
||||
chunk_metadata as metadata,
|
||||
0.0 as similarity,
|
||||
ts_rank_cd(
|
||||
to_tsvector('{language}', content),
|
||||
plainto_tsquery('{language}', $1)
|
||||
) as text_relevance,
|
||||
ts_rank_cd(
|
||||
to_tsvector('{language}', content),
|
||||
plainto_tsquery('{language}', $1)
|
||||
) as hybrid_score,
|
||||
ROW_NUMBER() OVER (
|
||||
ORDER BY ts_rank_cd(
|
||||
to_tsvector('{language}', content),
|
||||
plainto_tsquery('{language}', $1)
|
||||
) DESC
|
||||
) as rank
|
||||
FROM {self.schema_name}.document_chunks
|
||||
WHERE user_id = $3::uuid
|
||||
AND to_tsvector('{language}', content) @@ plainto_tsquery('{language}', $1)
|
||||
{dataset_filter}
|
||||
ORDER BY text_relevance DESC
|
||||
LIMIT $2
|
||||
"""
|
||||
|
||||
rows = await conn.fetch(query_sql, *params)
|
||||
|
||||
results = [
|
||||
HybridSearchResult(
|
||||
chunk_id=row['chunk_id'],
|
||||
document_id=row['document_id'],
|
||||
dataset_id=row['dataset_id'],
|
||||
text=row['text'],
|
||||
metadata=row['metadata'] if row['metadata'] else {},
|
||||
vector_similarity=0.0,
|
||||
text_relevance=float(row['text_relevance']),
|
||||
hybrid_score=float(row['hybrid_score']),
|
||||
rank=row['rank']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
logger.info(f"Full-text search returned {len(results)} results")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Full-text search failed: {e}")
|
||||
raise
|
||||
|
||||
async def get_document_chunks(
|
||||
self,
|
||||
document_id: str,
|
||||
user_id: str,
|
||||
include_embeddings: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all chunks for a specific document.
|
||||
|
||||
Args:
|
||||
document_id: Target document ID
|
||||
user_id: User making request
|
||||
include_embeddings: Whether to include embedding vectors
|
||||
|
||||
Returns:
|
||||
List of document chunks with metadata
|
||||
"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
# Resolve user UUID first
|
||||
resolved_user_id = await self._resolve_user_uuid(conn, user_id)
|
||||
# RLS context removed - using schema-level isolation instead
|
||||
|
||||
select_fields = [
|
||||
"id as chunk_id", "document_id", "dataset_id", "chunk_index",
|
||||
"content", "metadata as chunk_metadata", "created_at"
|
||||
]
|
||||
|
||||
if include_embeddings:
|
||||
select_fields.append("embedding")
|
||||
|
||||
query_sql = f"""
|
||||
SELECT {', '.join(select_fields)}
|
||||
FROM {self.schema_name}.document_chunks
|
||||
WHERE document_id = $1
|
||||
AND user_id = $2::uuid
|
||||
ORDER BY chunk_index
|
||||
"""
|
||||
|
||||
rows = await conn.fetch(query_sql, document_id, resolved_user_id)
|
||||
|
||||
chunks = []
|
||||
for row in rows:
|
||||
chunk = {
|
||||
'chunk_id': row['chunk_id'],
|
||||
'document_id': row['document_id'],
|
||||
'dataset_id': row['dataset_id'],
|
||||
'chunk_index': row['chunk_index'],
|
||||
'content': row['content'],
|
||||
'metadata': row['chunk_metadata'] if row['chunk_metadata'] else {},
|
||||
'created_at': row['created_at'].isoformat() if row['created_at'] else None
|
||||
}
|
||||
|
||||
if include_embeddings:
|
||||
chunk['embedding'] = list(row['embedding']) if row['embedding'] else []
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
logger.info(f"Retrieved {len(chunks)} chunks for document {document_id}")
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get document chunks: {e}")
|
||||
raise
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
chunk_id: str,
|
||||
user_id: str,
|
||||
similarity_threshold: float = 0.5,
|
||||
limit: int = 5,
|
||||
exclude_same_document: bool = True
|
||||
) -> List[HybridSearchResult]:
|
||||
"""
|
||||
Find chunks similar to a given chunk.
|
||||
|
||||
Args:
|
||||
chunk_id: Reference chunk ID
|
||||
user_id: User making request
|
||||
similarity_threshold: Minimum similarity threshold
|
||||
limit: Maximum results
|
||||
exclude_same_document: Whether to exclude chunks from same document
|
||||
|
||||
Returns:
|
||||
Similar chunks ranked by similarity
|
||||
"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
# Resolve user UUID first
|
||||
resolved_user_id = await self._resolve_user_uuid(conn, user_id)
|
||||
# RLS context removed - using schema-level isolation instead
|
||||
|
||||
# Get reference chunk embedding
|
||||
ref_query = f"""
|
||||
SELECT embedding, document_id
|
||||
FROM {self.schema_name}.document_chunks
|
||||
WHERE chunk_id = $1
|
||||
AND user_id = $2::uuid
|
||||
"""
|
||||
|
||||
ref_result = await conn.fetchrow(ref_query, chunk_id, resolved_user_id)
|
||||
if not ref_result:
|
||||
raise ValueError(f"Reference chunk {chunk_id} not found")
|
||||
|
||||
ref_embedding = ref_result['embedding']
|
||||
ref_document_id = ref_result['document_id']
|
||||
|
||||
# Build exclusion filter
|
||||
exclusion_filter = ""
|
||||
params = [ref_embedding, similarity_threshold, limit, chunk_id, resolved_user_id]
|
||||
if exclude_same_document:
|
||||
exclusion_filter = "AND document_id != $6"
|
||||
params.append(ref_document_id)
|
||||
|
||||
# Search for similar chunks
|
||||
similarity_query = f"""
|
||||
SELECT
|
||||
id as chunk_id,
|
||||
document_id,
|
||||
dataset_id,
|
||||
content as text,
|
||||
metadata as chunk_metadata,
|
||||
1 - (embedding <=> $1::vector) as similarity
|
||||
FROM {self.schema_name}.document_chunks
|
||||
WHERE user_id = $5::uuid
|
||||
AND id != $4::uuid
|
||||
AND 1 - (embedding <=> $1::vector) >= $2
|
||||
{exclusion_filter}
|
||||
ORDER BY embedding <=> $1::vector
|
||||
LIMIT $3
|
||||
"""
|
||||
|
||||
rows = await conn.fetch(similarity_query, *params)
|
||||
|
||||
results = [
|
||||
HybridSearchResult(
|
||||
chunk_id=row['chunk_id'],
|
||||
document_id=row['document_id'],
|
||||
dataset_id=row['dataset_id'],
|
||||
text=row['text'],
|
||||
metadata=row['metadata'] if row['metadata'] else {},
|
||||
vector_similarity=float(row['similarity']),
|
||||
text_relevance=0.0,
|
||||
hybrid_score=float(row['similarity']),
|
||||
rank=i+1
|
||||
)
|
||||
for i, row in enumerate(rows)
|
||||
]
|
||||
|
||||
logger.info(f"Found {len(results)} similar chunks to {chunk_id}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Similar chunk search failed: {e}")
|
||||
raise
|
||||
|
||||
# Private helper methods
|
||||
|
||||
async def get_dataset_ids_from_documents(
|
||||
self,
|
||||
document_ids: List[str],
|
||||
user_id: str
|
||||
) -> List[str]:
|
||||
"""Get unique dataset IDs from a list of document IDs"""
|
||||
try:
|
||||
resolved_user_id = await self._resolve_user_id(user_id)
|
||||
dataset_ids = []
|
||||
|
||||
async with self.postgresql_client.get_connection() as conn:
|
||||
# RLS context removed - using schema-level isolation instead
|
||||
|
||||
# Query to get dataset IDs from document IDs
|
||||
placeholders = ",".join(f"${i+1}" for i in range(len(document_ids)))
|
||||
query = f"""
|
||||
SELECT DISTINCT dataset_id
|
||||
FROM {self.schema_name}.documents
|
||||
WHERE id = ANY(ARRAY[{placeholders}]::uuid[])
|
||||
AND user_id = ${len(document_ids)+1}::uuid
|
||||
"""
|
||||
|
||||
params = document_ids + [resolved_user_id]
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
dataset_ids = [str(row['dataset_id']) for row in rows if row['dataset_id']]
|
||||
logger.info(f"🔍 Resolved {len(dataset_ids)} dataset IDs from {len(document_ids)} documents: {dataset_ids}")
|
||||
|
||||
return dataset_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resolve dataset IDs from documents: {e}")
|
||||
return []
|
||||
|
||||
async def _generate_query_embedding(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str
|
||||
) -> List[float]:
|
||||
"""Generate embedding for search query using simple BGE-M3 client"""
|
||||
try:
|
||||
# Use direct BGE-M3 embedding client with tenant/user for billing
|
||||
embeddings = await self.embedding_client.generate_embeddings(
|
||||
[query],
|
||||
tenant_id=self.tenant_id, # Pass tenant for billing
|
||||
user_id=user_id # Pass user for billing
|
||||
)
|
||||
|
||||
if not embeddings or not embeddings[0]:
|
||||
raise ValueError("Failed to generate query embedding")
|
||||
|
||||
return embeddings[0]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Query embedding generation failed: {e}")
|
||||
raise
|
||||
|
||||
async def _execute_hybrid_query(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: List[float],
|
||||
user_id: str,
|
||||
dataset_ids: Optional[List[str]],
|
||||
config: SearchConfig,
|
||||
limit: int
|
||||
) -> List[HybridSearchResult]:
|
||||
"""Execute the hybrid search combining vector + text results"""
|
||||
try:
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY START: query='{query}', user_id='{user_id}', dataset_ids={dataset_ids}")
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY CONFIG: vector_weight={config.vector_weight}, text_weight={config.text_weight}, limit={limit}")
|
||||
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Got DB connection, resolving user UUID")
|
||||
|
||||
# Resolve user UUID first
|
||||
actual_user_id = await self._resolve_user_uuid(conn, user_id)
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Resolved user_id to '{actual_user_id}'")
|
||||
|
||||
# RLS context removed - using schema-level isolation instead
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Using resolved UUID '{actual_user_id}' for query parameters")
|
||||
|
||||
# Build dataset filter
|
||||
dataset_filter = ""
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Building parameters and dataset filter")
|
||||
|
||||
# Convert embedding list to string format for PostgreSQL vector type
|
||||
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Converted embedding to PostgreSQL vector string (length: {len(embedding_str)})")
|
||||
|
||||
# Ensure UUID is properly formatted as string for PostgreSQL
|
||||
if isinstance(actual_user_id, str):
|
||||
try:
|
||||
# Validate it's a proper UUID and convert back to string
|
||||
validated_uuid = str(uuid_lib.UUID(actual_user_id))
|
||||
actual_user_id_str = validated_uuid
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Validated UUID format: '{actual_user_id_str}'")
|
||||
except ValueError:
|
||||
# If it's not a valid UUID string, keep as is
|
||||
actual_user_id_str = actual_user_id
|
||||
logger.warning(f"🔍 _EXECUTE_HYBRID_QUERY: UUID validation failed, using as-is: '{actual_user_id_str}'")
|
||||
else:
|
||||
actual_user_id_str = str(actual_user_id)
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Converted user_id to string: '{actual_user_id_str}'")
|
||||
|
||||
params = [embedding_str, query, config.min_vector_similarity, config.min_text_relevance, config.max_results]
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Base parameters prepared (count: {len(params)})")
|
||||
|
||||
# Handle dataset filtering - REQUIRE dataset_ids for security
|
||||
if dataset_ids:
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Processing dataset filter for: {dataset_ids}")
|
||||
# Ensure dataset_ids is a list
|
||||
if isinstance(dataset_ids, str):
|
||||
dataset_ids = [dataset_ids]
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Converted string to list: {dataset_ids}")
|
||||
|
||||
if len(dataset_ids) > 0:
|
||||
# Generate proper placeholders for dataset IDs
|
||||
placeholders = ",".join(f"${i+6}" for i in range(len(dataset_ids)))
|
||||
dataset_filter = f"AND dataset_id = ANY(ARRAY[{placeholders}]::uuid[])"
|
||||
params.extend(dataset_ids)
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Dataset filter: {dataset_filter}, dataset_ids: {dataset_ids}")
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Total parameters after dataset filter: {len(params)}")
|
||||
else:
|
||||
logger.error(f"🔍 _EXECUTE_HYBRID_QUERY: SECURITY ERROR - Empty dataset_ids list not permitted")
|
||||
raise ValueError("Dataset IDs cannot be empty. This could mean the agent has no datasets configured or dataset access control failed.")
|
||||
else:
|
||||
# SECURITY FIX: No dataset filter when None is NOT ALLOWED
|
||||
logger.error(f"🔍 _EXECUTE_HYBRID_QUERY: SECURITY ERROR - Dataset IDs are required for search operations")
|
||||
|
||||
# More informative error message for debugging
|
||||
error_msg = "Dataset IDs are required for hybrid search operations. This could mean: " \
|
||||
"1) Agent has no datasets configured, 2) No datasets selected in UI, or " \
|
||||
"3) Dataset access control failed. Check agent configuration and dataset permissions."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Hybrid search query combining vector similarity and text relevance
|
||||
hybrid_query = f"""
|
||||
WITH vector_search AS (
|
||||
SELECT
|
||||
id as chunk_id,
|
||||
document_id,
|
||||
dataset_id,
|
||||
content,
|
||||
metadata as chunk_metadata,
|
||||
1 - (embedding <=> $1::vector) as vector_similarity,
|
||||
0.0 as text_relevance
|
||||
FROM {self.schema_name}.document_chunks
|
||||
WHERE 1 - (embedding <=> $1::vector) >= $3
|
||||
{dataset_filter}
|
||||
),
|
||||
text_search AS (
|
||||
SELECT
|
||||
id as chunk_id,
|
||||
document_id,
|
||||
dataset_id,
|
||||
content,
|
||||
metadata as chunk_metadata,
|
||||
0.0 as vector_similarity,
|
||||
ts_rank_cd(
|
||||
to_tsvector('english', content),
|
||||
plainto_tsquery('english', $2)
|
||||
) as text_relevance
|
||||
FROM {self.schema_name}.document_chunks
|
||||
WHERE to_tsvector('english', content) @@ plainto_tsquery('english', $2)
|
||||
AND ts_rank_cd(
|
||||
to_tsvector('english', content),
|
||||
plainto_tsquery('english', $2)
|
||||
) >= $4
|
||||
{dataset_filter}
|
||||
),
|
||||
combined_results AS (
|
||||
SELECT
|
||||
u.chunk_id,
|
||||
dc.document_id,
|
||||
dc.dataset_id,
|
||||
dc.content,
|
||||
dc.metadata as chunk_metadata,
|
||||
COALESCE(v.vector_similarity, 0.0) as vector_similarity,
|
||||
COALESCE(t.text_relevance, 0.0) as text_relevance,
|
||||
(COALESCE(v.vector_similarity, 0.0) * {config.vector_weight} +
|
||||
COALESCE(t.text_relevance, 0.0) * {config.text_weight}) as hybrid_score
|
||||
FROM (
|
||||
SELECT chunk_id FROM vector_search
|
||||
UNION
|
||||
SELECT chunk_id FROM text_search
|
||||
) u
|
||||
LEFT JOIN vector_search v USING (chunk_id)
|
||||
LEFT JOIN text_search t USING (chunk_id)
|
||||
LEFT JOIN {self.schema_name}.document_chunks dc ON (dc.id = u.chunk_id)
|
||||
)
|
||||
SELECT
|
||||
chunk_id,
|
||||
document_id,
|
||||
dataset_id,
|
||||
content as text,
|
||||
chunk_metadata as metadata,
|
||||
vector_similarity,
|
||||
text_relevance,
|
||||
hybrid_score,
|
||||
ROW_NUMBER() OVER (ORDER BY hybrid_score DESC) as rank
|
||||
FROM combined_results
|
||||
WHERE hybrid_score > 0.0
|
||||
ORDER BY hybrid_score DESC
|
||||
LIMIT $5
|
||||
"""
|
||||
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Executing hybrid SQL with {len(params)} parameters")
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Parameter types: {[type(p) for p in params]}")
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Query preview: {hybrid_query[:500]}...")
|
||||
|
||||
rows = await conn.fetch(hybrid_query, *params)
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: SQL execution successful, got {len(rows)} rows")
|
||||
|
||||
results = []
|
||||
for i, row in enumerate(rows):
|
||||
result = HybridSearchResult(
|
||||
chunk_id=row['chunk_id'],
|
||||
document_id=row['document_id'],
|
||||
dataset_id=row['dataset_id'],
|
||||
text=row['text'],
|
||||
metadata=row['metadata'] if row['metadata'] else {},
|
||||
vector_similarity=float(row['vector_similarity']),
|
||||
text_relevance=float(row['text_relevance']),
|
||||
hybrid_score=float(row['hybrid_score']),
|
||||
rank=row['rank']
|
||||
)
|
||||
results.append(result)
|
||||
if i < 3: # Log first few results for debugging
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Result {i+1}: chunk_id='{result.chunk_id}', score={result.hybrid_score:.3f}")
|
||||
|
||||
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY COMPLETE: Processed {len(results)} results")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"🔍 _EXECUTE_HYBRID_QUERY ERROR: {e}")
|
||||
logger.exception("Full hybrid query execution error traceback:")
|
||||
raise
|
||||
|
||||
async def _rerank_results(
|
||||
self,
|
||||
results: List[HybridSearchResult],
|
||||
query: str,
|
||||
config: SearchConfig
|
||||
) -> List[HybridSearchResult]:
|
||||
"""
|
||||
Apply advanced re-ranking to search results.
|
||||
|
||||
This can include:
|
||||
- Query-document interaction features
|
||||
- Diversity scoring
|
||||
- Recency weighting
|
||||
- User preference learning
|
||||
"""
|
||||
try:
|
||||
# For now, apply simple diversity re-ranking
|
||||
# to avoid showing too many results from the same document
|
||||
|
||||
reranked = []
|
||||
document_counts = {}
|
||||
max_per_document = max(1, len(results) // 3) # At most 1/3 from same document
|
||||
|
||||
for result in results:
|
||||
doc_count = document_counts.get(result.document_id, 0)
|
||||
if doc_count < max_per_document:
|
||||
reranked.append(result)
|
||||
document_counts[result.document_id] = doc_count + 1
|
||||
|
||||
# Re-rank the remaining items
|
||||
remaining = [r for r in results if r not in reranked]
|
||||
reranked.extend(remaining)
|
||||
|
||||
# Update rank numbers
|
||||
for i, result in enumerate(reranked):
|
||||
result.rank = i + 1
|
||||
|
||||
return reranked
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Re-ranking failed, returning original results: {e}")
|
||||
return results
|
||||
|
||||
async def _resolve_user_uuid(self, conn: asyncpg.Connection, user_id: str) -> str:
|
||||
"""
|
||||
Resolve user email to UUID if needed.
|
||||
Returns a validated UUID string.
|
||||
"""
|
||||
logger.info(f"🔍 _RESOLVE_USER_UUID START: input user_id='{user_id}' (type: {type(user_id)})")
|
||||
|
||||
if "@" in user_id: # If user_id is an email, look up the UUID
|
||||
logger.info(f"🔍 _RESOLVE_USER_UUID: Detected email format, looking up UUID for '{user_id}'")
|
||||
user_lookup_sql = f"SELECT id FROM {self.schema_name}.users WHERE email = $1"
|
||||
logger.info(f"🔍 _RESOLVE_USER_UUID: Executing SQL: {user_lookup_sql}")
|
||||
user_row = await conn.fetchrow(user_lookup_sql, user_id)
|
||||
if user_row:
|
||||
resolved_uuid = str(user_row['id'])
|
||||
logger.info(f"🔍 _RESOLVE_USER_UUID: Found UUID '{resolved_uuid}' for email '{user_id}'")
|
||||
return resolved_uuid
|
||||
else:
|
||||
logger.error(f"🔍 _RESOLVE_USER_UUID ERROR: User not found for email: {user_id}")
|
||||
raise ValueError(f"User not found: {user_id}")
|
||||
else:
|
||||
# Already a UUID
|
||||
logger.info(f"🔍 _RESOLVE_USER_UUID: Input '{user_id}' is already UUID format, returning as-is")
|
||||
return user_id
|
||||
|
||||
# _set_rls_context method removed - using schema-level isolation instead of RLS
|
||||
|
||||
|
||||
# Factory function for dependency injection
|
||||
def get_pgvector_search_service(tenant_id: str, user_id: Optional[str] = None) -> PGVectorSearchService:
|
||||
"""Get PGVector search service instance"""
|
||||
return PGVectorSearchService(tenant_id=tenant_id, user_id=user_id)
|
||||
883
apps/tenant-backend/app/services/postgresql_file_service.py
Normal file
883
apps/tenant-backend/app/services/postgresql_file_service.py
Normal file
@@ -0,0 +1,883 @@
|
||||
"""
|
||||
GT 2.0 PostgreSQL File Storage Service
|
||||
|
||||
Replaces MinIO with PostgreSQL-based file storage using:
|
||||
- BYTEA for small files (<10MB)
|
||||
- PostgreSQL Large Objects (LOBs) for large files (10MB-1GB)
|
||||
- Filesystem with metadata for massive files (>1GB)
|
||||
|
||||
Provides perfect tenant isolation through PostgreSQL schemas.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
import mimetypes
|
||||
from typing import Dict, Any, List, Optional, BinaryIO, AsyncIterator, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
import aiofiles
|
||||
from fastapi import UploadFile
|
||||
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.core.config import get_settings
|
||||
from app.core.permissions import ADMIN_ROLES
|
||||
from app.core.path_security import sanitize_tenant_domain, sanitize_filename, safe_join_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PostgreSQLFileService:
|
||||
"""PostgreSQL-based file storage service with tenant isolation"""
|
||||
|
||||
# Storage type thresholds
|
||||
SMALL_FILE_THRESHOLD = 10 * 1024 * 1024 # 10MB - use BYTEA
|
||||
LARGE_FILE_THRESHOLD = 1024 * 1024 * 1024 # 1GB - use LOBs
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str, user_role: str = "user"):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.user_role = user_role
|
||||
self.settings = get_settings()
|
||||
|
||||
# Filesystem path for massive files (>1GB)
|
||||
# Sanitize tenant_domain to prevent path traversal
|
||||
safe_tenant = sanitize_tenant_domain(tenant_domain)
|
||||
self.filesystem_base = Path("/data") / safe_tenant / "files" # codeql[py/path-injection] sanitize_tenant_domain() validates input
|
||||
self.filesystem_base.mkdir(parents=True, exist_ok=True, mode=0o700)
|
||||
|
||||
logger.info(f"PostgreSQL file service initialized for {tenant_domain}/{user_id} (role: {user_role})")
|
||||
|
||||
async def store_file(
|
||||
self,
|
||||
file: UploadFile,
|
||||
dataset_id: Optional[str] = None,
|
||||
category: str = "documents"
|
||||
) -> Dict[str, Any]:
|
||||
"""Store file using appropriate PostgreSQL strategy"""
|
||||
|
||||
try:
|
||||
logger.info(f"PostgreSQL file service: storing file {file.filename} for tenant {self.tenant_domain}, user {self.user_id}")
|
||||
logger.info(f"Dataset ID: {dataset_id}, Category: {category}")
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
# Generate file metadata
|
||||
file_hash = hashlib.sha256(content).hexdigest()[:16]
|
||||
content_type = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
|
||||
|
||||
# Handle different file types with appropriate processing
|
||||
if file_size <= self.SMALL_FILE_THRESHOLD and content_type.startswith('text/'):
|
||||
# Small text files stored directly
|
||||
storage_type = "text"
|
||||
storage_ref = "content_text"
|
||||
try:
|
||||
text_content = content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
text_content = content.decode('latin-1') # Fallback encoding
|
||||
elif content_type == 'application/pdf':
|
||||
# PDF files: extract text content, store binary separately
|
||||
storage_type = "pdf_extracted"
|
||||
storage_ref = "content_text"
|
||||
text_content = await self._extract_pdf_text(content)
|
||||
else:
|
||||
# Other binary files: store as base64 for now
|
||||
import base64
|
||||
storage_type = "base64"
|
||||
storage_ref = "content_text"
|
||||
text_content = base64.b64encode(content).decode('utf-8')
|
||||
|
||||
# Get PostgreSQL client
|
||||
logger.info("Getting PostgreSQL client")
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Always expect user_id to be a UUID string - no email lookups
|
||||
logger.info(f"Using user UUID: {self.user_id}")
|
||||
|
||||
# Validate user_id is a valid UUID format
|
||||
try:
|
||||
import uuid
|
||||
user_uuid = str(uuid.UUID(self.user_id))
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Invalid user UUID format: {self.user_id}, error: {e}")
|
||||
raise ValueError(f"Invalid user ID format. Expected UUID, got: {self.user_id}")
|
||||
|
||||
logger.info(f"Validated user UUID: {user_uuid}")
|
||||
|
||||
# 1. Validate user_uuid is present
|
||||
if not user_uuid:
|
||||
raise ValueError("User UUID is required but not found")
|
||||
|
||||
# 2. Validate and clean dataset_id
|
||||
dataset_uuid_param = None
|
||||
if dataset_id and dataset_id.strip() and dataset_id != "":
|
||||
try:
|
||||
import uuid
|
||||
dataset_uuid_param = str(uuid.UUID(dataset_id.strip()))
|
||||
logger.info(f"Dataset UUID validated: {dataset_uuid_param}")
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid dataset UUID: {dataset_id}, error: {e}")
|
||||
raise ValueError(f"Invalid dataset ID format: {dataset_id}")
|
||||
else:
|
||||
logger.info("No dataset_id provided, using NULL")
|
||||
|
||||
# 3. Validate file content and metadata
|
||||
if not file.filename or not file.filename.strip():
|
||||
raise ValueError("Filename cannot be empty")
|
||||
|
||||
if not content:
|
||||
raise ValueError("File content cannot be empty")
|
||||
|
||||
# 4. Generate and validate all string parameters
|
||||
safe_filename = f"{file_hash}_{file.filename}"
|
||||
safe_original_filename = file.filename.strip()
|
||||
safe_content_type = content_type or "application/octet-stream"
|
||||
safe_file_hash = file_hash
|
||||
safe_metadata = json.dumps({
|
||||
"storage_type": storage_type,
|
||||
"storage_ref": storage_ref,
|
||||
"category": category
|
||||
})
|
||||
|
||||
logger.info(f"All parameters validated:")
|
||||
logger.info(f" user_uuid: {user_uuid}")
|
||||
logger.info(f" dataset_uuid: {dataset_uuid_param}")
|
||||
logger.info(f" filename: {safe_filename}")
|
||||
logger.info(f" original_filename: {safe_original_filename}")
|
||||
logger.info(f" file_type: {safe_content_type}")
|
||||
logger.info(f" file_size: {file_size}")
|
||||
logger.info(f" file_hash: {safe_file_hash}")
|
||||
|
||||
# Store metadata in documents table (using existing schema)
|
||||
try:
|
||||
# Application user now has BYPASSRLS privilege - no RLS context needed
|
||||
logger.info("Storing document with BYPASSRLS privilege")
|
||||
|
||||
# Require dataset_id for all document uploads
|
||||
if not dataset_uuid_param:
|
||||
raise ValueError("dataset_id is required for document uploads")
|
||||
|
||||
logger.info(f"Storing document with dataset_id: {dataset_uuid_param}")
|
||||
logger.info(f"Document details: {safe_filename} ({file_size} bytes)")
|
||||
|
||||
# Insert with dataset_id
|
||||
# Determine if content is searchable (under PostgreSQL tsvector size limit)
|
||||
is_searchable = text_content is None or len(text_content) < 1048575
|
||||
|
||||
async with pg_client.get_connection() as conn:
|
||||
# Get tenant_id for the document
|
||||
tenant_id = await conn.fetchval("""
|
||||
SELECT id FROM tenants WHERE domain = $1 LIMIT 1
|
||||
""", self.tenant_domain)
|
||||
|
||||
if not tenant_id:
|
||||
raise ValueError(f"Tenant not found for domain: {self.tenant_domain}")
|
||||
|
||||
document_id = await conn.fetchval("""
|
||||
INSERT INTO documents (
|
||||
tenant_id, user_id, dataset_id, filename, original_filename,
|
||||
file_type, file_size_bytes, file_hash, content_text, processing_status,
|
||||
metadata, is_searchable, created_at, updated_at
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8, $9, 'pending', $10, $11, NOW(), NOW()
|
||||
)
|
||||
RETURNING id
|
||||
""",
|
||||
tenant_id, user_uuid, dataset_uuid_param, safe_filename, safe_original_filename,
|
||||
safe_content_type, file_size, safe_file_hash, text_content,
|
||||
safe_metadata, is_searchable
|
||||
)
|
||||
logger.info(f"Document inserted successfully with ID: {document_id}")
|
||||
|
||||
except Exception as db_error:
|
||||
logger.error(f"Database insertion failed: {db_error}")
|
||||
logger.error(f"Tenant domain: {self.tenant_domain}")
|
||||
logger.error(f"User ID: {self.user_id}")
|
||||
logger.error(f"Dataset ID: {dataset_id}")
|
||||
raise
|
||||
|
||||
result = {
|
||||
"id": document_id,
|
||||
"filename": file.filename,
|
||||
"content_type": content_type,
|
||||
"file_size": file_size,
|
||||
"file_hash": file_hash,
|
||||
"storage_type": storage_type,
|
||||
"storage_ref": storage_ref,
|
||||
"upload_timestamp": datetime.utcnow().isoformat(),
|
||||
"download_url": f"/api/v1/files/{document_id}"
|
||||
}
|
||||
|
||||
logger.info(f"Stored file {file.filename} ({file_size} bytes) as {storage_type} for user {self.user_id}")
|
||||
|
||||
# Trigger document processing pipeline for RAG functionality
|
||||
try:
|
||||
await self._trigger_document_processing(document_id, dataset_id, user_uuid, file.filename)
|
||||
logger.info(f"Successfully triggered document processing for {document_id}")
|
||||
except Exception as process_error:
|
||||
logger.error(f"Failed to trigger document processing for {document_id}: {process_error}")
|
||||
# Update document status to show processing failed
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
await pg_client.execute_command(
|
||||
"UPDATE documents SET processing_status = 'failed', error_message = $1 WHERE id = $2",
|
||||
f"Processing failed: {str(process_error)}", document_id
|
||||
)
|
||||
except Exception as update_error:
|
||||
logger.error(f"Failed to update document status after processing error: {update_error}")
|
||||
# Don't fail the upload if processing trigger fails - user can retry manually
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store file {file.filename}: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Ensure content is cleared from memory
|
||||
if 'content' in locals():
|
||||
del content
|
||||
|
||||
async def _store_as_bytea(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
file_hash: str,
|
||||
dataset_id: Optional[str],
|
||||
category: str
|
||||
) -> str:
|
||||
"""Store small file as BYTEA in documents table"""
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Store file content directly in BYTEA column
|
||||
# This will be handled by the main insert in store_file
|
||||
return "bytea_column"
|
||||
|
||||
async def _store_as_lob(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
file_hash: str,
|
||||
dataset_id: Optional[str],
|
||||
category: str
|
||||
) -> str:
|
||||
"""Store large file as PostgreSQL Large Object"""
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Create Large Object and get OID
|
||||
async with pg_client.get_connection() as conn:
|
||||
# Start transaction for LOB operations
|
||||
async with conn.transaction():
|
||||
# Create LOB and get OID
|
||||
lob_oid = await conn.fetchval("SELECT lo_create(0)")
|
||||
|
||||
# Open LOB for writing
|
||||
lob_fd = await conn.fetchval("SELECT lo_open($1, 131072)", lob_oid) # INV_WRITE mode
|
||||
|
||||
# Write content in chunks for memory efficiency
|
||||
chunk_size = 8192
|
||||
offset = 0
|
||||
for i in range(0, len(content), chunk_size):
|
||||
chunk = content[i:i + chunk_size]
|
||||
await conn.execute("SELECT lo_write($1, $2)", lob_fd, chunk)
|
||||
offset += len(chunk)
|
||||
|
||||
# Close LOB
|
||||
await conn.execute("SELECT lo_close($1)", lob_fd)
|
||||
|
||||
logger.info(f"Created PostgreSQL LOB with OID {lob_oid} for {filename}")
|
||||
return str(lob_oid)
|
||||
|
||||
async def _store_as_filesystem(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
file_hash: str,
|
||||
dataset_id: Optional[str],
|
||||
category: str
|
||||
) -> str:
|
||||
"""Store massive file on filesystem with PostgreSQL metadata"""
|
||||
|
||||
# Create secure file path with user isolation
|
||||
user_dir = self.filesystem_base / self.user_id / category
|
||||
if dataset_id:
|
||||
user_dir = user_dir / dataset_id
|
||||
|
||||
user_dir.mkdir(parents=True, exist_ok=True, mode=0o700)
|
||||
|
||||
# Generate secure filename
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
secure_filename = f"{timestamp}_{file_hash}_{filename}"
|
||||
file_path = user_dir / secure_filename
|
||||
|
||||
# Write file with secure permissions
|
||||
async with aiofiles.open(file_path, 'wb') as f:
|
||||
await f.write(content)
|
||||
|
||||
# Set secure file permissions
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
logger.info(f"Stored large file on filesystem: {file_path}")
|
||||
return str(file_path)
|
||||
|
||||
async def get_file(self, document_id: str) -> AsyncIterator[bytes]:
|
||||
"""Stream file content by document ID"""
|
||||
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Validate user_id is a valid UUID format
|
||||
try:
|
||||
import uuid
|
||||
user_uuid = str(uuid.UUID(self.user_id))
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Invalid user UUID format: {self.user_id}, error: {e}")
|
||||
raise ValueError(f"Invalid user ID format. Expected UUID, got: {self.user_id}")
|
||||
|
||||
# Get document metadata using UUID directly
|
||||
# Admins can access any document in their tenant, regular users only their own
|
||||
if self.user_role in ADMIN_ROLES:
|
||||
doc_info = await pg_client.fetch_one("""
|
||||
SELECT metadata, file_size_bytes, filename, content_text
|
||||
FROM documents d
|
||||
WHERE d.id = $1
|
||||
AND d.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
""", document_id, self.tenant_domain)
|
||||
else:
|
||||
doc_info = await pg_client.fetch_one("""
|
||||
SELECT metadata, file_size_bytes, filename, content_text
|
||||
FROM documents
|
||||
WHERE id = $1 AND user_id = $2::uuid
|
||||
""", document_id, user_uuid)
|
||||
|
||||
if not doc_info:
|
||||
raise FileNotFoundError(f"Document {document_id} not found")
|
||||
|
||||
# Get storage info from metadata - handle JSON string or dict
|
||||
metadata_raw = doc_info["metadata"] or "{}"
|
||||
if isinstance(metadata_raw, str):
|
||||
import json
|
||||
metadata = json.loads(metadata_raw)
|
||||
else:
|
||||
metadata = metadata_raw or {}
|
||||
storage_type = metadata.get("storage_type", "text")
|
||||
|
||||
if storage_type == "text":
|
||||
# Text content stored directly
|
||||
if doc_info["content_text"]:
|
||||
content_bytes = doc_info["content_text"].encode('utf-8')
|
||||
async for chunk in self._stream_from_bytea(content_bytes):
|
||||
yield chunk
|
||||
else:
|
||||
raise FileNotFoundError(f"Document content not found")
|
||||
|
||||
elif storage_type == "base64":
|
||||
# Base64 encoded binary content
|
||||
if doc_info["content_text"]:
|
||||
import base64
|
||||
content_bytes = base64.b64decode(doc_info["content_text"])
|
||||
async for chunk in self._stream_from_bytea(content_bytes):
|
||||
yield chunk
|
||||
else:
|
||||
raise FileNotFoundError(f"Document content not found")
|
||||
|
||||
elif storage_type == "lob":
|
||||
# Stream from PostgreSQL LOB
|
||||
storage_ref = metadata.get("storage_ref", "")
|
||||
async for chunk in self._stream_from_lob(int(storage_ref)):
|
||||
yield chunk
|
||||
|
||||
elif storage_type == "filesystem":
|
||||
# Stream from filesystem
|
||||
storage_ref = metadata.get("storage_ref", "")
|
||||
async for chunk in self._stream_from_filesystem(storage_ref):
|
||||
yield chunk
|
||||
else:
|
||||
# Default: treat as text content
|
||||
if doc_info["content_text"]:
|
||||
content_bytes = doc_info["content_text"].encode('utf-8')
|
||||
async for chunk in self._stream_from_bytea(content_bytes):
|
||||
yield chunk
|
||||
else:
|
||||
raise FileNotFoundError(f"Document content not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file {document_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _stream_from_bytea(self, content: bytes) -> AsyncIterator[bytes]:
|
||||
"""Stream content from BYTEA in chunks"""
|
||||
chunk_size = 8192
|
||||
for i in range(0, len(content), chunk_size):
|
||||
yield content[i:i + chunk_size]
|
||||
|
||||
async def _stream_from_lob(self, lob_oid: int) -> AsyncIterator[bytes]:
|
||||
"""Stream content from PostgreSQL Large Object"""
|
||||
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
async with pg_client.get_connection() as conn:
|
||||
async with conn.transaction():
|
||||
# Open LOB for reading
|
||||
lob_fd = await conn.fetchval("SELECT lo_open($1, 262144)", lob_oid) # INV_READ mode
|
||||
|
||||
# Stream in chunks
|
||||
chunk_size = 8192
|
||||
while True:
|
||||
chunk = await conn.fetchval("SELECT lo_read($1, $2)", lob_fd, chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
# Close LOB
|
||||
await conn.execute("SELECT lo_close($1)", lob_fd)
|
||||
|
||||
async def _stream_from_filesystem(self, file_path: str) -> AsyncIterator[bytes]:
|
||||
"""Stream content from filesystem"""
|
||||
|
||||
# Verify file belongs to tenant (security check)
|
||||
path_obj = Path(file_path)
|
||||
if not str(path_obj).startswith(str(self.filesystem_base)):
|
||||
raise PermissionError("Access denied to file")
|
||||
|
||||
if not path_obj.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
async with aiofiles.open(file_path, 'rb') as f:
|
||||
chunk_size = 8192
|
||||
while True:
|
||||
chunk = await f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
async def delete_file(self, document_id: str) -> bool:
|
||||
"""Delete file and metadata"""
|
||||
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Validate user_id is a valid UUID format
|
||||
try:
|
||||
import uuid
|
||||
user_uuid = str(uuid.UUID(self.user_id))
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Invalid user UUID format: {self.user_id}, error: {e}")
|
||||
raise ValueError(f"Invalid user ID format. Expected UUID, got: {self.user_id}")
|
||||
|
||||
# Get document info before deletion
|
||||
# Admins can delete any document in their tenant, regular users only their own
|
||||
if self.user_role in ADMIN_ROLES:
|
||||
doc_info = await pg_client.fetch_one("""
|
||||
SELECT storage_type, storage_ref FROM documents d
|
||||
WHERE d.id = $1
|
||||
AND d.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
""", document_id, self.tenant_domain)
|
||||
else:
|
||||
doc_info = await pg_client.fetch_one("""
|
||||
SELECT storage_type, storage_ref FROM documents
|
||||
WHERE id = $1 AND user_id = $2::uuid
|
||||
""", document_id, user_uuid)
|
||||
|
||||
if not doc_info:
|
||||
logger.warning(f"Document {document_id} not found for deletion")
|
||||
return False
|
||||
|
||||
storage_type = doc_info["storage_type"]
|
||||
storage_ref = doc_info["storage_ref"]
|
||||
|
||||
# Delete file content based on storage type
|
||||
if storage_type == "lob":
|
||||
# Delete LOB
|
||||
async with pg_client.get_connection() as conn:
|
||||
await conn.execute("SELECT lo_unlink($1)", int(storage_ref))
|
||||
elif storage_type == "filesystem":
|
||||
# Delete filesystem file
|
||||
try:
|
||||
path_obj = Path(storage_ref)
|
||||
if path_obj.exists():
|
||||
path_obj.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete filesystem file {storage_ref}: {e}")
|
||||
# BYTEA files are deleted with the row
|
||||
|
||||
# Delete metadata record
|
||||
if self.user_role in ADMIN_ROLES:
|
||||
deleted = await pg_client.execute_command("""
|
||||
DELETE FROM documents d
|
||||
WHERE d.id = $1
|
||||
AND d.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
""", document_id, self.tenant_domain)
|
||||
else:
|
||||
deleted = await pg_client.execute_command("""
|
||||
DELETE FROM documents WHERE id = $1 AND user_id = $2::uuid
|
||||
""", document_id, user_uuid)
|
||||
|
||||
if deleted > 0:
|
||||
logger.info(f"Deleted file {document_id} ({storage_type})")
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file {document_id}: {e}")
|
||||
return False
|
||||
|
||||
async def get_file_info(self, document_id: str) -> Dict[str, Any]:
|
||||
"""Get file metadata"""
|
||||
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Validate user_id is a valid UUID format
|
||||
try:
|
||||
import uuid
|
||||
user_uuid = str(uuid.UUID(self.user_id))
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Invalid user UUID format: {self.user_id}, error: {e}")
|
||||
raise ValueError(f"Invalid user ID format. Expected UUID, got: {self.user_id}")
|
||||
|
||||
# Admins can access any document metadata in their tenant, regular users only their own
|
||||
if self.user_role in ADMIN_ROLES:
|
||||
doc_info = await pg_client.fetch_one("""
|
||||
SELECT id, filename, original_filename, file_type as content_type, file_size_bytes as file_size,
|
||||
file_hash, dataset_id, metadata->'storage_type' as storage_type, metadata->'category' as category, created_at
|
||||
FROM documents d
|
||||
WHERE d.id = $1
|
||||
AND d.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
""", document_id, self.tenant_domain)
|
||||
else:
|
||||
doc_info = await pg_client.fetch_one("""
|
||||
SELECT id, filename, original_filename, file_type as content_type, file_size_bytes as file_size,
|
||||
file_hash, dataset_id, metadata->'storage_type' as storage_type, metadata->'category' as category, created_at
|
||||
FROM documents
|
||||
WHERE id = $1 AND user_id = $2::uuid
|
||||
""", document_id, user_uuid)
|
||||
|
||||
if not doc_info:
|
||||
raise FileNotFoundError(f"Document {document_id} not found")
|
||||
|
||||
return {
|
||||
"id": doc_info["id"],
|
||||
"filename": doc_info["filename"],
|
||||
"original_filename": doc_info["original_filename"],
|
||||
"content_type": doc_info["content_type"],
|
||||
"file_size": doc_info["file_size"],
|
||||
"file_hash": doc_info["file_hash"],
|
||||
"dataset_id": str(doc_info["dataset_id"]) if doc_info["dataset_id"] else None,
|
||||
"storage_type": doc_info["storage_type"],
|
||||
"category": doc_info["category"],
|
||||
"created_at": doc_info["created_at"].isoformat(),
|
||||
"download_url": f"/api/v1/files/{document_id}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file info for {document_id}: {e}")
|
||||
raise
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
dataset_id: Optional[str] = None,
|
||||
category: str = "documents",
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List user files with optional filtering"""
|
||||
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
|
||||
# Validate user_id is a valid UUID format
|
||||
try:
|
||||
import uuid
|
||||
user_uuid = str(uuid.UUID(self.user_id))
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Invalid user UUID format: {self.user_id}, error: {e}")
|
||||
raise ValueError(f"Invalid user ID format. Expected UUID, got: {self.user_id}")
|
||||
|
||||
# Build permission-aware query
|
||||
# Admins can list any documents in their tenant
|
||||
# Regular users can list documents they own OR documents in datasets they can access
|
||||
if self.user_role in ADMIN_ROLES:
|
||||
where_clauses = ["d.tenant_id = (SELECT id FROM tenants WHERE domain = $1)"]
|
||||
params = [self.tenant_domain]
|
||||
param_idx = 2
|
||||
else:
|
||||
# Non-admin users can see:
|
||||
# 1. Documents they own
|
||||
# 2. Documents in datasets with access_group = 'organization'
|
||||
# 3. Documents in datasets they're a member of (team access)
|
||||
where_clauses = ["""(
|
||||
d.user_id = $1::uuid
|
||||
OR EXISTS (
|
||||
SELECT 1 FROM datasets ds
|
||||
WHERE ds.id = d.dataset_id
|
||||
AND ds.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
|
||||
AND (
|
||||
ds.access_group = 'organization'
|
||||
OR (ds.access_group = 'team' AND $1::uuid = ANY(ds.team_members))
|
||||
)
|
||||
)
|
||||
)"""]
|
||||
params = [user_uuid, self.tenant_domain]
|
||||
param_idx = 3
|
||||
|
||||
if dataset_id:
|
||||
where_clauses.append(f"d.dataset_id = ${param_idx}::uuid")
|
||||
params.append(dataset_id)
|
||||
param_idx += 1
|
||||
|
||||
if category:
|
||||
where_clauses.append(f"(d.metadata->>'category' = ${param_idx} OR d.metadata->>'category' IS NULL)")
|
||||
params.append(category)
|
||||
param_idx += 1
|
||||
|
||||
query = f"""
|
||||
SELECT d.id, d.filename, d.original_filename, d.file_type as content_type, d.file_size_bytes as file_size,
|
||||
d.metadata->>'storage_type' as storage_type, d.metadata->>'category' as category, d.created_at, d.updated_at, d.dataset_id,
|
||||
d.processing_status, d.metadata, d.user_id, COUNT(dc.id) as chunk_count,
|
||||
ds.created_by as dataset_owner_id
|
||||
FROM documents d
|
||||
LEFT JOIN document_chunks dc ON d.id = dc.document_id
|
||||
LEFT JOIN datasets ds ON d.dataset_id = ds.id
|
||||
WHERE {' AND '.join(where_clauses)}
|
||||
GROUP BY d.id, d.filename, d.original_filename, d.file_type, d.file_size_bytes, d.metadata, d.created_at, d.updated_at, d.dataset_id, d.processing_status, d.user_id, ds.created_by
|
||||
ORDER BY d.created_at DESC LIMIT ${param_idx} OFFSET ${param_idx + 1}
|
||||
"""
|
||||
params.extend([limit, offset])
|
||||
|
||||
files = await pg_client.execute_query(query, *params)
|
||||
|
||||
# Helper function to parse metadata
|
||||
def parse_metadata(metadata_value):
|
||||
if metadata_value is None:
|
||||
return {}
|
||||
if isinstance(metadata_value, str):
|
||||
import json
|
||||
try:
|
||||
return json.loads(metadata_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return {}
|
||||
return metadata_value if isinstance(metadata_value, dict) else {}
|
||||
|
||||
return [
|
||||
{
|
||||
"id": file["id"],
|
||||
"filename": file["filename"],
|
||||
"original_filename": file["original_filename"],
|
||||
"content_type": file["content_type"],
|
||||
"file_type": file["content_type"],
|
||||
"file_size": file["file_size"],
|
||||
"file_size_bytes": file["file_size"],
|
||||
"dataset_id": file["dataset_id"],
|
||||
"storage_type": file["storage_type"],
|
||||
"category": file["category"],
|
||||
"created_at": file["created_at"].isoformat(),
|
||||
"updated_at": file["updated_at"].isoformat() if file.get("updated_at") else None,
|
||||
"processing_status": file.get("processing_status", "pending"),
|
||||
"chunk_count": file.get("chunk_count", 0),
|
||||
"chunks_processed": parse_metadata(file.get("metadata")).get("chunks_processed", 0),
|
||||
"total_chunks_expected": parse_metadata(file.get("metadata")).get("total_chunks_expected", 0),
|
||||
"processing_progress": parse_metadata(file.get("metadata")).get("processing_progress", 0),
|
||||
"processing_stage": parse_metadata(file.get("metadata")).get("processing_stage"),
|
||||
"download_url": f"/api/v1/files/{file['id']}",
|
||||
# Permission flags - user can delete if:
|
||||
# 1. They are admin, OR
|
||||
# 2. They uploaded the document, OR
|
||||
# 3. They own the parent dataset
|
||||
"can_delete": (
|
||||
self.user_role in ADMIN_ROLES or
|
||||
file["user_id"] == user_uuid or
|
||||
(file.get("dataset_owner_id") and str(file["dataset_owner_id"]) == user_uuid)
|
||||
)
|
||||
}
|
||||
for file in files
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list files for user {self.user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def cleanup_orphaned_files(self) -> int:
|
||||
"""Clean up orphaned files and LOBs"""
|
||||
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
cleanup_count = 0
|
||||
|
||||
# Find orphaned LOBs (LOBs without corresponding document records)
|
||||
async with pg_client.get_connection() as conn:
|
||||
async with conn.transaction():
|
||||
orphaned_lobs = await conn.fetch("""
|
||||
SELECT lo.oid FROM pg_largeobject_metadata lo
|
||||
LEFT JOIN documents d ON lo.oid::text = d.storage_ref
|
||||
WHERE d.storage_ref IS NULL AND d.storage_type = 'lob'
|
||||
""")
|
||||
|
||||
for lob in orphaned_lobs:
|
||||
await conn.execute("SELECT lo_unlink($1)", lob["oid"])
|
||||
cleanup_count += 1
|
||||
|
||||
# Find orphaned filesystem files
|
||||
# Note: This would require more complex logic to safely identify orphans
|
||||
|
||||
logger.info(f"Cleaned up {cleanup_count} orphaned files")
|
||||
return cleanup_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup orphaned files: {e}")
|
||||
return 0
|
||||
|
||||
async def _trigger_document_processing(
|
||||
self,
|
||||
document_id: str,
|
||||
dataset_id: Optional[str],
|
||||
user_uuid: str,
|
||||
filename: str
|
||||
):
|
||||
"""Trigger document processing pipeline for RAG functionality"""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from app.services.document_processor import get_document_processor
|
||||
|
||||
logger.info(f"Triggering document processing for {document_id}")
|
||||
|
||||
# Get document processor instance
|
||||
processor = await get_document_processor(tenant_domain=self.tenant_domain)
|
||||
|
||||
# For documents uploaded via PostgreSQL file service, the content is already stored
|
||||
# We need to process it from the database content rather than a file path
|
||||
await self._process_document_from_database(
|
||||
processor, document_id, dataset_id, user_uuid, filename
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Document processing trigger failed for {document_id}: {e}")
|
||||
# Update document status to failed
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
await pg_client.execute_command(
|
||||
"UPDATE documents SET processing_status = 'failed', error_message = $1 WHERE id = $2",
|
||||
f"Processing trigger failed: {str(e)}", document_id
|
||||
)
|
||||
except Exception as update_error:
|
||||
logger.error(f"Failed to update document status to failed: {update_error}")
|
||||
raise
|
||||
|
||||
async def _process_document_from_database(
|
||||
self,
|
||||
processor,
|
||||
document_id: str,
|
||||
dataset_id: Optional[str],
|
||||
user_uuid: str,
|
||||
filename: str
|
||||
):
|
||||
"""Process document using content already stored in database"""
|
||||
try:
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Get document content from database
|
||||
pg_client = await get_postgresql_client()
|
||||
doc_info = await pg_client.fetch_one("""
|
||||
SELECT content_text, file_type, metadata
|
||||
FROM documents
|
||||
WHERE id = $1 AND user_id = $2::uuid
|
||||
""", document_id, user_uuid)
|
||||
|
||||
if not doc_info or not doc_info["content_text"]:
|
||||
raise ValueError("Document content not found in database")
|
||||
|
||||
# Create temporary file with the content
|
||||
# Sanitize the file extension to prevent path injection
|
||||
safe_suffix = sanitize_filename(filename)
|
||||
safe_suffix = Path(safe_suffix).suffix if safe_suffix else ".tmp"
|
||||
# codeql[py/path-injection] safe_suffix is sanitized via sanitize_filename()
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix=safe_suffix, delete=False) as temp_file:
|
||||
# Handle different storage types - metadata might be JSON string or dict
|
||||
metadata_raw = doc_info["metadata"] or "{}"
|
||||
if isinstance(metadata_raw, str):
|
||||
import json
|
||||
metadata = json.loads(metadata_raw)
|
||||
else:
|
||||
metadata = metadata_raw or {}
|
||||
storage_type = metadata.get("storage_type", "text")
|
||||
|
||||
if storage_type == "text":
|
||||
temp_file.write(doc_info["content_text"])
|
||||
elif storage_type == "base64":
|
||||
import base64
|
||||
content_bytes = base64.b64decode(doc_info["content_text"])
|
||||
temp_file.close()
|
||||
with open(temp_file.name, 'wb') as binary_file:
|
||||
binary_file.write(content_bytes)
|
||||
elif storage_type == "pdf_extracted":
|
||||
# For PDFs with extracted text, create a placeholder text file
|
||||
# since the actual text content is already extracted
|
||||
temp_file.write(doc_info["content_text"])
|
||||
else:
|
||||
temp_file.write(doc_info["content_text"])
|
||||
|
||||
temp_file_path = Path(temp_file.name)
|
||||
|
||||
try:
|
||||
# Process the document using the existing document processor
|
||||
await processor.process_file(
|
||||
file_path=temp_file_path,
|
||||
dataset_id=dataset_id, # Keep None as None - don't convert to empty string
|
||||
user_id=user_uuid,
|
||||
original_filename=filename,
|
||||
document_id=document_id # Use existing document instead of creating new one
|
||||
)
|
||||
|
||||
logger.info(f"Successfully processed document {document_id} from database content")
|
||||
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
try:
|
||||
os.unlink(temp_file_path)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Failed to cleanup temporary file {temp_file_path}: {cleanup_error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process document from database content: {e}")
|
||||
raise
|
||||
|
||||
async def _extract_pdf_text(self, content: bytes) -> str:
|
||||
"""Extract text content from PDF bytes using pypdf"""
|
||||
import io
|
||||
import pypdf as PyPDF2 # pypdf is the maintained successor to PyPDF2
|
||||
|
||||
try:
|
||||
# Create BytesIO object from content
|
||||
pdf_stream = io.BytesIO(content)
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_stream)
|
||||
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(pdf_reader.pages):
|
||||
try:
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num + 1} ---\n{page_text}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not extract text from page {page_num + 1}: {e}")
|
||||
|
||||
if not text_parts:
|
||||
# If no text could be extracted, return a placeholder
|
||||
return f"PDF document with {len(pdf_reader.pages)} pages (text extraction failed)"
|
||||
|
||||
extracted_text = "\n\n".join(text_parts)
|
||||
logger.info(f"Successfully extracted {len(extracted_text)} characters from PDF with {len(pdf_reader.pages)} pages")
|
||||
return extracted_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"PDF text extraction failed: {e}")
|
||||
# Return a fallback description instead of failing completely
|
||||
return f"PDF document (text extraction failed: {str(e)})"
|
||||
577
apps/tenant-backend/app/services/rag_orchestrator.py
Normal file
577
apps/tenant-backend/app/services/rag_orchestrator.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
RAG Orchestrator Service for GT 2.0
|
||||
|
||||
Coordinates RAG operations between chat, MCP tools, and datasets.
|
||||
Provides intelligent context retrieval and source attribution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
|
||||
from app.services.mcp_integration import MCPIntegrationService, MCPExecutionResult
|
||||
from app.services.pgvector_search_service import PGVectorSearchService, get_pgvector_search_service
|
||||
from app.models.agent import Agent
|
||||
from app.models.assistant_dataset import AssistantDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGContext:
|
||||
"""Context retrieved from RAG operations"""
|
||||
chunks: List[Dict[str, Any]]
|
||||
sources: List[Dict[str, Any]]
|
||||
search_queries: List[str]
|
||||
total_chunks: int
|
||||
retrieval_time_ms: float
|
||||
datasets_used: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGSearchParams:
|
||||
"""Parameters for RAG search operations"""
|
||||
query: str
|
||||
dataset_ids: Optional[List[str]] = None
|
||||
max_chunks: int = 5
|
||||
similarity_threshold: float = 0.7
|
||||
search_method: str = "hybrid" # hybrid, vector, text
|
||||
|
||||
|
||||
class RAGOrchestrator:
|
||||
"""
|
||||
Orchestrates RAG operations for enhanced chat responses.
|
||||
|
||||
Coordinates between:
|
||||
- Dataset search via MCP RAG server
|
||||
- Conversation history via MCP conversation server
|
||||
- Direct PGVector search for performance
|
||||
- Agent dataset bindings for context filtering
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.mcp_service = MCPIntegrationService()
|
||||
self.resource_cluster_url = "http://resource-cluster:8000"
|
||||
|
||||
async def get_rag_context(
|
||||
self,
|
||||
agent: Agent,
|
||||
user_message: str,
|
||||
conversation_id: str,
|
||||
params: Optional[RAGSearchParams] = None
|
||||
) -> RAGContext:
|
||||
"""
|
||||
Get comprehensive RAG context for a chat message.
|
||||
|
||||
Args:
|
||||
agent: Agent instance with dataset bindings
|
||||
user_message: User's message/query
|
||||
conversation_id: Current conversation ID
|
||||
params: Optional search parameters
|
||||
|
||||
Returns:
|
||||
RAGContext with relevant chunks and sources
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
if params is None:
|
||||
params = RAGSearchParams(query=user_message)
|
||||
|
||||
try:
|
||||
# Get agent's dataset IDs for search (unchanged)
|
||||
agent_dataset_ids = await self._get_agent_datasets(agent)
|
||||
|
||||
# Get conversation files if conversation exists (NEW: simplified approach)
|
||||
conversation_files = []
|
||||
if conversation_id:
|
||||
conversation_files = await self._get_conversation_files(conversation_id)
|
||||
|
||||
# Determine search strategy
|
||||
search_dataset_ids = params.dataset_ids or agent_dataset_ids
|
||||
|
||||
# Check if we have any sources to search
|
||||
if not search_dataset_ids and not conversation_files:
|
||||
logger.info(f"No search sources available - agent {agent.id if agent else 'none'}")
|
||||
return RAGContext(
|
||||
chunks=[],
|
||||
sources=[],
|
||||
search_queries=[params.query],
|
||||
total_chunks=0,
|
||||
retrieval_time_ms=0.0,
|
||||
datasets_used=[]
|
||||
)
|
||||
|
||||
# Prepare search tasks for dual-source search
|
||||
search_tasks = []
|
||||
|
||||
# Task 1: Dataset search via MCP RAG server (unchanged)
|
||||
if search_dataset_ids:
|
||||
search_tasks.append(
|
||||
self._search_datasets_via_mcp(params.query, search_dataset_ids, params)
|
||||
)
|
||||
|
||||
# Task 2: Conversation files search (NEW: direct search)
|
||||
if conversation_files:
|
||||
search_tasks.append(
|
||||
self._search_conversation_files(params.query, conversation_id)
|
||||
)
|
||||
|
||||
# Execute searches in parallel
|
||||
search_results = await asyncio.gather(*search_tasks, return_exceptions=True)
|
||||
|
||||
# Process search results from all sources
|
||||
all_chunks = []
|
||||
all_sources = []
|
||||
|
||||
result_index = 0
|
||||
|
||||
# Process dataset search results (if performed)
|
||||
if search_dataset_ids and result_index < len(search_results):
|
||||
dataset_result = search_results[result_index]
|
||||
if not isinstance(dataset_result, Exception) and dataset_result.get("success"):
|
||||
dataset_chunks = dataset_result.get("results", [])
|
||||
all_chunks.extend(dataset_chunks)
|
||||
all_sources.extend(self._extract_sources(dataset_chunks))
|
||||
result_index += 1
|
||||
|
||||
# Process conversation files search results (if performed)
|
||||
if conversation_files and result_index < len(search_results):
|
||||
conversation_result = search_results[result_index]
|
||||
if not isinstance(conversation_result, Exception):
|
||||
conversation_chunks = conversation_result or []
|
||||
all_chunks.extend(conversation_chunks)
|
||||
all_sources.extend(
|
||||
await self._extract_conversation_file_sources(conversation_chunks, conversation_id)
|
||||
)
|
||||
result_index += 1
|
||||
|
||||
# Rank and filter results based on agent preferences (now using all chunks)
|
||||
final_chunks = await self._rank_and_filter_chunks(
|
||||
all_chunks, agent_dataset_ids, params
|
||||
)
|
||||
|
||||
retrieval_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"RAG context retrieved: {len(final_chunks)} chunks from "
|
||||
f"{len(search_dataset_ids)} datasets + {len(conversation_files)} conversation files "
|
||||
f"in {retrieval_time:.1f}ms"
|
||||
)
|
||||
|
||||
return RAGContext(
|
||||
chunks=final_chunks,
|
||||
sources=all_sources,
|
||||
search_queries=[params.query],
|
||||
total_chunks=len(final_chunks),
|
||||
retrieval_time_ms=retrieval_time,
|
||||
datasets_used=search_dataset_ids
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAG context retrieval failed: {e}")
|
||||
retrieval_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
# Return empty context on failure
|
||||
return RAGContext(
|
||||
chunks=[],
|
||||
sources=[],
|
||||
search_queries=[params.query],
|
||||
total_chunks=0,
|
||||
retrieval_time_ms=retrieval_time,
|
||||
datasets_used=[]
|
||||
)
|
||||
|
||||
async def _get_agent_datasets(self, agent: Agent) -> List[str]:
|
||||
"""Get dataset IDs for an agent (simplified)"""
|
||||
try:
|
||||
# Get agent configuration from agent service (skip complex table lookup)
|
||||
from app.services.agent_service import AgentService
|
||||
agent_service = AgentService(self.tenant_domain, self.user_id)
|
||||
agent_data = await agent_service.get_agent(agent.id)
|
||||
|
||||
if agent_data and 'selected_dataset_ids' in agent_data and agent_data['selected_dataset_ids'] is not None:
|
||||
selected_dataset_ids = agent_data.get('selected_dataset_ids', [])
|
||||
logger.info(f"Found {len(selected_dataset_ids)} dataset IDs in agent configuration: {selected_dataset_ids}")
|
||||
return selected_dataset_ids
|
||||
else:
|
||||
logger.info(f"No selected_dataset_ids found in agent {agent.id} configuration: {agent_data.get('selected_dataset_ids') if agent_data else 'no agent_data'}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get agent datasets: {e}")
|
||||
return []
|
||||
|
||||
async def _get_conversation_files(self, conversation_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get conversation files (NEW: simplified approach)"""
|
||||
try:
|
||||
from app.services.conversation_file_service import get_conversation_file_service
|
||||
file_service = get_conversation_file_service(self.tenant_domain, self.user_id)
|
||||
conversation_files = await file_service.list_files(conversation_id)
|
||||
|
||||
# Filter to only completed files
|
||||
completed_files = [f for f in conversation_files if f.get('processing_status') == 'completed']
|
||||
|
||||
logger.info(f"Found {len(completed_files)} processed conversation files")
|
||||
return completed_files
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation files: {e}")
|
||||
return []
|
||||
|
||||
async def _search_conversation_files(
|
||||
self,
|
||||
query: str,
|
||||
conversation_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search conversation files using vector similarity (NEW: direct search)"""
|
||||
try:
|
||||
from app.services.conversation_file_service import get_conversation_file_service
|
||||
file_service = get_conversation_file_service(self.tenant_domain, self.user_id)
|
||||
|
||||
results = await file_service.search_conversation_files(
|
||||
conversation_id=conversation_id,
|
||||
query=query,
|
||||
max_results=5
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(results)} matching conversation files")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search conversation files: {e}")
|
||||
return []
|
||||
|
||||
async def _extract_conversation_file_sources(
|
||||
self,
|
||||
chunks: List[Dict[str, Any]],
|
||||
conversation_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Extract unique conversation file sources with rich metadata"""
|
||||
sources = {}
|
||||
|
||||
for chunk in chunks:
|
||||
file_id = chunk.get("id")
|
||||
if file_id and file_id not in sources:
|
||||
uploaded_at = chunk.get("uploaded_at")
|
||||
if not uploaded_at:
|
||||
logger.warning(f"Missing uploaded_at for file {file_id}")
|
||||
|
||||
file_size = chunk.get("file_size_bytes", 0)
|
||||
if file_size == 0:
|
||||
logger.warning(f"Missing file_size_bytes for file {file_id}")
|
||||
|
||||
sources[file_id] = {
|
||||
"document_id": file_id,
|
||||
"dataset_id": None,
|
||||
"document_name": chunk.get("original_filename", "Unknown File"),
|
||||
|
||||
"source_type": "conversation_file",
|
||||
"access_scope": "conversation",
|
||||
"search_method": "auto_rag",
|
||||
|
||||
"conversation_id": conversation_id,
|
||||
|
||||
"uploaded_at": uploaded_at,
|
||||
"file_size_bytes": file_size,
|
||||
"content_type": chunk.get("content_type", "unknown"),
|
||||
"processing_status": chunk.get("processing_status", "unknown"),
|
||||
|
||||
"chunk_count": 1,
|
||||
"relevance_score": chunk.get("similarity_score", 0.0)
|
||||
}
|
||||
elif file_id in sources:
|
||||
sources[file_id]["chunk_count"] += 1
|
||||
current_score = chunk.get("similarity_score", 0.0)
|
||||
if current_score > sources[file_id]["relevance_score"]:
|
||||
sources[file_id]["relevance_score"] = current_score
|
||||
|
||||
return list(sources.values())
|
||||
|
||||
# Keep old method for backward compatibility during migration
|
||||
async def _get_conversation_datasets(self, conversation_id: str) -> List[str]:
|
||||
"""Get dataset IDs associated with a conversation (LEGACY: for migration)"""
|
||||
try:
|
||||
from app.services.conversation_service import ConversationService
|
||||
conversation_service = ConversationService(self.tenant_domain, self.user_id)
|
||||
conversation_dataset_ids = await conversation_service.get_conversation_datasets(
|
||||
conversation_id=conversation_id,
|
||||
user_identifier=self.user_id
|
||||
)
|
||||
logger.info(f"Found {len(conversation_dataset_ids)} legacy conversation datasets: {conversation_dataset_ids}")
|
||||
return conversation_dataset_ids
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation datasets: {e}")
|
||||
return []
|
||||
|
||||
async def _search_datasets_via_mcp(
|
||||
self,
|
||||
query: str,
|
||||
dataset_ids: List[str],
|
||||
params: RAGSearchParams
|
||||
) -> Dict[str, Any]:
|
||||
"""Search datasets using MCP RAG server"""
|
||||
try:
|
||||
# Prepare MCP tool call
|
||||
tool_params = {
|
||||
"query": query,
|
||||
"dataset_ids": dataset_ids,
|
||||
"max_results": params.max_chunks,
|
||||
"search_method": params.search_method
|
||||
}
|
||||
|
||||
# Create capability token for MCP access
|
||||
capability_token = {
|
||||
"capabilities": [{"resource": "mcp:rag:*"}],
|
||||
"user_id": self.user_id,
|
||||
"tenant_domain": self.tenant_domain
|
||||
}
|
||||
|
||||
# Execute MCP tool call via resource cluster
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/mcp/execute",
|
||||
json={
|
||||
"server_id": "rag_server",
|
||||
"tool_name": "search_datasets",
|
||||
"parameters": tool_params,
|
||||
"capability_token": capability_token,
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"user_id": self.user_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.error(f"MCP RAG search failed: {response.status_code} - {response.text}")
|
||||
return {"success": False, "error": "MCP search failed"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP dataset search failed: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def _rank_and_filter_chunks(
|
||||
self,
|
||||
chunks: List[Dict[str, Any]],
|
||||
agent_dataset_ids: List[str],
|
||||
params: RAGSearchParams
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Rank and filter chunks based on agent preferences (simplified)"""
|
||||
try:
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
# Convert agent dataset list to set for fast lookup
|
||||
agent_datasets = set(agent_dataset_ids)
|
||||
|
||||
# Separate conversation files from dataset chunks
|
||||
conversation_file_chunks = []
|
||||
dataset_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
if chunk.get("source_type") == "conversation_file":
|
||||
# Conversation files ALWAYS included (user explicitly attached them)
|
||||
chunk["final_score"] = chunk.get("similarity_score", 1.0)
|
||||
chunk["dataset_priority"] = -1 # Highest priority (shown first)
|
||||
conversation_file_chunks.append(chunk)
|
||||
else:
|
||||
dataset_chunks.append(chunk)
|
||||
|
||||
# Filter and score dataset chunks using thresholds
|
||||
scored_dataset_chunks = []
|
||||
for chunk in dataset_chunks:
|
||||
dataset_id = chunk.get("dataset_id")
|
||||
similarity_score = chunk.get("similarity_score", 0.0)
|
||||
|
||||
# Check if chunk is from agent's configured datasets
|
||||
if dataset_id in agent_datasets:
|
||||
# Use agent default threshold
|
||||
if similarity_score >= 0.7: # Default threshold
|
||||
chunk["final_score"] = similarity_score
|
||||
chunk["dataset_priority"] = 0 # High priority for agent datasets
|
||||
scored_dataset_chunks.append(chunk)
|
||||
else:
|
||||
# Use request threshold for other datasets
|
||||
if similarity_score >= params.similarity_threshold:
|
||||
chunk["final_score"] = similarity_score
|
||||
chunk["dataset_priority"] = 999 # Low priority
|
||||
scored_dataset_chunks.append(chunk)
|
||||
|
||||
# Combine: conversation files first, then sorted dataset chunks
|
||||
scored_dataset_chunks.sort(key=lambda x: x["final_score"], reverse=True)
|
||||
|
||||
# Limit total results, but always include all conversation files
|
||||
final_chunks = conversation_file_chunks + scored_dataset_chunks
|
||||
|
||||
# If we exceed max_chunks, keep all conversation files and trim datasets
|
||||
if len(final_chunks) > params.max_chunks:
|
||||
dataset_limit = params.max_chunks - len(conversation_file_chunks)
|
||||
if dataset_limit > 0:
|
||||
final_chunks = conversation_file_chunks + scored_dataset_chunks[:dataset_limit]
|
||||
else:
|
||||
# If conversation files alone exceed limit, keep them all anyway
|
||||
final_chunks = conversation_file_chunks
|
||||
|
||||
logger.info(f"Ranked chunks: {len(conversation_file_chunks)} conversation files (always included) + {len(scored_dataset_chunks)} dataset chunks → {len(final_chunks)} total")
|
||||
|
||||
return final_chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chunk ranking failed: {e}")
|
||||
return chunks[:params.max_chunks] # Fallback to simple truncation
|
||||
|
||||
def _extract_sources(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Extract unique document sources from dataset chunks with metadata"""
|
||||
sources = {}
|
||||
|
||||
for chunk in chunks:
|
||||
document_id = chunk.get("document_id")
|
||||
if document_id and document_id not in sources:
|
||||
sources[document_id] = {
|
||||
"document_id": document_id,
|
||||
"dataset_id": chunk.get("dataset_id"),
|
||||
"document_name": chunk.get("metadata", {}).get("document_name", "Unknown"),
|
||||
|
||||
"source_type": "dataset",
|
||||
"access_scope": "permanent",
|
||||
"search_method": "mcp_tool",
|
||||
|
||||
"dataset_name": chunk.get("dataset_name", "Unknown Dataset"),
|
||||
|
||||
"chunk_count": 1,
|
||||
"relevance_score": chunk.get("similarity_score", 0.0)
|
||||
}
|
||||
elif document_id in sources:
|
||||
sources[document_id]["chunk_count"] += 1
|
||||
current_score = chunk.get("similarity_score", 0.0)
|
||||
if current_score > sources[document_id]["relevance_score"]:
|
||||
sources[document_id]["relevance_score"] = current_score
|
||||
|
||||
return list(sources.values())
|
||||
|
||||
def format_context_for_llm(self, rag_context: RAGContext) -> str:
|
||||
"""Format RAG context for inclusion in LLM prompt"""
|
||||
if not rag_context.chunks:
|
||||
return ""
|
||||
|
||||
context_parts = ["## Relevant Context\n"]
|
||||
|
||||
# Add dataset search results
|
||||
if rag_context.chunks:
|
||||
context_parts.append("### From Documents:")
|
||||
for i, chunk in enumerate(rag_context.chunks[:5], 1): # Limit to top 5
|
||||
document_name = chunk.get("metadata", {}).get("document_name", "Unknown Document")
|
||||
content = chunk.get("content", chunk.get("text", ""))
|
||||
|
||||
context_parts.append(f"\n**Source {i}**: {document_name}")
|
||||
context_parts.append(f"Content: {content[:500]}...") # Truncate long content
|
||||
context_parts.append("")
|
||||
|
||||
|
||||
# Add source attribution
|
||||
if rag_context.sources:
|
||||
context_parts.append("### Sources:")
|
||||
for source in rag_context.sources:
|
||||
context_parts.append(f"- {source['document_name']} ({source['chunk_count']} relevant sections)")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def format_context_for_agent(
|
||||
self,
|
||||
rag_context: RAGContext,
|
||||
compact_mode: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Format RAG results with clear source attribution for agent consumption.
|
||||
|
||||
Args:
|
||||
rag_context: RAG search results with chunks and sources
|
||||
compact_mode: Use compact format for >2 files (~200 tokens vs ~700)
|
||||
|
||||
Returns:
|
||||
Formatted context string ready for LLM injection
|
||||
"""
|
||||
if not rag_context.chunks:
|
||||
return ""
|
||||
|
||||
dataset_chunks = [
|
||||
c for c in rag_context.chunks
|
||||
if c.get('source_type') == 'dataset' or c.get('dataset_id')
|
||||
]
|
||||
file_chunks = [
|
||||
c for c in rag_context.chunks
|
||||
if c.get('source_type') == 'conversation_file'
|
||||
]
|
||||
|
||||
context_parts = []
|
||||
context_parts.append("=" * 80)
|
||||
context_parts.append("📚 KNOWLEDGE BASE CONTEXT - DATASET DOCUMENTS")
|
||||
context_parts.append("=" * 80)
|
||||
|
||||
if dataset_chunks:
|
||||
context_parts.append("\n📂 FROM AGENT'S PERMANENT DATASETS:")
|
||||
context_parts.append("(These are documents from the agent's configured knowledge base)\n")
|
||||
|
||||
for i, chunk in enumerate(dataset_chunks, 1):
|
||||
dataset_name = chunk.get('dataset_name', 'Unknown Dataset')
|
||||
doc_name = chunk.get('metadata', {}).get('document_name', 'Unknown')
|
||||
content = chunk.get('content', chunk.get('text', ''))
|
||||
score = chunk.get('similarity_score', 0.0)
|
||||
|
||||
if compact_mode:
|
||||
context_parts.append(f"\n[Dataset: {dataset_name}]\n{content[:400]}...")
|
||||
else:
|
||||
context_parts.append(f"\n{'─' * 80}")
|
||||
context_parts.append(f"📚 DATASET EXCERPT {i}")
|
||||
context_parts.append(f"Dataset: {dataset_name} / Document: {doc_name}")
|
||||
context_parts.append(f"Relevance: {score:.2f}")
|
||||
context_parts.append(f"{'─' * 80}")
|
||||
context_parts.append(content[:600] if len(content) > 600 else content)
|
||||
if len(content) > 600:
|
||||
context_parts.append("\n[... excerpt continues ...]")
|
||||
|
||||
if file_chunks:
|
||||
context_parts.append(f"\n\n{'=' * 80}")
|
||||
context_parts.append("📎 FROM CONVERSATION FILES - USER ATTACHED DOCUMENTS")
|
||||
context_parts.append("=" * 80)
|
||||
context_parts.append("(These are files the user attached to THIS specific conversation)\n")
|
||||
|
||||
for i, chunk in enumerate(file_chunks, 1):
|
||||
filename = chunk.get('document_name', chunk.get('original_filename', 'Unknown'))
|
||||
content = chunk.get('content', chunk.get('text', ''))
|
||||
score = chunk.get('similarity_score', 0.0)
|
||||
|
||||
if compact_mode:
|
||||
context_parts.append(f"\n[File: {filename}]\n{content[:400]}...")
|
||||
else:
|
||||
context_parts.append(f"\n{'─' * 80}")
|
||||
context_parts.append(f"📄 FILE EXCERPT {i}: {filename}")
|
||||
context_parts.append(f"Relevance: {score:.2f}")
|
||||
context_parts.append(f"{'─' * 80}")
|
||||
context_parts.append(content[:600] if len(content) > 600 else content)
|
||||
if len(content) > 600:
|
||||
context_parts.append("\n[... excerpt continues ...]")
|
||||
|
||||
context_text = "\n".join(context_parts)
|
||||
|
||||
formatted_context = f"""{context_text}
|
||||
|
||||
{'=' * 80}
|
||||
⚠️ CONTEXT USAGE INSTRUCTIONS:
|
||||
1. CONVERSATION FILES (📎) = User's attached files for THIS chat - cite as "In your attached file..."
|
||||
2. DATASET DOCUMENTS (📂) = Agent's knowledge base - cite as "According to the dataset..."
|
||||
3. Always prioritize conversation files when both sources have relevant information
|
||||
4. Be explicit about which source you're referencing in your answer
|
||||
{'=' * 80}"""
|
||||
|
||||
return formatted_context
|
||||
|
||||
|
||||
# Global instance factory
|
||||
def get_rag_orchestrator(tenant_domain: str, user_id: str) -> RAGOrchestrator:
|
||||
"""Get RAG orchestrator instance for tenant and user"""
|
||||
return RAGOrchestrator(tenant_domain, user_id)
|
||||
671
apps/tenant-backend/app/services/rag_service.py
Normal file
671
apps/tenant-backend/app/services/rag_service.py
Normal file
@@ -0,0 +1,671 @@
|
||||
"""
|
||||
RAG Service for GT 2.0 Tenant Backend
|
||||
|
||||
Orchestrates document processing, embedding generation, and vector storage
|
||||
with perfect tenant isolation and zero downtime compliance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import aiofiles
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
from typing import Dict, Any, List, Optional, BinaryIO
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.document import Document, RAGDataset, DatasetDocument, DocumentChunk
|
||||
from app.core.database import get_db_session
|
||||
from app.core.config import get_settings
|
||||
from app.core.resource_client import ResourceClusterClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RAGService:
|
||||
"""
|
||||
Comprehensive RAG service with perfect tenant isolation.
|
||||
|
||||
GT 2.0 Security Principles:
|
||||
- Perfect tenant isolation (all operations user-scoped)
|
||||
- Stateless document processing (no data persistence in Resource Cluster)
|
||||
- Encrypted vector storage per tenant
|
||||
- Zero downtime compliance (async operations)
|
||||
"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.settings = get_settings()
|
||||
self.resource_client = ResourceClusterClient()
|
||||
|
||||
# Tenant-specific directories
|
||||
self.upload_directory = Path(self.settings.upload_directory)
|
||||
self.temp_directory = Path(self.settings.temp_directory)
|
||||
|
||||
# Ensure directories exist with secure permissions
|
||||
self._ensure_directories()
|
||||
|
||||
logger.info("RAG service initialized with tenant isolation")
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure required directories exist with secure permissions"""
|
||||
for directory in [self.upload_directory, self.temp_directory]:
|
||||
directory.mkdir(parents=True, exist_ok=True, mode=0o700)
|
||||
|
||||
async def create_rag_dataset(
|
||||
self,
|
||||
user_id: str,
|
||||
dataset_name: str,
|
||||
description: Optional[str] = None,
|
||||
chunking_strategy: str = "hybrid",
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 128,
|
||||
embedding_model: str = "BAAI/bge-m3"
|
||||
) -> RAGDataset:
|
||||
"""Create a new RAG dataset with tenant isolation"""
|
||||
try:
|
||||
# Check if dataset already exists for this user
|
||||
existing = await self.db.execute(
|
||||
select(RAGDataset).where(
|
||||
and_(
|
||||
RAGDataset.user_id == user_id,
|
||||
RAGDataset.dataset_name == dataset_name
|
||||
)
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError(f"Dataset '{dataset_name}' already exists for user")
|
||||
|
||||
# Create dataset
|
||||
dataset = RAGDataset(
|
||||
user_id=user_id,
|
||||
dataset_name=dataset_name,
|
||||
description=description,
|
||||
chunking_strategy=chunking_strategy,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
embedding_model=embedding_model
|
||||
)
|
||||
|
||||
self.db.add(dataset)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(dataset)
|
||||
|
||||
logger.info(f"Created RAG dataset '{dataset_name}' for user {user_id}")
|
||||
return dataset
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
logger.error(f"Failed to create RAG dataset: {e}")
|
||||
raise
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
user_id: str,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
file_type: str,
|
||||
dataset_id: Optional[str] = None
|
||||
) -> Document:
|
||||
"""Upload and store document with tenant isolation"""
|
||||
try:
|
||||
# Validate file
|
||||
file_extension = Path(filename).suffix.lower()
|
||||
if not file_extension:
|
||||
raise ValueError("File must have an extension")
|
||||
|
||||
# Generate secure filename
|
||||
file_hash = hashlib.sha256(file_content).hexdigest()[:16]
|
||||
secure_filename = f"{user_id}_{file_hash}_{filename}"
|
||||
|
||||
# Tenant-specific file path
|
||||
user_upload_dir = self.upload_directory / user_id
|
||||
user_upload_dir.mkdir(exist_ok=True, mode=0o700)
|
||||
|
||||
file_path = user_upload_dir / secure_filename
|
||||
|
||||
# Save file with secure permissions
|
||||
async with aiofiles.open(file_path, 'wb') as f:
|
||||
await f.write(file_content)
|
||||
|
||||
# Set file permissions (owner read/write only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
# Create document record
|
||||
document = Document(
|
||||
filename=secure_filename,
|
||||
original_filename=filename,
|
||||
file_path=str(file_path),
|
||||
file_type=file_type,
|
||||
file_extension=file_extension,
|
||||
file_size_bytes=len(file_content),
|
||||
uploaded_by=user_id,
|
||||
processing_status="pending"
|
||||
)
|
||||
|
||||
self.db.add(document)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(document)
|
||||
|
||||
# Add to dataset if specified
|
||||
if dataset_id:
|
||||
await self.add_document_to_dataset(user_id, document.id, dataset_id)
|
||||
|
||||
# Clear file content from memory
|
||||
del file_content
|
||||
gc.collect()
|
||||
|
||||
logger.info(f"Uploaded document '{filename}' for user {user_id}")
|
||||
return document
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
logger.error(f"Failed to upload document: {e}")
|
||||
# Clear sensitive data even on error
|
||||
if 'file_content' in locals():
|
||||
del file_content
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
user_id: str,
|
||||
document_id: int,
|
||||
tenant_id: str,
|
||||
chunking_strategy: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Process document into chunks and generate embeddings"""
|
||||
try:
|
||||
# Get document with ownership check
|
||||
document = await self._get_user_document(user_id, document_id)
|
||||
if not document:
|
||||
raise PermissionError("Document not found or access denied")
|
||||
|
||||
# Check if already processed
|
||||
if document.is_processing_complete():
|
||||
logger.info(f"Document {document_id} already processed")
|
||||
return {"status": "already_processed", "chunk_count": document.chunk_count}
|
||||
|
||||
# Mark as processing
|
||||
document.mark_processing_started()
|
||||
await self.db.commit()
|
||||
|
||||
# Read document file
|
||||
file_content = await self._read_document_file(document)
|
||||
|
||||
# Process document using Resource Cluster (stateless)
|
||||
chunks = await self.resource_client.process_document(
|
||||
content=file_content,
|
||||
document_type=document.file_extension,
|
||||
strategy_type=chunking_strategy or "hybrid",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Clear file content from memory immediately
|
||||
del file_content
|
||||
gc.collect()
|
||||
|
||||
if not chunks:
|
||||
raise ValueError("Document processing returned no chunks")
|
||||
|
||||
# Generate embeddings for chunks (stateless)
|
||||
chunk_texts = [chunk["text"] for chunk in chunks]
|
||||
embeddings = await self.resource_client.generate_document_embeddings(
|
||||
documents=chunk_texts,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if len(embeddings) != len(chunk_texts):
|
||||
raise ValueError("Embedding count mismatch with chunk count")
|
||||
|
||||
# Store vectors in ChromaDB via Resource Cluster
|
||||
dataset_name = f"doc_{document.id}"
|
||||
collection_created = await self.resource_client.create_vector_collection(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
dataset_name=dataset_name
|
||||
)
|
||||
|
||||
if not collection_created:
|
||||
raise RuntimeError("Failed to create vector collection")
|
||||
|
||||
# Store vectors with metadata
|
||||
chunk_metadata = [chunk["metadata"] for chunk in chunks]
|
||||
vector_stored = await self.resource_client.store_vectors(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
dataset_name=dataset_name,
|
||||
documents=chunk_texts,
|
||||
embeddings=embeddings,
|
||||
metadata=chunk_metadata
|
||||
)
|
||||
|
||||
if not vector_stored:
|
||||
raise RuntimeError("Failed to store vectors")
|
||||
|
||||
# Clear embedding data from memory
|
||||
del chunk_texts, embeddings
|
||||
gc.collect()
|
||||
|
||||
# Update document record
|
||||
vector_store_ids = [f"{tenant_id}:{user_id}:{dataset_name}"]
|
||||
document.mark_processing_complete(
|
||||
chunk_count=len(chunks),
|
||||
vector_store_ids=vector_store_ids
|
||||
)
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(f"Processed document {document_id} into {len(chunks)} chunks")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"document_id": document_id,
|
||||
"chunk_count": len(chunks),
|
||||
"vector_store_ids": vector_store_ids
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Mark document processing as failed
|
||||
if 'document' in locals() and document:
|
||||
document.mark_processing_failed({"error": str(e)})
|
||||
await self.db.commit()
|
||||
|
||||
logger.error(f"Failed to process document {document_id}: {e}")
|
||||
# Ensure memory cleanup
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def add_document_to_dataset(
|
||||
self,
|
||||
user_id: str,
|
||||
document_id: int,
|
||||
dataset_id: str
|
||||
) -> DatasetDocument:
|
||||
"""Add processed document to RAG dataset"""
|
||||
try:
|
||||
# Verify dataset ownership
|
||||
dataset = await self._get_user_dataset(user_id, dataset_id)
|
||||
if not dataset:
|
||||
raise PermissionError("Dataset not found or access denied")
|
||||
|
||||
# Verify document ownership
|
||||
document = await self._get_user_document(user_id, document_id)
|
||||
if not document:
|
||||
raise PermissionError("Document not found or access denied")
|
||||
|
||||
# Check if already in dataset
|
||||
existing = await self.db.execute(
|
||||
select(DatasetDocument).where(
|
||||
and_(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.document_id == document_id
|
||||
)
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Document already in dataset")
|
||||
|
||||
# Create dataset document relationship
|
||||
dataset_doc = DatasetDocument(
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
user_id=user_id,
|
||||
chunk_count=document.chunk_count,
|
||||
vector_count=document.chunk_count # Assuming 1 vector per chunk
|
||||
)
|
||||
|
||||
self.db.add(dataset_doc)
|
||||
|
||||
# Update dataset statistics
|
||||
dataset.document_count += 1
|
||||
dataset.chunk_count += document.chunk_count
|
||||
dataset.vector_count += document.chunk_count
|
||||
dataset.total_size_bytes += document.file_size_bytes
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(dataset_doc)
|
||||
|
||||
logger.info(f"Added document {document_id} to dataset {dataset_id}")
|
||||
return dataset_doc
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
logger.error(f"Failed to add document to dataset: {e}")
|
||||
raise
|
||||
|
||||
async def search_documents(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
dataset_ids: Optional[List[str]] = None,
|
||||
top_k: int = 5,
|
||||
similarity_threshold: float = 0.7
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search documents using RAG with tenant isolation"""
|
||||
try:
|
||||
# Generate query embedding
|
||||
query_embeddings = await self.resource_client.generate_query_embeddings(
|
||||
queries=[query],
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if not query_embeddings:
|
||||
raise ValueError("Failed to generate query embedding")
|
||||
|
||||
query_embedding = query_embeddings[0]
|
||||
|
||||
# Get user's datasets if not specified
|
||||
if not dataset_ids:
|
||||
datasets = await self.list_user_datasets(user_id)
|
||||
dataset_ids = [d.id for d in datasets]
|
||||
|
||||
# Search across datasets
|
||||
all_results = []
|
||||
for dataset_id in dataset_ids:
|
||||
# Verify dataset ownership
|
||||
dataset = await self._get_user_dataset(user_id, dataset_id)
|
||||
if not dataset:
|
||||
continue
|
||||
|
||||
# Search in ChromaDB
|
||||
dataset_name = f"dataset_{dataset_id}"
|
||||
results = await self.resource_client.search_vectors(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
dataset_name=dataset_name,
|
||||
query_embedding=query_embedding,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# Filter by similarity threshold and add dataset context
|
||||
for result in results:
|
||||
if result.get("similarity", 0) >= similarity_threshold:
|
||||
result["dataset_id"] = dataset_id
|
||||
result["dataset_name"] = dataset.dataset_name
|
||||
all_results.append(result)
|
||||
|
||||
# Sort by similarity and limit
|
||||
all_results.sort(key=lambda x: x.get("similarity", 0), reverse=True)
|
||||
final_results = all_results[:top_k]
|
||||
|
||||
# Clear query embedding from memory
|
||||
del query_embedding, query_embeddings
|
||||
gc.collect()
|
||||
|
||||
logger.info(f"Search found {len(final_results)} results for user {user_id}")
|
||||
return final_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search documents: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def get_document_context(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
document_id: int,
|
||||
query: str,
|
||||
context_size: int = 3
|
||||
) -> Dict[str, Any]:
|
||||
"""Get relevant context from a specific document"""
|
||||
try:
|
||||
# Verify document ownership
|
||||
document = await self._get_user_document(user_id, document_id)
|
||||
if not document:
|
||||
raise PermissionError("Document not found or access denied")
|
||||
|
||||
if not document.is_processing_complete():
|
||||
raise ValueError("Document not yet processed")
|
||||
|
||||
# Generate query embedding
|
||||
query_embeddings = await self.resource_client.generate_query_embeddings(
|
||||
queries=[query],
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
query_embedding = query_embeddings[0]
|
||||
|
||||
# Search document's vectors
|
||||
dataset_name = f"doc_{document_id}"
|
||||
results = await self.resource_client.search_vectors(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
dataset_name=dataset_name,
|
||||
query_embedding=query_embedding,
|
||||
top_k=context_size
|
||||
)
|
||||
|
||||
context = {
|
||||
"document_id": document_id,
|
||||
"document_name": document.original_filename,
|
||||
"query": query,
|
||||
"relevant_chunks": results,
|
||||
"context_text": "\n\n".join([r["document"] for r in results])
|
||||
}
|
||||
|
||||
# Clear query embedding from memory
|
||||
del query_embedding, query_embeddings
|
||||
gc.collect()
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get document context: {e}")
|
||||
gc.collect()
|
||||
raise
|
||||
|
||||
async def list_user_documents(
|
||||
self,
|
||||
user_id: str,
|
||||
status_filter: Optional[str] = None,
|
||||
offset: int = 0,
|
||||
limit: int = 50
|
||||
) -> List[Document]:
|
||||
"""List user's documents with optional filtering"""
|
||||
try:
|
||||
query = select(Document).where(Document.uploaded_by == user_id)
|
||||
|
||||
if status_filter:
|
||||
query = query.where(Document.processing_status == status_filter)
|
||||
|
||||
query = query.order_by(Document.created_at.desc())
|
||||
query = query.offset(offset).limit(limit)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
return list(documents)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list user documents: {e}")
|
||||
raise
|
||||
|
||||
async def list_user_datasets(
|
||||
self,
|
||||
user_id: str,
|
||||
include_stats: bool = True
|
||||
) -> List[RAGDataset]:
|
||||
"""List user's RAG datasets"""
|
||||
try:
|
||||
query = select(RAGDataset).where(RAGDataset.user_id == user_id)
|
||||
|
||||
if include_stats:
|
||||
query = query.options(selectinload(RAGDataset.documents))
|
||||
|
||||
query = query.order_by(RAGDataset.created_at.desc())
|
||||
|
||||
result = await self.db.execute(query)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
return list(datasets)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list user datasets: {e}")
|
||||
raise
|
||||
|
||||
async def delete_document(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
document_id: int
|
||||
) -> bool:
|
||||
"""Delete document and associated vectors"""
|
||||
try:
|
||||
# Verify document ownership
|
||||
document = await self._get_user_document(user_id, document_id)
|
||||
if not document:
|
||||
raise PermissionError("Document not found or access denied")
|
||||
|
||||
# Delete vectors from ChromaDB if processed
|
||||
if document.is_processing_complete():
|
||||
dataset_name = f"doc_{document_id}"
|
||||
await self.resource_client.delete_vector_collection(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
dataset_name=dataset_name
|
||||
)
|
||||
|
||||
# Delete physical file
|
||||
if document.file_exists():
|
||||
os.remove(document.get_absolute_file_path())
|
||||
|
||||
# Delete from database (cascade will handle related records)
|
||||
await self.db.delete(document)
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(f"Deleted document {document_id} for user {user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
logger.error(f"Failed to delete document: {e}")
|
||||
raise
|
||||
|
||||
async def delete_dataset(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
dataset_id: str
|
||||
) -> bool:
|
||||
"""Delete RAG dataset and associated vectors"""
|
||||
try:
|
||||
# Verify dataset ownership
|
||||
dataset = await self._get_user_dataset(user_id, dataset_id)
|
||||
if not dataset:
|
||||
raise PermissionError("Dataset not found or access denied")
|
||||
|
||||
# Delete vectors from ChromaDB
|
||||
dataset_name = f"dataset_{dataset_id}"
|
||||
await self.resource_client.delete_vector_collection(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
dataset_name=dataset_name
|
||||
)
|
||||
|
||||
# Delete from database (cascade will handle related records)
|
||||
await self.db.delete(dataset)
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(f"Deleted dataset {dataset_id} for user {user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
logger.error(f"Failed to delete dataset: {e}")
|
||||
raise
|
||||
|
||||
async def get_rag_statistics(
|
||||
self,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get RAG usage statistics for user"""
|
||||
try:
|
||||
# Document statistics
|
||||
doc_query = select(Document).where(Document.uploaded_by == user_id)
|
||||
doc_result = await self.db.execute(doc_query)
|
||||
documents = doc_result.scalars().all()
|
||||
|
||||
# Dataset statistics
|
||||
dataset_query = select(RAGDataset).where(RAGDataset.user_id == user_id)
|
||||
dataset_result = await self.db.execute(dataset_query)
|
||||
datasets = dataset_result.scalars().all()
|
||||
|
||||
total_size = sum(doc.file_size_bytes for doc in documents)
|
||||
total_chunks = sum(doc.chunk_count for doc in documents)
|
||||
|
||||
stats = {
|
||||
"user_id": user_id,
|
||||
"document_count": len(documents),
|
||||
"dataset_count": len(datasets),
|
||||
"total_size_bytes": total_size,
|
||||
"total_size_mb": round(total_size / (1024 * 1024), 2),
|
||||
"total_chunks": total_chunks,
|
||||
"processed_documents": len([d for d in documents if d.is_processing_complete()]),
|
||||
"pending_documents": len([d for d in documents if d.is_pending_processing()]),
|
||||
"failed_documents": len([d for d in documents if d.is_processing_failed()])
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get RAG statistics: {e}")
|
||||
raise
|
||||
|
||||
# Private helper methods
|
||||
|
||||
async def _get_user_document(self, user_id: str, document_id: int) -> Optional[Document]:
|
||||
"""Get document with ownership verification"""
|
||||
result = await self.db.execute(
|
||||
select(Document).where(
|
||||
and_(
|
||||
Document.id == document_id,
|
||||
Document.uploaded_by == user_id
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _get_user_dataset(self, user_id: str, dataset_id: str) -> Optional[RAGDataset]:
|
||||
"""Get dataset with ownership verification"""
|
||||
result = await self.db.execute(
|
||||
select(RAGDataset).where(
|
||||
and_(
|
||||
RAGDataset.id == dataset_id,
|
||||
RAGDataset.user_id == user_id
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _read_document_file(self, document: Document) -> bytes:
|
||||
"""Read document file content"""
|
||||
file_path = document.get_absolute_file_path()
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Document file not found: {file_path}")
|
||||
|
||||
async with aiofiles.open(file_path, 'rb') as f:
|
||||
content = await f.read()
|
||||
|
||||
return content
|
||||
|
||||
|
||||
# Factory function for dependency injection
|
||||
async def get_rag_service(db: AsyncSession = None) -> RAGService:
|
||||
"""Get RAG service instance"""
|
||||
if db is None:
|
||||
async with get_db_session() as session:
|
||||
return RAGService(session)
|
||||
return RAGService(db)
|
||||
371
apps/tenant-backend/app/services/resource_cluster_client.py
Normal file
371
apps/tenant-backend/app/services/resource_cluster_client.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""
|
||||
Resource Cluster Client for GT 2.0 Tenant Backend
|
||||
|
||||
Handles communication with the Resource Cluster for AI/ML operations.
|
||||
Manages capability token generation and LLM inference requests.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, AsyncIterator, List
|
||||
from datetime import timedelta
|
||||
import asyncio
|
||||
from jose import jwt
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def fetch_model_rate_limit(
|
||||
tenant_id: str,
|
||||
model_id: str,
|
||||
control_panel_url: str
|
||||
) -> int:
|
||||
"""
|
||||
Fetch rate limit for a model from Control Panel API.
|
||||
|
||||
Returns requests_per_minute (converted from max_requests_per_hour in database).
|
||||
Fails fast if Control Panel is unreachable (GT 2.0 principle: no fallbacks).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
model_id: Model identifier
|
||||
control_panel_url: Control Panel API base URL
|
||||
|
||||
Returns:
|
||||
Requests per minute limit
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Control Panel API is unreachable
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
url = f"{control_panel_url}/api/v1/tenant-models/tenants/{tenant_id}/models/{model_id}"
|
||||
logger.debug(f"Fetching rate limit from Control Panel: {url}")
|
||||
|
||||
response = await client.get(url)
|
||||
|
||||
if response.status_code == 404:
|
||||
logger.warning(f"Model {model_id} not configured for tenant {tenant_id}, using default")
|
||||
return 1000 # Default: 1000 requests/minute
|
||||
|
||||
response.raise_for_status()
|
||||
config = response.json()
|
||||
|
||||
# API now returns requests_per_minute directly (translated from DB per-hour)
|
||||
rate_limits = config.get("rate_limits", {})
|
||||
requests_per_minute = rate_limits.get("requests_per_minute", 1000)
|
||||
|
||||
logger.info(f"Model {model_id} rate limit: {requests_per_minute} requests/minute")
|
||||
return requests_per_minute
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Control Panel API error: {e.response.status_code}")
|
||||
raise RuntimeError(f"Failed to fetch rate limit: HTTP {e.response.status_code}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Control Panel API unreachable: {e}")
|
||||
raise RuntimeError(f"Control Panel unreachable at {control_panel_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error fetching rate limit: {e}")
|
||||
raise RuntimeError(f"Failed to fetch rate limit: {e}")
|
||||
|
||||
|
||||
class ResourceClusterClient:
|
||||
"""Client for communicating with GT 2.0 Resource Cluster"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.resource_cluster_url = self.settings.resource_cluster_url
|
||||
self.secret_key = self.settings.secret_key
|
||||
self.algorithm = "HS256"
|
||||
self.client = httpx.AsyncClient(timeout=60.0)
|
||||
|
||||
async def generate_capability_token(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
assistant_config: Dict[str, Any],
|
||||
expires_minutes: int = 30
|
||||
) -> str:
|
||||
"""
|
||||
Generate capability token for resource access.
|
||||
|
||||
Fetches real rate limits from Control Panel (single source of truth).
|
||||
Fails fast if Control Panel is unreachable.
|
||||
"""
|
||||
|
||||
# Extract capabilities from agent configuration
|
||||
capabilities = []
|
||||
|
||||
# Add LLM capability with real rate limit from Control Panel
|
||||
model = assistant_config.get("resource_preferences", {}).get("primary_llm", "llama-3.1-70b-versatile")
|
||||
|
||||
# Fetch real rate limit from Control Panel API
|
||||
requests_per_minute = await fetch_model_rate_limit(
|
||||
tenant_id=tenant_id,
|
||||
model_id=model,
|
||||
control_panel_url=self.settings.control_panel_url
|
||||
)
|
||||
|
||||
capabilities.append({
|
||||
"resource": f"llm:groq",
|
||||
"actions": ["inference", "streaming"],
|
||||
"constraints": { # Changed from "limits" to match LLM gateway expectations
|
||||
"max_tokens_per_request": assistant_config.get("resource_preferences", {}).get("max_tokens", 4000),
|
||||
"max_requests_per_minute": requests_per_minute # Real limit from database (converted from per-hour)
|
||||
}
|
||||
})
|
||||
|
||||
# Add RAG capabilities if configured
|
||||
if assistant_config.get("capabilities", {}).get("rag_enabled"):
|
||||
capabilities.append({
|
||||
"resource": "rag:semantic_search",
|
||||
"actions": ["search", "retrieve"],
|
||||
"limits": {
|
||||
"max_results": 10
|
||||
}
|
||||
})
|
||||
|
||||
# Add embedding capability if RAG is enabled
|
||||
if assistant_config.get("capabilities", {}).get("embeddings_enabled"):
|
||||
capabilities.append({
|
||||
"resource": "embedding:text-embedding-3-small",
|
||||
"actions": ["generate"],
|
||||
"limits": {
|
||||
"max_texts_per_request": 100
|
||||
}
|
||||
})
|
||||
|
||||
# Create token payload
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"tenant_id": tenant_id,
|
||||
"capabilities": capabilities,
|
||||
"exp": asyncio.get_event_loop().time() + (expires_minutes * 60),
|
||||
"iat": asyncio.get_event_loop().time()
|
||||
}
|
||||
|
||||
# Sign token
|
||||
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
return token
|
||||
|
||||
async def execute_inference(
|
||||
self,
|
||||
prompt: str,
|
||||
assistant_config: Dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
stream: bool = False,
|
||||
conversation_context: Optional[List[Dict[str, str]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute LLM inference via Resource Cluster"""
|
||||
|
||||
# Generate capability token (now async - fetches real rate limits)
|
||||
token = await self.generate_capability_token(user_id, tenant_id, assistant_config)
|
||||
|
||||
# Prepare request
|
||||
model = assistant_config.get("resource_preferences", {}).get("primary_llm", "llama-3.1-70b-versatile")
|
||||
temperature = assistant_config.get("resource_preferences", {}).get("temperature", 0.7)
|
||||
max_tokens = assistant_config.get("resource_preferences", {}).get("max_tokens", 4000)
|
||||
|
||||
# Build messages array with system prompt
|
||||
messages = []
|
||||
|
||||
# Add system prompt from agent
|
||||
system_prompt = assistant_config.get("prompt", "You are a helpful AI agent.")
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# Add conversation context if provided
|
||||
if conversation_context:
|
||||
messages.extend(conversation_context)
|
||||
|
||||
# Add current user message
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Prepare request payload
|
||||
request_data = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
"user_id": user_id,
|
||||
"tenant_id": tenant_id
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
if stream:
|
||||
return await self._stream_inference(request_data, headers)
|
||||
else:
|
||||
response = await self.client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/inference/",
|
||||
json=request_data,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error during inference: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise
|
||||
|
||||
async def _stream_inference(
|
||||
self,
|
||||
request_data: Dict[str, Any],
|
||||
headers: Dict[str, str]
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream inference responses"""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.resource_cluster_url}/api/v1/inference/stream",
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=60.0
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = line[6:] # Remove "data: " prefix
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
if "content" in chunk:
|
||||
yield chunk["content"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse streaming chunk: {data}")
|
||||
continue
|
||||
|
||||
async def generate_embeddings(
|
||||
self,
|
||||
texts: List[str],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
model: str = "text-embedding-3-small"
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings for texts"""
|
||||
|
||||
# Generate capability token with embedding permission
|
||||
assistant_config = {"capabilities": {"embeddings_enabled": True}}
|
||||
token = self.generate_capability_token(user_id, tenant_id, assistant_config)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
request_data = {
|
||||
"texts": texts,
|
||||
"model": model
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self.client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/embeddings/",
|
||||
json=request_data,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result.get("embeddings", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error during embedding generation: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings: {e}")
|
||||
raise
|
||||
|
||||
async def search_rag(
|
||||
self,
|
||||
query: str,
|
||||
collection: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
top_k: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search RAG collection for relevant documents"""
|
||||
|
||||
# Generate capability token with RAG permission
|
||||
assistant_config = {"capabilities": {"rag_enabled": True}}
|
||||
token = self.generate_capability_token(user_id, tenant_id, assistant_config)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
request_data = {
|
||||
"query": query,
|
||||
"collection": collection,
|
||||
"top_k": top_k
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self.client.post(
|
||||
f"{self.resource_cluster_url}/api/v1/rag/search",
|
||||
json=request_data,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result.get("results", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error during RAG search: {e}")
|
||||
# Return empty results on error for now
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching RAG: {e}")
|
||||
return []
|
||||
|
||||
async def get_agent_templates(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get available agent templates from Resource Cluster"""
|
||||
|
||||
# Generate basic capability token
|
||||
token = self.generate_capability_token(user_id, tenant_id, {})
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self.client.get(
|
||||
f"{self.resource_cluster_url}/api/v1/templates/",
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error fetching templates: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching templates: {e}")
|
||||
return []
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client"""
|
||||
await self.client.aclose()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit"""
|
||||
await self.close()
|
||||
173
apps/tenant-backend/app/services/resource_service.py
Normal file
173
apps/tenant-backend/app/services/resource_service.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Resource Service for GT 2.0 Tenant Backend
|
||||
|
||||
Provides access to Resource Cluster capabilities and services.
|
||||
This is a wrapper around the resource_cluster_client for agent services.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
from app.core.resource_client import ResourceClusterClient
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ResourceService:
|
||||
"""Service for accessing Resource Cluster capabilities"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize resource service"""
|
||||
self.settings = get_settings()
|
||||
self.client = ResourceClusterClient()
|
||||
|
||||
async def get_available_models(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get available AI models for user from Resource Cluster"""
|
||||
try:
|
||||
# Get models from Resource Cluster via capability token
|
||||
token = await self.client._get_capability_token(
|
||||
tenant_id=self.settings.tenant_domain,
|
||||
user_id=user_id,
|
||||
resources=['model_registry']
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Content-Type': 'application/json',
|
||||
'X-Tenant-ID': self.settings.tenant_domain,
|
||||
'X-User-ID': user_id
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.client.base_url}/api/v1/models/",
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
response_data = await response.json()
|
||||
models_data = response_data.get("models", [])
|
||||
|
||||
# Transform to expected format
|
||||
available_models = []
|
||||
for model in models_data:
|
||||
if model.get("status", {}).get("deployment") == "available":
|
||||
available_models.append({
|
||||
"model_id": model["id"],
|
||||
"name": model["name"],
|
||||
"provider": model["provider"],
|
||||
"capabilities": ["chat", "completion"],
|
||||
"context_length": model.get("performance", {}).get("context_window", 4000),
|
||||
"available": True
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(available_models)} models from Resource Cluster")
|
||||
return available_models
|
||||
else:
|
||||
logger.error(f"Resource Cluster returned {response.status}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get available models from Resource Cluster: {e}")
|
||||
return []
|
||||
|
||||
async def get_available_tools(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get available tools for user"""
|
||||
try:
|
||||
# Mock tools for development
|
||||
return [
|
||||
{
|
||||
"tool_id": "web_search",
|
||||
"name": "Web Search",
|
||||
"description": "Search the web for information",
|
||||
"available": True,
|
||||
"capabilities": ["search", "retrieve"]
|
||||
},
|
||||
{
|
||||
"tool_id": "document_analysis",
|
||||
"name": "Document Analysis",
|
||||
"description": "Analyze documents and extract information",
|
||||
"available": True,
|
||||
"capabilities": ["analyze", "extract", "summarize"]
|
||||
},
|
||||
{
|
||||
"tool_id": "code_execution",
|
||||
"name": "Code Execution",
|
||||
"description": "Execute code in safe sandbox",
|
||||
"available": True,
|
||||
"capabilities": ["execute", "debug", "test"]
|
||||
}
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get available tools: {e}")
|
||||
return []
|
||||
|
||||
async def validate_capabilities(self, user_id: str, capabilities: List[str]) -> bool:
|
||||
"""Validate that user has access to required capabilities"""
|
||||
try:
|
||||
# For development, allow all capabilities
|
||||
logger.info(f"Validating capabilities {capabilities} for user {user_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate capabilities: {e}")
|
||||
return False
|
||||
|
||||
async def execute_agent_task(self,
|
||||
agent_id: str,
|
||||
task_description: str,
|
||||
parameters: Dict[str, Any],
|
||||
user_id: str) -> Dict[str, Any]:
|
||||
"""Execute an agent task via Resource Cluster"""
|
||||
try:
|
||||
# Mock execution for development
|
||||
execution_result = {
|
||||
"execution_id": f"exec_{agent_id}_{int(datetime.now().timestamp())}",
|
||||
"status": "completed",
|
||||
"result": f"Mock execution of task: {task_description}",
|
||||
"output_artifacts": [],
|
||||
"tokens_used": 150,
|
||||
"cost_cents": 1,
|
||||
"execution_time_ms": 2500
|
||||
}
|
||||
|
||||
logger.info(f"Mock agent execution: {execution_result['execution_id']}")
|
||||
return execution_result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute agent task: {e}")
|
||||
return {
|
||||
"execution_id": f"failed_{agent_id}",
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def get_resource_usage(self, user_id: str, timeframe_hours: int = 24) -> Dict[str, Any]:
|
||||
"""Get resource usage statistics for user"""
|
||||
try:
|
||||
# Mock usage data for development
|
||||
return {
|
||||
"total_requests": 25,
|
||||
"total_tokens": 15000,
|
||||
"total_cost_cents": 150,
|
||||
"execution_count": 12,
|
||||
"average_response_time_ms": 1250,
|
||||
"timeframe_hours": timeframe_hours
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get resource usage: {e}")
|
||||
return {}
|
||||
|
||||
async def check_rate_limits(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Check current rate limits for user"""
|
||||
try:
|
||||
# Mock rate limit data for development
|
||||
return {
|
||||
"requests_per_minute": {"current": 5, "limit": 60, "reset_time": "2024-01-01T00:00:00Z"},
|
||||
"tokens_per_hour": {"current": 2500, "limit": 50000, "reset_time": "2024-01-01T00:00:00Z"},
|
||||
"executions_per_day": {"current": 12, "limit": 1000, "reset_time": "2024-01-01T00:00:00Z"}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check rate limits: {e}")
|
||||
return {}
|
||||
|
||||
# Import datetime for mock execution
|
||||
from datetime import datetime
|
||||
385
apps/tenant-backend/app/services/summarization_service.py
Normal file
385
apps/tenant-backend/app/services/summarization_service.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
GT 2.0 Summarization Service
|
||||
|
||||
Provides AI-powered summarization capabilities for documents and datasets.
|
||||
Uses the same pattern as conversation title generation with Llama 3.1 8B.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
from app.core.resource_client import ResourceClusterClient
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SummarizationService:
|
||||
"""
|
||||
Service for generating AI summaries of documents and datasets.
|
||||
|
||||
Uses the same approach as conversation title generation:
|
||||
- Llama 3.1 8B instant model
|
||||
- Low temperature for consistency
|
||||
- Resource cluster for AI responses
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str):
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.resource_client = ResourceClusterClient()
|
||||
self.summarization_model = "llama-3.1-8b-instant"
|
||||
self.settings = get_settings()
|
||||
|
||||
async def generate_document_summary(
|
||||
self,
|
||||
document_id: str,
|
||||
document_content: str,
|
||||
document_name: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Generate AI summary for a document using Llama 3.1 8B.
|
||||
|
||||
Args:
|
||||
document_id: UUID of the document
|
||||
document_content: Full text content of the document
|
||||
document_name: Original filename/name of the document
|
||||
|
||||
Returns:
|
||||
Generated summary string or None if failed
|
||||
"""
|
||||
try:
|
||||
# Truncate content to first 3000 chars (like conversation title generation)
|
||||
content_preview = document_content[:3000]
|
||||
|
||||
# Create summarization prompt
|
||||
prompt = f"""Summarize this document '{document_name}' in 2-3 sentences.
|
||||
Focus on the main topics, key information, and purpose of the document.
|
||||
|
||||
Document content:
|
||||
{content_preview}
|
||||
|
||||
Summary:"""
|
||||
|
||||
logger.info(f"Generating summary for document {document_id} ({document_name})")
|
||||
|
||||
# Call Resource Cluster with same pattern as conversation titles
|
||||
summary = await self._call_ai_for_summary(
|
||||
prompt=prompt,
|
||||
context_type="document",
|
||||
max_tokens=150
|
||||
)
|
||||
|
||||
if summary:
|
||||
# Store summary in database
|
||||
await self._store_document_summary(document_id, summary)
|
||||
logger.info(f"Generated summary for document {document_id}: {summary[:100]}...")
|
||||
return summary
|
||||
else:
|
||||
logger.warning(f"Failed to generate summary for document {document_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating document summary for {document_id}: {e}")
|
||||
return None
|
||||
|
||||
async def generate_dataset_summary(self, dataset_id: str) -> Optional[str]:
|
||||
"""
|
||||
Generate AI summary for a dataset based on its document summaries.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset
|
||||
|
||||
Returns:
|
||||
Generated dataset summary or None if failed
|
||||
"""
|
||||
try:
|
||||
# Get all document summaries in this dataset
|
||||
document_summaries = await self._get_document_summaries_for_dataset(dataset_id)
|
||||
|
||||
if not document_summaries:
|
||||
logger.info(f"No document summaries found for dataset {dataset_id}")
|
||||
return None
|
||||
|
||||
# Get dataset name for context
|
||||
dataset_info = await self._get_dataset_info(dataset_id)
|
||||
dataset_name = dataset_info.get('name', 'Unknown Dataset') if dataset_info else 'Unknown Dataset'
|
||||
|
||||
# Combine summaries for LLM context
|
||||
combined_summaries = "\n".join([
|
||||
f"- {doc['filename']}: {doc['summary']}"
|
||||
for doc in document_summaries
|
||||
if doc['summary'] # Only include docs that have summaries
|
||||
])
|
||||
|
||||
if not combined_summaries.strip():
|
||||
logger.info(f"No valid document summaries for dataset {dataset_id}")
|
||||
return None
|
||||
|
||||
# Create dataset summarization prompt
|
||||
prompt = f"""Based on these document summaries, create a comprehensive 3-4 sentence summary describing what the dataset '{dataset_name}' contains and its purpose:
|
||||
|
||||
Documents in dataset:
|
||||
{combined_summaries}
|
||||
|
||||
Dataset summary:"""
|
||||
|
||||
logger.info(f"Generating summary for dataset {dataset_id} ({dataset_name})")
|
||||
|
||||
# Call AI for dataset summary
|
||||
summary = await self._call_ai_for_summary(
|
||||
prompt=prompt,
|
||||
context_type="dataset",
|
||||
max_tokens=200
|
||||
)
|
||||
|
||||
if summary:
|
||||
# Store dataset summary in database
|
||||
await self._store_dataset_summary(dataset_id, summary)
|
||||
logger.info(f"Generated dataset summary for {dataset_id}: {summary[:100]}...")
|
||||
return summary
|
||||
else:
|
||||
logger.warning(f"Failed to generate summary for dataset {dataset_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating dataset summary for {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
async def update_dataset_summary_on_change(self, dataset_id: str) -> bool:
|
||||
"""
|
||||
Regenerate dataset summary when documents are added/removed.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset to update
|
||||
|
||||
Returns:
|
||||
True if summary was updated successfully
|
||||
"""
|
||||
try:
|
||||
summary = await self.generate_dataset_summary(dataset_id)
|
||||
return summary is not None
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating dataset summary for {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _call_ai_for_summary(
|
||||
self,
|
||||
prompt: str,
|
||||
context_type: str,
|
||||
max_tokens: int = 150
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Call Resource Cluster for AI summary generation.
|
||||
Uses ResourceClusterClient for consistent service discovery.
|
||||
|
||||
Args:
|
||||
prompt: The summarization prompt
|
||||
context_type: Type of summary (document, dataset)
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Generated summary text or None if failed
|
||||
"""
|
||||
try:
|
||||
# Prepare request payload (same format as conversation service)
|
||||
request_data = {
|
||||
"model": self.summarization_model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.3, # Lower temperature for consistent summaries
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": 1.0
|
||||
}
|
||||
|
||||
logger.info(f"Calling Resource Cluster for {context_type} summary generation")
|
||||
|
||||
# Use ResourceClusterClient for consistent service discovery and auth
|
||||
result = await self.resource_client.call_inference_endpoint(
|
||||
tenant_id=self.tenant_domain,
|
||||
user_id=self.user_id,
|
||||
endpoint="chat/completions",
|
||||
data=request_data
|
||||
)
|
||||
|
||||
if result and "choices" in result and len(result["choices"]) > 0:
|
||||
summary = result["choices"][0]["message"]["content"].strip()
|
||||
logger.info(f"✅ AI {context_type} summary generated successfully: {summary[:50]}...")
|
||||
return summary
|
||||
else:
|
||||
logger.error(f"❌ Invalid AI response format for {context_type} summary: {result}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error calling Resource Cluster for {context_type} summary: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _store_document_summary(self, document_id: str, summary: str) -> None:
|
||||
"""Store document summary in database"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
schema_name = self.settings.postgres_schema
|
||||
|
||||
await conn.execute(f"""
|
||||
UPDATE {schema_name}.documents
|
||||
SET summary = $1,
|
||||
summary_generated_at = $2,
|
||||
summary_model = $3
|
||||
WHERE id = $4
|
||||
""", summary, datetime.now(), self.summarization_model, document_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing document summary for {document_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _store_dataset_summary(self, dataset_id: str, summary: str) -> None:
|
||||
"""Store dataset summary in database"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
schema_name = self.settings.postgres_schema
|
||||
|
||||
await conn.execute(f"""
|
||||
UPDATE {schema_name}.datasets
|
||||
SET summary = $1,
|
||||
summary_generated_at = $2,
|
||||
summary_model = $3
|
||||
WHERE id = $4
|
||||
""", summary, datetime.now(), self.summarization_model, dataset_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing dataset summary for {dataset_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _get_document_summaries_for_dataset(self, dataset_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all document summaries for a dataset"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
schema_name = self.settings.postgres_schema
|
||||
|
||||
rows = await conn.fetch(f"""
|
||||
SELECT id, filename, original_filename, summary, summary_generated_at
|
||||
FROM {schema_name}.documents
|
||||
WHERE dataset_id = $1
|
||||
AND summary IS NOT NULL
|
||||
AND summary != ''
|
||||
ORDER BY created_at ASC
|
||||
""", dataset_id)
|
||||
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting document summaries for dataset {dataset_id}: {e}")
|
||||
return []
|
||||
|
||||
async def _get_dataset_info(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get basic dataset information"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
schema_name = self.settings.postgres_schema
|
||||
|
||||
row = await conn.fetchrow(f"""
|
||||
SELECT id, name, description
|
||||
FROM {schema_name}.datasets
|
||||
WHERE id = $1
|
||||
""", dataset_id)
|
||||
|
||||
return dict(row) if row else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting dataset info for {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_datasets_with_summaries(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all user-accessible datasets with their summaries.
|
||||
Used for context injection in chat.
|
||||
|
||||
Args:
|
||||
user_id: UUID of the user
|
||||
|
||||
Returns:
|
||||
List of datasets with summaries
|
||||
"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
schema_name = self.settings.postgres_schema
|
||||
|
||||
rows = await conn.fetch(f"""
|
||||
SELECT id, name, description, summary, summary_generated_at,
|
||||
document_count, total_size_bytes
|
||||
FROM {schema_name}.datasets
|
||||
WHERE (created_by = $1::uuid
|
||||
OR access_group IN ('team', 'organization'))
|
||||
AND is_active = true
|
||||
ORDER BY name ASC
|
||||
""", user_id)
|
||||
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting datasets with summaries for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_filtered_datasets_with_summaries(
|
||||
self,
|
||||
user_id: str,
|
||||
allowed_dataset_ids: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get datasets with summaries filtered by allowed dataset IDs.
|
||||
Used for agent-aware context injection in chat.
|
||||
|
||||
Args:
|
||||
user_id: User UUID string
|
||||
allowed_dataset_ids: List of dataset IDs the agent/user should see
|
||||
|
||||
Returns:
|
||||
List of dataset dictionaries with summaries, filtered by allowed IDs
|
||||
"""
|
||||
if not allowed_dataset_ids:
|
||||
logger.info(f"No allowed dataset IDs provided for user {user_id} - returning empty list")
|
||||
return []
|
||||
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
schema_name = self.settings.postgres_schema
|
||||
|
||||
# Convert dataset IDs to UUID format for query
|
||||
placeholders = ",".join(f"${i+2}::uuid" for i in range(len(allowed_dataset_ids)))
|
||||
|
||||
query = f"""
|
||||
SELECT id, name, description, summary, summary_generated_at,
|
||||
document_count, total_size_bytes
|
||||
FROM {schema_name}.datasets
|
||||
WHERE (created_by = $1::uuid
|
||||
OR access_group IN ('team', 'organization'))
|
||||
AND is_active = true
|
||||
AND id = ANY(ARRAY[{placeholders}])
|
||||
ORDER BY name ASC
|
||||
"""
|
||||
|
||||
params = [user_id] + allowed_dataset_ids
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
filtered_datasets = [dict(row) for row in rows]
|
||||
logger.info(f"Filtered datasets for user {user_id}: {len(filtered_datasets)} out of {len(allowed_dataset_ids)} requested")
|
||||
|
||||
return filtered_datasets
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filtered datasets with summaries for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# Factory function for dependency injection
|
||||
def get_summarization_service(tenant_domain: str, user_id: str) -> SummarizationService:
|
||||
"""Factory function to create SummarizationService instance"""
|
||||
return SummarizationService(tenant_domain, user_id)
|
||||
478
apps/tenant-backend/app/services/task_classifier.py
Normal file
478
apps/tenant-backend/app/services/task_classifier.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
GT 2.0 Task Classifier Service
|
||||
|
||||
Analyzes user queries to determine task complexity and required subagent orchestration.
|
||||
Enables highly agentic behavior by intelligently routing tasks to specialized subagents.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskComplexity(str, Enum):
|
||||
"""Task complexity levels"""
|
||||
SIMPLE = "simple" # Direct response, no tools needed
|
||||
TOOL_ASSISTED = "tool_assisted" # Single tool call required
|
||||
MULTI_STEP = "multi_step" # Multiple sequential steps
|
||||
RESEARCH = "research" # Information gathering from multiple sources
|
||||
IMPLEMENTATION = "implementation" # Code/config changes
|
||||
COMPLEX = "complex" # Requires multiple subagents
|
||||
|
||||
|
||||
class SubagentType(str, Enum):
|
||||
"""Types of specialized subagents"""
|
||||
RESEARCH = "research" # Information gathering
|
||||
PLANNING = "planning" # Task decomposition
|
||||
IMPLEMENTATION = "implementation" # Execution
|
||||
VALIDATION = "validation" # Quality checks
|
||||
SYNTHESIS = "synthesis" # Result aggregation
|
||||
MONITOR = "monitor" # Status checking
|
||||
ANALYST = "analyst" # Data analysis
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskClassification:
|
||||
"""Result of task classification"""
|
||||
complexity: TaskComplexity
|
||||
confidence: float
|
||||
primary_intent: str
|
||||
subagent_plan: List[Dict[str, Any]]
|
||||
estimated_tools: List[str]
|
||||
parallel_execution: bool
|
||||
requires_confirmation: bool
|
||||
reasoning: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubagentTask:
|
||||
"""Task definition for a subagent"""
|
||||
subagent_type: SubagentType
|
||||
task_description: str
|
||||
required_tools: List[str]
|
||||
depends_on: List[str] # IDs of other subagent tasks
|
||||
priority: int
|
||||
timeout_seconds: int
|
||||
input_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskClassifier:
|
||||
"""
|
||||
Classifies user tasks and creates subagent execution plans.
|
||||
|
||||
Analyzes query patterns, identifies required capabilities,
|
||||
and orchestrates multi-agent workflows for complex tasks.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Pattern matchers for different task types
|
||||
self.research_patterns = [
|
||||
r"find\s+(?:all\s+)?(?:information|documents?|files?)\s+about",
|
||||
r"search\s+for",
|
||||
r"what\s+(?:is|are|does|do)",
|
||||
r"explain\s+(?:how|what|why)",
|
||||
r"list\s+(?:all\s+)?the",
|
||||
r"show\s+me\s+(?:all\s+)?(?:the\s+)?",
|
||||
r"check\s+(?:the\s+)?(?:recent|latest|current)",
|
||||
]
|
||||
|
||||
self.implementation_patterns = [
|
||||
r"(?:create|add|implement|build|write)\s+(?:a\s+)?(?:new\s+)?",
|
||||
r"(?:update|modify|change|edit|fix)\s+(?:the\s+)?",
|
||||
r"(?:delete|remove|clean\s+up)\s+(?:the\s+)?",
|
||||
r"(?:deploy|install|configure|setup)\s+",
|
||||
r"(?:refactor|optimize|improve)\s+",
|
||||
]
|
||||
|
||||
self.analysis_patterns = [
|
||||
r"analyze\s+(?:the\s+)?",
|
||||
r"compare\s+(?:the\s+)?",
|
||||
r"summarize\s+(?:the\s+)?",
|
||||
r"evaluate\s+(?:the\s+)?",
|
||||
r"review\s+(?:the\s+)?",
|
||||
r"identify\s+(?:patterns|trends|issues)",
|
||||
]
|
||||
|
||||
self.multi_step_indicators = [
|
||||
r"(?:and\s+then|after\s+that|followed\s+by)",
|
||||
r"(?:first|second|third|finally)",
|
||||
r"(?:step\s+\d+|phase\s+\d+)",
|
||||
r"make\s+sure\s+(?:to\s+)?",
|
||||
r"(?:also|additionally|furthermore)",
|
||||
r"for\s+(?:each|every|all)\s+",
|
||||
]
|
||||
|
||||
logger.info("Task classifier initialized")
|
||||
|
||||
async def classify_task(
|
||||
self,
|
||||
query: str,
|
||||
conversation_context: Optional[List[Dict[str, Any]]] = None,
|
||||
available_tools: Optional[List[str]] = None
|
||||
) -> TaskClassification:
|
||||
"""
|
||||
Classify a user query and create execution plan.
|
||||
|
||||
Args:
|
||||
query: User's input query
|
||||
conversation_context: Previous messages for context
|
||||
available_tools: List of available MCP tools
|
||||
|
||||
Returns:
|
||||
TaskClassification with complexity assessment and execution plan
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
|
||||
# Analyze query characteristics
|
||||
is_research = self._matches_patterns(query_lower, self.research_patterns)
|
||||
is_implementation = self._matches_patterns(query_lower, self.implementation_patterns)
|
||||
is_analysis = self._matches_patterns(query_lower, self.analysis_patterns)
|
||||
is_multi_step = self._matches_patterns(query_lower, self.multi_step_indicators)
|
||||
|
||||
# Count potential tool requirements
|
||||
tool_indicators = self._identify_tool_indicators(query_lower)
|
||||
|
||||
# Determine complexity
|
||||
complexity = self._determine_complexity(
|
||||
is_research, is_implementation, is_analysis, is_multi_step, tool_indicators
|
||||
)
|
||||
|
||||
# Create subagent plan based on complexity
|
||||
subagent_plan = await self._create_subagent_plan(
|
||||
query, complexity, is_research, is_implementation, is_analysis, available_tools
|
||||
)
|
||||
|
||||
# Estimate required tools
|
||||
estimated_tools = self._estimate_required_tools(query_lower, available_tools)
|
||||
|
||||
# Determine if parallel execution is possible
|
||||
parallel_execution = self._can_execute_parallel(subagent_plan)
|
||||
|
||||
# Check if confirmation is needed
|
||||
requires_confirmation = complexity in [TaskComplexity.IMPLEMENTATION, TaskComplexity.COMPLEX]
|
||||
|
||||
# Generate reasoning
|
||||
reasoning = self._generate_reasoning(
|
||||
query, complexity, is_research, is_implementation, is_analysis, is_multi_step
|
||||
)
|
||||
|
||||
return TaskClassification(
|
||||
complexity=complexity,
|
||||
confidence=self._calculate_confidence(complexity, subagent_plan),
|
||||
primary_intent=self._identify_primary_intent(is_research, is_implementation, is_analysis),
|
||||
subagent_plan=subagent_plan,
|
||||
estimated_tools=estimated_tools,
|
||||
parallel_execution=parallel_execution,
|
||||
requires_confirmation=requires_confirmation,
|
||||
reasoning=reasoning
|
||||
)
|
||||
|
||||
def _matches_patterns(self, text: str, patterns: List[str]) -> bool:
|
||||
"""Check if text matches any of the patterns"""
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _identify_tool_indicators(self, query: str) -> List[str]:
|
||||
"""Identify potential tool usage from query"""
|
||||
indicators = []
|
||||
|
||||
tool_keywords = {
|
||||
"search": ["search", "find", "look for", "locate"],
|
||||
"database": ["database", "query", "sql", "records"],
|
||||
"file": ["file", "document", "upload", "download"],
|
||||
"api": ["api", "endpoint", "service", "integration"],
|
||||
"conversation": ["conversation", "chat", "history", "previous"],
|
||||
"web": ["website", "url", "browse", "fetch"],
|
||||
}
|
||||
|
||||
for tool_type, keywords in tool_keywords.items():
|
||||
if any(keyword in query for keyword in keywords):
|
||||
indicators.append(tool_type)
|
||||
|
||||
return indicators
|
||||
|
||||
def _determine_complexity(
|
||||
self,
|
||||
is_research: bool,
|
||||
is_implementation: bool,
|
||||
is_analysis: bool,
|
||||
is_multi_step: bool,
|
||||
tool_indicators: List[str]
|
||||
) -> TaskComplexity:
|
||||
"""Determine task complexity based on characteristics"""
|
||||
|
||||
# Count complexity factors
|
||||
factors = sum([is_research, is_implementation, is_analysis, is_multi_step])
|
||||
tool_count = len(tool_indicators)
|
||||
|
||||
if factors == 0 and tool_count == 0:
|
||||
return TaskComplexity.SIMPLE
|
||||
elif factors == 1 and tool_count <= 1:
|
||||
return TaskComplexity.TOOL_ASSISTED
|
||||
elif is_multi_step or factors >= 2:
|
||||
if is_implementation:
|
||||
return TaskComplexity.IMPLEMENTATION
|
||||
elif is_research and (is_analysis or tool_count > 2):
|
||||
return TaskComplexity.RESEARCH
|
||||
else:
|
||||
return TaskComplexity.MULTI_STEP
|
||||
elif factors > 2 or (is_multi_step and is_implementation):
|
||||
return TaskComplexity.COMPLEX
|
||||
else:
|
||||
return TaskComplexity.TOOL_ASSISTED
|
||||
|
||||
async def _create_subagent_plan(
|
||||
self,
|
||||
query: str,
|
||||
complexity: TaskComplexity,
|
||||
is_research: bool,
|
||||
is_implementation: bool,
|
||||
is_analysis: bool,
|
||||
available_tools: Optional[List[str]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Create execution plan with subagents"""
|
||||
plan = []
|
||||
|
||||
if complexity == TaskComplexity.SIMPLE:
|
||||
# No subagents needed
|
||||
return []
|
||||
|
||||
elif complexity == TaskComplexity.TOOL_ASSISTED:
|
||||
# Single subagent for tool execution
|
||||
plan.append({
|
||||
"id": "tool_executor_1",
|
||||
"type": SubagentType.IMPLEMENTATION,
|
||||
"task": f"Execute required tool for: {query[:100]}",
|
||||
"depends_on": [],
|
||||
"priority": 1
|
||||
})
|
||||
|
||||
elif complexity == TaskComplexity.RESEARCH:
|
||||
# Research workflow
|
||||
plan.extend([
|
||||
{
|
||||
"id": "researcher_1",
|
||||
"type": SubagentType.RESEARCH,
|
||||
"task": f"Gather information about: {query[:100]}",
|
||||
"depends_on": [],
|
||||
"priority": 1
|
||||
},
|
||||
{
|
||||
"id": "analyst_1",
|
||||
"type": SubagentType.ANALYST,
|
||||
"task": "Analyze gathered information",
|
||||
"depends_on": ["researcher_1"],
|
||||
"priority": 2
|
||||
},
|
||||
{
|
||||
"id": "synthesizer_1",
|
||||
"type": SubagentType.SYNTHESIS,
|
||||
"task": "Compile findings into comprehensive response",
|
||||
"depends_on": ["analyst_1"],
|
||||
"priority": 3
|
||||
}
|
||||
])
|
||||
|
||||
elif complexity == TaskComplexity.IMPLEMENTATION:
|
||||
# Implementation workflow
|
||||
plan.extend([
|
||||
{
|
||||
"id": "planner_1",
|
||||
"type": SubagentType.PLANNING,
|
||||
"task": f"Create implementation plan for: {query[:100]}",
|
||||
"depends_on": [],
|
||||
"priority": 1
|
||||
},
|
||||
{
|
||||
"id": "implementer_1",
|
||||
"type": SubagentType.IMPLEMENTATION,
|
||||
"task": "Execute implementation steps",
|
||||
"depends_on": ["planner_1"],
|
||||
"priority": 2
|
||||
},
|
||||
{
|
||||
"id": "validator_1",
|
||||
"type": SubagentType.VALIDATION,
|
||||
"task": "Validate implementation results",
|
||||
"depends_on": ["implementer_1"],
|
||||
"priority": 3
|
||||
}
|
||||
])
|
||||
|
||||
elif complexity in [TaskComplexity.MULTI_STEP, TaskComplexity.COMPLEX]:
|
||||
# Complex multi-agent workflow
|
||||
if is_research:
|
||||
plan.append({
|
||||
"id": "researcher_1",
|
||||
"type": SubagentType.RESEARCH,
|
||||
"task": "Research required information",
|
||||
"depends_on": [],
|
||||
"priority": 1
|
||||
})
|
||||
|
||||
plan.append({
|
||||
"id": "planner_1",
|
||||
"type": SubagentType.PLANNING,
|
||||
"task": f"Decompose complex task: {query[:100]}",
|
||||
"depends_on": ["researcher_1"] if is_research else [],
|
||||
"priority": 2
|
||||
})
|
||||
|
||||
if is_implementation:
|
||||
plan.append({
|
||||
"id": "implementer_1",
|
||||
"type": SubagentType.IMPLEMENTATION,
|
||||
"task": "Execute planned steps",
|
||||
"depends_on": ["planner_1"],
|
||||
"priority": 3
|
||||
})
|
||||
|
||||
if is_analysis:
|
||||
plan.append({
|
||||
"id": "analyst_1",
|
||||
"type": SubagentType.ANALYST,
|
||||
"task": "Analyze results and patterns",
|
||||
"depends_on": ["implementer_1"] if is_implementation else ["planner_1"],
|
||||
"priority": 4
|
||||
})
|
||||
|
||||
# Always add synthesis for complex tasks
|
||||
final_deps = []
|
||||
if is_analysis:
|
||||
final_deps.append("analyst_1")
|
||||
elif is_implementation:
|
||||
final_deps.append("implementer_1")
|
||||
else:
|
||||
final_deps.append("planner_1")
|
||||
|
||||
plan.append({
|
||||
"id": "synthesizer_1",
|
||||
"type": SubagentType.SYNTHESIS,
|
||||
"task": "Synthesize all results into final response",
|
||||
"depends_on": final_deps,
|
||||
"priority": 5
|
||||
})
|
||||
|
||||
return plan
|
||||
|
||||
def _estimate_required_tools(
|
||||
self,
|
||||
query: str,
|
||||
available_tools: Optional[List[str]]
|
||||
) -> List[str]:
|
||||
"""Estimate which tools will be needed"""
|
||||
if not available_tools:
|
||||
return []
|
||||
|
||||
estimated = []
|
||||
|
||||
# Map query patterns to tools
|
||||
tool_patterns = {
|
||||
"search_datasets": ["search", "find", "look for", "dataset", "document"],
|
||||
"brave_search": ["web", "internet", "online", "website", "current"],
|
||||
"list_directory": ["files", "directory", "folder", "ls"],
|
||||
"read_file": ["read", "view", "open", "file content"],
|
||||
"write_file": ["write", "create", "save", "generate file"],
|
||||
}
|
||||
|
||||
for tool in available_tools:
|
||||
if tool in tool_patterns:
|
||||
if any(pattern in query for pattern in tool_patterns[tool]):
|
||||
estimated.append(tool)
|
||||
|
||||
return estimated
|
||||
|
||||
def _can_execute_parallel(self, subagent_plan: List[Dict[str, Any]]) -> bool:
|
||||
"""Check if any subagents can run in parallel"""
|
||||
if len(subagent_plan) < 2:
|
||||
return False
|
||||
|
||||
# Group by priority to find parallel opportunities
|
||||
priority_groups = {}
|
||||
for agent in subagent_plan:
|
||||
priority = agent.get("priority", 1)
|
||||
if priority not in priority_groups:
|
||||
priority_groups[priority] = []
|
||||
priority_groups[priority].append(agent)
|
||||
|
||||
# If any priority level has multiple agents, parallel execution is possible
|
||||
return any(len(agents) > 1 for agents in priority_groups.values())
|
||||
|
||||
def _calculate_confidence(
|
||||
self,
|
||||
complexity: TaskComplexity,
|
||||
subagent_plan: List[Dict[str, Any]]
|
||||
) -> float:
|
||||
"""Calculate confidence score for classification"""
|
||||
base_confidence = {
|
||||
TaskComplexity.SIMPLE: 0.95,
|
||||
TaskComplexity.TOOL_ASSISTED: 0.9,
|
||||
TaskComplexity.MULTI_STEP: 0.85,
|
||||
TaskComplexity.RESEARCH: 0.85,
|
||||
TaskComplexity.IMPLEMENTATION: 0.8,
|
||||
TaskComplexity.COMPLEX: 0.75
|
||||
}
|
||||
|
||||
confidence = base_confidence.get(complexity, 0.7)
|
||||
|
||||
# Adjust based on plan clarity
|
||||
if len(subagent_plan) > 0:
|
||||
confidence += 0.05
|
||||
|
||||
return min(confidence, 1.0)
|
||||
|
||||
def _identify_primary_intent(
|
||||
self,
|
||||
is_research: bool,
|
||||
is_implementation: bool,
|
||||
is_analysis: bool
|
||||
) -> str:
|
||||
"""Identify the primary intent of the query"""
|
||||
if is_implementation:
|
||||
return "implementation"
|
||||
elif is_research:
|
||||
return "research"
|
||||
elif is_analysis:
|
||||
return "analysis"
|
||||
else:
|
||||
return "general"
|
||||
|
||||
def _generate_reasoning(
|
||||
self,
|
||||
query: str,
|
||||
complexity: TaskComplexity,
|
||||
is_research: bool,
|
||||
is_implementation: bool,
|
||||
is_analysis: bool,
|
||||
is_multi_step: bool
|
||||
) -> str:
|
||||
"""Generate reasoning explanation for classification"""
|
||||
reasons = []
|
||||
|
||||
if is_multi_step:
|
||||
reasons.append("Query indicates multiple sequential steps")
|
||||
if is_research:
|
||||
reasons.append("Information gathering required")
|
||||
if is_implementation:
|
||||
reasons.append("Code or configuration changes needed")
|
||||
if is_analysis:
|
||||
reasons.append("Data analysis and synthesis required")
|
||||
|
||||
if complexity == TaskComplexity.COMPLEX:
|
||||
reasons.append("Multiple specialized agents needed for comprehensive execution")
|
||||
elif complexity == TaskComplexity.SIMPLE:
|
||||
reasons.append("Straightforward query with direct response possible")
|
||||
|
||||
return ". ".join(reasons) if reasons else "Standard query processing"
|
||||
|
||||
|
||||
# Factory function
|
||||
def get_task_classifier() -> TaskClassifier:
|
||||
"""Get task classifier instance"""
|
||||
return TaskClassifier()
|
||||
410
apps/tenant-backend/app/services/team_access_service.py
Normal file
410
apps/tenant-backend/app/services/team_access_service.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Team Access Control Service for GT 2.0 Tenant Backend
|
||||
|
||||
Implements team-based access control with file-based simplicity.
|
||||
Follows GT 2.0's principle of "Zero Complexity Addition"
|
||||
- Simple role-based permissions stored in files
|
||||
- Fast access checks using SQLite indexes
|
||||
- Perfect tenant isolation maintained
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_
|
||||
import logging
|
||||
|
||||
from app.models.team import Team, TeamRole, OrganizationSettings
|
||||
from app.models.agent import Agent
|
||||
from app.models.document import RAGDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeamAccessService:
|
||||
"""Elegant team-based access control following GT 2.0 philosophy"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self._role_cache = {} # Cache role permissions in memory
|
||||
|
||||
async def check_team_access(
|
||||
self,
|
||||
user_email: str,
|
||||
resource: Any,
|
||||
action: str,
|
||||
user_teams: Optional[List[int]] = None
|
||||
) -> bool:
|
||||
"""Check if user has access to perform action on resource
|
||||
|
||||
GT 2.0 Design: Simple, fast access checks without complex hierarchies
|
||||
"""
|
||||
try:
|
||||
# Step 1: Check resource ownership (fastest check)
|
||||
if hasattr(resource, 'created_by') and resource.created_by == user_email:
|
||||
return True # Owners always have full access
|
||||
|
||||
# Step 2: Check visibility-based access
|
||||
if hasattr(resource, 'visibility'):
|
||||
# Organization-wide resources
|
||||
if resource.visibility == "organization":
|
||||
return self._check_organization_action(action)
|
||||
|
||||
# Team resources
|
||||
if resource.visibility == "team" and resource.tenant_id:
|
||||
if not user_teams:
|
||||
user_teams = await self.get_user_teams(user_email)
|
||||
|
||||
if resource.tenant_id in user_teams:
|
||||
return await self._check_team_action(
|
||||
user_email,
|
||||
resource.tenant_id,
|
||||
action
|
||||
)
|
||||
|
||||
# Explicitly shared resources
|
||||
if hasattr(resource, 'shared_with') and resource.shared_with:
|
||||
if user_email in resource.shared_with:
|
||||
return self._check_shared_action(action)
|
||||
|
||||
# Step 3: Default deny for private resources not owned by user
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking team access: {e}")
|
||||
return False # Fail closed on errors
|
||||
|
||||
async def get_user_teams(self, user_email: str) -> List[int]:
|
||||
"""Get all teams the user belongs to
|
||||
|
||||
GT 2.0: Simple file-based membership check
|
||||
"""
|
||||
try:
|
||||
# Query all active teams
|
||||
result = await self.db.execute(
|
||||
select(Team).where(Team.is_active == True)
|
||||
)
|
||||
teams = result.scalars().all()
|
||||
|
||||
user_team_ids = []
|
||||
for team in teams:
|
||||
if team.is_member(user_email):
|
||||
user_team_ids.append(team.id)
|
||||
|
||||
return user_team_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user teams: {e}")
|
||||
return []
|
||||
|
||||
async def get_user_role_in_team(self, user_email: str, team_id: int) -> Optional[str]:
|
||||
"""Get user's role in a specific team"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(Team).where(Team.id == team_id)
|
||||
)
|
||||
team = result.scalar_one_or_none()
|
||||
|
||||
if team:
|
||||
return team.get_member_role(user_email)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user role: {e}")
|
||||
return None
|
||||
|
||||
async def get_team_resources(
|
||||
self,
|
||||
team_id: int,
|
||||
resource_type: str,
|
||||
user_email: str
|
||||
) -> List[Any]:
|
||||
"""Get all resources accessible to a team
|
||||
|
||||
GT 2.0: Simple visibility-based filtering
|
||||
"""
|
||||
try:
|
||||
if resource_type == "agent":
|
||||
# Get team and organization agents
|
||||
result = await self.db.execute(
|
||||
select(Agent).where(
|
||||
and_(
|
||||
Agent.is_active == True,
|
||||
or_(
|
||||
and_(
|
||||
Agent.visibility == "team",
|
||||
Agent.tenant_id == team_id
|
||||
),
|
||||
Agent.visibility == "organization"
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
elif resource_type == "dataset":
|
||||
# Get team and organization datasets
|
||||
result = await self.db.execute(
|
||||
select(RAGDataset).where(
|
||||
and_(
|
||||
RAGDataset.status == "active",
|
||||
or_(
|
||||
and_(
|
||||
RAGDataset.visibility == "team",
|
||||
RAGDataset.tenant_id == team_id
|
||||
),
|
||||
RAGDataset.visibility == "organization"
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting team resources: {e}")
|
||||
return []
|
||||
|
||||
async def share_with_team(
|
||||
self,
|
||||
resource: Any,
|
||||
team_id: int,
|
||||
sharer_email: str
|
||||
) -> bool:
|
||||
"""Share a resource with a team
|
||||
|
||||
GT 2.0: Simple visibility update, no complex permissions
|
||||
"""
|
||||
try:
|
||||
# Verify sharer owns the resource or has sharing permission
|
||||
if not self._can_share_resource(resource, sharer_email):
|
||||
return False
|
||||
|
||||
# Update resource visibility
|
||||
resource.visibility = "team"
|
||||
resource.tenant_id = team_id
|
||||
|
||||
await self.db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sharing with team: {e}")
|
||||
await self.db.rollback()
|
||||
return False
|
||||
|
||||
async def share_with_users(
|
||||
self,
|
||||
resource: Any,
|
||||
user_emails: List[str],
|
||||
sharer_email: str
|
||||
) -> bool:
|
||||
"""Share a resource with specific users
|
||||
|
||||
GT 2.0: Simple list-based sharing
|
||||
"""
|
||||
try:
|
||||
# Verify sharer owns the resource
|
||||
if not self._can_share_resource(resource, sharer_email):
|
||||
return False
|
||||
|
||||
# Update shared_with list
|
||||
current_shared = resource.shared_with or []
|
||||
new_shared = list(set(current_shared + user_emails))
|
||||
resource.shared_with = new_shared
|
||||
|
||||
await self.db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sharing with users: {e}")
|
||||
await self.db.rollback()
|
||||
return False
|
||||
|
||||
async def create_team(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
team_type: str,
|
||||
creator_email: str
|
||||
) -> Optional[Team]:
|
||||
"""Create a new team
|
||||
|
||||
GT 2.0: File-based team with simple SQLite reference
|
||||
"""
|
||||
try:
|
||||
# Check if user can create teams
|
||||
org_settings = await self._get_organization_settings()
|
||||
if not org_settings.allow_team_creation:
|
||||
logger.warning(f"Team creation disabled for organization")
|
||||
return None
|
||||
|
||||
# Check user's team limit
|
||||
user_teams = await self.get_user_teams(creator_email)
|
||||
if len(user_teams) >= org_settings.max_teams_per_user:
|
||||
logger.warning(f"User {creator_email} reached team limit")
|
||||
return None
|
||||
|
||||
# Create team
|
||||
team = Team(
|
||||
name=name,
|
||||
description=description,
|
||||
team_type=team_type,
|
||||
created_by=creator_email
|
||||
)
|
||||
|
||||
# Initialize with placeholder paths
|
||||
team.config_file_path = "placeholder"
|
||||
team.members_file_path = "placeholder"
|
||||
|
||||
# Save to get ID
|
||||
self.db.add(team)
|
||||
await self.db.flush()
|
||||
|
||||
# Initialize proper file paths
|
||||
team.initialize_file_paths()
|
||||
|
||||
# Add creator as owner
|
||||
team.add_member(creator_email, "owner", {"joined_as": "creator"})
|
||||
|
||||
# Save initial config
|
||||
config = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"team_type": team_type,
|
||||
"created_by": creator_email,
|
||||
"settings": {}
|
||||
}
|
||||
team.save_config_to_file(config)
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(team)
|
||||
|
||||
logger.info(f"Created team {team.id} by {creator_email}")
|
||||
return team
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating team: {e}")
|
||||
await self.db.rollback()
|
||||
return None
|
||||
|
||||
# Private helper methods
|
||||
|
||||
def _check_organization_action(self, action: str) -> bool:
|
||||
"""Check if action is allowed for organization resources"""
|
||||
# Organization resources are viewable by all
|
||||
if action in ["view", "use", "read"]:
|
||||
return True
|
||||
# Only owners can modify
|
||||
return False
|
||||
|
||||
async def _check_team_action(
|
||||
self,
|
||||
user_email: str,
|
||||
team_id: int,
|
||||
action: str
|
||||
) -> bool:
|
||||
"""Check if user can perform action on team resource"""
|
||||
role = await self.get_user_role_in_team(user_email, team_id)
|
||||
if not role:
|
||||
return False
|
||||
|
||||
# Get role permissions
|
||||
permissions = await self._get_role_permissions(role)
|
||||
|
||||
# Map action to permission
|
||||
action_permission_map = {
|
||||
"view": "can_view_resources",
|
||||
"read": "can_view_resources",
|
||||
"use": "can_view_resources",
|
||||
"create": "can_create_resources",
|
||||
"edit": "can_edit_team_resources",
|
||||
"update": "can_edit_team_resources",
|
||||
"delete": "can_delete_team_resources",
|
||||
"manage_members": "can_manage_members",
|
||||
"manage_team": "can_manage_team",
|
||||
}
|
||||
|
||||
permission_needed = action_permission_map.get(action, None)
|
||||
if permission_needed:
|
||||
return permissions.get(permission_needed, False)
|
||||
|
||||
return False
|
||||
|
||||
def _check_shared_action(self, action: str) -> bool:
|
||||
"""Check if action is allowed for shared resources"""
|
||||
# Shared resources can be viewed and used
|
||||
if action in ["view", "use", "read"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _can_share_resource(self, resource: Any, user_email: str) -> bool:
|
||||
"""Check if user can share a resource"""
|
||||
# Owners can always share
|
||||
if hasattr(resource, 'created_by') and resource.created_by == user_email:
|
||||
return True
|
||||
|
||||
# Team leads can share team resources
|
||||
# (Would need to check team role here in full implementation)
|
||||
|
||||
return False
|
||||
|
||||
async def _get_role_permissions(self, role_name: str) -> Dict[str, bool]:
|
||||
"""Get permissions for a role (with caching)"""
|
||||
if role_name in self._role_cache:
|
||||
return self._role_cache[role_name]
|
||||
|
||||
result = await self.db.execute(
|
||||
select(TeamRole).where(TeamRole.name == role_name)
|
||||
)
|
||||
role = result.scalar_one_or_none()
|
||||
|
||||
if role:
|
||||
permissions = {
|
||||
"can_view_resources": role.can_view_resources,
|
||||
"can_create_resources": role.can_create_resources,
|
||||
"can_edit_team_resources": role.can_edit_team_resources,
|
||||
"can_delete_team_resources": role.can_delete_team_resources,
|
||||
"can_manage_members": role.can_manage_members,
|
||||
"can_manage_team": role.can_manage_team,
|
||||
}
|
||||
self._role_cache[role_name] = permissions
|
||||
return permissions
|
||||
|
||||
# Default to viewer permissions
|
||||
return {
|
||||
"can_view_resources": True,
|
||||
"can_create_resources": False,
|
||||
"can_edit_team_resources": False,
|
||||
"can_delete_team_resources": False,
|
||||
"can_manage_members": False,
|
||||
"can_manage_team": False,
|
||||
}
|
||||
|
||||
async def _get_organization_settings(self) -> OrganizationSettings:
|
||||
"""Get organization settings (create default if not exists)"""
|
||||
result = await self.db.execute(
|
||||
select(OrganizationSettings).limit(1)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
# Create default settings
|
||||
settings = OrganizationSettings(
|
||||
organization_name="Default Organization",
|
||||
organization_domain="example.com"
|
||||
)
|
||||
settings.config_file_path = "placeholder"
|
||||
self.db.add(settings)
|
||||
await self.db.flush()
|
||||
|
||||
settings.initialize_file_paths()
|
||||
settings.save_config_to_file({
|
||||
"initialized": True,
|
||||
"default_config": True
|
||||
})
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(settings)
|
||||
|
||||
return settings
|
||||
2361
apps/tenant-backend/app/services/team_service.py
Normal file
2361
apps/tenant-backend/app/services/team_service.py
Normal file
File diff suppressed because it is too large
Load Diff
233
apps/tenant-backend/app/services/user_service.py
Normal file
233
apps/tenant-backend/app/services/user_service.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
GT 2.0 User Service - User Preferences Management
|
||||
|
||||
Manages user preferences including favorite agents using PostgreSQL + PGVector backend.
|
||||
Perfect tenant isolation - each tenant has separate user data.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from app.core.config import get_settings
|
||||
from app.core.postgresql_client import get_postgresql_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserService:
|
||||
"""GT 2.0 PostgreSQL User Service with Perfect Tenant Isolation"""
|
||||
|
||||
def __init__(self, tenant_domain: str, user_id: str, user_email: str = None):
|
||||
"""Initialize with tenant and user isolation using PostgreSQL storage"""
|
||||
self.tenant_domain = tenant_domain
|
||||
self.user_id = user_id
|
||||
self.user_email = user_email or user_id
|
||||
self.settings = get_settings()
|
||||
|
||||
logger.info(f"User service initialized for {tenant_domain}/{user_id} (email: {self.user_email})")
|
||||
|
||||
async def _get_user_id(self, pg_client) -> Optional[str]:
|
||||
"""Get user ID from email or user_id with fallback"""
|
||||
user_lookup_query = """
|
||||
SELECT id FROM users
|
||||
WHERE (email = $1 OR id::text = $1 OR username = $1)
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
user_id = await pg_client.fetch_scalar(user_lookup_query, self.user_email, self.tenant_domain)
|
||||
if not user_id:
|
||||
# If not found by email, try by user_id
|
||||
user_id = await pg_client.fetch_scalar(user_lookup_query, self.user_id, self.tenant_domain)
|
||||
|
||||
return user_id
|
||||
|
||||
async def get_user_preferences(self) -> Dict[str, Any]:
|
||||
"""Get user preferences from PostgreSQL"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
user_id = await self._get_user_id(pg_client)
|
||||
|
||||
if not user_id:
|
||||
logger.warning(f"User not found: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}")
|
||||
return {}
|
||||
|
||||
query = """
|
||||
SELECT preferences
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
|
||||
"""
|
||||
|
||||
result = await pg_client.fetch_one(query, user_id, self.tenant_domain)
|
||||
|
||||
if result and result["preferences"]:
|
||||
prefs = result["preferences"]
|
||||
# Handle both dict and JSON string
|
||||
if isinstance(prefs, str):
|
||||
return json.loads(prefs)
|
||||
return prefs
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user preferences: {e}")
|
||||
return {}
|
||||
|
||||
async def update_user_preferences(self, preferences: Dict[str, Any]) -> bool:
|
||||
"""Update user preferences in PostgreSQL (merges with existing)"""
|
||||
try:
|
||||
pg_client = await get_postgresql_client()
|
||||
user_id = await self._get_user_id(pg_client)
|
||||
|
||||
if not user_id:
|
||||
logger.warning(f"User not found: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}")
|
||||
return False
|
||||
|
||||
# Merge with existing preferences using PostgreSQL JSONB || operator
|
||||
query = """
|
||||
UPDATE users
|
||||
SET preferences = COALESCE(preferences, '{}'::jsonb) || $1::jsonb,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2
|
||||
AND tenant_id = (SELECT id FROM tenants WHERE domain = $3 LIMIT 1)
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
updated_id = await pg_client.fetch_scalar(
|
||||
query,
|
||||
json.dumps(preferences),
|
||||
user_id,
|
||||
self.tenant_domain
|
||||
)
|
||||
|
||||
if updated_id:
|
||||
logger.info(f"Updated preferences for user {user_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user preferences: {e}")
|
||||
return False
|
||||
|
||||
async def get_favorite_agent_ids(self) -> List[str]:
|
||||
"""Get user's favorited agent IDs"""
|
||||
try:
|
||||
preferences = await self.get_user_preferences()
|
||||
favorite_ids = preferences.get("favorite_agent_ids", [])
|
||||
|
||||
# Ensure it's a list
|
||||
if not isinstance(favorite_ids, list):
|
||||
return []
|
||||
|
||||
logger.info(f"Retrieved {len(favorite_ids)} favorite agent IDs for user {self.user_id}")
|
||||
return favorite_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting favorite agent IDs: {e}")
|
||||
return []
|
||||
|
||||
async def update_favorite_agent_ids(self, agent_ids: List[str]) -> bool:
|
||||
"""Update user's favorited agent IDs"""
|
||||
try:
|
||||
# Validate agent_ids is a list
|
||||
if not isinstance(agent_ids, list):
|
||||
logger.error(f"Invalid agent_ids type: {type(agent_ids)}")
|
||||
return False
|
||||
|
||||
# Update preferences with new favorite_agent_ids
|
||||
success = await self.update_user_preferences({
|
||||
"favorite_agent_ids": agent_ids
|
||||
})
|
||||
|
||||
if success:
|
||||
logger.info(f"Updated {len(agent_ids)} favorite agent IDs for user {self.user_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating favorite agent IDs: {e}")
|
||||
return False
|
||||
|
||||
async def add_favorite_agent(self, agent_id: str) -> bool:
|
||||
"""Add a single agent to favorites"""
|
||||
try:
|
||||
current_favorites = await self.get_favorite_agent_ids()
|
||||
|
||||
if agent_id not in current_favorites:
|
||||
current_favorites.append(agent_id)
|
||||
return await self.update_favorite_agent_ids(current_favorites)
|
||||
|
||||
# Already in favorites
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding favorite agent: {e}")
|
||||
return False
|
||||
|
||||
async def remove_favorite_agent(self, agent_id: str) -> bool:
|
||||
"""Remove a single agent from favorites"""
|
||||
try:
|
||||
current_favorites = await self.get_favorite_agent_ids()
|
||||
|
||||
if agent_id in current_favorites:
|
||||
current_favorites.remove(agent_id)
|
||||
return await self.update_favorite_agent_ids(current_favorites)
|
||||
|
||||
# Not in favorites
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing favorite agent: {e}")
|
||||
return False
|
||||
|
||||
async def get_custom_categories(self) -> List[Dict[str, Any]]:
|
||||
"""Get user's custom agent categories"""
|
||||
try:
|
||||
preferences = await self.get_user_preferences()
|
||||
custom_categories = preferences.get("custom_categories", [])
|
||||
|
||||
# Ensure it's a list
|
||||
if not isinstance(custom_categories, list):
|
||||
return []
|
||||
|
||||
logger.info(f"Retrieved {len(custom_categories)} custom categories for user {self.user_id}")
|
||||
return custom_categories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting custom categories: {e}")
|
||||
return []
|
||||
|
||||
async def update_custom_categories(self, categories: List[Dict[str, Any]]) -> bool:
|
||||
"""Update user's custom agent categories (replaces entire list)"""
|
||||
try:
|
||||
# Validate categories is a list
|
||||
if not isinstance(categories, list):
|
||||
logger.error(f"Invalid categories type: {type(categories)}")
|
||||
return False
|
||||
|
||||
# Convert Pydantic models to dicts if needed
|
||||
category_dicts = []
|
||||
for cat in categories:
|
||||
if hasattr(cat, 'dict'):
|
||||
category_dicts.append(cat.dict())
|
||||
elif isinstance(cat, dict):
|
||||
category_dicts.append(cat)
|
||||
else:
|
||||
logger.error(f"Invalid category type: {type(cat)}")
|
||||
return False
|
||||
|
||||
# Update preferences with new custom_categories
|
||||
success = await self.update_user_preferences({
|
||||
"custom_categories": category_dicts
|
||||
})
|
||||
|
||||
if success:
|
||||
logger.info(f"Updated {len(category_dicts)} custom categories for user {self.user_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating custom categories: {e}")
|
||||
return False
|
||||
448
apps/tenant-backend/app/services/vector_store.py
Normal file
448
apps/tenant-backend/app/services/vector_store.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""
|
||||
Vector Store Service for Tenant Backend
|
||||
|
||||
Manages tenant-specific ChromaDB instances with encryption.
|
||||
All vectors are stored locally in the tenant's encrypted database.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2
|
||||
import base64
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorSearchResult:
|
||||
"""Result from vector search"""
|
||||
document_id: str
|
||||
text: str
|
||||
score: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class TenantEncryption:
|
||||
"""Encryption handler for tenant data"""
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
"""Initialize encryption for tenant"""
|
||||
# Derive encryption key from tenant-specific secret
|
||||
tenant_key = f"{settings.SECRET_KEY}:{tenant_id}"
|
||||
kdf = PBKDF2(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=tenant_id.encode(),
|
||||
iterations=100000,
|
||||
)
|
||||
key = base64.urlsafe_b64encode(kdf.derive(tenant_key.encode()))
|
||||
self.cipher = Fernet(key)
|
||||
|
||||
def encrypt(self, data: str) -> bytes:
|
||||
"""Encrypt string data"""
|
||||
return self.cipher.encrypt(data.encode())
|
||||
|
||||
def decrypt(self, encrypted_data: bytes) -> str:
|
||||
"""Decrypt data to string"""
|
||||
return self.cipher.decrypt(encrypted_data).decode()
|
||||
|
||||
def encrypt_vector(self, vector: List[float]) -> bytes:
|
||||
"""Encrypt vector data"""
|
||||
vector_str = json.dumps(vector)
|
||||
return self.encrypt(vector_str)
|
||||
|
||||
def decrypt_vector(self, encrypted_vector: bytes) -> List[float]:
|
||||
"""Decrypt vector data"""
|
||||
vector_str = self.decrypt(encrypted_vector)
|
||||
return json.loads(vector_str)
|
||||
|
||||
|
||||
class VectorStoreService:
|
||||
"""
|
||||
Manages tenant-specific vector storage with ChromaDB.
|
||||
|
||||
Security principles:
|
||||
- All vectors stored in tenant-specific directory
|
||||
- Encryption at rest for all data
|
||||
- User-scoped collections
|
||||
- No cross-tenant access
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str, tenant_domain: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.tenant_domain = tenant_domain
|
||||
|
||||
# Initialize encryption
|
||||
self.encryption = TenantEncryption(tenant_id)
|
||||
|
||||
# Initialize ChromaDB client based on configuration mode
|
||||
if settings.chromadb_mode == "http":
|
||||
# Use HTTP client for per-tenant ChromaDB server
|
||||
self.client = chromadb.HttpClient(
|
||||
host=settings.chromadb_host,
|
||||
port=settings.chromadb_port,
|
||||
settings=Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=False
|
||||
)
|
||||
)
|
||||
logger.info(f"Vector store initialized for tenant {tenant_domain} using HTTP mode at {settings.chromadb_host}:{settings.chromadb_port}")
|
||||
else:
|
||||
# Use file-based client (fallback)
|
||||
self.storage_path = settings.chromadb_path
|
||||
os.makedirs(self.storage_path, exist_ok=True, mode=0o700)
|
||||
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=self.storage_path,
|
||||
settings=Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=False,
|
||||
is_persistent=True
|
||||
)
|
||||
)
|
||||
logger.info(f"Vector store initialized for tenant {tenant_domain} using file mode at {self.storage_path}")
|
||||
|
||||
async def create_user_collection(
|
||||
self,
|
||||
user_id: str,
|
||||
collection_name: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a user-scoped collection.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
collection_name: Name of the collection
|
||||
metadata: Optional collection metadata
|
||||
|
||||
Returns:
|
||||
Collection ID
|
||||
"""
|
||||
# Generate unique collection name for user
|
||||
collection_id = f"{user_id}_{collection_name}"
|
||||
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
|
||||
internal_name = f"col_{collection_hash}"
|
||||
|
||||
try:
|
||||
# Create or get collection
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=internal_name,
|
||||
metadata={
|
||||
"user_id": user_id,
|
||||
"collection_name": collection_name,
|
||||
"tenant_id": self.tenant_id,
|
||||
**(metadata or {})
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Created collection {collection_name} for user {user_id}")
|
||||
return collection_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating collection: {e}")
|
||||
raise
|
||||
|
||||
async def store_vectors(
|
||||
self,
|
||||
user_id: str,
|
||||
collection_name: str,
|
||||
documents: List[str],
|
||||
embeddings: List[List[float]],
|
||||
ids: Optional[List[str]] = None,
|
||||
metadata: Optional[List[Dict[str, Any]]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Store vectors in user collection with encryption.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
collection_name: Collection name
|
||||
documents: List of document texts
|
||||
embeddings: List of embedding vectors
|
||||
ids: Optional document IDs
|
||||
metadata: Optional document metadata
|
||||
|
||||
Returns:
|
||||
Success status
|
||||
"""
|
||||
try:
|
||||
# Get collection
|
||||
collection_id = f"{user_id}_{collection_name}"
|
||||
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
|
||||
internal_name = f"col_{collection_hash}"
|
||||
|
||||
collection = self.client.get_collection(name=internal_name)
|
||||
|
||||
# Generate IDs if not provided
|
||||
if ids is None:
|
||||
ids = [
|
||||
hashlib.sha256(f"{doc}:{i}".encode()).hexdigest()[:16]
|
||||
for i, doc in enumerate(documents)
|
||||
]
|
||||
|
||||
# Encrypt documents before storage
|
||||
encrypted_docs = [
|
||||
self.encryption.encrypt(doc).decode('latin-1')
|
||||
for doc in documents
|
||||
]
|
||||
|
||||
# Prepare metadata with encryption status
|
||||
final_metadata = []
|
||||
for i, doc_meta in enumerate(metadata or [{}] * len(documents)):
|
||||
meta = {
|
||||
**doc_meta,
|
||||
"encrypted": True,
|
||||
"user_id": user_id,
|
||||
"doc_hash": hashlib.sha256(documents[i].encode()).hexdigest()[:16]
|
||||
}
|
||||
final_metadata.append(meta)
|
||||
|
||||
# Add to collection
|
||||
collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=encrypted_docs,
|
||||
metadatas=final_metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Stored {len(documents)} vectors in collection {collection_name} "
|
||||
f"for user {user_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing vectors: {e}")
|
||||
raise
|
||||
|
||||
async def search(
|
||||
self,
|
||||
user_id: str,
|
||||
collection_name: str,
|
||||
query_embedding: List[float],
|
||||
top_k: int = 5,
|
||||
filter_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> List[VectorSearchResult]:
|
||||
"""
|
||||
Search for similar vectors in user collection.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
collection_name: Collection name
|
||||
query_embedding: Query vector
|
||||
top_k: Number of results to return
|
||||
filter_metadata: Optional metadata filters
|
||||
|
||||
Returns:
|
||||
List of search results with decrypted content
|
||||
"""
|
||||
try:
|
||||
# Get collection
|
||||
collection_id = f"{user_id}_{collection_name}"
|
||||
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
|
||||
internal_name = f"col_{collection_hash}"
|
||||
|
||||
collection = self.client.get_collection(name=internal_name)
|
||||
|
||||
# Prepare filter
|
||||
where_filter = {"user_id": user_id}
|
||||
if filter_metadata:
|
||||
where_filter.update(filter_metadata)
|
||||
|
||||
# Query collection
|
||||
results = collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=top_k,
|
||||
where=where_filter
|
||||
)
|
||||
|
||||
# Process results
|
||||
search_results = []
|
||||
if results and results['ids'] and len(results['ids'][0]) > 0:
|
||||
for i in range(len(results['ids'][0])):
|
||||
# Decrypt document text
|
||||
encrypted_doc = results['documents'][0][i].encode('latin-1')
|
||||
decrypted_doc = self.encryption.decrypt(encrypted_doc)
|
||||
|
||||
search_results.append(VectorSearchResult(
|
||||
document_id=results['ids'][0][i],
|
||||
text=decrypted_doc,
|
||||
score=1.0 - results['distances'][0][i], # Convert distance to similarity
|
||||
metadata=results['metadatas'][0][i] if results['metadatas'] else {}
|
||||
))
|
||||
|
||||
logger.info(
|
||||
f"Found {len(search_results)} results in collection {collection_name} "
|
||||
f"for user {user_id}"
|
||||
)
|
||||
return search_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching vectors: {e}")
|
||||
raise
|
||||
|
||||
async def delete_collection(
|
||||
self,
|
||||
user_id: str,
|
||||
collection_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a user collection.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
collection_name: Collection name
|
||||
|
||||
Returns:
|
||||
Success status
|
||||
"""
|
||||
try:
|
||||
collection_id = f"{user_id}_{collection_name}"
|
||||
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
|
||||
internal_name = f"col_{collection_hash}"
|
||||
|
||||
self.client.delete_collection(name=internal_name)
|
||||
|
||||
logger.info(f"Deleted collection {collection_name} for user {user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting collection: {e}")
|
||||
raise
|
||||
|
||||
async def list_user_collections(
|
||||
self,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all collections for a user.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
List of collection information
|
||||
"""
|
||||
try:
|
||||
all_collections = self.client.list_collections()
|
||||
user_collections = []
|
||||
|
||||
for collection in all_collections:
|
||||
metadata = collection.metadata
|
||||
if metadata and metadata.get("user_id") == user_id:
|
||||
user_collections.append({
|
||||
"name": metadata.get("collection_name"),
|
||||
"created_at": metadata.get("created_at"),
|
||||
"document_count": collection.count(),
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
return user_collections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing collections: {e}")
|
||||
raise
|
||||
|
||||
async def get_collection_stats(
|
||||
self,
|
||||
user_id: str,
|
||||
collection_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics for a user collection.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
collection_name: Collection name
|
||||
|
||||
Returns:
|
||||
Collection statistics
|
||||
"""
|
||||
try:
|
||||
collection_id = f"{user_id}_{collection_name}"
|
||||
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
|
||||
internal_name = f"col_{collection_hash}"
|
||||
|
||||
collection = self.client.get_collection(name=internal_name)
|
||||
|
||||
stats = {
|
||||
"document_count": collection.count(),
|
||||
"collection_name": collection_name,
|
||||
"user_id": user_id,
|
||||
"metadata": collection.metadata,
|
||||
"storage_mode": settings.chromadb_mode,
|
||||
"storage_path": getattr(self, 'storage_path', f"{settings.chromadb_host}:{settings.chromadb_port}")
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting collection stats: {e}")
|
||||
raise
|
||||
|
||||
async def update_document(
|
||||
self,
|
||||
user_id: str,
|
||||
collection_name: str,
|
||||
document_id: str,
|
||||
new_text: Optional[str] = None,
|
||||
new_embedding: Optional[List[float]] = None,
|
||||
new_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Update a document in the collection.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
collection_name: Collection name
|
||||
document_id: Document ID to update
|
||||
new_text: Optional new text
|
||||
new_embedding: Optional new embedding
|
||||
new_metadata: Optional new metadata
|
||||
|
||||
Returns:
|
||||
Success status
|
||||
"""
|
||||
try:
|
||||
collection_id = f"{user_id}_{collection_name}"
|
||||
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
|
||||
internal_name = f"col_{collection_hash}"
|
||||
|
||||
collection = self.client.get_collection(name=internal_name)
|
||||
|
||||
update_params = {"ids": [document_id]}
|
||||
|
||||
if new_text:
|
||||
encrypted_text = self.encryption.encrypt(new_text).decode('latin-1')
|
||||
update_params["documents"] = [encrypted_text]
|
||||
|
||||
if new_embedding:
|
||||
update_params["embeddings"] = [new_embedding]
|
||||
|
||||
if new_metadata:
|
||||
update_params["metadatas"] = [{
|
||||
**new_metadata,
|
||||
"encrypted": True,
|
||||
"user_id": user_id
|
||||
}]
|
||||
|
||||
collection.update(**update_params)
|
||||
|
||||
logger.info(f"Updated document {document_id} in collection {collection_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document: {e}")
|
||||
raise
|
||||
416
apps/tenant-backend/app/services/websocket_service.py
Normal file
416
apps/tenant-backend/app/services/websocket_service.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
WebSocket Integration Service for GT 2.0 Tenant Backend
|
||||
|
||||
Bridges the conversation service with WebSocket real-time streaming.
|
||||
Provides AI response streaming and event-driven message broadcasting.
|
||||
|
||||
GT 2.0 Architecture Principles:
|
||||
- Zero downtime: Non-blocking AI response streaming
|
||||
- Perfect tenant isolation: All streaming scoped to tenant
|
||||
- Self-contained: No external dependencies
|
||||
- Event-driven: Integrates with event automation system
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
from typing import AsyncGenerator, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db_session
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.event_service import EventService, EventType
|
||||
from app.services.agent_service import AgentService
|
||||
from app.websocket import get_websocket_manager, ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketService:
|
||||
"""
|
||||
Service for WebSocket-integrated AI conversation streaming.
|
||||
|
||||
Combines conversation service with real-time WebSocket delivery.
|
||||
"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
self.event_service = EventService(db)
|
||||
# Agent service will be initialized per request with tenant context
|
||||
self.websocket_manager = get_websocket_manager()
|
||||
|
||||
async def stream_ai_response(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
connection_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Stream AI response to WebSocket connection with real-time updates.
|
||||
|
||||
Args:
|
||||
conversation_id: Target conversation
|
||||
user_message: User's message content
|
||||
user_id: User identifier
|
||||
tenant_id: Tenant for isolation
|
||||
connection_id: WebSocket connection to stream to
|
||||
"""
|
||||
try:
|
||||
# Send streaming start notification
|
||||
await self.websocket_manager.send_to_connection(connection_id, {
|
||||
"type": "ai_response_start",
|
||||
"conversation_id": conversation_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
# Add user message to conversation
|
||||
await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=user_message,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Emit message sent event
|
||||
await self.event_service.emit_event(
|
||||
event_type=EventType.MESSAGE_SENT,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
data={
|
||||
"conversation_id": conversation_id,
|
||||
"message_type": "user",
|
||||
"content_length": len(user_message),
|
||||
"streaming": True
|
||||
}
|
||||
)
|
||||
|
||||
# Get conversation context for AI response
|
||||
conversation = await self.conversation_service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
include_messages=True
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ValueError("Conversation not found")
|
||||
|
||||
# Stream AI response
|
||||
full_response = ""
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# Get AI response generator
|
||||
async for chunk in self._generate_ai_response_stream(conversation, user_message):
|
||||
full_response += chunk
|
||||
|
||||
# Send chunk to WebSocket
|
||||
await self.websocket_manager.send_to_connection(connection_id, {
|
||||
"type": "ai_response_chunk",
|
||||
"conversation_id": conversation_id,
|
||||
"message_id": message_id,
|
||||
"content": chunk,
|
||||
"full_content_so_far": full_response,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
# Broadcast to other conversation participants
|
||||
await self.websocket_manager.broadcast_to_conversation(
|
||||
conversation_id,
|
||||
{
|
||||
"type": "ai_typing",
|
||||
"conversation_id": conversation_id,
|
||||
"content_preview": full_response[-50:] if len(full_response) > 50 else full_response
|
||||
},
|
||||
exclude_connection=connection_id
|
||||
)
|
||||
|
||||
# Save complete AI response
|
||||
await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="agent",
|
||||
content=full_response,
|
||||
user_id=user_id,
|
||||
message_id=message_id
|
||||
)
|
||||
|
||||
# Send completion notification
|
||||
await self.websocket_manager.send_to_connection(connection_id, {
|
||||
"type": "ai_response_complete",
|
||||
"conversation_id": conversation_id,
|
||||
"message_id": message_id,
|
||||
"full_content": full_response,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
# Broadcast completion to conversation participants
|
||||
await self.websocket_manager.broadcast_to_conversation(
|
||||
conversation_id,
|
||||
{
|
||||
"type": "new_ai_message",
|
||||
"conversation_id": conversation_id,
|
||||
"message_id": message_id,
|
||||
"content": full_response,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
},
|
||||
exclude_connection=connection_id
|
||||
)
|
||||
|
||||
# Emit AI response event
|
||||
await self.event_service.emit_event(
|
||||
event_type=EventType.MESSAGE_SENT,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
data={
|
||||
"conversation_id": conversation_id,
|
||||
"message_id": message_id,
|
||||
"message_type": "agent",
|
||||
"content_length": len(full_response),
|
||||
"streaming_completed": True
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"AI response streaming completed for conversation {conversation_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming AI response: {e}")
|
||||
|
||||
# Send error notification
|
||||
await self.websocket_manager.send_to_connection(connection_id, {
|
||||
"type": "ai_response_error",
|
||||
"conversation_id": conversation_id,
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
raise
|
||||
|
||||
async def _generate_ai_response_stream(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
user_message: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Generate AI response stream chunks.
|
||||
|
||||
This is a placeholder implementation. In production, this would
|
||||
integrate with the actual LLM service for streaming responses.
|
||||
|
||||
Args:
|
||||
conversation: Conversation context
|
||||
user_message: User's message
|
||||
|
||||
Yields:
|
||||
Response text chunks
|
||||
"""
|
||||
# Get agent configuration
|
||||
agent_id = conversation.get("agent_id")
|
||||
if agent_id:
|
||||
# Initialize AgentService with tenant context
|
||||
agent_service = AgentService(tenant_id, user_id)
|
||||
agent_config = await agent_service.get_agent(agent_id)
|
||||
assistant_config = agent_config if agent_config else {}
|
||||
else:
|
||||
assistant_config = {}
|
||||
|
||||
# Build conversation context
|
||||
messages = conversation.get("messages", [])
|
||||
context = []
|
||||
|
||||
# Add system prompt if available
|
||||
system_prompt = assistant_config.get("prompt", "You are a helpful AI agent.")
|
||||
context.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# Add recent conversation history (last 10 messages)
|
||||
for msg in messages[-10:]:
|
||||
context.append({
|
||||
"role": msg["role"],
|
||||
"content": msg["content"]
|
||||
})
|
||||
|
||||
# Add current user message
|
||||
context.append({"role": "user", "content": user_message})
|
||||
|
||||
# Simulate AI response streaming (replace with real LLM integration)
|
||||
# This demonstrates the streaming pattern that would be used with actual AI services
|
||||
response_text = await self._generate_mock_response(user_message, context)
|
||||
|
||||
# Stream response in chunks
|
||||
chunk_size = 5 # Characters per chunk for demo
|
||||
for i in range(0, len(response_text), chunk_size):
|
||||
chunk = response_text[i:i + chunk_size]
|
||||
|
||||
# Add realistic delay for streaming effect
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def _generate_mock_response(
|
||||
self,
|
||||
user_message: str,
|
||||
context: list
|
||||
) -> str:
|
||||
"""
|
||||
Generate mock AI response for development.
|
||||
|
||||
In production, this would be replaced with actual LLM API calls.
|
||||
|
||||
Args:
|
||||
user_message: User's message
|
||||
context: Conversation context
|
||||
|
||||
Returns:
|
||||
Generated response text
|
||||
"""
|
||||
# Simple mock response based on user input
|
||||
if "hello" in user_message.lower():
|
||||
return "Hello! I'm your AI agent. How can I help you today?"
|
||||
elif "help" in user_message.lower():
|
||||
return "I'm here to help! You can ask me questions, request information, or have a conversation. What would you like to know?"
|
||||
elif "weather" in user_message.lower():
|
||||
return "I don't have access to real-time weather data, but I'd be happy to help you find weather information or discuss weather patterns in general."
|
||||
elif "time" in user_message.lower():
|
||||
return f"The current time is approximately {datetime.utcnow().strftime('%H:%M UTC')}. Is there anything else I can help you with?"
|
||||
else:
|
||||
return f"Thank you for your message: '{user_message}'. I understand you're looking for assistance. Could you provide more details about what you'd like help with?"
|
||||
|
||||
async def handle_typing_indicator(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
is_typing: bool,
|
||||
connection_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Handle and broadcast typing indicators.
|
||||
|
||||
Args:
|
||||
conversation_id: Target conversation
|
||||
user_id: User who is typing
|
||||
is_typing: Whether user is currently typing
|
||||
connection_id: Connection that sent the indicator
|
||||
"""
|
||||
try:
|
||||
# Broadcast typing indicator to other conversation participants
|
||||
await self.websocket_manager.broadcast_to_conversation(
|
||||
conversation_id,
|
||||
{
|
||||
"type": "user_typing",
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"is_typing": is_typing,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
},
|
||||
exclude_connection=connection_id
|
||||
)
|
||||
|
||||
logger.debug(f"Typing indicator broadcast: {user_id} {'started' if is_typing else 'stopped'} typing")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling typing indicator: {e}")
|
||||
|
||||
async def send_system_notification(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
notification: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Send system notification to all user connections.
|
||||
|
||||
Args:
|
||||
user_id: Target user
|
||||
tenant_id: Tenant for isolation
|
||||
notification: Notification data
|
||||
"""
|
||||
try:
|
||||
message = {
|
||||
"type": "system_notification",
|
||||
"notification": notification,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.websocket_manager.send_to_user(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
message=message
|
||||
)
|
||||
|
||||
logger.info(f"System notification sent to user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending system notification: {e}")
|
||||
|
||||
async def broadcast_conversation_update(
|
||||
self,
|
||||
conversation_id: str,
|
||||
update_type: str,
|
||||
update_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Broadcast conversation update to all participants.
|
||||
|
||||
Args:
|
||||
conversation_id: Target conversation
|
||||
update_type: Type of update (title_changed, participant_added, etc.)
|
||||
update_data: Update details
|
||||
"""
|
||||
try:
|
||||
message = {
|
||||
"type": "conversation_update",
|
||||
"conversation_id": conversation_id,
|
||||
"update_type": update_type,
|
||||
"update_data": update_data,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.websocket_manager.broadcast_to_conversation(
|
||||
conversation_id,
|
||||
message
|
||||
)
|
||||
|
||||
logger.info(f"Conversation update broadcast: {update_type} for {conversation_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting conversation update: {e}")
|
||||
|
||||
async def get_connection_stats(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get WebSocket connection statistics for tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant to get stats for
|
||||
|
||||
Returns:
|
||||
Connection statistics
|
||||
"""
|
||||
try:
|
||||
all_stats = self.websocket_manager.get_connection_stats()
|
||||
|
||||
return {
|
||||
"tenant_connections": all_stats["connections_by_tenant"].get(tenant_id, 0),
|
||||
"active_conversations": len([
|
||||
conv_id for conv_id, connections in self.websocket_manager.conversation_connections.items()
|
||||
if any(
|
||||
self.websocket_manager.connections.get(conn_id, {}).tenant_id == tenant_id
|
||||
for conn_id in connections
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting connection stats: {e}")
|
||||
return {"tenant_connections": 0, "active_conversations": 0}
|
||||
|
||||
|
||||
# Factory function for dependency injection
|
||||
async def get_websocket_service(db: AsyncSession = None) -> WebSocketService:
|
||||
"""Get WebSocket service instance"""
|
||||
if db is None:
|
||||
async with get_db_session() as session:
|
||||
return WebSocketService(session)
|
||||
return WebSocketService(db)
|
||||
995
apps/tenant-backend/app/services/workflow_service.py
Normal file
995
apps/tenant-backend/app/services/workflow_service.py
Normal file
@@ -0,0 +1,995 @@
|
||||
import uuid
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.workflow import (
|
||||
Workflow, WorkflowExecution, WorkflowTrigger, WorkflowSession, WorkflowMessage,
|
||||
WorkflowStatus, TriggerType, InteractionMode,
|
||||
WORKFLOW_NODE_TYPES, INTERACTION_MODE_CONFIGS
|
||||
)
|
||||
from app.models.agent import Agent
|
||||
# Backward compatibility
|
||||
from app.models.agent import Agent
|
||||
from app.services.resource_service import ResourceService
|
||||
|
||||
|
||||
class WorkflowValidationError(Exception):
|
||||
"""Raised when workflow validation fails"""
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
"""Service for managing workflows with Agents as AI node definitions"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.resource_service = ResourceService()
|
||||
|
||||
def create_workflow(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
workflow_data: Dict[str, Any]
|
||||
) -> Workflow:
|
||||
"""Create a new workflow with validation"""
|
||||
|
||||
# Validate workflow definition
|
||||
self._validate_workflow_definition(
|
||||
workflow_data.get('definition', {}),
|
||||
user_id,
|
||||
tenant_id
|
||||
)
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=workflow_data['name'],
|
||||
description=workflow_data.get('description', ''),
|
||||
definition=workflow_data['definition'],
|
||||
triggers=workflow_data.get('triggers', []),
|
||||
interaction_modes=workflow_data.get('interaction_modes', ['button']),
|
||||
agent_ids=self._extract_agent_ids(workflow_data['definition']),
|
||||
api_key_ids=workflow_data.get('api_key_ids', []),
|
||||
webhook_ids=workflow_data.get('webhook_ids', []),
|
||||
dataset_ids=workflow_data.get('dataset_ids', []),
|
||||
integration_ids=workflow_data.get('integration_ids', []),
|
||||
config=workflow_data.get('config', {}),
|
||||
timeout_seconds=workflow_data.get('timeout_seconds', 300),
|
||||
max_retries=workflow_data.get('max_retries', 3)
|
||||
)
|
||||
|
||||
# Use sync database operations
|
||||
self.db.add(workflow)
|
||||
self.db.commit()
|
||||
self.db.refresh(workflow)
|
||||
|
||||
# Create triggers if specified
|
||||
for trigger_config in workflow_data.get('triggers', []):
|
||||
self.create_workflow_trigger(
|
||||
workflow.id,
|
||||
user_id,
|
||||
tenant_id,
|
||||
trigger_config
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
def get_workflow(self, workflow_id: str, user_id: str) -> Optional[Workflow]:
|
||||
"""Get a workflow by ID with user ownership validation"""
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.user_id == user_id
|
||||
)
|
||||
result = self.db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def list_user_workflows(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
status: Optional[WorkflowStatus] = None
|
||||
) -> List[Workflow]:
|
||||
"""List all workflows for a user"""
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.user_id == user_id,
|
||||
Workflow.tenant_id == tenant_id
|
||||
)
|
||||
|
||||
if status:
|
||||
stmt = stmt.where(Workflow.status == status)
|
||||
|
||||
stmt = stmt.order_by(Workflow.updated_at.desc())
|
||||
result = self.db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
def update_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
updates: Dict[str, Any]
|
||||
) -> Workflow:
|
||||
"""Update a workflow with validation"""
|
||||
|
||||
# Get existing workflow
|
||||
workflow = self.get_workflow(workflow_id, user_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found or access denied")
|
||||
|
||||
# Validate definition if updated
|
||||
if 'definition' in updates:
|
||||
self._validate_workflow_definition(
|
||||
updates['definition'],
|
||||
user_id,
|
||||
workflow.tenant_id
|
||||
)
|
||||
updates['agent_ids'] = self._extract_agent_ids(updates['definition'])
|
||||
|
||||
# Update workflow
|
||||
stmt = update(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.user_id == user_id
|
||||
).values(**updates)
|
||||
|
||||
self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
|
||||
# Return updated workflow
|
||||
return self.get_workflow(workflow_id, user_id)
|
||||
|
||||
def delete_workflow(self, workflow_id: str, user_id: str) -> bool:
|
||||
"""Delete a workflow and its related data"""
|
||||
|
||||
# Verify ownership
|
||||
workflow = self.get_workflow(workflow_id, user_id)
|
||||
if not workflow:
|
||||
return False
|
||||
|
||||
# Delete related records
|
||||
self._cleanup_workflow_data(workflow_id)
|
||||
|
||||
# Delete workflow
|
||||
stmt = delete(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.user_id == user_id
|
||||
)
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
|
||||
return result.rowcount > 0
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
input_data: Dict[str, Any],
|
||||
trigger_type: str = "manual",
|
||||
trigger_source: Optional[str] = None,
|
||||
interaction_mode: str = "api"
|
||||
) -> WorkflowExecution:
|
||||
"""Execute a workflow with specified input"""
|
||||
|
||||
# Get and validate workflow
|
||||
workflow = await self.get_workflow(workflow_id, user_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found or access denied")
|
||||
|
||||
if workflow.status not in [WorkflowStatus.ACTIVE, WorkflowStatus.DRAFT]:
|
||||
raise ValueError(f"Cannot execute workflow with status: {workflow.status}")
|
||||
|
||||
# Create execution record
|
||||
execution = WorkflowExecution(
|
||||
id=str(uuid.uuid4()),
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
tenant_id=workflow.tenant_id,
|
||||
status="pending",
|
||||
input_data=input_data,
|
||||
trigger_type=trigger_type,
|
||||
trigger_source=trigger_source,
|
||||
interaction_mode=interaction_mode,
|
||||
session_id=str(uuid.uuid4()) if interaction_mode == "chat" else None
|
||||
)
|
||||
|
||||
self.db.add(execution)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(execution)
|
||||
|
||||
# Execute workflow asynchronously (in real implementation, this would be a background task)
|
||||
try:
|
||||
execution_result = await self._execute_workflow_nodes(workflow, execution, input_data)
|
||||
|
||||
# Update execution with results
|
||||
execution.status = "completed"
|
||||
execution.output_data = execution_result.get('output', {})
|
||||
execution.completed_at = datetime.utcnow()
|
||||
execution.duration_ms = int((execution.completed_at - execution.started_at).total_seconds() * 1000)
|
||||
execution.progress_percentage = 100
|
||||
|
||||
# Update workflow statistics
|
||||
workflow.execution_count += 1
|
||||
workflow.last_executed = datetime.utcnow()
|
||||
workflow.total_tokens_used += execution_result.get('tokens_used', 0)
|
||||
workflow.total_cost_cents += execution_result.get('cost_cents', 0)
|
||||
|
||||
except Exception as e:
|
||||
# Mark execution as failed
|
||||
execution.status = "failed"
|
||||
execution.error_details = str(e)
|
||||
execution.completed_at = datetime.utcnow()
|
||||
execution.duration_ms = int((execution.completed_at - execution.started_at).total_seconds() * 1000)
|
||||
|
||||
await self.db.commit()
|
||||
return execution
|
||||
|
||||
async def get_execution_status(
|
||||
self,
|
||||
execution_id: str,
|
||||
user_id: str
|
||||
) -> Optional[WorkflowExecution]:
|
||||
"""Get execution status with user validation"""
|
||||
stmt = select(WorkflowExecution).where(
|
||||
WorkflowExecution.id == execution_id,
|
||||
WorkflowExecution.user_id == user_id
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def create_workflow_trigger(
|
||||
self,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
trigger_config: Dict[str, Any]
|
||||
) -> WorkflowTrigger:
|
||||
"""Create a trigger for a workflow"""
|
||||
|
||||
trigger = WorkflowTrigger(
|
||||
id=str(uuid.uuid4()),
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
trigger_type=trigger_config['type'],
|
||||
trigger_config=trigger_config
|
||||
)
|
||||
|
||||
# Configure trigger-specific settings
|
||||
if trigger_config['type'] == 'webhook':
|
||||
trigger.webhook_url = f"https://api.gt2.com/webhooks/{trigger.id}"
|
||||
trigger.webhook_secret = str(uuid.uuid4())
|
||||
elif trigger_config['type'] == 'cron':
|
||||
trigger.cron_schedule = trigger_config.get('schedule', '0 0 * * *')
|
||||
trigger.timezone = trigger_config.get('timezone', 'UTC')
|
||||
elif trigger_config['type'] == 'event':
|
||||
trigger.event_source = trigger_config.get('source', '')
|
||||
trigger.event_filters = trigger_config.get('filters', {})
|
||||
|
||||
self.db.add(trigger)
|
||||
self.db.commit()
|
||||
self.db.refresh(trigger)
|
||||
|
||||
return trigger
|
||||
|
||||
async def create_chat_session(
|
||||
self,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
session_config: Optional[Dict[str, Any]] = None
|
||||
) -> WorkflowSession:
|
||||
"""Create a chat session for workflow interaction"""
|
||||
|
||||
session = WorkflowSession(
|
||||
id=str(uuid.uuid4()),
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
session_type="chat",
|
||||
session_state=session_config or {},
|
||||
expires_at=datetime.utcnow() + timedelta(hours=24) # 24 hour session
|
||||
)
|
||||
|
||||
self.db.add(session)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(session)
|
||||
|
||||
return session
|
||||
|
||||
async def add_chat_message(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
agent_id: Optional[str] = None,
|
||||
confidence_score: Optional[int] = None,
|
||||
execution_id: Optional[str] = None
|
||||
) -> WorkflowMessage:
|
||||
"""Add a message to a chat session"""
|
||||
|
||||
# Get session and validate
|
||||
stmt = select(WorkflowSession).where(
|
||||
WorkflowSession.id == session_id,
|
||||
WorkflowSession.user_id == user_id,
|
||||
WorkflowSession.is_active == True
|
||||
)
|
||||
session = await self.db.execute(stmt)
|
||||
session = session.scalar_one_or_none()
|
||||
|
||||
if not session:
|
||||
raise ValueError("Chat session not found or expired")
|
||||
|
||||
message = WorkflowMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
session_id=session_id,
|
||||
workflow_id=session.workflow_id,
|
||||
execution_id=execution_id,
|
||||
user_id=user_id,
|
||||
tenant_id=session.tenant_id,
|
||||
role=role,
|
||||
content=content,
|
||||
agent_id=agent_id,
|
||||
confidence_score=confidence_score
|
||||
)
|
||||
|
||||
self.db.add(message)
|
||||
|
||||
# Update session
|
||||
session.message_count += 1
|
||||
session.last_message_at = datetime.utcnow()
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(message)
|
||||
|
||||
return message
|
||||
|
||||
def _validate_workflow_definition(
|
||||
self,
|
||||
definition: Dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
):
|
||||
"""Validate workflow definition and resource access"""
|
||||
|
||||
nodes = definition.get('nodes', [])
|
||||
edges = definition.get('edges', [])
|
||||
|
||||
# Validate nodes
|
||||
for node in nodes:
|
||||
node_type = node.get('type')
|
||||
if node_type not in WORKFLOW_NODE_TYPES:
|
||||
raise WorkflowValidationError(f"Invalid node type: {node_type}")
|
||||
|
||||
# Validate Agent nodes (support both agent and agent types)
|
||||
if node_type == 'agent' or node_type == 'agent':
|
||||
agent_id = node.get('data', {}).get('agent_id') or node.get('data', {}).get('agent_id')
|
||||
if not agent_id:
|
||||
raise WorkflowValidationError("Agent node missing agent_id or agent_id")
|
||||
|
||||
# Verify user owns the agent
|
||||
agent = self._get_user_agent(agent_id, user_id, tenant_id)
|
||||
if not agent:
|
||||
raise WorkflowValidationError(f"Agent {agent_id} not found or access denied")
|
||||
|
||||
# Validate Integration nodes
|
||||
elif node_type == 'integration':
|
||||
api_key_id = node.get('data', {}).get('api_key_id')
|
||||
if api_key_id:
|
||||
# In real implementation, validate API key ownership
|
||||
pass
|
||||
|
||||
# Validate edges (connections between nodes)
|
||||
node_ids = {node['id'] for node in nodes}
|
||||
for edge in edges:
|
||||
source = edge.get('source')
|
||||
target = edge.get('target')
|
||||
|
||||
if source not in node_ids or target not in node_ids:
|
||||
raise WorkflowValidationError("Invalid edge connection")
|
||||
|
||||
# Ensure workflow has at least one trigger node
|
||||
trigger_nodes = [n for n in nodes if n.get('type') == 'trigger']
|
||||
if not trigger_nodes:
|
||||
raise WorkflowValidationError("Workflow must have at least one trigger node")
|
||||
|
||||
def _get_user_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> Optional[Agent]:
|
||||
"""Get agent with ownership validation"""
|
||||
stmt = select(Agent).where(
|
||||
Agent.id == agent_id,
|
||||
Agent.user_id == user_id,
|
||||
Agent.tenant_id == tenant_id
|
||||
)
|
||||
result = self.db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
# Backward compatibility method
|
||||
def _get_user_assistant(
|
||||
self,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> Optional[Agent]:
|
||||
"""Backward compatibility wrapper for _get_user_agent"""
|
||||
return self._get_user_agent(agent_id, user_id, tenant_id)
|
||||
|
||||
def _extract_agent_ids(self, definition: Dict[str, Any]) -> List[str]:
|
||||
"""Extract agent IDs from workflow definition"""
|
||||
agent_ids = []
|
||||
|
||||
for node in definition.get('nodes', []):
|
||||
if node.get('type') in ['agent', 'agent']:
|
||||
agent_id = node.get('data', {}).get('agent_id') or node.get('data', {}).get('agent_id')
|
||||
if agent_id:
|
||||
agent_ids.append(agent_id)
|
||||
|
||||
return agent_ids
|
||||
|
||||
# Backward compatibility method
|
||||
def _extract_agent_ids(self, definition: Dict[str, Any]) -> List[str]:
|
||||
"""Backward compatibility wrapper for _extract_agent_ids"""
|
||||
return self._extract_agent_ids(definition)
|
||||
|
||||
async def _execute_workflow_nodes(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
execution: WorkflowExecution,
|
||||
input_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute workflow nodes in order"""
|
||||
|
||||
# Update execution status
|
||||
execution.status = "running"
|
||||
execution.progress_percentage = 10
|
||||
await self.db.commit()
|
||||
|
||||
# Parse workflow definition to create execution graph
|
||||
definition = workflow.definition
|
||||
nodes = definition.get('nodes', [])
|
||||
edges = definition.get('edges', [])
|
||||
|
||||
if not nodes:
|
||||
raise ValueError("Workflow has no nodes to execute")
|
||||
|
||||
# Find trigger node to start execution
|
||||
trigger_nodes = [n for n in nodes if n.get('type') == 'trigger']
|
||||
if not trigger_nodes:
|
||||
raise ValueError("Workflow has no trigger nodes")
|
||||
|
||||
execution_trace = []
|
||||
total_tokens = 0
|
||||
total_cost = 0
|
||||
current_data = input_data
|
||||
|
||||
# Execute nodes in simple sequential order (real implementation would use topological sort)
|
||||
for node in nodes:
|
||||
node_id = node.get('id')
|
||||
node_type = node.get('type')
|
||||
|
||||
try:
|
||||
if node_type == 'trigger':
|
||||
# Trigger nodes just pass through input data
|
||||
node_result = {
|
||||
'output': current_data,
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0
|
||||
}
|
||||
|
||||
elif node_type == 'agent' or node_type == 'agent': # Support both for compatibility
|
||||
# Execute Agent node via resource cluster
|
||||
node_result = await self._execute_agent_node_real(node, current_data, execution.user_id, execution.tenant_id)
|
||||
|
||||
elif node_type == 'integration':
|
||||
# Execute integration node (simulated - no external connections)
|
||||
node_result = await self._execute_integration_node_simulated(node, current_data)
|
||||
|
||||
elif node_type == 'logic':
|
||||
# Execute logic node (real logic operations)
|
||||
node_result = await self._execute_logic_node_simulated(node, current_data)
|
||||
|
||||
elif node_type == 'output':
|
||||
# Execute output node (simulated - no external deliveries)
|
||||
node_result = await self._execute_output_node_simulated(node, current_data)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown node type: {node_type}")
|
||||
|
||||
# Update execution state
|
||||
current_data = node_result.get('output', current_data)
|
||||
total_tokens += node_result.get('tokens_used', 0)
|
||||
total_cost += node_result.get('cost_cents', 0)
|
||||
|
||||
execution_trace.append({
|
||||
'node_id': node_id,
|
||||
'node_type': node_type,
|
||||
'status': 'completed',
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'tokens_used': node_result.get('tokens_used', 0),
|
||||
'cost_cents': node_result.get('cost_cents', 0),
|
||||
'execution_time_ms': node_result.get('execution_time_ms', 0)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Record failed node execution
|
||||
execution_trace.append({
|
||||
'node_id': node_id,
|
||||
'node_type': node_type,
|
||||
'status': 'failed',
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'error': str(e)
|
||||
})
|
||||
raise ValueError(f"Node {node_id} execution failed: {str(e)}")
|
||||
|
||||
return {
|
||||
'output': current_data,
|
||||
'tokens_used': total_tokens,
|
||||
'cost_cents': total_cost,
|
||||
'execution_trace': execution_trace
|
||||
}
|
||||
|
||||
async def _execute_agent_node_real(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
input_data: Dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute an Agent node with real Agent integration"""
|
||||
|
||||
# Support both agent_id and agent_id for backward compatibility
|
||||
agent_id = node.get('data', {}).get('agent_id') or node.get('data', {}).get('agent_id')
|
||||
if not agent_id:
|
||||
raise ValueError("Agent node missing agent_id or agent_id")
|
||||
|
||||
# Get Agent configuration
|
||||
agent = await self._get_user_agent(agent_id, user_id, tenant_id)
|
||||
if not agent:
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
# Prepare input text from workflow data
|
||||
input_text = input_data.get('message', '') or str(input_data)
|
||||
|
||||
# Use the existing conversation service for real execution
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
try:
|
||||
conversation_service = ConversationService(self.db)
|
||||
|
||||
# Create or get conversation for this workflow execution
|
||||
conversation_id = f"workflow-{agent_id}-{datetime.utcnow().isoformat()}"
|
||||
|
||||
# Execute agent with real conversation service (using agent_id for backward compatibility)
|
||||
response = await conversation_service.send_message(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
content=input_text,
|
||||
agent_id=agent_id # ConversationService still expects agent_id parameter
|
||||
)
|
||||
|
||||
# Parse response to extract metrics
|
||||
tokens_used = response.get('tokens_used', 100) # Default estimate
|
||||
cost_cents = max(1, tokens_used // 50) # Rough cost estimation
|
||||
|
||||
return {
|
||||
'output': response.get('content', 'Agent response'),
|
||||
'confidence': node.get('data', {}).get('confidence_threshold', 75),
|
||||
'tokens_used': tokens_used,
|
||||
'cost_cents': cost_cents,
|
||||
'execution_time_ms': response.get('response_time_ms', 1000)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# If conversation service fails, use basic text processing
|
||||
return {
|
||||
'output': f"Agent {agent.name} processed: {input_text[:100]}{'...' if len(input_text) > 100 else ''}",
|
||||
'confidence': 50, # Lower confidence for fallback
|
||||
'tokens_used': len(input_text.split()) * 2, # Rough token estimate
|
||||
'cost_cents': 2,
|
||||
'execution_time_ms': 500
|
||||
}
|
||||
|
||||
async def _execute_integration_node_simulated(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
input_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute an Integration node with simulated responses (no external connections)"""
|
||||
|
||||
integration_type = node.get('data', {}).get('integration_type', 'api')
|
||||
integration_name = node.get('data', {}).get('name', 'Unknown Integration')
|
||||
|
||||
# Simulate processing time based on integration type
|
||||
import asyncio
|
||||
processing_times = {
|
||||
'api': 200, # API calls: 200ms
|
||||
'webhook': 150, # Webhook: 150ms
|
||||
'database': 300, # Database: 300ms
|
||||
'email': 500, # Email: 500ms
|
||||
'storage': 250 # Storage: 250ms
|
||||
}
|
||||
|
||||
processing_time = processing_times.get(integration_type, 200)
|
||||
await asyncio.sleep(processing_time / 1000) # Convert to seconds
|
||||
|
||||
# Generate realistic simulated responses based on integration type
|
||||
simulated_responses = {
|
||||
'api': {
|
||||
'status': 'success',
|
||||
'data': {
|
||||
'response_code': 200,
|
||||
'message': f'API call to {integration_name} completed successfully',
|
||||
'processed_items': len(str(input_data)) // 10,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
},
|
||||
'webhook': {
|
||||
'status': 'delivered',
|
||||
'webhook_id': f'wh_{uuid.uuid4().hex[:8]}',
|
||||
'delivery_time_ms': processing_time,
|
||||
'response_code': 200
|
||||
},
|
||||
'database': {
|
||||
'status': 'executed',
|
||||
'affected_rows': 1,
|
||||
'query_time_ms': processing_time,
|
||||
'result_count': 1
|
||||
},
|
||||
'email': {
|
||||
'status': 'sent',
|
||||
'message_id': f'msg_{uuid.uuid4().hex[:12]}',
|
||||
'recipients': 1,
|
||||
'delivery_status': 'queued'
|
||||
},
|
||||
'storage': {
|
||||
'status': 'uploaded',
|
||||
'file_size_bytes': len(str(input_data)),
|
||||
'storage_path': f'/simulated/path/{uuid.uuid4().hex[:8]}.json',
|
||||
'etag': f'etag_{uuid.uuid4().hex[:16]}'
|
||||
}
|
||||
}
|
||||
|
||||
response_data = simulated_responses.get(integration_type, simulated_responses['api'])
|
||||
|
||||
return {
|
||||
'output': response_data,
|
||||
'simulated': True, # Mark as simulated
|
||||
'integration_type': integration_type,
|
||||
'integration_name': integration_name,
|
||||
'tokens_used': 0, # Integrations don't use AI tokens
|
||||
'cost_cents': 1, # Minimal cost for simulation
|
||||
'execution_time_ms': processing_time,
|
||||
'log_message': f'Integration {integration_name} simulated - external connections not implemented'
|
||||
}
|
||||
|
||||
async def _execute_logic_node_simulated(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
input_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a Logic node with real logic operations"""
|
||||
|
||||
logic_type = node.get('data', {}).get('logic_type', 'transform')
|
||||
logic_config = node.get('data', {}).get('logic_config', {})
|
||||
|
||||
import asyncio
|
||||
await asyncio.sleep(0.05) # Small processing delay
|
||||
|
||||
if logic_type == 'condition':
|
||||
# Simple condition evaluation
|
||||
condition = logic_config.get('condition', 'true')
|
||||
field = logic_config.get('field', 'message')
|
||||
operator = logic_config.get('operator', 'contains')
|
||||
value = logic_config.get('value', '')
|
||||
|
||||
input_value = str(input_data.get(field, ''))
|
||||
|
||||
if operator == 'contains':
|
||||
result = value.lower() in input_value.lower()
|
||||
elif operator == 'equals':
|
||||
result = input_value.lower() == value.lower()
|
||||
elif operator == 'length_gt':
|
||||
result = len(input_value) > int(value)
|
||||
else:
|
||||
result = True # Default to true
|
||||
|
||||
return {
|
||||
'output': {
|
||||
**input_data,
|
||||
'condition_result': result,
|
||||
'condition_evaluated': f'{field} {operator} {value}'
|
||||
},
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0,
|
||||
'execution_time_ms': 50
|
||||
}
|
||||
|
||||
elif logic_type == 'transform':
|
||||
# Data transformation
|
||||
transform_rules = logic_config.get('rules', [])
|
||||
transformed_data = dict(input_data)
|
||||
|
||||
# Apply simple transformations
|
||||
for rule in transform_rules:
|
||||
source_field = rule.get('source', '')
|
||||
target_field = rule.get('target', source_field)
|
||||
operation = rule.get('operation', 'copy')
|
||||
|
||||
if source_field in input_data:
|
||||
value = input_data[source_field]
|
||||
|
||||
if operation == 'uppercase':
|
||||
transformed_data[target_field] = str(value).upper()
|
||||
elif operation == 'lowercase':
|
||||
transformed_data[target_field] = str(value).lower()
|
||||
elif operation == 'length':
|
||||
transformed_data[target_field] = len(str(value))
|
||||
else: # copy
|
||||
transformed_data[target_field] = value
|
||||
|
||||
return {
|
||||
'output': transformed_data,
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0,
|
||||
'execution_time_ms': 50
|
||||
}
|
||||
|
||||
elif logic_type == 'aggregate':
|
||||
# Simple aggregation
|
||||
aggregate_field = logic_config.get('field', 'items')
|
||||
operation = logic_config.get('operation', 'count')
|
||||
|
||||
items = input_data.get(aggregate_field, [])
|
||||
if not isinstance(items, list):
|
||||
items = [items]
|
||||
|
||||
if operation == 'count':
|
||||
result = len(items)
|
||||
elif operation == 'sum' and all(isinstance(x, (int, float)) for x in items):
|
||||
result = sum(items)
|
||||
elif operation == 'average' and all(isinstance(x, (int, float)) for x in items):
|
||||
result = sum(items) / len(items) if items else 0
|
||||
else:
|
||||
result = len(items)
|
||||
|
||||
return {
|
||||
'output': {
|
||||
**input_data,
|
||||
f'{operation}_result': result,
|
||||
f'{operation}_field': aggregate_field
|
||||
},
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0,
|
||||
'execution_time_ms': 50
|
||||
}
|
||||
|
||||
else:
|
||||
# Default passthrough
|
||||
return {
|
||||
'output': input_data,
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0,
|
||||
'execution_time_ms': 50
|
||||
}
|
||||
|
||||
async def _execute_output_node_simulated(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
input_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute an Output node with simulated delivery (no external sends)"""
|
||||
|
||||
output_type = node.get('data', {}).get('output_type', 'webhook')
|
||||
output_config = node.get('data', {}).get('output_config', {})
|
||||
output_name = node.get('data', {}).get('name', 'Unknown Output')
|
||||
|
||||
import asyncio
|
||||
# Simulate delivery time based on output type
|
||||
delivery_times = {
|
||||
'webhook': 300, # Webhook delivery: 300ms
|
||||
'email': 800, # Email sending: 800ms
|
||||
'api': 250, # API call: 250ms
|
||||
'storage': 400, # File storage: 400ms
|
||||
'notification': 200 # Push notification: 200ms
|
||||
}
|
||||
|
||||
delivery_time = delivery_times.get(output_type, 300)
|
||||
await asyncio.sleep(delivery_time / 1000)
|
||||
|
||||
# Generate realistic simulated delivery responses
|
||||
simulated_deliveries = {
|
||||
'webhook': {
|
||||
'status': 'delivered',
|
||||
'webhook_url': output_config.get('url', 'https://api.example.com/webhook'),
|
||||
'response_code': 200,
|
||||
'delivery_id': f'wh_delivery_{uuid.uuid4().hex[:8]}',
|
||||
'payload_size_bytes': len(str(input_data))
|
||||
},
|
||||
'email': {
|
||||
'status': 'queued',
|
||||
'recipients': output_config.get('recipients', ['user@example.com']),
|
||||
'subject': output_config.get('subject', 'Workflow Output'),
|
||||
'message_id': f'email_{uuid.uuid4().hex[:12]}',
|
||||
'provider': 'simulated_smtp'
|
||||
},
|
||||
'api': {
|
||||
'status': 'sent',
|
||||
'endpoint': output_config.get('endpoint', '/api/results'),
|
||||
'method': output_config.get('method', 'POST'),
|
||||
'response_code': 201,
|
||||
'request_id': f'api_{uuid.uuid4().hex[:8]}'
|
||||
},
|
||||
'storage': {
|
||||
'status': 'stored',
|
||||
'storage_path': f'/outputs/{uuid.uuid4().hex[:8]}.json',
|
||||
'file_size_bytes': len(str(input_data)),
|
||||
'checksum': f'sha256_{uuid.uuid4().hex[:16]}'
|
||||
},
|
||||
'notification': {
|
||||
'status': 'pushed',
|
||||
'devices': output_config.get('devices', 1),
|
||||
'message': output_config.get('message', 'Workflow completed'),
|
||||
'notification_id': f'notif_{uuid.uuid4().hex[:8]}'
|
||||
}
|
||||
}
|
||||
|
||||
delivery_data = simulated_deliveries.get(output_type, simulated_deliveries['webhook'])
|
||||
|
||||
return {
|
||||
'output': {
|
||||
**input_data,
|
||||
'delivery_result': delivery_data,
|
||||
'output_type': output_type,
|
||||
'output_name': output_name
|
||||
},
|
||||
'simulated': True,
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 1, # Minimal cost for output
|
||||
'execution_time_ms': delivery_time,
|
||||
'log_message': f'Output {output_name} simulated - external delivery not implemented'
|
||||
}
|
||||
|
||||
async def _execute_logic_node_real(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
input_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a Logic node with actual data processing"""
|
||||
|
||||
logic_type = node.get('data', {}).get('logic_type')
|
||||
|
||||
if logic_type == 'decision':
|
||||
condition = node.get('data', {}).get('config', {}).get('condition', 'true')
|
||||
|
||||
# Simple condition evaluation (in production would use safe expression evaluator)
|
||||
try:
|
||||
# Basic condition evaluation for common cases
|
||||
if 'input.value' in condition:
|
||||
input_value = input_data.get('value', 0)
|
||||
condition_result = eval(condition.replace('input.value', str(input_value)))
|
||||
else:
|
||||
condition_result = True # Default to true for undefined conditions
|
||||
|
||||
return {
|
||||
'output': {
|
||||
'condition_result': condition_result,
|
||||
'original_data': input_data,
|
||||
'branch': 'true' if condition_result else 'false'
|
||||
},
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0,
|
||||
'execution_time_ms': 50
|
||||
}
|
||||
except:
|
||||
# Fallback to pass-through if condition evaluation fails
|
||||
return {
|
||||
'output': input_data,
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0,
|
||||
'execution_time_ms': 50
|
||||
}
|
||||
|
||||
elif logic_type == 'transform':
|
||||
# Simple data transformation
|
||||
return {
|
||||
'output': {
|
||||
'transformed_data': input_data,
|
||||
'transformation_type': 'basic',
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
},
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0,
|
||||
'execution_time_ms': 25
|
||||
}
|
||||
|
||||
# Default: pass through data
|
||||
return {
|
||||
'output': input_data,
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0,
|
||||
'execution_time_ms': 25
|
||||
}
|
||||
|
||||
async def _execute_output_node_real(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
input_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute an Output node with actual output delivery"""
|
||||
|
||||
output_type = node.get('data', {}).get('output_type')
|
||||
|
||||
if output_type == 'webhook':
|
||||
webhook_url = node.get('data', {}).get('config', {}).get('url')
|
||||
|
||||
if webhook_url:
|
||||
import httpx
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(webhook_url, json=input_data)
|
||||
|
||||
return {
|
||||
'output': {
|
||||
'webhook_sent': True,
|
||||
'status_code': response.status_code,
|
||||
'response': response.text[:500] # Limit response size
|
||||
},
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0,
|
||||
'execution_time_ms': 200
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'output': {
|
||||
'webhook_sent': False,
|
||||
'error': str(e)
|
||||
},
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0,
|
||||
'execution_time_ms': 100
|
||||
}
|
||||
|
||||
# For other output types, simulate delivery
|
||||
return {
|
||||
'output': {
|
||||
'output_type': output_type,
|
||||
'delivered': True,
|
||||
'data_sent': input_data
|
||||
},
|
||||
'tokens_used': 0,
|
||||
'cost_cents': 0,
|
||||
'execution_time_ms': 50
|
||||
}
|
||||
|
||||
def _cleanup_workflow_data(self, workflow_id: str):
|
||||
"""Clean up all data related to a workflow"""
|
||||
|
||||
# Delete triggers
|
||||
stmt = delete(WorkflowTrigger).where(WorkflowTrigger.workflow_id == workflow_id)
|
||||
self.db.execute(stmt)
|
||||
|
||||
# Delete sessions and messages
|
||||
stmt = delete(WorkflowMessage).where(WorkflowMessage.workflow_id == workflow_id)
|
||||
self.db.execute(stmt)
|
||||
|
||||
stmt = delete(WorkflowSession).where(WorkflowSession.workflow_id == workflow_id)
|
||||
self.db.execute(stmt)
|
||||
|
||||
# Delete executions
|
||||
stmt = delete(WorkflowExecution).where(WorkflowExecution.workflow_id == workflow_id)
|
||||
self.db.execute(stmt)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
|
||||
# WorkflowExecutor functionality integrated directly into WorkflowService
|
||||
# for better cohesion and to avoid mock implementations
|
||||
Reference in New Issue
Block a user