"""
WebSocket OCPP Server - /ws/evse endpoint
Handles multiple concurrent EVSE connections with token-based authentication
"""
import asyncio
import json
import logging
import websockets
from websockets.server import WebSocketServerProtocol
from typing import Dict, Set, Optional, Any
from datetime import datetime
import urllib.parse
import uuid

from .config import get_config
from .events import handle_event
from .auth import validate_token

logger = logging.getLogger(__name__)


class EVSEConnectionManager:
    """Manages EVSE WebSocket connections with in-memory storage"""
    
    def __init__(self):
        # Global in-memory structure keyed by charger ID
        self.connections: Dict[str, Dict[str, Any]] = {}
        self.websockets: Dict[str, WebSocketServerProtocol] = {}
        self._lock = asyncio.Lock()
        
    async def add_connection(self, charger_id: str, websocket: WebSocketServerProtocol, token: str) -> bool:
        """Add a new EVSE connection"""
        async with self._lock:
            if charger_id in self.connections:
                logger.warning(f"Charger {charger_id} already connected, replacing connection")
                await self._remove_connection_unsafe(charger_id)
            
            self.connections[charger_id] = {
                "websocket": websocket,
                "token": token,
                "connected_at": datetime.utcnow().isoformat(),
                "last_activity": datetime.utcnow().isoformat(),
                "message_count": 0,
                "status": "connected"
            }
            self.websockets[charger_id] = websocket
            
            logger.info(f"EVSE {charger_id} connected successfully")
            return True
    
    async def remove_connection(self, charger_id: str) -> bool:
        """Remove an EVSE connection"""
        async with self._lock:
            return await self._remove_connection_unsafe(charger_id)
    
    async def _remove_connection_unsafe(self, charger_id: str) -> bool:
        """Remove connection without lock (internal use)"""
        if charger_id in self.connections:
            # Close websocket if still open
            websocket = self.websockets.get(charger_id)
            if websocket:
                try:
                    # Check if connection is still open by trying to ping or checking state
                    if hasattr(websocket, 'closed') and not websocket.closed:
                        await websocket.close()
                    elif hasattr(websocket, 'close_code') and websocket.close_code is None:
                        await websocket.close()
                    else:
                        # Fallback: just try to close
                        await websocket.close()
                except Exception as e:
                    logger.warning(f"Error closing websocket for {charger_id}: {e}")
            
            del self.connections[charger_id]
            self.websockets.pop(charger_id, None)
            
            logger.info(f"EVSE {charger_id} disconnected")
            return True
        return False
    
    async def update_activity(self, charger_id: str) -> None:
        """Update last activity timestamp for a charger"""
        async with self._lock:
            if charger_id in self.connections:
                self.connections[charger_id]["last_activity"] = datetime.utcnow().isoformat()
                self.connections[charger_id]["message_count"] += 1
    
    async def get_connection(self, charger_id: str) -> Optional[Dict[str, Any]]:
        """Get connection info for a charger"""
        async with self._lock:
            return self.connections.get(charger_id)
    
    async def get_all_connections(self) -> Dict[str, Dict[str, Any]]:
        """Get all active connections"""
        async with self._lock:
            # Return copy without websocket objects for serialization
            result = {}
            for charger_id, conn_data in self.connections.items():
                result[charger_id] = {
                    "connected_at": conn_data["connected_at"],
                    "last_activity": conn_data["last_activity"],
                    "message_count": conn_data["message_count"],
                    "status": conn_data["status"]
                }
            return result
    
    async def send_message(self, charger_id: str, message: Dict[str, Any]) -> bool:
        """Send a message to a specific charger"""
        async with self._lock:
            websocket = self.websockets.get(charger_id)
            if websocket:
                try:
                    # Check if connection is still open before sending
                    if hasattr(websocket, 'closed') and websocket.closed:
                        logger.warning(f"WebSocket for {charger_id} is closed")
                        await self._remove_connection_unsafe(charger_id)
                        return False
                    elif hasattr(websocket, 'close_code') and websocket.close_code is not None:
                        logger.warning(f"WebSocket for {charger_id} is closed (close_code: {websocket.close_code})")
                        await self._remove_connection_unsafe(charger_id)
                        return False
                    
                    await websocket.send(json.dumps(message))
                    logger.debug(f"Message sent to {charger_id}: {message}")
                    return True
                except Exception as e:
                    logger.error(f"Failed to send message to {charger_id}: {e}")
                    # Remove dead connection
                    await self._remove_connection_unsafe(charger_id)
            return False
    
    async def broadcast_message(self, message: Dict[str, Any]) -> int:
        """Broadcast a message to all connected chargers"""
        sent_count = 0
        async with self._lock:
            for charger_id in list(self.websockets.keys()):
                if await self.send_message(charger_id, message):
                    sent_count += 1
        return sent_count
    
    def get_connection_count(self) -> int:
        """Get total number of active connections"""
        return len(self.connections)


# Global connection manager instance
connection_manager = EVSEConnectionManager()


class WebSocketOCPPServer:
    """WebSocket OCPP Server for EVSE connections"""
    
    def __init__(self):
        self.config = get_config()
        self.server = None
        
    def _extract_auth_token(self, path: str, headers: Dict[str, str]) -> Optional[str]:
        """Extract authentication token from query parameters or headers"""
        # Try query parameters first
        parsed_url = urllib.parse.urlparse(path)
        query_params = urllib.parse.parse_qs(parsed_url.query)
        
        # Check for token in query parameters
        if 'token' in query_params:
            return query_params['token'][0]
        
        # Check for token in Authorization header
        auth_header = headers.get('authorization', '') or headers.get('Authorization', '')
        if auth_header.startswith('Bearer '):
            return auth_header[7:]  # Remove 'Bearer ' prefix
        
        # Check for custom X-Auth-Token header
        return headers.get('x-auth-token') or headers.get('X-Auth-Token')
    
    def _extract_charger_id(self, path: str, headers: Dict[str, str]) -> Optional[str]:
        """Extract charger ID from path or headers"""
        # Try path first (e.g., /ws/evse/CHARGER_001)
        path_parts = path.split('/')
        if len(path_parts) >= 4 and path_parts[-1]:
            return path_parts[-1]
        
        # Try query parameters
        parsed_url = urllib.parse.urlparse(path)
        query_params = urllib.parse.parse_qs(parsed_url.query)
        if 'charger_id' in query_params:
            return query_params['charger_id'][0]
        
        # Try headers
        return headers.get('x-charger-id') or headers.get('X-Charger-ID')
    
    async def websocket_handler(self, websocket):
        """WebSocket handler wrapper that extracts path"""
        # Extract full URI with query parameters
        path = '/ws/evse'  # default
        
        if hasattr(websocket, 'request'):
            # Check for different ways to get the full URI
            if hasattr(websocket.request, 'raw_path'):
                path = websocket.request.raw_path.decode() if isinstance(websocket.request.raw_path, bytes) else str(websocket.request.raw_path)
            elif hasattr(websocket.request, 'path'):
                path = str(websocket.request.path)
                if hasattr(websocket.request, 'query_string') and websocket.request.query_string:
                    query = websocket.request.query_string.decode() if isinstance(websocket.request.query_string, bytes) else str(websocket.request.query_string)
                    path = f"{path}?{query}"
        elif hasattr(websocket, 'uri'):
            path = str(websocket.uri)
            
        logger.debug(f"WebSocket connection to: {path}")
        await self.handle_evse_connection(websocket, path)
    
    async def handle_evse_connection(self, websocket: WebSocketServerProtocol, path: str = "/ws/evse"):
        """Handle individual EVSE WebSocket connection"""
        charger_id = None
        
        try:
            # Debug: Check websocket attributes
            logger.debug(f"WebSocket attributes: {dir(websocket)}")
            if hasattr(websocket, 'request'):
                logger.debug(f"Request attributes: {dir(websocket.request)}")
            
            # Extract authentication token - try different attribute names
            headers = {}
            if hasattr(websocket, 'request') and hasattr(websocket.request, 'headers'):
                headers = dict(websocket.request.headers)
            elif hasattr(websocket, 'request_headers'):
                headers = dict(websocket.request_headers)
            elif hasattr(websocket, 'headers'):
                headers = dict(websocket.headers)
            elif hasattr(websocket, 'extra_headers'):
                headers = dict(websocket.extra_headers)
            
            logger.debug(f"Headers found: {headers}")
            token = self._extract_auth_token(path, headers)
            
            if not token:
                logger.warning(f"Connection rejected: No auth token provided for path {path}")
                await websocket.close(code=4001, reason="Authentication token required")
                return
            
            # Validate token
            if not validate_token(token):
                logger.warning(f"Connection rejected: Invalid token for path {path}")
                await websocket.close(code=4002, reason="Invalid authentication token")
                return
            
            # Extract charger ID
            charger_id = self._extract_charger_id(path, headers)
            if not charger_id:
                logger.warning(f"Connection rejected: No charger ID provided for path {path}")
                await websocket.close(code=4003, reason="Charger ID required")
                return
            
            # Add connection to manager
            await connection_manager.add_connection(charger_id, websocket, token)
            
            logger.info(f"EVSE {charger_id} authenticated and connected")
            
            # Send welcome message
            welcome_message = {
                "type": "welcome",
                "charger_id": charger_id,
                "server_time": datetime.utcnow().isoformat(),
                "session_id": str(uuid.uuid4())
            }
            await websocket.send(json.dumps(welcome_message))
            
            # Handle incoming messages
            async for raw_message in websocket:
                try:
                    # Update activity
                    await connection_manager.update_activity(charger_id)
                    
                    # Parse JSON message
                    message = json.loads(raw_message)
                    logger.debug(f"Received from {charger_id}: {message}")
                    
                    # Add charger ID to message context
                    message_context = {
                        "charger_id": charger_id,
                        "message": message,
                        "websocket": websocket,
                        "received_at": datetime.utcnow().isoformat()
                    }
                    
                    # Extract action and payload from the message
                    action = message.get("action", "Unknown")
                    payload = message.get("payload", {})
                    
                    # Add charge point ID to payload for logging
                    payload["_charge_point_id"] = charger_id
                    
                    # Handle the OCPP event directly
                    response_payload = await handle_event(action, payload)
                    
                    # Format response
                    response = {
                        "type": "response",
                        "action": action,
                        "payload": response_payload,
                        "timestamp": datetime.utcnow().isoformat() + "Z"
                    }
                    
                    # Send response if provided
                    if response:
                        await websocket.send(json.dumps(response))
                        logger.debug(f"Response sent to {charger_id}: {response}")
                
                except json.JSONDecodeError as e:
                    logger.error(f"Invalid JSON from {charger_id}: {e}")
                    error_response = {
                        "type": "error",
                        "error": "Invalid JSON format",
                        "timestamp": datetime.utcnow().isoformat()
                    }
                    await websocket.send(json.dumps(error_response))
                
                except Exception as e:
                    logger.error(f"Error processing message from {charger_id}: {e}")
                    error_response = {
                        "type": "error",
                        "error": "Message processing failed",
                        "timestamp": datetime.utcnow().isoformat()
                    }
                    try:
                        await websocket.send(json.dumps(error_response))
                    except:
                        pass  # Connection might be closed
        
        except websockets.exceptions.ConnectionClosed:
            logger.info(f"EVSE {charger_id} connection closed normally")
        
        except Exception as e:
            logger.error(f"Unexpected error in EVSE connection {charger_id}: {e}")
        
        finally:
            # Clean disconnect - remove from connection list
            if charger_id:
                await connection_manager.remove_connection(charger_id)
    
    async def start(self):
        """Start the WebSocket server (alias for start_server)"""
        await self.start_server()
    
    async def stop(self):
        """Stop the WebSocket server (alias for stop_server)"""
        await self.stop_server()
    
    async def wait_closed(self):
        """Wait for the server to be closed"""
        if self.server:
            await self.server.wait_closed()
    
    async def start_server(self):
        """Start the WebSocket server"""
        try:
            # Start WebSocket server on /ws/evse endpoint
            self.server = await websockets.serve(
                self.websocket_handler,
                self.config.host,
                self.config.port,
                ping_interval=self.config.ping_interval,
                ping_timeout=self.config.ping_timeout,
                close_timeout=self.config.close_timeout,
                max_size=2**20,  # 1MB max message size
                max_queue=32     # Max queued messages per connection
            )
            
            logger.info(f"WebSocket OCPP server started on ws://{self.config.host}:{self.config.port}/ws/evse")
            logger.info(f"Server accepting EVSE connections with token authentication")
            
        except Exception as e:
            logger.error(f"Failed to start WebSocket server: {e}")
            raise
    
    async def stop_server(self):
        """Stop the WebSocket server"""
        if self.server:
            self.server.close()
            await self.server.wait_closed()
            logger.info("WebSocket OCPP server stopped")
    
    async def get_server_stats(self) -> Dict[str, Any]:
        """Get server statistics"""
        connections = await connection_manager.get_all_connections()
        
        return {
            "total_connections": connection_manager.get_connection_count(),
            "active_chargers": list(connections.keys()),
            "server_uptime": "N/A",  # Would need start time tracking
            "endpoint": f"ws://{self.config.host}:{self.config.port}/ws/evse",
            "connections": connections
        }


# Export the connection manager for use by other modules
__all__ = ["WebSocketOCPPServer", "connection_manager", "EVSEConnectionManager"]


async def main():
    """Main function to run the WebSocket server"""
    import signal
    import asyncio
    
    logger.info("Starting Voltie OCPP WebSocket Server...")
    
    server = WebSocketOCPPServer()
    
    # Setup signal handlers for graceful shutdown
    def signal_handler():
        logger.info("Received shutdown signal, stopping server...")
        asyncio.create_task(server.stop_server())
    
    loop = asyncio.get_running_loop()
    for sig in [signal.SIGINT, signal.SIGTERM]:
        if hasattr(signal, sig.name):
            loop.add_signal_handler(sig, signal_handler)
    
    try:
        await server.start_server()
    except KeyboardInterrupt:
        logger.info("Received Ctrl+C, shutting down...")
    except Exception as e:
        logger.error(f"Server error: {e}")
        raise
    finally:
        await server.stop_server()


if __name__ == "__main__":
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        print("\nGraceful shutdown complete.")
    except Exception as e:
        logger.error(f"Server failed: {e}")
        import sys
        sys.exit(1)
