- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2 - Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2 - Made more general-purpose (flexible targets, expanded tools) - Added nemotron-mini-agent.csv for fast local inference via Ollama - Added nemotron-agent.csv for advanced reasoning via Ollama - Added wiki page: Projects for NVIDIA NIMs and Nemotron
498 lines
18 KiB
Python
498 lines
18 KiB
Python
"""
|
|
GT 2.0 PostgreSQL + PGVector Client for Tenant Backend
|
|
|
|
Replaces DuckDB service with direct PostgreSQL connections, providing:
|
|
- PostgreSQL + PGVector unified storage (replaces DuckDB + ChromaDB)
|
|
- BionicGPT Row Level Security patterns for enterprise isolation
|
|
- MVCC concurrency solving DuckDB file locking issues
|
|
- Hybrid vector + full-text search in single queries
|
|
- Connection pooling for 10,000+ concurrent connections
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Dict, List, Optional, Any, AsyncGenerator, Tuple, Union
|
|
from contextlib import asynccontextmanager
|
|
import json
|
|
from datetime import datetime
|
|
from uuid import UUID
|
|
|
|
import asyncpg
|
|
from asyncpg import Pool, Connection
|
|
from asyncpg.exceptions import PostgresError
|
|
|
|
from app.core.config import get_settings, get_tenant_schema_name
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class PostgreSQLClient:
|
|
"""PostgreSQL + PGVector client for tenant backend operations"""
|
|
|
|
def __init__(self, database_url: str, tenant_domain: str):
|
|
self.database_url = database_url
|
|
self.tenant_domain = tenant_domain
|
|
self.schema_name = get_tenant_schema_name(tenant_domain)
|
|
self._pool: Optional[Pool] = None
|
|
self._initialized = False
|
|
|
|
async def __aenter__(self):
|
|
await self.initialize()
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
await self.close()
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize connection pool and verify schema"""
|
|
if self._initialized:
|
|
return
|
|
|
|
logger.info(f"Initializing PostgreSQL connection pool for tenant: {self.tenant_domain}")
|
|
logger.info(f"Schema: {self.schema_name}, URL: {self.database_url}")
|
|
|
|
try:
|
|
# Create connection pool with resilient settings
|
|
# Sized for 100+ concurrent users with RAG/vector search workloads
|
|
self._pool = await asyncpg.create_pool(
|
|
self.database_url,
|
|
min_size=10,
|
|
max_size=50, # Increased from 20 to handle 100+ concurrent users
|
|
command_timeout=120, # Increased from 60s for queries under load
|
|
timeout=10, # Connection acquire timeout increased for high load
|
|
max_inactive_connection_lifetime=3600, # Recycle connections after 1 hour
|
|
server_settings={
|
|
'application_name': f'gt2_tenant_{self.tenant_domain}'
|
|
},
|
|
# Enable prepared statements for direct postgres connection (performance gain)
|
|
statement_cache_size=100
|
|
)
|
|
|
|
# Verify schema exists and has required tables
|
|
await self._verify_schema()
|
|
|
|
self._initialized = True
|
|
logger.info(f"PostgreSQL client initialized successfully for tenant: {self.tenant_domain}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize PostgreSQL client: {e}")
|
|
if self._pool:
|
|
await self._pool.close()
|
|
self._pool = None
|
|
raise
|
|
|
|
async def close(self) -> None:
|
|
"""Close connection pool"""
|
|
if self._pool:
|
|
await self._pool.close()
|
|
self._pool = None
|
|
self._initialized = False
|
|
logger.info(f"PostgreSQL connection pool closed for tenant: {self.tenant_domain}")
|
|
|
|
async def _verify_schema(self) -> None:
|
|
"""Verify tenant schema exists and has required tables"""
|
|
async with self._pool.acquire() as conn:
|
|
# Check if schema exists
|
|
schema_exists = await conn.fetchval("""
|
|
SELECT EXISTS (
|
|
SELECT 1 FROM information_schema.schemata
|
|
WHERE schema_name = $1
|
|
)
|
|
""", self.schema_name)
|
|
|
|
if not schema_exists:
|
|
raise RuntimeError(f"Tenant schema '{self.schema_name}' does not exist. Run schema initialization first.")
|
|
|
|
# Check for required tables
|
|
required_tables = ['tenants', 'users', 'agents', 'datasets', 'conversations', 'messages', 'documents', 'document_chunks']
|
|
|
|
for table in required_tables:
|
|
table_exists = await conn.fetchval(f"""
|
|
SELECT EXISTS (
|
|
SELECT 1 FROM information_schema.tables
|
|
WHERE table_schema = $1 AND table_name = $2
|
|
)
|
|
""", self.schema_name, table)
|
|
|
|
if not table_exists:
|
|
logger.warning(f"Table '{table}' not found in schema '{self.schema_name}'")
|
|
|
|
logger.info(f"Schema verification complete for tenant: {self.tenant_domain}")
|
|
|
|
@asynccontextmanager
|
|
async def get_connection(self) -> AsyncGenerator[Connection, None]:
|
|
"""Get a connection from the pool"""
|
|
if not self._pool:
|
|
raise RuntimeError("PostgreSQL client not initialized. Call initialize() first.")
|
|
|
|
async with self._pool.acquire() as conn:
|
|
try:
|
|
# Set schema search path for this connection
|
|
await conn.execute(f"SET search_path TO {self.schema_name}, public")
|
|
|
|
# Session variable logging removed - no longer using RLS
|
|
|
|
yield conn
|
|
except Exception as e:
|
|
logger.error(f"Database connection error: {e}")
|
|
raise
|
|
|
|
async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
|
|
"""Execute a SELECT query and return results"""
|
|
async with self.get_connection() as conn:
|
|
try:
|
|
rows = await conn.fetch(query, *args)
|
|
return [dict(row) for row in rows]
|
|
except PostgresError as e:
|
|
logger.error(f"Query execution failed: {e}, Query: {query}")
|
|
raise
|
|
|
|
async def execute_command(self, command: str, *args) -> int:
|
|
"""Execute an INSERT/UPDATE/DELETE command and return affected rows"""
|
|
async with self.get_connection() as conn:
|
|
try:
|
|
result = await conn.execute(command, *args)
|
|
# Parse result like "INSERT 0 5" to get affected rows
|
|
return int(result.split()[-1]) if result else 0
|
|
except PostgresError as e:
|
|
logger.error(f"Command execution failed: {e}, Command: {command}")
|
|
raise
|
|
|
|
async def fetch_one(self, query: str, *args) -> Optional[Dict[str, Any]]:
|
|
"""Execute query and return first row"""
|
|
async with self.get_connection() as conn:
|
|
try:
|
|
row = await conn.fetchrow(query, *args)
|
|
return dict(row) if row else None
|
|
except PostgresError as e:
|
|
logger.error(f"Fetch one failed: {e}, Query: {query}")
|
|
raise
|
|
|
|
async def fetch_scalar(self, query: str, *args) -> Any:
|
|
"""Execute query and return single value"""
|
|
async with self.get_connection() as conn:
|
|
try:
|
|
return await conn.fetchval(query, *args)
|
|
except PostgresError as e:
|
|
logger.error(f"Fetch scalar failed: {e}, Query: {query}")
|
|
raise
|
|
|
|
async def execute_transaction(self, commands: List[Tuple[str, tuple]]) -> List[int]:
|
|
"""Execute multiple commands in a transaction"""
|
|
async with self.get_connection() as conn:
|
|
async with conn.transaction():
|
|
results = []
|
|
for command, args in commands:
|
|
try:
|
|
result = await conn.execute(command, *args)
|
|
results.append(int(result.split()[-1]) if result else 0)
|
|
except PostgresError as e:
|
|
logger.error(f"Transaction command failed: {e}, Command: {command}")
|
|
raise
|
|
return results
|
|
|
|
# Vector Search Operations (PGVector)
|
|
|
|
async def vector_similarity_search(
|
|
self,
|
|
query_vector: List[float],
|
|
table: str = "document_chunks",
|
|
limit: int = 10,
|
|
similarity_threshold: float = 0.3,
|
|
user_id: Optional[str] = None,
|
|
dataset_id: Optional[str] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Perform vector similarity search using PGVector"""
|
|
|
|
# Convert Python list to PostgreSQL array format
|
|
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
|
|
|
|
query = f"""
|
|
SELECT
|
|
id,
|
|
content,
|
|
1 - (embedding <=> $1::vector) as similarity_score,
|
|
metadata
|
|
FROM {table}
|
|
WHERE embedding IS NOT NULL
|
|
AND 1 - (embedding <=> $1::vector) > $2
|
|
"""
|
|
|
|
params = [vector_str, similarity_threshold]
|
|
param_idx = 3
|
|
|
|
# Add user isolation if specified
|
|
if user_id:
|
|
query += f" AND user_id = ${param_idx}"
|
|
params.append(user_id)
|
|
param_idx += 1
|
|
|
|
# Add dataset filtering if specified
|
|
if dataset_id:
|
|
query += f" AND dataset_id = ${param_idx}"
|
|
params.append(dataset_id)
|
|
param_idx += 1
|
|
|
|
query += f" ORDER BY embedding <=> $1::vector LIMIT ${param_idx}"
|
|
params.append(limit)
|
|
|
|
return await self.execute_query(query, *params)
|
|
|
|
async def hybrid_search(
|
|
self,
|
|
query_text: str,
|
|
query_vector: List[float],
|
|
user_id: str,
|
|
limit: int = 10,
|
|
similarity_threshold: float = 0.3,
|
|
text_weight: float = 0.3,
|
|
vector_weight: float = 0.7,
|
|
dataset_id: Optional[str] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Perform hybrid search combining vector similarity and full-text search"""
|
|
|
|
vector_str = '[' + ','.join(map(str, query_vector)) + ']'
|
|
|
|
# Use the enhanced_hybrid_search_chunks function from BionicGPT integration
|
|
query = """
|
|
SELECT
|
|
id,
|
|
document_id,
|
|
content,
|
|
similarity_score,
|
|
text_rank,
|
|
combined_score,
|
|
metadata,
|
|
access_verified
|
|
FROM enhanced_hybrid_search_chunks($1, $2::vector, $3::uuid, $4, $5, $6, $7, $8)
|
|
"""
|
|
|
|
return await self.execute_query(
|
|
query,
|
|
query_text,
|
|
vector_str,
|
|
user_id,
|
|
dataset_id,
|
|
limit,
|
|
similarity_threshold,
|
|
text_weight,
|
|
vector_weight
|
|
)
|
|
|
|
async def insert_document_chunk(
|
|
self,
|
|
document_id: str,
|
|
tenant_id: int,
|
|
user_id: str,
|
|
chunk_index: int,
|
|
content: str,
|
|
content_hash: str,
|
|
embedding: List[float],
|
|
dataset_id: Optional[str] = None,
|
|
token_count: int = 0,
|
|
metadata: Optional[Dict] = None
|
|
) -> str:
|
|
"""Insert a document chunk with vector embedding"""
|
|
|
|
vector_str = '[' + ','.join(map(str, embedding)) + ']'
|
|
metadata_json = json.dumps(metadata or {})
|
|
|
|
query = """
|
|
INSERT INTO document_chunks (
|
|
document_id, tenant_id, user_id, dataset_id, chunk_index,
|
|
content, content_hash, token_count, embedding, metadata
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector, $10::jsonb)
|
|
RETURNING id
|
|
"""
|
|
|
|
return await self.fetch_scalar(
|
|
query,
|
|
document_id, tenant_id, user_id, dataset_id, chunk_index,
|
|
content, content_hash, token_count, vector_str, metadata_json
|
|
)
|
|
|
|
# Health Check and Statistics
|
|
|
|
async def health_check(self) -> Dict[str, Any]:
|
|
"""Perform health check on PostgreSQL connection"""
|
|
try:
|
|
if not self._pool:
|
|
return {"status": "unhealthy", "reason": "Connection pool not initialized"}
|
|
|
|
# Test basic connectivity
|
|
test_result = await self.fetch_scalar("SELECT 1")
|
|
|
|
# Get pool statistics
|
|
pool_stats = {
|
|
"size": self._pool.get_size(),
|
|
"min_size": self._pool.get_min_size(),
|
|
"max_size": self._pool.get_max_size(),
|
|
"idle_size": self._pool.get_idle_size()
|
|
}
|
|
|
|
# Test schema access
|
|
schema_test = await self.fetch_scalar("""
|
|
SELECT EXISTS (
|
|
SELECT 1 FROM information_schema.schemata
|
|
WHERE schema_name = $1
|
|
)
|
|
""", self.schema_name)
|
|
|
|
return {
|
|
"status": "healthy" if test_result == 1 and schema_test else "degraded",
|
|
"connectivity": "ok" if test_result == 1 else "failed",
|
|
"schema_access": "ok" if schema_test else "failed",
|
|
"tenant_domain": self.tenant_domain,
|
|
"schema_name": self.schema_name,
|
|
"pool_stats": pool_stats,
|
|
"database_type": "postgresql_pgvector"
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"PostgreSQL health check failed: {e}")
|
|
return {"status": "unhealthy", "reason": str(e)}
|
|
|
|
async def get_database_stats(self) -> Dict[str, Any]:
|
|
"""Get database statistics for monitoring"""
|
|
try:
|
|
# Get table counts and sizes
|
|
stats_query = """
|
|
SELECT
|
|
schemaname,
|
|
tablename,
|
|
n_tup_ins as inserts,
|
|
n_tup_upd as updates,
|
|
n_tup_del as deletes,
|
|
n_live_tup as live_tuples,
|
|
n_dead_tup as dead_tuples
|
|
FROM pg_stat_user_tables
|
|
WHERE schemaname = $1
|
|
"""
|
|
|
|
table_stats = await self.execute_query(stats_query, self.schema_name)
|
|
|
|
# Get total schema size
|
|
size_query = """
|
|
SELECT pg_size_pretty(
|
|
SUM(pg_total_relation_size(quote_ident(schemaname)||'.'||quote_ident(tablename)))
|
|
) as schema_size
|
|
FROM pg_tables
|
|
WHERE schemaname = $1
|
|
"""
|
|
|
|
schema_size = await self.fetch_scalar(size_query, self.schema_name)
|
|
|
|
# Get vector index statistics if available
|
|
vector_stats_query = """
|
|
SELECT
|
|
COUNT(*) as vector_count,
|
|
AVG(vector_dims(embedding)) as avg_dimensions
|
|
FROM document_chunks
|
|
WHERE embedding IS NOT NULL
|
|
"""
|
|
|
|
try:
|
|
vector_stats = await self.fetch_one(vector_stats_query)
|
|
except:
|
|
vector_stats = {"vector_count": 0, "avg_dimensions": 0}
|
|
|
|
return {
|
|
"tenant_domain": self.tenant_domain,
|
|
"schema_name": self.schema_name,
|
|
"schema_size": schema_size,
|
|
"table_stats": table_stats,
|
|
"vector_stats": vector_stats,
|
|
"engine_type": "PostgreSQL + PGVector",
|
|
"mvcc_enabled": True,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get database statistics: {e}")
|
|
return {"error": str(e)}
|
|
|
|
|
|
# Global client instance (singleton pattern for tenant backend)
|
|
_pg_client: Optional[PostgreSQLClient] = None
|
|
|
|
|
|
async def get_postgresql_client() -> PostgreSQLClient:
|
|
"""Get or create PostgreSQL client instance"""
|
|
global _pg_client
|
|
|
|
if not _pg_client:
|
|
settings = get_settings()
|
|
_pg_client = PostgreSQLClient(
|
|
database_url=settings.database_url,
|
|
tenant_domain=settings.tenant_domain
|
|
)
|
|
await _pg_client.initialize()
|
|
|
|
return _pg_client
|
|
|
|
|
|
async def init_postgresql() -> None:
|
|
"""Initialize PostgreSQL client during startup"""
|
|
logger.info("Initializing PostgreSQL client...")
|
|
await get_postgresql_client()
|
|
logger.info("PostgreSQL client initialized successfully")
|
|
|
|
|
|
async def close_postgresql() -> None:
|
|
"""Close PostgreSQL client during shutdown"""
|
|
global _pg_client
|
|
|
|
if _pg_client:
|
|
await _pg_client.close()
|
|
_pg_client = None
|
|
logger.info("PostgreSQL client closed")
|
|
|
|
|
|
# Context manager for database operations
|
|
@asynccontextmanager
|
|
async def get_db_session():
|
|
"""Async context manager for database operations"""
|
|
client = await get_postgresql_client()
|
|
async with client.get_connection() as conn:
|
|
yield conn
|
|
|
|
|
|
# Convenience functions for common operations
|
|
async def execute_query(query: str, *args) -> List[Dict[str, Any]]:
|
|
"""Execute a SELECT query"""
|
|
client = await get_postgresql_client()
|
|
return await client.execute_query(query, *args)
|
|
|
|
|
|
async def execute_command(command: str, *args) -> int:
|
|
"""Execute an INSERT/UPDATE/DELETE command"""
|
|
client = await get_postgresql_client()
|
|
return await client.execute_command(command, *args)
|
|
|
|
|
|
async def fetch_one(query: str, *args) -> Optional[Dict[str, Any]]:
|
|
"""Execute query and return first row"""
|
|
client = await get_postgresql_client()
|
|
return await client.fetch_one(query, *args)
|
|
|
|
|
|
async def fetch_scalar(query: str, *args) -> Any:
|
|
"""Execute query and return single value"""
|
|
client = await get_postgresql_client()
|
|
return await client.fetch_scalar(query, *args)
|
|
|
|
|
|
async def health_check() -> Dict[str, Any]:
|
|
"""Perform database health check"""
|
|
try:
|
|
client = await get_postgresql_client()
|
|
return await client.health_check()
|
|
except Exception as e:
|
|
return {"status": "unhealthy", "reason": str(e)}
|
|
|
|
|
|
async def get_database_info() -> Dict[str, Any]:
|
|
"""Get database information and statistics"""
|
|
try:
|
|
client = await get_postgresql_client()
|
|
return await client.get_database_stats()
|
|
except Exception as e:
|
|
return {"error": str(e)} |