Create a ctx global to replace request abuse

This commit is contained in:
Reid 'arrdem' McKenzie 2023-05-27 00:30:39 -06:00
parent bedad7d86b
commit 409d04a648
9 changed files with 75 additions and 44 deletions

View file

@ -12,6 +12,7 @@ from datetime import datetime
from tentacles.blueprints import user_ui, printer_ui, api from tentacles.blueprints import user_ui, printer_ui, api
from tentacles.store import Store from tentacles.store import Store
from tentacles.globals import _ctx, Ctx, ctx
@click.group() @click.group()
@ -19,39 +20,42 @@ def cli():
pass pass
def open_db(): def custom_ctx(app, wsgi_app):
request.db = Store(current_app.config.get("db", {}).get("uri")) def helper(environ, start_response):
request.db.connect() store = Store(app.config.get("db", {}).get("uri"))
store.connect()
token = _ctx.set(Ctx(store))
try:
return wsgi_app(environ, start_response)
finally:
_ctx.reset(token)
store.close()
return helper
def commit_db(resp):
request.db.close()
return resp
def create_j2_request_global(): def create_j2_request_global():
current_app.jinja_env.globals["ctx"] = ctx
current_app.jinja_env.globals["request"] = request current_app.jinja_env.globals["request"] = request
current_app.jinja_env.globals["datetime"] = datetime current_app.jinja_env.globals["datetime"] = datetime
def user_session(): def user_session():
if (session_id := request.cookies.get("sid", "")) and ( if (session_id := request.cookies.get("sid", "")) and (
uid := request.db.try_key(session_id) uid := ctx.db.try_key(session_id)
): ):
request.sid = session_id ctx.sid = session_id
request.uid = uid ctx.uid = uid
_id, gid, name, _email, _hash, _status, _verification = request.db.fetch_user( _id, gid, name, _email, _hash, _status, _verification = ctx.db.fetch_user(uid)
uid ctx.gid = gid
) ctx.username = name
request.gid = gid ctx.is_admin = gid == 0
request.username = name
request.is_admin = gid == 0
else: else:
request.sid = None ctx.sid = None
request.uid = None ctx.uid = None
request.gid = None ctx.gid = None
request.username = None ctx.username = None
request.is_admin = False ctx.is_admin = False
@cli.command() @cli.command()
@ -71,17 +75,17 @@ def serve(hostname: str, port: int, config: Path):
create_j2_request_global() create_j2_request_global()
# Before request # Before request
app.before_request(open_db)
app.before_request(user_session) app.before_request(user_session)
# After request
app.after_request(commit_db)
# Blueprints # Blueprints
app.register_blueprint(user_ui.BLUEPRINT) app.register_blueprint(user_ui.BLUEPRINT)
app.register_blueprint(printer_ui.BLUEPRINT) app.register_blueprint(printer_ui.BLUEPRINT)
app.register_blueprint(api.BLUEPRINT) app.register_blueprint(api.BLUEPRINT)
# Shove our middleware in there
app.wsgi_app = custom_ctx(app, app.wsgi_app)
# And run the blame thing
app.run(host=hostname, port=port) app.run(host=hostname, port=port)

View file

@ -20,6 +20,7 @@ from flask import (
flash, flash,
) )
from tentacles.globals import ctx
from .util import is_logged_in, requires_admin from .util import is_logged_in, requires_admin
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,7 +36,7 @@ def printers():
@requires_admin @requires_admin
@BLUEPRINT.route("/printers/add", methods=["get", "post"]) @BLUEPRINT.route("/printers/add", methods=["get", "post"])
def add_printer(): def add_printer():
if not is_logged_in(request): if not is_logged_in():
return redirect("/") return redirect("/")
elif request.method == "POST": elif request.method == "POST":
@ -43,7 +44,7 @@ def add_printer():
assert request.form["name"] assert request.form["name"]
assert request.form["url"] assert request.form["url"]
assert request.form["api_key"] assert request.form["api_key"]
request.db.try_create_printer( ctx.db.try_create_printer(
request.form["name"], request.form["name"],
request.form["url"], request.form["url"],
request.form["api_key"], request.form["api_key"],

View file

@ -8,6 +8,8 @@ from datetime import timedelta, datetime
from importlib.resources import files from importlib.resources import files
import re import re
from tentacles.globals import ctx
from click import group from click import group
from flask import ( from flask import (
Blueprint, Blueprint,
@ -39,11 +41,11 @@ def root():
@BLUEPRINT.route("/user/login", methods=["GET", "POST"]) @BLUEPRINT.route("/user/login", methods=["GET", "POST"])
def login(): def login():
if is_logged_in(request): if is_logged_in():
return redirect("/") return redirect("/")
elif request.method == "POST": elif request.method == "POST":
if sid := request.db.try_login( if sid := ctx.db.try_login(
username := request.form["username"], username := request.form["username"],
salt(request.form["password"]), salt(request.form["password"]),
timedelta(days=1), timedelta(days=1),
@ -63,7 +65,7 @@ def login():
@BLUEPRINT.route("/user/register", methods=["GET", "POST"]) @BLUEPRINT.route("/user/register", methods=["GET", "POST"])
def register(): def register():
if is_logged_in(request): if is_logged_in():
return redirect("/") return redirect("/")
elif request.method == "POST": elif request.method == "POST":
@ -83,7 +85,7 @@ def register():
break break
if res := request.db.try_create_user( if res := ctx.db.try_create_user(
username, email, salt(request.form["password"]), group_id, status_id username, email, salt(request.form["password"]), group_id, status_id
): ):
id, status = res id, status = res
@ -107,7 +109,7 @@ def register():
@BLUEPRINT.route("/user/logout") @BLUEPRINT.route("/user/logout")
def logout(): def logout():
# Invalidate the user's authorization # Invalidate the user's authorization
request.db.delete_key(request.sid) ctx.db.delete_key(ctx.uid, ctx.sid)
resp = redirect("/") resp = redirect("/")
resp.set_cookie("sid", "") resp.set_cookie("sid", "")
return resp return resp
@ -115,7 +117,7 @@ def logout():
@BLUEPRINT.route("/user", methods=["GET", "POST"]) @BLUEPRINT.route("/user", methods=["GET", "POST"])
def settings(): def settings():
if not is_logged_in(request): if not is_logged_in():
return redirect("/") return redirect("/")
elif request.method == "POST": elif request.method == "POST":
@ -129,11 +131,11 @@ def settings():
flash("Bad request", category="error") flash("Bad request", category="error")
return render_template("user.html.j2"), 400 return render_template("user.html.j2"), 400
request.db.create_key(request.sid, ttl, request.form.get("name")) ctx.db.create_key(ctx.sid, ttl, request.form.get("name"))
flash("Key created", category="success") flash("Key created", category="success")
elif request.form["action"] == "revoke": elif request.form["action"] == "revoke":
request.db.delete_key(request.uid, request.form.get("id")) ctx.db.delete_key(ctx.uid, request.form.get("id"))
flash("Key revoked", category="success") flash("Key revoked", category="success")
else: else:

View file

@ -18,11 +18,13 @@ from flask import (
flash, flash,
) )
from tentacles.globals import ctx
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def is_logged_in(request: Request) -> bool: def is_logged_in() -> bool:
return request.uid is not None return ctx.uid is not None
def salt(password: str) -> str: def salt(password: str) -> str:
@ -31,7 +33,7 @@ def salt(password: str) -> str:
def requires_admin(f): def requires_admin(f):
def _helper(*args, **kwargs): def _helper(*args, **kwargs):
if not request.is_admin: if not ctx.is_admin:
flash("Sorry, admins only", category="error") flash("Sorry, admins only", category="error")
redirect("/") redirect("/")
else: else:

View file

@ -0,0 +1,22 @@
#!/usr/bin/env python3
from contextvars import ContextVar
from typing import Optional
from attrs import define
from tentacles.store import Store
from werkzeug.local import LocalProxy
@define
class Ctx:
db: Store
uid: int = None
gid: int = None
sid: str = None
username: str = None
is_admin: bool = None
_ctx = ContextVar("tentacles.ctx")
ctx: Ctx = LocalProxy(_ctx)

View file

@ -25,11 +25,11 @@
</label> </label>
<ul class="menu"> <ul class="menu">
{% if not request.uid %} {% if not ctx.uid %}
<li><a href="/user/login">Log in</a></li> <li><a href="/user/login">Log in</a></li>
<li><a href="/user/register">Register</a></li> <li><a href="/user/register">Register</a></li>
{% else %} {% else %}
{% if request.is_admin %} {% if ctx.is_admin %}
<li><a href="/printers">Printers</a></li> <li><a href="/printers">Printers</a></li>
{% endif %} {% endif %}
<li><a href="/user">Settings</a></li> <li><a href="/user">Settings</a></li>

View file

@ -2,7 +2,7 @@
{% block content %} {% block content %}
<div class="panel queue"> <div class="panel queue">
<h2>Queue</h2> <h2>Queue</h2>
{% with jobs = request.db.list_jobs(uid=request.uid) %} {% with jobs = ctx.db.list_jobs(uid=request.uid) %}
{% if jobs %} {% if jobs %}
<ul> <ul>
{% for job in jobs %} {% for job in jobs %}
@ -18,7 +18,7 @@
{% if request.uid %} {% if request.uid %}
<div class="panel files"> <div class="panel files">
<h2>Files</h2> <h2>Files</h2>
{% with files = request.db.list_files(uid=request.uid) %} {% with files = ctx.db.list_files(uid=request.uid) %}
{% if files %} {% if files %}
<ul> <ul>
{% for file in files %} {% for file in files %}

View file

@ -2,7 +2,7 @@
{% block content %} {% block content %}
<div class="panel printers"> <div class="panel printers">
<h2>Printers</h2> <h2>Printers</h2>
{% with printers = request.db.list_printers() %} {% with printers = ctx.db.list_printers() %}
{% if printers %} {% if printers %}
<ul> <ul>
{% for printer in printers %} {% for printer in printers %}

View file

@ -3,7 +3,7 @@
<h1>User settings</h1> <h1>User settings</h1>
<div class=""> <div class="">
<h2>API keys</h2> <h2>API keys</h2>
{% with keys = request.db.list_keys(request.uid) %} {% with keys = ctx.db.list_keys(ctx.uid) %}
<ul> <ul>
{% for id, name, exp in keys if name != 'web session' %} {% for id, name, exp in keys if name != 'web session' %}
<li> <li>