Compare commits

..

3 commits

17 changed files with 1104 additions and 168 deletions

View file

@ -62,18 +62,18 @@ load("@arrdem_source_pypi//:requirements.bzl", "install_deps")
# Call it to define repos for your requirements.
install_deps()
git_repository(
name = "rules_zapp",
remote = "https://git.arrdem.com/arrdem/rules_zapp.git",
commit = "72f82e0ace184fe862f1b19c4f71c3bc36cf335b",
# tag = "0.1.2",
)
# local_repository(
# name = "rules_zapp",
# path = "/home/arrdem/Documents/hobby/programming/lang/python/rules_zapp",
# git_repository(
# name = "rules_zapp",
# remote = "https://git.arrdem.com/arrdem/rules_zapp.git",
# commit = "72f82e0ace184fe862f1b19c4f71c3bc36cf335b",
# # tag = "0.1.2",
# )
local_repository(
name = "rules_zapp",
path = "/home/arrdem/Documents/hobby/programming/lang/python/rules_zapp",
)
####################################################################################################
# Docker support
####################################################################################################

View file

@ -0,0 +1,3 @@
py_project(
name = "gcode-interpreter",
)

View file

@ -0,0 +1,15 @@
# Gcode Interpreter
Extracted [from
octoprint](https://raw.githubusercontent.com/OctoPrint/OctoPrint/master/src/octoprint/util/gcodeInterpreter.py), this
package provides static analysis (ahem. abstract interpretation) of GCODE scripts for 3d printers to provide key data
such as the bounding box through which the tool(s) move, estimated net movement time and the net amount of material
extruded.
## License
This artifact is licensed under the GNU Affero General Public License http://www.gnu.org/licenses/agpl.html
Copyright © Reid McKenzie <me@arrdem.com>
Copyright © Gina Häußge <osd@foosel.net>
Copyright © David Braam

View file

@ -0,0 +1,868 @@
__author__ = "Gina Häußge <osd@foosel.net> based on work by David Braam"
__license__ = "GNU Affero General Public License http://www.gnu.org/licenses/agpl.html"
__copyright__ = "Copyright (C) 2013 David Braam, Gina Häußge - Released under terms of the AGPLv3 License"
import base64
import codecs
import io
import logging
import math
import os
import re
import zlib
class Vector3D:
"""
3D vector value
Supports addition, subtraction and multiplication with a scalar value (float, int) as well as calculating the
length of the vector.
Examples:
>>> a = Vector3D(1.0, 1.0, 1.0)
>>> b = Vector3D(4.0, 4.0, 4.0)
>>> a + b == Vector3D(5.0, 5.0, 5.0)
True
>>> b - a == Vector3D(3.0, 3.0, 3.0)
True
>>> abs(a - b) == Vector3D(3.0, 3.0, 3.0)
True
>>> a * 2 == Vector3D(2.0, 2.0, 2.0)
True
>>> a * 2 == 2 * a
True
>>> a.length == math.sqrt(a.x ** 2 + a.y ** 2 + a.z ** 2)
True
>>> copied_a = Vector3D(a)
>>> a == copied_a
True
>>> copied_a.x == a.x and copied_a.y == a.y and copied_a.z == a.z
True
"""
def __init__(self, *args):
if len(args) == 3:
(self.x, self.y, self.z) = args
elif len(args) == 1:
# copy constructor
other = args[0]
if not isinstance(other, Vector3D):
raise ValueError("Object to copy must be a Vector3D instance")
self.x = other.x
self.y = other.y
self.z = other.z
@property
def length(self):
return math.sqrt(self.x * self.x + self.y * self.y + self.z * self.z)
def __add__(self, other):
try:
if len(other) == 3:
return Vector3D(self.x + other[0], self.y + other[1], self.z + other[2])
except TypeError:
# doesn't look like a 3-tuple
pass
try:
return Vector3D(self.x + other.x, self.y + other.y, self.z + other.z)
except AttributeError:
# also doesn't look like a Vector3D
pass
raise TypeError(
"other must be a Vector3D instance or a list or tuple of length 3"
)
def __sub__(self, other):
try:
if len(other) == 3:
return Vector3D(self.x - other[0], self.y - other[1], self.z - other[2])
except TypeError:
# doesn't look like a 3-tuple
pass
try:
return Vector3D(self.x - other.x, self.y - other.y, self.z - other.z)
except AttributeError:
# also doesn't look like a Vector3D
pass
raise TypeError(
"other must be a Vector3D instance or a list or tuple of length 3"
)
def __mul__(self, other):
try:
return Vector3D(self.x * other, self.y * other, self.z * other)
except TypeError:
# doesn't look like a scalar
pass
raise ValueError("other must be a float or int value")
def __rmul__(self, other):
return self.__mul__(other)
def __abs__(self):
return Vector3D(abs(self.x), abs(self.y), abs(self.z))
def __eq__(self, other):
if not isinstance(other, Vector3D):
return False
return self.x == other.x and self.y == other.y and self.z == other.z
def __str__(self):
return "Vector3D(x={}, y={}, z={}, length={})".format(
self.x, self.y, self.z, self.length
)
class MinMax3D:
"""
Tracks minimum and maximum of recorded values
Examples:
>>> minmax = MinMax3D()
>>> minmax.record(Vector3D(2.0, 2.0, 2.0))
>>> minmax.min.x == 2.0 == minmax.max.x and minmax.min.y == 2.0 == minmax.max.y and minmax.min.z == 2.0 == minmax.max.z
True
>>> minmax.record(Vector3D(1.0, 2.0, 3.0))
>>> minmax.min.x == 1.0 and minmax.min.y == 2.0 and minmax.min.z == 2.0
True
>>> minmax.max.x == 2.0 and minmax.max.y == 2.0 and minmax.max.z == 3.0
True
>>> minmax.size == Vector3D(1.0, 0.0, 1.0)
True
>>> empty = MinMax3D()
>>> empty.size == Vector3D(0.0, 0.0, 0.0)
True
>>> weird = MinMax3D(min_z=-1.0)
>>> weird.record(Vector3D(2.0, 2.0, 2.0))
>>> weird.record(Vector3D(1.0, 2.0, 3.0))
>>> weird.min.z == -1.0
True
>>> weird.size == Vector3D(1.0, 0.0, 4.0)
True
"""
def __init__(
self,
min_x=None,
min_y=None,
min_z=None,
max_x=None,
max_y=None,
max_z=None,
):
min_x = min_x if min_x is not None else float("inf")
min_y = min_y if min_y is not None else float("inf")
min_z = min_z if min_z is not None else float("inf")
max_x = max_x if max_x is not None else -float("inf")
max_y = max_y if max_y is not None else -float("inf")
max_z = max_z if max_z is not None else -float("inf")
self.min = Vector3D(min_x, min_y, min_z)
self.max = Vector3D(max_x, max_y, max_z)
def record(self, coordinate):
"""
Records the coordinate, storing the min and max values.
The input vector components must not be None.
"""
self.min.x = min(self.min.x, coordinate.x)
self.min.y = min(self.min.y, coordinate.y)
self.min.z = min(self.min.z, coordinate.z)
self.max.x = max(self.max.x, coordinate.x)
self.max.y = max(self.max.y, coordinate.y)
self.max.z = max(self.max.z, coordinate.z)
@property
def size(self):
result = Vector3D()
for c in "xyz":
min = getattr(self.min, c)
max = getattr(self.max, c)
value = abs(max - min) if max >= min else 0.0
setattr(result, c, value)
return result
@property
def dimensions(self):
size = self.size
return {"width": size.x, "depth": size.y, "height": size.z}
@property
def area(self):
return {
"minX": None if math.isinf(self.min.x) else self.min.x,
"minY": None if math.isinf(self.min.y) else self.min.y,
"minZ": None if math.isinf(self.min.z) else self.min.z,
"maxX": None if math.isinf(self.max.x) else self.max.x,
"maxY": None if math.isinf(self.max.y) else self.max.y,
"maxZ": None if math.isinf(self.max.z) else self.max.z,
}
class AnalysisAborted(Exception):
def __init__(self, reenqueue=True, *args, **kwargs):
self.reenqueue = reenqueue
Exception.__init__(self, *args, **kwargs)
regex_command = re.compile(
r"^\s*((?P<codeGM>[GM]\d+)(\.(?P<subcode>\d+))?|(?P<codeT>T)(?P<tool>\d+))"
)
"""Regex for a GCODE command."""
class gcode:
def __init__(self, incl_layers=False, progress_callback=None):
self._logger = logging.getLogger(__name__)
self.extrusionAmount = [0]
self.extrusionVolume = [0]
self.totalMoveTimeMinute = 0
self.filename = None
self._abort = False
self._reenqueue = True
self._filamentDiameter = 0
self._print_minMax = MinMax3D()
self._travel_minMax = MinMax3D()
self._progress_callback = progress_callback
self._incl_layers = incl_layers
self._layers = []
self._current_layer = None
def _track_layer(self, pos, arc=None):
if not self._incl_layers:
return
if self._current_layer is None or self._current_layer["z"] != pos.z:
self._current_layer = {"z": pos.z, "minmax": MinMax3D(), "commands": 1}
self._layers.append(self._current_layer)
elif self._current_layer:
self._current_layer["minmax"].record(pos)
if arc is not None:
self._addArcMinMax(
self._current_layer["minmax"],
arc["startAngle"],
arc["endAngle"],
arc["center"],
arc["radius"],
)
def _track_command(self):
if self._current_layer:
self._current_layer["commands"] += 1
@property
def dimensions(self):
return self._print_minMax.dimensions
@property
def travel_dimensions(self):
return self._travel_minMax.dimensions
@property
def printing_area(self):
return self._print_minMax.area
@property
def travel_area(self):
return self._travel_minMax.area
@property
def layers(self):
return [
{
"num": num + 1,
"z": layer["z"],
"commands": layer["commands"],
"bounds": {
"minX": layer["minmax"].min.x,
"maxX": layer["minmax"].max.x,
"minY": layer["minmax"].min.y,
"maxY": layer["minmax"].max.y,
},
}
for num, layer in enumerate(self._layers)
]
def load(
self,
filename,
throttle=None,
speedx=6000,
speedy=6000,
offsets=None,
max_extruders=10,
g90_extruder=False,
bed_z=0.0,
):
self._print_minMax.min.z = self._travel_minMax.min.z = bed_z
if os.path.isfile(filename):
self.filename = filename
self._fileSize = os.stat(filename).st_size
with codecs.open(filename, encoding="utf-8", errors="replace") as f:
self._load(
f,
throttle=throttle,
speedx=speedx,
speedy=speedy,
offsets=offsets,
max_extruders=max_extruders,
g90_extruder=g90_extruder,
)
def abort(self, reenqueue=True):
self._abort = True
self._reenqueue = reenqueue
def _load(
self,
gcodeFile,
throttle=None,
speedx=6000,
speedy=6000,
offsets=None,
max_extruders=10,
g90_extruder=False,
):
lineNo = 0
readBytes = 0
pos = Vector3D(0.0, 0.0, 0.0)
currentE = [0.0]
totalExtrusion = [0.0]
maxExtrusion = [0.0]
currentExtruder = 0
totalMoveTimeMinute = 0.0
relativeE = False
relativeMode = False
duplicationMode = False
scale = 1.0
fwretractTime = 0
fwretractDist = 0
fwrecoverTime = 0
feedrate = min(speedx, speedy)
if feedrate == 0:
# some somewhat sane default if axes speeds are insane...
feedrate = 2000
if offsets is None or not isinstance(offsets, (list, tuple)):
offsets = []
if len(offsets) < max_extruders:
offsets += [(0, 0)] * (max_extruders - len(offsets))
for line in gcodeFile:
if self._abort:
raise AnalysisAborted(reenqueue=self._reenqueue)
lineNo += 1
readBytes += len(line.encode("utf-8"))
if isinstance(gcodeFile, (io.IOBase, codecs.StreamReaderWriter)):
percentage = readBytes / self._fileSize
elif isinstance(gcodeFile, (list)):
percentage = lineNo / len(gcodeFile)
else:
percentage = None
try:
if (
self._progress_callback is not None
and (lineNo % 1000 == 0)
and percentage is not None
):
self._progress_callback(percentage)
except Exception as exc:
self._logger.debug(
"Progress callback %r error: %s", self._progress_callback, exc
)
if ";" in line:
comment = line[line.find(";") + 1 :].strip()
if comment.startswith("filament_diameter"):
# Slic3r
filamentValue = comment.split("=", 1)[1].strip()
try:
self._filamentDiameter = float(filamentValue)
except ValueError:
try:
self._filamentDiameter = float(
filamentValue.split(",")[0].strip()
)
except ValueError:
self._filamentDiameter = 0.0
elif comment.startswith("CURA_PROFILE_STRING") or comment.startswith(
"CURA_OCTO_PROFILE_STRING"
):
# Cura 15.04.* & OctoPrint Cura plugin
if comment.startswith("CURA_PROFILE_STRING"):
prefix = "CURA_PROFILE_STRING:"
else:
prefix = "CURA_OCTO_PROFILE_STRING:"
curaOptions = self._parseCuraProfileString(comment, prefix)
if "filament_diameter" in curaOptions:
try:
self._filamentDiameter = float(
curaOptions["filament_diameter"]
)
except ValueError:
self._filamentDiameter = 0.0
elif comment.startswith("filamentDiameter,"):
# Simplify3D
filamentValue = comment.split(",", 1)[1].strip()
try:
self._filamentDiameter = float(filamentValue)
except ValueError:
self._filamentDiameter = 0.0
line = line[0 : line.find(";")]
match = regex_command.search(line)
gcode = tool = None
if match:
values = match.groupdict()
if "codeGM" in values and values["codeGM"]:
gcode = values["codeGM"]
elif "codeT" in values and values["codeT"]:
gcode = values["codeT"]
tool = int(values["tool"])
# G codes
if gcode in ("G0", "G1", "G00", "G01"): # Move
x = getCodeFloat(line, "X")
y = getCodeFloat(line, "Y")
z = getCodeFloat(line, "Z")
e = getCodeFloat(line, "E")
f = getCodeFloat(line, "F")
if x is not None or y is not None or z is not None:
# this is a move
move = True
else:
# print head stays on position
move = False
oldPos = pos
# Use new coordinates if provided. If not provided, use prior coordinates (minus tool offset)
# in absolute and 0.0 in relative mode.
newPos = Vector3D(
x * scale if x is not None else (0.0 if relativeMode else pos.x),
y * scale if y is not None else (0.0 if relativeMode else pos.y),
z * scale if z is not None else (0.0 if relativeMode else pos.z),
)
if relativeMode:
# Relative mode: add to current position
pos += newPos
else:
# Absolute mode: apply tool offsets
pos = newPos
if f is not None and f != 0:
feedrate = f
if e is not None:
if relativeMode or relativeE:
# e is already relative, nothing to do
pass
else:
e -= currentE[currentExtruder]
totalExtrusion[currentExtruder] += e
currentE[currentExtruder] += e
maxExtrusion[currentExtruder] = max(
maxExtrusion[currentExtruder], totalExtrusion[currentExtruder]
)
if currentExtruder == 0 and len(currentE) > 1 and duplicationMode:
# Copy first extruder length to other extruders
for i in range(1, len(currentE)):
totalExtrusion[i] += e
currentE[i] += e
maxExtrusion[i] = max(maxExtrusion[i], totalExtrusion[i])
else:
e = 0
# If move, calculate new min/max coordinates
if move:
self._travel_minMax.record(oldPos)
self._travel_minMax.record(pos)
if e > 0:
# store as print move if extrusion is > 0
self._print_minMax.record(oldPos)
self._print_minMax.record(pos)
# move time in x, y, z, will be 0 if no movement happened
moveTimeXYZ = abs((oldPos - pos).length / feedrate)
# time needed for extruding, will be 0 if no extrusion happened
extrudeTime = abs(e / feedrate)
# time to add is maximum of both
totalMoveTimeMinute += max(moveTimeXYZ, extrudeTime)
# process layers if there's extrusion
if e:
self._track_layer(pos)
if gcode in ("G2", "G3", "G02", "G03"): # Arc Move
x = getCodeFloat(line, "X")
y = getCodeFloat(line, "Y")
z = getCodeFloat(line, "Z")
e = getCodeFloat(line, "E")
i = getCodeFloat(line, "I")
j = getCodeFloat(line, "J")
r = getCodeFloat(line, "R")
f = getCodeFloat(line, "F")
# this is a move or print head stays on position
move = (
x is not None
or y is not None
or z is not None
or i is not None
or j is not None
or r is not None
)
oldPos = pos
# Use new coordinates if provided. If not provided, use prior coordinates (minus tool offset)
# in absolute and 0.0 in relative mode.
newPos = Vector3D(
x * scale if x is not None else (0.0 if relativeMode else pos.x),
y * scale if y is not None else (0.0 if relativeMode else pos.y),
z * scale if z is not None else (0.0 if relativeMode else pos.z),
)
if relativeMode:
# Relative mode: add to current position
pos += newPos
else:
# Absolute mode: apply tool offsets
pos = newPos
if f is not None and f != 0:
feedrate = f
# get radius and offset
i = 0 if i is None else i
j = 0 if j is None else j
r = math.sqrt(i * i + j * j) if r is None else r
# calculate angles
centerArc = Vector3D(oldPos.x + i, oldPos.y + j, oldPos.z)
startAngle = math.atan2(oldPos.y - centerArc.y, oldPos.x - centerArc.x)
endAngle = math.atan2(pos.y - centerArc.y, pos.x - centerArc.x)
arcAngle = endAngle - startAngle
if gcode in ("G2", "G02"):
startAngle, endAngle = endAngle, startAngle
arcAngle = -arcAngle
if startAngle < 0:
startAngle += math.pi * 2
if endAngle < 0:
endAngle += math.pi * 2
if arcAngle < 0:
arcAngle += math.pi * 2
# from now on we only think in counter-clockwise direction
if e is not None:
if relativeMode or relativeE:
# e is already relative, nothing to do
pass
else:
e -= currentE[currentExtruder]
totalExtrusion[currentExtruder] += e
currentE[currentExtruder] += e
maxExtrusion[currentExtruder] = max(
maxExtrusion[currentExtruder], totalExtrusion[currentExtruder]
)
if currentExtruder == 0 and len(currentE) > 1 and duplicationMode:
# Copy first extruder length to other extruders
for i in range(1, len(currentE)):
totalExtrusion[i] += e
currentE[i] += e
maxExtrusion[i] = max(maxExtrusion[i], totalExtrusion[i])
else:
e = 0
# If move, calculate new min/max coordinates
if move:
self._travel_minMax.record(oldPos)
self._travel_minMax.record(pos)
self._addArcMinMax(
self._travel_minMax, startAngle, endAngle, centerArc, r
)
if e > 0:
# store as print move if extrusion is > 0
self._print_minMax.record(oldPos)
self._print_minMax.record(pos)
self._addArcMinMax(
self._print_minMax, startAngle, endAngle, centerArc, r
)
# calculate 3d arc length
arcLengthXYZ = math.sqrt((oldPos.z - pos.z) ** 2 + (arcAngle * r) ** 2)
# move time in x, y, z, will be 0 if no movement happened
moveTimeXYZ = abs(arcLengthXYZ / feedrate)
# time needed for extruding, will be 0 if no extrusion happened
extrudeTime = abs(e / feedrate)
# time to add is maximum of both
totalMoveTimeMinute += max(moveTimeXYZ, extrudeTime)
# process layers if there's extrusion
if e:
self._track_layer(
pos,
{
"startAngle": startAngle,
"endAngle": endAngle,
"center": centerArc,
"radius": r,
},
)
elif gcode == "G4": # Delay
S = getCodeFloat(line, "S")
if S is not None:
totalMoveTimeMinute += S / 60
P = getCodeFloat(line, "P")
if P is not None:
totalMoveTimeMinute += P / 60 / 1000
elif gcode == "G10": # Firmware retract
totalMoveTimeMinute += fwretractTime
elif gcode == "G11": # Firmware retract recover
totalMoveTimeMinute += fwrecoverTime
elif gcode == "G20": # Units are inches
scale = 25.4
elif gcode == "G21": # Units are mm
scale = 1.0
elif gcode == "G28": # Home
x = getCodeFloat(line, "X")
y = getCodeFloat(line, "Y")
z = getCodeFloat(line, "Z")
origin = Vector3D(0.0, 0.0, 0.0)
if x is None and y is None and z is None:
pos = origin
else:
pos = Vector3D(pos)
if x is not None:
pos.x = origin.x
if y is not None:
pos.y = origin.y
if z is not None:
pos.z = origin.z
elif gcode == "G90": # Absolute position
relativeMode = False
if g90_extruder:
relativeE = False
elif gcode == "G91": # Relative position
relativeMode = True
if g90_extruder:
relativeE = True
elif gcode == "G92":
x = getCodeFloat(line, "X")
y = getCodeFloat(line, "Y")
z = getCodeFloat(line, "Z")
e = getCodeFloat(line, "E")
if e is None and x is None and y is None and z is None:
# no parameters, set all axis to 0
currentE[currentExtruder] = 0.0
pos.x = 0.0
pos.y = 0.0
pos.z = 0.0
else:
# some parameters set, only set provided axes
if e is not None:
currentE[currentExtruder] = e
if x is not None:
pos.x = x
if y is not None:
pos.y = y
if z is not None:
pos.z = z
# M codes
elif gcode == "M82": # Absolute E
relativeE = False
elif gcode == "M83": # Relative E
relativeE = True
elif gcode in ("M207", "M208"): # Firmware retract settings
s = getCodeFloat(line, "S")
f = getCodeFloat(line, "F")
if s is not None and f is not None:
if gcode == "M207":
# Ensure division is valid
if f > 0:
fwretractTime = s / f
else:
fwretractTime = 0
fwretractDist = s
else:
if f > 0:
fwrecoverTime = (fwretractDist + s) / f
else:
fwrecoverTime = 0
elif gcode == "M605": # Duplication/Mirroring mode
s = getCodeInt(line, "S")
if s in [2, 4, 5, 6]:
# Duplication / Mirroring mode selected. Printer firmware copies extrusion commands
# from first extruder to all other extruders
duplicationMode = True
else:
duplicationMode = False
# T codes
elif tool is not None:
if tool > max_extruders:
self._logger.warning(
"GCODE tried to select tool %d, that looks wrong, ignoring for GCODE analysis"
% tool
)
elif tool == currentExtruder:
pass
else:
pos.x -= (
offsets[currentExtruder][0]
if currentExtruder < len(offsets)
else 0
)
pos.y -= (
offsets[currentExtruder][1]
if currentExtruder < len(offsets)
else 0
)
currentExtruder = tool
pos.x += (
offsets[currentExtruder][0]
if currentExtruder < len(offsets)
else 0
)
pos.y += (
offsets[currentExtruder][1]
if currentExtruder < len(offsets)
else 0
)
if len(currentE) <= currentExtruder:
for _ in range(len(currentE), currentExtruder + 1):
currentE.append(0.0)
if len(maxExtrusion) <= currentExtruder:
for _ in range(len(maxExtrusion), currentExtruder + 1):
maxExtrusion.append(0.0)
if len(totalExtrusion) <= currentExtruder:
for _ in range(len(totalExtrusion), currentExtruder + 1):
totalExtrusion.append(0.0)
if gcode or tool:
self._track_command()
if throttle is not None:
throttle(lineNo, readBytes)
if self._progress_callback is not None:
self._progress_callback(100.0)
self.extrusionAmount = maxExtrusion
self.extrusionVolume = [0] * len(maxExtrusion)
for i in range(len(maxExtrusion)):
radius = self._filamentDiameter / 2
self.extrusionVolume[i] = (
self.extrusionAmount[i] * (math.pi * radius * radius)
) / 1000
self.totalMoveTimeMinute = totalMoveTimeMinute
def _parseCuraProfileString(self, comment, prefix):
return {
key: value
for (key, value) in map(
lambda x: x.split(b"=", 1),
zlib.decompress(base64.b64decode(comment[len(prefix) :])).split(b"\b"),
)
}
def _intersectsAngle(self, start, end, angle):
if end < start and angle == 0:
# angle crosses 0 degrees
return True
else:
return start <= angle <= end
def _addArcMinMax(self, minmax, startAngle, endAngle, centerArc, radius):
startDeg = math.degrees(startAngle)
endDeg = math.degrees(endAngle)
if self._intersectsAngle(startDeg, endDeg, 0):
# arc crosses positive x
minmax.max.x = max(minmax.max.x, centerArc.x + radius)
if self._intersectsAngle(startDeg, endDeg, 90):
# arc crosses positive y
minmax.max.y = max(minmax.max.y, centerArc.y + radius)
if self._intersectsAngle(startDeg, endDeg, 180):
# arc crosses negative x
minmax.min.x = min(minmax.min.x, centerArc.x - radius)
if self._intersectsAngle(startDeg, endDeg, 270):
# arc crosses negative y
minmax.min.y = min(minmax.min.y, centerArc.y - radius)
def get_result(self):
result = {
"total_time": self.totalMoveTimeMinute,
"extrusion_length": self.extrusionAmount,
"extrusion_volume": self.extrusionVolume,
"dimensions": self.dimensions,
"printing_area": self.printing_area,
"travel_dimensions": self.travel_dimensions,
"travel_area": self.travel_area,
}
if self._incl_layers:
result["layers"] = self.layers
return result
def getCodeInt(line, code):
return getCode(line, code, int)
def getCodeFloat(line, code):
return getCode(line, code, float)
def getCode(line, code, c):
n = line.find(code) + 1
if n < 1:
return None
m = line.find(" ", n)
try:
if m < 0:
result = c(line[n:])
else:
result = c(line[n:m])
except ValueError:
return None
if math.isnan(result) or math.isinf(result):
return None
return result

View file

@ -17,9 +17,10 @@ from tentacles.blueprints import (
printer_ui,
user_ui,
)
from tentacles.db import Db
from tentacles.globals import _ctx, Ctx, ctx
from tentacles.store import Store
from tentacles.workers import Worker
from tentacles.workers import *
from tentacles.workers import assign_jobs, Worker
@click.group()
@ -28,7 +29,7 @@ def cli():
def db_factory(app):
store = Store(app.config.get("db", {}).get("uri"))
store = Db(app.config.get("db", {}).get("uri"))
store.connect()
return store
@ -38,10 +39,11 @@ def custom_ctx(app, wsgi_app):
store = db_factory(app)
token = _ctx.set(Ctx(store))
try:
return wsgi_app(environ, start_response)
with store.savepoint():
return wsgi_app(environ, start_response)
finally:
_ctx.reset(token)
store.close()
_ctx.reset(token)
return helper
@ -56,21 +58,21 @@ def user_session():
if (
(
(session_id := request.cookies.get("sid", ""))
and (uid := ctx.db.try_key(session_id))
and (row := ctx.db.try_key(kid=session_id))
)
or (
request.authorization
and request.authorization.token
and (uid := ctx.db.try_key(request.authorization.token))
and (row := ctx.db.try_key(kid=request.authorization.token))
)
or (
(api_key := request.headers.get("x-api-key"))
and (uid := ctx.db.try_key(api_key))
and (row := ctx.db.try_key(kid=api_key))
)
):
ctx.sid = session_id
ctx.uid = uid
user = ctx.db.fetch_user(uid)
ctx.sid = row.id
ctx.uid = row.user_id
user = ctx.db.fetch_user(row.user_id)
ctx.gid = user.group_id
ctx.username = user.name
ctx.is_admin = user.group_id == 0
@ -101,14 +103,17 @@ def make_app():
@cli.command()
@click.option("--hostname", "hostname", type=str, default="0.0.0.0")
@click.option("--port", "port", type=int, default=8080)
@click.option("--trace/--no-trace", "trace", default=False)
@click.option("--config", type=Path)
def serve(hostname: str, port: int, config: Path):
def serve(hostname: str, port: int, config: Path, trace: bool):
logging.basicConfig(
format="%(asctime)s %(relativeCreated)6d %(threadName)s - %(name)s - %(levelname)s - %(message)s",
format="%(asctime)s %(threadName)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)
logging.getLogger("tentacles").setLevel(logging.DEBUG)
if trace:
logging.getLogger("tentacles.db").setLevel(logging.DEBUG - 1)
app = make_app()
@ -133,8 +138,14 @@ def serve(hostname: str, port: int, config: Path):
server.shutdown_timeout = 1
server.subscribe()
# Spawn the worker thread
Worker(cherrypy.engine, app, db_factory, frequency=5).start()
# Spawn the worker thread(s)
Worker(cherrypy.engine, app, db_factory, poll_printers, frequency=5).start()
Worker(cherrypy.engine, app, db_factory, assign_jobs, frequency=5).start()
Worker(cherrypy.engine, app, db_factory, push_jobs, frequency=5).start()
Worker(cherrypy.engine, app, db_factory, revoke_jobs, frequency=5).start()
Worker(cherrypy.engine, app, db_factory, pull_jobs, frequency=5).start()
Worker(cherrypy.engine, app, db_factory, send_emails, frequency=5).start()
# Run the server
cherrypy.engine.start()

View file

@ -98,10 +98,12 @@ def create_file(location: Optional[str] = None):
return {"error": "file exists already"}, 409
file.save(sanitized_path)
fid = ctx.db.create_file(ctx.uid, file.filename, sanitized_path)
fid = ctx.db.create_file(
uid=ctx.uid, filename=file.filename, path=sanitized_path
)
if request.form.get("print", "").lower() == "true":
ctx.db.create_job(ctx.uid, fid)
ctx.db.create_job(uid=ctx.uid, fid=fid)
return {"status": "ok"}, 202
@ -119,7 +121,7 @@ def get_files():
"owner": ctx.uid,
"upload_date": f.upload_date,
}
for f in ctx.db.list_files(ctx.uid)
for f in ctx.db.list_files(uid=ctx.uid)
]
}, 200
@ -150,7 +152,7 @@ def get_jobs():
"finished_at": j.finished_at,
"printer_id": j.printer_id,
}
for j in ctx.db.list_jobs()
for j in ctx.db.list_jobs(uid=ctx.uid)
]
}, 200

View file

@ -42,9 +42,10 @@ def manipulate_files():
return render_template("files.html.j2"), code
case "delete":
file = ctx.db.fetch_file(ctx.uid, int(request.form.get("file_id")))
file = ctx.db.fetch_file(uid=ctx.uid, fid=int(request.form.get("file_id")))
if any(
job.finished_at is None for job in ctx.db.list_jobs_by_file(file.id)
job.finished_at is None
for job in ctx.db.list_jobs_by_file(uid=ctx.uid, fid=file.id)
):
flash("File is in use", category="error")
return render_template("files.html.j2"), 400
@ -52,7 +53,7 @@ def manipulate_files():
if os.path.exists(file.path):
os.unlink(file.path)
ctx.db.delete_file(ctx.uid, file.id)
ctx.db.delete_file(uid=ctx.uid, fid=file.id)
flash("File deleted", category="info")
case _:

View file

@ -29,22 +29,24 @@ def list_jobs():
def manipulate_jobs():
match request.form.get("action"):
case "enqueue":
ctx.db.create_job(ctx.uid, int(request.form.get("file_id")))
ctx.db.create_job(uid=ctx.uid, fid=int(request.form.get("file_id")))
flash("Job created!", category="info")
case "duplicate":
if job := ctx.db.fetch_job(ctx.uid, int(request.form.get("job_id"))):
ctx.db.create_job(ctx.uid, job.file_id)
if job := ctx.db.fetch_job(
uid=ctx.uid, jid=int(request.form.get("job_id"))
):
ctx.db.create_job(uid=ctx.uid, fid=job.file_id)
flash("Job created!", category="info")
else:
flash("Could not duplicate", category="error")
case "cancel":
ctx.db.cancel_job(ctx.uid, int(request.form.get("job_id")))
ctx.db.cancel_job(uid=ctx.uid, jid=int(request.form.get("job_id")))
flash("Cancellation reqested", category="info")
case "delete":
ctx.db.delete_job(ctx.uid, int(request.form.get("job_id")))
ctx.db.delete_job(uid=ctx.uid, jid=int(request.form.get("job_id")))
flash("Job deleted", category="info")
case _:

View file

@ -36,9 +36,10 @@ def add_printer():
assert request.form["url"]
assert request.form["api_key"]
ctx.db.try_create_printer(
request.form["name"],
request.form["url"],
request.form["api_key"],
name=request.form["name"],
url=request.form["url"],
api_key=request.form["api_key"],
sid=0, # Disconnected
)
flash("Printer created")
return redirect("/printers")

View file

@ -42,13 +42,13 @@ def get_login():
@BLUEPRINT.route("/user/login", methods=["POST"])
def post_login():
if sid := ctx.db.try_login(
username := request.form["username"],
salt(request.form["password"]),
timedelta(days=1),
if row := ctx.db.try_login(
username=(username := request.form["username"]),
password=salt(request.form["password"]),
ttl=timedelta(days=1),
):
resp = redirect("/")
resp.set_cookie("sid", sid)
resp.set_cookie("sid", row.id)
flash(f"Welcome, {username}", category="success")
return resp
@ -72,7 +72,7 @@ def post_register():
username = request.form["username"]
email = request.form["email"]
group_id = 1 # Normal users
status_id = -2 # Unverified
status_id = -3 # Unverified
for user_config in current_app.config.get("users", []):
if user_config["email"] == email:
@ -85,13 +85,17 @@ def post_register():
break
if user := ctx.db.try_create_user(
username, email, salt(request.form["password"]), group_id, status_id
username=username,
email=email,
password=salt(request.form["password"]),
gid=group_id,
sid=status_id,
):
if user.status_id == -2:
if user.status_id == -3:
ctx.db.create_email(
user.id,
"Tentacles email confirmation",
render_template(
uid=user.id,
subject="Tentacles email confirmation",
body=render_template(
"verification_email.html.j2",
username=user.name,
token_id=user.verification_token,
@ -103,6 +107,10 @@ def post_register():
"Please check your email for a verification request",
category="success",
)
elif user.status_id == 1:
flash("Welcome, please log in", category="success")
return render_template("register.html.j2")
except Exception as e:
@ -115,7 +123,7 @@ def post_register():
@BLUEPRINT.route("/user/logout")
def logout():
# Invalidate the user's authorization
ctx.db.delete_key(ctx.uid, ctx.sid)
ctx.db.delete_key(uid=ctx.uid, kid=ctx.sid)
resp = redirect("/")
resp.set_cookie("sid", "")
return resp
@ -132,7 +140,7 @@ def get_settings():
@BLUEPRINT.route("/user", methods=["POST"])
def post_settings():
if request.form["action"] == "add":
ttl_spec = request.form.get("ttl")
ttl_spec = request.form.get("ttl", "")
if ttl_spec == "forever":
ttl = None
elif m := re.fullmatch(r"(\d+)d", ttl_spec):
@ -141,7 +149,9 @@ def post_settings():
flash("Bad request", category="error")
return render_template("user.html.j2"), 400
ctx.db.create_key(ctx.sid, ttl, request.form.get("name"))
ctx.db.create_key(
uid=ctx.uid, ttl=ttl, name=request.form.get("name", "anonymous")
)
flash("Key created", category="success")
elif request.form["action"] == "revoke":

View file

@ -8,6 +8,7 @@ from importlib.resources import files
from inspect import signature
import logging
import sqlite3
from time import sleep
from types import GeneratorType, new_class
from typing import Optional
@ -27,7 +28,7 @@ def qfn(name, f):
# Force lazy values for convenience
if isinstance(res, GeneratorType):
res = list(res)
print("%s -> %r" % (name, res))
log.log(logging.DEBUG - 1, "%s (%r) -> %r", name, kwargs, res)
return res
_helper.__name__ = f.__name__
@ -64,7 +65,7 @@ class LoginError(StoreError):
pass
class Store(Queries):
class Db(Queries):
def __init__(self, path):
self._path = path
self._conn: sqlite3.Connection = None
@ -106,9 +107,24 @@ class Store(Queries):
try:
self.begin()
yield self
self.commit()
exc = None
for attempt in range(5):
try:
self.commit()
break
except sqlite3.OperationalError as e:
exc = e
if e.sqlite_errorcode == 6:
sleep(0.1 * attempt)
continue
else:
raise e
else:
raise exc
except sqlite3.Error:
self.rollback()
log.exception("Forced to roll back!")
return _helper()
@ -122,9 +138,11 @@ class Store(Queries):
################################################################################
# Wrappers for doing Python type mapping
def create_key(self, *, uid: int, name: str, ttl: timedelta):
def create_key(self, *, uid: int, name: str, ttl: Optional[timedelta]):
return super().create_key(
uid=uid, name=name, expiration=(datetime.now() + ttl).isoformat()
uid=uid,
name=name,
expiration=((datetime.now() + ttl).isoformat() if ttl else None),
)
def try_login(
@ -171,3 +189,6 @@ class Store(Queries):
"""
super().refresh_key(kid=kid, expiration=(datetime.now() + ttl).isoformat())
def finish_job(self, *, jid: int, state: str, message: Optional[str] = None):
super().finish_job(jid=jid, state=state, message=message)

View file

@ -3,13 +3,13 @@
from contextvars import ContextVar
from attrs import define
from tentacles.store import Store
from tentacles.db import Db
from werkzeug.local import LocalProxy
@define
class Ctx:
db: Store
db: Db
uid: int = None
gid: int = None
sid: str = None

View file

@ -93,7 +93,6 @@ CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT
, user_id INTEGER NOT NULL
, file_id INTEGER NOT NULL
, priority INTEGER CHECK(priority IS NOT NULL AND 0 <= priority)
, started_at TEXT
, cancelled_at TEXT
, finished_at TEXT
@ -198,7 +197,8 @@ INSERT INTO user_keys (
)
VALUES (:uid, :name, :expiration)
RETURNING
id
id
, user_id
;
-- name: try-login^
@ -232,7 +232,8 @@ WHERE
-- name: try-key^
SELECT
user_id
id
, user_id
FROM user_keys
WHERE
(expiration IS NULL OR unixepoch(expiration) > unixepoch('now'))
@ -265,7 +266,7 @@ INSERT INTO printers (
, api_key
, status_id
)
VALUES (:name, :url, :api_key, :status_id)
VALUES (:name, :url, :api_key, :sid)
RETURNING
id
;
@ -310,10 +311,10 @@ WHERE
-- name: update-printer-status!
UPDATE printers
SET
status_id = (SELECT id FROM printer_statuses WHERE name = :status)
status_id = (SELECT id FROM printer_statuses WHERE name = :status or id = :status)
, last_poll_date = datetime('now')
WHERE
id = :uid
id = :pid
;
----------------------------------------------------------------------------------------------------
@ -364,16 +365,10 @@ WHERE
INSERT INTO jobs (
user_id
, file_id
, priority
)
VALUES (
:uid
, :fid,
, (
SELECT priority + :priority
FROM users
WHERE uid = :uid
)
, :fid
)
RETURNING
id
@ -384,7 +379,7 @@ SELECT
FROM jobs
WHERE
user_id = :uid
AND id = :fid
AND id = :jid
;
-- name: list-jobs
@ -401,6 +396,7 @@ SELECT
FROM jobs
WHERE
file_id = :fid
, uid = :uid
;
-- name: list-job-queue
@ -410,11 +406,9 @@ FROM jobs
WHERE
finished_at IS NULL
AND (:uid IS NULL OR user_id = :uid)
ORDER BY
priority DESC
;
-- name: poll-job-queue
-- name: poll-job-queue^
SELECT
*
FROM jobs
@ -422,8 +416,6 @@ WHERE
started_at IS NULL
AND finished_at IS NULL
AND printer_id IS NULL
ORDER BY
priority DESC
LIMIT 1
;

View file

@ -35,7 +35,7 @@
{% macro job_state(job) %}
{{ 'queued' if (not job.finished_at and not job.printer_id and not job.cancelled_at) else
'running' if (not job.finished_at and job.printer_id and not job.cancelled_at) else
'cancelling' if (not job.finished_at and job.printer_id and job.cancelled_at) else
'cancelling' if (not job.finished_at and job.cancelled_at) else
job.state }}
{% endmacro %}

View file

@ -8,11 +8,10 @@ Mostly related to monitoring and managing Printer state.
"""
from contextlib import closing
from datetime import datetime, timedelta
import logging
from pathlib import Path
from threading import Event
from time import sleep
from typing import Callable
from urllib import parse as urlparse
from cherrypy.process.plugins import Monitor
@ -25,7 +24,7 @@ from requests.exceptions import (
HTTPError,
Timeout,
)
from tentacles.store import Store
from tentacles.db import Db
class OctoRest(_OR):
@ -46,36 +45,16 @@ log = logging.getLogger(__name__)
SHUTDOWN = Event()
def corn_job(every: timedelta):
def _decorator(f):
def _helper(*args, **kwargs):
last = None
while not SHUTDOWN.is_set():
if not last or (datetime.now() - last) > every:
log.debug(f"Ticking job {f.__name__}")
try:
last = datetime.now()
f(*args, **kwargs)
except Exception:
log.exception(f"Error while procesing task {f.__name__}")
else:
sleep(1)
return _helper
return _decorator
def poll_printers(app: App, store: Store) -> None:
def poll_printers(app: App, db: Db) -> None:
"""Poll printers for their status."""
for printer in store.list_printers():
mapped_job = store.fetch_job_by_printer(printer.id)
for printer in db.list_printers():
mapped_job = db.fetch_job_by_printer(pid=printer.id)
def _set_status(status: str):
if printer.status != status:
print(f"Printer {printer.id} {printer.status} -> {status}")
store.update_printer_status(printer.id, status)
log.info(f"Printer {printer.id} {printer.status} -> {status}")
db.update_printer_status(pid=printer.id, status=status)
try:
client = OctoRest(url=printer.url, apikey=printer.api_key)
@ -91,7 +70,7 @@ def poll_printers(app: App, store: Store) -> None:
# polling tasks. This violates separation of concerns a bit,
# but appears required for correctness.
if mapped_job:
store.finish_job(mapped_job.id, "error")
db.finish_job(jid=mapped_job.id, state="error")
_set_status("error")
@ -129,24 +108,24 @@ def poll_printers(app: App, store: Store) -> None:
)
def assign_jobs(app: App, store: Store) -> None:
def assign_jobs(app: App, db: Db) -> None:
"""Assign jobs to printers. Uploading files and job state management is handled separately."""
for printer_id in store.list_idle_printers():
if job_id := store.poll_job_queue():
store.assign_job(job_id, printer_id)
print(f"Mapped job {job_id} to printer {printer_id}")
for printer in db.list_idle_printers():
if job := db.poll_job_queue():
db.assign_job(jid=job.id, pid=printer.id)
log.info(f"Mapped job {job.id} to printer {printer.id}")
def push_jobs(app: App, store: Store) -> None:
def push_jobs(app: App, db: Db) -> None:
"""Ensure that job files are uploaded and started to the assigned printer."""
for job in store.list_mapped_jobs():
printer = store.fetch_printer(job.printer_id)
file = store.fetch_file(job.user_id, job.file_id)
for job in db.list_mapped_jobs():
printer = db.fetch_printer(pid=job.printer_id)
file = db.fetch_file(uid=job.user_id, fid=job.file_id)
if not file:
log.error(f"Job {job.id} no longer maps to a file")
store.delete_job(job.user_id, job.id)
db.delete_job(job.user_id, job.id)
try:
client = OctoRest(url=printer.url, apikey=printer.api_key)
@ -157,11 +136,15 @@ def push_jobs(app: App, store: Store) -> None:
printer_state = {"error": printer_job.get("error")}
if printer_state.get("error"):
print(f"Printer {printer.id} is in error, can't push")
log.warn(f"Printer {printer.id} is in error, can't push")
continue
try:
client.upload(file.path)
if not client.files_info("local", Path(file.path).name):
client.upload(file.path)
else:
log.info("Don't need to upload the job!")
except HTTPError as e:
if e.response.status_code == 409:
pass
@ -170,7 +153,7 @@ def push_jobs(app: App, store: Store) -> None:
client.select(Path(file.path).name)
client.start()
store.start_job(job.id)
db.start_job(job.id)
except TimeoutError:
pass
@ -178,19 +161,19 @@ def push_jobs(app: App, store: Store) -> None:
log.exception("Oop")
def revoke_jobs(app: App, store: Store) -> None:
def revoke_jobs(app: App, db: Db) -> None:
"""Ensure that job files are uploaded and started to the assigned printer.
Note that this will ALSO cancel jobs out of the print queue.
"""
for job in store.list_cancelled_jobs():
for job in db.list_canceling_jobs():
if job.printer_id:
printer = store.fetch_printer(job.printer_id)
printer = db.fetch_printer(pid=job.printer_id)
try:
print(f"Cancelling running job {job.id}")
log.info(f"Cancelling running job {job.id}")
client = OctoRest(url=printer.url, apikey=printer.api_key)
try:
client.cancel()
@ -200,8 +183,8 @@ def revoke_jobs(app: App, store: Store) -> None:
else:
raise
print(f"Job {job.id} -> cancelled")
store.finish_job(job.id, "cancelled")
log.info(f"Job {job.id} -> cancelled")
db.finish_job(jid=job.id, state="cancelled")
except TimeoutError:
pass
@ -210,15 +193,15 @@ def revoke_jobs(app: App, store: Store) -> None:
log.exception("Oop")
else:
print(f"Unmapped job {job.id} became cancelled")
store.finish_job(job.id, "cancelled")
log.info(f"Unmapped job {job.id} became cancelled")
db.finish_job(jid=job.id, state="cancelled")
def pull_jobs(app: App, store: Store) -> None:
def pull_jobs(app: App, db: Db) -> None:
"""Poll the state of mapped printers to control jobs."""
for job in store.list_running_jobs():
printer = store.fetch_printer(job.printer_id)
for job in db.list_running_jobs():
printer = db.fetch_printer(pid=job.printer_id)
try:
client = OctoRest(url=printer.url, apikey=printer.api_key)
job_state = client.job_info()
@ -231,19 +214,19 @@ def pull_jobs(app: App, store: Store) -> None:
pass
elif job_state.get("progress", {}).get("completion", 0.0) == 100.0:
print(f"Job {job.id} has succeeded")
store.finish_job(job.id, "success")
log.info(f"Job {job.id} has succeeded")
db.finish_job(jid=job.id, state="success")
elif printer_state.get("error"):
print(f"Job {job.id} has failed")
store.finish_job(job.id, "failed")
log.warn(f"Job {job.id} has failed")
db.finish_job(jid=job.id, state="failed")
elif printer_state.get("cancelling"):
print(f"Job {job.id} has been acknowledged as cancelled")
store.finish_job(job.id, "cancelled")
log.info(f"Job {job.id} has been acknowledged as cancelled")
db.finish_job(jid=job.id, state="cancelled")
else:
print(
log.warn(
f"Job {job.id} is in a weird state {job_state.get('progress')!r} {printer_state!r}"
)
@ -254,35 +237,60 @@ def pull_jobs(app: App, store: Store) -> None:
log.exception("Oop")
def send_emails(app, store: Store):
def send_emails(app, db: Db):
with closing(
FastMailSMTP(
app.config.get("fastmail", {}).get("username"),
app.config.get("fastmail", {}).get("key"),
)
) as fm:
for message in store.poll_spool():
for message in db.poll_email_queue():
fm.send_message(
from_addr="root@tirefireind.us",
to_addrs=[message.to],
subject=message.subject,
msg=message.body,
)
store.send_email(message.id)
db.send_email(message.id)
def once(f):
val = uninitialized = object()
def _helper(*args, **kwargs):
nonlocal val
if val is uninitialized:
val = f(*args, **kwargs)
return val
return _helper
def toil(*fs):
def _helper(*args, **kwargs):
for f in fs:
f(*args, **kwargs)
_helper.__name__ = "toil"
return _helper
class Worker(Monitor):
def __init__(self, bus, app, db_factory, **kwargs):
def __init__(
self,
bus,
app: App,
db_factory: Callable[[App], Db],
callback: Callable[[App, Db], None],
**kwargs,
):
self._app = app
self._db_factory = db_factory
super().__init__(bus, self.callback, **kwargs)
self._callback = callback
super().__init__(
bus, self.callback, **kwargs, name=f"Async {callback.__name__}"
)
def callback(self):
log.debug("Tick")
with self._app.app_context(), closing(self._db_factory(self._app)) as store:
poll_printers(self._app, store)
assign_jobs(self._app, store)
push_jobs(self._app, store)
revoke_jobs(self._app, store)
pull_jobs(self._app, store)
send_emails(self._app, store)
with closing(self._db_factory(self._app)) as db:
self._callback(self._app, db)

View file

@ -3,12 +3,12 @@
from datetime import timedelta
import pytest
import tentacles.store as s
from tentacles.db import Db
@pytest.yield_fixture
def store():
conn = s.Store(":memory:")
def db():
conn = Db(":memory:")
conn.connect()
yield conn
conn.close()
@ -25,9 +25,9 @@ def password_testy():
@pytest.fixture
def uid_testy(store: s.Store, username_testy, password_testy):
with store.savepoint():
return store.try_create_user(
def uid_testy(db: Db, username_testy, password_testy):
with db.savepoint():
return db.try_create_user(
username=username_testy,
email=username_testy,
password=password_testy,
@ -41,8 +41,10 @@ def login_ttl():
@pytest.fixture
def sid_testy(store, uid_testy, username_testy, password_testy, login_ttl):
with store.savepoint():
return store.try_login(
def sid_testy(db: Db, uid_testy, username_testy, password_testy, login_ttl):
with db.savepoint():
res = db.try_login(
username=username_testy, password=password_testy, ttl=login_ttl
).id
)
assert res.user_id == uid_testy
return res.id

View file

@ -1,13 +1,13 @@
#!/usr/bin/env python3
from tentacles.store import Store
from tentacles.db import Db
def test_store_initializes(store: Store):
assert isinstance(store, Store)
def test_db_initializes(store: Db):
assert isinstance(store, Db)
def test_store_savepoint(store: Store):
def test_db_savepoint(store: Db):
obj = store.savepoint()
assert hasattr(obj, "__enter__")