From 072364497c8f71692ed627d420d692c8a68b5718 Mon Sep 17 00:00:00 2001 From: virusdefender Date: Sat, 23 Dec 2017 22:27:53 +0800 Subject: [PATCH] new throttling --- options/options.py | 11 ++++ submission/views/oj.py | 34 +++------- utils/throttling.py | 144 ++++++++++++++++++----------------------- 3 files changed, 84 insertions(+), 105 deletions(-) diff --git a/options/options.py b/options/options.py index 7d8b9a9e..57169bb7 100644 --- a/options/options.py +++ b/options/options.py @@ -21,6 +21,7 @@ class OptionKeys: submission_list_show_all = "submission_list_show_all" smtp_config = "smtp_config" judge_server_token = "judge_server_token" + throttling = "throttling" class OptionDefaultValue: @@ -32,6 +33,8 @@ class OptionDefaultValue: submission_list_show_all = True smtp_config = {} judge_server_token = default_token + throttling = {"ip": {"capacity": 100, "fill_rate": 0.1, "default_capacity": 50}, + "user": {"capacity": 20, "fill_rate": 0.03, "default_capacity": 10}} class _SysOptionsMeta(type): @@ -180,6 +183,14 @@ class _SysOptionsMeta(type): def judge_server_token(cls, value): cls._set_option(OptionKeys.judge_server_token, value) + @property + def throttling(cls): + return cls._get_option(OptionKeys.throttling) + + @throttling.setter + def throttling(cls, value): + cls._set_option(OptionKeys.throttling, value) + class SysOptions(metaclass=_SysOptionsMeta): pass diff --git a/submission/views/oj.py b/submission/views/oj.py index 5d6351b0..4fb077b6 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -1,6 +1,5 @@ import ipaddress -from django.conf import settings from account.decorators import login_required, check_contest_permission from judge.tasks import judge_task # from judge.dispatcher import JudgeDispatcher @@ -8,7 +7,7 @@ from problem.models import Problem, ProblemRuleType from contest.models import Contest, ContestStatus, ContestRuleType from options.options import SysOptions from utils.api import APIView, validate_serializer -from utils.throttling import TokenBucket, BucketController +from utils.throttling import TokenBucket from utils.captcha import Captcha from utils.cache import cache from ..models import Submission @@ -19,29 +18,16 @@ from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerialize class SubmissionAPI(APIView): def throttling(self, request): - user_controller = BucketController(factor=request.user.id, - redis_conn=cache, - default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY) - user_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE, - capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY, - last_capacity=user_controller.last_capacity, - last_timestamp=user_controller.last_timestamp) - if user_bucket.consume(): - user_controller.last_capacity -= 1 - else: - return "Please wait %d seconds" % int(user_bucket.expected_time() + 1) + user_bucket = TokenBucket(key=str(request.user.id), + redis_conn=cache, **SysOptions.throttling["user"]) + can_consume, wait = user_bucket.consume() + if not can_consume: + return "Please wait %d seconds" % (int(wait)) - ip_controller = BucketController(factor=request.session["ip"], - redis_conn=cache, - default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3) - - ip_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE * 3, - capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3, - last_capacity=ip_controller.last_capacity, - last_timestamp=ip_controller.last_timestamp) - if ip_bucket.consume(): - ip_controller.last_capacity -= 1 - else: + ip_bucket = TokenBucket(key=request.session["ip"], + redis_conn=cache, **SysOptions.throttling["ip"]) + can_consume, wait = ip_bucket.consume() + if not can_consume: return "Captcha is required" @validate_serializer(CreateSubmissionSerializer) diff --git a/utils/throttling.py b/utils/throttling.py index 7c5f54a9..bab1184a 100644 --- a/utils/throttling.py +++ b/utils/throttling.py @@ -1,90 +1,72 @@ -from __future__ import print_function import time class TokenBucket: - def __init__(self, fill_rate, capacity, last_capacity, last_timestamp): - self.capacity = float(capacity) - self._left_tokens = last_capacity - self.fill_rate = float(fill_rate) - self.timestamp = last_timestamp + """ + 注意:对于单个key的操作不是线程安全的 + """ + def __init__(self, key, capacity, fill_rate, default_capacity, redis_conn): + """ + :param capacity: 最大容量 + :param fill_rate: 填充速度/每秒 + :param default_capacity: 初始容量 + :param redis_conn: redis connection + """ + self._key = key + self._capacity = capacity + self._fill_rate = fill_rate + self._default_capacity = default_capacity + self._redis_conn = redis_conn - def consume(self, tokens=1): - if tokens <= self.tokens: - self._left_tokens -= tokens - return True - return False + self._last_capacity_key = "last_capacity" + self._last_timestamp_key = "last_timestamp" - def expected_time(self, tokens=1): - _tokens = self.tokens - tokens = max(tokens, _tokens) - return (tokens - _tokens) / self.fill_rate * 60 + def _init_key(self): + self._last_capacity = self._default_capacity + now = time.time() + self._last_timestamp = now + return self._default_capacity, now @property - def tokens(self): - if self._left_tokens < self.capacity: + def _last_capacity(self): + last_capacity = self._redis_conn.hget(self._key, self._last_capacity_key) + if last_capacity is None: + return self._init_key()[0] + else: + return float(last_capacity) + + @_last_capacity.setter + def _last_capacity(self, value): + self._redis_conn.hset(self._key, self._last_capacity_key, value) + + @property + def _last_timestamp(self): + return float(self._redis_conn.hget(self._key, self._last_timestamp_key)) + + @_last_timestamp.setter + def _last_timestamp(self, value): + self._redis_conn.hset(self._key, self._last_timestamp_key, value) + + def _try_to_fill(self, now): + delta = self._fill_rate * (now - self._last_timestamp) + return min(self._last_capacity + delta, self._capacity) + + def consume(self, num=1): + """ + 消耗 num 个 token,返回是否成功 + :param num: + :return: result: bool, wait_time: float + """ + # print("capacity ", self.fill(time.time())) + if self._last_capacity >= num: + self._last_capacity -= num + return True, 0 + else: now = time.time() - delta = self.fill_rate * ((now - self.timestamp) / 60) - self._left_tokens = min(self.capacity, self._left_tokens + delta) - self.timestamp = now - return self._left_tokens - - -class BucketController: - def __init__(self, factor, redis_conn, default_capacity): - self.default_capacity = default_capacity - self.redis = redis_conn - self.key = "bucket_" + str(factor) - - @property - def last_capacity(self): - value = self.redis.hget(self.key, "last_capacity") - if value is None: - self.last_capacity = self.default_capacity - return self.default_capacity - return int(value) - - @last_capacity.setter - def last_capacity(self, value): - self.redis.hset(self.key, "last_capacity", value) - - @property - def last_timestamp(self): - value = self.redis.hget(self.key, "last_timestamp") - if value is None: - timestamp = int(time.time()) - self.last_timestamp = timestamp - return timestamp - return int(value) - - @last_timestamp.setter - def last_timestamp(self, value): - self.redis.hset(self.key, "last_timestamp", value) - - -""" -# # Token bucket, to limit submission rate -# # Demo - -success = failure = 0 -current_user_id = 1 -token_bucket_default_capacity = 50 -token_bucket_fill_rate = 10 -for i in range(5000): - controller = BucketController(user_id=current_user_id, - redis_conn=redis.Redis(), - default_capacity=token_bucket_default_capacity) - bucket = TokenBucket(fill_rate=token_bucket_fill_rate, - capacity=token_bucket_default_capacity, - last_capacity=controller.last_capacity, - last_timestamp=controller.last_timestamp) - time.sleep(0.05) - if bucket.consume(): - success += 1 - print(i, ": Accepted") - controller.last_capacity -= 1 - else: - failure += 1 - print(i, "Dropped, time left ", bucket.expected_time()) -print(success, failure) -""" + cur_num = self._try_to_fill(now) + if cur_num >= num: + self._last_capacity = cur_num - num + self._last_timestamp = now + return True, 0 + else: + return False, (num - cur_num) / self._fill_rate