Start trimming some of the __main__ crud

This commit is contained in:
Reid 'arrdem' McKenzie 2022-11-20 22:34:07 -07:00
parent ef339ef916
commit 53bf982217
5 changed files with 102 additions and 120 deletions

View file

@ -10,14 +10,13 @@ USER app
WORKDIR /app
ENV PATH="/app/.local/bin:${PATH}"
ENV PYTHONPATH="/app:${PYTHONPATH}"
ENV DOCKER_RUNNING=true # Trivialize detecting dockerization
# Trivialize detecting dockerization
ENV DOCKER_RUNNING=true
### App specific crap
# Deps vary least so do them first
RUN pip install --user install aiohttp aiohttp_basicauth async_lru cachetools click pycryptodome pyyaml retry
COPY --chown=app:app src/python .
COPY --chown=app:app relay.yaml .
COPY --chown=app:app relay.jsonld .
COPY --chown=app:app src/python relay.yaml relay.jsonld .
CMD ["python3", "relay/__main__.py", "-c", "relay.yaml"]
CMD ["python3", "relay/__main__.py", "-c", "relay.yaml", "run"]

View file

@ -4,4 +4,4 @@ cd "$(realpath $(dirname $0))"
bazel build ...
exec ../../bazel-bin/projects/activitypub_relay/activitypub_relay -c $(realpath ./relay.yaml)
exec ../../bazel-bin/projects/activitypub_relay/activitypub_relay -c $(realpath ./relay.yaml) run

View file

@ -1,3 +1 @@
__version__ = "0.2.2"
from . import logger

View file

@ -2,6 +2,8 @@ import Crypto
import asyncio
import click
import platform
import logging
import os
from urllib.parse import urlparse
@ -10,9 +12,6 @@ from relay.application import Application, request_id_middleware
from relay.config import relay_software_names
app = None
@click.group(
"cli", context_settings={"show_default": True}, invoke_without_command=True
)
@ -20,40 +19,47 @@ app = None
@click.version_option(version=__version__, prog_name="ActivityRelay")
@click.pass_context
def cli(ctx, config):
global app
app = Application(config, middlewares=[request_id_middleware])
ctx.obj = Application(config, middlewares=[request_id_middleware])
if not ctx.invoked_subcommand:
if app.config.host.endswith("example.com"):
relay_setup.callback()
level = {
"debug": logging.DEBUG,
"info": logging.INFO,
"error": logging.ERROR,
"critical": logging.CRITICAL,
"fatal": logging.FATAL,
}.get(os.getenv("LOG_LEVE", "INFO").lower(), logging.INFO)
else:
relay_run.callback()
logging.basicConfig(
level=level,
format="[%(asctime)s] %(levelname)s: %(message)s",
)
@cli.group("inbox")
@click.pass_context
def cli_inbox(ctx):
@click.pass_obj
def cli_inbox(ctx: Application):
"Manage the inboxes in the database"
pass
@cli_inbox.command("list")
def cli_inbox_list():
@click.pass_obj
def cli_inbox_list(obj: Application):
"List the connected instances or relays"
click.echo("Connected to the following instances or relays:")
for inbox in app.database.inboxes:
for inbox in obj.database.inboxes:
click.echo(f"- {inbox}")
@cli_inbox.command("follow")
@click.argument("actor")
def cli_inbox_follow(actor):
@click.pass_obj
def cli_inbox_follow(obj: Application, actor):
"Follow an actor (Relay must be running)"
if app.config.is_banned(actor):
if obj.config.is_banned(actor):
return click.echo(f"Error: Refusing to follow banned actor: {actor}")
if not actor.startswith("http"):
@ -64,14 +70,14 @@ def cli_inbox_follow(actor):
domain = urlparse(actor).hostname
try:
inbox_data = app.database["relay-list"][domain]
inbox_data = obj.database["relay-list"][domain]
inbox = inbox_data["inbox"]
except KeyError:
actor_data = asyncio.run(misc.request(actor))
inbox = actor_data.shared_inbox
message = misc.Message.new_follow(host=app.config.host, actor=actor.id)
message = misc.Message.new_follow(host=obj.config.host, actor=actor.id)
asyncio.run(misc.request(inbox, message))
click.echo(f"Sent follow message to actor: {actor}")
@ -79,7 +85,8 @@ def cli_inbox_follow(actor):
@cli_inbox.command("unfollow")
@click.argument("actor")
def cli_inbox_unfollow(actor):
@click.pass_obj
def cli_inbox_unfollow(obj: Application, actor):
"Unfollow an actor (Relay must be running)"
if not actor.startswith("http"):
@ -90,22 +97,22 @@ def cli_inbox_unfollow(actor):
domain = urlparse(actor).hostname
try:
inbox_data = app.database["relay-list"][domain]
inbox_data = obj.database["relay-list"][domain]
inbox = inbox_data["inbox"]
message = misc.Message.new_unfollow(
host=app.config.host, actor=actor, follow=inbox_data["followid"]
host=obj.config.host, actor=actor, follow=inbox_data["followid"]
)
except KeyError:
actor_data = asyncio.run(misc.request(actor))
inbox = actor_data.shared_inbox
message = misc.Message.new_unfollow(
host=app.config.host,
host=obj.config.host,
actor=actor,
follow={
"type": "Follow",
"object": actor,
"actor": f"https://{app.config.host}/actor",
"actor": f"https://{obj.config.host}/actor",
},
)
@ -115,17 +122,18 @@ def cli_inbox_unfollow(actor):
@cli_inbox.command("add")
@click.argument("inbox")
def cli_inbox_add(inbox):
@click.pass_obj
def cli_inbox_add(obj: Application, inbox):
"Add an inbox to the database"
if not inbox.startswith("http"):
inbox = f"https://{inbox}/inbox"
if app.config.is_banned(inbox):
if obj.config.is_banned(inbox):
return click.echo(f"Error: Refusing to add banned inbox: {inbox}")
if app.database.add_inbox(inbox):
app.database.save()
if obj.database.add_inbox(inbox):
obj.database.save()
return click.echo(f"Added inbox to the database: {inbox}")
click.echo(f"Error: Inbox already in database: {inbox}")
@ -133,18 +141,19 @@ def cli_inbox_add(inbox):
@cli_inbox.command("remove")
@click.argument("inbox")
def cli_inbox_remove(inbox):
@click.pass_obj
def cli_inbox_remove(obj: Application, inbox):
"Remove an inbox from the database"
try:
dbinbox = app.database.get_inbox(inbox, fail=True)
dbinbox = obj.database.get_inbox(inbox, fail=True)
except KeyError:
click.echo(f"Error: Inbox does not exist: {inbox}")
return
app.database.del_inbox(dbinbox["domain"])
app.database.save()
obj.database.del_inbox(dbinbox["domain"])
obj.database.save()
click.echo(f"Removed inbox from the database: {inbox}")
@ -156,28 +165,30 @@ def cli_instance():
@cli_instance.command("list")
def cli_instance_list():
@click.pass_obj
def cli_instance_list(obj: Application):
"List all banned instances"
click.echo("Banned instances or relays:")
for domain in app.config.blocked_instances:
for domain in obj.config.blocked_instances:
click.echo(f"- {domain}")
@cli_instance.command("ban")
@click.argument("target")
def cli_instance_ban(target):
@click.pass_obj
def cli_instance_ban(obj: Application, target):
"Ban an instance and remove the associated inbox if it exists"
if target.startswith("http"):
target = urlparse(target).hostname
if app.config.ban_instance(target):
app.config.save()
if obj.config.ban_instance(target):
obj.config.save()
if app.database.del_inbox(target):
app.database.save()
if obj.database.del_inbox(target):
obj.database.save()
click.echo(f"Banned instance: {target}")
return
@ -187,11 +198,12 @@ def cli_instance_ban(target):
@cli_instance.command("unban")
@click.argument("target")
def cli_instance_unban(target):
@click.pass_obj
def cli_instance_unban(obj: Application, target):
"Unban an instance"
if app.config.unban_instance(target):
app.config.save()
if obj.config.unban_instance(target):
obj.config.save()
click.echo(f"Unbanned instance: {target}")
return
@ -206,12 +218,13 @@ def cli_software():
@cli_software.command("list")
def cli_software_list():
@click.pass_obj
def cli_software_list(obj: Application):
"List all banned software"
click.echo("Banned software:")
for software in app.config.blocked_software:
for software in obj.config.blocked_software:
click.echo(f"- {software}")
@ -224,14 +237,15 @@ def cli_software_list():
help="Treat NAME like a domain and try to fet the software name from nodeinfo",
)
@click.argument("name")
def cli_software_ban(name, fetch_nodeinfo):
@click.pass_obj
def cli_software_ban(obj: Application, name, fetch_nodeinfo):
"Ban software. Use RELAYS for NAME to ban relays"
if name == "RELAYS":
for name in relay_software_names:
app.config.ban_software(name)
obj.config.ban_software(name)
app.config.save()
obj.config.save()
return click.echo("Banned all relay software")
if fetch_nodeinfo:
@ -243,7 +257,7 @@ def cli_software_ban(name, fetch_nodeinfo):
name = software
if config.ban_software(name):
app.config.save()
obj.config.save()
return click.echo(f"Banned software: {name}")
click.echo(f"Software already banned: {name}")
@ -258,12 +272,13 @@ def cli_software_ban(name, fetch_nodeinfo):
help="Treat NAME like a domain and try to fet the software name from nodeinfo",
)
@click.argument("name")
def cli_software_unban(name, fetch_nodeinfo):
@click.pass_obj
def cli_software_unban(obj: Application, name, fetch_nodeinfo):
"Ban software. Use RELAYS for NAME to unban relays"
if name == "RELAYS":
for name in relay_software_names:
app.config.unban_software(name)
obj.config.unban_software(name)
config.save()
return click.echo("Unbanned all relay software")
@ -276,8 +291,8 @@ def cli_software_unban(name, fetch_nodeinfo):
name = software
if app.config.unban_software(name):
app.config.save()
if obj.config.unban_software(name):
obj.config.save()
return click.echo(f"Unbanned software: {name}")
click.echo(f"Software wasn't banned: {name}")
@ -290,81 +305,86 @@ def cli_whitelist():
@cli_whitelist.command("list")
def cli_whitelist_list():
@click.pass_obj
def cli_whitelist_list(obj: Application):
click.echo("Current whitelisted domains")
for domain in app.config.whitelist:
for domain in obj.config.whitelist:
click.echo(f"- {domain}")
@cli_whitelist.command("add")
@click.argument("instance")
def cli_whitelist_add(instance):
@click.pass_obj
def cli_whitelist_add(obj: Application, instance):
"Add an instance to the whitelist"
if not app.config.add_whitelist(instance):
if not obj.config.add_whitelist(instance):
return click.echo(f"Instance already in the whitelist: {instance}")
app.config.save()
obj.config.save()
click.echo(f"Instance added to the whitelist: {instance}")
@cli_whitelist.command("remove")
@click.argument("instance")
def cli_whitelist_remove(instance):
@click.pass_obj
def cli_whitelist_remove(obj: Application, instance):
"Remove an instance from the whitelist"
if not app.config.del_whitelist(instance):
if not obj.config.del_whitelist(instance):
return click.echo(f"Instance not in the whitelist: {instance}")
app.config.save()
obj.config.save()
if app.config.whitelist_enabled:
if app.database.del_inbox(inbox):
app.database.save()
if obj.config.whitelist_enabled:
if obj.database.del_inbox(inbox):
obj.database.save()
click.echo(f"Removed instance from the whitelist: {instance}")
@cli.command("setup")
def relay_setup():
@click.pass_obj
def relay_setup(obj: Application):
"Generate a new config"
while True:
app.config.host = click.prompt(
"What domain will the relay be hosted on?", default=app.config.host
obj.config.host = click.prompt(
"What domain will the relay be hosted on?", default=obj.config.host
)
if not app.config.host.endswith("example.com"):
if not obj.config.host.endswith("example.com"):
break
click.echo("The domain must not be example.com")
app.config.listen = click.prompt(
"Which address should the relay listen on?", default=app.config.listen
obj.config.listen = click.prompt(
"Which address should the relay listen on?", default=obj.config.listen
)
while True:
app.config.port = click.prompt(
obj.config.port = click.prompt(
"What TCP port should the relay listen on?",
default=app.config.port,
default=obj.config.port,
type=int,
)
break
app.config.save()
obj.config.save()
if not app["is_docker"] and click.confirm(
if not obj["is_docker"] and click.confirm(
"Relay all setup! Would you like to run it now?"
):
relay_run.callback()
@cli.command("run")
def relay_run():
@click.pass_obj
def relay_run(obj: Application):
"Run the relay"
if app.config.host.endswith("example.com"):
if obj.config.host.endswith("example.com"):
return click.echo(
'Relay is not set up. Please edit your relay config or run "activityrelay setup".'
)
@ -385,12 +405,12 @@ def relay_run():
)
return click.echo(pip_command)
if not misc.check_open_port(app.config.listen, app.config.port):
if not misc.check_open_port(obj.config.listen, obj.config.port):
return click.echo(
f"Error: A server is already running on port {app.config.port}"
f"Error: A server is already running on port {obj.config.port}"
)
app.run()
obj.run()
if __name__ == "__main__":

View file

@ -1,35 +0,0 @@
import logging
import os
from pathlib import Path
## Get log level and file from environment if possible
env_log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
try:
env_log_file = Path(os.environ.get("LOG_FILE")).expanduser().resolve()
except TypeError:
env_log_file = None
## Make sure the level from the environment is valid
try:
log_level = getattr(logging, env_log_level)
except AttributeError:
log_level = logging.INFO
## Set logging config
handlers = [logging.StreamHandler()]
if env_log_file:
handlers.append(logging.FileHandler(env_log_file))
logging.basicConfig(
level=log_level,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=handlers,
)