Create a ctx global to replace request abuse
This commit is contained in:
parent
bedad7d86b
commit
409d04a648
9 changed files with 75 additions and 44 deletions
|
@ -12,6 +12,7 @@ from datetime import datetime
|
|||
|
||||
from tentacles.blueprints import user_ui, printer_ui, api
|
||||
from tentacles.store import Store
|
||||
from tentacles.globals import _ctx, Ctx, ctx
|
||||
|
||||
|
||||
@click.group()
|
||||
|
@ -19,39 +20,42 @@ def cli():
|
|||
pass
|
||||
|
||||
|
||||
def open_db():
|
||||
request.db = Store(current_app.config.get("db", {}).get("uri"))
|
||||
request.db.connect()
|
||||
def custom_ctx(app, wsgi_app):
|
||||
def helper(environ, start_response):
|
||||
store = Store(app.config.get("db", {}).get("uri"))
|
||||
store.connect()
|
||||
token = _ctx.set(Ctx(store))
|
||||
try:
|
||||
return wsgi_app(environ, start_response)
|
||||
finally:
|
||||
_ctx.reset(token)
|
||||
store.close()
|
||||
|
||||
|
||||
def commit_db(resp):
|
||||
request.db.close()
|
||||
return resp
|
||||
return helper
|
||||
|
||||
|
||||
def create_j2_request_global():
|
||||
current_app.jinja_env.globals["ctx"] = ctx
|
||||
current_app.jinja_env.globals["request"] = request
|
||||
current_app.jinja_env.globals["datetime"] = datetime
|
||||
|
||||
|
||||
def user_session():
|
||||
if (session_id := request.cookies.get("sid", "")) and (
|
||||
uid := request.db.try_key(session_id)
|
||||
uid := ctx.db.try_key(session_id)
|
||||
):
|
||||
request.sid = session_id
|
||||
request.uid = uid
|
||||
_id, gid, name, _email, _hash, _status, _verification = request.db.fetch_user(
|
||||
uid
|
||||
)
|
||||
request.gid = gid
|
||||
request.username = name
|
||||
request.is_admin = gid == 0
|
||||
ctx.sid = session_id
|
||||
ctx.uid = uid
|
||||
_id, gid, name, _email, _hash, _status, _verification = ctx.db.fetch_user(uid)
|
||||
ctx.gid = gid
|
||||
ctx.username = name
|
||||
ctx.is_admin = gid == 0
|
||||
else:
|
||||
request.sid = None
|
||||
request.uid = None
|
||||
request.gid = None
|
||||
request.username = None
|
||||
request.is_admin = False
|
||||
ctx.sid = None
|
||||
ctx.uid = None
|
||||
ctx.gid = None
|
||||
ctx.username = None
|
||||
ctx.is_admin = False
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
@ -71,17 +75,17 @@ def serve(hostname: str, port: int, config: Path):
|
|||
create_j2_request_global()
|
||||
|
||||
# Before request
|
||||
app.before_request(open_db)
|
||||
app.before_request(user_session)
|
||||
|
||||
# After request
|
||||
app.after_request(commit_db)
|
||||
|
||||
# Blueprints
|
||||
app.register_blueprint(user_ui.BLUEPRINT)
|
||||
app.register_blueprint(printer_ui.BLUEPRINT)
|
||||
app.register_blueprint(api.BLUEPRINT)
|
||||
|
||||
# Shove our middleware in there
|
||||
app.wsgi_app = custom_ctx(app, app.wsgi_app)
|
||||
|
||||
# And run the blame thing
|
||||
app.run(host=hostname, port=port)
|
||||
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from flask import (
|
|||
flash,
|
||||
)
|
||||
|
||||
from tentacles.globals import ctx
|
||||
from .util import is_logged_in, requires_admin
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -35,7 +36,7 @@ def printers():
|
|||
@requires_admin
|
||||
@BLUEPRINT.route("/printers/add", methods=["get", "post"])
|
||||
def add_printer():
|
||||
if not is_logged_in(request):
|
||||
if not is_logged_in():
|
||||
return redirect("/")
|
||||
|
||||
elif request.method == "POST":
|
||||
|
@ -43,7 +44,7 @@ def add_printer():
|
|||
assert request.form["name"]
|
||||
assert request.form["url"]
|
||||
assert request.form["api_key"]
|
||||
request.db.try_create_printer(
|
||||
ctx.db.try_create_printer(
|
||||
request.form["name"],
|
||||
request.form["url"],
|
||||
request.form["api_key"],
|
||||
|
|
|
@ -8,6 +8,8 @@ from datetime import timedelta, datetime
|
|||
from importlib.resources import files
|
||||
import re
|
||||
|
||||
from tentacles.globals import ctx
|
||||
|
||||
from click import group
|
||||
from flask import (
|
||||
Blueprint,
|
||||
|
@ -39,11 +41,11 @@ def root():
|
|||
|
||||
@BLUEPRINT.route("/user/login", methods=["GET", "POST"])
|
||||
def login():
|
||||
if is_logged_in(request):
|
||||
if is_logged_in():
|
||||
return redirect("/")
|
||||
|
||||
elif request.method == "POST":
|
||||
if sid := request.db.try_login(
|
||||
if sid := ctx.db.try_login(
|
||||
username := request.form["username"],
|
||||
salt(request.form["password"]),
|
||||
timedelta(days=1),
|
||||
|
@ -63,7 +65,7 @@ def login():
|
|||
|
||||
@BLUEPRINT.route("/user/register", methods=["GET", "POST"])
|
||||
def register():
|
||||
if is_logged_in(request):
|
||||
if is_logged_in():
|
||||
return redirect("/")
|
||||
|
||||
elif request.method == "POST":
|
||||
|
@ -83,7 +85,7 @@ def register():
|
|||
|
||||
break
|
||||
|
||||
if res := request.db.try_create_user(
|
||||
if res := ctx.db.try_create_user(
|
||||
username, email, salt(request.form["password"]), group_id, status_id
|
||||
):
|
||||
id, status = res
|
||||
|
@ -107,7 +109,7 @@ def register():
|
|||
@BLUEPRINT.route("/user/logout")
|
||||
def logout():
|
||||
# Invalidate the user's authorization
|
||||
request.db.delete_key(request.sid)
|
||||
ctx.db.delete_key(ctx.uid, ctx.sid)
|
||||
resp = redirect("/")
|
||||
resp.set_cookie("sid", "")
|
||||
return resp
|
||||
|
@ -115,7 +117,7 @@ def logout():
|
|||
|
||||
@BLUEPRINT.route("/user", methods=["GET", "POST"])
|
||||
def settings():
|
||||
if not is_logged_in(request):
|
||||
if not is_logged_in():
|
||||
return redirect("/")
|
||||
|
||||
elif request.method == "POST":
|
||||
|
@ -129,11 +131,11 @@ def settings():
|
|||
flash("Bad request", category="error")
|
||||
return render_template("user.html.j2"), 400
|
||||
|
||||
request.db.create_key(request.sid, ttl, request.form.get("name"))
|
||||
ctx.db.create_key(ctx.sid, ttl, request.form.get("name"))
|
||||
flash("Key created", category="success")
|
||||
|
||||
elif request.form["action"] == "revoke":
|
||||
request.db.delete_key(request.uid, request.form.get("id"))
|
||||
ctx.db.delete_key(ctx.uid, request.form.get("id"))
|
||||
flash("Key revoked", category="success")
|
||||
|
||||
else:
|
||||
|
|
|
@ -18,11 +18,13 @@ from flask import (
|
|||
flash,
|
||||
)
|
||||
|
||||
from tentacles.globals import ctx
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_logged_in(request: Request) -> bool:
|
||||
return request.uid is not None
|
||||
def is_logged_in() -> bool:
|
||||
return ctx.uid is not None
|
||||
|
||||
|
||||
def salt(password: str) -> str:
|
||||
|
@ -31,7 +33,7 @@ def salt(password: str) -> str:
|
|||
|
||||
def requires_admin(f):
|
||||
def _helper(*args, **kwargs):
|
||||
if not request.is_admin:
|
||||
if not ctx.is_admin:
|
||||
flash("Sorry, admins only", category="error")
|
||||
redirect("/")
|
||||
else:
|
||||
|
|
22
projects/tentacles/src/python/tentacles/globals.py
Normal file
22
projects/tentacles/src/python/tentacles/globals.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
from attrs import define
|
||||
from tentacles.store import Store
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
|
||||
@define
|
||||
class Ctx:
|
||||
db: Store
|
||||
uid: int = None
|
||||
gid: int = None
|
||||
sid: str = None
|
||||
username: str = None
|
||||
is_admin: bool = None
|
||||
|
||||
|
||||
_ctx = ContextVar("tentacles.ctx")
|
||||
ctx: Ctx = LocalProxy(_ctx)
|
|
@ -25,11 +25,11 @@
|
|||
</label>
|
||||
|
||||
<ul class="menu">
|
||||
{% if not request.uid %}
|
||||
{% if not ctx.uid %}
|
||||
<li><a href="/user/login">Log in</a></li>
|
||||
<li><a href="/user/register">Register</a></li>
|
||||
{% else %}
|
||||
{% if request.is_admin %}
|
||||
{% if ctx.is_admin %}
|
||||
<li><a href="/printers">Printers</a></li>
|
||||
{% endif %}
|
||||
<li><a href="/user">Settings</a></li>
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
{% block content %}
|
||||
<div class="panel queue">
|
||||
<h2>Queue</h2>
|
||||
{% with jobs = request.db.list_jobs(uid=request.uid) %}
|
||||
{% with jobs = ctx.db.list_jobs(uid=request.uid) %}
|
||||
{% if jobs %}
|
||||
<ul>
|
||||
{% for job in jobs %}
|
||||
|
@ -18,7 +18,7 @@
|
|||
{% if request.uid %}
|
||||
<div class="panel files">
|
||||
<h2>Files</h2>
|
||||
{% with files = request.db.list_files(uid=request.uid) %}
|
||||
{% with files = ctx.db.list_files(uid=request.uid) %}
|
||||
{% if files %}
|
||||
<ul>
|
||||
{% for file in files %}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
{% block content %}
|
||||
<div class="panel printers">
|
||||
<h2>Printers</h2>
|
||||
{% with printers = request.db.list_printers() %}
|
||||
{% with printers = ctx.db.list_printers() %}
|
||||
{% if printers %}
|
||||
<ul>
|
||||
{% for printer in printers %}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
<h1>User settings</h1>
|
||||
<div class="">
|
||||
<h2>API keys</h2>
|
||||
{% with keys = request.db.list_keys(request.uid) %}
|
||||
{% with keys = ctx.db.list_keys(ctx.uid) %}
|
||||
<ul>
|
||||
{% for id, name, exp in keys if name != 'web session' %}
|
||||
<li>
|
||||
|
|
Loading…
Reference in a new issue