Files
gt-ai-os-community/apps/tenant-backend/app/core/postgresql_client.py
HackWeasel b9dfb86260 GT AI OS Community Edition v2.0.33
Security hardening release addressing CodeQL and Dependabot alerts:

- Fix stack trace exposure in error responses
- Add SSRF protection with DNS resolution checking
- Implement proper URL hostname validation (replaces substring matching)
- Add centralized path sanitization to prevent path traversal
- Fix ReDoS vulnerability in email validation regex
- Improve HTML sanitization in validation utilities
- Fix capability wildcard matching in auth utilities
- Update glob dependency to address CVE
- Add CodeQL suppression comments for verified false positives

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-12 17:04:45 -05:00

498 lines
18 KiB
Python

"""
GT 2.0 PostgreSQL + PGVector Client for Tenant Backend
Replaces DuckDB service with direct PostgreSQL connections, providing:
- PostgreSQL + PGVector unified storage (replaces DuckDB + ChromaDB)
- BionicGPT Row Level Security patterns for enterprise isolation
- MVCC concurrency solving DuckDB file locking issues
- Hybrid vector + full-text search in single queries
- Connection pooling for 10,000+ concurrent connections
"""
import asyncio
import logging
from typing import Dict, List, Optional, Any, AsyncGenerator, Tuple, Union
from contextlib import asynccontextmanager
import json
from datetime import datetime
from uuid import UUID
import asyncpg
from asyncpg import Pool, Connection
from asyncpg.exceptions import PostgresError
from app.core.config import get_settings, get_tenant_schema_name
logger = logging.getLogger(__name__)
class PostgreSQLClient:
"""PostgreSQL + PGVector client for tenant backend operations"""
def __init__(self, database_url: str, tenant_domain: str):
self.database_url = database_url
self.tenant_domain = tenant_domain
self.schema_name = get_tenant_schema_name(tenant_domain)
self._pool: Optional[Pool] = None
self._initialized = False
async def __aenter__(self):
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async def initialize(self) -> None:
"""Initialize connection pool and verify schema"""
if self._initialized:
return
logger.info(f"Initializing PostgreSQL connection pool for tenant: {self.tenant_domain}")
logger.info(f"Schema: {self.schema_name}, URL: {self.database_url}")
try:
# Create connection pool with resilient settings
# Sized for 100+ concurrent users with RAG/vector search workloads
self._pool = await asyncpg.create_pool(
self.database_url,
min_size=10,
max_size=50, # Increased from 20 to handle 100+ concurrent users
command_timeout=120, # Increased from 60s for queries under load
timeout=10, # Connection acquire timeout increased for high load
max_inactive_connection_lifetime=3600, # Recycle connections after 1 hour
server_settings={
'application_name': f'gt2_tenant_{self.tenant_domain}'
},
# Enable prepared statements for direct postgres connection (performance gain)
statement_cache_size=100
)
# Verify schema exists and has required tables
await self._verify_schema()
self._initialized = True
logger.info(f"PostgreSQL client initialized successfully for tenant: {self.tenant_domain}")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL client: {e}")
if self._pool:
await self._pool.close()
self._pool = None
raise
async def close(self) -> None:
"""Close connection pool"""
if self._pool:
await self._pool.close()
self._pool = None
self._initialized = False
logger.info(f"PostgreSQL connection pool closed for tenant: {self.tenant_domain}")
async def _verify_schema(self) -> None:
"""Verify tenant schema exists and has required tables"""
async with self._pool.acquire() as conn:
# Check if schema exists
schema_exists = await conn.fetchval("""
SELECT EXISTS (
SELECT 1 FROM information_schema.schemata
WHERE schema_name = $1
)
""", self.schema_name)
if not schema_exists:
raise RuntimeError(f"Tenant schema '{self.schema_name}' does not exist. Run schema initialization first.")
# Check for required tables
required_tables = ['tenants', 'users', 'agents', 'datasets', 'conversations', 'messages', 'documents', 'document_chunks']
for table in required_tables:
table_exists = await conn.fetchval(f"""
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = $1 AND table_name = $2
)
""", self.schema_name, table)
if not table_exists:
logger.warning(f"Table '{table}' not found in schema '{self.schema_name}'")
logger.info(f"Schema verification complete for tenant: {self.tenant_domain}")
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator[Connection, None]:
"""Get a connection from the pool"""
if not self._pool:
raise RuntimeError("PostgreSQL client not initialized. Call initialize() first.")
async with self._pool.acquire() as conn:
try:
# Set schema search path for this connection
await conn.execute(f"SET search_path TO {self.schema_name}, public")
# Session variable logging removed - no longer using RLS
yield conn
except Exception as e:
logger.error(f"Database connection error: {e}")
raise
async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
"""Execute a SELECT query and return results"""
async with self.get_connection() as conn:
try:
rows = await conn.fetch(query, *args)
return [dict(row) for row in rows]
except PostgresError as e:
logger.error(f"Query execution failed: {e}, Query: {query}")
raise
async def execute_command(self, command: str, *args) -> int:
"""Execute an INSERT/UPDATE/DELETE command and return affected rows"""
async with self.get_connection() as conn:
try:
result = await conn.execute(command, *args)
# Parse result like "INSERT 0 5" to get affected rows
return int(result.split()[-1]) if result else 0
except PostgresError as e:
logger.error(f"Command execution failed: {e}, Command: {command}")
raise
async def fetch_one(self, query: str, *args) -> Optional[Dict[str, Any]]:
"""Execute query and return first row"""
async with self.get_connection() as conn:
try:
row = await conn.fetchrow(query, *args)
return dict(row) if row else None
except PostgresError as e:
logger.error(f"Fetch one failed: {e}, Query: {query}")
raise
async def fetch_scalar(self, query: str, *args) -> Any:
"""Execute query and return single value"""
async with self.get_connection() as conn:
try:
return await conn.fetchval(query, *args)
except PostgresError as e:
logger.error(f"Fetch scalar failed: {e}, Query: {query}")
raise
async def execute_transaction(self, commands: List[Tuple[str, tuple]]) -> List[int]:
"""Execute multiple commands in a transaction"""
async with self.get_connection() as conn:
async with conn.transaction():
results = []
for command, args in commands:
try:
result = await conn.execute(command, *args)
results.append(int(result.split()[-1]) if result else 0)
except PostgresError as e:
logger.error(f"Transaction command failed: {e}, Command: {command}")
raise
return results
# Vector Search Operations (PGVector)
async def vector_similarity_search(
self,
query_vector: List[float],
table: str = "document_chunks",
limit: int = 10,
similarity_threshold: float = 0.3,
user_id: Optional[str] = None,
dataset_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Perform vector similarity search using PGVector"""
# Convert Python list to PostgreSQL array format
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
query = f"""
SELECT
id,
content,
1 - (embedding <=> $1::vector) as similarity_score,
metadata
FROM {table}
WHERE embedding IS NOT NULL
AND 1 - (embedding <=> $1::vector) > $2
"""
params = [vector_str, similarity_threshold]
param_idx = 3
# Add user isolation if specified
if user_id:
query += f" AND user_id = ${param_idx}"
params.append(user_id)
param_idx += 1
# Add dataset filtering if specified
if dataset_id:
query += f" AND dataset_id = ${param_idx}"
params.append(dataset_id)
param_idx += 1
query += f" ORDER BY embedding <=> $1::vector LIMIT ${param_idx}"
params.append(limit)
return await self.execute_query(query, *params)
async def hybrid_search(
self,
query_text: str,
query_vector: List[float],
user_id: str,
limit: int = 10,
similarity_threshold: float = 0.3,
text_weight: float = 0.3,
vector_weight: float = 0.7,
dataset_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Perform hybrid search combining vector similarity and full-text search"""
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
# Use the enhanced_hybrid_search_chunks function from BionicGPT integration
query = """
SELECT
id,
document_id,
content,
similarity_score,
text_rank,
combined_score,
metadata,
access_verified
FROM enhanced_hybrid_search_chunks($1, $2::vector, $3::uuid, $4, $5, $6, $7, $8)
"""
return await self.execute_query(
query,
query_text,
vector_str,
user_id,
dataset_id,
limit,
similarity_threshold,
text_weight,
vector_weight
)
async def insert_document_chunk(
self,
document_id: str,
tenant_id: int,
user_id: str,
chunk_index: int,
content: str,
content_hash: str,
embedding: List[float],
dataset_id: Optional[str] = None,
token_count: int = 0,
metadata: Optional[Dict] = None
) -> str:
"""Insert a document chunk with vector embedding"""
vector_str = '[' + ','.join(map(str, embedding)) + ']'
metadata_json = json.dumps(metadata or {})
query = """
INSERT INTO document_chunks (
document_id, tenant_id, user_id, dataset_id, chunk_index,
content, content_hash, token_count, embedding, metadata
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector, $10::jsonb)
RETURNING id
"""
return await self.fetch_scalar(
query,
document_id, tenant_id, user_id, dataset_id, chunk_index,
content, content_hash, token_count, vector_str, metadata_json
)
# Health Check and Statistics
async def health_check(self) -> Dict[str, Any]:
"""Perform health check on PostgreSQL connection"""
try:
if not self._pool:
return {"status": "unhealthy", "reason": "Connection pool not initialized"}
# Test basic connectivity
test_result = await self.fetch_scalar("SELECT 1")
# Get pool statistics
pool_stats = {
"size": self._pool.get_size(),
"min_size": self._pool.get_min_size(),
"max_size": self._pool.get_max_size(),
"idle_size": self._pool.get_idle_size()
}
# Test schema access
schema_test = await self.fetch_scalar("""
SELECT EXISTS (
SELECT 1 FROM information_schema.schemata
WHERE schema_name = $1
)
""", self.schema_name)
return {
"status": "healthy" if test_result == 1 and schema_test else "degraded",
"connectivity": "ok" if test_result == 1 else "failed",
"schema_access": "ok" if schema_test else "failed",
"tenant_domain": self.tenant_domain,
"schema_name": self.schema_name,
"pool_stats": pool_stats,
"database_type": "postgresql_pgvector"
}
except Exception as e:
logger.error(f"PostgreSQL health check failed: {e}")
return {"status": "unhealthy", "reason": str(e)}
async def get_database_stats(self) -> Dict[str, Any]:
"""Get database statistics for monitoring"""
try:
# Get table counts and sizes
stats_query = """
SELECT
schemaname,
tablename,
n_tup_ins as inserts,
n_tup_upd as updates,
n_tup_del as deletes,
n_live_tup as live_tuples,
n_dead_tup as dead_tuples
FROM pg_stat_user_tables
WHERE schemaname = $1
"""
table_stats = await self.execute_query(stats_query, self.schema_name)
# Get total schema size
size_query = """
SELECT pg_size_pretty(
SUM(pg_total_relation_size(quote_ident(schemaname)||'.'||quote_ident(tablename)))
) as schema_size
FROM pg_tables
WHERE schemaname = $1
"""
schema_size = await self.fetch_scalar(size_query, self.schema_name)
# Get vector index statistics if available
vector_stats_query = """
SELECT
COUNT(*) as vector_count,
AVG(vector_dims(embedding)) as avg_dimensions
FROM document_chunks
WHERE embedding IS NOT NULL
"""
try:
vector_stats = await self.fetch_one(vector_stats_query)
except:
vector_stats = {"vector_count": 0, "avg_dimensions": 0}
return {
"tenant_domain": self.tenant_domain,
"schema_name": self.schema_name,
"schema_size": schema_size,
"table_stats": table_stats,
"vector_stats": vector_stats,
"engine_type": "PostgreSQL + PGVector",
"mvcc_enabled": True,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Failed to get database statistics: {e}")
return {"error": str(e)}
# Global client instance (singleton pattern for tenant backend)
_pg_client: Optional[PostgreSQLClient] = None
async def get_postgresql_client() -> PostgreSQLClient:
"""Get or create PostgreSQL client instance"""
global _pg_client
if not _pg_client:
settings = get_settings()
_pg_client = PostgreSQLClient(
database_url=settings.database_url,
tenant_domain=settings.tenant_domain
)
await _pg_client.initialize()
return _pg_client
async def init_postgresql() -> None:
"""Initialize PostgreSQL client during startup"""
logger.info("Initializing PostgreSQL client...")
await get_postgresql_client()
logger.info("PostgreSQL client initialized successfully")
async def close_postgresql() -> None:
"""Close PostgreSQL client during shutdown"""
global _pg_client
if _pg_client:
await _pg_client.close()
_pg_client = None
logger.info("PostgreSQL client closed")
# Context manager for database operations
@asynccontextmanager
async def get_db_session():
"""Async context manager for database operations"""
client = await get_postgresql_client()
async with client.get_connection() as conn:
yield conn
# Convenience functions for common operations
async def execute_query(query: str, *args) -> List[Dict[str, Any]]:
"""Execute a SELECT query"""
client = await get_postgresql_client()
return await client.execute_query(query, *args)
async def execute_command(command: str, *args) -> int:
"""Execute an INSERT/UPDATE/DELETE command"""
client = await get_postgresql_client()
return await client.execute_command(command, *args)
async def fetch_one(query: str, *args) -> Optional[Dict[str, Any]]:
"""Execute query and return first row"""
client = await get_postgresql_client()
return await client.fetch_one(query, *args)
async def fetch_scalar(query: str, *args) -> Any:
"""Execute query and return single value"""
client = await get_postgresql_client()
return await client.fetch_scalar(query, *args)
async def health_check() -> Dict[str, Any]:
"""Perform database health check"""
try:
client = await get_postgresql_client()
return await client.health_check()
except Exception as e:
return {"status": "unhealthy", "reason": str(e)}
async def get_database_info() -> Dict[str, Any]:
"""Get database information and statistics"""
try:
client = await get_postgresql_client()
return await client.get_database_stats()
except Exception as e:
return {"error": str(e)}