GT AI OS Community Edition v2.0.33
Security hardening release addressing CodeQL and Dependabot alerts: - Fix stack trace exposure in error responses - Add SSRF protection with DNS resolution checking - Implement proper URL hostname validation (replaces substring matching) - Add centralized path sanitization to prevent path traversal - Fix ReDoS vulnerability in email validation regex - Improve HTML sanitization in validation utilities - Fix capability wildcard matching in auth utilities - Update glob dependency to address CVE - Add CodeQL suppression comments for verified false positives 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
498
apps/tenant-backend/app/core/postgresql_client.py
Normal file
498
apps/tenant-backend/app/core/postgresql_client.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
GT 2.0 PostgreSQL + PGVector Client for Tenant Backend
|
||||
|
||||
Replaces DuckDB service with direct PostgreSQL connections, providing:
|
||||
- PostgreSQL + PGVector unified storage (replaces DuckDB + ChromaDB)
|
||||
- BionicGPT Row Level Security patterns for enterprise isolation
|
||||
- MVCC concurrency solving DuckDB file locking issues
|
||||
- Hybrid vector + full-text search in single queries
|
||||
- Connection pooling for 10,000+ concurrent connections
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator, Tuple, Union
|
||||
from contextlib import asynccontextmanager
|
||||
import json
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
from asyncpg import Pool, Connection
|
||||
from asyncpg.exceptions import PostgresError
|
||||
|
||||
from app.core.config import get_settings, get_tenant_schema_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PostgreSQLClient:
|
||||
"""PostgreSQL + PGVector client for tenant backend operations"""
|
||||
|
||||
def __init__(self, database_url: str, tenant_domain: str):
|
||||
self.database_url = database_url
|
||||
self.tenant_domain = tenant_domain
|
||||
self.schema_name = get_tenant_schema_name(tenant_domain)
|
||||
self._pool: Optional[Pool] = None
|
||||
self._initialized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize connection pool and verify schema"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info(f"Initializing PostgreSQL connection pool for tenant: {self.tenant_domain}")
|
||||
logger.info(f"Schema: {self.schema_name}, URL: {self.database_url}")
|
||||
|
||||
try:
|
||||
# Create connection pool with resilient settings
|
||||
# Sized for 100+ concurrent users with RAG/vector search workloads
|
||||
self._pool = await asyncpg.create_pool(
|
||||
self.database_url,
|
||||
min_size=10,
|
||||
max_size=50, # Increased from 20 to handle 100+ concurrent users
|
||||
command_timeout=120, # Increased from 60s for queries under load
|
||||
timeout=10, # Connection acquire timeout increased for high load
|
||||
max_inactive_connection_lifetime=3600, # Recycle connections after 1 hour
|
||||
server_settings={
|
||||
'application_name': f'gt2_tenant_{self.tenant_domain}'
|
||||
},
|
||||
# Enable prepared statements for direct postgres connection (performance gain)
|
||||
statement_cache_size=100
|
||||
)
|
||||
|
||||
# Verify schema exists and has required tables
|
||||
await self._verify_schema()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"PostgreSQL client initialized successfully for tenant: {self.tenant_domain}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL client: {e}")
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close connection pool"""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
self._initialized = False
|
||||
logger.info(f"PostgreSQL connection pool closed for tenant: {self.tenant_domain}")
|
||||
|
||||
async def _verify_schema(self) -> None:
|
||||
"""Verify tenant schema exists and has required tables"""
|
||||
async with self._pool.acquire() as conn:
|
||||
# Check if schema exists
|
||||
schema_exists = await conn.fetchval("""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.schemata
|
||||
WHERE schema_name = $1
|
||||
)
|
||||
""", self.schema_name)
|
||||
|
||||
if not schema_exists:
|
||||
raise RuntimeError(f"Tenant schema '{self.schema_name}' does not exist. Run schema initialization first.")
|
||||
|
||||
# Check for required tables
|
||||
required_tables = ['tenants', 'users', 'agents', 'datasets', 'conversations', 'messages', 'documents', 'document_chunks']
|
||||
|
||||
for table in required_tables:
|
||||
table_exists = await conn.fetchval(f"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = $1 AND table_name = $2
|
||||
)
|
||||
""", self.schema_name, table)
|
||||
|
||||
if not table_exists:
|
||||
logger.warning(f"Table '{table}' not found in schema '{self.schema_name}'")
|
||||
|
||||
logger.info(f"Schema verification complete for tenant: {self.tenant_domain}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection(self) -> AsyncGenerator[Connection, None]:
|
||||
"""Get a connection from the pool"""
|
||||
if not self._pool:
|
||||
raise RuntimeError("PostgreSQL client not initialized. Call initialize() first.")
|
||||
|
||||
async with self._pool.acquire() as conn:
|
||||
try:
|
||||
# Set schema search path for this connection
|
||||
await conn.execute(f"SET search_path TO {self.schema_name}, public")
|
||||
|
||||
# Session variable logging removed - no longer using RLS
|
||||
|
||||
yield conn
|
||||
except Exception as e:
|
||||
logger.error(f"Database connection error: {e}")
|
||||
raise
|
||||
|
||||
async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
|
||||
"""Execute a SELECT query and return results"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
rows = await conn.fetch(query, *args)
|
||||
return [dict(row) for row in rows]
|
||||
except PostgresError as e:
|
||||
logger.error(f"Query execution failed: {e}, Query: {query}")
|
||||
raise
|
||||
|
||||
async def execute_command(self, command: str, *args) -> int:
|
||||
"""Execute an INSERT/UPDATE/DELETE command and return affected rows"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
result = await conn.execute(command, *args)
|
||||
# Parse result like "INSERT 0 5" to get affected rows
|
||||
return int(result.split()[-1]) if result else 0
|
||||
except PostgresError as e:
|
||||
logger.error(f"Command execution failed: {e}, Command: {command}")
|
||||
raise
|
||||
|
||||
async def fetch_one(self, query: str, *args) -> Optional[Dict[str, Any]]:
|
||||
"""Execute query and return first row"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
row = await conn.fetchrow(query, *args)
|
||||
return dict(row) if row else None
|
||||
except PostgresError as e:
|
||||
logger.error(f"Fetch one failed: {e}, Query: {query}")
|
||||
raise
|
||||
|
||||
async def fetch_scalar(self, query: str, *args) -> Any:
|
||||
"""Execute query and return single value"""
|
||||
async with self.get_connection() as conn:
|
||||
try:
|
||||
return await conn.fetchval(query, *args)
|
||||
except PostgresError as e:
|
||||
logger.error(f"Fetch scalar failed: {e}, Query: {query}")
|
||||
raise
|
||||
|
||||
async def execute_transaction(self, commands: List[Tuple[str, tuple]]) -> List[int]:
|
||||
"""Execute multiple commands in a transaction"""
|
||||
async with self.get_connection() as conn:
|
||||
async with conn.transaction():
|
||||
results = []
|
||||
for command, args in commands:
|
||||
try:
|
||||
result = await conn.execute(command, *args)
|
||||
results.append(int(result.split()[-1]) if result else 0)
|
||||
except PostgresError as e:
|
||||
logger.error(f"Transaction command failed: {e}, Command: {command}")
|
||||
raise
|
||||
return results
|
||||
|
||||
# Vector Search Operations (PGVector)
|
||||
|
||||
async def vector_similarity_search(
|
||||
self,
|
||||
query_vector: List[float],
|
||||
table: str = "document_chunks",
|
||||
limit: int = 10,
|
||||
similarity_threshold: float = 0.3,
|
||||
user_id: Optional[str] = None,
|
||||
dataset_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Perform vector similarity search using PGVector"""
|
||||
|
||||
# Convert Python list to PostgreSQL array format
|
||||
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
id,
|
||||
content,
|
||||
1 - (embedding <=> $1::vector) as similarity_score,
|
||||
metadata
|
||||
FROM {table}
|
||||
WHERE embedding IS NOT NULL
|
||||
AND 1 - (embedding <=> $1::vector) > $2
|
||||
"""
|
||||
|
||||
params = [vector_str, similarity_threshold]
|
||||
param_idx = 3
|
||||
|
||||
# Add user isolation if specified
|
||||
if user_id:
|
||||
query += f" AND user_id = ${param_idx}"
|
||||
params.append(user_id)
|
||||
param_idx += 1
|
||||
|
||||
# Add dataset filtering if specified
|
||||
if dataset_id:
|
||||
query += f" AND dataset_id = ${param_idx}"
|
||||
params.append(dataset_id)
|
||||
param_idx += 1
|
||||
|
||||
query += f" ORDER BY embedding <=> $1::vector LIMIT ${param_idx}"
|
||||
params.append(limit)
|
||||
|
||||
return await self.execute_query(query, *params)
|
||||
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query_text: str,
|
||||
query_vector: List[float],
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
similarity_threshold: float = 0.3,
|
||||
text_weight: float = 0.3,
|
||||
vector_weight: float = 0.7,
|
||||
dataset_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Perform hybrid search combining vector similarity and full-text search"""
|
||||
|
||||
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
|
||||
|
||||
# Use the enhanced_hybrid_search_chunks function from BionicGPT integration
|
||||
query = """
|
||||
SELECT
|
||||
id,
|
||||
document_id,
|
||||
content,
|
||||
similarity_score,
|
||||
text_rank,
|
||||
combined_score,
|
||||
metadata,
|
||||
access_verified
|
||||
FROM enhanced_hybrid_search_chunks($1, $2::vector, $3::uuid, $4, $5, $6, $7, $8)
|
||||
"""
|
||||
|
||||
return await self.execute_query(
|
||||
query,
|
||||
query_text,
|
||||
vector_str,
|
||||
user_id,
|
||||
dataset_id,
|
||||
limit,
|
||||
similarity_threshold,
|
||||
text_weight,
|
||||
vector_weight
|
||||
)
|
||||
|
||||
async def insert_document_chunk(
|
||||
self,
|
||||
document_id: str,
|
||||
tenant_id: int,
|
||||
user_id: str,
|
||||
chunk_index: int,
|
||||
content: str,
|
||||
content_hash: str,
|
||||
embedding: List[float],
|
||||
dataset_id: Optional[str] = None,
|
||||
token_count: int = 0,
|
||||
metadata: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""Insert a document chunk with vector embedding"""
|
||||
|
||||
vector_str = '[' + ','.join(map(str, embedding)) + ']'
|
||||
metadata_json = json.dumps(metadata or {})
|
||||
|
||||
query = """
|
||||
INSERT INTO document_chunks (
|
||||
document_id, tenant_id, user_id, dataset_id, chunk_index,
|
||||
content, content_hash, token_count, embedding, metadata
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector, $10::jsonb)
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
return await self.fetch_scalar(
|
||||
query,
|
||||
document_id, tenant_id, user_id, dataset_id, chunk_index,
|
||||
content, content_hash, token_count, vector_str, metadata_json
|
||||
)
|
||||
|
||||
# Health Check and Statistics
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check on PostgreSQL connection"""
|
||||
try:
|
||||
if not self._pool:
|
||||
return {"status": "unhealthy", "reason": "Connection pool not initialized"}
|
||||
|
||||
# Test basic connectivity
|
||||
test_result = await self.fetch_scalar("SELECT 1")
|
||||
|
||||
# Get pool statistics
|
||||
pool_stats = {
|
||||
"size": self._pool.get_size(),
|
||||
"min_size": self._pool.get_min_size(),
|
||||
"max_size": self._pool.get_max_size(),
|
||||
"idle_size": self._pool.get_idle_size()
|
||||
}
|
||||
|
||||
# Test schema access
|
||||
schema_test = await self.fetch_scalar("""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.schemata
|
||||
WHERE schema_name = $1
|
||||
)
|
||||
""", self.schema_name)
|
||||
|
||||
return {
|
||||
"status": "healthy" if test_result == 1 and schema_test else "degraded",
|
||||
"connectivity": "ok" if test_result == 1 else "failed",
|
||||
"schema_access": "ok" if schema_test else "failed",
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"schema_name": self.schema_name,
|
||||
"pool_stats": pool_stats,
|
||||
"database_type": "postgresql_pgvector"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL health check failed: {e}")
|
||||
return {"status": "unhealthy", "reason": str(e)}
|
||||
|
||||
async def get_database_stats(self) -> Dict[str, Any]:
|
||||
"""Get database statistics for monitoring"""
|
||||
try:
|
||||
# Get table counts and sizes
|
||||
stats_query = """
|
||||
SELECT
|
||||
schemaname,
|
||||
tablename,
|
||||
n_tup_ins as inserts,
|
||||
n_tup_upd as updates,
|
||||
n_tup_del as deletes,
|
||||
n_live_tup as live_tuples,
|
||||
n_dead_tup as dead_tuples
|
||||
FROM pg_stat_user_tables
|
||||
WHERE schemaname = $1
|
||||
"""
|
||||
|
||||
table_stats = await self.execute_query(stats_query, self.schema_name)
|
||||
|
||||
# Get total schema size
|
||||
size_query = """
|
||||
SELECT pg_size_pretty(
|
||||
SUM(pg_total_relation_size(quote_ident(schemaname)||'.'||quote_ident(tablename)))
|
||||
) as schema_size
|
||||
FROM pg_tables
|
||||
WHERE schemaname = $1
|
||||
"""
|
||||
|
||||
schema_size = await self.fetch_scalar(size_query, self.schema_name)
|
||||
|
||||
# Get vector index statistics if available
|
||||
vector_stats_query = """
|
||||
SELECT
|
||||
COUNT(*) as vector_count,
|
||||
AVG(vector_dims(embedding)) as avg_dimensions
|
||||
FROM document_chunks
|
||||
WHERE embedding IS NOT NULL
|
||||
"""
|
||||
|
||||
try:
|
||||
vector_stats = await self.fetch_one(vector_stats_query)
|
||||
except:
|
||||
vector_stats = {"vector_count": 0, "avg_dimensions": 0}
|
||||
|
||||
return {
|
||||
"tenant_domain": self.tenant_domain,
|
||||
"schema_name": self.schema_name,
|
||||
"schema_size": schema_size,
|
||||
"table_stats": table_stats,
|
||||
"vector_stats": vector_stats,
|
||||
"engine_type": "PostgreSQL + PGVector",
|
||||
"mvcc_enabled": True,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database statistics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Global client instance (singleton pattern for tenant backend)
|
||||
_pg_client: Optional[PostgreSQLClient] = None
|
||||
|
||||
|
||||
async def get_postgresql_client() -> PostgreSQLClient:
|
||||
"""Get or create PostgreSQL client instance"""
|
||||
global _pg_client
|
||||
|
||||
if not _pg_client:
|
||||
settings = get_settings()
|
||||
_pg_client = PostgreSQLClient(
|
||||
database_url=settings.database_url,
|
||||
tenant_domain=settings.tenant_domain
|
||||
)
|
||||
await _pg_client.initialize()
|
||||
|
||||
return _pg_client
|
||||
|
||||
|
||||
async def init_postgresql() -> None:
|
||||
"""Initialize PostgreSQL client during startup"""
|
||||
logger.info("Initializing PostgreSQL client...")
|
||||
await get_postgresql_client()
|
||||
logger.info("PostgreSQL client initialized successfully")
|
||||
|
||||
|
||||
async def close_postgresql() -> None:
|
||||
"""Close PostgreSQL client during shutdown"""
|
||||
global _pg_client
|
||||
|
||||
if _pg_client:
|
||||
await _pg_client.close()
|
||||
_pg_client = None
|
||||
logger.info("PostgreSQL client closed")
|
||||
|
||||
|
||||
# Context manager for database operations
|
||||
@asynccontextmanager
|
||||
async def get_db_session():
|
||||
"""Async context manager for database operations"""
|
||||
client = await get_postgresql_client()
|
||||
async with client.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
# Convenience functions for common operations
|
||||
async def execute_query(query: str, *args) -> List[Dict[str, Any]]:
|
||||
"""Execute a SELECT query"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.execute_query(query, *args)
|
||||
|
||||
|
||||
async def execute_command(command: str, *args) -> int:
|
||||
"""Execute an INSERT/UPDATE/DELETE command"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.execute_command(command, *args)
|
||||
|
||||
|
||||
async def fetch_one(query: str, *args) -> Optional[Dict[str, Any]]:
|
||||
"""Execute query and return first row"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.fetch_one(query, *args)
|
||||
|
||||
|
||||
async def fetch_scalar(query: str, *args) -> Any:
|
||||
"""Execute query and return single value"""
|
||||
client = await get_postgresql_client()
|
||||
return await client.fetch_scalar(query, *args)
|
||||
|
||||
|
||||
async def health_check() -> Dict[str, Any]:
|
||||
"""Perform database health check"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
return await client.health_check()
|
||||
except Exception as e:
|
||||
return {"status": "unhealthy", "reason": str(e)}
|
||||
|
||||
|
||||
async def get_database_info() -> Dict[str, Any]:
|
||||
"""Get database information and statistics"""
|
||||
try:
|
||||
client = await get_postgresql_client()
|
||||
return await client.get_database_stats()
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
Reference in New Issue
Block a user