longtasks/longtasks/longtasks.py
2025-11-06 14:41:55 +08:00

262 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding:utf-8 -*-
from traceback import format_exc
import asyncio
import aioredis
from random import randint
import json
import time
from typing import Any, Dict
from appPublic.worker import get_event_loop, schedule_interval, schedule_once
from appPublic.dictObject import DictObject
from appPublic.uniqueID import getID
from appPublic.log import debug, exception
class LongTasks:
def __init__(self, redis_url, taskname, worker_cnt=2, stuck_seconds=600, max_age_hours=3):
self.redis_url = redis_url
self.worker_cnt = worker_cnt
self.taskname = taskname
self.max_age_seconds = max_age_hours * 3600
self.task_queue = f'{taskname}_pending'
self.processing_queue = f'{taskname}_processing'
self.stuck_seconds = stuck_seconds
self.redis = None
async def repush_pending_task(self):
async for key in self.redis.scan_iter(match=f"{self.taskname}:task:*"):
task = await self.redis.hgetall(key)
if task['status'] == 'PENDING':
jstr = json.dumps({
"task_id": task['task_id'],
"payload": json.loads(task['payload'])
})
await self.redis.rpush(self.task_queue, jstr)
debug(f'{jstr=}, {self.task_queue=}')
async def cleanup_expired_tasks(self):
"""清理超过 max_age_hours 的任务"""
now = time.time()
debug('cleanup_expired_tasks() called ...')
async for key in self.redis.scan_iter(match=f"{self.taskname}:task:*"):
taskid = key.split(':')[-1]
task = await self.redis.hgetall(key)
if not task:
debug(f'{key} task not found')
await self.redis.delete(key)
await self.redis.lrem(self.task_queue, 0, key) # 从任务队列中移除(可选)
continue
created_at = task.get("created_at")
if not created_at:
await self.redis.delete(key)
await self.redis.lrem(self.task_queue, 0, key) # 从任务队列中移除(可选)
await self.delete_redis_task(taskid)
debug(f'{key}, {task} no created_at key')
continue
created_at = float(created_at)
if created_at + self.max_age_seconds < now:
debug(f"🧹 删除过期任务: {key}, {task}")
await self.redis.delete(key)
await self.redis.lrem(self.task_queue, 0, key) # 从任务队列中移除(可选)
status = task.get('status')
if status not in ['SUCCEEDED', 'FAILED', 'RUNNING', 'PENDING']:
debug(f"🧹 删除任务: {key}, {task}")
await self.redis.delete(key)
await self.redis.lrem(self.task_queue, 0, key) # 从任务队列中移除(可选)
async def process_task(self, payload:dict, workid:int=None):
sec = randint(0,5)
await asyncio.sleep(sec)
debug(f'{payload=} done')
return {
'result': 'OK'
}
async def start_redis(self):
self.redis = await aioredis.from_url(self.redis_url, decode_responses=True)
async def run(self, *args, **kw):
await self.start_redis()
await self.cleanup_expired_tasks()
schedule_interval(3600, self.cleanup_expired_tasks)
schedule_interval(300, self.recover_stuck_tasks)
workers = [asyncio.create_task(self.worker_loop(i)) for i in range(self.worker_cnt)]
try:
await asyncio.gather(*workers)
except asyncio.CancelledError:
for w in workers:
w.cancel()
finally:
await self.redis.close()
async def update_task_hash(self, task_id: str, mapping: Dict[str, Any]):
# all values must be str
# str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in mapping.items()}
await self.set_redis_task(task_id, mapping)
async def recover_stuck_tasks(self):
"""
启动时或定期调用,检查 processing_queue 中可能卡住的任务,
如果某任务的 task:{id}.started_at 距现在 > self.stuck_seconds则认为卡住并重新入队或标记为 failed。
"""
# 读取整个 processing_queue注意当队列非常大时需改成分页
debug('recover_stuck_tasks() called')
items = await self.redis.lrange(self.processing_queue, 0, -1)
now = time.time()
for task_id in items:
info = await self.get_redis_task(task_id)
if not info:
# 如果 task hash 不存在,可选择直接删除或重新 enqueue
# 这里我们选择重新入队并删除 processing entry
await self.redis.rpush(self.task_queue, task_id)
await self.redis.lrem(self.processing_queue, 1, task_id)
debug(f"[recover] requeued missing-hash {task_id}")
continue
started_at = float(info.get("started_at") or 0)
status = info.get("status")
if status == "RUNNING" and (now - started_at) > self.stuck_seconds:
# 任务卡住 -> 重新入队并更新 attempts 或直接标记失败
# 示例:重新入队并增加 attempts 字段
attempts = int(json.loads(info.get("attempts") or "0"))
attempts += 1
await self.update_task_hash(task_id, {"status": "PENDING", "attempts": attempts})
await self.redis.rpush(self.task_queue, task_id)
await self.redis.lrem(self.processing_queue, 1, task_id)
debug(f"[recover] task {task_id} requeued due to stuck")
# else: 正常 running 或其他状态,不处理
async def worker_loop(self, worker_id: int):
debug(f"[worker {worker_id}] start")
while True:
try:
# BRPOPLPUSH: 从 task_queue 弹出(阻塞),并 push 到 processing_queue原子
# aioredis: brpoplpush(source, destination, timeout)
# debug(f"Before BRPOPLPUSH: {self.task_queue} length = {await self.redis.llen(self.task_queue)}")
task_id = await self.redis.brpoplpush(self.task_queue, self.processing_queue, timeout=5)
if not task_id:
await asyncio.sleep(0.1)
# debug(f'No task in task queue {self.task_queue=}, {self.processing_queue=}')
# await self.repush_pending_task()
continue
else:
debug(f'get task_id={task_id}')
task_obj = await self.get_redis_task(task_id)
payload = task_obj["payload"]
debug(f'task={task_obj}')
payload = json.loads(payload)
# 1) 更新 task hash 为 running这一步很重要客户端读取到状态
started_at = time.time()
await self.update_task_hash(task_id, {
"status": "RUNNING",
"started_at": started_at,
# optional: increment attempts
})
# 2) 执行任务catch exceptions
try:
result = await self.process_task(payload, worker_id)
except asyncio.CancelledError:
# 若希望支持取消,可把 status 设为 cancelling 等
await self.update_task_hash(task_id, {"status": "FAILED", "error": "cancelled"})
# 移除 processing_queue 项(已处理)
await self.redis.lrem(self.processing_queue, 1, task_id)
continue
except Exception as e:
# 写回失败信息
await self.update_task_hash(task_id, {
"status": "FAILED",
"error": str(e),
"finished_at": time.time()
})
# 从 processing_queue 移除该项
await self.redis.lrem(self.processing_queue, 1, task_id)
debug(f'{e=}\n{format_exc()}')
continue
# 3) 写回成功结果并移除 processing_queue 项
await self.update_task_hash(task_id, {
"status": "SUCCEEDED",
"result": result,
"finished_at": time.time()
})
# 最后一步:从 processing_queue 中移除任务项LREM
await self.redis.lrem(self.processing_queue, 1, task_id)
debug(f"[worker {worker_id}] finished {task_id}")
except asyncio.CancelledError:
break
except Exception as e:
exception(f"[worker {worker_id}] loop error: {e}")
await asyncio.sleep(1)
async def submit_task(self, payload):
taskid = getID()
task_data = {
"task_id": taskid,
"status": "PENDING",
"created_at": time.time(),
"payload": payload
}
await self.set_redis_task(taskid, task_data)
await self.redis.rpush(self.task_queue, taskid)
return {'task_id': taskid}
async def all_taks_status(self):
x = False
async for key in self.redis.scan_iter(match=f"{self.taskname}:task:*"):
taskid = key.split(':')[-1]
task = await self.get_redis_task(taskid)
print(f'{task}')
x = True
return x
async def set_redis_task(self, taskid, task_data):
str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in task_data.items()}
await self.redis.hset(f"{self.taskname}:task:{taskid}", mapping=str_map)
async def get_redis_task(self, taskid):
task = await self.redis.hgetall(f'{self.taskname}:task:{taskid}')
if task['created_at']:
task['created_at'] = float(task['created_at'])
if task['started_at']:
task['started_at'] = float(task['started_at'])
if task['finished_at']:
task['finished_at'] = float(task['finished_at'])
if task['payload']:
task['payload'] = json.loads(task['payload'])
if task['status'] == 'SUCCEEDED' and task['result']:
task['result'] = json.loads(task['result'])
return task
async def delete_redis_task(self, taskid):
await self.redis.delete(f'{self.taskname}:task:{taskid}')
async def get_status(self, taskid:str):
task = await self.get_redis_task(taskid)
if not task:
return {'error': 'no task'}
return task
if __name__ == '__main__':
async def main(lt):
for i in range(0, 10):
payload = {
"task": f"task {i}"
}
x = await lt.submit_task(payload)
while True:
x = await lt.all_taks_status()
if not x:
break
print('\n')
await asyncio.sleep(10)
lt = LongTasks('redis://127.0.0.1:6379', 'test')
loop = get_event_loop()
loop.create_task(lt.run())
loop.run_until_complete(main(lt))