"""Query evaluation unit tests."""

from datalog.easy import read, select
from datalog.types import (
    CachedDataset,
    Constant,
    Dataset,
    PartlyIndexedDataset,
    TableIndexedDataset,
)

import pytest


DBCLS = [Dataset, CachedDataset, TableIndexedDataset, PartlyIndexedDataset]


@pytest.mark.parametrize("db_cls,", DBCLS)
def test_id_query(db_cls):
    """Querying for a constant in the dataset."""

    ab = (
        Constant("a"),
        Constant("b"),
    )
    assert not select(
        db_cls([], []),
        (
            "a",
            "b",
        ),
    )
    assert select(db_cls([ab], []), ("a", "b",)) == [
        (
            (("a", "b"),),
            {},
        )
    ]


@pytest.mark.parametrize("db_cls,", DBCLS)
def test_lvar_query(db_cls):
    """Querying for a binding in the dataset."""

    d = read("""a(b). a(c).""", db_cls=db_cls)

    assert select(d, ("a", "X")) == [
        ((("a", "b"),), {"X": "b"}),
        ((("a", "c"),), {"X": "c"}),
    ]


@pytest.mark.parametrize("db_cls,", DBCLS)
def test_lvar_unification(db_cls):
    """Querying for MATCHING bindings in the dataset."""

    d = read("""edge(b, c). edge(c, c).""", db_cls=db_cls)

    assert (
        select(
            d,
            (
                "edge",
                "X",
                "X",
            ),
        )
        == [((("edge", "c", "c"),), {"X": "c"})]
    )


@pytest.mark.parametrize("db_cls,", DBCLS)
def test_rule_join(db_cls):
    """Test a basic join query - the parent -> grandparent relation."""

    d = read(
        """
child(a, b).
child(b, c).
child(b, d).
child(b, e).

grandchild(A, B) :-
  child(A, C),
  child(C, B).
""",
        db_cls=db_cls,
    )

    assert select(d, ("grandchild", "a", "X",)) == [
        ((("grandchild", "a", "c"),), {"X": "c"}),
        ((("grandchild", "a", "d"),), {"X": "d"}),
        ((("grandchild", "a", "e"),), {"X": "e"}),
    ]


@pytest.mark.parametrize("db_cls,", DBCLS)
def test_antijoin(db_cls):
    """Test a query containing an antijoin."""

    d = read(
        """
a(foo, bar).
b(foo, bar).
a(baz, qux).
% matching b(baz, qux). is our antijoin test

no-b(X, Y) :-
  a(X, Y),
  ~b(X, Z).
""",
        db_cls=db_cls,
    )

    assert select(d, ("no-b", "X", "Y")) == [
        ((("no-b", "baz", "qux"),), {"X": "baz", "Y": "qux"})
    ]


@pytest.mark.parametrize("db_cls,", DBCLS)
def test_nested_antijoin(db_cls):
    """Test a query which negates a subquery which uses an antijoin.

    Shouldn't exercise anything more than `test_antjoin` does, but it's an interesting case since you
    actually can't capture the same semantics using a single query. Antijoins can't propagate positive
    information (create lvar bindings) so I'm not sure you can express this another way without a
    different evaluation strategy.

    """

    d = read(
        """
a(foo, bar).
b(foo, bar).
a(baz, qux).
b(baz, quack).

b-not-quack(X, Y) :-
  b(X, Y),
  ~=(Y, quack).

a-no-nonquack(X, Y) :-
  a(X, Y),
  ~b-not-quack(X, Y).
""",
        db_cls=db_cls,
    )

    assert select(d, ("a-no-nonquack", "X", "Y")) == [
        ((("a-no-nonquack", "baz", "qux"),), {"X": "baz", "Y": "qux"})
    ]


@pytest.mark.parametrize("db_cls,", DBCLS)
def test_alternate_rule(db_cls):
    """Testing that both recursion and alternation work."""

    d = read(
        """
edge(a, b).
edge(b, c).
edge(c, d).
edge(d, e).
edge(e, f).

path(A, B) :-
  edge(A, B).

path(A, B) :-
  edge(A, C),
  path(C, B).
""",
        db_cls=db_cls,
    )

    # Should be able to recurse to this one.
    assert select(d, ("path", "a", "f")) == [((("path", "a", "f"),), {})]


# FIXME (arrdem 2019-06-13):
#
#   This test is BROKEN for the simple dataset.  In order for left-recursive production rules to
#   work, they have to ground out somewhere.  Under the naive, cache-less datalog this is an
#   infinite left recursion.  Under the cached versions, the right-recursion becomes iteration over
#   an incrementally realized list which ... is weird but does work because the recursion grounds
#   out in iterating over an empty list on the 2nd round then falls through to the other production
#   rule which generates ground tuples and feeds everything.
#
#   It's not clear how to make this work with the current (lack of) query planner on the naive db as
#   really fixing this requires some amount of insight into the data dependency structure and may
#   involve reordering rules.
@pytest.mark.parametrize("db_cls,", DBCLS)
def test_alternate_rule_lrec(db_cls):
    """Testing that both recursion and alternation work."""

    if db_cls == Dataset:
        pytest.xfail(
            "left-recursive rules aren't supported with a trivial store and no planner"
        )

    d = read(
        """
edge(a, b).
edge(b, c).
edge(c, d).
edge(d, e).
edge(e, f).

path(A, B) :-
  edge(A, B).

path(A, B) :-
  path(A, C),
  edge(C, B).
""",
        db_cls=db_cls,
    )

    # Should be able to recurse to this one.
    assert select(d, ("path", "a", "f")) == [((("path", "a", "f"),), {})]


@pytest.mark.parametrize("db_cls,", DBCLS)
def test_cojoin(db_cls):
    """Tests that unification occurs correctly."""

    d = read(
        """
edge(a, b).
edge(b, c).
edge(c, d).
edge(d, e).
edge(e, f).
edge(c, q).

two_path(A, B, C) :- edge(A, B), edge(B, C).
""",
        db_cls=db_cls,
    )

    # Should be able to recurse to this one.
    assert [t for t, _ in select(d, ("two_path", "A", "B", "C"))] == [
        (("two_path", "a", "b", "c"),),
        (("two_path", "b", "c", "d"),),
        (("two_path", "b", "c", "q"),),
        (("two_path", "c", "d", "e"),),
        (("two_path", "d", "e", "f"),),
    ]