chore: automatic commit 2025-04-30 12:48

This commit is contained in:
2025-04-30 12:48:06 +02:00
parent f69356473b
commit e4ab1e1bb5
5284 changed files with 868438 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
from .csrf import CSRFProtect
from .form import FlaskForm
from .form import Form
from .recaptcha import Recaptcha
from .recaptcha import RecaptchaField
from .recaptcha import RecaptchaWidget
__version__ = "1.2.2"
__all__ = [
"CSRFProtect",
"FlaskForm",
"Form",
"Recaptcha",
"RecaptchaField",
"RecaptchaWidget",
]

View File

@@ -0,0 +1,11 @@
import warnings
class FlaskWTFDeprecationWarning(DeprecationWarning):
pass
warnings.simplefilter("always", FlaskWTFDeprecationWarning)
warnings.filterwarnings(
"ignore", category=FlaskWTFDeprecationWarning, module="wtforms|flask_wtf"
)

View File

@@ -0,0 +1,329 @@
import hashlib
import hmac
import logging
import os
from urllib.parse import urlparse
from flask import Blueprint
from flask import current_app
from flask import g
from flask import request
from flask import session
from itsdangerous import BadData
from itsdangerous import SignatureExpired
from itsdangerous import URLSafeTimedSerializer
from werkzeug.exceptions import BadRequest
from wtforms import ValidationError
from wtforms.csrf.core import CSRF
__all__ = ("generate_csrf", "validate_csrf", "CSRFProtect")
logger = logging.getLogger(__name__)
def generate_csrf(secret_key=None, token_key=None):
"""Generate a CSRF token. The token is cached for a request, so multiple
calls to this function will generate the same token.
During testing, it might be useful to access the signed token in
``g.csrf_token`` and the raw token in ``session['csrf_token']``.
:param secret_key: Used to securely sign the token. Default is
``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
:param token_key: Key where token is stored in session for comparison.
Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
"""
secret_key = _get_config(
secret_key,
"WTF_CSRF_SECRET_KEY",
current_app.secret_key,
message="A secret key is required to use CSRF.",
)
field_name = _get_config(
token_key,
"WTF_CSRF_FIELD_NAME",
"csrf_token",
message="A field name is required to use CSRF.",
)
if field_name not in g:
s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
if field_name not in session:
session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
try:
token = s.dumps(session[field_name])
except TypeError:
session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
token = s.dumps(session[field_name])
setattr(g, field_name, token)
return g.get(field_name)
def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
"""Check if the given data is a valid CSRF token. This compares the given
signed token to the one stored in the session.
:param data: The signed CSRF token to be checked.
:param secret_key: Used to securely sign the token. Default is
``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
:param time_limit: Number of seconds that the token is valid. Default is
``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
:param token_key: Key where token is stored in session for comparison.
Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
:raises ValidationError: Contains the reason that validation failed.
.. versionchanged:: 0.14
Raises ``ValidationError`` with a specific error message rather than
returning ``True`` or ``False``.
"""
secret_key = _get_config(
secret_key,
"WTF_CSRF_SECRET_KEY",
current_app.secret_key,
message="A secret key is required to use CSRF.",
)
field_name = _get_config(
token_key,
"WTF_CSRF_FIELD_NAME",
"csrf_token",
message="A field name is required to use CSRF.",
)
time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)
if not data:
raise ValidationError("The CSRF token is missing.")
if field_name not in session:
raise ValidationError("The CSRF session token is missing.")
s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
try:
token = s.loads(data, max_age=time_limit)
except SignatureExpired as e:
raise ValidationError("The CSRF token has expired.") from e
except BadData as e:
raise ValidationError("The CSRF token is invalid.") from e
if not hmac.compare_digest(session[field_name], token):
raise ValidationError("The CSRF tokens do not match.")
def _get_config(
value, config_name, default=None, required=True, message="CSRF is not configured."
):
"""Find config value based on provided value, Flask config, and default
value.
:param value: already provided config value
:param config_name: Flask ``config`` key
:param default: default value if not provided or configured
:param required: whether the value must not be ``None``
:param message: error message if required config is not found
:raises KeyError: if required config is not found
"""
if value is None:
value = current_app.config.get(config_name, default)
if required and value is None:
raise RuntimeError(message)
return value
class _FlaskFormCSRF(CSRF):
def setup_form(self, form):
self.meta = form.meta
return super().setup_form(form)
def generate_csrf_token(self, csrf_token_field):
return generate_csrf(
secret_key=self.meta.csrf_secret, token_key=self.meta.csrf_field_name
)
def validate_csrf_token(self, form, field):
if g.get("csrf_valid", False):
# already validated by CSRFProtect
return
try:
validate_csrf(
field.data,
self.meta.csrf_secret,
self.meta.csrf_time_limit,
self.meta.csrf_field_name,
)
except ValidationError as e:
logger.info(e.args[0])
raise
class CSRFProtect:
"""Enable CSRF protection globally for a Flask app.
::
app = Flask(__name__)
csrf = CSRFProtect(app)
Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``
header sent with JavaScript requests. Render the token in templates using
``{{ csrf_token() }}``.
See the :ref:`csrf` documentation.
"""
def __init__(self, app=None):
self._exempt_views = set()
self._exempt_blueprints = set()
if app:
self.init_app(app)
def init_app(self, app):
app.extensions["csrf"] = self
app.config.setdefault("WTF_CSRF_ENABLED", True)
app.config.setdefault("WTF_CSRF_CHECK_DEFAULT", True)
app.config["WTF_CSRF_METHODS"] = set(
app.config.get("WTF_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
)
app.config.setdefault("WTF_CSRF_FIELD_NAME", "csrf_token")
app.config.setdefault("WTF_CSRF_HEADERS", ["X-CSRFToken", "X-CSRF-Token"])
app.config.setdefault("WTF_CSRF_TIME_LIMIT", 3600)
app.config.setdefault("WTF_CSRF_SSL_STRICT", True)
app.jinja_env.globals["csrf_token"] = generate_csrf
app.context_processor(lambda: {"csrf_token": generate_csrf})
@app.before_request
def csrf_protect():
if not app.config["WTF_CSRF_ENABLED"]:
return
if not app.config["WTF_CSRF_CHECK_DEFAULT"]:
return
if request.method not in app.config["WTF_CSRF_METHODS"]:
return
if not request.endpoint:
return
if app.blueprints.get(request.blueprint) in self._exempt_blueprints:
return
view = app.view_functions.get(request.endpoint)
dest = f"{view.__module__}.{view.__name__}"
if dest in self._exempt_views:
return
self.protect()
def _get_csrf_token(self):
# find the token in the form data
field_name = current_app.config["WTF_CSRF_FIELD_NAME"]
base_token = request.form.get(field_name)
if base_token:
return base_token
# if the form has a prefix, the name will be {prefix}-csrf_token
for key in request.form:
if key.endswith(field_name):
csrf_token = request.form[key]
if csrf_token:
return csrf_token
# find the token in the headers
for header_name in current_app.config["WTF_CSRF_HEADERS"]:
csrf_token = request.headers.get(header_name)
if csrf_token:
return csrf_token
return None
def protect(self):
if request.method not in current_app.config["WTF_CSRF_METHODS"]:
return
try:
validate_csrf(self._get_csrf_token())
except ValidationError as e:
logger.info(e.args[0])
self._error_response(e.args[0])
if request.is_secure and current_app.config["WTF_CSRF_SSL_STRICT"]:
if not request.referrer:
self._error_response("The referrer header is missing.")
good_referrer = f"https://{request.host}/"
if not same_origin(request.referrer, good_referrer):
self._error_response("The referrer does not match the host.")
g.csrf_valid = True # mark this request as CSRF valid
def exempt(self, view):
"""Mark a view or blueprint to be excluded from CSRF protection.
::
@app.route('/some-view', methods=['POST'])
@csrf.exempt
def some_view():
...
::
bp = Blueprint(...)
csrf.exempt(bp)
"""
if isinstance(view, Blueprint):
self._exempt_blueprints.add(view)
return view
if isinstance(view, str):
view_location = view
else:
view_location = ".".join((view.__module__, view.__name__))
self._exempt_views.add(view_location)
return view
def _error_response(self, reason):
raise CSRFError(reason)
class CSRFError(BadRequest):
"""Raise if the client sends invalid CSRF data with the request.
Generates a 400 Bad Request response with the failure reason by default.
Customize the response by registering a handler with
:meth:`flask.Flask.errorhandler`.
"""
description = "CSRF validation failed."
def same_origin(current_uri, compare_uri):
current = urlparse(current_uri)
compare = urlparse(compare_uri)
return (
current.scheme == compare.scheme
and current.hostname == compare.hostname
and current.port == compare.port
)

View File

@@ -0,0 +1,146 @@
from collections import abc
from werkzeug.datastructures import FileStorage
from wtforms import FileField as _FileField
from wtforms import MultipleFileField as _MultipleFileField
from wtforms.validators import DataRequired
from wtforms.validators import StopValidation
from wtforms.validators import ValidationError
class FileField(_FileField):
"""Werkzeug-aware subclass of :class:`wtforms.fields.FileField`."""
def process_formdata(self, valuelist):
valuelist = (x for x in valuelist if isinstance(x, FileStorage) and x)
data = next(valuelist, None)
if data is not None:
self.data = data
else:
self.raw_data = ()
class MultipleFileField(_MultipleFileField):
"""Werkzeug-aware subclass of :class:`wtforms.fields.MultipleFileField`.
.. versionadded:: 1.2.0
"""
def process_formdata(self, valuelist):
valuelist = (x for x in valuelist if isinstance(x, FileStorage) and x)
data = list(valuelist) or None
if data is not None:
self.data = data
else:
self.raw_data = ()
class FileRequired(DataRequired):
"""Validates that the uploaded files(s) is a Werkzeug
:class:`~werkzeug.datastructures.FileStorage` object.
:param message: error message
You can also use the synonym ``file_required``.
"""
def __call__(self, form, field):
field_data = [field.data] if not isinstance(field.data, list) else field.data
if not (
all(isinstance(x, FileStorage) and x for x in field_data) and field_data
):
raise StopValidation(
self.message or field.gettext("This field is required.")
)
file_required = FileRequired
class FileAllowed:
"""Validates that the uploaded file(s) is allowed by a given list of
extensions or a Flask-Uploads :class:`~flaskext.uploads.UploadSet`.
:param upload_set: A list of extensions or an
:class:`~flaskext.uploads.UploadSet`
:param message: error message
You can also use the synonym ``file_allowed``.
"""
def __init__(self, upload_set, message=None):
self.upload_set = upload_set
self.message = message
def __call__(self, form, field):
field_data = [field.data] if not isinstance(field.data, list) else field.data
if not (
all(isinstance(x, FileStorage) and x for x in field_data) and field_data
):
return
filenames = [f.filename.lower() for f in field_data]
for filename in filenames:
if isinstance(self.upload_set, abc.Iterable):
if any(filename.endswith("." + x) for x in self.upload_set):
continue
raise StopValidation(
self.message
or field.gettext(
"File does not have an approved extension: {extensions}"
).format(extensions=", ".join(self.upload_set))
)
if not self.upload_set.file_allowed(field_data, filename):
raise StopValidation(
self.message
or field.gettext("File does not have an approved extension.")
)
file_allowed = FileAllowed
class FileSize:
"""Validates that the uploaded file(s) is within a minimum and maximum
file size (set in bytes).
:param min_size: minimum allowed file size (in bytes). Defaults to 0 bytes.
:param max_size: maximum allowed file size (in bytes).
:param message: error message
You can also use the synonym ``file_size``.
"""
def __init__(self, max_size, min_size=0, message=None):
self.min_size = min_size
self.max_size = max_size
self.message = message
def __call__(self, form, field):
field_data = [field.data] if not isinstance(field.data, list) else field.data
if not (
all(isinstance(x, FileStorage) and x for x in field_data) and field_data
):
return
for f in field_data:
file_size = len(f.read())
f.seek(0) # reset cursor position to beginning of file
if (file_size < self.min_size) or (file_size > self.max_size):
# the file is too small or too big => validation failure
raise ValidationError(
self.message
or field.gettext(
f"File must be between {self.min_size}"
f" and {self.max_size} bytes."
)
)
file_size = FileSize

View File

@@ -0,0 +1,127 @@
from flask import current_app
from flask import request
from flask import session
from markupsafe import Markup
from werkzeug.datastructures import CombinedMultiDict
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.utils import cached_property
from wtforms import Form
from wtforms.meta import DefaultMeta
from wtforms.widgets import HiddenInput
from .csrf import _FlaskFormCSRF
try:
from .i18n import translations
except ImportError:
translations = None # babel not installed
SUBMIT_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
_Auto = object()
class FlaskForm(Form):
"""Flask-specific subclass of WTForms :class:`~wtforms.form.Form`.
If ``formdata`` is not specified, this will use :attr:`flask.request.form`
and :attr:`flask.request.files`. Explicitly pass ``formdata=None`` to
prevent this.
"""
class Meta(DefaultMeta):
csrf_class = _FlaskFormCSRF
csrf_context = session # not used, provided for custom csrf_class
@cached_property
def csrf(self):
return current_app.config.get("WTF_CSRF_ENABLED", True)
@cached_property
def csrf_secret(self):
return current_app.config.get("WTF_CSRF_SECRET_KEY", current_app.secret_key)
@cached_property
def csrf_field_name(self):
return current_app.config.get("WTF_CSRF_FIELD_NAME", "csrf_token")
@cached_property
def csrf_time_limit(self):
return current_app.config.get("WTF_CSRF_TIME_LIMIT", 3600)
def wrap_formdata(self, form, formdata):
if formdata is _Auto:
if _is_submitted():
if request.files:
return CombinedMultiDict((request.files, request.form))
elif request.form:
return request.form
elif request.is_json:
return ImmutableMultiDict(request.get_json())
return None
return formdata
def get_translations(self, form):
if not current_app.config.get("WTF_I18N_ENABLED", True):
return super().get_translations(form)
return translations
def __init__(self, formdata=_Auto, **kwargs):
super().__init__(formdata=formdata, **kwargs)
def is_submitted(self):
"""Consider the form submitted if there is an active request and
the method is ``POST``, ``PUT``, ``PATCH``, or ``DELETE``.
"""
return _is_submitted()
def validate_on_submit(self, extra_validators=None):
"""Call :meth:`validate` only if the form is submitted.
This is a shortcut for ``form.is_submitted() and form.validate()``.
"""
return self.is_submitted() and self.validate(extra_validators=extra_validators)
def hidden_tag(self, *fields):
"""Render the form's hidden fields in one call.
A field is considered hidden if it uses the
:class:`~wtforms.widgets.HiddenInput` widget.
If ``fields`` are given, only render the given fields that
are hidden. If a string is passed, render the field with that
name if it exists.
.. versionchanged:: 0.13
No longer wraps inputs in hidden div.
This is valid HTML 5.
.. versionchanged:: 0.13
Skip passed fields that aren't hidden.
Skip passed names that don't exist.
"""
def hidden_fields(fields):
for f in fields:
if isinstance(f, str):
f = getattr(self, f, None)
if f is None or not isinstance(f.widget, HiddenInput):
continue
yield f
return Markup("\n".join(str(f) for f in hidden_fields(fields or self)))
def _is_submitted():
"""Consider the form submitted if there is an active request and
the method is ``POST``, ``PUT``, ``PATCH``, or ``DELETE``.
"""
return bool(request) and request.method in SUBMIT_METHODS

View File

@@ -0,0 +1,47 @@
from babel import support
from flask import current_app
from flask import request
from flask_babel import get_locale
from wtforms.i18n import messages_path
__all__ = ("Translations", "translations")
def _get_translations():
"""Returns the correct gettext translations.
Copy from flask-babel with some modifications.
"""
if not request:
return None
# babel should be in extensions for get_locale
if "babel" not in current_app.extensions:
return None
translations = getattr(request, "wtforms_translations", None)
if translations is None:
translations = support.Translations.load(
messages_path(), [get_locale()], domain="wtforms"
)
request.wtforms_translations = translations
return translations
class Translations:
def gettext(self, string):
t = _get_translations()
return string if t is None else t.ugettext(string)
def ngettext(self, singular, plural, n):
t = _get_translations()
if t is None:
return singular if n == 1 else plural
return t.ungettext(singular, plural, n)
translations = Translations()

View File

@@ -0,0 +1,5 @@
from .fields import RecaptchaField
from .validators import Recaptcha
from .widgets import RecaptchaWidget
__all__ = ["RecaptchaField", "RecaptchaWidget", "Recaptcha"]

View File

@@ -0,0 +1,17 @@
from wtforms.fields import Field
from . import widgets
from .validators import Recaptcha
__all__ = ["RecaptchaField"]
class RecaptchaField(Field):
widget = widgets.RecaptchaWidget()
# error message if recaptcha validation fails
recaptcha_error = None
def __init__(self, label="", validators=None, **kwargs):
validators = validators or [Recaptcha()]
super().__init__(label, validators, **kwargs)

View File

@@ -0,0 +1,75 @@
import json
from urllib import request as http
from urllib.parse import urlencode
from flask import current_app
from flask import request
from wtforms import ValidationError
RECAPTCHA_VERIFY_SERVER_DEFAULT = "https://www.google.com/recaptcha/api/siteverify"
RECAPTCHA_ERROR_CODES = {
"missing-input-secret": "The secret parameter is missing.",
"invalid-input-secret": "The secret parameter is invalid or malformed.",
"missing-input-response": "The response parameter is missing.",
"invalid-input-response": "The response parameter is invalid or malformed.",
}
__all__ = ["Recaptcha"]
class Recaptcha:
"""Validates a ReCaptcha."""
def __init__(self, message=None):
if message is None:
message = RECAPTCHA_ERROR_CODES["missing-input-response"]
self.message = message
def __call__(self, form, field):
if current_app.testing:
return True
if request.is_json:
response = request.json.get("g-recaptcha-response", "")
else:
response = request.form.get("g-recaptcha-response", "")
remote_ip = request.remote_addr
if not response:
raise ValidationError(field.gettext(self.message))
if not self._validate_recaptcha(response, remote_ip):
field.recaptcha_error = "incorrect-captcha-sol"
raise ValidationError(field.gettext(self.message))
def _validate_recaptcha(self, response, remote_addr):
"""Performs the actual validation."""
try:
private_key = current_app.config["RECAPTCHA_PRIVATE_KEY"]
except KeyError:
raise RuntimeError("No RECAPTCHA_PRIVATE_KEY config set") from None
verify_server = current_app.config.get("RECAPTCHA_VERIFY_SERVER")
if not verify_server:
verify_server = RECAPTCHA_VERIFY_SERVER_DEFAULT
data = urlencode(
{"secret": private_key, "remoteip": remote_addr, "response": response}
)
http_response = http.urlopen(verify_server, data.encode("utf-8"))
if http_response.code != 200:
return False
json_resp = json.loads(http_response.read())
if json_resp["success"]:
return True
for error in json_resp.get("error-codes", []):
if error in RECAPTCHA_ERROR_CODES:
raise ValidationError(RECAPTCHA_ERROR_CODES[error])
return False

View File

@@ -0,0 +1,43 @@
from urllib.parse import urlencode
from flask import current_app
from markupsafe import Markup
RECAPTCHA_SCRIPT_DEFAULT = "https://www.google.com/recaptcha/api.js"
RECAPTCHA_DIV_CLASS_DEFAULT = "g-recaptcha"
RECAPTCHA_TEMPLATE = """
<script src='%s' async defer></script>
<div class="%s" %s></div>
"""
__all__ = ["RecaptchaWidget"]
class RecaptchaWidget:
def recaptcha_html(self, public_key):
html = current_app.config.get("RECAPTCHA_HTML")
if html:
return Markup(html)
params = current_app.config.get("RECAPTCHA_PARAMETERS")
script = current_app.config.get("RECAPTCHA_SCRIPT")
if not script:
script = RECAPTCHA_SCRIPT_DEFAULT
if params:
script += "?" + urlencode(params)
attrs = current_app.config.get("RECAPTCHA_DATA_ATTRS", {})
attrs["sitekey"] = public_key
snippet = " ".join(f'data-{k}="{attrs[k]}"' for k in attrs) # noqa: B028, B907
div_class = current_app.config.get("RECAPTCHA_DIV_CLASS")
if not div_class:
div_class = RECAPTCHA_DIV_CLASS_DEFAULT
return Markup(RECAPTCHA_TEMPLATE % (script, div_class, snippet))
def __call__(self, field, error=None, **kwargs):
"""Returns the recaptcha input HTML."""
try:
public_key = current_app.config["RECAPTCHA_PUBLIC_KEY"]
except KeyError:
raise RuntimeError("RECAPTCHA_PUBLIC_KEY config not set") from None
return self.recaptcha_html(public_key)