#2794: connection throttle race condition

This commit is contained in:
md_5 2020-05-10 09:44:44 +10:00
parent eeb3c6d3bf
commit 67c2dfd884

View File

@ -9,11 +9,12 @@ import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
public class ConnectionThrottle public class ConnectionThrottle
{ {
private final LoadingCache<InetAddress, Integer> throttle; private final LoadingCache<InetAddress, AtomicInteger> throttle;
private final int throttleLimit; private final int throttleLimit;
public ConnectionThrottle(int throttleTime, int throttleLimit) public ConnectionThrottle(int throttleTime, int throttleLimit)
@ -29,12 +30,12 @@ public class ConnectionThrottle
.concurrencyLevel( Runtime.getRuntime().availableProcessors() ) .concurrencyLevel( Runtime.getRuntime().availableProcessors() )
.initialCapacity( 100 ) .initialCapacity( 100 )
.expireAfterWrite( throttleTime, TimeUnit.MILLISECONDS ) .expireAfterWrite( throttleTime, TimeUnit.MILLISECONDS )
.build( new CacheLoader<InetAddress, Integer>() .build( new CacheLoader<InetAddress, AtomicInteger>()
{ {
@Override @Override
public Integer load(InetAddress key) throws Exception public AtomicInteger load(InetAddress key) throws Exception
{ {
return 0; return new AtomicInteger();
} }
} ); } );
this.throttleLimit = throttleLimit; this.throttleLimit = throttleLimit;
@ -48,8 +49,11 @@ public class ConnectionThrottle
} }
InetAddress address = ( (InetSocketAddress) socketAddress ).getAddress(); InetAddress address = ( (InetSocketAddress) socketAddress ).getAddress();
int throttleCount = throttle.getUnchecked( address ) - 1; AtomicInteger throttleCount = throttle.getIfPresent( address );
throttle.put( address, throttleCount ); if ( throttleCount != null )
{
throttleCount.decrementAndGet();
}
} }
public boolean throttle(SocketAddress socketAddress) public boolean throttle(SocketAddress socketAddress)
@ -60,8 +64,7 @@ public class ConnectionThrottle
} }
InetAddress address = ( (InetSocketAddress) socketAddress ).getAddress(); InetAddress address = ( (InetSocketAddress) socketAddress ).getAddress();
int throttleCount = throttle.getUnchecked( address ) + 1; int throttleCount = throttle.getUnchecked( address ).incrementAndGet();
throttle.put( address, throttleCount );
return throttleCount > throttleLimit; return throttleCount > throttleLimit;
} }