diff --git a/account/decorators.py b/account/decorators.py new file mode 100644 index 00000000..6b108960 --- /dev/null +++ b/account/decorators.py @@ -0,0 +1,24 @@ +# coding=utf-8 +from django.http import HttpResponse +from django.shortcuts import render + +from utils.shortcuts import error_response +from .models import User + + +def login_required(func): + def check(*args, **kwargs): + # 在class based views 里面,args 有两个元素,一个是self, 第二个才是request, + # 在function based views 里面,args 只有request 一个参数 + request = args[-1] + if request.user.is_authenticated(): + return func(*args, **kwargs) + if request.is_ajax(): + return error_response(u"请先登录") + else: + return render(request, "utils/error.html", {"error": u"请先登录"}) + return check + + +def admin_required(): + pass diff --git a/account/serializers.py b/account/serializers.py index ffa6b3b1..41634b7a 100644 --- a/account/serializers.py +++ b/account/serializers.py @@ -5,3 +5,20 @@ from rest_framework import serializers class UserLoginSerializer(serializers.Serializer): username = serializers.CharField(max_length=30) password = serializers.CharField(max_length=30) + + +class UsernameCheckSerializer(serializers.Serializer): + username = serializers.CharField(max_length=30) + + +class UserRegisterSerializer(serializers.Serializer): + username = serializers.CharField(max_length=30) + real_name = serializers.CharField(max_length=30) + password = serializers.CharField(max_length=30, min_length=6) + + +class UserChangePasswordSerializer(serializers.Serializer): + username = serializers.CharField(max_length=30) + old_password = serializers.CharField(max_length=30, min_length=6) + new_password = serializers.CharField(max_length=30, min_length=6) + diff --git a/account/test_urls.py b/account/test_urls.py new file mode 100644 index 00000000..40667215 --- /dev/null +++ b/account/test_urls.py @@ -0,0 +1,12 @@ +# coding=utf-8 +from django.conf.urls import include, url + +from .tests import LoginRequiredCBVTestWithArgs, LoginRequiredCBVTestWithoutArgs + + +urlpatterns = [ + url(r'^test/fbv/1/$', "account.tests.login_required_FBV_test_without_args"), + url(r'^test/fbv/(?P\d+)/$', "account.tests.login_required_FBC_test_with_args"), + url(r'^test/cbv/1/$', LoginRequiredCBVTestWithoutArgs.as_view()), + url(r'^test/cbv/(?P\d+)/$', LoginRequiredCBVTestWithArgs.as_view()), +] diff --git a/account/tests.py b/account/tests.py index 12d087a8..83226b70 100644 --- a/account/tests.py +++ b/account/tests.py @@ -1,10 +1,16 @@ # coding=utf-8 +import json + from django.core.urlresolvers import reverse from django.test import TestCase, Client +from django.http import HttpResponse from rest_framework.test import APITestCase, APIClient +from rest_framework.views import APIView +from rest_framework.response import Response from .models import User +from .decorators import login_required class UserLoginTest(TestCase): @@ -36,3 +42,137 @@ class UserLoginAPITest(APITestCase): data = {"username": "test", "password": "test"} response = self.client.post(self.url, data=data) self.assertEqual(response.data, {"code": 0, "data": u"登录成功"}) + + +class UsernameCheckTest(APITestCase): + def setUp(self): + self.client = APIClient() + self.url = reverse("username_check_api") + User.objects.create(username="testtest") + + def test_invalid_data(self): + response = self.client.post(self.url, data={"username111": "testtest"}) + self.assertEqual(response.data["code"], 1) + + def test_username_exists(self): + response = self.client.post(self.url, data={"username": "testtest"}) + self.assertEqual(response.data, {"code": 0, "data": True}) + + def test_username_does_not_exist(self): + response = self.client.post(self.url, data={"username": "testtest123"}) + self.assertEqual(response.data, {"code": 0, "data": False}) + + +class UserRegisterAPITest(APITestCase): + def setUp(self): + self.client = APIClient() + self.url = reverse("user_register_api") + + def test_invalid_data(self): + data = {"username": "test", "real_name": "TT"} + response = self.client.post(self.url, data=data) + self.assertEqual(response.data["code"], 1) + + def test_short_password(self): + data = {"username": "test", "real_name": "TT", "password": "qq"} + response = self.client.post(self.url, data=data) + self.assertEqual(response.data["code"], 1) + + def test_same_username(self): + User.objects.create(username="aa", real_name="ww") + data = {"username": "aa", "real_name": "ww", "password": "zzzzzzz"} + response = self.client.post(self.url, data=data) + self.assertEqual(response.data, {"code": 1, "data": u"用户名已存在"}) + + +class UserChangePasswordAPITest(APITestCase): + def setUp(self): + self.client = APIClient() + self.url = reverse("user_change_password_api") + User.objects.create(username="test", password="aaabbb") + + def test_error_old_password(self): + data = {"username": "test", "old_password": "aaaccc", "new_password": "aaaddd"} + response = self.client.post(self.url, data=data) + self.assertEqual(response.data, {"code": 1, "data": u"密码不正确,请重新修改!"}) + + def test_invalid_data_format(self): + data = {"username": "test", "old_password": "aaa", "new_password": "aaaddd"} + response = self.client.post(self.url, data=data) + self.assertEqual(response.data["code"], 1) + + def test_username_does_not_exist(self): + data = {"username": "test1", "old_password": "aaabbb", "new_password": "aaaddd"} + response = self.client.post(self.url, data=data) + self.assertEqual(response.data["code"], 1) + + +@login_required +def login_required_FBV_test_without_args(request): + return HttpResponse("function based view test1") + + +@login_required +def login_required_FBC_test_with_args(request, problem_id): + return HttpResponse(problem_id) + + +class LoginRequiredCBVTestWithoutArgs(APIView): + @login_required + def get(self, request): + return HttpResponse("class based view login required test1") + +class LoginRequiredCBVTestWithArgs(APIView): + @login_required + def get(self, request, problem_id): + return HttpResponse(problem_id) + + +class LoginRequiredDecoratorTest(TestCase): + urls = 'account.test_urls' + + def setUp(self): + self.client = Client() + user = User.objects.create(username="test") + user.set_password("test") + user.save() + + def test_fbv_without_args(self): + # 没登陆 + response = self.client.get("/test/fbv/1/") + self.assertTemplateUsed(response, "utils/error.html") + + # 登陆后 + self.client.login(username="test", password="test") + response = self.client.get("/test/fbv/1/") + self.assertEqual(response.content, "function based view test1") + + def test_fbv_with_args(self): + # 没登陆 + response = self.client.get("/test/fbv/1024/") + self.assertTemplateUsed(response, "utils/error.html") + + # 登陆后 + self.client.login(username="test", password="test") + response = self.client.get("/test/fbv/1024/") + self.assertEqual(response.content, "1024") + + def test_cbv_without_args(self): + # 没登陆 + response = self.client.get("/test/cbv/1/") + self.assertTemplateUsed(response, "utils/error.html") + + # 登陆后 + self.client.login(username="test", password="test") + response = self.client.get("/test/cbv/1/") + self.assertEqual(response.content, "class based view login required test1") + + def test_cbv_with_args(self): + # 没登陆 + response = self.client.get("/test/cbv/1024/", HTTP_X_REQUESTED_WITH='XMLHttpRequest') + self.assertEqual(json.loads(response.content), {"code": 1, "data": u"请先登录"}) + + # 登陆后 + self.client.login(username="test", password="test") + response = self.client.get("/test/cbv/1024/") + self.assertEqual(response.content, "1024") diff --git a/account/views.py b/account/views.py index 8c973204..dd05874b 100644 --- a/account/views.py +++ b/account/views.py @@ -6,7 +6,8 @@ from rest_framework.views import APIView from utils.shortcuts import serializer_invalid_response, error_response, success_response from .models import User -from .serializers import UserLoginSerializer +from .serializers import UserLoginSerializer, UsernameCheckSerializer, UserRegisterSerializer, \ + UserChangePasswordSerializer class UserLoginAPIView(APIView): @@ -30,17 +31,62 @@ class UserLoginAPIView(APIView): return serializer_invalid_response(serializer) -class UserRegisterView(APIView): - def get(self, request): - pass - +class UserRegisterAPIView(APIView): def post(self, request): - pass + """ + 用户注册json api接口 + --- + request_serializer: UserRegisterSerializer + """ + serializer = UserRegisterSerializer(data=request.DATA) + if serializer.is_valid(): + data = serializer.data + try: + User.objects.get(username=data["username"]) + return error_response(u"用户名已存在") + except User.DoesNotExist: + user = User.objects.create(username=data["username"], real_name=data["real_name"]) + user.set_password(data["password"]) + user.save() + return success_response(u"注册成功!") + else: + return serializer_invalid_response(serializer) -class UserChangePasswordView(APIView): - def get(self, request): - pass - +class UserChangePasswordAPIView(APIView): def post(self, request): - pass \ No newline at end of file + """ + 用户修改密码json api接口 + --- + request_serializer: UserChangePasswordSerializer + """ + serializer = UserChangePasswordSerializer(data=request.DATA) + if serializer.is_valid(): + data = serializer.data + user = auth.authenticate(username=data["username"], password=data["old_password"]) + if user: + user.set_password(data["new_password"]) + user.save() + return success_response(u"用户密码修改成功!") + else: + return error_response(u"密码不正确,请重新修改!") + else: + return serializer_invalid_response(serializer) + + +class UsernameCheckAPIView(APIView): + def post(self, request): + """ + 检测用户名是否存在,存在返回True,不存在返回False + --- + request_serializer: UsernameCheckSerializer + """ + serializer = UsernameCheckSerializer(data=request.DATA) + if serializer.is_valid(): + try: + User.objects.get(username=serializer.data["username"]) + return success_response(True) + except User.DoesNotExist: + return success_response(False) + else: + return serializer_invalid_response(serializer) \ No newline at end of file diff --git a/oj/urls.py b/oj/urls.py index f9d4cbf7..fc652b32 100644 --- a/oj/urls.py +++ b/oj/urls.py @@ -3,17 +3,17 @@ from django.conf.urls import include, url from django.contrib import admin from django.views.generic import TemplateView -from account.views import UserLoginAPIView +from account.views import UserLoginAPIView, UsernameCheckAPIView, UserRegisterAPIView, UserChangePasswordAPIView urlpatterns = [ url("^$", TemplateView.as_view(template_name="oj/index.html"), name="index_page"), url(r'^docs/', include('rest_framework_swagger.urls')), url(r'^admin/$', TemplateView.as_view(template_name="admin/index.html"), name="admin_index_page"), url(r'^login/$', TemplateView.as_view(template_name="oj/account/login.html"), name="user_login_page"), - url(r'^register/$', TemplateView.as_view(template_name="oj/account/register.html"), name="user_register_page"), - url(r'^change_password/$', TemplateView.as_view(template_name="oj/account/change_password.html"), name="user_change_password_page"), - url(r'^api/login/$', UserLoginAPIView.as_view(), name="login_api"), url(r'^api/login/$', UserLoginAPIView.as_view(), name="user_login_api"), + url(r'^api/register/$', UserRegisterAPIView.as_view(), name="user_register_api"), + url(r'^api/change_password/$', UserChangePasswordAPIView.as_view(), name="user_change_password_api"), + url(r'^api/username_check/$', UsernameCheckAPIView.as_view(), name="username_check_api"), url(r'^problem/(?P\d+)/$', "problem.views.problem_page", name="problem_page"), url(r'^admin/contest/$', TemplateView.as_view(template_name="admin/contest/add_contest.html"), name="add_contest_page"),