diff --git a/fps/parser.py b/fps/parser.py index bca71b79..95f19da5 100644 --- a/fps/parser.py +++ b/fps/parser.py @@ -1,26 +1,33 @@ #!/usr/bin/env python3 + import base64 import copy import random import string +import hashlib +import json import xml.etree.ElementTree as ET class FPSParser(object): - def __init__(self, fps_path): - self.fps_path = fps_path - - @property - def _root(self): - root = ET.ElementTree(file=self.fps_path).getroot() - version = root.attrib.get("version", "No Version") + def __init__(self, fps_path=None, string_data=None): + if fps_path: + self._etree = ET.parse(fps_path).getroot() + elif string_data: + self._ertree = ET.fromstring(string_data).getroot() + else: + raise ValueError("You must tell me the file path or directly give me the data for the file") + version = self._etree.attrib.get("version", "No Version") if version not in ["1.1", "1.2"]: raise ValueError("Unsupported version '" + version + "'") - return root + + @property + def etree(self): + return self._etree def parse(self): ret = [] - for node in self._root: + for node in self._etree: if node.tag == "item": ret.append(self._parse_one_problem(node)) return ret @@ -112,20 +119,50 @@ class FPSHelper(object): _problem[item] = _problem[item].replace(img["src"], os.path.join(base_url, file_name)) return _problem - def save_test_case(self, problem, base_dir, input_preprocessor=None, output_preprocessor=None): + # { + # "spj": false, + # "test_cases": { + # "1": { + # "stripped_output_md5": "84f244e41d3c8fd4bdb43ed0e1f7a067", + # "input_size": 12, + # "output_size": 7, + # "input_name": "1.in", + # "output_name": "1.out" + # } + # } + # } + def save_test_case(self, problem, base_dir): + spj = problem.get("spj", {}) + test_cases = {} for index, item in enumerate(problem["test_cases"]): - with open(os.path.join(base_dir, str(index + 1) + ".in"), "w", encoding="utf-8") as f: - if input_preprocessor: - input_content = input_preprocessor(item["input"]) - else: - input_content = item["input"] - f.write(input_content) - with open(os.path.join(base_dir, str(index + 1) + ".out"), "w", encoding="utf-8") as f: - if output_preprocessor: - output_content = output_preprocessor(item["output"]) - else: - output_content = item["output"] - f.write(output_content) + input_content = item.get("input") + output_content = item.get("output") + if input_content: + with open(os.path.join(base_dir, str(index + 1) + ".in"), "w", encoding="utf-8") as f: + f.write(input_content) + if output_content: + with open(os.path.join(base_dir, str(index + 1) + ".out"), "w", encoding="utf-8") as f: + f.write(output_content) + if spj: + one_info = { + "input_size": len(input_content), + "input_name": f"{index}.in" + } + else: + one_info = { + "input_size": len(input_content), + "input_name": f"{index}.in", + "output_size": len(output_content), + "output_name": f"{index}.out", + "stripped_output_md5": hashlib.md5(output_content.rstrip()).hexdigest() + } + test_cases[index] = one_info + info = { + "spj": True if spj else False, + "test_cases": test_cases + } + with open(os.path.join(base_dir, "info"), "w", encoding="utf-8") as f: + f.write(json.dumps(info, indent=4)) if __name__ == "__main__": diff --git a/problem/serializers.py b/problem/serializers.py index 051b710b..103c100d 100644 --- a/problem/serializers.py +++ b/problem/serializers.py @@ -1,7 +1,9 @@ from django import forms +from options.options import SysOptions from judge.languages import language_names, spj_language_names from utils.api import UsernameSerializer, serializers +from utils.constants import Difficulty from .models import Problem, ProblemRuleType, ProblemTag from .utils import parse_problem_template @@ -27,12 +29,6 @@ class CreateProblemCodeTemplateSerializer(serializers.Serializer): pass -class Difficulty(object): - LOW = "Low" - MID = "Mid" - HIGH = "High" - - class CreateOrEditProblemSerializer(serializers.Serializer): _id = serializers.CharField(max_length=32, allow_blank=True, allow_null=True) title = serializers.CharField(max_length=128) @@ -41,7 +37,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer): output_description = serializers.CharField() samples = serializers.ListField(child=CreateSampleSerializer(), allow_empty=False) test_case_id = serializers.CharField(max_length=32) - test_case_score = serializers.ListField(child=CreateTestCaseScoreSerializer(), allow_empty=False) + test_case_score = serializers.ListField(child=CreateTestCaseScoreSerializer(), allow_empty=True) time_limit = serializers.IntegerField(min_value=1, max_value=1000 * 60) memory_limit = serializers.IntegerField(min_value=1, max_value=1024) languages = serializers.MultipleChoiceField(choices=language_names) @@ -52,7 +48,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer): spj_code = serializers.CharField(allow_blank=True, allow_null=True) spj_compile_ok = serializers.BooleanField(default=False) visible = serializers.BooleanField() - difficulty = serializers.ChoiceField(choices=[Difficulty.LOW, Difficulty.MID, Difficulty.HIGH]) + difficulty = serializers.ChoiceField(choices=Difficulty.choices()) tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False) hint = serializers.CharField(allow_blank=True, allow_null=True) source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True) @@ -128,41 +124,42 @@ class ContestProblemMakePublicSerializer(serializers.Serializer): class ExportProblemSerializer(serializers.ModelSerializer): + display_id = serializers.SerializerMethodField() description = serializers.SerializerMethodField() input_description = serializers.SerializerMethodField() output_description = serializers.SerializerMethodField() test_case_score = serializers.SerializerMethodField() hint = serializers.SerializerMethodField() - time_limit = serializers.SerializerMethodField() - memory_limit = serializers.SerializerMethodField() spj = serializers.SerializerMethodField() template = serializers.SerializerMethodField() + source = serializers.SerializerMethodField() + tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True) + + def get_display_id(self, obj): + return obj._id + + def _html_format_value(self, value): + return {"format": "html", "value": value} def get_description(self, obj): - return {"format": "html", "value": obj.description} + return self._html_format_value(obj.description) def get_input_description(self, obj): - return {"format": "html", "value": obj.input_description} + return self._html_format_value(obj.input_description) def get_output_description(self, obj): - return {"format": "html", "value": obj.output_description} + return self._html_format_value(obj.output_description) def get_hint(self, obj): - return {"format": "html", "value": obj.hint} + return self._html_format_value(obj.hint) def get_test_case_score(self, obj): - return obj.test_case_score if obj.rule_type == ProblemRuleType.OI else [] - - def get_time_limit(self, obj): - return {"unit": "ms", "value": obj.time_limit} - - def get_memory_limit(self, obj): - return {"unit": "MB", "value": obj.memory_limit} + return [{"score": item["score"], "input_name": item["input_name"]} + for item in obj.test_case_score] if obj.rule_type == ProblemRuleType.OI else None def get_spj(self, obj): - return {"enabled": obj.spj, - "code": obj.spj_code if obj.spj else None, - "language": obj.spj_language if obj.spj else None} + return {"code": obj.spj_code, + "language": obj.spj_language} if obj.spj else None def get_template(self, obj): ret = {} @@ -170,9 +167,12 @@ class ExportProblemSerializer(serializers.ModelSerializer): ret[k] = parse_problem_template(v) return ret + def get_source(self, obj): + return obj.source or f"{SysOptions.website_name} {SysOptions.website_base_url}" + class Meta: model = Problem - fields = ("_id", "title", "description", + fields = ("display_id", "title", "description", "tags", "input_description", "output_description", "test_case_score", "hint", "time_limit", "memory_limit", "samples", "template", "spj", "rule_type", "source", "template") @@ -182,3 +182,76 @@ class AddContestProblemSerializer(serializers.Serializer): contest_id = serializers.IntegerField() problem_id = serializers.IntegerField() display_id = serializers.CharField() + + +class ExportProblemRequestSerialzier(serializers.Serializer): + problem_id = serializers.ListField(child=serializers.IntegerField(), allow_empty=False) + + +class UploadProblemForm(forms.Form): + file = forms.FileField() + + +class FormatValueSerializer(serializers.Serializer): + format = serializers.ChoiceField(choices=["html", "markdown"]) + value = serializers.CharField(allow_blank=True) + + +class TestCaseScoreSerializer(serializers.Serializer): + score = serializers.IntegerField(min_value=1) + input_name = serializers.CharField(max_length=32) + + +class TemplateSerializer(serializers.Serializer): + prepend = serializers.CharField() + template = serializers.CharField() + append = serializers.CharField() + + +class SPJSerializer(serializers.Serializer): + code = serializers.CharField() + language = serializers.ChoiceField(choices=spj_language_names) + + +class AnswerSerializer(serializers.Serializer): + code = serializers.CharField() + language = serializers.ChoiceField(choices=language_names) + + +class ImportProblemSerializer(serializers.Serializer): + display_id = serializers.CharField(max_length=128) + title = serializers.CharField(max_length=128) + description = FormatValueSerializer() + input_description = FormatValueSerializer() + output_description = FormatValueSerializer() + hint = FormatValueSerializer() + test_case_score = serializers.ListField(child=TestCaseScoreSerializer(), allow_null=True) + time_limit = serializers.IntegerField(min_value=1, max_value=60000) + memory_limit = serializers.IntegerField(min_value=1, max_value=10240) + samples = serializers.ListField(child=CreateSampleSerializer()) + template = serializers.DictField(child=TemplateSerializer()) + spj = SPJSerializer(allow_null=True) + rule_type = serializers.ChoiceField(choices=ProblemRuleType.choices()) + source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True) + answers = serializers.ListField(child=AnswerSerializer()) + tags = serializers.ListField(child=serializers.CharField()) + + +class FPSProblemSerializer(serializers.Serializer): + class UnitSerializer(serializers.Serializer): + unit = serializers.ChoiceField(choices=["MB", "s", "ms"]) + value = serializers.IntegerField(min_value=1, max_value=60000) + + title = serializers.CharField(max_length=128) + description = serializers.CharField() + input = serializers.CharField() + output = serializers.CharField() + hint = serializers.CharField(allow_blank=True, allow_null=True) + time_limit = UnitSerializer() + memory_limit = UnitSerializer() + samples = serializers.ListField(child=CreateSampleSerializer()) + source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True) + spj = SPJSerializer(allow_null=True) + template = serializers.ListField(child=serializers.DictField(), allow_empty=True, allow_null=True) + append = serializers.ListField(child=serializers.DictField(), allow_empty=True, allow_null=True) + prepend = serializers.ListField(child=serializers.DictField(), allow_empty=True, allow_null=True) diff --git a/problem/urls/admin.py b/problem/urls/admin.py index 8e16a8a3..e3a921fc 100644 --- a/problem/urls/admin.py +++ b/problem/urls/admin.py @@ -1,7 +1,8 @@ from django.conf.urls import url -from ..views.admin import ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView -from ..views.admin import CompileSPJAPI, AddContestProblemAPI +from ..views.admin import (ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView, + CompileSPJAPI, AddContestProblemAPI, ExportProblemAPI, ImportProblemAPI, + FPSProblemImport) urlpatterns = [ url(r"^test_case/?$", TestCaseAPI.as_view(), name="test_case_api"), @@ -10,4 +11,7 @@ urlpatterns = [ url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_admin_api"), url(r"^contest_problem/make_public/?$", MakeContestProblemPublicAPIView.as_view(), name="make_public_api"), url(r"^contest/add_problem_from_public/?$", AddContestProblemAPI.as_view(), name="add_contest_problem_from_public_api"), + url(r"^export_problem/?$", ExportProblemAPI.as_view(), name="export_problem_api"), + url(r"^import_problem/?$", ImportProblemAPI.as_view(), name="import_problem_api"), + url(r"^import_fps/?$", FPSProblemImport.as_view(), name="fps_problem_api"), ] diff --git a/problem/utils.py b/problem/utils.py index f8243099..9e29cd67 100644 --- a/problem/utils.py +++ b/problem/utils.py @@ -1,5 +1,17 @@ import re +TEMPLATE_BASE = """//PREPEND BEGIN +{} +//PREPEND END + +//TEMPLATE BEGIN +{} +//TEMPLATE END + +//APPEND BEGIN +{} +//APPEND END""" + def parse_problem_template(template_str): prepend = re.findall("//PREPEND BEGIN\n([\s\S]+?)//PREPEND END", template_str) @@ -8,3 +20,7 @@ def parse_problem_template(template_str): return {"prepend": prepend[0] if prepend else "", "template": template[0] if template else "", "append": append[0] if append else ""} + + +def build_problem_template(prepend, template, append): + return TEMPLATE_BASE.format(prepend, template, append) diff --git a/problem/views/admin.py b/problem/views/admin.py index 7b79ce26..38cc94a0 100644 --- a/problem/views/admin.py +++ b/problem/views/admin.py @@ -3,35 +3,91 @@ import json import os import shutil import zipfile +import tempfile from wsgiref.util import FileWrapper from django.conf import settings -from django.http import StreamingHttpResponse, HttpResponse +from django.http import StreamingHttpResponse, HttpResponse, FileResponse +from django.db import transaction from account.decorators import problem_permission_required, ensure_created_by from judge.dispatcher import SPJCompiler +from judge.languages import language_names from contest.models import Contest, ContestStatus -from submission.models import Submission -from utils.api import APIView, CSRFExemptAPIView, validate_serializer +from submission.models import Submission, JudgeStatus +from fps.parser import FPSHelper, FPSParser +from utils.api import APIView, CSRFExemptAPIView, validate_serializer, APIError from utils.shortcuts import rand_str, natural_sort_key +from utils.tasks import delete_files +from utils.constants import Difficulty +from ..utils import TEMPLATE_BASE, build_problem_template from ..models import Problem, ProblemRuleType, ProblemTag from ..serializers import (CreateContestProblemSerializer, CompileSPJSerializer, CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer, ProblemAdminSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer, - AddContestProblemSerializer) + AddContestProblemSerializer, ExportProblemSerializer, + ExportProblemRequestSerialzier, UploadProblemForm, ImportProblemSerializer, + FPSProblemSerializer) -class TestCaseAPI(CSRFExemptAPIView): - request_parsers = () +class TestCaseZipProcessor(object): + def process_zip(self, uploaded_zip_file, spj, dir=""): + try: + zip_file = zipfile.ZipFile(uploaded_zip_file, "r") + except zipfile.BadZipFile: + raise APIError("Bad zip file") + name_list = zip_file.namelist() + test_case_list = self.filter_name_list(name_list, spj=spj, dir=dir) + if not test_case_list: + raise APIError("Empty file") - def filter_name_list(self, name_list, spj): + test_case_id = rand_str() + test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id) + os.mkdir(test_case_dir) + + size_cache = {} + md5_cache = {} + + for item in test_case_list: + with open(os.path.join(test_case_dir, item), "wb") as f: + content = zip_file.read(f"{dir}{item}").replace(b"\r\n", b"\n") + size_cache[item] = len(content) + if item.endswith(".out"): + md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest() + f.write(content) + test_case_info = {"spj": spj, "test_cases": {}} + + info = [] + + if spj: + for index, item in enumerate(test_case_list): + data = {"input_name": item, "input_size": size_cache[item]} + info.append(data) + test_case_info["test_cases"][str(index + 1)] = data + else: + # ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")] + test_case_list = zip(*[test_case_list[i::2] for i in range(2)]) + for index, item in enumerate(test_case_list): + data = {"stripped_output_md5": md5_cache[item[1]], + "input_size": size_cache[item[0]], + "output_size": size_cache[item[1]], + "input_name": item[0], + "output_name": item[1]} + info.append(data) + test_case_info["test_cases"][str(index + 1)] = data + + with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f: + f.write(json.dumps(test_case_info, indent=4)) + return info, test_case_id + + def filter_name_list(self, name_list, spj, dir=""): ret = [] prefix = 1 if spj: while True: - in_name = str(prefix) + ".in" - if in_name in name_list: + in_name = f"{prefix}.in" + if f"{dir}{in_name}" in name_list: ret.append(in_name) prefix += 1 continue @@ -39,9 +95,9 @@ class TestCaseAPI(CSRFExemptAPIView): return sorted(ret, key=natural_sort_key) else: while True: - in_name = str(prefix) + ".in" - out_name = str(prefix) + ".out" - if in_name in name_list and out_name in name_list: + in_name = f"{prefix}.in" + out_name = f"{prefix}.out" + if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list: ret.append(in_name) ret.append(out_name) prefix += 1 @@ -49,6 +105,10 @@ class TestCaseAPI(CSRFExemptAPIView): else: return sorted(ret, key=natural_sort_key) + +class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor): + request_parsers = () + def get(self, request): problem_id = request.GET.get("problem_id") if not problem_id: @@ -90,62 +150,13 @@ class TestCaseAPI(CSRFExemptAPIView): file = form.cleaned_data["file"] else: return self.error("Upload failed") - tmp_file = os.path.join("/tmp", rand_str() + ".zip") - with open(tmp_file, "wb") as f: + zip_file = f"/tmp/{rand_str()}.zip" + with open(zip_file, "wb") as f: for chunk in file: f.write(chunk) - try: - zip_file = zipfile.ZipFile(tmp_file) - except zipfile.BadZipFile: - return self.error("Bad zip file") - name_list = zip_file.namelist() - test_case_list = self.filter_name_list(name_list, spj=spj) - if not test_case_list: - return self.error("Empty file") - - test_case_id = rand_str() - test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id) - os.mkdir(test_case_dir) - - size_cache = {} - md5_cache = {} - - for item in test_case_list: - with open(os.path.join(test_case_dir, item), "wb") as f: - content = zip_file.read(item).replace(b"\r\n", b"\n") - size_cache[item] = len(content) - if item.endswith(".out"): - md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest() - f.write(content) - test_case_info = {"spj": spj, "test_cases": {}} - - hint = None - diff = set(name_list).difference(set(test_case_list)) - if diff: - hint = ", ".join(diff) + " are ignored" - - ret = [] - - if spj: - for index, item in enumerate(test_case_list): - data = {"input_name": item, "input_size": size_cache[item]} - ret.append(data) - test_case_info["test_cases"][str(index + 1)] = data - else: - # ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")] - test_case_list = zip(*[test_case_list[i::2] for i in range(2)]) - for index, item in enumerate(test_case_list): - data = {"stripped_output_md5": md5_cache[item[1]], - "input_size": size_cache[item[0]], - "output_size": size_cache[item[1]], - "input_name": item[0], - "output_name": item[1]} - ret.append(data) - test_case_info["test_cases"][str(index + 1)] = data - - with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f: - f.write(json.dumps(test_case_info, indent=4)) - return self.success({"id": test_case_id, "info": ret, "hint": hint, "spj": spj}) + info, test_case_id = self.process_zip(zip_file, spj=spj) + os.remove(zip_file) + return self.success({"id": test_case_id, "info": info, "spj": spj}) class CompileSPJAPI(APIView): @@ -466,3 +477,204 @@ class AddContestProblemAPI(APIView): problem.save() problem.tags.set(tags) return self.success() + + +class ExportProblemAPI(APIView): + def choose_answers(self, user, problem): + ret = [] + for item in problem.languages: + submission = Submission.objects.filter(problem=problem, + user_id=user.id, + language=item, + result=JudgeStatus.ACCEPTED).order_by("-create_time").first() + if submission: + ret.append({"language": submission.language, "code": submission.code}) + return ret + + def process_one_problem(self, zip_file, user, problem, index): + info = ExportProblemSerializer(problem).data + info["answers"] = self.choose_answers(user, problem=problem) + compression = zipfile.ZIP_DEFLATED + zip_file.writestr(zinfo_or_arcname=f"{index}/problem.json", + data=json.dumps(info, indent=4), + compress_type=compression) + problem_test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) + with open(os.path.join(problem_test_case_dir, "info")) as f: + info = json.load(f) + for k, v in info["test_cases"].items(): + zip_file.write(filename=os.path.join(problem_test_case_dir, v["input_name"]), + arcname=f"{index}/testcase/{v['input_name']}", + compress_type=compression) + if not info["spj"]: + zip_file.write(filename=os.path.join(problem_test_case_dir, v["output_name"]), + arcname=f"{index}/testcase/{v['output_name']}", + compress_type=compression) + + @validate_serializer(ExportProblemRequestSerialzier) + def get(self, request): + problems = Problem.objects.filter(id__in=request.data["problem_id"]) + for problem in problems: + if problem.contest: + ensure_created_by(problem.contest, request.user) + else: + ensure_created_by(problem, request.user) + path = f"/tmp/{rand_str()}.zip" + with zipfile.ZipFile(path, "w") as zip_file: + for index, problem in enumerate(problems): + self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1) + delete_files.apply_async((path,), countdown=300) + resp = FileResponse(open(path, "rb")) + resp["Content-Type"] = "application/zip" + resp["Content-Disposition"] = f"attachment;filename=problem-export.zip" + return resp + + +class ImportProblemAPI(CSRFExemptAPIView, TestCaseZipProcessor): + request_parsers = () + + def post(self, request): + form = UploadProblemForm(request.POST, request.FILES) + if form.is_valid(): + file = form.cleaned_data["file"] + tmp_file = f"/tmp/{rand_str()}.zip" + with open(tmp_file, "wb") as f: + for chunk in file: + f.write(chunk) + else: + return self.error("Upload failed") + + count = 0 + with zipfile.ZipFile(tmp_file, "r") as zip_file: + name_list = zip_file.namelist() + for item in name_list: + if "/problem.json" in item: + count += 1 + with transaction.atomic(): + for i in range(1, count + 1): + with zip_file.open(f"{i}/problem.json") as f: + problem_info = json.load(f) + serializer = ImportProblemSerializer(data=problem_info) + if not serializer.is_valid(): + return self.error(f"Invalid problem format, error is {serializer.errors}") + else: + problem_info = serializer.data + for item in problem_info["template"].keys(): + if item not in language_names: + return self.error(f"Unsupported language {item}") + + problem_info["display_id"] = problem_info["display_id"][:24] + for k, v in problem_info["template"].items(): + problem_info["template"][k] = build_problem_template(v["prepend"], v["template"], + v["append"]) + + spj = problem_info["spj"] is not None + rule_type = problem_info["rule_type"] + test_case_score = problem_info["test_case_score"] + + # process test case + _, test_case_id = self.process_zip(tmp_file, spj=spj, dir=f"{i}/testcase/") + + problem_obj = Problem.objects.create(_id=problem_info["display_id"], + title=problem_info["title"], + description=problem_info["description"]["value"], + input_description=problem_info["input_description"][ + "value"], + output_description=problem_info["output_description"][ + "value"], + hint=problem_info["hint"]["value"], + test_case_score=test_case_score if test_case_score else [], + time_limit=problem_info["time_limit"], + memory_limit=problem_info["memory_limit"], + samples=problem_info["samples"], + template=problem_info["template"], + rule_type=problem_info["rule_type"], + source=problem_info["source"], + spj=spj, + spj_code=problem_info["spj"]["code"] if spj else None, + spj_language=problem_info["spj"][ + "language"] if spj else None, + spj_version=rand_str(8) if spj else "", + languages=language_names, + created_by=request.user, + visible=False, + difficulty=Difficulty.MID, + total_score=sum(item["score"] for item in test_case_score) + if rule_type == ProblemRuleType.OI else 0, + test_case_id=test_case_id + ) + for tag_name in problem_info["tags"]: + tag_obj, _ = ProblemTag.objects.get_or_create(name=tag_name) + problem_obj.tags.add(tag_obj) + return self.success({"import_count": count}) + + +class FPSProblemImport(CSRFExemptAPIView): + request_parsers = () + + def _create_problem(self, problem_data, creator): + if problem_data["time_limit"]["unit"] == "ms": + time_limit = problem_data["time_limit"]["value"] + else: + time_limit = problem_data["time_limit"]["value"] * 1000 + template = {} + prepend = {} + append = {} + for t in problem_data["prepend"]: + prepend[t["language"]] = t["code"] + for t in problem_data["append"]: + append[t["language"]] = t["code"] + for t in problem_data["template"]: + our_lang = lang = t["language"] + if lang == "Python": + our_lang = "Python3" + template[our_lang] = TEMPLATE_BASE.format(prepend.get(lang, ""), t["code"], append.get(lang, "")) + spj = problem_data["spj"] is not None + Problem.objects.create(_id=f"fps-{rand_str(4)}", + title=problem_data["title"], + description=problem_data["description"], + input_description=problem_data["input"], + output_description=problem_data["output"], + hint=problem_data["hint"], + test_case_score=[], + time_limit=time_limit, + memory_limit=problem_data["memory_limit"]["value"], + samples=problem_data["samples"], + template=template, + rule_type=ProblemRuleType.ACM, + source=problem_data.get("source", ""), + spj=spj, + spj_code=problem_data["spj"]["code"] if spj else None, + spj_language=problem_data["spj"]["language"] if spj else None, + spj_version=rand_str(8) if spj else "", + visible=False, + languages=language_names, + created_by=creator, + difficulty=Difficulty.MID, + test_case_id=problem_data["test_case_id"]) + + def post(self, request): + form = UploadProblemForm(request.POST, request.FILES) + if form.is_valid(): + file = form.cleaned_data["file"] + with tempfile.NamedTemporaryFile("wb") as tf: + for chunk in file.chunks(4096): + tf.file.write(chunk) + problems = FPSParser(tf.name).parse() + else: + return self.error("Parse upload file error") + + helper = FPSHelper() + with transaction.atomic(): + for _problem in problems: + test_case_id = rand_str() + test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id) + os.mkdir(test_case_dir) + helper.save_test_case(_problem, test_case_dir) + problem_data = helper.save_image(_problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX) + s = FPSProblemSerializer(data=problem_data) + if not s.is_valid(): + return self.error(f"Parse FPS file error: {s.errors}") + problem_data = s.data + problem_data["test_case_id"] = test_case_id + self._create_problem(problem_data, request.user) + return self.success({"import_count": len(problems)}) diff --git a/utils/constants.py b/utils/constants.py index 390d5685..50068f19 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -26,3 +26,9 @@ class CacheKey: contest_rank_cache = "contest_rank_cache" website_config = "website_config" option = "option" + + +class Difficulty(Choices): + LOW = "Low" + MID = "Mid" + HIGH = "High" diff --git a/utils/migrate_data.py b/utils/migrate_data.py index 4cb70f07..b9b9bd30 100644 --- a/utils/migrate_data.py +++ b/utils/migrate_data.py @@ -5,6 +5,7 @@ import re import json import django import hashlib +from json.decoder import JSONDecodeError sys.path.append("../") os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings") @@ -59,8 +60,8 @@ def set_problem_display_id_prefix(): def get_stripped_output_md5(test_case_id, output_name): output_path = os.path.join(settings.TEST_CASE_DIR, test_case_id, output_name) - with open(output_path, "r") as f: - return hashlib.md5(f.read().encode("utf-8").rstrip()).hexdigest() + with open(output_path, 'r') as f: + return hashlib.md5(f.read().rstrip().encode('utf-8')).hexdigest() def get_test_case_score(test_case_id): @@ -190,8 +191,12 @@ if __name__ == "__main__": print("Data file does not exist") exit(1) - with open(data_path, "r") as data_file: - old_data = json.load(data_file) + try: + with open(data_path, "r") as data_file: + old_data = json.load(data_file) + except JSONDecodeError: + print("Data file format error, ensure it's a valid json file!") + exit(1) print("Read old data successfully.\n") for obj in old_data: diff --git a/utils/tasks.py b/utils/tasks.py new file mode 100644 index 00000000..442b0bc4 --- /dev/null +++ b/utils/tasks.py @@ -0,0 +1,11 @@ +import os +from celery import shared_task + + +@shared_task +def delete_files(*args): + for item in args: + try: + os.remove(item) + except Exception: + pass