DATE_FORMAT = "%Y-%m-%dT%H:%M:%S+00:00" from flask import Flask, Response, request, url_for, render_template, redirect from flask import Blueprint, current_app, g, abort, session as flask_session from enum import Enum from srht.config import cfg, cfgi, cfgkeys, config, get_origin, get_global_domain from srht.crypto import fernet from srht.email import mail_exception from srht.database import db from srht.markdown import markdown from srht.validation import Validation from datetime import datetime, timedelta from jinja2 import Markup, FileSystemLoader, ChoiceLoader, contextfunction from jinja2 import escape from prometheus_client import Counter, Summary, make_wsgi_app from timeit import default_timer from urllib.parse import urlparse, quote_plus from werkzeug.local import LocalProxy from werkzeug.routing import UnicodeConverter from werkzeug.urls import url_quote try: from werkzeug.wsgi import DispatcherMiddleware except ImportError: from werkzeug.middleware.dispatcher import DispatcherMiddleware import binascii import hashlib import inspect import humanize import decimal import bleach import json import locale import sys import os class NamespacedSession: def __getitem__(self, key): return flask_session[f"{current_app.site}:{key}"] def __setitem__(self, key, value): flask_session[f"{current_app.site}:{key}"] = value def __delitem__(self, key): del flask_session[f"{current_app.site}:{key}"] def get(self, key, *args, **kwargs): return flask_session.get(f"{current_app.site}:{key}", *args, **kwargs) def set(self, key, *args, **kwargs): return flask_session.set(f"{current_app.site}:{key}", *args, **kwargs) def setdefault(self, key, *args, **kwargs): return flask_session.setdefault( f"{current_app.site}:{key}", *args, **kwargs) def pop(self, key, *args, **kwargs): return flask_session.pop(f"{current_app.site}:{key}", *args, **kwargs) _session = NamespacedSession() session = LocalProxy(lambda: _session) humanize.time._now = lambda: datetime.utcnow() try: locale.setlocale(locale.LC_ALL, 'en_US') except: pass def date_handler(obj): if hasattr(obj, 'strftime'): return obj.strftime(DATE_FORMAT) if isinstance(obj, decimal.Decimal): return "{:.2f}".format(obj) if isinstance(obj, Enum): return obj.name return obj def datef(d): if not d: return 'Never' if isinstance(d, timedelta): return Markup('{}'.format( f'{d.seconds} seconds', humanize.naturaltime(d).rstrip(" ago"))) return Markup('{}'.format( d.strftime('%Y-%m-%d %H:%M:%S UTC'), humanize.naturaltime(d))) icon_cache = {} def icon(i, cls=""): if i in icon_cache: svg = icon_cache[i] return Markup(f'{svg}') fa_license = """""" path = os.path.join(current_app.mod_path, 'static', 'icons', i + '.svg') with open(path) as f: svg = f.read() icon_cache[i] = svg if g and "fa_license" not in g: svg += fa_license g.fa_license = True return Markup(f'{svg}') @contextfunction def pagination(context): template = context.environment.get_template("pagination.html") return Markup(template.render(**context.parent)) def csrf_token(): if '_csrf_token_v2' not in flask_session: flask_session['_csrf_token_v2'] = binascii.hexlify(os.urandom(64)).decode() return Markup("""""".format(escape(flask_session['_csrf_token_v2']))) _csrf_bypass_views = set() _csrf_bypass_blueprints = set() def csrf_bypass(f): if isinstance(f, Blueprint): _csrf_bypass_blueprints.update([f]) else: view = '.'.join((f.__module__, f.__name__)) _csrf_bypass_views.update([view]) return f def paginate_query(query, results_per_page=15): page = request.args.get("page") total_results = query.count() total_pages = total_results // results_per_page + 1 if total_results % results_per_page == 0: total_pages -= 1 if page is not None: try: page = int(page) - 1 query = query.offset(page * results_per_page) except: page = 0 else: page = 0 if page < 1: abort(400) query = query.limit(results_per_page).all() return query, { "total_pages": total_pages, "page": page + 1, "total_results": total_results } class ModifiedUnicodeConverter(UnicodeConverter): """Added ~ and ^ to safe URL characters, otherwise no changes.""" def to_url(self, value): return url_quote(value, charset=self.map.charset, safe='/:~^') class SrhtFlask(Flask): def __init__(self, site, name, oauth_service=None, oauth_provider=None, *args, **kwargs): super().__init__(name, *args, **kwargs) self.site = site self.wsgi_app = DispatcherMiddleware(self.wsgi_app, { "/metrics": make_wsgi_app(), }) self.metrics = type("metrics", tuple(), { m.describe()[0].name: m for m in [ Counter("http_requests", "Number of HTTP requests", [ "method", "route", "status", ]), Summary("request_time", "Duration of HTTP requests", [ "method", "route", ]), ] }) self.url_map.converters['default'] = ModifiedUnicodeConverter self.url_map.converters['string'] = ModifiedUnicodeConverter choices = [FileSystemLoader("templates")] mod = __import__(name) if hasattr(mod, "__path__"): path = list(mod.__path__)[0] self.mod_path = path choices.append(FileSystemLoader( os.path.join("/etc", self.site, "templates"))) choices.append(FileSystemLoader(os.path.join(path, "templates"))) choices.append(FileSystemLoader(os.path.join( os.path.dirname(__file__), "templates" ))) self.jinja_env.cache = None self.jinja_env.filters['date'] = datef self.jinja_env.globals['pagination'] = pagination self.jinja_env.globals['icon'] = icon self.jinja_env.globals['csrf_token'] = csrf_token self.jinja_loader = ChoiceLoader(choices) self.secret_key = cfg("sr.ht", "service-key", default= cfg("sr.ht", "secret-key", default=None)) if self.secret_key is None: raise Exception("[sr.ht]service-key missing from config") self.oauth_service = oauth_service self.oauth_provider = oauth_provider if self.oauth_service: from srht.oauth import oauth_blueprint self.register_blueprint(oauth_blueprint) from srht.oauth.scope import set_client_id set_client_id(self.oauth_service.client_id) # TODO: Remove self.no_csrf_prefixes = ['/api'] @self.before_request def _csrf_check(): if request.method != 'POST': return if request.blueprint in _csrf_bypass_blueprints: return view = self.view_functions.get(request.endpoint) if not view: return view = "{0}.{1}".format(view.__module__, view.__name__) if view in _csrf_bypass_views: return # TODO: Remove for prefix in self.no_csrf_prefixes: if request.path.startswith(prefix): return token = flask_session.get('_csrf_token_v2', None) if not token or token != request.form.get('_csrf_token'): abort(403) @self.teardown_appcontext def expire_db(err): db.session.expire_all() @self.errorhandler(500) def handle_500(e): if self.debug: raise e # shit try: if hasattr(db, 'session'): db.session.rollback() db.session.close() mail_exception(e) except Exception as e2: # shit shit raise e2.with_traceback(e2.__traceback__) return render_template("internal_error.html"), 500 @self.errorhandler(404) def handle_404(e): if request.path.startswith("/api"): return { "errors": [ { "reason": "404 not found" } ] }, 404 return render_template("not_found.html"), 404 @self.context_processor def inject(): from srht.oauth import current_user user_class = (current_user._get_current_object().__class__ if current_user else None) root = get_origin(self.site, external=True) ctx = { 'root': root, 'domain': urlparse(root).netloc, 'app': self, 'len': len, 'any': any, 'str': str, 'request': request, 'url_for': url_for, 'cfg': cfg, 'cfgi': cfgi, 'cfgkeys': cfgkeys, 'get_origin': get_origin, 'valid': Validation(request), 'site': site, 'site_name': cfg("sr.ht", "site-name", default=None), 'environment': cfg("sr.ht", "environment", default="production"), 'network': self.get_network(), 'current_user': (user_class.query .filter(user_class.id == current_user.id) ).one_or_none() if current_user else None, 'static_resource': self.static_resource, } if self.oauth_service: ctx.update({ "oauth_url": self.oauth_service.oauth_url( request.full_path), "logout_url": "{}/logout?return_to={}{}".format( get_origin("meta.sr.ht", external=True), root, quote_plus(request.full_path)), }) return ctx @self.teardown_appcontext def shutdown_session(resp): db.session.remove() return resp @self.template_filter() def md(text): return markdown(text) @self.template_filter() def extended_md(text, baselevel=1): return markdown(text, ["h1", "h2", "h3", "h4", "h5"], baselevel) @self.before_request def get_session_cookie(): # TODO: We could probably speed things up by skipping the # round-trip until we actually need any user info which isn't # present in the user's info cookie cookie = request.cookies.get("sr.ht.unified-login.v1") if not cookie: return user_info = json.loads(fernet.decrypt(cookie.encode()).decode()) g.current_user = self.oauth_service.get_user(user_info) @self.before_request def begin_track_request(): request._srht_start_time = default_timer() @self.after_request def track_request(resp): if not hasattr(request, "_srht_start_time"): return resp self.metrics.http_requests.labels( method=request.method, route=request.endpoint, status=resp.status_code, ).inc() self.metrics.request_time.labels( method=request.method, route=request.endpoint, ).observe(max(default_timer() - request._srht_start_time, 0)) return resp def make_response(self, rv): # Converts responses from dicts to JSON response objects response = None def jsonify_wrap(obj): jsonification = json.dumps(obj, default=date_handler) return Response(jsonification, mimetype='application/json') if isinstance(rv, tuple) and \ (isinstance(rv[0], dict) or isinstance(rv[0], list)): response = jsonify_wrap(rv[0]), rv[1] elif isinstance(rv, dict): response = jsonify_wrap(rv) elif isinstance(rv, list): response = jsonify_wrap(rv) else: response = rv response = super(SrhtFlask, self).make_response(response) global_domain = get_global_domain(self.site) if "set_current_user" in g and g.set_current_user: cookie_key = f"sr.ht.unified-login.v1" if not g.current_user: # Clear user info cookie response.set_cookie(cookie_key, "", domain=global_domain, max_age=0) else: # Set user info cookie user_info = g.current_user.to_dict(first_party=True) user_info = json.dumps(user_info) response.set_cookie(cookie_key, fernet.encrypt(user_info.encode()).decode(), domain=global_domain, max_age=60 * 60 * 24 * 365) path = request.path return response def static_resource(self, path): """ Given /example.ext, hashes the file and returns /example.hash.ext """ if not hasattr(self, "static_cache"): self.static_cache = dict() if path in self.static_cache: return self.static_cache[path] sha256 = hashlib.sha256() with open(os.path.join(self.mod_path, path), "rb") as f: sha256.update(f.read()) path, ext = os.path.splitext(path) self.static_cache[path] = f"{path}.{sha256.hexdigest()[:8]}{ext}" return self.static_cache[path] def get_network(self): return [s for s in config if s.endswith(".sr.ht")]