"""This module contains the ``GateEncoder`` class which can be used for encoding gates
to integers and back.
Usage:
    >>> from qgym.utils import GateEncoder
    >>> encoder = GateEncoder().learn_gates(["x", "y", "z", "cnot", "measure"])
    >>> encoded_list = encoder.encode_gates(["x", "x", "measure", "z"])
    >>> print(encoded_list)
    [1, 1, 5, 3]
    >>> encoder.decode_gates(encoded_list)
    ['x', 'x', 'measure', 'z']
"""
from __future__ import annotations
import warnings
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, TypeVar, cast, overload
from qgym.custom_types import Gate
T = TypeVar("T")
[docs]
class GateEncoder:
    """Learns a set of gates and creates a mapping to integers and back."""
[docs]
    def __init__(self) -> None:
        """Initialize the ``GateEncoder``."""
        self.n_gates = 0
        self._encoding_dct: dict[str, int] = {}
        self._decoding_dct: dict[int, str] = {}
        self._longest_name = 0 
[docs]
    def learn_gates(self, gates: Iterable[str]) -> GateEncoder:
        """Learns the gates names from an ``Iterable`` and creates a mapping from unique
        gate names to integers and back.
        Args:
            gates: ``Iterable`` containing the names of the gates that should be
                learned. The ``Iterable`` can contain duplicate names.
        Returns:
            Self.
        """
        self.n_gates = 0
        for idx, gate_name in enumerate(gates, 1):
            if gate_name in self._encoding_dct:
                warnings.warn(f"'gates' contains multiple entries of {gate_name}")
            else:
                self._encoding_dct[gate_name] = idx
                self._decoding_dct[idx] = gate_name
                self._longest_name = max(self._longest_name, len(gate_name))
                self.n_gates += 1
        return self 
    @overload
    def encode_gates(self, gates: str) -> int: ...
    @overload
    def encode_gates(self, gates: Mapping[str, T]) -> dict[int, T]: ...
    @overload
    def encode_gates(self, gates: Sequence[Gate]) -> list[Gate]: ...
    @overload
    def encode_gates(self, gates: set[str]) -> set[int]: ...
    @overload
    def encode_gates(self, gates: list[str] | tuple[str, ...]) -> list[int]: ...
[docs]
    def encode_gates(
        self,
        gates: (
            str
            | Mapping[str, Any]
            | Sequence[Gate]
            | set[str]
            | list[str]
            | tuple[str, ...]
        ),
    ) -> int | dict[int, Any] | list[Gate] | set[int] | list[int]:
        """Encode the gate names (of type ``str``) in `gates` to integers, based on the
        gates seen in ``learn_gates``.
        Args:
            gates: Gates to encode. The input type determines the return type.
        Raises:
            TypeError: When an unsupported type is given.
        Returns: Integer encoded version of gates. The output structure should resemble
            the input structure. So a ``Mapping`` will return a ``Dict``, while a single
            ``str`` will return an ``int``.
        """
        if isinstance(gates, str):
            encoded_str = self._encoding_dct[gates]
            return encoded_str
        if isinstance(gates, Mapping):
            return self._encode_mapping(gates)
        if isinstance(gates, Sequence) and (
            len(gates) == 0 or isinstance(gates[0], Gate)
        ):
            # We assume that if the first element of gates is a Gate, then the whole
            # Sequence contains Gate objects.
            encoded_gates_list: list[Gate] = []
            for gate in gates:
                gate = cast(Gate, gate)
                encoded_name = self._encoding_dct[gate.name]
                encoded_gates_list.append(Gate(encoded_name, gate.q1, gate.q2))
            return encoded_gates_list
        if isinstance(gates, set):
            encoded_names_set: set[int] = set()
            for gate_name in gates:
                gate_encoding = self._encoding_dct[gate_name]
                encoded_names_set.add(gate_encoding)
            return encoded_names_set
        if isinstance(gates, (list, tuple)):
            encoded_names_list: list[int] = []
            for gate_name in gates:
                gate_encoding = self._encoding_dct[gate_name]
                encoded_names_list.append(gate_encoding)
            return encoded_names_list
        raise TypeError(
            f"gates type must be str, Mapping or Sequence, got {type(gates)}."
        ) 
    def _encode_mapping(self, mapping: Mapping[str, Any]) -> dict[int, Any]:
        """Encode a mapping with gate names.
        Raises:
            ValueError: For unknown mappings.
        """
        encoded_dict: dict[int, Any] = {}
        gate_name: str
        for gate_name, item in mapping.items():
            gate_encoding = self._encoding_dct[gate_name]
            if isinstance(item, int):
                encoded_dict[gate_encoding] = item
            elif isinstance(item, Iterable):
                item_encoded = []
                for i in item:
                    item_encoded.append(self._encoding_dct[i])
                encoded_dict[gate_encoding] = item_encoded
            else:
                raise ValueError("Unknown mapping")
        return encoded_dict
    @overload
    def decode_gates(self, encoded_gates: int) -> str: ...
    @overload
    def decode_gates(self, encoded_gates: Mapping[int, Any]) -> dict[str, Any]: ...
    @overload
    def decode_gates(self, encoded_gates: Sequence[Gate]) -> list[Gate]: ...
    @overload
    def decode_gates(self, encoded_gates: set[int]) -> set[str]: ...
    @overload
    def decode_gates(self, encoded_gates: list[int] | tuple[int, ...]) -> list[str]: ...
[docs]
    def decode_gates(
        self,
        encoded_gates: (
            int
            | Mapping[int, Any]
            | Sequence[Gate]
            | set[int]
            | list[int]
            | tuple[int, ...]
        ),
    ) -> str | dict[str, Any] | list[Gate] | set[str] | list[str]:
        """Decode integer encoded gate names to the original gate names based on the
        gates seen in ``learn_gates``.
        Args:
            encoded_gates: Encoded gates that are to be decoded. The input type
            determines the return type.
        Raises:
            TypeError: When an unsupported type is given
        Returns:
            Decoded version of encoded_gates. The output structure should resemble the
            input structure. So a ``Mapping`` will return a ``Dict``, while a single
            ``int`` will return a ``str``.
        """
        if isinstance(encoded_gates, int):
            decoded_int = self._decoding_dct[encoded_gates]
            return decoded_int
        if isinstance(encoded_gates, Mapping):
            decoded_dict: dict[str, Any] = {}
            gate_int: int
            for gate_int in encoded_gates:
                gate_name = self._decoding_dct[gate_int]
                decoded_dict[gate_name] = encoded_gates[gate_int]
            return decoded_dict
        if isinstance(encoded_gates, Sequence) and isinstance(encoded_gates[0], Gate):
            # We assume that if the first element of encoded_gates is a Gate, then the
            # whole Sequence contains Gate objects.
            decoded_gate_list: list[Gate] = []
            for gate in encoded_gates:
                gate = cast(Gate, gate)
                decoded_gate_name = self._decoding_dct[gate.name]
                decoded_gate_list.append(Gate(decoded_gate_name, gate.q1, gate.q2))
            return decoded_gate_list
        if isinstance(encoded_gates, set):
            decoded_name_set: set[str] = set()
            for gate_int in encoded_gates:
                decoded_gate = self._decoding_dct[gate_int]
                decoded_name_set.add(decoded_gate)
            return decoded_name_set
        if isinstance(encoded_gates, (list, tuple)):
            decoded_name_list: list[str] = []
            for gate_int in encoded_gates:
                decoded_gate = self._decoding_dct[gate_int]
                decoded_name_list.append(decoded_gate)
            return decoded_name_list
        raise TypeError(
            "encoded_gates must be int, Mapping or Sequence, got "
            f"{type(encoded_gates)}."
        ) 
[docs]
    def __repr__(self) -> str:
        """Make a string representation without endline characters."""
        return f"{self.__class__.__name__}(encoding={self._encoding_dct})"