Create a Symbol value class

This commit is contained in:
Reid 'arrdem' McKenzie 2023-05-08 17:27:04 -06:00
parent 10b143f2fa
commit bcd50fe57f
4 changed files with 48 additions and 24 deletions

View file

@ -1,6 +1,7 @@
py_project(
name = "milkshake",
lib_deps = [
py_requirement("attrs"),
py_requirement("lark"),
],
)

View file

@ -3,6 +3,7 @@
from importlib.resources import files
from attrs import define, field
from lark import Lark, Tree, Token, Transformer, v_args
@ -10,6 +11,19 @@ with files(__package__).joinpath("grammar.lark").open("r", encoding="utf-8") as
GRAMMAR = fp.read()
@define
class Symbol:
name: str
def __eq__(self, other):
if isinstance(other, str):
return self.name == other
elif isinstance(other, type(self)):
return self.name == other.name
else:
return False
@v_args(tree=True)
class T(Transformer):
"""A prepackaged transformer that cleans up the quoting details."""
@ -25,7 +39,9 @@ class T(Transformer):
qq_map = un_qq
qq_set = un_qq
qq_atom = un_qq
qq_symbol = un_qq
def qq_symbol(self, obj):
return self.symbol(self.un_qq(obj))
def qq_quote(self, obj):
return self.quote(self.un_qq(obj))
@ -37,6 +53,9 @@ class T(Transformer):
unquote = quote
unquote_splicing = quote
def symbol(self, obj):
return Symbol(obj.children[0].value)
PARSER = Lark(GRAMMAR, start=["module", "expr"])

View file

@ -56,7 +56,11 @@ string: /"([^"]|\\")+"/
pattern: /\/([^\/]|\\\/)+\//
number: /[+-]?(\d+r)?(\d[\d,_\.]*)([\.,][\d,_\.]*)?(e[+-]?\d+)?/
number: US_FORMAT_NUMBER | EU_FORMAT_NUMBER | UNDERSCORE_NUMBER
US_FORMAT_NUMBER: /[+-]?(\d+r)?(\d+(,\d{3})*)(\.\d+(,\d{3})*)?(e[+-]?\d+)?/
EU_FORMAT_NUMBER: /[+-]?(\d+r)?(\d+(\.\d{3})*)(,\d+(\.\d{3})*)?(e[+-]?\d+)?/
UNDERSCORE_NUMBER: /[+-]?(\d+r)?(\d+(_\d{3})*)([,\.]\d+(_\d{3})*)?(e[+-]?\d+)?/
// Note that we're demoting Symbol from the default parse priority of 0 to -1
// This because _anything more specific_ should be chosen over symbol

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
from milkshake import slurp
from milkshake import slurp, Symbol
from lark import Tree, Token
import pytest
@ -10,26 +10,26 @@ import pytest
"input, val",
[
("()", Tree("list", [])),
("nil", nil := Tree("symbol", ["nil"])),
("nil", nil := Symbol("nil")),
("(nil nil nil)", Tree("list", [nil, nil, nil])),
(
"(/ + - * % ^ \\ & # @ ! = |)",
Tree(
"list",
[
Tree("symbol", ["/"]),
Tree("symbol", ["+"]),
Tree("symbol", ["-"]),
Tree("symbol", ["*"]),
Tree("symbol", ["%"]),
Tree("symbol", ["^"]),
Tree("symbol", ["\\"]),
Tree("symbol", ["&"]),
Tree("symbol", ["#"]),
Tree("symbol", ["@"]),
Tree("symbol", ["!"]),
Tree("symbol", ["="]),
Tree("symbol", ["|"]),
Symbol("/"),
Symbol("+"),
Symbol("-"),
Symbol("*"),
Symbol("%"),
Symbol("^"),
Symbol("\\"),
Symbol("&"),
Symbol("#"),
Symbol("@"),
Symbol("!"),
Symbol("="),
Symbol("|"),
],
),
),
@ -48,12 +48,12 @@ import pytest
Tree(
"list",
[
Tree("symbol", ["+inf"]),
Tree("symbol", ["-inf"]),
Tree("symbol", ["inf"]),
Tree("symbol", ["nan"]),
Tree("symbol", ["+nan"]),
Tree("symbol", ["-nan"]),
Symbol("+inf"),
Symbol("-inf"),
Symbol("inf"),
Symbol("nan"),
Symbol("+nan"),
Symbol("-nan"),
],
),
),
@ -66,7 +66,7 @@ import pytest
Tree(
"list",
[
Tree("symbol", ["nil"]),
Symbol("nil"),
Tree(
"unquote",
[