Files
gt-ai-os-community/apps/control-panel-backend/app/services/groq_service.py
HackWeasel 310491a557 GT AI OS Community v2.0.33 - Add NVIDIA NIM and Nemotron agents
- Updated python_coding_microproject.csv to use NVIDIA NIM Kimi K2
- Updated kali_linux_shell_simulator.csv to use NVIDIA NIM Kimi K2
  - Made more general-purpose (flexible targets, expanded tools)
- Added nemotron-mini-agent.csv for fast local inference via Ollama
- Added nemotron-agent.csv for advanced reasoning via Ollama
- Added wiki page: Projects for NVIDIA NIMs and Nemotron
2025-12-12 17:47:14 -05:00

307 lines
12 KiB
Python

"""
Groq LLM integration service with high availability and failover support
"""
import asyncio
import time
from typing import Dict, Any, List, Optional, AsyncGenerator
from datetime import datetime, timedelta
import httpx
import json
import logging
from contextlib import asynccontextmanager
from app.models.ai_resource import AIResource
from app.models.usage import UsageRecord
logger = logging.getLogger(__name__)
class GroqAPIError(Exception):
"""Custom exception for Groq API errors"""
def __init__(self, message: str, status_code: Optional[int] = None, response_body: Optional[str] = None):
self.message = message
self.status_code = status_code
self.response_body = response_body
super().__init__(self.message)
class GroqClient:
"""High-availability Groq API client with automatic failover"""
def __init__(self, resource: AIResource, api_key: str):
self.resource = resource
self.api_key = api_key
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
)
self._current_endpoint_index = 0
self._endpoint_failures = {}
self._rate_limit_reset = None
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.client.aclose()
def _get_next_endpoint(self) -> Optional[str]:
"""Get next available endpoint with circuit breaker logic"""
endpoints = self.resource.get_available_endpoints()
if not endpoints:
return None
# Try current endpoint first if not in failure state
current_endpoint = endpoints[self._current_endpoint_index % len(endpoints)]
failure_info = self._endpoint_failures.get(current_endpoint)
if not failure_info or failure_info["reset_time"] < datetime.utcnow():
return current_endpoint
# Find next healthy endpoint
for i in range(len(endpoints)):
endpoint = endpoints[(self._current_endpoint_index + i + 1) % len(endpoints)]
failure_info = self._endpoint_failures.get(endpoint)
if not failure_info or failure_info["reset_time"] < datetime.utcnow():
self._current_endpoint_index = (self._current_endpoint_index + i + 1) % len(endpoints)
return endpoint
return None
def _mark_endpoint_failed(self, endpoint: str, backoff_minutes: int = 5):
"""Mark endpoint as failed with exponential backoff"""
current_failures = self._endpoint_failures.get(endpoint, {"count": 0})
current_failures["count"] += 1
# Exponential backoff: 5min, 10min, 20min, 40min, max 60min
backoff_time = min(backoff_minutes * (2 ** (current_failures["count"] - 1)), 60)
current_failures["reset_time"] = datetime.utcnow() + timedelta(minutes=backoff_time)
self._endpoint_failures[endpoint] = current_failures
logger.warning(f"Marked endpoint {endpoint} as failed for {backoff_time} minutes (failure #{current_failures['count']})")
def _reset_endpoint_failures(self, endpoint: str):
"""Reset failure count for successful endpoint"""
if endpoint in self._endpoint_failures:
del self._endpoint_failures[endpoint]
async def _make_request(self, method: str, path: str, **kwargs) -> Dict[str, Any]:
"""Make HTTP request with automatic failover"""
last_error = None
for attempt in range(len(self.resource.get_available_endpoints()) + 1):
endpoint = self._get_next_endpoint()
if not endpoint:
raise GroqAPIError("No healthy endpoints available")
url = f"{endpoint.rstrip('/')}/{path.lstrip('/')}"
try:
logger.debug(f"Making {method} request to {url}")
response = await self.client.request(method, url, **kwargs)
# Handle rate limiting
if response.status_code == 429:
retry_after = int(response.headers.get("retry-after", "60"))
self._rate_limit_reset = datetime.utcnow() + timedelta(seconds=retry_after)
raise GroqAPIError(f"Rate limited, retry after {retry_after} seconds", 429)
# Handle server errors with failover
if response.status_code >= 500:
self._mark_endpoint_failed(endpoint)
last_error = GroqAPIError(f"Server error: {response.status_code}", response.status_code, response.text)
continue
# Handle client errors (don't retry)
if response.status_code >= 400:
raise GroqAPIError(f"Client error: {response.status_code}", response.status_code, response.text)
# Success - reset failures for this endpoint
self._reset_endpoint_failures(endpoint)
return response.json()
except httpx.RequestError as e:
logger.warning(f"Request failed for endpoint {endpoint}: {e}")
self._mark_endpoint_failed(endpoint)
last_error = GroqAPIError(f"Request failed: {str(e)}")
continue
# All endpoints failed
raise last_error or GroqAPIError("All endpoints failed")
async def health_check(self) -> bool:
"""Check if the Groq API is healthy"""
try:
await self._make_request("GET", "models")
return True
except Exception as e:
logger.error(f"Health check failed: {e}")
return False
async def list_models(self) -> List[Dict[str, Any]]:
"""List available models"""
response = await self._make_request("GET", "models")
return response.get("data", [])
async def chat_completion(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
stream: bool = False,
**kwargs
) -> Dict[str, Any]:
"""Create chat completion"""
config = self.resource.merge_config(kwargs)
payload = {
"model": model or self.resource.model_name,
"messages": messages,
"stream": stream,
**config
}
# Remove None values
payload = {k: v for k, v in payload.items() if v is not None}
start_time = time.time()
response = await self._make_request("POST", "chat/completions", json=payload)
latency_ms = int((time.time() - start_time) * 1000)
# Log performance metrics
if latency_ms > self.resource.latency_sla_ms:
logger.warning(f"Request exceeded SLA: {latency_ms}ms > {self.resource.latency_sla_ms}ms")
return {
**response,
"_metadata": {
"latency_ms": latency_ms,
"model_used": payload["model"],
"endpoint_used": self._get_next_endpoint()
}
}
async def chat_completion_stream(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
**kwargs
) -> AsyncGenerator[Dict[str, Any], None]:
"""Create streaming chat completion"""
config = self.resource.merge_config(kwargs)
payload = {
"model": model or self.resource.model_name,
"messages": messages,
"stream": True,
**config
}
# Remove None values
payload = {k: v for k, v in payload.items() if v is not None}
endpoint = self._get_next_endpoint()
if not endpoint:
raise GroqAPIError("No healthy endpoints available")
url = f"{endpoint.rstrip('/')}/chat/completions"
async with self.client.stream("POST", url, json=payload) as response:
if response.status_code >= 400:
error_text = await response.aread()
raise GroqAPIError(f"Stream error: {response.status_code}", response.status_code, error_text.decode())
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:] # Remove "data: " prefix
if data.strip() == "[DONE]":
break
try:
yield json.loads(data)
except json.JSONDecodeError:
continue
class GroqService:
"""Service for managing Groq resources and API interactions"""
def __init__(self):
self._clients: Dict[int, GroqClient] = {}
@asynccontextmanager
async def get_client(self, resource: AIResource, api_key: str):
"""Get or create a Groq client for the resource"""
if resource.id not in self._clients:
self._clients[resource.id] = GroqClient(resource, api_key)
try:
yield self._clients[resource.id]
finally:
# Keep clients alive for reuse, cleanup handled separately
pass
async def health_check_resource(self, resource: AIResource, api_key: str) -> bool:
"""Perform health check on a Groq resource"""
try:
async with self.get_client(resource, api_key) as client:
is_healthy = await client.health_check()
resource.update_health_status("healthy" if is_healthy else "unhealthy")
return is_healthy
except Exception as e:
logger.error(f"Health check failed for resource {resource.id}: {e}")
resource.update_health_status("unhealthy")
return False
async def chat_completion(
self,
resource: AIResource,
api_key: str,
messages: List[Dict[str, str]],
user_email: str,
tenant_id: int,
**kwargs
) -> Dict[str, Any]:
"""Create chat completion with usage tracking"""
async with self.get_client(resource, api_key) as client:
response = await client.chat_completion(messages, **kwargs)
# Extract usage information
usage = response.get("usage", {})
total_tokens = usage.get("total_tokens", 0)
# Calculate cost
cost_cents = resource.calculate_cost(total_tokens)
# Create usage record (would be saved to database)
usage_record = {
"tenant_id": tenant_id,
"resource_id": resource.id,
"user_email": user_email,
"request_type": "chat_completion",
"tokens_used": total_tokens,
"cost_cents": cost_cents,
"model_used": response.get("_metadata", {}).get("model_used", resource.model_name),
"latency_ms": response.get("_metadata", {}).get("latency_ms", 0)
}
logger.info(f"Chat completion: {total_tokens} tokens, ${cost_cents/100:.4f} cost")
return {
**response,
"_usage_record": usage_record
}
async def cleanup_clients(self):
"""Cleanup inactive clients"""
for resource_id, client in list(self._clients.items()):
try:
await client.client.aclose()
except Exception:
pass
self._clients.clear()
# Global service instance
groq_service = GroqService()