GT AI OS Community Edition v2.0.33

Security hardening release addressing CodeQL and Dependabot alerts:

- Fix stack trace exposure in error responses
- Add SSRF protection with DNS resolution checking
- Implement proper URL hostname validation (replaces substring matching)
- Add centralized path sanitization to prevent path traversal
- Fix ReDoS vulnerability in email validation regex
- Improve HTML sanitization in validation utilities
- Fix capability wildcard matching in auth utilities
- Update glob dependency to address CVE
- Add CodeQL suppression comments for verified false positives

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
HackWeasel
2025-12-12 17:04:45 -05:00
commit b9dfb86260
746 changed files with 232071 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
"""
GT 2.0 Tenant Backend Services
Business logic and orchestration services for tenant applications.
"""

View File

@@ -0,0 +1,451 @@
"""
Access Controller Service for GT 2.0
Manages resource access control with capability-based security.
Ensures perfect tenant isolation and proper permission cascading.
"""
import os
import stat
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
import logging
from pathlib import Path
from app.models.access_group import (
AccessGroup, TenantStructure, User, Resource,
ResourceCreate, ResourceUpdate, ResourceResponse
)
from app.core.security import verify_capability_token
from app.core.database import get_db_session
logger = logging.getLogger(__name__)
class AccessController:
"""
Centralized access control service
Manages permissions for all resources with tenant isolation
"""
def __init__(self, tenant_domain: str):
self.tenant_domain = tenant_domain
self.base_path = Path(f"/data/{tenant_domain}")
self._ensure_tenant_directory()
def _ensure_tenant_directory(self):
"""
Ensure tenant directory exists with proper permissions
OS User: gt-{tenant_domain}-{pod_id}
Permissions: 700 (owner only)
"""
if not self.base_path.exists():
self.base_path.mkdir(parents=True, exist_ok=True)
# Set strict permissions - owner only
os.chmod(self.base_path, stat.S_IRWXU) # 700
logger.info(f"Created tenant directory: {self.base_path} with 700 permissions")
async def check_permission(
self,
user_id: str,
resource: Resource,
action: str = "read"
) -> Tuple[bool, Optional[str]]:
"""
Check if user has permission for action on resource
Args:
user_id: User requesting access
resource: Resource being accessed
action: read, write, delete, share
Returns:
Tuple of (allowed, reason)
"""
# Verify tenant isolation
if resource.tenant_domain != self.tenant_domain:
logger.warning(f"Cross-tenant access attempt: {user_id} -> {resource.id}")
return False, "Cross-tenant access denied"
# Owner has all permissions
if resource.owner_id == user_id:
return True, "Owner access granted"
# Check action-specific permissions
if action == "read":
return self._check_read_permission(user_id, resource)
elif action == "write":
return self._check_write_permission(user_id, resource)
elif action == "delete":
return False, "Only owner can delete"
elif action == "share":
return False, "Only owner can share"
else:
return False, f"Unknown action: {action}"
def _check_read_permission(self, user_id: str, resource: Resource) -> Tuple[bool, str]:
"""Check read permission based on access group"""
if resource.access_group == AccessGroup.ORGANIZATION:
return True, "Organization-wide read access"
elif resource.access_group == AccessGroup.TEAM:
if user_id in resource.team_members:
return True, "Team member read access"
return False, "Not a team member"
else: # INDIVIDUAL
return False, "Private resource"
def _check_write_permission(self, user_id: str, resource: Resource) -> Tuple[bool, str]:
"""Check write permission - only owner can write"""
return False, "Only owner can modify"
async def create_resource(
self,
user_id: str,
resource_data: ResourceCreate,
capability_token: str
) -> Resource:
"""
Create a new resource with proper access control
Args:
user_id: User creating the resource
resource_data: Resource creation data
capability_token: JWT capability token
Returns:
Created resource
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
# Create resource
resource = Resource(
id=self._generate_resource_id(),
name=resource_data.name,
resource_type=resource_data.resource_type,
owner_id=user_id,
tenant_domain=self.tenant_domain,
access_group=resource_data.access_group,
team_members=resource_data.team_members or [],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
metadata=resource_data.metadata or {},
file_path=None
)
# Create file-based storage if needed
if self._requires_file_storage(resource.resource_type):
resource.file_path = await self._create_resource_file(resource)
# Audit log
logger.info(f"Resource created: {resource.id} by {user_id} in {self.tenant_domain}")
return resource
async def update_resource_access(
self,
user_id: str,
resource_id: str,
new_access_group: AccessGroup,
team_members: Optional[List[str]] = None
) -> Resource:
"""
Update resource access group
Args:
user_id: User requesting update
resource_id: Resource to update
new_access_group: New access level
team_members: Team members if team access
Returns:
Updated resource
"""
# Load resource
resource = await self._load_resource(resource_id)
# Check permission
allowed, reason = await self.check_permission(user_id, resource, "share")
if not allowed:
raise PermissionError(f"Access denied: {reason}")
# Update access
old_group = resource.access_group
resource.update_access_group(new_access_group, team_members)
# Update file permissions if needed
if resource.file_path:
await self._update_file_permissions(resource)
# Audit log
logger.info(
f"Access updated: {resource_id} from {old_group} to {new_access_group} "
f"by {user_id}"
)
return resource
async def list_accessible_resources(
self,
user_id: str,
resource_type: Optional[str] = None
) -> List[Resource]:
"""
List all resources accessible to user
Args:
user_id: User requesting list
resource_type: Filter by type
Returns:
List of accessible resources
"""
accessible = []
# Get all resources in tenant
all_resources = await self._list_tenant_resources(resource_type)
for resource in all_resources:
allowed, _ = await self.check_permission(user_id, resource, "read")
if allowed:
accessible.append(resource)
return accessible
async def get_resource_stats(self, user_id: str) -> Dict[str, Any]:
"""
Get resource statistics for user
Args:
user_id: User to get stats for
Returns:
Statistics dictionary
"""
all_resources = await self._list_tenant_resources()
owned = [r for r in all_resources if r.owner_id == user_id]
accessible = await self.list_accessible_resources(user_id)
stats = {
"owned_count": len(owned),
"accessible_count": len(accessible),
"by_type": {},
"by_access_group": {
AccessGroup.INDIVIDUAL: 0,
AccessGroup.TEAM: 0,
AccessGroup.ORGANIZATION: 0
}
}
for resource in owned:
# Count by type
if resource.resource_type not in stats["by_type"]:
stats["by_type"][resource.resource_type] = 0
stats["by_type"][resource.resource_type] += 1
# Count by access group
stats["by_access_group"][resource.access_group] += 1
return stats
def _generate_resource_id(self) -> str:
"""Generate unique resource ID"""
import uuid
return str(uuid.uuid4())
def _requires_file_storage(self, resource_type: str) -> bool:
"""Check if resource type requires file storage"""
file_based_types = [
"agent", "dataset", "document", "workflow",
"notebook", "model", "configuration"
]
return resource_type in file_based_types
async def _create_resource_file(self, resource: Resource) -> str:
"""
Create file for resource with proper permissions
Args:
resource: Resource to create file for
Returns:
File path
"""
# Determine path based on resource type
type_dir = self.base_path / resource.resource_type / resource.id
type_dir.mkdir(parents=True, exist_ok=True)
# Create main file
file_path = type_dir / "data.json"
file_path.touch()
# Set strict permissions - 700 for directory, 600 for file
os.chmod(type_dir, stat.S_IRWXU) # 700
os.chmod(file_path, stat.S_IRUSR | stat.S_IWUSR) # 600
logger.info(f"Created resource file: {file_path} with secure permissions")
return str(file_path)
async def _update_file_permissions(self, resource: Resource):
"""Update file permissions (always 700/600 for security)"""
if not resource.file_path or not Path(resource.file_path).exists():
return
# Permissions don't change based on access group
# All files remain 700/600 for OS-level security
# Access control is handled at application level
pass
async def _load_resource(self, resource_id: str) -> Resource:
"""Load resource from storage"""
try:
# Search for resource in all resource type directories
for resource_type_dir in self.base_path.iterdir():
if not resource_type_dir.is_dir():
continue
resource_file = resource_type_dir / "data.json"
if resource_file.exists():
try:
import json
with open(resource_file, 'r') as f:
resources_data = json.load(f)
if not isinstance(resources_data, list):
resources_data = [resources_data]
for resource_data in resources_data:
if resource_data.get('id') == resource_id:
return Resource(
id=resource_data['id'],
name=resource_data['name'],
resource_type=resource_data['resource_type'],
owner_id=resource_data['owner_id'],
tenant_domain=resource_data['tenant_domain'],
access_group=AccessGroup(resource_data['access_group']),
team_members=resource_data.get('team_members', []),
created_at=datetime.fromisoformat(resource_data['created_at']),
updated_at=datetime.fromisoformat(resource_data['updated_at']),
metadata=resource_data.get('metadata', {}),
file_path=resource_data.get('file_path')
)
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f"Failed to parse resource file {resource_file}: {e}")
continue
raise ValueError(f"Resource {resource_id} not found")
except Exception as e:
logger.error(f"Failed to load resource {resource_id}: {e}")
raise
async def _list_tenant_resources(
self,
resource_type: Optional[str] = None
) -> List[Resource]:
"""List all resources in tenant"""
try:
import json
resources = []
# If specific resource type requested, search only that directory
search_dirs = [self.base_path / resource_type] if resource_type else list(self.base_path.iterdir())
for resource_type_dir in search_dirs:
if not resource_type_dir.exists() or not resource_type_dir.is_dir():
continue
resource_file = resource_type_dir / "data.json"
if resource_file.exists():
try:
with open(resource_file, 'r') as f:
resources_data = json.load(f)
if not isinstance(resources_data, list):
resources_data = [resources_data]
for resource_data in resources_data:
try:
resource = Resource(
id=resource_data['id'],
name=resource_data['name'],
resource_type=resource_data['resource_type'],
owner_id=resource_data['owner_id'],
tenant_domain=resource_data['tenant_domain'],
access_group=AccessGroup(resource_data['access_group']),
team_members=resource_data.get('team_members', []),
created_at=datetime.fromisoformat(resource_data['created_at']),
updated_at=datetime.fromisoformat(resource_data['updated_at']),
metadata=resource_data.get('metadata', {}),
file_path=resource_data.get('file_path')
)
resources.append(resource)
except (KeyError, ValueError) as e:
logger.warning(f"Failed to parse resource data: {e}")
continue
except (json.JSONDecodeError, IOError) as e:
logger.warning(f"Failed to read resource file {resource_file}: {e}")
continue
return resources
except Exception as e:
logger.error(f"Failed to list tenant resources: {e}")
raise
class AccessControlMiddleware:
"""
Middleware for enforcing access control on API requests
"""
def __init__(self, tenant_domain: str):
self.controller = AccessController(tenant_domain)
async def verify_request(
self,
user_id: str,
resource_id: str,
action: str,
capability_token: str
) -> bool:
"""
Verify request has proper permissions
Args:
user_id: User making request
resource_id: Resource being accessed
action: Action being performed
capability_token: JWT capability token
Returns:
True if allowed, raises PermissionError if not
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data:
raise PermissionError("Invalid capability token")
# Verify tenant match
if token_data.get("tenant_id") != self.controller.tenant_domain:
raise PermissionError("Tenant mismatch in capability token")
# Load resource and check permission
resource = await self.controller._load_resource(resource_id)
allowed, reason = await self.controller.check_permission(
user_id, resource, action
)
if not allowed:
logger.warning(
f"Access denied: {user_id} -> {resource_id} ({action}): {reason}"
)
raise PermissionError(f"Access denied: {reason}")
return True

View File

@@ -0,0 +1,920 @@
"""
GT 2.0 Agent Orchestrator Client
Client for interacting with the Resource Cluster's Agent Orchestration system.
Enables spawning and managing subagents for complex task execution.
"""
import logging
import asyncio
import httpx
import uuid
from typing import Dict, Any, List, Optional
from datetime import datetime
from enum import Enum
from app.services.task_classifier import SubagentType, TaskClassification
from app.models.agent import Agent
logger = logging.getLogger(__name__)
class ExecutionStrategy(str, Enum):
"""Execution strategies for subagents"""
SEQUENTIAL = "sequential"
PARALLEL = "parallel"
CONDITIONAL = "conditional"
PIPELINE = "pipeline"
MAP_REDUCE = "map_reduce"
class SubagentOrchestrator:
"""
Orchestrates subagent execution for complex tasks.
Manages lifecycle of subagents spawned from main agent templates,
coordinates their execution, and aggregates results.
"""
def __init__(self, tenant_domain: str, user_id: str):
self.tenant_domain = tenant_domain
self.user_id = user_id
self.resource_cluster_url = "http://resource-cluster:8000"
self.active_subagents: Dict[str, Dict[str, Any]] = {}
self.execution_history: List[Dict[str, Any]] = []
async def execute_task_plan(
self,
task_classification: TaskClassification,
parent_agent: Agent,
conversation_id: str,
user_message: str,
available_tools: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Execute a task plan using subagents.
Args:
task_classification: Task classification with execution plan
parent_agent: Parent agent spawning subagents
conversation_id: Current conversation ID
user_message: Original user message
available_tools: Available MCP tools
Returns:
Aggregated results from subagent execution
"""
try:
execution_id = str(uuid.uuid4())
logger.info(f"Starting subagent execution {execution_id} for {task_classification.complexity} task")
# Track execution
execution_record = {
"execution_id": execution_id,
"conversation_id": conversation_id,
"parent_agent_id": parent_agent.id,
"task_complexity": task_classification.complexity,
"started_at": datetime.now().isoformat(),
"subagent_plan": task_classification.subagent_plan
}
self.execution_history.append(execution_record)
# Determine execution strategy
strategy = self._determine_strategy(task_classification)
# Execute based on strategy
if strategy == ExecutionStrategy.PARALLEL:
results = await self._execute_parallel(
task_classification.subagent_plan,
parent_agent,
conversation_id,
user_message,
available_tools
)
elif strategy == ExecutionStrategy.SEQUENTIAL:
results = await self._execute_sequential(
task_classification.subagent_plan,
parent_agent,
conversation_id,
user_message,
available_tools
)
elif strategy == ExecutionStrategy.PIPELINE:
results = await self._execute_pipeline(
task_classification.subagent_plan,
parent_agent,
conversation_id,
user_message,
available_tools
)
else:
# Default to sequential
results = await self._execute_sequential(
task_classification.subagent_plan,
parent_agent,
conversation_id,
user_message,
available_tools
)
# Update execution record
execution_record["completed_at"] = datetime.now().isoformat()
execution_record["results"] = results
# Synthesize final response
final_response = await self._synthesize_results(
results,
task_classification,
user_message
)
logger.info(f"Completed subagent execution {execution_id}")
return {
"execution_id": execution_id,
"strategy": strategy,
"subagent_results": results,
"final_response": final_response,
"execution_time_ms": self._calculate_execution_time(execution_record)
}
except Exception as e:
logger.error(f"Subagent execution failed: {e}")
return {
"error": str(e),
"partial_results": self.active_subagents
}
async def _execute_parallel(
self,
subagent_plan: List[Dict[str, Any]],
parent_agent: Agent,
conversation_id: str,
user_message: str,
available_tools: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Execute subagents in parallel"""
# Group subagents by priority
priority_groups = {}
for plan_item in subagent_plan:
priority = plan_item.get("priority", 1)
if priority not in priority_groups:
priority_groups[priority] = []
priority_groups[priority].append(plan_item)
results = {}
# Execute each priority group
for priority in sorted(priority_groups.keys()):
group_tasks = []
for plan_item in priority_groups[priority]:
# Check dependencies
if self._dependencies_met(plan_item, results):
task = asyncio.create_task(
self._execute_subagent(
plan_item,
parent_agent,
conversation_id,
user_message,
available_tools,
results
)
)
group_tasks.append((plan_item["id"], task))
# Wait for group to complete
for agent_id, task in group_tasks:
try:
results[agent_id] = await task
except Exception as e:
logger.error(f"Subagent {agent_id} failed: {e}")
results[agent_id] = {"error": str(e)}
return results
async def _execute_sequential(
self,
subagent_plan: List[Dict[str, Any]],
parent_agent: Agent,
conversation_id: str,
user_message: str,
available_tools: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Execute subagents sequentially"""
results = {}
for plan_item in subagent_plan:
if self._dependencies_met(plan_item, results):
try:
results[plan_item["id"]] = await self._execute_subagent(
plan_item,
parent_agent,
conversation_id,
user_message,
available_tools,
results
)
except Exception as e:
logger.error(f"Subagent {plan_item['id']} failed: {e}")
results[plan_item["id"]] = {"error": str(e)}
return results
async def _execute_pipeline(
self,
subagent_plan: List[Dict[str, Any]],
parent_agent: Agent,
conversation_id: str,
user_message: str,
available_tools: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Execute subagents in pipeline mode"""
results = {}
pipeline_data = {"original_message": user_message}
for plan_item in subagent_plan:
try:
# Pass output from previous stage as input
result = await self._execute_subagent(
plan_item,
parent_agent,
conversation_id,
user_message,
available_tools,
results,
pipeline_data
)
results[plan_item["id"]] = result
# Update pipeline data with output
if "output" in result:
pipeline_data = result["output"]
except Exception as e:
logger.error(f"Pipeline stage {plan_item['id']} failed: {e}")
results[plan_item["id"]] = {"error": str(e)}
break # Pipeline broken
return results
async def _execute_subagent(
self,
plan_item: Dict[str, Any],
parent_agent: Agent,
conversation_id: str,
user_message: str,
available_tools: List[Dict[str, Any]],
previous_results: Dict[str, Any],
pipeline_data: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Execute a single subagent"""
subagent_id = plan_item["id"]
subagent_type = plan_item["type"]
task_description = plan_item["task"]
logger.info(f"Executing subagent {subagent_id} ({subagent_type}): {task_description[:50]}...")
# Track subagent
self.active_subagents[subagent_id] = {
"type": subagent_type,
"task": task_description,
"started_at": datetime.now().isoformat(),
"status": "running"
}
try:
# Create subagent configuration based on type
subagent_config = self._create_subagent_config(
subagent_type,
parent_agent,
task_description,
pipeline_data
)
# Select tools for this subagent
subagent_tools = self._select_tools_for_subagent(
subagent_type,
available_tools
)
# Execute subagent based on type
if subagent_type == SubagentType.RESEARCH:
result = await self._execute_research_agent(
subagent_config,
task_description,
subagent_tools,
conversation_id
)
elif subagent_type == SubagentType.PLANNING:
result = await self._execute_planning_agent(
subagent_config,
task_description,
user_message,
previous_results
)
elif subagent_type == SubagentType.IMPLEMENTATION:
result = await self._execute_implementation_agent(
subagent_config,
task_description,
subagent_tools,
previous_results
)
elif subagent_type == SubagentType.VALIDATION:
result = await self._execute_validation_agent(
subagent_config,
task_description,
previous_results
)
elif subagent_type == SubagentType.SYNTHESIS:
result = await self._execute_synthesis_agent(
subagent_config,
task_description,
previous_results
)
elif subagent_type == SubagentType.ANALYST:
result = await self._execute_analyst_agent(
subagent_config,
task_description,
previous_results
)
else:
# Default execution
result = await self._execute_generic_agent(
subagent_config,
task_description,
subagent_tools
)
# Update tracking
self.active_subagents[subagent_id]["status"] = "completed"
self.active_subagents[subagent_id]["completed_at"] = datetime.now().isoformat()
self.active_subagents[subagent_id]["result"] = result
return result
except Exception as e:
logger.error(f"Subagent {subagent_id} execution failed: {e}")
self.active_subagents[subagent_id]["status"] = "failed"
self.active_subagents[subagent_id]["error"] = str(e)
raise
async def _execute_research_agent(
self,
config: Dict[str, Any],
task: str,
tools: List[Dict[str, Any]],
conversation_id: str
) -> Dict[str, Any]:
"""Execute research subagent"""
# Research agents focus on information gathering
prompt = f"""You are a research specialist. Your task is to:
{task}
Available tools: {[t['name'] for t in tools]}
Gather comprehensive information and return structured findings."""
result = await self._call_llm_with_tools(
prompt,
config,
tools,
max_iterations=3
)
return {
"type": "research",
"findings": result.get("content", ""),
"sources": result.get("tool_results", []),
"output": result
}
async def _execute_planning_agent(
self,
config: Dict[str, Any],
task: str,
original_query: str,
previous_results: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute planning subagent"""
context = self._format_previous_results(previous_results)
prompt = f"""You are a planning specialist. Break down this task into actionable steps:
Original request: {original_query}
Specific task: {task}
Context from previous agents:
{context}
Create a detailed execution plan with clear steps."""
result = await self._call_llm(prompt, config)
return {
"type": "planning",
"plan": result.get("content", ""),
"steps": self._extract_steps(result.get("content", "")),
"output": result
}
async def _execute_implementation_agent(
self,
config: Dict[str, Any],
task: str,
tools: List[Dict[str, Any]],
previous_results: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute implementation subagent"""
context = self._format_previous_results(previous_results)
prompt = f"""You are an implementation specialist. Execute this task:
{task}
Context:
{context}
Available tools: {[t['name'] for t in tools]}
Complete the implementation and return results."""
result = await self._call_llm_with_tools(
prompt,
config,
tools,
max_iterations=5
)
return {
"type": "implementation",
"implementation": result.get("content", ""),
"tool_calls": result.get("tool_calls", []),
"output": result
}
async def _execute_validation_agent(
self,
config: Dict[str, Any],
task: str,
previous_results: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute validation subagent"""
context = self._format_previous_results(previous_results)
prompt = f"""You are a validation specialist. Verify the following:
{task}
Results to validate:
{context}
Check for correctness, completeness, and quality."""
result = await self._call_llm(prompt, config)
return {
"type": "validation",
"validation_result": result.get("content", ""),
"issues_found": self._extract_issues(result.get("content", "")),
"output": result
}
async def _execute_synthesis_agent(
self,
config: Dict[str, Any],
task: str,
previous_results: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute synthesis subagent"""
all_results = self._format_all_results(previous_results)
prompt = f"""You are a synthesis specialist. Combine and summarize these results:
Task: {task}
Results from all agents:
{all_results}
Create a comprehensive, coherent response that addresses the original request."""
result = await self._call_llm(prompt, config)
return {
"type": "synthesis",
"final_response": result.get("content", ""),
"output": result
}
async def _execute_analyst_agent(
self,
config: Dict[str, Any],
task: str,
previous_results: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute analyst subagent"""
data = self._format_previous_results(previous_results)
prompt = f"""You are an analysis specialist. Analyze the following:
{task}
Data to analyze:
{data}
Identify patterns, insights, and recommendations."""
result = await self._call_llm(prompt, config)
return {
"type": "analysis",
"analysis": result.get("content", ""),
"insights": self._extract_insights(result.get("content", "")),
"output": result
}
async def _execute_generic_agent(
self,
config: Dict[str, Any],
task: str,
tools: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Execute generic subagent"""
prompt = f"""Complete the following task:
{task}
Available tools: {[t['name'] for t in tools] if tools else 'None'}"""
if tools:
result = await self._call_llm_with_tools(prompt, config, tools)
else:
result = await self._call_llm(prompt, config)
return {
"type": "generic",
"result": result.get("content", ""),
"output": result
}
async def _call_llm(
self,
prompt: str,
config: Dict[str, Any]
) -> Dict[str, Any]:
"""Call LLM without tools"""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
# Require model to be specified in config - no hardcoded fallbacks
model = config.get("model")
if not model:
raise ValueError(f"No model specified in subagent config: {config}")
response = await client.post(
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
json={
"model": model,
"messages": [
{"role": "system", "content": config.get("instructions", "")},
{"role": "user", "content": prompt}
],
"temperature": config.get("temperature", 0.7),
"max_tokens": config.get("max_tokens", 2000)
},
headers={
"X-Tenant-ID": self.tenant_domain,
"X-User-ID": self.user_id
}
)
if response.status_code == 200:
result = response.json()
return {
"content": result["choices"][0]["message"]["content"],
"model": result["model"]
}
else:
raise Exception(f"LLM call failed: {response.status_code}")
except Exception as e:
logger.error(f"LLM call failed: {e}")
return {"content": f"Error: {str(e)}"}
async def _call_llm_with_tools(
self,
prompt: str,
config: Dict[str, Any],
tools: List[Dict[str, Any]],
max_iterations: int = 3
) -> Dict[str, Any]:
"""Call LLM with tool execution capability"""
messages = [
{"role": "system", "content": config.get("instructions", "")},
{"role": "user", "content": prompt}
]
tool_results = []
iterations = 0
while iterations < max_iterations:
try:
async with httpx.AsyncClient(timeout=60.0) as client:
# Require model to be specified in config - no hardcoded fallbacks
model = config.get("model")
if not model:
raise ValueError(f"No model specified in subagent config: {config}")
response = await client.post(
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
json={
"model": model,
"messages": messages,
"temperature": config.get("temperature", 0.7),
"max_tokens": config.get("max_tokens", 2000),
"tools": tools,
"tool_choice": "auto"
},
headers={
"X-Tenant-ID": self.tenant_domain,
"X-User-ID": self.user_id
}
)
if response.status_code != 200:
raise Exception(f"LLM call failed: {response.status_code}")
result = response.json()
choice = result["choices"][0]
message = choice["message"]
# Add agent's response to messages
messages.append(message)
# Check for tool calls
if message.get("tool_calls"):
# Execute tools
for tool_call in message["tool_calls"]:
tool_result = await self._execute_tool(
tool_call["function"]["name"],
tool_call["function"].get("arguments", {})
)
tool_results.append({
"tool": tool_call["function"]["name"],
"result": tool_result
})
# Add tool result to messages
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"content": str(tool_result)
})
iterations += 1
continue # Get next response
# No more tool calls, return final result
return {
"content": message.get("content", ""),
"tool_calls": message.get("tool_calls", []),
"tool_results": tool_results,
"model": result["model"]
}
except Exception as e:
logger.error(f"LLM with tools call failed: {e}")
return {"content": f"Error: {str(e)}", "tool_results": tool_results}
iterations += 1
# Max iterations reached
return {
"content": "Max iterations reached",
"tool_results": tool_results
}
async def _execute_tool(
self,
tool_name: str,
arguments: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute an MCP tool"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{self.resource_cluster_url}/api/v1/mcp/execute",
json={
"tool_name": tool_name,
"parameters": arguments,
"tenant_domain": self.tenant_domain,
"user_id": self.user_id
}
)
if response.status_code == 200:
return response.json()
else:
return {"error": f"Tool execution failed: {response.status_code}"}
except Exception as e:
logger.error(f"Tool execution failed: {e}")
return {"error": str(e)}
def _determine_strategy(self, task_classification: TaskClassification) -> ExecutionStrategy:
"""Determine execution strategy based on task classification"""
if task_classification.parallel_execution:
return ExecutionStrategy.PARALLEL
elif len(task_classification.subagent_plan) > 3:
return ExecutionStrategy.PIPELINE
else:
return ExecutionStrategy.SEQUENTIAL
def _dependencies_met(
self,
plan_item: Dict[str, Any],
completed_results: Dict[str, Any]
) -> bool:
"""Check if dependencies are met for a subagent"""
depends_on = plan_item.get("depends_on", [])
return all(dep in completed_results for dep in depends_on)
def _create_subagent_config(
self,
subagent_type: SubagentType,
parent_agent: Agent,
task: str,
pipeline_data: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create configuration for subagent"""
# Base config from parent
config = {
"model": parent_agent.model_name,
"temperature": parent_agent.model_settings.get("temperature", 0.7),
"max_tokens": parent_agent.model_settings.get("max_tokens", 2000)
}
# Customize based on subagent type
if subagent_type == SubagentType.RESEARCH:
config["instructions"] = "You are a research specialist. Be thorough and accurate."
config["temperature"] = 0.3 # Lower for factual research
elif subagent_type == SubagentType.PLANNING:
config["instructions"] = "You are a planning specialist. Create clear, actionable plans."
config["temperature"] = 0.5
elif subagent_type == SubagentType.IMPLEMENTATION:
config["instructions"] = "You are an implementation specialist. Execute tasks precisely."
config["temperature"] = 0.3
elif subagent_type == SubagentType.SYNTHESIS:
config["instructions"] = "You are a synthesis specialist. Create coherent summaries."
config["temperature"] = 0.7
else:
config["instructions"] = parent_agent.instructions or ""
return config
def _select_tools_for_subagent(
self,
subagent_type: SubagentType,
available_tools: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Select appropriate tools for subagent type"""
if not available_tools:
return []
# Tool selection based on subagent type
if subagent_type == SubagentType.RESEARCH:
# Research agents get search tools
return [t for t in available_tools if any(
keyword in t["name"].lower()
for keyword in ["search", "find", "list", "get", "fetch"]
)]
elif subagent_type == SubagentType.IMPLEMENTATION:
# Implementation agents get action tools
return [t for t in available_tools if any(
keyword in t["name"].lower()
for keyword in ["create", "update", "write", "execute", "run"]
)]
elif subagent_type == SubagentType.VALIDATION:
# Validation agents get read/check tools
return [t for t in available_tools if any(
keyword in t["name"].lower()
for keyword in ["read", "check", "verify", "test"]
)]
else:
# Give all tools to other types
return available_tools
async def _synthesize_results(
self,
results: Dict[str, Any],
task_classification: TaskClassification,
user_message: str
) -> str:
"""Synthesize final response from all subagent results"""
# Look for synthesis agent result first
for agent_id, result in results.items():
if result.get("type") == "synthesis":
return result.get("final_response", "")
# Otherwise, compile results
response_parts = []
# Add results in order of priority
for plan_item in sorted(
task_classification.subagent_plan,
key=lambda x: x.get("priority", 999)
):
agent_id = plan_item["id"]
if agent_id in results:
result = results[agent_id]
if "error" not in result:
content = result.get("output", {}).get("content", "")
if content:
response_parts.append(content)
return "\n\n".join(response_parts) if response_parts else "Task completed"
def _format_previous_results(self, results: Dict[str, Any]) -> str:
"""Format previous results for context"""
if not results:
return "No previous results"
formatted = []
for agent_id, result in results.items():
if "error" not in result:
formatted.append(f"{agent_id}: {result.get('output', {}).get('content', '')[:200]}")
return "\n".join(formatted) if formatted else "No valid previous results"
def _format_all_results(self, results: Dict[str, Any]) -> str:
"""Format all results for synthesis"""
if not results:
return "No results to synthesize"
formatted = []
for agent_id, result in results.items():
if "error" not in result:
agent_type = result.get("type", "unknown")
content = result.get("output", {}).get("content", "")
formatted.append(f"[{agent_type}] {agent_id}:\n{content}\n")
return "\n".join(formatted) if formatted else "No valid results to synthesize"
def _extract_steps(self, content: str) -> List[str]:
"""Extract steps from planning content"""
import re
steps = []
# Look for numbered lists
pattern = r"(?:^|\n)\s*(?:\d+[\.\)]|\-|\*)\s+(.+)"
matches = re.findall(pattern, content)
for match in matches:
steps.append(match.strip())
return steps
def _extract_issues(self, content: str) -> List[str]:
"""Extract issues from validation content"""
import re
issues = []
# Look for issue indicators
issue_patterns = [
r"(?:issue|problem|error|warning|concern):\s*(.+)",
r"(?:^|\n)\s*[\-\*]\s*(?:Issue|Problem|Error):\s*(.+)"
]
for pattern in issue_patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
issues.extend([m.strip() for m in matches])
return issues
def _extract_insights(self, content: str) -> List[str]:
"""Extract insights from analysis content"""
import re
insights = []
# Look for insight indicators
insight_patterns = [
r"(?:insight|finding|observation|pattern):\s*(.+)",
r"(?:^|\n)\s*\d+[\.\)]\s*(.+(?:shows?|indicates?|suggests?|reveals?).+)"
]
for pattern in insight_patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
insights.extend([m.strip() for m in matches])
return insights
def _calculate_execution_time(self, execution_record: Dict[str, Any]) -> float:
"""Calculate execution time in milliseconds"""
if "completed_at" in execution_record and "started_at" in execution_record:
start = datetime.fromisoformat(execution_record["started_at"])
end = datetime.fromisoformat(execution_record["completed_at"])
return (end - start).total_seconds() * 1000
return 0.0
# Factory function
def get_subagent_orchestrator(tenant_domain: str, user_id: str) -> SubagentOrchestrator:
"""Get subagent orchestrator instance"""
return SubagentOrchestrator(tenant_domain, user_id)

View File

@@ -0,0 +1,854 @@
import json
import os
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from pathlib import Path
from app.core.config import get_settings
from app.core.postgresql_client import get_postgresql_client
from app.core.permissions import get_user_role, validate_visibility_permission, can_edit_resource, can_delete_resource, is_effective_owner
from app.services.category_service import CategoryService
import logging
logger = logging.getLogger(__name__)
class AgentService:
"""GT 2.0 PostgreSQL+PGVector Agent Service with Perfect Tenant Isolation"""
def __init__(self, tenant_domain: str, user_id: str, user_email: str = None):
"""Initialize with tenant and user isolation using PostgreSQL+PGVector storage"""
self.tenant_domain = tenant_domain
self.user_id = user_id
self.user_email = user_email or user_id # Fallback to user_id if no email provided
self.settings = get_settings()
self._resolved_user_uuid = None # Cache for resolved user UUID (performance optimization)
logger.info(f"Agent service initialized with PostgreSQL+PGVector for {tenant_domain}/{user_id} (email: {self.user_email})")
async def _get_resolved_user_uuid(self, user_identifier: Optional[str] = None) -> str:
"""
Resolve user identifier to UUID with caching for performance.
This optimization reduces repeated database lookups by caching the resolved UUID.
Performance impact: ~50% reduction in query time for operations with multiple queries.
Pattern matches conversation_service.py for consistency.
"""
identifier = user_identifier or self.user_email or self.user_id
# Return cached UUID if already resolved for this instance
if self._resolved_user_uuid and str(identifier) in [str(self.user_email), str(self.user_id)]:
return self._resolved_user_uuid
# Check if already a UUID
if "@" not in str(identifier):
try:
# Validate it's a proper UUID format
uuid.UUID(str(identifier))
if str(identifier) == str(self.user_id):
self._resolved_user_uuid = str(identifier)
return str(identifier)
except (ValueError, AttributeError):
pass # Not a valid UUID, treat as email/username
# Resolve email to UUID
pg_client = await get_postgresql_client()
query = """
SELECT id FROM users
WHERE (email = $1 OR username = $1)
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
LIMIT 1
"""
result = await pg_client.fetch_one(query, str(identifier), self.tenant_domain)
if not result:
raise ValueError(f"User not found: {identifier}")
user_uuid = str(result["id"])
# Cache if this is the service's primary user
if str(identifier) in [str(self.user_email), str(self.user_id)]:
self._resolved_user_uuid = user_uuid
return user_uuid
async def create_agent(
self,
name: str,
agent_type: str = "conversational",
prompt_template: str = "",
description: str = "",
capabilities: Optional[List[str]] = None,
access_group: str = "INDIVIDUAL",
**kwargs
) -> Dict[str, Any]:
"""Create a new agent using PostgreSQL+PGVector storage following GT 2.0 principles"""
try:
# Get PostgreSQL client
pg_client = await get_postgresql_client()
# Generate agent ID
agent_id = str(uuid.uuid4())
# Resolve user UUID with caching (performance optimization)
user_id = await self._get_resolved_user_uuid()
logger.info(f"Found user ID: {user_id} for email/id: {self.user_email}/{self.user_id}")
# Create agent in PostgreSQL
query = """
INSERT INTO agents (
id, name, description, system_prompt,
tenant_id, created_by, model, temperature, max_tokens,
visibility, configuration, is_active, access_group, agent_type
) VALUES (
$1, $2, $3, $4,
(SELECT id FROM tenants WHERE domain = $5 LIMIT 1),
$6,
$7, $8, $9, $10, $11, true, $12, $13
)
RETURNING id, name, description, system_prompt, model, temperature, max_tokens,
visibility, configuration, access_group, agent_type, created_at, updated_at
"""
# Prepare configuration with additional kwargs
# Ensure list fields are always lists, never None
configuration = {
"agent_type": agent_type,
"capabilities": capabilities or [],
"personality_config": kwargs.get("personality_config", {}),
"resource_preferences": kwargs.get("resource_preferences", {}),
"model_config": kwargs.get("model_config", {}),
"tags": kwargs.get("tags") or [],
"easy_prompts": kwargs.get("easy_prompts") or [],
"selected_dataset_ids": kwargs.get("selected_dataset_ids") or [],
**{k: v for k, v in kwargs.items() if k not in ["tags", "easy_prompts", "selected_dataset_ids"]}
}
# Extract model configuration
model = kwargs.get("model")
if not model:
raise ValueError("Model is required for agent creation")
temperature = kwargs.get("temperature", 0.7)
max_tokens = kwargs.get("max_tokens", 8000) # Increased to match Groq Llama 3.1 capabilities
# Use access_group as visibility directly (individual, organization only)
visibility = access_group.lower()
# Validate visibility permission based on user role
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
validate_visibility_permission(visibility, user_role)
logger.info(f"User {self.user_email} (role: {user_role}) creating agent with visibility: {visibility}")
# Auto-create category if specified (Issue #215)
# This ensures imported agents with unknown categories create those categories
# Category is stored in agent_type column
category = kwargs.get("category")
if category and isinstance(category, str) and category.strip():
category_slug = category.strip().lower()
try:
category_service = CategoryService(self.tenant_domain, user_id, self.user_email)
# Pass category_description from CSV import if provided
category_description = kwargs.get("category_description")
await category_service.get_or_create_category(category_slug, description=category_description)
logger.info(f"Ensured category exists: {category}")
except Exception as cat_err:
logger.warning(f"Failed to ensure category '{category}' exists: {cat_err}")
# Continue with agent creation even if category creation fails
# Use category as agent_type (they map to the same column)
agent_type = category_slug
agent_data = await pg_client.fetch_one(
query,
agent_id, name, description, prompt_template,
self.tenant_domain, user_id,
model, temperature, max_tokens, visibility,
json.dumps(configuration), access_group, agent_type
)
if not agent_data:
raise RuntimeError("Failed to create agent - no data returned")
# Convert to dict with proper types
# Parse configuration JSON if it's a string
config = agent_data["configuration"]
if isinstance(config, str):
config = json.loads(config)
elif config is None:
config = {}
result = {
"id": str(agent_data["id"]),
"name": agent_data["name"],
"agent_type": config.get("agent_type", "conversational"),
"prompt_template": agent_data["system_prompt"],
"description": agent_data["description"],
"capabilities": config.get("capabilities", []),
"access_group": agent_data["access_group"],
"config": config,
"model": agent_data["model"],
"temperature": float(agent_data["temperature"]) if agent_data["temperature"] is not None else None,
"max_tokens": agent_data["max_tokens"],
"top_p": config.get("top_p"),
"frequency_penalty": config.get("frequency_penalty"),
"presence_penalty": config.get("presence_penalty"),
"visibility": agent_data["visibility"],
"dataset_connection": config.get("dataset_connection"),
"selected_dataset_ids": config.get("selected_dataset_ids", []),
"max_chunks_per_query": config.get("max_chunks_per_query"),
"history_context": config.get("history_context"),
"personality_config": config.get("personality_config", {}),
"resource_preferences": config.get("resource_preferences", {}),
"tags": config.get("tags", []),
"is_favorite": config.get("is_favorite", False),
"conversation_count": 0,
"total_cost_cents": 0,
"created_at": agent_data["created_at"].isoformat(),
"updated_at": agent_data["updated_at"].isoformat(),
"user_id": self.user_id,
"tenant_domain": self.tenant_domain
}
logger.info(f"Created agent {agent_id} in PostgreSQL for user {self.user_id}")
return result
except Exception as e:
logger.error(f"Failed to create agent: {e}")
raise
async def get_user_agents(
self,
active_only: bool = True,
sort_by: Optional[str] = None,
filter_usage: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Get all agents for the current user using PostgreSQL storage"""
try:
# Get PostgreSQL client
pg_client = await get_postgresql_client()
# Resolve user UUID with caching (performance optimization)
try:
user_id = await self._get_resolved_user_uuid()
except ValueError as e:
logger.warning(f"User not found for agents list: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}: {e}")
return []
# Get user role to determine access level
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
is_admin = user_role in ["admin", "developer"]
# Query agents from PostgreSQL with conversation counts
# Admins see ALL agents, others see only their own or organization-level agents
if is_admin:
where_clause = "WHERE a.tenant_id = (SELECT id FROM tenants WHERE domain = $1)"
params = [self.tenant_domain]
else:
where_clause = "WHERE (a.created_by = $1 OR a.visibility = 'organization') AND a.tenant_id = (SELECT id FROM tenants WHERE domain = $2)"
params = [user_id, self.tenant_domain]
# Prepare user_id parameter for per-user usage tracking
# Need to add user_id as an additional parameter for usage calculations
user_id_param_index = len(params) + 1
params.append(user_id)
# Per-user usage tracking: Count only conversations for this user
query = f"""
SELECT
a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
a.is_active, a.created_by, a.agent_type,
u.full_name as created_by_name,
COUNT(CASE WHEN c.user_id = ${user_id_param_index}::uuid THEN c.id END) as user_conversation_count,
MAX(CASE WHEN c.user_id = ${user_id_param_index}::uuid THEN c.created_at END) as user_last_used_at
FROM agents a
LEFT JOIN conversations c ON a.id = c.agent_id
LEFT JOIN users u ON a.created_by = u.id
{where_clause}
"""
if active_only:
query += " AND a.is_active = true"
# Time-based usage filters (per-user)
if filter_usage == "used_last_7_days":
query += f" AND EXISTS (SELECT 1 FROM conversations c2 WHERE c2.agent_id = a.id AND c2.user_id = ${user_id_param_index}::uuid AND c2.created_at >= NOW() - INTERVAL '7 days')"
elif filter_usage == "used_last_30_days":
query += f" AND EXISTS (SELECT 1 FROM conversations c2 WHERE c2.agent_id = a.id AND c2.user_id = ${user_id_param_index}::uuid AND c2.created_at >= NOW() - INTERVAL '30 days')"
query += " GROUP BY a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens, a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at, a.is_active, a.created_by, a.agent_type, u.full_name"
# User-specific sorting
if sort_by == "recent_usage":
query += " ORDER BY user_last_used_at DESC NULLS LAST, a.updated_at DESC"
elif sort_by == "my_most_used":
query += " ORDER BY user_conversation_count DESC, a.updated_at DESC"
else:
query += " ORDER BY a.updated_at DESC"
agents_data = await pg_client.execute_query(query, *params)
# Convert to proper format
agents = []
for agent in agents_data:
# Debug logging for creator name
logger.info(f"🔍 Agent '{agent['name']}': created_by={agent.get('created_by')}, created_by_name={agent.get('created_by_name')}")
# Parse configuration JSON if it's a string
config = agent["configuration"]
if isinstance(config, str):
config = json.loads(config)
elif config is None:
config = {}
disclaimer_val = config.get("disclaimer")
easy_prompts_val = config.get("easy_prompts", [])
logger.info(f"get_user_agents - Agent {agent['name']}: disclaimer={disclaimer_val}, easy_prompts={easy_prompts_val}")
# Determine if user can edit this agent
# User can edit if they created it OR if they're admin/developer
# Use cached user_role from line 190 (no need to re-query for each agent)
is_owner = is_effective_owner(str(agent["created_by"]), str(user_id), user_role)
can_edit = can_edit_resource(str(agent["created_by"]), str(user_id), user_role, agent["visibility"])
can_delete = can_delete_resource(str(agent["created_by"]), str(user_id), user_role)
logger.info(f"Agent {agent['name']}: created_by={agent['created_by']}, user_id={user_id}, user_role={user_role}, is_owner={is_owner}, can_edit={can_edit}, can_delete={can_delete}")
agents.append({
"id": str(agent["id"]),
"name": agent["name"],
"agent_type": agent["agent_type"] or "conversational",
"prompt_template": agent["system_prompt"],
"description": agent["description"],
"capabilities": config.get("capabilities", []),
"access_group": agent["access_group"],
"config": config,
"model": agent["model"],
"temperature": float(agent["temperature"]) if agent["temperature"] is not None else None,
"max_tokens": agent["max_tokens"],
"visibility": agent["visibility"],
"dataset_connection": config.get("dataset_connection"),
"selected_dataset_ids": config.get("selected_dataset_ids", []),
"personality_config": config.get("personality_config", {}),
"resource_preferences": config.get("resource_preferences", {}),
"tags": config.get("tags", []),
"is_favorite": config.get("is_favorite", False),
"disclaimer": disclaimer_val,
"easy_prompts": easy_prompts_val,
"conversation_count": int(agent["user_conversation_count"]) if agent.get("user_conversation_count") is not None else 0,
"last_used_at": agent["user_last_used_at"].isoformat() if agent.get("user_last_used_at") else None,
"total_cost_cents": 0,
"created_at": agent["created_at"].isoformat() if agent["created_at"] else None,
"updated_at": agent["updated_at"].isoformat() if agent["updated_at"] else None,
"is_active": agent["is_active"],
"user_id": agent["created_by"],
"created_by_name": agent.get("created_by_name", "Unknown"),
"tenant_domain": self.tenant_domain,
"can_edit": can_edit,
"can_delete": can_delete,
"is_owner": is_owner
})
# Fetch team-shared agents and merge with owned agents
team_shared = await self.get_team_shared_agents(user_id)
# Merge and deduplicate (owned agents take precedence)
agent_ids_seen = {agent["id"] for agent in agents}
for team_agent in team_shared:
if team_agent["id"] not in agent_ids_seen:
agents.append(team_agent)
agent_ids_seen.add(team_agent["id"])
logger.info(f"Retrieved {len(agents)} total agents ({len(agents) - len(team_shared)} owned + {len(team_shared)} team-shared) from PostgreSQL for user {self.user_id}")
return agents
except Exception as e:
logger.error(f"Error reading agents for user {self.user_id}: {e}")
return []
async def get_team_shared_agents(self, user_id: str) -> List[Dict[str, Any]]:
"""
Get agents shared to teams where user is a member (via junction table).
Uses the user_accessible_resources view for efficient lookups.
Returns agents with permission flags:
- can_edit: True if user has 'edit' permission for this agent
- can_delete: False (only owner can delete)
- is_owner: False (team-shared agents)
- shared_via_team: True (indicates team sharing)
- shared_in_teams: Number of teams this agent is shared with
"""
try:
pg_client = await get_postgresql_client()
# Query agents using the efficient user_accessible_resources view
# This view joins team_memberships -> team_resource_shares -> agents
# Include per-user usage statistics
query = """
SELECT DISTINCT
a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
a.is_active, a.created_by, a.agent_type,
u.full_name as created_by_name,
COUNT(DISTINCT CASE WHEN c.user_id = $1::uuid THEN c.id END) as user_conversation_count,
MAX(CASE WHEN c.user_id = $1::uuid THEN c.created_at END) as user_last_used_at,
uar.best_permission as user_permission,
uar.shared_in_teams,
uar.team_ids
FROM user_accessible_resources uar
INNER JOIN agents a ON a.id = uar.resource_id
LEFT JOIN users u ON a.created_by = u.id
LEFT JOIN conversations c ON a.id = c.agent_id
WHERE uar.user_id = $1::uuid
AND uar.resource_type = 'agent'
AND a.tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
AND a.is_active = true
GROUP BY a.id, a.name, a.description, a.system_prompt, a.model, a.temperature,
a.max_tokens, a.visibility, a.configuration, a.access_group, a.created_at,
a.updated_at, a.is_active, a.created_by, a.agent_type, u.full_name,
uar.best_permission, uar.shared_in_teams, uar.team_ids
ORDER BY a.updated_at DESC
"""
agents_data = await pg_client.execute_query(query, user_id, self.tenant_domain)
# Format agents with team sharing metadata
agents = []
for agent in agents_data:
# Parse configuration JSON
config = agent["configuration"]
if isinstance(config, str):
config = json.loads(config)
elif config is None:
config = {}
# Get permission from view (will be "read" or "edit")
user_permission = agent.get("user_permission")
can_edit = user_permission == "edit"
# Get team sharing metadata
shared_in_teams = agent.get("shared_in_teams", 0)
team_ids = agent.get("team_ids", [])
agents.append({
"id": str(agent["id"]),
"name": agent["name"],
"agent_type": agent["agent_type"] or "conversational",
"prompt_template": agent["system_prompt"],
"description": agent["description"],
"capabilities": config.get("capabilities", []),
"access_group": agent["access_group"],
"config": config,
"model": agent["model"],
"temperature": float(agent["temperature"]) if agent["temperature"] is not None else None,
"max_tokens": agent["max_tokens"],
"visibility": agent["visibility"],
"dataset_connection": config.get("dataset_connection"),
"selected_dataset_ids": config.get("selected_dataset_ids", []),
"personality_config": config.get("personality_config", {}),
"resource_preferences": config.get("resource_preferences", {}),
"tags": config.get("tags", []),
"is_favorite": config.get("is_favorite", False),
"disclaimer": config.get("disclaimer"),
"easy_prompts": config.get("easy_prompts", []),
"conversation_count": int(agent["user_conversation_count"]) if agent.get("user_conversation_count") else 0,
"last_used_at": agent["user_last_used_at"].isoformat() if agent.get("user_last_used_at") else None,
"total_cost_cents": 0,
"created_at": agent["created_at"].isoformat() if agent["created_at"] else None,
"updated_at": agent["updated_at"].isoformat() if agent["updated_at"] else None,
"is_active": agent["is_active"],
"user_id": agent["created_by"],
"created_by_name": agent.get("created_by_name", "Unknown"),
"tenant_domain": self.tenant_domain,
"can_edit": can_edit,
"can_delete": False, # Only owner can delete
"is_owner": False, # Team-shared agents
"shared_via_team": True,
"shared_in_teams": shared_in_teams,
"team_ids": [str(tid) for tid in team_ids] if team_ids else [],
"team_permission": user_permission
})
logger.info(f"Retrieved {len(agents)} team-shared agents for user {user_id}")
return agents
except Exception as e:
logger.error(f"Error fetching team-shared agents for user {user_id}: {e}")
return []
async def get_agent(self, agent_id: str) -> Optional[Dict[str, Any]]:
"""Get a specific agent by ID using PostgreSQL"""
try:
# Get PostgreSQL client
pg_client = await get_postgresql_client()
# Resolve user UUID with caching (performance optimization)
try:
user_id = await self._get_resolved_user_uuid()
except ValueError as e:
logger.warning(f"User not found: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}: {e}")
return None
# Check if user is admin - admins can see all agents
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
is_admin = user_role in ["admin", "developer"]
# Query the agent first
query = """
SELECT
a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
a.is_active, a.created_by, a.agent_type,
COUNT(c.id) as conversation_count
FROM agents a
LEFT JOIN conversations c ON a.id = c.agent_id
WHERE a.id = $1 AND a.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
GROUP BY a.id, a.name, a.description, a.system_prompt, a.model, a.temperature, a.max_tokens,
a.visibility, a.configuration, a.access_group, a.created_at, a.updated_at,
a.is_active, a.created_by, a.agent_type
LIMIT 1
"""
agent_data = await pg_client.fetch_one(query, agent_id, self.tenant_domain)
logger.info(f"Agent query result: {agent_data is not None}")
# If agent doesn't exist, return None
if not agent_data:
return None
# Check access: admin, owner, organization, or team-based
if not is_admin:
is_owner = str(agent_data["created_by"]) == str(user_id)
is_org_wide = agent_data["visibility"] == "organization"
# Check team-based access if not owner or org-wide
if not is_owner and not is_org_wide:
# Import TeamService here to avoid circular dependency
from app.services.team_service import TeamService
team_service = TeamService(self.tenant_domain, str(user_id), self.user_email)
has_team_access = await team_service.check_user_resource_permission(
user_id=str(user_id),
resource_type="agent",
resource_id=agent_id,
required_permission="read"
)
if not has_team_access:
logger.warning(f"User {user_id} denied access to agent {agent_id}")
return None
logger.info(f"User {user_id} has team-based access to agent {agent_id}")
if agent_data:
# Parse configuration JSON if it's a string
config = agent_data["configuration"]
if isinstance(config, str):
config = json.loads(config)
elif config is None:
config = {}
# Convert to proper format
logger.info(f"Config disclaimer: {config.get('disclaimer')}, easy_prompts: {config.get('easy_prompts')}")
# Compute is_owner for export permission checks
is_owner = str(agent_data["created_by"]) == str(user_id)
result = {
"id": str(agent_data["id"]),
"name": agent_data["name"],
"agent_type": agent_data["agent_type"] or "conversational",
"prompt_template": agent_data["system_prompt"],
"description": agent_data["description"],
"capabilities": config.get("capabilities", []),
"access_group": agent_data["access_group"],
"config": config,
"model": agent_data["model"],
"temperature": float(agent_data["temperature"]) if agent_data["temperature"] is not None else None,
"max_tokens": agent_data["max_tokens"],
"visibility": agent_data["visibility"],
"dataset_connection": config.get("dataset_connection"),
"selected_dataset_ids": config.get("selected_dataset_ids", []),
"personality_config": config.get("personality_config", {}),
"resource_preferences": config.get("resource_preferences", {}),
"tags": config.get("tags", []),
"is_favorite": config.get("is_favorite", False),
"disclaimer": config.get("disclaimer"),
"easy_prompts": config.get("easy_prompts", []),
"conversation_count": int(agent_data["conversation_count"]) if agent_data.get("conversation_count") is not None else 0,
"total_cost_cents": 0,
"created_at": agent_data["created_at"].isoformat() if agent_data["created_at"] else None,
"updated_at": agent_data["updated_at"].isoformat() if agent_data["updated_at"] else None,
"is_active": agent_data["is_active"],
"created_by": agent_data["created_by"], # Keep DB field
"user_id": agent_data["created_by"], # Alias for compatibility
"is_owner": is_owner, # Computed ownership for export/edit permissions
"tenant_domain": self.tenant_domain
}
return result
return None
except Exception as e:
logger.error(f"Error reading agent {agent_id}: {e}")
return None
async def update_agent(
self,
agent_id: str,
updates: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""Update an agent's configuration using PostgreSQL with permission checks"""
try:
logger.info(f"Processing updates for agent {agent_id}: {updates}")
# Log which fields will be processed
logger.info(f"Update fields being processed: {list(updates.keys())}")
# Get PostgreSQL client
pg_client = await get_postgresql_client()
# Get user role for permission checks
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
# If updating visibility, validate permission
if "visibility" in updates:
validate_visibility_permission(updates["visibility"], user_role)
logger.info(f"User {self.user_email} (role: {user_role}) updating agent visibility to: {updates['visibility']}")
# Build dynamic UPDATE query based on provided updates
set_clauses = []
params = []
param_idx = 1
# Collect all configuration updates in a single object
config_updates = {}
# Handle each update field mapping to correct column names
for field, value in updates.items():
if field in ["name", "description", "access_group"]:
set_clauses.append(f"{field} = ${param_idx}")
params.append(value)
param_idx += 1
elif field == "prompt_template":
set_clauses.append(f"system_prompt = ${param_idx}")
params.append(value)
param_idx += 1
elif field in ["model", "temperature", "max_tokens", "visibility", "agent_type"]:
set_clauses.append(f"{field} = ${param_idx}")
params.append(value)
param_idx += 1
elif field == "is_active":
set_clauses.append(f"is_active = ${param_idx}")
params.append(value)
param_idx += 1
elif field in ["config", "configuration", "personality_config", "resource_preferences", "tags", "is_favorite",
"dataset_connection", "selected_dataset_ids", "disclaimer", "easy_prompts"]:
# Collect configuration updates
if field in ["config", "configuration"]:
config_updates.update(value if isinstance(value, dict) else {})
else:
config_updates[field] = value
# Apply configuration updates as a single operation
if config_updates:
set_clauses.append(f"configuration = configuration || ${param_idx}::jsonb")
params.append(json.dumps(config_updates))
param_idx += 1
if not set_clauses:
logger.warning(f"No valid update fields provided for agent {agent_id}")
return await self.get_agent(agent_id)
# Add updated_at timestamp
set_clauses.append(f"updated_at = NOW()")
# Resolve user UUID with caching (performance optimization)
try:
user_id = await self._get_resolved_user_uuid()
except ValueError as e:
logger.warning(f"User not found for update: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}: {e}")
return None
# Check if user is admin - admins can update any agent
is_admin = user_role in ["admin", "developer"]
# Build final query - admins can update any agent in tenant, others only their own
if is_admin:
query = f"""
UPDATE agents
SET {', '.join(set_clauses)}
WHERE id = ${param_idx}
AND tenant_id = (SELECT id FROM tenants WHERE domain = ${param_idx + 1})
RETURNING id
"""
params.extend([agent_id, self.tenant_domain])
else:
query = f"""
UPDATE agents
SET {', '.join(set_clauses)}
WHERE id = ${param_idx}
AND tenant_id = (SELECT id FROM tenants WHERE domain = ${param_idx + 1})
AND created_by = ${param_idx + 2}
RETURNING id
"""
params.extend([agent_id, self.tenant_domain, user_id])
# Execute update
logger.info(f"Executing update query: {query}")
logger.info(f"Query parameters: {params}")
updated_id = await pg_client.fetch_scalar(query, *params)
logger.info(f"Update result: {updated_id}")
if updated_id:
# Get updated agent data
updated_agent = await self.get_agent(agent_id)
logger.info(f"Updated agent {agent_id} in PostgreSQL")
return updated_agent
return None
except Exception as e:
logger.error(f"Error updating agent {agent_id}: {e}")
return None
async def delete_agent(self, agent_id: str) -> bool:
"""Soft delete an agent using PostgreSQL"""
try:
# Get PostgreSQL client
pg_client = await get_postgresql_client()
# Get user role to check if admin
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
is_admin = user_role in ["admin", "developer"]
# Soft delete in PostgreSQL - admins can delete any agent, others only their own
if is_admin:
query = """
UPDATE agents
SET is_active = false, updated_at = NOW()
WHERE id = $1
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
RETURNING id
"""
deleted_id = await pg_client.fetch_scalar(query, agent_id, self.tenant_domain)
else:
query = """
UPDATE agents
SET is_active = false, updated_at = NOW()
WHERE id = $1
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
AND created_by = (SELECT id FROM users WHERE email = $3)
RETURNING id
"""
deleted_id = await pg_client.fetch_scalar(query, agent_id, self.tenant_domain, self.user_email or self.user_id)
if deleted_id:
logger.info(f"Deleted agent {agent_id} from PostgreSQL")
return True
return False
except Exception as e:
logger.error(f"Error deleting agent {agent_id}: {e}")
return False
async def check_access_permission(self, agent_id: str, requesting_user_id: str, access_type: str = "read") -> bool:
"""
Check if user has access to agent (via ownership, organization, or team).
Args:
agent_id: UUID of the agent
requesting_user_id: UUID of the user requesting access
access_type: 'read' or 'edit' (default: 'read')
Returns:
True if user has required access
"""
try:
pg_client = await get_postgresql_client()
# Check if admin/developer
user_role = await get_user_role(pg_client, requesting_user_id, self.tenant_domain)
if user_role in ["admin", "developer"]:
return True
# Get agent to check ownership and visibility
query = """
SELECT created_by, visibility
FROM agents
WHERE id = $1 AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
"""
agent_data = await pg_client.fetch_one(query, agent_id, self.tenant_domain)
if not agent_data:
return False
owner_id = str(agent_data["created_by"])
visibility = agent_data["visibility"]
# Owner has full access
if requesting_user_id == owner_id:
return True
# Organization-wide resources are accessible to all in tenant
if visibility == "organization":
return True
# Check team-based access
from app.services.team_service import TeamService
team_service = TeamService(self.tenant_domain, requesting_user_id, requesting_user_id)
return await team_service.check_user_resource_permission(
user_id=requesting_user_id,
resource_type="agent",
resource_id=agent_id,
required_permission=access_type
)
except Exception as e:
logger.error(f"Error checking access permission for agent {agent_id}: {e}")
return False
async def _check_team_membership(self, user_id: str, team_members: List[str]) -> bool:
"""Check if user is in the team members list"""
return user_id in team_members
async def _check_same_tenant(self, user_id: str) -> bool:
"""Check if requesting user is in the same tenant through PostgreSQL"""
try:
pg_client = await get_postgresql_client()
# Check if user exists in same tenant
query = """
SELECT COUNT(*) as count
FROM users
WHERE id = $1 AND tenant_id = (SELECT id FROM tenants WHERE domain = $2)
"""
result = await pg_client.fetch_one(query, user_id, self.tenant_domain)
return result and result["count"] > 0
except Exception as e:
logger.error(f"Failed to check tenant membership for user {user_id}: {e}")
return False
def get_agent_conversation_history(self, agent_id: str) -> List[Dict[str, Any]]:
"""Get conversation history for an agent (file-based)"""
conversations_path = Path(f"/data/{self.tenant_domain}/users/{self.user_id}/conversations")
conversations_path.mkdir(parents=True, exist_ok=True, mode=0o700)
conversations = []
try:
for conv_file in conversations_path.glob("*.json"):
with open(conv_file, 'r') as f:
conv_data = json.load(f)
if conv_data.get("agent_id") == agent_id:
conversations.append(conv_data)
except Exception as e:
logger.error(f"Error reading conversations for agent {agent_id}: {e}")
conversations.sort(key=lambda x: x.get("updated_at", ""), reverse=True)
return conversations

View File

@@ -0,0 +1,493 @@
"""
Assistant Builder Service for GT 2.0
Manages assistant creation, deployment, and lifecycle.
Integrates with template library and file-based storage.
"""
import os
import json
import stat
from typing import List, Optional, Dict, Any
from datetime import datetime
from pathlib import Path
import logging
from app.models.assistant_template import (
AssistantTemplate, AssistantInstance, AssistantBuilder,
AssistantType, PersonalityConfig, ResourcePreferences, MemorySettings,
AssistantTemplateLibrary, BUILTIN_TEMPLATES
)
from app.models.access_group import AccessGroup
from app.core.security import verify_capability_token
from app.services.access_controller import AccessController
logger = logging.getLogger(__name__)
class AssistantBuilderService:
"""
Service for building and managing assistants
Handles both template-based and custom assistant creation
"""
def __init__(self, tenant_domain: str, resource_cluster_url: str = "http://resource-cluster:8004"):
self.tenant_domain = tenant_domain
self.base_path = Path(f"/data/{tenant_domain}/assistants")
self.template_library = AssistantTemplateLibrary(resource_cluster_url)
self.access_controller = AccessController(tenant_domain)
self._ensure_directories()
def _ensure_directories(self):
"""Ensure assistant directories exist with proper permissions"""
self.base_path.mkdir(parents=True, exist_ok=True)
os.chmod(self.base_path, stat.S_IRWXU) # 700
# Create subdirectories
for subdir in ["templates", "instances", "shared"]:
path = self.base_path / subdir
path.mkdir(exist_ok=True)
os.chmod(path, stat.S_IRWXU) # 700
async def create_from_template(
self,
template_id: str,
user_id: str,
instance_name: str,
customizations: Optional[Dict[str, Any]] = None,
capability_token: str = None
) -> AssistantInstance:
"""
Create assistant instance from template
Args:
template_id: Template to use
user_id: User creating the assistant
instance_name: Name for the instance
customizations: Optional customizations
capability_token: JWT capability token
Returns:
Created assistant instance
"""
# Verify capability token
if capability_token:
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
# Deploy from template
instance = await self.template_library.deploy_template(
template_id=template_id,
user_id=user_id,
instance_name=instance_name,
tenant_domain=self.tenant_domain,
customizations=customizations
)
# Create file storage
await self._create_assistant_files(instance)
# Save to database (would be SQLite in production)
await self._save_assistant(instance)
logger.info(f"Created assistant {instance.id} from template {template_id} for {user_id}")
return instance
async def create_custom(
self,
builder_config: AssistantBuilder,
user_id: str,
capability_token: str = None
) -> AssistantInstance:
"""
Create custom assistant from builder configuration
Args:
builder_config: Custom assistant configuration
user_id: User creating the assistant
capability_token: JWT capability token
Returns:
Created assistant instance
"""
# Verify capability token
if capability_token:
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
# Check if user has required capabilities
user_capabilities = token_data.get("capabilities", [])
for required_cap in builder_config.requested_capabilities:
if not any(required_cap in cap.get("resource", "") for cap in user_capabilities):
raise PermissionError(f"Missing capability: {required_cap}")
# Build instance
instance = builder_config.build_instance(user_id, self.tenant_domain)
# Create file storage
await self._create_assistant_files(instance)
# Save to database
await self._save_assistant(instance)
logger.info(f"Created custom assistant {instance.id} for {user_id}")
return instance
async def get_assistant(
self,
assistant_id: str,
user_id: str
) -> Optional[AssistantInstance]:
"""
Get assistant instance by ID
Args:
assistant_id: Assistant ID
user_id: User requesting the assistant
Returns:
Assistant instance if found and accessible
"""
# Load assistant
instance = await self._load_assistant(assistant_id)
if not instance:
return None
# Check access permission
allowed, _ = await self.access_controller.check_permission(
user_id, instance, "read"
)
if not allowed:
return None
return instance
async def list_user_assistants(
self,
user_id: str,
include_shared: bool = True
) -> List[AssistantInstance]:
"""
List all assistants accessible to user
Args:
user_id: User to list assistants for
include_shared: Include team/org shared assistants
Returns:
List of accessible assistants
"""
assistants = []
# Get owned assistants
owned = await self._get_owned_assistants(user_id)
assistants.extend(owned)
# Get shared assistants if requested
if include_shared:
shared = await self._get_shared_assistants(user_id)
assistants.extend(shared)
return assistants
async def update_assistant(
self,
assistant_id: str,
user_id: str,
updates: Dict[str, Any]
) -> AssistantInstance:
"""
Update assistant configuration
Args:
assistant_id: Assistant to update
user_id: User requesting update
updates: Configuration updates
Returns:
Updated assistant instance
"""
# Load assistant
instance = await self._load_assistant(assistant_id)
if not instance:
raise ValueError(f"Assistant not found: {assistant_id}")
# Check permission
if instance.owner_id != user_id:
raise PermissionError("Only owner can update assistant")
# Apply updates
if "personality" in updates:
instance.personality_config = PersonalityConfig(**updates["personality"])
if "resources" in updates:
instance.resource_preferences = ResourcePreferences(**updates["resources"])
if "memory" in updates:
instance.memory_settings = MemorySettings(**updates["memory"])
if "system_prompt" in updates:
instance.system_prompt = updates["system_prompt"]
instance.updated_at = datetime.utcnow()
# Save changes
await self._save_assistant(instance)
await self._update_assistant_files(instance)
logger.info(f"Updated assistant {assistant_id} by {user_id}")
return instance
async def share_assistant(
self,
assistant_id: str,
user_id: str,
access_group: AccessGroup,
team_members: Optional[List[str]] = None
) -> AssistantInstance:
"""
Share assistant with team or organization
Args:
assistant_id: Assistant to share
user_id: User sharing (must be owner)
access_group: New access level
team_members: Team members if team access
Returns:
Updated assistant instance
"""
# Load assistant
instance = await self._load_assistant(assistant_id)
if not instance:
raise ValueError(f"Assistant not found: {assistant_id}")
# Check ownership
if instance.owner_id != user_id:
raise PermissionError("Only owner can share assistant")
# Update access
instance.access_group = access_group
if access_group == AccessGroup.TEAM:
instance.team_members = team_members or []
else:
instance.team_members = []
instance.updated_at = datetime.utcnow()
# Save changes
await self._save_assistant(instance)
logger.info(f"Shared assistant {assistant_id} with {access_group.value} by {user_id}")
return instance
async def delete_assistant(
self,
assistant_id: str,
user_id: str
) -> bool:
"""
Delete assistant and its files
Args:
assistant_id: Assistant to delete
user_id: User requesting deletion
Returns:
True if deleted
"""
# Load assistant
instance = await self._load_assistant(assistant_id)
if not instance:
return False
# Check ownership
if instance.owner_id != user_id:
raise PermissionError("Only owner can delete assistant")
# Delete files
await self._delete_assistant_files(instance)
# Delete from database
await self._delete_assistant_record(assistant_id)
logger.info(f"Deleted assistant {assistant_id} by {user_id}")
return True
async def get_assistant_statistics(
self,
assistant_id: str,
user_id: str
) -> Dict[str, Any]:
"""
Get usage statistics for assistant
Args:
assistant_id: Assistant ID
user_id: User requesting stats
Returns:
Statistics dictionary
"""
# Load assistant
instance = await self.get_assistant(assistant_id, user_id)
if not instance:
raise ValueError(f"Assistant not found or not accessible: {assistant_id}")
return {
"assistant_id": assistant_id,
"name": instance.name,
"created_at": instance.created_at.isoformat(),
"last_used": instance.last_used.isoformat() if instance.last_used else None,
"conversation_count": instance.conversation_count,
"total_messages": instance.total_messages,
"total_tokens_used": instance.total_tokens_used,
"access_group": instance.access_group.value,
"team_members_count": len(instance.team_members),
"linked_datasets_count": len(instance.linked_datasets),
"linked_tools_count": len(instance.linked_tools)
}
async def _create_assistant_files(self, instance: AssistantInstance):
"""Create file structure for assistant"""
# Get file paths
file_structure = instance.get_file_structure()
# Create directories
for key, path in file_structure.items():
if key in ["memory", "resources"]:
# These are directories
Path(path).mkdir(parents=True, exist_ok=True)
os.chmod(Path(path), stat.S_IRWXU) # 700
else:
# These are files
parent = Path(path).parent
parent.mkdir(parents=True, exist_ok=True)
os.chmod(parent, stat.S_IRWXU) # 700
# Save configuration
config_path = Path(file_structure["config"])
config_data = {
"id": instance.id,
"name": instance.name,
"template_id": instance.template_id,
"personality": instance.personality_config.model_dump(),
"resources": instance.resource_preferences.model_dump(),
"memory": instance.memory_settings.model_dump(),
"created_at": instance.created_at.isoformat(),
"updated_at": instance.updated_at.isoformat()
}
with open(config_path, 'w') as f:
json.dump(config_data, f, indent=2)
os.chmod(config_path, stat.S_IRUSR | stat.S_IWUSR) # 600
# Save prompt
prompt_path = Path(file_structure["prompt"])
with open(prompt_path, 'w') as f:
f.write(instance.system_prompt)
os.chmod(prompt_path, stat.S_IRUSR | stat.S_IWUSR) # 600
# Save capabilities
capabilities_path = Path(file_structure["capabilities"])
with open(capabilities_path, 'w') as f:
json.dump(instance.capabilities, f, indent=2)
os.chmod(capabilities_path, stat.S_IRUSR | stat.S_IWUSR) # 600
# Update instance with file paths
instance.config_file_path = str(config_path)
instance.memory_file_path = str(Path(file_structure["memory"]))
async def _update_assistant_files(self, instance: AssistantInstance):
"""Update assistant files with current configuration"""
if instance.config_file_path:
config_data = {
"id": instance.id,
"name": instance.name,
"template_id": instance.template_id,
"personality": instance.personality_config.model_dump(),
"resources": instance.resource_preferences.model_dump(),
"memory": instance.memory_settings.model_dump(),
"created_at": instance.created_at.isoformat(),
"updated_at": instance.updated_at.isoformat()
}
with open(instance.config_file_path, 'w') as f:
json.dump(config_data, f, indent=2)
async def _delete_assistant_files(self, instance: AssistantInstance):
"""Delete assistant file structure"""
file_structure = instance.get_file_structure()
base_dir = Path(file_structure["config"]).parent
if base_dir.exists():
import shutil
shutil.rmtree(base_dir)
logger.info(f"Deleted assistant files at {base_dir}")
async def _save_assistant(self, instance: AssistantInstance):
"""Save assistant to database (SQLite in production)"""
# This would save to SQLite database
# For now, we'll save to a JSON file as placeholder
db_file = self.base_path / "instances" / f"{instance.id}.json"
with open(db_file, 'w') as f:
json.dump(instance.model_dump(mode='json'), f, indent=2, default=str)
os.chmod(db_file, stat.S_IRUSR | stat.S_IWUSR) # 600
async def _load_assistant(self, assistant_id: str) -> Optional[AssistantInstance]:
"""Load assistant from database"""
db_file = self.base_path / "instances" / f"{assistant_id}.json"
if not db_file.exists():
return None
with open(db_file, 'r') as f:
data = json.load(f)
# Convert datetime strings back to datetime objects
for field in ['created_at', 'updated_at', 'last_used']:
if field in data and data[field]:
data[field] = datetime.fromisoformat(data[field])
return AssistantInstance(**data)
async def _delete_assistant_record(self, assistant_id: str):
"""Delete assistant from database"""
db_file = self.base_path / "instances" / f"{assistant_id}.json"
if db_file.exists():
db_file.unlink()
async def _get_owned_assistants(self, user_id: str) -> List[AssistantInstance]:
"""Get assistants owned by user"""
assistants = []
instances_dir = self.base_path / "instances"
if instances_dir.exists():
for file in instances_dir.glob("*.json"):
instance = await self._load_assistant(file.stem)
if instance and instance.owner_id == user_id:
assistants.append(instance)
return assistants
async def _get_shared_assistants(self, user_id: str) -> List[AssistantInstance]:
"""Get assistants shared with user"""
assistants = []
instances_dir = self.base_path / "instances"
if instances_dir.exists():
for file in instances_dir.glob("*.json"):
instance = await self._load_assistant(file.stem)
if instance and instance.owner_id != user_id:
# Check if user has access
allowed, _ = await self.access_controller.check_permission(
user_id, instance, "read"
)
if allowed:
assistants.append(instance)
return assistants

View File

@@ -0,0 +1,599 @@
"""
AssistantManager Service for GT 2.0 Tenant Backend
File-based agent lifecycle management with perfect tenant isolation.
Implements the core Agent System specification from CLAUDE.md.
"""
import os
import json
import asyncio
from datetime import datetime
from typing import Dict, Any, List, Optional, Union
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func, desc
from sqlalchemy.orm import selectinload
import logging
from app.models.agent import Agent
from app.models.conversation import Conversation
from app.models.message import Message
from app.core.config import get_settings
logger = logging.getLogger(__name__)
class AssistantManager:
"""File-based agent lifecycle management"""
def __init__(self, db: AsyncSession):
self.db = db
self.settings = get_settings()
async def create_from_template(self, template_id: str, config: Dict[str, Any], user_identifier: str) -> str:
"""Create agent from template or custom config"""
try:
# Get template configuration
template_config = await self._load_template_config(template_id)
# Merge template config with user overrides
merged_config = {**template_config, **config}
# Create agent record
agent = Agent(
name=merged_config.get("name", f"Agent from {template_id}"),
description=merged_config.get("description", f"Created from template: {template_id}"),
template_id=template_id,
created_by=user_identifier,
user_name=merged_config.get("user_name"),
personality_config=merged_config.get("personality_config", {}),
resource_preferences=merged_config.get("resource_preferences", {}),
memory_settings=merged_config.get("memory_settings", {}),
tags=merged_config.get("tags", []),
)
# Initialize with placeholder paths first
agent.config_file_path = "placeholder"
agent.prompt_file_path = "placeholder"
agent.capabilities_file_path = "placeholder"
# Save to database first to get ID and UUID
self.db.add(agent)
await self.db.flush() # Flush to get the generated UUID without committing
# Now we can initialize proper file paths with the UUID
agent.initialize_file_paths()
# Create file system structure
await self._setup_assistant_files(agent, merged_config)
# Commit all changes
await self.db.commit()
await self.db.refresh(agent)
logger.info(
f"Created agent from template",
extra={
"agent_id": agent.id,
"assistant_uuid": agent.uuid,
"template_id": template_id,
"created_by": user_identifier,
}
)
return str(agent.uuid)
except Exception as e:
logger.error(f"Failed to create agent from template: {e}", exc_info=True)
await self.db.rollback()
raise
async def create_custom_assistant(self, config: Dict[str, Any], user_identifier: str) -> str:
"""Create custom agent without template"""
try:
# Validate required fields
if not config.get("name"):
raise ValueError("Agent name is required")
# Create agent record
agent = Agent(
name=config["name"],
description=config.get("description", "Custom AI agent"),
template_id=None, # No template used
created_by=user_identifier,
user_name=config.get("user_name"),
personality_config=config.get("personality_config", {}),
resource_preferences=config.get("resource_preferences", {}),
memory_settings=config.get("memory_settings", {}),
tags=config.get("tags", []),
)
# Initialize with placeholder paths first
agent.config_file_path = "placeholder"
agent.prompt_file_path = "placeholder"
agent.capabilities_file_path = "placeholder"
# Save to database first to get ID and UUID
self.db.add(agent)
await self.db.flush() # Flush to get the generated UUID without committing
# Now we can initialize proper file paths with the UUID
agent.initialize_file_paths()
# Create file system structure
await self._setup_assistant_files(agent, config)
# Commit all changes
await self.db.commit()
await self.db.refresh(agent)
logger.info(
f"Created custom agent",
extra={
"agent_id": agent.id,
"assistant_uuid": agent.uuid,
"created_by": user_identifier,
}
)
return str(agent.uuid)
except Exception as e:
logger.error(f"Failed to create custom agent: {e}", exc_info=True)
await self.db.rollback()
raise
async def get_assistant_config(self, assistant_uuid: str, user_identifier: str) -> Dict[str, Any]:
"""Get complete agent configuration including file-based data"""
try:
# Get agent from database
result = await self.db.execute(
select(Agent).where(
and_(
Agent.uuid == assistant_uuid,
Agent.created_by == user_identifier,
Agent.is_active == True
)
)
)
agent = result.scalar_one_or_none()
if not agent:
raise ValueError(f"Agent not found: {assistant_uuid}")
# Load complete configuration
return agent.get_full_configuration()
except Exception as e:
logger.error(f"Failed to get agent config: {e}", exc_info=True)
raise
async def list_user_assistants(
self,
user_identifier: str,
include_archived: bool = False,
template_id: Optional[str] = None,
search: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> List[Dict[str, Any]]:
"""List user's agents with filtering options"""
try:
# Build base query
query = select(Agent).where(Agent.created_by == user_identifier)
# Apply filters
if not include_archived:
query = query.where(Agent.is_active == True)
if template_id:
query = query.where(Agent.template_id == template_id)
if search:
search_term = f"%{search}%"
query = query.where(
or_(
Agent.name.ilike(search_term),
Agent.description.ilike(search_term)
)
)
# Apply ordering and pagination
query = query.order_by(desc(Agent.last_used_at), desc(Agent.created_at))
query = query.limit(limit).offset(offset)
result = await self.db.execute(query)
agents = result.scalars().all()
return [agent.to_dict() for agent in agents]
except Exception as e:
logger.error(f"Failed to list user agents: {e}", exc_info=True)
raise
async def count_user_assistants(
self,
user_identifier: str,
include_archived: bool = False,
template_id: Optional[str] = None,
search: Optional[str] = None
) -> int:
"""Count user's agents matching criteria"""
try:
# Build base query
query = select(func.count(Agent.id)).where(Agent.created_by == user_identifier)
# Apply filters
if not include_archived:
query = query.where(Agent.is_active == True)
if template_id:
query = query.where(Agent.template_id == template_id)
if search:
search_term = f"%{search}%"
query = query.where(
or_(
Agent.name.ilike(search_term),
Agent.description.ilike(search_term)
)
)
result = await self.db.execute(query)
return result.scalar() or 0
except Exception as e:
logger.error(f"Failed to count user agents: {e}", exc_info=True)
raise
async def update_assistant(self, agent_id: str, updates: Dict[str, Any], user_identifier: str) -> bool:
"""Update agent configuration (renamed from update_configuration)"""
return await self.update_configuration(agent_id, updates, user_identifier)
async def update_configuration(self, assistant_uuid: str, updates: Dict[str, Any], user_identifier: str) -> bool:
"""Update agent configuration"""
try:
# Get agent
result = await self.db.execute(
select(Agent).where(
and_(
Agent.uuid == assistant_uuid,
Agent.created_by == user_identifier,
Agent.is_active == True
)
)
)
agent = result.scalar_one_or_none()
if not agent:
raise ValueError(f"Agent not found: {assistant_uuid}")
# Update database fields
if "name" in updates:
agent.name = updates["name"]
if "description" in updates:
agent.description = updates["description"]
if "personality_config" in updates:
agent.personality_config = updates["personality_config"]
if "resource_preferences" in updates:
agent.resource_preferences = updates["resource_preferences"]
if "memory_settings" in updates:
agent.memory_settings = updates["memory_settings"]
if "tags" in updates:
agent.tags = updates["tags"]
# Update file-based configurations
if "config" in updates:
agent.save_config_to_file(updates["config"])
if "prompt" in updates:
agent.save_prompt_to_file(updates["prompt"])
if "capabilities" in updates:
agent.save_capabilities_to_file(updates["capabilities"])
agent.updated_at = datetime.utcnow()
await self.db.commit()
logger.info(
f"Updated agent configuration",
extra={
"assistant_uuid": assistant_uuid,
"updated_fields": list(updates.keys()),
}
)
return True
except Exception as e:
logger.error(f"Failed to update agent configuration: {e}", exc_info=True)
await self.db.rollback()
raise
async def clone_assistant(self, source_uuid: str, new_name: str, user_identifier: str, modifications: Dict[str, Any] = None) -> str:
"""Clone existing agent with modifications"""
try:
# Get source agent
result = await self.db.execute(
select(Agent).where(
and_(
Agent.uuid == source_uuid,
Agent.created_by == user_identifier,
Agent.is_active == True
)
)
)
source_assistant = result.scalar_one_or_none()
if not source_assistant:
raise ValueError(f"Source agent not found: {source_uuid}")
# Clone agent
cloned_assistant = source_assistant.clone(new_name, user_identifier, modifications or {})
# Initialize with placeholder paths first
cloned_assistant.config_file_path = "placeholder"
cloned_assistant.prompt_file_path = "placeholder"
cloned_assistant.capabilities_file_path = "placeholder"
# Save to database first to get UUID
self.db.add(cloned_assistant)
await self.db.flush() # Flush to get the generated UUID
# Initialize proper file paths with UUID
cloned_assistant.initialize_file_paths()
# Copy and modify files
await self._clone_assistant_files(source_assistant, cloned_assistant, modifications or {})
# Commit all changes
await self.db.commit()
await self.db.refresh(cloned_assistant)
logger.info(
f"Cloned agent",
extra={
"source_uuid": source_uuid,
"new_uuid": cloned_assistant.uuid,
"new_name": new_name,
}
)
return str(cloned_assistant.uuid)
except Exception as e:
logger.error(f"Failed to clone agent: {e}", exc_info=True)
await self.db.rollback()
raise
async def archive_assistant(self, assistant_uuid: str, user_identifier: str) -> bool:
"""Archive agent (soft delete)"""
try:
result = await self.db.execute(
select(Agent).where(
and_(
Agent.uuid == assistant_uuid,
Agent.created_by == user_identifier
)
)
)
agent = result.scalar_one_or_none()
if not agent:
raise ValueError(f"Agent not found: {assistant_uuid}")
agent.archive()
await self.db.commit()
logger.info(
f"Archived agent",
extra={"assistant_uuid": assistant_uuid}
)
return True
except Exception as e:
logger.error(f"Failed to archive agent: {e}", exc_info=True)
await self.db.rollback()
raise
async def get_assistant_statistics(self, assistant_uuid: str, user_identifier: str) -> Dict[str, Any]:
"""Get usage statistics for agent"""
try:
result = await self.db.execute(
select(Agent).where(
and_(
Agent.uuid == assistant_uuid,
Agent.created_by == user_identifier,
Agent.is_active == True
)
)
)
agent = result.scalar_one_or_none()
if not agent:
raise ValueError(f"Agent not found: {assistant_uuid}")
# Get conversation statistics
conv_result = await self.db.execute(
select(func.count(Conversation.id))
.where(Conversation.agent_id == agent.id)
)
conversation_count = conv_result.scalar() or 0
# Get message statistics
msg_result = await self.db.execute(
select(
func.count(Message.id),
func.sum(Message.tokens_used),
func.sum(Message.cost_cents)
)
.join(Conversation, Message.conversation_id == Conversation.id)
.where(Conversation.agent_id == agent.id)
)
message_stats = msg_result.first()
return {
"agent_id": assistant_uuid, # Use agent_id to match schema
"name": agent.name,
"created_at": agent.created_at, # Return datetime object, not ISO string
"last_used_at": agent.last_used_at, # Return datetime object, not ISO string
"conversation_count": conversation_count,
"total_messages": message_stats[0] or 0,
"total_tokens_used": message_stats[1] or 0,
"total_cost_cents": message_stats[2] or 0,
"total_cost_dollars": (message_stats[2] or 0) / 100.0,
"average_tokens_per_message": (
(message_stats[1] or 0) / max(1, message_stats[0] or 1)
),
"is_favorite": agent.is_favorite,
"tags": agent.tags,
}
except Exception as e:
logger.error(f"Failed to get agent statistics: {e}", exc_info=True)
raise
# Private helper methods
async def _load_template_config(self, template_id: str) -> Dict[str, Any]:
"""Load template configuration from Resource Cluster or built-in templates"""
# Built-in templates (as specified in CLAUDE.md)
builtin_templates = {
"research_assistant": {
"name": "Research & Analysis Agent",
"description": "Specialized in information synthesis and analysis",
"prompt": """You are a research agent specialized in information synthesis and analysis.
Focus on providing well-sourced, analytical responses with clear reasoning.""",
"personality_config": {
"tone": "balanced",
"explanation_depth": "expert",
"interaction_style": "collaborative"
},
"resource_preferences": {
"primary_llm": "groq:llama3-70b-8192",
"temperature": 0.7,
"max_tokens": 4000
},
"capabilities": [
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 4000}},
{"resource": "rag:semantic_search", "actions": ["search"], "limits": {}},
{"resource": "tools:web_search", "actions": ["search"], "limits": {"requests_per_hour": 50}},
{"resource": "export:citations", "actions": ["create"], "limits": {}}
]
},
"coding_assistant": {
"name": "Software Development Agent",
"description": "Focused on code quality and best practices",
"prompt": """You are a software development agent focused on code quality and best practices.
Provide clear explanations, suggest improvements, and help debug issues.""",
"personality_config": {
"tone": "direct",
"explanation_depth": "intermediate",
"interaction_style": "teaching"
},
"resource_preferences": {
"primary_llm": "groq:llama3-70b-8192",
"temperature": 0.3,
"max_tokens": 4000
},
"capabilities": [
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 4000}},
{"resource": "tools:github_integration", "actions": ["read"], "limits": {}},
{"resource": "resources:documentation", "actions": ["search"], "limits": {}},
{"resource": "export:code_snippets", "actions": ["create"], "limits": {}}
]
},
"cyber_analyst": {
"name": "Cybersecurity Analysis Agent",
"description": "For threat detection and response analysis",
"prompt": """You are a cybersecurity analyst agent for threat detection and response.
Prioritize security best practices and provide actionable recommendations.""",
"personality_config": {
"tone": "formal",
"explanation_depth": "expert",
"interaction_style": "direct"
},
"resource_preferences": {
"primary_llm": "groq:llama3-70b-8192",
"temperature": 0.2,
"max_tokens": 4000
},
"capabilities": [
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 4000}},
{"resource": "tools:security_scanning", "actions": ["analyze"], "limits": {}},
{"resource": "resources:threat_intelligence", "actions": ["search"], "limits": {}},
{"resource": "export:security_reports", "actions": ["create"], "limits": {}}
]
},
"educational_tutor": {
"name": "AI Literacy Educational Agent",
"description": "Develops critical thinking and AI literacy",
"prompt": """You are an educational agent focused on developing critical thinking and AI literacy.
Use socratic questioning and encourage deep analysis of problems.""",
"personality_config": {
"tone": "casual",
"explanation_depth": "beginner",
"interaction_style": "teaching"
},
"resource_preferences": {
"primary_llm": "groq:llama3-70b-8192",
"temperature": 0.8,
"max_tokens": 3000
},
"capabilities": [
{"resource": "llm:groq", "actions": ["inference"], "limits": {"max_tokens_per_request": 3000}},
{"resource": "games:strategic_thinking", "actions": ["play"], "limits": {}},
{"resource": "puzzles:logic_reasoning", "actions": ["present"], "limits": {}},
{"resource": "analytics:learning_progress", "actions": ["track"], "limits": {}}
]
}
}
if template_id in builtin_templates:
return builtin_templates[template_id]
# TODO: In the future, load from Resource Cluster Agent Library
# For now, return empty config for unknown templates
logger.warning(f"Unknown template ID: {template_id}")
return {
"name": f"Agent ({template_id})",
"description": "Custom agent",
"prompt": "You are a helpful AI agent.",
"capabilities": []
}
async def _setup_assistant_files(self, agent: Agent, config: Dict[str, Any]) -> None:
"""Create file system structure for agent"""
# Ensure directory exists
agent.ensure_directory_exists()
# Save configuration files
agent.save_config_to_file(config)
agent.save_prompt_to_file(config.get("prompt", "You are a helpful AI agent."))
agent.save_capabilities_to_file(config.get("capabilities", []))
logger.info(f"Created agent files for {agent.uuid}")
async def _clone_assistant_files(self, source: Agent, target: Agent, modifications: Dict[str, Any]) -> None:
"""Clone agent files with modifications"""
# Load source configurations
source_config = source.load_config_from_file()
source_prompt = source.load_prompt_from_file()
source_capabilities = source.load_capabilities_from_file()
# Apply modifications
target_config = {**source_config, **modifications.get("config", {})}
target_prompt = modifications.get("prompt", source_prompt)
target_capabilities = modifications.get("capabilities", source_capabilities)
# Create target files
target.ensure_directory_exists()
target.save_config_to_file(target_config)
target.save_prompt_to_file(target_prompt)
target.save_capabilities_to_file(target_capabilities)
logger.info(f"Cloned agent files from {source.uuid} to {target.uuid}")
async def get_assistant_manager(db: AsyncSession) -> AssistantManager:
"""Get AssistantManager instance"""
return AssistantManager(db)

View File

@@ -0,0 +1,632 @@
"""
Automation Chain Executor
Executes automation chains with configurable depth, capability-based limits,
and comprehensive error handling.
"""
import asyncio
import logging
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import json
from app.services.event_bus import Event, Automation, TriggerType, TenantEventBus
from app.core.security import verify_capability_token
from app.core.path_security import sanitize_tenant_domain
logger = logging.getLogger(__name__)
class ChainDepthExceeded(Exception):
"""Raised when automation chain depth exceeds limit"""
pass
class AutomationTimeout(Exception):
"""Raised when automation execution times out"""
pass
@dataclass
class ExecutionContext:
"""Context for automation execution"""
automation_id: str
chain_depth: int = 0
parent_automation_id: Optional[str] = None
start_time: datetime = None
execution_history: List[Dict[str, Any]] = None
variables: Dict[str, Any] = None
def __post_init__(self):
if self.start_time is None:
self.start_time = datetime.utcnow()
if self.execution_history is None:
self.execution_history = []
if self.variables is None:
self.variables = {}
def add_execution(self, action: str, result: Any, duration_ms: float):
"""Add execution record to history"""
self.execution_history.append({
"action": action,
"result": result,
"duration_ms": duration_ms,
"timestamp": datetime.utcnow().isoformat()
})
def get_total_duration(self) -> float:
"""Get total execution duration in milliseconds"""
return (datetime.utcnow() - self.start_time).total_seconds() * 1000
class AutomationChainExecutor:
"""
Execute automation chains with configurable depth and capability-based limits.
Features:
- Configurable max chain depth per tenant
- Retry logic with exponential backoff
- Comprehensive error handling
- Execution history tracking
- Variable passing between chain steps
"""
def __init__(
self,
tenant_domain: str,
event_bus: TenantEventBus,
base_path: Optional[Path] = None
):
self.tenant_domain = tenant_domain
self.event_bus = event_bus
# Sanitize tenant_domain to prevent path traversal
safe_tenant = sanitize_tenant_domain(tenant_domain)
self.base_path = base_path or (Path("/data") / safe_tenant / "automations")
self.execution_path = self.base_path / "executions"
self.running_chains: Dict[str, ExecutionContext] = {}
# Ensure directories exist
self._ensure_directories()
logger.info(f"AutomationChainExecutor initialized for {tenant_domain}")
def _ensure_directories(self):
"""Ensure execution directories exist with proper permissions"""
import os
import stat
# codeql[py/path-injection] execution_path derived from sanitize_tenant_domain() at line 86
self.execution_path.mkdir(parents=True, exist_ok=True)
os.chmod(self.execution_path, stat.S_IRWXU) # 700 permissions
async def execute_chain(
self,
automation: Automation,
event: Event,
capability_token: str,
current_depth: int = 0
) -> Any:
"""
Execute automation chain with depth control.
Args:
automation: Automation to execute
event: Triggering event
capability_token: JWT capability token
current_depth: Current chain depth
Returns:
Execution result
Raises:
ChainDepthExceeded: If chain depth exceeds limit
AutomationTimeout: If execution times out
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data:
raise ValueError("Invalid capability token")
# Get max chain depth from capability token (tenant-specific)
max_depth = self._get_constraint(token_data, "max_automation_chain_depth", 5)
# Check depth limit
if current_depth >= max_depth:
raise ChainDepthExceeded(
f"Chain depth {current_depth} exceeds limit {max_depth}"
)
# Create execution context
context = ExecutionContext(
automation_id=automation.id,
chain_depth=current_depth,
parent_automation_id=event.metadata.get("parent_automation_id")
)
# Track running chain
self.running_chains[automation.id] = context
try:
# Execute automation with timeout
timeout = self._get_constraint(token_data, "automation_timeout_seconds", 300)
result = await asyncio.wait_for(
self._execute_automation(automation, event, context, token_data),
timeout=timeout
)
# If this automation triggers chain
if automation.triggers_chain:
await self._trigger_chain_automations(
automation,
result,
capability_token,
current_depth
)
# Store execution history
await self._store_execution(context, result)
return result
except asyncio.TimeoutError:
raise AutomationTimeout(
f"Automation {automation.id} timed out after {timeout} seconds"
)
finally:
# Remove from running chains
self.running_chains.pop(automation.id, None)
async def _execute_automation(
self,
automation: Automation,
event: Event,
context: ExecutionContext,
token_data: Dict[str, Any]
) -> Any:
"""Execute automation with retry logic"""
results = []
retry_count = 0
max_retries = min(automation.max_retries, 5) # Cap at 5 retries
while retry_count <= max_retries:
try:
# Execute each action
for action in automation.actions:
start_time = datetime.utcnow()
# Check if action is allowed by capabilities
if not self._is_action_allowed(action, token_data):
logger.warning(f"Action {action.get('type')} not allowed by capabilities")
continue
# Execute action with context
result = await self._execute_action(action, event, context, token_data)
# Track execution
duration_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
context.add_execution(action.get("type"), result, duration_ms)
results.append(result)
# Update variables for next action
if isinstance(result, dict) and "variables" in result:
context.variables.update(result["variables"])
# Success - break retry loop
break
except Exception as e:
retry_count += 1
if retry_count > max_retries:
logger.error(f"Automation {automation.id} failed after {max_retries} retries: {e}")
raise
# Exponential backoff
wait_time = min(2 ** retry_count, 30) # Max 30 seconds
logger.info(f"Retrying automation {automation.id} in {wait_time} seconds...")
await asyncio.sleep(wait_time)
return {
"automation_id": automation.id,
"results": results,
"context": {
"chain_depth": context.chain_depth,
"total_duration_ms": context.get_total_duration(),
"variables": context.variables
}
}
async def _execute_action(
self,
action: Dict[str, Any],
event: Event,
context: ExecutionContext,
token_data: Dict[str, Any]
) -> Any:
"""Execute a single action with capability constraints"""
action_type = action.get("type")
if action_type == "api_call":
return await self._execute_api_call(action, context, token_data)
elif action_type == "data_transform":
return await self._execute_data_transform(action, context)
elif action_type == "conditional":
return await self._execute_conditional(action, context)
elif action_type == "loop":
return await self._execute_loop(action, event, context, token_data)
elif action_type == "wait":
return await self._execute_wait(action)
elif action_type == "variable_set":
return await self._execute_variable_set(action, context)
else:
# Delegate to event bus for standard actions
return await self.event_bus._execute_action(action, event, None)
async def _execute_api_call(
self,
action: Dict[str, Any],
context: ExecutionContext,
token_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute API call action with rate limiting"""
endpoint = action.get("endpoint")
method = action.get("method", "GET")
headers = action.get("headers", {})
body = action.get("body")
# Apply variable substitution
if body and context.variables:
body = self._substitute_variables(body, context.variables)
# Check rate limits
rate_limit = self._get_constraint(token_data, "api_calls_per_minute", 60)
# In production, implement actual rate limiting
logger.info(f"Mock API call: {method} {endpoint}")
# Mock response
return {
"status": 200,
"data": {"message": "Mock API response"},
"headers": {"content-type": "application/json"}
}
async def _execute_data_transform(
self,
action: Dict[str, Any],
context: ExecutionContext
) -> Dict[str, Any]:
"""Execute data transformation action"""
transform_type = action.get("transform_type")
source = action.get("source")
target = action.get("target")
# Get source data from context
source_data = context.variables.get(source)
if transform_type == "json_parse":
result = json.loads(source_data) if isinstance(source_data, str) else source_data
elif transform_type == "json_stringify":
result = json.dumps(source_data)
elif transform_type == "extract":
path = action.get("path", "")
result = self._extract_path(source_data, path)
elif transform_type == "map":
mapping = action.get("mapping", {})
result = {k: self._extract_path(source_data, v) for k, v in mapping.items()}
else:
result = source_data
# Store result in context
context.variables[target] = result
return {
"transform_type": transform_type,
"target": target,
"variables": {target: result}
}
async def _execute_conditional(
self,
action: Dict[str, Any],
context: ExecutionContext
) -> Dict[str, Any]:
"""Execute conditional action"""
condition = action.get("condition")
then_actions = action.get("then", [])
else_actions = action.get("else", [])
# Evaluate condition
if self._evaluate_condition(condition, context.variables):
actions_to_execute = then_actions
branch = "then"
else:
actions_to_execute = else_actions
branch = "else"
# Execute branch actions
results = []
for sub_action in actions_to_execute:
result = await self._execute_action(sub_action, None, context, {})
results.append(result)
return {
"condition": condition,
"branch": branch,
"results": results
}
async def _execute_loop(
self,
action: Dict[str, Any],
event: Event,
context: ExecutionContext,
token_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute loop action with iteration limit"""
items = action.get("items", [])
variable = action.get("variable", "item")
loop_actions = action.get("actions", [])
# Get max iterations from capabilities
max_iterations = self._get_constraint(token_data, "max_loop_iterations", 100)
# Resolve items from context if it's a variable reference
if isinstance(items, str) and items.startswith("$"):
items = context.variables.get(items[1:], [])
# Limit iterations
items = items[:max_iterations]
results = []
for item in items:
# Set loop variable
context.variables[variable] = item
# Execute loop actions
for loop_action in loop_actions:
result = await self._execute_action(loop_action, event, context, token_data)
results.append(result)
return {
"loop_count": len(items),
"results": results
}
async def _execute_wait(self, action: Dict[str, Any]) -> Dict[str, Any]:
"""Execute wait action"""
duration = action.get("duration", 1)
max_wait = 60 # Maximum 60 seconds wait
duration = min(duration, max_wait)
await asyncio.sleep(duration)
return {
"waited": duration,
"unit": "seconds"
}
async def _execute_variable_set(
self,
action: Dict[str, Any],
context: ExecutionContext
) -> Dict[str, Any]:
"""Set variables in context"""
variables = action.get("variables", {})
for key, value in variables.items():
# Substitute existing variables in value
if isinstance(value, str):
value = self._substitute_variables(value, context.variables)
context.variables[key] = value
return {
"variables": variables
}
async def _trigger_chain_automations(
self,
automation: Automation,
result: Any,
capability_token: str,
current_depth: int
):
"""Trigger chained automations"""
for target_id in automation.chain_targets:
# Load target automation
target_automation = await self.event_bus.get_automation(target_id)
if not target_automation:
logger.warning(f"Chain target automation {target_id} not found")
continue
# Create chain event
chain_event = Event(
type=TriggerType.CHAIN.value,
tenant=self.tenant_domain,
user=automation.owner_id,
data=result,
metadata={
"parent_automation_id": automation.id,
"chain_depth": current_depth + 1
}
)
# Execute chained automation
try:
await self.execute_chain(
target_automation,
chain_event,
capability_token,
current_depth + 1
)
except ChainDepthExceeded:
logger.warning(f"Chain depth exceeded for automation {target_id}")
except Exception as e:
logger.error(f"Error executing chained automation {target_id}: {e}")
def _get_constraint(
self,
token_data: Dict[str, Any],
constraint_name: str,
default: Any
) -> Any:
"""Get constraint value from capability token"""
constraints = token_data.get("constraints", {})
return constraints.get(constraint_name, default)
def _is_action_allowed(
self,
action: Dict[str, Any],
token_data: Dict[str, Any]
) -> bool:
"""Check if action is allowed by capabilities"""
action_type = action.get("type")
# Check specific action capabilities
capabilities = token_data.get("capabilities", [])
# Map action types to required capabilities
required_capabilities = {
"api_call": "automation:api_calls",
"webhook": "automation:webhooks",
"email": "automation:email",
"data_transform": "automation:data_processing",
"conditional": "automation:logic",
"loop": "automation:logic"
}
required = required_capabilities.get(action_type)
if not required:
return True # Allow unknown actions by default
# Check if capability exists
return any(
cap.get("resource") == required
for cap in capabilities
)
def _substitute_variables(
self,
template: Any,
variables: Dict[str, Any]
) -> Any:
"""Substitute variables in template"""
if not isinstance(template, str):
return template
# Simple variable substitution
for key, value in variables.items():
template = template.replace(f"${{{key}}}", str(value))
template = template.replace(f"${key}", str(value))
return template
def _extract_path(self, data: Any, path: str) -> Any:
"""Extract value from nested data using path"""
if not path:
return data
parts = path.split(".")
current = data
for part in parts:
if isinstance(current, dict):
current = current.get(part)
elif isinstance(current, list) and part.isdigit():
index = int(part)
if 0 <= index < len(current):
current = current[index]
else:
return None
else:
return None
return current
def _evaluate_condition(
self,
condition: Dict[str, Any],
variables: Dict[str, Any]
) -> bool:
"""Evaluate condition against variables"""
left = condition.get("left")
operator = condition.get("operator")
right = condition.get("right")
# Resolve variables
if isinstance(left, str) and left.startswith("$"):
left = variables.get(left[1:])
if isinstance(right, str) and right.startswith("$"):
right = variables.get(right[1:])
# Evaluate
try:
if operator == "equals":
return left == right
elif operator == "not_equals":
return left != right
elif operator == "greater_than":
return float(left) > float(right)
elif operator == "less_than":
return float(left) < float(right)
elif operator == "contains":
return right in left
elif operator == "exists":
return left is not None
elif operator == "not_exists":
return left is None
else:
return False
except (ValueError, TypeError):
return False
async def _store_execution(
self,
context: ExecutionContext,
result: Any
):
"""Store execution history to file system"""
execution_record = {
"automation_id": context.automation_id,
"chain_depth": context.chain_depth,
"parent_automation_id": context.parent_automation_id,
"start_time": context.start_time.isoformat(),
"total_duration_ms": context.get_total_duration(),
"execution_history": context.execution_history,
"variables": context.variables,
"result": result if isinstance(result, (dict, list, str, int, float, bool)) else str(result)
}
# Create execution file
execution_file = self.execution_path / f"{context.automation_id}_{context.start_time.strftime('%Y%m%d_%H%M%S')}.json"
with open(execution_file, "w") as f:
json.dump(execution_record, f, indent=2)
async def get_execution_history(
self,
automation_id: Optional[str] = None,
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get execution history for automations"""
executions = []
# Get all execution files
pattern = f"{automation_id}_*.json" if automation_id else "*.json"
for execution_file in sorted(
self.execution_path.glob(pattern),
key=lambda x: x.stat().st_mtime,
reverse=True
)[:limit]:
try:
with open(execution_file, "r") as f:
executions.append(json.load(f))
except Exception as e:
logger.error(f"Error loading execution {execution_file}: {e}")
return executions

View File

@@ -0,0 +1,514 @@
"""
Category Service for GT 2.0 Tenant Backend
Provides tenant-scoped agent category management with permission-based
editing and deletion. Supports Issue #215 requirements.
Permission Model:
- Admins/developers can edit/delete ANY category
- Regular users can only edit/delete categories they created
- All users can view and use all tenant categories
"""
import uuid
import re
from typing import Dict, List, Optional, Any
from datetime import datetime
from app.core.config import get_settings
from app.core.postgresql_client import get_postgresql_client
from app.core.permissions import get_user_role
import logging
logger = logging.getLogger(__name__)
# Admin roles that can manage all categories
ADMIN_ROLES = ["admin", "developer"]
class CategoryService:
"""GT 2.0 Category Management Service with Tenant Isolation"""
def __init__(self, tenant_domain: str, user_id: str, user_email: str = None):
"""Initialize with tenant and user isolation using PostgreSQL storage"""
self.tenant_domain = tenant_domain
self.user_id = user_id
self.user_email = user_email or user_id
self.settings = get_settings()
logger.info(f"Category service initialized for {tenant_domain}/{user_id}")
def _generate_slug(self, name: str) -> str:
"""Generate URL-safe slug from category name"""
# Convert to lowercase, replace non-alphanumeric with hyphens
slug = re.sub(r'[^a-zA-Z0-9]+', '-', name.lower())
# Remove leading/trailing hyphens
slug = slug.strip('-')
return slug or 'category'
async def _get_user_id(self, pg_client) -> str:
"""Get user UUID from email/username/uuid with tenant isolation"""
identifier = self.user_email
user_lookup_query = """
SELECT id FROM users
WHERE (email = $1 OR id::text = $1 OR username = $1)
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
LIMIT 1
"""
user_id = await pg_client.fetch_scalar(user_lookup_query, identifier, self.tenant_domain)
if not user_id:
user_id = await pg_client.fetch_scalar(user_lookup_query, self.user_id, self.tenant_domain)
if not user_id:
raise RuntimeError(f"User not found: {identifier} in tenant {self.tenant_domain}")
return str(user_id)
async def _get_tenant_id(self, pg_client) -> str:
"""Get tenant UUID from domain"""
query = "SELECT id FROM tenants WHERE domain = $1 LIMIT 1"
tenant_id = await pg_client.fetch_scalar(query, self.tenant_domain)
if not tenant_id:
raise RuntimeError(f"Tenant not found: {self.tenant_domain}")
return str(tenant_id)
async def _can_manage_category(self, pg_client, category: Dict) -> tuple:
"""
Check if current user can manage (edit/delete) a category.
Returns (can_edit, can_delete) tuple.
"""
# Get user role
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
is_admin = user_role in ADMIN_ROLES
# Get current user ID
current_user_id = await self._get_user_id(pg_client)
# Admins can manage all categories
if is_admin:
return (True, True)
# Check if user created this category
created_by = category.get('created_by')
if created_by and str(created_by) == current_user_id:
return (True, True)
# Regular users cannot manage other users' categories or defaults
return (False, False)
async def get_all_categories(self) -> List[Dict[str, Any]]:
"""
Get all active categories for the tenant.
Returns categories with permission flags for current user.
"""
try:
pg_client = await get_postgresql_client()
user_id = await self._get_user_id(pg_client)
user_role = await get_user_role(pg_client, self.user_email, self.tenant_domain)
is_admin = user_role in ADMIN_ROLES
query = """
SELECT
c.id, c.name, c.slug, c.description, c.icon,
c.is_default, c.created_by, c.sort_order,
c.created_at, c.updated_at,
u.full_name as created_by_name
FROM categories c
LEFT JOIN users u ON c.created_by = u.id
WHERE c.tenant_id = (SELECT id FROM tenants WHERE domain = $1 LIMIT 1)
AND c.is_deleted = FALSE
ORDER BY c.sort_order ASC, c.name ASC
"""
rows = await pg_client.execute_query(query, self.tenant_domain)
categories = []
for row in rows:
# Determine permissions
can_edit = False
can_delete = False
if is_admin:
can_edit = True
can_delete = True
elif row.get('created_by') and str(row['created_by']) == user_id:
can_edit = True
can_delete = True
categories.append({
"id": str(row["id"]),
"name": row["name"],
"slug": row["slug"],
"description": row.get("description"),
"icon": row.get("icon"),
"is_default": row.get("is_default", False),
"created_by": str(row["created_by"]) if row.get("created_by") else None,
"created_by_name": row.get("created_by_name"),
"can_edit": can_edit,
"can_delete": can_delete,
"sort_order": row.get("sort_order", 0),
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
})
logger.info(f"Retrieved {len(categories)} categories for tenant {self.tenant_domain}")
return categories
except Exception as e:
logger.error(f"Error getting categories: {e}")
raise
async def get_category_by_id(self, category_id: str) -> Optional[Dict[str, Any]]:
"""Get a single category by ID"""
try:
pg_client = await get_postgresql_client()
query = """
SELECT
c.id, c.name, c.slug, c.description, c.icon,
c.is_default, c.created_by, c.sort_order,
c.created_at, c.updated_at,
u.full_name as created_by_name
FROM categories c
LEFT JOIN users u ON c.created_by = u.id
WHERE c.id = $1::uuid
AND c.tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
AND c.is_deleted = FALSE
"""
row = await pg_client.fetch_one(query, category_id, self.tenant_domain)
if not row:
return None
can_edit, can_delete = await self._can_manage_category(pg_client, dict(row))
return {
"id": str(row["id"]),
"name": row["name"],
"slug": row["slug"],
"description": row.get("description"),
"icon": row.get("icon"),
"is_default": row.get("is_default", False),
"created_by": str(row["created_by"]) if row.get("created_by") else None,
"created_by_name": row.get("created_by_name"),
"can_edit": can_edit,
"can_delete": can_delete,
"sort_order": row.get("sort_order", 0),
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
}
except Exception as e:
logger.error(f"Error getting category {category_id}: {e}")
raise
async def get_category_by_slug(self, slug: str) -> Optional[Dict[str, Any]]:
"""Get a single category by slug"""
try:
pg_client = await get_postgresql_client()
query = """
SELECT
c.id, c.name, c.slug, c.description, c.icon,
c.is_default, c.created_by, c.sort_order,
c.created_at, c.updated_at,
u.full_name as created_by_name
FROM categories c
LEFT JOIN users u ON c.created_by = u.id
WHERE c.slug = $1
AND c.tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
AND c.is_deleted = FALSE
"""
row = await pg_client.fetch_one(query, slug.lower(), self.tenant_domain)
if not row:
return None
can_edit, can_delete = await self._can_manage_category(pg_client, dict(row))
return {
"id": str(row["id"]),
"name": row["name"],
"slug": row["slug"],
"description": row.get("description"),
"icon": row.get("icon"),
"is_default": row.get("is_default", False),
"created_by": str(row["created_by"]) if row.get("created_by") else None,
"created_by_name": row.get("created_by_name"),
"can_edit": can_edit,
"can_delete": can_delete,
"sort_order": row.get("sort_order", 0),
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
}
except Exception as e:
logger.error(f"Error getting category by slug {slug}: {e}")
raise
async def create_category(
self,
name: str,
description: Optional[str] = None,
icon: Optional[str] = None
) -> Dict[str, Any]:
"""
Create a new custom category.
The creating user becomes the owner and can edit/delete it.
"""
try:
pg_client = await get_postgresql_client()
user_id = await self._get_user_id(pg_client)
tenant_id = await self._get_tenant_id(pg_client)
# Generate slug
slug = self._generate_slug(name)
# Check if slug already exists
existing = await self.get_category_by_slug(slug)
if existing:
raise ValueError(f"A category with name '{name}' already exists")
# Generate category ID
category_id = str(uuid.uuid4())
# Get next sort_order (after all existing categories)
sort_query = """
SELECT COALESCE(MAX(sort_order), 0) + 10 as next_order
FROM categories
WHERE tenant_id = $1::uuid
"""
next_order = await pg_client.fetch_scalar(sort_query, tenant_id)
# Create category
query = """
INSERT INTO categories (
id, tenant_id, name, slug, description, icon,
is_default, created_by, sort_order, is_deleted,
created_at, updated_at
) VALUES (
$1::uuid, $2::uuid, $3, $4, $5, $6,
FALSE, $7::uuid, $8, FALSE,
NOW(), NOW()
)
RETURNING id, name, slug, description, icon, is_default,
created_by, sort_order, created_at, updated_at
"""
row = await pg_client.fetch_one(
query,
category_id, tenant_id, name, slug, description, icon,
user_id, next_order
)
if not row:
raise RuntimeError("Failed to create category")
logger.info(f"Created category {category_id}: {name} for user {user_id}")
# Get creator name
user_query = "SELECT full_name FROM users WHERE id = $1::uuid"
created_by_name = await pg_client.fetch_scalar(user_query, user_id)
return {
"id": str(row["id"]),
"name": row["name"],
"slug": row["slug"],
"description": row.get("description"),
"icon": row.get("icon"),
"is_default": False,
"created_by": user_id,
"created_by_name": created_by_name,
"can_edit": True,
"can_delete": True,
"sort_order": row.get("sort_order", 0),
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
"updated_at": row["updated_at"].isoformat() if row.get("updated_at") else None,
}
except ValueError:
raise
except Exception as e:
logger.error(f"Error creating category: {e}")
raise
async def update_category(
self,
category_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
icon: Optional[str] = None
) -> Dict[str, Any]:
"""
Update a category.
Requires permission (admin or category creator).
"""
try:
pg_client = await get_postgresql_client()
# Get existing category
existing = await self.get_category_by_id(category_id)
if not existing:
raise ValueError("Category not found")
# Check permissions
can_edit, _ = await self._can_manage_category(pg_client, existing)
if not can_edit:
raise PermissionError("You do not have permission to edit this category")
# Build update fields
updates = []
params = [category_id, self.tenant_domain]
param_idx = 3
if name is not None:
new_slug = self._generate_slug(name)
# Check if new slug conflicts with another category
slug_check = await self.get_category_by_slug(new_slug)
if slug_check and slug_check["id"] != category_id:
raise ValueError(f"A category with name '{name}' already exists")
updates.append(f"name = ${param_idx}")
params.append(name)
param_idx += 1
updates.append(f"slug = ${param_idx}")
params.append(new_slug)
param_idx += 1
if description is not None:
updates.append(f"description = ${param_idx}")
params.append(description)
param_idx += 1
if icon is not None:
updates.append(f"icon = ${param_idx}")
params.append(icon)
param_idx += 1
if not updates:
return existing
updates.append("updated_at = NOW()")
query = f"""
UPDATE categories
SET {', '.join(updates)}
WHERE id = $1::uuid
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
AND is_deleted = FALSE
RETURNING id
"""
result = await pg_client.fetch_scalar(query, *params)
if not result:
raise RuntimeError("Failed to update category")
logger.info(f"Updated category {category_id}")
# Return updated category
return await self.get_category_by_id(category_id)
except (ValueError, PermissionError):
raise
except Exception as e:
logger.error(f"Error updating category {category_id}: {e}")
raise
async def delete_category(self, category_id: str) -> bool:
"""
Soft delete a category.
Requires permission (admin or category creator).
"""
try:
pg_client = await get_postgresql_client()
# Get existing category
existing = await self.get_category_by_id(category_id)
if not existing:
raise ValueError("Category not found")
# Check permissions
_, can_delete = await self._can_manage_category(pg_client, existing)
if not can_delete:
raise PermissionError("You do not have permission to delete this category")
# Soft delete
query = """
UPDATE categories
SET is_deleted = TRUE, updated_at = NOW()
WHERE id = $1::uuid
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
"""
await pg_client.execute_command(query, category_id, self.tenant_domain)
logger.info(f"Deleted category {category_id}")
return True
except (ValueError, PermissionError):
raise
except Exception as e:
logger.error(f"Error deleting category {category_id}: {e}")
raise
async def get_or_create_category(
self,
slug: str,
description: Optional[str] = None
) -> Dict[str, Any]:
"""
Get existing category by slug or create it if not exists.
Used for agent import to auto-create missing categories.
If the category was soft-deleted, it will be restored.
Args:
slug: Category slug (lowercase, hyphenated)
description: Optional description for new/restored categories
"""
try:
# Try to get existing active category
existing = await self.get_category_by_slug(slug)
if existing:
return existing
# Check if there's a soft-deleted category with this slug
pg_client = await get_postgresql_client()
deleted_query = """
SELECT id FROM categories
WHERE slug = $1
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
AND is_deleted = TRUE
"""
deleted_id = await pg_client.fetch_scalar(deleted_query, slug.lower(), self.tenant_domain)
if deleted_id:
# Restore the soft-deleted category
user_id = await self._get_user_id(pg_client)
restore_query = """
UPDATE categories
SET is_deleted = FALSE,
updated_at = NOW(),
created_by = $3::uuid
WHERE id = $1::uuid
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
"""
await pg_client.execute_command(restore_query, str(deleted_id), self.tenant_domain, user_id)
logger.info(f"Restored soft-deleted category: {slug}")
# Return the restored category
return await self.get_category_by_slug(slug)
# Auto-create with importing user as creator
name = slug.replace('-', ' ').title()
return await self.create_category(
name=name,
description=description, # Use provided description or None
icon=None
)
except Exception as e:
logger.error(f"Error in get_or_create_category for slug {slug}: {e}")
raise

View File

@@ -0,0 +1,563 @@
"""
Conversation File Service for GT 2.0
Handles conversation-scoped file attachments as a simpler alternative to dataset-based uploads.
Preserves all existing dataset infrastructure while providing direct conversation file storage.
"""
import os
import uuid
import logging
import asyncio
from pathlib import Path
from typing import Dict, Any, List, Optional
from datetime import datetime
from fastapi import UploadFile, HTTPException
from app.core.config import get_settings
from app.core.postgresql_client import get_postgresql_client
from app.core.path_security import sanitize_tenant_domain
from app.services.embedding_client import BGE_M3_EmbeddingClient
from app.services.document_processor import DocumentProcessor
logger = logging.getLogger(__name__)
class ConversationFileService:
"""Service for managing conversation-scoped file attachments"""
def __init__(self, tenant_domain: str, user_id: str):
self.tenant_domain = tenant_domain
self.user_id = user_id
self.settings = get_settings()
self.schema_name = f"tenant_{tenant_domain.replace('.', '_').replace('-', '_')}"
# File storage configuration
# Sanitize tenant_domain to prevent path traversal
safe_tenant = sanitize_tenant_domain(tenant_domain)
# codeql[py/path-injection] safe_tenant validated by sanitize_tenant_domain()
self.storage_root = Path(self.settings.file_storage_path) / safe_tenant / "conversations"
self.storage_root.mkdir(parents=True, exist_ok=True)
logger.info(f"ConversationFileService initialized for {tenant_domain}/{user_id}")
def _get_conversation_storage_path(self, conversation_id: str) -> Path:
"""Get storage directory for conversation files"""
conv_path = self.storage_root / conversation_id
conv_path.mkdir(parents=True, exist_ok=True)
return conv_path
def _generate_safe_filename(self, original_filename: str, file_id: str) -> str:
"""Generate safe filename for storage"""
# Sanitize filename and prepend file ID
safe_name = "".join(c for c in original_filename if c.isalnum() or c in ".-_")
return f"{file_id}-{safe_name}"
async def upload_files(
self,
conversation_id: str,
files: List[UploadFile],
user_id: str
) -> List[Dict[str, Any]]:
"""Upload files directly to conversation"""
try:
# Validate conversation access
await self._validate_conversation_access(conversation_id, user_id)
uploaded_files = []
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File must have a filename")
# Generate file metadata
file_id = str(uuid.uuid4())
safe_filename = self._generate_safe_filename(file.filename, file_id)
conversation_path = self._get_conversation_storage_path(conversation_id)
file_path = conversation_path / safe_filename
# Store file to disk
content = await file.read()
with open(file_path, "wb") as f:
f.write(content)
# Create database record
file_record = await self._create_file_record(
file_id=file_id,
conversation_id=conversation_id,
original_filename=file.filename,
safe_filename=safe_filename,
content_type=file.content_type or "application/octet-stream",
file_size=len(content),
file_path=str(file_path.relative_to(Path(self.settings.file_storage_path))),
uploaded_by=user_id
)
uploaded_files.append(file_record)
# Queue for background processing
asyncio.create_task(self._process_file_embeddings(file_id))
logger.info(f"Uploaded conversation file: {file.filename} -> {file_id}")
return uploaded_files
except Exception as e:
logger.error(f"Failed to upload conversation files: {e}")
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
async def _get_user_uuid(self, user_email: str) -> str:
"""Resolve user email to UUID"""
client = await get_postgresql_client()
query = f"SELECT id FROM {self.schema_name}.users WHERE email = $1 LIMIT 1"
result = await client.fetch_one(query, user_email)
if not result:
raise ValueError(f"User not found: {user_email}")
return str(result['id'])
async def _create_file_record(
self,
file_id: str,
conversation_id: str,
original_filename: str,
safe_filename: str,
content_type: str,
file_size: int,
file_path: str,
uploaded_by: str
) -> Dict[str, Any]:
"""Create conversation_files database record"""
client = await get_postgresql_client()
# Resolve user email to UUID if needed
user_uuid = uploaded_by
if '@' in uploaded_by: # Check if it's an email
user_uuid = await self._get_user_uuid(uploaded_by)
query = f"""
INSERT INTO {self.schema_name}.conversation_files (
id, conversation_id, filename, original_filename, content_type,
file_size_bytes, file_path, uploaded_by, processing_status
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'pending')
RETURNING id, filename, original_filename, content_type, file_size_bytes,
processing_status, uploaded_at
"""
result = await client.fetch_one(
query,
file_id, conversation_id, safe_filename, original_filename,
content_type, file_size, file_path, user_uuid
)
# Convert UUID fields to strings for JSON serialization
result_dict = dict(result)
if 'id' in result_dict and result_dict['id']:
result_dict['id'] = str(result_dict['id'])
return result_dict
async def _process_file_embeddings(self, file_id: str):
"""Background task to process file content and generate embeddings"""
try:
# Update status to processing
await self._update_processing_status(file_id, "processing")
# Get file record
file_record = await self._get_file_record(file_id)
if not file_record:
logger.error(f"File record not found: {file_id}")
return
# Read file content
file_path = Path(self.settings.file_storage_path) / file_record['file_path']
if not file_path.exists():
logger.error(f"File not found on disk: {file_path}")
await self._update_processing_status(file_id, "failed")
return
# Extract text content using DocumentProcessor public methods
processor = DocumentProcessor()
text_content = await processor.extract_text_from_path(
file_path,
file_record['content_type']
)
if not text_content:
logger.warning(f"No text content extracted from {file_record['original_filename']}")
await self._update_processing_status(file_id, "completed")
return
# Chunk content for RAG
chunks = await processor.chunk_text_simple(text_content)
# Generate embeddings for full document (single embedding for semantic search)
embedding_client = BGE_M3_EmbeddingClient()
embeddings = await embedding_client.generate_embeddings([text_content])
if not embeddings:
logger.error(f"Failed to generate embeddings for {file_id}")
await self._update_processing_status(file_id, "failed")
return
# Update record with processed content (chunks as JSONB, embedding as vector)
await self._update_file_processing_results(
file_id, chunks, embeddings[0], "completed"
)
logger.info(f"Successfully processed file: {file_record['original_filename']}")
except Exception as e:
logger.error(f"Failed to process file {file_id}: {e}")
await self._update_processing_status(file_id, "failed")
async def _update_processing_status(self, file_id: str, status: str):
"""Update file processing status"""
client = await get_postgresql_client()
query = f"""
UPDATE {self.schema_name}.conversation_files
SET processing_status = $1,
processed_at = CASE WHEN $1 IN ('completed', 'failed') THEN NOW() ELSE processed_at END
WHERE id = $2
"""
await client.execute_query(query, status, file_id)
async def _update_file_processing_results(
self,
file_id: str,
chunks: List[str],
embedding: List[float],
status: str
):
"""Update file with processing results"""
import json
client = await get_postgresql_client()
# Sanitize chunks: remove null bytes and other control characters
# that PostgreSQL can't handle in JSONB
sanitized_chunks = [
chunk.replace('\u0000', '').replace('\x00', '')
for chunk in chunks
]
# Convert chunks list to JSONB-compatible format
chunks_json = json.dumps(sanitized_chunks)
# Convert embedding to PostgreSQL vector format
embedding_str = f"[{','.join(map(str, embedding))}]"
query = f"""
UPDATE {self.schema_name}.conversation_files
SET processed_chunks = $1::jsonb,
embeddings = $2::vector,
processing_status = $3,
processed_at = NOW()
WHERE id = $4
"""
await client.execute_query(query, chunks_json, embedding_str, status, file_id)
async def _get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]:
"""Get file record by ID"""
client = await get_postgresql_client()
query = f"""
SELECT id, conversation_id, filename, original_filename, content_type,
file_size_bytes, file_path, processing_status, uploaded_at
FROM {self.schema_name}.conversation_files
WHERE id = $1
"""
result = await client.fetch_one(query, file_id)
return dict(result) if result else None
async def list_files(self, conversation_id: str) -> List[Dict[str, Any]]:
"""List files attached to conversation"""
try:
client = await get_postgresql_client()
query = f"""
SELECT id, filename, original_filename, content_type, file_size_bytes,
processing_status, uploaded_at, processed_at
FROM {self.schema_name}.conversation_files
WHERE conversation_id = $1
ORDER BY uploaded_at DESC
"""
rows = await client.execute_query(query, conversation_id)
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"Failed to list conversation files: {e}")
return []
async def delete_file(self, conversation_id: str, file_id: str, user_id: str, allow_post_message_deletion: bool = False) -> bool:
"""Delete specific file from conversation
Args:
conversation_id: The conversation ID
file_id: The file ID to delete
user_id: The user requesting deletion
allow_post_message_deletion: If False, prevents deletion after messages exist (default: False)
"""
try:
logger.info(f"DELETE FILE CALLED: file_id={file_id}, conversation_id={conversation_id}, user_id={user_id}")
# Validate access
await self._validate_conversation_access(conversation_id, user_id)
logger.info(f"DELETE FILE: Access validated")
# Check if conversation has messages (unless explicitly allowed to delete post-message)
if not allow_post_message_deletion:
client = await get_postgresql_client()
message_check_query = f"""
SELECT COUNT(*) as count
FROM {self.schema_name}.messages
WHERE conversation_id = $1
"""
message_count_result = await client.fetch_one(message_check_query, conversation_id)
message_count = message_count_result['count'] if message_count_result else 0
if message_count > 0:
from fastapi import HTTPException
raise HTTPException(
status_code=400,
detail="Cannot delete files after conversation has started. Files are part of the conversation context."
)
# Get file record for cleanup
file_record = await self._get_file_record(file_id)
logger.info(f"DELETE FILE: file_record={file_record}")
if not file_record or str(file_record['conversation_id']) != conversation_id:
logger.warning(f"DELETE FILE FAILED: file not found or conversation mismatch. file_record={file_record}, expected_conv_id={conversation_id}")
return False
# Delete from database
client = await get_postgresql_client()
query = f"""
DELETE FROM {self.schema_name}.conversation_files
WHERE id = $1 AND conversation_id = $2
"""
rows_deleted = await client.execute_command(query, file_id, conversation_id)
if rows_deleted > 0:
# Delete file from disk
file_path = Path(self.settings.file_storage_path) / file_record['file_path']
if file_path.exists():
file_path.unlink()
logger.info(f"Deleted conversation file: {file_id}")
return True
return False
except HTTPException:
raise # Re-raise HTTPException to preserve status code and message
except Exception as e:
logger.error(f"Failed to delete conversation file: {e}")
return False
async def search_conversation_files(
self,
conversation_id: str,
query: str,
max_results: int = 5
) -> List[Dict[str, Any]]:
"""Search files within a conversation using vector similarity"""
try:
# Generate query embedding
embedding_client = BGE_M3_EmbeddingClient()
embeddings = await embedding_client.generate_embeddings([query])
if not embeddings:
return []
query_embedding = embeddings[0]
# Convert embedding to PostgreSQL vector format
embedding_str = '[' + ','.join(map(str, query_embedding)) + ']'
# Vector search against conversation files
client = await get_postgresql_client()
search_query = f"""
SELECT id, filename, original_filename, processed_chunks,
1 - (embeddings <=> $1::vector) as similarity_score
FROM {self.schema_name}.conversation_files
WHERE conversation_id = $2
AND processing_status = 'completed'
AND embeddings IS NOT NULL
AND 1 - (embeddings <=> $1::vector) > 0.1
ORDER BY embeddings <=> $1::vector
LIMIT $3
"""
rows = await client.execute_query(
search_query, embedding_str, conversation_id, max_results
)
results = []
for row in rows:
processed_chunks = row.get('processed_chunks', [])
if not processed_chunks:
continue
# Handle case where processed_chunks might be returned as JSON string
if isinstance(processed_chunks, str):
import json
processed_chunks = json.loads(processed_chunks)
for idx, chunk_text in enumerate(processed_chunks):
results.append({
'id': f"{row['id']}_chunk_{idx}",
'document_id': row['id'],
'document_name': row['original_filename'],
'original_filename': row['original_filename'],
'chunk_index': idx,
'content': chunk_text,
'similarity_score': row['similarity_score'],
'source': 'conversation_file',
'source_type': 'conversation_file'
})
if len(results) >= max_results:
results = results[:max_results]
break
logger.info(f"Found {len(results)} chunks from {len(rows)} matching conversation files")
return results
except Exception as e:
logger.error(f"Failed to search conversation files: {e}")
return []
async def get_all_chunks_for_conversation(
self,
conversation_id: str,
max_chunks_per_file: int = 50,
max_total_chunks: int = 100
) -> List[Dict[str, Any]]:
"""
Retrieve ALL chunks from files attached to conversation.
Non-query-dependent - returns everything up to limits.
Args:
conversation_id: UUID of conversation
max_chunks_per_file: Limit per file (enforces diversity)
max_total_chunks: Total chunk limit across all files
Returns:
List of chunks with metadata, grouped by file
"""
try:
client = await get_postgresql_client()
query = f"""
SELECT id, filename, original_filename, processed_chunks,
file_size_bytes, uploaded_at
FROM {self.schema_name}.conversation_files
WHERE conversation_id = $1
AND processing_status = 'completed'
AND processed_chunks IS NOT NULL
ORDER BY uploaded_at ASC
"""
rows = await client.execute_query(query, conversation_id)
results = []
total_chunks = 0
for row in rows:
if total_chunks >= max_total_chunks:
break
processed_chunks = row.get('processed_chunks', [])
# Handle JSON string if needed
if isinstance(processed_chunks, str):
import json
processed_chunks = json.loads(processed_chunks)
# Limit chunks per file (diversity enforcement)
chunks_from_this_file = 0
for idx, chunk_text in enumerate(processed_chunks):
if chunks_from_this_file >= max_chunks_per_file:
break
if total_chunks >= max_total_chunks:
break
results.append({
'id': f"{row['id']}_chunk_{idx}",
'document_id': row['id'],
'document_name': row['original_filename'],
'original_filename': row['original_filename'],
'chunk_index': idx,
'total_chunks': len(processed_chunks),
'content': chunk_text,
'file_size_bytes': row['file_size_bytes'],
'source': 'conversation_file',
'source_type': 'conversation_file'
})
chunks_from_this_file += 1
total_chunks += 1
logger.info(f"Retrieved {len(results)} total chunks from {len(rows)} conversation files")
return results
except Exception as e:
logger.error(f"Failed to get all chunks for conversation: {e}")
return []
async def _validate_conversation_access(self, conversation_id: str, user_id: str):
"""Validate user has access to conversation"""
client = await get_postgresql_client()
query = f"""
SELECT id FROM {self.schema_name}.conversations
WHERE id = $1 AND user_id = (
SELECT id FROM {self.schema_name}.users WHERE email = $2 LIMIT 1
)
"""
result = await client.fetch_one(query, conversation_id, user_id)
if not result:
raise HTTPException(
status_code=403,
detail="Access denied: conversation not found or access denied"
)
async def get_file_content(self, file_id: str, user_id: str) -> Optional[bytes]:
"""Get file content for download"""
try:
file_record = await self._get_file_record(file_id)
if not file_record:
return None
# Validate access to conversation
await self._validate_conversation_access(file_record['conversation_id'], user_id)
# Read file content
file_path = Path(self.settings.file_storage_path) / file_record['file_path']
if file_path.exists():
with open(file_path, "rb") as f:
return f.read()
return None
except Exception as e:
logger.error(f"Failed to get file content: {e}")
return None
# Factory function for service instances
def get_conversation_file_service(tenant_domain: str, user_id: str) -> ConversationFileService:
"""Get conversation file service instance"""
return ConversationFileService(tenant_domain, user_id)

View File

@@ -0,0 +1,959 @@
"""
Conversation Service for GT 2.0 Tenant Backend - PostgreSQL + PGVector
Manages AI-powered conversations with Agent integration using PostgreSQL directly.
Handles message persistence, context management, and LLM inference.
Replaces SQLAlchemy with direct PostgreSQL operations for GT 2.0 principles.
"""
import json
import logging
import uuid
from datetime import datetime
from typing import Dict, Any, List, Optional, AsyncIterator, AsyncGenerator
from app.core.config import get_settings
from app.core.postgresql_client import get_postgresql_client
from app.services.agent_service import AgentService
from app.core.resource_client import ResourceClusterClient
from app.services.conversation_summarizer import ConversationSummarizer
logger = logging.getLogger(__name__)
class ConversationService:
"""PostgreSQL-based service for managing AI conversations"""
def __init__(self, tenant_domain: str, user_id: str):
"""Initialize with tenant and user isolation using PostgreSQL"""
self.tenant_domain = tenant_domain
self.user_id = user_id
self.settings = get_settings()
self.agent_service = AgentService(tenant_domain, user_id)
self.resource_client = ResourceClusterClient()
self._resolved_user_uuid = None # Cache for resolved user UUID
logger.info(f"Conversation service initialized with PostgreSQL for {tenant_domain}/{user_id}")
async def _get_resolved_user_uuid(self, user_identifier: Optional[str] = None) -> str:
"""
Resolve user identifier to UUID with caching for performance.
This optimization reduces repeated database lookups by caching the resolved UUID.
Performance impact: ~50% reduction in query time for operations with multiple queries.
"""
identifier = user_identifier or self.user_id
# Return cached UUID if already resolved for this instance
if self._resolved_user_uuid and identifier == self.user_id:
return self._resolved_user_uuid
# Check if already a UUID
if not "@" in identifier:
try:
# Validate it's a proper UUID format
uuid.UUID(identifier)
if identifier == self.user_id:
self._resolved_user_uuid = identifier
return identifier
except ValueError:
pass # Not a valid UUID, treat as email/username
# Resolve email to UUID
pg_client = await get_postgresql_client()
query = "SELECT id FROM users WHERE email = $1 LIMIT 1"
result = await pg_client.fetch_one(query, identifier)
if not result:
raise ValueError(f"User not found: {identifier}")
user_uuid = str(result["id"])
# Cache if this is the service's primary user
if identifier == self.user_id:
self._resolved_user_uuid = user_uuid
return user_uuid
def _get_user_clause(self, param_num: int, user_identifier: str) -> str:
"""
DEPRECATED: Get the appropriate SQL clause for user identification.
Use _get_resolved_user_uuid() instead for better performance.
"""
if "@" in user_identifier:
# Email - do lookup
return f"(SELECT id FROM users WHERE email = ${param_num} LIMIT 1)"
else:
# UUID - use directly
return f"${param_num}::uuid"
async def create_conversation(
self,
agent_id: str,
title: Optional[str],
user_identifier: Optional[str] = None
) -> Dict[str, Any]:
"""Create a new conversation with an agent using PostgreSQL"""
try:
# Resolve user UUID with caching (performance optimization)
user_uuid = await self._get_resolved_user_uuid(user_identifier)
# Get agent configuration
agent_data = await self.agent_service.get_agent(agent_id)
if not agent_data:
raise ValueError(f"Agent {agent_id} not found")
# Validate tenant has access to the agent's model
agent_model = agent_data.get("model")
if agent_model:
available_models = await self.get_available_models(self.tenant_domain)
available_model_ids = [m["model_id"] for m in available_models]
if agent_model not in available_model_ids:
raise ValueError(f"Agent model '{agent_model}' is not accessible to tenant '{self.tenant_domain}'. Available models: {', '.join(available_model_ids)}")
logger.info(f"Validated tenant access to model '{agent_model}' for agent '{agent_data.get('name')}'")
else:
logger.warning(f"Agent {agent_id} has no model configured, will use default")
# Get PostgreSQL client
pg_client = await get_postgresql_client()
# Generate conversation ID
conversation_id = str(uuid.uuid4())
# Create conversation in PostgreSQL (optimized: use resolved UUID directly)
query = """
INSERT INTO conversations (
id, title, tenant_id, user_id, agent_id, summary,
total_messages, total_tokens, metadata, is_archived,
created_at, updated_at
) VALUES (
$1, $2,
(SELECT id FROM tenants WHERE domain = $3 LIMIT 1),
$4::uuid,
$5, '', 0, 0, '{}', false, NOW(), NOW()
)
RETURNING id, title, tenant_id, user_id, agent_id, created_at, updated_at
"""
conv_title = title or f"Conversation with {agent_data.get('name', 'Agent')}"
conversation_data = await pg_client.fetch_one(
query,
conversation_id, conv_title, self.tenant_domain,
user_uuid, agent_id
)
if not conversation_data:
raise RuntimeError("Failed to create conversation - no data returned")
# Note: conversation_settings and conversation_participants are now created automatically
# by the auto_create_conversation_settings trigger, so we don't need to create them manually
# Get the model_id from the auto-created settings or use agent's model
settings_query = """
SELECT model_id FROM conversation_settings WHERE conversation_id = $1
"""
settings_data = await pg_client.fetch_one(settings_query, conversation_id)
model_id = settings_data["model_id"] if settings_data else agent_model
result = {
"id": str(conversation_data["id"]),
"title": conversation_data["title"],
"agent_id": str(conversation_data["agent_id"]),
"model_id": model_id,
"created_at": conversation_data["created_at"].isoformat(),
"user_id": user_uuid,
"tenant_domain": self.tenant_domain
}
logger.info(f"Created conversation {conversation_id} in PostgreSQL for user {user_uuid}")
return result
except Exception as e:
logger.error(f"Failed to create conversation: {e}")
raise
async def list_conversations(
self,
user_identifier: str,
agent_id: Optional[str] = None,
search: Optional[str] = None,
time_filter: str = "all",
limit: int = 20,
offset: int = 0
) -> Dict[str, Any]:
"""List conversations for a user using PostgreSQL with server-side filtering"""
try:
# Resolve user UUID with caching (performance optimization)
user_uuid = await self._get_resolved_user_uuid(user_identifier)
pg_client = await get_postgresql_client()
# Build query with optional filters - exclude archived conversations (optimized: use cached UUID)
where_clause = "WHERE c.user_id = $1::uuid AND c.is_archived = false"
params = [user_uuid]
param_count = 1
# Time filter
if time_filter != "all":
if time_filter == "today":
where_clause += " AND c.updated_at >= NOW() - INTERVAL '1 day'"
elif time_filter == "week":
where_clause += " AND c.updated_at >= NOW() - INTERVAL '7 days'"
elif time_filter == "month":
where_clause += " AND c.updated_at >= NOW() - INTERVAL '30 days'"
# Agent filter
if agent_id:
param_count += 1
where_clause += f" AND c.agent_id = ${param_count}"
params.append(agent_id)
# Search filter (case-insensitive title search)
if search:
param_count += 1
where_clause += f" AND c.title ILIKE ${param_count}"
params.append(f"%{search}%")
# Get conversations with agent info and unread counts (optimized: use cached UUID)
query = f"""
SELECT
c.id, c.title, c.agent_id, c.created_at, c.updated_at,
c.total_messages, c.total_tokens, c.is_archived,
a.name as agent_name,
COUNT(m.id) FILTER (
WHERE m.created_at > COALESCE((c.metadata->>'last_read_at')::timestamptz, c.created_at)
AND m.user_id != $1::uuid
) as unread_count
FROM conversations c
LEFT JOIN agents a ON c.agent_id = a.id
LEFT JOIN messages m ON m.conversation_id = c.id
{where_clause}
GROUP BY c.id, c.title, c.agent_id, c.created_at, c.updated_at,
c.total_messages, c.total_tokens, c.is_archived, a.name
ORDER BY
CASE WHEN COUNT(m.id) FILTER (
WHERE m.created_at > COALESCE((c.metadata->>'last_read_at')::timestamptz, c.created_at)
AND m.user_id != $1::uuid
) > 0 THEN 0 ELSE 1 END,
c.updated_at DESC
LIMIT ${param_count + 1} OFFSET ${param_count + 2}
"""
params.extend([limit, offset])
conversations = await pg_client.execute_query(query, *params)
# Get total count
count_query = f"""
SELECT COUNT(*) as total
FROM conversations c
{where_clause}
"""
count_result = await pg_client.fetch_one(count_query, *params[:-2]) # Exclude limit/offset
total = count_result["total"] if count_result else 0
# Format results with lightweight fields including unread count
conversation_list = [
{
"id": str(conv["id"]),
"title": conv["title"],
"agent_id": str(conv["agent_id"]) if conv["agent_id"] else None,
"agent_name": conv["agent_name"] or "AI Assistant",
"created_at": conv["created_at"].isoformat(),
"updated_at": conv["updated_at"].isoformat(),
"last_message_at": conv["updated_at"].isoformat(), # Use updated_at as last activity
"message_count": conv["total_messages"] or 0,
"token_count": conv["total_tokens"] or 0,
"is_archived": conv["is_archived"],
"unread_count": conv.get("unread_count", 0) or 0 # Include unread count
# Removed preview field for performance
}
for conv in conversations
]
return {
"conversations": conversation_list,
"total": total,
"limit": limit,
"offset": offset
}
except Exception as e:
logger.error(f"Failed to list conversations: {e}")
raise
async def get_conversation(
self,
conversation_id: str,
user_identifier: str
) -> Optional[Dict[str, Any]]:
"""Get a specific conversation with details"""
try:
# Resolve user UUID with caching (performance optimization)
user_uuid = await self._get_resolved_user_uuid(user_identifier)
pg_client = await get_postgresql_client()
query = """
SELECT
c.id, c.title, c.agent_id, c.created_at, c.updated_at,
c.total_messages, c.total_tokens, c.is_archived, c.summary,
a.name as agent_name,
cs.model_id, cs.temperature, cs.max_tokens, cs.system_prompt
FROM conversations c
LEFT JOIN agents a ON c.agent_id = a.id
LEFT JOIN conversation_settings cs ON c.id = cs.conversation_id
WHERE c.id = $1
AND c.user_id = $2::uuid
LIMIT 1
"""
conversation = await pg_client.fetch_one(query, conversation_id, user_uuid)
if not conversation:
return None
return {
"id": conversation["id"],
"title": conversation["title"],
"agent_id": conversation["agent_id"],
"agent_name": conversation["agent_name"],
"model_id": conversation["model_id"],
"temperature": float(conversation["temperature"]) if conversation["temperature"] else 0.7,
"max_tokens": conversation["max_tokens"],
"system_prompt": conversation["system_prompt"],
"summary": conversation["summary"],
"message_count": conversation["total_messages"],
"token_count": conversation["total_tokens"],
"is_archived": conversation["is_archived"],
"created_at": conversation["created_at"].isoformat(),
"updated_at": conversation["updated_at"].isoformat()
}
except Exception as e:
logger.error(f"Failed to get conversation {conversation_id}: {e}")
return None
async def add_message(
self,
conversation_id: str,
role: str,
content: str,
user_identifier: str,
model_used: Optional[str] = None,
token_count: int = 0,
metadata: Optional[Dict] = None,
attachments: Optional[List] = None
) -> Dict[str, Any]:
"""Add a message to a conversation"""
try:
# Resolve user UUID with caching (performance optimization)
user_uuid = await self._get_resolved_user_uuid(user_identifier)
pg_client = await get_postgresql_client()
message_id = str(uuid.uuid4())
# Insert message (optimized: use cached UUID)
query = """
INSERT INTO messages (
id, conversation_id, user_id, role, content,
content_type, token_count, model_used, metadata, attachments, created_at
) VALUES (
$1, $2, $3::uuid,
$4, $5, 'text', $6, $7, $8, $9, NOW()
)
RETURNING id, created_at
"""
message_data = await pg_client.fetch_one(
query,
message_id, conversation_id, user_uuid,
role, content, token_count, model_used,
json.dumps(metadata or {}), json.dumps(attachments or [])
)
if not message_data:
raise RuntimeError("Failed to add message - no data returned")
# Update conversation totals (optimized: use cached UUID)
update_query = """
UPDATE conversations
SET total_messages = total_messages + 1,
total_tokens = total_tokens + $3,
updated_at = NOW()
WHERE id = $1
AND user_id = $2::uuid
"""
await pg_client.execute_command(update_query, conversation_id, user_uuid, token_count)
result = {
"id": message_data["id"],
"conversation_id": conversation_id,
"role": role,
"content": content,
"token_count": token_count,
"model_used": model_used,
"metadata": metadata or {},
"attachments": attachments or [],
"created_at": message_data["created_at"].isoformat()
}
logger.info(f"Added message {message_id} to conversation {conversation_id}")
return result
except Exception as e:
logger.error(f"Failed to add message to conversation {conversation_id}: {e}")
raise
async def get_messages(
self,
conversation_id: str,
user_identifier: str,
limit: int = 50,
offset: int = 0
) -> List[Dict[str, Any]]:
"""Get messages for a conversation"""
try:
# Resolve user UUID with caching (performance optimization)
user_uuid = await self._get_resolved_user_uuid(user_identifier)
pg_client = await get_postgresql_client()
query = """
SELECT
m.id, m.role, m.content, m.content_type, m.token_count,
m.model_used, m.finish_reason, m.metadata, m.attachments, m.created_at
FROM messages m
JOIN conversations c ON m.conversation_id = c.id
WHERE c.id = $1
AND c.user_id = $2::uuid
ORDER BY m.created_at ASC
LIMIT $3 OFFSET $4
"""
messages = await pg_client.execute_query(query, conversation_id, user_uuid, limit, offset)
return [
{
"id": msg["id"],
"role": msg["role"],
"content": msg["content"],
"content_type": msg["content_type"],
"token_count": msg["token_count"],
"model_used": msg["model_used"],
"finish_reason": msg["finish_reason"],
"metadata": (
json.loads(msg["metadata"]) if isinstance(msg["metadata"], str)
else (msg["metadata"] if isinstance(msg["metadata"], dict) else {})
),
"attachments": (
json.loads(msg["attachments"]) if isinstance(msg["attachments"], str)
else (msg["attachments"] if isinstance(msg["attachments"], list) else [])
),
"context_sources": (
(json.loads(msg["metadata"]) if isinstance(msg["metadata"], str) else msg["metadata"]).get("context_sources", [])
if (isinstance(msg["metadata"], str) or isinstance(msg["metadata"], dict))
else []
),
"created_at": msg["created_at"].isoformat()
}
for msg in messages
]
except Exception as e:
logger.error(f"Failed to get messages for conversation {conversation_id}: {e}")
return []
async def send_message(
self,
conversation_id: str,
content: str,
user_identifier: Optional[str] = None,
stream: bool = False
) -> Dict[str, Any]:
"""Send a message to conversation and get AI response"""
user_id = user_identifier or self.user_id
# Check if this is the first message
existing_messages = await self.get_messages(conversation_id, user_id)
is_first_message = len(existing_messages) == 0
# Add user message
user_message = await self.add_message(
conversation_id=conversation_id,
role="user",
content=content,
user_identifier=user_identifier
)
# Get conversation details for agent
conversation = await self.get_conversation(conversation_id, user_identifier)
agent_id = conversation.get("agent_id")
ai_message = None
if agent_id:
agent_data = await self.agent_service.get_agent(agent_id)
# Prepare messages for AI
messages = [
{"role": "system", "content": agent_data.get("prompt_template", "You are a helpful assistant.")},
{"role": "user", "content": content}
]
# Get AI response
ai_response = await self.get_ai_response(
model=agent_data.get("model", "llama-3.1-8b-instant"),
messages=messages,
tenant_id=self.tenant_domain,
user_id=user_id
)
# Extract content from response
ai_content = ai_response["choices"][0]["message"]["content"]
# Add AI message
ai_message = await self.add_message(
conversation_id=conversation_id,
role="agent",
content=ai_content,
user_identifier=user_id,
model_used=agent_data.get("model"),
token_count=ai_response["usage"]["total_tokens"]
)
return {
"user_message": user_message,
"ai_message": ai_message,
"is_first_message": is_first_message,
"conversation_id": conversation_id
}
async def update_conversation(
self,
conversation_id: str,
user_identifier: str,
title: Optional[str] = None
) -> bool:
"""Update conversation properties like title"""
try:
# Resolve user UUID with caching (performance optimization)
user_uuid = await self._get_resolved_user_uuid(user_identifier)
pg_client = await get_postgresql_client()
# Build dynamic update query based on provided fields
update_fields = []
params = []
param_count = 1
if title is not None:
update_fields.append(f"title = ${param_count}")
params.append(title)
param_count += 1
if not update_fields:
return True # Nothing to update
# Add updated_at timestamp
update_fields.append(f"updated_at = NOW()")
query = f"""
UPDATE conversations
SET {', '.join(update_fields)}
WHERE id = ${param_count}
AND user_id = ${param_count + 1}::uuid
RETURNING id
"""
params.extend([conversation_id, user_uuid])
result = await pg_client.fetch_scalar(query, *params)
if result:
logger.info(f"Updated conversation {conversation_id}")
return True
return False
except Exception as e:
logger.error(f"Failed to update conversation {conversation_id}: {e}")
return False
async def auto_generate_conversation_title(
self,
conversation_id: str,
user_identifier: str
) -> Optional[str]:
"""Generate conversation title based on first user prompt and agent response pair"""
try:
# Get only the first few messages (first exchange)
messages = await self.get_messages(conversation_id, user_identifier, limit=2)
if not messages or len(messages) < 2:
return None # Need at least one user-agent exchange
# Only use first user message and first agent response for title
first_exchange = messages[:2]
# Generate title using the summarization service
from app.services.conversation_summarizer import generate_conversation_title
new_title = await generate_conversation_title(first_exchange, self.tenant_domain, user_identifier)
# Update the conversation with the generated title
success = await self.update_conversation(
conversation_id=conversation_id,
user_identifier=user_identifier,
title=new_title
)
if success:
logger.info(f"Auto-generated title '{new_title}' for conversation {conversation_id} based on first exchange")
return new_title
else:
logger.warning(f"Failed to update conversation {conversation_id} with generated title")
return None
except Exception as e:
logger.error(f"Failed to auto-generate title for conversation {conversation_id}: {e}")
return None
async def delete_conversation(
self,
conversation_id: str,
user_identifier: str
) -> bool:
"""Soft delete a conversation (archive it)"""
try:
# Resolve user UUID with caching (performance optimization)
user_uuid = await self._get_resolved_user_uuid(user_identifier)
pg_client = await get_postgresql_client()
query = """
UPDATE conversations
SET is_archived = true, updated_at = NOW()
WHERE id = $1
AND user_id = $2::uuid
RETURNING id
"""
result = await pg_client.fetch_scalar(query, conversation_id, user_uuid)
if result:
logger.info(f"Archived conversation {conversation_id}")
return True
return False
except Exception as e:
logger.error(f"Failed to archive conversation {conversation_id}: {e}")
return False
async def get_ai_response(
self,
model: str,
messages: List[Dict[str, str]],
tenant_id: str,
user_id: str,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
top_p: float = 1.0,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None
) -> Dict[str, Any]:
"""Get AI response from Resource Cluster"""
try:
# Prepare request for Resource Cluster
request_data = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p
}
# Add tools if provided
if tools:
request_data["tools"] = tools
if tool_choice:
request_data["tool_choice"] = tool_choice
# Call Resource Cluster AI inference endpoint
response = await self.resource_client.call_inference_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint="chat/completions",
data=request_data
)
return response
except Exception as e:
logger.error(f"Failed to get AI response: {e}")
raise
# Streaming removed for reliability - using non-streaming only
async def get_available_models(self, tenant_id: str) -> List[Dict[str, Any]]:
"""Get available models for tenant from Resource Cluster"""
try:
# Get models dynamically from Resource Cluster
import aiohttp
resource_cluster_url = self.resource_client.base_url
async with aiohttp.ClientSession() as session:
# Get capability token for model access
token = await self.resource_client._get_capability_token(
tenant_id=tenant_id,
user_id=self.user_id,
resources=['model_registry']
)
headers = {
'Authorization': f'Bearer {token}',
'Content-Type': 'application/json',
'X-Tenant-ID': tenant_id,
'X-User-ID': self.user_id
}
async with session.get(
f"{resource_cluster_url}/api/v1/models/",
headers=headers,
timeout=aiohttp.ClientTimeout(total=10)
) as response:
if response.status == 200:
response_data = await response.json()
models_data = response_data.get("models", [])
# Transform Resource Cluster model format to frontend format
available_models = []
for model in models_data:
# Only include available models
if model.get("status", {}).get("deployment") == "available":
available_models.append({
"id": model.get("uuid"), # Database UUID for unique identification
"model_id": model["id"], # model_id string for API calls
"name": model["name"],
"provider": model["provider"],
"model_type": model["model_type"],
"context_window": model.get("performance", {}).get("context_window", 4000),
"max_tokens": model.get("performance", {}).get("max_tokens", 4000),
"performance": model.get("performance", {}), # Include full performance for chat.py
"capabilities": {"chat": True} # All LLM models support chat
})
logger.info(f"Retrieved {len(available_models)} models from Resource Cluster")
return available_models
else:
logger.error(f"Resource Cluster returned {response.status}: {await response.text()}")
raise RuntimeError(f"Resource Cluster API error: {response.status}")
except Exception as e:
logger.error(f"Failed to get models from Resource Cluster: {e}")
raise
async def get_conversation_datasets(self, conversation_id: str, user_identifier: str) -> List[str]:
"""Get dataset IDs attached to a conversation"""
try:
pg_client = await get_postgresql_client()
# Ensure proper schema qualification
schema_name = f"tenant_{self.tenant_domain.replace('.', '_').replace('-', '_')}"
query = f"""
SELECT cd.dataset_id
FROM {schema_name}.conversations c
JOIN {schema_name}.conversation_datasets cd ON cd.conversation_id = c.id
WHERE c.id = $1
AND c.user_id = (SELECT id FROM {schema_name}.users WHERE email = $2 LIMIT 1)
AND cd.is_active = true
ORDER BY cd.attached_at ASC
"""
rows = await pg_client.execute_query(query, conversation_id, user_identifier)
dataset_ids = [str(row['dataset_id']) for row in rows]
logger.info(f"Found {len(dataset_ids)} datasets for conversation {conversation_id}")
return dataset_ids
except Exception as e:
logger.error(f"Failed to get conversation datasets: {e}")
return []
async def add_datasets_to_conversation(
self,
conversation_id: str,
user_identifier: str,
dataset_ids: List[str],
source: str = "user_selected"
) -> bool:
"""Add datasets to a conversation"""
try:
if not dataset_ids:
return True
pg_client = await get_postgresql_client()
# Ensure proper schema qualification
schema_name = f"tenant_{self.tenant_domain.replace('.', '_').replace('-', '_')}"
# Get user ID first
user_query = f"SELECT id FROM {schema_name}.users WHERE email = $1 LIMIT 1"
user_result = await pg_client.fetch_scalar(user_query, user_identifier)
if not user_result:
logger.error(f"User not found: {user_identifier}")
return False
user_id = user_result
# Insert dataset attachments (ON CONFLICT DO NOTHING to avoid duplicates)
values_list = []
params = []
param_idx = 1
for dataset_id in dataset_ids:
values_list.append(f"(${param_idx}, ${param_idx + 1}, ${param_idx + 2})")
params.extend([conversation_id, dataset_id, user_id])
param_idx += 3
query = f"""
INSERT INTO {schema_name}.conversation_datasets (conversation_id, dataset_id, attached_by)
VALUES {', '.join(values_list)}
ON CONFLICT (conversation_id, dataset_id) DO UPDATE SET
is_active = true,
attached_at = NOW()
"""
await pg_client.execute_query(query, *params)
logger.info(f"Added {len(dataset_ids)} datasets to conversation {conversation_id}")
return True
except Exception as e:
logger.error(f"Failed to add datasets to conversation: {e}")
return False
async def copy_agent_datasets_to_conversation(
self,
conversation_id: str,
user_identifier: str,
agent_id: str
) -> bool:
"""Copy an agent's default datasets to a new conversation"""
try:
# Get agent's selected dataset IDs from config
from app.services.agent_service import AgentService
agent_service = AgentService(self.tenant_domain, user_identifier)
agent_data = await agent_service.get_agent(agent_id)
if not agent_data:
logger.warning(f"Agent {agent_id} not found")
return False
# Get selected_dataset_ids from agent config
selected_dataset_ids = agent_data.get('selected_dataset_ids', [])
if not selected_dataset_ids:
logger.info(f"Agent {agent_id} has no default datasets")
return True
# Add agent's datasets to conversation
success = await self.add_datasets_to_conversation(
conversation_id=conversation_id,
user_identifier=user_identifier,
dataset_ids=selected_dataset_ids,
source="agent_default"
)
if success:
logger.info(f"Copied {len(selected_dataset_ids)} datasets from agent {agent_id} to conversation {conversation_id}")
return success
except Exception as e:
logger.error(f"Failed to copy agent datasets: {e}")
return False
async def get_recent_conversations(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
"""Get recent conversations ordered by last activity"""
try:
pg_client = await get_postgresql_client()
# Handle both email and UUID formats using existing pattern
user_clause = self._get_user_clause(1, user_id)
query = f"""
SELECT c.id, c.title, c.created_at, c.updated_at,
COUNT(m.id) as message_count,
MAX(m.created_at) as last_message_at,
a.name as agent_name
FROM conversations c
LEFT JOIN messages m ON m.conversation_id = c.id
LEFT JOIN agents a ON a.id = c.agent_id
WHERE c.user_id = {user_clause}
AND c.is_archived = false
GROUP BY c.id, c.title, c.created_at, c.updated_at, a.name
ORDER BY COALESCE(MAX(m.created_at), c.created_at) DESC
LIMIT $2
"""
rows = await pg_client.execute_query(query, user_id, limit)
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"Failed to get recent conversations: {e}")
return []
async def mark_conversation_read(
self,
conversation_id: str,
user_identifier: str
) -> bool:
"""
Mark a conversation as read by updating last_read_at in metadata.
Args:
conversation_id: UUID of the conversation
user_identifier: User email or UUID
Returns:
bool: True if successful, False otherwise
"""
try:
# Resolve user UUID with caching (performance optimization)
user_uuid = await self._get_resolved_user_uuid(user_identifier)
pg_client = await get_postgresql_client()
# Update last_read_at in conversation metadata
query = """
UPDATE conversations
SET metadata = jsonb_set(
COALESCE(metadata, '{}'::jsonb),
'{last_read_at}',
to_jsonb(NOW()::text)
)
WHERE id = $1
AND user_id = $2::uuid
RETURNING id
"""
result = await pg_client.fetch_one(query, conversation_id, user_uuid)
if result:
logger.info(f"Marked conversation {conversation_id} as read for user {user_identifier}")
return True
else:
logger.warning(f"Conversation {conversation_id} not found or access denied for user {user_identifier}")
return False
except Exception as e:
logger.error(f"Failed to mark conversation as read: {e}")
return False

View File

@@ -0,0 +1,200 @@
"""
Conversation Summarization Service for GT 2.0
Automatically generates meaningful conversation titles using a specialized
summarization agent with llama-3.1-8b-instant.
"""
import json
import logging
from typing import List, Optional, Dict, Any
from app.core.config import get_settings
from app.core.resource_client import ResourceClusterClient
logger = logging.getLogger(__name__)
settings = get_settings()
class ConversationSummarizer:
"""Service for generating conversation summaries and titles"""
def __init__(self, tenant_id: str, user_id: str):
self.tenant_id = tenant_id
self.user_id = user_id
self.resource_client = ResourceClusterClient()
self.summarization_model = "llama-3.1-8b-instant"
async def generate_conversation_title(self, messages: List[Dict[str, Any]]) -> str:
"""
Generate a concise conversation title based on the conversation content.
Args:
messages: List of message dictionaries from the conversation
Returns:
Generated conversation title (3-6 words)
"""
try:
# Extract conversation context for summarization
conversation_text = self._prepare_conversation_context(messages)
if not conversation_text.strip():
return "New Conversation"
# Generate title using specialized summarization prompt
title = await self._call_summarization_agent(conversation_text)
# Validate and clean the generated title
clean_title = self._clean_title(title)
logger.info(f"Generated conversation title: '{clean_title}' from {len(messages)} messages")
return clean_title
except Exception as e:
logger.error(f"Error generating conversation title: {e}")
return self._fallback_title(messages)
def _prepare_conversation_context(self, messages: List[Dict[str, Any]]) -> str:
"""Prepare conversation context for summarization"""
if not messages:
return ""
# Limit to first few exchanges for title generation
context_messages = messages[:6] # First 3 user-agent exchanges
context_parts = []
for msg in context_messages:
role = "User" if msg.get("role") == "user" else "Agent"
# Truncate very long messages for context
content = msg.get("content", "")
content = content[:200] + "..." if len(content) > 200 else content
context_parts.append(f"{role}: {content}")
return "\n".join(context_parts)
async def _call_summarization_agent(self, conversation_text: str) -> str:
"""Call the resource cluster AI inference for summarization"""
summarization_prompt = f"""You are a conversation title generator. Your job is to create concise, descriptive titles for conversations.
Given this conversation:
---
{conversation_text}
---
Generate a title that:
- Is 3-6 words maximum
- Captures the main topic or purpose
- Is clear and descriptive
- Uses title case
- Does NOT include quotes or special characters
Examples of good titles:
- "Python Code Review"
- "Database Migration Help"
- "React Component Design"
- "System Architecture Discussion"
Title:"""
request_data = {
"messages": [
{
"role": "user",
"content": summarization_prompt
}
],
"model": self.summarization_model,
"temperature": 0.3, # Lower temperature for consistent, focused titles
"max_tokens": 20, # Short response for title generation
"stream": False
}
try:
# Use the resource client instead of direct HTTP calls
result = await self.resource_client.call_inference_endpoint(
tenant_id=self.tenant_id,
user_id=self.user_id,
endpoint="chat/completions",
data=request_data
)
if result and "choices" in result and len(result["choices"]) > 0:
title = result["choices"][0]["message"]["content"].strip()
return title
else:
logger.error("Invalid response format from summarization API")
return ""
except Exception as e:
logger.error(f"Error calling summarization agent: {e}")
return ""
def _clean_title(self, raw_title: str) -> str:
"""Clean and validate the generated title"""
if not raw_title:
return "New Conversation"
# Remove quotes, extra whitespace, and special characters
cleaned = raw_title.strip().strip('"\'').strip()
# Remove common prefixes that AI might add
prefixes_to_remove = [
"Title:", "title:", "TITLE:",
"Conversation:", "conversation:",
"Topic:", "topic:",
"Subject:", "subject:"
]
for prefix in prefixes_to_remove:
if cleaned.startswith(prefix):
cleaned = cleaned[len(prefix):].strip()
# Limit length and ensure it's reasonable
if len(cleaned) > 50:
cleaned = cleaned[:47] + "..."
# Ensure it's not empty after cleaning
if not cleaned or len(cleaned.split()) > 8:
return "New Conversation"
return cleaned
def _fallback_title(self, messages: List[Dict[str, Any]]) -> str:
"""Generate fallback title when AI summarization fails"""
if not messages:
return "New Conversation"
# Try to use the first user message for context
first_user_msg = next((msg for msg in messages if msg.get("role") == "user"), None)
if first_user_msg and first_user_msg.get("content"):
# Extract first few words from the user's message
words = first_user_msg["content"].strip().split()[:4]
if len(words) >= 2:
fallback = " ".join(words).capitalize()
# Remove common question words for cleaner titles
for word in ["How", "What", "Can", "Could", "Please", "Help"]:
if fallback.startswith(word + " "):
fallback = fallback[len(word):].strip()
break
return fallback if fallback else "New Conversation"
return "New Conversation"
async def generate_conversation_title(messages: List[Dict[str, Any]], tenant_id: str, user_id: str) -> str:
"""
Convenience function to generate a conversation title.
Args:
messages: List of message dictionaries from the conversation
tenant_id: Tenant identifier
user_id: User identifier
Returns:
Generated conversation title
"""
summarizer = ConversationSummarizer(tenant_id, user_id)
return await summarizer.generate_conversation_title(messages)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,585 @@
"""
Dataset Sharing Service for GT 2.0
Implements hierarchical dataset sharing with perfect tenant isolation.
Enables secure data collaboration while maintaining ownership and access control.
"""
import os
import stat
import json
import logging
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from pathlib import Path
from dataclasses import dataclass, field
from enum import Enum
from uuid import uuid4
from app.models.access_group import AccessGroup, Resource
from app.services.access_controller import AccessController
from app.core.security import verify_capability_token
logger = logging.getLogger(__name__)
class SharingPermission(Enum):
"""Sharing permission levels"""
READ = "read" # Can view and search dataset
WRITE = "write" # Can add documents
ADMIN = "admin" # Can modify sharing settings
@dataclass
class DatasetShare:
"""Dataset sharing configuration"""
id: str = field(default_factory=lambda: str(uuid4()))
dataset_id: str = ""
owner_id: str = ""
access_group: AccessGroup = AccessGroup.INDIVIDUAL
team_members: List[str] = field(default_factory=list)
team_permissions: Dict[str, SharingPermission] = field(default_factory=dict)
shared_at: datetime = field(default_factory=datetime.utcnow)
expires_at: Optional[datetime] = None
is_active: bool = True
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for storage"""
return {
"id": self.id,
"dataset_id": self.dataset_id,
"owner_id": self.owner_id,
"access_group": self.access_group.value,
"team_members": self.team_members,
"team_permissions": {k: v.value for k, v in self.team_permissions.items()},
"shared_at": self.shared_at.isoformat(),
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"is_active": self.is_active
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DatasetShare":
"""Create from dictionary"""
return cls(
id=data.get("id", str(uuid4())),
dataset_id=data["dataset_id"],
owner_id=data["owner_id"],
access_group=AccessGroup(data["access_group"]),
team_members=data.get("team_members", []),
team_permissions={
k: SharingPermission(v) for k, v in data.get("team_permissions", {}).items()
},
shared_at=datetime.fromisoformat(data["shared_at"]),
expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
is_active=data.get("is_active", True)
)
@dataclass
class DatasetInfo:
"""Dataset information for sharing"""
id: str
name: str
description: str
owner_id: str
document_count: int
size_bytes: int
created_at: datetime
updated_at: datetime
tags: List[str] = field(default_factory=list)
class DatasetSharingService:
"""
Service for hierarchical dataset sharing with capability-based access control.
Features:
- Individual, Team, and Organization level sharing
- Granular permission management (read, write, admin)
- Time-based expiration of shares
- Perfect tenant isolation through file-based storage
- Event emission for sharing activities
"""
def __init__(self, tenant_domain: str, access_controller: AccessController):
self.tenant_domain = tenant_domain
self.access_controller = access_controller
self.base_path = Path(f"/data/{tenant_domain}/dataset_sharing")
self.shares_path = self.base_path / "shares"
self.datasets_path = self.base_path / "datasets"
# Ensure directories exist with proper permissions
self._ensure_directories()
logger.info(f"DatasetSharingService initialized for {tenant_domain}")
def _ensure_directories(self):
"""Ensure sharing directories exist with proper permissions"""
for path in [self.shares_path, self.datasets_path]:
path.mkdir(parents=True, exist_ok=True)
# Set permissions to 700 (owner only)
os.chmod(path, stat.S_IRWXU)
async def share_dataset(
self,
dataset_id: str,
owner_id: str,
access_group: AccessGroup,
team_members: Optional[List[str]] = None,
team_permissions: Optional[Dict[str, SharingPermission]] = None,
expires_at: Optional[datetime] = None,
capability_token: str = ""
) -> DatasetShare:
"""
Share a dataset with specified access group.
Args:
dataset_id: Dataset to share
owner_id: Owner of the dataset
access_group: Level of sharing (Individual, Team, Organization)
team_members: List of team members (if Team access)
team_permissions: Permissions for each team member
expires_at: Optional expiration time
capability_token: JWT capability token
Returns:
DatasetShare configuration
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
# Verify ownership
dataset_resource = await self._load_dataset_resource(dataset_id)
if not dataset_resource or dataset_resource.owner_id != owner_id:
raise PermissionError("Only dataset owner can modify sharing")
# Validate team members for team sharing
if access_group == AccessGroup.TEAM:
if not team_members:
raise ValueError("Team members required for team sharing")
# Ensure all team members are valid users in tenant
for member in team_members:
if not await self._is_valid_tenant_user(member):
logger.warning(f"Invalid team member: {member}")
# Create sharing configuration
share = DatasetShare(
dataset_id=dataset_id,
owner_id=owner_id,
access_group=access_group,
team_members=team_members or [],
team_permissions=team_permissions or {},
expires_at=expires_at
)
# Set default permissions for team members
if access_group == AccessGroup.TEAM:
for member in share.team_members:
if member not in share.team_permissions:
share.team_permissions[member] = SharingPermission.READ
# Store sharing configuration
await self._store_share(share)
# Update dataset resource access group
await self.access_controller.update_resource_access(
owner_id, dataset_id, access_group, team_members
)
# Emit sharing event
if hasattr(self.access_controller, 'event_bus'):
await self.access_controller.event_bus.emit_event(
"dataset.shared",
owner_id,
{
"dataset_id": dataset_id,
"access_group": access_group.value,
"team_members": team_members or [],
"expires_at": expires_at.isoformat() if expires_at else None
}
)
logger.info(f"Dataset {dataset_id} shared as {access_group.value} by {owner_id}")
return share
async def get_dataset_sharing(
self,
dataset_id: str,
user_id: str,
capability_token: str
) -> Optional[DatasetShare]:
"""
Get sharing configuration for a dataset.
Args:
dataset_id: Dataset ID
user_id: Requesting user
capability_token: JWT capability token
Returns:
DatasetShare if user has access, None otherwise
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
# Load sharing configuration
share = await self._load_share(dataset_id)
if not share:
return None
# Check if user has access to view sharing info
if share.owner_id == user_id:
return share # Owner can always see
if share.access_group == AccessGroup.TEAM and user_id in share.team_members:
return share # Team member can see
if share.access_group == AccessGroup.ORGANIZATION:
# All tenant users can see organization shares
if await self._is_valid_tenant_user(user_id):
return share
return None
async def check_dataset_access(
self,
dataset_id: str,
user_id: str,
permission: SharingPermission = SharingPermission.READ
) -> Tuple[bool, Optional[str]]:
"""
Check if user has specified permission on dataset.
Args:
dataset_id: Dataset to check
user_id: User requesting access
permission: Required permission level
Returns:
Tuple of (allowed, reason)
"""
# Load sharing configuration
share = await self._load_share(dataset_id)
if not share or not share.is_active:
return False, "Dataset not shared or sharing inactive"
# Check expiration
if share.expires_at and datetime.utcnow() > share.expires_at:
return False, "Dataset sharing has expired"
# Owner has all permissions
if share.owner_id == user_id:
return True, "Owner access"
# Check access group permissions
if share.access_group == AccessGroup.INDIVIDUAL:
return False, "Private dataset"
elif share.access_group == AccessGroup.TEAM:
if user_id not in share.team_members:
return False, "Not a team member"
# Check specific permission
user_permission = share.team_permissions.get(user_id, SharingPermission.READ)
if self._has_permission(user_permission, permission):
return True, f"Team member with {user_permission.value} permission"
else:
return False, f"Insufficient permission: has {user_permission.value}, needs {permission.value}"
elif share.access_group == AccessGroup.ORGANIZATION:
# Organization sharing is typically read-only
if permission == SharingPermission.READ:
if await self._is_valid_tenant_user(user_id):
return True, "Organization-wide read access"
return False, "Organization access is read-only"
return False, "Unknown access configuration"
async def list_accessible_datasets(
self,
user_id: str,
capability_token: str,
include_owned: bool = True,
include_shared: bool = True
) -> List[DatasetInfo]:
"""
List datasets accessible to user.
Args:
user_id: User requesting list
capability_token: JWT capability token
include_owned: Include user's own datasets
include_shared: Include datasets shared with user
Returns:
List of accessible datasets
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
accessible_datasets = []
# Get all dataset shares
all_shares = await self._list_all_shares()
for share in all_shares:
# Skip inactive or expired shares
if not share.is_active:
continue
if share.expires_at and datetime.utcnow() > share.expires_at:
continue
# Check if user has access
has_access = False
if include_owned and share.owner_id == user_id:
has_access = True
elif include_shared:
allowed, _ = await self.check_dataset_access(share.dataset_id, user_id)
has_access = allowed
if has_access:
dataset_info = await self._load_dataset_info(share.dataset_id)
if dataset_info:
accessible_datasets.append(dataset_info)
return accessible_datasets
async def revoke_dataset_sharing(
self,
dataset_id: str,
owner_id: str,
capability_token: str
) -> bool:
"""
Revoke dataset sharing (make it private).
Args:
dataset_id: Dataset to make private
owner_id: Owner of the dataset
capability_token: JWT capability token
Returns:
True if revoked successfully
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
# Verify ownership
share = await self._load_share(dataset_id)
if not share or share.owner_id != owner_id:
raise PermissionError("Only dataset owner can revoke sharing")
# Update sharing to individual (private)
share.access_group = AccessGroup.INDIVIDUAL
share.team_members = []
share.team_permissions = {}
share.is_active = False
# Store updated share
await self._store_share(share)
# Update resource access
await self.access_controller.update_resource_access(
owner_id, dataset_id, AccessGroup.INDIVIDUAL, []
)
# Emit revocation event
if hasattr(self.access_controller, 'event_bus'):
await self.access_controller.event_bus.emit_event(
"dataset.sharing_revoked",
owner_id,
{"dataset_id": dataset_id}
)
logger.info(f"Dataset {dataset_id} sharing revoked by {owner_id}")
return True
async def update_team_permissions(
self,
dataset_id: str,
owner_id: str,
user_id: str,
permission: SharingPermission,
capability_token: str
) -> bool:
"""
Update team member permissions for a dataset.
Args:
dataset_id: Dataset ID
owner_id: Owner of the dataset
user_id: Team member to update
permission: New permission level
capability_token: JWT capability token
Returns:
True if updated successfully
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
# Load and verify sharing
share = await self._load_share(dataset_id)
if not share or share.owner_id != owner_id:
raise PermissionError("Only dataset owner can update permissions")
if share.access_group != AccessGroup.TEAM:
raise ValueError("Can only update permissions for team-shared datasets")
if user_id not in share.team_members:
raise ValueError("User is not a team member")
# Update permission
share.team_permissions[user_id] = permission
# Store updated share
await self._store_share(share)
logger.info(f"Updated {user_id} permission to {permission.value} for dataset {dataset_id}")
return True
async def get_sharing_statistics(
self,
user_id: str,
capability_token: str
) -> Dict[str, Any]:
"""
Get sharing statistics for user.
Args:
user_id: User to get stats for
capability_token: JWT capability token
Returns:
Statistics dictionary
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
stats = {
"owned_datasets": 0,
"shared_with_me": 0,
"sharing_breakdown": {
AccessGroup.INDIVIDUAL: 0,
AccessGroup.TEAM: 0,
AccessGroup.ORGANIZATION: 0
},
"total_team_members": 0,
"expired_shares": 0
}
all_shares = await self._list_all_shares()
for share in all_shares:
# Count owned datasets
if share.owner_id == user_id:
stats["owned_datasets"] += 1
stats["sharing_breakdown"][share.access_group] += 1
stats["total_team_members"] += len(share.team_members)
# Count expired shares
if share.expires_at and datetime.utcnow() > share.expires_at:
stats["expired_shares"] += 1
# Count datasets shared with user
elif user_id in share.team_members or share.access_group == AccessGroup.ORGANIZATION:
if share.is_active and (not share.expires_at or datetime.utcnow() <= share.expires_at):
stats["shared_with_me"] += 1
return stats
def _has_permission(self, user_permission: SharingPermission, required: SharingPermission) -> bool:
"""Check if user permission satisfies required permission"""
permission_hierarchy = {
SharingPermission.READ: 1,
SharingPermission.WRITE: 2,
SharingPermission.ADMIN: 3
}
return permission_hierarchy[user_permission] >= permission_hierarchy[required]
async def _store_share(self, share: DatasetShare):
"""Store sharing configuration to file system"""
share_file = self.shares_path / f"{share.dataset_id}.json"
with open(share_file, "w") as f:
json.dump(share.to_dict(), f, indent=2)
# Set secure permissions
os.chmod(share_file, stat.S_IRUSR | stat.S_IWUSR) # 600
async def _load_share(self, dataset_id: str) -> Optional[DatasetShare]:
"""Load sharing configuration from file system"""
share_file = self.shares_path / f"{dataset_id}.json"
if not share_file.exists():
return None
try:
with open(share_file, "r") as f:
data = json.load(f)
return DatasetShare.from_dict(data)
except Exception as e:
logger.error(f"Error loading share for dataset {dataset_id}: {e}")
return None
async def _list_all_shares(self) -> List[DatasetShare]:
"""List all sharing configurations"""
shares = []
if self.shares_path.exists():
for share_file in self.shares_path.glob("*.json"):
try:
with open(share_file, "r") as f:
data = json.load(f)
shares.append(DatasetShare.from_dict(data))
except Exception as e:
logger.error(f"Error loading share file {share_file}: {e}")
return shares
async def _load_dataset_resource(self, dataset_id: str) -> Optional[Resource]:
"""Load dataset resource (implementation would query storage)"""
# Placeholder - would integrate with actual resource storage
return Resource(
id=dataset_id,
name=f"Dataset {dataset_id}",
resource_type="dataset",
owner_id="mock_owner",
tenant_domain=self.tenant_domain,
access_group=AccessGroup.INDIVIDUAL
)
async def _load_dataset_info(self, dataset_id: str) -> Optional[DatasetInfo]:
"""Load dataset information (implementation would query storage)"""
# Placeholder - would integrate with actual dataset storage
return DatasetInfo(
id=dataset_id,
name=f"Dataset {dataset_id}",
description="Mock dataset for testing",
owner_id="mock_owner",
document_count=10,
size_bytes=1024000,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
tags=["test", "mock"]
)
async def _is_valid_tenant_user(self, user_id: str) -> bool:
"""Check if user is valid in tenant (implementation would query user store)"""
# Placeholder - would integrate with actual user management
return "@" in user_id and user_id.endswith((".com", ".org", ".edu"))

View File

@@ -0,0 +1,445 @@
"""
Dataset Summarization Service for GT 2.0
Generates comprehensive summaries for datasets based on their constituent documents.
Provides analytics, topic clustering, and overview generation for RAG optimization.
"""
import logging
import asyncio
import httpx
import json
from typing import Dict, Any, Optional, List
from datetime import datetime
from collections import Counter
from app.core.database import get_db_session, execute_command, fetch_one, fetch_all
logger = logging.getLogger(__name__)
class DatasetSummarizer:
"""
Service for generating dataset-level summaries and analytics.
Features:
- Aggregate document summaries into dataset overview
- Topic clustering and theme analysis
- Dataset statistics and metrics
- Search optimization recommendations
- RAG performance insights
"""
def __init__(self):
self.resource_cluster_url = "http://gentwo-resource-backend:8000"
async def generate_dataset_summary(
self,
dataset_id: str,
tenant_domain: str,
user_id: str,
force_regenerate: bool = False
) -> Dict[str, Any]:
"""
Generate a comprehensive summary for a dataset.
Args:
dataset_id: Dataset ID to summarize
tenant_domain: Tenant domain for database context
user_id: User requesting the summary
force_regenerate: Force regeneration even if summary exists
Returns:
Dictionary with dataset summary including overview, topics,
statistics, and search optimization insights
"""
try:
# Check if summary already exists and is recent
if not force_regenerate:
existing_summary = await self._get_existing_summary(dataset_id, tenant_domain)
if existing_summary and self._is_summary_fresh(existing_summary):
logger.info(f"Using cached dataset summary for {dataset_id}")
return existing_summary
# Get dataset information and documents
dataset_info = await self._get_dataset_info(dataset_id, tenant_domain)
if not dataset_info:
raise ValueError(f"Dataset {dataset_id} not found")
documents = await self._get_dataset_documents(dataset_id, tenant_domain)
document_summaries = await self._get_document_summaries(dataset_id, tenant_domain)
# Generate statistics
stats = await self._calculate_dataset_statistics(dataset_id, tenant_domain)
# Analyze topics across all documents
topics_analysis = await self._analyze_dataset_topics(document_summaries)
# Generate overall summary using LLM
overview = await self._generate_dataset_overview(
dataset_info, document_summaries, topics_analysis, stats
)
# Create comprehensive summary
summary_data = {
"dataset_id": dataset_id,
"overview": overview,
"statistics": stats,
"topics": topics_analysis,
"recommendations": await self._generate_search_recommendations(stats, topics_analysis),
"metadata": {
"document_count": len(documents),
"has_summaries": len(document_summaries),
"generated_at": datetime.utcnow().isoformat(),
"generated_by": user_id
}
}
# Store summary in database
await self._store_dataset_summary(dataset_id, summary_data, tenant_domain, user_id)
logger.info(f"Generated dataset summary for {dataset_id} with {len(documents)} documents")
return summary_data
except Exception as e:
logger.error(f"Failed to generate dataset summary for {dataset_id}: {e}")
# Return basic fallback summary
return {
"dataset_id": dataset_id,
"overview": "Dataset summary generation failed",
"statistics": {"error": str(e)},
"topics": [],
"recommendations": [],
"metadata": {
"generated_at": datetime.utcnow().isoformat(),
"error": str(e)
}
}
async def _get_dataset_info(self, dataset_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
"""Get basic dataset information"""
async with get_db_session() as session:
query = """
SELECT id, dataset_name, description, chunking_strategy,
chunk_size, chunk_overlap, created_at
FROM datasets
WHERE id = $1
"""
result = await fetch_one(session, query, dataset_id)
return dict(result) if result else None
async def _get_dataset_documents(self, dataset_id: str, tenant_domain: str) -> List[Dict[str, Any]]:
"""Get all documents in the dataset"""
async with get_db_session() as session:
query = """
SELECT id, filename, original_filename, file_type,
file_size_bytes, chunk_count, created_at
FROM documents
WHERE dataset_id = $1 AND processing_status = 'completed'
ORDER BY created_at DESC
"""
results = await fetch_all(session, query, dataset_id)
return [dict(row) for row in results]
async def _get_document_summaries(self, dataset_id: str, tenant_domain: str) -> List[Dict[str, Any]]:
"""Get summaries for all documents in the dataset"""
async with get_db_session() as session:
query = """
SELECT ds.document_id, ds.quick_summary, ds.detailed_analysis,
ds.topics, ds.metadata, ds.confidence,
d.filename, d.original_filename
FROM document_summaries ds
JOIN documents d ON ds.document_id = d.id
WHERE d.dataset_id = $1
ORDER BY ds.created_at DESC
"""
results = await fetch_all(session, query, dataset_id)
summaries = []
for row in results:
summary = dict(row)
# Parse JSON fields
if summary["topics"]:
summary["topics"] = json.loads(summary["topics"])
if summary["metadata"]:
summary["metadata"] = json.loads(summary["metadata"])
summaries.append(summary)
return summaries
async def _calculate_dataset_statistics(self, dataset_id: str, tenant_domain: str) -> Dict[str, Any]:
"""Calculate comprehensive dataset statistics"""
async with get_db_session() as session:
# Basic document statistics
doc_stats_query = """
SELECT
COUNT(*) as total_documents,
SUM(file_size_bytes) as total_size_bytes,
SUM(chunk_count) as total_chunks,
AVG(chunk_count) as avg_chunks_per_doc,
COUNT(DISTINCT file_type) as unique_file_types
FROM documents
WHERE dataset_id = $1 AND processing_status = 'completed'
"""
doc_stats = await fetch_one(session, doc_stats_query, dataset_id)
# Chunk statistics
chunk_stats_query = """
SELECT
COUNT(*) as total_vector_embeddings,
AVG(token_count) as avg_tokens_per_chunk,
MIN(token_count) as min_tokens,
MAX(token_count) as max_tokens
FROM document_chunks
WHERE dataset_id = $1
"""
chunk_stats = await fetch_one(session, chunk_stats_query, dataset_id)
# File type distribution
file_types_query = """
SELECT file_type, COUNT(*) as count
FROM documents
WHERE dataset_id = $1 AND processing_status = 'completed'
GROUP BY file_type
ORDER BY count DESC
"""
file_types_results = await fetch_all(session, file_types_query, dataset_id)
file_types = {row["file_type"]: row["count"] for row in file_types_results}
return {
"documents": {
"total": doc_stats["total_documents"] or 0,
"total_size_mb": round((doc_stats["total_size_bytes"] or 0) / 1024 / 1024, 2),
"avg_chunks_per_document": round(doc_stats["avg_chunks_per_doc"] or 0, 1),
"unique_file_types": doc_stats["unique_file_types"] or 0,
"file_type_distribution": file_types
},
"chunks": {
"total": chunk_stats["total_vector_embeddings"] or 0,
"avg_tokens": round(chunk_stats["avg_tokens_per_chunk"] or 0, 1),
"token_range": {
"min": chunk_stats["min_tokens"] or 0,
"max": chunk_stats["max_tokens"] or 0
}
},
"search_readiness": {
"has_vectors": (chunk_stats["total_vector_embeddings"] or 0) > 0,
"vector_coverage": 1.0 if (doc_stats["total_chunks"] or 0) == (chunk_stats["total_vector_embeddings"] or 0) else 0.0
}
}
async def _analyze_dataset_topics(self, document_summaries: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Analyze topics across all document summaries"""
if not document_summaries:
return {"main_topics": [], "topic_distribution": {}, "confidence": 0.0}
# Collect all topics from document summaries
all_topics = []
for summary in document_summaries:
topics = summary.get("topics", [])
if isinstance(topics, list):
all_topics.extend(topics)
# Count topic frequencies
topic_counts = Counter(all_topics)
# Get top topics
main_topics = [topic for topic, count in topic_counts.most_common(10)]
# Calculate topic distribution
total_topics = len(all_topics)
topic_distribution = {}
if total_topics > 0:
for topic, count in topic_counts.items():
topic_distribution[topic] = round(count / total_topics, 3)
# Calculate confidence based on number of summaries available
confidence = min(1.0, len(document_summaries) / 5.0) # Full confidence with 5+ documents
return {
"main_topics": main_topics,
"topic_distribution": topic_distribution,
"confidence": confidence,
"total_unique_topics": len(topic_counts)
}
async def _generate_dataset_overview(
self,
dataset_info: Dict[str, Any],
document_summaries: List[Dict[str, Any]],
topics_analysis: Dict[str, Any],
stats: Dict[str, Any]
) -> str:
"""Generate LLM-powered overview of the dataset"""
# Create context for LLM
context = f"""Dataset: {dataset_info['dataset_name']}
Description: {dataset_info.get('description', 'No description provided')}
Statistics:
- {stats['documents']['total']} documents ({stats['documents']['total_size_mb']} MB)
- {stats['chunks']['total']} text chunks for search
- Average {stats['documents']['avg_chunks_per_document']} chunks per document
Main Topics: {', '.join(topics_analysis['main_topics'][:5])}
Document Summaries:
"""
# Add sample document summaries
for i, summary in enumerate(document_summaries[:3]): # First 3 documents
context += f"\n{i+1}. {summary['filename']}: {summary['quick_summary']}"
prompt = f"""Analyze this dataset and provide a comprehensive 2-3 paragraph overview.
{context}
Focus on:
1. What type of content this dataset contains
2. The main themes and topics covered
3. How useful this would be for AI-powered search and retrieval
4. Any notable patterns or characteristics
Provide a professional, informative summary suitable for users exploring their datasets."""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
json={
"model": "llama-3.1-8b-instant",
"messages": [
{
"role": "system",
"content": "You are a data analysis expert. Provide clear, insightful dataset summaries."
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.3,
"max_tokens": 500
},
headers={
"X-Tenant-ID": "default",
"Content-Type": "application/json"
}
)
if response.status_code == 200:
llm_response = response.json()
return llm_response["choices"][0]["message"]["content"]
else:
raise Exception(f"LLM API error: {response.status_code}")
except Exception as e:
logger.warning(f"LLM overview generation failed: {e}")
# Fallback to template-based overview
return f"This dataset contains {stats['documents']['total']} documents covering topics such as {', '.join(topics_analysis['main_topics'][:3])}. The dataset includes {stats['chunks']['total']} searchable text chunks optimized for AI-powered retrieval and question answering."
async def _generate_search_recommendations(
self,
stats: Dict[str, Any],
topics_analysis: Dict[str, Any]
) -> List[str]:
"""Generate recommendations for optimizing search performance"""
recommendations = []
# Vector coverage recommendations
if not stats["search_readiness"]["has_vectors"]:
recommendations.append("Generate vector embeddings for all documents to enable semantic search")
elif stats["search_readiness"]["vector_coverage"] < 1.0:
recommendations.append("Complete vector embedding generation for optimal search performance")
# Chunk size recommendations
avg_tokens = stats["chunks"]["avg_tokens"]
if avg_tokens < 100:
recommendations.append("Consider increasing chunk size for better context in search results")
elif avg_tokens > 600:
recommendations.append("Consider reducing chunk size for more precise search matches")
# Topic diversity recommendations
if topics_analysis["total_unique_topics"] < 3:
recommendations.append("Dataset may benefit from more diverse content for comprehensive coverage")
elif topics_analysis["total_unique_topics"] > 50:
recommendations.append("Consider organizing content into focused sub-datasets for better search precision")
# Document count recommendations
doc_count = stats["documents"]["total"]
if doc_count < 5:
recommendations.append("Add more documents to improve search quality and coverage")
elif doc_count > 100:
recommendations.append("Consider implementing advanced filtering and categorization for better navigation")
return recommendations[:5] # Limit to top 5 recommendations
async def _store_dataset_summary(
self,
dataset_id: str,
summary_data: Dict[str, Any],
tenant_domain: str,
user_id: str
):
"""Store or update dataset summary in database"""
async with get_db_session() as session:
query = """
UPDATE datasets
SET
summary = $1,
summary_generated_at = $2,
updated_at = NOW()
WHERE id = $3
"""
await execute_command(
session,
query,
json.dumps(summary_data),
datetime.utcnow(),
dataset_id
)
async def _get_existing_summary(self, dataset_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
"""Get existing dataset summary if available"""
async with get_db_session() as session:
query = """
SELECT summary, summary_generated_at
FROM datasets
WHERE id = $1 AND summary IS NOT NULL
"""
result = await fetch_one(session, query, dataset_id)
if result and result["summary"]:
return json.loads(result["summary"])
return None
def _is_summary_fresh(self, summary: Dict[str, Any], max_age_hours: int = 24) -> bool:
"""Check if summary is recent enough to avoid regeneration"""
try:
generated_at = datetime.fromisoformat(summary["metadata"]["generated_at"])
age_hours = (datetime.utcnow() - generated_at).total_seconds() / 3600
return age_hours < max_age_hours
except (KeyError, ValueError):
return False
# Global instance
dataset_summarizer = DatasetSummarizer()
async def generate_dataset_summary(
dataset_id: str,
tenant_domain: str,
user_id: str,
force_regenerate: bool = False
) -> Dict[str, Any]:
"""Convenience function for dataset summary generation"""
return await dataset_summarizer.generate_dataset_summary(
dataset_id, tenant_domain, user_id, force_regenerate
)
async def get_dataset_summary(dataset_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
"""Convenience function for retrieving dataset summary"""
return await dataset_summarizer._get_existing_summary(dataset_id, tenant_domain)

View File

@@ -0,0 +1,834 @@
"""
Document Processing Service for GT 2.0
Handles file upload, text extraction, chunking, and embedding generation
for RAG pipeline. Supports multiple file formats with intelligent chunking.
"""
import asyncio
import logging
import hashlib
import mimetypes
import re
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
import uuid
# Document processing libraries
import pypdf as PyPDF2 # pypdf is the maintained successor to PyPDF2
import docx
import pandas as pd
import json
import csv
from io import StringIO
# Database and core services
from app.core.postgresql_client import get_postgresql_client
# Resource cluster client for embeddings
import httpx
from app.services.embedding_client import get_embedding_client
# Document summarization
from app.services.summarization_service import SummarizationService
logger = logging.getLogger(__name__)
class DocumentProcessor:
"""
Comprehensive document processing service for RAG pipeline.
Features:
- Multi-format support (PDF, DOCX, TXT, MD, CSV, JSON)
- Intelligent chunking with overlap
- Async embedding generation with batch processing
- Progress tracking
- Error handling and recovery
"""
def __init__(self, db=None, tenant_domain=None):
self.db = db
self.tenant_domain = tenant_domain or "test" # Default fallback
# Use configurable embedding client instead of hardcoded URL
self.embedding_client = get_embedding_client()
self.chunk_size = 512 # Default chunk size in tokens
self.chunk_overlap = 128 # Default overlap
self.max_file_size = 100 * 1024 * 1024 # 100MB limit
# Embedding batch processing configuration
self.EMBEDDING_BATCH_SIZE = 15 # Process embeddings in batches of 15 (ARM64 optimized)
self.MAX_CONCURRENT_BATCHES = 3 # Process up to 3 batches concurrently
self.MAX_RETRIES = 3 # Maximum retries per batch
self.INITIAL_RETRY_DELAY = 1.0 # Initial delay in seconds
# Supported file types
self.supported_types = {
'.pdf': 'application/pdf',
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'.txt': 'text/plain',
'.md': 'text/markdown',
'.csv': 'text/csv',
'.json': 'application/json'
}
async def process_file(
self,
file_path: Path,
dataset_id: str,
user_id: str,
original_filename: str,
document_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Process a uploaded file through the complete RAG pipeline.
Args:
file_path: Path to uploaded file
dataset_id: Dataset UUID to attach to
user_id: User who uploaded the file
original_filename: Original filename
document_id: Optional existing document ID to update instead of creating new
Returns:
Dict: Document record with processing status
"""
logger.info(f"Processing file {original_filename} for dataset {dataset_id}")
# Process file directly (no session management needed with PostgreSQL client)
return await self._process_file_internal(file_path, dataset_id, user_id, original_filename, document_id)
async def _process_file_internal(
self,
file_path: Path,
dataset_id: str,
user_id: str,
original_filename: str,
document_id: Optional[str] = None
) -> Dict[str, Any]:
"""Internal file processing method"""
try:
# 1. Validate file
await self._validate_file(file_path)
# 2. Create or use existing document record
if document_id:
# Use existing document
document = {"id": document_id}
logger.info(f"Using existing document {document_id} for processing")
else:
# Create new document record
document = await self._create_document_record(
file_path, dataset_id, user_id, original_filename
)
# 3. Get or extract text content
await self._update_processing_status(document["id"], "processing", processing_stage="Getting text content")
# Check if content already exists (e.g., from upload-time extraction)
existing_content, storage_type = await self._get_existing_document_content(document["id"])
if existing_content and storage_type in ["pdf_extracted", "text"]:
# Use existing extracted content
text_content = existing_content
logger.info(f"Using existing extracted content ({len(text_content)} chars, type: {storage_type})")
else:
# Extract text from file
await self._update_processing_status(document["id"], "processing", processing_stage="Extracting text")
# Determine file type for extraction
if document_id:
# For existing documents, determine file type from file extension
file_ext = file_path.suffix.lower()
file_type = self.supported_types.get(file_ext, 'text/plain')
else:
file_type = document["file_type"]
text_content = await self._extract_text(file_path, file_type)
# 4. Update document with extracted text
await self._update_document_content(document["id"], text_content)
# 5. Generate document summary
await self._update_processing_status(document["id"], "processing", processing_stage="Generating summary")
await self._generate_document_summary(document["id"], text_content, original_filename, user_id)
# 6. Chunk the document
await self._update_processing_status(document["id"], "processing", processing_stage="Creating chunks")
chunks = await self._chunk_text(text_content, document["id"])
# Set expected chunk count for progress tracking
await self._update_processing_status(
document["id"], "processing",
processing_stage="Preparing for embedding generation",
total_chunks_expected=len(chunks)
)
# 7. Generate embeddings
await self._update_processing_status(document["id"], "processing", processing_stage="Starting embedding generation")
await self._generate_embeddings_for_chunks(chunks, dataset_id, user_id)
# 8. Update final status
await self._update_processing_status(
document["id"], "completed",
processing_stage="Completed",
chunks_processed=len(chunks),
total_chunks_expected=len(chunks)
)
await self._update_chunk_count(document["id"], len(chunks))
# 9. Update dataset summary (after document is fully processed)
await self._update_dataset_summary_after_document_change(dataset_id, user_id)
logger.info(f"Successfully processed {original_filename} with {len(chunks)} chunks")
return document
except Exception as e:
logger.error(f"Error processing file {original_filename}: {e}")
if 'document' in locals():
await self._update_processing_status(
document["id"], "failed",
error_message=str(e),
processing_stage="Failed"
)
raise
async def _validate_file(self, file_path: Path):
"""Validate file size and type"""
if not file_path.exists():
raise ValueError("File does not exist")
file_size = file_path.stat().st_size
if file_size > self.max_file_size:
raise ValueError(f"File too large: {file_size} bytes (max: {self.max_file_size})")
file_ext = file_path.suffix.lower()
if file_ext not in self.supported_types:
raise ValueError(f"Unsupported file type: {file_ext}")
async def _create_document_record(
self,
file_path: Path,
dataset_id: str,
user_id: str,
original_filename: str
) -> Dict[str, Any]:
"""Create document record in database"""
# Calculate file hash
with open(file_path, 'rb') as f:
file_hash = hashlib.sha256(f.read()).hexdigest()
file_ext = file_path.suffix.lower()
file_size = file_path.stat().st_size
document_id = str(uuid.uuid4())
# Insert document record using raw SQL
# Note: tenant_id is nullable UUID, so we set it to NULL for individual documents
pg_client = await get_postgresql_client()
await pg_client.execute_command(
"""INSERT INTO documents (
id, user_id, dataset_id, filename, original_filename,
file_type, file_size_bytes, file_hash, processing_status
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)""",
document_id, str(user_id), dataset_id, str(file_path.name),
original_filename, self.supported_types[file_ext], file_size, file_hash, "pending"
)
return {
"id": document_id,
"user_id": user_id,
"dataset_id": dataset_id,
"filename": str(file_path.name),
"original_filename": original_filename,
"file_type": self.supported_types[file_ext],
"file_size_bytes": file_size,
"file_hash": file_hash,
"processing_status": "pending",
"chunk_count": 0
}
async def _extract_text(self, file_path: Path, file_type: str) -> str:
"""Extract text content from various file formats"""
try:
if file_type == 'application/pdf':
return await self._extract_pdf_text(file_path)
elif 'wordprocessingml' in file_type:
return await self._extract_docx_text(file_path)
elif file_type == 'text/csv':
return await self._extract_csv_text(file_path)
elif file_type == 'application/json':
return await self._extract_json_text(file_path)
else: # text/plain, text/markdown
return await self._extract_plain_text(file_path)
except Exception as e:
logger.error(f"Text extraction failed for {file_path}: {e}")
raise ValueError(f"Could not extract text from file: {e}")
async def _extract_pdf_text(self, file_path: Path) -> str:
"""Extract text from PDF file"""
text_parts = []
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page_num, page in enumerate(pdf_reader.pages):
try:
page_text = page.extract_text()
if page_text.strip():
text_parts.append(f"--- Page {page_num + 1} ---\n{page_text}")
except Exception as e:
logger.warning(f"Could not extract text from page {page_num + 1}: {e}")
if not text_parts:
raise ValueError("No text could be extracted from PDF")
return "\n\n".join(text_parts)
async def _extract_docx_text(self, file_path: Path) -> str:
"""Extract text from DOCX file"""
doc = docx.Document(file_path)
text_parts = []
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text_parts.append(paragraph.text)
if not text_parts:
raise ValueError("No text could be extracted from DOCX")
return "\n\n".join(text_parts)
async def _extract_csv_text(self, file_path: Path) -> str:
"""Extract and format text from CSV file"""
try:
df = pd.read_csv(file_path)
# Create readable format
text_parts = [f"CSV Data with {len(df)} rows and {len(df.columns)} columns"]
text_parts.append(f"Columns: {', '.join(df.columns.tolist())}")
text_parts.append("")
# Sample first few rows in readable format
for idx, row in df.head(20).iterrows():
row_text = []
for col in df.columns:
if pd.notna(row[col]):
row_text.append(f"{col}: {row[col]}")
text_parts.append(f"Row {idx + 1}: " + " | ".join(row_text))
return "\n".join(text_parts)
except Exception as e:
logger.error(f"CSV parsing error: {e}")
# Fallback to reading as plain text
return await self._extract_plain_text(file_path)
async def _extract_json_text(self, file_path: Path) -> str:
"""Extract and format text from JSON file"""
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Convert JSON to readable text format
def json_to_text(obj, prefix=""):
text_parts = []
if isinstance(obj, dict):
for key, value in obj.items():
if isinstance(value, (dict, list)):
text_parts.append(f"{prefix}{key}:")
text_parts.extend(json_to_text(value, prefix + " "))
else:
text_parts.append(f"{prefix}{key}: {value}")
elif isinstance(obj, list):
for i, item in enumerate(obj):
if isinstance(item, (dict, list)):
text_parts.append(f"{prefix}Item {i + 1}:")
text_parts.extend(json_to_text(item, prefix + " "))
else:
text_parts.append(f"{prefix}Item {i + 1}: {item}")
else:
text_parts.append(f"{prefix}{obj}")
return text_parts
return "\n".join(json_to_text(data))
async def _extract_plain_text(self, file_path: Path) -> str:
"""Extract text from plain text files"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except UnicodeDecodeError:
# Try with latin-1 encoding
with open(file_path, 'r', encoding='latin-1') as f:
return f.read()
async def extract_text_from_path(self, file_path: Path, content_type: str) -> str:
"""Public wrapper for text extraction from file path"""
return await self._extract_text(file_path, content_type)
async def chunk_text_simple(self, text: str) -> List[str]:
"""Public wrapper for simple text chunking without document_id"""
chunks = []
chunk_size = self.chunk_size * 4 # ~2048 chars
overlap = self.chunk_overlap * 4 # ~512 chars
for i in range(0, len(text), chunk_size - overlap):
chunk = text[i:i + chunk_size]
if chunk.strip():
chunks.append(chunk)
return chunks
async def _chunk_text(self, text: str, document_id: str) -> List[Dict[str, Any]]:
"""
Split text into overlapping chunks optimized for embeddings.
Returns:
List of chunk dictionaries with content and metadata
"""
# Simple sentence-aware chunking
sentences = re.split(r'[.!?]+', text)
sentences = [s.strip() for s in sentences if s.strip()]
chunks = []
current_chunk = ""
current_tokens = 0
chunk_index = 0
for sentence in sentences:
sentence_tokens = len(sentence.split())
# If adding this sentence would exceed chunk size, save current chunk
if current_tokens + sentence_tokens > self.chunk_size and current_chunk:
# Create chunk with overlap from previous chunk
chunk_content = current_chunk.strip()
if chunk_content:
chunks.append({
"document_id": document_id,
"chunk_index": chunk_index,
"content": chunk_content,
"token_count": current_tokens,
"content_hash": hashlib.md5(chunk_content.encode()).hexdigest()
})
chunk_index += 1
# Start new chunk with overlap
if self.chunk_overlap > 0 and chunks:
# Take last few sentences for overlap
overlap_sentences = current_chunk.split('.')[-2:] # Rough overlap
current_chunk = '. '.join(s.strip() for s in overlap_sentences if s.strip())
current_tokens = len(current_chunk.split())
else:
current_chunk = ""
current_tokens = 0
# Add sentence to current chunk
if current_chunk:
current_chunk += ". " + sentence
else:
current_chunk = sentence
current_tokens += sentence_tokens
# Add final chunk
if current_chunk.strip():
chunk_content = current_chunk.strip()
chunks.append({
"document_id": document_id,
"chunk_index": chunk_index,
"content": chunk_content,
"token_count": current_tokens,
"content_hash": hashlib.md5(chunk_content.encode()).hexdigest()
})
logger.info(f"Created {len(chunks)} chunks from document {document_id}")
return chunks
async def _generate_embeddings_for_chunks(
self,
chunks: List[Dict[str, Any]],
dataset_id: str,
user_id: str
):
"""
Generate embeddings for all chunks using concurrent batch processing.
Processes chunks in batches with controlled concurrency to optimize performance
while preventing system overload. Includes retry logic and progressive storage.
"""
if not chunks:
return
total_chunks = len(chunks)
document_id = chunks[0]["document_id"]
total_batches = (total_chunks + self.EMBEDDING_BATCH_SIZE - 1) // self.EMBEDDING_BATCH_SIZE
logger.info(f"Starting concurrent embedding generation for {total_chunks} chunks")
logger.info(f"Batch size: {self.EMBEDDING_BATCH_SIZE}, Total batches: {total_batches}, Max concurrent: {self.MAX_CONCURRENT_BATCHES}")
# Create semaphore to limit concurrent batches
semaphore = asyncio.Semaphore(self.MAX_CONCURRENT_BATCHES)
# Create batch data with metadata
batch_tasks = []
for batch_start in range(0, total_chunks, self.EMBEDDING_BATCH_SIZE):
batch_end = min(batch_start + self.EMBEDDING_BATCH_SIZE, total_chunks)
batch_chunks = chunks[batch_start:batch_end]
batch_num = (batch_start // self.EMBEDDING_BATCH_SIZE) + 1
batch_metadata = {
"chunks": batch_chunks,
"batch_num": batch_num,
"start_index": batch_start,
"end_index": batch_end,
"dataset_id": dataset_id,
"user_id": user_id,
"document_id": document_id
}
# Create task for this batch
task = self._process_batch_with_semaphore(semaphore, batch_metadata, total_batches, total_chunks)
batch_tasks.append(task)
# Process all batches concurrently
logger.info(f"Starting concurrent processing of {len(batch_tasks)} batches")
start_time = asyncio.get_event_loop().time()
results = await asyncio.gather(*batch_tasks, return_exceptions=True)
end_time = asyncio.get_event_loop().time()
processing_time = end_time - start_time
# Analyze results
successful_batches = []
failed_batches = []
for i, result in enumerate(results):
batch_num = i + 1
if isinstance(result, Exception):
failed_batches.append({
"batch_num": batch_num,
"error": str(result)
})
logger.error(f"Batch {batch_num} failed: {result}")
else:
successful_batches.append(result)
successful_chunks = sum(len(batch["chunks"]) for batch in successful_batches)
logger.info(f"Concurrent processing completed in {processing_time:.2f} seconds")
logger.info(f"Successfully processed {successful_chunks}/{total_chunks} chunks in {len(successful_batches)}/{total_batches} batches")
# Report final results
if failed_batches:
failed_chunk_count = total_chunks - successful_chunks
error_details = "; ".join([f"Batch {b['batch_num']}: {b['error']}" for b in failed_batches[:3]])
if len(failed_batches) > 3:
error_details += f" (and {len(failed_batches) - 3} more failures)"
raise ValueError(f"Failed to generate embeddings for {failed_chunk_count}/{total_chunks} chunks. Errors: {error_details}")
logger.info(f"Successfully stored all {total_chunks} chunks with embeddings")
async def _process_batch_with_semaphore(
self,
semaphore: asyncio.Semaphore,
batch_metadata: Dict[str, Any],
total_batches: int,
total_chunks: int
) -> Dict[str, Any]:
"""
Process a single batch with semaphore-controlled concurrency.
Args:
semaphore: Concurrency control semaphore
batch_metadata: Batch information including chunks and metadata
total_batches: Total number of batches
total_chunks: Total number of chunks
Returns:
Dict with batch processing results
"""
async with semaphore:
batch_chunks = batch_metadata["chunks"]
batch_num = batch_metadata["batch_num"]
dataset_id = batch_metadata["dataset_id"]
user_id = batch_metadata["user_id"]
document_id = batch_metadata["document_id"]
logger.info(f"Starting batch {batch_num}/{total_batches} ({len(batch_chunks)} chunks)")
try:
# Generate embeddings for this batch (pass user_id for billing)
embeddings = await self._generate_embedding_batch(batch_chunks, user_id=user_id)
# Store embeddings for this batch immediately
await self._store_chunk_embeddings(batch_chunks, embeddings, dataset_id, user_id)
# Update progress in database
progress_stage = f"Completed batch {batch_num}/{total_batches}"
# Calculate current progress (approximate since batches complete out of order)
await self._update_processing_status(
document_id, "processing",
processing_stage=progress_stage,
chunks_processed=batch_num * self.EMBEDDING_BATCH_SIZE, # Approximate
total_chunks_expected=total_chunks
)
logger.info(f"Successfully completed batch {batch_num}/{total_batches}")
return {
"batch_num": batch_num,
"chunks": batch_chunks,
"success": True
}
except Exception as e:
logger.error(f"Failed to process batch {batch_num}/{total_batches}: {e}")
raise ValueError(f"Batch {batch_num} failed: {str(e)}")
async def _generate_embedding_batch(
self,
batch_chunks: List[Dict[str, Any]],
user_id: str = None
) -> List[List[float]]:
"""
Generate embeddings for a single batch of chunks with retry logic.
Args:
batch_chunks: List of chunk dictionaries
user_id: User ID for usage tracking
Returns:
List of embedding vectors
Raises:
ValueError: If embedding generation fails after all retries
"""
texts = [chunk["content"] for chunk in batch_chunks]
for attempt in range(self.MAX_RETRIES + 1):
try:
# Use the configurable embedding client with tenant/user context for billing
embeddings = await self.embedding_client.generate_embeddings(
texts,
tenant_id=self.tenant_domain,
user_id=str(user_id) if user_id else None
)
if len(embeddings) != len(texts):
raise ValueError(f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}")
return embeddings
except Exception as e:
if attempt < self.MAX_RETRIES:
delay = self.INITIAL_RETRY_DELAY * (2 ** attempt) # Exponential backoff
logger.warning(f"Embedding generation attempt {attempt + 1}/{self.MAX_RETRIES + 1} failed: {e}. Retrying in {delay}s...")
await asyncio.sleep(delay)
else:
logger.error(f"All {self.MAX_RETRIES + 1} embedding generation attempts failed. Final error: {e}")
logger.error(f"Failed request details: URL=http://gentwo-vllm-embeddings:8000/v1/embeddings, texts_count={len(texts)}")
raise ValueError(f"Embedding generation failed after {self.MAX_RETRIES + 1} attempts: {str(e)}")
async def _store_chunk_embeddings(
self,
batch_chunks: List[Dict[str, Any]],
embeddings: List[List[float]],
dataset_id: str,
user_id: str
):
"""Store chunk embeddings in database with proper error handling."""
pg_client = await get_postgresql_client()
for chunk_data, embedding in zip(batch_chunks, embeddings):
chunk_id = str(uuid.uuid4())
# Convert embedding list to PostgreSQL array format
embedding_array = f"[{','.join(map(str, embedding))}]" if embedding else None
await pg_client.execute_command(
"""INSERT INTO document_chunks (
id, document_id, user_id, dataset_id, chunk_index,
content, content_hash, token_count, embedding
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector)""",
chunk_id, chunk_data["document_id"], str(user_id),
dataset_id, chunk_data["chunk_index"], chunk_data["content"],
chunk_data["content_hash"], chunk_data["token_count"], embedding_array
)
async def _update_processing_status(
self,
document_id: str,
status: str,
error_message: Optional[str] = None,
processing_stage: Optional[str] = None,
chunks_processed: Optional[int] = None,
total_chunks_expected: Optional[int] = None
):
"""Update document processing status with progress tracking via metadata JSONB"""
# Calculate progress percentage if we have the data
processing_progress = None
if chunks_processed is not None and total_chunks_expected is not None and total_chunks_expected > 0:
processing_progress = min(100, int((chunks_processed / total_chunks_expected) * 100))
# Build progress metadata object
import json
progress_data = {}
if processing_stage is not None:
progress_data['processing_stage'] = processing_stage
if chunks_processed is not None:
progress_data['chunks_processed'] = chunks_processed
if total_chunks_expected is not None:
progress_data['total_chunks_expected'] = total_chunks_expected
if processing_progress is not None:
progress_data['processing_progress'] = processing_progress
pg_client = await get_postgresql_client()
if error_message:
await pg_client.execute_command(
"""UPDATE documents SET
processing_status = $1,
error_message = $2,
metadata = COALESCE(metadata, '{}'::jsonb) || $3::jsonb,
updated_at = NOW()
WHERE id = $4""",
status, error_message, json.dumps(progress_data), document_id
)
else:
await pg_client.execute_command(
"""UPDATE documents SET
processing_status = $1,
metadata = COALESCE(metadata, '{}'::jsonb) || $2::jsonb,
updated_at = NOW()
WHERE id = $3""",
status, json.dumps(progress_data), document_id
)
async def _get_existing_document_content(self, document_id: str) -> tuple[str, str]:
"""Get existing document content and storage type"""
pg_client = await get_postgresql_client()
result = await pg_client.fetch_one(
"SELECT content_text, metadata FROM documents WHERE id = $1",
document_id
)
if result and result["content_text"]:
# Handle metadata - might be JSON string or dict
metadata_raw = result["metadata"] or "{}"
if isinstance(metadata_raw, str):
import json
try:
metadata = json.loads(metadata_raw)
except json.JSONDecodeError:
metadata = {}
else:
metadata = metadata_raw or {}
storage_type = metadata.get("storage_type", "unknown")
return result["content_text"], storage_type
return None, None
async def _update_document_content(self, document_id: str, content: str):
"""Update document with extracted text content"""
pg_client = await get_postgresql_client()
await pg_client.execute_command(
"UPDATE documents SET content_text = $1, updated_at = NOW() WHERE id = $2",
content, document_id
)
async def _update_chunk_count(self, document_id: str, chunk_count: int):
"""Update document with final chunk count"""
pg_client = await get_postgresql_client()
await pg_client.execute_command(
"UPDATE documents SET chunk_count = $1, updated_at = NOW() WHERE id = $2",
chunk_count, document_id
)
async def _generate_document_summary(
self,
document_id: str,
content: str,
filename: str,
user_id: str
):
"""Generate and store AI summary for the document"""
try:
# Use tenant_domain from instance context
tenant_domain = self.tenant_domain
# Create summarization service instance
summarization_service = SummarizationService(tenant_domain, user_id)
# Generate summary using our new service
summary = await summarization_service.generate_document_summary(
document_id=document_id,
document_content=content,
document_name=filename
)
if summary:
logger.info(f"Generated summary for document {document_id}: {summary[:100]}...")
else:
logger.warning(f"Failed to generate summary for document {document_id}")
except Exception as e:
logger.error(f"Error generating document summary for {document_id}: {e}")
# Don't fail the entire document processing if summarization fails
async def _update_dataset_summary_after_document_change(
self,
dataset_id: str,
user_id: str
):
"""Update dataset summary after a document is added or removed"""
try:
# Create summarization service instance
summarization_service = SummarizationService(self.tenant_domain, user_id)
# Update dataset summary asynchronously (don't block document processing)
asyncio.create_task(
summarization_service.update_dataset_summary_on_change(dataset_id)
)
logger.info(f"Triggered dataset summary update for dataset {dataset_id}")
except Exception as e:
logger.error(f"Error triggering dataset summary update for {dataset_id}: {e}")
# Don't fail document processing if dataset summary update fails
async def get_processing_status(self, document_id: str) -> Dict[str, Any]:
"""Get current processing status of a document with progress information from metadata"""
pg_client = await get_postgresql_client()
result = await pg_client.fetch_one(
"""SELECT processing_status, error_message, chunk_count, metadata
FROM documents WHERE id = $1""",
document_id
)
if not result:
raise ValueError("Document not found")
# Extract progress data from metadata JSONB
metadata = result["metadata"] or {}
return {
"status": result["processing_status"],
"error_message": result["error_message"],
"chunk_count": result["chunk_count"],
"chunks_processed": metadata.get("chunks_processed"),
"total_chunks_expected": metadata.get("total_chunks_expected"),
"processing_progress": metadata.get("processing_progress"),
"processing_stage": metadata.get("processing_stage")
}
# Factory function for document processor
async def get_document_processor(tenant_domain=None):
"""Get document processor instance (will create its own DB session when needed)"""
return DocumentProcessor(tenant_domain=tenant_domain)

View File

@@ -0,0 +1,317 @@
"""
Document Summarization Service for GT 2.0
Generates AI-powered summaries for uploaded documents using the Resource Cluster.
Provides both quick summaries and detailed analysis for RAG visualization.
"""
import logging
import asyncio
import httpx
from typing import Dict, Any, Optional, List
from datetime import datetime
from app.core.database import get_db_session, execute_command, fetch_one
logger = logging.getLogger(__name__)
class DocumentSummarizer:
"""
Service for generating document summaries using Resource Cluster LLM.
Features:
- Quick document summaries (2-3 sentences)
- Detailed analysis with key topics and themes
- Metadata extraction (document type, language, etc.)
- Integration with document processor workflow
"""
def __init__(self):
self.resource_cluster_url = "http://gentwo-resource-backend:8000"
self.max_content_length = 4000 # Max chars to send for summarization
async def generate_document_summary(
self,
document_id: str,
content: str,
filename: str,
tenant_domain: str,
user_id: str
) -> Dict[str, Any]:
"""
Generate a comprehensive summary for a document.
Args:
document_id: Document ID in the database
content: Document text content
filename: Original filename
tenant_domain: Tenant domain for context
user_id: User who uploaded the document
Returns:
Dictionary with summary data including quick_summary, detailed_analysis,
topics, metadata, and confidence scores
"""
try:
# Truncate content if too long
truncated_content = content[:self.max_content_length]
if len(content) > self.max_content_length:
truncated_content += "... [content truncated]"
# Generate summary using Resource Cluster LLM
summary_data = await self._call_llm_for_summary(
content=truncated_content,
filename=filename,
document_type=self._detect_document_type(filename)
)
# Store summary in database
await self._store_document_summary(
document_id=document_id,
summary_data=summary_data,
tenant_domain=tenant_domain,
user_id=user_id
)
logger.info(f"Generated summary for document {document_id}: {filename}")
return summary_data
except Exception as e:
logger.error(f"Failed to generate summary for document {document_id}: {e}")
# Return basic fallback summary
return {
"quick_summary": f"Document: {filename}",
"detailed_analysis": "Summary generation failed",
"topics": [],
"metadata": {
"document_type": self._detect_document_type(filename),
"estimated_read_time": len(content) // 200, # ~200 words per minute
"character_count": len(content),
"language": "unknown"
},
"confidence": 0.0,
"error": str(e)
}
async def _call_llm_for_summary(
self,
content: str,
filename: str,
document_type: str
) -> Dict[str, Any]:
"""Call Resource Cluster LLM to generate document summary"""
prompt = f"""Analyze this {document_type} document and provide a comprehensive summary.
Document: {filename}
Content:
{content}
Please provide:
1. A concise 2-3 sentence summary
2. Key topics and themes (list)
3. Document analysis including tone, purpose, and target audience
4. Estimated language and reading level
Format your response as JSON with these keys:
- quick_summary: Brief 2-3 sentence overview
- detailed_analysis: Paragraph with deeper insights
- topics: Array of key topics/themes
- metadata: Object with language, tone, purpose, target_audience
- confidence: Float 0-1 indicating analysis confidence"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{self.resource_cluster_url}/api/v1/ai/chat/completions",
json={
"model": "llama-3.1-8b-instant",
"messages": [
{
"role": "system",
"content": "You are a document analysis expert. Provide accurate, concise summaries in valid JSON format."
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.3,
"max_tokens": 1000
},
headers={
"X-Tenant-ID": "default",
"Content-Type": "application/json"
}
)
if response.status_code == 200:
llm_response = response.json()
content_text = llm_response["choices"][0]["message"]["content"]
# Try to parse JSON response
try:
import json
summary_data = json.loads(content_text)
# Validate required fields and add defaults if missing
return {
"quick_summary": summary_data.get("quick_summary", f"Analysis of {filename}"),
"detailed_analysis": summary_data.get("detailed_analysis", "Detailed analysis not available"),
"topics": summary_data.get("topics", []),
"metadata": {
**summary_data.get("metadata", {}),
"document_type": document_type,
"generated_at": datetime.utcnow().isoformat()
},
"confidence": min(1.0, max(0.0, summary_data.get("confidence", 0.7)))
}
except json.JSONDecodeError:
# Fallback if LLM doesn't return valid JSON
return {
"quick_summary": content_text[:200] + "..." if len(content_text) > 200 else content_text,
"detailed_analysis": content_text,
"topics": [],
"metadata": {
"document_type": document_type,
"generated_at": datetime.utcnow().isoformat(),
"note": "Summary extracted from free-form LLM response"
},
"confidence": 0.5
}
else:
raise Exception(f"Resource Cluster API error: {response.status_code}")
except Exception as e:
logger.error(f"LLM summarization failed: {e}")
raise
async def _store_document_summary(
self,
document_id: str,
summary_data: Dict[str, Any],
tenant_domain: str,
user_id: str
):
"""Store generated summary in database"""
# Use the same database session pattern as document processor
async with get_db_session(tenant_domain) as session:
try:
# Insert or update document summary
query = """
INSERT INTO document_summaries (
document_id, user_id, quick_summary, detailed_analysis,
topics, metadata, confidence, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (document_id)
DO UPDATE SET
quick_summary = EXCLUDED.quick_summary,
detailed_analysis = EXCLUDED.detailed_analysis,
topics = EXCLUDED.topics,
metadata = EXCLUDED.metadata,
confidence = EXCLUDED.confidence,
updated_at = EXCLUDED.updated_at
"""
import json
await execute_command(
session,
query,
document_id,
user_id,
summary_data["quick_summary"],
summary_data["detailed_analysis"],
json.dumps(summary_data["topics"]),
json.dumps(summary_data["metadata"]),
summary_data["confidence"],
datetime.utcnow(),
datetime.utcnow()
)
logger.info(f"Stored summary for document {document_id}")
except Exception as e:
logger.error(f"Failed to store document summary: {e}")
raise
def _detect_document_type(self, filename: str) -> str:
"""Detect document type from filename extension"""
import pathlib
extension = pathlib.Path(filename).suffix.lower()
type_mapping = {
'.pdf': 'PDF document',
'.docx': 'Word document',
'.doc': 'Word document',
'.txt': 'Text file',
'.md': 'Markdown document',
'.csv': 'CSV data file',
'.json': 'JSON data file',
'.html': 'HTML document',
'.htm': 'HTML document',
'.rtf': 'Rich text document'
}
return type_mapping.get(extension, 'Unknown document type')
async def get_document_summary(
self,
document_id: str,
tenant_domain: str
) -> Optional[Dict[str, Any]]:
"""Retrieve stored document summary"""
async with get_db_session(tenant_domain) as session:
try:
query = """
SELECT quick_summary, detailed_analysis, topics, metadata,
confidence, created_at, updated_at
FROM document_summaries
WHERE document_id = $1
"""
result = await fetch_one(session, query, document_id)
if result:
import json
return {
"quick_summary": result["quick_summary"],
"detailed_analysis": result["detailed_analysis"],
"topics": json.loads(result["topics"]) if result["topics"] else [],
"metadata": json.loads(result["metadata"]) if result["metadata"] else {},
"confidence": result["confidence"],
"created_at": result["created_at"].isoformat(),
"updated_at": result["updated_at"].isoformat()
}
return None
except Exception as e:
logger.error(f"Failed to retrieve document summary: {e}")
return None
# Global instance
document_summarizer = DocumentSummarizer()
async def generate_document_summary(
document_id: str,
content: str,
filename: str,
tenant_domain: str,
user_id: str
) -> Dict[str, Any]:
"""Convenience function for document summary generation"""
return await document_summarizer.generate_document_summary(
document_id, content, filename, tenant_domain, user_id
)
async def get_document_summary(document_id: str, tenant_domain: str) -> Optional[Dict[str, Any]]:
"""Convenience function for retrieving document summary"""
return await document_summarizer.get_document_summary(document_id, tenant_domain)

View File

@@ -0,0 +1,286 @@
"""
BGE-M3 Embedding Client for GT 2.0
Simple client for the vLLM BGE-M3 embedding service running on port 8005.
Provides text embedding generation for RAG pipeline.
"""
import asyncio
import logging
from typing import List, Dict, Any, Optional
import httpx
logger = logging.getLogger(__name__)
class BGE_M3_EmbeddingClient:
"""
Simple client for BGE-M3 embedding service via vLLM.
Features:
- Async HTTP client for embeddings
- Batch processing support
- Error handling and retries
- OpenAI-compatible API format
"""
def __init__(self, base_url: str = None):
# Determine base URL from environment or configuration
if base_url is None:
base_url = self._get_embedding_endpoint()
self.base_url = base_url
self.model = "BAAI/bge-m3"
self.embedding_dimensions = 1024
self.max_batch_size = 32
# Initialize BGE-M3 tokenizer for accurate token counting
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
logger.info("Initialized BGE-M3 tokenizer for accurate token counting")
except Exception as e:
logger.warning(f"Failed to load BGE-M3 tokenizer, using word estimation: {e}")
self.tokenizer = None
def _get_embedding_endpoint(self) -> str:
"""
Get the BGE-M3 endpoint based on configuration.
This should sync with the control panel configuration.
"""
import os
# Check environment variables for BGE-M3 configuration
is_local_mode = os.getenv('BGE_M3_LOCAL_MODE', 'true').lower() == 'true'
external_endpoint = os.getenv('BGE_M3_EXTERNAL_ENDPOINT')
if not is_local_mode and external_endpoint:
return external_endpoint
# Default to local endpoint
return os.getenv('EMBEDDING_ENDPOINT', 'http://host.docker.internal:8005')
def update_endpoint(self, new_endpoint: str):
"""Update the embedding endpoint dynamically"""
self.base_url = new_endpoint
logger.info(f"BGE-M3 client endpoint updated to: {new_endpoint}")
async def health_check(self) -> bool:
"""Check if BGE-M3 service is responding"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(f"{self.base_url}/v1/models")
if response.status_code == 200:
models = response.json()
model_ids = [model['id'] for model in models.get('data', [])]
return self.model in model_ids
return False
except Exception as e:
logger.error(f"Health check failed: {e}")
return False
async def generate_embeddings(
self,
texts: List[str],
tenant_id: Optional[str] = None,
user_id: Optional[str] = None,
request_id: Optional[str] = None
) -> List[List[float]]:
"""
Generate embeddings for a list of texts using BGE-M3.
Args:
texts: List of text strings to embed
tenant_id: Tenant ID for usage tracking (optional)
user_id: User ID for usage tracking (optional)
request_id: Request ID for tracking (optional)
Returns:
List of embedding vectors (each is a list of 1024 floats)
Raises:
ValueError: If embedding generation fails
"""
if not texts:
return []
if len(texts) > self.max_batch_size:
# Process in batches
all_embeddings = []
for i in range(0, len(texts), self.max_batch_size):
batch = texts[i:i + self.max_batch_size]
batch_embeddings = await self._generate_batch(batch)
all_embeddings.extend(batch_embeddings)
embeddings = all_embeddings
else:
embeddings = await self._generate_batch(texts)
# Log usage if tenant context provided (fire and forget)
if tenant_id and user_id:
import asyncio
tokens_used = self._count_tokens(texts)
asyncio.create_task(
self._log_embedding_usage(
tenant_id=tenant_id,
user_id=user_id,
tokens_used=tokens_used,
embedding_count=len(embeddings),
request_id=request_id
)
)
return embeddings
async def _generate_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for a single batch"""
try:
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
f"{self.base_url}/v1/embeddings",
json={
"input": texts,
"model": self.model
}
)
if response.status_code == 200:
data = response.json()
# Extract embeddings from OpenAI-compatible response
embeddings = []
for item in data.get("data", []):
embedding = item.get("embedding", [])
if len(embedding) != self.embedding_dimensions:
raise ValueError(f"Invalid embedding dimensions: {len(embedding)} (expected {self.embedding_dimensions})")
embeddings.append(embedding)
logger.info(f"Generated {len(embeddings)} embeddings")
return embeddings
else:
error_text = response.text
logger.error(f"Embedding generation failed: {response.status_code} - {error_text}")
raise ValueError(f"Embedding generation failed: {response.status_code}")
except httpx.TimeoutException:
logger.error("Embedding generation timed out")
raise ValueError("Embedding generation timed out")
except Exception as e:
logger.error(f"Error calling embedding service: {e}")
raise ValueError(f"Embedding service error: {str(e)}")
def _count_tokens(self, texts: List[str]) -> int:
"""Count tokens using actual BGE-M3 tokenizer."""
if self.tokenizer is not None:
try:
total_tokens = 0
for text in texts:
tokens = self.tokenizer.encode(text, add_special_tokens=False)
total_tokens += len(tokens)
return total_tokens
except Exception as e:
logger.warning(f"Tokenizer error, falling back to estimation: {e}")
# Fallback: word count * 1.3
total_words = sum(len(text.split()) for text in texts)
return int(total_words * 1.3)
async def _log_embedding_usage(
self,
tenant_id: str,
user_id: str,
tokens_used: int,
embedding_count: int,
request_id: Optional[str] = None
) -> None:
"""Log embedding usage to control panel database for billing."""
try:
import asyncpg
import os
# Calculate cost: BGE-M3 pricing ~$0.10 per million tokens
cost_cents = (tokens_used / 1_000_000) * 0.10 * 100
db_password = os.getenv("CONTROL_PANEL_DB_PASSWORD")
if not db_password:
logger.warning("CONTROL_PANEL_DB_PASSWORD not set, skipping embedding usage logging")
return
conn = await asyncpg.connect(
host=os.getenv("CONTROL_PANEL_DB_HOST", "gentwo-controlpanel-postgres"),
database=os.getenv("CONTROL_PANEL_DB_NAME", "gt2_admin"),
user=os.getenv("CONTROL_PANEL_DB_USER", "postgres"),
password=db_password,
timeout=5.0
)
try:
await conn.execute("""
INSERT INTO public.embedding_usage_logs
(tenant_id, user_id, tokens_used, embedding_count, model, cost_cents, request_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""", tenant_id, user_id, tokens_used, embedding_count, self.model, cost_cents, request_id)
logger.info(
f"Logged embedding usage: tenant={tenant_id}, user={user_id}, "
f"tokens={tokens_used}, embeddings={embedding_count}, cost_cents={cost_cents:.4f}"
)
finally:
await conn.close()
except Exception as e:
logger.warning(f"Failed to log embedding usage: {e}")
async def generate_single_embedding(self, text: str) -> List[float]:
"""
Generate embedding for a single text.
Args:
text: Text string to embed
Returns:
Embedding vector (list of 1024 floats)
"""
embeddings = await self.generate_embeddings([text])
return embeddings[0] if embeddings else []
# Global client instance
_embedding_client: Optional[BGE_M3_EmbeddingClient] = None
def get_embedding_client() -> BGE_M3_EmbeddingClient:
"""Get or create global embedding client instance"""
global _embedding_client
if _embedding_client is None:
_embedding_client = BGE_M3_EmbeddingClient()
else:
# Always refresh the endpoint from current configuration
current_endpoint = _embedding_client._get_embedding_endpoint()
if _embedding_client.base_url != current_endpoint:
_embedding_client.base_url = current_endpoint
logger.info(f"BGE-M3 client endpoint refreshed to: {current_endpoint}")
return _embedding_client
async def test_embedding_client():
"""Test function for the embedding client"""
client = get_embedding_client()
# Test health check
is_healthy = await client.health_check()
print(f"BGE-M3 service healthy: {is_healthy}")
if is_healthy:
# Test embedding generation
test_texts = [
"This is a test document about machine learning.",
"GT 2.0 is an enterprise AI platform.",
"Vector embeddings enable semantic search."
]
embeddings = await client.generate_embeddings(test_texts)
print(f"Generated {len(embeddings)} embeddings")
print(f"Embedding dimensions: {len(embeddings[0]) if embeddings else 0}")
if __name__ == "__main__":
asyncio.run(test_embedding_client())

View File

@@ -0,0 +1,722 @@
"""
Enhanced API Key Management Service for GT 2.0
Implements advanced API key management with capability-based permissions,
configurable constraints, and comprehensive audit logging.
"""
import os
import stat
import json
import secrets
import hashlib
import logging
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime, timedelta
from pathlib import Path
from dataclasses import dataclass, field
from enum import Enum
from uuid import uuid4
import jwt
from app.core.security import verify_capability_token
logger = logging.getLogger(__name__)
class APIKeyStatus(Enum):
"""API key status states"""
ACTIVE = "active"
SUSPENDED = "suspended"
EXPIRED = "expired"
REVOKED = "revoked"
class APIKeyScope(Enum):
"""API key scope levels"""
USER = "user" # User-specific operations
TENANT = "tenant" # Tenant-wide operations
ADMIN = "admin" # Administrative operations
@dataclass
class APIKeyUsage:
"""API key usage tracking"""
requests_count: int = 0
last_used: Optional[datetime] = None
bytes_transferred: int = 0
errors_count: int = 0
rate_limit_hits: int = 0
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for storage"""
return {
"requests_count": self.requests_count,
"last_used": self.last_used.isoformat() if self.last_used else None,
"bytes_transferred": self.bytes_transferred,
"errors_count": self.errors_count,
"rate_limit_hits": self.rate_limit_hits
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "APIKeyUsage":
"""Create from dictionary"""
return cls(
requests_count=data.get("requests_count", 0),
last_used=datetime.fromisoformat(data["last_used"]) if data.get("last_used") else None,
bytes_transferred=data.get("bytes_transferred", 0),
errors_count=data.get("errors_count", 0),
rate_limit_hits=data.get("rate_limit_hits", 0)
)
@dataclass
class APIKeyConfig:
"""Enhanced API key configuration"""
id: str = field(default_factory=lambda: str(uuid4()))
name: str = ""
description: str = ""
owner_id: str = ""
key_hash: str = ""
# Capability and permissions
capabilities: List[str] = field(default_factory=list)
scope: APIKeyScope = APIKeyScope.USER
tenant_constraints: Dict[str, Any] = field(default_factory=dict)
# Rate limiting and quotas
rate_limit_per_hour: int = 1000
daily_quota: int = 10000
monthly_quota: int = 300000
cost_limit_cents: int = 1000
# Resource constraints
max_tokens_per_request: int = 4000
max_concurrent_requests: int = 10
allowed_endpoints: List[str] = field(default_factory=list)
blocked_endpoints: List[str] = field(default_factory=list)
# Network and security
allowed_ips: List[str] = field(default_factory=list)
allowed_domains: List[str] = field(default_factory=list)
require_tls: bool = True
# Lifecycle management
status: APIKeyStatus = APIKeyStatus.ACTIVE
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
expires_at: Optional[datetime] = None
last_rotated: Optional[datetime] = None
# Usage tracking
usage: APIKeyUsage = field(default_factory=APIKeyUsage)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for storage"""
return {
"id": self.id,
"name": self.name,
"description": self.description,
"owner_id": self.owner_id,
"key_hash": self.key_hash,
"capabilities": self.capabilities,
"scope": self.scope.value,
"tenant_constraints": self.tenant_constraints,
"rate_limit_per_hour": self.rate_limit_per_hour,
"daily_quota": self.daily_quota,
"monthly_quota": self.monthly_quota,
"cost_limit_cents": self.cost_limit_cents,
"max_tokens_per_request": self.max_tokens_per_request,
"max_concurrent_requests": self.max_concurrent_requests,
"allowed_endpoints": self.allowed_endpoints,
"blocked_endpoints": self.blocked_endpoints,
"allowed_ips": self.allowed_ips,
"allowed_domains": self.allowed_domains,
"require_tls": self.require_tls,
"status": self.status.value,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"last_rotated": self.last_rotated.isoformat() if self.last_rotated else None,
"usage": self.usage.to_dict()
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "APIKeyConfig":
"""Create from dictionary"""
return cls(
id=data["id"],
name=data["name"],
description=data.get("description", ""),
owner_id=data["owner_id"],
key_hash=data["key_hash"],
capabilities=data.get("capabilities", []),
scope=APIKeyScope(data.get("scope", "user")),
tenant_constraints=data.get("tenant_constraints", {}),
rate_limit_per_hour=data.get("rate_limit_per_hour", 1000),
daily_quota=data.get("daily_quota", 10000),
monthly_quota=data.get("monthly_quota", 300000),
cost_limit_cents=data.get("cost_limit_cents", 1000),
max_tokens_per_request=data.get("max_tokens_per_request", 4000),
max_concurrent_requests=data.get("max_concurrent_requests", 10),
allowed_endpoints=data.get("allowed_endpoints", []),
blocked_endpoints=data.get("blocked_endpoints", []),
allowed_ips=data.get("allowed_ips", []),
allowed_domains=data.get("allowed_domains", []),
require_tls=data.get("require_tls", True),
status=APIKeyStatus(data.get("status", "active")),
created_at=datetime.fromisoformat(data["created_at"]),
updated_at=datetime.fromisoformat(data["updated_at"]),
expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
last_rotated=datetime.fromisoformat(data["last_rotated"]) if data.get("last_rotated") else None,
usage=APIKeyUsage.from_dict(data.get("usage", {}))
)
class EnhancedAPIKeyService:
"""
Enhanced API Key management service with advanced capabilities.
Features:
- Capability-based permissions with tenant constraints
- Granular rate limiting and quota management
- Network-based access controls (IP, domain restrictions)
- Comprehensive usage tracking and analytics
- Automated key rotation and lifecycle management
- Perfect tenant isolation through file-based storage
"""
def __init__(self, tenant_domain: str, signing_key: str = ""):
self.tenant_domain = tenant_domain
self.signing_key = signing_key or self._generate_signing_key()
self.base_path = Path(f"/data/{tenant_domain}/api_keys")
self.keys_path = self.base_path / "keys"
self.usage_path = self.base_path / "usage"
self.audit_path = self.base_path / "audit"
# Ensure directories exist with proper permissions
self._ensure_directories()
logger.info(f"EnhancedAPIKeyService initialized for {tenant_domain}")
def _ensure_directories(self):
"""Ensure API key directories exist with proper permissions"""
for path in [self.keys_path, self.usage_path, self.audit_path]:
path.mkdir(parents=True, exist_ok=True)
# Set permissions to 700 (owner only)
os.chmod(path, stat.S_IRWXU)
def _generate_signing_key(self) -> str:
"""Generate cryptographic signing key for JWT tokens"""
return secrets.token_urlsafe(64)
async def create_api_key(
self,
name: str,
owner_id: str,
capabilities: List[str],
scope: APIKeyScope = APIKeyScope.USER,
expires_in_days: int = 90,
constraints: Optional[Dict[str, Any]] = None,
capability_token: str = ""
) -> Tuple[APIKeyConfig, str]:
"""
Create a new API key with specified capabilities and constraints.
Args:
name: Human-readable name for the key
owner_id: User who owns the key
capabilities: List of capability strings
scope: Key scope level
expires_in_days: Expiration time in days
constraints: Custom constraints for the key
capability_token: Admin capability token
Returns:
Tuple of (APIKeyConfig, raw_key)
"""
# Verify admin capability for key creation
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
# Generate secure API key
raw_key = f"gt2_{self.tenant_domain}_{secrets.token_urlsafe(32)}"
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
# Apply constraints with tenant-specific defaults
final_constraints = self._apply_tenant_defaults(constraints or {})
# Create API key configuration
api_key = APIKeyConfig(
name=name,
owner_id=owner_id,
key_hash=key_hash,
capabilities=capabilities,
scope=scope,
tenant_constraints=final_constraints,
expires_at=datetime.utcnow() + timedelta(days=expires_in_days)
)
# Apply scope-based defaults
self._apply_scope_defaults(api_key, scope)
# Store API key
await self._store_api_key(api_key)
# Log creation
await self._audit_log("api_key_created", owner_id, {
"key_id": api_key.id,
"name": name,
"scope": scope.value,
"capabilities": capabilities
})
logger.info(f"Created API key: {name} ({api_key.id}) for {owner_id}")
return api_key, raw_key
async def validate_api_key(
self,
raw_key: str,
endpoint: str = "",
client_ip: str = "",
user_agent: str = ""
) -> Tuple[bool, Optional[APIKeyConfig], Optional[str]]:
"""
Validate API key and check constraints.
Args:
raw_key: Raw API key from request
endpoint: Requested endpoint
client_ip: Client IP address
user_agent: Client user agent
Returns:
Tuple of (valid, api_key_config, error_message)
"""
# Hash the key for lookup
# Security Note: SHA256 is used here for API key lookup/indexing, not password storage.
# API keys are high-entropy random strings, making them resistant to dictionary/rainbow attacks.
# This is an acceptable security pattern similar to how GitHub and Stripe handle API keys.
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
# Load API key configuration
api_key = await self._load_api_key_by_hash(key_hash)
if not api_key:
return False, None, "Invalid API key"
# Check key status
if api_key.status != APIKeyStatus.ACTIVE:
return False, api_key, f"API key is {api_key.status.value}"
# Check expiration
if api_key.expires_at and datetime.utcnow() > api_key.expires_at:
# Auto-expire the key
api_key.status = APIKeyStatus.EXPIRED
await self._store_api_key(api_key)
return False, api_key, "API key has expired"
# Check endpoint restrictions
if api_key.allowed_endpoints:
if endpoint not in api_key.allowed_endpoints:
return False, api_key, f"Endpoint {endpoint} not allowed"
if endpoint in api_key.blocked_endpoints:
return False, api_key, f"Endpoint {endpoint} is blocked"
# Check IP restrictions
if api_key.allowed_ips and client_ip not in api_key.allowed_ips:
return False, api_key, f"IP {client_ip} not allowed"
# Check rate limits
rate_limit_ok, rate_error = await self._check_rate_limits(api_key)
if not rate_limit_ok:
return False, api_key, rate_error
# Update usage
await self._update_usage(api_key, endpoint, client_ip)
return True, api_key, None
async def generate_capability_token(
self,
api_key: APIKeyConfig,
additional_context: Optional[Dict[str, Any]] = None
) -> str:
"""
Generate JWT capability token from API key.
Args:
api_key: API key configuration
additional_context: Additional context for the token
Returns:
JWT capability token
"""
# Build capability payload
capabilities = []
for cap_string in api_key.capabilities:
capability = {
"resource": cap_string,
"actions": ["*"], # API keys get full action access for their capabilities
"constraints": api_key.tenant_constraints.get(cap_string, {})
}
capabilities.append(capability)
# Create JWT payload
payload = {
"sub": api_key.owner_id,
"tenant_id": self.tenant_domain,
"api_key_id": api_key.id,
"scope": api_key.scope.value,
"capabilities": capabilities,
"constraints": api_key.tenant_constraints,
"rate_limits": {
"requests_per_hour": api_key.rate_limit_per_hour,
"max_tokens_per_request": api_key.max_tokens_per_request,
"cost_limit_cents": api_key.cost_limit_cents
},
"iat": int(datetime.utcnow().timestamp()),
"exp": int((datetime.utcnow() + timedelta(hours=1)).timestamp())
}
# Add additional context
if additional_context:
payload.update(additional_context)
# Sign and return token
token = jwt.encode(payload, self.signing_key, algorithm="HS256")
return token
async def rotate_api_key(
self,
key_id: str,
owner_id: str,
capability_token: str
) -> Tuple[APIKeyConfig, str]:
"""
Rotate API key (generate new key value).
Args:
key_id: API key ID to rotate
owner_id: Owner of the key
capability_token: Admin capability token
Returns:
Tuple of (updated_config, new_raw_key)
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
# Load existing key
api_key = await self._load_api_key(key_id)
if not api_key:
raise ValueError("API key not found")
# Verify ownership
if api_key.owner_id != owner_id:
raise PermissionError("Only key owner can rotate")
# Generate new key
new_raw_key = f"gt2_{self.tenant_domain}_{secrets.token_urlsafe(32)}"
new_key_hash = hashlib.sha256(new_raw_key.encode()).hexdigest()
# Update configuration
api_key.key_hash = new_key_hash
api_key.last_rotated = datetime.utcnow()
api_key.updated_at = datetime.utcnow()
# Store updated key
await self._store_api_key(api_key)
# Log rotation
await self._audit_log("api_key_rotated", owner_id, {
"key_id": key_id,
"name": api_key.name
})
logger.info(f"Rotated API key: {api_key.name} ({key_id})")
return api_key, new_raw_key
async def revoke_api_key(
self,
key_id: str,
owner_id: str,
capability_token: str
) -> bool:
"""
Revoke API key (mark as revoked).
Args:
key_id: API key ID to revoke
owner_id: Owner of the key
capability_token: Admin capability token
Returns:
True if revoked successfully
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
# Load and verify key
api_key = await self._load_api_key(key_id)
if not api_key:
return False
if api_key.owner_id != owner_id:
raise PermissionError("Only key owner can revoke")
# Revoke key
api_key.status = APIKeyStatus.REVOKED
api_key.updated_at = datetime.utcnow()
# Store updated key
await self._store_api_key(api_key)
# Log revocation
await self._audit_log("api_key_revoked", owner_id, {
"key_id": key_id,
"name": api_key.name
})
logger.info(f"Revoked API key: {api_key.name} ({key_id})")
return True
async def list_user_api_keys(
self,
owner_id: str,
capability_token: str,
include_usage: bool = True
) -> List[APIKeyConfig]:
"""
List API keys for a user.
Args:
owner_id: User to get keys for
capability_token: User capability token
include_usage: Include usage statistics
Returns:
List of API key configurations
"""
# Verify capability token
token_data = verify_capability_token(capability_token)
if not token_data or token_data.get("tenant_id") != self.tenant_domain:
raise PermissionError("Invalid capability token")
user_keys = []
# Load all keys and filter by owner
if self.keys_path.exists():
for key_file in self.keys_path.glob("*.json"):
try:
with open(key_file, "r") as f:
data = json.load(f)
if data.get("owner_id") == owner_id:
api_key = APIKeyConfig.from_dict(data)
# Update usage if requested
if include_usage:
await self._update_key_usage_stats(api_key)
user_keys.append(api_key)
except Exception as e:
logger.error(f"Error loading key file {key_file}: {e}")
return sorted(user_keys, key=lambda k: k.created_at, reverse=True)
async def get_usage_analytics(
self,
owner_id: str,
key_id: Optional[str] = None,
days: int = 30
) -> Dict[str, Any]:
"""
Get usage analytics for API keys.
Args:
owner_id: Owner of the keys
key_id: Specific key ID (optional)
days: Number of days to analyze
Returns:
Usage analytics data
"""
analytics = {
"total_requests": 0,
"total_errors": 0,
"avg_requests_per_day": 0,
"most_used_endpoints": [],
"rate_limit_hits": 0,
"keys_analyzed": 0,
"date_range": {
"start": (datetime.utcnow() - timedelta(days=days)).isoformat(),
"end": datetime.utcnow().isoformat()
}
}
# Get user's keys
user_keys = await self.list_user_api_keys(owner_id, "", include_usage=True)
# Filter by specific key if requested
if key_id:
user_keys = [key for key in user_keys if key.id == key_id]
# Aggregate usage data
for api_key in user_keys:
analytics["total_requests"] += api_key.usage.requests_count
analytics["total_errors"] += api_key.usage.errors_count
analytics["rate_limit_hits"] += api_key.usage.rate_limit_hits
analytics["keys_analyzed"] += 1
# Calculate averages
if days > 0:
analytics["avg_requests_per_day"] = analytics["total_requests"] / days
return analytics
def _apply_tenant_defaults(self, constraints: Dict[str, Any]) -> Dict[str, Any]:
"""Apply tenant-specific default constraints"""
defaults = {
"max_automation_chain_depth": 5,
"mcp_memory_limit_mb": 512,
"mcp_timeout_seconds": 30,
"max_file_size_bytes": 10 * 1024 * 1024, # 10MB
"allowed_file_types": [".pdf", ".txt", ".md", ".json", ".csv"],
"enable_premium_features": False
}
# Merge with provided constraints (provided values take precedence)
final_constraints = defaults.copy()
final_constraints.update(constraints)
return final_constraints
def _apply_scope_defaults(self, api_key: APIKeyConfig, scope: APIKeyScope):
"""Apply scope-based default limits"""
if scope == APIKeyScope.USER:
api_key.rate_limit_per_hour = 1000
api_key.daily_quota = 10000
api_key.cost_limit_cents = 1000
elif scope == APIKeyScope.TENANT:
api_key.rate_limit_per_hour = 5000
api_key.daily_quota = 50000
api_key.cost_limit_cents = 5000
elif scope == APIKeyScope.ADMIN:
api_key.rate_limit_per_hour = 10000
api_key.daily_quota = 100000
api_key.cost_limit_cents = 10000
async def _check_rate_limits(self, api_key: APIKeyConfig) -> Tuple[bool, Optional[str]]:
"""Check if API key is within rate limits"""
# For now, implement basic hourly check
# In production, would check against usage tracking database
current_hour = datetime.utcnow().replace(minute=0, second=0, microsecond=0)
# Load hourly usage (mock implementation)
hourly_usage = 0 # Would query actual usage data
if hourly_usage >= api_key.rate_limit_per_hour:
api_key.usage.rate_limit_hits += 1
await self._store_api_key(api_key)
return False, f"Rate limit exceeded: {hourly_usage}/{api_key.rate_limit_per_hour} requests per hour"
return True, None
async def _update_usage(self, api_key: APIKeyConfig, endpoint: str, client_ip: str):
"""Update API key usage statistics"""
api_key.usage.requests_count += 1
api_key.usage.last_used = datetime.utcnow()
# Store updated usage
await self._store_api_key(api_key)
# Log detailed usage (for analytics)
await self._log_usage(api_key.id, endpoint, client_ip)
async def _store_api_key(self, api_key: APIKeyConfig):
"""Store API key configuration to file system"""
key_file = self.keys_path / f"{api_key.id}.json"
with open(key_file, "w") as f:
json.dump(api_key.to_dict(), f, indent=2)
# Set secure permissions
os.chmod(key_file, stat.S_IRUSR | stat.S_IWUSR) # 600
async def _load_api_key(self, key_id: str) -> Optional[APIKeyConfig]:
"""Load API key configuration by ID"""
key_file = self.keys_path / f"{key_id}.json"
if not key_file.exists():
return None
try:
with open(key_file, "r") as f:
data = json.load(f)
return APIKeyConfig.from_dict(data)
except Exception as e:
logger.error(f"Error loading API key {key_id}: {e}")
return None
async def _load_api_key_by_hash(self, key_hash: str) -> Optional[APIKeyConfig]:
"""Load API key configuration by hash"""
if not self.keys_path.exists():
return None
for key_file in self.keys_path.glob("*.json"):
try:
with open(key_file, "r") as f:
data = json.load(f)
if data.get("key_hash") == key_hash:
return APIKeyConfig.from_dict(data)
except Exception as e:
logger.error(f"Error loading key file {key_file}: {e}")
return None
async def _update_key_usage_stats(self, api_key: APIKeyConfig):
"""Update comprehensive usage statistics for a key"""
# In production, would aggregate from detailed usage logs
# For now, use existing basic stats
pass
async def _log_usage(self, key_id: str, endpoint: str, client_ip: str):
"""Log detailed API key usage for analytics"""
usage_record = {
"timestamp": datetime.utcnow().isoformat(),
"key_id": key_id,
"endpoint": endpoint,
"client_ip": client_ip,
"tenant": self.tenant_domain
}
# Store in daily usage file
date_str = datetime.utcnow().strftime("%Y-%m-%d")
usage_file = self.usage_path / f"usage_{date_str}.jsonl"
with open(usage_file, "a") as f:
f.write(json.dumps(usage_record) + "\n")
async def _audit_log(self, action: str, user_id: str, details: Dict[str, Any]):
"""Log API key management actions for audit"""
audit_record = {
"timestamp": datetime.utcnow().isoformat(),
"action": action,
"user_id": user_id,
"tenant": self.tenant_domain,
"details": details
}
# Store in daily audit file
date_str = datetime.utcnow().strftime("%Y-%m-%d")
audit_file = self.audit_path / f"audit_{date_str}.jsonl"
with open(audit_file, "a") as f:
f.write(json.dumps(audit_record) + "\n")

View File

@@ -0,0 +1,635 @@
"""
Tenant Event Bus System
Implements event-driven architecture for automation triggers with perfect
tenant isolation and capability-based execution.
"""
import asyncio
import logging
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass, field
from datetime import datetime
from uuid import uuid4
from enum import Enum
import json
from pathlib import Path
from app.core.path_security import sanitize_tenant_domain
logger = logging.getLogger(__name__)
class TriggerType(Enum):
"""Types of automation triggers"""
CRON = "cron" # Time-based
WEBHOOK = "webhook" # External HTTP
EVENT = "event" # Internal events
CHAIN = "chain" # Triggered by other automations
MANUAL = "manual" # User-initiated
# Event catalog with required fields
EVENT_CATALOG = {
"document.uploaded": ["document_id", "dataset_id", "filename"],
"document.processed": ["document_id", "chunks_created"],
"agent.created": ["agent_id", "name", "owner_id"],
"chat.started": ["conversation_id", "agent_id"],
"resource.shared": ["resource_id", "access_group", "shared_with"],
"quota.warning": ["resource_type", "current_usage", "limit"],
"automation.completed": ["automation_id", "result", "duration_ms"],
"automation.failed": ["automation_id", "error", "retry_count"]
}
@dataclass
class Event:
"""Event data structure"""
id: str = field(default_factory=lambda: str(uuid4()))
type: str = ""
tenant: str = ""
user: str = ""
timestamp: datetime = field(default_factory=datetime.utcnow)
data: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""Convert event to dictionary"""
return {
"id": self.id,
"type": self.type,
"tenant": self.tenant,
"user": self.user,
"timestamp": self.timestamp.isoformat(),
"data": self.data,
"metadata": self.metadata
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Event":
"""Create event from dictionary"""
return cls(
id=data.get("id", str(uuid4())),
type=data.get("type", ""),
tenant=data.get("tenant", ""),
user=data.get("user", ""),
timestamp=datetime.fromisoformat(data.get("timestamp", datetime.utcnow().isoformat())),
data=data.get("data", {}),
metadata=data.get("metadata", {})
)
@dataclass
class Automation:
"""Automation configuration"""
id: str = field(default_factory=lambda: str(uuid4()))
name: str = ""
owner_id: str = ""
trigger_type: TriggerType = TriggerType.MANUAL
trigger_config: Dict[str, Any] = field(default_factory=dict)
conditions: List[Dict[str, Any]] = field(default_factory=list)
actions: List[Dict[str, Any]] = field(default_factory=list)
triggers_chain: bool = False
chain_targets: List[str] = field(default_factory=list)
max_retries: int = 3
timeout_seconds: int = 300
is_active: bool = True
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
def matches_event(self, event: Event) -> bool:
"""Check if automation should trigger for event"""
if not self.is_active:
return False
if self.trigger_type != TriggerType.EVENT:
return False
# Check event type matches
event_types = self.trigger_config.get("event_types", [])
if event.type not in event_types:
return False
# Check conditions
for condition in self.conditions:
if not self._evaluate_condition(condition, event):
return False
return True
def _evaluate_condition(self, condition: Dict[str, Any], event: Event) -> bool:
"""Evaluate a single condition"""
field = condition.get("field")
operator = condition.get("operator")
value = condition.get("value")
# Get field value from event
if "." in field:
parts = field.split(".")
# Handle data.field paths by starting from the event object
if parts[0] == "data":
event_value = event.data
parts = parts[1:] # Skip the "data" part
else:
event_value = event
for part in parts:
if isinstance(event_value, dict):
event_value = event_value.get(part)
elif hasattr(event_value, part):
event_value = getattr(event_value, part)
else:
return False
else:
event_value = getattr(event, field, None)
# Evaluate condition
if operator == "equals":
return event_value == value
elif operator == "not_equals":
return event_value != value
elif operator == "contains":
return value in str(event_value)
elif operator == "greater_than":
return float(event_value) > float(value)
elif operator == "less_than":
return float(event_value) < float(value)
else:
return False
class TenantEventBus:
"""
Event system for automation triggers with tenant isolation.
Features:
- Perfect tenant isolation through file-based storage
- Event persistence and replay capability
- Automation matching and triggering
- Access control for automation execution
"""
def __init__(self, tenant_domain: str, base_path: Optional[Path] = None):
self.tenant_domain = tenant_domain
# Sanitize tenant_domain to prevent path traversal
safe_tenant = sanitize_tenant_domain(tenant_domain)
self.base_path = base_path or (Path("/data") / safe_tenant / "events")
self.event_store_path = self.base_path / "store"
self.automations_path = self.base_path / "automations"
self.event_handlers: Dict[str, List[Callable]] = {}
self.running_automations: Dict[str, asyncio.Task] = {}
# Ensure directories exist with proper permissions
self._ensure_directories()
logger.info(f"TenantEventBus initialized for {tenant_domain}")
def _ensure_directories(self):
"""Ensure event directories exist with proper permissions"""
import os
import stat
for path in [self.event_store_path, self.automations_path]:
path.mkdir(parents=True, exist_ok=True)
# Set permissions to 700 (owner only)
# codeql[py/path-injection] paths derived from sanitize_tenant_domain() at line 175
os.chmod(path, stat.S_IRWXU)
async def emit_event(
self,
event_type: str,
user_id: str,
data: Dict[str, Any],
metadata: Optional[Dict[str, Any]] = None
) -> Event:
"""
Emit an event and trigger matching automations.
Args:
event_type: Type of event from EVENT_CATALOG
user_id: User who triggered the event
data: Event data
metadata: Optional metadata
Returns:
Created event
"""
# Validate event type
if event_type not in EVENT_CATALOG:
logger.warning(f"Unknown event type: {event_type}")
# Create event
event = Event(
type=event_type,
tenant=self.tenant_domain,
user=user_id,
data=data,
metadata=metadata or {}
)
# Store event
await self._store_event(event)
# Find matching automations
automations = await self._find_matching_automations(event)
# Trigger automations with access control
for automation in automations:
if await self._can_trigger(user_id, automation):
asyncio.create_task(self._trigger_automation(automation, event))
# Call registered handlers
if event_type in self.event_handlers:
for handler in self.event_handlers[event_type]:
asyncio.create_task(handler(event))
logger.info(f"Event emitted: {event_type} by {user_id}")
return event
async def _store_event(self, event: Event):
"""Store event to file system"""
# Create daily event file
date_str = event.timestamp.strftime("%Y-%m-%d")
event_file = self.event_store_path / f"events_{date_str}.jsonl"
# Append event as JSON line
with open(event_file, "a") as f:
f.write(json.dumps(event.to_dict()) + "\n")
async def _find_matching_automations(self, event: Event) -> List[Automation]:
"""Find automations that match the event"""
matching = []
# Load all automations from file system
if self.automations_path.exists():
for automation_file in self.automations_path.glob("*.json"):
try:
with open(automation_file, "r") as f:
automation_data = json.load(f)
automation = Automation(**automation_data)
if automation.matches_event(event):
matching.append(automation)
except Exception as e:
logger.error(f"Error loading automation {automation_file}: {e}")
return matching
async def _can_trigger(self, user_id: str, automation: Automation) -> bool:
"""Check if user can trigger automation"""
# Owner can always trigger their automations
if automation.owner_id == user_id:
return True
# Check if automation is public or shared
# This would integrate with AccessController
# For now, only owner can trigger
return False
async def _trigger_automation(self, automation: Automation, event: Event):
"""Trigger automation execution"""
try:
# Check if automation is already running
if automation.id in self.running_automations:
logger.info(f"Automation {automation.id} already running, skipping")
return
# Create task for automation execution
task = asyncio.create_task(
self._execute_automation(automation, event)
)
self.running_automations[automation.id] = task
# Wait for completion with timeout
await asyncio.wait_for(task, timeout=automation.timeout_seconds)
except asyncio.TimeoutError:
logger.error(f"Automation {automation.id} timed out")
await self.emit_event(
"automation.failed",
automation.owner_id,
{
"automation_id": automation.id,
"error": "Timeout",
"retry_count": 0
}
)
except Exception as e:
logger.error(f"Error triggering automation {automation.id}: {e}")
await self.emit_event(
"automation.failed",
automation.owner_id,
{
"automation_id": automation.id,
"error": str(e),
"retry_count": 0
}
)
finally:
# Remove from running automations
self.running_automations.pop(automation.id, None)
async def _execute_automation(self, automation: Automation, event: Event) -> Any:
"""Execute automation actions"""
start_time = datetime.utcnow()
results = []
try:
# Execute each action in sequence
for action in automation.actions:
result = await self._execute_action(action, event, automation)
results.append(result)
# Calculate duration
duration_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
# Emit completion event
await self.emit_event(
"automation.completed",
automation.owner_id,
{
"automation_id": automation.id,
"result": results,
"duration_ms": duration_ms
}
)
return results
except Exception as e:
logger.error(f"Error executing automation {automation.id}: {e}")
raise
async def _execute_action(
self,
action: Dict[str, Any],
event: Event,
automation: Automation
) -> Any:
"""Execute a single action"""
action_type = action.get("type")
if action_type == "webhook":
return await self._execute_webhook_action(action, event)
elif action_type == "email":
return await self._execute_email_action(action, event)
elif action_type == "log":
return await self._execute_log_action(action, event)
elif action_type == "chain":
return await self._execute_chain_action(action, event, automation)
else:
logger.warning(f"Unknown action type: {action_type}")
return None
async def _execute_webhook_action(
self,
action: Dict[str, Any],
event: Event
) -> Dict[str, Any]:
"""Execute webhook action (mock implementation)"""
url = action.get("url")
method = action.get("method", "POST")
headers = action.get("headers", {})
body = action.get("body", event.to_dict())
logger.info(f"Mock webhook call to {url}")
# In production, use httpx or aiohttp to make actual HTTP request
return {
"status": "success",
"url": url,
"method": method,
"mock": True
}
async def _execute_email_action(
self,
action: Dict[str, Any],
event: Event
) -> Dict[str, Any]:
"""Execute email action (mock implementation)"""
to = action.get("to")
subject = action.get("subject")
body = action.get("body")
logger.info(f"Mock email to {to}: {subject}")
# In production, integrate with email service
return {
"status": "success",
"to": to,
"subject": subject,
"mock": True
}
async def _execute_log_action(
self,
action: Dict[str, Any],
event: Event
) -> Dict[str, Any]:
"""Execute log action"""
message = action.get("message", f"Event: {event.type}")
level = action.get("level", "info")
if level == "debug":
logger.debug(message)
elif level == "info":
logger.info(message)
elif level == "warning":
logger.warning(message)
elif level == "error":
logger.error(message)
return {
"status": "success",
"message": message,
"level": level
}
async def _execute_chain_action(
self,
action: Dict[str, Any],
event: Event,
automation: Automation
) -> Dict[str, Any]:
"""Execute chain action to trigger other automations"""
target_automation_id = action.get("target_automation_id")
if not target_automation_id:
return {"status": "error", "message": "No target automation specified"}
# Emit chain event
chain_event = await self.emit_event(
"automation.chain",
automation.owner_id,
{
"source_automation": automation.id,
"target_automation": target_automation_id,
"original_event": event.to_dict()
}
)
return {
"status": "success",
"chain_event_id": chain_event.id,
"target_automation": target_automation_id
}
def register_handler(self, event_type: str, handler: Callable):
"""Register an event handler"""
if event_type not in self.event_handlers:
self.event_handlers[event_type] = []
self.event_handlers[event_type].append(handler)
async def create_automation(
self,
name: str,
owner_id: str,
trigger_type: TriggerType,
trigger_config: Dict[str, Any],
actions: List[Dict[str, Any]],
conditions: Optional[List[Dict[str, Any]]] = None
) -> Automation:
"""Create and save a new automation"""
automation = Automation(
name=name,
owner_id=owner_id,
trigger_type=trigger_type,
trigger_config=trigger_config,
actions=actions,
conditions=conditions or []
)
# Save to file system
automation_file = self.automations_path / f"{automation.id}.json"
with open(automation_file, "w") as f:
json.dump({
"id": automation.id,
"name": automation.name,
"owner_id": automation.owner_id,
"trigger_type": automation.trigger_type.value,
"trigger_config": automation.trigger_config,
"conditions": automation.conditions,
"actions": automation.actions,
"triggers_chain": automation.triggers_chain,
"chain_targets": automation.chain_targets,
"max_retries": automation.max_retries,
"timeout_seconds": automation.timeout_seconds,
"is_active": automation.is_active,
"created_at": automation.created_at.isoformat(),
"updated_at": automation.updated_at.isoformat()
}, f, indent=2)
logger.info(f"Created automation: {automation.name} ({automation.id})")
return automation
async def get_automation(self, automation_id: str) -> Optional[Automation]:
"""Get automation by ID"""
automation_file = self.automations_path / f"{automation_id}.json"
if not automation_file.exists():
return None
try:
with open(automation_file, "r") as f:
data = json.load(f)
data["trigger_type"] = TriggerType(data["trigger_type"])
data["created_at"] = datetime.fromisoformat(data["created_at"])
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
return Automation(**data)
except Exception as e:
logger.error(f"Error loading automation {automation_id}: {e}")
return None
async def list_automations(self, owner_id: Optional[str] = None) -> List[Automation]:
"""List all automations, optionally filtered by owner"""
automations = []
if self.automations_path.exists():
for automation_file in self.automations_path.glob("*.json"):
try:
with open(automation_file, "r") as f:
data = json.load(f)
# Filter by owner if specified
if owner_id and data.get("owner_id") != owner_id:
continue
data["trigger_type"] = TriggerType(data["trigger_type"])
data["created_at"] = datetime.fromisoformat(data["created_at"])
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
automations.append(Automation(**data))
except Exception as e:
logger.error(f"Error loading automation {automation_file}: {e}")
return automations
async def delete_automation(self, automation_id: str, owner_id: str) -> bool:
"""Delete an automation"""
automation = await self.get_automation(automation_id)
if not automation:
return False
# Check ownership
if automation.owner_id != owner_id:
logger.warning(f"User {owner_id} attempted to delete automation owned by {automation.owner_id}")
return False
# Delete file
automation_file = self.automations_path / f"{automation_id}.json"
automation_file.unlink()
logger.info(f"Deleted automation: {automation_id}")
return True
async def get_event_history(
self,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
event_type: Optional[str] = None,
user_id: Optional[str] = None,
limit: int = 100
) -> List[Event]:
"""Get event history with optional filters"""
events = []
# Determine date range
if not end_date:
end_date = datetime.utcnow()
if not start_date:
start_date = end_date.replace(hour=0, minute=0, second=0, microsecond=0)
# Iterate through daily event files
current_date = start_date
while current_date <= end_date:
date_str = current_date.strftime("%Y-%m-%d")
event_file = self.event_store_path / f"events_{date_str}.jsonl"
if event_file.exists():
with open(event_file, "r") as f:
for line in f:
try:
event_data = json.loads(line)
event = Event.from_dict(event_data)
# Apply filters
if event_type and event.type != event_type:
continue
if user_id and event.user != user_id:
continue
events.append(event)
if len(events) >= limit:
return events
except Exception as e:
logger.error(f"Error parsing event: {e}")
# Move to next day
current_date = current_date.replace(day=current_date.day + 1)
return events

View File

@@ -0,0 +1,869 @@
"""
Event Automation Service for GT 2.0 Tenant Backend
Handles event-driven automation workflows including:
- Document processing triggers
- Conversation events
- RAG pipeline automation
- Agent lifecycle events
- User activity tracking
Perfect tenant isolation with zero downtime compliance.
"""
import logging
import asyncio
import json
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Callable
from enum import Enum
from dataclasses import dataclass, asdict
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, desc
from sqlalchemy.orm import selectinload
from app.core.database import get_db_session
from app.core.config import get_settings
from app.models.event import Event, EventTrigger, EventAction, EventSubscription
from app.services.rag_service import RAGService
from app.services.conversation_service import ConversationService
from app.services.assistant_manager import AssistantManager
logger = logging.getLogger(__name__)
class EventType(str, Enum):
"""Event types for automation triggers"""
DOCUMENT_UPLOADED = "document.uploaded"
DOCUMENT_PROCESSED = "document.processed"
DOCUMENT_FAILED = "document.failed"
CONVERSATION_STARTED = "conversation.started"
MESSAGE_SENT = "message.sent"
ASSISTANT_CREATED = "agent.created"
RAG_SEARCH_PERFORMED = "rag.search_performed"
USER_LOGIN = "user.login"
USER_ACTIVITY = "user.activity"
SYSTEM_HEALTH_CHECK = "system.health_check"
TEAM_INVITATION_CREATED = "team.invitation.created"
TEAM_OBSERVABLE_REQUESTED = "team.observable_requested"
class ActionType(str, Enum):
"""Action types for event responses"""
PROCESS_DOCUMENT = "process_document"
SEND_NOTIFICATION = "send_notification"
UPDATE_STATISTICS = "update_statistics"
TRIGGER_RAG_INDEXING = "trigger_rag_indexing"
LOG_ANALYTICS = "log_analytics"
EXECUTE_WEBHOOK = "execute_webhook"
CREATE_ASSISTANT = "create_assistant"
SCHEDULE_TASK = "schedule_task"
@dataclass
class EventPayload:
"""Event payload structure"""
event_id: str
event_type: EventType
user_id: str
tenant_id: str
timestamp: datetime
data: Dict[str, Any]
metadata: Dict[str, Any] = None
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@dataclass
class EventActionConfig:
"""Configuration for event actions"""
action_type: ActionType
config: Dict[str, Any]
delay_seconds: int = 0
retry_count: int = 3
retry_delay: int = 60
condition: Optional[str] = None # Python expression for conditional execution
class EventService:
"""
Event automation service with perfect tenant isolation.
GT 2.0 Security Principles:
- Perfect tenant isolation (all events user-scoped)
- Zero downtime compliance (async processing)
- Self-contained automation (no external dependencies)
- Stateless event processing
"""
def __init__(self, db: AsyncSession):
self.db = db
self.settings = get_settings()
self.rag_service = RAGService(db)
self.conversation_service = ConversationService(db)
self.assistant_manager = AssistantManager(db)
# Event handlers registry
self.action_handlers: Dict[ActionType, Callable] = {
ActionType.PROCESS_DOCUMENT: self._handle_process_document,
ActionType.SEND_NOTIFICATION: self._handle_send_notification,
ActionType.UPDATE_STATISTICS: self._handle_update_statistics,
ActionType.TRIGGER_RAG_INDEXING: self._handle_trigger_rag_indexing,
ActionType.LOG_ANALYTICS: self._handle_log_analytics,
ActionType.EXECUTE_WEBHOOK: self._handle_execute_webhook,
ActionType.CREATE_ASSISTANT: self._handle_create_assistant,
ActionType.SCHEDULE_TASK: self._handle_schedule_task,
}
# Active event subscriptions cache
self.subscriptions_cache: Dict[str, List[EventSubscription]] = {}
self.cache_expiry: Optional[datetime] = None
logger.info("Event automation service initialized with tenant isolation")
async def emit_event(
self,
event_type: EventType,
user_id: str,
tenant_id: str,
data: Dict[str, Any],
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""Emit an event and trigger automated actions"""
try:
# Create event payload
event_payload = EventPayload(
event_id=str(uuid.uuid4()),
event_type=event_type,
user_id=user_id,
tenant_id=tenant_id,
timestamp=datetime.utcnow(),
data=data,
metadata=metadata or {}
)
# Store event in database
event_record = Event(
event_id=event_payload.event_id,
event_type=event_type.value,
user_id=user_id,
tenant_id=tenant_id,
payload=event_payload.to_dict(),
status="processing"
)
self.db.add(event_record)
await self.db.commit()
# Process event asynchronously
asyncio.create_task(self._process_event(event_payload))
logger.info(f"Event emitted: {event_type.value} for user {user_id}")
return event_payload.event_id
except Exception as e:
logger.error(f"Failed to emit event {event_type.value}: {e}")
raise
async def _process_event(self, event_payload: EventPayload) -> None:
"""Process event and execute matching actions"""
try:
# Get subscriptions for this event type
subscriptions = await self._get_event_subscriptions(
event_payload.event_type,
event_payload.user_id,
event_payload.tenant_id
)
if not subscriptions:
logger.debug(f"No subscriptions found for event {event_payload.event_type}")
return
# Execute actions for each subscription
for subscription in subscriptions:
try:
await self._execute_subscription_actions(subscription, event_payload)
except Exception as e:
logger.error(f"Failed to execute subscription {subscription.id}: {e}")
continue
# Update event status
await self._update_event_status(event_payload.event_id, "completed")
except Exception as e:
logger.error(f"Failed to process event {event_payload.event_id}: {e}")
await self._update_event_status(event_payload.event_id, "failed", str(e))
async def _get_event_subscriptions(
self,
event_type: EventType,
user_id: str,
tenant_id: str
) -> List[EventSubscription]:
"""Get active subscriptions for event type with tenant isolation"""
try:
# Check cache first
cache_key = f"{tenant_id}:{user_id}:{event_type.value}"
if (self.cache_expiry and datetime.utcnow() < self.cache_expiry and
cache_key in self.subscriptions_cache):
return self.subscriptions_cache[cache_key]
# Query database
query = select(EventSubscription).where(
and_(
EventSubscription.event_type == event_type.value,
EventSubscription.user_id == user_id,
EventSubscription.tenant_id == tenant_id,
EventSubscription.is_active == True
)
).options(selectinload(EventSubscription.actions))
result = await self.db.execute(query)
subscriptions = result.scalars().all()
# Cache results
self.subscriptions_cache[cache_key] = list(subscriptions)
self.cache_expiry = datetime.utcnow() + timedelta(minutes=5)
return list(subscriptions)
except Exception as e:
logger.error(f"Failed to get event subscriptions: {e}")
return []
async def _execute_subscription_actions(
self,
subscription: EventSubscription,
event_payload: EventPayload
) -> None:
"""Execute all actions for a subscription"""
try:
for action in subscription.actions:
# Check if action should be executed
if not await self._should_execute_action(action, event_payload):
continue
# Add delay if specified
if action.delay_seconds > 0:
await asyncio.sleep(action.delay_seconds)
# Execute action with retry logic
await self._execute_action_with_retry(action, event_payload)
except Exception as e:
logger.error(f"Failed to execute subscription actions: {e}")
raise
async def _should_execute_action(
self,
action: EventAction,
event_payload: EventPayload
) -> bool:
"""Check if action should be executed based on conditions"""
try:
if not action.condition:
return True
# Create evaluation context
context = {
'event': event_payload.to_dict(),
'data': event_payload.data,
'user_id': event_payload.user_id,
'tenant_id': event_payload.tenant_id,
'event_type': event_payload.event_type.value
}
# Safely evaluate condition
try:
result = eval(action.condition, {"__builtins__": {}}, context)
return bool(result)
except Exception as e:
logger.warning(f"Failed to evaluate action condition: {e}")
return False
except Exception as e:
logger.error(f"Error checking action condition: {e}")
return False
async def _execute_action_with_retry(
self,
action: EventAction,
event_payload: EventPayload
) -> None:
"""Execute action with retry logic"""
last_error = None
for attempt in range(action.retry_count + 1):
try:
await self._execute_action(action, event_payload)
return # Success
except Exception as e:
last_error = e
logger.warning(f"Action execution attempt {attempt + 1} failed: {e}")
if attempt < action.retry_count:
await asyncio.sleep(action.retry_delay)
else:
logger.error(f"Action execution failed after {action.retry_count + 1} attempts")
raise last_error
async def _execute_action(
self,
action: EventAction,
event_payload: EventPayload
) -> None:
"""Execute a specific action"""
try:
action_type = ActionType(action.action_type)
handler = self.action_handlers.get(action_type)
if not handler:
raise ValueError(f"No handler for action type: {action_type}")
await handler(action, event_payload)
logger.debug(f"Action executed: {action_type.value} for event {event_payload.event_id}")
except Exception as e:
logger.error(f"Failed to execute action {action.action_type}: {e}")
raise
# Action Handlers
async def _handle_process_document(
self,
action: EventAction,
event_payload: EventPayload
) -> None:
"""Handle document processing automation"""
try:
document_id = event_payload.data.get("document_id")
if not document_id:
raise ValueError("document_id required for process_document action")
chunking_strategy = action.config.get("chunking_strategy", "hybrid")
result = await self.rag_service.process_document(
user_id=event_payload.user_id,
document_id=document_id,
tenant_id=event_payload.tenant_id,
chunking_strategy=chunking_strategy
)
# Emit processing completed event
await self.emit_event(
EventType.DOCUMENT_PROCESSED,
event_payload.user_id,
event_payload.tenant_id,
{
"document_id": document_id,
"chunk_count": result.get("chunk_count", 0),
"processing_result": result
}
)
except Exception as e:
# Emit processing failed event
await self.emit_event(
EventType.DOCUMENT_FAILED,
event_payload.user_id,
event_payload.tenant_id,
{
"document_id": event_payload.data.get("document_id"),
"error": str(e)
}
)
raise
async def _handle_send_notification(
self,
action: EventAction,
event_payload: EventPayload
) -> None:
"""Handle notification sending"""
try:
notification_type = action.config.get("type", "system")
message = action.config.get("message", "Event notification")
# Format message with event data
formatted_message = message.format(**event_payload.data)
# Store notification (implement notification system later)
notification_data = {
"type": notification_type,
"message": formatted_message,
"user_id": event_payload.user_id,
"event_id": event_payload.event_id,
"created_at": datetime.utcnow().isoformat()
}
logger.info(f"Notification sent: {formatted_message} to user {event_payload.user_id}")
except Exception as e:
logger.error(f"Failed to send notification: {e}")
raise
async def _handle_update_statistics(
self,
action: EventAction,
event_payload: EventPayload
) -> None:
"""Handle statistics updates"""
try:
stat_type = action.config.get("type")
increment = action.config.get("increment", 1)
# Update user statistics (implement statistics system later)
logger.info(f"Statistics updated: {stat_type} += {increment} for user {event_payload.user_id}")
except Exception as e:
logger.error(f"Failed to update statistics: {e}")
raise
async def _handle_trigger_rag_indexing(
self,
action: EventAction,
event_payload: EventPayload
) -> None:
"""Handle RAG reindexing automation"""
try:
dataset_ids = action.config.get("dataset_ids", [])
if not dataset_ids:
# Get all user datasets
datasets = await self.rag_service.list_user_datasets(event_payload.user_id)
dataset_ids = [d.id for d in datasets]
for dataset_id in dataset_ids:
# Trigger reindexing for dataset
logger.info(f"RAG reindexing triggered for dataset {dataset_id}")
except Exception as e:
logger.error(f"Failed to trigger RAG indexing: {e}")
raise
async def _handle_log_analytics(
self,
action: EventAction,
event_payload: EventPayload
) -> None:
"""Handle analytics logging"""
try:
analytics_data = {
"event_type": event_payload.event_type.value,
"user_id": event_payload.user_id,
"tenant_id": event_payload.tenant_id,
"timestamp": event_payload.timestamp.isoformat(),
"data": event_payload.data,
"custom_properties": action.config.get("properties", {})
}
# Log analytics (implement analytics system later)
logger.info(f"Analytics logged: {analytics_data}")
except Exception as e:
logger.error(f"Failed to log analytics: {e}")
raise
async def _handle_execute_webhook(
self,
action: EventAction,
event_payload: EventPayload
) -> None:
"""Handle webhook execution"""
try:
webhook_url = action.config.get("url")
method = action.config.get("method", "POST")
headers = action.config.get("headers", {})
if not webhook_url:
raise ValueError("webhook url required")
# Prepare webhook payload
webhook_payload = {
"event": event_payload.to_dict(),
"timestamp": datetime.utcnow().isoformat()
}
# Execute webhook (implement HTTP client later)
logger.info(f"Webhook executed: {method} {webhook_url}")
except Exception as e:
logger.error(f"Failed to execute webhook: {e}")
raise
async def _handle_create_assistant(
self,
action: EventAction,
event_payload: EventPayload
) -> None:
"""Handle automatic agent creation"""
try:
template_id = action.config.get("template_id", "general_assistant")
assistant_name = action.config.get("name", "Auto-created Agent")
# Create agent
agent_id = await self.assistant_manager.create_from_template(
template_id=template_id,
config={"name": assistant_name},
user_identifier=event_payload.user_id
)
# Emit agent created event
await self.emit_event(
EventType.ASSISTANT_CREATED,
event_payload.user_id,
event_payload.tenant_id,
{
"agent_id": agent_id,
"template_id": template_id,
"trigger_event": event_payload.event_id
}
)
except Exception as e:
logger.error(f"Failed to create agent: {e}")
raise
async def _handle_schedule_task(
self,
action: EventAction,
event_payload: EventPayload
) -> None:
"""Handle task scheduling"""
try:
task_type = action.config.get("task_type")
delay_minutes = action.config.get("delay_minutes", 0)
# Schedule task for future execution
scheduled_time = datetime.utcnow() + timedelta(minutes=delay_minutes)
logger.info(f"Task scheduled: {task_type} for {scheduled_time}")
except Exception as e:
logger.error(f"Failed to schedule task: {e}")
raise
# Subscription Management
async def create_subscription(
self,
user_id: str,
tenant_id: str,
event_type: EventType,
actions: List[EventActionConfig],
name: Optional[str] = None,
description: Optional[str] = None
) -> str:
"""Create an event subscription"""
try:
subscription = EventSubscription(
user_id=user_id,
tenant_id=tenant_id,
event_type=event_type.value,
name=name or f"Auto-subscription for {event_type.value}",
description=description,
is_active=True
)
self.db.add(subscription)
await self.db.flush()
# Create actions
for action_config in actions:
action = EventAction(
subscription_id=subscription.id,
action_type=action_config.action_type.value,
config=action_config.config,
delay_seconds=action_config.delay_seconds,
retry_count=action_config.retry_count,
retry_delay=action_config.retry_delay,
condition=action_config.condition
)
self.db.add(action)
await self.db.commit()
# Clear subscriptions cache
self._clear_subscriptions_cache()
logger.info(f"Event subscription created: {subscription.id} for {event_type.value}")
return subscription.id
except Exception as e:
await self.db.rollback()
logger.error(f"Failed to create subscription: {e}")
raise
async def get_user_subscriptions(
self,
user_id: str,
tenant_id: str
) -> List[EventSubscription]:
"""Get all subscriptions for a user"""
try:
query = select(EventSubscription).where(
and_(
EventSubscription.user_id == user_id,
EventSubscription.tenant_id == tenant_id
)
).options(selectinload(EventSubscription.actions))
result = await self.db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Failed to get user subscriptions: {e}")
raise
async def update_subscription_status(
self,
subscription_id: str,
user_id: str,
is_active: bool
) -> bool:
"""Update subscription status with ownership verification"""
try:
query = select(EventSubscription).where(
and_(
EventSubscription.id == subscription_id,
EventSubscription.user_id == user_id
)
)
result = await self.db.execute(query)
subscription = result.scalar_one_or_none()
if not subscription:
return False
subscription.is_active = is_active
subscription.updated_at = datetime.utcnow()
await self.db.commit()
# Clear subscriptions cache
self._clear_subscriptions_cache()
return True
except Exception as e:
await self.db.rollback()
logger.error(f"Failed to update subscription status: {e}")
raise
async def delete_subscription(
self,
subscription_id: str,
user_id: str
) -> bool:
"""Delete subscription with ownership verification"""
try:
query = select(EventSubscription).where(
and_(
EventSubscription.id == subscription_id,
EventSubscription.user_id == user_id
)
)
result = await self.db.execute(query)
subscription = result.scalar_one_or_none()
if not subscription:
return False
await self.db.delete(subscription)
await self.db.commit()
# Clear subscriptions cache
self._clear_subscriptions_cache()
return True
except Exception as e:
await self.db.rollback()
logger.error(f"Failed to delete subscription: {e}")
raise
# Utility Methods
async def _update_event_status(
self,
event_id: str,
status: str,
error_message: Optional[str] = None
) -> None:
"""Update event processing status"""
try:
query = select(Event).where(Event.event_id == event_id)
result = await self.db.execute(query)
event = result.scalar_one_or_none()
if event:
event.status = status
event.error_message = error_message
event.completed_at = datetime.utcnow()
await self.db.commit()
except Exception as e:
logger.error(f"Failed to update event status: {e}")
def _clear_subscriptions_cache(self) -> None:
"""Clear subscriptions cache"""
self.subscriptions_cache.clear()
self.cache_expiry = None
async def get_event_history(
self,
user_id: str,
tenant_id: str,
event_types: Optional[List[EventType]] = None,
limit: int = 100,
offset: int = 0
) -> List[Event]:
"""Get event history for user with filtering"""
try:
query = select(Event).where(
and_(
Event.user_id == user_id,
Event.tenant_id == tenant_id
)
)
if event_types:
event_type_values = [et.value for et in event_types]
query = query.where(Event.event_type.in_(event_type_values))
query = query.order_by(desc(Event.created_at)).offset(offset).limit(limit)
result = await self.db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Failed to get event history: {e}")
raise
async def get_event_statistics(
self,
user_id: str,
tenant_id: str,
days: int = 30
) -> Dict[str, Any]:
"""Get event statistics for user"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days)
query = select(Event).where(
and_(
Event.user_id == user_id,
Event.tenant_id == tenant_id,
Event.created_at >= cutoff_date
)
)
result = await self.db.execute(query)
events = result.scalars().all()
# Calculate statistics
stats = {
"total_events": len(events),
"events_by_type": {},
"events_by_status": {},
"average_events_per_day": 0
}
for event in events:
# Count by type
event_type = event.event_type
stats["events_by_type"][event_type] = stats["events_by_type"].get(event_type, 0) + 1
# Count by status
status = event.status
stats["events_by_status"][status] = stats["events_by_status"].get(status, 0) + 1
# Calculate average per day
if days > 0:
stats["average_events_per_day"] = round(len(events) / days, 2)
return stats
except Exception as e:
logger.error(f"Failed to get event statistics: {e}")
raise
# Factory function for dependency injection
async def get_event_service(db: AsyncSession = None) -> EventService:
"""Get event service instance"""
if db is None:
async with get_db_session() as session:
return EventService(session)
return EventService(db)
# Default event subscriptions for new users
DEFAULT_SUBSCRIPTIONS = [
{
"event_type": EventType.DOCUMENT_UPLOADED,
"actions": [
EventActionConfig(
action_type=ActionType.PROCESS_DOCUMENT,
config={"chunking_strategy": "hybrid"},
delay_seconds=5 # Small delay to ensure file is fully uploaded
)
]
},
{
"event_type": EventType.DOCUMENT_PROCESSED,
"actions": [
EventActionConfig(
action_type=ActionType.SEND_NOTIFICATION,
config={
"type": "success",
"message": "Document '{filename}' has been processed successfully with {chunk_count} chunks."
}
),
EventActionConfig(
action_type=ActionType.UPDATE_STATISTICS,
config={"type": "documents_processed", "increment": 1}
)
]
},
{
"event_type": EventType.CONVERSATION_STARTED,
"actions": [
EventActionConfig(
action_type=ActionType.LOG_ANALYTICS,
config={"properties": {"conversation_start": True}}
)
]
}
]
async def setup_default_subscriptions(
user_id: str,
tenant_id: str,
event_service: EventService
) -> None:
"""Setup default event subscriptions for new user"""
try:
for subscription_config in DEFAULT_SUBSCRIPTIONS:
await event_service.create_subscription(
user_id=user_id,
tenant_id=tenant_id,
event_type=subscription_config["event_type"],
actions=subscription_config["actions"],
name=f"Default: {subscription_config['event_type'].value}",
description="Automatically created default subscription"
)
logger.info(f"Default event subscriptions created for user {user_id}")
except Exception as e:
logger.error(f"Failed to setup default subscriptions: {e}")
raise

View File

@@ -0,0 +1,636 @@
"""
GT 2.0 Tenant Backend - External Service Management
Business logic for managing external web services with Resource Cluster integration
"""
import asyncio
import httpx
import json
import logging
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy import select, update, delete, and_
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.external_service import ExternalServiceInstance, ServiceAccessLog, ServiceTemplate
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class ExternalServiceManager:
"""Manages external service instances and Resource Cluster integration"""
def __init__(self, db: AsyncSession):
self.db = db
self.resource_cluster_base_url = settings.resource_cluster_url or "http://resource-cluster:8003"
self.capability_token = None
def set_capability_token(self, token: str):
"""Set capability token for Resource Cluster API calls"""
self.capability_token = token
async def create_service_instance(
self,
service_type: str,
service_name: str,
user_email: str,
config_overrides: Optional[Dict[str, Any]] = None,
template_id: Optional[str] = None
) -> ExternalServiceInstance:
"""Create a new external service instance"""
# Validate service type
supported_services = ['ctfd', 'canvas', 'guacamole']
if service_type not in supported_services:
raise ValueError(f"Unsupported service type: {service_type}")
# Load template if provided
template = None
if template_id:
template = await self.get_service_template(template_id)
if not template:
raise ValueError(f"Template {template_id} not found")
# Prepare configuration
service_config = {}
if template:
service_config.update(template.default_config)
if config_overrides:
service_config.update(config_overrides)
# Call Resource Cluster to create instance
resource_instance = await self._create_resource_cluster_instance(
service_type=service_type,
config_overrides=service_config
)
# Create database record
instance = ExternalServiceInstance(
service_type=service_type,
service_name=service_name,
description=f"{service_type.title()} instance for {user_email}",
resource_instance_id=resource_instance['instance_id'],
endpoint_url=resource_instance['endpoint_url'],
status=resource_instance['status'],
service_config=service_config,
created_by=user_email,
allowed_users=[user_email],
resource_limits=template.resource_requirements if template else {},
auto_start=template.default_config.get('auto_start', True) if template else True
)
self.db.add(instance)
await self.db.commit()
await self.db.refresh(instance)
logger.info(
f"Created {service_type} service instance {instance.id} "
f"for user {user_email}"
)
return instance
async def _create_resource_cluster_instance(
self,
service_type: str,
config_overrides: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create instance via Resource Cluster API with zero downtime error handling"""
if not self.capability_token:
raise ValueError("Capability token not set")
max_retries = 3
base_delay = 1.0
for attempt in range(max_retries):
try:
timeout = httpx.Timeout(60.0, connect=10.0)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{self.resource_cluster_base_url}/api/v1/services/instances",
json={
"service_type": service_type,
"config_overrides": config_overrides
},
headers={
"Authorization": f"Bearer {self.capability_token}",
"Content-Type": "application/json"
}
)
if response.status_code == 200:
return response.json()
elif response.status_code in [500, 502, 503, 504] and attempt < max_retries - 1:
# Retry for server errors
delay = base_delay * (2 ** attempt)
logger.warning(f"Service creation failed (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
await asyncio.sleep(delay)
continue
else:
try:
error_detail = response.json().get('detail', f'HTTP {response.status_code}')
except:
error_detail = f'HTTP {response.status_code}'
raise RuntimeError(f"Failed to create service instance: {error_detail}")
except httpx.TimeoutException:
if attempt < max_retries - 1:
delay = base_delay * (2 ** attempt)
logger.warning(f"Service creation timeout (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
await asyncio.sleep(delay)
continue
else:
raise RuntimeError("Failed to create service instance: timeout after retries")
except httpx.RequestError as e:
if attempt < max_retries - 1:
delay = base_delay * (2 ** attempt)
logger.warning(f"Service creation request error (attempt {attempt + 1}/{max_retries}): {e}, retrying in {delay}s")
await asyncio.sleep(delay)
continue
else:
raise RuntimeError(f"Failed to create service instance: {e}")
raise RuntimeError("Failed to create service instance: maximum retries exceeded")
async def get_service_instance(
self,
instance_id: str,
user_email: str
) -> Optional[ExternalServiceInstance]:
"""Get service instance with access control"""
query = select(ExternalServiceInstance).where(
and_(
ExternalServiceInstance.id == instance_id,
ExternalServiceInstance.allowed_users.op('json_extract_path_text')('*').op('@>')([user_email])
)
)
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def list_user_services(
self,
user_email: str,
service_type: Optional[str] = None,
status: Optional[str] = None
) -> List[ExternalServiceInstance]:
"""List all services accessible to a user"""
query = select(ExternalServiceInstance).where(
ExternalServiceInstance.allowed_users.op('json_extract_path_text')('*').op('@>')([user_email])
)
if service_type:
query = query.where(ExternalServiceInstance.service_type == service_type)
if status:
query = query.where(ExternalServiceInstance.status == status)
query = query.order_by(ExternalServiceInstance.created_at.desc())
result = await self.db.execute(query)
return result.scalars().all()
async def stop_service_instance(
self,
instance_id: str,
user_email: str
) -> bool:
"""Stop a service instance"""
# Check access
instance = await self.get_service_instance(instance_id, user_email)
if not instance:
raise ValueError(f"Service instance {instance_id} not found or access denied")
# Call Resource Cluster to stop instance
success = await self._stop_resource_cluster_instance(instance.resource_instance_id)
if success:
# Update database status
instance.status = 'stopped'
instance.updated_at = datetime.utcnow()
await self.db.commit()
logger.info(
f"Stopped {instance.service_type} instance {instance_id} "
f"by user {user_email}"
)
return success
async def _stop_resource_cluster_instance(self, resource_instance_id: str) -> bool:
"""Stop instance via Resource Cluster API with zero downtime error handling"""
if not self.capability_token:
raise ValueError("Capability token not set")
max_retries = 3
base_delay = 1.0
for attempt in range(max_retries):
try:
timeout = httpx.Timeout(30.0, connect=10.0)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.delete(
f"{self.resource_cluster_base_url}/api/v1/services/instances/{resource_instance_id}",
headers={
"Authorization": f"Bearer {self.capability_token}"
}
)
if response.status_code == 200:
return True
elif response.status_code == 404:
# Instance already gone, consider it successfully stopped
logger.info(f"Instance {resource_instance_id} not found, assuming already stopped")
return True
elif response.status_code in [500, 502, 503, 504] and attempt < max_retries - 1:
# Retry for server errors
delay = base_delay * (2 ** attempt)
logger.warning(f"Instance stop failed (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
await asyncio.sleep(delay)
continue
else:
logger.error(f"Failed to stop instance {resource_instance_id}: HTTP {response.status_code}")
return False
except httpx.TimeoutException:
if attempt < max_retries - 1:
delay = base_delay * (2 ** attempt)
logger.warning(f"Instance stop timeout (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
await asyncio.sleep(delay)
continue
else:
logger.error(f"Failed to stop instance {resource_instance_id}: timeout after retries")
return False
except httpx.RequestError as e:
if attempt < max_retries - 1:
delay = base_delay * (2 ** attempt)
logger.warning(f"Instance stop request error (attempt {attempt + 1}/{max_retries}): {e}, retrying in {delay}s")
await asyncio.sleep(delay)
continue
else:
logger.error(f"Failed to stop instance {resource_instance_id}: {e}")
return False
logger.error(f"Failed to stop instance {resource_instance_id}: maximum retries exceeded")
return False
async def get_service_health(
self,
instance_id: str,
user_email: str
) -> Dict[str, Any]:
"""Get service health status"""
# Check access
instance = await self.get_service_instance(instance_id, user_email)
if not instance:
raise ValueError(f"Service instance {instance_id} not found or access denied")
# Get health from Resource Cluster
health = await self._get_resource_cluster_health(instance.resource_instance_id)
# Update instance health status
instance.health_status = health.get('status', 'unknown')
instance.last_health_check = datetime.utcnow()
if health.get('restart_count', 0) != instance.restart_count:
instance.restart_count = health.get('restart_count', 0)
await self.db.commit()
return health
async def _get_resource_cluster_health(self, resource_instance_id: str) -> Dict[str, Any]:
"""Get health status via Resource Cluster API with zero downtime error handling"""
if not self.capability_token:
raise ValueError("Capability token not set")
try:
timeout = httpx.Timeout(10.0, connect=5.0)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(
f"{self.resource_cluster_base_url}/api/v1/services/health/{resource_instance_id}",
headers={
"Authorization": f"Bearer {self.capability_token}"
}
)
if response.status_code == 200:
return response.json()
elif response.status_code == 404:
return {
'status': 'not_found',
'error': 'Instance not found'
}
else:
return {
'status': 'error',
'error': f'Health check failed: HTTP {response.status_code}'
}
except httpx.TimeoutException:
logger.warning(f"Health check timeout for instance {resource_instance_id}")
return {
'status': 'timeout',
'error': 'Health check timeout'
}
except httpx.RequestError as e:
logger.warning(f"Health check request error for instance {resource_instance_id}: {e}")
return {
'status': 'connection_error',
'error': f'Connection error: {e}'
}
async def generate_sso_token(
self,
instance_id: str,
user_email: str
) -> Dict[str, Any]:
"""Generate SSO token for iframe embedding"""
# Check access
instance = await self.get_service_instance(instance_id, user_email)
if not instance:
raise ValueError(f"Service instance {instance_id} not found or access denied")
# Generate SSO token via Resource Cluster
sso_data = await self._generate_resource_cluster_sso_token(instance.resource_instance_id)
# Update last accessed time
instance.last_accessed = datetime.utcnow()
await self.db.commit()
return sso_data
async def _generate_resource_cluster_sso_token(self, resource_instance_id: str) -> Dict[str, Any]:
"""Generate SSO token via Resource Cluster API with zero downtime error handling"""
if not self.capability_token:
raise ValueError("Capability token not set")
max_retries = 3
base_delay = 1.0
for attempt in range(max_retries):
try:
timeout = httpx.Timeout(10.0, connect=5.0)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{self.resource_cluster_base_url}/api/v1/services/sso-token/{resource_instance_id}",
headers={
"Authorization": f"Bearer {self.capability_token}"
}
)
if response.status_code == 200:
return response.json()
elif response.status_code in [500, 502, 503, 504] and attempt < max_retries - 1:
# Retry for server errors
delay = base_delay * (2 ** attempt)
logger.warning(f"SSO token generation failed (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
await asyncio.sleep(delay)
continue
else:
try:
error_detail = response.json().get('detail', f'HTTP {response.status_code}')
except:
error_detail = f'HTTP {response.status_code}'
raise RuntimeError(f"Failed to generate SSO token: {error_detail}")
except httpx.TimeoutException:
if attempt < max_retries - 1:
delay = base_delay * (2 ** attempt)
logger.warning(f"SSO token generation timeout (attempt {attempt + 1}/{max_retries}), retrying in {delay}s")
await asyncio.sleep(delay)
continue
else:
raise RuntimeError("Failed to generate SSO token: timeout after retries")
except httpx.RequestError as e:
if attempt < max_retries - 1:
delay = base_delay * (2 ** attempt)
logger.warning(f"SSO token generation request error (attempt {attempt + 1}/{max_retries}): {e}, retrying in {delay}s")
await asyncio.sleep(delay)
continue
else:
raise RuntimeError(f"Failed to generate SSO token: {e}")
raise RuntimeError("Failed to generate SSO token: maximum retries exceeded")
async def log_service_access(
self,
service_instance_id: str,
service_type: str,
user_email: str,
access_type: str,
session_id: str,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
referer: Optional[str] = None,
session_duration_seconds: Optional[int] = None,
actions_performed: Optional[List[str]] = None
) -> ServiceAccessLog:
"""Log service access event"""
access_log = ServiceAccessLog(
service_instance_id=service_instance_id,
service_type=service_type,
user_email=user_email,
session_id=session_id,
access_type=access_type,
ip_address=ip_address,
user_agent=user_agent,
referer=referer,
session_duration_seconds=session_duration_seconds,
actions_performed=actions_performed or []
)
self.db.add(access_log)
await self.db.commit()
await self.db.refresh(access_log)
return access_log
async def get_service_analytics(
self,
instance_id: str,
user_email: str,
days: int = 30
) -> Dict[str, Any]:
"""Get service usage analytics"""
# Check access
instance = await self.get_service_instance(instance_id, user_email)
if not instance:
raise ValueError(f"Service instance {instance_id} not found or access denied")
# Query access logs
since_date = datetime.utcnow() - timedelta(days=days)
query = select(ServiceAccessLog).where(
and_(
ServiceAccessLog.service_instance_id == instance_id,
ServiceAccessLog.timestamp >= since_date
)
).order_by(ServiceAccessLog.timestamp.desc())
result = await self.db.execute(query)
access_logs = result.scalars().all()
# Compute analytics
total_sessions = len(set(log.session_id for log in access_logs))
total_time_seconds = sum(
log.session_duration_seconds or 0
for log in access_logs
if log.session_duration_seconds
)
unique_users = len(set(log.user_email for log in access_logs))
# Group by day for trend analysis
daily_usage = {}
for log in access_logs:
day = log.timestamp.date().isoformat()
if day not in daily_usage:
daily_usage[day] = {'sessions': 0, 'users': set()}
if log.access_type == 'login':
daily_usage[day]['sessions'] += 1
daily_usage[day]['users'].add(log.user_email)
# Convert sets to counts
for day_data in daily_usage.values():
day_data['unique_users'] = len(day_data['users'])
del day_data['users']
return {
'instance_id': instance_id,
'service_type': instance.service_type,
'service_name': instance.service_name,
'analytics_period_days': days,
'total_sessions': total_sessions,
'total_time_hours': round(total_time_seconds / 3600, 1),
'unique_users': unique_users,
'average_session_duration_minutes': round(
total_time_seconds / max(total_sessions, 1) / 60, 1
),
'daily_usage': daily_usage,
'health_status': instance.health_status,
'uptime_percentage': self._calculate_uptime_percentage(access_logs, days),
'last_accessed': instance.last_accessed.isoformat() if instance.last_accessed else None,
'created_at': instance.created_at.isoformat()
}
def _calculate_uptime_percentage(self, access_logs: List[ServiceAccessLog], days: int) -> float:
"""Calculate approximate uptime percentage based on access patterns"""
if not access_logs:
return 0.0
# Simple heuristic: if we have recent login events, assume service is up
recent_logins = [
log for log in access_logs
if log.access_type == 'login' and
log.timestamp > datetime.utcnow() - timedelta(days=1)
]
if recent_logins:
return 95.0 # Assume good uptime if recently accessed
elif len(access_logs) > 0:
return 85.0 # Some historical usage
else:
return 50.0 # No usage data
async def create_service_template(
self,
template_name: str,
service_type: str,
description: str,
default_config: Dict[str, Any],
created_by: str,
**kwargs
) -> ServiceTemplate:
"""Create a new service template"""
template = ServiceTemplate(
template_name=template_name,
service_type=service_type,
description=description,
default_config=default_config,
created_by=created_by,
**kwargs
)
self.db.add(template)
await self.db.commit()
await self.db.refresh(template)
return template
async def get_service_template(self, template_id: str) -> Optional[ServiceTemplate]:
"""Get service template by ID"""
query = select(ServiceTemplate).where(ServiceTemplate.id == template_id)
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def list_service_templates(
self,
service_type: Optional[str] = None,
category: Optional[str] = None,
public_only: bool = True
) -> List[ServiceTemplate]:
"""List available service templates"""
query = select(ServiceTemplate).where(ServiceTemplate.is_active == True)
if public_only:
query = query.where(ServiceTemplate.is_public == True)
if service_type:
query = query.where(ServiceTemplate.service_type == service_type)
if category:
query = query.where(ServiceTemplate.category == category)
query = query.order_by(ServiceTemplate.usage_count.desc())
result = await self.db.execute(query)
return result.scalars().all()
async def share_service_instance(
self,
instance_id: str,
owner_email: str,
share_with_emails: List[str],
access_level: str = 'read'
) -> bool:
"""Share service instance with other users"""
# Check owner access
instance = await self.get_service_instance(instance_id, owner_email)
if not instance:
raise ValueError(f"Service instance {instance_id} not found or access denied")
if instance.created_by != owner_email:
raise ValueError("Only the instance creator can share access")
# Update allowed users
current_users = set(instance.allowed_users)
new_users = current_users.union(set(share_with_emails))
instance.allowed_users = list(new_users)
instance.access_level = 'team' if len(new_users) > 1 else 'private'
instance.updated_at = datetime.utcnow()
await self.db.commit()
logger.info(
f"Shared {instance.service_type} instance {instance_id} "
f"with {len(share_with_emails)} users"
)
return True

View File

@@ -0,0 +1,950 @@
from typing import List, Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, and_, or_, desc
from sqlalchemy.orm import selectinload
from app.models.game import (
GameSession, PuzzleSession, PhilosophicalDialogue,
LearningAnalytics, GameTemplate
)
import json
import uuid
from datetime import datetime, timedelta
import random
class GameService:
"""Service for managing strategic games (Chess, Go, etc.)"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_available_games(self, user_id: str) -> Dict[str, Any]:
"""Get available games and user's current progress"""
# Get user's analytics to determine appropriate difficulty
analytics = await self.get_or_create_analytics(user_id)
# Available game types with current user ratings
available_games = {
"chess": {
"name": "Strategic Chess",
"description": "Classical chess with AI analysis and move commentary",
"current_rating": analytics.chess_rating,
"difficulty_levels": ["beginner", "intermediate", "advanced", "expert"],
"features": ["move_analysis", "position_evaluation", "opening_guidance", "endgame_tutorials"],
"estimated_time": "15-45 minutes",
"skills_developed": ["strategic_planning", "pattern_recognition", "calculation_depth"]
},
"go": {
"name": "Strategic Go",
"description": "The ancient game of Go with territory and influence analysis",
"current_rating": analytics.go_rating,
"difficulty_levels": ["beginner", "intermediate", "advanced", "expert"],
"features": ["territory_visualization", "influence_mapping", "joseki_suggestions", "life_death_training"],
"estimated_time": "20-60 minutes",
"skills_developed": ["strategic_concepts", "reading_ability", "intuitive_judgment"]
}
}
# Get recent game sessions
recent_sessions_query = select(GameSession).where(
and_(
GameSession.user_id == user_id,
GameSession.started_at >= datetime.utcnow() - timedelta(days=7)
)
).order_by(desc(GameSession.started_at)).limit(5)
result = await self.db.execute(recent_sessions_query)
recent_sessions = result.scalars().all()
return {
"available_games": available_games,
"recent_sessions": [self._serialize_game_session(session) for session in recent_sessions],
"user_analytics": self._serialize_analytics(analytics)
}
async def start_game_session(self, user_id: str, game_type: str, config: Dict[str, Any]) -> GameSession:
"""Start a new game session"""
analytics = await self.get_or_create_analytics(user_id)
# Determine AI opponent configuration based on user rating
ai_config = self._configure_ai_opponent(game_type, analytics, config.get('difficulty', 'intermediate'))
session = GameSession(
user_id=user_id,
game_type=game_type,
game_name=config.get('name', f"{game_type.title()} Game"),
difficulty_level=config.get('difficulty', 'intermediate'),
ai_opponent_config=ai_config,
game_rules=self._get_game_rules(game_type),
current_state=self._initialize_game_state(game_type),
current_rating=getattr(analytics, f"{game_type}_rating", 1200)
)
self.db.add(session)
await self.db.commit()
await self.db.refresh(session)
return session
async def make_move(self, session_id: str, user_id: str, move_data: Dict[str, Any]) -> Dict[str, Any]:
"""Process a user move and generate AI response"""
session_query = select(GameSession).where(
and_(GameSession.id == session_id, GameSession.user_id == user_id)
)
result = await self.db.execute(session_query)
session = result.scalar_one_or_none()
if not session or session.game_status != 'active':
raise ValueError("Game session not found or not active")
# Process user move
move_result = self._process_move(session, move_data, is_ai=False)
# Generate AI response
ai_move = self._generate_ai_move(session)
ai_result = self._process_move(session, ai_move, is_ai=True)
# Update session state
session.moves_count += 2 # User move + AI move
session.last_move_at = datetime.utcnow()
session.move_history = session.move_history + [move_result, ai_result]
# Check for game end conditions
game_status = self._check_game_status(session)
if game_status['ended']:
session.game_status = 'completed'
session.completed_at = datetime.utcnow()
session.outcome = game_status['outcome']
session.ai_analysis = self._generate_game_analysis(session)
session.learning_insights = self._extract_learning_insights(session)
# Update user analytics
await self._update_analytics_after_game(session)
await self.db.commit()
await self.db.refresh(session)
return {
"user_move": move_result,
"ai_move": ai_result,
"current_state": session.current_state,
"game_status": game_status,
"analysis": self._generate_move_analysis(session, move_data) if game_status.get('ended') else None
}
async def get_game_analysis(self, session_id: str, user_id: str) -> Dict[str, Any]:
"""Get detailed analysis of the current game position"""
session_query = select(GameSession).where(
and_(GameSession.id == session_id, GameSession.user_id == user_id)
)
result = await self.db.execute(session_query)
session = result.scalar_one_or_none()
if not session:
raise ValueError("Game session not found")
analysis = {
"position_evaluation": self._evaluate_position(session),
"best_moves": self._get_best_moves(session),
"strategic_insights": self._get_strategic_insights(session),
"learning_points": self._get_learning_points(session),
"skill_assessment": self._assess_current_skill(session)
}
return analysis
async def get_user_game_history(self, user_id: str, game_type: Optional[str] = None, limit: int = 20) -> List[Dict[str, Any]]:
"""Get user's game history with performance trends"""
query = select(GameSession).where(GameSession.user_id == user_id)
if game_type:
query = query.where(GameSession.game_type == game_type)
query = query.order_by(desc(GameSession.started_at)).limit(limit)
result = await self.db.execute(query)
sessions = result.scalars().all()
return [self._serialize_game_session(session) for session in sessions]
def _configure_ai_opponent(self, game_type: str, analytics: LearningAnalytics, difficulty: str) -> Dict[str, Any]:
"""Configure AI opponent based on user skill level"""
base_config = {
"personality": "teaching", # teaching, aggressive, defensive, balanced
"explanation_mode": True,
"move_commentary": True,
"mistake_correction": True,
"hint_availability": True
}
if game_type == "chess":
rating_map = {
"beginner": analytics.chess_rating - 200,
"intermediate": analytics.chess_rating,
"advanced": analytics.chess_rating + 200,
"expert": analytics.chess_rating + 400
}
base_config.update({
"engine_strength": rating_map.get(difficulty, analytics.chess_rating),
"opening_book": True,
"endgame_tablebase": difficulty in ["advanced", "expert"],
"thinking_time": {"beginner": 1, "intermediate": 3, "advanced": 5, "expert": 10}[difficulty]
})
elif game_type == "go":
base_config.update({
"handicap_stones": {"beginner": 4, "intermediate": 2, "advanced": 0, "expert": 0}[difficulty],
"commentary_level": {"beginner": "detailed", "intermediate": "moderate", "advanced": "minimal", "expert": "minimal"}[difficulty],
"joseki_teaching": difficulty in ["beginner", "intermediate"]
})
return base_config
def _get_game_rules(self, game_type: str) -> Dict[str, Any]:
"""Get standard rules for the game type"""
rules = {
"chess": {
"board_size": "8x8",
"time_control": "unlimited",
"special_rules": ["castling", "en_passant", "promotion"],
"victory_conditions": ["checkmate", "resignation", "time_forfeit"]
},
"go": {
"board_size": "19x19",
"komi": 6.5,
"special_rules": ["ko_rule", "suicide_rule"],
"victory_conditions": ["territory_count", "resignation"]
}
}
return rules.get(game_type, {})
def _initialize_game_state(self, game_type: str) -> Dict[str, Any]:
"""Initialize the starting game state"""
if game_type == "chess":
return {
"board": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", # FEN notation
"to_move": "white",
"castling_rights": "KQkq",
"en_passant": None,
"halfmove_clock": 0,
"fullmove_number": 1
}
elif game_type == "go":
return {
"board": [[0 for _ in range(19)] for _ in range(19)], # 0=empty, 1=black, 2=white
"to_move": "black",
"captured_stones": {"black": 0, "white": 0},
"ko_position": None,
"move_number": 0
}
return {}
def _process_move(self, session: GameSession, move_data: Dict[str, Any], is_ai: bool) -> Dict[str, Any]:
"""Process a move and update game state"""
# This would contain game-specific logic for processing moves
# For now, return a mock processed move
return {
"move": move_data,
"is_ai": is_ai,
"timestamp": datetime.utcnow().isoformat(),
"evaluation": random.uniform(-1.0, 1.0), # Mock evaluation
"commentary": "Good move!" if not is_ai else "AI plays strategically"
}
def _generate_ai_move(self, session: GameSession) -> Dict[str, Any]:
"""Generate AI move based on current position"""
# Mock AI move generation
if session.game_type == "chess":
return {"from": "e2", "to": "e4", "piece": "pawn"}
elif session.game_type == "go":
return {"x": 10, "y": 10, "color": "white"}
return {}
def _check_game_status(self, session: GameSession) -> Dict[str, Any]:
"""Check if game has ended and determine outcome"""
# Mock game status check
return {
"ended": session.moves_count > 20, # Mock end condition
"outcome": random.choice(["win", "loss", "draw"]) if session.moves_count > 20 else None
}
def _generate_game_analysis(self, session: GameSession) -> Dict[str, Any]:
"""Generate comprehensive game analysis"""
return {
"game_quality": random.uniform(0.6, 0.95),
"key_moments": [
{"move": 5, "evaluation": "Excellent opening choice"},
{"move": 12, "evaluation": "Missed tactical opportunity"},
{"move": 18, "evaluation": "Strong endgame technique"}
],
"skill_demonstration": {
"tactical_awareness": random.uniform(0.5, 1.0),
"strategic_understanding": random.uniform(0.4, 0.9),
"time_management": random.uniform(0.6, 1.0)
}
}
def _extract_learning_insights(self, session: GameSession) -> List[str]:
"""Extract key learning insights from the game"""
insights = [
"Focus on controlling the center in the opening",
"Look for tactical combinations before moving",
"Consider your opponent's threats before making your move",
"Practice endgame fundamentals"
]
return random.sample(insights, 2)
async def _update_analytics_after_game(self, session: GameSession):
"""Update user analytics based on game performance"""
analytics = await self.get_or_create_analytics(session.user_id)
# Update game-specific rating
if session.game_type == "chess":
rating_change = self._calculate_rating_change(session)
analytics.chess_rating += rating_change
elif session.game_type == "go":
rating_change = self._calculate_rating_change(session)
analytics.go_rating += rating_change
# Update cognitive skills based on performance
if session.ai_analysis:
skill_updates = session.ai_analysis.get("skill_demonstration", {})
analytics.strategic_thinking_score = self._update_skill_score(
analytics.strategic_thinking_score,
skill_updates.get("strategic_understanding", 0.5)
)
analytics.total_sessions += 1
analytics.total_time_minutes += session.time_spent_seconds // 60
analytics.last_activity_date = datetime.utcnow()
await self.db.commit()
def _calculate_rating_change(self, session: GameSession) -> int:
"""Calculate ELO rating change based on game outcome"""
base_change = 30
if session.outcome == "win":
return base_change
elif session.outcome == "loss":
return -base_change
else: # draw
return 0
def _update_skill_score(self, current_score: float, performance: float) -> float:
"""Update skill score using exponential moving average"""
alpha = 0.1 # Learning rate
return current_score * (1 - alpha) + performance * 100 * alpha
async def get_or_create_analytics(self, user_id: str) -> LearningAnalytics:
"""Get or create learning analytics for user"""
query = select(LearningAnalytics).where(LearningAnalytics.user_id == user_id)
result = await self.db.execute(query)
analytics = result.scalar_one_or_none()
if not analytics:
analytics = LearningAnalytics(user_id=user_id)
self.db.add(analytics)
await self.db.commit()
await self.db.refresh(analytics)
return analytics
def _serialize_game_session(self, session: GameSession) -> Dict[str, Any]:
"""Serialize game session for API response"""
return {
"id": session.id,
"game_type": session.game_type,
"game_name": session.game_name,
"difficulty_level": session.difficulty_level,
"game_status": session.game_status,
"moves_count": session.moves_count,
"time_spent_seconds": session.time_spent_seconds,
"current_rating": session.current_rating,
"outcome": session.outcome,
"started_at": session.started_at.isoformat() if session.started_at else None,
"completed_at": session.completed_at.isoformat() if session.completed_at else None,
"learning_insights": session.learning_insights
}
def _serialize_analytics(self, analytics: LearningAnalytics) -> Dict[str, Any]:
"""Serialize learning analytics for API response"""
return {
"chess_rating": analytics.chess_rating,
"go_rating": analytics.go_rating,
"total_sessions": analytics.total_sessions,
"total_time_minutes": analytics.total_time_minutes,
"current_streak_days": analytics.current_streak_days,
"strategic_thinking_score": analytics.strategic_thinking_score,
"logical_reasoning_score": analytics.logical_reasoning_score,
"pattern_recognition_score": analytics.pattern_recognition_score,
"ai_collaboration_skills": {
"dependency_index": analytics.ai_dependency_index,
"prompt_engineering": analytics.prompt_engineering_skill,
"output_evaluation": analytics.ai_output_evaluation_skill,
"collaborative_solving": analytics.collaborative_problem_solving
}
}
# Mock helper methods for analysis (would be replaced with actual game engines)
def _evaluate_position(self, session: GameSession) -> Dict[str, Any]:
return {"advantage": random.uniform(-2.0, 2.0), "complexity": random.uniform(0.3, 1.0)}
def _get_best_moves(self, session: GameSession) -> List[Dict[str, Any]]:
return [{"move": "e4", "evaluation": 0.3}, {"move": "d4", "evaluation": 0.2}]
def _get_strategic_insights(self, session: GameSession) -> List[str]:
return ["Control the center", "Develop pieces actively", "Ensure king safety"]
def _get_learning_points(self, session: GameSession) -> List[str]:
return ["Practice tactical patterns", "Study endgame principles"]
def _assess_current_skill(self, session: GameSession) -> Dict[str, float]:
return {
"tactical_strength": random.uniform(0.4, 0.9),
"positional_understanding": random.uniform(0.3, 0.8),
"calculation_accuracy": random.uniform(0.5, 0.95)
}
def _generate_move_analysis(self, session: GameSession, move_data: Dict[str, Any]) -> Dict[str, Any]:
return {
"move_quality": random.uniform(0.6, 1.0),
"alternatives": ["Better move: Nf3", "Consider: Bc4"],
"consequences": "Leads to tactical complications"
}
class PuzzleService:
"""Service for managing logic puzzles and reasoning challenges"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_available_puzzles(self, user_id: str, category: Optional[str] = None) -> Dict[str, Any]:
"""Get available puzzles based on user skill level"""
analytics = await self._get_analytics(user_id)
puzzle_categories = {
"lateral_thinking": {
"name": "Lateral Thinking",
"description": "Creative problem-solving that requires thinking outside conventional patterns",
"difficulty_range": [1, 10],
"skills_developed": ["creative_thinking", "assumption_challenging", "perspective_shifting"]
},
"logical_deduction": {
"name": "Logical Deduction",
"description": "Step-by-step reasoning to reach logical conclusions",
"difficulty_range": [1, 8],
"skills_developed": ["systematic_thinking", "evidence_evaluation", "logical_consistency"]
},
"mathematical_reasoning": {
"name": "Mathematical Reasoning",
"description": "Number patterns, sequences, and mathematical logic",
"difficulty_range": [1, 9],
"skills_developed": ["pattern_recognition", "analytical_thinking", "quantitative_reasoning"]
},
"spatial_reasoning": {
"name": "Spatial Reasoning",
"description": "3D visualization and spatial relationship puzzles",
"difficulty_range": [1, 7],
"skills_developed": ["spatial_visualization", "mental_rotation", "pattern_matching"]
}
}
if category and category in puzzle_categories:
return {"category": puzzle_categories[category]}
return {
"categories": puzzle_categories,
"recommended_difficulty": min(analytics.puzzle_solving_level + 1, 10),
"user_progress": {
"current_level": analytics.puzzle_solving_level,
"puzzles_solved_total": analytics.total_sessions,
"favorite_categories": [] # Would be determined from session history
}
}
async def start_puzzle_session(self, user_id: str, puzzle_type: str, difficulty: int = None) -> PuzzleSession:
"""Start a new puzzle session"""
analytics = await self._get_analytics(user_id)
if difficulty is None:
difficulty = min(analytics.puzzle_solving_level + 1, 10)
puzzle_def = self._generate_puzzle(puzzle_type, difficulty)
session = PuzzleSession(
user_id=user_id,
puzzle_type=puzzle_type,
puzzle_category=puzzle_def["category"],
puzzle_definition=puzzle_def["definition"],
solution_criteria=puzzle_def["solution_criteria"],
difficulty_rating=difficulty,
estimated_time_minutes=puzzle_def.get("estimated_time", 10)
)
self.db.add(session)
await self.db.commit()
await self.db.refresh(session)
return session
async def submit_solution(self, session_id: str, user_id: str, solution: Dict[str, Any], reasoning: str = None) -> Dict[str, Any]:
"""Submit a solution attempt for evaluation"""
session_query = select(PuzzleSession).where(
and_(PuzzleSession.id == session_id, PuzzleSession.user_id == user_id)
)
result = await self.db.execute(session_query)
session = result.scalar_one_or_none()
if not session or session.session_status not in ['active', 'hint_requested']:
raise ValueError("Puzzle session not found or not active")
# Evaluate solution
evaluation = self._evaluate_solution(session, solution, reasoning)
# Update session
session.attempts_count += 1
session.current_attempt = solution
session.attempt_history = session.attempt_history + [solution]
session.reasoning_explanation = reasoning
session.time_spent_seconds += 60 # Mock time tracking
if evaluation["correct"]:
session.is_solved = True
session.session_status = 'solved'
session.solved_at = datetime.utcnow()
session.solution_quality_score = evaluation["quality_score"]
session.ai_feedback = evaluation["feedback"]
# Update user analytics
await self._update_analytics_after_puzzle(session, evaluation)
await self.db.commit()
await self.db.refresh(session)
return {
"correct": evaluation["correct"],
"feedback": evaluation["feedback"],
"quality_score": evaluation.get("quality_score", 0),
"hints_available": not evaluation["correct"],
"next_difficulty_recommendation": evaluation.get("next_difficulty", session.difficulty_rating)
}
async def get_hint(self, session_id: str, user_id: str, hint_level: int = 1) -> Dict[str, Any]:
"""Provide a hint for the current puzzle"""
session_query = select(PuzzleSession).where(
and_(PuzzleSession.id == session_id, PuzzleSession.user_id == user_id)
)
result = await self.db.execute(session_query)
session = result.scalar_one_or_none()
if not session or session.session_status not in ['active', 'hint_requested']:
raise ValueError("Puzzle session not found or not active")
hint = self._generate_hint(session, hint_level)
# Update session
session.hints_used_count += 1
session.hints_given = session.hints_given + [hint]
session.session_status = 'hint_requested'
await self.db.commit()
return {
"hint": hint["text"],
"hint_type": hint["type"],
"points_deducted": hint.get("point_penalty", 5),
"hints_remaining": max(0, 3 - session.hints_used_count)
}
def _generate_puzzle(self, puzzle_type: str, difficulty: int) -> Dict[str, Any]:
"""Generate a puzzle based on type and difficulty"""
puzzles = {
"lateral_thinking": {
"definition": {
"question": "A man lives on the 20th floor of an apartment building. Every morning he takes the elevator down to the ground floor. When he comes home, he takes the elevator to the 10th floor and walks the rest of the way... except on rainy days, when he takes the elevator all the way to the 20th floor. Why?",
"context": "This is a classic lateral thinking puzzle that requires you to challenge your assumptions."
},
"solution_criteria": {
"key_insights": ["height limitation", "elevator button accessibility"],
"required_elements": ["umbrella usage", "physical constraint explanation"]
},
"category": "lateral_thinking",
"estimated_time": 15
},
"logical_deduction": {
"definition": {
"question": "Five people of different nationalities live in five houses of different colors, drink different beverages, smoke different brands, and keep different pets. Using the given clues, determine who owns the fish.",
"clues": [
"The Brit lives in the red house",
"The Swede keeps dogs as pets",
"The Dane drinks tea",
"The green house is on the left of the white house",
"The green house's owner drinks coffee"
]
},
"solution_criteria": {
"format": "grid_solution",
"required_mapping": ["nationality", "house_color", "beverage", "smoke", "pet"]
},
"category": "logical_deduction",
"estimated_time": 25
}
}
return puzzles.get(puzzle_type, puzzles["lateral_thinking"])
def _evaluate_solution(self, session: PuzzleSession, solution: Dict[str, Any], reasoning: str = None) -> Dict[str, Any]:
"""Evaluate a puzzle solution"""
# Mock evaluation logic
is_correct = random.choice([True, False]) # Would be actual evaluation
quality_score = random.uniform(60, 95) if is_correct else random.uniform(20, 60)
feedback = {
"correctness": "Excellent reasoning!" if is_correct else "Not quite right, but good thinking.",
"reasoning_quality": "Clear logical steps" if reasoning else "Consider explaining your reasoning",
"suggestions": ["Try thinking about the constraints differently", "What assumptions might you be making?"]
}
return {
"correct": is_correct,
"quality_score": quality_score,
"feedback": feedback,
"next_difficulty": session.difficulty_rating + (1 if is_correct else 0)
}
def _generate_hint(self, session: PuzzleSession, hint_level: int) -> Dict[str, Any]:
"""Generate appropriate hint based on puzzle and user progress"""
hints = [
{"text": "Consider what might be different about rainy days", "type": "direction"},
{"text": "Think about the man's physical characteristics", "type": "clue"},
{"text": "What would an umbrella help him reach?", "type": "leading"}
]
return hints[min(hint_level - 1, len(hints) - 1)]
async def _update_analytics_after_puzzle(self, session: PuzzleSession, evaluation: Dict[str, Any]):
"""Update analytics after puzzle completion"""
analytics = await self._get_analytics(session.user_id)
if evaluation["correct"]:
analytics.puzzle_solving_level = min(analytics.puzzle_solving_level + 0.1, 10)
analytics.logical_reasoning_score = self._update_skill_score(
analytics.logical_reasoning_score,
evaluation["quality_score"] / 100
)
analytics.total_sessions += 1
analytics.last_activity_date = datetime.utcnow()
await self.db.commit()
async def _get_analytics(self, user_id: str) -> LearningAnalytics:
"""Get user analytics (shared with GameService)"""
query = select(LearningAnalytics).where(LearningAnalytics.user_id == user_id)
result = await self.db.execute(query)
analytics = result.scalar_one_or_none()
if not analytics:
analytics = LearningAnalytics(user_id=user_id)
self.db.add(analytics)
await self.db.commit()
await self.db.refresh(analytics)
return analytics
def _update_skill_score(self, current_score: float, performance: float) -> float:
"""Update skill score using exponential moving average"""
alpha = 0.1
return current_score * (1 - alpha) + performance * 100 * alpha
class PhilosophicalDialogueService:
"""Service for managing philosophical dilemmas and ethical reasoning"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_available_dilemmas(self, user_id: str) -> Dict[str, Any]:
"""Get available philosophical dilemmas"""
analytics = await self._get_analytics(user_id)
dilemma_types = {
"ethical_frameworks": {
"name": "Ethical Framework Analysis",
"description": "Explore different ethical theories through practical dilemmas",
"topics": ["trolley_problem", "utilitarian_vs_deontological", "virtue_ethics_scenarios"],
"skills_developed": ["ethical_reasoning", "framework_application", "moral_consistency"]
},
"game_theory": {
"name": "Game Theory Dilemmas",
"description": "Strategic decision-making in competitive and cooperative scenarios",
"topics": ["prisoners_dilemma", "tragedy_of_commons", "coordination_games"],
"skills_developed": ["strategic_thinking", "cooperation_analysis", "incentive_understanding"]
},
"ai_consciousness": {
"name": "AI Consciousness & Rights",
"description": "Explore questions about AI sentience, rights, and moral status",
"topics": ["chinese_room", "turing_test_ethics", "ai_rights", "consciousness_criteria"],
"skills_developed": ["conceptual_analysis", "consciousness_theory", "future_ethics"]
}
}
return {
"dilemma_types": dilemma_types,
"user_level": analytics.philosophical_depth_level,
"recommended_topics": self._get_recommended_topics(analytics),
"recent_insights": [] # Would come from recent sessions
}
async def start_dialogue_session(self, user_id: str, dilemma_type: str, topic: str) -> PhilosophicalDialogue:
"""Start a new philosophical dialogue session"""
scenario = self._get_dilemma_scenario(dilemma_type, topic)
dialogue = PhilosophicalDialogue(
user_id=user_id,
dilemma_type=dilemma_type,
dilemma_title=scenario["title"],
scenario_description=scenario["description"],
framework_options=scenario["frameworks"],
complexity_level=scenario.get("complexity", "intermediate"),
estimated_discussion_time=scenario.get("estimated_time", 20)
)
self.db.add(dialogue)
await self.db.commit()
await self.db.refresh(dialogue)
return dialogue
async def submit_response(self, dialogue_id: str, user_id: str, response: str, framework: str = None) -> Dict[str, Any]:
"""Submit a response to the philosophical dilemma"""
dialogue_query = select(PhilosophicalDialogue).where(
and_(PhilosophicalDialogue.id == dialogue_id, PhilosophicalDialogue.user_id == user_id)
)
result = await self.db.execute(dialogue_query)
dialogue = result.scalar_one_or_none()
if not dialogue or dialogue.dialogue_status != 'active':
raise ValueError("Dialogue session not found or not active")
# Generate AI response based on user input
ai_response = self._generate_ai_response(dialogue, response, framework)
# Update dialogue state
dialogue.exchange_count += 1
dialogue.dialogue_history = dialogue.dialogue_history + [
{"speaker": "user", "content": response, "framework": framework, "timestamp": datetime.utcnow().isoformat()},
{"speaker": "ai", "content": ai_response["content"], "type": ai_response["type"], "timestamp": datetime.utcnow().isoformat()}
]
if framework and framework not in dialogue.frameworks_explored:
dialogue.frameworks_explored = dialogue.frameworks_explored + [framework]
dialogue.framework_applications += 1
dialogue.last_exchange_at = datetime.utcnow()
await self.db.commit()
await self.db.refresh(dialogue)
return {
"ai_response": ai_response["content"],
"follow_up_questions": ai_response.get("questions", []),
"suggested_frameworks": ai_response.get("suggested_frameworks", []),
"dialogue_progress": {
"exchanges": dialogue.exchange_count,
"frameworks_explored": len(dialogue.frameworks_explored),
"depth_assessment": self._assess_dialogue_depth(dialogue)
}
}
async def conclude_dialogue(self, dialogue_id: str, user_id: str, final_position: Dict[str, Any]) -> Dict[str, Any]:
"""Conclude the philosophical dialogue with final assessment"""
dialogue_query = select(PhilosophicalDialogue).where(
and_(PhilosophicalDialogue.id == dialogue_id, PhilosophicalDialogue.user_id == user_id)
)
result = await self.db.execute(dialogue_query)
dialogue = result.scalar_one_or_none()
if not dialogue:
raise ValueError("Dialogue session not found")
# Generate final assessment
assessment = self._generate_final_assessment(dialogue, final_position)
# Update dialogue
dialogue.dialogue_status = 'concluded'
dialogue.concluded_at = datetime.utcnow()
dialogue.final_position = final_position
dialogue.key_insights = assessment["insights"]
dialogue.ai_assessment = assessment["ai_evaluation"]
# Update scores based on dialogue quality
dialogue.ethical_consistency_score = assessment["scores"]["consistency"]
dialogue.perspective_flexibility_score = assessment["scores"]["flexibility"]
dialogue.framework_mastery_score = assessment["scores"]["framework_mastery"]
dialogue.synthesis_quality_score = assessment["scores"]["synthesis"]
# Update user analytics
await self._update_analytics_after_dialogue(dialogue, assessment)
await self.db.commit()
await self.db.refresh(dialogue)
return {
"final_assessment": assessment,
"skill_development": assessment["skill_changes"],
"recommended_next_topics": assessment["recommendations"]
}
def _get_dilemma_scenario(self, dilemma_type: str, topic: str) -> Dict[str, Any]:
"""Get specific dilemma scenario"""
scenarios = {
"ethical_frameworks": {
"trolley_problem": {
"title": "The Trolley Problem",
"description": "A runaway trolley is heading towards five people tied to the tracks. You can pull a lever to divert it to another track, where it will kill one person instead. Do you pull the lever?",
"frameworks": ["utilitarianism", "deontological", "virtue_ethics", "care_ethics"],
"complexity": "intermediate"
}
},
"game_theory": {
"prisoners_dilemma": {
"title": "The Prisoner's Dilemma",
"description": "You and your partner are arrested and held separately. You can either confess (defect) or remain silent (cooperate). The outcomes depend on both choices.",
"frameworks": ["rational_choice", "social_contract", "evolutionary_ethics"],
"complexity": "intermediate"
}
},
"ai_consciousness": {
"chinese_room": {
"title": "The Chinese Room Argument",
"description": "Is a computer program that can perfectly simulate understanding Chinese actually understanding Chinese? What does this mean for AI consciousness?",
"frameworks": ["functionalism", "behaviorism", "phenomenology", "computational_theory"],
"complexity": "advanced"
}
}
}
return scenarios.get(dilemma_type, {}).get(topic, scenarios["ethical_frameworks"]["trolley_problem"])
def _generate_ai_response(self, dialogue: PhilosophicalDialogue, user_response: str, framework: str = None) -> Dict[str, Any]:
"""Generate AI response using Socratic method"""
# Mock AI response generation
response_types = ["socratic_question", "framework_challenge", "perspective_shift", "synthesis_prompt"]
response_type = random.choice(response_types)
responses = {
"socratic_question": {
"content": "That's an interesting perspective. What underlying assumptions are you making about the value of individual lives versus collective outcomes?",
"questions": ["How do you weigh individual rights against collective welfare?", "What if the numbers were different?"],
"type": "questioning"
},
"framework_challenge": {
"content": "Your utilitarian approach focuses on outcomes. How might a deontologist view this same situation?",
"suggested_frameworks": ["deontological", "virtue_ethics"],
"type": "framework_exploration"
},
"perspective_shift": {
"content": "Consider this from the perspective of each person involved. How might their consent or agency factor into your decision?",
"questions": ["Does intent matter as much as outcome?", "What about the rights of those who cannot consent?"],
"type": "perspective_expansion"
},
"synthesis_prompt": {
"content": "You've explored multiple frameworks. Can you synthesize these different approaches into a coherent position?",
"type": "synthesis"
}
}
return responses[response_type]
def _assess_dialogue_depth(self, dialogue: PhilosophicalDialogue) -> Dict[str, Any]:
"""Assess the depth and quality of the dialogue"""
return {
"conceptual_depth": min(dialogue.exchange_count / 5, 1.0),
"framework_breadth": len(dialogue.frameworks_explored) / 4,
"synthesis_attempts": dialogue.synthesis_attempts,
"perspective_shifts": dialogue.perspective_shifts
}
def _generate_final_assessment(self, dialogue: PhilosophicalDialogue, final_position: Dict[str, Any]) -> Dict[str, Any]:
"""Generate comprehensive final assessment"""
return {
"insights": [
"Demonstrated understanding of utilitarian reasoning",
"Showed ability to consider multiple perspectives",
"Struggled with synthesizing competing frameworks"
],
"scores": {
"consistency": random.uniform(0.6, 0.9),
"flexibility": random.uniform(0.5, 0.8),
"framework_mastery": random.uniform(0.4, 0.85),
"synthesis": random.uniform(0.3, 0.7)
},
"skill_changes": {
"ethical_reasoning": "+5%",
"perspective_taking": "+3%",
"logical_consistency": "+2%"
},
"recommendations": [
"Explore virtue ethics in more depth",
"Practice synthesizing competing moral intuitions",
"Consider real-world applications of ethical frameworks"
],
"ai_evaluation": {
"dialogue_quality": "Good engagement with multiple perspectives",
"growth_areas": "Framework synthesis and practical application",
"strengths": "Clear reasoning and willingness to explore"
}
}
async def _update_analytics_after_dialogue(self, dialogue: PhilosophicalDialogue, assessment: Dict[str, Any]):
"""Update user analytics after dialogue completion"""
analytics = await self._get_analytics(dialogue.user_id)
# Update philosophical reasoning skills
analytics.ethical_reasoning_score = self._update_skill_score(
analytics.ethical_reasoning_score,
assessment["scores"]["consistency"]
)
analytics.philosophical_depth_level = min(
analytics.philosophical_depth_level + 0.1,
10
)
analytics.total_sessions += 1
analytics.last_activity_date = datetime.utcnow()
await self.db.commit()
def _get_recommended_topics(self, analytics: LearningAnalytics) -> List[str]:
"""Get recommended topics based on user progress"""
if analytics.philosophical_depth_level < 3:
return ["trolley_problem", "prisoners_dilemma"]
elif analytics.philosophical_depth_level < 6:
return ["virtue_ethics_scenarios", "tragedy_of_commons", "ai_rights"]
else:
return ["chinese_room", "consciousness_criteria", "coordination_games"]
async def _get_analytics(self, user_id: str) -> LearningAnalytics:
"""Get user analytics"""
query = select(LearningAnalytics).where(LearningAnalytics.user_id == user_id)
result = await self.db.execute(query)
analytics = result.scalar_one_or_none()
if not analytics:
analytics = LearningAnalytics(user_id=user_id)
self.db.add(analytics)
await self.db.commit()
await self.db.refresh(analytics)
return analytics
def _update_skill_score(self, current_score: float, performance: float) -> float:
"""Update skill score using exponential moving average"""
alpha = 0.1
return current_score * (1 - alpha) + performance * 100 * alpha

View File

@@ -0,0 +1,427 @@
"""
Model Context Protocol (MCP) Integration Service
Enables extensible tool integration for GT 2.0 agents
"""
from typing import List, Dict, Any, Optional
import httpx
import asyncio
import json
from pydantic import BaseModel, Field
from datetime import datetime
import logging
from app.models.agent import Agent
logger = logging.getLogger(__name__)
class MCPTool(BaseModel):
"""MCP Tool definition"""
name: str
description: str
parameters: Dict[str, Any]
returns: Dict[str, Any] = Field(default_factory=dict)
endpoint: str
requires_auth: bool = False
rate_limit: Optional[int] = None # requests per minute
timeout: int = 30 # seconds
enabled: bool = True
class MCPServer(BaseModel):
"""MCP Server configuration"""
id: str
name: str
base_url: str
api_key: Optional[str] = None
tools: List[MCPTool] = Field(default_factory=list)
health_check_endpoint: str = "/health"
tools_endpoint: str = "/tools"
timeout: int = 30
max_retries: int = 3
enabled: bool = True
created_at: datetime = Field(default_factory=datetime.utcnow)
last_health_check: Optional[datetime] = None
health_status: str = "unknown" # healthy, unhealthy, unknown
class MCPExecutionResult(BaseModel):
"""Result of tool execution"""
success: bool
data: Dict[str, Any] = Field(default_factory=dict)
error: Optional[str] = None
execution_time_ms: int
tokens_used: Optional[int] = None
cost_cents: Optional[int] = None
class MCPIntegrationService:
"""Service for managing MCP integrations"""
def __init__(self):
self.servers: Dict[str, MCPServer] = {}
self.rate_limits: Dict[str, Dict[str, List[datetime]]] = {} # server_id -> tool_name -> timestamps
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100)
)
async def register_server(self, server: MCPServer) -> bool:
"""Register a new MCP server"""
try:
# Validate server is reachable
if await self.health_check(server):
# Discover available tools
tools = await self.discover_tools(server)
server.tools = tools
server.health_status = "healthy"
server.last_health_check = datetime.utcnow()
self.servers[server.id] = server
self.rate_limits[server.id] = {}
logger.info(f"Registered MCP server: {server.name} with {len(tools)} tools")
return True
else:
server.health_status = "unhealthy"
logger.error(f"Failed to register MCP server: {server.name} - health check failed")
return False
except Exception as e:
server.health_status = "unhealthy"
logger.error(f"Failed to register MCP server: {server.name} - {str(e)}")
return False
async def health_check(self, server: MCPServer) -> bool:
"""Check if MCP server is healthy"""
try:
headers = {}
if server.api_key:
headers["Authorization"] = f"Bearer {server.api_key}"
response = await self.client.get(
f"{server.base_url.rstrip('/')}{server.health_check_endpoint}",
headers=headers,
timeout=10.0
)
is_healthy = response.status_code == 200
if server.id in self.servers:
self.servers[server.id].health_status = "healthy" if is_healthy else "unhealthy"
self.servers[server.id].last_health_check = datetime.utcnow()
return is_healthy
except Exception as e:
logger.warning(f"Health check failed for {server.name}: {e}")
if server.id in self.servers:
self.servers[server.id].health_status = "unhealthy"
return False
async def discover_tools(self, server: MCPServer) -> List[MCPTool]:
"""Discover available tools from MCP server"""
try:
headers = {}
if server.api_key:
headers["Authorization"] = f"Bearer {server.api_key}"
response = await self.client.get(
f"{server.base_url.rstrip('/')}{server.tools_endpoint}",
headers=headers
)
if response.status_code == 200:
tools_data = response.json()
tools = []
for tool_data in tools_data.get("tools", []):
try:
tool = MCPTool(**tool_data)
tools.append(tool)
except Exception as e:
logger.warning(f"Invalid tool definition from {server.name}: {e}")
return tools
except Exception as e:
logger.error(f"Tool discovery failed for {server.name}: {e}")
return []
def _check_rate_limit(self, server_id: str, tool_name: str) -> bool:
"""Check if rate limit allows execution"""
server = self.servers.get(server_id)
if not server:
return False
tool = next((t for t in server.tools if t.name == tool_name), None)
if not tool or not tool.rate_limit:
return True
now = datetime.utcnow()
minute_ago = now.timestamp() - 60
# Initialize tracking if needed
if server_id not in self.rate_limits:
self.rate_limits[server_id] = {}
if tool_name not in self.rate_limits[server_id]:
self.rate_limits[server_id][tool_name] = []
# Clean old timestamps
timestamps = self.rate_limits[server_id][tool_name]
self.rate_limits[server_id][tool_name] = [
ts for ts in timestamps if ts.timestamp() > minute_ago
]
# Check limit
current_count = len(self.rate_limits[server_id][tool_name])
if current_count >= tool.rate_limit:
return False
# Record this request
self.rate_limits[server_id][tool_name].append(now)
return True
async def execute_tool(
self,
server_id: str,
tool_name: str,
parameters: Dict[str, Any],
assistant_context: Optional[Dict[str, Any]] = None,
user_context: Optional[Dict[str, Any]] = None
) -> MCPExecutionResult:
"""Execute a tool on an MCP server"""
start_time = datetime.utcnow()
try:
# Validate server exists and is enabled
if server_id not in self.servers:
return MCPExecutionResult(
success=False,
error=f"Server {server_id} not registered",
execution_time_ms=0
)
server = self.servers[server_id]
if not server.enabled:
return MCPExecutionResult(
success=False,
error=f"Server {server_id} is disabled",
execution_time_ms=0
)
# Find tool
tool = next((t for t in server.tools if t.name == tool_name), None)
if not tool:
return MCPExecutionResult(
success=False,
error=f"Tool {tool_name} not found on server {server_id}",
execution_time_ms=0
)
if not tool.enabled:
return MCPExecutionResult(
success=False,
error=f"Tool {tool_name} is disabled",
execution_time_ms=0
)
# Check rate limit
if not self._check_rate_limit(server_id, tool_name):
return MCPExecutionResult(
success=False,
error=f"Rate limit exceeded for tool {tool_name}",
execution_time_ms=0
)
# Prepare request
headers = {"Content-Type": "application/json"}
if server.api_key:
headers["Authorization"] = f"Bearer {server.api_key}"
payload = {
"tool": tool_name,
"parameters": parameters
}
# Add context if provided
if assistant_context:
payload["assistant_context"] = assistant_context
if user_context:
payload["user_context"] = user_context
# Execute with retries
last_exception = None
for attempt in range(server.max_retries):
try:
response = await self.client.post(
f"{server.base_url.rstrip('/')}{tool.endpoint}",
json=payload,
headers=headers,
timeout=tool.timeout
)
execution_time = int((datetime.utcnow() - start_time).total_seconds() * 1000)
if response.status_code == 200:
result_data = response.json()
return MCPExecutionResult(
success=True,
data=result_data,
execution_time_ms=execution_time,
tokens_used=result_data.get("tokens_used"),
cost_cents=result_data.get("cost_cents")
)
else:
return MCPExecutionResult(
success=False,
error=f"HTTP {response.status_code}: {response.text}",
execution_time_ms=execution_time
)
except httpx.TimeoutException as e:
last_exception = e
if attempt == server.max_retries - 1:
break
await asyncio.sleep(2 ** attempt) # Exponential backoff
except Exception as e:
last_exception = e
if attempt == server.max_retries - 1:
break
await asyncio.sleep(2 ** attempt)
# All retries failed
execution_time = int((datetime.utcnow() - start_time).total_seconds() * 1000)
return MCPExecutionResult(
success=False,
error=f"Tool execution failed after {server.max_retries} attempts: {str(last_exception)}",
execution_time_ms=execution_time
)
except Exception as e:
execution_time = int((datetime.utcnow() - start_time).total_seconds() * 1000)
return MCPExecutionResult(
success=False,
error=f"Unexpected error: {str(e)}",
execution_time_ms=execution_time
)
async def get_tools_for_assistant(
self,
agent: Agent
) -> List[MCPTool]:
"""Get all available tools for an agent based on integrations"""
tools = []
for integration_id in agent.mcp_integration_ids:
if integration_id in self.servers:
server = self.servers[integration_id]
if server.enabled:
# Filter by enabled tools only
tools.extend([tool for tool in server.tools if tool.enabled])
return tools
def format_tools_for_llm(self, tools: List[MCPTool]) -> List[Dict[str, Any]]:
"""Format tools for LLM function calling (OpenAI format)"""
return [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
}
}
for tool in tools
]
async def get_server_status(self, server_id: str) -> Dict[str, Any]:
"""Get detailed status of an MCP server"""
if server_id not in self.servers:
return {"error": "Server not found"}
server = self.servers[server_id]
# Perform health check
await self.health_check(server)
return {
"id": server.id,
"name": server.name,
"base_url": server.base_url,
"enabled": server.enabled,
"health_status": server.health_status,
"last_health_check": server.last_health_check.isoformat() if server.last_health_check else None,
"tools_count": len(server.tools),
"enabled_tools_count": sum(1 for tool in server.tools if tool.enabled),
"tools": [
{
"name": tool.name,
"description": tool.description,
"enabled": tool.enabled,
"rate_limit": tool.rate_limit
}
for tool in server.tools
]
}
async def list_servers(self) -> List[Dict[str, Any]]:
"""List all registered MCP servers"""
servers = []
for server_id, server in self.servers.items():
servers.append(await self.get_server_status(server_id))
return servers
async def remove_server(self, server_id: str) -> bool:
"""Remove an MCP server"""
if server_id in self.servers:
del self.servers[server_id]
if server_id in self.rate_limits:
del self.rate_limits[server_id]
logger.info(f"Removed MCP server: {server_id}")
return True
return False
async def update_server_config(self, server_id: str, config: Dict[str, Any]) -> bool:
"""Update MCP server configuration"""
if server_id not in self.servers:
return False
server = self.servers[server_id]
# Update allowed fields
if "enabled" in config:
server.enabled = config["enabled"]
if "api_key" in config:
server.api_key = config["api_key"]
if "timeout" in config:
server.timeout = config["timeout"]
if "max_retries" in config:
server.max_retries = config["max_retries"]
# Rediscover tools if server was re-enabled or config changed
if server.enabled:
tools = await self.discover_tools(server)
server.tools = tools
logger.info(f"Updated MCP server config: {server_id}")
return True
async def close(self):
"""Close HTTP client"""
await self.client.aclose()
# Singleton instance
mcp_service = MCPIntegrationService()
# Default MCP servers for GT 2.0 (can be configured via admin)
DEFAULT_MCP_SERVERS = [
{
"id": "gt2_core_tools",
"name": "GT 2.0 Core Tools",
"base_url": "http://localhost:8003", # Internal tools server
"tools_endpoint": "/mcp/tools",
"health_check_endpoint": "/mcp/health"
},
{
"id": "web_search",
"name": "Web Search Tools",
"base_url": "http://localhost:8004",
"api_key": None # Configured via admin
}
]

View File

@@ -0,0 +1,371 @@
"""
Message Bus Client for GT 2.0 Tenant Backend
Handles communication with Admin Cluster via RabbitMQ message queues.
Processes commands from admin and sends responses back.
"""
import json
import logging
import asyncio
import uuid
import hmac
import hashlib
from datetime import datetime
from typing import Dict, Any, Optional, Callable, List
from pydantic import BaseModel, Field
import aio_pika
from aio_pika import connect, Message, DeliveryMode, ExchangeType
from aio_pika.abc import AbstractIncomingMessage
from app.core.config import get_settings
logger = logging.getLogger(__name__)
class AdminCommand(BaseModel):
"""Command received from Admin Cluster"""
command_id: str
command_type: str # TENANT_PROVISION, TENANT_SUSPEND, etc.
target_cluster: str
target_namespace: str
payload: Dict[str, Any]
timestamp: str
signature: str
class TenantResponse(BaseModel):
"""Response sent to Admin Cluster"""
command_id: str
response_type: str # SUCCESS, ERROR, PROCESSING
target_cluster: str = "admin"
source_cluster: str = "tenant"
namespace: str
payload: Dict[str, Any] = Field(default_factory=dict)
timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
signature: Optional[str] = None
class MessageBusClient:
"""
Client for RabbitMQ message bus communication with Admin Cluster.
GT 2.0 Security Principles:
- HMAC signature verification for all messages
- Tenant-scoped command processing
- No cross-tenant message leakage
"""
def __init__(self):
self.settings = get_settings()
self.connection = None
self.channel = None
self.admin_to_tenant_queue = None
self.tenant_to_admin_queue = None
# Message handlers
self.command_handlers: Dict[str, Callable[[AdminCommand], Any]] = {}
# RabbitMQ configuration from admin specification
self.rabbitmq_url = getattr(
self.settings,
'RABBITMQ_URL',
'amqp://gt2_admin:dev_password_change_in_prod@rabbitmq:5672/'
)
# Security
self.secret_key = getattr(self.settings, 'SECRET_KEY', 'production-secret-key')
logger.info("Message bus client initialized")
async def connect(self) -> bool:
"""Connect to RabbitMQ message bus"""
try:
self.connection = await connect(self.rabbitmq_url)
self.channel = await self.connection.channel()
# Declare exchanges and queues matching admin specification
await self.channel.declare_exchange(
'gt2_commands', ExchangeType.DIRECT, durable=True
)
# Admin → Tenant command queue
self.admin_to_tenant_queue = await self.channel.declare_queue(
'admin_to_tenant', durable=True
)
# Tenant → Admin response queue
self.tenant_to_admin_queue = await self.channel.declare_queue(
'tenant_to_admin', durable=True
)
# Start consuming commands
await self.admin_to_tenant_queue.consume(self._handle_admin_command)
logger.info("Connected to RabbitMQ message bus")
return True
except Exception as e:
logger.error(f"Failed to connect to message bus: {e}")
return False
def _verify_signature(self, message_data: Dict[str, Any], signature: str) -> bool:
"""Verify HMAC signature of incoming message"""
try:
# Create message content for signature (excluding signature field)
content = {k: v for k, v in message_data.items() if k != 'signature'}
content_str = json.dumps(content, sort_keys=True)
# Calculate expected signature
expected_signature = hmac.new(
self.secret_key.encode(),
content_str.encode(),
hashlib.sha256
).hexdigest()
return hmac.compare_digest(expected_signature, signature)
except Exception as e:
logger.error(f"Signature verification failed: {e}")
return False
def _sign_message(self, message_data: Dict[str, Any]) -> str:
"""Create HMAC signature for outgoing message"""
try:
content_str = json.dumps(message_data, sort_keys=True)
return hmac.new(
self.secret_key.encode(),
content_str.encode(),
hashlib.sha256
).hexdigest()
except Exception as e:
logger.error(f"Message signing failed: {e}")
return ""
async def _handle_admin_command(self, message: AbstractIncomingMessage) -> None:
"""Handle incoming command from Admin Cluster"""
try:
# Parse message
message_data = json.loads(message.body.decode())
command = AdminCommand(**message_data)
logger.info(f"Received admin command: {command.command_type} ({command.command_id})")
# Verify signature
if not self._verify_signature(message_data, command.signature):
logger.error(f"Invalid signature for command {command.command_id}")
await self._send_response(command.command_id, "ERROR", {
"error": "Invalid signature",
"namespace": command.target_namespace
})
return
# Check if we have a handler for this command type
if command.command_type not in self.command_handlers:
logger.warning(f"No handler for command type: {command.command_type}")
await self._send_response(command.command_id, "ERROR", {
"error": f"Unknown command type: {command.command_type}",
"namespace": command.target_namespace
})
return
# Execute command handler
handler = self.command_handlers[command.command_type]
result = await handler(command)
# Send success response
await self._send_response(command.command_id, "SUCCESS", {
"result": result,
"namespace": command.target_namespace
})
# Acknowledge message
message.ack()
except Exception as e:
logger.error(f"Error handling admin command: {e}")
try:
if 'command' in locals():
await self._send_response(command.command_id, "ERROR", {
"error": str(e),
"namespace": getattr(command, 'target_namespace', 'unknown')
})
message.nack(requeue=False)
except:
pass
async def _send_response(self, command_id: str, response_type: str, payload: Dict[str, Any]) -> None:
"""Send response back to Admin Cluster"""
try:
response_data = {
"command_id": command_id,
"response_type": response_type,
"target_cluster": "admin",
"source_cluster": "tenant",
"namespace": payload.get("namespace", "unknown"),
"payload": payload,
"timestamp": datetime.utcnow().isoformat()
}
# Sign the response
response_data["signature"] = self._sign_message(response_data)
# Send message
message = Message(
json.dumps(response_data).encode(),
delivery_mode=DeliveryMode.PERSISTENT
)
await self.channel.default_exchange.publish(
message, routing_key=self.tenant_to_admin_queue.name
)
logger.info(f"Sent response to admin: {response_type} for {command_id}")
except Exception as e:
logger.error(f"Failed to send response to admin: {e}")
def register_handler(self, command_type: str, handler: Callable[[AdminCommand], Any]) -> None:
"""Register handler for specific command type"""
self.command_handlers[command_type] = handler
logger.info(f"Registered handler for command type: {command_type}")
async def send_notification(self, notification_type: str, payload: Dict[str, Any]) -> None:
"""Send notification to Admin Cluster (not in response to a command)"""
try:
notification_data = {
"command_id": str(uuid.uuid4()),
"response_type": f"NOTIFICATION_{notification_type}",
"target_cluster": "admin",
"source_cluster": "tenant",
"namespace": getattr(self.settings, 'TENANT_ID', 'unknown'),
"payload": payload,
"timestamp": datetime.utcnow().isoformat()
}
# Sign the notification
notification_data["signature"] = self._sign_message(notification_data)
# Send message
message = Message(
json.dumps(notification_data).encode(),
delivery_mode=DeliveryMode.PERSISTENT
)
await self.channel.default_exchange.publish(
message, routing_key=self.tenant_to_admin_queue.name
)
logger.info(f"Sent notification to admin: {notification_type}")
except Exception as e:
logger.error(f"Failed to send notification to admin: {e}")
async def disconnect(self) -> None:
"""Disconnect from message bus"""
try:
if self.connection:
await self.connection.close()
logger.info("Disconnected from message bus")
except Exception as e:
logger.error(f"Error disconnecting from message bus: {e}")
# Global instance
message_bus_client = MessageBusClient()
# Command handlers for different admin commands
async def handle_tenant_provision(command: AdminCommand) -> Dict[str, Any]:
"""Handle tenant provisioning command"""
logger.info(f"Processing tenant provision for: {command.target_namespace}")
# TODO: Implement tenant provisioning logic
# - Create tenant directory structure
# - Initialize SQLite database
# - Set up access controls
# - Configure resource quotas
return {
"status": "provisioned",
"namespace": command.target_namespace,
"resources_allocated": command.payload.get("resources", {})
}
async def handle_tenant_suspend(command: AdminCommand) -> Dict[str, Any]:
"""Handle tenant suspension command"""
logger.info(f"Processing tenant suspension for: {command.target_namespace}")
# TODO: Implement tenant suspension logic
# - Disable API access
# - Pause running processes
# - Preserve data integrity
return {
"status": "suspended",
"namespace": command.target_namespace
}
async def handle_tenant_activate(command: AdminCommand) -> Dict[str, Any]:
"""Handle tenant activation command"""
logger.info(f"Processing tenant activation for: {command.target_namespace}")
# TODO: Implement tenant activation logic
# - Restore API access
# - Resume processes
# - Validate system state
return {
"status": "activated",
"namespace": command.target_namespace
}
async def handle_resource_allocate(command: AdminCommand) -> Dict[str, Any]:
"""Handle resource allocation command"""
logger.info(f"Processing resource allocation for: {command.target_namespace}")
# TODO: Implement resource allocation logic
# - Update resource quotas
# - Configure rate limits
# - Enable new capabilities
return {
"status": "allocated",
"namespace": command.target_namespace,
"resources": command.payload.get("resources", {})
}
async def handle_resource_revoke(command: AdminCommand) -> Dict[str, Any]:
"""Handle resource revocation command"""
logger.info(f"Processing resource revocation for: {command.target_namespace}")
# TODO: Implement resource revocation logic
# - Remove resource access
# - Update quotas
# - Gracefully handle running operations
return {
"status": "revoked",
"namespace": command.target_namespace
}
# Register all command handlers
async def initialize_message_bus() -> bool:
"""Initialize message bus with all command handlers"""
try:
# Register command handlers
message_bus_client.register_handler("TENANT_PROVISION", handle_tenant_provision)
message_bus_client.register_handler("TENANT_SUSPEND", handle_tenant_suspend)
message_bus_client.register_handler("TENANT_ACTIVATE", handle_tenant_activate)
message_bus_client.register_handler("RESOURCE_ALLOCATE", handle_resource_allocate)
message_bus_client.register_handler("RESOURCE_REVOKE", handle_resource_revoke)
# Connect to message bus
connected = await message_bus_client.connect()
if connected:
logger.info("Message bus client initialized successfully")
return True
else:
logger.error("Failed to initialize message bus client")
return False
except Exception as e:
logger.error(f"Error initializing message bus: {e}")
return False

View File

@@ -0,0 +1,347 @@
"""
Optics Cost Calculation Service
Calculates inference and storage costs for the Optics feature.
"""
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
import httpx
import logging
from app.core.config import get_settings
logger = logging.getLogger(__name__)
# Storage cost rate
STORAGE_COST_PER_MB_CENTS = 4.0 # $0.04 per MB
# Fallback pricing for unknown models
DEFAULT_MODEL_PRICING = {
"cost_per_1k_input": 0.10,
"cost_per_1k_output": 0.10
}
class OpticsPricingCache:
"""Simple in-memory cache for model pricing"""
_pricing: Optional[Dict[str, Any]] = None
_expires_at: Optional[datetime] = None
_ttl_seconds: int = 300 # 5 minutes
@classmethod
def get(cls) -> Optional[Dict[str, Any]]:
if cls._pricing and cls._expires_at and datetime.utcnow() < cls._expires_at:
return cls._pricing
return None
@classmethod
def set(cls, pricing: Dict[str, Any]):
cls._pricing = pricing
cls._expires_at = datetime.utcnow() + timedelta(seconds=cls._ttl_seconds)
@classmethod
def clear(cls):
cls._pricing = None
cls._expires_at = None
async def fetch_optics_settings(tenant_domain: str) -> Dict[str, Any]:
"""
Fetch Optics settings from Control Panel for a tenant.
Returns:
dict with 'enabled', 'storage_cost_per_mb_cents'
"""
settings = get_settings()
control_panel_url = settings.control_panel_url or "http://gentwo-control-panel-backend:8001"
service_token = settings.service_auth_token or "internal-service-token"
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{control_panel_url}/internal/optics/tenant/{tenant_domain}/settings",
headers={
"X-Service-Auth": service_token,
"X-Service-Name": "tenant-backend"
},
timeout=10.0
)
if response.status_code == 200:
return response.json()
elif response.status_code == 404:
logger.warning(f"Tenant {tenant_domain} not found in Control Panel")
return {"enabled": False, "storage_cost_per_mb_cents": STORAGE_COST_PER_MB_CENTS}
else:
logger.error(f"Failed to fetch optics settings: {response.status_code}")
return {"enabled": False, "storage_cost_per_mb_cents": STORAGE_COST_PER_MB_CENTS}
except Exception as e:
logger.error(f"Error fetching optics settings: {str(e)}")
# Default to disabled on error
return {"enabled": False, "storage_cost_per_mb_cents": STORAGE_COST_PER_MB_CENTS}
async def fetch_model_pricing() -> Dict[str, Dict[str, float]]:
"""
Fetch model pricing from Control Panel.
Uses caching to avoid repeated requests.
Returns:
dict mapping model_id -> {cost_per_1k_input, cost_per_1k_output}
"""
# Check cache first
cached = OpticsPricingCache.get()
if cached:
return cached
settings = get_settings()
control_panel_url = settings.control_panel_url or "http://gentwo-control-panel-backend:8001"
service_token = settings.service_auth_token or "internal-service-token"
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{control_panel_url}/internal/optics/model-pricing",
headers={
"X-Service-Auth": service_token,
"X-Service-Name": "tenant-backend"
},
timeout=10.0
)
if response.status_code == 200:
data = response.json()
pricing = data.get("models", {})
OpticsPricingCache.set(pricing)
return pricing
else:
logger.error(f"Failed to fetch model pricing: {response.status_code}")
return {}
except Exception as e:
logger.error(f"Error fetching model pricing: {str(e)}")
return {}
def get_model_cost_per_1k(model_id: str, pricing_map: Dict[str, Dict[str, float]]) -> float:
"""
Get combined cost per 1k tokens for a model.
Args:
model_id: Model identifier (e.g., 'llama-3.3-70b-versatile')
pricing_map: Map of model_id -> pricing info
Returns:
Combined input + output cost per 1k tokens in dollars
"""
pricing = pricing_map.get(model_id)
if pricing:
return (pricing.get("cost_per_1k_input", 0.0) + pricing.get("cost_per_1k_output", 0.0))
# Try variations of the model ID
# Sometimes model_used might have provider prefix like "groq:model-name"
if ":" in model_id:
model_name = model_id.split(":", 1)[1]
pricing = pricing_map.get(model_name)
if pricing:
return (pricing.get("cost_per_1k_input", 0.0) + pricing.get("cost_per_1k_output", 0.0))
# Return default pricing
return DEFAULT_MODEL_PRICING["cost_per_1k_input"] + DEFAULT_MODEL_PRICING["cost_per_1k_output"]
def calculate_inference_cost_cents(tokens: int, cost_per_1k: float) -> float:
"""
Calculate inference cost in cents from token count.
Args:
tokens: Total token count
cost_per_1k: Cost per 1000 tokens in dollars
Returns:
Cost in cents
"""
return (tokens / 1000) * cost_per_1k * 100
def calculate_storage_cost_cents(total_mb: float, cost_per_mb_cents: float = STORAGE_COST_PER_MB_CENTS) -> float:
"""
Calculate storage cost in cents.
Args:
total_mb: Total storage in megabytes
cost_per_mb_cents: Cost per MB in cents (default 4 cents = $0.04)
Returns:
Cost in cents
"""
return total_mb * cost_per_mb_cents
def format_cost_display(cents: float) -> str:
"""Format cost in cents to a display string like '$12.34'"""
dollars = cents / 100
return f"${dollars:,.2f}"
async def get_optics_cost_summary(
pg_client,
tenant_domain: str,
date_start: datetime,
date_end: datetime,
user_id: Optional[str] = None,
include_user_breakdown: bool = False
) -> Dict[str, Any]:
"""
Calculate full Optics cost summary for a tenant.
Args:
pg_client: PostgreSQL client
tenant_domain: Tenant domain
date_start: Start date for cost calculation
date_end: End date for cost calculation
user_id: Optional user ID filter
include_user_breakdown: Whether to include per-user breakdown
Returns:
Complete cost summary with breakdowns
"""
schema = f"tenant_{tenant_domain.replace('-', '_')}"
# Fetch model pricing
pricing_map = await fetch_model_pricing()
# Build user filter
user_filter = ""
params = [date_start, date_end]
param_idx = 3
if user_id:
user_filter = f"AND c.user_id = ${param_idx}::uuid"
params.append(user_id)
param_idx += 1
# Query token usage by model
token_query = f"""
SELECT
COALESCE(m.model_used, 'unknown') as model_id,
COALESCE(SUM(m.token_count), 0) as total_tokens,
COUNT(DISTINCT c.id) as conversations,
COUNT(m.id) as messages
FROM {schema}.messages m
JOIN {schema}.conversations c ON m.conversation_id = c.id
WHERE c.created_at >= $1 AND c.created_at <= $2
AND m.model_used IS NOT NULL AND m.model_used != ''
{user_filter}
GROUP BY m.model_used
ORDER BY total_tokens DESC
"""
token_results = await pg_client.execute_query(token_query, *params)
# Calculate inference costs by model
by_model = []
total_inference_cents = 0.0
total_tokens = 0
for row in token_results or []:
model_id = row["model_id"]
tokens = int(row["total_tokens"])
total_tokens += tokens
cost_per_1k = get_model_cost_per_1k(model_id, pricing_map)
cost_cents = calculate_inference_cost_cents(tokens, cost_per_1k)
total_inference_cents += cost_cents
# Clean up model name for display
model_name = model_id.split(":")[-1] if ":" in model_id else model_id
by_model.append({
"model_id": model_id,
"model_name": model_name,
"tokens": tokens,
"conversations": row["conversations"],
"messages": row["messages"],
"cost_cents": round(cost_cents, 2),
"cost_display": format_cost_display(cost_cents)
})
# Calculate percentages
for item in by_model:
item["percentage"] = round((item["cost_cents"] / total_inference_cents * 100) if total_inference_cents > 0 else 0, 1)
# Query storage
storage_params = []
storage_user_filter = ""
if user_id:
storage_user_filter = f"WHERE d.user_id = $1::uuid"
storage_params.append(user_id)
storage_query = f"""
SELECT
COALESCE(SUM(d.file_size_bytes), 0) / 1048576.0 as total_mb,
COUNT(d.id) as document_count,
COUNT(DISTINCT d.dataset_id) as dataset_count
FROM {schema}.documents d
{storage_user_filter}
"""
storage_result = await pg_client.execute_query(storage_query, *storage_params) if storage_params else await pg_client.execute_query(storage_query)
storage_data = storage_result[0] if storage_result else {"total_mb": 0, "document_count": 0, "dataset_count": 0}
total_storage_mb = float(storage_data.get("total_mb", 0))
storage_cost_cents = calculate_storage_cost_cents(total_storage_mb)
# Total cost
total_cost_cents = total_inference_cents + storage_cost_cents
# User breakdown (admin only)
by_user = []
if include_user_breakdown:
user_query = f"""
SELECT
c.user_id,
u.email,
COALESCE(SUM(m.token_count), 0) as tokens
FROM {schema}.messages m
JOIN {schema}.conversations c ON m.conversation_id = c.id
JOIN {schema}.users u ON c.user_id = u.id
WHERE c.created_at >= $1 AND c.created_at <= $2
GROUP BY c.user_id, u.email
ORDER BY tokens DESC
"""
user_results = await pg_client.execute_query(user_query, date_start, date_end)
for row in user_results or []:
user_tokens = int(row["tokens"])
# Use average model cost for user breakdown
avg_cost_per_1k = (total_inference_cents / total_tokens * 10) if total_tokens > 0 else 0.2
user_cost_cents = (user_tokens / 1000) * avg_cost_per_1k
by_user.append({
"user_id": str(row["user_id"]),
"email": row["email"],
"tokens": user_tokens,
"cost_cents": round(user_cost_cents, 2),
"cost_display": format_cost_display(user_cost_cents),
"percentage": round((user_tokens / total_tokens * 100) if total_tokens > 0 else 0, 1)
})
return {
"inference_cost_cents": round(total_inference_cents, 2),
"storage_cost_cents": round(storage_cost_cents, 2),
"total_cost_cents": round(total_cost_cents, 2),
"inference_cost_display": format_cost_display(total_inference_cents),
"storage_cost_display": format_cost_display(storage_cost_cents),
"total_cost_display": format_cost_display(total_cost_cents),
"total_tokens": total_tokens,
"total_storage_mb": round(total_storage_mb, 2),
"document_count": storage_data.get("document_count", 0),
"dataset_count": storage_data.get("dataset_count", 0),
"by_model": by_model,
"by_user": by_user if include_user_breakdown else None,
"period_start": date_start.isoformat(),
"period_end": date_end.isoformat()
}

View File

@@ -0,0 +1,806 @@
"""
PGVector Hybrid Search Service for GT 2.0 Tenant Backend
Provides unified vector similarity search + full-text search using PostgreSQL
with PGVector extension. Replaces ChromaDB for better performance and consistency.
Features:
- Vector similarity search using PGVector
- Full-text search using PostgreSQL built-in features
- Hybrid scoring combining both approaches
- Perfect tenant isolation using RLS
- Zero-downtime MVCC operations
"""
import logging
import asyncio
import json
import uuid as uuid_lib
from typing import Dict, Any, List, Optional, Tuple, Union
from dataclasses import dataclass
from datetime import datetime
import uuid
import asyncpg
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text, select, and_, or_
from app.core.postgresql_client import get_postgresql_client
from app.core.config import get_settings
from app.services.embedding_client import BGE_M3_EmbeddingClient
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class HybridSearchResult:
"""Result from hybrid vector + text search"""
chunk_id: str
document_id: str
dataset_id: Optional[str]
text: str
metadata: Dict[str, Any]
vector_similarity: float
text_relevance: float
hybrid_score: float
rank: int
@dataclass
class SearchConfig:
"""Configuration for hybrid search behavior"""
vector_weight: float = 0.7
text_weight: float = 0.3
min_vector_similarity: float = 0.3
min_text_relevance: float = 0.01
max_results: int = 100
rerank_results: bool = True
class PGVectorSearchService:
"""
Hybrid search service using PostgreSQL + PGVector.
GT 2.0 Principles:
- Perfect tenant isolation via RLS policies
- Zero downtime MVCC operations
- Real implementation (no mocks)
- Operational elegance through unified storage
"""
def __init__(self, tenant_id: str, user_id: Optional[str] = None):
self.tenant_id = tenant_id
self.user_id = user_id
self.settings = get_settings()
self.embedding_client = BGE_M3_EmbeddingClient()
# Schema naming for tenant isolation
self.schema_name = self.settings.postgres_schema
logger.info(f"PGVector search service initialized for tenant {tenant_id}")
async def hybrid_search(
self,
query: str,
user_id: str,
dataset_ids: Optional[List[str]] = None,
config: Optional[SearchConfig] = None,
limit: int = 10
) -> List[HybridSearchResult]:
"""
Perform hybrid vector + text search across user's documents.
Args:
query: Search query text
user_id: User performing search (for RLS)
dataset_ids: Optional list of dataset IDs to search
config: Search configuration parameters
limit: Maximum results to return
Returns:
List of ranked search results
"""
if config is None:
config = SearchConfig()
try:
logger.info(f"🔍 HYBRID_SEARCH START: query='{query}', user_id='{user_id}', dataset_ids={dataset_ids}")
logger.info(f"🔍 HYBRID_SEARCH CONFIG: vector_weight={config.vector_weight}, text_weight={config.text_weight}, min_similarity={config.min_vector_similarity}")
# Generate query embedding via resource cluster
logger.info(f"🔍 HYBRID_SEARCH: Generating embedding for query '{query}' with user_id '{user_id}'")
query_embedding = await self._generate_query_embedding(query, user_id)
logger.info(f"🔍 HYBRID_SEARCH: Generated embedding with {len(query_embedding)} dimensions")
# Execute hybrid search query
logger.info(f"🔍 HYBRID_SEARCH: Executing hybrid query with user_id='{user_id}', dataset_ids={dataset_ids}")
results = await self._execute_hybrid_query(
query=query,
query_embedding=query_embedding,
user_id=user_id,
dataset_ids=dataset_ids,
config=config,
limit=limit
)
logger.info(f"🔍 HYBRID_SEARCH: Query returned {len(results)} raw results")
# Apply re-ranking if enabled
if config.rerank_results and len(results) > 1:
logger.info(f"🔍 HYBRID_SEARCH: Applying re-ranking to {len(results)} results")
results = await self._rerank_results(results, query, config)
logger.info(f"🔍 HYBRID_SEARCH: Re-ranking complete, final result count: {len(results)}")
logger.info(f"🔍 HYBRID_SEARCH COMPLETE: Returned {len(results)} results for user {user_id}")
return results
except Exception as e:
logger.error(f"🔍 HYBRID_SEARCH ERROR: {e}")
logger.exception("Full hybrid search error traceback:")
raise
async def vector_similarity_search(
self,
query_embedding: List[float],
user_id: str,
dataset_ids: Optional[List[str]] = None,
similarity_threshold: float = 0.3,
limit: int = 10
) -> List[HybridSearchResult]:
"""
Pure vector similarity search using PGVector.
Args:
query_embedding: Pre-computed query embedding
user_id: User performing search
dataset_ids: Optional dataset filter
similarity_threshold: Minimum cosine similarity
limit: Maximum results
Returns:
Vector similarity results
"""
try:
logger.info(f"🔍 VECTOR_SEARCH START: user_id='{user_id}', dataset_ids={dataset_ids}, threshold={similarity_threshold}")
client = await get_postgresql_client()
async with client.get_connection() as conn:
logger.info(f"🔍 VECTOR_SEARCH: Got DB connection, resolving user UUID from '{user_id}'")
# Resolve user UUID first
resolved_user_id = await self._resolve_user_uuid(conn, user_id)
logger.info(f"🔍 VECTOR_SEARCH: Resolved user_id '{user_id}' to UUID '{resolved_user_id}'")
# RLS context removed - using schema-level isolation instead
logger.info(f"🔍 VECTOR_SEARCH: Using resolved UUID '{resolved_user_id}' for query parameters")
# Build query with dataset filtering
dataset_filter = ""
params = [query_embedding, similarity_threshold, limit]
if dataset_ids:
logger.info(f"🔍 VECTOR_SEARCH: Adding dataset filter for datasets: {dataset_ids}")
dataset_start_idx = 4 # Start after query_embedding, similarity_threshold, limit
placeholders = ",".join(f"${dataset_start_idx + i}" for i in range(len(dataset_ids)))
dataset_filter = f"AND dataset_id = ANY(ARRAY[{placeholders}]::uuid[])"
params.extend(dataset_ids)
logger.info(f"🔍 VECTOR_SEARCH: Dataset filter SQL: {dataset_filter}")
else:
logger.error(f"🔍 VECTOR_SEARCH: SECURITY ERROR - Dataset IDs are required for search operations")
raise ValueError("Dataset IDs are required for vector search operations. This could mean the agent has no datasets configured or dataset access control failed.")
query_sql = f"""
SELECT
id as chunk_id,
document_id,
dataset_id,
content as text,
metadata as chunk_metadata,
1 - (embedding <=> $1::vector) as similarity,
0.0 as text_relevance,
1 - (embedding <=> $1::vector) as hybrid_score,
ROW_NUMBER() OVER (ORDER BY embedding <=> $1::vector) as rank
FROM {self.schema_name}.document_chunks
WHERE 1 - (embedding <=> $1::vector) >= $2
{dataset_filter}
ORDER BY embedding <=> $1::vector
LIMIT $3
"""
logger.info(f"🔍 VECTOR_SEARCH: Executing SQL query with {len(params)} parameters")
logger.info(f"🔍 VECTOR_SEARCH: SQL: {query_sql}")
logger.info(f"🔍 VECTOR_SEARCH: Params types: embedding={type(query_embedding)} (len={len(query_embedding)}), threshold={type(similarity_threshold)}, limit={type(limit)}")
if dataset_ids:
logger.info(f"🔍 VECTOR_SEARCH: Dataset params: {[type(d) for d in dataset_ids]}")
rows = await conn.fetch(query_sql, *params)
logger.info(f"🔍 VECTOR_SEARCH: Query executed successfully, got {len(rows)} rows")
results = [
HybridSearchResult(
chunk_id=row['chunk_id'],
document_id=row['document_id'],
dataset_id=row['dataset_id'],
text=row['text'],
metadata=row['metadata'] if row['metadata'] else {},
vector_similarity=float(row['similarity']),
text_relevance=0.0,
hybrid_score=float(row['hybrid_score']),
rank=row['rank']
)
for row in rows
]
logger.info(f"Vector search returned {len(results)} results")
return results
except Exception as e:
logger.error(f"Vector similarity search failed: {e}")
raise
async def full_text_search(
self,
query: str,
user_id: str,
dataset_ids: Optional[List[str]] = None,
language: str = 'english',
limit: int = 10
) -> List[HybridSearchResult]:
"""
Full-text search using PostgreSQL's built-in features.
Args:
query: Text query
user_id: User performing search
dataset_ids: Optional dataset filter
language: Text search language configuration
limit: Maximum results
Returns:
Text relevance results
"""
try:
client = await get_postgresql_client()
async with client.get_connection() as conn:
# Resolve user UUID first
resolved_user_id = await self._resolve_user_uuid(conn, user_id)
# RLS context removed - using schema-level isolation instead
# Build dataset filter - REQUIRE dataset_ids for security
dataset_filter = ""
params = [query, limit, resolved_user_id]
if dataset_ids:
placeholders = ",".join(f"${i+4}" for i in range(len(dataset_ids)))
dataset_filter = f"AND dataset_id = ANY(ARRAY[{placeholders}]::uuid[])"
params.extend(dataset_ids)
else:
logger.error(f"🔍 FULL_TEXT_SEARCH: SECURITY ERROR - Dataset IDs are required for search operations")
raise ValueError("Dataset IDs are required for full-text search operations. This could mean the agent has no datasets configured or dataset access control failed.")
query_sql = f"""
SELECT
chunk_id,
document_id,
dataset_id,
content as text,
chunk_metadata as metadata,
0.0 as similarity,
ts_rank_cd(
to_tsvector('{language}', content),
plainto_tsquery('{language}', $1)
) as text_relevance,
ts_rank_cd(
to_tsvector('{language}', content),
plainto_tsquery('{language}', $1)
) as hybrid_score,
ROW_NUMBER() OVER (
ORDER BY ts_rank_cd(
to_tsvector('{language}', content),
plainto_tsquery('{language}', $1)
) DESC
) as rank
FROM {self.schema_name}.document_chunks
WHERE user_id = $3::uuid
AND to_tsvector('{language}', content) @@ plainto_tsquery('{language}', $1)
{dataset_filter}
ORDER BY text_relevance DESC
LIMIT $2
"""
rows = await conn.fetch(query_sql, *params)
results = [
HybridSearchResult(
chunk_id=row['chunk_id'],
document_id=row['document_id'],
dataset_id=row['dataset_id'],
text=row['text'],
metadata=row['metadata'] if row['metadata'] else {},
vector_similarity=0.0,
text_relevance=float(row['text_relevance']),
hybrid_score=float(row['hybrid_score']),
rank=row['rank']
)
for row in rows
]
logger.info(f"Full-text search returned {len(results)} results")
return results
except Exception as e:
logger.error(f"Full-text search failed: {e}")
raise
async def get_document_chunks(
self,
document_id: str,
user_id: str,
include_embeddings: bool = False
) -> List[Dict[str, Any]]:
"""
Get all chunks for a specific document.
Args:
document_id: Target document ID
user_id: User making request
include_embeddings: Whether to include embedding vectors
Returns:
List of document chunks with metadata
"""
try:
client = await get_postgresql_client()
async with client.get_connection() as conn:
# Resolve user UUID first
resolved_user_id = await self._resolve_user_uuid(conn, user_id)
# RLS context removed - using schema-level isolation instead
select_fields = [
"id as chunk_id", "document_id", "dataset_id", "chunk_index",
"content", "metadata as chunk_metadata", "created_at"
]
if include_embeddings:
select_fields.append("embedding")
query_sql = f"""
SELECT {', '.join(select_fields)}
FROM {self.schema_name}.document_chunks
WHERE document_id = $1
AND user_id = $2::uuid
ORDER BY chunk_index
"""
rows = await conn.fetch(query_sql, document_id, resolved_user_id)
chunks = []
for row in rows:
chunk = {
'chunk_id': row['chunk_id'],
'document_id': row['document_id'],
'dataset_id': row['dataset_id'],
'chunk_index': row['chunk_index'],
'content': row['content'],
'metadata': row['chunk_metadata'] if row['chunk_metadata'] else {},
'created_at': row['created_at'].isoformat() if row['created_at'] else None
}
if include_embeddings:
chunk['embedding'] = list(row['embedding']) if row['embedding'] else []
chunks.append(chunk)
logger.info(f"Retrieved {len(chunks)} chunks for document {document_id}")
return chunks
except Exception as e:
logger.error(f"Failed to get document chunks: {e}")
raise
async def search_similar_chunks(
self,
chunk_id: str,
user_id: str,
similarity_threshold: float = 0.5,
limit: int = 5,
exclude_same_document: bool = True
) -> List[HybridSearchResult]:
"""
Find chunks similar to a given chunk.
Args:
chunk_id: Reference chunk ID
user_id: User making request
similarity_threshold: Minimum similarity threshold
limit: Maximum results
exclude_same_document: Whether to exclude chunks from same document
Returns:
Similar chunks ranked by similarity
"""
try:
client = await get_postgresql_client()
async with client.get_connection() as conn:
# Resolve user UUID first
resolved_user_id = await self._resolve_user_uuid(conn, user_id)
# RLS context removed - using schema-level isolation instead
# Get reference chunk embedding
ref_query = f"""
SELECT embedding, document_id
FROM {self.schema_name}.document_chunks
WHERE chunk_id = $1
AND user_id = $2::uuid
"""
ref_result = await conn.fetchrow(ref_query, chunk_id, resolved_user_id)
if not ref_result:
raise ValueError(f"Reference chunk {chunk_id} not found")
ref_embedding = ref_result['embedding']
ref_document_id = ref_result['document_id']
# Build exclusion filter
exclusion_filter = ""
params = [ref_embedding, similarity_threshold, limit, chunk_id, resolved_user_id]
if exclude_same_document:
exclusion_filter = "AND document_id != $6"
params.append(ref_document_id)
# Search for similar chunks
similarity_query = f"""
SELECT
id as chunk_id,
document_id,
dataset_id,
content as text,
metadata as chunk_metadata,
1 - (embedding <=> $1::vector) as similarity
FROM {self.schema_name}.document_chunks
WHERE user_id = $5::uuid
AND id != $4::uuid
AND 1 - (embedding <=> $1::vector) >= $2
{exclusion_filter}
ORDER BY embedding <=> $1::vector
LIMIT $3
"""
rows = await conn.fetch(similarity_query, *params)
results = [
HybridSearchResult(
chunk_id=row['chunk_id'],
document_id=row['document_id'],
dataset_id=row['dataset_id'],
text=row['text'],
metadata=row['metadata'] if row['metadata'] else {},
vector_similarity=float(row['similarity']),
text_relevance=0.0,
hybrid_score=float(row['similarity']),
rank=i+1
)
for i, row in enumerate(rows)
]
logger.info(f"Found {len(results)} similar chunks to {chunk_id}")
return results
except Exception as e:
logger.error(f"Similar chunk search failed: {e}")
raise
# Private helper methods
async def get_dataset_ids_from_documents(
self,
document_ids: List[str],
user_id: str
) -> List[str]:
"""Get unique dataset IDs from a list of document IDs"""
try:
resolved_user_id = await self._resolve_user_id(user_id)
dataset_ids = []
async with self.postgresql_client.get_connection() as conn:
# RLS context removed - using schema-level isolation instead
# Query to get dataset IDs from document IDs
placeholders = ",".join(f"${i+1}" for i in range(len(document_ids)))
query = f"""
SELECT DISTINCT dataset_id
FROM {self.schema_name}.documents
WHERE id = ANY(ARRAY[{placeholders}]::uuid[])
AND user_id = ${len(document_ids)+1}::uuid
"""
params = document_ids + [resolved_user_id]
rows = await conn.fetch(query, *params)
dataset_ids = [str(row['dataset_id']) for row in rows if row['dataset_id']]
logger.info(f"🔍 Resolved {len(dataset_ids)} dataset IDs from {len(document_ids)} documents: {dataset_ids}")
return dataset_ids
except Exception as e:
logger.error(f"Failed to resolve dataset IDs from documents: {e}")
return []
async def _generate_query_embedding(
self,
query: str,
user_id: str
) -> List[float]:
"""Generate embedding for search query using simple BGE-M3 client"""
try:
# Use direct BGE-M3 embedding client with tenant/user for billing
embeddings = await self.embedding_client.generate_embeddings(
[query],
tenant_id=self.tenant_id, # Pass tenant for billing
user_id=user_id # Pass user for billing
)
if not embeddings or not embeddings[0]:
raise ValueError("Failed to generate query embedding")
return embeddings[0]
except Exception as e:
logger.error(f"Query embedding generation failed: {e}")
raise
async def _execute_hybrid_query(
self,
query: str,
query_embedding: List[float],
user_id: str,
dataset_ids: Optional[List[str]],
config: SearchConfig,
limit: int
) -> List[HybridSearchResult]:
"""Execute the hybrid search combining vector + text results"""
try:
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY START: query='{query}', user_id='{user_id}', dataset_ids={dataset_ids}")
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY CONFIG: vector_weight={config.vector_weight}, text_weight={config.text_weight}, limit={limit}")
client = await get_postgresql_client()
async with client.get_connection() as conn:
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Got DB connection, resolving user UUID")
# Resolve user UUID first
actual_user_id = await self._resolve_user_uuid(conn, user_id)
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Resolved user_id to '{actual_user_id}'")
# RLS context removed - using schema-level isolation instead
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Using resolved UUID '{actual_user_id}' for query parameters")
# Build dataset filter
dataset_filter = ""
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Building parameters and dataset filter")
# Convert embedding list to string format for PostgreSQL vector type
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Converted embedding to PostgreSQL vector string (length: {len(embedding_str)})")
# Ensure UUID is properly formatted as string for PostgreSQL
if isinstance(actual_user_id, str):
try:
# Validate it's a proper UUID and convert back to string
validated_uuid = str(uuid_lib.UUID(actual_user_id))
actual_user_id_str = validated_uuid
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Validated UUID format: '{actual_user_id_str}'")
except ValueError:
# If it's not a valid UUID string, keep as is
actual_user_id_str = actual_user_id
logger.warning(f"🔍 _EXECUTE_HYBRID_QUERY: UUID validation failed, using as-is: '{actual_user_id_str}'")
else:
actual_user_id_str = str(actual_user_id)
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Converted user_id to string: '{actual_user_id_str}'")
params = [embedding_str, query, config.min_vector_similarity, config.min_text_relevance, config.max_results]
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Base parameters prepared (count: {len(params)})")
# Handle dataset filtering - REQUIRE dataset_ids for security
if dataset_ids:
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Processing dataset filter for: {dataset_ids}")
# Ensure dataset_ids is a list
if isinstance(dataset_ids, str):
dataset_ids = [dataset_ids]
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Converted string to list: {dataset_ids}")
if len(dataset_ids) > 0:
# Generate proper placeholders for dataset IDs
placeholders = ",".join(f"${i+6}" for i in range(len(dataset_ids)))
dataset_filter = f"AND dataset_id = ANY(ARRAY[{placeholders}]::uuid[])"
params.extend(dataset_ids)
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Dataset filter: {dataset_filter}, dataset_ids: {dataset_ids}")
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Total parameters after dataset filter: {len(params)}")
else:
logger.error(f"🔍 _EXECUTE_HYBRID_QUERY: SECURITY ERROR - Empty dataset_ids list not permitted")
raise ValueError("Dataset IDs cannot be empty. This could mean the agent has no datasets configured or dataset access control failed.")
else:
# SECURITY FIX: No dataset filter when None is NOT ALLOWED
logger.error(f"🔍 _EXECUTE_HYBRID_QUERY: SECURITY ERROR - Dataset IDs are required for search operations")
# More informative error message for debugging
error_msg = "Dataset IDs are required for hybrid search operations. This could mean: " \
"1) Agent has no datasets configured, 2) No datasets selected in UI, or " \
"3) Dataset access control failed. Check agent configuration and dataset permissions."
raise ValueError(error_msg)
# Hybrid search query combining vector similarity and text relevance
hybrid_query = f"""
WITH vector_search AS (
SELECT
id as chunk_id,
document_id,
dataset_id,
content,
metadata as chunk_metadata,
1 - (embedding <=> $1::vector) as vector_similarity,
0.0 as text_relevance
FROM {self.schema_name}.document_chunks
WHERE 1 - (embedding <=> $1::vector) >= $3
{dataset_filter}
),
text_search AS (
SELECT
id as chunk_id,
document_id,
dataset_id,
content,
metadata as chunk_metadata,
0.0 as vector_similarity,
ts_rank_cd(
to_tsvector('english', content),
plainto_tsquery('english', $2)
) as text_relevance
FROM {self.schema_name}.document_chunks
WHERE to_tsvector('english', content) @@ plainto_tsquery('english', $2)
AND ts_rank_cd(
to_tsvector('english', content),
plainto_tsquery('english', $2)
) >= $4
{dataset_filter}
),
combined_results AS (
SELECT
u.chunk_id,
dc.document_id,
dc.dataset_id,
dc.content,
dc.metadata as chunk_metadata,
COALESCE(v.vector_similarity, 0.0) as vector_similarity,
COALESCE(t.text_relevance, 0.0) as text_relevance,
(COALESCE(v.vector_similarity, 0.0) * {config.vector_weight} +
COALESCE(t.text_relevance, 0.0) * {config.text_weight}) as hybrid_score
FROM (
SELECT chunk_id FROM vector_search
UNION
SELECT chunk_id FROM text_search
) u
LEFT JOIN vector_search v USING (chunk_id)
LEFT JOIN text_search t USING (chunk_id)
LEFT JOIN {self.schema_name}.document_chunks dc ON (dc.id = u.chunk_id)
)
SELECT
chunk_id,
document_id,
dataset_id,
content as text,
chunk_metadata as metadata,
vector_similarity,
text_relevance,
hybrid_score,
ROW_NUMBER() OVER (ORDER BY hybrid_score DESC) as rank
FROM combined_results
WHERE hybrid_score > 0.0
ORDER BY hybrid_score DESC
LIMIT $5
"""
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Executing hybrid SQL with {len(params)} parameters")
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Parameter types: {[type(p) for p in params]}")
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Query preview: {hybrid_query[:500]}...")
rows = await conn.fetch(hybrid_query, *params)
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: SQL execution successful, got {len(rows)} rows")
results = []
for i, row in enumerate(rows):
result = HybridSearchResult(
chunk_id=row['chunk_id'],
document_id=row['document_id'],
dataset_id=row['dataset_id'],
text=row['text'],
metadata=row['metadata'] if row['metadata'] else {},
vector_similarity=float(row['vector_similarity']),
text_relevance=float(row['text_relevance']),
hybrid_score=float(row['hybrid_score']),
rank=row['rank']
)
results.append(result)
if i < 3: # Log first few results for debugging
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY: Result {i+1}: chunk_id='{result.chunk_id}', score={result.hybrid_score:.3f}")
logger.info(f"🔍 _EXECUTE_HYBRID_QUERY COMPLETE: Processed {len(results)} results")
return results
except Exception as e:
logger.error(f"🔍 _EXECUTE_HYBRID_QUERY ERROR: {e}")
logger.exception("Full hybrid query execution error traceback:")
raise
async def _rerank_results(
self,
results: List[HybridSearchResult],
query: str,
config: SearchConfig
) -> List[HybridSearchResult]:
"""
Apply advanced re-ranking to search results.
This can include:
- Query-document interaction features
- Diversity scoring
- Recency weighting
- User preference learning
"""
try:
# For now, apply simple diversity re-ranking
# to avoid showing too many results from the same document
reranked = []
document_counts = {}
max_per_document = max(1, len(results) // 3) # At most 1/3 from same document
for result in results:
doc_count = document_counts.get(result.document_id, 0)
if doc_count < max_per_document:
reranked.append(result)
document_counts[result.document_id] = doc_count + 1
# Re-rank the remaining items
remaining = [r for r in results if r not in reranked]
reranked.extend(remaining)
# Update rank numbers
for i, result in enumerate(reranked):
result.rank = i + 1
return reranked
except Exception as e:
logger.warning(f"Re-ranking failed, returning original results: {e}")
return results
async def _resolve_user_uuid(self, conn: asyncpg.Connection, user_id: str) -> str:
"""
Resolve user email to UUID if needed.
Returns a validated UUID string.
"""
logger.info(f"🔍 _RESOLVE_USER_UUID START: input user_id='{user_id}' (type: {type(user_id)})")
if "@" in user_id: # If user_id is an email, look up the UUID
logger.info(f"🔍 _RESOLVE_USER_UUID: Detected email format, looking up UUID for '{user_id}'")
user_lookup_sql = f"SELECT id FROM {self.schema_name}.users WHERE email = $1"
logger.info(f"🔍 _RESOLVE_USER_UUID: Executing SQL: {user_lookup_sql}")
user_row = await conn.fetchrow(user_lookup_sql, user_id)
if user_row:
resolved_uuid = str(user_row['id'])
logger.info(f"🔍 _RESOLVE_USER_UUID: Found UUID '{resolved_uuid}' for email '{user_id}'")
return resolved_uuid
else:
logger.error(f"🔍 _RESOLVE_USER_UUID ERROR: User not found for email: {user_id}")
raise ValueError(f"User not found: {user_id}")
else:
# Already a UUID
logger.info(f"🔍 _RESOLVE_USER_UUID: Input '{user_id}' is already UUID format, returning as-is")
return user_id
# _set_rls_context method removed - using schema-level isolation instead of RLS
# Factory function for dependency injection
def get_pgvector_search_service(tenant_id: str, user_id: Optional[str] = None) -> PGVectorSearchService:
"""Get PGVector search service instance"""
return PGVectorSearchService(tenant_id=tenant_id, user_id=user_id)

View File

@@ -0,0 +1,883 @@
"""
GT 2.0 PostgreSQL File Storage Service
Replaces MinIO with PostgreSQL-based file storage using:
- BYTEA for small files (<10MB)
- PostgreSQL Large Objects (LOBs) for large files (10MB-1GB)
- Filesystem with metadata for massive files (>1GB)
Provides perfect tenant isolation through PostgreSQL schemas.
"""
import asyncio
import json
import logging
import os
import hashlib
import mimetypes
from typing import Dict, Any, List, Optional, BinaryIO, AsyncIterator, Tuple
from datetime import datetime, timedelta
from pathlib import Path
import aiofiles
from fastapi import UploadFile
from app.core.postgresql_client import get_postgresql_client
from app.core.config import get_settings
from app.core.permissions import ADMIN_ROLES
from app.core.path_security import sanitize_tenant_domain, sanitize_filename, safe_join_path
logger = logging.getLogger(__name__)
class PostgreSQLFileService:
"""PostgreSQL-based file storage service with tenant isolation"""
# Storage type thresholds
SMALL_FILE_THRESHOLD = 10 * 1024 * 1024 # 10MB - use BYTEA
LARGE_FILE_THRESHOLD = 1024 * 1024 * 1024 # 1GB - use LOBs
def __init__(self, tenant_domain: str, user_id: str, user_role: str = "user"):
self.tenant_domain = tenant_domain
self.user_id = user_id
self.user_role = user_role
self.settings = get_settings()
# Filesystem path for massive files (>1GB)
# Sanitize tenant_domain to prevent path traversal
safe_tenant = sanitize_tenant_domain(tenant_domain)
self.filesystem_base = Path("/data") / safe_tenant / "files" # codeql[py/path-injection] sanitize_tenant_domain() validates input
self.filesystem_base.mkdir(parents=True, exist_ok=True, mode=0o700)
logger.info(f"PostgreSQL file service initialized for {tenant_domain}/{user_id} (role: {user_role})")
async def store_file(
self,
file: UploadFile,
dataset_id: Optional[str] = None,
category: str = "documents"
) -> Dict[str, Any]:
"""Store file using appropriate PostgreSQL strategy"""
try:
logger.info(f"PostgreSQL file service: storing file {file.filename} for tenant {self.tenant_domain}, user {self.user_id}")
logger.info(f"Dataset ID: {dataset_id}, Category: {category}")
# Read file content
content = await file.read()
file_size = len(content)
# Generate file metadata
file_hash = hashlib.sha256(content).hexdigest()[:16]
content_type = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
# Handle different file types with appropriate processing
if file_size <= self.SMALL_FILE_THRESHOLD and content_type.startswith('text/'):
# Small text files stored directly
storage_type = "text"
storage_ref = "content_text"
try:
text_content = content.decode('utf-8')
except UnicodeDecodeError:
text_content = content.decode('latin-1') # Fallback encoding
elif content_type == 'application/pdf':
# PDF files: extract text content, store binary separately
storage_type = "pdf_extracted"
storage_ref = "content_text"
text_content = await self._extract_pdf_text(content)
else:
# Other binary files: store as base64 for now
import base64
storage_type = "base64"
storage_ref = "content_text"
text_content = base64.b64encode(content).decode('utf-8')
# Get PostgreSQL client
logger.info("Getting PostgreSQL client")
pg_client = await get_postgresql_client()
# Always expect user_id to be a UUID string - no email lookups
logger.info(f"Using user UUID: {self.user_id}")
# Validate user_id is a valid UUID format
try:
import uuid
user_uuid = str(uuid.UUID(self.user_id))
except (ValueError, TypeError) as e:
logger.error(f"Invalid user UUID format: {self.user_id}, error: {e}")
raise ValueError(f"Invalid user ID format. Expected UUID, got: {self.user_id}")
logger.info(f"Validated user UUID: {user_uuid}")
# 1. Validate user_uuid is present
if not user_uuid:
raise ValueError("User UUID is required but not found")
# 2. Validate and clean dataset_id
dataset_uuid_param = None
if dataset_id and dataset_id.strip() and dataset_id != "":
try:
import uuid
dataset_uuid_param = str(uuid.UUID(dataset_id.strip()))
logger.info(f"Dataset UUID validated: {dataset_uuid_param}")
except ValueError as e:
logger.error(f"Invalid dataset UUID: {dataset_id}, error: {e}")
raise ValueError(f"Invalid dataset ID format: {dataset_id}")
else:
logger.info("No dataset_id provided, using NULL")
# 3. Validate file content and metadata
if not file.filename or not file.filename.strip():
raise ValueError("Filename cannot be empty")
if not content:
raise ValueError("File content cannot be empty")
# 4. Generate and validate all string parameters
safe_filename = f"{file_hash}_{file.filename}"
safe_original_filename = file.filename.strip()
safe_content_type = content_type or "application/octet-stream"
safe_file_hash = file_hash
safe_metadata = json.dumps({
"storage_type": storage_type,
"storage_ref": storage_ref,
"category": category
})
logger.info(f"All parameters validated:")
logger.info(f" user_uuid: {user_uuid}")
logger.info(f" dataset_uuid: {dataset_uuid_param}")
logger.info(f" filename: {safe_filename}")
logger.info(f" original_filename: {safe_original_filename}")
logger.info(f" file_type: {safe_content_type}")
logger.info(f" file_size: {file_size}")
logger.info(f" file_hash: {safe_file_hash}")
# Store metadata in documents table (using existing schema)
try:
# Application user now has BYPASSRLS privilege - no RLS context needed
logger.info("Storing document with BYPASSRLS privilege")
# Require dataset_id for all document uploads
if not dataset_uuid_param:
raise ValueError("dataset_id is required for document uploads")
logger.info(f"Storing document with dataset_id: {dataset_uuid_param}")
logger.info(f"Document details: {safe_filename} ({file_size} bytes)")
# Insert with dataset_id
# Determine if content is searchable (under PostgreSQL tsvector size limit)
is_searchable = text_content is None or len(text_content) < 1048575
async with pg_client.get_connection() as conn:
# Get tenant_id for the document
tenant_id = await conn.fetchval("""
SELECT id FROM tenants WHERE domain = $1 LIMIT 1
""", self.tenant_domain)
if not tenant_id:
raise ValueError(f"Tenant not found for domain: {self.tenant_domain}")
document_id = await conn.fetchval("""
INSERT INTO documents (
tenant_id, user_id, dataset_id, filename, original_filename,
file_type, file_size_bytes, file_hash, content_text, processing_status,
metadata, is_searchable, created_at, updated_at
) VALUES (
$1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8, $9, 'pending', $10, $11, NOW(), NOW()
)
RETURNING id
""",
tenant_id, user_uuid, dataset_uuid_param, safe_filename, safe_original_filename,
safe_content_type, file_size, safe_file_hash, text_content,
safe_metadata, is_searchable
)
logger.info(f"Document inserted successfully with ID: {document_id}")
except Exception as db_error:
logger.error(f"Database insertion failed: {db_error}")
logger.error(f"Tenant domain: {self.tenant_domain}")
logger.error(f"User ID: {self.user_id}")
logger.error(f"Dataset ID: {dataset_id}")
raise
result = {
"id": document_id,
"filename": file.filename,
"content_type": content_type,
"file_size": file_size,
"file_hash": file_hash,
"storage_type": storage_type,
"storage_ref": storage_ref,
"upload_timestamp": datetime.utcnow().isoformat(),
"download_url": f"/api/v1/files/{document_id}"
}
logger.info(f"Stored file {file.filename} ({file_size} bytes) as {storage_type} for user {self.user_id}")
# Trigger document processing pipeline for RAG functionality
try:
await self._trigger_document_processing(document_id, dataset_id, user_uuid, file.filename)
logger.info(f"Successfully triggered document processing for {document_id}")
except Exception as process_error:
logger.error(f"Failed to trigger document processing for {document_id}: {process_error}")
# Update document status to show processing failed
try:
pg_client = await get_postgresql_client()
await pg_client.execute_command(
"UPDATE documents SET processing_status = 'failed', error_message = $1 WHERE id = $2",
f"Processing failed: {str(process_error)}", document_id
)
except Exception as update_error:
logger.error(f"Failed to update document status after processing error: {update_error}")
# Don't fail the upload if processing trigger fails - user can retry manually
return result
except Exception as e:
logger.error(f"Failed to store file {file.filename}: {e}")
raise
finally:
# Ensure content is cleared from memory
if 'content' in locals():
del content
async def _store_as_bytea(
self,
content: bytes,
filename: str,
content_type: str,
file_hash: str,
dataset_id: Optional[str],
category: str
) -> str:
"""Store small file as BYTEA in documents table"""
pg_client = await get_postgresql_client()
# Store file content directly in BYTEA column
# This will be handled by the main insert in store_file
return "bytea_column"
async def _store_as_lob(
self,
content: bytes,
filename: str,
content_type: str,
file_hash: str,
dataset_id: Optional[str],
category: str
) -> str:
"""Store large file as PostgreSQL Large Object"""
pg_client = await get_postgresql_client()
# Create Large Object and get OID
async with pg_client.get_connection() as conn:
# Start transaction for LOB operations
async with conn.transaction():
# Create LOB and get OID
lob_oid = await conn.fetchval("SELECT lo_create(0)")
# Open LOB for writing
lob_fd = await conn.fetchval("SELECT lo_open($1, 131072)", lob_oid) # INV_WRITE mode
# Write content in chunks for memory efficiency
chunk_size = 8192
offset = 0
for i in range(0, len(content), chunk_size):
chunk = content[i:i + chunk_size]
await conn.execute("SELECT lo_write($1, $2)", lob_fd, chunk)
offset += len(chunk)
# Close LOB
await conn.execute("SELECT lo_close($1)", lob_fd)
logger.info(f"Created PostgreSQL LOB with OID {lob_oid} for {filename}")
return str(lob_oid)
async def _store_as_filesystem(
self,
content: bytes,
filename: str,
content_type: str,
file_hash: str,
dataset_id: Optional[str],
category: str
) -> str:
"""Store massive file on filesystem with PostgreSQL metadata"""
# Create secure file path with user isolation
user_dir = self.filesystem_base / self.user_id / category
if dataset_id:
user_dir = user_dir / dataset_id
user_dir.mkdir(parents=True, exist_ok=True, mode=0o700)
# Generate secure filename
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
secure_filename = f"{timestamp}_{file_hash}_{filename}"
file_path = user_dir / secure_filename
# Write file with secure permissions
async with aiofiles.open(file_path, 'wb') as f:
await f.write(content)
# Set secure file permissions
os.chmod(file_path, 0o600)
logger.info(f"Stored large file on filesystem: {file_path}")
return str(file_path)
async def get_file(self, document_id: str) -> AsyncIterator[bytes]:
"""Stream file content by document ID"""
try:
pg_client = await get_postgresql_client()
# Validate user_id is a valid UUID format
try:
import uuid
user_uuid = str(uuid.UUID(self.user_id))
except (ValueError, TypeError) as e:
logger.error(f"Invalid user UUID format: {self.user_id}, error: {e}")
raise ValueError(f"Invalid user ID format. Expected UUID, got: {self.user_id}")
# Get document metadata using UUID directly
# Admins can access any document in their tenant, regular users only their own
if self.user_role in ADMIN_ROLES:
doc_info = await pg_client.fetch_one("""
SELECT metadata, file_size_bytes, filename, content_text
FROM documents d
WHERE d.id = $1
AND d.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
""", document_id, self.tenant_domain)
else:
doc_info = await pg_client.fetch_one("""
SELECT metadata, file_size_bytes, filename, content_text
FROM documents
WHERE id = $1 AND user_id = $2::uuid
""", document_id, user_uuid)
if not doc_info:
raise FileNotFoundError(f"Document {document_id} not found")
# Get storage info from metadata - handle JSON string or dict
metadata_raw = doc_info["metadata"] or "{}"
if isinstance(metadata_raw, str):
import json
metadata = json.loads(metadata_raw)
else:
metadata = metadata_raw or {}
storage_type = metadata.get("storage_type", "text")
if storage_type == "text":
# Text content stored directly
if doc_info["content_text"]:
content_bytes = doc_info["content_text"].encode('utf-8')
async for chunk in self._stream_from_bytea(content_bytes):
yield chunk
else:
raise FileNotFoundError(f"Document content not found")
elif storage_type == "base64":
# Base64 encoded binary content
if doc_info["content_text"]:
import base64
content_bytes = base64.b64decode(doc_info["content_text"])
async for chunk in self._stream_from_bytea(content_bytes):
yield chunk
else:
raise FileNotFoundError(f"Document content not found")
elif storage_type == "lob":
# Stream from PostgreSQL LOB
storage_ref = metadata.get("storage_ref", "")
async for chunk in self._stream_from_lob(int(storage_ref)):
yield chunk
elif storage_type == "filesystem":
# Stream from filesystem
storage_ref = metadata.get("storage_ref", "")
async for chunk in self._stream_from_filesystem(storage_ref):
yield chunk
else:
# Default: treat as text content
if doc_info["content_text"]:
content_bytes = doc_info["content_text"].encode('utf-8')
async for chunk in self._stream_from_bytea(content_bytes):
yield chunk
else:
raise FileNotFoundError(f"Document content not found")
except Exception as e:
logger.error(f"Failed to get file {document_id}: {e}")
raise
async def _stream_from_bytea(self, content: bytes) -> AsyncIterator[bytes]:
"""Stream content from BYTEA in chunks"""
chunk_size = 8192
for i in range(0, len(content), chunk_size):
yield content[i:i + chunk_size]
async def _stream_from_lob(self, lob_oid: int) -> AsyncIterator[bytes]:
"""Stream content from PostgreSQL Large Object"""
pg_client = await get_postgresql_client()
async with pg_client.get_connection() as conn:
async with conn.transaction():
# Open LOB for reading
lob_fd = await conn.fetchval("SELECT lo_open($1, 262144)", lob_oid) # INV_READ mode
# Stream in chunks
chunk_size = 8192
while True:
chunk = await conn.fetchval("SELECT lo_read($1, $2)", lob_fd, chunk_size)
if not chunk:
break
yield chunk
# Close LOB
await conn.execute("SELECT lo_close($1)", lob_fd)
async def _stream_from_filesystem(self, file_path: str) -> AsyncIterator[bytes]:
"""Stream content from filesystem"""
# Verify file belongs to tenant (security check)
path_obj = Path(file_path)
if not str(path_obj).startswith(str(self.filesystem_base)):
raise PermissionError("Access denied to file")
if not path_obj.exists():
raise FileNotFoundError(f"File not found: {file_path}")
async with aiofiles.open(file_path, 'rb') as f:
chunk_size = 8192
while True:
chunk = await f.read(chunk_size)
if not chunk:
break
yield chunk
async def delete_file(self, document_id: str) -> bool:
"""Delete file and metadata"""
try:
pg_client = await get_postgresql_client()
# Validate user_id is a valid UUID format
try:
import uuid
user_uuid = str(uuid.UUID(self.user_id))
except (ValueError, TypeError) as e:
logger.error(f"Invalid user UUID format: {self.user_id}, error: {e}")
raise ValueError(f"Invalid user ID format. Expected UUID, got: {self.user_id}")
# Get document info before deletion
# Admins can delete any document in their tenant, regular users only their own
if self.user_role in ADMIN_ROLES:
doc_info = await pg_client.fetch_one("""
SELECT storage_type, storage_ref FROM documents d
WHERE d.id = $1
AND d.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
""", document_id, self.tenant_domain)
else:
doc_info = await pg_client.fetch_one("""
SELECT storage_type, storage_ref FROM documents
WHERE id = $1 AND user_id = $2::uuid
""", document_id, user_uuid)
if not doc_info:
logger.warning(f"Document {document_id} not found for deletion")
return False
storage_type = doc_info["storage_type"]
storage_ref = doc_info["storage_ref"]
# Delete file content based on storage type
if storage_type == "lob":
# Delete LOB
async with pg_client.get_connection() as conn:
await conn.execute("SELECT lo_unlink($1)", int(storage_ref))
elif storage_type == "filesystem":
# Delete filesystem file
try:
path_obj = Path(storage_ref)
if path_obj.exists():
path_obj.unlink()
except Exception as e:
logger.warning(f"Failed to delete filesystem file {storage_ref}: {e}")
# BYTEA files are deleted with the row
# Delete metadata record
if self.user_role in ADMIN_ROLES:
deleted = await pg_client.execute_command("""
DELETE FROM documents d
WHERE d.id = $1
AND d.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
""", document_id, self.tenant_domain)
else:
deleted = await pg_client.execute_command("""
DELETE FROM documents WHERE id = $1 AND user_id = $2::uuid
""", document_id, user_uuid)
if deleted > 0:
logger.info(f"Deleted file {document_id} ({storage_type})")
return True
else:
return False
except Exception as e:
logger.error(f"Failed to delete file {document_id}: {e}")
return False
async def get_file_info(self, document_id: str) -> Dict[str, Any]:
"""Get file metadata"""
try:
pg_client = await get_postgresql_client()
# Validate user_id is a valid UUID format
try:
import uuid
user_uuid = str(uuid.UUID(self.user_id))
except (ValueError, TypeError) as e:
logger.error(f"Invalid user UUID format: {self.user_id}, error: {e}")
raise ValueError(f"Invalid user ID format. Expected UUID, got: {self.user_id}")
# Admins can access any document metadata in their tenant, regular users only their own
if self.user_role in ADMIN_ROLES:
doc_info = await pg_client.fetch_one("""
SELECT id, filename, original_filename, file_type as content_type, file_size_bytes as file_size,
file_hash, dataset_id, metadata->'storage_type' as storage_type, metadata->'category' as category, created_at
FROM documents d
WHERE d.id = $1
AND d.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
""", document_id, self.tenant_domain)
else:
doc_info = await pg_client.fetch_one("""
SELECT id, filename, original_filename, file_type as content_type, file_size_bytes as file_size,
file_hash, dataset_id, metadata->'storage_type' as storage_type, metadata->'category' as category, created_at
FROM documents
WHERE id = $1 AND user_id = $2::uuid
""", document_id, user_uuid)
if not doc_info:
raise FileNotFoundError(f"Document {document_id} not found")
return {
"id": doc_info["id"],
"filename": doc_info["filename"],
"original_filename": doc_info["original_filename"],
"content_type": doc_info["content_type"],
"file_size": doc_info["file_size"],
"file_hash": doc_info["file_hash"],
"dataset_id": str(doc_info["dataset_id"]) if doc_info["dataset_id"] else None,
"storage_type": doc_info["storage_type"],
"category": doc_info["category"],
"created_at": doc_info["created_at"].isoformat(),
"download_url": f"/api/v1/files/{document_id}"
}
except Exception as e:
logger.error(f"Failed to get file info for {document_id}: {e}")
raise
async def list_files(
self,
dataset_id: Optional[str] = None,
category: str = "documents",
limit: int = 50,
offset: int = 0
) -> List[Dict[str, Any]]:
"""List user files with optional filtering"""
try:
pg_client = await get_postgresql_client()
# Validate user_id is a valid UUID format
try:
import uuid
user_uuid = str(uuid.UUID(self.user_id))
except (ValueError, TypeError) as e:
logger.error(f"Invalid user UUID format: {self.user_id}, error: {e}")
raise ValueError(f"Invalid user ID format. Expected UUID, got: {self.user_id}")
# Build permission-aware query
# Admins can list any documents in their tenant
# Regular users can list documents they own OR documents in datasets they can access
if self.user_role in ADMIN_ROLES:
where_clauses = ["d.tenant_id = (SELECT id FROM tenants WHERE domain = $1)"]
params = [self.tenant_domain]
param_idx = 2
else:
# Non-admin users can see:
# 1. Documents they own
# 2. Documents in datasets with access_group = 'organization'
# 3. Documents in datasets they're a member of (team access)
where_clauses = ["""(
d.user_id = $1::uuid
OR EXISTS (
SELECT 1 FROM datasets ds
WHERE ds.id = d.dataset_id
AND ds.tenant_id = (SELECT id FROM tenants WHERE domain = $2)
AND (
ds.access_group = 'organization'
OR (ds.access_group = 'team' AND $1::uuid = ANY(ds.team_members))
)
)
)"""]
params = [user_uuid, self.tenant_domain]
param_idx = 3
if dataset_id:
where_clauses.append(f"d.dataset_id = ${param_idx}::uuid")
params.append(dataset_id)
param_idx += 1
if category:
where_clauses.append(f"(d.metadata->>'category' = ${param_idx} OR d.metadata->>'category' IS NULL)")
params.append(category)
param_idx += 1
query = f"""
SELECT d.id, d.filename, d.original_filename, d.file_type as content_type, d.file_size_bytes as file_size,
d.metadata->>'storage_type' as storage_type, d.metadata->>'category' as category, d.created_at, d.updated_at, d.dataset_id,
d.processing_status, d.metadata, d.user_id, COUNT(dc.id) as chunk_count,
ds.created_by as dataset_owner_id
FROM documents d
LEFT JOIN document_chunks dc ON d.id = dc.document_id
LEFT JOIN datasets ds ON d.dataset_id = ds.id
WHERE {' AND '.join(where_clauses)}
GROUP BY d.id, d.filename, d.original_filename, d.file_type, d.file_size_bytes, d.metadata, d.created_at, d.updated_at, d.dataset_id, d.processing_status, d.user_id, ds.created_by
ORDER BY d.created_at DESC LIMIT ${param_idx} OFFSET ${param_idx + 1}
"""
params.extend([limit, offset])
files = await pg_client.execute_query(query, *params)
# Helper function to parse metadata
def parse_metadata(metadata_value):
if metadata_value is None:
return {}
if isinstance(metadata_value, str):
import json
try:
return json.loads(metadata_value)
except (json.JSONDecodeError, ValueError):
return {}
return metadata_value if isinstance(metadata_value, dict) else {}
return [
{
"id": file["id"],
"filename": file["filename"],
"original_filename": file["original_filename"],
"content_type": file["content_type"],
"file_type": file["content_type"],
"file_size": file["file_size"],
"file_size_bytes": file["file_size"],
"dataset_id": file["dataset_id"],
"storage_type": file["storage_type"],
"category": file["category"],
"created_at": file["created_at"].isoformat(),
"updated_at": file["updated_at"].isoformat() if file.get("updated_at") else None,
"processing_status": file.get("processing_status", "pending"),
"chunk_count": file.get("chunk_count", 0),
"chunks_processed": parse_metadata(file.get("metadata")).get("chunks_processed", 0),
"total_chunks_expected": parse_metadata(file.get("metadata")).get("total_chunks_expected", 0),
"processing_progress": parse_metadata(file.get("metadata")).get("processing_progress", 0),
"processing_stage": parse_metadata(file.get("metadata")).get("processing_stage"),
"download_url": f"/api/v1/files/{file['id']}",
# Permission flags - user can delete if:
# 1. They are admin, OR
# 2. They uploaded the document, OR
# 3. They own the parent dataset
"can_delete": (
self.user_role in ADMIN_ROLES or
file["user_id"] == user_uuid or
(file.get("dataset_owner_id") and str(file["dataset_owner_id"]) == user_uuid)
)
}
for file in files
]
except Exception as e:
logger.error(f"Failed to list files for user {self.user_id}: {e}")
return []
async def cleanup_orphaned_files(self) -> int:
"""Clean up orphaned files and LOBs"""
try:
pg_client = await get_postgresql_client()
cleanup_count = 0
# Find orphaned LOBs (LOBs without corresponding document records)
async with pg_client.get_connection() as conn:
async with conn.transaction():
orphaned_lobs = await conn.fetch("""
SELECT lo.oid FROM pg_largeobject_metadata lo
LEFT JOIN documents d ON lo.oid::text = d.storage_ref
WHERE d.storage_ref IS NULL AND d.storage_type = 'lob'
""")
for lob in orphaned_lobs:
await conn.execute("SELECT lo_unlink($1)", lob["oid"])
cleanup_count += 1
# Find orphaned filesystem files
# Note: This would require more complex logic to safely identify orphans
logger.info(f"Cleaned up {cleanup_count} orphaned files")
return cleanup_count
except Exception as e:
logger.error(f"Failed to cleanup orphaned files: {e}")
return 0
async def _trigger_document_processing(
self,
document_id: str,
dataset_id: Optional[str],
user_uuid: str,
filename: str
):
"""Trigger document processing pipeline for RAG functionality"""
try:
# Import here to avoid circular imports
from app.services.document_processor import get_document_processor
logger.info(f"Triggering document processing for {document_id}")
# Get document processor instance
processor = await get_document_processor(tenant_domain=self.tenant_domain)
# For documents uploaded via PostgreSQL file service, the content is already stored
# We need to process it from the database content rather than a file path
await self._process_document_from_database(
processor, document_id, dataset_id, user_uuid, filename
)
except Exception as e:
logger.error(f"Document processing trigger failed for {document_id}: {e}")
# Update document status to failed
try:
pg_client = await get_postgresql_client()
await pg_client.execute_command(
"UPDATE documents SET processing_status = 'failed', error_message = $1 WHERE id = $2",
f"Processing trigger failed: {str(e)}", document_id
)
except Exception as update_error:
logger.error(f"Failed to update document status to failed: {update_error}")
raise
async def _process_document_from_database(
self,
processor,
document_id: str,
dataset_id: Optional[str],
user_uuid: str,
filename: str
):
"""Process document using content already stored in database"""
try:
import tempfile
import os
from pathlib import Path
# Get document content from database
pg_client = await get_postgresql_client()
doc_info = await pg_client.fetch_one("""
SELECT content_text, file_type, metadata
FROM documents
WHERE id = $1 AND user_id = $2::uuid
""", document_id, user_uuid)
if not doc_info or not doc_info["content_text"]:
raise ValueError("Document content not found in database")
# Create temporary file with the content
# Sanitize the file extension to prevent path injection
safe_suffix = sanitize_filename(filename)
safe_suffix = Path(safe_suffix).suffix if safe_suffix else ".tmp"
# codeql[py/path-injection] safe_suffix is sanitized via sanitize_filename()
with tempfile.NamedTemporaryFile(mode='w', suffix=safe_suffix, delete=False) as temp_file:
# Handle different storage types - metadata might be JSON string or dict
metadata_raw = doc_info["metadata"] or "{}"
if isinstance(metadata_raw, str):
import json
metadata = json.loads(metadata_raw)
else:
metadata = metadata_raw or {}
storage_type = metadata.get("storage_type", "text")
if storage_type == "text":
temp_file.write(doc_info["content_text"])
elif storage_type == "base64":
import base64
content_bytes = base64.b64decode(doc_info["content_text"])
temp_file.close()
with open(temp_file.name, 'wb') as binary_file:
binary_file.write(content_bytes)
elif storage_type == "pdf_extracted":
# For PDFs with extracted text, create a placeholder text file
# since the actual text content is already extracted
temp_file.write(doc_info["content_text"])
else:
temp_file.write(doc_info["content_text"])
temp_file_path = Path(temp_file.name)
try:
# Process the document using the existing document processor
await processor.process_file(
file_path=temp_file_path,
dataset_id=dataset_id, # Keep None as None - don't convert to empty string
user_id=user_uuid,
original_filename=filename,
document_id=document_id # Use existing document instead of creating new one
)
logger.info(f"Successfully processed document {document_id} from database content")
finally:
# Clean up temporary file
try:
os.unlink(temp_file_path)
except Exception as cleanup_error:
logger.warning(f"Failed to cleanup temporary file {temp_file_path}: {cleanup_error}")
except Exception as e:
logger.error(f"Failed to process document from database content: {e}")
raise
async def _extract_pdf_text(self, content: bytes) -> str:
"""Extract text content from PDF bytes using pypdf"""
import io
import pypdf as PyPDF2 # pypdf is the maintained successor to PyPDF2
try:
# Create BytesIO object from content
pdf_stream = io.BytesIO(content)
pdf_reader = PyPDF2.PdfReader(pdf_stream)
text_parts = []
for page_num, page in enumerate(pdf_reader.pages):
try:
page_text = page.extract_text()
if page_text.strip():
text_parts.append(f"--- Page {page_num + 1} ---\n{page_text}")
except Exception as e:
logger.warning(f"Could not extract text from page {page_num + 1}: {e}")
if not text_parts:
# If no text could be extracted, return a placeholder
return f"PDF document with {len(pdf_reader.pages)} pages (text extraction failed)"
extracted_text = "\n\n".join(text_parts)
logger.info(f"Successfully extracted {len(extracted_text)} characters from PDF with {len(pdf_reader.pages)} pages")
return extracted_text
except Exception as e:
logger.error(f"PDF text extraction failed: {e}")
# Return a fallback description instead of failing completely
return f"PDF document (text extraction failed: {str(e)})"

View File

@@ -0,0 +1,577 @@
"""
RAG Orchestrator Service for GT 2.0
Coordinates RAG operations between chat, MCP tools, and datasets.
Provides intelligent context retrieval and source attribution.
"""
import logging
import asyncio
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime
import httpx
from app.services.mcp_integration import MCPIntegrationService, MCPExecutionResult
from app.services.pgvector_search_service import PGVectorSearchService, get_pgvector_search_service
from app.models.agent import Agent
from app.models.assistant_dataset import AssistantDataset
logger = logging.getLogger(__name__)
@dataclass
class RAGContext:
"""Context retrieved from RAG operations"""
chunks: List[Dict[str, Any]]
sources: List[Dict[str, Any]]
search_queries: List[str]
total_chunks: int
retrieval_time_ms: float
datasets_used: List[str]
@dataclass
class RAGSearchParams:
"""Parameters for RAG search operations"""
query: str
dataset_ids: Optional[List[str]] = None
max_chunks: int = 5
similarity_threshold: float = 0.7
search_method: str = "hybrid" # hybrid, vector, text
class RAGOrchestrator:
"""
Orchestrates RAG operations for enhanced chat responses.
Coordinates between:
- Dataset search via MCP RAG server
- Conversation history via MCP conversation server
- Direct PGVector search for performance
- Agent dataset bindings for context filtering
"""
def __init__(self, tenant_domain: str, user_id: str):
self.tenant_domain = tenant_domain
self.user_id = user_id
self.mcp_service = MCPIntegrationService()
self.resource_cluster_url = "http://resource-cluster:8000"
async def get_rag_context(
self,
agent: Agent,
user_message: str,
conversation_id: str,
params: Optional[RAGSearchParams] = None
) -> RAGContext:
"""
Get comprehensive RAG context for a chat message.
Args:
agent: Agent instance with dataset bindings
user_message: User's message/query
conversation_id: Current conversation ID
params: Optional search parameters
Returns:
RAGContext with relevant chunks and sources
"""
start_time = datetime.now()
if params is None:
params = RAGSearchParams(query=user_message)
try:
# Get agent's dataset IDs for search (unchanged)
agent_dataset_ids = await self._get_agent_datasets(agent)
# Get conversation files if conversation exists (NEW: simplified approach)
conversation_files = []
if conversation_id:
conversation_files = await self._get_conversation_files(conversation_id)
# Determine search strategy
search_dataset_ids = params.dataset_ids or agent_dataset_ids
# Check if we have any sources to search
if not search_dataset_ids and not conversation_files:
logger.info(f"No search sources available - agent {agent.id if agent else 'none'}")
return RAGContext(
chunks=[],
sources=[],
search_queries=[params.query],
total_chunks=0,
retrieval_time_ms=0.0,
datasets_used=[]
)
# Prepare search tasks for dual-source search
search_tasks = []
# Task 1: Dataset search via MCP RAG server (unchanged)
if search_dataset_ids:
search_tasks.append(
self._search_datasets_via_mcp(params.query, search_dataset_ids, params)
)
# Task 2: Conversation files search (NEW: direct search)
if conversation_files:
search_tasks.append(
self._search_conversation_files(params.query, conversation_id)
)
# Execute searches in parallel
search_results = await asyncio.gather(*search_tasks, return_exceptions=True)
# Process search results from all sources
all_chunks = []
all_sources = []
result_index = 0
# Process dataset search results (if performed)
if search_dataset_ids and result_index < len(search_results):
dataset_result = search_results[result_index]
if not isinstance(dataset_result, Exception) and dataset_result.get("success"):
dataset_chunks = dataset_result.get("results", [])
all_chunks.extend(dataset_chunks)
all_sources.extend(self._extract_sources(dataset_chunks))
result_index += 1
# Process conversation files search results (if performed)
if conversation_files and result_index < len(search_results):
conversation_result = search_results[result_index]
if not isinstance(conversation_result, Exception):
conversation_chunks = conversation_result or []
all_chunks.extend(conversation_chunks)
all_sources.extend(
await self._extract_conversation_file_sources(conversation_chunks, conversation_id)
)
result_index += 1
# Rank and filter results based on agent preferences (now using all chunks)
final_chunks = await self._rank_and_filter_chunks(
all_chunks, agent_dataset_ids, params
)
retrieval_time = (datetime.now() - start_time).total_seconds() * 1000
logger.info(
f"RAG context retrieved: {len(final_chunks)} chunks from "
f"{len(search_dataset_ids)} datasets + {len(conversation_files)} conversation files "
f"in {retrieval_time:.1f}ms"
)
return RAGContext(
chunks=final_chunks,
sources=all_sources,
search_queries=[params.query],
total_chunks=len(final_chunks),
retrieval_time_ms=retrieval_time,
datasets_used=search_dataset_ids
)
except Exception as e:
logger.error(f"RAG context retrieval failed: {e}")
retrieval_time = (datetime.now() - start_time).total_seconds() * 1000
# Return empty context on failure
return RAGContext(
chunks=[],
sources=[],
search_queries=[params.query],
total_chunks=0,
retrieval_time_ms=retrieval_time,
datasets_used=[]
)
async def _get_agent_datasets(self, agent: Agent) -> List[str]:
"""Get dataset IDs for an agent (simplified)"""
try:
# Get agent configuration from agent service (skip complex table lookup)
from app.services.agent_service import AgentService
agent_service = AgentService(self.tenant_domain, self.user_id)
agent_data = await agent_service.get_agent(agent.id)
if agent_data and 'selected_dataset_ids' in agent_data and agent_data['selected_dataset_ids'] is not None:
selected_dataset_ids = agent_data.get('selected_dataset_ids', [])
logger.info(f"Found {len(selected_dataset_ids)} dataset IDs in agent configuration: {selected_dataset_ids}")
return selected_dataset_ids
else:
logger.info(f"No selected_dataset_ids found in agent {agent.id} configuration: {agent_data.get('selected_dataset_ids') if agent_data else 'no agent_data'}")
return []
except Exception as e:
logger.error(f"Failed to get agent datasets: {e}")
return []
async def _get_conversation_files(self, conversation_id: str) -> List[Dict[str, Any]]:
"""Get conversation files (NEW: simplified approach)"""
try:
from app.services.conversation_file_service import get_conversation_file_service
file_service = get_conversation_file_service(self.tenant_domain, self.user_id)
conversation_files = await file_service.list_files(conversation_id)
# Filter to only completed files
completed_files = [f for f in conversation_files if f.get('processing_status') == 'completed']
logger.info(f"Found {len(completed_files)} processed conversation files")
return completed_files
except Exception as e:
logger.error(f"Failed to get conversation files: {e}")
return []
async def _search_conversation_files(
self,
query: str,
conversation_id: str
) -> List[Dict[str, Any]]:
"""Search conversation files using vector similarity (NEW: direct search)"""
try:
from app.services.conversation_file_service import get_conversation_file_service
file_service = get_conversation_file_service(self.tenant_domain, self.user_id)
results = await file_service.search_conversation_files(
conversation_id=conversation_id,
query=query,
max_results=5
)
logger.info(f"Found {len(results)} matching conversation files")
return results
except Exception as e:
logger.error(f"Failed to search conversation files: {e}")
return []
async def _extract_conversation_file_sources(
self,
chunks: List[Dict[str, Any]],
conversation_id: str
) -> List[Dict[str, Any]]:
"""Extract unique conversation file sources with rich metadata"""
sources = {}
for chunk in chunks:
file_id = chunk.get("id")
if file_id and file_id not in sources:
uploaded_at = chunk.get("uploaded_at")
if not uploaded_at:
logger.warning(f"Missing uploaded_at for file {file_id}")
file_size = chunk.get("file_size_bytes", 0)
if file_size == 0:
logger.warning(f"Missing file_size_bytes for file {file_id}")
sources[file_id] = {
"document_id": file_id,
"dataset_id": None,
"document_name": chunk.get("original_filename", "Unknown File"),
"source_type": "conversation_file",
"access_scope": "conversation",
"search_method": "auto_rag",
"conversation_id": conversation_id,
"uploaded_at": uploaded_at,
"file_size_bytes": file_size,
"content_type": chunk.get("content_type", "unknown"),
"processing_status": chunk.get("processing_status", "unknown"),
"chunk_count": 1,
"relevance_score": chunk.get("similarity_score", 0.0)
}
elif file_id in sources:
sources[file_id]["chunk_count"] += 1
current_score = chunk.get("similarity_score", 0.0)
if current_score > sources[file_id]["relevance_score"]:
sources[file_id]["relevance_score"] = current_score
return list(sources.values())
# Keep old method for backward compatibility during migration
async def _get_conversation_datasets(self, conversation_id: str) -> List[str]:
"""Get dataset IDs associated with a conversation (LEGACY: for migration)"""
try:
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.tenant_domain, self.user_id)
conversation_dataset_ids = await conversation_service.get_conversation_datasets(
conversation_id=conversation_id,
user_identifier=self.user_id
)
logger.info(f"Found {len(conversation_dataset_ids)} legacy conversation datasets: {conversation_dataset_ids}")
return conversation_dataset_ids
except Exception as e:
logger.error(f"Failed to get conversation datasets: {e}")
return []
async def _search_datasets_via_mcp(
self,
query: str,
dataset_ids: List[str],
params: RAGSearchParams
) -> Dict[str, Any]:
"""Search datasets using MCP RAG server"""
try:
# Prepare MCP tool call
tool_params = {
"query": query,
"dataset_ids": dataset_ids,
"max_results": params.max_chunks,
"search_method": params.search_method
}
# Create capability token for MCP access
capability_token = {
"capabilities": [{"resource": "mcp:rag:*"}],
"user_id": self.user_id,
"tenant_domain": self.tenant_domain
}
# Execute MCP tool call via resource cluster
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{self.resource_cluster_url}/api/v1/mcp/execute",
json={
"server_id": "rag_server",
"tool_name": "search_datasets",
"parameters": tool_params,
"capability_token": capability_token,
"tenant_domain": self.tenant_domain,
"user_id": self.user_id
}
)
if response.status_code == 200:
return response.json()
else:
logger.error(f"MCP RAG search failed: {response.status_code} - {response.text}")
return {"success": False, "error": "MCP search failed"}
except Exception as e:
logger.error(f"MCP dataset search failed: {e}")
return {"success": False, "error": str(e)}
async def _rank_and_filter_chunks(
self,
chunks: List[Dict[str, Any]],
agent_dataset_ids: List[str],
params: RAGSearchParams
) -> List[Dict[str, Any]]:
"""Rank and filter chunks based on agent preferences (simplified)"""
try:
if not chunks:
return []
# Convert agent dataset list to set for fast lookup
agent_datasets = set(agent_dataset_ids)
# Separate conversation files from dataset chunks
conversation_file_chunks = []
dataset_chunks = []
for chunk in chunks:
if chunk.get("source_type") == "conversation_file":
# Conversation files ALWAYS included (user explicitly attached them)
chunk["final_score"] = chunk.get("similarity_score", 1.0)
chunk["dataset_priority"] = -1 # Highest priority (shown first)
conversation_file_chunks.append(chunk)
else:
dataset_chunks.append(chunk)
# Filter and score dataset chunks using thresholds
scored_dataset_chunks = []
for chunk in dataset_chunks:
dataset_id = chunk.get("dataset_id")
similarity_score = chunk.get("similarity_score", 0.0)
# Check if chunk is from agent's configured datasets
if dataset_id in agent_datasets:
# Use agent default threshold
if similarity_score >= 0.7: # Default threshold
chunk["final_score"] = similarity_score
chunk["dataset_priority"] = 0 # High priority for agent datasets
scored_dataset_chunks.append(chunk)
else:
# Use request threshold for other datasets
if similarity_score >= params.similarity_threshold:
chunk["final_score"] = similarity_score
chunk["dataset_priority"] = 999 # Low priority
scored_dataset_chunks.append(chunk)
# Combine: conversation files first, then sorted dataset chunks
scored_dataset_chunks.sort(key=lambda x: x["final_score"], reverse=True)
# Limit total results, but always include all conversation files
final_chunks = conversation_file_chunks + scored_dataset_chunks
# If we exceed max_chunks, keep all conversation files and trim datasets
if len(final_chunks) > params.max_chunks:
dataset_limit = params.max_chunks - len(conversation_file_chunks)
if dataset_limit > 0:
final_chunks = conversation_file_chunks + scored_dataset_chunks[:dataset_limit]
else:
# If conversation files alone exceed limit, keep them all anyway
final_chunks = conversation_file_chunks
logger.info(f"Ranked chunks: {len(conversation_file_chunks)} conversation files (always included) + {len(scored_dataset_chunks)} dataset chunks → {len(final_chunks)} total")
return final_chunks
except Exception as e:
logger.error(f"Chunk ranking failed: {e}")
return chunks[:params.max_chunks] # Fallback to simple truncation
def _extract_sources(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extract unique document sources from dataset chunks with metadata"""
sources = {}
for chunk in chunks:
document_id = chunk.get("document_id")
if document_id and document_id not in sources:
sources[document_id] = {
"document_id": document_id,
"dataset_id": chunk.get("dataset_id"),
"document_name": chunk.get("metadata", {}).get("document_name", "Unknown"),
"source_type": "dataset",
"access_scope": "permanent",
"search_method": "mcp_tool",
"dataset_name": chunk.get("dataset_name", "Unknown Dataset"),
"chunk_count": 1,
"relevance_score": chunk.get("similarity_score", 0.0)
}
elif document_id in sources:
sources[document_id]["chunk_count"] += 1
current_score = chunk.get("similarity_score", 0.0)
if current_score > sources[document_id]["relevance_score"]:
sources[document_id]["relevance_score"] = current_score
return list(sources.values())
def format_context_for_llm(self, rag_context: RAGContext) -> str:
"""Format RAG context for inclusion in LLM prompt"""
if not rag_context.chunks:
return ""
context_parts = ["## Relevant Context\n"]
# Add dataset search results
if rag_context.chunks:
context_parts.append("### From Documents:")
for i, chunk in enumerate(rag_context.chunks[:5], 1): # Limit to top 5
document_name = chunk.get("metadata", {}).get("document_name", "Unknown Document")
content = chunk.get("content", chunk.get("text", ""))
context_parts.append(f"\n**Source {i}**: {document_name}")
context_parts.append(f"Content: {content[:500]}...") # Truncate long content
context_parts.append("")
# Add source attribution
if rag_context.sources:
context_parts.append("### Sources:")
for source in rag_context.sources:
context_parts.append(f"- {source['document_name']} ({source['chunk_count']} relevant sections)")
return "\n".join(context_parts)
def format_context_for_agent(
self,
rag_context: RAGContext,
compact_mode: bool = False
) -> str:
"""
Format RAG results with clear source attribution for agent consumption.
Args:
rag_context: RAG search results with chunks and sources
compact_mode: Use compact format for >2 files (~200 tokens vs ~700)
Returns:
Formatted context string ready for LLM injection
"""
if not rag_context.chunks:
return ""
dataset_chunks = [
c for c in rag_context.chunks
if c.get('source_type') == 'dataset' or c.get('dataset_id')
]
file_chunks = [
c for c in rag_context.chunks
if c.get('source_type') == 'conversation_file'
]
context_parts = []
context_parts.append("=" * 80)
context_parts.append("📚 KNOWLEDGE BASE CONTEXT - DATASET DOCUMENTS")
context_parts.append("=" * 80)
if dataset_chunks:
context_parts.append("\n📂 FROM AGENT'S PERMANENT DATASETS:")
context_parts.append("(These are documents from the agent's configured knowledge base)\n")
for i, chunk in enumerate(dataset_chunks, 1):
dataset_name = chunk.get('dataset_name', 'Unknown Dataset')
doc_name = chunk.get('metadata', {}).get('document_name', 'Unknown')
content = chunk.get('content', chunk.get('text', ''))
score = chunk.get('similarity_score', 0.0)
if compact_mode:
context_parts.append(f"\n[Dataset: {dataset_name}]\n{content[:400]}...")
else:
context_parts.append(f"\n{'' * 80}")
context_parts.append(f"📚 DATASET EXCERPT {i}")
context_parts.append(f"Dataset: {dataset_name} / Document: {doc_name}")
context_parts.append(f"Relevance: {score:.2f}")
context_parts.append(f"{'' * 80}")
context_parts.append(content[:600] if len(content) > 600 else content)
if len(content) > 600:
context_parts.append("\n[... excerpt continues ...]")
if file_chunks:
context_parts.append(f"\n\n{'=' * 80}")
context_parts.append("📎 FROM CONVERSATION FILES - USER ATTACHED DOCUMENTS")
context_parts.append("=" * 80)
context_parts.append("(These are files the user attached to THIS specific conversation)\n")
for i, chunk in enumerate(file_chunks, 1):
filename = chunk.get('document_name', chunk.get('original_filename', 'Unknown'))
content = chunk.get('content', chunk.get('text', ''))
score = chunk.get('similarity_score', 0.0)
if compact_mode:
context_parts.append(f"\n[File: {filename}]\n{content[:400]}...")
else:
context_parts.append(f"\n{'' * 80}")
context_parts.append(f"📄 FILE EXCERPT {i}: {filename}")
context_parts.append(f"Relevance: {score:.2f}")
context_parts.append(f"{'' * 80}")
context_parts.append(content[:600] if len(content) > 600 else content)
if len(content) > 600:
context_parts.append("\n[... excerpt continues ...]")
context_text = "\n".join(context_parts)
formatted_context = f"""{context_text}
{'=' * 80}
⚠️ CONTEXT USAGE INSTRUCTIONS:
1. CONVERSATION FILES (📎) = User's attached files for THIS chat - cite as "In your attached file..."
2. DATASET DOCUMENTS (📂) = Agent's knowledge base - cite as "According to the dataset..."
3. Always prioritize conversation files when both sources have relevant information
4. Be explicit about which source you're referencing in your answer
{'=' * 80}"""
return formatted_context
# Global instance factory
def get_rag_orchestrator(tenant_domain: str, user_id: str) -> RAGOrchestrator:
"""Get RAG orchestrator instance for tenant and user"""
return RAGOrchestrator(tenant_domain, user_id)

View File

@@ -0,0 +1,671 @@
"""
RAG Service for GT 2.0 Tenant Backend
Orchestrates document processing, embedding generation, and vector storage
with perfect tenant isolation and zero downtime compliance.
"""
import logging
import asyncio
import aiofiles
import os
import json
import gc
from typing import Dict, Any, List, Optional, BinaryIO
from datetime import datetime
from pathlib import Path
import hashlib
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_
from sqlalchemy.orm import selectinload
from app.models.document import Document, RAGDataset, DatasetDocument, DocumentChunk
from app.core.database import get_db_session
from app.core.config import get_settings
from app.core.resource_client import ResourceClusterClient
logger = logging.getLogger(__name__)
class RAGService:
"""
Comprehensive RAG service with perfect tenant isolation.
GT 2.0 Security Principles:
- Perfect tenant isolation (all operations user-scoped)
- Stateless document processing (no data persistence in Resource Cluster)
- Encrypted vector storage per tenant
- Zero downtime compliance (async operations)
"""
def __init__(self, db: AsyncSession):
self.db = db
self.settings = get_settings()
self.resource_client = ResourceClusterClient()
# Tenant-specific directories
self.upload_directory = Path(self.settings.upload_directory)
self.temp_directory = Path(self.settings.temp_directory)
# Ensure directories exist with secure permissions
self._ensure_directories()
logger.info("RAG service initialized with tenant isolation")
def _ensure_directories(self):
"""Ensure required directories exist with secure permissions"""
for directory in [self.upload_directory, self.temp_directory]:
directory.mkdir(parents=True, exist_ok=True, mode=0o700)
async def create_rag_dataset(
self,
user_id: str,
dataset_name: str,
description: Optional[str] = None,
chunking_strategy: str = "hybrid",
chunk_size: int = 512,
chunk_overlap: int = 128,
embedding_model: str = "BAAI/bge-m3"
) -> RAGDataset:
"""Create a new RAG dataset with tenant isolation"""
try:
# Check if dataset already exists for this user
existing = await self.db.execute(
select(RAGDataset).where(
and_(
RAGDataset.user_id == user_id,
RAGDataset.dataset_name == dataset_name
)
)
)
if existing.scalar_one_or_none():
raise ValueError(f"Dataset '{dataset_name}' already exists for user")
# Create dataset
dataset = RAGDataset(
user_id=user_id,
dataset_name=dataset_name,
description=description,
chunking_strategy=chunking_strategy,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
embedding_model=embedding_model
)
self.db.add(dataset)
await self.db.commit()
await self.db.refresh(dataset)
logger.info(f"Created RAG dataset '{dataset_name}' for user {user_id}")
return dataset
except Exception as e:
await self.db.rollback()
logger.error(f"Failed to create RAG dataset: {e}")
raise
async def upload_document(
self,
user_id: str,
file_content: bytes,
filename: str,
file_type: str,
dataset_id: Optional[str] = None
) -> Document:
"""Upload and store document with tenant isolation"""
try:
# Validate file
file_extension = Path(filename).suffix.lower()
if not file_extension:
raise ValueError("File must have an extension")
# Generate secure filename
file_hash = hashlib.sha256(file_content).hexdigest()[:16]
secure_filename = f"{user_id}_{file_hash}_{filename}"
# Tenant-specific file path
user_upload_dir = self.upload_directory / user_id
user_upload_dir.mkdir(exist_ok=True, mode=0o700)
file_path = user_upload_dir / secure_filename
# Save file with secure permissions
async with aiofiles.open(file_path, 'wb') as f:
await f.write(file_content)
# Set file permissions (owner read/write only)
os.chmod(file_path, 0o600)
# Create document record
document = Document(
filename=secure_filename,
original_filename=filename,
file_path=str(file_path),
file_type=file_type,
file_extension=file_extension,
file_size_bytes=len(file_content),
uploaded_by=user_id,
processing_status="pending"
)
self.db.add(document)
await self.db.commit()
await self.db.refresh(document)
# Add to dataset if specified
if dataset_id:
await self.add_document_to_dataset(user_id, document.id, dataset_id)
# Clear file content from memory
del file_content
gc.collect()
logger.info(f"Uploaded document '{filename}' for user {user_id}")
return document
except Exception as e:
await self.db.rollback()
logger.error(f"Failed to upload document: {e}")
# Clear sensitive data even on error
if 'file_content' in locals():
del file_content
gc.collect()
raise
async def process_document(
self,
user_id: str,
document_id: int,
tenant_id: str,
chunking_strategy: Optional[str] = None
) -> Dict[str, Any]:
"""Process document into chunks and generate embeddings"""
try:
# Get document with ownership check
document = await self._get_user_document(user_id, document_id)
if not document:
raise PermissionError("Document not found or access denied")
# Check if already processed
if document.is_processing_complete():
logger.info(f"Document {document_id} already processed")
return {"status": "already_processed", "chunk_count": document.chunk_count}
# Mark as processing
document.mark_processing_started()
await self.db.commit()
# Read document file
file_content = await self._read_document_file(document)
# Process document using Resource Cluster (stateless)
chunks = await self.resource_client.process_document(
content=file_content,
document_type=document.file_extension,
strategy_type=chunking_strategy or "hybrid",
tenant_id=tenant_id,
user_id=user_id
)
# Clear file content from memory immediately
del file_content
gc.collect()
if not chunks:
raise ValueError("Document processing returned no chunks")
# Generate embeddings for chunks (stateless)
chunk_texts = [chunk["text"] for chunk in chunks]
embeddings = await self.resource_client.generate_document_embeddings(
documents=chunk_texts,
tenant_id=tenant_id,
user_id=user_id
)
if len(embeddings) != len(chunk_texts):
raise ValueError("Embedding count mismatch with chunk count")
# Store vectors in ChromaDB via Resource Cluster
dataset_name = f"doc_{document.id}"
collection_created = await self.resource_client.create_vector_collection(
tenant_id=tenant_id,
user_id=user_id,
dataset_name=dataset_name
)
if not collection_created:
raise RuntimeError("Failed to create vector collection")
# Store vectors with metadata
chunk_metadata = [chunk["metadata"] for chunk in chunks]
vector_stored = await self.resource_client.store_vectors(
tenant_id=tenant_id,
user_id=user_id,
dataset_name=dataset_name,
documents=chunk_texts,
embeddings=embeddings,
metadata=chunk_metadata
)
if not vector_stored:
raise RuntimeError("Failed to store vectors")
# Clear embedding data from memory
del chunk_texts, embeddings
gc.collect()
# Update document record
vector_store_ids = [f"{tenant_id}:{user_id}:{dataset_name}"]
document.mark_processing_complete(
chunk_count=len(chunks),
vector_store_ids=vector_store_ids
)
await self.db.commit()
logger.info(f"Processed document {document_id} into {len(chunks)} chunks")
return {
"status": "completed",
"document_id": document_id,
"chunk_count": len(chunks),
"vector_store_ids": vector_store_ids
}
except Exception as e:
# Mark document processing as failed
if 'document' in locals() and document:
document.mark_processing_failed({"error": str(e)})
await self.db.commit()
logger.error(f"Failed to process document {document_id}: {e}")
# Ensure memory cleanup
gc.collect()
raise
async def add_document_to_dataset(
self,
user_id: str,
document_id: int,
dataset_id: str
) -> DatasetDocument:
"""Add processed document to RAG dataset"""
try:
# Verify dataset ownership
dataset = await self._get_user_dataset(user_id, dataset_id)
if not dataset:
raise PermissionError("Dataset not found or access denied")
# Verify document ownership
document = await self._get_user_document(user_id, document_id)
if not document:
raise PermissionError("Document not found or access denied")
# Check if already in dataset
existing = await self.db.execute(
select(DatasetDocument).where(
and_(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.document_id == document_id
)
)
)
if existing.scalar_one_or_none():
raise ValueError("Document already in dataset")
# Create dataset document relationship
dataset_doc = DatasetDocument(
dataset_id=dataset_id,
document_id=document_id,
user_id=user_id,
chunk_count=document.chunk_count,
vector_count=document.chunk_count # Assuming 1 vector per chunk
)
self.db.add(dataset_doc)
# Update dataset statistics
dataset.document_count += 1
dataset.chunk_count += document.chunk_count
dataset.vector_count += document.chunk_count
dataset.total_size_bytes += document.file_size_bytes
await self.db.commit()
await self.db.refresh(dataset_doc)
logger.info(f"Added document {document_id} to dataset {dataset_id}")
return dataset_doc
except Exception as e:
await self.db.rollback()
logger.error(f"Failed to add document to dataset: {e}")
raise
async def search_documents(
self,
user_id: str,
tenant_id: str,
query: str,
dataset_ids: Optional[List[str]] = None,
top_k: int = 5,
similarity_threshold: float = 0.7
) -> List[Dict[str, Any]]:
"""Search documents using RAG with tenant isolation"""
try:
# Generate query embedding
query_embeddings = await self.resource_client.generate_query_embeddings(
queries=[query],
tenant_id=tenant_id,
user_id=user_id
)
if not query_embeddings:
raise ValueError("Failed to generate query embedding")
query_embedding = query_embeddings[0]
# Get user's datasets if not specified
if not dataset_ids:
datasets = await self.list_user_datasets(user_id)
dataset_ids = [d.id for d in datasets]
# Search across datasets
all_results = []
for dataset_id in dataset_ids:
# Verify dataset ownership
dataset = await self._get_user_dataset(user_id, dataset_id)
if not dataset:
continue
# Search in ChromaDB
dataset_name = f"dataset_{dataset_id}"
results = await self.resource_client.search_vectors(
tenant_id=tenant_id,
user_id=user_id,
dataset_name=dataset_name,
query_embedding=query_embedding,
top_k=top_k
)
# Filter by similarity threshold and add dataset context
for result in results:
if result.get("similarity", 0) >= similarity_threshold:
result["dataset_id"] = dataset_id
result["dataset_name"] = dataset.dataset_name
all_results.append(result)
# Sort by similarity and limit
all_results.sort(key=lambda x: x.get("similarity", 0), reverse=True)
final_results = all_results[:top_k]
# Clear query embedding from memory
del query_embedding, query_embeddings
gc.collect()
logger.info(f"Search found {len(final_results)} results for user {user_id}")
return final_results
except Exception as e:
logger.error(f"Failed to search documents: {e}")
gc.collect()
raise
async def get_document_context(
self,
user_id: str,
tenant_id: str,
document_id: int,
query: str,
context_size: int = 3
) -> Dict[str, Any]:
"""Get relevant context from a specific document"""
try:
# Verify document ownership
document = await self._get_user_document(user_id, document_id)
if not document:
raise PermissionError("Document not found or access denied")
if not document.is_processing_complete():
raise ValueError("Document not yet processed")
# Generate query embedding
query_embeddings = await self.resource_client.generate_query_embeddings(
queries=[query],
tenant_id=tenant_id,
user_id=user_id
)
query_embedding = query_embeddings[0]
# Search document's vectors
dataset_name = f"doc_{document_id}"
results = await self.resource_client.search_vectors(
tenant_id=tenant_id,
user_id=user_id,
dataset_name=dataset_name,
query_embedding=query_embedding,
top_k=context_size
)
context = {
"document_id": document_id,
"document_name": document.original_filename,
"query": query,
"relevant_chunks": results,
"context_text": "\n\n".join([r["document"] for r in results])
}
# Clear query embedding from memory
del query_embedding, query_embeddings
gc.collect()
return context
except Exception as e:
logger.error(f"Failed to get document context: {e}")
gc.collect()
raise
async def list_user_documents(
self,
user_id: str,
status_filter: Optional[str] = None,
offset: int = 0,
limit: int = 50
) -> List[Document]:
"""List user's documents with optional filtering"""
try:
query = select(Document).where(Document.uploaded_by == user_id)
if status_filter:
query = query.where(Document.processing_status == status_filter)
query = query.order_by(Document.created_at.desc())
query = query.offset(offset).limit(limit)
result = await self.db.execute(query)
documents = result.scalars().all()
return list(documents)
except Exception as e:
logger.error(f"Failed to list user documents: {e}")
raise
async def list_user_datasets(
self,
user_id: str,
include_stats: bool = True
) -> List[RAGDataset]:
"""List user's RAG datasets"""
try:
query = select(RAGDataset).where(RAGDataset.user_id == user_id)
if include_stats:
query = query.options(selectinload(RAGDataset.documents))
query = query.order_by(RAGDataset.created_at.desc())
result = await self.db.execute(query)
datasets = result.scalars().all()
return list(datasets)
except Exception as e:
logger.error(f"Failed to list user datasets: {e}")
raise
async def delete_document(
self,
user_id: str,
tenant_id: str,
document_id: int
) -> bool:
"""Delete document and associated vectors"""
try:
# Verify document ownership
document = await self._get_user_document(user_id, document_id)
if not document:
raise PermissionError("Document not found or access denied")
# Delete vectors from ChromaDB if processed
if document.is_processing_complete():
dataset_name = f"doc_{document_id}"
await self.resource_client.delete_vector_collection(
tenant_id=tenant_id,
user_id=user_id,
dataset_name=dataset_name
)
# Delete physical file
if document.file_exists():
os.remove(document.get_absolute_file_path())
# Delete from database (cascade will handle related records)
await self.db.delete(document)
await self.db.commit()
logger.info(f"Deleted document {document_id} for user {user_id}")
return True
except Exception as e:
await self.db.rollback()
logger.error(f"Failed to delete document: {e}")
raise
async def delete_dataset(
self,
user_id: str,
tenant_id: str,
dataset_id: str
) -> bool:
"""Delete RAG dataset and associated vectors"""
try:
# Verify dataset ownership
dataset = await self._get_user_dataset(user_id, dataset_id)
if not dataset:
raise PermissionError("Dataset not found or access denied")
# Delete vectors from ChromaDB
dataset_name = f"dataset_{dataset_id}"
await self.resource_client.delete_vector_collection(
tenant_id=tenant_id,
user_id=user_id,
dataset_name=dataset_name
)
# Delete from database (cascade will handle related records)
await self.db.delete(dataset)
await self.db.commit()
logger.info(f"Deleted dataset {dataset_id} for user {user_id}")
return True
except Exception as e:
await self.db.rollback()
logger.error(f"Failed to delete dataset: {e}")
raise
async def get_rag_statistics(
self,
user_id: str
) -> Dict[str, Any]:
"""Get RAG usage statistics for user"""
try:
# Document statistics
doc_query = select(Document).where(Document.uploaded_by == user_id)
doc_result = await self.db.execute(doc_query)
documents = doc_result.scalars().all()
# Dataset statistics
dataset_query = select(RAGDataset).where(RAGDataset.user_id == user_id)
dataset_result = await self.db.execute(dataset_query)
datasets = dataset_result.scalars().all()
total_size = sum(doc.file_size_bytes for doc in documents)
total_chunks = sum(doc.chunk_count for doc in documents)
stats = {
"user_id": user_id,
"document_count": len(documents),
"dataset_count": len(datasets),
"total_size_bytes": total_size,
"total_size_mb": round(total_size / (1024 * 1024), 2),
"total_chunks": total_chunks,
"processed_documents": len([d for d in documents if d.is_processing_complete()]),
"pending_documents": len([d for d in documents if d.is_pending_processing()]),
"failed_documents": len([d for d in documents if d.is_processing_failed()])
}
return stats
except Exception as e:
logger.error(f"Failed to get RAG statistics: {e}")
raise
# Private helper methods
async def _get_user_document(self, user_id: str, document_id: int) -> Optional[Document]:
"""Get document with ownership verification"""
result = await self.db.execute(
select(Document).where(
and_(
Document.id == document_id,
Document.uploaded_by == user_id
)
)
)
return result.scalar_one_or_none()
async def _get_user_dataset(self, user_id: str, dataset_id: str) -> Optional[RAGDataset]:
"""Get dataset with ownership verification"""
result = await self.db.execute(
select(RAGDataset).where(
and_(
RAGDataset.id == dataset_id,
RAGDataset.user_id == user_id
)
)
)
return result.scalar_one_or_none()
async def _read_document_file(self, document: Document) -> bytes:
"""Read document file content"""
file_path = document.get_absolute_file_path()
if not os.path.exists(file_path):
raise FileNotFoundError(f"Document file not found: {file_path}")
async with aiofiles.open(file_path, 'rb') as f:
content = await f.read()
return content
# Factory function for dependency injection
async def get_rag_service(db: AsyncSession = None) -> RAGService:
"""Get RAG service instance"""
if db is None:
async with get_db_session() as session:
return RAGService(session)
return RAGService(db)

View File

@@ -0,0 +1,371 @@
"""
Resource Cluster Client for GT 2.0 Tenant Backend
Handles communication with the Resource Cluster for AI/ML operations.
Manages capability token generation and LLM inference requests.
"""
import httpx
import json
import logging
from typing import Dict, Any, Optional, AsyncIterator, List
from datetime import timedelta
import asyncio
from jose import jwt
from app.core.config import get_settings
logger = logging.getLogger(__name__)
async def fetch_model_rate_limit(
tenant_id: str,
model_id: str,
control_panel_url: str
) -> int:
"""
Fetch rate limit for a model from Control Panel API.
Returns requests_per_minute (converted from max_requests_per_hour in database).
Fails fast if Control Panel is unreachable (GT 2.0 principle: no fallbacks).
Args:
tenant_id: Tenant identifier
model_id: Model identifier
control_panel_url: Control Panel API base URL
Returns:
Requests per minute limit
Raises:
RuntimeError: If Control Panel API is unreachable
"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
url = f"{control_panel_url}/api/v1/tenant-models/tenants/{tenant_id}/models/{model_id}"
logger.debug(f"Fetching rate limit from Control Panel: {url}")
response = await client.get(url)
if response.status_code == 404:
logger.warning(f"Model {model_id} not configured for tenant {tenant_id}, using default")
return 1000 # Default: 1000 requests/minute
response.raise_for_status()
config = response.json()
# API now returns requests_per_minute directly (translated from DB per-hour)
rate_limits = config.get("rate_limits", {})
requests_per_minute = rate_limits.get("requests_per_minute", 1000)
logger.info(f"Model {model_id} rate limit: {requests_per_minute} requests/minute")
return requests_per_minute
except httpx.HTTPStatusError as e:
logger.error(f"Control Panel API error: {e.response.status_code}")
raise RuntimeError(f"Failed to fetch rate limit: HTTP {e.response.status_code}")
except httpx.RequestError as e:
logger.error(f"Control Panel API unreachable: {e}")
raise RuntimeError(f"Control Panel unreachable at {control_panel_url}")
except Exception as e:
logger.error(f"Unexpected error fetching rate limit: {e}")
raise RuntimeError(f"Failed to fetch rate limit: {e}")
class ResourceClusterClient:
"""Client for communicating with GT 2.0 Resource Cluster"""
def __init__(self):
self.settings = get_settings()
self.resource_cluster_url = self.settings.resource_cluster_url
self.secret_key = self.settings.secret_key
self.algorithm = "HS256"
self.client = httpx.AsyncClient(timeout=60.0)
async def generate_capability_token(
self,
user_id: str,
tenant_id: str,
assistant_config: Dict[str, Any],
expires_minutes: int = 30
) -> str:
"""
Generate capability token for resource access.
Fetches real rate limits from Control Panel (single source of truth).
Fails fast if Control Panel is unreachable.
"""
# Extract capabilities from agent configuration
capabilities = []
# Add LLM capability with real rate limit from Control Panel
model = assistant_config.get("resource_preferences", {}).get("primary_llm", "llama-3.1-70b-versatile")
# Fetch real rate limit from Control Panel API
requests_per_minute = await fetch_model_rate_limit(
tenant_id=tenant_id,
model_id=model,
control_panel_url=self.settings.control_panel_url
)
capabilities.append({
"resource": f"llm:groq",
"actions": ["inference", "streaming"],
"constraints": { # Changed from "limits" to match LLM gateway expectations
"max_tokens_per_request": assistant_config.get("resource_preferences", {}).get("max_tokens", 4000),
"max_requests_per_minute": requests_per_minute # Real limit from database (converted from per-hour)
}
})
# Add RAG capabilities if configured
if assistant_config.get("capabilities", {}).get("rag_enabled"):
capabilities.append({
"resource": "rag:semantic_search",
"actions": ["search", "retrieve"],
"limits": {
"max_results": 10
}
})
# Add embedding capability if RAG is enabled
if assistant_config.get("capabilities", {}).get("embeddings_enabled"):
capabilities.append({
"resource": "embedding:text-embedding-3-small",
"actions": ["generate"],
"limits": {
"max_texts_per_request": 100
}
})
# Create token payload
payload = {
"sub": user_id,
"tenant_id": tenant_id,
"capabilities": capabilities,
"exp": asyncio.get_event_loop().time() + (expires_minutes * 60),
"iat": asyncio.get_event_loop().time()
}
# Sign token
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
return token
async def execute_inference(
self,
prompt: str,
assistant_config: Dict[str, Any],
user_id: str,
tenant_id: str,
stream: bool = False,
conversation_context: Optional[List[Dict[str, str]]] = None
) -> Dict[str, Any]:
"""Execute LLM inference via Resource Cluster"""
# Generate capability token (now async - fetches real rate limits)
token = await self.generate_capability_token(user_id, tenant_id, assistant_config)
# Prepare request
model = assistant_config.get("resource_preferences", {}).get("primary_llm", "llama-3.1-70b-versatile")
temperature = assistant_config.get("resource_preferences", {}).get("temperature", 0.7)
max_tokens = assistant_config.get("resource_preferences", {}).get("max_tokens", 4000)
# Build messages array with system prompt
messages = []
# Add system prompt from agent
system_prompt = assistant_config.get("prompt", "You are a helpful AI agent.")
messages.append({"role": "system", "content": system_prompt})
# Add conversation context if provided
if conversation_context:
messages.extend(conversation_context)
# Add current user message
messages.append({"role": "user", "content": prompt})
# Prepare request payload
request_data = {
"messages": messages,
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
"user_id": user_id,
"tenant_id": tenant_id
}
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
try:
if stream:
return await self._stream_inference(request_data, headers)
else:
response = await self.client.post(
f"{self.resource_cluster_url}/api/v1/inference/",
json=request_data,
headers=headers
)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
logger.error(f"HTTP error during inference: {e}")
raise
except Exception as e:
logger.error(f"Error during inference: {e}")
raise
async def _stream_inference(
self,
request_data: Dict[str, Any],
headers: Dict[str, str]
) -> AsyncIterator[str]:
"""Stream inference responses"""
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.resource_cluster_url}/api/v1/inference/stream",
json=request_data,
headers=headers,
timeout=60.0
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:] # Remove "data: " prefix
if data == "[DONE]":
break
try:
chunk = json.loads(data)
if "content" in chunk:
yield chunk["content"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse streaming chunk: {data}")
continue
async def generate_embeddings(
self,
texts: List[str],
user_id: str,
tenant_id: str,
model: str = "text-embedding-3-small"
) -> List[List[float]]:
"""Generate embeddings for texts"""
# Generate capability token with embedding permission
assistant_config = {"capabilities": {"embeddings_enabled": True}}
token = self.generate_capability_token(user_id, tenant_id, assistant_config)
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
request_data = {
"texts": texts,
"model": model
}
try:
response = await self.client.post(
f"{self.resource_cluster_url}/api/v1/embeddings/",
json=request_data,
headers=headers
)
response.raise_for_status()
result = response.json()
return result.get("embeddings", [])
except httpx.HTTPError as e:
logger.error(f"HTTP error during embedding generation: {e}")
raise
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise
async def search_rag(
self,
query: str,
collection: str,
user_id: str,
tenant_id: str,
top_k: int = 5
) -> List[Dict[str, Any]]:
"""Search RAG collection for relevant documents"""
# Generate capability token with RAG permission
assistant_config = {"capabilities": {"rag_enabled": True}}
token = self.generate_capability_token(user_id, tenant_id, assistant_config)
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
request_data = {
"query": query,
"collection": collection,
"top_k": top_k
}
try:
response = await self.client.post(
f"{self.resource_cluster_url}/api/v1/rag/search",
json=request_data,
headers=headers
)
response.raise_for_status()
result = response.json()
return result.get("results", [])
except httpx.HTTPError as e:
logger.error(f"HTTP error during RAG search: {e}")
# Return empty results on error for now
return []
except Exception as e:
logger.error(f"Error searching RAG: {e}")
return []
async def get_agent_templates(
self,
user_id: str,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Get available agent templates from Resource Cluster"""
# Generate basic capability token
token = self.generate_capability_token(user_id, tenant_id, {})
headers = {
"Authorization": f"Bearer {token}"
}
try:
response = await self.client.get(
f"{self.resource_cluster_url}/api/v1/templates/",
headers=headers
)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
logger.error(f"HTTP error fetching templates: {e}")
return []
except Exception as e:
logger.error(f"Error fetching templates: {e}")
return []
async def close(self):
"""Close the HTTP client"""
await self.client.aclose()
async def __aenter__(self):
"""Async context manager entry"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
await self.close()

View File

@@ -0,0 +1,173 @@
"""
Resource Service for GT 2.0 Tenant Backend
Provides access to Resource Cluster capabilities and services.
This is a wrapper around the resource_cluster_client for agent services.
"""
import logging
from typing import Dict, List, Optional, Any
from app.core.resource_client import ResourceClusterClient
from app.core.config import get_settings
logger = logging.getLogger(__name__)
class ResourceService:
"""Service for accessing Resource Cluster capabilities"""
def __init__(self):
"""Initialize resource service"""
self.settings = get_settings()
self.client = ResourceClusterClient()
async def get_available_models(self, user_id: str) -> List[Dict[str, Any]]:
"""Get available AI models for user from Resource Cluster"""
try:
# Get models from Resource Cluster via capability token
token = await self.client._get_capability_token(
tenant_id=self.settings.tenant_domain,
user_id=user_id,
resources=['model_registry']
)
import aiohttp
headers = {
'Authorization': f'Bearer {token}',
'Content-Type': 'application/json',
'X-Tenant-ID': self.settings.tenant_domain,
'X-User-ID': user_id
}
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.client.base_url}/api/v1/models/",
headers=headers,
timeout=aiohttp.ClientTimeout(total=10)
) as response:
if response.status == 200:
response_data = await response.json()
models_data = response_data.get("models", [])
# Transform to expected format
available_models = []
for model in models_data:
if model.get("status", {}).get("deployment") == "available":
available_models.append({
"model_id": model["id"],
"name": model["name"],
"provider": model["provider"],
"capabilities": ["chat", "completion"],
"context_length": model.get("performance", {}).get("context_window", 4000),
"available": True
})
logger.info(f"Retrieved {len(available_models)} models from Resource Cluster")
return available_models
else:
logger.error(f"Resource Cluster returned {response.status}")
return []
except Exception as e:
logger.error(f"Failed to get available models from Resource Cluster: {e}")
return []
async def get_available_tools(self, user_id: str) -> List[Dict[str, Any]]:
"""Get available tools for user"""
try:
# Mock tools for development
return [
{
"tool_id": "web_search",
"name": "Web Search",
"description": "Search the web for information",
"available": True,
"capabilities": ["search", "retrieve"]
},
{
"tool_id": "document_analysis",
"name": "Document Analysis",
"description": "Analyze documents and extract information",
"available": True,
"capabilities": ["analyze", "extract", "summarize"]
},
{
"tool_id": "code_execution",
"name": "Code Execution",
"description": "Execute code in safe sandbox",
"available": True,
"capabilities": ["execute", "debug", "test"]
}
]
except Exception as e:
logger.error(f"Failed to get available tools: {e}")
return []
async def validate_capabilities(self, user_id: str, capabilities: List[str]) -> bool:
"""Validate that user has access to required capabilities"""
try:
# For development, allow all capabilities
logger.info(f"Validating capabilities {capabilities} for user {user_id}")
return True
except Exception as e:
logger.error(f"Failed to validate capabilities: {e}")
return False
async def execute_agent_task(self,
agent_id: str,
task_description: str,
parameters: Dict[str, Any],
user_id: str) -> Dict[str, Any]:
"""Execute an agent task via Resource Cluster"""
try:
# Mock execution for development
execution_result = {
"execution_id": f"exec_{agent_id}_{int(datetime.now().timestamp())}",
"status": "completed",
"result": f"Mock execution of task: {task_description}",
"output_artifacts": [],
"tokens_used": 150,
"cost_cents": 1,
"execution_time_ms": 2500
}
logger.info(f"Mock agent execution: {execution_result['execution_id']}")
return execution_result
except Exception as e:
logger.error(f"Failed to execute agent task: {e}")
return {
"execution_id": f"failed_{agent_id}",
"status": "failed",
"error": str(e)
}
async def get_resource_usage(self, user_id: str, timeframe_hours: int = 24) -> Dict[str, Any]:
"""Get resource usage statistics for user"""
try:
# Mock usage data for development
return {
"total_requests": 25,
"total_tokens": 15000,
"total_cost_cents": 150,
"execution_count": 12,
"average_response_time_ms": 1250,
"timeframe_hours": timeframe_hours
}
except Exception as e:
logger.error(f"Failed to get resource usage: {e}")
return {}
async def check_rate_limits(self, user_id: str) -> Dict[str, Any]:
"""Check current rate limits for user"""
try:
# Mock rate limit data for development
return {
"requests_per_minute": {"current": 5, "limit": 60, "reset_time": "2024-01-01T00:00:00Z"},
"tokens_per_hour": {"current": 2500, "limit": 50000, "reset_time": "2024-01-01T00:00:00Z"},
"executions_per_day": {"current": 12, "limit": 1000, "reset_time": "2024-01-01T00:00:00Z"}
}
except Exception as e:
logger.error(f"Failed to check rate limits: {e}")
return {}
# Import datetime for mock execution
from datetime import datetime

View File

@@ -0,0 +1,385 @@
"""
GT 2.0 Summarization Service
Provides AI-powered summarization capabilities for documents and datasets.
Uses the same pattern as conversation title generation with Llama 3.1 8B.
"""
import logging
import asyncio
from typing import Dict, Any, List, Optional
from datetime import datetime
from app.core.postgresql_client import get_postgresql_client
from app.core.resource_client import ResourceClusterClient
from app.core.config import get_settings
logger = logging.getLogger(__name__)
class SummarizationService:
"""
Service for generating AI summaries of documents and datasets.
Uses the same approach as conversation title generation:
- Llama 3.1 8B instant model
- Low temperature for consistency
- Resource cluster for AI responses
"""
def __init__(self, tenant_domain: str, user_id: str):
self.tenant_domain = tenant_domain
self.user_id = user_id
self.resource_client = ResourceClusterClient()
self.summarization_model = "llama-3.1-8b-instant"
self.settings = get_settings()
async def generate_document_summary(
self,
document_id: str,
document_content: str,
document_name: str
) -> Optional[str]:
"""
Generate AI summary for a document using Llama 3.1 8B.
Args:
document_id: UUID of the document
document_content: Full text content of the document
document_name: Original filename/name of the document
Returns:
Generated summary string or None if failed
"""
try:
# Truncate content to first 3000 chars (like conversation title generation)
content_preview = document_content[:3000]
# Create summarization prompt
prompt = f"""Summarize this document '{document_name}' in 2-3 sentences.
Focus on the main topics, key information, and purpose of the document.
Document content:
{content_preview}
Summary:"""
logger.info(f"Generating summary for document {document_id} ({document_name})")
# Call Resource Cluster with same pattern as conversation titles
summary = await self._call_ai_for_summary(
prompt=prompt,
context_type="document",
max_tokens=150
)
if summary:
# Store summary in database
await self._store_document_summary(document_id, summary)
logger.info(f"Generated summary for document {document_id}: {summary[:100]}...")
return summary
else:
logger.warning(f"Failed to generate summary for document {document_id}")
return None
except Exception as e:
logger.error(f"Error generating document summary for {document_id}: {e}")
return None
async def generate_dataset_summary(self, dataset_id: str) -> Optional[str]:
"""
Generate AI summary for a dataset based on its document summaries.
Args:
dataset_id: UUID of the dataset
Returns:
Generated dataset summary or None if failed
"""
try:
# Get all document summaries in this dataset
document_summaries = await self._get_document_summaries_for_dataset(dataset_id)
if not document_summaries:
logger.info(f"No document summaries found for dataset {dataset_id}")
return None
# Get dataset name for context
dataset_info = await self._get_dataset_info(dataset_id)
dataset_name = dataset_info.get('name', 'Unknown Dataset') if dataset_info else 'Unknown Dataset'
# Combine summaries for LLM context
combined_summaries = "\n".join([
f"- {doc['filename']}: {doc['summary']}"
for doc in document_summaries
if doc['summary'] # Only include docs that have summaries
])
if not combined_summaries.strip():
logger.info(f"No valid document summaries for dataset {dataset_id}")
return None
# Create dataset summarization prompt
prompt = f"""Based on these document summaries, create a comprehensive 3-4 sentence summary describing what the dataset '{dataset_name}' contains and its purpose:
Documents in dataset:
{combined_summaries}
Dataset summary:"""
logger.info(f"Generating summary for dataset {dataset_id} ({dataset_name})")
# Call AI for dataset summary
summary = await self._call_ai_for_summary(
prompt=prompt,
context_type="dataset",
max_tokens=200
)
if summary:
# Store dataset summary in database
await self._store_dataset_summary(dataset_id, summary)
logger.info(f"Generated dataset summary for {dataset_id}: {summary[:100]}...")
return summary
else:
logger.warning(f"Failed to generate summary for dataset {dataset_id}")
return None
except Exception as e:
logger.error(f"Error generating dataset summary for {dataset_id}: {e}")
return None
async def update_dataset_summary_on_change(self, dataset_id: str) -> bool:
"""
Regenerate dataset summary when documents are added/removed.
Args:
dataset_id: UUID of the dataset to update
Returns:
True if summary was updated successfully
"""
try:
summary = await self.generate_dataset_summary(dataset_id)
return summary is not None
except Exception as e:
logger.error(f"Error updating dataset summary for {dataset_id}: {e}")
return False
async def _call_ai_for_summary(
self,
prompt: str,
context_type: str,
max_tokens: int = 150
) -> Optional[str]:
"""
Call Resource Cluster for AI summary generation.
Uses ResourceClusterClient for consistent service discovery.
Args:
prompt: The summarization prompt
context_type: Type of summary (document, dataset)
max_tokens: Maximum tokens to generate
Returns:
Generated summary text or None if failed
"""
try:
# Prepare request payload (same format as conversation service)
request_data = {
"model": self.summarization_model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3, # Lower temperature for consistent summaries
"max_tokens": max_tokens,
"top_p": 1.0
}
logger.info(f"Calling Resource Cluster for {context_type} summary generation")
# Use ResourceClusterClient for consistent service discovery and auth
result = await self.resource_client.call_inference_endpoint(
tenant_id=self.tenant_domain,
user_id=self.user_id,
endpoint="chat/completions",
data=request_data
)
if result and "choices" in result and len(result["choices"]) > 0:
summary = result["choices"][0]["message"]["content"].strip()
logger.info(f"✅ AI {context_type} summary generated successfully: {summary[:50]}...")
return summary
else:
logger.error(f"❌ Invalid AI response format for {context_type} summary: {result}")
return None
except Exception as e:
logger.error(f"❌ Error calling Resource Cluster for {context_type} summary: {e}", exc_info=True)
return None
async def _store_document_summary(self, document_id: str, summary: str) -> None:
"""Store document summary in database"""
try:
client = await get_postgresql_client()
async with client.get_connection() as conn:
schema_name = self.settings.postgres_schema
await conn.execute(f"""
UPDATE {schema_name}.documents
SET summary = $1,
summary_generated_at = $2,
summary_model = $3
WHERE id = $4
""", summary, datetime.now(), self.summarization_model, document_id)
except Exception as e:
logger.error(f"Error storing document summary for {document_id}: {e}")
raise
async def _store_dataset_summary(self, dataset_id: str, summary: str) -> None:
"""Store dataset summary in database"""
try:
client = await get_postgresql_client()
async with client.get_connection() as conn:
schema_name = self.settings.postgres_schema
await conn.execute(f"""
UPDATE {schema_name}.datasets
SET summary = $1,
summary_generated_at = $2,
summary_model = $3
WHERE id = $4
""", summary, datetime.now(), self.summarization_model, dataset_id)
except Exception as e:
logger.error(f"Error storing dataset summary for {dataset_id}: {e}")
raise
async def _get_document_summaries_for_dataset(self, dataset_id: str) -> List[Dict[str, Any]]:
"""Get all document summaries for a dataset"""
try:
client = await get_postgresql_client()
async with client.get_connection() as conn:
schema_name = self.settings.postgres_schema
rows = await conn.fetch(f"""
SELECT id, filename, original_filename, summary, summary_generated_at
FROM {schema_name}.documents
WHERE dataset_id = $1
AND summary IS NOT NULL
AND summary != ''
ORDER BY created_at ASC
""", dataset_id)
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"Error getting document summaries for dataset {dataset_id}: {e}")
return []
async def _get_dataset_info(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Get basic dataset information"""
try:
client = await get_postgresql_client()
async with client.get_connection() as conn:
schema_name = self.settings.postgres_schema
row = await conn.fetchrow(f"""
SELECT id, name, description
FROM {schema_name}.datasets
WHERE id = $1
""", dataset_id)
return dict(row) if row else None
except Exception as e:
logger.error(f"Error getting dataset info for {dataset_id}: {e}")
return None
async def get_datasets_with_summaries(self, user_id: str) -> List[Dict[str, Any]]:
"""
Get all user-accessible datasets with their summaries.
Used for context injection in chat.
Args:
user_id: UUID of the user
Returns:
List of datasets with summaries
"""
try:
client = await get_postgresql_client()
async with client.get_connection() as conn:
schema_name = self.settings.postgres_schema
rows = await conn.fetch(f"""
SELECT id, name, description, summary, summary_generated_at,
document_count, total_size_bytes
FROM {schema_name}.datasets
WHERE (created_by = $1::uuid
OR access_group IN ('team', 'organization'))
AND is_active = true
ORDER BY name ASC
""", user_id)
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"Error getting datasets with summaries for user {user_id}: {e}")
return []
async def get_filtered_datasets_with_summaries(
self,
user_id: str,
allowed_dataset_ids: List[str]
) -> List[Dict[str, Any]]:
"""
Get datasets with summaries filtered by allowed dataset IDs.
Used for agent-aware context injection in chat.
Args:
user_id: User UUID string
allowed_dataset_ids: List of dataset IDs the agent/user should see
Returns:
List of dataset dictionaries with summaries, filtered by allowed IDs
"""
if not allowed_dataset_ids:
logger.info(f"No allowed dataset IDs provided for user {user_id} - returning empty list")
return []
try:
client = await get_postgresql_client()
async with client.get_connection() as conn:
schema_name = self.settings.postgres_schema
# Convert dataset IDs to UUID format for query
placeholders = ",".join(f"${i+2}::uuid" for i in range(len(allowed_dataset_ids)))
query = f"""
SELECT id, name, description, summary, summary_generated_at,
document_count, total_size_bytes
FROM {schema_name}.datasets
WHERE (created_by = $1::uuid
OR access_group IN ('team', 'organization'))
AND is_active = true
AND id = ANY(ARRAY[{placeholders}])
ORDER BY name ASC
"""
params = [user_id] + allowed_dataset_ids
rows = await conn.fetch(query, *params)
filtered_datasets = [dict(row) for row in rows]
logger.info(f"Filtered datasets for user {user_id}: {len(filtered_datasets)} out of {len(allowed_dataset_ids)} requested")
return filtered_datasets
except Exception as e:
logger.error(f"Error getting filtered datasets with summaries for user {user_id}: {e}")
return []
# Factory function for dependency injection
def get_summarization_service(tenant_domain: str, user_id: str) -> SummarizationService:
"""Factory function to create SummarizationService instance"""
return SummarizationService(tenant_domain, user_id)

View File

@@ -0,0 +1,478 @@
"""
GT 2.0 Task Classifier Service
Analyzes user queries to determine task complexity and required subagent orchestration.
Enables highly agentic behavior by intelligently routing tasks to specialized subagents.
"""
import logging
import re
from typing import Dict, Any, List, Optional, Tuple
from enum import Enum
from dataclasses import dataclass
logger = logging.getLogger(__name__)
class TaskComplexity(str, Enum):
"""Task complexity levels"""
SIMPLE = "simple" # Direct response, no tools needed
TOOL_ASSISTED = "tool_assisted" # Single tool call required
MULTI_STEP = "multi_step" # Multiple sequential steps
RESEARCH = "research" # Information gathering from multiple sources
IMPLEMENTATION = "implementation" # Code/config changes
COMPLEX = "complex" # Requires multiple subagents
class SubagentType(str, Enum):
"""Types of specialized subagents"""
RESEARCH = "research" # Information gathering
PLANNING = "planning" # Task decomposition
IMPLEMENTATION = "implementation" # Execution
VALIDATION = "validation" # Quality checks
SYNTHESIS = "synthesis" # Result aggregation
MONITOR = "monitor" # Status checking
ANALYST = "analyst" # Data analysis
@dataclass
class TaskClassification:
"""Result of task classification"""
complexity: TaskComplexity
confidence: float
primary_intent: str
subagent_plan: List[Dict[str, Any]]
estimated_tools: List[str]
parallel_execution: bool
requires_confirmation: bool
reasoning: str
@dataclass
class SubagentTask:
"""Task definition for a subagent"""
subagent_type: SubagentType
task_description: str
required_tools: List[str]
depends_on: List[str] # IDs of other subagent tasks
priority: int
timeout_seconds: int
input_data: Optional[Dict[str, Any]] = None
class TaskClassifier:
"""
Classifies user tasks and creates subagent execution plans.
Analyzes query patterns, identifies required capabilities,
and orchestrates multi-agent workflows for complex tasks.
"""
def __init__(self):
# Pattern matchers for different task types
self.research_patterns = [
r"find\s+(?:all\s+)?(?:information|documents?|files?)\s+about",
r"search\s+for",
r"what\s+(?:is|are|does|do)",
r"explain\s+(?:how|what|why)",
r"list\s+(?:all\s+)?the",
r"show\s+me\s+(?:all\s+)?(?:the\s+)?",
r"check\s+(?:the\s+)?(?:recent|latest|current)",
]
self.implementation_patterns = [
r"(?:create|add|implement|build|write)\s+(?:a\s+)?(?:new\s+)?",
r"(?:update|modify|change|edit|fix)\s+(?:the\s+)?",
r"(?:delete|remove|clean\s+up)\s+(?:the\s+)?",
r"(?:deploy|install|configure|setup)\s+",
r"(?:refactor|optimize|improve)\s+",
]
self.analysis_patterns = [
r"analyze\s+(?:the\s+)?",
r"compare\s+(?:the\s+)?",
r"summarize\s+(?:the\s+)?",
r"evaluate\s+(?:the\s+)?",
r"review\s+(?:the\s+)?",
r"identify\s+(?:patterns|trends|issues)",
]
self.multi_step_indicators = [
r"(?:and\s+then|after\s+that|followed\s+by)",
r"(?:first|second|third|finally)",
r"(?:step\s+\d+|phase\s+\d+)",
r"make\s+sure\s+(?:to\s+)?",
r"(?:also|additionally|furthermore)",
r"for\s+(?:each|every|all)\s+",
]
logger.info("Task classifier initialized")
async def classify_task(
self,
query: str,
conversation_context: Optional[List[Dict[str, Any]]] = None,
available_tools: Optional[List[str]] = None
) -> TaskClassification:
"""
Classify a user query and create execution plan.
Args:
query: User's input query
conversation_context: Previous messages for context
available_tools: List of available MCP tools
Returns:
TaskClassification with complexity assessment and execution plan
"""
query_lower = query.lower()
# Analyze query characteristics
is_research = self._matches_patterns(query_lower, self.research_patterns)
is_implementation = self._matches_patterns(query_lower, self.implementation_patterns)
is_analysis = self._matches_patterns(query_lower, self.analysis_patterns)
is_multi_step = self._matches_patterns(query_lower, self.multi_step_indicators)
# Count potential tool requirements
tool_indicators = self._identify_tool_indicators(query_lower)
# Determine complexity
complexity = self._determine_complexity(
is_research, is_implementation, is_analysis, is_multi_step, tool_indicators
)
# Create subagent plan based on complexity
subagent_plan = await self._create_subagent_plan(
query, complexity, is_research, is_implementation, is_analysis, available_tools
)
# Estimate required tools
estimated_tools = self._estimate_required_tools(query_lower, available_tools)
# Determine if parallel execution is possible
parallel_execution = self._can_execute_parallel(subagent_plan)
# Check if confirmation is needed
requires_confirmation = complexity in [TaskComplexity.IMPLEMENTATION, TaskComplexity.COMPLEX]
# Generate reasoning
reasoning = self._generate_reasoning(
query, complexity, is_research, is_implementation, is_analysis, is_multi_step
)
return TaskClassification(
complexity=complexity,
confidence=self._calculate_confidence(complexity, subagent_plan),
primary_intent=self._identify_primary_intent(is_research, is_implementation, is_analysis),
subagent_plan=subagent_plan,
estimated_tools=estimated_tools,
parallel_execution=parallel_execution,
requires_confirmation=requires_confirmation,
reasoning=reasoning
)
def _matches_patterns(self, text: str, patterns: List[str]) -> bool:
"""Check if text matches any of the patterns"""
for pattern in patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
def _identify_tool_indicators(self, query: str) -> List[str]:
"""Identify potential tool usage from query"""
indicators = []
tool_keywords = {
"search": ["search", "find", "look for", "locate"],
"database": ["database", "query", "sql", "records"],
"file": ["file", "document", "upload", "download"],
"api": ["api", "endpoint", "service", "integration"],
"conversation": ["conversation", "chat", "history", "previous"],
"web": ["website", "url", "browse", "fetch"],
}
for tool_type, keywords in tool_keywords.items():
if any(keyword in query for keyword in keywords):
indicators.append(tool_type)
return indicators
def _determine_complexity(
self,
is_research: bool,
is_implementation: bool,
is_analysis: bool,
is_multi_step: bool,
tool_indicators: List[str]
) -> TaskComplexity:
"""Determine task complexity based on characteristics"""
# Count complexity factors
factors = sum([is_research, is_implementation, is_analysis, is_multi_step])
tool_count = len(tool_indicators)
if factors == 0 and tool_count == 0:
return TaskComplexity.SIMPLE
elif factors == 1 and tool_count <= 1:
return TaskComplexity.TOOL_ASSISTED
elif is_multi_step or factors >= 2:
if is_implementation:
return TaskComplexity.IMPLEMENTATION
elif is_research and (is_analysis or tool_count > 2):
return TaskComplexity.RESEARCH
else:
return TaskComplexity.MULTI_STEP
elif factors > 2 or (is_multi_step and is_implementation):
return TaskComplexity.COMPLEX
else:
return TaskComplexity.TOOL_ASSISTED
async def _create_subagent_plan(
self,
query: str,
complexity: TaskComplexity,
is_research: bool,
is_implementation: bool,
is_analysis: bool,
available_tools: Optional[List[str]]
) -> List[Dict[str, Any]]:
"""Create execution plan with subagents"""
plan = []
if complexity == TaskComplexity.SIMPLE:
# No subagents needed
return []
elif complexity == TaskComplexity.TOOL_ASSISTED:
# Single subagent for tool execution
plan.append({
"id": "tool_executor_1",
"type": SubagentType.IMPLEMENTATION,
"task": f"Execute required tool for: {query[:100]}",
"depends_on": [],
"priority": 1
})
elif complexity == TaskComplexity.RESEARCH:
# Research workflow
plan.extend([
{
"id": "researcher_1",
"type": SubagentType.RESEARCH,
"task": f"Gather information about: {query[:100]}",
"depends_on": [],
"priority": 1
},
{
"id": "analyst_1",
"type": SubagentType.ANALYST,
"task": "Analyze gathered information",
"depends_on": ["researcher_1"],
"priority": 2
},
{
"id": "synthesizer_1",
"type": SubagentType.SYNTHESIS,
"task": "Compile findings into comprehensive response",
"depends_on": ["analyst_1"],
"priority": 3
}
])
elif complexity == TaskComplexity.IMPLEMENTATION:
# Implementation workflow
plan.extend([
{
"id": "planner_1",
"type": SubagentType.PLANNING,
"task": f"Create implementation plan for: {query[:100]}",
"depends_on": [],
"priority": 1
},
{
"id": "implementer_1",
"type": SubagentType.IMPLEMENTATION,
"task": "Execute implementation steps",
"depends_on": ["planner_1"],
"priority": 2
},
{
"id": "validator_1",
"type": SubagentType.VALIDATION,
"task": "Validate implementation results",
"depends_on": ["implementer_1"],
"priority": 3
}
])
elif complexity in [TaskComplexity.MULTI_STEP, TaskComplexity.COMPLEX]:
# Complex multi-agent workflow
if is_research:
plan.append({
"id": "researcher_1",
"type": SubagentType.RESEARCH,
"task": "Research required information",
"depends_on": [],
"priority": 1
})
plan.append({
"id": "planner_1",
"type": SubagentType.PLANNING,
"task": f"Decompose complex task: {query[:100]}",
"depends_on": ["researcher_1"] if is_research else [],
"priority": 2
})
if is_implementation:
plan.append({
"id": "implementer_1",
"type": SubagentType.IMPLEMENTATION,
"task": "Execute planned steps",
"depends_on": ["planner_1"],
"priority": 3
})
if is_analysis:
plan.append({
"id": "analyst_1",
"type": SubagentType.ANALYST,
"task": "Analyze results and patterns",
"depends_on": ["implementer_1"] if is_implementation else ["planner_1"],
"priority": 4
})
# Always add synthesis for complex tasks
final_deps = []
if is_analysis:
final_deps.append("analyst_1")
elif is_implementation:
final_deps.append("implementer_1")
else:
final_deps.append("planner_1")
plan.append({
"id": "synthesizer_1",
"type": SubagentType.SYNTHESIS,
"task": "Synthesize all results into final response",
"depends_on": final_deps,
"priority": 5
})
return plan
def _estimate_required_tools(
self,
query: str,
available_tools: Optional[List[str]]
) -> List[str]:
"""Estimate which tools will be needed"""
if not available_tools:
return []
estimated = []
# Map query patterns to tools
tool_patterns = {
"search_datasets": ["search", "find", "look for", "dataset", "document"],
"brave_search": ["web", "internet", "online", "website", "current"],
"list_directory": ["files", "directory", "folder", "ls"],
"read_file": ["read", "view", "open", "file content"],
"write_file": ["write", "create", "save", "generate file"],
}
for tool in available_tools:
if tool in tool_patterns:
if any(pattern in query for pattern in tool_patterns[tool]):
estimated.append(tool)
return estimated
def _can_execute_parallel(self, subagent_plan: List[Dict[str, Any]]) -> bool:
"""Check if any subagents can run in parallel"""
if len(subagent_plan) < 2:
return False
# Group by priority to find parallel opportunities
priority_groups = {}
for agent in subagent_plan:
priority = agent.get("priority", 1)
if priority not in priority_groups:
priority_groups[priority] = []
priority_groups[priority].append(agent)
# If any priority level has multiple agents, parallel execution is possible
return any(len(agents) > 1 for agents in priority_groups.values())
def _calculate_confidence(
self,
complexity: TaskComplexity,
subagent_plan: List[Dict[str, Any]]
) -> float:
"""Calculate confidence score for classification"""
base_confidence = {
TaskComplexity.SIMPLE: 0.95,
TaskComplexity.TOOL_ASSISTED: 0.9,
TaskComplexity.MULTI_STEP: 0.85,
TaskComplexity.RESEARCH: 0.85,
TaskComplexity.IMPLEMENTATION: 0.8,
TaskComplexity.COMPLEX: 0.75
}
confidence = base_confidence.get(complexity, 0.7)
# Adjust based on plan clarity
if len(subagent_plan) > 0:
confidence += 0.05
return min(confidence, 1.0)
def _identify_primary_intent(
self,
is_research: bool,
is_implementation: bool,
is_analysis: bool
) -> str:
"""Identify the primary intent of the query"""
if is_implementation:
return "implementation"
elif is_research:
return "research"
elif is_analysis:
return "analysis"
else:
return "general"
def _generate_reasoning(
self,
query: str,
complexity: TaskComplexity,
is_research: bool,
is_implementation: bool,
is_analysis: bool,
is_multi_step: bool
) -> str:
"""Generate reasoning explanation for classification"""
reasons = []
if is_multi_step:
reasons.append("Query indicates multiple sequential steps")
if is_research:
reasons.append("Information gathering required")
if is_implementation:
reasons.append("Code or configuration changes needed")
if is_analysis:
reasons.append("Data analysis and synthesis required")
if complexity == TaskComplexity.COMPLEX:
reasons.append("Multiple specialized agents needed for comprehensive execution")
elif complexity == TaskComplexity.SIMPLE:
reasons.append("Straightforward query with direct response possible")
return ". ".join(reasons) if reasons else "Standard query processing"
# Factory function
def get_task_classifier() -> TaskClassifier:
"""Get task classifier instance"""
return TaskClassifier()

View File

@@ -0,0 +1,410 @@
"""
Team Access Control Service for GT 2.0 Tenant Backend
Implements team-based access control with file-based simplicity.
Follows GT 2.0's principle of "Zero Complexity Addition"
- Simple role-based permissions stored in files
- Fast access checks using SQLite indexes
- Perfect tenant isolation maintained
"""
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_
import logging
from app.models.team import Team, TeamRole, OrganizationSettings
from app.models.agent import Agent
from app.models.document import RAGDataset
logger = logging.getLogger(__name__)
class TeamAccessService:
"""Elegant team-based access control following GT 2.0 philosophy"""
def __init__(self, db: AsyncSession):
self.db = db
self._role_cache = {} # Cache role permissions in memory
async def check_team_access(
self,
user_email: str,
resource: Any,
action: str,
user_teams: Optional[List[int]] = None
) -> bool:
"""Check if user has access to perform action on resource
GT 2.0 Design: Simple, fast access checks without complex hierarchies
"""
try:
# Step 1: Check resource ownership (fastest check)
if hasattr(resource, 'created_by') and resource.created_by == user_email:
return True # Owners always have full access
# Step 2: Check visibility-based access
if hasattr(resource, 'visibility'):
# Organization-wide resources
if resource.visibility == "organization":
return self._check_organization_action(action)
# Team resources
if resource.visibility == "team" and resource.tenant_id:
if not user_teams:
user_teams = await self.get_user_teams(user_email)
if resource.tenant_id in user_teams:
return await self._check_team_action(
user_email,
resource.tenant_id,
action
)
# Explicitly shared resources
if hasattr(resource, 'shared_with') and resource.shared_with:
if user_email in resource.shared_with:
return self._check_shared_action(action)
# Step 3: Default deny for private resources not owned by user
return False
except Exception as e:
logger.error(f"Error checking team access: {e}")
return False # Fail closed on errors
async def get_user_teams(self, user_email: str) -> List[int]:
"""Get all teams the user belongs to
GT 2.0: Simple file-based membership check
"""
try:
# Query all active teams
result = await self.db.execute(
select(Team).where(Team.is_active == True)
)
teams = result.scalars().all()
user_team_ids = []
for team in teams:
if team.is_member(user_email):
user_team_ids.append(team.id)
return user_team_ids
except Exception as e:
logger.error(f"Error getting user teams: {e}")
return []
async def get_user_role_in_team(self, user_email: str, team_id: int) -> Optional[str]:
"""Get user's role in a specific team"""
try:
result = await self.db.execute(
select(Team).where(Team.id == team_id)
)
team = result.scalar_one_or_none()
if team:
return team.get_member_role(user_email)
return None
except Exception as e:
logger.error(f"Error getting user role: {e}")
return None
async def get_team_resources(
self,
team_id: int,
resource_type: str,
user_email: str
) -> List[Any]:
"""Get all resources accessible to a team
GT 2.0: Simple visibility-based filtering
"""
try:
if resource_type == "agent":
# Get team and organization agents
result = await self.db.execute(
select(Agent).where(
and_(
Agent.is_active == True,
or_(
and_(
Agent.visibility == "team",
Agent.tenant_id == team_id
),
Agent.visibility == "organization"
)
)
)
)
return result.scalars().all()
elif resource_type == "dataset":
# Get team and organization datasets
result = await self.db.execute(
select(RAGDataset).where(
and_(
RAGDataset.status == "active",
or_(
and_(
RAGDataset.visibility == "team",
RAGDataset.tenant_id == team_id
),
RAGDataset.visibility == "organization"
)
)
)
)
return result.scalars().all()
return []
except Exception as e:
logger.error(f"Error getting team resources: {e}")
return []
async def share_with_team(
self,
resource: Any,
team_id: int,
sharer_email: str
) -> bool:
"""Share a resource with a team
GT 2.0: Simple visibility update, no complex permissions
"""
try:
# Verify sharer owns the resource or has sharing permission
if not self._can_share_resource(resource, sharer_email):
return False
# Update resource visibility
resource.visibility = "team"
resource.tenant_id = team_id
await self.db.commit()
return True
except Exception as e:
logger.error(f"Error sharing with team: {e}")
await self.db.rollback()
return False
async def share_with_users(
self,
resource: Any,
user_emails: List[str],
sharer_email: str
) -> bool:
"""Share a resource with specific users
GT 2.0: Simple list-based sharing
"""
try:
# Verify sharer owns the resource
if not self._can_share_resource(resource, sharer_email):
return False
# Update shared_with list
current_shared = resource.shared_with or []
new_shared = list(set(current_shared + user_emails))
resource.shared_with = new_shared
await self.db.commit()
return True
except Exception as e:
logger.error(f"Error sharing with users: {e}")
await self.db.rollback()
return False
async def create_team(
self,
name: str,
description: str,
team_type: str,
creator_email: str
) -> Optional[Team]:
"""Create a new team
GT 2.0: File-based team with simple SQLite reference
"""
try:
# Check if user can create teams
org_settings = await self._get_organization_settings()
if not org_settings.allow_team_creation:
logger.warning(f"Team creation disabled for organization")
return None
# Check user's team limit
user_teams = await self.get_user_teams(creator_email)
if len(user_teams) >= org_settings.max_teams_per_user:
logger.warning(f"User {creator_email} reached team limit")
return None
# Create team
team = Team(
name=name,
description=description,
team_type=team_type,
created_by=creator_email
)
# Initialize with placeholder paths
team.config_file_path = "placeholder"
team.members_file_path = "placeholder"
# Save to get ID
self.db.add(team)
await self.db.flush()
# Initialize proper file paths
team.initialize_file_paths()
# Add creator as owner
team.add_member(creator_email, "owner", {"joined_as": "creator"})
# Save initial config
config = {
"name": name,
"description": description,
"team_type": team_type,
"created_by": creator_email,
"settings": {}
}
team.save_config_to_file(config)
await self.db.commit()
await self.db.refresh(team)
logger.info(f"Created team {team.id} by {creator_email}")
return team
except Exception as e:
logger.error(f"Error creating team: {e}")
await self.db.rollback()
return None
# Private helper methods
def _check_organization_action(self, action: str) -> bool:
"""Check if action is allowed for organization resources"""
# Organization resources are viewable by all
if action in ["view", "use", "read"]:
return True
# Only owners can modify
return False
async def _check_team_action(
self,
user_email: str,
team_id: int,
action: str
) -> bool:
"""Check if user can perform action on team resource"""
role = await self.get_user_role_in_team(user_email, team_id)
if not role:
return False
# Get role permissions
permissions = await self._get_role_permissions(role)
# Map action to permission
action_permission_map = {
"view": "can_view_resources",
"read": "can_view_resources",
"use": "can_view_resources",
"create": "can_create_resources",
"edit": "can_edit_team_resources",
"update": "can_edit_team_resources",
"delete": "can_delete_team_resources",
"manage_members": "can_manage_members",
"manage_team": "can_manage_team",
}
permission_needed = action_permission_map.get(action, None)
if permission_needed:
return permissions.get(permission_needed, False)
return False
def _check_shared_action(self, action: str) -> bool:
"""Check if action is allowed for shared resources"""
# Shared resources can be viewed and used
if action in ["view", "use", "read"]:
return True
return False
def _can_share_resource(self, resource: Any, user_email: str) -> bool:
"""Check if user can share a resource"""
# Owners can always share
if hasattr(resource, 'created_by') and resource.created_by == user_email:
return True
# Team leads can share team resources
# (Would need to check team role here in full implementation)
return False
async def _get_role_permissions(self, role_name: str) -> Dict[str, bool]:
"""Get permissions for a role (with caching)"""
if role_name in self._role_cache:
return self._role_cache[role_name]
result = await self.db.execute(
select(TeamRole).where(TeamRole.name == role_name)
)
role = result.scalar_one_or_none()
if role:
permissions = {
"can_view_resources": role.can_view_resources,
"can_create_resources": role.can_create_resources,
"can_edit_team_resources": role.can_edit_team_resources,
"can_delete_team_resources": role.can_delete_team_resources,
"can_manage_members": role.can_manage_members,
"can_manage_team": role.can_manage_team,
}
self._role_cache[role_name] = permissions
return permissions
# Default to viewer permissions
return {
"can_view_resources": True,
"can_create_resources": False,
"can_edit_team_resources": False,
"can_delete_team_resources": False,
"can_manage_members": False,
"can_manage_team": False,
}
async def _get_organization_settings(self) -> OrganizationSettings:
"""Get organization settings (create default if not exists)"""
result = await self.db.execute(
select(OrganizationSettings).limit(1)
)
settings = result.scalar_one_or_none()
if not settings:
# Create default settings
settings = OrganizationSettings(
organization_name="Default Organization",
organization_domain="example.com"
)
settings.config_file_path = "placeholder"
self.db.add(settings)
await self.db.flush()
settings.initialize_file_paths()
settings.save_config_to_file({
"initialized": True,
"default_config": True
})
await self.db.commit()
await self.db.refresh(settings)
return settings

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,233 @@
"""
GT 2.0 User Service - User Preferences Management
Manages user preferences including favorite agents using PostgreSQL + PGVector backend.
Perfect tenant isolation - each tenant has separate user data.
"""
import json
import logging
from typing import List, Optional, Dict, Any
from app.core.config import get_settings
from app.core.postgresql_client import get_postgresql_client
logger = logging.getLogger(__name__)
class UserService:
"""GT 2.0 PostgreSQL User Service with Perfect Tenant Isolation"""
def __init__(self, tenant_domain: str, user_id: str, user_email: str = None):
"""Initialize with tenant and user isolation using PostgreSQL storage"""
self.tenant_domain = tenant_domain
self.user_id = user_id
self.user_email = user_email or user_id
self.settings = get_settings()
logger.info(f"User service initialized for {tenant_domain}/{user_id} (email: {self.user_email})")
async def _get_user_id(self, pg_client) -> Optional[str]:
"""Get user ID from email or user_id with fallback"""
user_lookup_query = """
SELECT id FROM users
WHERE (email = $1 OR id::text = $1 OR username = $1)
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
LIMIT 1
"""
user_id = await pg_client.fetch_scalar(user_lookup_query, self.user_email, self.tenant_domain)
if not user_id:
# If not found by email, try by user_id
user_id = await pg_client.fetch_scalar(user_lookup_query, self.user_id, self.tenant_domain)
return user_id
async def get_user_preferences(self) -> Dict[str, Any]:
"""Get user preferences from PostgreSQL"""
try:
pg_client = await get_postgresql_client()
user_id = await self._get_user_id(pg_client)
if not user_id:
logger.warning(f"User not found: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}")
return {}
query = """
SELECT preferences
FROM users
WHERE id = $1
AND tenant_id = (SELECT id FROM tenants WHERE domain = $2 LIMIT 1)
"""
result = await pg_client.fetch_one(query, user_id, self.tenant_domain)
if result and result["preferences"]:
prefs = result["preferences"]
# Handle both dict and JSON string
if isinstance(prefs, str):
return json.loads(prefs)
return prefs
return {}
except Exception as e:
logger.error(f"Error getting user preferences: {e}")
return {}
async def update_user_preferences(self, preferences: Dict[str, Any]) -> bool:
"""Update user preferences in PostgreSQL (merges with existing)"""
try:
pg_client = await get_postgresql_client()
user_id = await self._get_user_id(pg_client)
if not user_id:
logger.warning(f"User not found: {self.user_email} (or {self.user_id}) in tenant {self.tenant_domain}")
return False
# Merge with existing preferences using PostgreSQL JSONB || operator
query = """
UPDATE users
SET preferences = COALESCE(preferences, '{}'::jsonb) || $1::jsonb,
updated_at = NOW()
WHERE id = $2
AND tenant_id = (SELECT id FROM tenants WHERE domain = $3 LIMIT 1)
RETURNING id
"""
updated_id = await pg_client.fetch_scalar(
query,
json.dumps(preferences),
user_id,
self.tenant_domain
)
if updated_id:
logger.info(f"Updated preferences for user {user_id}")
return True
return False
except Exception as e:
logger.error(f"Error updating user preferences: {e}")
return False
async def get_favorite_agent_ids(self) -> List[str]:
"""Get user's favorited agent IDs"""
try:
preferences = await self.get_user_preferences()
favorite_ids = preferences.get("favorite_agent_ids", [])
# Ensure it's a list
if not isinstance(favorite_ids, list):
return []
logger.info(f"Retrieved {len(favorite_ids)} favorite agent IDs for user {self.user_id}")
return favorite_ids
except Exception as e:
logger.error(f"Error getting favorite agent IDs: {e}")
return []
async def update_favorite_agent_ids(self, agent_ids: List[str]) -> bool:
"""Update user's favorited agent IDs"""
try:
# Validate agent_ids is a list
if not isinstance(agent_ids, list):
logger.error(f"Invalid agent_ids type: {type(agent_ids)}")
return False
# Update preferences with new favorite_agent_ids
success = await self.update_user_preferences({
"favorite_agent_ids": agent_ids
})
if success:
logger.info(f"Updated {len(agent_ids)} favorite agent IDs for user {self.user_id}")
return success
except Exception as e:
logger.error(f"Error updating favorite agent IDs: {e}")
return False
async def add_favorite_agent(self, agent_id: str) -> bool:
"""Add a single agent to favorites"""
try:
current_favorites = await self.get_favorite_agent_ids()
if agent_id not in current_favorites:
current_favorites.append(agent_id)
return await self.update_favorite_agent_ids(current_favorites)
# Already in favorites
return True
except Exception as e:
logger.error(f"Error adding favorite agent: {e}")
return False
async def remove_favorite_agent(self, agent_id: str) -> bool:
"""Remove a single agent from favorites"""
try:
current_favorites = await self.get_favorite_agent_ids()
if agent_id in current_favorites:
current_favorites.remove(agent_id)
return await self.update_favorite_agent_ids(current_favorites)
# Not in favorites
return True
except Exception as e:
logger.error(f"Error removing favorite agent: {e}")
return False
async def get_custom_categories(self) -> List[Dict[str, Any]]:
"""Get user's custom agent categories"""
try:
preferences = await self.get_user_preferences()
custom_categories = preferences.get("custom_categories", [])
# Ensure it's a list
if not isinstance(custom_categories, list):
return []
logger.info(f"Retrieved {len(custom_categories)} custom categories for user {self.user_id}")
return custom_categories
except Exception as e:
logger.error(f"Error getting custom categories: {e}")
return []
async def update_custom_categories(self, categories: List[Dict[str, Any]]) -> bool:
"""Update user's custom agent categories (replaces entire list)"""
try:
# Validate categories is a list
if not isinstance(categories, list):
logger.error(f"Invalid categories type: {type(categories)}")
return False
# Convert Pydantic models to dicts if needed
category_dicts = []
for cat in categories:
if hasattr(cat, 'dict'):
category_dicts.append(cat.dict())
elif isinstance(cat, dict):
category_dicts.append(cat)
else:
logger.error(f"Invalid category type: {type(cat)}")
return False
# Update preferences with new custom_categories
success = await self.update_user_preferences({
"custom_categories": category_dicts
})
if success:
logger.info(f"Updated {len(category_dicts)} custom categories for user {self.user_id}")
return success
except Exception as e:
logger.error(f"Error updating custom categories: {e}")
return False

View File

@@ -0,0 +1,448 @@
"""
Vector Store Service for Tenant Backend
Manages tenant-specific ChromaDB instances with encryption.
All vectors are stored locally in the tenant's encrypted database.
"""
import logging
import os
import hashlib
import json
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import chromadb
from chromadb.config import Settings
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2
import base64
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class VectorSearchResult:
"""Result from vector search"""
document_id: str
text: str
score: float
metadata: Dict[str, Any]
class TenantEncryption:
"""Encryption handler for tenant data"""
def __init__(self, tenant_id: str):
"""Initialize encryption for tenant"""
# Derive encryption key from tenant-specific secret
tenant_key = f"{settings.SECRET_KEY}:{tenant_id}"
kdf = PBKDF2(
algorithm=hashes.SHA256(),
length=32,
salt=tenant_id.encode(),
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(tenant_key.encode()))
self.cipher = Fernet(key)
def encrypt(self, data: str) -> bytes:
"""Encrypt string data"""
return self.cipher.encrypt(data.encode())
def decrypt(self, encrypted_data: bytes) -> str:
"""Decrypt data to string"""
return self.cipher.decrypt(encrypted_data).decode()
def encrypt_vector(self, vector: List[float]) -> bytes:
"""Encrypt vector data"""
vector_str = json.dumps(vector)
return self.encrypt(vector_str)
def decrypt_vector(self, encrypted_vector: bytes) -> List[float]:
"""Decrypt vector data"""
vector_str = self.decrypt(encrypted_vector)
return json.loads(vector_str)
class VectorStoreService:
"""
Manages tenant-specific vector storage with ChromaDB.
Security principles:
- All vectors stored in tenant-specific directory
- Encryption at rest for all data
- User-scoped collections
- No cross-tenant access
"""
def __init__(self, tenant_id: str, tenant_domain: str):
self.tenant_id = tenant_id
self.tenant_domain = tenant_domain
# Initialize encryption
self.encryption = TenantEncryption(tenant_id)
# Initialize ChromaDB client based on configuration mode
if settings.chromadb_mode == "http":
# Use HTTP client for per-tenant ChromaDB server
self.client = chromadb.HttpClient(
host=settings.chromadb_host,
port=settings.chromadb_port,
settings=Settings(
anonymized_telemetry=False,
allow_reset=False
)
)
logger.info(f"Vector store initialized for tenant {tenant_domain} using HTTP mode at {settings.chromadb_host}:{settings.chromadb_port}")
else:
# Use file-based client (fallback)
self.storage_path = settings.chromadb_path
os.makedirs(self.storage_path, exist_ok=True, mode=0o700)
self.client = chromadb.PersistentClient(
path=self.storage_path,
settings=Settings(
anonymized_telemetry=False,
allow_reset=False,
is_persistent=True
)
)
logger.info(f"Vector store initialized for tenant {tenant_domain} using file mode at {self.storage_path}")
async def create_user_collection(
self,
user_id: str,
collection_name: str,
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""
Create a user-scoped collection.
Args:
user_id: User identifier
collection_name: Name of the collection
metadata: Optional collection metadata
Returns:
Collection ID
"""
# Generate unique collection name for user
collection_id = f"{user_id}_{collection_name}"
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
internal_name = f"col_{collection_hash}"
try:
# Create or get collection
collection = self.client.get_or_create_collection(
name=internal_name,
metadata={
"user_id": user_id,
"collection_name": collection_name,
"tenant_id": self.tenant_id,
**(metadata or {})
}
)
logger.info(f"Created collection {collection_name} for user {user_id}")
return collection_id
except Exception as e:
logger.error(f"Error creating collection: {e}")
raise
async def store_vectors(
self,
user_id: str,
collection_name: str,
documents: List[str],
embeddings: List[List[float]],
ids: Optional[List[str]] = None,
metadata: Optional[List[Dict[str, Any]]] = None
) -> bool:
"""
Store vectors in user collection with encryption.
Args:
user_id: User identifier
collection_name: Collection name
documents: List of document texts
embeddings: List of embedding vectors
ids: Optional document IDs
metadata: Optional document metadata
Returns:
Success status
"""
try:
# Get collection
collection_id = f"{user_id}_{collection_name}"
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
internal_name = f"col_{collection_hash}"
collection = self.client.get_collection(name=internal_name)
# Generate IDs if not provided
if ids is None:
ids = [
hashlib.sha256(f"{doc}:{i}".encode()).hexdigest()[:16]
for i, doc in enumerate(documents)
]
# Encrypt documents before storage
encrypted_docs = [
self.encryption.encrypt(doc).decode('latin-1')
for doc in documents
]
# Prepare metadata with encryption status
final_metadata = []
for i, doc_meta in enumerate(metadata or [{}] * len(documents)):
meta = {
**doc_meta,
"encrypted": True,
"user_id": user_id,
"doc_hash": hashlib.sha256(documents[i].encode()).hexdigest()[:16]
}
final_metadata.append(meta)
# Add to collection
collection.add(
ids=ids,
embeddings=embeddings,
documents=encrypted_docs,
metadatas=final_metadata
)
logger.info(
f"Stored {len(documents)} vectors in collection {collection_name} "
f"for user {user_id}"
)
return True
except Exception as e:
logger.error(f"Error storing vectors: {e}")
raise
async def search(
self,
user_id: str,
collection_name: str,
query_embedding: List[float],
top_k: int = 5,
filter_metadata: Optional[Dict[str, Any]] = None
) -> List[VectorSearchResult]:
"""
Search for similar vectors in user collection.
Args:
user_id: User identifier
collection_name: Collection name
query_embedding: Query vector
top_k: Number of results to return
filter_metadata: Optional metadata filters
Returns:
List of search results with decrypted content
"""
try:
# Get collection
collection_id = f"{user_id}_{collection_name}"
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
internal_name = f"col_{collection_hash}"
collection = self.client.get_collection(name=internal_name)
# Prepare filter
where_filter = {"user_id": user_id}
if filter_metadata:
where_filter.update(filter_metadata)
# Query collection
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=where_filter
)
# Process results
search_results = []
if results and results['ids'] and len(results['ids'][0]) > 0:
for i in range(len(results['ids'][0])):
# Decrypt document text
encrypted_doc = results['documents'][0][i].encode('latin-1')
decrypted_doc = self.encryption.decrypt(encrypted_doc)
search_results.append(VectorSearchResult(
document_id=results['ids'][0][i],
text=decrypted_doc,
score=1.0 - results['distances'][0][i], # Convert distance to similarity
metadata=results['metadatas'][0][i] if results['metadatas'] else {}
))
logger.info(
f"Found {len(search_results)} results in collection {collection_name} "
f"for user {user_id}"
)
return search_results
except Exception as e:
logger.error(f"Error searching vectors: {e}")
raise
async def delete_collection(
self,
user_id: str,
collection_name: str
) -> bool:
"""
Delete a user collection.
Args:
user_id: User identifier
collection_name: Collection name
Returns:
Success status
"""
try:
collection_id = f"{user_id}_{collection_name}"
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
internal_name = f"col_{collection_hash}"
self.client.delete_collection(name=internal_name)
logger.info(f"Deleted collection {collection_name} for user {user_id}")
return True
except Exception as e:
logger.error(f"Error deleting collection: {e}")
raise
async def list_user_collections(
self,
user_id: str
) -> List[Dict[str, Any]]:
"""
List all collections for a user.
Args:
user_id: User identifier
Returns:
List of collection information
"""
try:
all_collections = self.client.list_collections()
user_collections = []
for collection in all_collections:
metadata = collection.metadata
if metadata and metadata.get("user_id") == user_id:
user_collections.append({
"name": metadata.get("collection_name"),
"created_at": metadata.get("created_at"),
"document_count": collection.count(),
"metadata": metadata
})
return user_collections
except Exception as e:
logger.error(f"Error listing collections: {e}")
raise
async def get_collection_stats(
self,
user_id: str,
collection_name: str
) -> Dict[str, Any]:
"""
Get statistics for a user collection.
Args:
user_id: User identifier
collection_name: Collection name
Returns:
Collection statistics
"""
try:
collection_id = f"{user_id}_{collection_name}"
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
internal_name = f"col_{collection_hash}"
collection = self.client.get_collection(name=internal_name)
stats = {
"document_count": collection.count(),
"collection_name": collection_name,
"user_id": user_id,
"metadata": collection.metadata,
"storage_mode": settings.chromadb_mode,
"storage_path": getattr(self, 'storage_path', f"{settings.chromadb_host}:{settings.chromadb_port}")
}
return stats
except Exception as e:
logger.error(f"Error getting collection stats: {e}")
raise
async def update_document(
self,
user_id: str,
collection_name: str,
document_id: str,
new_text: Optional[str] = None,
new_embedding: Optional[List[float]] = None,
new_metadata: Optional[Dict[str, Any]] = None
) -> bool:
"""
Update a document in the collection.
Args:
user_id: User identifier
collection_name: Collection name
document_id: Document ID to update
new_text: Optional new text
new_embedding: Optional new embedding
new_metadata: Optional new metadata
Returns:
Success status
"""
try:
collection_id = f"{user_id}_{collection_name}"
collection_hash = hashlib.sha256(collection_id.encode()).hexdigest()[:8]
internal_name = f"col_{collection_hash}"
collection = self.client.get_collection(name=internal_name)
update_params = {"ids": [document_id]}
if new_text:
encrypted_text = self.encryption.encrypt(new_text).decode('latin-1')
update_params["documents"] = [encrypted_text]
if new_embedding:
update_params["embeddings"] = [new_embedding]
if new_metadata:
update_params["metadatas"] = [{
**new_metadata,
"encrypted": True,
"user_id": user_id
}]
collection.update(**update_params)
logger.info(f"Updated document {document_id} in collection {collection_name}")
return True
except Exception as e:
logger.error(f"Error updating document: {e}")
raise

View File

@@ -0,0 +1,416 @@
"""
WebSocket Integration Service for GT 2.0 Tenant Backend
Bridges the conversation service with WebSocket real-time streaming.
Provides AI response streaming and event-driven message broadcasting.
GT 2.0 Architecture Principles:
- Zero downtime: Non-blocking AI response streaming
- Perfect tenant isolation: All streaming scoped to tenant
- Self-contained: No external dependencies
- Event-driven: Integrates with event automation system
"""
import asyncio
import logging
import json
from typing import AsyncGenerator, Dict, Any, Optional
from datetime import datetime
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db_session
from app.services.conversation_service import ConversationService
from app.services.event_service import EventService, EventType
from app.services.agent_service import AgentService
from app.websocket import get_websocket_manager, ChatMessage
logger = logging.getLogger(__name__)
class WebSocketService:
"""
Service for WebSocket-integrated AI conversation streaming.
Combines conversation service with real-time WebSocket delivery.
"""
def __init__(self, db: AsyncSession):
self.db = db
self.conversation_service = ConversationService(db)
self.event_service = EventService(db)
# Agent service will be initialized per request with tenant context
self.websocket_manager = get_websocket_manager()
async def stream_ai_response(
self,
conversation_id: str,
user_message: str,
user_id: str,
tenant_id: str,
connection_id: str
) -> None:
"""
Stream AI response to WebSocket connection with real-time updates.
Args:
conversation_id: Target conversation
user_message: User's message content
user_id: User identifier
tenant_id: Tenant for isolation
connection_id: WebSocket connection to stream to
"""
try:
# Send streaming start notification
await self.websocket_manager.send_to_connection(connection_id, {
"type": "ai_response_start",
"conversation_id": conversation_id,
"timestamp": datetime.utcnow().isoformat()
})
# Add user message to conversation
await self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=user_message,
user_id=user_id
)
# Emit message sent event
await self.event_service.emit_event(
event_type=EventType.MESSAGE_SENT,
user_id=user_id,
tenant_id=tenant_id,
data={
"conversation_id": conversation_id,
"message_type": "user",
"content_length": len(user_message),
"streaming": True
}
)
# Get conversation context for AI response
conversation = await self.conversation_service.get_conversation(
conversation_id=conversation_id,
user_id=user_id,
include_messages=True
)
if not conversation:
raise ValueError("Conversation not found")
# Stream AI response
full_response = ""
message_id = str(uuid.uuid4())
# Get AI response generator
async for chunk in self._generate_ai_response_stream(conversation, user_message):
full_response += chunk
# Send chunk to WebSocket
await self.websocket_manager.send_to_connection(connection_id, {
"type": "ai_response_chunk",
"conversation_id": conversation_id,
"message_id": message_id,
"content": chunk,
"full_content_so_far": full_response,
"timestamp": datetime.utcnow().isoformat()
})
# Broadcast to other conversation participants
await self.websocket_manager.broadcast_to_conversation(
conversation_id,
{
"type": "ai_typing",
"conversation_id": conversation_id,
"content_preview": full_response[-50:] if len(full_response) > 50 else full_response
},
exclude_connection=connection_id
)
# Save complete AI response
await self.conversation_service.add_message(
conversation_id=conversation_id,
role="agent",
content=full_response,
user_id=user_id,
message_id=message_id
)
# Send completion notification
await self.websocket_manager.send_to_connection(connection_id, {
"type": "ai_response_complete",
"conversation_id": conversation_id,
"message_id": message_id,
"full_content": full_response,
"timestamp": datetime.utcnow().isoformat()
})
# Broadcast completion to conversation participants
await self.websocket_manager.broadcast_to_conversation(
conversation_id,
{
"type": "new_ai_message",
"conversation_id": conversation_id,
"message_id": message_id,
"content": full_response,
"timestamp": datetime.utcnow().isoformat()
},
exclude_connection=connection_id
)
# Emit AI response event
await self.event_service.emit_event(
event_type=EventType.MESSAGE_SENT,
user_id=user_id,
tenant_id=tenant_id,
data={
"conversation_id": conversation_id,
"message_id": message_id,
"message_type": "agent",
"content_length": len(full_response),
"streaming_completed": True
}
)
logger.info(f"AI response streaming completed for conversation {conversation_id}")
except Exception as e:
logger.error(f"Error streaming AI response: {e}")
# Send error notification
await self.websocket_manager.send_to_connection(connection_id, {
"type": "ai_response_error",
"conversation_id": conversation_id,
"error": str(e),
"timestamp": datetime.utcnow().isoformat()
})
raise
async def _generate_ai_response_stream(
self,
conversation: Dict[str, Any],
user_message: str
) -> AsyncGenerator[str, None]:
"""
Generate AI response stream chunks.
This is a placeholder implementation. In production, this would
integrate with the actual LLM service for streaming responses.
Args:
conversation: Conversation context
user_message: User's message
Yields:
Response text chunks
"""
# Get agent configuration
agent_id = conversation.get("agent_id")
if agent_id:
# Initialize AgentService with tenant context
agent_service = AgentService(tenant_id, user_id)
agent_config = await agent_service.get_agent(agent_id)
assistant_config = agent_config if agent_config else {}
else:
assistant_config = {}
# Build conversation context
messages = conversation.get("messages", [])
context = []
# Add system prompt if available
system_prompt = assistant_config.get("prompt", "You are a helpful AI agent.")
context.append({"role": "system", "content": system_prompt})
# Add recent conversation history (last 10 messages)
for msg in messages[-10:]:
context.append({
"role": msg["role"],
"content": msg["content"]
})
# Add current user message
context.append({"role": "user", "content": user_message})
# Simulate AI response streaming (replace with real LLM integration)
# This demonstrates the streaming pattern that would be used with actual AI services
response_text = await self._generate_mock_response(user_message, context)
# Stream response in chunks
chunk_size = 5 # Characters per chunk for demo
for i in range(0, len(response_text), chunk_size):
chunk = response_text[i:i + chunk_size]
# Add realistic delay for streaming effect
await asyncio.sleep(0.05)
yield chunk
async def _generate_mock_response(
self,
user_message: str,
context: list
) -> str:
"""
Generate mock AI response for development.
In production, this would be replaced with actual LLM API calls.
Args:
user_message: User's message
context: Conversation context
Returns:
Generated response text
"""
# Simple mock response based on user input
if "hello" in user_message.lower():
return "Hello! I'm your AI agent. How can I help you today?"
elif "help" in user_message.lower():
return "I'm here to help! You can ask me questions, request information, or have a conversation. What would you like to know?"
elif "weather" in user_message.lower():
return "I don't have access to real-time weather data, but I'd be happy to help you find weather information or discuss weather patterns in general."
elif "time" in user_message.lower():
return f"The current time is approximately {datetime.utcnow().strftime('%H:%M UTC')}. Is there anything else I can help you with?"
else:
return f"Thank you for your message: '{user_message}'. I understand you're looking for assistance. Could you provide more details about what you'd like help with?"
async def handle_typing_indicator(
self,
conversation_id: str,
user_id: str,
is_typing: bool,
connection_id: str
) -> None:
"""
Handle and broadcast typing indicators.
Args:
conversation_id: Target conversation
user_id: User who is typing
is_typing: Whether user is currently typing
connection_id: Connection that sent the indicator
"""
try:
# Broadcast typing indicator to other conversation participants
await self.websocket_manager.broadcast_to_conversation(
conversation_id,
{
"type": "user_typing",
"conversation_id": conversation_id,
"user_id": user_id,
"is_typing": is_typing,
"timestamp": datetime.utcnow().isoformat()
},
exclude_connection=connection_id
)
logger.debug(f"Typing indicator broadcast: {user_id} {'started' if is_typing else 'stopped'} typing")
except Exception as e:
logger.error(f"Error handling typing indicator: {e}")
async def send_system_notification(
self,
user_id: str,
tenant_id: str,
notification: Dict[str, Any]
) -> None:
"""
Send system notification to all user connections.
Args:
user_id: Target user
tenant_id: Tenant for isolation
notification: Notification data
"""
try:
message = {
"type": "system_notification",
"notification": notification,
"timestamp": datetime.utcnow().isoformat()
}
await self.websocket_manager.send_to_user(
user_id=user_id,
tenant_id=tenant_id,
message=message
)
logger.info(f"System notification sent to user {user_id}")
except Exception as e:
logger.error(f"Error sending system notification: {e}")
async def broadcast_conversation_update(
self,
conversation_id: str,
update_type: str,
update_data: Dict[str, Any]
) -> None:
"""
Broadcast conversation update to all participants.
Args:
conversation_id: Target conversation
update_type: Type of update (title_changed, participant_added, etc.)
update_data: Update details
"""
try:
message = {
"type": "conversation_update",
"conversation_id": conversation_id,
"update_type": update_type,
"update_data": update_data,
"timestamp": datetime.utcnow().isoformat()
}
await self.websocket_manager.broadcast_to_conversation(
conversation_id,
message
)
logger.info(f"Conversation update broadcast: {update_type} for {conversation_id}")
except Exception as e:
logger.error(f"Error broadcasting conversation update: {e}")
async def get_connection_stats(self, tenant_id: str) -> Dict[str, Any]:
"""
Get WebSocket connection statistics for tenant.
Args:
tenant_id: Tenant to get stats for
Returns:
Connection statistics
"""
try:
all_stats = self.websocket_manager.get_connection_stats()
return {
"tenant_connections": all_stats["connections_by_tenant"].get(tenant_id, 0),
"active_conversations": len([
conv_id for conv_id, connections in self.websocket_manager.conversation_connections.items()
if any(
self.websocket_manager.connections.get(conn_id, {}).tenant_id == tenant_id
for conn_id in connections
)
])
}
except Exception as e:
logger.error(f"Error getting connection stats: {e}")
return {"tenant_connections": 0, "active_conversations": 0}
# Factory function for dependency injection
async def get_websocket_service(db: AsyncSession = None) -> WebSocketService:
"""Get WebSocket service instance"""
if db is None:
async with get_db_session() as session:
return WebSocketService(session)
return WebSocketService(db)

View File

@@ -0,0 +1,995 @@
import uuid
import json
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Union
from sqlalchemy.orm import Session
from sqlalchemy import select, update, delete
from sqlalchemy.orm import selectinload
from app.models.workflow import (
Workflow, WorkflowExecution, WorkflowTrigger, WorkflowSession, WorkflowMessage,
WorkflowStatus, TriggerType, InteractionMode,
WORKFLOW_NODE_TYPES, INTERACTION_MODE_CONFIGS
)
from app.models.agent import Agent
# Backward compatibility
from app.models.agent import Agent
from app.services.resource_service import ResourceService
class WorkflowValidationError(Exception):
"""Raised when workflow validation fails"""
pass
class WorkflowService:
"""Service for managing workflows with Agents as AI node definitions"""
def __init__(self, db: Session):
self.db = db
self.resource_service = ResourceService()
def create_workflow(
self,
user_id: str,
tenant_id: str,
workflow_data: Dict[str, Any]
) -> Workflow:
"""Create a new workflow with validation"""
# Validate workflow definition
self._validate_workflow_definition(
workflow_data.get('definition', {}),
user_id,
tenant_id
)
# Create workflow
workflow = Workflow(
id=str(uuid.uuid4()),
tenant_id=tenant_id,
user_id=user_id,
name=workflow_data['name'],
description=workflow_data.get('description', ''),
definition=workflow_data['definition'],
triggers=workflow_data.get('triggers', []),
interaction_modes=workflow_data.get('interaction_modes', ['button']),
agent_ids=self._extract_agent_ids(workflow_data['definition']),
api_key_ids=workflow_data.get('api_key_ids', []),
webhook_ids=workflow_data.get('webhook_ids', []),
dataset_ids=workflow_data.get('dataset_ids', []),
integration_ids=workflow_data.get('integration_ids', []),
config=workflow_data.get('config', {}),
timeout_seconds=workflow_data.get('timeout_seconds', 300),
max_retries=workflow_data.get('max_retries', 3)
)
# Use sync database operations
self.db.add(workflow)
self.db.commit()
self.db.refresh(workflow)
# Create triggers if specified
for trigger_config in workflow_data.get('triggers', []):
self.create_workflow_trigger(
workflow.id,
user_id,
tenant_id,
trigger_config
)
return workflow
def get_workflow(self, workflow_id: str, user_id: str) -> Optional[Workflow]:
"""Get a workflow by ID with user ownership validation"""
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.user_id == user_id
)
result = self.db.execute(stmt)
return result.scalar_one_or_none()
def list_user_workflows(
self,
user_id: str,
tenant_id: str,
status: Optional[WorkflowStatus] = None
) -> List[Workflow]:
"""List all workflows for a user"""
stmt = select(Workflow).where(
Workflow.user_id == user_id,
Workflow.tenant_id == tenant_id
)
if status:
stmt = stmt.where(Workflow.status == status)
stmt = stmt.order_by(Workflow.updated_at.desc())
result = self.db.execute(stmt)
return result.scalars().all()
def update_workflow(
self,
workflow_id: str,
user_id: str,
updates: Dict[str, Any]
) -> Workflow:
"""Update a workflow with validation"""
# Get existing workflow
workflow = self.get_workflow(workflow_id, user_id)
if not workflow:
raise ValueError("Workflow not found or access denied")
# Validate definition if updated
if 'definition' in updates:
self._validate_workflow_definition(
updates['definition'],
user_id,
workflow.tenant_id
)
updates['agent_ids'] = self._extract_agent_ids(updates['definition'])
# Update workflow
stmt = update(Workflow).where(
Workflow.id == workflow_id,
Workflow.user_id == user_id
).values(**updates)
self.db.execute(stmt)
self.db.commit()
# Return updated workflow
return self.get_workflow(workflow_id, user_id)
def delete_workflow(self, workflow_id: str, user_id: str) -> bool:
"""Delete a workflow and its related data"""
# Verify ownership
workflow = self.get_workflow(workflow_id, user_id)
if not workflow:
return False
# Delete related records
self._cleanup_workflow_data(workflow_id)
# Delete workflow
stmt = delete(Workflow).where(
Workflow.id == workflow_id,
Workflow.user_id == user_id
)
result = self.db.execute(stmt)
self.db.commit()
return result.rowcount > 0
async def execute_workflow(
self,
workflow_id: str,
user_id: str,
input_data: Dict[str, Any],
trigger_type: str = "manual",
trigger_source: Optional[str] = None,
interaction_mode: str = "api"
) -> WorkflowExecution:
"""Execute a workflow with specified input"""
# Get and validate workflow
workflow = await self.get_workflow(workflow_id, user_id)
if not workflow:
raise ValueError("Workflow not found or access denied")
if workflow.status not in [WorkflowStatus.ACTIVE, WorkflowStatus.DRAFT]:
raise ValueError(f"Cannot execute workflow with status: {workflow.status}")
# Create execution record
execution = WorkflowExecution(
id=str(uuid.uuid4()),
workflow_id=workflow_id,
user_id=user_id,
tenant_id=workflow.tenant_id,
status="pending",
input_data=input_data,
trigger_type=trigger_type,
trigger_source=trigger_source,
interaction_mode=interaction_mode,
session_id=str(uuid.uuid4()) if interaction_mode == "chat" else None
)
self.db.add(execution)
await self.db.commit()
await self.db.refresh(execution)
# Execute workflow asynchronously (in real implementation, this would be a background task)
try:
execution_result = await self._execute_workflow_nodes(workflow, execution, input_data)
# Update execution with results
execution.status = "completed"
execution.output_data = execution_result.get('output', {})
execution.completed_at = datetime.utcnow()
execution.duration_ms = int((execution.completed_at - execution.started_at).total_seconds() * 1000)
execution.progress_percentage = 100
# Update workflow statistics
workflow.execution_count += 1
workflow.last_executed = datetime.utcnow()
workflow.total_tokens_used += execution_result.get('tokens_used', 0)
workflow.total_cost_cents += execution_result.get('cost_cents', 0)
except Exception as e:
# Mark execution as failed
execution.status = "failed"
execution.error_details = str(e)
execution.completed_at = datetime.utcnow()
execution.duration_ms = int((execution.completed_at - execution.started_at).total_seconds() * 1000)
await self.db.commit()
return execution
async def get_execution_status(
self,
execution_id: str,
user_id: str
) -> Optional[WorkflowExecution]:
"""Get execution status with user validation"""
stmt = select(WorkflowExecution).where(
WorkflowExecution.id == execution_id,
WorkflowExecution.user_id == user_id
)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
def create_workflow_trigger(
self,
workflow_id: str,
user_id: str,
tenant_id: str,
trigger_config: Dict[str, Any]
) -> WorkflowTrigger:
"""Create a trigger for a workflow"""
trigger = WorkflowTrigger(
id=str(uuid.uuid4()),
workflow_id=workflow_id,
user_id=user_id,
tenant_id=tenant_id,
trigger_type=trigger_config['type'],
trigger_config=trigger_config
)
# Configure trigger-specific settings
if trigger_config['type'] == 'webhook':
trigger.webhook_url = f"https://api.gt2.com/webhooks/{trigger.id}"
trigger.webhook_secret = str(uuid.uuid4())
elif trigger_config['type'] == 'cron':
trigger.cron_schedule = trigger_config.get('schedule', '0 0 * * *')
trigger.timezone = trigger_config.get('timezone', 'UTC')
elif trigger_config['type'] == 'event':
trigger.event_source = trigger_config.get('source', '')
trigger.event_filters = trigger_config.get('filters', {})
self.db.add(trigger)
self.db.commit()
self.db.refresh(trigger)
return trigger
async def create_chat_session(
self,
workflow_id: str,
user_id: str,
tenant_id: str,
session_config: Optional[Dict[str, Any]] = None
) -> WorkflowSession:
"""Create a chat session for workflow interaction"""
session = WorkflowSession(
id=str(uuid.uuid4()),
workflow_id=workflow_id,
user_id=user_id,
tenant_id=tenant_id,
session_type="chat",
session_state=session_config or {},
expires_at=datetime.utcnow() + timedelta(hours=24) # 24 hour session
)
self.db.add(session)
await self.db.commit()
await self.db.refresh(session)
return session
async def add_chat_message(
self,
session_id: str,
user_id: str,
role: str,
content: str,
agent_id: Optional[str] = None,
confidence_score: Optional[int] = None,
execution_id: Optional[str] = None
) -> WorkflowMessage:
"""Add a message to a chat session"""
# Get session and validate
stmt = select(WorkflowSession).where(
WorkflowSession.id == session_id,
WorkflowSession.user_id == user_id,
WorkflowSession.is_active == True
)
session = await self.db.execute(stmt)
session = session.scalar_one_or_none()
if not session:
raise ValueError("Chat session not found or expired")
message = WorkflowMessage(
id=str(uuid.uuid4()),
session_id=session_id,
workflow_id=session.workflow_id,
execution_id=execution_id,
user_id=user_id,
tenant_id=session.tenant_id,
role=role,
content=content,
agent_id=agent_id,
confidence_score=confidence_score
)
self.db.add(message)
# Update session
session.message_count += 1
session.last_message_at = datetime.utcnow()
await self.db.commit()
await self.db.refresh(message)
return message
def _validate_workflow_definition(
self,
definition: Dict[str, Any],
user_id: str,
tenant_id: str
):
"""Validate workflow definition and resource access"""
nodes = definition.get('nodes', [])
edges = definition.get('edges', [])
# Validate nodes
for node in nodes:
node_type = node.get('type')
if node_type not in WORKFLOW_NODE_TYPES:
raise WorkflowValidationError(f"Invalid node type: {node_type}")
# Validate Agent nodes (support both agent and agent types)
if node_type == 'agent' or node_type == 'agent':
agent_id = node.get('data', {}).get('agent_id') or node.get('data', {}).get('agent_id')
if not agent_id:
raise WorkflowValidationError("Agent node missing agent_id or agent_id")
# Verify user owns the agent
agent = self._get_user_agent(agent_id, user_id, tenant_id)
if not agent:
raise WorkflowValidationError(f"Agent {agent_id} not found or access denied")
# Validate Integration nodes
elif node_type == 'integration':
api_key_id = node.get('data', {}).get('api_key_id')
if api_key_id:
# In real implementation, validate API key ownership
pass
# Validate edges (connections between nodes)
node_ids = {node['id'] for node in nodes}
for edge in edges:
source = edge.get('source')
target = edge.get('target')
if source not in node_ids or target not in node_ids:
raise WorkflowValidationError("Invalid edge connection")
# Ensure workflow has at least one trigger node
trigger_nodes = [n for n in nodes if n.get('type') == 'trigger']
if not trigger_nodes:
raise WorkflowValidationError("Workflow must have at least one trigger node")
def _get_user_agent(
self,
agent_id: str,
user_id: str,
tenant_id: str
) -> Optional[Agent]:
"""Get agent with ownership validation"""
stmt = select(Agent).where(
Agent.id == agent_id,
Agent.user_id == user_id,
Agent.tenant_id == tenant_id
)
result = self.db.execute(stmt)
return result.scalar_one_or_none()
# Backward compatibility method
def _get_user_assistant(
self,
agent_id: str,
user_id: str,
tenant_id: str
) -> Optional[Agent]:
"""Backward compatibility wrapper for _get_user_agent"""
return self._get_user_agent(agent_id, user_id, tenant_id)
def _extract_agent_ids(self, definition: Dict[str, Any]) -> List[str]:
"""Extract agent IDs from workflow definition"""
agent_ids = []
for node in definition.get('nodes', []):
if node.get('type') in ['agent', 'agent']:
agent_id = node.get('data', {}).get('agent_id') or node.get('data', {}).get('agent_id')
if agent_id:
agent_ids.append(agent_id)
return agent_ids
# Backward compatibility method
def _extract_agent_ids(self, definition: Dict[str, Any]) -> List[str]:
"""Backward compatibility wrapper for _extract_agent_ids"""
return self._extract_agent_ids(definition)
async def _execute_workflow_nodes(
self,
workflow: Workflow,
execution: WorkflowExecution,
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute workflow nodes in order"""
# Update execution status
execution.status = "running"
execution.progress_percentage = 10
await self.db.commit()
# Parse workflow definition to create execution graph
definition = workflow.definition
nodes = definition.get('nodes', [])
edges = definition.get('edges', [])
if not nodes:
raise ValueError("Workflow has no nodes to execute")
# Find trigger node to start execution
trigger_nodes = [n for n in nodes if n.get('type') == 'trigger']
if not trigger_nodes:
raise ValueError("Workflow has no trigger nodes")
execution_trace = []
total_tokens = 0
total_cost = 0
current_data = input_data
# Execute nodes in simple sequential order (real implementation would use topological sort)
for node in nodes:
node_id = node.get('id')
node_type = node.get('type')
try:
if node_type == 'trigger':
# Trigger nodes just pass through input data
node_result = {
'output': current_data,
'tokens_used': 0,
'cost_cents': 0
}
elif node_type == 'agent' or node_type == 'agent': # Support both for compatibility
# Execute Agent node via resource cluster
node_result = await self._execute_agent_node_real(node, current_data, execution.user_id, execution.tenant_id)
elif node_type == 'integration':
# Execute integration node (simulated - no external connections)
node_result = await self._execute_integration_node_simulated(node, current_data)
elif node_type == 'logic':
# Execute logic node (real logic operations)
node_result = await self._execute_logic_node_simulated(node, current_data)
elif node_type == 'output':
# Execute output node (simulated - no external deliveries)
node_result = await self._execute_output_node_simulated(node, current_data)
else:
raise ValueError(f"Unknown node type: {node_type}")
# Update execution state
current_data = node_result.get('output', current_data)
total_tokens += node_result.get('tokens_used', 0)
total_cost += node_result.get('cost_cents', 0)
execution_trace.append({
'node_id': node_id,
'node_type': node_type,
'status': 'completed',
'timestamp': datetime.utcnow().isoformat(),
'tokens_used': node_result.get('tokens_used', 0),
'cost_cents': node_result.get('cost_cents', 0),
'execution_time_ms': node_result.get('execution_time_ms', 0)
})
except Exception as e:
# Record failed node execution
execution_trace.append({
'node_id': node_id,
'node_type': node_type,
'status': 'failed',
'timestamp': datetime.utcnow().isoformat(),
'error': str(e)
})
raise ValueError(f"Node {node_id} execution failed: {str(e)}")
return {
'output': current_data,
'tokens_used': total_tokens,
'cost_cents': total_cost,
'execution_trace': execution_trace
}
async def _execute_agent_node_real(
self,
node: Dict[str, Any],
input_data: Dict[str, Any],
user_id: str,
tenant_id: str
) -> Dict[str, Any]:
"""Execute an Agent node with real Agent integration"""
# Support both agent_id and agent_id for backward compatibility
agent_id = node.get('data', {}).get('agent_id') or node.get('data', {}).get('agent_id')
if not agent_id:
raise ValueError("Agent node missing agent_id or agent_id")
# Get Agent configuration
agent = await self._get_user_agent(agent_id, user_id, tenant_id)
if not agent:
raise ValueError(f"Agent {agent_id} not found")
# Prepare input text from workflow data
input_text = input_data.get('message', '') or str(input_data)
# Use the existing conversation service for real execution
from app.services.conversation_service import ConversationService
try:
conversation_service = ConversationService(self.db)
# Create or get conversation for this workflow execution
conversation_id = f"workflow-{agent_id}-{datetime.utcnow().isoformat()}"
# Execute agent with real conversation service (using agent_id for backward compatibility)
response = await conversation_service.send_message(
conversation_id=conversation_id,
user_id=user_id,
tenant_id=tenant_id,
content=input_text,
agent_id=agent_id # ConversationService still expects agent_id parameter
)
# Parse response to extract metrics
tokens_used = response.get('tokens_used', 100) # Default estimate
cost_cents = max(1, tokens_used // 50) # Rough cost estimation
return {
'output': response.get('content', 'Agent response'),
'confidence': node.get('data', {}).get('confidence_threshold', 75),
'tokens_used': tokens_used,
'cost_cents': cost_cents,
'execution_time_ms': response.get('response_time_ms', 1000)
}
except Exception as e:
# If conversation service fails, use basic text processing
return {
'output': f"Agent {agent.name} processed: {input_text[:100]}{'...' if len(input_text) > 100 else ''}",
'confidence': 50, # Lower confidence for fallback
'tokens_used': len(input_text.split()) * 2, # Rough token estimate
'cost_cents': 2,
'execution_time_ms': 500
}
async def _execute_integration_node_simulated(
self,
node: Dict[str, Any],
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute an Integration node with simulated responses (no external connections)"""
integration_type = node.get('data', {}).get('integration_type', 'api')
integration_name = node.get('data', {}).get('name', 'Unknown Integration')
# Simulate processing time based on integration type
import asyncio
processing_times = {
'api': 200, # API calls: 200ms
'webhook': 150, # Webhook: 150ms
'database': 300, # Database: 300ms
'email': 500, # Email: 500ms
'storage': 250 # Storage: 250ms
}
processing_time = processing_times.get(integration_type, 200)
await asyncio.sleep(processing_time / 1000) # Convert to seconds
# Generate realistic simulated responses based on integration type
simulated_responses = {
'api': {
'status': 'success',
'data': {
'response_code': 200,
'message': f'API call to {integration_name} completed successfully',
'processed_items': len(str(input_data)) // 10,
'timestamp': datetime.utcnow().isoformat()
}
},
'webhook': {
'status': 'delivered',
'webhook_id': f'wh_{uuid.uuid4().hex[:8]}',
'delivery_time_ms': processing_time,
'response_code': 200
},
'database': {
'status': 'executed',
'affected_rows': 1,
'query_time_ms': processing_time,
'result_count': 1
},
'email': {
'status': 'sent',
'message_id': f'msg_{uuid.uuid4().hex[:12]}',
'recipients': 1,
'delivery_status': 'queued'
},
'storage': {
'status': 'uploaded',
'file_size_bytes': len(str(input_data)),
'storage_path': f'/simulated/path/{uuid.uuid4().hex[:8]}.json',
'etag': f'etag_{uuid.uuid4().hex[:16]}'
}
}
response_data = simulated_responses.get(integration_type, simulated_responses['api'])
return {
'output': response_data,
'simulated': True, # Mark as simulated
'integration_type': integration_type,
'integration_name': integration_name,
'tokens_used': 0, # Integrations don't use AI tokens
'cost_cents': 1, # Minimal cost for simulation
'execution_time_ms': processing_time,
'log_message': f'Integration {integration_name} simulated - external connections not implemented'
}
async def _execute_logic_node_simulated(
self,
node: Dict[str, Any],
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute a Logic node with real logic operations"""
logic_type = node.get('data', {}).get('logic_type', 'transform')
logic_config = node.get('data', {}).get('logic_config', {})
import asyncio
await asyncio.sleep(0.05) # Small processing delay
if logic_type == 'condition':
# Simple condition evaluation
condition = logic_config.get('condition', 'true')
field = logic_config.get('field', 'message')
operator = logic_config.get('operator', 'contains')
value = logic_config.get('value', '')
input_value = str(input_data.get(field, ''))
if operator == 'contains':
result = value.lower() in input_value.lower()
elif operator == 'equals':
result = input_value.lower() == value.lower()
elif operator == 'length_gt':
result = len(input_value) > int(value)
else:
result = True # Default to true
return {
'output': {
**input_data,
'condition_result': result,
'condition_evaluated': f'{field} {operator} {value}'
},
'tokens_used': 0,
'cost_cents': 0,
'execution_time_ms': 50
}
elif logic_type == 'transform':
# Data transformation
transform_rules = logic_config.get('rules', [])
transformed_data = dict(input_data)
# Apply simple transformations
for rule in transform_rules:
source_field = rule.get('source', '')
target_field = rule.get('target', source_field)
operation = rule.get('operation', 'copy')
if source_field in input_data:
value = input_data[source_field]
if operation == 'uppercase':
transformed_data[target_field] = str(value).upper()
elif operation == 'lowercase':
transformed_data[target_field] = str(value).lower()
elif operation == 'length':
transformed_data[target_field] = len(str(value))
else: # copy
transformed_data[target_field] = value
return {
'output': transformed_data,
'tokens_used': 0,
'cost_cents': 0,
'execution_time_ms': 50
}
elif logic_type == 'aggregate':
# Simple aggregation
aggregate_field = logic_config.get('field', 'items')
operation = logic_config.get('operation', 'count')
items = input_data.get(aggregate_field, [])
if not isinstance(items, list):
items = [items]
if operation == 'count':
result = len(items)
elif operation == 'sum' and all(isinstance(x, (int, float)) for x in items):
result = sum(items)
elif operation == 'average' and all(isinstance(x, (int, float)) for x in items):
result = sum(items) / len(items) if items else 0
else:
result = len(items)
return {
'output': {
**input_data,
f'{operation}_result': result,
f'{operation}_field': aggregate_field
},
'tokens_used': 0,
'cost_cents': 0,
'execution_time_ms': 50
}
else:
# Default passthrough
return {
'output': input_data,
'tokens_used': 0,
'cost_cents': 0,
'execution_time_ms': 50
}
async def _execute_output_node_simulated(
self,
node: Dict[str, Any],
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute an Output node with simulated delivery (no external sends)"""
output_type = node.get('data', {}).get('output_type', 'webhook')
output_config = node.get('data', {}).get('output_config', {})
output_name = node.get('data', {}).get('name', 'Unknown Output')
import asyncio
# Simulate delivery time based on output type
delivery_times = {
'webhook': 300, # Webhook delivery: 300ms
'email': 800, # Email sending: 800ms
'api': 250, # API call: 250ms
'storage': 400, # File storage: 400ms
'notification': 200 # Push notification: 200ms
}
delivery_time = delivery_times.get(output_type, 300)
await asyncio.sleep(delivery_time / 1000)
# Generate realistic simulated delivery responses
simulated_deliveries = {
'webhook': {
'status': 'delivered',
'webhook_url': output_config.get('url', 'https://api.example.com/webhook'),
'response_code': 200,
'delivery_id': f'wh_delivery_{uuid.uuid4().hex[:8]}',
'payload_size_bytes': len(str(input_data))
},
'email': {
'status': 'queued',
'recipients': output_config.get('recipients', ['user@example.com']),
'subject': output_config.get('subject', 'Workflow Output'),
'message_id': f'email_{uuid.uuid4().hex[:12]}',
'provider': 'simulated_smtp'
},
'api': {
'status': 'sent',
'endpoint': output_config.get('endpoint', '/api/results'),
'method': output_config.get('method', 'POST'),
'response_code': 201,
'request_id': f'api_{uuid.uuid4().hex[:8]}'
},
'storage': {
'status': 'stored',
'storage_path': f'/outputs/{uuid.uuid4().hex[:8]}.json',
'file_size_bytes': len(str(input_data)),
'checksum': f'sha256_{uuid.uuid4().hex[:16]}'
},
'notification': {
'status': 'pushed',
'devices': output_config.get('devices', 1),
'message': output_config.get('message', 'Workflow completed'),
'notification_id': f'notif_{uuid.uuid4().hex[:8]}'
}
}
delivery_data = simulated_deliveries.get(output_type, simulated_deliveries['webhook'])
return {
'output': {
**input_data,
'delivery_result': delivery_data,
'output_type': output_type,
'output_name': output_name
},
'simulated': True,
'tokens_used': 0,
'cost_cents': 1, # Minimal cost for output
'execution_time_ms': delivery_time,
'log_message': f'Output {output_name} simulated - external delivery not implemented'
}
async def _execute_logic_node_real(
self,
node: Dict[str, Any],
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute a Logic node with actual data processing"""
logic_type = node.get('data', {}).get('logic_type')
if logic_type == 'decision':
condition = node.get('data', {}).get('config', {}).get('condition', 'true')
# Simple condition evaluation (in production would use safe expression evaluator)
try:
# Basic condition evaluation for common cases
if 'input.value' in condition:
input_value = input_data.get('value', 0)
condition_result = eval(condition.replace('input.value', str(input_value)))
else:
condition_result = True # Default to true for undefined conditions
return {
'output': {
'condition_result': condition_result,
'original_data': input_data,
'branch': 'true' if condition_result else 'false'
},
'tokens_used': 0,
'cost_cents': 0,
'execution_time_ms': 50
}
except:
# Fallback to pass-through if condition evaluation fails
return {
'output': input_data,
'tokens_used': 0,
'cost_cents': 0,
'execution_time_ms': 50
}
elif logic_type == 'transform':
# Simple data transformation
return {
'output': {
'transformed_data': input_data,
'transformation_type': 'basic',
'timestamp': datetime.utcnow().isoformat()
},
'tokens_used': 0,
'cost_cents': 0,
'execution_time_ms': 25
}
# Default: pass through data
return {
'output': input_data,
'tokens_used': 0,
'cost_cents': 0,
'execution_time_ms': 25
}
async def _execute_output_node_real(
self,
node: Dict[str, Any],
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute an Output node with actual output delivery"""
output_type = node.get('data', {}).get('output_type')
if output_type == 'webhook':
webhook_url = node.get('data', {}).get('config', {}).get('url')
if webhook_url:
import httpx
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(webhook_url, json=input_data)
return {
'output': {
'webhook_sent': True,
'status_code': response.status_code,
'response': response.text[:500] # Limit response size
},
'tokens_used': 0,
'cost_cents': 0,
'execution_time_ms': 200
}
except Exception as e:
return {
'output': {
'webhook_sent': False,
'error': str(e)
},
'tokens_used': 0,
'cost_cents': 0,
'execution_time_ms': 100
}
# For other output types, simulate delivery
return {
'output': {
'output_type': output_type,
'delivered': True,
'data_sent': input_data
},
'tokens_used': 0,
'cost_cents': 0,
'execution_time_ms': 50
}
def _cleanup_workflow_data(self, workflow_id: str):
"""Clean up all data related to a workflow"""
# Delete triggers
stmt = delete(WorkflowTrigger).where(WorkflowTrigger.workflow_id == workflow_id)
self.db.execute(stmt)
# Delete sessions and messages
stmt = delete(WorkflowMessage).where(WorkflowMessage.workflow_id == workflow_id)
self.db.execute(stmt)
stmt = delete(WorkflowSession).where(WorkflowSession.workflow_id == workflow_id)
self.db.execute(stmt)
# Delete executions
stmt = delete(WorkflowExecution).where(WorkflowExecution.workflow_id == workflow_id)
self.db.execute(stmt)
self.db.commit()
# WorkflowExecutor functionality integrated directly into WorkflowService
# for better cohesion and to avoid mock implementations