add session management api; add more unit tests for account module

This commit is contained in:
zema1 2017-09-16 10:38:49 +08:00
parent a3ca8b2336
commit 1ee0596a3a
9 changed files with 320 additions and 17 deletions

View File

@ -3,6 +3,7 @@ exclude =
xss_filter.py, xss_filter.py,
*/migrations/, */migrations/,
*settings.py *settings.py
*/apps.py
max-line-length = 180 max-line-length = 180
inline-quotes = " inline-quotes = "
no-accept-encodings = True no-accept-encodings = True

View File

@ -1 +1 @@
default_app_config = 'account.apps.ProfilesConfig' default_app_config = "account.apps.ProfilesConfig"

View File

@ -5,8 +5,8 @@ from django.contrib.auth.signals import user_logged_in, user_logged_out
@receiver(user_logged_in) @receiver(user_logged_in)
def add_user_session(sender, request, user, **kwargs): def add_user_session(sender, request, user, **kwargs):
request.session["ip"] = request.META.get('REMOTE_ADDR', '') request.session["ip"] = request.META.get("REMOTE_ADDR", "")
request.session["user_agent"] = request.META.get('HTTP_USER_AGENT', '') request.session["user_agent"] = request.META.get("HTTP_USER_AGENT", "")
request.session["last_login"] = now() request.session["last_login"] = now()
if request.session.session_key not in user.session_keys: if request.session.session_key not in user.session_keys:
user.session_keys.append(request.session.session_key) user.session_keys.append(request.session.session_key)

View File

@ -1,7 +1,9 @@
import time import time
from unittest import mock from unittest import mock
from datetime import timedelta
from django.contrib import auth from django.contrib import auth
from django.utils.timezone import now
from otpauth import OtpAuth from otpauth import OtpAuth
from utils.api.tests import APIClient, APITestCase from utils.api.tests import APIClient, APITestCase
@ -28,6 +30,40 @@ class PermissionDecoratorTest(APITestCase):
pass pass
class DuplicateUserCheckAPITest(APITestCase):
def setUp(self):
self.create_user("test", "test123", login=False)
self.url = self.reverse("check_username_or_email")
def test_duplicate_username(self):
resp = self.client.post(self.url, data={"username": "test"})
data = resp.data["data"]
self.assertEqual(data["username"], True)
def test_ok_username(self):
resp = self.client.post(self.url, data={"username": "test1"})
data = resp.data["data"]
self.assertEqual(data["username"], False)
class TFARequiredCheckAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("tfa_required_check")
self.create_user("test", "test123", login=False)
def test_not_required_tfa(self):
resp = self.client.post(self.url, data={"username": "test"})
self.assertSuccess(resp)
self.assertEqual(resp.data["data"]["result"], False)
def test_required_tfa(self):
user = User.objects.first()
user.two_factor_auth = True
user.save()
resp = self.client.post(self.url, data={"username": "test"})
self.assertEqual(resp.data["data"]["result"], True)
class UserLoginAPITest(APITestCase): class UserLoginAPITest(APITestCase):
def setUp(self): def setUp(self):
self.username = self.password = "test" self.username = self.password = "test"
@ -87,7 +123,7 @@ class UserLoginAPITest(APITestCase):
response = self.client.post(self.login_url, response = self.client.post(self.login_url,
data={"username": self.username, data={"username": self.username,
"password": self.password}) "password": self.password})
self.assertDictEqual(response.data, {"error": None, "data": "tfa_required"}) self.assertDictEqual(response.data, {"error": "error", "data": "tfa_required"})
user = auth.get_user(self.client) user = auth.get_user(self.client)
self.assertFalse(user.is_authenticated()) self.assertFalse(user.is_authenticated())
@ -142,6 +178,160 @@ class UserRegisterAPITest(CaptchaTest):
self.assertDictEqual(response.data, {"error": "error", "data": "Email already exists"}) self.assertDictEqual(response.data, {"error": "error", "data": "Email already exists"})
class SessionManagementAPITest(APITestCase):
def setUp(self):
self.create_user("test", "test123")
self.url = self.reverse("session_management_api")
def test_get_sessions(self):
resp = self.client.get(self.url)
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(len(data), 1)
def test_delete_session_key(self):
# resp = self.client.delete(self.url, data={"session_key": self.client.session.session_key})
resp = self.client.delete(self.url + "?session_key=" + self.client.session.session_key)
self.assertSuccess(resp)
def test_delete_session_with_invalid_key(self):
resp = self.client.delete(self.url + "?session_key=aaaaaaaaaa")
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid session_key"})
class UserProfileAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("user_profile_api")
def test_get_profile_without_login(self):
resp = self.client.get(self.url)
self.assertDictEqual(resp.data, {"error": None, "data": 0})
def test_get_profile(self):
self.create_user("test", "test123")
resp = self.client.get(self.url)
self.assertSuccess(resp)
def test_update_profile(self):
self.create_user("test", "test123")
update_data = {"real_name": "zemal", "submission_number": 233}
resp = self.client.put(self.url, data=update_data)
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data["real_name"], "zemal")
self.assertEqual(data["submission_number"], 0)
class TwoFactorAuthAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("two_factor_auth_api")
self.create_user("test", "test123")
self.create_website_config()
def _get_tfa_code(self):
user = User.objects.first()
code = OtpAuth(user.tfa_token).totp()
if len(str(code)) < 6:
code = (6 - len(str(code))) * "0" + str(code)
return code
def test_get_image(self):
resp = self.client.get(self.url)
self.assertSuccess(resp)
def test_open_tfa_with_invalid_code(self):
self.test_get_image()
resp = self.client.post(self.url, data={"code": "000000"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"})
def test_open_tfa_with_correct_code(self):
self.test_get_image()
code = self._get_tfa_code()
resp = self.client.post(self.url, data={"code": code})
self.assertSuccess(resp)
user = User.objects.first()
self.assertEqual(user.two_factor_auth, True)
def test_close_tfa_with_invalid_code(self):
self.test_open_tfa_with_correct_code()
resp = self.client.post(self.url, data={"code": "000000"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"})
def test_close_tfa_with_correct_code(self):
self.test_open_tfa_with_correct_code()
code = self._get_tfa_code()
resp = self.client.put(self.url, data={"code": code})
self.assertSuccess(resp)
user = User.objects.first()
self.assertEqual(user.two_factor_auth, False)
@mock.patch("account.views.oj.send_email_async.delay")
class ApplyResetPasswordAPITest(CaptchaTest):
def setUp(self):
self.create_user("test", "test123", login=False)
user = User.objects.first()
user.email = "test@oj.com"
user.save()
self.url = self.reverse("apply_reset_password_api")
self.create_website_config()
self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)}
def _refresh_captcha(self):
self.data["captcha"] = self._set_captcha(self.client.session)
def test_apply_reset_password(self, send_email_delay):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
send_email_delay.assert_called()
def test_apply_reset_password_twice_in_20_mins(self, send_email_delay):
self.test_apply_reset_password()
send_email_delay.reset_mock()
self._refresh_captcha()
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "You can only reset password once per 20 minutes"})
send_email_delay.assert_not_called()
def test_apply_reset_password_again_after_20_mins(self, send_email_delay):
self.test_apply_reset_password()
user = User.objects.first()
user.reset_password_token_expire_time = now() - timedelta(minutes=21)
user.save()
self._refresh_captcha()
self.test_apply_reset_password()
class ResetPasswordAPITest(CaptchaTest):
def setUp(self):
self.create_user("test", "test123", login=False)
self.url = self.reverse("reset_password_api")
user = User.objects.first()
user.reset_password_token = "online_judge?"
user.reset_password_token_expire_time = now() + timedelta(minutes=20)
user.save()
self.data = {"token": user.reset_password_token,
"captcha": self._set_captcha(self.client.session),
"password": "test456"}
def test_reset_password_with_correct_token(self):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
self.assertTrue(self.client.login(username="test", password="test456"))
def test_reset_password_with_invalid_token(self):
self.data["token"] = "aaaaaaaaaaa"
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Token dose not exist"})
def test_reset_password_with_expired_token(self):
user = User.objects.first()
user.reset_password_token_expire_time = now() - timedelta(seconds=30)
user.save()
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Token have expired"})
class UserChangePasswordAPITest(CaptchaTest): class UserChangePasswordAPITest(CaptchaTest):
def setUp(self): def setUp(self):
self.client = APIClient() self.client = APIClient()
@ -248,3 +438,37 @@ class AdminUserTest(APITestCase):
# if `openapi_app_key` is not None, the value is not changed # if `openapi_app_key` is not None, the value is not changed
self.assertTrue(resp_data["open_api"]) self.assertTrue(resp_data["open_api"])
self.assertEqual(User.objects.get(id=self.regular_user.id).open_api_appkey, key) self.assertEqual(User.objects.get(id=self.regular_user.id).open_api_appkey, key)
class UserRankAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("user_rank_api")
self.create_user("test1", "test123", login=False)
self.create_user("test2", "test123", login=False)
test1 = User.objects.get(username="test1")
profile1 = test1.userprofile
profile1.submission_number = 10
profile1.accepted_number = 10
profile1.total_score = 240
profile1.save()
test2 = User.objects.get(username="test2")
profile2 = test2.userprofile
profile2.submission_number = 15
profile2.accepted_number = 10
profile2.total_score = 700
profile2.save()
def test_get_acm_rank(self):
resp = self.client.get(self.url, data={"rule": "acm"})
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data[0]["user"]["username"], "test1")
self.assertEqual(data[1]["user"]["username"], "test2")
def test_get_oi_rank(self):
resp = self.client.get(self.url, data={"rule": "oi"})
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data[0]["user"]["username"], "test2")
self.assertEqual(data[1]["user"]["username"], "test1")

View File

@ -4,7 +4,7 @@ from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
UserChangePasswordAPI, UserRegisterAPI, UserChangePasswordAPI, UserRegisterAPI,
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck, UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI, SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
UserRankAPI) UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI)
from utils.captcha.views import CaptchaAPIView from utils.captcha.views import CaptchaAPIView
@ -14,12 +14,14 @@ urlpatterns = [
url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"), url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"),
url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"), url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"),
url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"), url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"),
url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="apply_reset_password_api"), url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"),
url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"), url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"),
url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"), url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"),
url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"), url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"),
url(r"^avatar/upload/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"), url(r"^avatar/upload/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"),
url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"), url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"),
url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"),
url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"), url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"),
url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"), url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"),
url(r"^sessions/?$", SessionManagementAPI.as_view(), name="session_management_api")
] ]

View File

@ -14,7 +14,7 @@ from django.template.loader import render_to_string
from conf.models import WebsiteConfig from conf.models import WebsiteConfig
from utils.api import APIView, validate_serializer, CSRFExemptAPIView from utils.api import APIView, validate_serializer, CSRFExemptAPIView
from utils.captcha import Captcha from utils.captcha import Captcha
from utils.shortcuts import rand_str, img2base64 from utils.shortcuts import rand_str, img2base64, datetime2str
from ..decorators import login_required from ..decorators import login_required
from ..models import User, UserProfile from ..models import User, UserProfile
@ -77,7 +77,6 @@ class AvatarUploadAPI(CSRFExemptAPIView):
with open(os.path.join(settings.IMAGE_UPLOAD_DIR, name), "wb") as img: with open(os.path.join(settings.IMAGE_UPLOAD_DIR, name), "wb") as img:
for chunk in avatar: for chunk in avatar:
img.write(chunk) img.write(chunk)
print(os.path.join(settings.IMAGE_UPLOAD_DIR, name))
return self.success({"path": "/static/upload/" + name}) return self.success({"path": "/static/upload/" + name})
@ -126,7 +125,7 @@ class TwoFactorAuthAPI(APIView):
user.save() user.save()
config = WebsiteConfig.objects.first() config = WebsiteConfig.objects.first()
label = f"{config.name_shortcut}:{user.username}@{config.base_url}" label = f"{config.name_shortcut}:{user.username}"
image = qrcode.make(OtpAuth(token).to_uri("totp", label, config.name)) image = qrcode.make(OtpAuth(token).to_uri("totp", label, config.name))
return self.success(img2base64(image)) return self.success(img2base64(image))
@ -143,18 +142,38 @@ class TwoFactorAuthAPI(APIView):
user.save() user.save()
return self.success("Succeeded") return self.success("Succeeded")
else: else:
return self.error("Invalid captcha") return self.error("Invalid code")
@login_required @login_required
@validate_serializer(TwoFactorAuthCodeSerializer) @validate_serializer(TwoFactorAuthCodeSerializer)
def put(self, request): def put(self, request):
code = request.data["code"] code = request.data["code"]
user = request.user user = request.user
if not user.two_factor_auth:
return self.error("Other session have disabled TFA")
if OtpAuth(user.tfa_token).valid_totp(code): if OtpAuth(user.tfa_token).valid_totp(code):
user.two_factor_auth = False user.two_factor_auth = False
user.save() user.save()
return self.success("Succeeded")
else: else:
return self.error("Invalid captcha") return self.error("Invalid code")
class CheckTFARequiredAPI(APIView):
@validate_serializer(UsernameOrEmailCheckSerializer)
def post(self, request):
"""
Check TFA is required
"""
data = request.data
result = False
if data.get("username"):
try:
user = User.objects.get(username=data["username"])
result = user.two_factor_auth
except User.DoesNotExist:
pass
return self.success({"result": result})
class UserLoginAPI(APIView): class UserLoginAPI(APIView):
@ -173,7 +192,7 @@ class UserLoginAPI(APIView):
# `tfa_code` not in post data # `tfa_code` not in post data
if user.two_factor_auth and "tfa_code" not in data: if user.two_factor_auth and "tfa_code" not in data:
return self.success("tfa_required") return self.error("tfa_required")
if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
auth.login(request, user) auth.login(request, user)
@ -302,11 +321,55 @@ class ResetPasswordAPI(APIView):
if int((user.reset_password_token_expire_time - now()).total_seconds()) < 0: if int((user.reset_password_token_expire_time - now()).total_seconds()) < 0:
return self.error("Token have expired") return self.error("Token have expired")
user.reset_password_token = None user.reset_password_token = None
user.two_factor_auth = False
user.set_password(data["password"]) user.set_password(data["password"])
user.save() user.save()
return self.success("Succeeded") return self.success("Succeeded")
class SessionManagementAPI(APIView):
@login_required
def get(self, request):
engine = import_module(settings.SESSION_ENGINE)
SessionStore = engine.SessionStore
current_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
session_keys = request.user.session_keys
result = []
modified = False
for key in session_keys[:]:
session = SessionStore(key)
# session does not exist or is expiry
if not session._session:
session_keys.remove(key)
modified = True
continue
s = {}
if current_session == key:
s["current_session"] = True
s["ip"] = session["ip"]
s["user_agent"] = session["user_agent"]
s["last_login"] = datetime2str(session["last_login"])
s["session_key"] = key
result.append(s)
if modified:
request.user.save()
return self.success(result)
@login_required
def delete(self, request):
session_key = request.GET.get("session_key")
if not session_key:
return self.error("Parameter Error")
request.session.delete(session_key)
if session_key in request.user.session_keys:
request.user.session_keys.remove(session_key)
request.user.save()
return self.success("Succeeded")
else:
return self.error("Invalid session_key")
class UserRankAPI(APIView): class UserRankAPI(APIView):
def get(self, request): def get(self, request):
rule_type = request.GET.get("rule") rule_type = request.GET.get("rule")

View File

@ -1,11 +1,9 @@
from django.utils import timezone
from rest_framework import serializers from rest_framework import serializers
class DateTimeTZField(serializers.DateTimeField): class DateTimeTZField(serializers.DateTimeField):
def to_representation(self, value): def to_representation(self, value):
# self.format = "%Y-%m-%d %H:%M:%S %Z" # value = timezone.localtime(value)
value = timezone.localtime(value)
return super(DateTimeTZField, self).to_representation(value) return super(DateTimeTZField, self).to_representation(value)

View File

@ -3,12 +3,14 @@ from django.test.testcases import TestCase
from rest_framework.test import APIClient from rest_framework.test import APIClient
from account.models import AdminType, ProblemPermission, User, UserProfile from account.models import AdminType, ProblemPermission, User, UserProfile
from conf.models import WebsiteConfig
class APITestCase(TestCase): class APITestCase(TestCase):
client_class = APIClient client_class = APIClient
def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True, problem_permission=ProblemPermission.NONE): def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True,
problem_permission=ProblemPermission.NONE):
user = User.objects.create(username=username, admin_type=admin_type, problem_permission=problem_permission) user = User.objects.create(username=username, admin_type=admin_type, problem_permission=problem_permission)
user.set_password(password) user.set_password(password)
UserProfile.objects.create(user=user) UserProfile.objects.create(user=user)
@ -18,13 +20,17 @@ class APITestCase(TestCase):
return user return user
def create_admin(self, username="admin", password="admin", login=True): def create_admin(self, username="admin", password="admin", login=True):
return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, problem_permission=ProblemPermission.OWN, return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN,
problem_permission=ProblemPermission.OWN,
login=login) login=login)
def create_super_admin(self, username="root", password="root", login=True): def create_super_admin(self, username="root", password="root", login=True):
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN,
problem_permission=ProblemPermission.ALL, login=login) problem_permission=ProblemPermission.ALL, login=login)
def create_website_config(self):
return WebsiteConfig.objects.create()
def reverse(self, url_name): def reverse(self, url_name):
return reverse(url_name) return reverse(url_name)

View File

@ -69,3 +69,12 @@ def img2base64(img):
img_prefix = "data:image/png;base64," img_prefix = "data:image/png;base64,"
b64_str = img_prefix + b64encode(buf_str).decode("utf-8") b64_str = img_prefix + b64encode(buf_str).decode("utf-8")
return b64_str return b64_str
def datetime2str(value, format="iso-8601"):
if format.lower() == "iso-8601":
value = value.isoformat()
if value.endswith("+00:00"):
value = value[:-6] + "Z"
return value
return value.strftime(format)