diff --git a/protocol/src/main/java/net/md_5/bungee/protocol/Protocol.java b/protocol/src/main/java/net/md_5/bungee/protocol/Protocol.java index 2378754e..e1db3eae 100644 --- a/protocol/src/main/java/net/md_5/bungee/protocol/Protocol.java +++ b/protocol/src/main/java/net/md_5/bungee/protocol/Protocol.java @@ -1,6 +1,7 @@ package net.md_5.bungee.protocol; import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; import gnu.trove.map.TIntObjectMap; import gnu.trove.map.TObjectIntMap; import gnu.trove.map.hash.TIntObjectHashMap; @@ -217,8 +218,8 @@ public enum Protocol /*========================================================================*/ public static final int MAX_PACKET_ID = 0xFF; /*========================================================================*/ - public final DirectionData TO_SERVER = new DirectionData( ProtocolConstants.Direction.TO_SERVER ); - public final DirectionData TO_CLIENT = new DirectionData( ProtocolConstants.Direction.TO_CLIENT ); + public final DirectionData TO_SERVER = new DirectionData(this, ProtocolConstants.Direction.TO_SERVER ); + public final DirectionData TO_CLIENT = new DirectionData(this, ProtocolConstants.Direction.TO_CLIENT ); @RequiredArgsConstructor private static class ProtocolData { @@ -242,6 +243,7 @@ public enum Protocol public static class DirectionData { + private final Protocol protocolPhase; private final TIntObjectMap protocols = new TIntObjectHashMap<>(); { for ( int protocol : ProtocolConstants.SUPPORTED_VERSION_IDS ) @@ -263,12 +265,22 @@ public enum Protocol @Getter private final ProtocolConstants.Direction direction; - public final DefinedPacket createPacket(int id, int protocol) + private ProtocolData getProtocolData(int version) { - ProtocolData protocolData = protocols.get( protocol ); + ProtocolData protocol = protocols.get( version ); + if ( protocol == null && ( protocolPhase == Protocol.HANDSHAKE || protocolPhase == Protocol.STATUS ) ) + { + protocol = Iterables.getFirst( protocols.valueCollection(), null ); + } + return protocol; + } + + public final DefinedPacket createPacket(int id, int version) + { + ProtocolData protocolData = getProtocolData( version ); if (protocolData == null) { - throw new BadPacketException( "Unsupported protocol" ); + throw new BadPacketException( "Unsupported protocol version" ); } if ( id > MAX_PACKET_ID ) { @@ -319,13 +331,13 @@ public enum Protocol } } - final int getId(Class packet, int protocol) + final int getId(Class packet, int version) { - ProtocolData protocolData = protocols.get( protocol ); + ProtocolData protocolData = getProtocolData( version ); if (protocolData == null) { - throw new BadPacketException( "Unsupported protocol" ); + throw new BadPacketException( "Unsupported protocol version" ); } Preconditions.checkArgument( protocolData.packetMap.containsKey( packet ), "Cannot get ID for packet " + packet ); diff --git a/proxy/src/main/java/net/md_5/bungee/connection/InitialHandler.java b/proxy/src/main/java/net/md_5/bungee/connection/InitialHandler.java index d714836c..dab6ef2d 100644 --- a/proxy/src/main/java/net/md_5/bungee/connection/InitialHandler.java +++ b/proxy/src/main/java/net/md_5/bungee/connection/InitialHandler.java @@ -293,6 +293,12 @@ public class InitialHandler extends PacketHandler implements PendingConnection thisState = State.USERNAME; ch.setProtocol( Protocol.LOGIN ); + if ( !ProtocolConstants.SUPPORTED_VERSION_IDS.contains( handshake.getProtocolVersion() ) ) + { + disconnect( bungee.getTranslation( "outdated_server" ) ); + return; + } + if ( bungee.getConnectionThrottle() != null && bungee.getConnectionThrottle().throttle( getAddress().getAddress() ) ) { disconnect( bungee.getTranslation( "join_throttle_kick", TimeUnit.MILLISECONDS.toSeconds( bungee.getConfig().getThrottle() ) ) ); @@ -309,12 +315,6 @@ public class InitialHandler extends PacketHandler implements PendingConnection Preconditions.checkState( thisState == State.USERNAME, "Not expecting USERNAME" ); this.loginRequest = loginRequest; - if ( !ProtocolConstants.SUPPORTED_VERSION_IDS.contains( handshake.getProtocolVersion() ) ) - { - disconnect( bungee.getTranslation( "outdated_server" ) ); - return; - } - if ( getName().contains( "." ) ) { disconnect( bungee.getTranslation( "name_invalid" ) );