"""
Unified Message Dispatcher for OCPP events
Handles both UDP and WebSocket OCPP messages through a common interface
"""
import json
import logging
from typing import Dict, Callable, Any, Optional, Tuple, Union
from datetime import datetime

logger = logging.getLogger(__name__)


class UnifiedOCPPDispatcher:
    """
    Unified dispatcher that handles OCPP messages for both UDP and WebSocket transports
    """
    
    def __init__(self):
        self.handlers: Dict[str, Callable] = {}
        self._register_default_handlers()
        
    def _register_default_handlers(self):
        """Register default OCPP event handlers from client module"""
        try:
            # Import client event handlers
            from client.events import get_event_handlers
            
            # Get all event handlers and register them
            event_handlers = get_event_handlers()
            for action, handler in event_handlers.items():
                self.register_handler(action, handler)
                
        except ImportError as e:
            logger.warning(f"Could not import client event handlers: {e}")
            # Fall back to basic handlers if client import fails
            self._register_fallback_handlers()
    
    def _register_fallback_handlers(self):
        """Register basic fallback handlers if client handlers unavailable"""
        async def fallback_boot_notification(payload: Dict[str, Any]) -> Dict[str, Any]:
            return {
                "status": "Accepted",
                "currentTime": datetime.utcnow().isoformat() + "Z",
                "interval": 300
            }
        
        async def fallback_heartbeat(payload: Dict[str, Any]) -> Dict[str, Any]:
            return {
                "currentTime": datetime.utcnow().isoformat() + "Z"
            }
        
        basic_handlers = {
            "BootNotification": fallback_boot_notification,
            "Heartbeat": fallback_heartbeat,
        }
        
        for action, handler in basic_handlers.items():
            self.register_handler(action, handler)
    
    def register_handler(self, message_type: str, handler: Callable):
        """Register a handler for a specific OCPP action"""
        self.handlers[message_type] = handler
        logger.info(f"Registered unified handler for OCPP action: {message_type}")
    
    async def handle_udp_packet(self, data: bytes, addr: Tuple[str, int]) -> Optional[bytes]:
        """
        Handle incoming UDP packet with OCPP message (legacy interface)
        
        Args:
            data: Raw UDP packet data
            addr: Source address (host, port)
            
        Returns:
            Response bytes to send back, or None
        """
        try:
            # Decode the incoming message
            message_str = data.decode('utf-8').strip()
            
            # Try to parse as simple action/payload format first
            try:
                simple_message = json.loads(message_str)
                if "action" in simple_message and "payload" in simple_message:
                    action = simple_message["action"]
                    payload = simple_message["payload"]
                    
                    # Add transport metadata
                    payload["_source_addr"] = f"{addr[0]}:{addr[1]}"
                    payload["_transport"] = "udp"
                    payload["_charge_point_id"] = payload.get("_charge_point_id", f"udp-{addr[0]}-{addr[1]}")
                    
                    # Handle the message
                    response_payload = await self.dispatch_action(action, payload)
                    
                    if response_payload is not None:
                        response = {
                            "action": action,
                            "payload": response_payload,
                            "timestamp": datetime.utcnow().isoformat() + "Z"
                        }
                        return json.dumps(response).encode('utf-8')
                        
            except (json.JSONDecodeError, KeyError):
                pass
            
            # Try to parse as OCPP array format [MessageType, MessageId, Action, Payload]
            try:
                from client.utils import parse_ocpp_message, format_ocpp_message
                
                parsed = parse_ocpp_message(message_str)
                if parsed:
                    message_type, message_id, action, payload = parsed
                    
                    # Add transport metadata
                    if isinstance(payload, dict):
                        payload["_source_addr"] = f"{addr[0]}:{addr[1]}"
                        payload["_transport"] = "udp"
                        payload["_charge_point_id"] = payload.get("_charge_point_id", f"udp-{addr[0]}-{addr[1]}")
                    
                    # Handle the message
                    response_payload = await self.dispatch_action(action, payload)
                    
                    if response_payload is not None:
                        # Format as OCPP response
                        response_message = format_ocpp_message("CALLRESULT", message_id, response_payload)
                        return response_message.encode('utf-8')
                        
            except Exception as e:
                logger.error(f"Error parsing OCPP array format: {e}")
            
            logger.warning(f"Could not parse UDP message from {addr}: {message_str}")
            return None
            
        except Exception as e:
            logger.error(f"Error handling UDP packet from {addr}: {e}")
            return None
    
    async def handle_websocket_message(self, message_context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """
        Handle WebSocket message context from EVSE client
        
        Args:
            message_context: Context containing charger_id, message, websocket, etc.
            
        Returns:
            Response message to send back to the charger (if any)
        """
        try:
            # Extract message details from context
            charger_id = message_context.get("charger_id", "unknown")
            message = message_context.get("message", {})
            websocket = message_context.get("websocket")
            
            # Check if this is a standard OCPP message array format
            if isinstance(message, list) and len(message) >= 2:
                # Handle standard OCPP message format [MessageType, MessageId, Action, Payload]
                response_array = await self._handle_ocpp_array_message(charger_id, message)
                if response_array:
                    # Convert array response to dict format for WebSocket
                    return {
                        "type": "ocpp_response",
                        "message": response_array,
                        "timestamp": datetime.utcnow().isoformat() + "Z"
                    }
                return None
            
            # Handle JSON object format (WebSocket specific)
            action = message.get("action") or message.get("messageType") or message.get("type")
            if not action:
                logger.warning(f"No action found in WebSocket message from {charger_id}: {message}")
                return {
                    "type": "error",
                    "error": "No action specified in message",
                    "timestamp": datetime.utcnow().isoformat() + "Z"
                }
            
            logger.debug(f"Dispatching WebSocket action '{action}' from charger {charger_id}")
            
            # Extract payload
            payload = message.get("payload", message.get("data", {}))
            
            # Add context information to payload
            if isinstance(payload, dict):
                payload['_source_addr'] = charger_id
                payload['_charge_point_id'] = charger_id
                payload['_websocket'] = websocket
                payload['_transport'] = 'websocket'
            
            # Dispatch to handler
            response_payload = await self.dispatch_action(action, payload)
            
            if response_payload is not None:
                return {
                    "type": "response",
                    "action": action,
                    "payload": response_payload,
                    "timestamp": datetime.utcnow().isoformat() + "Z"
                }
            
            return None
                
        except Exception as e:
            logger.error(f"Error dispatching WebSocket message: {e}")
            return {
                "type": "error",
                "error": "Internal server error",
                "timestamp": datetime.utcnow().isoformat() + "Z"
            }
    
    async def _handle_ocpp_array_message(self, charge_point_id: str, message: list) -> Optional[list]:
        """Handle OCPP array format message [MessageType, MessageId, Action, Payload]"""
        try:
            if len(message) < 3:
                logger.warning(f"Invalid OCPP message format from {charge_point_id}: {message}")
                return None
            
            message_type = message[0]
            message_id = message[1]
            
            if message_type == 2 and len(message) >= 4:  # CALL
                action = message[2]
                payload = message[3] if len(message) > 3 else {}
                
                # Add transport metadata
                if isinstance(payload, dict):
                    payload['_source_addr'] = charge_point_id
                    payload['_charge_point_id'] = charge_point_id
                    payload['_transport'] = 'websocket'
                
                # Dispatch the action
                response_payload = await self.dispatch_action(action, payload)
                
                if response_payload is not None:
                    # Return CALLRESULT format
                    return [3, message_id, response_payload]
                    
            else:
                logger.warning(f"Unsupported OCPP message type {message_type} from {charge_point_id}")
                
        except Exception as e:
            logger.error(f"Error handling OCPP array message from {charge_point_id}: {e}")
            # Return CALLERROR
            return [4, message.get(1, "unknown"), "InternalError", str(e), {}]
        
        return None
    
    async def dispatch_action(self, action: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """
        Dispatch an OCPP action to its appropriate handler
        
        Args:
            action: OCPP action name (e.g., "BootNotification")
            payload: Message payload with transport metadata
            
        Returns:
            Response payload or None
        """
        try:
            transport = payload.get('_transport', 'unknown')
            source = payload.get('_source_addr', 'unknown')
            
            logger.debug(f"Dispatching {transport.upper()} action '{action}' from {source}")
            
            if action in self.handlers:
                handler = self.handlers[action]
                response = await handler(payload)
                logger.debug(f"Action '{action}' handled successfully by unified dispatcher")
                return response
            else:
                logger.warning(f"No handler registered for action: {action}")
                
                # Try to handle unknown event
                if "handle_unknown_event" in self.handlers:
                    return await self.handlers["handle_unknown_event"](action, payload)
                else:
                    return {
                        "status": "NotSupported",
                        "error": f"Action '{action}' is not supported"
                    }
                    
        except Exception as e:
            logger.error(f"Error dispatching action '{action}': {e}")
            return {
                "status": "InternalError",
                "error": f"Failed to process {action}: {str(e)}"
            }
    
    def get_registered_actions(self) -> list:
        """Get list of registered OCPP actions"""
        return list(self.handlers.keys())


# Global unified dispatcher instance
_unified_dispatcher = None

def get_unified_dispatcher() -> UnifiedOCPPDispatcher:
    """Get the global unified dispatcher instance"""
    global _unified_dispatcher
    if _unified_dispatcher is None:
        _unified_dispatcher = UnifiedOCPPDispatcher()
    return _unified_dispatcher


__all__ = ["UnifiedOCPPDispatcher", "get_unified_dispatcher"]
