"""
Message dispatcher for handling OCPP messages in UDP client
"""
import json
import logging
from typing import Dict, Callable, Any, Optional, Tuple
from .drivers import get_driver
from .utils import format_ocpp_message, MessageBuilder, parse_ocpp_message

logger = logging.getLogger(__name__)


class MessageDispatcher:
    """Handles routing and dispatching of OCPP messages"""
    
    def __init__(self):
        self.handlers: Dict[str, Callable] = {}
        self._register_default_handlers()
        
    def _register_default_handlers(self):
        """Register default message handlers"""
        from .events import get_event_handlers
        
        # Get all event handlers and register them
        event_handlers = get_event_handlers()
        for action, handler in event_handlers.items():
            self.register_handler(action, handler)
        
    def register_handler(self, message_type: str, handler: Callable):
        """Register a handler for a specific message type"""
        self.handlers[message_type] = handler
        logger.info(f"Registered handler for message type: {message_type}")
    
    async def handle_packet(self, data: bytes, addr: Tuple[str, int]) -> Optional[bytes]:
        """
        Handle incoming UDP packet with OCPP message
        
        Args:
            data: Raw UDP packet data
            addr: Source address (host, port)
            
        Returns:
            Response bytes to send back, or None
        """
        try:
            # Decode the incoming message
            message_str = data.decode('utf-8').strip()
            
            # Try to parse as simple action/payload format first
            try:
                simple_message = json.loads(message_str)
                if 'action' in simple_message:
                    # Handle simple format: {"action": "EventName", "payload": {...}}
                    action = simple_message['action']
                    payload = simple_message.get('payload', {})
                    payload['_source_addr'] = f"{addr[0]}:{addr[1]}"
                    
                    logger.info(f"Processing {action} message from {addr[0]}:{addr[1]}")
                    
                    # Dispatch to handler
                    response = await self.dispatch(action, payload)
                    
                    if response:
                        response_str = format_ocpp_message(response)
                        return response_str.encode('utf-8')
                    return None
            except (json.JSONDecodeError, KeyError):
                pass
            
            # Try to parse as full OCPP message format
            parsed_message = parse_ocpp_message(message_str)
            if parsed_message:
                response_str = await self.dispatch_message(parsed_message, addr)
                if response_str:
                    return response_str.encode('utf-8')
            else:
                logger.warning(f"Failed to parse message from {addr[0]}:{addr[1]}: {message_str}")
                error_response = {"status": "error", "message": "Invalid message format"}
                return format_ocpp_message(error_response).encode('utf-8')
            
            return None
            
        except UnicodeDecodeError:
            logger.error(f"Failed to decode message from {addr[0]}:{addr[1]} - not UTF-8")
            return None
        except Exception as e:
            logger.error(f"Error processing packet from {addr[0]}:{addr[1]}: {e}")
            return None
        
    async def dispatch(self, message_type: str, message_data: Dict[str, Any]) -> Any:
        """Dispatch a message to the appropriate handler"""
        if message_type not in self.handlers:
            logger.warning(f"Unknown OCPP action: {message_type}")
            # Call default handler for unknown events
            if 'unknown' in self.handlers:
                return await self.handlers['unknown'](message_type, message_data)
            else:
                return {"status": "error", "message": f"Unknown action: {message_type}"}
            
        handler = self.handlers[message_type]
        try:
            result = await handler(message_data)
            logger.debug(f"Message {message_type} handled successfully")
            return result
        except Exception as e:
            logger.error(f"Error handling message {message_type}: {e}")
            return {"status": "error", "message": f"Handler error: {str(e)}"}
    
    async def dispatch_message(self, parsed_message: Dict[str, Any], addr: Tuple[str, int]) -> Optional[str]:
        """
        Dispatch a parsed OCPP message and return response string
        
        Args:
            parsed_message: Parsed OCPP message dict
            addr: Source address (host, port)
            
        Returns:
            Response message as string, or None
        """
        try:
            # Extract message components
            message_type_id = parsed_message.get("messageTypeId")
            unique_id = parsed_message.get("uniqueId")
            
            if message_type_id == 2:  # CALL
                action = parsed_message.get("action")
                payload = parsed_message.get("payload", {})
                payload['_source_addr'] = f"{addr[0]}:{addr[1]}"
                
                logger.info(f"Processing CALL message from {addr[0]}:{addr[1]} - Action: {action}")
                
                # Dispatch to handler
                response_payload = await self.dispatch(action, payload)
                
                if response_payload is not None:
                    # Create CALLRESULT response
                    response_msg = MessageBuilder.create_call_result(unique_id, response_payload)
                    return format_ocpp_message(response_msg)
                else:
                    # Create generic error response
                    error_msg = MessageBuilder.create_call_error(
                        unique_id,
                        "NotImplemented",
                        f"Handler for {action} not implemented",
                        {}
                    )
                    return format_ocpp_message(error_msg)
                    
            elif message_type_id == 3:  # CALLRESULT
                logger.info(f"Received CALLRESULT from {addr[0]}:{addr[1]} - ID: {unique_id}")
                # Handle response (typically for outgoing calls)
                return None
                
            elif message_type_id == 4:  # CALLERROR
                error_code = parsed_message.get("errorCode")
                error_description = parsed_message.get("errorDescription")
                logger.warning(f"Received CALLERROR from {addr[0]}:{addr[1]} - {error_code}: {error_description}")
                return None
                
            else:
                logger.warning(f"Unknown message type ID: {message_type_id}")
                return None
                
        except Exception as e:
            logger.error(f"Error dispatching message from {addr[0]}:{addr[1]}: {e}")
            
            # Try to create error response if we have unique_id
            unique_id = parsed_message.get("uniqueId")
            if unique_id:
                error_msg = MessageBuilder.create_call_error(
                    unique_id,
                    "InternalError",
                    f"Error processing message: {str(e)}",
                    {}
                )
                return format_ocpp_message(error_msg)
            return None
