以下是使用python redis lua实现的分布式锁
import functools
import uuid
import time
from typing import Callable, Any, Optional
import redis
# Lua 脚本:原子性获取锁
# KEYS[1]: 锁 Key
# ARGV[1]: token
# ARGV[2]: 过期毫秒
_ACQUIRE_LOCK_LUA = """
if redis.call('set', KEYS[1], ARGV[1], 'NX', 'PX', ARGV[2]) then
return 1
else
return 0
end
"""
# Lua 脚本:安全释放锁
_RELEASE_LOCK_LUA = """
if redis.call('get', KEYS[1]) == ARGV[1] then
return redis.call('del', KEYS[1])
else
return 0
end
"""
class RedisDistributedLock:
def __init__(self, client: redis.Redis, key: str, timeout: int = 10000):
self.client = client
self.key = key
self.timeout = timeout
self.token = None
self._acquire_script = self.client.register_script(_ACQUIRE_LOCK_LUA)
self._release_script = self.client.register_script(_RELEASE_LOCK_LUA)
def acquire(self, blocking: bool = True, wait_timeout: Optional[float] = None) -> bool:
self.token = str(uuid.uuid4())
start = time.time()
while True:
try:
result = self._acquire_script(keys=[self.key], args=[self.token, str(self.timeout)])
if result == 1:
return True
except redis.RedisError:
pass # 网络抖动等可重试
if not blocking:
return False
if wait_timeout is not None and (time.time() - start) >= wait_timeout:
return False
time.sleep(0.05)
def release(self) -> bool:
try:
result = self._release_script(keys=[self.key], args=[self.token])
return result == 1
except redis.RedisError:
return False
def redis_lock(
key_prefix: str,
timeout: int = 10000,
blocking: bool = True,
wait_timeout: Optional[float] = None
) -> Callable:
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
lock_key = f"{key_prefix}:{func.__name__}"
client = redis.Redis(host='localhost', port=6379, db=0)
lock = RedisDistributedLock(client, lock_key, timeout)
if not lock.acquire(blocking=blocking, wait_timeout=wait_timeout):
raise TimeoutError(f"获取分布式锁失败: {lock_key}")
try:
return func(*args, **kwargs)
finally:
if not lock.release():
print(f"警告: 释放锁失败: {lock_key}")
return wrapper
return decorator
# 使用示例
@redis_lock(key_prefix='order', timeout=5000, blocking=True, wait_timeout=2)
def process_order(order_id: int):
print(f"处理订单 {order_id}...开始")
time.sleep(1)
print(f"处理订单 {order_id}...完成")
if __name__ == '__main__':
import threading
def worker(i):
try:
process_order(i)
except TimeoutError as e:
print(e)
threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()