- execute_sql.dspy: SELECT-only SQL execution with pagination - read_log.dspy: read last N lines from whitelisted log files - tail_log.dspy: incremental log monitoring from last position - RBAC: developer role only - Security: SQL validation, log file whitelist
201 lines
5.6 KiB
Python
201 lines
5.6 KiB
Python
"""bugfix 模块 - 开发者调试工具"""
|
||
import os
|
||
import re
|
||
from ahserver.serverenv import ServerEnv
|
||
from sqlor.dbpools import get_sor_context
|
||
from appPublic.log import debug, exception
|
||
|
||
|
||
MODULE_NAME = "bugfix"
|
||
|
||
# 允许的日志文件白名单
|
||
ALLOWED_LOGS = ['sage.log', 'backend_accounting.log']
|
||
|
||
# 日志监控位置记录 {filename: {'lines': N, 'mtime': timestamp}}
|
||
_log_tails = {}
|
||
|
||
|
||
def _validate_select_only(sql):
|
||
"""验证 SQL 只能是 SELECT 语句"""
|
||
if not sql:
|
||
return False, "SQL 不能为空"
|
||
|
||
# 去除前后空格和注释
|
||
cleaned = sql.strip()
|
||
# 移除 -- 注释
|
||
cleaned = re.sub(r'--.*$', '', cleaned, flags=re.MULTILINE)
|
||
# 移除 /* */ 注释
|
||
cleaned = re.sub(r'/\*.*?\*/', '', cleaned, flags=re.DOTALL)
|
||
cleaned = cleaned.strip()
|
||
|
||
if not cleaned:
|
||
return False, "SQL 不能为空"
|
||
|
||
# 检查是否以 SELECT 开头
|
||
if not cleaned.upper().startswith('SELECT'):
|
||
return False, "仅允许执行 SELECT 语句"
|
||
|
||
# 黑名单检查 - 禁止危险操作
|
||
blacklist = [
|
||
'INSERT', 'UPDATE', 'DELETE', 'DROP', 'ALTER', 'CREATE',
|
||
'TRUNCATE', 'REPLACE', 'MERGE', 'GRANT', 'REVOKE',
|
||
'EXEC', 'EXECUTE', 'CALL', 'INTO', 'LOAD_FILE', 'INTO OUTFILE',
|
||
'INTO DUMPFILE'
|
||
]
|
||
upper = cleaned.upper()
|
||
for kw in blacklist:
|
||
# 用 \b 匹配完整单词
|
||
if re.search(r'\b' + kw + r'\b', upper):
|
||
return False, f"禁止使用 {kw} 语句"
|
||
|
||
return True, "OK"
|
||
|
||
|
||
async def execute_select_sql(sql, page=1, rows=20):
|
||
"""执行 SELECT SQL 查询(sqlor 标准分页)
|
||
|
||
Args:
|
||
sql: SQL 语句(仅允许 SELECT)
|
||
page: 页码,从 1 开始
|
||
rows: 每页条数,默认 20
|
||
|
||
Returns:
|
||
dict: {status, total, rows}
|
||
"""
|
||
valid, msg = _validate_select_only(sql)
|
||
if not valid:
|
||
return {'status': 'error', 'error': msg}
|
||
|
||
try:
|
||
env = ServerEnv()
|
||
dbname = env.get_module_dbname('bugfix')
|
||
async with get_sor_context(env, dbname) as sor:
|
||
ns = {'page': page, 'rows': rows}
|
||
result = await sor.sqlExe(sql, ns)
|
||
return {
|
||
'status': 'ok',
|
||
'total': result.get('total', 0),
|
||
'rows': result.get('rows', [])
|
||
}
|
||
except Exception as e:
|
||
exception(f'execute_select_sql error: {e}')
|
||
return {'status': 'error', 'error': str(e)}
|
||
|
||
|
||
async def read_log_file(filename, lines=500):
|
||
"""读取日志文件
|
||
|
||
Args:
|
||
filename: 日志文件名(必须在白名单中)
|
||
lines: 读取最后 N 行,默认 500
|
||
|
||
Returns:
|
||
dict: {status, data, filename}
|
||
"""
|
||
if filename not in ALLOWED_LOGS:
|
||
return {'status': 'error', 'error': f'不允许读取 {filename},仅允许: {", ".join(ALLOWED_LOGS)}'}
|
||
|
||
try:
|
||
env = ServerEnv()
|
||
# 日志目录在 sage/logs/ 下
|
||
log_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'logs')
|
||
# 备用路径
|
||
if not os.path.isdir(log_dir):
|
||
log_dir = os.path.expanduser('~/repos/sage/logs')
|
||
|
||
log_path = os.path.join(log_dir, filename)
|
||
if not os.path.isfile(log_path):
|
||
return {'status': 'error', 'error': f'日志文件不存在: {log_path}'}
|
||
|
||
# 读取最后 N 行
|
||
with open(log_path, 'r', encoding='utf-8', errors='replace') as f:
|
||
all_lines = f.readlines()
|
||
|
||
start = max(0, len(all_lines) - lines)
|
||
content = ''.join(all_lines[start:])
|
||
|
||
return {
|
||
'status': 'ok',
|
||
'filename': filename,
|
||
'total_lines': len(all_lines),
|
||
'returned_lines': len(all_lines) - start,
|
||
'content': content
|
||
}
|
||
except Exception as e:
|
||
exception(f'read_log_file error: {e}')
|
||
return {'status': 'error', 'error': str(e)}
|
||
|
||
|
||
async def tail_log_file(filename, reset=False):
|
||
"""日志监控 - 从上次读取位置继续读新增内容
|
||
|
||
Args:
|
||
filename: 日志文件名(必须在白名单中)
|
||
reset: True=重置位置到文件末尾,下次从末尾开始监控
|
||
|
||
Returns:
|
||
dict: {status, filename, new_lines, content, total_lines}
|
||
"""
|
||
if filename not in ALLOWED_LOGS:
|
||
return {'status': 'error', 'error': f'不允许读取 {filename},仅允许: {", ".join(ALLOWED_LOGS)}'}
|
||
|
||
try:
|
||
env = ServerEnv()
|
||
log_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'logs')
|
||
if not os.path.isdir(log_dir):
|
||
log_dir = os.path.expanduser('~/repos/sage/logs')
|
||
|
||
log_path = os.path.join(log_dir, filename)
|
||
if not os.path.isfile(log_path):
|
||
return {'status': 'error', 'error': f'日志文件不存在: {log_path}'}
|
||
|
||
# 获取文件修改时间,如果文件被替换则重置位置
|
||
mtime = os.path.getmtime(log_path)
|
||
last = _log_tails.get(filename)
|
||
|
||
with open(log_path, 'r', encoding='utf-8', errors='replace') as f:
|
||
all_lines = f.readlines()
|
||
|
||
total = len(all_lines)
|
||
|
||
# reset=True 或文件被替换(mtime变了且行数变少)时,跳到末尾
|
||
if reset or (last and last.get('mtime') != mtime and total < last.get('lines', 0)):
|
||
_log_tails[filename] = {'lines': total, 'mtime': mtime}
|
||
return {
|
||
'status': 'ok',
|
||
'filename': filename,
|
||
'new_lines': 0,
|
||
'content': '',
|
||
'total_lines': total,
|
||
'reset': True
|
||
}
|
||
|
||
# 从上次位置继续读
|
||
start = last.get('lines', 0) if last else max(0, total - 100) # 首次读最后100行
|
||
new_content = ''.join(all_lines[start:])
|
||
new_count = total - start
|
||
|
||
# 更新位置
|
||
_log_tails[filename] = {'lines': total, 'mtime': mtime}
|
||
|
||
return {
|
||
'status': 'ok',
|
||
'filename': filename,
|
||
'new_lines': new_count,
|
||
'content': new_content,
|
||
'total_lines': total
|
||
}
|
||
except Exception as e:
|
||
exception(f'tail_log_file error: {e}')
|
||
return {'status': 'error', 'error': str(e)}
|
||
|
||
|
||
def load_bugfix():
|
||
"""注册函数到 ServerEnv"""
|
||
env = ServerEnv()
|
||
env.execute_select_sql = execute_select_sql
|
||
env.read_log_file = read_log_file
|
||
env.tail_log_file = tail_log_file
|
||
debug(f'[bugfix] module loaded')
|
||
return True
|