bugfix
This commit is contained in:
parent
3cb0dc22f2
commit
592188db0d
@ -14,28 +14,54 @@ class LongTasks:
|
|||||||
self.redis_url = redis_url
|
self.redis_url = redis_url
|
||||||
self.worker_cnt = worker_cnt
|
self.worker_cnt = worker_cnt
|
||||||
self.taskname = taskname
|
self.taskname = taskname
|
||||||
self.max_age_secords = max_age_hours * 3600
|
self.max_age_seconds = max_age_hours * 3600
|
||||||
self.task_queue = f'{taskname}_pending'
|
self.task_queue = f'{taskname}_pending'
|
||||||
self.processing_queue = f'{taskname}_processing'
|
self.processing_queue = f'{taskname}_processing'
|
||||||
self.stuck_seconds = stuck_seconds
|
self.stuck_seconds = stuck_seconds
|
||||||
|
self.redis = None
|
||||||
|
|
||||||
|
async def repush_pending_task(self):
|
||||||
|
async for key in self.redis.scan_iter(match=f"{self.taskname}:task:*"):
|
||||||
|
task = await self.redis.hgetall(key)
|
||||||
|
if task['status'] == 'PENDING':
|
||||||
|
jstr = json.dumps({
|
||||||
|
"task_id": task['task_id'],
|
||||||
|
"payload": json.loads(task['payload'])
|
||||||
|
})
|
||||||
|
await self.redis.rpush(self.task_queue, jstr)
|
||||||
|
debug(f'{jstr=}, {self.task_queue=}')
|
||||||
|
|
||||||
async def cleanup_expired_tasks(self):
|
async def cleanup_expired_tasks(self):
|
||||||
"""清理超过 max_age_hours 的任务"""
|
"""清理超过 max_age_hours 的任务"""
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
debug('cleanup_expired_tasks() called ...')
|
||||||
async for key in self.redis.scan_iter(match=f"{self.taskname}:task:*"):
|
async for key in self.redis.scan_iter(match=f"{self.taskname}:task:*"):
|
||||||
|
taskid = key.split(':')[-1]
|
||||||
task = await self.redis.hgetall(key)
|
task = await self.redis.hgetall(key)
|
||||||
if not task:
|
if not task:
|
||||||
|
debug(f'{key} task not found')
|
||||||
|
await self.redis.delete(key)
|
||||||
|
await self.redis.lrem(self.task_queue, 0, key) # 从任务队列中移除(可选)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
created_at = task.get("created_at")
|
created_at = task.get("created_at")
|
||||||
if not created_at:
|
if not created_at:
|
||||||
|
await self.redis.delete(key)
|
||||||
|
await self.redis.lrem(self.task_queue, 0, key) # 从任务队列中移除(可选)
|
||||||
|
await self.delete_redis_task(taskid)
|
||||||
|
debug(f'{key}, {task} no created_at key')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if created + self.max_age_seconds < now:
|
created_at = float(created_at)
|
||||||
debug(f"🧹 删除过期任务: {key}")
|
if created_at + self.max_age_seconds < now:
|
||||||
|
debug(f"🧹 删除过期任务: {key}, {task}")
|
||||||
await self.redis.delete(key)
|
await self.redis.delete(key)
|
||||||
await self.redis.lrem("task_queue", 0, key) # 从任务队列中移除(可选)
|
await self.redis.lrem(self.task_queue, 0, key) # 从任务队列中移除(可选)
|
||||||
|
status = task.get('status')
|
||||||
|
if status not in ['SUCCEEDED', 'FAILED', 'RUNNING', 'PENDING']:
|
||||||
|
debug(f"🧹 删除任务: {key}, {task}")
|
||||||
|
await self.redis.delete(key)
|
||||||
|
await self.redis.lrem(self.task_queue, 0, key) # 从任务队列中移除(可选)
|
||||||
|
|
||||||
async def process_task(self, payload:dict, workid:int=None):
|
async def process_task(self, payload:dict, workid:int=None):
|
||||||
sec = randint(0,5)
|
sec = randint(0,5)
|
||||||
@ -45,10 +71,14 @@ class LongTasks:
|
|||||||
'result': 'OK'
|
'result': 'OK'
|
||||||
}
|
}
|
||||||
|
|
||||||
async def run(self):
|
async def start_redis(self):
|
||||||
schedule_interval(3600, self.cleanup_expired_tasks)
|
|
||||||
self.redis = await aioredis.from_url(self.redis_url, decode_responses=True)
|
self.redis = await aioredis.from_url(self.redis_url, decode_responses=True)
|
||||||
await self.recover_stuck_tasks()
|
|
||||||
|
async def run(self):
|
||||||
|
await self.start_redis()
|
||||||
|
await self.cleanup_expired_tasks()
|
||||||
|
schedule_interval(3600, self.cleanup_expired_tasks)
|
||||||
|
schedule_interval(300, self.recover_stuck_tasks)
|
||||||
workers = [asyncio.create_task(self.worker_loop(i)) for i in range(self.worker_cnt)]
|
workers = [asyncio.create_task(self.worker_loop(i)) for i in range(self.worker_cnt)]
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(*workers)
|
await asyncio.gather(*workers)
|
||||||
@ -69,23 +99,16 @@ class LongTasks:
|
|||||||
如果某任务的 task:{id}.started_at 距现在 > self.stuck_seconds,则认为卡住并重新入队或标记为 failed。
|
如果某任务的 task:{id}.started_at 距现在 > self.stuck_seconds,则认为卡住并重新入队或标记为 failed。
|
||||||
"""
|
"""
|
||||||
# 读取整个 processing_queue(注意:当队列非常大时需改成分页)
|
# 读取整个 processing_queue(注意:当队列非常大时需改成分页)
|
||||||
|
debug('recover_stuck_tasks() called')
|
||||||
items = await self.redis.lrange(self.processing_queue, 0, -1)
|
items = await self.redis.lrange(self.processing_queue, 0, -1)
|
||||||
now = time.time()
|
now = time.time()
|
||||||
for item in items:
|
for task_id in items:
|
||||||
try:
|
|
||||||
task_obj = json.loads(item)
|
|
||||||
task_id = task_obj["task_id"]
|
|
||||||
except Exception:
|
|
||||||
# 非法项直接清理
|
|
||||||
await self.redis.lrem(self.processing_queue, 1, item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
info = await self.get_redis_task(task_id)
|
info = await self.get_redis_task(task_id)
|
||||||
if not info:
|
if not info:
|
||||||
# 如果 task hash 不存在,可选择直接删除或重新 enqueue
|
# 如果 task hash 不存在,可选择直接删除或重新 enqueue
|
||||||
# 这里我们选择重新入队并删除 processing entry
|
# 这里我们选择重新入队并删除 processing entry
|
||||||
await self.redis.rpush(self.task_queue, item)
|
await self.redis.rpush(self.task_queue, task_id)
|
||||||
await self.redis.lrem(self.processing_queue, 1, item)
|
await self.redis.lrem(self.processing_queue, 1, task_id)
|
||||||
debug(f"[recover] requeued missing-hash {task_id}")
|
debug(f"[recover] requeued missing-hash {task_id}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -97,8 +120,8 @@ class LongTasks:
|
|||||||
attempts = int(json.loads(info.get("attempts") or "0"))
|
attempts = int(json.loads(info.get("attempts") or "0"))
|
||||||
attempts += 1
|
attempts += 1
|
||||||
await self.update_task_hash(task_id, {"status": "PENDING", "attempts": attempts})
|
await self.update_task_hash(task_id, {"status": "PENDING", "attempts": attempts})
|
||||||
await self.redis.rpush(self.task_queue, item)
|
await self.redis.rpush(self.task_queue, task_id)
|
||||||
await self.redis.lrem(self.processing_queue, 1, item)
|
await self.redis.lrem(self.processing_queue, 1, task_id)
|
||||||
debug(f"[recover] task {task_id} requeued due to stuck")
|
debug(f"[recover] task {task_id} requeued due to stuck")
|
||||||
# else: 正常 running 或其他状态,不处理
|
# else: 正常 running 或其他状态,不处理
|
||||||
|
|
||||||
@ -108,26 +131,19 @@ class LongTasks:
|
|||||||
try:
|
try:
|
||||||
# BRPOPLPUSH: 从 task_queue 弹出(阻塞),并 push 到 processing_queue(原子)
|
# BRPOPLPUSH: 从 task_queue 弹出(阻塞),并 push 到 processing_queue(原子)
|
||||||
# aioredis: brpoplpush(source, destination, timeout)
|
# aioredis: brpoplpush(source, destination, timeout)
|
||||||
item = await self.redis.brpoplpush(self.task_queue, self.processing_queue, timeout=5)
|
# debug(f"Before BRPOPLPUSH: {self.task_queue} length = {await self.redis.llen(self.task_queue)}")
|
||||||
if not item:
|
task_id = await self.redis.brpoplpush(self.task_queue, self.processing_queue, timeout=5)
|
||||||
|
if not task_id:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
# debug(f'No task in task queue {self.task_queue=}, {self.processing_queue=}')
|
||||||
|
# await self.repush_pending_task()
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
|
debug(f'get task_id={task_id}')
|
||||||
|
|
||||||
|
task_obj = await self.get_redis_task(task_id)
|
||||||
|
payload = task_obj["payload"]
|
||||||
|
|
||||||
# item 是字符串 JSON
|
|
||||||
try:
|
|
||||||
task_obj = json.loads(item)
|
|
||||||
task_id = task_obj["task_id"]
|
|
||||||
payload = task_obj["payload"]
|
|
||||||
except Exception as e:
|
|
||||||
exception(f"[worker {worker_id}] bad item in queue, removing: {e}")
|
|
||||||
# 异常数据从 processing_queue 中移除
|
|
||||||
await self.redis.lrem(self.processing_queue, 1, item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
created = task_obj.get('created_at', 0)
|
|
||||||
if created + self.max_age_seconds < now:
|
|
||||||
continue
|
|
||||||
# 1) 更新 task hash 为 running(这一步很重要:客户端读取到状态)
|
# 1) 更新 task hash 为 running(这一步很重要:客户端读取到状态)
|
||||||
started_at = time.time()
|
started_at = time.time()
|
||||||
await self.update_task_hash(task_id, {
|
await self.update_task_hash(task_id, {
|
||||||
@ -143,7 +159,7 @@ class LongTasks:
|
|||||||
# 若希望支持取消,可把 status 设为 cancelling 等
|
# 若希望支持取消,可把 status 设为 cancelling 等
|
||||||
await self.update_task_hash(task_id, {"status": "FAILED", "error": "cancelled"})
|
await self.update_task_hash(task_id, {"status": "FAILED", "error": "cancelled"})
|
||||||
# 移除 processing_queue 项(已处理)
|
# 移除 processing_queue 项(已处理)
|
||||||
await self.redis.lrem(self.processing_queue, 1, item)
|
await self.redis.lrem(self.processing_queue, 1, task_id)
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 写回失败信息
|
# 写回失败信息
|
||||||
@ -153,7 +169,7 @@ class LongTasks:
|
|||||||
"finished_at": time.time()
|
"finished_at": time.time()
|
||||||
})
|
})
|
||||||
# 从 processing_queue 移除该项
|
# 从 processing_queue 移除该项
|
||||||
await self.redis.lrem(self.processing_queue, 1, item)
|
await self.redis.lrem(self.processing_queue, 1, task_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 3) 写回成功结果并移除 processing_queue 项
|
# 3) 写回成功结果并移除 processing_queue 项
|
||||||
@ -163,7 +179,7 @@ class LongTasks:
|
|||||||
"finished_at": time.time()
|
"finished_at": time.time()
|
||||||
})
|
})
|
||||||
# 最后一步:从 processing_queue 中移除任务项(LREM)
|
# 最后一步:从 processing_queue 中移除任务项(LREM)
|
||||||
await self.redis.lrem(self.processing_queue, 1, item)
|
await self.redis.lrem(self.processing_queue, 1, task_id)
|
||||||
debug(f"[worker {worker_id}] finished {task_id}")
|
debug(f"[worker {worker_id}] finished {task_id}")
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@ -181,12 +197,18 @@ class LongTasks:
|
|||||||
"payload": json.dumps(payload)
|
"payload": json.dumps(payload)
|
||||||
}
|
}
|
||||||
await self.set_redis_task(taskid, task_data)
|
await self.set_redis_task(taskid, task_data)
|
||||||
await self.redis.rpush(self.task_queue, json.dumps({
|
await self.redis.rpush(self.task_queue, taskid)
|
||||||
"task_id": taskid,
|
|
||||||
"payload": payload
|
|
||||||
}))
|
|
||||||
return {'task_id': taskid}
|
return {'task_id': taskid}
|
||||||
|
|
||||||
|
async def all_taks_status(self):
|
||||||
|
x = False
|
||||||
|
async for key in self.redis.scan_iter(match=f"{self.taskname}:task:*"):
|
||||||
|
taskid = key.split(':')[-1]
|
||||||
|
task = await self.get_redis_task(taskid)
|
||||||
|
print(f'{task}')
|
||||||
|
x = True
|
||||||
|
return x
|
||||||
|
|
||||||
async def set_redis_task(self, taskid, task_data):
|
async def set_redis_task(self, taskid, task_data):
|
||||||
str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in task_data.items()}
|
str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in task_data.items()}
|
||||||
await self.redis.hset(f"{self.taskname}:task:{taskid}", mapping=str_map)
|
await self.redis.hset(f"{self.taskname}:task:{taskid}", mapping=str_map)
|
||||||
@ -195,6 +217,9 @@ class LongTasks:
|
|||||||
task = await self.redis.hgetall(f'{self.taskname}:task:{taskid}')
|
task = await self.redis.hgetall(f'{self.taskname}:task:{taskid}')
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
async def delete_redis_task(self, taskid):
|
||||||
|
await self.redis.delete(f'{self.taskname}:task:{taskid}')
|
||||||
|
|
||||||
async def get_status(self, taskid:str):
|
async def get_status(self, taskid:str):
|
||||||
task = await self.get_redis_task(taskid)
|
task = await self.get_redis_task(taskid)
|
||||||
if not task:
|
if not task:
|
||||||
@ -203,23 +228,17 @@ class LongTasks:
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
async def main(lt):
|
async def main(lt):
|
||||||
tasks = []
|
|
||||||
for i in range(0, 10):
|
for i in range(0, 10):
|
||||||
payload = {
|
payload = {
|
||||||
"task": f"task {i}"
|
"task": f"task {i}"
|
||||||
}
|
}
|
||||||
x = await lt.submit_task(payload)
|
x = await lt.submit_task(payload)
|
||||||
tasks.append(x)
|
|
||||||
while True:
|
while True:
|
||||||
if len(tasks) < 1:
|
x = await lt.all_taks_status()
|
||||||
|
if not x:
|
||||||
break
|
break
|
||||||
tasks1 = [i for i in tasks]
|
print('\n')
|
||||||
for t in tasks1:
|
await asyncio.sleep(10)
|
||||||
task = await lt.get_status(t['task_id'])
|
|
||||||
print(f'{task}')
|
|
||||||
if task['status'] in ['SUCCEEDED', 'FAILED']:
|
|
||||||
tasks = [i for i in tasks if i['task_id'] != t['task_id']]
|
|
||||||
await asyncio.sleep(2)
|
|
||||||
|
|
||||||
lt = LongTasks('redis://127.0.0.1:6379', 'test')
|
lt = LongTasks('redis://127.0.0.1:6379', 'test')
|
||||||
loop = get_event_loop()
|
loop = get_event_loop()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user