"""
Connection manager for WebSocket OCPP connections
"""
import asyncio
import logging
import websockets
from typing import Dict, Optional, Set
from datetime import datetime, timezone

logger = logging.getLogger(__name__)


class ConnectionManager:
    """Manages active WebSocket connections to charge points"""
    
    def __init__(self):
        self.connections: Dict[str, websockets.WebSocketServerProtocol] = {}
        self.connection_info: Dict[str, Dict] = {}
        self._lock = asyncio.Lock()
        self.start_time = datetime.now(timezone.utc)
        self.total_connections_ever = 0
        
    async def register_connection(self, charge_point_id: str, websocket: websockets.WebSocketServerProtocol):
        """Register a new charge point connection"""
        async with self._lock:
            # Close existing connection if any
            if charge_point_id in self.connections:
                logger.warning(f"Closing existing connection for {charge_point_id}")
                await self._close_connection(charge_point_id)
            
            # Register new connection
            self.connections[charge_point_id] = websocket
            self.connection_info[charge_point_id] = {
                "connected_at": datetime.now(timezone.utc).isoformat(),
                "remote_address": websocket.remote_address,
                "subprotocol": websocket.subprotocol,
                "last_seen": datetime.now(timezone.utc).isoformat(),
                "message_count": 0
            }
            
            self.total_connections_ever += 1
            
            logger.info(f"Registered connection for charge point: {charge_point_id}")
            logger.info(f"Active connections: {len(self.connections)}")
    
    async def unregister_connection(self, charge_point_id: str):
        """Unregister a charge point connection"""
        async with self._lock:
            if charge_point_id in self.connections:
                await self._close_connection(charge_point_id)
                logger.info(f"Unregistered connection for charge point: {charge_point_id}")
                logger.info(f"Active connections: {len(self.connections)}")
    
    async def _close_connection(self, charge_point_id: str):
        """Close a specific connection (internal method)"""
        if charge_point_id in self.connections:
            websocket = self.connections.pop(charge_point_id)
            self.connection_info.pop(charge_point_id, None)
            
            try:
                await websocket.close()
            except Exception as e:
                logger.debug(f"Error closing websocket for {charge_point_id}: {e}")
    
    async def send_message(self, charge_point_id: str, message: str) -> bool:
        """Send a message to a specific charge point"""
        async with self._lock:
            if charge_point_id not in self.connections:
                logger.warning(f"No active connection for charge point: {charge_point_id}")
                return False
            
            websocket = self.connections[charge_point_id]
            
            try:
                await websocket.send(message)
                
                # Update connection info
                if charge_point_id in self.connection_info:
                    self.connection_info[charge_point_id]["last_seen"] = datetime.now(timezone.utc).isoformat()
                    self.connection_info[charge_point_id]["message_count"] += 1
                
                logger.debug(f"Sent message to {charge_point_id}: {message}")
                return True
                
            except websockets.exceptions.ConnectionClosed:
                logger.warning(f"Connection closed for {charge_point_id}")
                await self._close_connection(charge_point_id)
                return False
            except Exception as e:
                logger.error(f"Error sending message to {charge_point_id}: {e}")
                return False
    
    async def broadcast_message(self, message: str, exclude: Optional[Set[str]] = None) -> int:
        """Broadcast a message to all connected charge points"""
        exclude = exclude or set()
        successful_sends = 0
        
        # Get list of charge point IDs to avoid dict modification during iteration
        charge_point_ids = list(self.connections.keys())
        
        for charge_point_id in charge_point_ids:
            if charge_point_id not in exclude:
                if await self.send_message(charge_point_id, message):
                    successful_sends += 1
        
        logger.info(f"Broadcast message to {successful_sends} charge points")
        return successful_sends
    
    def is_connected(self, charge_point_id: str) -> bool:
        """Check if a charge point is currently connected"""
        return charge_point_id in self.connections
    
    def get_connected_charge_points(self) -> Dict[str, Dict]:
        """Get list of all connected charge points with their info"""
        return self.connection_info.copy()
    
    def get_connection_count(self) -> int:
        """Get the number of active connections"""
        return len(self.connections)
    
    async def close_all_connections(self):
        """Close all active connections"""
        logger.info("Closing all WebSocket connections...")
        
        async with self._lock:
            charge_point_ids = list(self.connections.keys())
            
            for charge_point_id in charge_point_ids:
                await self._close_connection(charge_point_id)
        
        logger.info("All connections closed")
    
    async def ping_all_connections(self) -> Dict[str, bool]:
        """Ping all connections to check if they're still alive"""
        results = {}
        
        charge_point_ids = list(self.connections.keys())
        
        for charge_point_id in charge_point_ids:
            try:
                websocket = self.connections.get(charge_point_id)
                if websocket:
                    pong_waiter = await websocket.ping()
                    await asyncio.wait_for(pong_waiter, timeout=10)
                    results[charge_point_id] = True
                    logger.debug(f"Ping successful for {charge_point_id}")
                else:
                    results[charge_point_id] = False
            except Exception as e:
                logger.warning(f"Ping failed for {charge_point_id}: {e}")
                results[charge_point_id] = False
                # Remove dead connection
                await self.unregister_connection(charge_point_id)
        
        return results
    
    def get_total_connections(self) -> int:
        """Get the total number of connections that have been made since server start"""
        return self.total_connections_ever
    
    def get_uptime(self) -> str:
        """Get server uptime in human-readable format"""
        uptime_delta = datetime.now(timezone.utc) - self.start_time
        days = uptime_delta.days
        hours, remainder = divmod(uptime_delta.seconds, 3600)
        minutes, seconds = divmod(remainder, 60)
        
        if days > 0:
            return f"{days}d {hours}h {minutes}m {seconds}s"
        elif hours > 0:
            return f"{hours}h {minutes}m {seconds}s"
        elif minutes > 0:
            return f"{minutes}m {seconds}s"
        else:
            return f"{seconds}s"
