~sirn/fanboi2

ref: 98796226802a26b82c888365ad5b9cd331006792 fanboi2/fanboi2/forms.py -rw-r--r-- 3.6 KiB
98796226Kridsada Thanabulpong Bump copyright year. 3 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import hmac
import os
from hashlib import sha1
from wtforms import TextField, TextAreaField, Form, BooleanField
from wtforms.ext.csrf.fields import CSRFTokenField
from wtforms.validators import Length as _Length
from wtforms.validators import Required, ValidationError


class Length(_Length):
    """"Works just like :class:`wtforms.validators.Length` but treat DOS
    newline as single character instead of two. This is to prevent situation
    where field length seemed incorrectly counted.
    """

    def __call__(self, form, field):
        length = field.data and len(field.data.replace('\r\n', '\n')) or 0
        if length < self.min or self.max != -1 and length > self.max:
            message = self.message
            if message is None:
                if self.max == -1:
                    message = field.ngettext(
                        'Field must be at least %(min)d character long.',
                        'Field must be at least %(min)d characters long.',
                        self.min)
                elif self.min == -1:
                    message = field.ngettext(
                        'Field cannot be longer than %(max)d character.',
                        'Field cannot be longer than %(max)d characters.',
                        self.max)
                else:
                    message = field.gettext('Field must be between %(min)d '
                                            'and %(max)d characters long.')
            raise ValidationError(message % dict(min=self.min, max=self.max))


class SecureForm(Form):
    """Generate CSRF token based based on randomly generated string token."""
    csrf_token = CSRFTokenField()

    def __init__(self, formdata=None, obj=None, prefix='', request=None):
        super(SecureForm, self).__init__(formdata, obj, prefix)
        self.request = request
        self.csrf_token.current_token = self.generate_csrf_token()

    def _generate_hmac(self, message):
        secret = self.request.registry.settings['app.secret']
        return hmac.new(
            bytes(secret.encode('utf8')),
            bytes(message.encode('utf8')),
            digestmod=sha1,
        ).hexdigest()

    def generate_csrf_token(self):
        if 'csrf' not in self.request.session:
            self.request.session['csrf'] = sha1(os.urandom(64)).hexdigest()
        self.csrf_token.csrf_key = self.request.session['csrf']
        return self._generate_hmac(self.request.session['csrf'])

    def validate_csrf_token(self, field):
        if not field.data:
            raise ValidationError('CSRF token missing.')
        hmac_compare = self._generate_hmac(field.csrf_key)
        if not hmac.compare_digest(field.data, hmac_compare):
            raise ValidationError('CSRF token mismatched.')

    @property
    def data(self):
        d = super(SecureForm, self).data
        d.pop('csrf_token')
        return d


class TopicForm(Form):
    """A :class:`Form` for creating new topic. This form should be populated
    to two objects, :attr:`title` to :class:`Topic` and :attr:`body` to
    :class:`Post`.
    """
    title = TextField('Title', validators=[Required(), Length(5, 200)])
    body = TextAreaField('Body', validators=[Required(), Length(5, 4000)])


class SecureTopicForm(SecureForm, TopicForm, Form):
    pass


class PostForm(Form):
    """A :class:`Form` for replying to a topic. The :attr:`body` field should
    be populated to :class:`Post`.
    """
    body = TextAreaField('Body', validators=[Required(), Length(5, 4000)])
    bumped = BooleanField('Bump this topic', default=True)


class SecurePostForm(SecureForm, PostForm, Form):
    pass