From 98991877a17e6d5dca1dd7a2a6e0beeb217c3ed9 Mon Sep 17 00:00:00 2001 From: yumoqing Date: Tue, 4 Nov 2025 16:34:32 +0800 Subject: [PATCH] bugfix --- longtasks/longtasks.py | 152 ++++++++++++++++++++++++++++------------- pyproject.toml | 2 +- 2 files changed, 106 insertions(+), 48 deletions(-) diff --git a/longtasks/longtasks.py b/longtasks/longtasks.py index d72afbe..35ee1ca 100644 --- a/longtasks/longtasks.py +++ b/longtasks/longtasks.py @@ -1,56 +1,72 @@ # -*- coding:utf-8 -*- import asyncio import aioredis +from random import randint import json import time from typing import Any, Dict +from appPublic.worker import get_event_loop, schedule_interval, schedule_once from appPublic.uniqueID import getID +def work1(x, y=1): + print(f'{x} * {y} = {x*y}') -class LongTask: - def __init__(self, task: asyncio.Task): - self.id = getID() - self.status = 'pending', - self.result = None - self.task_status = None - self.start_time = time.time() +async def work2(): + print('task2 ...') - def start(self): - self.status = 'running' - self.task.run() - - def status(self): - return { - 'taskid': self.id, - 'status':self.status, - 'task_status': self.task_status, - 'result': self.result - } - class LongTasks: - def __init__(self, redis_url, task_queue, processing_queue, worker_cnt=2 stuck_seconds=600): + def __init__(self, redis_url, taskname, worker_cnt=2, stuck_seconds=600, max_age_hours=3): self.redis_url = redis_url self.worker_cnt = worker_cnt - self.task_queue = task_queue - self.processing_queue = processing_queue + self.taskname = taskname + self.max_age_secords = max_age_hours * 3600 + self.task_queue = f'{taskname}_pending' + self.processing_queue = f'{taskname}_processing' self.stuck_seconds = stuck_seconds - async def init(self): + async def cleanup_expired_tasks(self): + """清理超过 max_age_hours 的任务""" + now = time.time() + async for key in self.redis.scan_iter(match=f"{self.taskname}:task:*"): + task = await self.redis.hgetall(key) + if not task: + continue + + created_at = task.get("created_at") + if not created_at: + continue + + if created + self.max_age_seconds < now: + print(f"🧹 删除过期任务: {key}") + await self.redis.delete(key) + await self.redis.lrem("task_queue", 0, key) # 从任务队列中移除(可选) + + + async def process_task(self, payload:dict): + sec = randint(0,5) + await asyncio.sleep(sec) + print(f'{payload=} done') + return { + 'result': 'OK' + } + + async def run(self): + schedule_interval(3600, self.cleanup_expired_tasks) self.redis = await aioredis.from_url(self.redis_url, decode_responses=True) await self.recover_stuck_tasks() - workers = [asyncio.create_task(self.worker_loop(redis, i)) for i in range(self.worker_cnt)] + workers = [asyncio.create_task(self.worker_loop(i)) for i in range(self.worker_cnt)] try: await asyncio.gather(*workers) except asyncio.CancelledError: for w in workers: w.cancel() finally: - await redis.close() + await self.redis.close() async def update_task_hash(self, task_id: str, mapping: Dict[str, Any]): # all values must be str - str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in mapping.items()} - await self.redis.hset(f"task:{task_id}", mapping=str_map) + # str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in mapping.items()} + await self.set_redis_task(task_id, mapping) async def recover_stuck_tasks(self): """ @@ -58,7 +74,7 @@ class LongTasks: 如果某任务的 task:{id}.started_at 距现在 > self.stuck_seconds,则认为卡住并重新入队或标记为 failed。 """ # 读取整个 processing_queue(注意:当队列非常大时需改成分页) - items = await redis.lrange(self.processing_queue, 0, -1) + items = await self.redis.lrange(self.processing_queue, 0, -1) now = time.time() for item in items: try: @@ -66,29 +82,28 @@ class LongTasks: task_id = task_obj["task_id"] except Exception: # 非法项直接清理 - await redis.lrem(self.processing_queue, 1, item) + await self.redis.lrem(self.processing_queue, 1, item) continue - task_key = f"task:{task_id}" - info = await redis.hgetall(task_key, encoding="utf-8") + info = await self.get_redis_task(task_id) if not info: # 如果 task hash 不存在,可选择直接删除或重新 enqueue # 这里我们选择重新入队并删除 processing entry - await redis.rpush(self.task_queue, item) - await redis.lrem(self.processing_queue, 1, item) + await self.redis.rpush(self.task_queue, item) + await self.redis.lrem(self.processing_queue, 1, item) print(f"[recover] requeued missing-hash {task_id}") continue started_at = float(info.get("started_at") or 0) status = info.get("status") - if status == "running" and (now - started_at) > self.stuck_seconds: + if status == "RUNNING" and (now - started_at) > self.stuck_seconds: # 任务卡住 -> 重新入队并更新 attempts 或直接标记失败 # 示例:重新入队并增加 attempts 字段 attempts = int(json.loads(info.get("attempts") or "0")) attempts += 1 - await self.update_task_hash(task_id, {"status": "queued", "attempts": attempts}) - await redis.rpush(self.task_queue, item) - await redis.lrem(self.processing_queue, 1, item) + await self.update_task_hash(task_id, {"status": "PENDING", "attempts": attempts}) + await self.redis.rpush(self.task_queue, item) + await self.redis.lrem(self.processing_queue, 1, item) print(f"[recover] task {task_id} requeued due to stuck") # else: 正常 running 或其他状态,不处理 @@ -113,28 +128,32 @@ class LongTasks: # 异常数据从 processing_queue 中移除 await self.redis.lrem(self.processing_queue, 1, item) continue - + + now = time.time() + created = task_obj.get('created_at', 0) + if created + self.max_age_seconds < now: + continue # 1) 更新 task hash 为 running(这一步很重要:客户端读取到状态) started_at = time.time() await self.update_task_hash(task_id, { - "status": "running", + "status": "RUNNING", "started_at": started_at, # optional: increment attempts }) # 2) 执行任务(catch exceptions) try: - result = await process_task(payload) + result = await self.process_task(payload) except asyncio.CancelledError: # 若希望支持取消,可把 status 设为 cancelling 等 - await self.update_task_hash(task_id, {"status": "failed", "error": "cancelled"}) + await self.update_task_hash(task_id, {"status": "FAILED", "error": "cancelled"}) # 移除 processing_queue 项(已处理) await self.redis.lrem(self.processing_queue, 1, item) continue except Exception as e: # 写回失败信息 await self.update_task_hash(task_id, { - "status": "failed", + "status": "FAILED", "error": str(e), "finished_at": time.time() }) @@ -144,7 +163,7 @@ class LongTasks: # 3) 写回成功结果并移除 processing_queue 项 await self.update_task_hash(task_id, { - "status": "success", + "status": "SUCCEEDED", "result": result, "finished_at": time.time() }) @@ -162,19 +181,58 @@ class LongTasks: taskid = getID() task_data = { "task_id": taskid, - "status": "pending", + "status": "PENDING", + "created_at": time.time(), "payload": json.dumps(payload) } - await self.redis.hset(f'task:{taskid}', mapping=task_data) + await self.set_redis_task(taskid, task_data) await self.redis.rpush(self.task_queue, json.dumps({ "task_id": taskid, "payload": payload })) return {'task_id': taskid} - async def get_status(taskid:str): - task = self.redis.hgetall(f'task:{taskid}', encoding="utf-8") + async def set_redis_task(self, taskid, task_data): + str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in task_data.items()} + await self.redis.hset(f"{self.taskname}:task:{taskid}", mapping=str_map) + + async def get_redis_task(self, taskid): + task = await self.redis.hgetall(f'{self.taskname}:task:{taskid}') + return task + + async def get_status(self, taskid:str): + task = await self.get_redis_task(taskid) if not task: return {'error': 'no task'} return task +if __name__ == '__main__': + async def main(lt): + x = schedule_interval(5, work1, 6) + print(f'interval worker {x}') + y = schedule_interval(3, work2) + print(f'interval worker {y}') + await asyncio.sleep(10000) + return + tasks = [] + for i in range(0, 10): + payload = { + "task": f"task {i}" + } + x = await lt.submit_task(payload) + tasks.append(x) + while True: + if len(tasks) < 1: + break + tasks1 = [i for i in tasks] + for t in tasks1: + task = await lt.get_status(t['task_id']) + print(f'{task}') + if task['status'] in ['SUCCEEDED', 'FAILED']: + tasks = [i for i in tasks if i['task_id'] != t['task_id']] + await asyncio.sleep(2) + + lt = LongTasks('redis://127.0.0.1:6379', 'test') + loop = get_event_loop() + loop.create_task(lt.run()) + loop.run_until_complete(main(lt)) diff --git a/pyproject.toml b/pyproject.toml index 7828a09..c382ccc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "0.1.0" description = "long task engine with submit, get_status interface to client" authors = [{name = "Yu Moqing", email = "yumoqing@gmail.com"}] license = {text = "MIT"} -dependencies = ["ahserver", "sqlor", "appPublic"] +dependencies = ["ahserver", "sqlor", "appPublic", "aioredis" ] [tool.setuptools.packages.find] where = ["."]