修正dispatcher,使用redis存任务队列,修正submission的post,修改部分settings

This commit is contained in:
zemal 2017-05-10 17:20:52 +08:00
parent 4733eecef9
commit 08bd591bfb
13 changed files with 133 additions and 107 deletions

View File

@ -41,7 +41,8 @@ class JudgeServer(models.Model):
@property
def status(self):
if (timezone.now() - self.last_heartbeat).total_seconds() > 5:
# 增加一秒延时,提高对网络环境的适应性
if (timezone.now() - self.last_heartbeat).total_seconds() > 6:
return "abnormal"
return "normal"

View File

@ -1,9 +1,11 @@
import hashlib
from django.utils import timezone
from django_redis import get_redis_connection
from account.decorators import super_admin_required
from judge.languages import languages, spj_languages
from judge.dispatcher import process_pending_task
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from utils.shortcuts import rand_str
@ -126,6 +128,10 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
service_url=service_url,
last_heartbeat=timezone.now(),
)
# 新server上线 处理队列中的防止没有新的提交而导致一直waiting
conn = get_redis_connection("JudgeQueue")
process_pending_task(conn)
return self.success()

View File

@ -9,23 +9,32 @@ from django.db.models import F
from django_redis import get_redis_connection
from judge.languages import languages
from account.models import User, UserProfile
from account.models import User
from conf.models import JudgeServer, JudgeServerToken
from problem.models import Problem, ProblemRuleType
from submission.models import JudgeStatus
from submission.models import JudgeStatus, Submission
logger = logging.getLogger(__name__)
WAITING_QUEUE = "waiting_queue"
# 继续处理在队列中的问题
def process_pending_task(redis_conn):
if redis_conn.llen(WAITING_QUEUE):
# 防止循环引入
from submission.tasks import judge_task
data = json.loads(redis_conn.rpop(WAITING_QUEUE))
judge_task.delay(**data)
class JudgeDispatcher(object):
def __init__(self, submission_obj, problem_obj):
def __init__(self, submission_id, problem_id):
token = JudgeServerToken.objects.first().token
self.token = hashlib.sha256(token.encode("utf-8")).hexdigest()
self.redis_conn = get_redis_connection("JudgeQueue")
self.submission_obj = submission_obj
self.problem_obj = problem_obj
self.submission_obj = Submission.objects.get(pk=submission_id)
self.problem_obj = Problem.objects.get(pk=problem_id)
def _request(self, url, data=None):
kwargs = {"headers": {"X-Judge-Server-Token": self.token,
@ -41,10 +50,10 @@ class JudgeDispatcher(object):
def choose_judge_server():
with transaction.atomic():
# TODO: use more reasonable way
servers = JudgeServer.objects.select_for_update().filter(
status="normal").order_by("task_number")
if servers.exists():
server = servers.first()
servers = JudgeServer.objects.select_for_update().all().order_by('task_number')
servers = [s for s in servers if s.status == "normal"]
if servers:
server = servers[0]
server.used_instance_number = F("task_number") + 1
server.save()
return server
@ -60,28 +69,31 @@ class JudgeDispatcher(object):
def judge(self, output=False):
server = self.choose_judge_server()
if not server:
self.redis_conn.lpush(WAITING_QUEUE, self.submission_obj.id)
data = {'submission_id': self.submission_obj.id, 'problem_id': self.problem_obj.id}
self.redis_conn.lpush(WAITING_QUEUE, json.dumps(data))
return
language = list(filter(lambda item: self.submission_obj.language == item['name'], languages))[0]
data = {"language_config": language['config'],
"src": self.submission_obj.code,
"max_cpu_time": self.problem_obj.time_limit,
"max_memory": self.problem_obj.memory_limit,
"test_case_id": self.problem_obj.test_case_id,
"output": output}
data = {
"language_config": language['config'],
"src": self.submission_obj.code,
"max_cpu_time": self.problem_obj.time_limit,
"max_memory": 1024 * 1024 * self.problem_obj.memory_limit,
"test_case_id": self.problem_obj.test_case_id,
"output": output
}
# TODO: try catch
resp = self._request(urljoin(server.service_url, "/judge"), data=data)
self.submission_obj.info = resp
if resp['err']:
self.submission_obj.result = JudgeStatus.COMPILE_ERROR
else:
error_test_case = list(filter(lambda case: case['result'] != 0, resp))
error_test_case = list(filter(lambda case: case['result'] != 0, resp['data']))
# 多个测试点全部正确AC否则ACM模式下取第一个测试点状态
if not error_test_case:
self.submission_obj.result = JudgeStatus.ACCEPTED
elif self.problem_obj.rule_tyle == ProblemRuleType.ACM:
self.submission_obj.result = error_test_case[0].result
elif self.problem_obj.rule_type == ProblemRuleType.ACM:
self.submission_obj.result = error_test_case[0]['result']
else:
self.submission_obj.result = JudgeStatus.PARTIALLY_ACCEPTED
self.submission_obj.save()
@ -92,37 +104,36 @@ class JudgeDispatcher(object):
pass
else:
self.update_problem_status()
# 取redis中等待中的提交
if self.redis_conn.llen(WAITING_QUEUE):
pass
process_pending_task(self.redis_conn)
def compile_spj(self, service_url, src, spj_version, spj_compile_config, test_case_id):
data = {"src": src, "spj_version": spj_version,
"spj_compile_config": spj_compile_config, "test_case_id": test_case_id}
return self._request(service_url + "/compile_spj", data=data)
"spj_compile_config": spj_compile_config,
"test_case_id": test_case_id}
return self._request(urljoin(service_url, "compile_spj"), data=data)
def update_problem_status(self):
with transaction.atomic():
problem = Problem.objects.select_for_update().get(id=self.problem_obj.problem_id)
# 更新普通题目的计数器
problem.add_submission_number()
# 更新用户做题状态
problem = Problem.objects.select_for_update().get(id=self.problem_obj.id)
user = User.objects.select_for_update().get(id=self.submission_obj.user_id)
problems_status = UserProfile.objects.get(user=user).problem_status
# 更新提交计数器
problem.add_submission_number()
user_profile = user.userprofile
user_profile.add_submission_number()
if self.submission_obj.result == JudgeStatus.ACCEPTED:
problem.add_ac_number()
problems_status = user_profile.problems_status
if "problems" not in problems_status:
problems_status["problems"] = {}
# 增加用户提交计数器
user.userprofile.add_submission_number()
# 之前状态不是ac, 现在是ac了 需要更新用户ac题目数量计数器,这里需要判重
if problems_status["problems"].get(str(problem.id), JudgeStatus.WRONG_ANSWER) != JudgeStatus.ACCEPTED:
if self.submission_obj.result == JudgeStatus.ACCEPTED:
user.userprofile.add_accepted_problem_number()
user_profile.add_accepted_problem_number()
problems_status["problems"][str(problem.id)] = JudgeStatus.ACCEPTED
else:
problems_status["problems"][str(problem.id)] = JudgeStatus.WRONG_ANSWER
user.problems_status = problems_status
user.save(update_fields=["problems_status"])
user_profile.problems_status = problems_status
user_profile.save(update_fields=["problems_status"])

View File

@ -0,0 +1,6 @@
from __future__ import absolute_import, unicode_literals
# Django starts so that shared_task will use this app.
from .celery import app as celery_app
__all__ = ['celery_app']

18
oj/celery.py Normal file
View File

@ -0,0 +1,18 @@
from __future__ import absolute_import, unicode_literals
import os
from celery import Celery
from django.conf import settings
# set the default Django settings module for the 'celery' program.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'oj.settings')
app = Celery('oj')
# Using a string here means the worker will not have to
# pickle the object when using Windows.
app.config_from_object('django.conf:settings')
# load task modules from all registered Django app configs.
app.autodiscover_tasks(lambda: settings.INSTALLED_APPS)
# app.autodiscover_tasks()

View File

@ -1,19 +0,0 @@
class DBRouter(object):
def db_for_read(self, model, **hints):
if model._meta.app_label == "submission":
return "submission"
return "default"
def db_for_write(self, model, **hints):
if model._meta.app_label == "submission":
return "submission"
return "default"
def allow_relation(self, obj1, obj2, **hints):
return True
def allow_migrate(self, db, app_label, model=None, **hints):
if app_label == "submission":
return db == app_label
else:
return db == "default"

View File

@ -34,16 +34,11 @@ CACHES = {
}
}
REDIS_CACHE = {
"host": "127.0.0.1",
"port": 6379,
"db": 1
}
# For celery
REDIS_QUEUE = {
"host": "127.0.0.1",
"port": 6379,
"db": 2
"db": 4
}
DEBUG = True

View File

@ -164,8 +164,6 @@ BROKER_URL = 'redis://%s:%s/%s' % (REDIS_QUEUE["host"], str(REDIS_QUEUE["port"])
CELERY_ACCEPT_CONTENT = ["json"]
CELERY_TASK_SERIALIZER = "json"
DATABASE_ROUTERS = ['oj.db_router.DBRouter']
IMAGE_UPLOAD_DIR = os.path.join(BASE_DIR, 'upload/')
# 用于限制用户恶意提交大量代码

View File

@ -64,11 +64,11 @@ class AbstractProblem(models.Model):
def add_submission_number(self):
self.total_submit_number = models.F("total_submit_number") + 1
self.save()
self.save(update_fields=['total_submit_number'])
def add_ac_number(self):
self.total_accepted_number = models.F("total_accepted_number") + 1
self.save()
self.save(update_fields=['total_accepted_number'])
class Problem(AbstractProblem):

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9.6 on 2017-05-09 12:03
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('submission', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='submission',
name='code',
field=models.TextField(),
),
]

View File

@ -1,7 +1,6 @@
from django.db import models
from jsonfield import JSONField
from utils.models import RichTextField
from utils.shortcuts import rand_str
@ -25,7 +24,7 @@ class Submission(models.Model):
problem_id = models.IntegerField(db_index=True)
created_time = models.DateTimeField(auto_now_add=True)
user_id = models.IntegerField(db_index=True)
code = RichTextField()
code = models.TextField()
result = models.IntegerField(default=JudgeStatus.PENDING)
# 判题结果的详细信息
info = JSONField(default={})

View File

@ -1,7 +1,8 @@
from __future__ import absolute_import, unicode_literals
from celery import shared_task
from judge.tasks import JudgeDispatcher
from judge.dispatcher import JudgeDispatcher
@shared_task
def _judge(submission_obj, problem_obj):
return JudgeDispatcher(submission_obj, problem_obj).judge()
def judge_task(submission_id, problem_id):
JudgeDispatcher(submission_id, problem_id).judge()

View File

@ -4,53 +4,43 @@ from django_redis import get_redis_connection
from account.decorators import login_required
from account.models import AdminType, User
from problem.models import Problem
from submission.tasks import judge_task
from utils.api import APIView, validate_serializer
from utils.shortcuts import build_query_string
from utils.throttling import TokenBucket, BucketController
from ..models import Submission
from ..serializers import CreateSubmissionSerializer
from ..tasks import _judge
def _submit_code(response, user, problem_id, language, code):
controller = BucketController(user_id=user.id,
redis_conn=get_redis_connection("Throttling"),
default_capacity=30)
bucket = TokenBucket(fill_rate=10,
capacity=20,
last_capacity=controller.last_capacity,
last_timestamp=controller.last_timestamp)
if bucket.consume():
controller.last_capacity -= 1
else:
return response.error("Please wait %d seconds" % int(bucket.expected_time() + 1))
try:
problem = Problem.objects.get(id=problem_id)
except Problem.DoesNotExist:
return response.error("Problem not exist")
submission = Submission.objects.create(user_id=user.id,
language=language,
code=code,
problem_id=problem.id)
try:
_judge.delay(submission, problem)
except Exception:
return response.error("Failed")
return response.success({"submission_id": submission.id})
class SubmissionAPI(APIView):
@validate_serializer(CreateSubmissionSerializer)
# TODO: login
# @login_required
def post(self, request):
controller = BucketController(user_id=request.user.id,
redis_conn=get_redis_connection("Throttling"),
default_capacity=30)
bucket = TokenBucket(fill_rate=10, capacity=20,
last_capacity=controller.last_capacity,
last_timestamp=controller.last_timestamp)
if bucket.consume():
controller.last_capacity -= 1
else:
return self.error("Please wait %d seconds" % int(bucket.expected_time() + 1))
data = request.data
return _submit_code(self, request.user, data["problem_id"], data["language"], data["code"])
try:
problem = Problem.objects.get(id=data['problem_id'])
except Problem.DoesNotExist:
return self.error("Problem not exist")
# TODO: user_id
submission = Submission.objects.create(user_id=1,
language=data['language'],
code=data['code'],
problem_id=problem.id)
judge_task.delay(submission.id, problem.id)
# JudgeDispatcher(submission.id, problem.id).judge()
return self.success({"submission_id": submission.id})
@login_required
def get(self, request):