diff --git a/.flake8 b/.flake8 index 3198bbc9..d64dbd5c 100644 --- a/.flake8 +++ b/.flake8 @@ -3,6 +3,7 @@ exclude = xss_filter.py, */migrations/, *settings.py + */apps.py max-line-length = 180 inline-quotes = " no-accept-encodings = True diff --git a/account/__init__.py b/account/__init__.py index d35b9f07..28aa7693 100644 --- a/account/__init__.py +++ b/account/__init__.py @@ -1 +1 @@ -default_app_config = 'account.apps.ProfilesConfig' \ No newline at end of file +default_app_config = "account.apps.ProfilesConfig" diff --git a/account/signals.py b/account/signals.py index 4d14970e..0a15370c 100644 --- a/account/signals.py +++ b/account/signals.py @@ -5,8 +5,8 @@ from django.contrib.auth.signals import user_logged_in, user_logged_out @receiver(user_logged_in) def add_user_session(sender, request, user, **kwargs): - request.session["ip"] = request.META.get('REMOTE_ADDR', '') - request.session["user_agent"] = request.META.get('HTTP_USER_AGENT', '') + request.session["ip"] = request.META.get("REMOTE_ADDR", "") + request.session["user_agent"] = request.META.get("HTTP_USER_AGENT", "") request.session["last_login"] = now() if request.session.session_key not in user.session_keys: user.session_keys.append(request.session.session_key) diff --git a/account/tests.py b/account/tests.py index e37867dd..34da1066 100644 --- a/account/tests.py +++ b/account/tests.py @@ -1,7 +1,9 @@ import time from unittest import mock +from datetime import timedelta from django.contrib import auth +from django.utils.timezone import now from otpauth import OtpAuth from utils.api.tests import APIClient, APITestCase @@ -28,6 +30,40 @@ class PermissionDecoratorTest(APITestCase): pass +class DuplicateUserCheckAPITest(APITestCase): + def setUp(self): + self.create_user("test", "test123", login=False) + self.url = self.reverse("check_username_or_email") + + def test_duplicate_username(self): + resp = self.client.post(self.url, data={"username": "test"}) + data = resp.data["data"] + self.assertEqual(data["username"], True) + + def test_ok_username(self): + resp = self.client.post(self.url, data={"username": "test1"}) + data = resp.data["data"] + self.assertEqual(data["username"], False) + + +class TFARequiredCheckAPITest(APITestCase): + def setUp(self): + self.url = self.reverse("tfa_required_check") + self.create_user("test", "test123", login=False) + + def test_not_required_tfa(self): + resp = self.client.post(self.url, data={"username": "test"}) + self.assertSuccess(resp) + self.assertEqual(resp.data["data"]["result"], False) + + def test_required_tfa(self): + user = User.objects.first() + user.two_factor_auth = True + user.save() + resp = self.client.post(self.url, data={"username": "test"}) + self.assertEqual(resp.data["data"]["result"], True) + + class UserLoginAPITest(APITestCase): def setUp(self): self.username = self.password = "test" @@ -87,7 +123,7 @@ class UserLoginAPITest(APITestCase): response = self.client.post(self.login_url, data={"username": self.username, "password": self.password}) - self.assertDictEqual(response.data, {"error": None, "data": "tfa_required"}) + self.assertDictEqual(response.data, {"error": "error", "data": "tfa_required"}) user = auth.get_user(self.client) self.assertFalse(user.is_authenticated()) @@ -142,6 +178,160 @@ class UserRegisterAPITest(CaptchaTest): self.assertDictEqual(response.data, {"error": "error", "data": "Email already exists"}) +class SessionManagementAPITest(APITestCase): + def setUp(self): + self.create_user("test", "test123") + self.url = self.reverse("session_management_api") + + def test_get_sessions(self): + resp = self.client.get(self.url) + self.assertSuccess(resp) + data = resp.data["data"] + self.assertEqual(len(data), 1) + + def test_delete_session_key(self): + # resp = self.client.delete(self.url, data={"session_key": self.client.session.session_key}) + resp = self.client.delete(self.url + "?session_key=" + self.client.session.session_key) + self.assertSuccess(resp) + + def test_delete_session_with_invalid_key(self): + resp = self.client.delete(self.url + "?session_key=aaaaaaaaaa") + self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid session_key"}) + + +class UserProfileAPITest(APITestCase): + def setUp(self): + self.url = self.reverse("user_profile_api") + + def test_get_profile_without_login(self): + resp = self.client.get(self.url) + self.assertDictEqual(resp.data, {"error": None, "data": 0}) + + def test_get_profile(self): + self.create_user("test", "test123") + resp = self.client.get(self.url) + self.assertSuccess(resp) + + def test_update_profile(self): + self.create_user("test", "test123") + update_data = {"real_name": "zemal", "submission_number": 233} + resp = self.client.put(self.url, data=update_data) + self.assertSuccess(resp) + data = resp.data["data"] + self.assertEqual(data["real_name"], "zemal") + self.assertEqual(data["submission_number"], 0) + + +class TwoFactorAuthAPITest(APITestCase): + def setUp(self): + self.url = self.reverse("two_factor_auth_api") + self.create_user("test", "test123") + self.create_website_config() + + def _get_tfa_code(self): + user = User.objects.first() + code = OtpAuth(user.tfa_token).totp() + if len(str(code)) < 6: + code = (6 - len(str(code))) * "0" + str(code) + return code + + def test_get_image(self): + resp = self.client.get(self.url) + self.assertSuccess(resp) + + def test_open_tfa_with_invalid_code(self): + self.test_get_image() + resp = self.client.post(self.url, data={"code": "000000"}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"}) + + def test_open_tfa_with_correct_code(self): + self.test_get_image() + code = self._get_tfa_code() + resp = self.client.post(self.url, data={"code": code}) + self.assertSuccess(resp) + user = User.objects.first() + self.assertEqual(user.two_factor_auth, True) + + def test_close_tfa_with_invalid_code(self): + self.test_open_tfa_with_correct_code() + resp = self.client.post(self.url, data={"code": "000000"}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"}) + + def test_close_tfa_with_correct_code(self): + self.test_open_tfa_with_correct_code() + code = self._get_tfa_code() + resp = self.client.put(self.url, data={"code": code}) + self.assertSuccess(resp) + user = User.objects.first() + self.assertEqual(user.two_factor_auth, False) + + +@mock.patch("account.views.oj.send_email_async.delay") +class ApplyResetPasswordAPITest(CaptchaTest): + def setUp(self): + self.create_user("test", "test123", login=False) + user = User.objects.first() + user.email = "test@oj.com" + user.save() + self.url = self.reverse("apply_reset_password_api") + self.create_website_config() + self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)} + + def _refresh_captcha(self): + self.data["captcha"] = self._set_captcha(self.client.session) + + def test_apply_reset_password(self, send_email_delay): + resp = self.client.post(self.url, data=self.data) + self.assertSuccess(resp) + send_email_delay.assert_called() + + def test_apply_reset_password_twice_in_20_mins(self, send_email_delay): + self.test_apply_reset_password() + send_email_delay.reset_mock() + self._refresh_captcha() + resp = self.client.post(self.url, data=self.data) + self.assertDictEqual(resp.data, {"error": "error", "data": "You can only reset password once per 20 minutes"}) + send_email_delay.assert_not_called() + + def test_apply_reset_password_again_after_20_mins(self, send_email_delay): + self.test_apply_reset_password() + user = User.objects.first() + user.reset_password_token_expire_time = now() - timedelta(minutes=21) + user.save() + self._refresh_captcha() + self.test_apply_reset_password() + + +class ResetPasswordAPITest(CaptchaTest): + def setUp(self): + self.create_user("test", "test123", login=False) + self.url = self.reverse("reset_password_api") + user = User.objects.first() + user.reset_password_token = "online_judge?" + user.reset_password_token_expire_time = now() + timedelta(minutes=20) + user.save() + self.data = {"token": user.reset_password_token, + "captcha": self._set_captcha(self.client.session), + "password": "test456"} + + def test_reset_password_with_correct_token(self): + resp = self.client.post(self.url, data=self.data) + self.assertSuccess(resp) + self.assertTrue(self.client.login(username="test", password="test456")) + + def test_reset_password_with_invalid_token(self): + self.data["token"] = "aaaaaaaaaaa" + resp = self.client.post(self.url, data=self.data) + self.assertDictEqual(resp.data, {"error": "error", "data": "Token dose not exist"}) + + def test_reset_password_with_expired_token(self): + user = User.objects.first() + user.reset_password_token_expire_time = now() - timedelta(seconds=30) + user.save() + resp = self.client.post(self.url, data=self.data) + self.assertDictEqual(resp.data, {"error": "error", "data": "Token have expired"}) + + class UserChangePasswordAPITest(CaptchaTest): def setUp(self): self.client = APIClient() @@ -248,3 +438,37 @@ class AdminUserTest(APITestCase): # if `openapi_app_key` is not None, the value is not changed self.assertTrue(resp_data["open_api"]) self.assertEqual(User.objects.get(id=self.regular_user.id).open_api_appkey, key) + + +class UserRankAPITest(APITestCase): + def setUp(self): + self.url = self.reverse("user_rank_api") + self.create_user("test1", "test123", login=False) + self.create_user("test2", "test123", login=False) + test1 = User.objects.get(username="test1") + profile1 = test1.userprofile + profile1.submission_number = 10 + profile1.accepted_number = 10 + profile1.total_score = 240 + profile1.save() + + test2 = User.objects.get(username="test2") + profile2 = test2.userprofile + profile2.submission_number = 15 + profile2.accepted_number = 10 + profile2.total_score = 700 + profile2.save() + + def test_get_acm_rank(self): + resp = self.client.get(self.url, data={"rule": "acm"}) + self.assertSuccess(resp) + data = resp.data["data"] + self.assertEqual(data[0]["user"]["username"], "test1") + self.assertEqual(data[1]["user"]["username"], "test2") + + def test_get_oi_rank(self): + resp = self.client.get(self.url, data={"rule": "oi"}) + self.assertSuccess(resp) + data = resp.data["data"] + self.assertEqual(data[0]["user"]["username"], "test2") + self.assertEqual(data[1]["user"]["username"], "test1") diff --git a/account/urls/oj.py b/account/urls/oj.py index e31a81ec..a9bf6423 100644 --- a/account/urls/oj.py +++ b/account/urls/oj.py @@ -4,7 +4,7 @@ from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI, UserChangePasswordAPI, UserRegisterAPI, UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck, SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI, - UserRankAPI) + UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI) from utils.captcha.views import CaptchaAPIView @@ -14,12 +14,14 @@ urlpatterns = [ url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"), url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"), url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"), - url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="apply_reset_password_api"), + url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"), url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"), url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"), url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"), url(r"^avatar/upload/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"), url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"), + url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"), url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"), url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"), + url(r"^sessions/?$", SessionManagementAPI.as_view(), name="session_management_api") ] diff --git a/account/views/oj.py b/account/views/oj.py index bc00ea11..e88ce9ed 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -14,7 +14,7 @@ from django.template.loader import render_to_string from conf.models import WebsiteConfig from utils.api import APIView, validate_serializer, CSRFExemptAPIView from utils.captcha import Captcha -from utils.shortcuts import rand_str, img2base64 +from utils.shortcuts import rand_str, img2base64, datetime2str from ..decorators import login_required from ..models import User, UserProfile @@ -77,7 +77,6 @@ class AvatarUploadAPI(CSRFExemptAPIView): with open(os.path.join(settings.IMAGE_UPLOAD_DIR, name), "wb") as img: for chunk in avatar: img.write(chunk) - print(os.path.join(settings.IMAGE_UPLOAD_DIR, name)) return self.success({"path": "/static/upload/" + name}) @@ -126,7 +125,7 @@ class TwoFactorAuthAPI(APIView): user.save() config = WebsiteConfig.objects.first() - label = f"{config.name_shortcut}:{user.username}@{config.base_url}" + label = f"{config.name_shortcut}:{user.username}" image = qrcode.make(OtpAuth(token).to_uri("totp", label, config.name)) return self.success(img2base64(image)) @@ -143,18 +142,38 @@ class TwoFactorAuthAPI(APIView): user.save() return self.success("Succeeded") else: - return self.error("Invalid captcha") + return self.error("Invalid code") @login_required @validate_serializer(TwoFactorAuthCodeSerializer) def put(self, request): code = request.data["code"] user = request.user + if not user.two_factor_auth: + return self.error("Other session have disabled TFA") if OtpAuth(user.tfa_token).valid_totp(code): user.two_factor_auth = False user.save() + return self.success("Succeeded") else: - return self.error("Invalid captcha") + return self.error("Invalid code") + + +class CheckTFARequiredAPI(APIView): + @validate_serializer(UsernameOrEmailCheckSerializer) + def post(self, request): + """ + Check TFA is required + """ + data = request.data + result = False + if data.get("username"): + try: + user = User.objects.get(username=data["username"]) + result = user.two_factor_auth + except User.DoesNotExist: + pass + return self.success({"result": result}) class UserLoginAPI(APIView): @@ -173,7 +192,7 @@ class UserLoginAPI(APIView): # `tfa_code` not in post data if user.two_factor_auth and "tfa_code" not in data: - return self.success("tfa_required") + return self.error("tfa_required") if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): auth.login(request, user) @@ -302,11 +321,55 @@ class ResetPasswordAPI(APIView): if int((user.reset_password_token_expire_time - now()).total_seconds()) < 0: return self.error("Token have expired") user.reset_password_token = None + user.two_factor_auth = False user.set_password(data["password"]) user.save() return self.success("Succeeded") +class SessionManagementAPI(APIView): + @login_required + def get(self, request): + engine = import_module(settings.SESSION_ENGINE) + SessionStore = engine.SessionStore + current_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME) + session_keys = request.user.session_keys + result = [] + modified = False + for key in session_keys[:]: + session = SessionStore(key) + # session does not exist or is expiry + if not session._session: + session_keys.remove(key) + modified = True + continue + + s = {} + if current_session == key: + s["current_session"] = True + s["ip"] = session["ip"] + s["user_agent"] = session["user_agent"] + s["last_login"] = datetime2str(session["last_login"]) + s["session_key"] = key + result.append(s) + if modified: + request.user.save() + return self.success(result) + + @login_required + def delete(self, request): + session_key = request.GET.get("session_key") + if not session_key: + return self.error("Parameter Error") + request.session.delete(session_key) + if session_key in request.user.session_keys: + request.user.session_keys.remove(session_key) + request.user.save() + return self.success("Succeeded") + else: + return self.error("Invalid session_key") + + class UserRankAPI(APIView): def get(self, request): rule_type = request.GET.get("rule") diff --git a/utils/api/_serializers.py b/utils/api/_serializers.py index 816845af..737a9656 100644 --- a/utils/api/_serializers.py +++ b/utils/api/_serializers.py @@ -1,11 +1,9 @@ -from django.utils import timezone from rest_framework import serializers class DateTimeTZField(serializers.DateTimeField): def to_representation(self, value): - # self.format = "%Y-%m-%d %H:%M:%S %Z" - value = timezone.localtime(value) + # value = timezone.localtime(value) return super(DateTimeTZField, self).to_representation(value) diff --git a/utils/api/tests.py b/utils/api/tests.py index 9f9d997c..3d9cc306 100644 --- a/utils/api/tests.py +++ b/utils/api/tests.py @@ -3,12 +3,14 @@ from django.test.testcases import TestCase from rest_framework.test import APIClient from account.models import AdminType, ProblemPermission, User, UserProfile +from conf.models import WebsiteConfig class APITestCase(TestCase): client_class = APIClient - def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True, problem_permission=ProblemPermission.NONE): + def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True, + problem_permission=ProblemPermission.NONE): user = User.objects.create(username=username, admin_type=admin_type, problem_permission=problem_permission) user.set_password(password) UserProfile.objects.create(user=user) @@ -18,13 +20,17 @@ class APITestCase(TestCase): return user def create_admin(self, username="admin", password="admin", login=True): - return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, problem_permission=ProblemPermission.OWN, + return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, + problem_permission=ProblemPermission.OWN, login=login) def create_super_admin(self, username="root", password="root", login=True): return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, problem_permission=ProblemPermission.ALL, login=login) + def create_website_config(self): + return WebsiteConfig.objects.create() + def reverse(self, url_name): return reverse(url_name) diff --git a/utils/shortcuts.py b/utils/shortcuts.py index e2cc4707..01ddd801 100644 --- a/utils/shortcuts.py +++ b/utils/shortcuts.py @@ -69,3 +69,12 @@ def img2base64(img): img_prefix = "data:image/png;base64," b64_str = img_prefix + b64encode(buf_str).decode("utf-8") return b64_str + + +def datetime2str(value, format="iso-8601"): + if format.lower() == "iso-8601": + value = value.isoformat() + if value.endswith("+00:00"): + value = value[:-6] + "Z" + return value + return value.strftime(format)