OnlineJudge/problem/views/admin.py
2018-10-14 17:49:20 +08:00

700 lines
30 KiB
Python

import hashlib
import json
import os
import shutil
import tempfile
import zipfile
from wsgiref.util import FileWrapper
from django.conf import settings
from django.db import transaction
from django.db.models import Q
from django.http import StreamingHttpResponse, FileResponse
from account.decorators import problem_permission_required, ensure_created_by
from contest.models import Contest, ContestStatus
from fps.parser import FPSHelper, FPSParser
from judge.dispatcher import SPJCompiler
from judge.languages import language_names
from submission.models import Submission, JudgeStatus
from utils.api import APIView, CSRFExemptAPIView, validate_serializer, APIError
from utils.constants import Difficulty
from utils.shortcuts import rand_str, natural_sort_key
from utils.tasks import delete_files
from ..models import Problem, ProblemRuleType, ProblemTag
from ..serializers import (CreateContestProblemSerializer, CompileSPJSerializer,
CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer,
ProblemAdminSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer,
AddContestProblemSerializer, ExportProblemSerializer,
ExportProblemRequestSerialzier, UploadProblemForm, ImportProblemSerializer,
FPSProblemSerializer)
from ..utils import TEMPLATE_BASE, build_problem_template
class TestCaseZipProcessor(object):
def process_zip(self, uploaded_zip_file, spj, dir=""):
try:
zip_file = zipfile.ZipFile(uploaded_zip_file, "r")
except zipfile.BadZipFile:
raise APIError("Bad zip file")
name_list = zip_file.namelist()
test_case_list = self.filter_name_list(name_list, spj=spj, dir=dir)
if not test_case_list:
raise APIError("Empty file")
test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
os.chmod(test_case_dir, 0o710)
size_cache = {}
md5_cache = {}
for item in test_case_list:
with open(os.path.join(test_case_dir, item), "wb") as f:
content = zip_file.read(f"{dir}{item}").replace(b"\r\n", b"\n")
size_cache[item] = len(content)
if item.endswith(".out"):
md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest()
f.write(content)
test_case_info = {"spj": spj, "test_cases": {}}
info = []
if spj:
for index, item in enumerate(test_case_list):
data = {"input_name": item, "input_size": size_cache[item]}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
else:
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
for index, item in enumerate(test_case_list):
data = {"stripped_output_md5": md5_cache[item[1]],
"input_size": size_cache[item[0]],
"output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1]}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f:
f.write(json.dumps(test_case_info, indent=4))
for item in os.listdir(test_case_dir):
os.chmod(os.path.join(test_case_dir, item), 0o640)
return info, test_case_id
def filter_name_list(self, name_list, spj, dir=""):
ret = []
prefix = 1
if spj:
while True:
in_name = f"{prefix}.in"
if f"{dir}{in_name}" in name_list:
ret.append(in_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
else:
while True:
in_name = f"{prefix}.in"
out_name = f"{prefix}.out"
if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list:
ret.append(in_name)
ret.append(out_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
request_parsers = ()
def get(self, request):
problem_id = request.GET.get("problem_id")
if not problem_id:
return self.error("Parameter error, problem_id is required")
try:
problem = Problem.objects.get(id=problem_id)
except Problem.DoesNotExist:
return self.error("Problem does not exists")
if problem.contest:
ensure_created_by(problem.contest, request.user)
else:
ensure_created_by(problem, request.user)
test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
if not os.path.isdir(test_case_dir):
return self.error("Test case does not exists")
name_list = self.filter_name_list(os.listdir(test_case_dir), problem.spj)
name_list.append("info")
file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip")
with zipfile.ZipFile(file_name, "w") as file:
for test_case in name_list:
file.write(f"{test_case_dir}/{test_case}", test_case)
response = StreamingHttpResponse(FileWrapper(open(file_name, "rb")),
content_type="application/octet-stream")
response["Content-Disposition"] = f"attachment; filename=problem_{problem.id}_test_cases.zip"
response["Content-Length"] = os.path.getsize(file_name)
return response
def post(self, request):
form = TestCaseUploadForm(request.POST, request.FILES)
if form.is_valid():
spj = form.cleaned_data["spj"] == "true"
file = form.cleaned_data["file"]
else:
return self.error("Upload failed")
zip_file = f"/tmp/{rand_str()}.zip"
with open(zip_file, "wb") as f:
for chunk in file:
f.write(chunk)
info, test_case_id = self.process_zip(zip_file, spj=spj)
os.remove(zip_file)
return self.success({"id": test_case_id, "info": info, "spj": spj})
class CompileSPJAPI(APIView):
@validate_serializer(CompileSPJSerializer)
def post(self, request):
data = request.data
spj_version = rand_str(8)
error = SPJCompiler(data["spj_code"], spj_version, data["spj_language"]).compile_spj()
if error:
return self.error(error)
else:
return self.success()
class ProblemBase(APIView):
def common_checks(self, request):
data = request.data
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
return "Invalid spj"
if not data["spj_compile_ok"]:
return "SPJ code must be compiled successfully"
data["spj_version"] = hashlib.md5(
(data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
if item["score"] <= 0:
return "Invalid score"
else:
total_score += item["score"]
data["total_score"] = total_score
data["languages"] = list(data["languages"])
class ProblemAPI(ProblemBase):
@problem_permission_required
@validate_serializer(CreateProblemSerializer)
def post(self, request):
data = request.data
_id = data["_id"]
if not _id:
return self.error("Display ID is required")
if Problem.objects.filter(_id=_id, contest_id__isnull=True).exists():
return self.error("Display ID already exists")
error_info = self.common_checks(request)
if error_info:
return self.error(error_info)
# todo check filename and score info
tags = data.pop("tags")
data["created_by"] = request.user
problem = Problem.objects.create(**data)
for item in tags:
try:
tag = ProblemTag.objects.get(name=item)
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=item)
problem.tags.add(tag)
return self.success(ProblemAdminSerializer(problem).data)
@problem_permission_required
def get(self, request):
problem_id = request.GET.get("id")
rule_type = request.GET.get("rule_type")
user = request.user
if problem_id:
try:
problem = Problem.objects.get(id=problem_id)
ensure_created_by(problem, request.user)
return self.success(ProblemAdminSerializer(problem).data)
except Problem.DoesNotExist:
return self.error("Problem does not exist")
problems = Problem.objects.filter(contest_id__isnull=True).order_by("-create_time")
if rule_type:
if rule_type not in ProblemRuleType.choices():
return self.error("Invalid rule_type")
else:
problems = problems.filter(rule_type=rule_type)
keyword = request.GET.get("keyword", "").strip()
if keyword:
problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword))
if not user.can_mgmt_all_problem():
problems = problems.filter(created_by=user)
return self.success(self.paginate_data(request, problems, ProblemAdminSerializer))
@problem_permission_required
@validate_serializer(EditProblemSerializer)
def put(self, request):
data = request.data
problem_id = data.pop("id")
try:
problem = Problem.objects.get(id=problem_id)
ensure_created_by(problem, request.user)
except Problem.DoesNotExist:
return self.error("Problem does not exist")
_id = data["_id"]
if not _id:
return self.error("Display ID is required")
if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest_id__isnull=True).exists():
return self.error("Display ID already exists")
error_info = self.common_checks(request)
if error_info:
return self.error(error_info)
# todo check filename and score info
tags = data.pop("tags")
data["languages"] = list(data["languages"])
for k, v in data.items():
setattr(problem, k, v)
problem.save()
problem.tags.remove(*problem.tags.all())
for tag in tags:
try:
tag = ProblemTag.objects.get(name=tag)
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=tag)
problem.tags.add(tag)
return self.success()
@problem_permission_required
def delete(self, request):
id = request.GET.get("id")
if not id:
return self.error("Invalid parameter, id is required")
try:
problem = Problem.objects.get(id=id, contest_id__isnull=True)
except Problem.DoesNotExist:
return self.error("Problem does not exists")
ensure_created_by(problem, request.user)
if Submission.objects.filter(problem=problem).exists():
return self.error("Can't delete the problem as it has submissions")
d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
if os.path.isdir(d):
shutil.rmtree(d, ignore_errors=True)
problem.delete()
return self.success()
class ContestProblemAPI(ProblemBase):
@validate_serializer(CreateContestProblemSerializer)
def post(self, request):
data = request.data
try:
contest = Contest.objects.get(id=data.pop("contest_id"))
ensure_created_by(contest, request.user)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
if data["rule_type"] != contest.rule_type:
return self.error("Invalid rule type")
_id = data["_id"]
if not _id:
return self.error("Display ID is required")
if Problem.objects.filter(_id=_id, contest=contest).exists():
return self.error("Duplicate Display id")
error_info = self.common_checks(request)
if error_info:
return self.error(error_info)
# todo check filename and score info
data["contest"] = contest
tags = data.pop("tags")
data["created_by"] = request.user
problem = Problem.objects.create(**data)
for item in tags:
try:
tag = ProblemTag.objects.get(name=item)
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=item)
problem.tags.add(tag)
return self.success(ProblemAdminSerializer(problem).data)
def get(self, request):
problem_id = request.GET.get("id")
contest_id = request.GET.get("contest_id")
user = request.user
if problem_id:
try:
problem = Problem.objects.get(id=problem_id)
ensure_created_by(problem.contest, user)
except Problem.DoesNotExist:
return self.error("Problem does not exist")
return self.success(ProblemAdminSerializer(problem).data)
if not contest_id:
return self.error("Contest id is required")
try:
contest = Contest.objects.get(id=contest_id)
ensure_created_by(contest, user)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
problems = Problem.objects.filter(contest=contest).order_by("-create_time")
if user.is_admin():
problems = problems.filter(contest__created_by=user)
keyword = request.GET.get("keyword")
if keyword:
problems = problems.filter(title__contains=keyword)
return self.success(self.paginate_data(request, problems, ProblemAdminSerializer))
@validate_serializer(EditContestProblemSerializer)
def put(self, request):
data = request.data
user = request.user
try:
contest = Contest.objects.get(id=data.pop("contest_id"))
ensure_created_by(contest, user)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
if data["rule_type"] != contest.rule_type:
return self.error("Invalid rule type")
problem_id = data.pop("id")
try:
problem = Problem.objects.get(id=problem_id, contest=contest)
except Problem.DoesNotExist:
return self.error("Problem does not exist")
_id = data["_id"]
if not _id:
return self.error("Display ID is required")
if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest=contest).exists():
return self.error("Display ID already exists")
error_info = self.common_checks(request)
if error_info:
return self.error(error_info)
# todo check filename and score info
tags = data.pop("tags")
data["languages"] = list(data["languages"])
for k, v in data.items():
setattr(problem, k, v)
problem.save()
problem.tags.remove(*problem.tags.all())
for tag in tags:
try:
tag = ProblemTag.objects.get(name=tag)
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=tag)
problem.tags.add(tag)
return self.success()
def delete(self, request):
id = request.GET.get("id")
if not id:
return self.error("Invalid parameter, id is required")
try:
problem = Problem.objects.get(id=id, contest_id__isnull=False)
except Problem.DoesNotExist:
return self.error("Problem does not exists")
ensure_created_by(problem.contest, request.user)
if Submission.objects.filter(problem=problem).exists():
return self.error("Can't delete the problem as it has submissions")
d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
if os.path.isdir(d):
shutil.rmtree(d, ignore_errors=True)
problem.delete()
return self.success()
class MakeContestProblemPublicAPIView(APIView):
@validate_serializer(ContestProblemMakePublicSerializer)
@problem_permission_required
def post(self, request):
data = request.data
display_id = data.get("display_id")
if Problem.objects.filter(_id=display_id, contest_id__isnull=True).exists():
return self.error("Duplicate display ID")
try:
problem = Problem.objects.get(id=data["id"])
except Problem.DoesNotExist:
return self.error("Problem does not exist")
if not problem.contest or problem.is_public:
return self.error("Already be a public problem")
problem.is_public = True
problem.save()
# https://docs.djangoproject.com/en/1.11/topics/db/queries/#copying-model-instances
tags = problem.tags.all()
problem.pk = None
problem.contest = None
problem._id = display_id
problem.visible = False
problem.submission_number = problem.accepted_number = 0
problem.statistic_info = {}
problem.save()
problem.tags.set(tags)
return self.success()
class AddContestProblemAPI(APIView):
@validate_serializer(AddContestProblemSerializer)
def post(self, request):
data = request.data
try:
contest = Contest.objects.get(id=data["contest_id"])
problem = Problem.objects.get(id=data["problem_id"])
except (Contest.DoesNotExist, Problem.DoesNotExist):
return self.error("Contest or Problem does not exist")
if contest.status == ContestStatus.CONTEST_ENDED:
return self.error("Contest has ended")
if Problem.objects.filter(contest=contest, _id=data["display_id"]).exists():
return self.error("Duplicate display id in this contest")
tags = problem.tags.all()
problem.pk = None
problem.contest = contest
problem.is_public = True
problem.visible = True
problem._id = request.data["display_id"]
problem.submission_number = problem.accepted_number = 0
problem.statistic_info = {}
problem.save()
problem.tags.set(tags)
return self.success()
class ExportProblemAPI(APIView):
def choose_answers(self, user, problem):
ret = []
for item in problem.languages:
submission = Submission.objects.filter(problem=problem,
user_id=user.id,
language=item,
result=JudgeStatus.ACCEPTED).order_by("-create_time").first()
if submission:
ret.append({"language": submission.language, "code": submission.code})
return ret
def process_one_problem(self, zip_file, user, problem, index):
info = ExportProblemSerializer(problem).data
info["answers"] = self.choose_answers(user, problem=problem)
compression = zipfile.ZIP_DEFLATED
zip_file.writestr(zinfo_or_arcname=f"{index}/problem.json",
data=json.dumps(info, indent=4),
compress_type=compression)
problem_test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
with open(os.path.join(problem_test_case_dir, "info")) as f:
info = json.load(f)
for k, v in info["test_cases"].items():
zip_file.write(filename=os.path.join(problem_test_case_dir, v["input_name"]),
arcname=f"{index}/testcase/{v['input_name']}",
compress_type=compression)
if not info["spj"]:
zip_file.write(filename=os.path.join(problem_test_case_dir, v["output_name"]),
arcname=f"{index}/testcase/{v['output_name']}",
compress_type=compression)
@validate_serializer(ExportProblemRequestSerialzier)
def get(self, request):
problems = Problem.objects.filter(id__in=request.data["problem_id"])
for problem in problems:
if problem.contest:
ensure_created_by(problem.contest, request.user)
else:
ensure_created_by(problem, request.user)
path = f"/tmp/{rand_str()}.zip"
with zipfile.ZipFile(path, "w") as zip_file:
for index, problem in enumerate(problems):
self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1)
delete_files.apply_async((path,), countdown=300)
resp = FileResponse(open(path, "rb"))
resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = f"attachment;filename=problem-export.zip"
return resp
class ImportProblemAPI(CSRFExemptAPIView, TestCaseZipProcessor):
request_parsers = ()
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
tmp_file = f"/tmp/{rand_str()}.zip"
with open(tmp_file, "wb") as f:
for chunk in file:
f.write(chunk)
else:
return self.error("Upload failed")
count = 0
with zipfile.ZipFile(tmp_file, "r") as zip_file:
name_list = zip_file.namelist()
for item in name_list:
if "/problem.json" in item:
count += 1
with transaction.atomic():
for i in range(1, count + 1):
with zip_file.open(f"{i}/problem.json") as f:
problem_info = json.load(f)
serializer = ImportProblemSerializer(data=problem_info)
if not serializer.is_valid():
return self.error(f"Invalid problem format, error is {serializer.errors}")
else:
problem_info = serializer.data
for item in problem_info["template"].keys():
if item not in language_names:
return self.error(f"Unsupported language {item}")
problem_info["display_id"] = problem_info["display_id"][:24]
for k, v in problem_info["template"].items():
problem_info["template"][k] = build_problem_template(v["prepend"], v["template"],
v["append"])
spj = problem_info["spj"] is not None
rule_type = problem_info["rule_type"]
test_case_score = problem_info["test_case_score"]
# process test case
_, test_case_id = self.process_zip(tmp_file, spj=spj, dir=f"{i}/testcase/")
problem_obj = Problem.objects.create(_id=problem_info["display_id"],
title=problem_info["title"],
description=problem_info["description"]["value"],
input_description=problem_info["input_description"][
"value"],
output_description=problem_info["output_description"][
"value"],
hint=problem_info["hint"]["value"],
test_case_score=test_case_score if test_case_score else [],
time_limit=problem_info["time_limit"],
memory_limit=problem_info["memory_limit"],
samples=problem_info["samples"],
template=problem_info["template"],
rule_type=problem_info["rule_type"],
source=problem_info["source"],
spj=spj,
spj_code=problem_info["spj"]["code"] if spj else None,
spj_language=problem_info["spj"][
"language"] if spj else None,
spj_version=rand_str(8) if spj else "",
languages=language_names,
created_by=request.user,
visible=False,
difficulty=Difficulty.MID,
total_score=sum(item["score"] for item in test_case_score)
if rule_type == ProblemRuleType.OI else 0,
test_case_id=test_case_id
)
for tag_name in problem_info["tags"]:
tag_obj, _ = ProblemTag.objects.get_or_create(name=tag_name)
problem_obj.tags.add(tag_obj)
return self.success({"import_count": count})
class FPSProblemImport(CSRFExemptAPIView):
request_parsers = ()
def _create_problem(self, problem_data, creator):
if problem_data["time_limit"]["unit"] == "ms":
time_limit = problem_data["time_limit"]["value"]
else:
time_limit = problem_data["time_limit"]["value"] * 1000
template = {}
prepend = {}
append = {}
for t in problem_data["prepend"]:
prepend[t["language"]] = t["code"]
for t in problem_data["append"]:
append[t["language"]] = t["code"]
for t in problem_data["template"]:
our_lang = lang = t["language"]
if lang == "Python":
our_lang = "Python3"
template[our_lang] = TEMPLATE_BASE.format(prepend.get(lang, ""), t["code"], append.get(lang, ""))
spj = problem_data["spj"] is not None
Problem.objects.create(_id=f"fps-{rand_str(4)}",
title=problem_data["title"],
description=problem_data["description"],
input_description=problem_data["input"],
output_description=problem_data["output"],
hint=problem_data["hint"],
test_case_score=[],
time_limit=time_limit,
memory_limit=problem_data["memory_limit"]["value"],
samples=problem_data["samples"],
template=template,
rule_type=ProblemRuleType.ACM,
source=problem_data.get("source", ""),
spj=spj,
spj_code=problem_data["spj"]["code"] if spj else None,
spj_language=problem_data["spj"]["language"] if spj else None,
spj_version=rand_str(8) if spj else "",
visible=False,
languages=language_names,
created_by=creator,
difficulty=Difficulty.MID,
test_case_id=problem_data["test_case_id"])
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
with tempfile.NamedTemporaryFile("wb") as tf:
for chunk in file.chunks(4096):
tf.file.write(chunk)
problems = FPSParser(tf.name).parse()
else:
return self.error("Parse upload file error")
helper = FPSHelper()
with transaction.atomic():
for _problem in problems:
test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
helper.save_test_case(_problem, test_case_dir)
problem_data = helper.save_image(_problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX)
s = FPSProblemSerializer(data=problem_data)
if not s.is_valid():
return self.error(f"Parse FPS file error: {s.errors}")
problem_data = s.data
problem_data["test_case_id"] = test_case_id
self._create_problem(problem_data, request.user)
return self.success({"import_count": len(problems)})