Merge pull request #233 from QingdaoU/feature/django20

Feature/django20
This commit is contained in:
李扬 2019-03-26 10:52:34 +08:00 committed by GitHub
commit 9869f6294c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 521 additions and 245 deletions

4
.dockerignore Normal file
View File

@ -0,0 +1,4 @@
venv
.idea
.git
.DS_Store

View File

@ -4,6 +4,7 @@ exclude =
*/migrations/,
*settings.py
*/apps.py
venv/
max-line-length = 180
inline-quotes = "
no-accept-encodings = True

View File

@ -1,4 +1,4 @@
FROM python:3.6-alpine3.6
FROM python:3.7-alpine3.9
ENV OJ_ENV production

View File

@ -31,19 +31,19 @@ class BasePermissionDecorator(object):
class login_required(BasePermissionDecorator):
def check_permission(self):
return self.request.user.is_authenticated()
return self.request.user.is_authenticated
class super_admin_required(BasePermissionDecorator):
def check_permission(self):
user = self.request.user
return user.is_authenticated() and user.is_super_admin()
return user.is_authenticated and user.is_super_admin()
class admin_role_required(BasePermissionDecorator):
def check_permission(self):
user = self.request.user
return user.is_authenticated() and user.is_admin_role()
return user.is_authenticated and user.is_admin_role()
class problem_permission_required(admin_role_required):
@ -80,7 +80,7 @@ def check_contest_permission(check_type="details"):
return self.error("Contest %s doesn't exist" % contest_id)
# Anonymous
if not user.is_authenticated():
if not user.is_authenticated:
return self.error("Please login first.")
# creator or owner

View File

@ -22,7 +22,7 @@ class APITokenAuthMiddleware(MiddlewareMixin):
class SessionRecordMiddleware(MiddlewareMixin):
def process_request(self, request):
request.ip = request.META.get(settings.IP_HEADER, request.META.get("REMOTE_ADDR"))
if request.user.is_authenticated():
if request.user.is_authenticated:
session = request.session
session["user_agent"] = request.META.get("HTTP_USER_AGENT", "")
session["ip"] = request.ip
@ -37,7 +37,7 @@ class AdminRoleRequiredMiddleware(MiddlewareMixin):
def process_request(self, request):
path = request.path_info
if path.startswith("/admin/") or path.startswith("/api/admin/"):
if not (request.user.is_authenticated() and request.user.is_admin_role()):
if not (request.user.is_authenticated and request.user.is_admin_role()):
return JSONResponse.response({"error": "login-required", "data": "Please login in first"})

View File

@ -60,7 +60,7 @@ class User(AbstractBaseUser):
return self.problem_permission == ProblemPermission.ALL
def is_contest_admin(self, contest):
return self.is_authenticated() and (contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN)
return self.is_authenticated and (contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN)
class Meta:
db_table = "user"

View File

@ -131,6 +131,10 @@ class ImageUploadForm(forms.Form):
image = forms.FileField()
class FileUploadForm(forms.Form):
file = forms.FileField()
class RankInfoSerializer(serializers.ModelSerializer):
user = UsernameSerializer()

View File

@ -1,14 +1,13 @@
import logging
from celery import shared_task
import dramatiq
from options.options import SysOptions
from utils.shortcuts import send_email
from utils.shortcuts import send_email, DRAMATIQ_WORKER_ARGS
logger = logging.getLogger(__name__)
@shared_task
@dramatiq.actor(**DRAMATIQ_WORKER_ARGS(max_retries=3))
def send_email_async(from_name, to_email, to_name, subject, content):
if not SysOptions.smtp_config:
return

View File

@ -101,13 +101,13 @@ class UserLoginAPITest(APITestCase):
self.assertDictEqual(response.data, {"error": None, "data": "Succeeded"})
user = auth.get_user(self.client)
self.assertTrue(user.is_authenticated())
self.assertTrue(user.is_authenticated)
def test_login_with_correct_info_upper_username(self):
resp = self.client.post(self.login_url, data={"username": self.username.upper(), "password": self.password})
self.assertDictEqual(resp.data, {"error": None, "data": "Succeeded"})
user = auth.get_user(self.client)
self.assertTrue(user.is_authenticated())
self.assertTrue(user.is_authenticated)
def test_login_with_wrong_info(self):
response = self.client.post(self.login_url,
@ -115,7 +115,7 @@ class UserLoginAPITest(APITestCase):
self.assertDictEqual(response.data, {"error": "error", "data": "Invalid username or password"})
user = auth.get_user(self.client)
self.assertFalse(user.is_authenticated())
self.assertFalse(user.is_authenticated)
def test_tfa_login(self):
token = self._set_tfa()
@ -129,7 +129,7 @@ class UserLoginAPITest(APITestCase):
self.assertDictEqual(response.data, {"error": None, "data": "Succeeded"})
user = auth.get_user(self.client)
self.assertTrue(user.is_authenticated())
self.assertTrue(user.is_authenticated)
def test_tfa_login_wrong_code(self):
self._set_tfa()
@ -140,7 +140,7 @@ class UserLoginAPITest(APITestCase):
self.assertDictEqual(response.data, {"error": "error", "data": "Invalid two factor verification code"})
user = auth.get_user(self.client)
self.assertFalse(user.is_authenticated())
self.assertFalse(user.is_authenticated)
def test_tfa_login_without_code(self):
self._set_tfa()
@ -150,7 +150,7 @@ class UserLoginAPITest(APITestCase):
self.assertDictEqual(response.data, {"error": "error", "data": "tfa_required"})
user = auth.get_user(self.client)
self.assertFalse(user.is_authenticated())
self.assertFalse(user.is_authenticated)
def test_user_disabled(self):
self.user.is_disabled = True
@ -304,7 +304,7 @@ class TwoFactorAuthAPITest(APITestCase):
self.assertEqual(user.two_factor_auth, False)
@mock.patch("account.views.oj.send_email_async.delay")
@mock.patch("account.views.oj.send_email_async.send")
class ApplyResetPasswordAPITest(CaptchaTest):
def setUp(self):
self.create_user("test", "test123", login=False)
@ -317,20 +317,20 @@ class ApplyResetPasswordAPITest(CaptchaTest):
def _refresh_captcha(self):
self.data["captcha"] = self._set_captcha(self.client.session)
def test_apply_reset_password(self, send_email_delay):
def test_apply_reset_password(self, send_email_send):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
send_email_delay.assert_called()
send_email_send.assert_called()
def test_apply_reset_password_twice_in_20_mins(self, send_email_delay):
def test_apply_reset_password_twice_in_20_mins(self, send_email_send):
self.test_apply_reset_password()
send_email_delay.reset_mock()
send_email_send.reset_mock()
self._refresh_captcha()
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "You can only reset password once per 20 minutes"})
send_email_delay.assert_not_called()
send_email_send.assert_not_called()
def test_apply_reset_password_again_after_20_mins(self, send_email_delay):
def test_apply_reset_password_again_after_20_mins(self, send_email_send):
self.test_apply_reset_password()
user = User.objects.first()
user.reset_password_token_expire_time = now() - timedelta(minutes=21)

View File

@ -121,25 +121,15 @@ class UserAdminAPI(APIView):
Q(email__icontains=keyword))
return self.success(self.paginate_data(request, user, UserAdminSerializer))
def delete_one(self, user_id):
try:
user = User.objects.get(id=user_id)
except User.DoesNotExist:
return f"User {user_id} does not exist"
if Submission.objects.filter(user_id=user_id).exists():
return f"Can't delete the user {user_id} as he/she has submissions"
user.delete()
@super_admin_required
def delete(self, request):
id = request.GET.get("id")
if not id:
return self.error("Invalid Parameter, id is required")
for user_id in id.split(","):
if user_id:
error = self.delete_one(user_id)
if error:
return self.error(error)
ids = id.split(",")
if str(request.user.id) in ids:
return self.error("Current user can not be deleted")
User.objects.filter(id__in=ids).delete()
return self.success()

View File

@ -35,7 +35,7 @@ class UserProfileAPI(APIView):
判断是否登录 若登录返回用户信息
"""
user = request.user
if not user.is_authenticated():
if not user.is_authenticated:
return self.success()
show_real_name = False
username = request.GET.get("username")
@ -280,7 +280,7 @@ class UserChangePasswordAPI(APIView):
class ApplyResetPasswordAPI(APIView):
@validate_serializer(ApplyResetPasswordSerializer)
def post(self, request):
if request.user.is_authenticated():
if request.user.is_authenticated:
return self.error("You have already logged in, are you kidding me? ")
data = request.data
captcha = Captcha(request)
@ -302,11 +302,11 @@ class ApplyResetPasswordAPI(APIView):
"link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}"
}
email_html = render_to_string("reset_password_email.html", render_data)
send_email_async.delay(from_name=SysOptions.website_name_shortcut,
to_email=user.email,
to_name=user.username,
subject=f"Reset your password",
content=email_html)
send_email_async.send(from_name=SysOptions.website_name_shortcut,
to_email=user.email,
to_name=user.username,
subject=f"Reset your password",
content=email_html)
return self.success("Succeeded")

View File

@ -9,7 +9,7 @@ class Announcement(models.Model):
# HTML
content = RichTextField()
create_time = models.DateTimeField(auto_now_add=True)
created_by = models.ForeignKey(User)
created_by = models.ForeignKey(User, on_delete=models.CASCADE)
last_update_time = models.DateTimeField(auto_now=True)
visible = models.BooleanField(default=True)

View File

@ -0,0 +1,23 @@
# Generated by Django 2.1.7 on 2019-03-26 02:01
from django.conf import settings
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('contest', '0009_auto_20180501_0436'),
]
operations = [
migrations.AlterUniqueTogether(
name='acmcontestrank',
unique_together={('user', 'contest')},
),
migrations.AlterUniqueTogether(
name='oicontestrank',
unique_together={('user', 'contest')},
),
]

View File

@ -20,7 +20,7 @@ class Contest(models.Model):
end_time = models.DateTimeField()
create_time = models.DateTimeField(auto_now_add=True)
last_update_time = models.DateTimeField(auto_now=True)
created_by = models.ForeignKey(User)
created_by = models.ForeignKey(User, on_delete=models.CASCADE)
# 是否可见 false的话相当于删除
visible = models.BooleanField(default=True)
allowed_ip_ranges = JSONField(default=list)
@ -47,7 +47,7 @@ class Contest(models.Model):
def problem_details_permission(self, user):
return self.rule_type == ContestRuleType.ACM or \
self.status == ContestStatus.CONTEST_ENDED or \
user.is_authenticated() and user.is_contest_admin(self) or \
user.is_authenticated and user.is_contest_admin(self) or \
self.real_time_rank
class Meta:
@ -56,8 +56,8 @@ class Contest(models.Model):
class AbstractContestRank(models.Model):
user = models.ForeignKey(User)
contest = models.ForeignKey(Contest)
user = models.ForeignKey(User, on_delete=models.CASCADE)
contest = models.ForeignKey(Contest, on_delete=models.CASCADE)
submission_number = models.IntegerField(default=0)
class Meta:
@ -74,6 +74,7 @@ class ACMContestRank(AbstractContestRank):
class Meta:
db_table = "acm_contest_rank"
unique_together = (("user", "contest"),)
class OIContestRank(AbstractContestRank):
@ -84,13 +85,14 @@ class OIContestRank(AbstractContestRank):
class Meta:
db_table = "oi_contest_rank"
unique_together = (("user", "contest"),)
class ContestAnnouncement(models.Model):
contest = models.ForeignKey(Contest)
contest = models.ForeignKey(Contest, on_delete=models.CASCADE)
title = models.TextField()
content = RichTextField()
created_by = models.ForeignKey(User)
created_by = models.ForeignKey(User, on_delete=models.CASCADE)
visible = models.BooleanField(default=True)
create_time = models.DateTimeField(auto_now_add=True)

View File

@ -45,7 +45,7 @@ class ContestAdminAPITest(APITestCase):
response_data = response.data["data"]
for k in data.keys():
if isinstance(data[k], datetime):
continue
continue
self.assertEqual(response_data[k], data[k])
def test_get_contests(self):

View File

@ -234,7 +234,7 @@ class DownloadContestSubmissions(APIView):
exclude_admin = request.GET.get("exclude_admin") == "1"
zip_path = self._dump_submissions(contest, exclude_admin)
delete_files.apply_async((zip_path,), countdown=300)
delete_files.send_with_options(args=(zip_path,), delay=300_000)
resp = FileResponse(open(zip_path, "rb"))
resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = f"attachment;filename={os.path.basename(zip_path)}"

View File

@ -8,7 +8,7 @@ from django.core.cache import cache
from problem.models import Problem
from utils.api import APIView, validate_serializer
from utils.constants import CacheKey
from utils.shortcuts import datetime2str
from utils.shortcuts import datetime2str, check_is_id
from account.models import AdminType
from account.decorators import login_required, check_contest_permission
@ -35,7 +35,7 @@ class ContestAnnouncementListAPI(APIView):
class ContestAPI(APIView):
def get(self, request):
id = request.GET.get("id")
if not id:
if not id or not check_is_id(id):
return self.error("Invalid parameter, id is required")
try:
contest = Contest.objects.get(id=id, visible=True)
@ -121,7 +121,7 @@ class ContestRankAPI(APIView):
def get(self, request):
download_csv = request.GET.get("download_csv")
force_refresh = request.GET.get("force_refresh")
is_contest_admin = request.user.is_authenticated() and request.user.is_contest_admin(self.contest)
is_contest_admin = request.user.is_authenticated and request.user.is_contest_admin(self.contest)
if self.contest.rule_type == ContestRuleType.OI:
serializer = OIContestRankSerializer
else:

View File

@ -1,19 +1,32 @@
django==1.11.4
djangorestframework==3.4.0
pillow
otpauth
flake8-quotes
pytz
coverage
python-dateutil
celery
Envelopes
qrcode
flake8-coding
requests
django-redis
psycopg2-binary
gunicorn
jsonfield
XlsxWriter
raven
certifi==2019.3.9
chardet==3.0.4
coverage==4.5.3
Django==2.1.7
django-redis==4.10.0
djangorestframework==3.8.2
entrypoints==0.3
Envelopes==0.4
flake8==3.7.7
flake8-coding==1.3.1
flake8-quotes==1.0.0
gunicorn==19.9.0
idna==2.8
jsonfield==2.0.2
mccabe==0.6.1
otpauth==1.0.1
Pillow==5.4.1
psycopg2-binary==2.7.7
pycodestyle==2.5.0
pyflakes==2.1.1
python-dateutil==2.8.0
pytz==2018.9
qrcode==6.1
raven==6.10.0
redis==3.2.0
requests==2.21.0
six==1.12.0
urllib3==1.24.1
XlsxWriter==1.1.5
django-dramatiq==0.5.0
dramatiq==1.3.0
django-dbconn-retry==0.1.5

View File

@ -38,12 +38,12 @@ startsecs=5
stopwaitsecs = 5
killasgroup=true
[program:celery]
command=celery -A oj worker -l warning --autoscale 2,%(ENV_MAX_WORKER_NUM)s
[program:dramatiq]
command=python3 manage.py rundramatiq --no-reload --processes %(ENV_MAX_WORKER_NUM)s --threads 4
directory=/app/
user=nobody
stdout_logfile=/data/log/celery.log
stderr_logfile=/data/log/celery.log
stdout_logfile=/data/log/dramatiq.log
stderr_logfile=/data/log/dramatiq.log
autostart=true
autorestart=true
startsecs=5

View File

@ -1,5 +1,19 @@
{
"update": [
{
"version": "2019-03-25",
"level": "Recommend",
"title": "2019-03-25",
"details": [
"Update Django to version 2.1 and Python to version 3.7",
"Replace celery with dramatiq",
"Add problem file IO Mode",
"You can add attachments in all editor",
"You can upload source code file in submission editor",
"Frontend and UI improvements",
"Fixed a lot of bugs"
]
},
{
"version": "2018-12-15",
"level": "Recommend",

View File

@ -4,7 +4,7 @@ import logging
from urllib.parse import urljoin
import requests
from django.db import transaction
from django.db import transaction, IntegrityError
from django.db.models import F
from account.models import User
@ -26,7 +26,28 @@ def process_pending_task():
# 防止循环引入
from judge.tasks import judge_task
data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8"))
judge_task.delay(**data)
judge_task.send(**data)
class ChooseJudgeServer:
def __init__(self):
self.server = None
def __enter__(self) -> [JudgeServer, None]:
with transaction.atomic():
servers = JudgeServer.objects.select_for_update().filter(is_disabled=False).order_by("task_number")
servers = [s for s in servers if s.status == "normal"]
for server in servers:
if server.task_number <= server.cpu_core * 2:
server.task_number = F("task_number") + 1
server.save()
self.server = server
return server
return None
def __exit__(self, exc_type, exc_val, exc_tb):
if self.server:
JudgeServer.objects.filter(id=self.server.id).update(task_number=F("task_number") - 1)
class DispatcherBase(object):
@ -42,25 +63,6 @@ class DispatcherBase(object):
except Exception as e:
logger.exception(e)
@staticmethod
def choose_judge_server():
with transaction.atomic():
servers = JudgeServer.objects.select_for_update().filter(is_disabled=False).order_by("task_number")
servers = [s for s in servers if s.status == "normal"]
for server in servers:
if server.task_number <= server.cpu_core * 2:
server.task_number = F("task_number") + 1
server.save()
return server
@staticmethod
def release_judge_server(judge_server_id):
with transaction.atomic():
# 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下
server = JudgeServer.objects.get(id=judge_server_id)
server.task_number = F("task_number") - 1
server.save()
class SPJCompiler(DispatcherBase):
def __init__(self, spj_code, spj_version, spj_language):
@ -74,13 +76,14 @@ class SPJCompiler(DispatcherBase):
}
def compile_spj(self):
server = self.choose_judge_server()
if not server:
return "No available judge_server"
result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data)
self.release_judge_server(server.id)
if result["err"]:
return result["data"]
with ChooseJudgeServer() as server:
if not server:
return "No available judge_server"
result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data)
if not result:
return "Failed to call judge server"
if result["err"]:
return result["data"]
class JudgeDispatcher(DispatcherBase):
@ -118,12 +121,6 @@ class JudgeDispatcher(DispatcherBase):
self.submission.statistic_info["score"] = score
def judge(self):
server = self.choose_judge_server()
if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
cache.lpush(CacheKey.waiting_queue, json.dumps(data))
return
language = self.submission.language
sub_config = list(filter(lambda item: language == item["name"], SysOptions.languages))[0]
spj_config = {}
@ -149,12 +146,22 @@ class JudgeDispatcher(DispatcherBase):
"spj_version": self.problem.spj_version,
"spj_config": spj_config.get("config"),
"spj_compile_config": spj_config.get("compile"),
"spj_src": self.problem.spj_code
"spj_src": self.problem.spj_code,
"io_mode": self.problem.io_mode
}
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING)
with ChooseJudgeServer() as server:
if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
cache.lpush(CacheKey.waiting_queue, json.dumps(data))
return
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING)
resp = self._request(urljoin(server.service_url, "/judge"), data=data)
if not resp:
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.SYSTEM_ERROR)
return
resp = self._request(urljoin(server.service_url, "/judge"), data=data)
if resp["err"]:
self.submission.result = JudgeStatus.COMPILE_ERROR
self.submission.statistic_info["err_info"] = resp["data"]
@ -173,7 +180,6 @@ class JudgeDispatcher(DispatcherBase):
else:
self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED
self.submission.save()
self.release_judge_server(server.id)
if self.contest_id:
if self.contest.status != ContestStatus.CONTEST_UNDERWAY or \
@ -181,8 +187,9 @@ class JudgeDispatcher(DispatcherBase):
logger.info(
"Contest debug mode, id: " + str(self.contest_id) + ", submission id: " + self.submission.id)
return
self.update_contest_problem_status()
self.update_contest_rank()
with transaction.atomic():
self.update_contest_problem_status()
self.update_contest_rank()
else:
if self.last_result:
self.update_problem_status_rejudge()
@ -322,20 +329,31 @@ class JudgeDispatcher(DispatcherBase):
def update_contest_rank(self):
if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank:
cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}")
with transaction.atomic():
if self.contest.rule_type == ContestRuleType.ACM:
acm_rank, _ = ACMContestRank.objects.select_for_update(). \
get_or_create(user_id=self.submission.user_id, contest=self.contest)
self._update_acm_contest_rank(acm_rank)
else:
oi_rank, _ = OIContestRank.objects.select_for_update(). \
get_or_create(user_id=self.submission.user_id, contest=self.contest)
self._update_oi_contest_rank(oi_rank)
def get_rank(model):
return model.objects.select_for_update().get(user_id=self.submission.user_id, contest=self.contest)
if self.contest.rule_type == ContestRuleType.ACM:
model = ACMContestRank
func = self._update_acm_contest_rank
else:
model = OIContestRank
func = self._update_oi_contest_rank
try:
rank = get_rank(model)
except model.DoesNotExist:
try:
model.objects.create(user_id=self.submission.user_id, contest=self.contest)
rank = get_rank(model)
except IntegrityError:
rank = get_rank(model)
func(rank)
def _update_acm_contest_rank(self, rank):
info = rank.submission_info.get(str(self.submission.problem_id))
# 因前面更改过,这里需要重新获取
problem = Problem.objects.get(contest_id=self.contest_id, id=self.problem.id)
problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id)
# 此题提交过
if info:
if info["is_ac"]:

View File

@ -1,3 +1,6 @@
from problem.models import ProblemIOMode
default_env = ["LANG=en_US.UTF-8", "LANGUAGE=en_US:en", "LC_ALL=en_US.UTF-8"]
_c_lang_config = {
@ -28,7 +31,7 @@ int main() {
},
"run": {
"command": "{exe_path}",
"seccomp_rule": "c_cpp",
"seccomp_rule": {ProblemIOMode.standard: "c_cpp", ProblemIOMode.file: "c_cpp_file_io"},
"env": default_env
}
}

View File

@ -1,12 +1,12 @@
from __future__ import absolute_import, unicode_literals
from celery import shared_task
import dramatiq
from account.models import User
from submission.models import Submission
from judge.dispatcher import JudgeDispatcher
from utils.shortcuts import DRAMATIQ_WORKER_ARGS
@shared_task
@dramatiq.actor(**DRAMATIQ_WORKER_ARGS())
def judge_task(submission_id, problem_id):
uid = Submission.objects.get(id=submission_id).user_id
if User.objects.get(id=uid).is_disabled:

View File

@ -1,6 +0,0 @@
from __future__ import absolute_import, unicode_literals
# Django starts so that shared_task will use this app.
from .celery import app as celery_app
__all__ = ["celery_app"]

View File

@ -1,18 +0,0 @@
from __future__ import absolute_import, unicode_literals
import os
from celery import Celery
from django.conf import settings
# set the default Django settings module for the "celery" program.
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings")
app = Celery("oj")
# Using a string here means the worker will not have to
# pickle the object when using Windows.
app.config_from_object("django.conf:settings")
# load task modules from all registered Django app configs.
app.autodiscover_tasks(lambda: settings.INSTALLED_APPS)
# app.autodiscover_tasks()

View File

@ -26,16 +26,22 @@ with open(os.path.join(DATA_DIR, "config", "secret.key"), "r") as f:
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Applications
VENDOR_APPS = (
VENDOR_APPS = [
'django.contrib.auth',
'django.contrib.sessions',
'django.contrib.contenttypes',
'django.contrib.messages',
'django.contrib.staticfiles',
'rest_framework',
'raven.contrib.django.raven_compat'
)
LOCAL_APPS = (
'django_dramatiq',
'django_dbconn_retry',
]
if production_env:
VENDOR_APPS.append('raven.contrib.django.raven_compat')
LOCAL_APPS = [
'account',
'announcement',
'conf',
@ -45,11 +51,11 @@ LOCAL_APPS = (
'submission',
'options',
'judge',
)
]
INSTALLED_APPS = VENDOR_APPS + LOCAL_APPS
MIDDLEWARE_CLASSES = (
MIDDLEWARE = (
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
@ -164,6 +170,11 @@ LOGGING = {
'level': 'ERROR',
'propagate': True,
},
'dramatiq': {
'handlers': LOGGING_HANDLERS,
'level': 'DEBUG',
'propagate': False,
},
'': {
'handlers': LOGGING_HANDLERS,
'level': 'WARNING',
@ -202,11 +213,32 @@ CACHES = {
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
SESSION_CACHE_ALIAS = "default"
CELERY_RESULT_BACKEND = f"{REDIS_URL}/2"
BROKER_URL = f"{REDIS_URL}/3"
CELERY_TASK_SOFT_TIME_LIMIT = CELERY_TASK_TIME_LIMIT = 180
CELERY_ACCEPT_CONTENT = ["json"]
CELERY_TASK_SERIALIZER = "json"
DRAMATIQ_BROKER = {
"BROKER": "dramatiq.brokers.redis.RedisBroker",
"OPTIONS": {
"url": f"{REDIS_URL}/4",
},
"MIDDLEWARE": [
# "dramatiq.middleware.Prometheus",
"dramatiq.middleware.AgeLimit",
"dramatiq.middleware.TimeLimit",
"dramatiq.middleware.Callbacks",
"dramatiq.middleware.Retries",
# "django_dramatiq.middleware.AdminMiddleware",
"django_dramatiq.middleware.DbConnectionsMiddleware"
]
}
DRAMATIQ_RESULT_BACKEND = {
"BACKEND": "dramatiq.results.backends.redis.RedisBackend",
"BACKEND_OPTIONS": {
"url": f"{REDIS_URL}/4",
},
"MIDDLEWARE_OPTIONS": {
"result_ttl": None
}
}
RAVEN_CONFIG = {
'dsn': 'https://b200023b8aed4d708fb593c5e0a6ad3d:1fddaba168f84fcf97e0d549faaeaff0@sentry.io/263057'
}

View File

@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.3 on 2018-05-01 04:36
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('options', '0002_auto_20180501_0436'),
]
operations = [
migrations.RunSQL("""
DELETE FROM options_sysoptions WHERE key = 'languages';
""")
]

View File

@ -1,13 +1,92 @@
import functools
import os
from django.core.cache import cache
import threading
import time
from django.db import transaction, IntegrityError
from utils.constants import CacheKey
from utils.shortcuts import rand_str
from judge.languages import languages
from .models import SysOptions as SysOptionsModel
class my_property:
"""
metaclass 中使用以实现
1. ttl = None不缓存
2. ttl is callable条件缓存
3. 缓存 ttl
"""
def __init__(self, func=None, fset=None, ttl=None):
self.fset = fset
self.local = threading.local()
self.ttl = ttl
self._check_ttl(ttl)
self.func = func
functools.update_wrapper(self, func)
def _check_ttl(self, value):
if value is None or callable(value):
return
return self._check_timeout(value)
def _check_timeout(self, value):
if not isinstance(value, int):
raise ValueError(f"Invalid timeout type: {type(value)}")
if value < 0:
raise ValueError("Invalid timeout value, it must >= 0")
def __get__(self, obj, cls):
if obj is None:
return self
now = time.time()
if self.ttl:
if hasattr(self.local, "value"):
value, expire_at = self.local.value
if now < expire_at:
return value
value = self.func(obj)
# 如果定义了条件缓存, ttl 是一个函数,返回要缓存多久;返回 0 代表不要缓存
if callable(self.ttl):
# 而且条件缓存说不要缓存,那就直接返回,不要设置 local
timeout = self.ttl(value)
self._check_timeout(timeout)
if timeout == 0:
return value
elif timeout > 0:
self.local.value = (value, now + timeout)
else:
# ttl 是一个数字
self.local.value = (value, now + self.ttl)
return value
else:
return self.func(obj)
def __set__(self, obj, value):
if not self.fset:
raise AttributeError("can't set attribute")
self.fset(obj, value)
if hasattr(self.local, "value"):
del self.local.value
def setter(self, func):
self.fset = func
return self
def __call__(self, func, *args, **kwargs) -> "my_property":
if self.func is None:
self.func = func
functools.update_wrapper(self, func)
return self
DEFAULT_SHORT_TTL = 2
def default_token():
token = os.environ.get("JUDGE_SERVER_TOKEN")
return token if token else rand_str()
@ -41,23 +120,10 @@ class OptionDefaultValue:
class _SysOptionsMeta(type):
@classmethod
def _set_cache(mcs, option_key, option_value):
cache.set(f"{CacheKey.option}:{option_key}", option_value, timeout=60)
@classmethod
def _del_cache(mcs, option_key):
cache.delete(f"{CacheKey.option}:{option_key}")
@classmethod
def _get_keys(cls):
return [key for key in OptionKeys.__dict__ if not key.startswith("__")]
def rebuild_cache(cls):
for key in cls._get_keys():
# get option 的时候会写 cache 的
cls._get_option(key, use_cache=False)
@classmethod
def _init_option(mcs):
for item in mcs._get_keys():
@ -71,19 +137,14 @@ class _SysOptionsMeta(type):
pass
@classmethod
def _get_option(mcs, option_key, use_cache=True):
def _get_option(mcs, option_key):
try:
if use_cache:
option = cache.get(f"{CacheKey.option}:{option_key}")
if option:
return option
option = SysOptionsModel.objects.get(key=option_key)
value = option.value
mcs._set_cache(option_key, value)
return value
except SysOptionsModel.DoesNotExist:
mcs._init_option()
return mcs._get_option(option_key, use_cache=use_cache)
return mcs._get_option(option_key)
@classmethod
def _set_option(mcs, option_key: str, option_value):
@ -92,7 +153,6 @@ class _SysOptionsMeta(type):
option = SysOptionsModel.objects.select_for_update().get(key=option_key)
option.value = option_value
option.save()
mcs._del_cache(option_key)
except SysOptionsModel.DoesNotExist:
mcs._init_option()
mcs._set_option(option_key, option_value)
@ -105,7 +165,6 @@ class _SysOptionsMeta(type):
value = option.value + 1
option.value = value
option.save()
mcs._del_cache(option_key)
except SysOptionsModel.DoesNotExist:
mcs._init_option()
return mcs._increment(option_key)
@ -122,7 +181,7 @@ class _SysOptionsMeta(type):
result[key] = mcs._get_option(key)
return result
@property
@my_property(ttl=DEFAULT_SHORT_TTL)
def website_base_url(cls):
return cls._get_option(OptionKeys.website_base_url)
@ -130,7 +189,7 @@ class _SysOptionsMeta(type):
def website_base_url(cls, value):
cls._set_option(OptionKeys.website_base_url, value)
@property
@my_property(ttl=DEFAULT_SHORT_TTL)
def website_name(cls):
return cls._get_option(OptionKeys.website_name)
@ -138,7 +197,7 @@ class _SysOptionsMeta(type):
def website_name(cls, value):
cls._set_option(OptionKeys.website_name, value)
@property
@my_property(ttl=DEFAULT_SHORT_TTL)
def website_name_shortcut(cls):
return cls._get_option(OptionKeys.website_name_shortcut)
@ -146,7 +205,7 @@ class _SysOptionsMeta(type):
def website_name_shortcut(cls, value):
cls._set_option(OptionKeys.website_name_shortcut, value)
@property
@my_property(ttl=DEFAULT_SHORT_TTL)
def website_footer(cls):
return cls._get_option(OptionKeys.website_footer)
@ -154,7 +213,7 @@ class _SysOptionsMeta(type):
def website_footer(cls, value):
cls._set_option(OptionKeys.website_footer, value)
@property
@my_property
def allow_register(cls):
return cls._get_option(OptionKeys.allow_register)
@ -162,7 +221,7 @@ class _SysOptionsMeta(type):
def allow_register(cls, value):
cls._set_option(OptionKeys.allow_register, value)
@property
@my_property(ttl=DEFAULT_SHORT_TTL)
def submission_list_show_all(cls):
return cls._get_option(OptionKeys.submission_list_show_all)
@ -170,7 +229,7 @@ class _SysOptionsMeta(type):
def submission_list_show_all(cls, value):
cls._set_option(OptionKeys.submission_list_show_all, value)
@property
@my_property
def smtp_config(cls):
return cls._get_option(OptionKeys.smtp_config)
@ -178,7 +237,7 @@ class _SysOptionsMeta(type):
def smtp_config(cls, value):
cls._set_option(OptionKeys.smtp_config, value)
@property
@my_property
def judge_server_token(cls):
return cls._get_option(OptionKeys.judge_server_token)
@ -186,7 +245,7 @@ class _SysOptionsMeta(type):
def judge_server_token(cls, value):
cls._set_option(OptionKeys.judge_server_token, value)
@property
@my_property
def throttling(cls):
return cls._get_option(OptionKeys.throttling)
@ -194,7 +253,7 @@ class _SysOptionsMeta(type):
def throttling(cls, value):
cls._set_option(OptionKeys.throttling, value)
@property
@my_property(ttl=DEFAULT_SHORT_TTL)
def languages(cls):
return cls._get_option(OptionKeys.languages)
@ -202,15 +261,15 @@ class _SysOptionsMeta(type):
def languages(cls, value):
cls._set_option(OptionKeys.languages, value)
@property
@my_property(ttl=DEFAULT_SHORT_TTL)
def spj_languages(cls):
return [item for item in cls.languages if "spj" in item]
@property
@my_property(ttl=DEFAULT_SHORT_TTL)
def language_names(cls):
return [item["name"] for item in languages]
@property
@my_property(ttl=DEFAULT_SHORT_TTL)
def spj_language_names(cls):
return [item["name"] for item in cls.languages if "spj" in item]

View File

@ -0,0 +1,20 @@
# Generated by Django 2.1.7 on 2019-03-12 07:13
import django.contrib.postgres.fields.jsonb
from django.db import migrations
import problem.models
class Migration(migrations.Migration):
dependencies = [
('problem', '0012_auto_20180501_0436'),
]
operations = [
migrations.AddField(
model_name='problem',
name='io_mode',
field=django.contrib.postgres.fields.jsonb.JSONField(default=problem.models._default_io_mode),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 2.1.7 on 2019-03-13 09:38
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0013_problem_io_mode'),
]
operations = [
migrations.AddField(
model_name='problem',
name='share_submission',
field=models.BooleanField(default=False),
),
]

View File

@ -25,10 +25,19 @@ class ProblemDifficulty(object):
Low = "Low"
class ProblemIOMode(Choices):
standard = "Standard IO"
file = "File IO"
def _default_io_mode():
return {"io_mode": ProblemIOMode.standard, "input": "input.txt", "output": "output.txt"}
class Problem(models.Model):
# display ID
_id = models.TextField(db_index=True)
contest = models.ForeignKey(Contest, null=True)
contest = models.ForeignKey(Contest, null=True, on_delete=models.CASCADE)
# for contest problem
is_public = models.BooleanField(default=False)
title = models.TextField()
@ -47,11 +56,13 @@ class Problem(models.Model):
create_time = models.DateTimeField(auto_now_add=True)
# we can not use auto_now here
last_update_time = models.DateTimeField(null=True)
created_by = models.ForeignKey(User)
created_by = models.ForeignKey(User, on_delete=models.CASCADE)
# ms
time_limit = models.IntegerField()
# MB
memory_limit = models.IntegerField()
# io mode
io_mode = JSONField(default=_default_io_mode)
# special judge related
spj = models.BooleanField(default=False)
spj_language = models.TextField(null=True)
@ -69,6 +80,7 @@ class Problem(models.Model):
accepted_number = models.BigIntegerField(default=0)
# {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count
statistic_info = JSONField(default=dict)
share_submission = models.BooleanField(default=False)
class Meta:
db_table = "problem"

View File

@ -1,3 +1,5 @@
import re
from django import forms
from options.options import SysOptions
@ -5,7 +7,7 @@ from utils.api import UsernameSerializer, serializers
from utils.constants import Difficulty
from utils.serializers import LanguageNameMultiChoiceField, SPJLanguageNameChoiceField, LanguageNameChoiceField
from .models import Problem, ProblemRuleType, ProblemTag
from .models import Problem, ProblemRuleType, ProblemTag, ProblemIOMode
from .utils import parse_problem_template
@ -29,6 +31,20 @@ class CreateProblemCodeTemplateSerializer(serializers.Serializer):
pass
class ProblemIOModeSerializer(serializers.Serializer):
io_mode = serializers.ChoiceField(choices=ProblemIOMode.choices())
input = serializers.CharField()
output = serializers.CharField()
def validate(self, attrs):
if attrs["input"] == attrs["output"]:
raise serializers.ValidationError("Invalid io mode")
for item in (attrs["input"], attrs["output"]):
if not re.match("^[a-zA-Z0-9.]+$", item):
raise serializers.ValidationError("Invalid io file name format")
return attrs
class CreateOrEditProblemSerializer(serializers.Serializer):
_id = serializers.CharField(max_length=32, allow_blank=True, allow_null=True)
title = serializers.CharField(max_length=1024)
@ -43,6 +59,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
languages = LanguageNameMultiChoiceField()
template = serializers.DictField(child=serializers.CharField(min_length=1))
rule_type = serializers.ChoiceField(choices=[ProblemRuleType.ACM, ProblemRuleType.OI])
io_mode = ProblemIOModeSerializer()
spj = serializers.BooleanField()
spj_language = SPJLanguageNameChoiceField(allow_blank=True, allow_null=True)
spj_code = serializers.CharField(allow_blank=True, allow_null=True)
@ -52,6 +69,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False)
hint = serializers.CharField(allow_blank=True, allow_null=True)
source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True)
share_submission = serializers.BooleanField()
class CreateProblemSerializer(CreateOrEditProblemSerializer):

View File

@ -9,7 +9,7 @@ from django.conf import settings
from utils.api.tests import APITestCase
from .models import ProblemTag
from .models import ProblemTag, ProblemIOMode
from .models import Problem, ProblemRuleType
from contest.models import Contest
from contest.tests import DEFAULT_CONTEST_DATA
@ -25,6 +25,8 @@ DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test
"test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
"stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
"input_size": 0, "score": 0}],
"io_mode": {"io_mode": ProblemIOMode.standard, "input": "input.txt", "output": "output.txt"},
"share_submission": False,
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}

View File

@ -1,4 +1,6 @@
import re
from functools import lru_cache
TEMPLATE_BASE = """//PREPEND BEGIN
{}
@ -13,6 +15,7 @@ TEMPLATE_BASE = """//PREPEND BEGIN
//APPEND END"""
@lru_cache(maxsize=100)
def parse_problem_template(template_str):
prepend = re.findall(r"//PREPEND BEGIN\n([\s\S]+?)//PREPEND END", template_str)
template = re.findall(r"//TEMPLATE BEGIN\n([\s\S]+?)//TEMPLATE END", template_str)
@ -22,5 +25,6 @@ def parse_problem_template(template_str):
"append": append[0] if append else ""}
@lru_cache(maxsize=100)
def build_problem_template(prepend, template, append):
return TEMPLATE_BASE.format(prepend, template, append)

View File

@ -300,8 +300,6 @@ class ProblemAPI(ProblemBase):
except Problem.DoesNotExist:
return self.error("Problem does not exists")
ensure_created_by(problem, request.user)
if Submission.objects.filter(problem=problem).exists():
return self.error("Can't delete the problem as it has submissions")
d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
if os.path.isdir(d):
shutil.rmtree(d, ignore_errors=True)
@ -541,7 +539,7 @@ class ExportProblemAPI(APIView):
with zipfile.ZipFile(path, "w") as zip_file:
for index, problem in enumerate(problems):
self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1)
delete_files.apply_async((path,), countdown=300)
delete_files.send_with_options(args=(path,), delay=300_000)
resp = FileResponse(open(path, "rb"))
resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = f"attachment;filename=problem-export.zip"

View File

@ -25,7 +25,7 @@ class PickOneAPI(APIView):
class ProblemAPI(APIView):
@staticmethod
def _add_problem_status(request, queryset_values):
if request.user.is_authenticated():
if request.user.is_authenticated:
profile = request.user.userprofile
acm_problems_status = profile.acm_problems_status.get("problems", {})
oi_problems_status = profile.oi_problems_status.get("problems", {})
@ -81,7 +81,7 @@ class ProblemAPI(APIView):
class ContestProblemAPI(APIView):
def _add_problem_status(self, request, queryset_values):
if request.user.is_authenticated():
if request.user.is_authenticated:
profile = request.user.userprofile
if self.contest.rule_type == ContestRuleType.ACM:
problems_status = profile.acm_problems_status.get("contest_problems", {})

View File

@ -22,8 +22,8 @@ class JudgeStatus:
class Submission(models.Model):
id = models.TextField(default=rand_str, primary_key=True, db_index=True)
contest = models.ForeignKey(Contest, null=True)
problem = models.ForeignKey(Problem)
contest = models.ForeignKey(Contest, null=True, on_delete=models.CASCADE)
problem = models.ForeignKey(Problem, on_delete=models.CASCADE)
create_time = models.DateTimeField(auto_now_add=True)
user_id = models.IntegerField(db_index=True)
username = models.TextField()
@ -41,6 +41,7 @@ class Submission(models.Model):
def check_user_permission(self, user, check_share=True):
return self.user_id == user.id or \
(check_share and self.shared is True) or \
(check_share and self.problem.share_submission) or \
user.is_super_admin() or \
user.can_mgmt_all_problem() or \
self.problem.created_by_id == user.id

View File

@ -46,6 +46,6 @@ class SubmissionListSerializer(serializers.ModelSerializer):
def get_show_link(self, obj):
# 没传user或为匿名user
if self.user is None or not self.user.is_authenticated():
if self.user is None or not self.user.is_authenticated:
return False
return obj.check_user_permission(self.user)

View File

@ -57,7 +57,7 @@ class SubmissionListTest(SubmissionPrepare):
self.assertSuccess(resp)
@mock.patch("submission.views.oj.judge_task.delay")
@mock.patch("submission.views.oj.judge_task.send")
class SubmissionAPITest(SubmissionPrepare):
def setUp(self):
self._create_problem_and_submission()

View File

@ -18,5 +18,5 @@ class SubmissionRejudgeAPI(APIView):
submission.statistic_info = {}
submission.save()
judge_task.delay(submission.id, submission.problem.id)
judge_task.send(submission.id, submission.problem.id)
return self.success()

View File

@ -80,7 +80,7 @@ class SubmissionAPI(APIView):
contest_id=data.get("contest_id"))
# use this for debug
# JudgeDispatcher(submission.id, problem.id).judge()
judge_task.delay(submission.id, problem.id)
judge_task.send(submission.id, problem.id)
if hide_id:
return self.success()
else:
@ -198,6 +198,6 @@ class SubmissionExistsAPI(APIView):
def get(self, request):
if not request.GET.get("problem_id"):
return self.error("Parameter error, problem_id is required")
return self.success(request.user.is_authenticated() and
return self.success(request.user.is_authenticated and
Submission.objects.filter(problem_id=request.GET["problem_id"],
user_id=request.user.id).exists())

View File

@ -1,7 +1,6 @@
import functools
import json
import logging
from collections import OrderedDict
from django.http import HttpResponse, QueryDict
from django.utils.decorators import method_decorator
@ -89,20 +88,24 @@ class APIView(View):
def error(self, msg="error", err="error"):
return self.response({"error": err, "data": msg})
def _serializer_error_to_str(self, errors):
for k, v in errors.items():
if isinstance(v, list):
return k, v[0]
elif isinstance(v, OrderedDict):
for _k, _v in v.items():
return self._serializer_error_to_str({_k: _v})
def extract_errors(self, errors, key="field"):
if isinstance(errors, dict):
if not errors:
return key, "Invalid field"
key = list(errors.keys())[0]
return self.extract_errors(errors.pop(key), key)
elif isinstance(errors, list):
return self.extract_errors(errors[0], key)
return key, errors
def invalid_serializer(self, serializer):
k, v = self._serializer_error_to_str(serializer.errors)
if k != "non_field_errors":
return self.error(err="invalid-" + k, msg=k + ": " + v)
key, error = self.extract_errors(serializer.errors)
if key == "non_field_errors":
msg = error
else:
return self.error(err="invalid-field", msg=v)
msg = f"{key}: {error}"
return self.error(err=f"invalid-{key}", msg=msg)
def server_error(self):
return self.error(err="server-error", msg="server error")

View File

@ -1,4 +1,4 @@
from django.core.urlresolvers import reverse
from django.urls import reverse
from django.test.testcases import TestCase
from rest_framework.test import APIClient

View File

@ -25,7 +25,6 @@ class CacheKey:
waiting_queue = "waiting_queue"
contest_rank_cache = "contest_rank_cache"
website_config = "website_config"
option = "option"
class Difficulty(Choices):

View File

@ -81,3 +81,14 @@ def send_email(smtp_config, from_name, to_email, to_name, subject, content):
def get_env(name, default=""):
return os.environ.get(name, default)
def DRAMATIQ_WORKER_ARGS(time_limit=3600_000, max_retries=0, max_age=7200_000):
return {"max_retries": max_retries, "time_limit": time_limit, "max_age": max_age}
def check_is_id(value):
try:
return int(value) > 0
except Exception:
return False

View File

@ -1,8 +1,10 @@
import os
from celery import shared_task
import dramatiq
from utils.shortcuts import DRAMATIQ_WORKER_ARGS
@shared_task
@dramatiq.actor(**DRAMATIQ_WORKER_ARGS())
def delete_files(*args):
for item in args:
try:

View File

@ -1,7 +1,8 @@
from django.conf.urls import url
from .views import SimditorImageUploadAPIView
from .views import SimditorImageUploadAPIView, SimditorFileUploadAPIView
urlpatterns = [
url(r"^upload_image/?$", SimditorImageUploadAPIView.as_view(), name="upload_image")
url(r"^upload_image/?$", SimditorImageUploadAPIView.as_view(), name="upload_image"),
url(r"^upload_file/?$", SimditorFileUploadAPIView.as_view(), name="upload_file")
]

View File

@ -1,6 +1,6 @@
import os
from django.conf import settings
from account.serializers import ImageUploadForm
from account.serializers import ImageUploadForm, FileUploadForm
from utils.shortcuts import rand_str
from utils.api import CSRFExemptAPIView
import logging
@ -35,10 +35,41 @@ class SimditorImageUploadAPIView(CSRFExemptAPIView):
except IOError as e:
logger.error(e)
return self.response({
"success": True,
"success": False,
"msg": "Upload Error",
"file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"})
"file_path": ""})
return self.response({
"success": True,
"msg": "Success",
"file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"})
class SimditorFileUploadAPIView(CSRFExemptAPIView):
request_parsers = ()
def post(self, request):
form = FileUploadForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
else:
return self.response({
"success": False,
"msg": "Upload failed"
})
suffix = os.path.splitext(file.name)[-1].lower()
file_name = rand_str(10) + suffix
try:
with open(os.path.join(settings.UPLOAD_DIR, file_name), "wb") as f:
for chunk in file:
f.write(chunk)
except IOError as e:
logger.error(e)
return self.response({
"success": False,
"msg": "Upload Error"})
return self.response({
"success": True,
"msg": "Success",
"file_path": f"{settings.UPLOAD_PREFIX}/{file_name}",
"file_name": file.name})

View File

@ -142,7 +142,7 @@ class XSSHtml(HTMLParser):
return attrs
def _true_url(self, url):
prog = re.compile(r"^(http|https|ftp)://.+", re.I | re.S)
prog = re.compile(r"(^(http|https|ftp)://.+)|(^/)", re.I | re.S)
if prog.match(url):
return url
else: