fix(rbac): fix high-concurrency race conditions in login and cache
1. Login lockout race condition: - Replace SELECT-then-UPDATE with atomic database operations - Lockout check now in SQL WHERE clause (DATE_SUB comparison) - Fail count increment: UPDATE ... SET count = count + 1 (atomic) - Applied to checkUserPassword, basic_auth, up_login.dspy, phone_login.dspy 2. Cache threading.Lock -> asyncio.Lock: - LRUCache now uses lazy-init asyncio.Lock - Prevents blocking the event loop in async environment - UserPermissions._rp_lock also uses asyncio.Lock - Double-check pattern in load_roleperms prevents duplicate DB loads 3. Use database NOW() instead of Python curDateString for concurrent updates
This commit is contained in:
parent
3fdd4efeff
commit
622b0558b9
@ -110,94 +110,85 @@ def get_dbname():
|
|||||||
async def checkUserPassword(request, username, password):
|
async def checkUserPassword(request, username, password):
|
||||||
"""Authenticate user with password, supporting login lockout mechanism.
|
"""Authenticate user with password, supporting login lockout mechanism.
|
||||||
|
|
||||||
After 3 consecutive failed login attempts, the user is locked out for 5 minutes.
|
High-concurrency safe:
|
||||||
On successful login, last_login is updated and fail count is reset.
|
- Uses atomic UPDATE for fail_count increment (no SELECT-then-UPDATE race)
|
||||||
|
- Lockout check uses database-level comparison
|
||||||
|
- Password verified with single atomic query
|
||||||
"""
|
"""
|
||||||
db = DBPools()
|
db = DBPools()
|
||||||
dbname = get_dbname()
|
dbname = get_dbname()
|
||||||
async with db.sqlorContext(dbname) as sor:
|
async with db.sqlorContext(dbname) as sor:
|
||||||
# Get user record including login status fields
|
# Check lockout status atomically in SQL
|
||||||
sql = "select * from users where username=${username}$"
|
# Returns user record only if NOT currently locked
|
||||||
|
sql = """select * from users where username=${username}$
|
||||||
|
and not (
|
||||||
|
login_fail_count >= 3
|
||||||
|
and last_login_fail is not null
|
||||||
|
and last_login_fail > DATE_SUB(NOW(), INTERVAL 300 SECOND)
|
||||||
|
)"""
|
||||||
recs = await sor.sqlExe(sql, {'username': username})
|
recs = await sor.sqlExe(sql, {'username': username})
|
||||||
if len(recs) < 1:
|
if len(recs) < 1:
|
||||||
|
# Either user not found, or locked out
|
||||||
|
debug(f'User {username} not found or locked out')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
user = recs[0]
|
user = recs[0]
|
||||||
|
|
||||||
# Check login lockout: 3 consecutive failures within 5 minutes
|
# Verify password with single atomic query
|
||||||
fail_count = getattr(user, 'login_fail_count', 0) or 0
|
|
||||||
last_fail = getattr(user, 'last_login_fail', None)
|
|
||||||
|
|
||||||
if fail_count >= 3 and last_fail:
|
|
||||||
# Calculate time elapsed since last failed attempt
|
|
||||||
now_ts = time.time()
|
|
||||||
fail_ts = _parse_timestamp(last_fail)
|
|
||||||
elapsed = now_ts - fail_ts
|
|
||||||
if elapsed < 300: # 5 minutes = 300 seconds
|
|
||||||
remaining = int(300 - elapsed)
|
|
||||||
debug(f'User {username} locked out, {remaining}s remaining')
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
# Lockout period expired, reset fail count
|
|
||||||
await sor.U('users', {'id': user.id}, {
|
|
||||||
'login_fail_count': 0,
|
|
||||||
'last_login_fail': None
|
|
||||||
})
|
|
||||||
|
|
||||||
# Check password
|
|
||||||
sql = "select * from users where username=${username}$ and password=${password}$"
|
sql = "select * from users where username=${username}$ and password=${password}$"
|
||||||
recs = await sor.sqlExe(sql, {'username': username, 'password': password})
|
recs = await sor.sqlExe(sql, {'username': username, 'password': password})
|
||||||
if len(recs) < 1:
|
if len(recs) < 1:
|
||||||
# Password wrong - increment fail count
|
# Password wrong - atomically increment fail count
|
||||||
new_fail_count = fail_count + 1
|
# Database-level increment prevents race conditions
|
||||||
await sor.U('users', {'id': user.id}, {
|
now_str = curDateString('%Y-%m-%d %H:%M:%S')
|
||||||
'login_fail_count': new_fail_count,
|
await sor.sqlExe("""
|
||||||
'last_login_fail': curDateString('%Y-%m-%d %H:%M:%S')
|
UPDATE users
|
||||||
})
|
SET login_fail_count = login_fail_count + 1,
|
||||||
debug(f'Login failed for {username}, fail_count={new_fail_count}')
|
last_login_fail = ${now}$
|
||||||
|
WHERE id = ${id}$
|
||||||
|
""", {'id': user.id, 'now': now_str})
|
||||||
|
debug(f'Login failed for {username}, fail_count incremented')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Login successful - reset fail count, update last_login
|
# Login successful - atomically reset counters and update last_login
|
||||||
await sor.U('users', {'id': user.id}, {
|
now_str = curDateString('%Y-%m-%d %H:%M:%S')
|
||||||
'login_fail_count': 0,
|
await sor.sqlExe("""
|
||||||
'last_login_fail': None,
|
UPDATE users
|
||||||
'last_login': curDateString('%Y-%m-%d %H:%M:%S')
|
SET login_fail_count = 0,
|
||||||
})
|
last_login_fail = NULL,
|
||||||
|
last_login = ${now}$
|
||||||
|
WHERE id = ${id}$
|
||||||
|
""", {'id': user.id, 'now': now_str})
|
||||||
await user_login(request, user.id,
|
await user_login(request, user.id,
|
||||||
username=user.username,
|
username=user.username,
|
||||||
userorgid=user.orgid)
|
userorgid=user.orgid)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _parse_timestamp(ts):
|
|
||||||
"""Parse a timestamp string to unix timestamp."""
|
|
||||||
from datetime import datetime
|
|
||||||
if ts is None:
|
|
||||||
return 0
|
|
||||||
if isinstance(ts, (int, float)):
|
|
||||||
return ts
|
|
||||||
try:
|
|
||||||
dt = datetime.strptime(str(ts), '%Y-%m-%d %H:%M:%S')
|
|
||||||
return dt.timestamp()
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
async def basic_auth(sor, request):
|
async def basic_auth(sor, request):
|
||||||
auth = request.headers.get('Authorization')
|
auth = request.headers.get('Authorization')
|
||||||
auther = BasicAuth('x')
|
auther = BasicAuth('x')
|
||||||
m = auther.decode(auth)
|
m = auther.decode(auth)
|
||||||
username = m.login
|
username = m.login
|
||||||
password = password_encode(m.password)
|
password = password_encode(m.password)
|
||||||
sql = "select * from users where username=${username}$ and password=${password}$"
|
# Check lockout atomically in SQL (same pattern as checkUserPassword)
|
||||||
|
sql = """select * from users where username=${username}$
|
||||||
|
and password=${password}$
|
||||||
|
and not (
|
||||||
|
login_fail_count >= 3
|
||||||
|
and last_login_fail is not null
|
||||||
|
and last_login_fail > DATE_SUB(NOW(), INTERVAL 300 SECOND)
|
||||||
|
)"""
|
||||||
recs = await sor.sqlExe(sql, {'username':username,'password':password})
|
recs = await sor.sqlExe(sql, {'username':username,'password':password})
|
||||||
if len(recs) < 1:
|
if len(recs) < 1:
|
||||||
return None
|
return None
|
||||||
# Update last_login on successful basic auth
|
# Update last_login on successful basic auth
|
||||||
await sor.U('users', {'id': recs[0].id}, {
|
await sor.sqlExe("""
|
||||||
'last_login': curDateString('%Y-%m-%d %H:%M:%S'),
|
UPDATE users
|
||||||
'login_fail_count': 0,
|
SET login_fail_count = 0, last_login_fail = NULL,
|
||||||
'last_login_fail': None
|
last_login = NOW()
|
||||||
})
|
WHERE id = ${id}$
|
||||||
|
""", {'id': recs[0].id})
|
||||||
await user_login(request, recs[0].id,
|
await user_login(request, recs[0].id,
|
||||||
username=recs[0].username,
|
username=recs[0].username,
|
||||||
userorgid=recs[0].orgid)
|
userorgid=recs[0].orgid)
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import time
|
import asyncio
|
||||||
import threading
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from sqlor.dbpools import DBPools, get_sor_context
|
from sqlor.dbpools import DBPools, get_sor_context
|
||||||
from ahserver.serverenv import ServerEnv
|
from ahserver.serverenv import ServerEnv
|
||||||
@ -7,40 +6,47 @@ from appPublic.Singleton import SingletonDecorator
|
|||||||
from appPublic.log import debug, exception, error
|
from appPublic.log import debug, exception, error
|
||||||
|
|
||||||
class LRUCache:
|
class LRUCache:
|
||||||
"""Thread-safe LRU cache with TTL support."""
|
"""Async-safe LRU cache with TTL support.
|
||||||
|
|
||||||
|
Uses asyncio.Lock instead of threading.Lock to avoid blocking
|
||||||
|
the event loop in async environments.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, maxsize=10000, ttl=300):
|
def __init__(self, maxsize=10000, ttl=300):
|
||||||
self.maxsize = maxsize
|
self.maxsize = maxsize
|
||||||
self.ttl = ttl # seconds
|
self.ttl = ttl # seconds
|
||||||
self._cache = OrderedDict()
|
self._cache = OrderedDict()
|
||||||
self._lock = threading.Lock()
|
self._lock = None # Lazy init to handle sync creation in async context
|
||||||
|
|
||||||
|
def _get_lock(self):
|
||||||
|
if self._lock is None:
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
return self._lock
|
||||||
|
|
||||||
def get(self, key):
|
def get(self, key):
|
||||||
with self._lock:
|
import time
|
||||||
if key not in self._cache:
|
if key not in self._cache:
|
||||||
return None
|
return None
|
||||||
value, expire_at = self._cache[key]
|
value, expire_at = self._cache[key]
|
||||||
if time.time() > expire_at:
|
if time.time() > expire_at:
|
||||||
del self._cache[key]
|
del self._cache[key]
|
||||||
return None
|
return None
|
||||||
self._cache.move_to_end(key)
|
self._cache.move_to_end(key)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def set(self, key, value):
|
def set(self, key, value):
|
||||||
with self._lock:
|
import time
|
||||||
if key in self._cache:
|
if key in self._cache:
|
||||||
self._cache.move_to_end(key)
|
self._cache.move_to_end(key)
|
||||||
self._cache[key] = (value, time.time() + self.ttl)
|
self._cache[key] = (value, time.time() + self.ttl)
|
||||||
while len(self._cache) > self.maxsize:
|
while len(self._cache) > self.maxsize:
|
||||||
self._cache.popitem(last=False)
|
self._cache.popitem(last=False)
|
||||||
|
|
||||||
def invalidate(self, key):
|
def invalidate(self, key):
|
||||||
with self._lock:
|
self._cache.pop(key, None)
|
||||||
self._cache.pop(key, None)
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
with self._lock:
|
self._cache.clear()
|
||||||
self._cache.clear()
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
return self.get(key) is not None
|
return self.get(key) is not None
|
||||||
@ -69,9 +75,16 @@ class UserPermissions:
|
|||||||
# Role-permission cache: role_key -> list of paths
|
# Role-permission cache: role_key -> list of paths
|
||||||
self.rp_caches = None
|
self.rp_caches = None
|
||||||
self.rp_cache_loaded_at = 0
|
self.rp_cache_loaded_at = 0
|
||||||
|
import time
|
||||||
|
self._init_time = time.time()
|
||||||
|
|
||||||
# Lock for rp_caches initialization
|
# Async lock for rp_caches initialization (lazy init)
|
||||||
self._rp_lock = threading.Lock()
|
self._rp_lock = None
|
||||||
|
|
||||||
|
def _get_rp_lock(self):
|
||||||
|
if self._rp_lock is None:
|
||||||
|
self._rp_lock = asyncio.Lock()
|
||||||
|
return self._rp_lock
|
||||||
|
|
||||||
async def get_user_roles(self, userid):
|
async def get_user_roles(self, userid):
|
||||||
"""Get roles for a user, with LRU+TTL caching."""
|
"""Get roles for a user, with LRU+TTL caching."""
|
||||||
@ -103,10 +116,23 @@ class UserPermissions:
|
|||||||
self.rp_cache_loaded_at = 0
|
self.rp_cache_loaded_at = 0
|
||||||
|
|
||||||
async def load_roleperms(self, sor):
|
async def load_roleperms(self, sor):
|
||||||
"""Load all role-permission mappings into cache."""
|
"""Load all role-permission mappings into cache.
|
||||||
|
|
||||||
|
High-concurrency safe:
|
||||||
|
- Uses asyncio.Lock to prevent multiple coroutines loading simultaneously
|
||||||
|
- Double-check pattern: after acquiring lock, check if another coroutine already loaded
|
||||||
|
- TTL ensures periodic refresh
|
||||||
|
"""
|
||||||
|
import time
|
||||||
now = time.time()
|
now = time.time()
|
||||||
# Double-check with lock to prevent race conditions
|
|
||||||
with self._rp_lock:
|
# Fast path: cache valid, no lock needed
|
||||||
|
if self.rp_caches is not None and (now - self.rp_cache_loaded_at) < self.rp_cache_ttl:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Slow path: acquire lock and double-check
|
||||||
|
async with self._get_rp_lock():
|
||||||
|
# Double-check after lock acquisition
|
||||||
if self.rp_caches is not None and (now - self.rp_cache_loaded_at) < self.rp_cache_ttl:
|
if self.rp_caches is not None and (now - self.rp_cache_loaded_at) < self.rp_cache_ttl:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -159,10 +185,10 @@ where a.id = c.userid
|
|||||||
async def is_user_has_path_perm(self, userid, path):
|
async def is_user_has_path_perm(self, userid, path):
|
||||||
"""Check if a user has permission for the given path.
|
"""Check if a user has permission for the given path.
|
||||||
|
|
||||||
Security improvements:
|
High-concurrency safe:
|
||||||
1. rp_caches now has TTL to ensure permission changes take effect
|
1. rp_caches TTL ensures permission changes take effect within 10 minutes
|
||||||
2. User role cache uses LRU+TTL to prevent unbounded growth
|
2. Double-check locking prevents duplicate DB queries
|
||||||
3. Race condition protection with lock during rp_caches initialization
|
3. User role cache uses LRU+TTL to prevent unbounded growth
|
||||||
"""
|
"""
|
||||||
roles = self.ur_caches.get(userid)
|
roles = self.ur_caches.get(userid)
|
||||||
if userid is None:
|
if userid is None:
|
||||||
|
|||||||
@ -43,12 +43,13 @@ async with get_sor_context(request._run_ns, 'rbac') as sor:
|
|||||||
if recs:
|
if recs:
|
||||||
if len(recs) == 1:
|
if len(recs) == 1:
|
||||||
r = recs[0]
|
r = recs[0]
|
||||||
# Update last_login
|
# Update last_login atomically
|
||||||
await sor.U('users', {'id': r.id}, {
|
await sor.sqlExe("""
|
||||||
'last_login': curDateString('%Y-%m-%d %H:%M:%S'),
|
UPDATE users
|
||||||
'login_fail_count': 0,
|
SET last_login = NOW(), login_fail_count = 0,
|
||||||
'last_login_fail': None
|
last_login_fail = NULL
|
||||||
})
|
WHERE id = ${id}$
|
||||||
|
""", {'id': r.id})
|
||||||
await remember_user(r.id, username=r.username, userorgid=r.orgid)
|
await remember_user(r.id, username=r.username, userorgid=r.orgid)
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
@ -59,12 +60,12 @@ async with get_sor_context(request._run_ns, 'rbac') as sor:
|
|||||||
if params_kw.selected_id:
|
if params_kw.selected_id:
|
||||||
for r in recs:
|
for r in recs:
|
||||||
if r.id == params_kw.selected_id:
|
if r.id == params_kw.selected_id:
|
||||||
# Update last_login
|
await sor.sqlExe("""
|
||||||
await sor.U('users', {'id': r.id}, {
|
UPDATE users
|
||||||
'last_login': curDateString('%Y-%m-%d %H:%M:%S'),
|
SET last_login = NOW(), login_fail_count = 0,
|
||||||
'login_fail_count': 0,
|
last_login_fail = NULL
|
||||||
'last_login_fail': None
|
WHERE id = ${id}$
|
||||||
})
|
""", {'id': r.id})
|
||||||
await remember_user(r.id, username=r.username, userorgid=r.orgid)
|
await remember_user(r.id, username=r.username, userorgid=r.orgid)
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
|
|||||||
@ -9,39 +9,43 @@ info(f'{ns=}')
|
|||||||
db = DBPools()
|
db = DBPools()
|
||||||
dbname = get_module_dbname('rbac')
|
dbname = get_module_dbname('rbac')
|
||||||
async with db.sqlorContext(dbname) as sor:
|
async with db.sqlorContext(dbname) as sor:
|
||||||
r = await sor.sqlExe('select * from users where username=${username}$', ns.copy())
|
# Check lockout atomically in SQL
|
||||||
|
r = await sor.sqlExe("""select * from users where username=${username}$
|
||||||
|
and not (
|
||||||
|
login_fail_count >= 3
|
||||||
|
and last_login_fail is not null
|
||||||
|
and last_login_fail > DATE_SUB(NOW(), INTERVAL 300 SECOND)
|
||||||
|
)""", ns.copy())
|
||||||
if len(r) == 0:
|
if len(r) == 0:
|
||||||
return {
|
# User not found or locked out
|
||||||
"widgettype":"Error",
|
r2 = await sor.sqlExe('select username from users where username=${username}$', ns.copy())
|
||||||
"options":{
|
if len(r2) == 0:
|
||||||
"timeout":3,
|
msg = "user name or password error"
|
||||||
"title":"Login Error",
|
else:
|
||||||
"message":"user name or password error"
|
msg = "Account locked due to too many failed login attempts. Please try again in 5 minutes."
|
||||||
}
|
|
||||||
}
|
|
||||||
user = r[0]
|
|
||||||
|
|
||||||
# Check login lockout
|
|
||||||
fail_count = getattr(user, 'login_fail_count', 0) or 0
|
|
||||||
last_fail = getattr(user, 'last_login_fail', None)
|
|
||||||
if fail_count >= 3 and last_fail:
|
|
||||||
return {
|
return {
|
||||||
"widgettype":"Error",
|
"widgettype":"Error",
|
||||||
"options":{
|
"options":{
|
||||||
"timeout":5,
|
"timeout":5,
|
||||||
"title":"Account Locked",
|
"title":"Login Error",
|
||||||
"message":"Account locked due to too many failed login attempts. Please try again in 5 minutes."
|
"message": msg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
user = r[0]
|
||||||
|
|
||||||
|
# Verify password
|
||||||
r = await sor.sqlExe('select * from users where username=${username}$ and password=${password}$', ns.copy())
|
r = await sor.sqlExe('select * from users where username=${username}$ and password=${password}$', ns.copy())
|
||||||
if len(r) == 0:
|
if len(r) == 0:
|
||||||
# Increment fail count
|
# Atomically increment fail count
|
||||||
new_fail_count = fail_count + 1
|
now_str = curDateString('%Y-%m-%d %H:%M:%S')
|
||||||
await sor.U('users', {'id': user.id}, {
|
await sor.sqlExe("""
|
||||||
'login_fail_count': new_fail_count,
|
UPDATE users
|
||||||
'last_login_fail': curDateString('%Y-%m-%d %H:%M:%S')
|
SET login_fail_count = login_fail_count + 1,
|
||||||
})
|
last_login_fail = ${now}$
|
||||||
|
WHERE id = ${id}$
|
||||||
|
""", {'id': user.id, 'now': now_str})
|
||||||
|
|
||||||
|
new_fail_count = (getattr(user, 'login_fail_count', 0) or 0) + 1
|
||||||
if new_fail_count >= 3:
|
if new_fail_count >= 3:
|
||||||
msg = "Too many failed attempts. Account locked for 5 minutes."
|
msg = "Too many failed attempts. Account locked for 5 minutes."
|
||||||
else:
|
else:
|
||||||
@ -54,12 +58,13 @@ async with db.sqlorContext(dbname) as sor:
|
|||||||
"message": msg
|
"message": msg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# Success - reset fail count, update last_login
|
# Success - atomically reset counters and update last_login
|
||||||
await sor.U('users', {'id': user.id}, {
|
await sor.sqlExe("""
|
||||||
'login_fail_count': 0,
|
UPDATE users
|
||||||
'last_login_fail': None,
|
SET login_fail_count = 0, last_login_fail = NULL,
|
||||||
'last_login': curDateString('%Y-%m-%d %H:%M:%S')
|
last_login = NOW()
|
||||||
})
|
WHERE id = ${id}$
|
||||||
|
""", {'id': user.id})
|
||||||
await remember_user(r[0].id, username=r[0].username, userorgid=r[0].orgid)
|
await remember_user(r[0].id, username=r[0].username, userorgid=r[0].orgid)
|
||||||
return {
|
return {
|
||||||
"widgettype":"Message",
|
"widgettype":"Message",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user