mirror of
https://github.com/QingdaoU/OnlineJudge.git
synced 2024-12-28 16:12:13 +00:00
add session management api; add more unit tests for account module
This commit is contained in:
parent
a3ca8b2336
commit
1ee0596a3a
1
.flake8
1
.flake8
@ -3,6 +3,7 @@ exclude =
|
||||
xss_filter.py,
|
||||
*/migrations/,
|
||||
*settings.py
|
||||
*/apps.py
|
||||
max-line-length = 180
|
||||
inline-quotes = "
|
||||
no-accept-encodings = True
|
||||
|
@ -1 +1 @@
|
||||
default_app_config = 'account.apps.ProfilesConfig'
|
||||
default_app_config = "account.apps.ProfilesConfig"
|
||||
|
@ -5,8 +5,8 @@ from django.contrib.auth.signals import user_logged_in, user_logged_out
|
||||
|
||||
@receiver(user_logged_in)
|
||||
def add_user_session(sender, request, user, **kwargs):
|
||||
request.session["ip"] = request.META.get('REMOTE_ADDR', '')
|
||||
request.session["user_agent"] = request.META.get('HTTP_USER_AGENT', '')
|
||||
request.session["ip"] = request.META.get("REMOTE_ADDR", "")
|
||||
request.session["user_agent"] = request.META.get("HTTP_USER_AGENT", "")
|
||||
request.session["last_login"] = now()
|
||||
if request.session.session_key not in user.session_keys:
|
||||
user.session_keys.append(request.session.session_key)
|
||||
|
226
account/tests.py
226
account/tests.py
@ -1,7 +1,9 @@
|
||||
import time
|
||||
from unittest import mock
|
||||
from datetime import timedelta
|
||||
|
||||
from django.contrib import auth
|
||||
from django.utils.timezone import now
|
||||
from otpauth import OtpAuth
|
||||
|
||||
from utils.api.tests import APIClient, APITestCase
|
||||
@ -28,6 +30,40 @@ class PermissionDecoratorTest(APITestCase):
|
||||
pass
|
||||
|
||||
|
||||
class DuplicateUserCheckAPITest(APITestCase):
|
||||
def setUp(self):
|
||||
self.create_user("test", "test123", login=False)
|
||||
self.url = self.reverse("check_username_or_email")
|
||||
|
||||
def test_duplicate_username(self):
|
||||
resp = self.client.post(self.url, data={"username": "test"})
|
||||
data = resp.data["data"]
|
||||
self.assertEqual(data["username"], True)
|
||||
|
||||
def test_ok_username(self):
|
||||
resp = self.client.post(self.url, data={"username": "test1"})
|
||||
data = resp.data["data"]
|
||||
self.assertEqual(data["username"], False)
|
||||
|
||||
|
||||
class TFARequiredCheckAPITest(APITestCase):
|
||||
def setUp(self):
|
||||
self.url = self.reverse("tfa_required_check")
|
||||
self.create_user("test", "test123", login=False)
|
||||
|
||||
def test_not_required_tfa(self):
|
||||
resp = self.client.post(self.url, data={"username": "test"})
|
||||
self.assertSuccess(resp)
|
||||
self.assertEqual(resp.data["data"]["result"], False)
|
||||
|
||||
def test_required_tfa(self):
|
||||
user = User.objects.first()
|
||||
user.two_factor_auth = True
|
||||
user.save()
|
||||
resp = self.client.post(self.url, data={"username": "test"})
|
||||
self.assertEqual(resp.data["data"]["result"], True)
|
||||
|
||||
|
||||
class UserLoginAPITest(APITestCase):
|
||||
def setUp(self):
|
||||
self.username = self.password = "test"
|
||||
@ -87,7 +123,7 @@ class UserLoginAPITest(APITestCase):
|
||||
response = self.client.post(self.login_url,
|
||||
data={"username": self.username,
|
||||
"password": self.password})
|
||||
self.assertDictEqual(response.data, {"error": None, "data": "tfa_required"})
|
||||
self.assertDictEqual(response.data, {"error": "error", "data": "tfa_required"})
|
||||
|
||||
user = auth.get_user(self.client)
|
||||
self.assertFalse(user.is_authenticated())
|
||||
@ -142,6 +178,160 @@ class UserRegisterAPITest(CaptchaTest):
|
||||
self.assertDictEqual(response.data, {"error": "error", "data": "Email already exists"})
|
||||
|
||||
|
||||
class SessionManagementAPITest(APITestCase):
|
||||
def setUp(self):
|
||||
self.create_user("test", "test123")
|
||||
self.url = self.reverse("session_management_api")
|
||||
|
||||
def test_get_sessions(self):
|
||||
resp = self.client.get(self.url)
|
||||
self.assertSuccess(resp)
|
||||
data = resp.data["data"]
|
||||
self.assertEqual(len(data), 1)
|
||||
|
||||
def test_delete_session_key(self):
|
||||
# resp = self.client.delete(self.url, data={"session_key": self.client.session.session_key})
|
||||
resp = self.client.delete(self.url + "?session_key=" + self.client.session.session_key)
|
||||
self.assertSuccess(resp)
|
||||
|
||||
def test_delete_session_with_invalid_key(self):
|
||||
resp = self.client.delete(self.url + "?session_key=aaaaaaaaaa")
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid session_key"})
|
||||
|
||||
|
||||
class UserProfileAPITest(APITestCase):
|
||||
def setUp(self):
|
||||
self.url = self.reverse("user_profile_api")
|
||||
|
||||
def test_get_profile_without_login(self):
|
||||
resp = self.client.get(self.url)
|
||||
self.assertDictEqual(resp.data, {"error": None, "data": 0})
|
||||
|
||||
def test_get_profile(self):
|
||||
self.create_user("test", "test123")
|
||||
resp = self.client.get(self.url)
|
||||
self.assertSuccess(resp)
|
||||
|
||||
def test_update_profile(self):
|
||||
self.create_user("test", "test123")
|
||||
update_data = {"real_name": "zemal", "submission_number": 233}
|
||||
resp = self.client.put(self.url, data=update_data)
|
||||
self.assertSuccess(resp)
|
||||
data = resp.data["data"]
|
||||
self.assertEqual(data["real_name"], "zemal")
|
||||
self.assertEqual(data["submission_number"], 0)
|
||||
|
||||
|
||||
class TwoFactorAuthAPITest(APITestCase):
|
||||
def setUp(self):
|
||||
self.url = self.reverse("two_factor_auth_api")
|
||||
self.create_user("test", "test123")
|
||||
self.create_website_config()
|
||||
|
||||
def _get_tfa_code(self):
|
||||
user = User.objects.first()
|
||||
code = OtpAuth(user.tfa_token).totp()
|
||||
if len(str(code)) < 6:
|
||||
code = (6 - len(str(code))) * "0" + str(code)
|
||||
return code
|
||||
|
||||
def test_get_image(self):
|
||||
resp = self.client.get(self.url)
|
||||
self.assertSuccess(resp)
|
||||
|
||||
def test_open_tfa_with_invalid_code(self):
|
||||
self.test_get_image()
|
||||
resp = self.client.post(self.url, data={"code": "000000"})
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"})
|
||||
|
||||
def test_open_tfa_with_correct_code(self):
|
||||
self.test_get_image()
|
||||
code = self._get_tfa_code()
|
||||
resp = self.client.post(self.url, data={"code": code})
|
||||
self.assertSuccess(resp)
|
||||
user = User.objects.first()
|
||||
self.assertEqual(user.two_factor_auth, True)
|
||||
|
||||
def test_close_tfa_with_invalid_code(self):
|
||||
self.test_open_tfa_with_correct_code()
|
||||
resp = self.client.post(self.url, data={"code": "000000"})
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"})
|
||||
|
||||
def test_close_tfa_with_correct_code(self):
|
||||
self.test_open_tfa_with_correct_code()
|
||||
code = self._get_tfa_code()
|
||||
resp = self.client.put(self.url, data={"code": code})
|
||||
self.assertSuccess(resp)
|
||||
user = User.objects.first()
|
||||
self.assertEqual(user.two_factor_auth, False)
|
||||
|
||||
|
||||
@mock.patch("account.views.oj.send_email_async.delay")
|
||||
class ApplyResetPasswordAPITest(CaptchaTest):
|
||||
def setUp(self):
|
||||
self.create_user("test", "test123", login=False)
|
||||
user = User.objects.first()
|
||||
user.email = "test@oj.com"
|
||||
user.save()
|
||||
self.url = self.reverse("apply_reset_password_api")
|
||||
self.create_website_config()
|
||||
self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)}
|
||||
|
||||
def _refresh_captcha(self):
|
||||
self.data["captcha"] = self._set_captcha(self.client.session)
|
||||
|
||||
def test_apply_reset_password(self, send_email_delay):
|
||||
resp = self.client.post(self.url, data=self.data)
|
||||
self.assertSuccess(resp)
|
||||
send_email_delay.assert_called()
|
||||
|
||||
def test_apply_reset_password_twice_in_20_mins(self, send_email_delay):
|
||||
self.test_apply_reset_password()
|
||||
send_email_delay.reset_mock()
|
||||
self._refresh_captcha()
|
||||
resp = self.client.post(self.url, data=self.data)
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "You can only reset password once per 20 minutes"})
|
||||
send_email_delay.assert_not_called()
|
||||
|
||||
def test_apply_reset_password_again_after_20_mins(self, send_email_delay):
|
||||
self.test_apply_reset_password()
|
||||
user = User.objects.first()
|
||||
user.reset_password_token_expire_time = now() - timedelta(minutes=21)
|
||||
user.save()
|
||||
self._refresh_captcha()
|
||||
self.test_apply_reset_password()
|
||||
|
||||
|
||||
class ResetPasswordAPITest(CaptchaTest):
|
||||
def setUp(self):
|
||||
self.create_user("test", "test123", login=False)
|
||||
self.url = self.reverse("reset_password_api")
|
||||
user = User.objects.first()
|
||||
user.reset_password_token = "online_judge?"
|
||||
user.reset_password_token_expire_time = now() + timedelta(minutes=20)
|
||||
user.save()
|
||||
self.data = {"token": user.reset_password_token,
|
||||
"captcha": self._set_captcha(self.client.session),
|
||||
"password": "test456"}
|
||||
|
||||
def test_reset_password_with_correct_token(self):
|
||||
resp = self.client.post(self.url, data=self.data)
|
||||
self.assertSuccess(resp)
|
||||
self.assertTrue(self.client.login(username="test", password="test456"))
|
||||
|
||||
def test_reset_password_with_invalid_token(self):
|
||||
self.data["token"] = "aaaaaaaaaaa"
|
||||
resp = self.client.post(self.url, data=self.data)
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Token dose not exist"})
|
||||
|
||||
def test_reset_password_with_expired_token(self):
|
||||
user = User.objects.first()
|
||||
user.reset_password_token_expire_time = now() - timedelta(seconds=30)
|
||||
user.save()
|
||||
resp = self.client.post(self.url, data=self.data)
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Token have expired"})
|
||||
|
||||
|
||||
class UserChangePasswordAPITest(CaptchaTest):
|
||||
def setUp(self):
|
||||
self.client = APIClient()
|
||||
@ -248,3 +438,37 @@ class AdminUserTest(APITestCase):
|
||||
# if `openapi_app_key` is not None, the value is not changed
|
||||
self.assertTrue(resp_data["open_api"])
|
||||
self.assertEqual(User.objects.get(id=self.regular_user.id).open_api_appkey, key)
|
||||
|
||||
|
||||
class UserRankAPITest(APITestCase):
|
||||
def setUp(self):
|
||||
self.url = self.reverse("user_rank_api")
|
||||
self.create_user("test1", "test123", login=False)
|
||||
self.create_user("test2", "test123", login=False)
|
||||
test1 = User.objects.get(username="test1")
|
||||
profile1 = test1.userprofile
|
||||
profile1.submission_number = 10
|
||||
profile1.accepted_number = 10
|
||||
profile1.total_score = 240
|
||||
profile1.save()
|
||||
|
||||
test2 = User.objects.get(username="test2")
|
||||
profile2 = test2.userprofile
|
||||
profile2.submission_number = 15
|
||||
profile2.accepted_number = 10
|
||||
profile2.total_score = 700
|
||||
profile2.save()
|
||||
|
||||
def test_get_acm_rank(self):
|
||||
resp = self.client.get(self.url, data={"rule": "acm"})
|
||||
self.assertSuccess(resp)
|
||||
data = resp.data["data"]
|
||||
self.assertEqual(data[0]["user"]["username"], "test1")
|
||||
self.assertEqual(data[1]["user"]["username"], "test2")
|
||||
|
||||
def test_get_oi_rank(self):
|
||||
resp = self.client.get(self.url, data={"rule": "oi"})
|
||||
self.assertSuccess(resp)
|
||||
data = resp.data["data"]
|
||||
self.assertEqual(data[0]["user"]["username"], "test2")
|
||||
self.assertEqual(data[1]["user"]["username"], "test1")
|
||||
|
@ -4,7 +4,7 @@ from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
|
||||
UserChangePasswordAPI, UserRegisterAPI,
|
||||
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
|
||||
SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
|
||||
UserRankAPI)
|
||||
UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI)
|
||||
|
||||
from utils.captcha.views import CaptchaAPIView
|
||||
|
||||
@ -14,12 +14,14 @@ urlpatterns = [
|
||||
url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"),
|
||||
url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"),
|
||||
url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"),
|
||||
url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="apply_reset_password_api"),
|
||||
url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"),
|
||||
url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"),
|
||||
url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"),
|
||||
url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"),
|
||||
url(r"^avatar/upload/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"),
|
||||
url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"),
|
||||
url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"),
|
||||
url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"),
|
||||
url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"),
|
||||
url(r"^sessions/?$", SessionManagementAPI.as_view(), name="session_management_api")
|
||||
]
|
||||
|
@ -14,7 +14,7 @@ from django.template.loader import render_to_string
|
||||
from conf.models import WebsiteConfig
|
||||
from utils.api import APIView, validate_serializer, CSRFExemptAPIView
|
||||
from utils.captcha import Captcha
|
||||
from utils.shortcuts import rand_str, img2base64
|
||||
from utils.shortcuts import rand_str, img2base64, datetime2str
|
||||
|
||||
from ..decorators import login_required
|
||||
from ..models import User, UserProfile
|
||||
@ -77,7 +77,6 @@ class AvatarUploadAPI(CSRFExemptAPIView):
|
||||
with open(os.path.join(settings.IMAGE_UPLOAD_DIR, name), "wb") as img:
|
||||
for chunk in avatar:
|
||||
img.write(chunk)
|
||||
print(os.path.join(settings.IMAGE_UPLOAD_DIR, name))
|
||||
return self.success({"path": "/static/upload/" + name})
|
||||
|
||||
|
||||
@ -126,7 +125,7 @@ class TwoFactorAuthAPI(APIView):
|
||||
user.save()
|
||||
|
||||
config = WebsiteConfig.objects.first()
|
||||
label = f"{config.name_shortcut}:{user.username}@{config.base_url}"
|
||||
label = f"{config.name_shortcut}:{user.username}"
|
||||
image = qrcode.make(OtpAuth(token).to_uri("totp", label, config.name))
|
||||
return self.success(img2base64(image))
|
||||
|
||||
@ -143,18 +142,38 @@ class TwoFactorAuthAPI(APIView):
|
||||
user.save()
|
||||
return self.success("Succeeded")
|
||||
else:
|
||||
return self.error("Invalid captcha")
|
||||
return self.error("Invalid code")
|
||||
|
||||
@login_required
|
||||
@validate_serializer(TwoFactorAuthCodeSerializer)
|
||||
def put(self, request):
|
||||
code = request.data["code"]
|
||||
user = request.user
|
||||
if not user.two_factor_auth:
|
||||
return self.error("Other session have disabled TFA")
|
||||
if OtpAuth(user.tfa_token).valid_totp(code):
|
||||
user.two_factor_auth = False
|
||||
user.save()
|
||||
return self.success("Succeeded")
|
||||
else:
|
||||
return self.error("Invalid captcha")
|
||||
return self.error("Invalid code")
|
||||
|
||||
|
||||
class CheckTFARequiredAPI(APIView):
|
||||
@validate_serializer(UsernameOrEmailCheckSerializer)
|
||||
def post(self, request):
|
||||
"""
|
||||
Check TFA is required
|
||||
"""
|
||||
data = request.data
|
||||
result = False
|
||||
if data.get("username"):
|
||||
try:
|
||||
user = User.objects.get(username=data["username"])
|
||||
result = user.two_factor_auth
|
||||
except User.DoesNotExist:
|
||||
pass
|
||||
return self.success({"result": result})
|
||||
|
||||
|
||||
class UserLoginAPI(APIView):
|
||||
@ -173,7 +192,7 @@ class UserLoginAPI(APIView):
|
||||
|
||||
# `tfa_code` not in post data
|
||||
if user.two_factor_auth and "tfa_code" not in data:
|
||||
return self.success("tfa_required")
|
||||
return self.error("tfa_required")
|
||||
|
||||
if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
|
||||
auth.login(request, user)
|
||||
@ -302,11 +321,55 @@ class ResetPasswordAPI(APIView):
|
||||
if int((user.reset_password_token_expire_time - now()).total_seconds()) < 0:
|
||||
return self.error("Token have expired")
|
||||
user.reset_password_token = None
|
||||
user.two_factor_auth = False
|
||||
user.set_password(data["password"])
|
||||
user.save()
|
||||
return self.success("Succeeded")
|
||||
|
||||
|
||||
class SessionManagementAPI(APIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
SessionStore = engine.SessionStore
|
||||
current_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
|
||||
session_keys = request.user.session_keys
|
||||
result = []
|
||||
modified = False
|
||||
for key in session_keys[:]:
|
||||
session = SessionStore(key)
|
||||
# session does not exist or is expiry
|
||||
if not session._session:
|
||||
session_keys.remove(key)
|
||||
modified = True
|
||||
continue
|
||||
|
||||
s = {}
|
||||
if current_session == key:
|
||||
s["current_session"] = True
|
||||
s["ip"] = session["ip"]
|
||||
s["user_agent"] = session["user_agent"]
|
||||
s["last_login"] = datetime2str(session["last_login"])
|
||||
s["session_key"] = key
|
||||
result.append(s)
|
||||
if modified:
|
||||
request.user.save()
|
||||
return self.success(result)
|
||||
|
||||
@login_required
|
||||
def delete(self, request):
|
||||
session_key = request.GET.get("session_key")
|
||||
if not session_key:
|
||||
return self.error("Parameter Error")
|
||||
request.session.delete(session_key)
|
||||
if session_key in request.user.session_keys:
|
||||
request.user.session_keys.remove(session_key)
|
||||
request.user.save()
|
||||
return self.success("Succeeded")
|
||||
else:
|
||||
return self.error("Invalid session_key")
|
||||
|
||||
|
||||
class UserRankAPI(APIView):
|
||||
def get(self, request):
|
||||
rule_type = request.GET.get("rule")
|
||||
|
@ -1,11 +1,9 @@
|
||||
from django.utils import timezone
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class DateTimeTZField(serializers.DateTimeField):
|
||||
def to_representation(self, value):
|
||||
# self.format = "%Y-%m-%d %H:%M:%S %Z"
|
||||
value = timezone.localtime(value)
|
||||
# value = timezone.localtime(value)
|
||||
return super(DateTimeTZField, self).to_representation(value)
|
||||
|
||||
|
||||
|
@ -3,12 +3,14 @@ from django.test.testcases import TestCase
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from account.models import AdminType, ProblemPermission, User, UserProfile
|
||||
from conf.models import WebsiteConfig
|
||||
|
||||
|
||||
class APITestCase(TestCase):
|
||||
client_class = APIClient
|
||||
|
||||
def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True, problem_permission=ProblemPermission.NONE):
|
||||
def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True,
|
||||
problem_permission=ProblemPermission.NONE):
|
||||
user = User.objects.create(username=username, admin_type=admin_type, problem_permission=problem_permission)
|
||||
user.set_password(password)
|
||||
UserProfile.objects.create(user=user)
|
||||
@ -18,13 +20,17 @@ class APITestCase(TestCase):
|
||||
return user
|
||||
|
||||
def create_admin(self, username="admin", password="admin", login=True):
|
||||
return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, problem_permission=ProblemPermission.OWN,
|
||||
return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN,
|
||||
problem_permission=ProblemPermission.OWN,
|
||||
login=login)
|
||||
|
||||
def create_super_admin(self, username="root", password="root", login=True):
|
||||
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN,
|
||||
problem_permission=ProblemPermission.ALL, login=login)
|
||||
|
||||
def create_website_config(self):
|
||||
return WebsiteConfig.objects.create()
|
||||
|
||||
def reverse(self, url_name):
|
||||
return reverse(url_name)
|
||||
|
||||
|
@ -69,3 +69,12 @@ def img2base64(img):
|
||||
img_prefix = "data:image/png;base64,"
|
||||
b64_str = img_prefix + b64encode(buf_str).decode("utf-8")
|
||||
return b64_str
|
||||
|
||||
|
||||
def datetime2str(value, format="iso-8601"):
|
||||
if format.lower() == "iso-8601":
|
||||
value = value.isoformat()
|
||||
if value.endswith("+00:00"):
|
||||
value = value[:-6] + "Z"
|
||||
return value
|
||||
return value.strftime(format)
|
||||
|
Loading…
Reference in New Issue
Block a user