001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.wake.remote.transport.netty; 020 021import io.netty.bootstrap.Bootstrap; 022import io.netty.bootstrap.ServerBootstrap; 023import io.netty.channel.Channel; 024import io.netty.channel.ChannelFuture; 025import io.netty.channel.ChannelOption; 026import io.netty.channel.EventLoopGroup; 027import io.netty.channel.group.ChannelGroup; 028import io.netty.channel.group.DefaultChannelGroup; 029import io.netty.channel.nio.NioEventLoopGroup; 030import io.netty.channel.socket.nio.NioServerSocketChannel; 031import io.netty.channel.socket.nio.NioSocketChannel; 032import io.netty.util.concurrent.GlobalEventExecutor; 033import org.apache.reef.tang.annotations.Parameter; 034import org.apache.reef.wake.EStage; 035import org.apache.reef.wake.EventHandler; 036import org.apache.reef.wake.impl.DefaultThreadFactory; 037import org.apache.reef.wake.remote.Encoder; 038import org.apache.reef.wake.remote.RemoteConfiguration; 039import org.apache.reef.wake.remote.address.LocalAddressProvider; 040import org.apache.reef.wake.remote.exception.RemoteRuntimeException; 041import org.apache.reef.wake.remote.impl.TransportEvent; 042import org.apache.reef.wake.remote.ports.TcpPortProvider; 043import org.apache.reef.wake.remote.transport.Link; 044import org.apache.reef.wake.remote.transport.LinkListener; 045import org.apache.reef.wake.remote.transport.Transport; 046import org.apache.reef.wake.remote.transport.exception.TransportRuntimeException; 047 048import javax.inject.Inject; 049import java.io.IOException; 050import java.net.BindException; 051import java.net.ConnectException; 052import java.net.InetSocketAddress; 053import java.net.SocketAddress; 054import java.util.Iterator; 055import java.util.concurrent.ConcurrentHashMap; 056import java.util.concurrent.ConcurrentMap; 057import java.util.concurrent.atomic.AtomicInteger; 058import java.util.logging.Level; 059import java.util.logging.Logger; 060 061/** 062 * Messaging transport implementation with Netty. 063 */ 064public class NettyMessagingTransport implements Transport { 065 066 private static final String CLASS_NAME = NettyMessagingTransport.class.getName(); 067 private static final Logger LOG = Logger.getLogger(CLASS_NAME); 068 069 private static final int SERVER_BOSS_NUM_THREADS = 3; 070 private static final int SERVER_WORKER_NUM_THREADS = 20; 071 private static final int CLIENT_WORKER_NUM_THREADS = 10; 072 073 private final ConcurrentMap<SocketAddress, LinkReference> addrToLinkRefMap = new ConcurrentHashMap<>(); 074 075 private final EventLoopGroup clientWorkerGroup; 076 private final EventLoopGroup serverBossGroup; 077 private final EventLoopGroup serverWorkerGroup; 078 079 private final Bootstrap clientBootstrap; 080 private final ServerBootstrap serverBootstrap; 081 private final Channel acceptor; 082 083 private final ChannelGroup clientChannelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); 084 private final ChannelGroup serverChannelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); 085 086 private final int serverPort; 087 private final SocketAddress localAddress; 088 089 private final NettyClientEventListener clientEventListener; 090 private final NettyServerEventListener serverEventListener; 091 092 private final int numberOfTries; 093 private final int retryTimeout; 094 /** 095 * Indicates a hostname that isn't set or known. 096 */ 097 public static final String UNKNOWN_HOST_NAME = "##UNKNOWN##"; 098 099 /** 100 * Constructs a messaging transport. 101 * 102 * @param hostAddress the server host address 103 * @param port the server listening port; when it is 0, randomly assign a port number 104 * @param clientStage the client-side stage that handles transport events 105 * @param serverStage the server-side stage that handles transport events 106 * @param numberOfTries the number of tries of connection 107 * @param retryTimeout the timeout of reconnection 108 * @param tcpPortProvider gives an iterator that produces random tcp ports in a range 109 */ 110 @Inject 111 NettyMessagingTransport( 112 @Parameter(RemoteConfiguration.HostAddress.class) final String hostAddress, 113 @Parameter(RemoteConfiguration.Port.class) final int port, 114 @Parameter(RemoteConfiguration.RemoteClientStage.class) final EStage<TransportEvent> clientStage, 115 @Parameter(RemoteConfiguration.RemoteServerStage.class) final EStage<TransportEvent> serverStage, 116 @Parameter(RemoteConfiguration.NumberOfTries.class) final int numberOfTries, 117 @Parameter(RemoteConfiguration.RetryTimeout.class) final int retryTimeout, 118 final TcpPortProvider tcpPortProvider, 119 final LocalAddressProvider localAddressProvider) { 120 121 int p = port; 122 if (p < 0) { 123 throw new RemoteRuntimeException("Invalid server port: " + p); 124 } 125 126 final String host = UNKNOWN_HOST_NAME.equals(hostAddress) ? localAddressProvider.getLocalAddress() : hostAddress; 127 128 this.numberOfTries = numberOfTries; 129 this.retryTimeout = retryTimeout; 130 this.clientEventListener = new NettyClientEventListener(this.addrToLinkRefMap, clientStage); 131 this.serverEventListener = new NettyServerEventListener(this.addrToLinkRefMap, serverStage); 132 133 this.serverBossGroup = new NioEventLoopGroup(SERVER_BOSS_NUM_THREADS, 134 new DefaultThreadFactory(CLASS_NAME + "ServerBoss")); 135 this.serverWorkerGroup = new NioEventLoopGroup(SERVER_WORKER_NUM_THREADS, 136 new DefaultThreadFactory(CLASS_NAME + "ServerWorker")); 137 this.clientWorkerGroup = new NioEventLoopGroup(CLIENT_WORKER_NUM_THREADS, 138 new DefaultThreadFactory(CLASS_NAME + "ClientWorker")); 139 140 this.clientBootstrap = new Bootstrap(); 141 this.clientBootstrap.group(this.clientWorkerGroup) 142 .channel(NioSocketChannel.class) 143 .handler(new NettyChannelInitializer(new NettyDefaultChannelHandlerFactory("client", 144 this.clientChannelGroup, this.clientEventListener))) 145 .option(ChannelOption.SO_REUSEADDR, true) 146 .option(ChannelOption.SO_KEEPALIVE, true); 147 148 this.serverBootstrap = new ServerBootstrap(); 149 this.serverBootstrap.group(this.serverBossGroup, this.serverWorkerGroup) 150 .channel(NioServerSocketChannel.class) 151 .childHandler(new NettyChannelInitializer(new NettyDefaultChannelHandlerFactory("server", 152 this.serverChannelGroup, this.serverEventListener))) 153 .option(ChannelOption.SO_BACKLOG, 128) 154 .option(ChannelOption.SO_REUSEADDR, true) 155 .childOption(ChannelOption.SO_KEEPALIVE, true); 156 157 LOG.log(Level.FINE, "Binding to {0}", p); 158 159 Channel acceptorFound = null; 160 try { 161 if (p > 0) { 162 acceptorFound = this.serverBootstrap.bind(new InetSocketAddress(host, p)).sync().channel(); 163 } else { 164 final Iterator<Integer> ports = tcpPortProvider.iterator(); 165 while (acceptorFound == null) { 166 if (!ports.hasNext()) { 167 throw new IllegalStateException("tcpPortProvider cannot find a free port."); 168 } 169 p = ports.next(); 170 LOG.log(Level.FINEST, "Try port {0}", p); 171 try { 172 acceptorFound = this.serverBootstrap.bind(new InetSocketAddress(host, p)).sync().channel(); 173 } catch (final Exception ex) { 174 if (ex instanceof BindException) { 175 LOG.log(Level.FINEST, "The port {0} is already bound. Try again", p); 176 } else { 177 throw ex; 178 } 179 } 180 } 181 } 182 } catch (final IllegalStateException ex) { 183 final RuntimeException transportException = 184 new TransportRuntimeException("tcpPortProvider failed to return free ports.", ex); 185 LOG.log(Level.SEVERE, "Cannot find a free port with " + tcpPortProvider, transportException); 186 187 this.clientWorkerGroup.shutdownGracefully(); 188 this.serverBossGroup.shutdownGracefully(); 189 this.serverWorkerGroup.shutdownGracefully(); 190 throw transportException; 191 192 } catch (final Exception ex) { 193 final RuntimeException transportException = 194 new TransportRuntimeException("Cannot bind to port " + p, ex); 195 LOG.log(Level.SEVERE, "Cannot bind to port " + p, ex); 196 197 this.clientWorkerGroup.shutdownGracefully(); 198 this.serverBossGroup.shutdownGracefully(); 199 this.serverWorkerGroup.shutdownGracefully(); 200 throw transportException; 201 } 202 203 this.acceptor = acceptorFound; 204 this.serverPort = p; 205 this.localAddress = new InetSocketAddress(host, this.serverPort); 206 207 LOG.log(Level.FINE, "Starting netty transport socket address: {0}", this.localAddress); 208 } 209 210 /** 211 * Closes all channels and releases all resources. 212 */ 213 @Override 214 public void close() throws Exception { 215 216 LOG.log(Level.FINE, "Closing netty transport socket address: {0}", this.localAddress); 217 218 this.clientChannelGroup.close().awaitUninterruptibly(); 219 this.serverChannelGroup.close().awaitUninterruptibly(); 220 this.acceptor.close().sync(); 221 this.clientWorkerGroup.shutdownGracefully(); 222 this.serverBossGroup.shutdownGracefully(); 223 this.serverWorkerGroup.shutdownGracefully(); 224 225 LOG.log(Level.FINE, "Closing netty transport socket address: {0} done", this.localAddress); 226 } 227 228 /** 229 * Returns a link for the remote address if cached; otherwise opens, caches and returns. 230 * When it opens a link for the remote address, only one attempt for the address is made at a given time 231 * 232 * @param remoteAddr the remote socket address 233 * @param encoder the encoder 234 * @param listener the link listener 235 * @return a link associated with the address 236 */ 237 @Override 238 public <T> Link<T> open(final SocketAddress remoteAddr, final Encoder<? super T> encoder, 239 final LinkListener<? super T> listener) throws IOException { 240 241 Link<T> link = null; 242 243 for (int i = 0; i <= this.numberOfTries; ++i) { 244 LinkReference linkRef = this.addrToLinkRefMap.get(remoteAddr); 245 246 if (linkRef != null) { 247 link = (Link<T>) linkRef.getLink(); 248 if (LOG.isLoggable(Level.FINE)) { 249 LOG.log(Level.FINE, "Link {0} for {1} found", new Object[]{link, remoteAddr}); 250 } 251 if (link != null) { 252 return link; 253 } 254 } 255 256 if (i == this.numberOfTries) { 257 // Connection failure 258 throw new ConnectException("Connection to " + remoteAddr + " refused"); 259 } 260 261 LOG.log(Level.FINE, "No cached link for {0} thread {1}", 262 new Object[]{remoteAddr, Thread.currentThread()}); 263 264 // no linkRef 265 final LinkReference newLinkRef = new LinkReference(); 266 final LinkReference prior = this.addrToLinkRefMap.putIfAbsent(remoteAddr, newLinkRef); 267 final AtomicInteger flag = prior != null ? 268 prior.getConnectInProgress() : newLinkRef.getConnectInProgress(); 269 270 synchronized (flag) { 271 if (!flag.compareAndSet(0, 1)) { 272 while (flag.get() == 1) { 273 try { 274 flag.wait(); 275 } catch (final InterruptedException ex) { 276 LOG.log(Level.WARNING, "Wait interrupted", ex); 277 } 278 } 279 } 280 } 281 282 linkRef = this.addrToLinkRefMap.get(remoteAddr); 283 link = (Link<T>) linkRef.getLink(); 284 285 if (link != null) { 286 return link; 287 } 288 289 ChannelFuture connectFuture = null; 290 try { 291 connectFuture = this.clientBootstrap.connect(remoteAddr); 292 connectFuture.syncUninterruptibly(); 293 294 link = new NettyLink<>(connectFuture.channel(), encoder, listener); 295 linkRef.setLink(link); 296 297 synchronized (flag) { 298 flag.compareAndSet(1, 2); 299 flag.notifyAll(); 300 } 301 break; 302 } catch (final Exception e) { 303 if (e.getClass().getSimpleName().compareTo("ConnectException") == 0) { 304 LOG.log(Level.WARNING, "Connection refused. Retry {0} of {1}", 305 new Object[]{i + 1, this.numberOfTries}); 306 synchronized (flag) { 307 flag.compareAndSet(1, 0); 308 flag.notifyAll(); 309 } 310 311 if (i < this.numberOfTries) { 312 try { 313 Thread.sleep(retryTimeout); 314 } catch (final InterruptedException interrupt) { 315 LOG.log(Level.WARNING, "Thread {0} interrupted while sleeping", Thread.currentThread()); 316 } 317 } 318 } else { 319 throw e; 320 } 321 } 322 } 323 324 return link; 325 } 326 327 /** 328 * Returns a link for the remote address if already cached; otherwise, returns null. 329 * 330 * @param remoteAddr the remote address 331 * @return a link if already cached; otherwise, null 332 */ 333 public <T> Link<T> get(final SocketAddress remoteAddr) { 334 final LinkReference linkRef = this.addrToLinkRefMap.get(remoteAddr); 335 return linkRef != null ? (Link<T>) linkRef.getLink() : null; 336 } 337 338 /** 339 * Gets a server local socket address of this transport. 340 * 341 * @return a server local socket address 342 */ 343 @Override 344 public SocketAddress getLocalAddress() { 345 return this.localAddress; 346 } 347 348 /** 349 * Gets a server listening port of this transport. 350 * 351 * @return a listening port number 352 */ 353 @Override 354 public int getListeningPort() { 355 return this.serverPort; 356 } 357 358 /** 359 * Registers the exception event handler. 360 * 361 * @param handler the exception event handler 362 */ 363 @Override 364 public void registerErrorHandler(final EventHandler<Exception> handler) { 365 this.clientEventListener.registerErrorHandler(handler); 366 this.serverEventListener.registerErrorHandler(handler); 367 } 368}