bugfix
This commit is contained in:
parent
fd7fce82ec
commit
98991877a1
@ -1,56 +1,72 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import asyncio
|
||||
import aioredis
|
||||
from random import randint
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from appPublic.worker import get_event_loop, schedule_interval, schedule_once
|
||||
from appPublic.uniqueID import getID
|
||||
|
||||
def work1(x, y=1):
|
||||
print(f'{x} * {y} = {x*y}')
|
||||
|
||||
class LongTask:
|
||||
def __init__(self, task: asyncio.Task):
|
||||
self.id = getID()
|
||||
self.status = 'pending',
|
||||
self.result = None
|
||||
self.task_status = None
|
||||
self.start_time = time.time()
|
||||
|
||||
def start(self):
|
||||
self.status = 'running'
|
||||
self.task.run()
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
'taskid': self.id,
|
||||
'status':self.status,
|
||||
'task_status': self.task_status,
|
||||
'result': self.result
|
||||
}
|
||||
async def work2():
|
||||
print('task2 ...')
|
||||
|
||||
class LongTasks:
|
||||
def __init__(self, redis_url, task_queue, processing_queue, worker_cnt=2 stuck_seconds=600):
|
||||
def __init__(self, redis_url, taskname, worker_cnt=2, stuck_seconds=600, max_age_hours=3):
|
||||
self.redis_url = redis_url
|
||||
self.worker_cnt = worker_cnt
|
||||
self.task_queue = task_queue
|
||||
self.processing_queue = processing_queue
|
||||
self.taskname = taskname
|
||||
self.max_age_secords = max_age_hours * 3600
|
||||
self.task_queue = f'{taskname}_pending'
|
||||
self.processing_queue = f'{taskname}_processing'
|
||||
self.stuck_seconds = stuck_seconds
|
||||
|
||||
async def init(self):
|
||||
async def cleanup_expired_tasks(self):
|
||||
"""清理超过 max_age_hours 的任务"""
|
||||
now = time.time()
|
||||
async for key in self.redis.scan_iter(match=f"{self.taskname}:task:*"):
|
||||
task = await self.redis.hgetall(key)
|
||||
if not task:
|
||||
continue
|
||||
|
||||
created_at = task.get("created_at")
|
||||
if not created_at:
|
||||
continue
|
||||
|
||||
if created + self.max_age_seconds < now:
|
||||
print(f"🧹 删除过期任务: {key}")
|
||||
await self.redis.delete(key)
|
||||
await self.redis.lrem("task_queue", 0, key) # 从任务队列中移除(可选)
|
||||
|
||||
|
||||
async def process_task(self, payload:dict):
|
||||
sec = randint(0,5)
|
||||
await asyncio.sleep(sec)
|
||||
print(f'{payload=} done')
|
||||
return {
|
||||
'result': 'OK'
|
||||
}
|
||||
|
||||
async def run(self):
|
||||
schedule_interval(3600, self.cleanup_expired_tasks)
|
||||
self.redis = await aioredis.from_url(self.redis_url, decode_responses=True)
|
||||
await self.recover_stuck_tasks()
|
||||
workers = [asyncio.create_task(self.worker_loop(redis, i)) for i in range(self.worker_cnt)]
|
||||
workers = [asyncio.create_task(self.worker_loop(i)) for i in range(self.worker_cnt)]
|
||||
try:
|
||||
await asyncio.gather(*workers)
|
||||
except asyncio.CancelledError:
|
||||
for w in workers:
|
||||
w.cancel()
|
||||
finally:
|
||||
await redis.close()
|
||||
await self.redis.close()
|
||||
|
||||
async def update_task_hash(self, task_id: str, mapping: Dict[str, Any]):
|
||||
# all values must be str
|
||||
str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in mapping.items()}
|
||||
await self.redis.hset(f"task:{task_id}", mapping=str_map)
|
||||
# str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in mapping.items()}
|
||||
await self.set_redis_task(task_id, mapping)
|
||||
|
||||
async def recover_stuck_tasks(self):
|
||||
"""
|
||||
@ -58,7 +74,7 @@ class LongTasks:
|
||||
如果某任务的 task:{id}.started_at 距现在 > self.stuck_seconds,则认为卡住并重新入队或标记为 failed。
|
||||
"""
|
||||
# 读取整个 processing_queue(注意:当队列非常大时需改成分页)
|
||||
items = await redis.lrange(self.processing_queue, 0, -1)
|
||||
items = await self.redis.lrange(self.processing_queue, 0, -1)
|
||||
now = time.time()
|
||||
for item in items:
|
||||
try:
|
||||
@ -66,29 +82,28 @@ class LongTasks:
|
||||
task_id = task_obj["task_id"]
|
||||
except Exception:
|
||||
# 非法项直接清理
|
||||
await redis.lrem(self.processing_queue, 1, item)
|
||||
await self.redis.lrem(self.processing_queue, 1, item)
|
||||
continue
|
||||
|
||||
task_key = f"task:{task_id}"
|
||||
info = await redis.hgetall(task_key, encoding="utf-8")
|
||||
info = await self.get_redis_task(task_id)
|
||||
if not info:
|
||||
# 如果 task hash 不存在,可选择直接删除或重新 enqueue
|
||||
# 这里我们选择重新入队并删除 processing entry
|
||||
await redis.rpush(self.task_queue, item)
|
||||
await redis.lrem(self.processing_queue, 1, item)
|
||||
await self.redis.rpush(self.task_queue, item)
|
||||
await self.redis.lrem(self.processing_queue, 1, item)
|
||||
print(f"[recover] requeued missing-hash {task_id}")
|
||||
continue
|
||||
|
||||
started_at = float(info.get("started_at") or 0)
|
||||
status = info.get("status")
|
||||
if status == "running" and (now - started_at) > self.stuck_seconds:
|
||||
if status == "RUNNING" and (now - started_at) > self.stuck_seconds:
|
||||
# 任务卡住 -> 重新入队并更新 attempts 或直接标记失败
|
||||
# 示例:重新入队并增加 attempts 字段
|
||||
attempts = int(json.loads(info.get("attempts") or "0"))
|
||||
attempts += 1
|
||||
await self.update_task_hash(task_id, {"status": "queued", "attempts": attempts})
|
||||
await redis.rpush(self.task_queue, item)
|
||||
await redis.lrem(self.processing_queue, 1, item)
|
||||
await self.update_task_hash(task_id, {"status": "PENDING", "attempts": attempts})
|
||||
await self.redis.rpush(self.task_queue, item)
|
||||
await self.redis.lrem(self.processing_queue, 1, item)
|
||||
print(f"[recover] task {task_id} requeued due to stuck")
|
||||
# else: 正常 running 或其他状态,不处理
|
||||
|
||||
@ -114,27 +129,31 @@ class LongTasks:
|
||||
await self.redis.lrem(self.processing_queue, 1, item)
|
||||
continue
|
||||
|
||||
now = time.time()
|
||||
created = task_obj.get('created_at', 0)
|
||||
if created + self.max_age_seconds < now:
|
||||
continue
|
||||
# 1) 更新 task hash 为 running(这一步很重要:客户端读取到状态)
|
||||
started_at = time.time()
|
||||
await self.update_task_hash(task_id, {
|
||||
"status": "running",
|
||||
"status": "RUNNING",
|
||||
"started_at": started_at,
|
||||
# optional: increment attempts
|
||||
})
|
||||
|
||||
# 2) 执行任务(catch exceptions)
|
||||
try:
|
||||
result = await process_task(payload)
|
||||
result = await self.process_task(payload)
|
||||
except asyncio.CancelledError:
|
||||
# 若希望支持取消,可把 status 设为 cancelling 等
|
||||
await self.update_task_hash(task_id, {"status": "failed", "error": "cancelled"})
|
||||
await self.update_task_hash(task_id, {"status": "FAILED", "error": "cancelled"})
|
||||
# 移除 processing_queue 项(已处理)
|
||||
await self.redis.lrem(self.processing_queue, 1, item)
|
||||
continue
|
||||
except Exception as e:
|
||||
# 写回失败信息
|
||||
await self.update_task_hash(task_id, {
|
||||
"status": "failed",
|
||||
"status": "FAILED",
|
||||
"error": str(e),
|
||||
"finished_at": time.time()
|
||||
})
|
||||
@ -144,7 +163,7 @@ class LongTasks:
|
||||
|
||||
# 3) 写回成功结果并移除 processing_queue 项
|
||||
await self.update_task_hash(task_id, {
|
||||
"status": "success",
|
||||
"status": "SUCCEEDED",
|
||||
"result": result,
|
||||
"finished_at": time.time()
|
||||
})
|
||||
@ -162,19 +181,58 @@ class LongTasks:
|
||||
taskid = getID()
|
||||
task_data = {
|
||||
"task_id": taskid,
|
||||
"status": "pending",
|
||||
"status": "PENDING",
|
||||
"created_at": time.time(),
|
||||
"payload": json.dumps(payload)
|
||||
}
|
||||
await self.redis.hset(f'task:{taskid}', mapping=task_data)
|
||||
await self.set_redis_task(taskid, task_data)
|
||||
await self.redis.rpush(self.task_queue, json.dumps({
|
||||
"task_id": taskid,
|
||||
"payload": payload
|
||||
}))
|
||||
return {'task_id': taskid}
|
||||
|
||||
async def get_status(taskid:str):
|
||||
task = self.redis.hgetall(f'task:{taskid}', encoding="utf-8")
|
||||
async def set_redis_task(self, taskid, task_data):
|
||||
str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in task_data.items()}
|
||||
await self.redis.hset(f"{self.taskname}:task:{taskid}", mapping=str_map)
|
||||
|
||||
async def get_redis_task(self, taskid):
|
||||
task = await self.redis.hgetall(f'{self.taskname}:task:{taskid}')
|
||||
return task
|
||||
|
||||
async def get_status(self, taskid:str):
|
||||
task = await self.get_redis_task(taskid)
|
||||
if not task:
|
||||
return {'error': 'no task'}
|
||||
return task
|
||||
|
||||
if __name__ == '__main__':
|
||||
async def main(lt):
|
||||
x = schedule_interval(5, work1, 6)
|
||||
print(f'interval worker {x}')
|
||||
y = schedule_interval(3, work2)
|
||||
print(f'interval worker {y}')
|
||||
await asyncio.sleep(10000)
|
||||
return
|
||||
tasks = []
|
||||
for i in range(0, 10):
|
||||
payload = {
|
||||
"task": f"task {i}"
|
||||
}
|
||||
x = await lt.submit_task(payload)
|
||||
tasks.append(x)
|
||||
while True:
|
||||
if len(tasks) < 1:
|
||||
break
|
||||
tasks1 = [i for i in tasks]
|
||||
for t in tasks1:
|
||||
task = await lt.get_status(t['task_id'])
|
||||
print(f'{task}')
|
||||
if task['status'] in ['SUCCEEDED', 'FAILED']:
|
||||
tasks = [i for i in tasks if i['task_id'] != t['task_id']]
|
||||
await asyncio.sleep(2)
|
||||
|
||||
lt = LongTasks('redis://127.0.0.1:6379', 'test')
|
||||
loop = get_event_loop()
|
||||
loop.create_task(lt.run())
|
||||
loop.run_until_complete(main(lt))
|
||||
|
||||
@ -8,7 +8,7 @@ version = "0.1.0"
|
||||
description = "long task engine with submit, get_status interface to client"
|
||||
authors = [{name = "Yu Moqing", email = "yumoqing@gmail.com"}]
|
||||
license = {text = "MIT"}
|
||||
dependencies = ["ahserver", "sqlor", "appPublic"]
|
||||
dependencies = ["ahserver", "sqlor", "appPublic", "aioredis" ]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user