From 728373a5ff95b3a338bdbde70c607f50c810fee6 Mon Sep 17 00:00:00 2001 From: zema1 Date: Fri, 27 Oct 2017 18:36:29 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84contest=E6=9D=83=E9=99=90?= =?UTF-8?q?=E6=8E=A7=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- account/decorators.py | 81 ++++++++++++++++-------------- account/serializers.py | 9 +++- account/tests.py | 6 +-- account/urls/oj.py | 3 +- account/views/oj.py | 34 ++++++++++--- contest/models.py | 13 +++-- contest/tests.py | 29 +++++------ contest/urls/oj.py | 4 +- contest/views/oj.py | 21 ++++---- deploy/run.sh | 2 +- oj/settings.py | 2 +- problem/serializers.py | 9 +++- problem/tests.py | 22 ++++----- problem/views/oj.py | 22 +++++---- submission/serializers.py | 1 + submission/views/oj.py | 100 +++++++++++++++++++++++--------------- utils/api/api.py | 14 ++---- utils/api/tests.py | 4 +- utils/throttling.py | 5 +- 19 files changed, 219 insertions(+), 162 deletions(-) diff --git a/account/decorators.py b/account/decorators.py index b3523664..45215301 100644 --- a/account/decorators.py +++ b/account/decorators.py @@ -4,7 +4,7 @@ from utils.api import JSONResponse from .models import ProblemPermission -from contest.models import Contest, ContestType, ContestStatus +from contest.models import Contest, ContestType, ContestStatus, ContestRuleType class BasePermissionDecorator(object): @@ -25,7 +25,7 @@ class BasePermissionDecorator(object): return self.error("Your account is disabled") return self.func(*args, **kwargs) else: - return self.error("Please login in first") + return self.error("Please login first") def check_permission(self): raise NotImplementedError() @@ -57,45 +57,54 @@ class problem_permission_required(admin_role_required): return True -def check_contest_permission(func): +def check_contest_permission(check_type="details"): """ - 只供Class based view 使用,检查用户是否有权进入该contest, + 只供Class based view 使用,检查用户是否有权进入该contest, check_type 可选 details, problems, ranks, submissions 若通过验证,在view中可通过self.contest获得该contest """ - @functools.wraps(func) - def _check_permission(*args, **kwargs): - self = args[0] - request = args[1] - user = request.user - if kwargs.get("contest_id"): - contest_id = kwargs.pop("contest_id") - else: - contest_id = request.GET.get("contest_id") - if not contest_id: - return self.error("Parameter contest_id not exist.") - try: - # use self.contest to avoid query contest again in view. - self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True) - except Contest.DoesNotExist: - return self.error("Contest %s doesn't exist" % contest_id) + def decorator(func): + def _check_permission(*args, **kwargs): + self = args[0] + request = args[1] + user = request.user + if kwargs.get("contest_id"): + contest_id = kwargs.pop("contest_id") + else: + contest_id = request.GET.get("contest_id") + if not contest_id: + return self.error("Parameter contest_id not exist.") + + try: + # use self.contest to avoid query contest again in view. + self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True) + except Contest.DoesNotExist: + return self.error("Contest %s doesn't exist" % contest_id) + + # creator or owner + if self.contest.is_contest_admin(user): + return func(*args, **kwargs) + + if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST: + # Anonymous + if not user.is_authenticated(): + return self.error("Please login first.") + # password error + if ("accessible_contests" not in request.session) or \ + (self.contest.id not in request.session["accessible_contests"]): + return self.error("Password is required.") + + # regular use get contest problems, ranks etc. before contest started + if self.contest.status == ContestStatus.CONTEST_NOT_START and check_type != "details": + return self.error("Contest has not started yet.") + + # check is user have permission to get ranks, submissions OI Contest + if self.contest.status == ContestStatus.CONTEST_UNDERWAY and self.contest.rule_type == ContestRuleType.OI: + if not self.contest.real_time_rank and (check_type == "ranks" or check_type == "submissions"): + return self.error(f"No permission to get {check_type}") - # creator or owner - if self.contest.is_contest_admin(user): return func(*args, **kwargs) - if self.contest.status == ContestStatus.CONTEST_NOT_START: - return self.error("Contest has not started yet.") + return _check_permission - if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST: - # Anonymous - if not user.is_authenticated(): - return self.error("Please login in first.") - # password error - if ("accessible_contests" not in request.session) or \ - (self.contest.id not in request.session["accessible_contests"]): - return self.error("Password is required.") - - return func(*args, **kwargs) - - return _check_permission + return decorator diff --git a/account/serializers.py b/account/serializers.py index aa9675a3..9c2cd15c 100644 --- a/account/serializers.py +++ b/account/serializers.py @@ -8,7 +8,7 @@ from .models import AdminType, ProblemPermission, User, UserProfile class UserLoginSerializer(serializers.Serializer): username = serializers.CharField() password = serializers.CharField() - tfa_code = serializers.CharField(required=False, allow_null=True) + tfa_code = serializers.CharField(required=False, allow_blank=True) class UsernameOrEmailCheckSerializer(serializers.Serializer): @@ -26,6 +26,13 @@ class UserRegisterSerializer(serializers.Serializer): class UserChangePasswordSerializer(serializers.Serializer): old_password = serializers.CharField() new_password = serializers.CharField(min_length=6) + tfa_code = serializers.CharField(required=False, allow_blank=True) + + +class UserChangeEmailSerializer(serializers.Serializer): + password = serializers.CharField() + new_email = serializers.EmailField(max_length=64) + tfa_code = serializers.CharField(required=False, allow_blank=True) class UserSerializer(serializers.ModelSerializer): diff --git a/account/tests.py b/account/tests.py index 75515ced..ccc8b1a8 100644 --- a/account/tests.py +++ b/account/tests.py @@ -362,7 +362,7 @@ class UserChangePasswordAPITest(CaptchaTest): def test_login_required(self): response = self.client.post(self.url, data=self.data) - self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login in first"}) + self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login first"}) def test_valid_ola_password(self): self.assertTrue(self.client.login(username=self.username, password=self.old_password)) @@ -476,13 +476,13 @@ class UserRankAPITest(APITestCase): def test_get_acm_rank(self): resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM}) self.assertSuccess(resp) - data = resp.data["data"] + data = resp.data["data"]["results"] self.assertEqual(data[0]["user"]["username"], "test1") self.assertEqual(data[1]["user"]["username"], "test2") def test_get_oi_rank(self): resp = self.client.get(self.url, data={"rule": ContestRuleType.OI}) self.assertSuccess(resp) - data = resp.data["data"] + data = resp.data["data"]["results"] self.assertEqual(data[0]["user"]["username"], "test2") self.assertEqual(data[1]["user"]["username"], "test1") diff --git a/account/urls/oj.py b/account/urls/oj.py index aacbb59d..0af54a50 100644 --- a/account/urls/oj.py +++ b/account/urls/oj.py @@ -1,7 +1,7 @@ from django.conf.urls import url from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI, - UserChangePasswordAPI, UserRegisterAPI, + UserChangePasswordAPI, UserRegisterAPI, UserChangeEmailAPI, UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI, UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI) @@ -13,6 +13,7 @@ urlpatterns = [ url(r"^logout/?$", UserLogoutAPI.as_view(), name="user_logout_api"), url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"), url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"), + url(r"^change_email/?$", UserChangeEmailAPI.as_view(), name="user_change_email"), url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"), url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"), url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"), diff --git a/account/views/oj.py b/account/views/oj.py index 3cccd362..385399c4 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -21,7 +21,7 @@ from ..models import User, UserProfile from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer, UserChangePasswordSerializer, UserLoginSerializer, UserRegisterSerializer, UsernameOrEmailCheckSerializer, - RankInfoSerializer) + RankInfoSerializer, UserChangeEmailSerializer) from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer, EditUserProfileSerializer, AvatarUploadForm) from ..tasks import send_email_async @@ -176,11 +176,6 @@ class UserLoginAPI(APIView): else: return self.error("Invalid username or password") - # todo remove this, only for debug use - def get(self, request): - auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"])) - return self.success() - class UserLogoutAPI(APIView): def get(self, request): @@ -233,6 +228,27 @@ class UserRegisterAPI(APIView): return self.success("Succeeded") +class UserChangeEmailAPI(APIView): + @validate_serializer(UserChangeEmailSerializer) + @login_required + def post(self, request): + data = request.data + user = auth.authenticate(username=request.user.username, password=data["password"]) + if user: + if user.two_factor_auth: + if "tfa_code" not in data: + return self.error("tfa_required") + if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): + return self.error("Invalid two factor verification code") + if User.objects.filter(email=data["new_email"]).exists(): + return self.error("The email is owned by other account") + user.email = data["new_email"] + user.save() + return self.success("Succeeded") + else: + return self.error("Wrong password") + + class UserChangePasswordAPI(APIView): @validate_serializer(UserChangePasswordSerializer) @login_required @@ -244,7 +260,11 @@ class UserChangePasswordAPI(APIView): username = request.user.username user = auth.authenticate(username=username, password=data["old_password"]) if user: - # TODO: check tfa? + if user.two_factor_auth: + if "tfa_code" not in data: + return self.error("tfa_required") + if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): + return self.error("Invalid two factor verification code") user.set_password(data["new_password"]) user.save() return self.success("Succeeded") diff --git a/contest/models.py b/contest/models.py index 08a3babb..adfa3a7a 100644 --- a/contest/models.py +++ b/contest/models.py @@ -45,13 +45,12 @@ class Contest(models.Model): def is_contest_admin(self, user): return user.is_authenticated() and (self.created_by == user or user.admin_type == AdminType.SUPER_ADMIN) - def check_oi_permission(self, user): - if self.status != ContestStatus.CONTEST_ENDED and not self.real_time_rank: - if self.is_contest_admin(user): - return True - else: - return False - return True + # 是否有权查看problem 的一些统计信息 诸如submission_number, accepted_number 等 + def problem_details_permission(self, user): + return self.rule_type == ContestRuleType.ACM or \ + self.status == ContestStatus.CONTEST_ENDED or \ + self.is_contest_admin(user) or \ + self.real_time_rank class Meta: db_table = "contest" diff --git a/contest/tests.py b/contest/tests.py index df05df87..635c3268 100644 --- a/contest/tests.py +++ b/contest/tests.py @@ -58,43 +58,40 @@ class ContestAdminAPITest(APITestCase): class ContestAPITest(APITestCase): def setUp(self): self.create_admin() - self.url = self.reverse("contest_api") - - def create_contest(self): url = self.reverse("contest_admin_api") - return self.client.post(url, data=DEFAULT_CONTEST_DATA) + self.contest = self.client.post(url, data=DEFAULT_CONTEST_DATA).data["data"] + self.url = self.reverse("contest_api") + "?contest_id=" + str(self.contest["id"]) def test_get_contest_list(self): - self.create_contest() - response = self.client.get(self.url) + url = self.reverse("contest_list_api") + response = self.client.get(url + "?limit=10") self.assertSuccess(response) + self.assertEqual(len(response.data["data"]["results"]), 1) def test_get_one_contest(self): - contest_id = self.create_contest().data["data"]["id"] - response = self.client.get("{}?id={}".format(self.url, contest_id)) - self.assertSuccess(response) + resp = self.client.get(self.url) + self.assertSuccess(resp) def test_regular_user_validate_contest_password(self): - contest_id = self.create_contest().data["data"]["id"] self.create_user("test", "test123") url = self.reverse("contest_password_api") - resp = self.client.post(url, {"contest_id": contest_id, "password": "error_password"}) + resp = self.client.post(url, {"contest_id": self.contest["id"], "password": "error_password"}) self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"}) - resp = self.client.post(url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]}) + resp = self.client.post(url, {"contest_id": self.contest["id"], "password": DEFAULT_CONTEST_DATA["password"]}) self.assertSuccess(resp) def test_regular_user_access_contest(self): - contest_id = self.create_contest().data["data"]["id"] self.create_user("test", "test123") url = self.reverse("contest_access_api") - resp = self.client.get(url + "?contest_id=" + str(contest_id)) + resp = self.client.get(url + "?contest_id=" + str(self.contest["id"])) self.assertFalse(resp.data["data"]["access"]) password_url = self.reverse("contest_password_api") - resp = self.client.post(password_url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]}) + resp = self.client.post(password_url, + {"contest_id": self.contest["id"], "password": DEFAULT_CONTEST_DATA["password"]}) self.assertSuccess(resp) - resp = self.client.get(url + "?contest_id=" + str(contest_id)) + resp = self.client.get(self.url) self.assertSuccess(resp) diff --git a/contest/urls/oj.py b/contest/urls/oj.py index cfa12f6a..9e94fa58 100644 --- a/contest/urls/oj.py +++ b/contest/urls/oj.py @@ -1,10 +1,12 @@ from django.conf.urls import url -from ..views.oj import ContestAnnouncementListAPI, ContestAPI +from ..views.oj import ContestAnnouncementListAPI from ..views.oj import ContestPasswordVerifyAPI, ContestAccessAPI +from ..views.oj import ContestListAPI, ContestAPI from ..views.oj import ContestRankAPI urlpatterns = [ + url(r"^contests/?$", ContestListAPI.as_view(), name="contest_list_api"), url(r"^contest/?$", ContestAPI.as_view(), name="contest_api"), url(r"^contest/password/?$", ContestPasswordVerifyAPI.as_view(), name="contest_password_api"), url(r"^contest/announcement/?$", ContestAnnouncementListAPI.as_view(), name="contest_announcement_api"), diff --git a/contest/views/oj.py b/contest/views/oj.py index 1e787ce2..342d30c4 100644 --- a/contest/views/oj.py +++ b/contest/views/oj.py @@ -12,6 +12,7 @@ from ..serializers import OIContestRankSerializer, ACMContestRankSerializer class ContestAnnouncementListAPI(APIView): + @check_contest_permission(check_type="announcements") def get(self, request): contest_id = request.GET.get("contest_id") if not contest_id: @@ -24,15 +25,13 @@ class ContestAnnouncementListAPI(APIView): class ContestAPI(APIView): + @check_contest_permission(check_type="details") def get(self, request): - contest_id = request.GET.get("id") - if contest_id: - try: - contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True) - except Contest.DoesNotExist: - return self.error("Contest does not exist") - return self.success(ContestSerializer(contest).data) + return self.success(ContestSerializer(self.contest).data) + +class ContestListAPI(APIView): + def get(self, request): contests = Contest.objects.select_related("created_by").filter(visible=True) keyword = request.GET.get("keyword") rule_type = request.GET.get("rule_type") @@ -49,7 +48,8 @@ class ContestAPI(APIView): contests = contests.filter(end_time__lt=cur) else: contests = contests.filter(start_time__lte=cur, end_time__gte=cur) - return self.success(self.paginate_data(request, contests, ContestSerializer)) + data = self.paginate_data(request, contests, ContestSerializer) + return self.success(data) class ContestPasswordVerifyAPI(APIView): @@ -91,11 +91,9 @@ class ContestRankAPI(APIView): return OIContestRank.objects.filter(contest=self.contest). \ select_related("user").order_by("-total_score") - @check_contest_permission + @check_contest_permission(check_type="ranks") def get(self, request): if self.contest.rule_type == ContestRuleType.OI: - if not self.contest.check_oi_permission(request.user): - return self.error("You have no permission for ranks now") serializer = OIContestRankSerializer else: serializer = ACMContestRankSerializer @@ -105,5 +103,4 @@ class ContestRankAPI(APIView): if not qs: qs = self.get_rank() cache.set(cache_key, qs) - return self.success(self.paginate_data(request, qs, serializer)) diff --git a/deploy/run.sh b/deploy/run.sh index 34ef16ba..e992e966 100644 --- a/deploy/run.sh +++ b/deploy/run.sh @@ -14,7 +14,7 @@ cd $BASE find . -name "*.pyc" -delete # wait for postgresql start -sleep 5 +sleep 6 n=0 while [ $n -lt 3 ] diff --git a/oj/settings.py b/oj/settings.py index 81f55d38..c25a9b96 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -192,7 +192,7 @@ CELERY_ACCEPT_CONTENT = ["json"] CELERY_TASK_SERIALIZER = "json" # 用于限制用户恶意提交大量代码 -TOKEN_BUCKET_DEFAULT_CAPACITY = 50 +TOKEN_BUCKET_DEFAULT_CAPACITY = 20 # 单位:每分钟 TOKEN_BUCKET_FILL_RATE = 2 diff --git a/problem/serializers.py b/problem/serializers.py index 0e44710b..47e9a53e 100644 --- a/problem/serializers.py +++ b/problem/serializers.py @@ -107,4 +107,11 @@ class ProblemSerializer(BaseProblemSerializer): class ContestProblemSerializer(BaseProblemSerializer): class Meta: model = Problem - exclude = ("test_case_score", "test_case_id", "visible", "is_public") + exclude = ("test_case_score", "test_case_id", "visible", "is_public", "difficulty") + + +class ContestProblemSafeSerializer(BaseProblemSerializer): + class Meta: + model = Problem + exclude = ("test_case_score", "test_case_id", "visible", "is_public", "difficulty" + "submission_number", "accepted_number", "statistic_info") diff --git a/problem/tests.py b/problem/tests.py index 072abd48..8ff20175 100644 --- a/problem/tests.py +++ b/problem/tests.py @@ -196,30 +196,26 @@ class ContestProblemAdminTest(APITestCase): def setUp(self): self.url = self.reverse("contest_problem_admin_api") self.create_admin() - - def create_contest(self): - url = self.reverse("contest_admin_api") - return self.client.post(url, data=DEFAULT_CONTEST_DATA) + self.contest = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"] def test_create_contest_problem(self): - contest = self.create_contest() data = copy.deepcopy(DEFAULT_PROBLEM_DATA) - data["contest_id"] = contest.data["data"]["id"] + data["contest_id"] = self.contest["id"] resp = self.client.post(self.url, data=data) self.assertSuccess(resp) - return contest, resp + return resp.data["data"] def test_get_contest_problem(self): - contest, contest_problem = self.test_create_contest_problem() - contest_id = contest.data["data"]["id"] + self.test_create_contest_problem() + contest_id = self.contest["id"] resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) self.assertSuccess(resp) - self.assertEqual(len(resp.data["data"]), 1) + self.assertEqual(len(resp.data["data"]["results"]), 1) def test_get_one_contest_problem(self): - contest, contest_problem = self.test_create_contest_problem() - contest_id = contest.data["data"]["id"] - problem_id = contest_problem.data["data"]["id"] + contest_problem = self.test_create_contest_problem() + contest_id = self.contest["id"] + problem_id = contest_problem["id"] resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}") self.assertSuccess(resp) diff --git a/problem/views/oj.py b/problem/views/oj.py index 2f54ed5d..0cc9abdf 100644 --- a/problem/views/oj.py +++ b/problem/views/oj.py @@ -4,7 +4,7 @@ from utils.api import APIView from account.decorators import check_contest_permission from ..models import ProblemTag, Problem, ProblemRuleType from ..serializers import ProblemSerializer, TagSerializer -from ..serializers import ContestProblemSerializer +from ..serializers import ContestProblemSerializer, ContestProblemSafeSerializer from contest.models import ContestRuleType @@ -81,8 +81,6 @@ class ProblemAPI(APIView): class ContestProblemAPI(APIView): def _add_problem_status(self, request, queryset_values): - if self.contest.rule_type == ContestRuleType.OI and not self.contest.check_oi_permission(request.user): - return if request.user.is_authenticated(): profile = request.user.userprofile if self.contest.rule_type == ContestRuleType.ACM: @@ -92,7 +90,7 @@ class ContestProblemAPI(APIView): for problem in queryset_values: problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status") - @check_contest_permission + @check_contest_permission(check_type="problems") def get(self, request): problem_id = request.GET.get("problem_id") if problem_id: @@ -102,11 +100,17 @@ class ContestProblemAPI(APIView): visible=True) except Problem.DoesNotExist: return self.error("Problem does not exist.") - problem_data = ContestProblemSerializer(problem).data - self._add_problem_status(request, [problem_data, ]) + if self.contest.problem_details_permission(request.user): + problem_data = ContestProblemSerializer(problem).data + self._add_problem_status(request, [problem_data, ]) + else: + problem_data = ContestProblemSafeSerializer(problem).data return self.success(problem_data) + contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True) - # 根据profile, 为做过的题目添加标记 - data = ContestProblemSerializer(contest_problems, many=True).data - self._add_problem_status(request, data) + if self.contest.problem_details_permission(request.user): + data = ContestProblemSerializer(contest_problems, many=True).data + self._add_problem_status(request, data) + else: + data = ContestProblemSafeSerializer(contest_problems, many=True).data return self.success(data) diff --git a/submission/serializers.py b/submission/serializers.py index 66a517bd..21bd75db 100644 --- a/submission/serializers.py +++ b/submission/serializers.py @@ -8,6 +8,7 @@ class CreateSubmissionSerializer(serializers.Serializer): language = serializers.ChoiceField(choices=language_names) code = serializers.CharField(max_length=20000) contest_id = serializers.IntegerField(required=False) + captcha = serializers.CharField(required=False) class ShareSubmissionSerializer(serializers.Serializer): diff --git a/submission/views/oj.py b/submission/views/oj.py index 613bd385..d29b3839 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -1,3 +1,4 @@ +from django.conf import settings from account.decorators import login_required, check_contest_permission from judge.tasks import judge_task # from judge.dispatcher import JudgeDispatcher @@ -5,6 +6,7 @@ from problem.models import Problem, ProblemRuleType from contest.models import Contest, ContestStatus, ContestRuleType from utils.api import APIView, validate_serializer from utils.throttling import TokenBucket, BucketController +from utils.captcha import Captcha from utils.cache import cache from ..models import Submission from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer, @@ -12,43 +14,38 @@ from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer -def _submit(response, user, problem_id, language, code, contest_id): - # TODO: 预设默认值,需修改 - controller = BucketController(user_id=user.id, - redis_conn=cache, - default_capacity=30) - bucket = TokenBucket(fill_rate=10, capacity=20, - last_capacity=controller.last_capacity, - last_timestamp=controller.last_timestamp) - if bucket.consume(): - controller.last_capacity -= 1 - else: - return response.error("Please wait %d seconds" % int(bucket.expected_time() + 1)) - - try: - problem = Problem.objects.get(id=problem_id, - contest_id=contest_id, - visible=True) - except Problem.DoesNotExist: - return response.error("Problem not exist") - - submission = Submission.objects.create(user_id=user.id, - username=user.username, - language=language, - code=code, - problem_id=problem.id, - contest_id=contest_id) - # use this for debug - # JudgeDispatcher(submission.id, problem.id).judge() - judge_task.delay(submission.id, problem.id) - return response.success({"submission_id": submission.id}) - - class SubmissionAPI(APIView): + def throttling(self, request): + user_controller = BucketController(factor=request.user.id, + redis_conn=cache, + default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY) + user_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE, + capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY, + last_capacity=user_controller.last_capacity, + last_timestamp=user_controller.last_timestamp) + if user_bucket.consume(): + user_controller.last_capacity -= 1 + else: + return "Please wait %d seconds" % int(user_bucket.expected_time() + 1) + + ip_controller = BucketController(factor=request.session["ip"], + redis_conn=cache, + default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3) + + ip_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE * 3, + capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3, + last_capacity=ip_controller.last_capacity, + last_timestamp=ip_controller.last_timestamp) + if ip_bucket.consume(): + ip_controller.last_capacity -= 1 + else: + return "Captcha is required" + @validate_serializer(CreateSubmissionSerializer) @login_required def post(self, request): data = request.data + hide_id = False if data.get("contest_id"): try: contest = Contest.objects.get(id=data["contest_id"]) @@ -56,9 +53,39 @@ class SubmissionAPI(APIView): return self.error("Contest doesn't exist.") if contest.status == ContestStatus.CONTEST_ENDED: return self.error("The contest have ended") - if contest.status == ContestStatus.CONTEST_NOT_START and request.user != contest.created_by: + if contest.status == ContestStatus.CONTEST_NOT_START and not contest.is_contest_admin(request.user): return self.error("Contest have not started") - return _submit(self, request.user, data["problem_id"], data["language"], data["code"], data.get("contest_id")) + if not contest.problem_details_permission(): + hide_id = True + + if data.get("captcha"): + if not Captcha(request).check(data["captcha"]): + return self.error("Invalid captcha") + + error = self.throttling(request) + if error: + return self.error(error) + + try: + problem = Problem.objects.get(id=data["problem_id"], + contest_id=data.get("contest_id"), + visible=True) + except Problem.DoesNotExist: + return self.error("Problem not exist") + + submission = Submission.objects.create(user_id=request.user.id, + username=request.user.username, + language=data["language"], + code=data["code"], + problem_id=problem.id, + contest_id=data.get("contest_id")) + # use this for debug + # JudgeDispatcher(submission.id, problem.id).judge() + judge_task.delay(submission.id, problem.id) + if hide_id: + return self.success() + else: + return self.success({"submission_id": submission.id}) @login_required def get(self, request): @@ -123,15 +150,12 @@ class SubmissionListAPI(APIView): class ContestSubmissionListAPI(APIView): - @check_contest_permission + @check_contest_permission(check_type="submissions") def get(self, request): if not request.GET.get("limit"): return self.error("Limit is needed") contest = self.contest - if not contest.check_oi_permission(request.user): - return self.error("No permission for OI contest submissions") - submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by") problem_id = request.GET.get("problem_id") myself = request.GET.get("myself") diff --git a/utils/api/api.py b/utils/api/api.py index f49b039a..09671d27 100644 --- a/utils/api/api.py +++ b/utils/api/api.py @@ -107,18 +107,12 @@ class APIView(View): :param object_serializer: 用来序列化query set, 如果为None, 则直接对query set切片 :return: """ - need_paginate = request.GET.get("limit", None) - if need_paginate is None: - if object_serializer: - return object_serializer(query_set, many=True).data - else: - return {"results": query_set, "total": query_set.count()} try: - limit = int(request.GET.get("limit", "100")) + limit = int(request.GET.get("limit", "10")) except ValueError: - limit = 100 - if limit < 0: - limit = 100 + limit = 10 + if limit < 0 or limit > 100: + limit = 10 try: offset = int(request.GET.get("offset", "0")) except ValueError: diff --git a/utils/api/tests.py b/utils/api/tests.py index 4b485c9c..b47ceae9 100644 --- a/utils/api/tests.py +++ b/utils/api/tests.py @@ -27,8 +27,8 @@ class APITestCase(TestCase): return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, problem_permission=ProblemPermission.ALL, login=login) - def reverse(self, url_name): - return reverse(url_name) + def reverse(self, url_name, *args, **kwargs): + return reverse(url_name, *args, **kwargs) def assertSuccess(self, response): if not response.data["error"] is None: diff --git a/utils/throttling.py b/utils/throttling.py index c1fe8dce..7c5f54a9 100644 --- a/utils/throttling.py +++ b/utils/throttling.py @@ -31,11 +31,10 @@ class TokenBucket: class BucketController: - def __init__(self, user_id, redis_conn, default_capacity): - self.user_id = user_id + def __init__(self, factor, redis_conn, default_capacity): self.default_capacity = default_capacity self.redis = redis_conn - self.key = "bucket_" + str(self.user_id) + self.key = "bucket_" + str(factor) @property def last_capacity(self):