add session management api; add more unit tests for account module

This commit is contained in:
zema1 2017-09-16 10:38:49 +08:00
parent a3ca8b2336
commit 1ee0596a3a
9 changed files with 320 additions and 17 deletions

View File

@ -3,6 +3,7 @@ exclude =
xss_filter.py,
*/migrations/,
*settings.py
*/apps.py
max-line-length = 180
inline-quotes = "
no-accept-encodings = True

View File

@ -1 +1 @@
default_app_config = 'account.apps.ProfilesConfig'
default_app_config = "account.apps.ProfilesConfig"

View File

@ -5,8 +5,8 @@ from django.contrib.auth.signals import user_logged_in, user_logged_out
@receiver(user_logged_in)
def add_user_session(sender, request, user, **kwargs):
request.session["ip"] = request.META.get('REMOTE_ADDR', '')
request.session["user_agent"] = request.META.get('HTTP_USER_AGENT', '')
request.session["ip"] = request.META.get("REMOTE_ADDR", "")
request.session["user_agent"] = request.META.get("HTTP_USER_AGENT", "")
request.session["last_login"] = now()
if request.session.session_key not in user.session_keys:
user.session_keys.append(request.session.session_key)

View File

@ -1,7 +1,9 @@
import time
from unittest import mock
from datetime import timedelta
from django.contrib import auth
from django.utils.timezone import now
from otpauth import OtpAuth
from utils.api.tests import APIClient, APITestCase
@ -28,6 +30,40 @@ class PermissionDecoratorTest(APITestCase):
pass
class DuplicateUserCheckAPITest(APITestCase):
def setUp(self):
self.create_user("test", "test123", login=False)
self.url = self.reverse("check_username_or_email")
def test_duplicate_username(self):
resp = self.client.post(self.url, data={"username": "test"})
data = resp.data["data"]
self.assertEqual(data["username"], True)
def test_ok_username(self):
resp = self.client.post(self.url, data={"username": "test1"})
data = resp.data["data"]
self.assertEqual(data["username"], False)
class TFARequiredCheckAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("tfa_required_check")
self.create_user("test", "test123", login=False)
def test_not_required_tfa(self):
resp = self.client.post(self.url, data={"username": "test"})
self.assertSuccess(resp)
self.assertEqual(resp.data["data"]["result"], False)
def test_required_tfa(self):
user = User.objects.first()
user.two_factor_auth = True
user.save()
resp = self.client.post(self.url, data={"username": "test"})
self.assertEqual(resp.data["data"]["result"], True)
class UserLoginAPITest(APITestCase):
def setUp(self):
self.username = self.password = "test"
@ -87,7 +123,7 @@ class UserLoginAPITest(APITestCase):
response = self.client.post(self.login_url,
data={"username": self.username,
"password": self.password})
self.assertDictEqual(response.data, {"error": None, "data": "tfa_required"})
self.assertDictEqual(response.data, {"error": "error", "data": "tfa_required"})
user = auth.get_user(self.client)
self.assertFalse(user.is_authenticated())
@ -142,6 +178,160 @@ class UserRegisterAPITest(CaptchaTest):
self.assertDictEqual(response.data, {"error": "error", "data": "Email already exists"})
class SessionManagementAPITest(APITestCase):
def setUp(self):
self.create_user("test", "test123")
self.url = self.reverse("session_management_api")
def test_get_sessions(self):
resp = self.client.get(self.url)
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(len(data), 1)
def test_delete_session_key(self):
# resp = self.client.delete(self.url, data={"session_key": self.client.session.session_key})
resp = self.client.delete(self.url + "?session_key=" + self.client.session.session_key)
self.assertSuccess(resp)
def test_delete_session_with_invalid_key(self):
resp = self.client.delete(self.url + "?session_key=aaaaaaaaaa")
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid session_key"})
class UserProfileAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("user_profile_api")
def test_get_profile_without_login(self):
resp = self.client.get(self.url)
self.assertDictEqual(resp.data, {"error": None, "data": 0})
def test_get_profile(self):
self.create_user("test", "test123")
resp = self.client.get(self.url)
self.assertSuccess(resp)
def test_update_profile(self):
self.create_user("test", "test123")
update_data = {"real_name": "zemal", "submission_number": 233}
resp = self.client.put(self.url, data=update_data)
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data["real_name"], "zemal")
self.assertEqual(data["submission_number"], 0)
class TwoFactorAuthAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("two_factor_auth_api")
self.create_user("test", "test123")
self.create_website_config()
def _get_tfa_code(self):
user = User.objects.first()
code = OtpAuth(user.tfa_token).totp()
if len(str(code)) < 6:
code = (6 - len(str(code))) * "0" + str(code)
return code
def test_get_image(self):
resp = self.client.get(self.url)
self.assertSuccess(resp)
def test_open_tfa_with_invalid_code(self):
self.test_get_image()
resp = self.client.post(self.url, data={"code": "000000"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"})
def test_open_tfa_with_correct_code(self):
self.test_get_image()
code = self._get_tfa_code()
resp = self.client.post(self.url, data={"code": code})
self.assertSuccess(resp)
user = User.objects.first()
self.assertEqual(user.two_factor_auth, True)
def test_close_tfa_with_invalid_code(self):
self.test_open_tfa_with_correct_code()
resp = self.client.post(self.url, data={"code": "000000"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"})
def test_close_tfa_with_correct_code(self):
self.test_open_tfa_with_correct_code()
code = self._get_tfa_code()
resp = self.client.put(self.url, data={"code": code})
self.assertSuccess(resp)
user = User.objects.first()
self.assertEqual(user.two_factor_auth, False)
@mock.patch("account.views.oj.send_email_async.delay")
class ApplyResetPasswordAPITest(CaptchaTest):
def setUp(self):
self.create_user("test", "test123", login=False)
user = User.objects.first()
user.email = "test@oj.com"
user.save()
self.url = self.reverse("apply_reset_password_api")
self.create_website_config()
self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)}
def _refresh_captcha(self):
self.data["captcha"] = self._set_captcha(self.client.session)
def test_apply_reset_password(self, send_email_delay):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
send_email_delay.assert_called()
def test_apply_reset_password_twice_in_20_mins(self, send_email_delay):
self.test_apply_reset_password()
send_email_delay.reset_mock()
self._refresh_captcha()
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "You can only reset password once per 20 minutes"})
send_email_delay.assert_not_called()
def test_apply_reset_password_again_after_20_mins(self, send_email_delay):
self.test_apply_reset_password()
user = User.objects.first()
user.reset_password_token_expire_time = now() - timedelta(minutes=21)
user.save()
self._refresh_captcha()
self.test_apply_reset_password()
class ResetPasswordAPITest(CaptchaTest):
def setUp(self):
self.create_user("test", "test123", login=False)
self.url = self.reverse("reset_password_api")
user = User.objects.first()
user.reset_password_token = "online_judge?"
user.reset_password_token_expire_time = now() + timedelta(minutes=20)
user.save()
self.data = {"token": user.reset_password_token,
"captcha": self._set_captcha(self.client.session),
"password": "test456"}
def test_reset_password_with_correct_token(self):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
self.assertTrue(self.client.login(username="test", password="test456"))
def test_reset_password_with_invalid_token(self):
self.data["token"] = "aaaaaaaaaaa"
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Token dose not exist"})
def test_reset_password_with_expired_token(self):
user = User.objects.first()
user.reset_password_token_expire_time = now() - timedelta(seconds=30)
user.save()
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Token have expired"})
class UserChangePasswordAPITest(CaptchaTest):
def setUp(self):
self.client = APIClient()
@ -248,3 +438,37 @@ class AdminUserTest(APITestCase):
# if `openapi_app_key` is not None, the value is not changed
self.assertTrue(resp_data["open_api"])
self.assertEqual(User.objects.get(id=self.regular_user.id).open_api_appkey, key)
class UserRankAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("user_rank_api")
self.create_user("test1", "test123", login=False)
self.create_user("test2", "test123", login=False)
test1 = User.objects.get(username="test1")
profile1 = test1.userprofile
profile1.submission_number = 10
profile1.accepted_number = 10
profile1.total_score = 240
profile1.save()
test2 = User.objects.get(username="test2")
profile2 = test2.userprofile
profile2.submission_number = 15
profile2.accepted_number = 10
profile2.total_score = 700
profile2.save()
def test_get_acm_rank(self):
resp = self.client.get(self.url, data={"rule": "acm"})
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data[0]["user"]["username"], "test1")
self.assertEqual(data[1]["user"]["username"], "test2")
def test_get_oi_rank(self):
resp = self.client.get(self.url, data={"rule": "oi"})
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data[0]["user"]["username"], "test2")
self.assertEqual(data[1]["user"]["username"], "test1")

View File

@ -4,7 +4,7 @@ from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
UserChangePasswordAPI, UserRegisterAPI,
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
UserRankAPI)
UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI)
from utils.captcha.views import CaptchaAPIView
@ -14,12 +14,14 @@ urlpatterns = [
url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"),
url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"),
url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"),
url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="apply_reset_password_api"),
url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"),
url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"),
url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"),
url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"),
url(r"^avatar/upload/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"),
url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"),
url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"),
url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"),
url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"),
url(r"^sessions/?$", SessionManagementAPI.as_view(), name="session_management_api")
]

View File

@ -14,7 +14,7 @@ from django.template.loader import render_to_string
from conf.models import WebsiteConfig
from utils.api import APIView, validate_serializer, CSRFExemptAPIView
from utils.captcha import Captcha
from utils.shortcuts import rand_str, img2base64
from utils.shortcuts import rand_str, img2base64, datetime2str
from ..decorators import login_required
from ..models import User, UserProfile
@ -77,7 +77,6 @@ class AvatarUploadAPI(CSRFExemptAPIView):
with open(os.path.join(settings.IMAGE_UPLOAD_DIR, name), "wb") as img:
for chunk in avatar:
img.write(chunk)
print(os.path.join(settings.IMAGE_UPLOAD_DIR, name))
return self.success({"path": "/static/upload/" + name})
@ -126,7 +125,7 @@ class TwoFactorAuthAPI(APIView):
user.save()
config = WebsiteConfig.objects.first()
label = f"{config.name_shortcut}:{user.username}@{config.base_url}"
label = f"{config.name_shortcut}:{user.username}"
image = qrcode.make(OtpAuth(token).to_uri("totp", label, config.name))
return self.success(img2base64(image))
@ -143,18 +142,38 @@ class TwoFactorAuthAPI(APIView):
user.save()
return self.success("Succeeded")
else:
return self.error("Invalid captcha")
return self.error("Invalid code")
@login_required
@validate_serializer(TwoFactorAuthCodeSerializer)
def put(self, request):
code = request.data["code"]
user = request.user
if not user.two_factor_auth:
return self.error("Other session have disabled TFA")
if OtpAuth(user.tfa_token).valid_totp(code):
user.two_factor_auth = False
user.save()
return self.success("Succeeded")
else:
return self.error("Invalid captcha")
return self.error("Invalid code")
class CheckTFARequiredAPI(APIView):
@validate_serializer(UsernameOrEmailCheckSerializer)
def post(self, request):
"""
Check TFA is required
"""
data = request.data
result = False
if data.get("username"):
try:
user = User.objects.get(username=data["username"])
result = user.two_factor_auth
except User.DoesNotExist:
pass
return self.success({"result": result})
class UserLoginAPI(APIView):
@ -173,7 +192,7 @@ class UserLoginAPI(APIView):
# `tfa_code` not in post data
if user.two_factor_auth and "tfa_code" not in data:
return self.success("tfa_required")
return self.error("tfa_required")
if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
auth.login(request, user)
@ -302,11 +321,55 @@ class ResetPasswordAPI(APIView):
if int((user.reset_password_token_expire_time - now()).total_seconds()) < 0:
return self.error("Token have expired")
user.reset_password_token = None
user.two_factor_auth = False
user.set_password(data["password"])
user.save()
return self.success("Succeeded")
class SessionManagementAPI(APIView):
@login_required
def get(self, request):
engine = import_module(settings.SESSION_ENGINE)
SessionStore = engine.SessionStore
current_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
session_keys = request.user.session_keys
result = []
modified = False
for key in session_keys[:]:
session = SessionStore(key)
# session does not exist or is expiry
if not session._session:
session_keys.remove(key)
modified = True
continue
s = {}
if current_session == key:
s["current_session"] = True
s["ip"] = session["ip"]
s["user_agent"] = session["user_agent"]
s["last_login"] = datetime2str(session["last_login"])
s["session_key"] = key
result.append(s)
if modified:
request.user.save()
return self.success(result)
@login_required
def delete(self, request):
session_key = request.GET.get("session_key")
if not session_key:
return self.error("Parameter Error")
request.session.delete(session_key)
if session_key in request.user.session_keys:
request.user.session_keys.remove(session_key)
request.user.save()
return self.success("Succeeded")
else:
return self.error("Invalid session_key")
class UserRankAPI(APIView):
def get(self, request):
rule_type = request.GET.get("rule")

View File

@ -1,11 +1,9 @@
from django.utils import timezone
from rest_framework import serializers
class DateTimeTZField(serializers.DateTimeField):
def to_representation(self, value):
# self.format = "%Y-%m-%d %H:%M:%S %Z"
value = timezone.localtime(value)
# value = timezone.localtime(value)
return super(DateTimeTZField, self).to_representation(value)

View File

@ -3,12 +3,14 @@ from django.test.testcases import TestCase
from rest_framework.test import APIClient
from account.models import AdminType, ProblemPermission, User, UserProfile
from conf.models import WebsiteConfig
class APITestCase(TestCase):
client_class = APIClient
def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True, problem_permission=ProblemPermission.NONE):
def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True,
problem_permission=ProblemPermission.NONE):
user = User.objects.create(username=username, admin_type=admin_type, problem_permission=problem_permission)
user.set_password(password)
UserProfile.objects.create(user=user)
@ -18,13 +20,17 @@ class APITestCase(TestCase):
return user
def create_admin(self, username="admin", password="admin", login=True):
return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, problem_permission=ProblemPermission.OWN,
return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN,
problem_permission=ProblemPermission.OWN,
login=login)
def create_super_admin(self, username="root", password="root", login=True):
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN,
problem_permission=ProblemPermission.ALL, login=login)
def create_website_config(self):
return WebsiteConfig.objects.create()
def reverse(self, url_name):
return reverse(url_name)

View File

@ -69,3 +69,12 @@ def img2base64(img):
img_prefix = "data:image/png;base64,"
b64_str = img_prefix + b64encode(buf_str).decode("utf-8")
return b64_str
def datetime2str(value, format="iso-8601"):
if format.lower() == "iso-8601":
value = value.isoformat()
if value.endswith("+00:00"):
value = value[:-6] + "Z"
return value
return value.strftime(format)