解决 dispatcher 中部分数据库锁的问题

This commit is contained in:
virusdefender 2019-03-12 14:57:44 +08:00
parent c192304fd8
commit 1adfd35615
3 changed files with 62 additions and 43 deletions

4
.dockerignore Normal file
View File

@ -0,0 +1,4 @@
venv
.idea
.git
.DS_Store

View File

@ -4,7 +4,7 @@ import logging
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
from django.db import transaction from django.db import transaction, IntegrityError
from django.db.models import F from django.db.models import F
from account.models import User from account.models import User
@ -29,6 +29,27 @@ def process_pending_task():
judge_task.send(**data) judge_task.send(**data)
class ChooseJudgeServer:
def __init__(self):
self.server = None
def __enter__(self) -> [JudgeServer, None]:
with transaction.atomic():
servers = JudgeServer.objects.select_for_update().filter(is_disabled=False).order_by("task_number")
servers = [s for s in servers if s.status == "normal"]
for server in servers:
if server.task_number <= server.cpu_core * 2:
server.task_number = F("task_number") + 1
server.save()
self.server = server
return server
return None
def __exit__(self, exc_type, exc_val, exc_tb):
if self.server:
JudgeServer.objects.filter(id=self.server.id).update(task_number=F("task_number") - 1)
class DispatcherBase(object): class DispatcherBase(object):
def __init__(self): def __init__(self):
self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest()
@ -42,25 +63,6 @@ class DispatcherBase(object):
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
@staticmethod
def choose_judge_server():
with transaction.atomic():
servers = JudgeServer.objects.select_for_update().filter(is_disabled=False).order_by("task_number")
servers = [s for s in servers if s.status == "normal"]
for server in servers:
if server.task_number <= server.cpu_core * 2:
server.task_number = F("task_number") + 1
server.save()
return server
@staticmethod
def release_judge_server(judge_server_id):
with transaction.atomic():
# 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下
server = JudgeServer.objects.get(id=judge_server_id)
server.task_number = F("task_number") - 1
server.save()
class SPJCompiler(DispatcherBase): class SPJCompiler(DispatcherBase):
def __init__(self, spj_code, spj_version, spj_language): def __init__(self, spj_code, spj_version, spj_language):
@ -74,13 +76,14 @@ class SPJCompiler(DispatcherBase):
} }
def compile_spj(self): def compile_spj(self):
server = self.choose_judge_server() with ChooseJudgeServer() as server:
if not server: if not server:
return "No available judge_server" return "No available judge_server"
result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data) result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data)
self.release_judge_server(server.id) if not result:
if result["err"]: return "Failed to call judge server"
return result["data"] if result["err"]:
return result["data"]
class JudgeDispatcher(DispatcherBase): class JudgeDispatcher(DispatcherBase):
@ -118,12 +121,6 @@ class JudgeDispatcher(DispatcherBase):
self.submission.statistic_info["score"] = score self.submission.statistic_info["score"] = score
def judge(self): def judge(self):
server = self.choose_judge_server()
if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
cache.lpush(CacheKey.waiting_queue, json.dumps(data))
return
language = self.submission.language language = self.submission.language
sub_config = list(filter(lambda item: language == item["name"], SysOptions.languages))[0] sub_config = list(filter(lambda item: language == item["name"], SysOptions.languages))[0]
spj_config = {} spj_config = {}
@ -152,9 +149,18 @@ class JudgeDispatcher(DispatcherBase):
"spj_src": self.problem.spj_code "spj_src": self.problem.spj_code
} }
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING) with ChooseJudgeServer() as server:
if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
cache.lpush(CacheKey.waiting_queue, json.dumps(data))
return
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING)
resp = self._request(urljoin(server.service_url, "/judge"), data=data)
if not resp:
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.SYSTEM_ERROR)
return
resp = self._request(urljoin(server.service_url, "/judge"), data=data)
if resp["err"]: if resp["err"]:
self.submission.result = JudgeStatus.COMPILE_ERROR self.submission.result = JudgeStatus.COMPILE_ERROR
self.submission.statistic_info["err_info"] = resp["data"] self.submission.statistic_info["err_info"] = resp["data"]
@ -173,7 +179,6 @@ class JudgeDispatcher(DispatcherBase):
else: else:
self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED
self.submission.save() self.submission.save()
self.release_judge_server(server.id)
if self.contest_id: if self.contest_id:
if self.contest.status != ContestStatus.CONTEST_UNDERWAY or \ if self.contest.status != ContestStatus.CONTEST_UNDERWAY or \
@ -322,15 +327,25 @@ class JudgeDispatcher(DispatcherBase):
def update_contest_rank(self): def update_contest_rank(self):
if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank: if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank:
cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}") cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}")
with transaction.atomic(): with transaction.atomic():
if self.contest.rule_type == ContestRuleType.ACM: if self.contest.rule_type == ContestRuleType.ACM:
acm_rank, _ = ACMContestRank.objects.select_for_update(). \ model = ACMContestRank
get_or_create(user_id=self.submission.user_id, contest=self.contest) func = self._update_acm_contest_rank
self._update_acm_contest_rank(acm_rank)
else: else:
oi_rank, _ = OIContestRank.objects.select_for_update(). \ model = OIContestRank
get_or_create(user_id=self.submission.user_id, contest=self.contest) func = self._update_oi_contest_rank
self._update_oi_contest_rank(oi_rank)
try:
# todo unique index
# func 也不是安全的
rank = model.objects.get(user_id=self.submission.user_id, contest=self.contest)
except ACMContestRank.DoesNotExist:
try:
rank = model.objects.create(user_id=self.submission.user_id, contest=self.contest)
except IntegrityError:
rank = model.objects.get(user_id=self.submission.user_id, contest=self.contest)
func(rank)
def _update_acm_contest_rank(self, rank): def _update_acm_contest_rank(self, rank):
info = rank.submission_info.get(str(self.submission.problem_id)) info = rank.submission_info.get(str(self.submission.problem_id))