Compare commits

...

2 Commits

Author SHA1 Message Date
6b2736b03b bugfix 2025-11-04 16:56:01 +08:00
98991877a1 bugfix 2025-11-04 16:34:32 +08:00
4 changed files with 162 additions and 48 deletions

View File

@ -1,2 +1,44 @@
# longtasks # longtasks
## Usage
Use with ahserver, you need to save a instance of LongTasks or its child class to ServerEnv, in child class, implements the process_task method for your business
save LongTasks instace to ServerEnv:
```
from ahserver.serverenv import ServerEnv
from longtasks.longtasks import Longtasks
from appPublic.worker import schedule_once
class MyTasks(LongTasks):
# Child class
async def process_task(self, payload):
# use logic to execute task
pass
def load_longtasks()
longtasks = MyTasks('redis://127.0.0.1:6379', 'example')
env = ServerEnv()
env.longtasks = longtasks
# run the backend job
schedule_once(0.1, longtasks.run)
```
submit a task in dspy
```
payload = {
'prompt':'gagagag'
}
x = await longtasks.submit_task(payload)
# x is a dict with has a 'task_id' key
return x
```
query task status
```
taskid = 'mytaskid'
task_status = await longtasks.get_status(taskid)
return task_status
```

14
longtasks/init.py Normal file
View File

@ -0,0 +1,14 @@
from ahserver.serverenv import ServerEnv
from longtasks.longtasks import Longtasks
from appPublic.worker import schedule_once
class MyTasks(LongTasks):
async def process_task(self, payload):
....
def load_longtasks()
longtasks = MyTasks('redis://127.0.0.1:6379', 'example')
env = ServerEnv()
env.longtasks = longtasks
schedule_once(0.1, longtasks.run)

View File

@ -1,56 +1,72 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import asyncio import asyncio
import aioredis import aioredis
from random import randint
import json import json
import time import time
from typing import Any, Dict from typing import Any, Dict
from appPublic.worker import get_event_loop, schedule_interval, schedule_once
from appPublic.uniqueID import getID from appPublic.uniqueID import getID
def work1(x, y=1):
print(f'{x} * {y} = {x*y}')
class LongTask: async def work2():
def __init__(self, task: asyncio.Task): print('task2 ...')
self.id = getID()
self.status = 'pending',
self.result = None
self.task_status = None
self.start_time = time.time()
def start(self):
self.status = 'running'
self.task.run()
def status(self):
return {
'taskid': self.id,
'status':self.status,
'task_status': self.task_status,
'result': self.result
}
class LongTasks: class LongTasks:
def __init__(self, redis_url, task_queue, processing_queue, worker_cnt=2 stuck_seconds=600): def __init__(self, redis_url, taskname, worker_cnt=2, stuck_seconds=600, max_age_hours=3):
self.redis_url = redis_url self.redis_url = redis_url
self.worker_cnt = worker_cnt self.worker_cnt = worker_cnt
self.task_queue = task_queue self.taskname = taskname
self.processing_queue = processing_queue self.max_age_secords = max_age_hours * 3600
self.task_queue = f'{taskname}_pending'
self.processing_queue = f'{taskname}_processing'
self.stuck_seconds = stuck_seconds self.stuck_seconds = stuck_seconds
async def init(self): async def cleanup_expired_tasks(self):
"""清理超过 max_age_hours 的任务"""
now = time.time()
async for key in self.redis.scan_iter(match=f"{self.taskname}:task:*"):
task = await self.redis.hgetall(key)
if not task:
continue
created_at = task.get("created_at")
if not created_at:
continue
if created + self.max_age_seconds < now:
print(f"🧹 删除过期任务: {key}")
await self.redis.delete(key)
await self.redis.lrem("task_queue", 0, key) # 从任务队列中移除(可选)
async def process_task(self, payload:dict):
sec = randint(0,5)
await asyncio.sleep(sec)
print(f'{payload=} done')
return {
'result': 'OK'
}
async def run(self):
schedule_interval(3600, self.cleanup_expired_tasks)
self.redis = await aioredis.from_url(self.redis_url, decode_responses=True) self.redis = await aioredis.from_url(self.redis_url, decode_responses=True)
await self.recover_stuck_tasks() await self.recover_stuck_tasks()
workers = [asyncio.create_task(self.worker_loop(redis, i)) for i in range(self.worker_cnt)] workers = [asyncio.create_task(self.worker_loop(i)) for i in range(self.worker_cnt)]
try: try:
await asyncio.gather(*workers) await asyncio.gather(*workers)
except asyncio.CancelledError: except asyncio.CancelledError:
for w in workers: for w in workers:
w.cancel() w.cancel()
finally: finally:
await redis.close() await self.redis.close()
async def update_task_hash(self, task_id: str, mapping: Dict[str, Any]): async def update_task_hash(self, task_id: str, mapping: Dict[str, Any]):
# all values must be str # all values must be str
str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in mapping.items()} # str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in mapping.items()}
await self.redis.hset(f"task:{task_id}", mapping=str_map) await self.set_redis_task(task_id, mapping)
async def recover_stuck_tasks(self): async def recover_stuck_tasks(self):
""" """
@ -58,7 +74,7 @@ class LongTasks:
如果某任务的 task:{id}.started_at 距现在 > self.stuck_seconds则认为卡住并重新入队或标记为 failed 如果某任务的 task:{id}.started_at 距现在 > self.stuck_seconds则认为卡住并重新入队或标记为 failed
""" """
# 读取整个 processing_queue注意当队列非常大时需改成分页 # 读取整个 processing_queue注意当队列非常大时需改成分页
items = await redis.lrange(self.processing_queue, 0, -1) items = await self.redis.lrange(self.processing_queue, 0, -1)
now = time.time() now = time.time()
for item in items: for item in items:
try: try:
@ -66,29 +82,28 @@ class LongTasks:
task_id = task_obj["task_id"] task_id = task_obj["task_id"]
except Exception: except Exception:
# 非法项直接清理 # 非法项直接清理
await redis.lrem(self.processing_queue, 1, item) await self.redis.lrem(self.processing_queue, 1, item)
continue continue
task_key = f"task:{task_id}" info = await self.get_redis_task(task_id)
info = await redis.hgetall(task_key, encoding="utf-8")
if not info: if not info:
# 如果 task hash 不存在,可选择直接删除或重新 enqueue # 如果 task hash 不存在,可选择直接删除或重新 enqueue
# 这里我们选择重新入队并删除 processing entry # 这里我们选择重新入队并删除 processing entry
await redis.rpush(self.task_queue, item) await self.redis.rpush(self.task_queue, item)
await redis.lrem(self.processing_queue, 1, item) await self.redis.lrem(self.processing_queue, 1, item)
print(f"[recover] requeued missing-hash {task_id}") print(f"[recover] requeued missing-hash {task_id}")
continue continue
started_at = float(info.get("started_at") or 0) started_at = float(info.get("started_at") or 0)
status = info.get("status") status = info.get("status")
if status == "running" and (now - started_at) > self.stuck_seconds: if status == "RUNNING" and (now - started_at) > self.stuck_seconds:
# 任务卡住 -> 重新入队并更新 attempts 或直接标记失败 # 任务卡住 -> 重新入队并更新 attempts 或直接标记失败
# 示例:重新入队并增加 attempts 字段 # 示例:重新入队并增加 attempts 字段
attempts = int(json.loads(info.get("attempts") or "0")) attempts = int(json.loads(info.get("attempts") or "0"))
attempts += 1 attempts += 1
await self.update_task_hash(task_id, {"status": "queued", "attempts": attempts}) await self.update_task_hash(task_id, {"status": "PENDING", "attempts": attempts})
await redis.rpush(self.task_queue, item) await self.redis.rpush(self.task_queue, item)
await redis.lrem(self.processing_queue, 1, item) await self.redis.lrem(self.processing_queue, 1, item)
print(f"[recover] task {task_id} requeued due to stuck") print(f"[recover] task {task_id} requeued due to stuck")
# else: 正常 running 或其他状态,不处理 # else: 正常 running 或其他状态,不处理
@ -113,28 +128,32 @@ class LongTasks:
# 异常数据从 processing_queue 中移除 # 异常数据从 processing_queue 中移除
await self.redis.lrem(self.processing_queue, 1, item) await self.redis.lrem(self.processing_queue, 1, item)
continue continue
now = time.time()
created = task_obj.get('created_at', 0)
if created + self.max_age_seconds < now:
continue
# 1) 更新 task hash 为 running这一步很重要客户端读取到状态 # 1) 更新 task hash 为 running这一步很重要客户端读取到状态
started_at = time.time() started_at = time.time()
await self.update_task_hash(task_id, { await self.update_task_hash(task_id, {
"status": "running", "status": "RUNNING",
"started_at": started_at, "started_at": started_at,
# optional: increment attempts # optional: increment attempts
}) })
# 2) 执行任务catch exceptions # 2) 执行任务catch exceptions
try: try:
result = await process_task(payload) result = await self.process_task(payload)
except asyncio.CancelledError: except asyncio.CancelledError:
# 若希望支持取消,可把 status 设为 cancelling 等 # 若希望支持取消,可把 status 设为 cancelling 等
await self.update_task_hash(task_id, {"status": "failed", "error": "cancelled"}) await self.update_task_hash(task_id, {"status": "FAILED", "error": "cancelled"})
# 移除 processing_queue 项(已处理) # 移除 processing_queue 项(已处理)
await self.redis.lrem(self.processing_queue, 1, item) await self.redis.lrem(self.processing_queue, 1, item)
continue continue
except Exception as e: except Exception as e:
# 写回失败信息 # 写回失败信息
await self.update_task_hash(task_id, { await self.update_task_hash(task_id, {
"status": "failed", "status": "FAILED",
"error": str(e), "error": str(e),
"finished_at": time.time() "finished_at": time.time()
}) })
@ -144,7 +163,7 @@ class LongTasks:
# 3) 写回成功结果并移除 processing_queue 项 # 3) 写回成功结果并移除 processing_queue 项
await self.update_task_hash(task_id, { await self.update_task_hash(task_id, {
"status": "success", "status": "SUCCEEDED",
"result": result, "result": result,
"finished_at": time.time() "finished_at": time.time()
}) })
@ -162,19 +181,58 @@ class LongTasks:
taskid = getID() taskid = getID()
task_data = { task_data = {
"task_id": taskid, "task_id": taskid,
"status": "pending", "status": "PENDING",
"created_at": time.time(),
"payload": json.dumps(payload) "payload": json.dumps(payload)
} }
await self.redis.hset(f'task:{taskid}', mapping=task_data) await self.set_redis_task(taskid, task_data)
await self.redis.rpush(self.task_queue, json.dumps({ await self.redis.rpush(self.task_queue, json.dumps({
"task_id": taskid, "task_id": taskid,
"payload": payload "payload": payload
})) }))
return {'task_id': taskid} return {'task_id': taskid}
async def get_status(taskid:str): async def set_redis_task(self, taskid, task_data):
task = self.redis.hgetall(f'task:{taskid}', encoding="utf-8") str_map = {k: json.dumps(v) if not isinstance(v, str) else v for k, v in task_data.items()}
await self.redis.hset(f"{self.taskname}:task:{taskid}", mapping=str_map)
async def get_redis_task(self, taskid):
task = await self.redis.hgetall(f'{self.taskname}:task:{taskid}')
return task
async def get_status(self, taskid:str):
task = await self.get_redis_task(taskid)
if not task: if not task:
return {'error': 'no task'} return {'error': 'no task'}
return task return task
if __name__ == '__main__':
async def main(lt):
x = schedule_interval(5, work1, 6)
print(f'interval worker {x}')
y = schedule_interval(3, work2)
print(f'interval worker {y}')
await asyncio.sleep(10000)
return
tasks = []
for i in range(0, 10):
payload = {
"task": f"task {i}"
}
x = await lt.submit_task(payload)
tasks.append(x)
while True:
if len(tasks) < 1:
break
tasks1 = [i for i in tasks]
for t in tasks1:
task = await lt.get_status(t['task_id'])
print(f'{task}')
if task['status'] in ['SUCCEEDED', 'FAILED']:
tasks = [i for i in tasks if i['task_id'] != t['task_id']]
await asyncio.sleep(2)
lt = LongTasks('redis://127.0.0.1:6379', 'test')
loop = get_event_loop()
loop.create_task(lt.run())
loop.run_until_complete(main(lt))

View File

@ -8,7 +8,7 @@ version = "0.1.0"
description = "long task engine with submit, get_status interface to client" description = "long task engine with submit, get_status interface to client"
authors = [{name = "Yu Moqing", email = "yumoqing@gmail.com"}] authors = [{name = "Yu Moqing", email = "yumoqing@gmail.com"}]
license = {text = "MIT"} license = {text = "MIT"}
dependencies = ["ahserver", "sqlor", "appPublic"] dependencies = ["ahserver", "sqlor", "appPublic", "aioredis" ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["."] where = ["."]