# coding=utf-8 from __future__ import unicode_literals import _judger import psutil import os import json from multiprocessing import Pool from config import TEST_CASE_DIR, JUDGER_RUN_LOG_PATH, LOW_PRIVILEDGE_GID, LOW_PRIVILEDGE_UID from exception import JudgeClientError def _run(instance, test_case_file_id): return instance._judge_one(test_case_file_id) class JudgeClient(object): def __init__(self, run_config, exe_path, max_cpu_time, max_memory, test_case_id, submission_dir): self._run_config = run_config self._exe_path = exe_path self._max_cpu_time = max_cpu_time self._max_memory = max_memory self._max_real_time = self._max_cpu_time * 3 self._test_case_id = test_case_id self._test_case_dir = os.path.join(TEST_CASE_DIR, test_case_id) self._submission_dir = submission_dir self._pool = Pool(processes=psutil.cpu_count()) self._test_case_info = self._load_test_case_info() def _load_test_case_info(self): try: with open(os.path.join(self._test_case_dir, "info")) as f: return json.loads(f.read()) except IOError: raise JudgeClientError("Test case not found") except ValueError: raise JudgeClientError("Bad test case config") def _seccomp_rule_path(self, rule_name): if rule_name: return "/usr/lib/judger/librule_{rule_name}.so".format(rule_name=rule_name).encode("utf-8") def _judge_one(self, test_case_file_id): in_file = os.path.join(self._test_case_dir, str(test_case_file_id) + ".in") out_file = os.path.join(self._submission_dir, str(test_case_file_id) + ".out") command = self._run_config["command"].format(exe_path=self._exe_path, exe_dir=os.path.dirname(self._exe_path), max_memory=self._max_memory / 1024).split(" ") run_result = _judger.run(max_cpu_time=self._max_cpu_time, max_real_time=self._max_real_time, max_memory=self._max_memory, max_output_size=1024 * 1024 * 1024, max_process_number=self._run_config["max_process_number"], exe_path=command[0].encode("utf-8"), input_path=in_file, output_path=out_file, error_path=out_file, args=[item.encode("utf-8") for item in command[1::]], env=[("PATH=" + os.getenv("PATH")).encode("utf-8")], log_path=JUDGER_RUN_LOG_PATH, seccomp_rule_so_path=self._seccomp_rule_path(self._run_config["seccomp_rule"]), uid=LOW_PRIVILEDGE_UID, gid=LOW_PRIVILEDGE_GID) run_result["test_case"] = test_case_file_id return run_result def run(self): tmp_result = [] result = [] for _ in range(self._test_case_info["test_case_number"]): tmp_result.append(self._pool.apply_async(_run, (self, _ + 1))) self._pool.close() self._pool.join() for item in tmp_result: # exception will be raised, when get() is called # # http://stackoverflow.com/questions/22094852/how-to-catch-exceptions-in-workers-in-multiprocessing result.append(item.get()) return result def __getstate__(self): # http://stackoverflow.com/questions/25382455/python-notimplementederror-pool-objects-cannot-be-passed-between-processes self_dict = self.__dict__.copy() del self_dict["_pool"] return self_dict