设计思路
- 状态表示:使用一个整数变量来表示锁的状态,低16位用于记录写锁的重入次数,高16位用于记录读锁的重入次数。
- 公平性实现:使用一个队列来保存等待获取锁的线程,根据公平性策略决定下一个获取锁的线程。
- 可重入实现:记录持有锁的线程,当同一个线程再次获取锁时,增加重入次数。
关键代码实现
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
public class MyReadWriteLock {
private final Sync sync;
public MyReadWriteLock(boolean fair) {
sync = fair? new FairSync() : new NonFairSync();
}
public Lock readLock() { return new ReadLock(); }
public Lock writeLock() { return new WriteLock(); }
private static final class HoldCounter {
int count = 0;
final long threadId = Thread.currentThread().getId();
}
private static final class ThreadLocalHoldCounter
extends ThreadLocal<HoldCounter> {
public HoldCounter initialValue() {
return new HoldCounter();
}
}
private static class Sync extends AbstractQueuedSynchronizer {
static final int SHARED_SHIFT = 16;
static final int SHARED_UNIT = (1 << SHARED_SHIFT);
static final int MAX_COUNT = (1 << SHARED_SHIFT) - 1;
static final int EXCLUSIVE_MASK = (1 << SHARED_SHIFT) - 1;
ThreadLocalHoldCounter readHolds = new ThreadLocalHoldCounter();
HoldCounter cachedHoldCounter;
int sharedCount(int c) { return c >>> SHARED_SHIFT; }
int exclusiveCount(int c) { return c & EXCLUSIVE_MASK; }
abstract static class ConditionObject implements Condition, java.io.Serializable {
private static final long serialVersionUID = 1173984872572414699L;
// 省略部分实现
}
final ConditionObject newCondition() {
return new ConditionObject();
}
final boolean writerShouldBlock() {
return hasQueuedPredecessors();
}
final boolean readerShouldBlock() {
return hasQueuedPredecessors();
}
@Override
protected final boolean tryAcquire(int acquires) {
Thread current = Thread.currentThread();
int c = getState();
int w = exclusiveCount(c);
if (c != 0) {
if (w == 0 || current != getExclusiveOwnerThread())
return false;
if (w + exclusiveCount(acquires) > MAX_COUNT)
throw new Error("Maximum lock count exceeded");
setState(c + acquires);
return true;
}
if (writerShouldBlock() ||
!compareAndSetState(c, c + acquires))
return false;
setExclusiveOwnerThread(current);
return true;
}
@Override
protected final boolean tryRelease(int releases) {
if (!isHeldExclusively())
throw new IllegalMonitorStateException();
int nextc = getState() - releases;
boolean free = exclusiveCount(nextc) == 0;
if (free)
setExclusiveOwnerThread(null);
setState(nextc);
return free;
}
@Override
protected final int tryAcquireShared(int unused) {
for (;;) {
int c = getState();
if (exclusiveCount(c) != 0 &&
getExclusiveOwnerThread() != Thread.currentThread())
return -1;
int r = sharedCount(c);
if (readerShouldBlock() ||
r == MAX_COUNT)
return -1;
if (compareAndSetState(c, c + SHARED_UNIT)) {
if (r == 0) {
firstReader = Thread.currentThread();
firstReaderHoldCount = 1;
} else if (firstReader == Thread.currentThread()) {
firstReaderHoldCount++;
} else {
HoldCounter rh = readHolds.get();
if (rh.count == 0)
readHolds.set(new HoldCounter());
rh.count++;
}
return 1;
}
}
}
@Override
protected final boolean tryReleaseShared(int unused) {
Thread current = Thread.currentThread();
if (firstReader == current) {
if (firstReaderHoldCount == 1)
firstReader = null;
else
firstReaderHoldCount--;
} else {
HoldCounter rh = readHolds.get();
if (rh.count <= 0)
throw new IllegalMonitorStateException();
if (rh.count == 1)
readHolds.remove();
rh.count--;
}
for (;;) {
int c = getState();
int nextc = c - SHARED_UNIT;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
final boolean isWriteLocked() {
return exclusiveCount(getState()) != 0;
}
final int getReadLockCount() {
return sharedCount(getState());
}
}
private static class NonFairSync extends Sync {
final boolean writerShouldBlock() {
return false;
}
final boolean readerShouldBlock() {
return false;
}
}
private static class FairSync extends Sync {
final boolean writerShouldBlock() {
return hasQueuedPredecessors();
}
final boolean readerShouldBlock() {
return hasQueuedPredecessors();
}
}
public class ReadLock implements Lock {
public void lock() {
sync.acquireShared(1);
}
public void lockInterruptibly() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public boolean tryLock() {
return sync.tryAcquireShared(1) >= 0;
}
public boolean tryLock(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public void unlock() {
sync.releaseShared(1);
}
public Condition newCondition() {
throw new UnsupportedOperationException();
}
}
public class WriteLock implements Lock {
public void lock() {
sync.acquire(1);
}
public void lockInterruptibly() throws InterruptedException {
sync.acquireInterruptibly(1);
}
public boolean tryLock() {
return sync.tryAcquire(1);
}
public boolean tryLock(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireNanos(1, unit.toNanos(timeout));
}
public void unlock() {
sync.release(1);
}
public Condition newCondition() {
return sync.newCondition();
}
}
}